|
网站内容均来自网络,本站只提供信息平台,如有侵权请联系删除,谢谢!
先把效果图给大家放上来
个人觉得效果还行。识别不太准确是因为这个 app学习图片的时间太短(电脑太卡)。
(笔者是 window10) 安装运行环境:
- npm install --global windows-build-tools
复制代码 (这个时间很漫长。。。)
- npm install @tensorflow/tfjs-node
复制代码 (这个时间很漫长。。。)
项目目录如下
train文件夹 index.js(入口文件)- const tf = require('@tensorflow/tfjs-node')
- const getData = require('./data')
- const TRAIN_DIR = '../垃圾分类/train'
- const OUTPUT_DIR = '../outputDir'
- const MOBILENET_URL = 'http://ai-sample.oss-cn-hangzhou.aliyuncs.com/pipcook/models/mobilenet/web_model/model.json'
- const main = async () => {
- // 加载数据
- const { ds, classes} = await getData(TRAIN_DIR, OUTPUT_DIR)
- // 定义模型
- const mobilenet = await tf.loadLayersModel(MOBILENET_URL)
- mobilenet.summary()
- // console.log(mobilenet.layers.map((l, i) => [l.name, i]))
- const model = tf.sequential()
- for (let i = 0; i <= 86; i += 1) {
- const layer = mobilenet.layers[i]
- layer.trainable = false
- model.add(layer)
- }
- model.add(tf.layers.flatten())
- model.add(tf.layers.dense({
- units: 10,
- activation: 'relu'
- }))
- model.add(tf.layers.dense({
- units: classes.length,
- activation: 'softmax'
- }))
- // 训练模型
- model.compile({
- loss: 'sparseCategoricalCrossentropy',
- optimizer: tf.train.adam(),
- metrics: ['acc']
- })
- await model.fitDataset(ds, { epochs: 20 })
- await model.save(`file://${process.cwd()}/${OUTPUT_DIR}`)
- }
- main()
复制代码 data.js(处理数据)- const fs = require('fs')
- const tf = require('@tensorflow/tfjs-node')
- const img2x = (imgPath) => {
- const buffer = fs.readFileSync(imgPath)
- return tf.tidy(() => {
- const imgTs = tf.node.decodeImage(new Uint8Array(buffer))
- const imgTsResized = tf.image.resizeBilinear(imgTs, [224, 224])
- return imgTsResized.toFloat().sub(255/2).div(255/2).reshape([1, 224, 224, 3])
- })
- }
- const getData = async (trainDir, outputDir) => {
- const classes = fs.readdirSync(trainDir)
- fs.writeFileSync(`${outputDir}/classes.json`, JSON.stringify(classes))
- const data = []
- classes.forEach((dir, dirIndex) => {
- fs.readdirSync(`${trainDir}/${dir}`)
- .filter(n => n.match(/jpg$/))
- .slice(0, 10)
- .forEach(filename => {
- console.log('读取', dir, filename)
- const imgPath = `${trainDir}/${dir}/${filename}`
- data.push({ imgPath, dirIndex })
- })
- })
- tf.util.shuffle(data)
- const ds = tf.data.generator(function* () {
- const count = data.length
- const batchSize = 32
- for (let start = 0; start < count; start += batchSize) {
- const end = Math.min(start + batchSize, count)
- yield tf.tidy(() => {
- const inputs = []
- const labels = []
- for (let j = start; j < end; j += 1) {
- const { imgPath, dirIndex } = data[j]
- const x = img2x(imgPath)
- inputs.push(x)
- labels.push(dirIndex)
- }
- const xs = tf.concat(inputs)
- const ys = tf.tensor(labels)
- return { xs, ys }
- })
- }
- })
- return {
- ds,
- classes
- }
- }
- module.exports = getData
复制代码 安装一些运行项目需要的插件
app 文件夹- import React, { PureComponent } from 'react'
- import { Button, Progress, Spin, Empty } from 'antd'
- import 'antd/dist/antd.css'
- import * as tf from '@tensorflow/tfjs'
- import { file2img, img2x } from './utils'
- import intro from './intro'
- const DATA_URL = 'http://127.0.0.1:8080/'
- class App extends PureComponent {
- state = {}
- async componentDidMount() {
- this.model = await tf.loadLayersModel(DATA_URL + '/model.json')
- // this.model.summary()
- this.CLASSES = await fetch(DATA_URL + '/classes.json').then(res => res.json())
- }
- predict = async (file) => {
- const img = await file2img(file)
- this.setState({
- imgSrc: img.src,
- isLoading: true
- })
- setTimeout(() => {
- const pred = tf.tidy(() => {
- const x = img2x(img)
- return this.model.predict(x)
- })
- const results = pred.arraySync()[0]
- .map((score, i) => ({score, label: this.CLASSES[i]}))
- .sort((a, b) => b.score - a.score)
- this.setState({
- results,
- isLoading: false
- })
- }, 0)
- }
- renderResult = (item) => {
- const finalScore = Math.round(item.score * 100)
- return (
- <tr key={item.label}>
- <td style={{ width: 80, padding: '5px 0' }}>{item.label}</td>
- <td>
- <Progress percent={finalScore} status={finalScore === 100 ? 'success' : 'normal'} />
- </td>
- </tr>
- )
- }
- render() {
- const { imgSrc, results, isLoading } = this.state
- const finalItem = results && {...results[0], ...intro[results[0].label]}
- return (
- <div style={{padding: 20}}>
- <span
- style={{ color: '#cccccc', textAlign: 'center', fontSize: 12, display: 'block' }}
- >识别可能不准确</span>
- <Button
- type="primary"
- size="large"
- style={{width: '100%'}}
- onClick={() => this.upload.click()}
- >
- 选择图片识别
- </Button>
- <input
- type="file"
- onChange={e => this.predict(e.target.files[0])}
- ref={el => {this.upload = el}}
- style={{ display: 'none' }}
- />
- {
- !results && !imgSrc && <Empty style={{ marginTop: 40 }} />
- }
- {imgSrc && <div style={{ marginTop: 20, textAlign: 'center' }}>
- <img src={imgSrc} style={{ maxWidth: '100%' }} />
- </div>}
- {finalItem && <div style={{marginTop: 20}}>识别结果: </div>}
- {finalItem && <div style={{display: 'flex', alignItems: 'flex-start', marginTop: 20}}>
- <img
- src={finalItem.icon}
- width={120}
- />
- <div>
- <h2 style={{color: finalItem.color}}>
- {finalItem.label}
- </h2>
- <div style={{color: finalItem.color}}>
- {finalItem.intro}
- </div>
- </div>
- </div>}
- {
- isLoading && <Spin size="large" style={{display: 'flex', justifyContent: 'center', alignItems: 'center', marginTop: 40 }} />
- }
- {results && <div style={{ marginTop: 20 }}>
- <table style={{width: '100%'}}>
- <tbody>
- <tr>
- <td>类别</td>
- <td>匹配度</td>
- </tr>
- {results.map(this.renderResult)}
- </tbody>
- </table>
- </div>}
- </div>
- )
- }
- }
- export default App
复制代码 index.html- <!DOCTYPE html>
- <html>
- <head>
- <title>垃圾分类</title>
- <meta name="viewport" content="width=device-width, inital-scale=1">
- </head>
- <body>
- <div id="app"></div>
- <script src="./index.js"></script>
- </body>
- </html>
复制代码 index.js- import React from 'react'
- import ReactDOM from 'react-dom'
- import App from './App'
- ReactDOM.render(<App />, document.querySelector('#app'))
复制代码 intro.js- export default {
- '可回收物': {
- icon: 'https://lajifenleiapp.com/static/svg/1_3F6BA8.svg',
- color: '#3f6ba8',
- intro: '是指在日常生活中或者为日常生活提供服务的活动中产生的,已经失去原有全部或者部分使用价值,回收后经过再加工可以成为生产原料或者经过整理可以再利用的物品,包括废纸类、塑料类、玻璃类、金属类、织物类等。'
- },
- '有害垃圾': {
- icon: 'https://lajifenleiapp.com/static/svg/2v_B43953.svg',
- color: '#b43953',
- intro: '是指生活垃圾中对人体健康或者自然环境造成直接或者潜在危害的物质,包括废充电电池、废扣式电池、废灯管、弃置药品、废杀虫剂(容器)、废油漆(容器)、废日用化学品、废水银产品、废旧电器以及电子产品等。'
- },
- '厨余垃圾': {
- icon: 'https://lajifenleiapp.com/static/svg/3v_48925B.svg',
- color: '#48925b',
- intro: '是指居民日常生活中产生的有机易腐垃圾,包括菜叶、剩菜、剩饭、果皮、蛋壳、茶渣、骨头等。'
- },
- '其他垃圾': {
- icon: 'https://lajifenleiapp.com/static/svg/4_89918B.svg',
- color: '#89918b',
- intro: '是指除可回收物、有害垃圾和厨余垃圾之外的,混杂、污染、难分类的其他生活垃圾。'
- }
- }
复制代码 utils.js- import * as tf from '@tensorflow/tfjs'
- export const file2img = async (f) => {
- return new Promise(reslove => {
- const reader = new FileReader()
- reader.readAsDataURL(f)
- reader.onload = (e) => {
- const img = document.createElement('img')
- img.src = e.target.result
- img.width = 224
- img.height = 224
- img.onload = () => { reslove(img) }
- }
- })
- }
- export function img2x(imgEl) {
- return tf.tidy(() => {
- return tf.browser.fromPixels(imgEl)
- .toFloat().sub(255/2).div(255/2)
- .reshape([1, 224, 224, 3])
- })
- }
复制代码 运行项目代码之前,我们需要先在 train 目录下运行,node index.js,生成 model.json 以供识别系统使用。之后需要在根目录下运行 hs outputDir --cors, 使得生成的 model.json 运行在 http 环境下,之后才可以运行 npm start ,不然项目是会报错的。
主要的代码就是上面这些。前面笔者也说了。自己对这方面完全不懂,所以也无法解说其中的代码。各位感兴趣就自己研究一下。代码地址奉上。
gitee.com/suiboyu/gar…
总结
到此这篇关于如何利用React实现图片识别App的文章就介绍到这了,更多相关React图片识别App内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有账号?立即注册
x
免责声明
1. 本论坛所提供的信息均来自网络,本网站只提供平台服务,所有账号发表的言论与本网站无关。
2. 其他单位或个人在使用、转载或引用本文时,必须事先获得该帖子作者和本人的同意。
3. 本帖部分内容转载自其他媒体,但并不代表本人赞同其观点和对其真实性负责。
4. 如有侵权,请立即联系,本网站将及时删除相关内容。
|