共计 3093 个字符,预计需要花费 8 分钟才能阅读完成。
假如 emo 文件夹下,有 1,2,3,4 等文件夹,每个文件夹代表一个类别
1 import tensorflow as tf
2 from PIL import Image
3 from glob import glob
4 import os
5 import progressbar
6 import time
7
8
9 class TFRecord():
10 def __init__(self, path=None, tfrecord_file=None):
11 self.path = path
12 self.tfrecord_file = tfrecord_file
13
14 def _convert_image(self, idx, img_path, is_training=True):
15 label = idx
16
17 with tf.gfile.FastGFile(img_path, 'rb') as fid:
18 img_str = fid.read()
19
20 # img_data = Image.open(img_path)
21 # img_data = img_data.resize((224, 224))
22 # img_str = img_data.tobytes()
23
24 file_name = img_path
25
26 if is_training:
27 feature_key_value_pair = {
28 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(29 value=[file_name.encode()])),
30 'img': tf.train.Feature(bytes_list=tf.train.BytesList(31 value=[img_str])),
32 'label': tf.train.Feature(int64_list=tf.train.Int64List(33 value=[label]))
34 }
35 else:
36 feature_key_value_pair = {
37 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(38 value=[file_name.encode()])),
39 'img': tf.train.Feature(bytes_list=tf.train.BytesList(40 value=[img_str])),
41 'label': tf.train.Feature(int64_list=tf.train.Int64List(42 value=[-1]))
43 }
44
45 feature = tf.train.Features(feature=feature_key_value_pair)
46 example = tf.train.Example(features=feature)
47 return example
48
49 def convert_img_folder(self):
50
51 folder_path = self.path
52 tfrecord_path = self.tfrecord_file
53 img_paths = []
54 for file in os.listdir(folder_path):
55 for img_path in os.listdir(os.path.join(folder_path, file)):
56 img_paths.append(os.path.join(folder_path, file, img_path))
57
58
59 with tf.python_io.TFRecordWriter(tfrecord_path) as tfwrite:
60 widgets = ['[INFO] write image to tfrecord:', progressbar.Percentage(), " ",
61 progressbar.Bar(), " ", progressbar.ETA()]
62 pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start()
63
64 cate = [folder_path + '/' + x for x in os.listdir(folder_path) if
65 os.path.isdir(folder_path + '/' + x)]
66
67 i = 0
68 for idx, folder in enumerate(cate):
69 for img_path in glob(folder + '/*.jpg'):
70 example = self._convert_image(idx, img_path)
71 tfwrite.write(example.SerializeToString())
72 pbar.update(i)
73 i += 1
74
75 pbar.finish()
76
77 def _extract_fn(self, tfrecord):
78 feautres = {79 'file_name': tf.FixedLenFeature([], tf.string),
80 'img': tf.FixedLenFeature([], tf.string),
81 'label': tf.FixedLenFeature([], tf.int64)
82 }
83 sample = tf.parse_single_example(tfrecord, feautres)
84 img = tf.image.decode_jpeg(sample['img'])
85 img = tf.image.resize_images(img, (224, 224), method=1)
86 label = sample['label']
87 file_name = sample['file_name']
88 return [img, label, file_name]
89
90 def extract_image(self, shuffle_size, batch_size):
91 dataset = tf.data.TFRecordDataset([self.tfrecord_file])
92 dataset = dataset.map(self._extract_fn)
93 dataset = dataset.shuffle(shuffle_size).batch(batch_size)
94 print("---------", type(dataset))
95 return dataset
96
97
98
99
100 if __name__=='__main__':
101
102 # start = time.time()
103 # t = GenerateTFRecord('/')
104 # t.convert_img_folder('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
105 # print("Took %f seconds." % (time.time() - start))
106
107 t =TFRecord('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
108 t.convert_img_folder()
109 dataset = t.extract_image(100, 64)
110 for(batch, batch_data) in enumerate(dataset):
111 data, label, _ = batch_data
112 print(label)
113 print(data.shape)
正文完