假如emo文件夹下,有1,2,3,4等文件夹,每个文件夹代表一个类别

1 import tensorflow as tf  2 from PIL import Image  3 from glob import glob  4 import os  5 import progressbar  6 import time  7   8   9 class TFRecord(): 10     def __init__(self, path=None, tfrecord_file=None): 11         self.path = path 12         self.tfrecord_file = tfrecord_file 13  14     def _convert_image(self, idx, img_path, is_training=True): 15         label = idx 16  17         with tf.gfile.FastGFile(img_path, 'rb') as fid: 18             img_str = fid.read() 19  20         # img_data = Image.open(img_path) 21         # img_data = img_data.resize((224, 224)) 22         # img_str = img_data.tobytes() 23  24         file_name = img_path 25  26         if is_training: 27             feature_key_value_pair = { 28                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList( 29                     value=[file_name.encode()])), 30                 'img': tf.train.Feature(bytes_list=tf.train.BytesList( 31                     value=[img_str])), 32                 'label': tf.train.Feature(int64_list=tf.train.Int64List( 33                     value=[label])) 34             } 35         else: 36             feature_key_value_pair = { 37                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList( 38                     value=[file_name.encode()])), 39                 'img': tf.train.Feature(bytes_list=tf.train.BytesList( 40                     value=[img_str])), 41                 'label': tf.train.Feature(int64_list=tf.train.Int64List( 42                     value=[-1])) 43             } 44  45         feature = tf.train.Features(feature=feature_key_value_pair) 46         example = tf.train.Example(features=feature) 47         return example 48  49     def convert_img_folder(self): 50  51         folder_path = self.path 52         tfrecord_path = self.tfrecord_file 53         img_paths = [] 54         for file in os.listdir(folder_path): 55             for img_path in os.listdir(os.path.join(folder_path, file)): 56                 img_paths.append(os.path.join(folder_path, file, img_path)) 57  58  59         with tf.python_io.TFRecordWriter(tfrecord_path) as tfwrite: 60             widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ", 61                        progressbar.Bar(), " ", progressbar.ETA()] 62             pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start() 63  64             cate = [folder_path + '/' + x for x in os.listdir(folder_path) if 65                     os.path.isdir(folder_path + '/' + x)] 66  67             i = 0 68             for idx, folder in enumerate(cate): 69                 for img_path in glob(folder + '/*.jpg'): 70                     example = self._convert_image(idx, img_path) 71                     tfwrite.write(example.SerializeToString()) 72                     pbar.update(i) 73                     i += 1 74  75             pbar.finish() 76  77     def _extract_fn(self, tfrecord): 78         feautres = { 79             'file_name': tf.FixedLenFeature([], tf.string), 80             'img': tf.FixedLenFeature([], tf.string), 81             'label': tf.FixedLenFeature([], tf.int64) 82         } 83         sample = tf.parse_single_example(tfrecord, feautres) 84         img = tf.image.decode_jpeg(sample['img']) 85         img = tf.image.resize_images(img, (224, 224), method=1) 86         label = sample['label'] 87         file_name = sample['file_name'] 88         return [img, label, file_name] 89  90     def extract_image(self, shuffle_size, batch_size): 91         dataset = tf.data.TFRecordDataset([self.tfrecord_file]) 92         dataset = dataset.map(self._extract_fn) 93         dataset = dataset.shuffle(shuffle_size).batch(batch_size) 94         print("---------", type(dataset)) 95         return dataset 96  97  98  99 100 if __name__=='__main__':101 102     # start = time.time()103     # t = GenerateTFRecord('/')104     # t.convert_img_folder('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')105     # print("Took %f seconds." % (time.time() - start))106 107     t =TFRecord('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')108     t.convert_img_folder()109     dataset = t.extract_image(100, 64)110     for(batch, batch_data) in enumerate(dataset):111         data, label, _ = batch_data112         print(label)113         print(data.shape)