关于人工智能:MindSpore易点通精讲系列数据集加载之TFRecordDataset

37次阅读

共计 10521 个字符,预计需要花费 27 分钟才能阅读完成。

Dive Into MindSpore — TFRecordDataset For Dataset LoadMindSpore 易点通·精讲系列 – 数据集加载之 TFRecordDataset 本文开发环境 Ubuntu 20.04Python 3.8MindSpore 1.7.0 本文内容摘要背景介绍先看文档生成 TFRecord 数据加载本文总结本文参考 1. 背景介绍 TFRecord 格局是 TensorFlow 官网设计的一种数据格式。TFRecord 格局是一种用于存储二进制记录序列的简略格局,该格局可能更好的利用内存,外部蕴含多个 tf.train.Example,在一个 Examples 音讯体中蕴含一系列的 tf.train.feature 属性,而每一个 feature 是一个 key-value 的键值对,其中 key 是 string 类型,value 的取值有三种:bytes_list:能够存储 string 和 byte 两种数据类型 float_list:能够存储 float(float32) 和 double(float64) 两种数据类型 int64_list:能够存储 bool, enum, int32, uint32, int64, uint64 数据类型下面简略介绍了 TFRecord 的常识,上面咱们就要进入正题,来谈谈 MindSpore 中对 TFRecord 格局的反对。2. 先看文档老传统,先来看看官网对 API 的形容。

上面对主要参数做简略介绍:dataset_files — 数据集文件门路。schema — 读取模式策略,艰深来说就是要读取的 tfrecord 文件内的数据内容格局。能够通过 json 或者 Schema 传入。默认为 None 不指定。columns_list — 指定读取的具体数据列。默认全副读取。num_samples — 指定读取进去的样本数量。shuffle — 是否对数据进行打乱,可参考之前的文章解读。3. 生成 TFRecord 本文应用的是 THUCNews 数据集,如果须要将该数据集用于商业用途,请分割数据集作者。数据集启智社区下载地址因为下文须要用到 TFRecord 数据集来做加载,本节先来生成 TFRecord 数据集。对 TensorFlow 不理解的读者能够间接照搬代码即可。生成 TFRecord 代码如下:import codecs
import os
import re
import six
import tensorflow as tf

from collections import Counter

def _int64_feature(values):

"""Returns a TF-Feature of int64s.

Args:
    values: A scalar or list of values.

Returns:
    A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
    values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def _float32_feature(values):

"""Returns a TF-Feature of float32s.

Args:
    values: A scalar or list of values.

Returns:
    A TF-Feature.
 """
if not isinstance(values, (tuple, list)):
    values = [values]
return tf.train.Feature(float_list=tf.train.FloatList(value=values))

def _bytes_feature(values):

"""Returns a TF-Feature of bytes.
Args:
    values: A scalar or list of values.

Returns:
    A TF-Feature
"""
if not isinstance(values, (tuple, list)):
    values = [values]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))

def convert_to_feature(values):

"""Convert to TF-Feature based on the type of element in values.

Args:
    values: A scalar or list of values.

Returns:
    A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
    values = [values]

if isinstance(values[0], int):
    return _int64_feature(values)
elif isinstance(values[0], float):
    return _float32_feature(values)
elif isinstance(values[0], bytes):
    return _bytes_feature(values)
else:
    raise ValueError("feature type {0} is not supported now !".format(type(values[0])))

def dict_to_example(dictionary):

"""Converts a dictionary of string->int to a tf.Example."""
features = {}
for k, v in six.iteritems(dictionary):
    features[k] = convert_to_feature(values=v)
return tf.train.Example(features=tf.train.Features(feature=features))

def get_txt_files(data_dir):

cls_txt_dict = {}
txt_file_list = []

# get files list and class files list.
sub_data_name_list = next(os.walk(data_dir))[1]
sub_data_name_list = sorted(sub_data_name_list)
for sub_data_name in sub_data_name_list:
    sub_data_dir = os.path.join(data_dir, sub_data_name)
    data_name_list = next(os.walk(sub_data_dir))[2]
    data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
    cls_txt_dict[sub_data_name] = data_file_list
    txt_file_list.extend(data_file_list)
    num_data_files = len(data_file_list)
    print("{}: {}".format(sub_data_name, num_data_files), flush=True)
num_txt_files = len(txt_file_list)
print("total: {}".format(num_txt_files), flush=True)

return cls_txt_dict, txt_file_list

def get_txt_data(txt_file):

with codecs.open(txt_file, "r", "UTF8") as fp:
    txt_content = fp.read()
txt_data = re.sub("\s+", " ", txt_content)

return txt_data

def build_vocab(txt_file_list, vocab_size=7000):

counter = Counter()
for txt_file in txt_file_list:
    txt_data = get_txt_data(txt_file)
    counter.update(txt_data)

num_vocab = len(counter)
if num_vocab < vocab_size - 1:
    real_vocab_size = num_vocab + 2
else:
    real_vocab_size = vocab_size

# pad_id is 0, unk_id is 1
vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}

print("real vocab size: {}".format(real_vocab_size), flush=True)
print("vocab dict:\n{}".format(vocab_dict), flush=True)

return vocab_dict

def make_tfrecords(

    data_dir, tfrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800,
    num_train=8, num_test=2, start_fid=0):
# get txt files
cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir)
# map word to id
vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size=vocab_size)
# map class to id
class_dict = {class_name: ix for ix, class_name in enumerate(cls_txt_dict.keys())}

train_writers = []
for fid in range(start_fid, num_train+start_fid):
    tfrecord_file = os.path.join(tfrecord_dir, "train_{:04d}.tfrecord".format(fid))
    writer = tf.io.TFRecordWriter(tfrecord_file)
    train_writers.append(writer)

test_writers = []
for fid in range(start_fid, num_test+start_fid):
    tfrecord_file = os.path.join(tfrecord_dir, "test_{:04d}.tfrecord".format(fid))
    writer = tf.io.TFRecordWriter(tfrecord_file)
    test_writers.append(writer)

pad_id = 0
unk_id = 1
num_samples = 0
num_train_samples = 0
num_test_samples = 0
for class_name, class_file_list in cls_txt_dict.items():
    class_id = class_dict[class_name]
    num_class_pass = 0
    for txt_file in class_file_list:
        txt_data = get_txt_data(txt_file=txt_file)
        txt_len = len(txt_data)
        if txt_len < min_seq_length:
            num_class_pass += 1
            continue
        if txt_len > max_seq_length:
            txt_data = txt_data[:max_seq_length]
            txt_len = max_seq_length
        word_ids = []
        for word in txt_data:
            word_id = vocab_dict.get(word, unk_id)
            word_ids.append(word_id)
        for _ in range(max_seq_length - txt_len):
            word_ids.append(pad_id)

        example = dict_to_example({"input": word_ids, "class": class_id})
        num_samples += 1
        if num_samples % 10 == 0:
            num_test_samples += 1
            writer_id = num_test_samples % num_test
            test_writers[writer_id].write(example.SerializeToString())
        else:
            num_train_samples += 1
            writer_id = num_train_samples % num_train
            train_writers[writer_id].write(example.SerializeToString())
    print("{} pass: {}".format(class_name, num_class_pass), flush=True)

for writer in train_writers:
    writer.close()
for writer in test_writers:
    writer.close()

print("num samples: {}".format(num_samples), flush=True)
print("num train samples: {}".format(num_train_samples), flush=True)
print("num test samples: {}".format(num_test_samples), flush=True)

def main():

data_dir = "{your_data_dir}"
tfrecord_dir = "{your_tfrecord_dir}"
make_tfrecords(data_dir=data_dir, tfrecord_dir=tfrecord_dir)

if name == “__main__”:

main()

复制将以上代码保留到文件 make_tfrecord.py,运行命令:留神:须要替换 data_dir 和 tfrecord_dir 为集体目录。python3 make_tfrecord.py
复制应用 tree 命令查看生成的 TFRecord 数据目录,输入内容如下:.
├── test_0000.tfrecord
├── test_0001.tfrecord
├── train_0000.tfrecord
├── train_0001.tfrecord
├── train_0002.tfrecord
├── train_0003.tfrecord
├── train_0004.tfrecord
├── train_0005.tfrecord
├── train_0006.tfrecord
└── train_0007.tfrecord

0 directories, 10 files
复制 4. 数据加载有了 3 中的 TFRecord 数据集,上面来介绍如何在 MindSpore 中应用该数据集。4.1 schema 应用 4.1.1 不指定 schema 首先来看看对于参数 schema 不指定,即采纳默认值的状况下,是否正确读取数据。代码如下:import os

from mindspore.common import dtype as mstype
from mindspore.dataset import Schema
from mindspore.dataset import TFRecordDataset

def get_tfrecord_files(tfrecord_dir, file_suffix=”tfrecord”, is_train=True):

if not os.path.exists(tfrecord_dir):
    raise ValueError("tfrecord directory: {} not exists!".format(tfrecord_dir))

if is_train:
    file_prefix = "train"
else:
    file_prefix = "test"

data_sources = []
for parent, _, filenames in os.walk(tfrecord_dir):
    for filename in filenames:
        if not filename.startswith(file_prefix):
            continue
        tmp_path = os.path.join(parent, filename)
        if tmp_path.endswith(file_suffix):
            data_sources.append(tmp_path)
return data_sources

def load_tfrecord(tfrecord_dir, tfrecord_json=None):

tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

dataset = TFRecordDataset(dataset_files=tfrecord_files, shuffle=False)

data_iter = dataset.create_dict_iterator()
for item in data_iter:
    print(item, flush=True)
    break

def main():

tfrecord_dir = "{your_tfrecord_dir}"
tfrecord_json = "{your_tfrecord_json_file}"
load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=None)

if name == “__main__”:

main()

复制代码解读:get_tfrecord_files — 获取指定的 TFRecord 文件列表 load_tfrecord — 数据集加载将上述代码保留到文件 load_tfrecord_dataset.py,运行如下命令:python3 load_tfrecord_dataset.py
复制输入内容如下:能够看出能正确解析出之前保留在 TFRecord 内的数据,数据类型和数据维度解析正确。{‘class’: Tensor(shape=[1], dtype=Int64, value= [0]), ‘input’: Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18,
……
135, 979, 1, 35, 166, 181, 90, 143])}
复制 4.1.2 应用 Schema 对象上面介绍,如何应用 mindspore.dataset.Schema 来指定读取模型策略。批改 load_tfrecord 代码如下:def load_tfrecord(tfrecord_dir, tfrecord_json=None):

tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

data_schema = Schema()
data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, shuffle=False)

data_iter = dataset.create_dict_iterator()
for item in data_iter:
    print(item, flush=True)
    break

复制代码解读:这里应用了 Schema 对象,并且指定了列名,列的数据类型和数据维度。保留并再次运行文件 load_tfrecord_dataset.py,输入内容如下:能够看出能正确解析出之前保留在 TFRecord 内的数据,数据类型和数据维度解析正确。{‘input’: Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18, 742, 330, 385, 999, 837, 56, 529, 1000,
…..
135, 979, 1, 35, 166, 181, 90, 143]), ‘class’: Tensor(shape=[1], dtype=Int64, value= [0])}
复制 4.1.3 应用 JSON 文件上面介绍,如何应用 JSON 文件来指定读取模型策略。新建 tfrecord_sample.json 文件,在文件内写入如下内容:numRows — 数据列数 columns — 顺次为每列的列名、数据类型、数据维数、数据维度。{
“datasetType”: “TF”,
“numRows”: 2,
“columns”: {

"input": {
  "type": "int64",
  "rank": 1,
  "shape": [800]
},
"class" : {
  "type": "int64",
  "rank": 1,
  "shape": [1]
}

}
}
复制有了相应的 JSON 文件,上面来介绍如何应用该文件进行数据读取。批改 load_tfrecord 代码如下:def load_tfrecord(tfrecord_dir, tfrecord_json=None):

tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=tfrecord_json, shuffle=False)

data_iter = dataset.create_dict_iterator()
for item in data_iter:
    print(item, flush=True)
    break

复制同时批改 main 局部代码如下:load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=tfrecord_json)
复制代码解读这里间接将 schema 参数指定为 JSON 的文件门路保留并再次运行文件 load_tfrecord_dataset.py,输入内容如下:{‘class’: Tensor(shape=[1], dtype=Int64, value= [0]), ‘input’: Tensor(shape=[800], dtype=Int64, value= [1719, 636, 1063, 18, ……
135, 979, 1, 35, 166, 181, 90, 143])}
复制 4.2 columns_list 应用在某些场景下,咱们可能只须要某(几)列的数据,而非全副数据,这时候就能够通过制订 columns_list 来进行数据加载。上面咱们只读取 class 列,来简略看看如何操作。在 4.1.2 根底上,批改 load_tfrecord 代码如下:def load_tfrecord(tfrecord_dir, tfrecord_json=None):

tfrecord_files = get_tfrecord_files(tfrecord_dir)
# print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

data_schema = Schema()
data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, columns_list=["class"], shuffle=False)

data_iter = dataset.create_dict_iterator()
for item in data_iter:
    print(item, flush=True)
    break

复制保留并再次运行文件 load_tfrecord_dataset.py,输入内容如下:能够看到只读取了咱们指定的列,且数据加载正确。{‘class’: Tensor(shape=[1], dtype=Int64, value= [0])}
复制 5. 本文总结本文介绍了在 MindSpore 中如何加载 TFRecord 数据集,并重点介绍了 TFRecordDataset 中的 schema 和 columns_list 参数应用。6. 本文参考 THUCTC: 一个高效的中文文本分类工具包 THUCNews 数据集 TFRecordDataset API 本文为原创文章,版权归作者所有,未经受权不得转载!

正文完
 0