四时宝库

程序员的知识宝库

TensorFlow中TFRecord文件的生成和读取方法

TFRecord是一种TensorFlow 的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议(protocol buffer 是google 的一种数据交换的格式,它独立于语言,独立于平台),其后缀一般为tfrecord。TFRecord文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是TensorFlow“从文件里读取数据”的一种官方推荐方法。

TFRecord文件的生成

第一步,生成TFRecord Writer

writer = tf.python_io.TFRecordWriter(path, options=None)

path:TFRecord文件的存放路径;

option:TFRecordOptions对象,定义TFRecord文件保存的压缩格式;

有三种文件压缩格式可选,分别为TFRecordCompressionType.ZLIB、TFRecordCompressionType.GZIP以及TFRecordCompressionType.NONE,默认为最后一种,即不做任何压缩,定义方法如下:

option = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)

第二步,tf.train.Feature生成protocol buffer格式

Example的protocol buffer格式的定义:

Example对象的里面有个Features类型的成员features。Features是一个map的列表,通过feature = example.features.feature 得到的feature是一个map,key是string,value是Feature类型。

内层feature是一个字典值,它是将某个类型列表(如name,shape等等)编码成特定的feature格式,而该字典键(Key)用于读取TFRecord文件时索引得到不同的数据,某个类型列表可能包含零个或多个值,列表类型一般有BytesList, FloatList, Int64List,通常用如下方法来生成某个列表类型再送给内层的tf.train.Feature编码:

tf.train.BytesList(value=[value]) # value转化为字符串(二进制)列表

tf.train.FloatList(value=[value]) # value转化为浮点型列表

tf.train.Int64List(value=[value]) # value转化为整型列表

其中,value是要保存的数据。

外层features再将内层字典编码:

features_extern = tf.train.Features(feature_internal)

第三步,使用tf.train.Example将features编码数据封装成特定的PB协议格式

example = tf.train.Example(features_extern)

第四步,将example数据序列化为字符串

example_str = example.SerializeToString()

第五步,将序列化为字符串的example数据写入协议缓冲区(PB)

writer.write(example_str)

writer.close()

writer.close()关闭TFRecordWriter,在写完数据到协议缓冲区后通常需要调用writer.close()主动关闭TFRecord文件操作接口。

完整代码如下所示:

import tensorflow as tf

def gen_tfrecord(input,output):

''' 借助TFRecordWriter 将信息写进 TFRecord 文件'''

writer = tf.python_io.TFRecordWriter(output)

# 读取图片并进行解码

image = tf.read_file(input)

image = tf.image.decode_jpeg(image)

with tf.Session() as sess:

image = sess.run(image)

shape = image.shape

# 将图片转换成 string。

image_data = image.tostring()

print(type(image))

print(len(image_data))

name = bytes("Liushishi", encoding='utf8')

print(type(name))

# 创建 Example 对象,并且将 Feature 一一对应填充进去。

example = tf.train.Example(features=tf.train.Features(feature={

'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),

'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),

'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))

}

))

# 将 example 序列化成 string 类型,然后写入。

writer.write(example.SerializeToString())

writer.close()

原始演示图像:

代码说明:

1. 将图片解码,然后转化成 string 数据,然后填充进去。

2. Feature 的 value 是列表,所以要记得加 []。

3. example 需要调用 SerializetoString() 进行序列化后才能够写入。

上面的 Example 表示,要将一张图片信息写进 TFRecord 当中,而图片信息包含了图片的名字(name),图片的维度信息还有图片的数据,分别对应了 name、shape、data 等等3个 feature。

调用上述gen_tfrecord() 函数,可以获得输出文件tfrecord。

gen_tfrecord('./liushishi.jpg','./data/liushishi.tfrecord')

TFRecord 文件的读取

示例代码如下:

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

def _parse_record(example_proto):

# 将数据反序列化为结构化的数据

features = {

'name': tf.FixedLenFeature((), tf.string),

'shape': tf.FixedLenFeature([3], tf.int64),

'data': tf.FixedLenFeature((), tf.string)}

parsed_features = tf.parse_single_example(example_proto, features=features)

return parsed_features

def gen_display_tfrecord(input_file):

# 用 dataset 读取 tfrecord 文件

dataset = tf.data.TFRecordDataset(input_file)

dataset = dataset.map(_parse_record)

iterator = dataset.make_one_shot_iterator()

with tf.Session() as sess:

features = sess.run(iterator.get_next())

name = features['name']

name = name.decode()

img_data = features['data']

shape = features['shape']

print(type(shape))

print(len(img_data))

# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组

img_data = np.fromstring(img_data, dtype=np.uint8)

image_data = np.reshape(img_data, shape)

plt.figure()

#显示图片

plt.imshow(image_data)

plt.title(name)

plt.show()

print("done")

代码说明:

  • 用 dataset 去读取 tfrecord 文件;
  • 将数据反序列化为结构化的数据。在解析 example 的时候,用现成的 API函数 tf.parse_single_example;
  • 用 np.fromstring() 方法就可以获取解析后的 string 数据,记得数据格式还原成 np.uint8;
  • 因为将图片 shape 写进了 example 中,解析的时候必须指定维度,在这里是 [3] ,不然程序报错。

调用gen_display_tfrecord函数,验证输出结果。

gen_display_tfrecord('./data/liushishi.tfrecord')

参考链接:

你可能无法回避的 TFRecord 文件格式详细讲解

https://frank909.blog.csdn.net/article/details/80789608

tensorflow TFRecord文件的生成和读取方法

https://zhuanlan.zhihu.com/p/31992460

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接