关于tensorflow:Tensorflow基本概念

5次阅读

共计 1572 个字符,预计需要花费 4 分钟才能阅读完成。

Tensorflow 基本概念

1.Tensor

Tensorflow 张量,是 Tensorflow 中最根底的概念,也是最次要的数据结构。它是一个 N 维数组。

2.Variable

Tensorflow 变量,个别用于示意图中的各计算参数,包含矩阵,向量等。它在图中有固定的地位。

3.placeholder

Tensorflow 占位符,用于示意输入输出数据的格局,容许传入指定的类型和形态的数据。

4.Session

Tensorflow 会话,在 Tensorflow 中是计算图的具体执行者,与图进行理论的交互。

5.Operation

Tensorflow 操作,是 Tensorflow 图中的节点,它的输出和输入都是Tensor。它的操作都是实现各种操作,包含算数操作、矩阵操作、神经网络构建操作等。

6.Queue

Tensorflow 队列,也是图中的一个节点,是一种有状态的节点。

7.QueueRunner

队列管理器 ,通常会应用 多个线程 来读取数据,而后应用 一个线程 来应用数据。应用队列管理器来治理这些读写队列的线程。

8.Coordinator

应用 QueueRunner 时,因为入队和出队由各自线程实现,且未进行同步通信,导致程序无奈失常完结的状况。为了实现线程之间的 同步,须要应用Coordinator

Tensorflow 程序步骤

(一)加载训练数据

1. 生成或导入样本数据集。
2. 归一化解决。
3. 划分样本数据集为 训练样本集 测试样本集

(二)构建训练模型

1. 初始化超参数
2. 初始化变量和占位符
3. 定义模型构造
4. 定义损失函数

(三)进行数据训练

1. 初始化模型
2. 加载数据进行训练

(四)评估和预测

1. 评估机器学习模型
2. 调优超参数
3. 预测后果

加载数据

在 Tensorflow 中加载数据的形式一共有三种:预加载数据、填充数据和从文件读取数据。

预加载数据

在 Tensorflow 中定义常量或变量来保留所有数据,例如:

a = tf.constant([1, 2])
b = tf.constant([3, 4])
x = tf.add(a, b)

因为常数会间接存储在数据流图数据结构中,在训练过程中,这个构造体可能会被复制屡次,从而导致内存的大量耗费。

填充数据

将数据填充到任意一个张量中。而后通过会话 run() 函数中的 feed_dict 参数进行获取数据:

数据量大 时,填充数据的形式也存在耗费内存的问题。

从 CVS 文件中读取数据

要存文件中读取数据,首先须要应用读取器将数据读取到队列中,而后从队列中获取数据进行解决:

1. 创立队列
2. 创立读取器获取数据
3. 解决数据

读取 TFRecords 数据

Tensorflow 针对解决 数据量微小 的利用场景进行了优化,定义了 TFRecords 格局。

采纳这种读取形式读取数据分为两个步骤:

1. 把样本数据转换为 TFRecords 二进制文件
2. 读取TFRecords 格局。

存储和加载模型

Tensorflow 中提供了 tf.train.Saver 类实现训练模型的保留和加载。

存储模型

在模型的设计和训练的过程中,会耗费大量的工夫。为了升高训练过程中意外状况产生造成的不良影响,所以会对训练过程中模型进行定期存储。(模型复用,节俭整体训练工夫)

saver = tf.train.Saver(max_to_keep, keep_checkpoint_event_n_hours)

存储的模型,会生成四个文件:

加载模型

为了保障意外中断的模型可能持续训练以及训练实现的模型加载在其余数据上间接应用,会对模型进行加载应用。

加载存储好的模型,包含了两个步骤:

1. 加载模型:

saver = tf.train.import_meta_graph("my_test_model-100.meta")

2. 加载训练参数:

saver.restore(sess, tf.train.latest_checkpoint('./'))
正文完
 0