TFRecord是一种TensorFlow 的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议(protocol buffer 是google 的一种数据交换的格式,它独立于语言,独立于平台),其后缀一般为tfrecord。TFRecord文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是TensorFlow“从文件里读取数据”的一种官方推荐方法。
TFRecord文件的生成
第一步,生成TFRecord Writer
writer = tf.python_io.TFRecordWriter(path, options=None)
path:TFRecord文件的存放路径;
option:TFRecordOptions对象,定义TFRecord文件保存的压缩格式;
有三种文件压缩格式可选,分别为TFRecordCompressionType.ZLIB、TFRecordCompressionType.GZIP以及TFRecordCompressionType.NONE,默认为最后一种,即不做任何压缩,定义方法如下:
option = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
第二步,tf.train.Feature生成protocol buffer格式
Example的protocol buffer格式的定义:
Example对象的里面有个Features类型的成员features。Features是一个map的列表,通过feature = example.features.feature 得到的feature是一个map,key是string,value是Feature类型。
内层feature是一个字典值,它是将某个类型列表(如name,shape等等)编码成特定的feature格式,而该字典键(Key)用于读取TFRecord文件时索引得到不同的数据,某个类型列表可能包含零个或多个值,列表类型一般有BytesList, FloatList, Int64List,通常用如下方法来生成某个列表类型再送给内层的tf.train.Feature编码:
tf.train.BytesList(value=[value]) # value转化为字符串(二进制)列表
tf.train.FloatList(value=[value]) # value转化为浮点型列表
tf.train.Int64List(value=[value]) # value转化为整型列表
其中,value是要保存的数据。
外层features再将内层字典编码:
features_extern = tf.train.Features(feature_internal)
第三步,使用tf.train.Example将features编码数据封装成特定的PB协议格式
example = tf.train.Example(features_extern)
第四步,将example数据序列化为字符串
example_str = example.SerializeToString()
第五步,将序列化为字符串的example数据写入协议缓冲区(PB)
writer.write(example_str)
writer.close()
writer.close()关闭TFRecordWriter,在写完数据到协议缓冲区后通常需要调用writer.close()主动关闭TFRecord文件操作接口。
完整代码如下所示:
import tensorflow as tf
def gen_tfrecord(input,output):
''' 借助TFRecordWriter 将信息写进 TFRecord 文件'''
writer = tf.python_io.TFRecordWriter(output)
# 读取图片并进行解码
image = tf.read_file(input)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
image = sess.run(image)
shape = image.shape
# 将图片转换成 string。
image_data = image.tostring()
print(type(image))
print(len(image_data))
name = bytes("Liushishi", encoding='utf8')
print(type(name))
# 创建 Example 对象,并且将 Feature 一一对应填充进去。
example = tf.train.Example(features=tf.train.Features(feature={
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
))
# 将 example 序列化成 string 类型,然后写入。
writer.write(example.SerializeToString())
writer.close()
原始演示图像:
代码说明:
1. 将图片解码,然后转化成 string 数据,然后填充进去。
2. Feature 的 value 是列表,所以要记得加 []。
3. example 需要调用 SerializetoString() 进行序列化后才能够写入。
上面的 Example 表示,要将一张图片信息写进 TFRecord 当中,而图片信息包含了图片的名字(name),图片的维度信息还有图片的数据,分别对应了 name、shape、data 等等3个 feature。
调用上述gen_tfrecord() 函数,可以获得输出文件tfrecord。
gen_tfrecord('./liushishi.jpg','./data/liushishi.tfrecord')
TFRecord 文件的读取
示例代码如下:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
def _parse_record(example_proto):
# 将数据反序列化为结构化的数据
features = {
'name': tf.FixedLenFeature((), tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def gen_display_tfrecord(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
features = sess.run(iterator.get_next())
name = features['name']
name = name.decode()
img_data = features['data']
shape = features['shape']
print(type(shape))
print(len(img_data))
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
plt.figure()
#显示图片
plt.imshow(image_data)
plt.title(name)
plt.show()
print("done")
代码说明:
- 用 dataset 去读取 tfrecord 文件;
- 将数据反序列化为结构化的数据。在解析 example 的时候,用现成的 API函数 tf.parse_single_example;
- 用 np.fromstring() 方法就可以获取解析后的 string 数据,记得数据格式还原成 np.uint8;
- 因为将图片 shape 写进了 example 中,解析的时候必须指定维度,在这里是 [3] ,不然程序报错。
调用gen_display_tfrecord函数,验证输出结果。
gen_display_tfrecord('./data/liushishi.tfrecord')
参考链接:
你可能无法回避的 TFRecord 文件格式详细讲解
https://frank909.blog.csdn.net/article/details/80789608
tensorflow TFRecord文件的生成和读取方法
https://zhuanlan.zhihu.com/p/31992460