<>前言
最近新建了一个conda环境,搞上了tensorflow 2.0
(Beat),,,TF2.0改变确实很多,比如删除了Session……这对于我等习惯了先建图——再Session执行的人来说,我现在方的雅痞……2.0如何以图形式运行我还没有一点头绪(刚发现了tf.compat里面有历史版本233)……所以还在瑟瑟发抖的使用新版TF强烈推荐的keras。
今天正准备用TF2.0小跑一个图像任务,首先就是数据的读入,然而这边数据集11G,所以打算整合进TFrecord,方便之后;
<>介绍
TFrecord是Tensorflow提供并推荐使用的一种统一一种二进制文件格式,用于存储数据,理论上它可以保存任何格式的信息。
type value
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
如上:整个文件由文件长度信息、长度校验码、数据、数据校验码组成。
TFRecord 的核心内容在于内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。
比如我这边使用的Example是这样的:
exam = tf.train.Example ( features=tf.train.Features( feature={ 'name' : tf.
train.Feature(bytes_list=tf.train.BytesList (value=[splits[-1].encode('utf-8')])
), 'shape': tf.train.Feature(int64_list=tf.train.Int64List (value=[img.shape[0],
img.shape[1], img.shape[2]])), 'data' : tf.train.Feature(bytes_list=tf.train.
BytesList (value=[bytes(img.numpy())])) } ) )
可以看出,一个 Example 消息体包含了一个Features,而Features由诸多feature组成,其中每个feature 是一个 map,也就是
key-value 的键值对。其中,key 取值是 String 类型;而 value 是 Feature 类型的消息体,它的取值有 3 种:
* BytesList
* FloatList
* Int64List
需要注意的是,他们都是列表的形式。
<>如何创建TFrecord文件
从上面我们知道,TFRecord 内由一系列Example组成,每个Example可以代表一组数据。
Tensorflow 2.0 Beat 中,输出TFrecord的API为tf.io.TFRecordWriter (filename,
options=None), 其中第二个参数是用来控制文件的输出配置,一般不用管。第一个参数就是你要保存的文件名,调用该函数后,会返回一个Writer实例。
有了Writer,我们就可以不停的调用Writer.write
(example)来把我们的Examples输出到文件中,需要注意的是,该函数接受的是一个string,所以我们应该先把example序列化为string类型,即
Writer.write(example.SerializeToString())
当把所有的example输出到文件后,需要调用Writer.close()关闭文件。
例子:
writer = tf.io.TFRecordWriter (file_name) for item in file_list: # item =
.\\data\\xx(label)\\xxx.jpg splits = item.split ('\\') label = splits[2] img =
tf.io.read_file (item) img = tf.image.decode_jpeg (img) exam = tf.train.Example
( features=tf.train.Features( feature={ 'name' : tf.train.Feature(bytes_list=tf.
train.BytesList (value=[splits[-1].encode('utf-8')])), 'label': tf.train.Feature
(int64_list=tf.train.Int64List (value=[int(label)])), 'shape': tf.train.Feature(
int64_list=tf.train.Int64List (value=[img.shape[0], img.shape[1], img.shape[2]])
), 'data' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[bytes(img.
numpy())])) } ) ) writer.write (exam.SerializeToString()) writer.close()
这里因为Tensorflow 2.0 默认使用的是Eager模式,所以img是一个 Eager Tensor,需要转为numpy。
<>如何读取TFrecord
老版本中,我们可以使用tf.TFrecordReader(),不过这个在2.0里我没找到,所以我们使用
tf.data.TFRecordDataset(filename),调用后我们会得到一个Dataset(tf.data.Dataset
),字面理解,这里面就存放着我们之前写入的所有Example。
还记得写入时,我们把每个example都进行了序列化么,所以我们要得到之前的example,还需要解析以下之前写入的序列化string。
tf.io.parse_single_example(example_proto, feature_description)函数可以解析单条example.
解释一下这个函数:
第一个参数就是要解析的string,重点在于第二个参数,他要我们指定解析出来的example的格式。为了能正确解析,这个要和我们写入时的example对应起来:
比如我们写入时example为:
exam = tf.train.Example ( features=tf.train.Features( feature={ 'name' : tf.
train.Feature(bytes_list=tf.train.BytesList (value=[splits[-1].encode('utf-8')])
), 'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label)]))
, 'shape': tf.train.Feature(int64_list=tf.train.Int64List (value=[img.shape[0],
img.shape[1], img.shape[2]])), 'data' : tf.train.Feature(bytes_list=tf.train.
BytesList (value=[bytes(img.numpy())])) } ) )
则我们需要指定的参数为:
feature_description = { 'name' : tf.io.FixedLenFeature([], tf.string,
default_value='Nan'), 'label': tf.io.FixedLenFeature([] , tf.int64,
default_value=-1), # 默认值自己定义 'shape': tf.io.FixedLenFeature([3], tf.int64),
'data' : tf.io.FixedLenFeature([], tf.string) }
可以看到其中每一条都和之前的example中的feature对应(feature_description 中
map的key可以不对应,比如name改成id还是没问题的)。
OK,我们目前解决了解析一条example,但是一个Dataset中的example那么多。没关系tensorflow的dataset提供了
Dataset.map(func),可以给定一个映射规则,将dataset中的所有条目按照该规则进行映射,其实和python的map函数差不多。
所以我们可以把我们的映射一条的函数呈递给Dataset.map(func),以解析所有的example。
reader = tf.data.TFRecordDataset(file_name) # 打开一个TFrecord feature_description
= { 'name' : tf.io.FixedLenFeature([], tf.string, default_value='Nan'), 'label':
tf.io.FixedLenFeature([] , tf.int64, default_value=-1), 'shape': tf.io.
FixedLenFeature([3], tf.int64), 'data' : tf.io.FixedLenFeature([], tf.string) }
def _parse_function (exam_proto): # 映射函数,用于解析一条example return tf.io.
parse_single_example(exam_proto, feature_description) reader = reader.map (
_parse_function)
读取的话,我们可以用for循环:
for row in reader.take(10): # 只取前10条 # for row in reader: # 枚举所有example print (
row['name']) print (np.frombuffer(row['data'].numpy(), dtype=np.uint8)) #
如果要恢复成3d数组,可reshape
不过我们还可以完出花样:
dataset中还提供了很多方法,比如batch,shuffle,repeat。。。更多的可以自行去官网摸索(不知何时,访问TF官网突然就啥都不用了)
我们就可以这样:
reader = tf.data.TFRecordDataset(file_name) feature_description = { 'name' : tf
.io.FixedLenFeature([], tf.string, default_value='Nan'), 'label': tf.io.
FixedLenFeature([] , tf.int64, default_value=-1), 'shape': tf.io.FixedLenFeature
([3], tf.int64), 'data' : tf.io.FixedLenFeature([], tf.string) } def
_parse_function (exam_proto): return tf.io.parse_single_example (exam_proto,
feature_description) reader = reader.repeat (1) # 读取数据的重复次数为:1次,这个相当于epoch
reader= reader.shuffle (buffer_size = 2000) # 在缓冲区中随机打乱数据 reader = reader.map (
_parse_function) # 解析数据 batch = reader.batch (batch_size = 10) #
每10条数据为一个batch,生成一个新的Dataset shape = [] batch_data_x, batch_data_y = np.array([]
), np.array([]) for item in batch.take(1): # 测试,只取1个batch shape = item['shape'][
0].numpy() for data in item['data']: # 一个item就是一个batch img_data = np.frombuffer(
data.numpy(), dtype=np.uint8) batch_data_x = np.append (batch_data_x, img_data)
for label in item ['label']: batch_data_y = np.append (batch_data_y, label.numpy
()) batch_data_x = batch_data_x.reshape ([-1, shape[0], shape[1], shape[2]])
print (batch_data_x.shape, batch_data_y.shape) # = (10, 480, 640, 3) (10,) #
我的图片数据时480*640*3的
可以很方便的读取出数据的各批次,还能随即等等。