关于tensorflow:我通过-tensorflow-预测了博客的粉丝数

10次阅读

共计 2942 个字符,预计需要花费 8 分钟才能阅读完成。

前言

因为最近接触了 tensorflow.js,出于试一下的心态,想通过线性回归预测一下博客的粉丝走向和数量,后果翻车了。尽管场景用错中央,然而整个实战办法用在身高体重等方面的预测还是有可行性,所以就记录下来了。

需要

依据某博客或论坛,抓取一下博主的拜访总量和粉丝总量,剖析其关联,训练数据,最初通过输出指定拜访数量预测吸粉总数。

Tensorflow.js

Tensorflow.js 是一个能够在浏览器或 Node 环境利用 JavaScript 语法运行深度学习。让前端就能够实现相似依据图片类型的含糊搜寻,语音辨认管制网页,图片的人像辨认等性能,既加重服务器训练压力,也爱护了用户隐衷 (在非凡场景下,不必将图片传到服务器后做人像标识)。

技术清单

  1. tensorflow.js
  2. parcel
  3. tfjs-vis

实战

实战是须要本地有 Node 环境,并且装置了 npm 等包管理工具,对于这些的装置这里就略过了。次要是我的项目的搭起,线性回归的编码以及运行后果。
. 我的项目搭建

(1). 创立我的项目目录和 package.json


{
  "name": "tensorflow-test",
  "version": "1.0.0",
  "description": "","main":"index.js","dependencies": {"@tensorflow-models/speech-commands":"^0.4.0","@tensorflow/tfjs":"^1.3.1","@tensorflow/tfjs-node":"^1.2.9","@tensorflow/tfjs-vis":"^1.2.0"},"devDependencies": {},"scripts": {"test":"echo \"Error: no test specified\" && exit 1"},"author":"",
  "license": "ISC",
  "browserslist": ["last 1 Chrome version"]
}

(2). 切换到当前目录,运行 npm install 进行装置
(3). 在当前目录下创立目录和运行文件。

(4). 装置 parcel,一个打包工具。npm install -g parcel-bundler

  1. 编码

(1). 页面须要有数据训练过程图和模型下载按钮。

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title> 粉丝数量预测 </title>
</head>
<body>
  <button onclick="download()"> 保留模型 </button>
</body>
<script src="script.js"></script>
</html>

(2). 线性回归根本流程

(3). 编码

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {

    // 浏览量 - 粉丝量
    const flows = [20333,25759,101190,86334,265252,1366198,166114,109979,371423,1291843,1239191,225711,1163189,882702,31415,678478,545108,1304729,73479,2515393,1714555,344847,3147811,1626033,3702785,377376,258472,312769,540292,616665,1207153,2577882,11564515,28231,328984,585611,595275];
    const fans = [0,494,6618,3411,12023,7791,65,7109,14014,11840,1202,266,7915,7503,2216,33265,284,34849,4188,41721,25384,1269,62207,20754,192980,28601,7645,1779,13112,10824,4612,548,2311,44,34,259,150];

    tfvis.render.scatterplot({name: 'csdn 浏览量和粉丝量关联'},
        {values: flows.map((x, i) => ({x,y:fans[i]}))},
        {xAxisDomain: [20333, 11600000],
            yAxisDomain: [0, 200000]
        }
    );

    // 对数据集进行归一化解决
    const inputs = tf.tensor(flows).sub(20333).div(11544182);
    const lables = tf.tensor(fans).div(192980);

    const model = tf.sequential();

    // 给模型增加层级和神经元
    //model.add(tf.layers.dense({unit: 1, inputShape: [1]}));
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

    // 配置模型训练,设置损失计算函数 (均方差等),优化器的 SGD 配置
    model.compile({loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1)});

    // 开始训练
    // await model.fit(
    //     inputs,
    //     lables,
    //     {
    //         batchSize:37,
 //            epochs:200,
    //         callbacks: tfvis.show.fitCallbacks(//                 {name: '训练过程'},
    //                 ['loss', 'val_loss', 'acc', 'val_acc'],
 //                      {callbacks: ['onEpochEnd'] }
    //         )
    //     }
    // );

    await model.fit(
        inputs,
        lables,
        {
            batchSize:37,
            epochs:200,
            callbacks: tfvis.show.fitCallbacks({ name: '训练过程'},
                ['loss']
            )
        }
    );

    // 模型预测,输出浏览量输入预测的粉丝数
    const output = model.predict(tf.tensor([165265]).sub(20333).div(11544182));
    //const output = model.predict(tf.tensor([180]).sub(150).div(20));

    alert('165265 预测粉丝数'+output.mul(192980).dataSync()[0]);


    // 保留模型
  window.download = async () => {await model.save('downloads://my-model');
  }


};

(4). 打包并运行 parcel tf_test/index.html

(5). 运行成果

正文完
 0