目录
什么是模型?
什么是神经网络?
训练模型
用 TensorFlow.js 训练模型
用 TensorFlow.js 进行预测
在 TensorFlow.js 中使用预训练的模型
导入 Keras 模型
为什么要用在浏览器中?
总结
首页 web前端 js教程 怎样用 TensorFlow.js 创建基本的 AI 模型?

怎样用 TensorFlow.js 创建基本的 AI 模型?

Nov 10, 2020 pm 05:54 PM
javascript tensorflow 前端

怎样用 TensorFlow.js 创建基本的 AI 模型?

在本文中我们来研究怎样用 TensorFlow.js 创建基本的 AI 模型,并使用更复杂的模型实现一些有趣的功能。我只是刚刚开始接触人工智能,尽管不需要深入的人工智能知识,但还是需要搞清楚一些概念才行。

什么是模型?

真实世界是很复杂的,我们需要对其进行简化才能理解,可以用通过模型来进行简化,这种模型有很多种:比如世界地图,或者图表等。

1.jpg

比如要建立一个用来表示房子出租价格与房屋面积关系的模型:首先要收集一些数据:

房间数量 价格
3 131000
3 125000
4 235000
4 265000
5 535000

然后,把这些数据显示在二维图形上,把每个参数(价格,房间数量)都做为 1 个维度:

2.gif

然后我们可以画一条线,并预测 更多房间的房屋出租价格。这种模型被称为线性回归,它是机器学习中最简单的模型之一。不过这个模型还不够好:

  1. 只有 5 个数据,所以不够可靠。
  2. 只有 2 个参数(价格,房间),但是还有更多可能会影响价格的因素:比如地区、装修情况等。

可以通过添加更多的数据来解决第一个问题,比如一百万个。对于第二个问题,可以添加更多维度。在二维图表中可以很容易理解数据并画一条线,在三维图中可以使用平面:

3.jpeg

但是当数据的维度是三维呢四维甚至是 1000000 维的时候,大脑就没有办法在图表上对其进行可视化了,但是可以在维度超过三维时通过数学来计算超平面,而神经网络就是为了解决这个问题而生的。

什么是神经网络?

要解什么是神经网络,需要知道什么是神经元。真正的神经元看上去是这样的:

4.gif

神经元由以下几部分组成:

  • 树突:这是数据的输入端。
  • 轴突:这是输出端。
  • 突触(未在图中表示):该结构允许一个神经元与另一个神经元之间进行通信。它负责在轴突的神经末梢和附近神经元的树突之间传递电信号。这些突触是学习的关键,因为它们会根据用途增减电活动。

机器学习中的神经元(简化):

5.jpg

  • Inputs(输入) :输入的参数。
  • Weights(权重) :像突触一样,用来通过调节神经元更好的建立线性回归。
  • Linear function(线性函数) :每个神经元就像一个线性回归函数,对于线性回归模型,只需要一个神经元够了。
  • Activation function(激活函数) :可以用一些激活函数来将输出从标量改为另一个非线性函数。常见的有 sigmoid、RELU 和 tanh。
  • Output(输出) :应用激活函数后的计算输出。

激活函数是非常有用的,神经网络的强大主要归功于它。假如没有任何激活功能,就不可能得到智能的神经元网络。因为尽管你的神经网络中有多个神经元,但神经网络的输出始终将是线性回归。所以需要一些机制来将各个线性回归变形为非线性的来解决非线性问题。通过激活函数可以将这些线性函数转换为非线性函数:

6.jpg

训练模型

正如 2D 线性回归的例子所描述的,只需要在图中画一条线就可以预测新数据了。尽管如此,“深度学习”的思想是让我们的神经网络学会画这条线。对于一条简单的线,可以用只有一个神经元的非常简单的神经网络即可,但是对于想要做更复杂事情的模型,例如对两组数据进行分类这种操作,需要通过“训练”使网络学习怎样得到下面的内容:

7.png

这个过程并不复杂,因为它是二维的。每个模型都用来描述一个世界,但是“训练”的概念在所有模型中都非常相似。第一步是绘制一条随机线,并在算法中通过迭代对其进行改进,每次迭代中过程中修正错误。这种优化算法名为 Gradient Descent(梯度下降)(有着相同概念的算法还有更复杂的 SGD 或 ADAM 等)。每种算法(线性回归,对数回归等)都有不同的成本函数来度量误差,成本函数会始终收敛于某个点。它可以是凸函数或凹函数,但是最终要收敛在 0% 误差的点上。我们的目标就是实现这一点。

8.png

当使用梯度下降算法时,先从其成本函数的某个随机点开始,但是我们不知道它究竟在什么地方!这就像你被蒙着眼睛丢在一座山上,想要下山的话必须一步一步地走到最低点。如果地形是不规则的(例如凹函数),则下降会更加复杂。

在这里不会深入解释“梯度下降”算法,只需要记住这是训练 AI 模型过程中最小化预测误差的优化算法就足够了。这种算法需要大量的时间和 GPU 进行矩阵乘法。通常在第一次执行时很难达到这个收敛点,因此需要修正一些超参数,例如学习率(learning rate)或添加正则化(regularization)。在梯度下降迭代之后,当误差接近 0% 时,会接近收敛点。这样就创建了模型,接下来就能够进行预测了。

9.gif

用 TensorFlow.js 训练模型

TensorFlow.js 提供了一种创建神经网络的简便方法。首先用 trainModel 方法创建一个 LinearModel 类。我们将使用顺序模型。顺序模型是其中一层的输出是下一层的输入的模型,即当模型拓扑是简单的层级结构,没有分支或跳过。在 trainModel 方法内部定义层(我们仅使用一层,因为它足以解决线性回归问题):

import * as tf from '@tensorflow/tfjs';

/**
* 线性模型类
*/
export default class LinearModel {
  /**
 * 训练模型
 */
  async trainModel(xs, ys){
    const layers = tf.layers.dense({
      units: 1, // 输出空间的纬度
      inputShape: [1], // 只有一个参数
    });
    const lossAndOptimizer = {
      loss: 'meanSquaredError',
      optimizer: 'sgd', // 随机梯度下降
    };

    this.linearModel = tf.sequential();
    this.linearModel.add(layers); // 添加一层
    this.linearModel.compile(lossAndOptimizer);

    // 开始模型训练
    await this.linearModel.fit(
      tf.tensor1d(xs),
      tf.tensor1d(ys),
    );
  }

  //...
}
登录后复制

使用这个类进行训练:

const model = new LinearModel()

// xs 与 ys 是 数组成员(x-axis 与 y-axis)
await model.trainModel(xs, ys)
登录后复制

训练结束后就可以开始预测了。

用 TensorFlow.js 进行预测

尽管在训练模型时需要事先定义一些超参数,但是进行一般的预测还是很容易的。通过下面的代码就够了:

import * as tf from '@tensorflow/tfjs';

export default class LinearModel {
  ... //前面训练模型的代码

  predict(value){
    return Array.from(
      this.linearModel
      .predict(tf.tensor2d([value], [1, 1]))
      .dataSync()
    )
  }
}
登录后复制

现在就可以预测了:

const prediction = model.predict(500) // 预测数字 500
console.log(prediction) // => 420.423
登录后复制

10.gif

在 TensorFlow.js 中使用预训练的模型

训练模型是最难的部分。首先对数据进行标准化来进行训练,还需要正确的设定所有超参数等等。对于咱们初学者,可以直接用那些预先训练好的模型。 TensorFlow.js 可以使用很多预训练的模型,还可以导入使用 TensorFlow 或 Keras 创建的外部模型。例如可以直接用 posenet 模型(实时人体姿态评估)做一些有意思的项目:

11.gif

这个 Demo 的代码:https://github.com/aralroca/posenet-d3

它用起来很容易:

import * as posenet from '@tensorflow-models/posenet'

// 设置一些常数
const imageScaleFactor = 0.5
const outputStride = 16
const flipHorizontal = true
const weight = 0.5

// 加载模型
const net = await posenet.load(weight)

// 进行预测
const poses = await net.estimateSinglePose(
  imageElement,
  imageScaleFactor,
  flipHorizontal,
  outputStride
)
登录后复制

这个 JSON 是 pose 变量:

{
  "score": 0.32371445304906,
  "keypoints": [
    {
      "position": {
        "y": 76.291801452637,
        "x": 253.36747741699
      },
      "part": "nose",
      "score": 0.99539834260941
    },
    {
      "position": {
        "y": 71.10383605957,
        "x": 253.54365539551
      },
      "part": "leftEye",
      "score": 0.98781454563141
    }
    // 后面还有: rightEye, leftEar, rightEar, leftShoulder, rightShoulder
    // leftElbow, rightElbow, leftWrist, rightWrist, leftHip, rightHip,
    // leftKnee, rightKnee, leftAnkle, rightAnkle...
  ]
}
登录后复制

从官方的 demo 可以看得到,用这个模型可以开发出很多有趣的项目。

怎样用 TensorFlow.js 创建基本的 AI 模型?

这个项目的源代码:https://github.com/aralroca/fishFollow-posenet-tfjs

导入 Keras 模型

可以把外部模型导入 TensorFlow.js。下面是一个用 Keras 模型(h5格式)进行数字识别的程序。首先要用 tfjs_converter 对模型的格式进行转换。

pip install tensorflowjs
登录后复制

使用转换器:

tensorflowjs_converter --input_format keras keras/cnn.h5 src/assets
登录后复制

最后,把模型导入到 JS 代码中:

// 载入模型
const model = await tf.loadModel('./assets/model.json')

// 准备图片
let img = tf.fromPixels(imageData, 1)
img = img.reshape([1, 28, 28, 1])
img = tf.cast(img, 'float32')

// 进行预测
const output = model.predict(img)
登录后复制

只需要几行代码行就完成了。当然还可以在代码中添加更多的逻辑来实现更多功能,例如可以把数字写在 canvas 上,然后得到其图像来进行预测。

13.gif

这个项目的源代码:https://github.com/aralroca/MNIST_React_TensorFlowJS

为什么要用在浏览器中?

由于设备的不同,在浏览器中训练模型时,效率可能很低。用 TensorFlow.js 利用 WebGL 在后台训练模型,比用 Python 版的 TensorFlow 慢 1.5 ~ 2倍。

但是在 TensorFlow.js 出现之前,没有能直接在浏览器中使用机器学习模型的 API,现在则可以在浏览器应用中离线训练和使用模型。而且预测速度更快,因为不需要向服务器发送请求。另一个好处是成本低,因为所有这些计算都是在客户端完成的。

总结

  • 模型是表示现实世界的一种简化方式,可以使用它来进行预测。
  • 可以用神经网络创建模型。
  • TensorFlow.js 是创建神经网络的简便工具。

英文原文地址:https://aralroca.com/blog/first-steps-with-tensorflowjs

作者:Aral Roca

更多编程相关知识,请访问:编程课程!!

以上是怎样用 TensorFlow.js 创建基本的 AI 模型?的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
4 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
4 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
4 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.聊天命令以及如何使用它们
4 周前 By 尊渡假赌尊渡假赌尊渡假赌

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

WebSocket与JavaScript:实现实时监控系统的关键技术 WebSocket与JavaScript:实现实时监控系统的关键技术 Dec 17, 2023 pm 05:30 PM

WebSocket与JavaScript:实现实时监控系统的关键技术引言:随着互联网技术的快速发展,实时监控系统在各个领域中得到了广泛的应用。而实现实时监控的关键技术之一就是WebSocket与JavaScript的结合使用。本文将介绍WebSocket与JavaScript在实时监控系统中的应用,并给出代码示例,详细解释其实现原理。一、WebSocket技

PHP与Vue:完美搭档的前端开发利器 PHP与Vue:完美搭档的前端开发利器 Mar 16, 2024 pm 12:09 PM

PHP与Vue:完美搭档的前端开发利器在当今互联网高速发展的时代,前端开发变得愈发重要。随着用户对网站和应用的体验要求越来越高,前端开发人员需要使用更加高效和灵活的工具来创建响应式和交互式的界面。PHP和Vue.js作为前端开发领域的两个重要技术,搭配起来可以称得上是完美的利器。本文将探讨PHP和Vue的结合,以及详细的代码示例,帮助读者更好地理解和应用这两

前端面试官常问的问题 前端面试官常问的问题 Mar 19, 2024 pm 02:24 PM

在前端开发面试中,常见问题涵盖广泛,包括HTML/CSS基础、JavaScript基础、框架和库、项目经验、算法和数据结构、性能优化、跨域请求、前端工程化、设计模式以及新技术和趋势。面试官的问题旨在评估候选人的技术技能、项目经验以及对行业趋势的理解。因此,应试者应充分准备这些方面,以展现自己的能力和专业知识。

JavaScript和WebSocket:打造高效的实时天气预报系统 JavaScript和WebSocket:打造高效的实时天气预报系统 Dec 17, 2023 pm 05:13 PM

JavaScript和WebSocket:打造高效的实时天气预报系统引言:如今,天气预报的准确性对于日常生活以及决策制定具有重要意义。随着技术的发展,我们可以通过实时获取天气数据来提供更准确可靠的天气预报。在本文中,我们将学习如何使用JavaScript和WebSocket技术,来构建一个高效的实时天气预报系统。本文将通过具体的代码示例来展示实现的过程。We

简易JavaScript教程:获取HTTP状态码的方法 简易JavaScript教程:获取HTTP状态码的方法 Jan 05, 2024 pm 06:08 PM

JavaScript教程:如何获取HTTP状态码,需要具体代码示例前言:在Web开发中,经常会涉及到与服务器进行数据交互的场景。在与服务器进行通信时,我们经常需要获取返回的HTTP状态码来判断操作是否成功,根据不同的状态码来进行相应的处理。本篇文章将教你如何使用JavaScript获取HTTP状态码,并提供一些实用的代码示例。使用XMLHttpRequest

Django是前端还是后端?一探究竟! Django是前端还是后端?一探究竟! Jan 19, 2024 am 08:37 AM

Django是一个Python编写的web应用框架,它强调快速开发和干净方法。尽管Django是一个web框架,但是要回答Django是前端还是后端这个问题,需要深入理解前后端的概念。前端是指用户直接和交互的界面,后端是指服务器端的程序,他们通过HTTP协议进行数据的交互。在前端和后端分离的情况下,前后端程序可以独立开发,分别实现业务逻辑和交互效果,数据的交

Go语言前端技术探秘:前端开发新视野 Go语言前端技术探秘:前端开发新视野 Mar 28, 2024 pm 01:06 PM

Go语言作为一种快速、高效的编程语言,在后端开发领域广受欢迎。然而,很少有人将Go语言与前端开发联系起来。事实上,使用Go语言进行前端开发不仅可以提高效率,还能为开发者带来全新的视野。本文将探讨使用Go语言进行前端开发的可能性,并提供具体的代码示例,帮助读者更好地了解这一领域。在传统的前端开发中,通常会使用JavaScript、HTML和CSS来构建用户界面

Django:前端和后端开发都能搞定的神奇框架! Django:前端和后端开发都能搞定的神奇框架! Jan 19, 2024 am 08:52 AM

Django:前端和后端开发都能搞定的神奇框架!Django是一个高效、可扩展的Web应用程序框架。它能够支持多种Web开发模式,包括MVC和MTV,可以轻松地开发出高质量的Web应用程序。Django不仅支持后端开发,还能够快速构建出前端的界面,通过模板语言,实现灵活的视图展示。Django把前端开发和后端开发融合成了一种无缝的整合,让开发人员不必专门学习

See all articles