共计 6462 个字符,预计需要花费 17 分钟才能阅读完成。
前言
只有光头才能变强。文本已收录至我的 GitHub 仓库,欢迎 Star:https://github.com/ZhongFuCheng3y/3y
回顾前面:
从零开始学 TensorFlow【01- 搭建环境、HelloWorld 篇】
什么是 TensorFlow?
众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用 dataset 的 api 去加载 mnist 的数据。(minst 的数据要么我们是提前下载好,放在对应的目录上,要么就根据他给的 url 直接从网上下载)。
一般来说,我们使用 TensorFlow 是从 TFRecord 文件中读取数据的。
TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据
所以,这篇文章来聊聊怎么读取 TFRecord 文件的数据。
一、入门对数据集的数据进行读和写
首先,我们来体验一下怎么造一个 TFRecord 文件,怎么从 TFRecord 文件中读取数据,遍历 (消费) 这些数据。
1.1 造一个 TFRecord 文件
现在,我们还没有 TFRecord 文件,我们可以自己简单写一个:
def write_sample_to_tfrecord():
gmv_values = np.arange(10)
click_values = np.arange(10)
label_values = np.arange(10)
with tf.python_io.TFRecordWriter(“/Users/zhongfucheng/data/fashin/demo.tfrecord”, options=None) as writer:
for _ in range(10):
feature_internal = {
“gmv”: tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
“click”: tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
“label”: tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
}
features_extern = tf.train.Features(feature=feature_internal)
# 使用 tf.train.Example 将 features 编码数据封装成特定的 PB 协议格式
# example = tf.train.Example(features=tf.train.Features(feature=features_extern))
example = tf.train.Example(features=features_extern)
# 将 example 数据系列化为字符串
example_str = example.SerializeToString()
# 将系列化为字符串的 example 数据写入协议缓冲区
writer.write(example_str)
if __name__ == ‘__main__’:
write_sample_to_tfrecord()
我相信大家代码应该是能够看得懂的,其实就是分了几步:
生成 TFRecord Writer
tf.train.Feature 生成协议信息
使用 tf.train.Example 将 features 编码数据封装成特定的 PB 协议格式
将 example 数据系列化为字符串
将系列化为字符串的 example 数据写入协议缓冲区
参考资料:
https://zhuanlan.zhihu.com/p/31992460
ok,现在我们就有了一个 TFRecord 文件啦。
1.2 读取 TFRecord 文件
其实就是通过 tf.data.TFRecordDataset 这个 api 来读取到 TFRecord 文件,生成处 dataset 对象
对 dataset 进行处理(shape 处理,格式处理 … 等等)
使用迭代器对 dataset 进行消费(遍历)
demo 代码如下:
import tensorflow as tf
def read_tensorflow_tfrecord_files():
# 定义消费缓冲区协议的 parser, 作为 dataset.map()方法中传入的 lambda:
def _parse_function(single_sample):
features = {
“gmv”: tf.FixedLenFeature([1], tf.float32),
“click”: tf.FixedLenFeature([1], tf.int64), # ()或者 [] 没啥影响
“label”: tf.FixedLenFeature([1], tf.int64)
}
parsed_features = tf.parse_single_example(single_sample, features=features)
# 对 parsed 之后的值进行 cast.
gmv = tf.cast(parsed_features[“gmv”], tf.float64)
click = tf.cast(parsed_features[“click”], tf.float64)
label = tf.cast(parsed_features[“label”], tf.float64)
return gmv, click, label
# 开始定义 dataset 以及解析 tfrecord 格式
filenames = tf.placeholder(tf.string, shape=[None])
# 定义 dataset 和 一些列 trasformation method
dataset = tf.data.TFRecordDataset(filenames)
parsed_dataset = dataset.map(_parse_function) # 消费缓冲区需要定义在 dataset 的 map 函数中
batchd_dataset = parsed_dataset.batch(3)
# 创建 Iterator
sample_iter = batchd_dataset.make_initializable_iterator()
# 获取 next_sample
gmv, click, label = sample_iter.get_next()
training_filenames = [
“/Users/zhongfucheng/data/fashin/demo.tfrecord”]
with tf.Session() as session:
# 初始化带参数的 Iterator
session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
# 读取文件
print(session.run(gmv))
if __name__ == ‘__main__’:
read_tensorflow_tfrecord_files()
无意外的话,我们可以输出这样的结果:
[[0.]
[1.]
[2.]]
ok,现在我们已经大概知道怎么写一个 TFRecord 文件,以及怎么读取 TFRecord 文件的数据,并且消费这些数据了。
二、epoch 和 batchSize 术语解释
我在学习 TensorFlow 翻阅资料时,经常看到一些机器学习的术语,由于自己没啥机器学习的基础,所以很多时候看到一些专业名词就开始懵逼了。
2.1epoch
当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个 epoch。
这可能使我们跟 dataset.repeat()方法联系起来,这个方法可以使当前数据集重复一遍。比如说,原有的数据集是 [1,2,3,4,5],如果我调用 dataset.repeat(2) 的话,那么我们的数据集就变成了[1,2,3,4,5],[1,2,3,4,5]
所以会有个说法:假设原先的数据是一个 epoch,使用 repeat(5)就可以将之变成 5 个 epoch
2.2batchSize
一般来说我们的数据集都是比较大的,无法一次性将整个数据集的数据喂进神经网络中,所以我们会将数据集分成好几个部分。每次喂多少条样本进神经网络,这个叫做 batchSize。
在 TensorFlow 也提供了方法给我们设置:dataset.batch(),在 API 中是这样介绍 batchSize 的:
representing the number of consecutive elements of this dataset to combine in a single batch
我们一般在每次训练之前,会将整个数据集的顺序打乱,提高我们模型训练的效果。这里我们用到的 api 是:dataset.shffle();
三、再来聊聊 dataset
我从官网的介绍中截了一个 dataset 的方法图(部分):
dataset 的功能主要有以下三种:
创建 dataset 实例
通过文件创建(比如 TFRecord)
通过内存创建
对数据集的数据进行变换
比如上面的 batch(),常见的 map(),flat_map(),zip(),repeat()等等
文档中一般都有给出例子,跑一下一般就知道对应的意思了。
创建迭代器,遍历数据集的数据
3.1 聊聊迭代器
迭代器可以分为四种:
单次。对数据集进行一次迭代,不支持参数化
可初始化迭代
使用前需要进行初始化,支持传入参数。面向的是同一个 DataSet
可重新初始化:同一个 Iterator 从不同的 DataSet 中读取数据
DataSet 的对象具有相同的结构,可以使用 tf.data.Iterator.from_structure 来进行初始化
问题:每次 Iterator 切换时,数据都从头开始打印了
可馈送(也是通过对象相同的结果来创建的迭代器)
可让您在两个数据集之间切换的可馈送迭代器
通过一个 string handler 来实现。
可馈送的 Iterator 在不同的 Iterator 切换的时候,可以做到不从头开始。
简单总结:
1、单次 Iterator,它最简单,但无法重用,无法处理数据集参数化的要求。
2、可以初始化的 Iterator,它可以满足 Dataset 重复加载数据,满足了参数化要求。
3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。
4、可馈送的 Iterator,它可以通过 feeding 的方式,让程序在运行时候选择正确的 Iterator, 它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时,可以做到不重头开始读取数据。
string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个 Demo 来使用了一下,代码如下:
def read_tensorflow_tfrecord_files():
# 开始定义 dataset 以及解析 tfrecord 格式.
train_filenames = tf.placeholder(tf.string, shape=[None])
vali_filenames = tf.placeholder(tf.string, shape=[None])
# 加载 train_dataset batch_inputs 这个方法每个人都不一样的,这个方法我就不给了。
train_dataset = batch_inputs([
train_filenames], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 加载 validation_dataset batch_inputs 这个方法每个人都不一样的,这个方法我就不给了。
validation_dataset = batch_inputs([vali_filenames
], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 创建出 string_handler()的迭代器(通过相同数据结构的 dataset 来构建)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
# 有了迭代器就可以调用 next 方法了。
itemid = iterator.get_next()
# 指定哪种具体的迭代器,有单次迭代的,有初始化的。
training_iterator = train_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# 定义出 placeholder 的值
training_filenames = [
“/Users/zhongfucheng/tfrecord_test/data01aa”]
validation_filenames = [“/Users/zhongfucheng/tfrecord_validation/part-r-00766”]
with tf.Session() as sess:
# 初始化迭代器
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
for _ in range(2):
sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
print(“this is training iterator —-“)
for _ in range(5):
print(sess.run(itemid, feed_dict={handle: training_handle}))
sess.run(validation_iterator.initializer,
feed_dict={vali_filenames: validation_filenames})
print(“this is validation iterator “)
for _ in range(5):
print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))
if __name__ == ‘__main__’:
read_tensorflow_tfrecord_files()
参考资料:
https://blog.csdn.net/briblue/article/details/80962728
3.2 dataset 参考资料
在翻阅资料时,发现写得不错的一些博客:
https://www.jianshu.com/p/91803a119f18
https://irvingzhang0512.github.io/2018/04/19/tensorflow-api-2/
http://www.feiguyunai.com/index.php/2017/12/25/pyhtonai-ml-dataprocess-datasetapi/
最后
乐于输出干货的 Java 技术公众号:Java3y。公众号内有 200 多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!
下一篇文章打算讲讲如何理解 axis~
觉得我的文章写得不错,不妨点一下赞!