前言
       后面曾经通过采集拿到了图片,并且也手动对图片做了标注。接下来就要通过 Tensorflow.js 基于 mobileNet 训练模型,最初就能够实现在采集中对图片进行主动分类了。
       这种性能在利用场景里就比拟多了,比方图标素材站点,用户通过上传一个图标,零碎会主动匹配出类似的图标,还有二手平台,用户通过上传闲置物品图片,平台主动给出分类等,这些也都是后期对海量图片进行了标注训练而失去一个损失率极低的模型。上面就通过简答的代码实现一个小的动漫分类。
环境
Node
Http-Server
Parcel
Tensorflow
编码
- 训练模型
1.1. 创立我的项目,装置依赖包
npm install @tensorflow/tfjs –legacy-peer-deps
npm install @tensorflow/tfjs-node-gpu –legacy-peer-deps
1.2. 全局装置 Http-Server
npm install i http-server
1.3. 下载 mobileNet 模型文件 (网上有下载)
1.4. 根目录下启动 Http 服务 (开启跨域),用于 mobileNet 和训练后果的模型可拜访
http-server –cors -p 8080
1.5. 创立训练执行脚本 run.js
const tf = require('@tensorflow/tfjs-node-gpu');
const getData = require('./data');
const TRAIN_PATH = './ 动漫分类 /train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';
(async () => {const { ds, classes} = await getData(TRAIN_PATH, OUT_PUT);
console.log(ds, classes);
// 引入他人训练好的模型
const mobilenet = await tf.loadLayersModel(MOBILENET_URL);
// 查看模型构造
mobilenet.summary();
const model = tf.sequential();
// 截断模型,复用了 86 个层
for (let i = 0; i < 86; ++i) {const layer = mobilenet.layers[i];
layer.trainable = false;
model.add(layer);
}
// 降维,摊平数据
model.add(tf.layers.flatten());
// 设置全连贯层
model.add(tf.layers.dense({
units: 10,
activation: 'relu'// 设置激活函数,用于解决非线性问题
}));
model.add(tf.layers.dense({
units: classes.length,
activation: 'softmax'// 用于多分类问题
}));
// 设置损失函数,优化器
model.compile({
loss: 'sparseCategoricalCrossentropy',
optimizer: tf.train.adam(),
metrics:['acc']
});
// 训练模型
await model.fitDataset(ds, { epochs: 20});
// 保留模型
await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();
1.6. 创立图片与 Tensor 转换库 data.js
const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");
const img2x = (imgPath) => {const buffer = fs.readFileSync(imgPath);
// 革除数据
return tf.tidy(() => {
// 把图片转成 tensor
const imgt = tf.node.decodeImage(new Uint8Array(buffer), 3);
// 调整图片大小
const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);
// 归一化
return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
});
}
const getData = async (traindir, output) => {let classes = fs.readdirSync(traindir, 'utf-8');
fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));
const data = [];
classes.forEach((dir, dirIndex) => {fs.readdirSync(`${traindir}/${dir}`)
.filter(n => n.match(/jpg$/))
.slice(0, 1000)
.forEach(filename => {const imgPath = `${traindir}/${dir}/${filename}`;
data.push({imgPath, dirIndex});
});
});
console.log(data);
// 打乱训练程序,进步准确度
tf.util.shuffle(data);
const ds = tf.data.generator(function* () {
const count = data.length;
const batchSize = 32;
for (let start = 0; start < count; start += batchSize) {const end = Math.min(start + batchSize, count);
console.log('以后批次', start);
yield tf.tidy(() => {const inputs = [];
const labels = [];
for (let j = start; j < end; ++j) {const { imgPath, dirIndex} = data[j];
const x = img2x(imgPath);
inputs.push(x);
labels.push(dirIndex);
}
const xs = tf.concat(inputs);
const ys = tf.tensor(labels);
return {xs, ys};
});
}
});
return {ds, classes};
}
module.exports = getData;
1.7. 运行执行文件
node run.js
- 调用模型
2.1. 全局装置 parcel
npm install i parcel
2.2. 创立页面 index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<br>
2.3. 创立模型调用预测脚本 script.js
import * as tf from '@tensorflow/tfjs';
import {img2x, file2img} from './utils';
const MODEL_PATH = 'http://127.0.0.1:8080/t7';
const CLASSES = ["假面骑士","奥特曼","海贼王","火影忍者","龙珠"];
window.onload = async () => {const model = await tf.loadLayersModel(MODEL_PATH + '/output/model.json');
window.predict = async (file) => {const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {const x = img2x(img);
return model.predict(x);
});
const index = pred.argMax(1).dataSync()[0];
console.log(pred.argMax(1).dataSync());
let predictStr = "";
if (typeof CLASSES[index] == 'undefined') {predictStr = BRAND_CLASSES[index];
} else {predictStr = CLASSES[index];
}
setTimeout(() => {alert(` 预测后果:${predictStr}`);
}, 0);
};
};
2.4. 创立图片 tensor 格局转换库 utils.js
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){return tf.tidy(() => {const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
2.5. 打包我的项目并运行
parcel index.html
2.6. 运行成果
留神
- 模型训练过程报错 Input to reshape is a tensor with 50176 values, but the requested shape has 150528
1.1. 起因
        张量 reshape 不对,理论输出元素个数与所需矩阵元素个数不统一,就是采集过去的图片有多种图片格式,而不同格局的通道不同 (jpg3 通道,png4 通道,灰色图片 1 通道),在将图片转换 tensor 时与代码里的张量形态不匹配。
.2. 解决办法
        一种办法是删除灰色或 png 图片,其二是批改代码 tf.node.decodeImage (new Uint8Array (buffer), 3)