四时宝库

程序员的知识宝库

走近深度学习,认识MoXing:数据输入教程

本文主要介绍MoXing将数据的输入定义在input_fn方法中,并在mox.run时注册该方法。

基本方法:

def input_fn(mode, **kwargs):

...

return input_0, input_1, ...

mox.run(..., input_fn=input_fn, ...)

输入参数:

· mode: 当前调用input_fn时的运行模式,需要用户在input_fn中做好判断使用相应的数据集和数据集增强、预处理方法。

· **kwargs: 扩展参数的预留位置。

返回值:

· tf.Tensor或tf.Tensor的list

input_fn中的返回值包含了2种情况:

1) auto_batch=True

当用户实现的input_fn的返回值input_i不包含batch_size维度时,在mox.run中用户需要添加参数:

mox.run(...

batch_size=32,

auto_batch=True,

...)

MoXing会自动将input_fn中的输入以batch为单位聚合,并将含有batch_size维度的Tensor输入到model_fn中。例(auto_batch的缺省值为True):

def input_fn(mode, **kwargs):

...

return image, label

def model_fn(inputs, mode, **kwargs):

images, labels = inputs

...

mox.run(...

batch_size=32,

...)

input_fn的返回值:image是一个[224, 224, 3]的Tensor,label是一个[1000]的Tensor。

model_fn的输入参数:images是一个[32, 224, 224, 3], labels是一个[32, 1000]的Tensor。

2) auto_batch=False

当auto_batch为False时,用户就需要自己在input_fn中将组织batch。注意:不论auto_batch的值是什么,mox.run中的batch_size都必须填写(用于计算运行时吞吐量)。例:

def input_fn(mode, **kwargs):

...

return images, labels

def model_fn(inputs, mode, **kwargs):

images, labels = inputs

...

mox.run(...

auto_batch=False,

batch_size=32,

...)

input_fn的返回值:images是一个[32, 224, 224, 3]的Tensor,label是一个[32, 1000]的Tensor。

model_fn的输入参数:images是一个[32, 224, 224, 3], labels是一个[32, 1000]的Tensor。

1 读取图像分类数据集Raw Data

基本使用方法:

def input_fn(mode, **kwargs):

meta = mox.ImageClassificationRawMetadata(base_dir='/export1/flowers/raw/split/train')

dataset = mox.ImageClassificationRawDataset(meta)

image, label = dataset.get(['image', 'label'])

# 将图片resize到相同大小并添加shape信息,或者还可以增加一些数据增强方法。

image = tf.expand_dims(image, 0)

image = tf.image.resize_bilinear(image, [224, 224])

image = tf.squeeze(image)

image.set_shape([224, 224, 3])

return image, label

数据集必须是如下目录结构的:

base_dir

|- label_0

|- 0_0.jpg

|- 0_1.jpg

|- 0_x.jpg

|- label_1

|- 1_0.jpg

|- 1_1.jpg

|- 1_y.jpg

|- label_m

|- m_0.jpg

|- m_1.jpg

|- m_z.jpg

|- labels.txt

其中label_0, label_1, ..., label_m代表(m+1)个分类,第i个分类的名称即为label_i。 labels.txt是一个label_index到label_string的映射,可以提供也可以不提供。labels.txt必须是如下内容:

0: label_0

1: label_1

...

m: label_m

也就是当模型输出的label值为i时(训练或预测),对应的label名称是label_i。

利用训练好的模型做预测服务时,发现正确率非常低。

当使用纯图像文件数据集时,如果labels.txt没有提供,存储数据集的文件系统对分类目录的排序顺序即为label的顺序,比如在用户存储的文件系统中数据集以以下顺序排列(也就是os.listdir得到的list中的顺序):

base_dir

|- label_0

|- label_1

|- label_10

|- label_11

|- label_2

...

则等效于labels.txt中写入内容:

0: label_0

1: label_1

2: label_10

3: label_11

4: label_2

...

但是有可能在预测服务的客户端中又以另一种完全不同的映射顺序将服务端返回的label_id值转换成label_string,导致预测结果不准确。为了防止这种情况的发生,最好提供labels.txt,用户能更好的掌握服务端返回值和实际预测结果的映射关系。

如果在input_fn中涉及多个数据集,如训练集、验证集等,使用mode将input_fn的返回值做分支判断,MoXing中使用常量mox.ModeKeys来定义模式,分别有:

训练态:mox.ModeKeys.TRAIN

验证态:mox.ModeKeys.EVAL

预测态:mox.ModeKeys.PREDICT

导出态: mox.ModeKeys.EXPORT

由MoXing内部使用,在阐述模型部分的章节说明。例:

def input_fn(mode, **kwargs):

if mode == mox.ModeKeys.TRAIN:

meta = mox.ImageClassificationRawMetadata(base_dir='/export1/flowers/raw/split/train')

else:

meta = mox.ImageClassificationRawMetadata(base_dir='/export1/flowers/raw/split/eval')

dataset = mox.ImageClassificationRawDataset(meta)

image, label = dataset.get(['image', 'label'])

...

return image, label

2 读取tfrecord

读取tfrecord文件和生成tfrecord文件的代码是相关的,tfrecord文件中以键值对的形式存放了数据。

例:考虑读取一个key值含有image和label的tfrecord,image和label都以字节流的形式储存于tfrecord文件中:

import tensorflow as tf

import moxing.tensorflow as mox

slim = tf.contrib.slim

keys_to_features = {

'image': tf.FixedLenFeature(shape=(), dtype=tf.string, default_value=None),

'label': tf.FixedLenFeature(shape=(), dtype=tf.string, default_value=None),

}

items_to_handlers = {

'image': slim.tfexample_decoder.Tensor('image'),

'label': slim.tfexample_decoder.Tensor('label'),

}

dataset = mox.get_tfrecord(dataset_dir='/xxx',

file_pattern='*.tfrecord',

keys_to_features=keys_to_features,

items_to_handlers=items_to_handlers)

image, label = dataset.get(['image', 'label'])

例:考虑读取一个key值含有image/encoded, image/format, image/class/label的tfrecord,并同时将image从字节流解码为像素值张量:

import tensorflow as tf

import moxing.tensorflow as mox

slim = tf.contrib.slim

keys_to_features = {

'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),

'image/format': tf.FixedLenFeature((), tf.string, default_value=''),

'image/class/label': tf.FixedLenFeature(

[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),

}

items_to_handlers = {

'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1),

'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),

}

dataset = mox.get_tfrecord(dataset_dir='/xxx’,

file_pattern='*.tfrecord',

keys_to_features=keys_to_features,

items_to_handlers=items_to_handlers)

image, label = dataset.get(['image', 'label'])

3 利用tf.data模块读取任意数据集

用户实现数据集类my_dataset,提供next()方法获取下一份数据,可以是一个batch的samples也可以是单个sample,用auto_batch来做控制。基本写法如下:

import tensorflow as tf

import moxing.tensorflow as tf

import my_dataset

def input_fn(run_mode, **kwargs):

def gen():

while True:

yield my_dataset.next()

ds = tf.data.Dataset.from_generator(

gen,

output_types=(tf.float32, tf.int64),

output_shapes=(tf.TensorShape([224, 224, 3]), tf.TensorShape([1000])))

return ds.make_one_shot_iterator().get_next()

在使用这种方法时,由于数据的产生顺序完全取决于用户实现的代码,MoXing无法保证数据的shuffle,所以用户必须确保自己提供的my_dataset.next()具有数据随机性。

4 数据增强

MoXing提供了部分的数据增强方法,这些数据增强方法都是和模型名称绑定,如:

data_augmentation_fn = mox.get_data_augmentation_fn(

name='resnet_v1_50', run_mode=mox.ModeKeys.TRAIN,

output_height=224, output_width=224)

image = data_augmentation_fn(image)

即获取一个resnet_v1_50模型在训练态时对应的数据增强方法。

用户也可以自定义数据增强方法:

def input_fn(mode, **kwargs):

...

image, label = dataset.get(['image', 'label'])

image = my_data_augmentation_fn(image)

return image, label

需要注意的是:从dataset.get()中获取的image如果没有shape信息,甚至每张图片的大小不一致,可能会导致后续的算子出现错误;所以推荐在对image操作之前,将image的size统一(当模型有batch_size维度时,要求输入数据的shape必须相同),并将shape信息进行补全。如:

def input_fn(mode, **kwargs):

...

image, label = dataset.get(['image', 'label'])

# 将image统一至[224, 224, 3]的大小并补全shape信息

image = tf.expand_dims(image, 0)

image = tf.image.resize_bilinear(image, [224, 224])

image = tf.squeeze(image)

image.set_shape([224, 224, 3])

# 调用自定义数据增强方法,如水平翻转

image = tf.image.flip_left_right(image)

return image, label

运行作业日志提示如下信息,并经过很长时间都没有反应。

INFO:tensorflow:Find tfrecord files. Using tfrecord files in this job.

INFO:tensorflow:Automatically extracting num_samples from tfrecord. If the dataset is large, it may take some time. You can also manually specify the num_samples to Dataset to save time.

这个现象的原因是用户使用的tfrecord文件作为数据集,MoXing在扫描tfrecord文件并抽取总样本数量的值,如果tfrecord文件所在位置是一个网络文件系统,而该文件系统的IO速度不高,很可能在这一步会停留很久。

解决办法:根据用户数据集的实际情况填写tfrecord文件的总样本数量。

可能涉及的API:

① mox.get_tfrecord

mox.get_tfrecord(..., num_samples=1000, ...)

② 所有BaseTFRecordMetadata类以及其子类:

BaseTFRecordMetadata(..., num_samples=1000, ...)

③ DLS服务中的预置模型库:

当使用的是未划分的单数据集时,即train或eval数据集,手动指定运行参数:samples_per_epoch,表示所选数据集中的总样本数量。

当使用的是划分好的数据集时,即train和eval数据集,手动指定运行参数:samples_per_epoch和samples_per_epoch_eval,分别表示所选train数据集和eval数据集中的总样本数量。

查看MoXing系列文章请关注我。

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接