【tf.keras.utils.Sequence】构建自己的数据集生成器

it2024-01-30  57

every blog every motto: You can do more than you think.

0. 前言

在训练模型时,我们往往不一次将数据全部加载进内存中,而是将数据分批次加载到内存中。


一种方法是用 while True 遍历数据,用yeid产生,具体可参考语义分割代码讲解部分另一种方法是本文即将讲解的tf.keras.utils.Sequence方法

1. 正文

1.1 基础用法

__ len __ 中返回的即1个epoch迭代的次数,即: 总样本数/ batch_size

__ getitem __ 根据len中的迭代次数,生成数据


注意: __ len __ ,__ getitem __ 必须要实现

""" 测试 __getitem__ """ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf class Date(tf.keras.utils.Sequence): def __init__(self): print('初始化相关参数') def __len__(self): """ 此方法要实现,否则会报错 正常程序中返回1个epoch迭代的次数 :return: """ return 5 def __getitem__(self, index): """生成一个batch的数据""" print('index:', index) x_batch = ['x1', 'x2', 'x3', 'x4'] y_batch = ['y1', 'y2', 'y3', 'y4'] print('-'*20) return x_batch, y_batch # 实例化数据 date = Date() for batch_number, (x, y) in enumerate(date): print('正在进行第{} batch'.format(batch_number)) print('x_batch:', x) print('y_batcxh:', y)

结果:

1.2 扩展(2020.11.12 15:37增补)

可以在类中实现on_epoch_end方法,保证在每个epoch后打乱原有数据的顺序

1.2.1 训练样例:

测试代码,如下:

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np print('tensorflow version: ', tf.__version__) class ZerosFirstEpochOnesAfter(tf.keras.utils.Sequence): def __init__(self): self.shuffle = True def __len__(self): return 2 def on_epoch_end(self): print('---------------on_epoch_end------------') # 打乱索引 # if self.shuffle: # print('==============================================================shuffle') # np.random.shuffle(self.indices) def __getitem__(self, item): return np.zeros((16, 1)), np.zeros((16,)) def main(): model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(1, input_dim=1, activation="softmax")) model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'] ) model.fit(ZerosFirstEpochOnesAfter(), epochs=3, ) if __name__ == '__main__': main()

tensorflow 2.0:

tensorflow 2.1:

tesorflow 2.3: 由以上三个版本的训练结果,我们可以发现,

在2.0和2.1版本中,是没有进行on_epoch_end方法调用的,即没有实现on_epoch_end方法内注释部分的打乱顺序,这是tensorflow早期版本的一个bug,具体可参考文后第4个链接。在2.3版本中已得到改进

1.2.2 循环遍历:

1.2.2.1 原始版测试

循环遍历,如下所示:

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np print('tensorflow version: ', tf.__version__) class Date(tf.keras.utils.Sequence): def __init__(self): print('初始化相关参数') self.lines = [1,2,3,4,5] self.shuffle = True def __len__(self): """ 此方法要实现,否则会报错 正常程序中返回1个epoch迭代的次数 :return: """ return 2 def on_epoch_end(self): print('=======================') if self.shuffle == True: print('------------一个epoch结束,打乱了顺序---') np.random.shuffle(self.lines) def __getitem__(self, index): """生成一个batch的数据""" print('index:', index) x_batch = ['x1', 'x2', 'x3', 'x4'] y_batch = ['y1', 'y2', 'y3', 'y4'] print('-' * 20) return x_batch, y_batch # 实例化数据 date = Date() for epoch in range(2): for batch_number, (x, y) in enumerate(date): print('正在进行第{} batch'.format(batch_number)) print('x_batch:', x) print('y_batcxh:', y) print('一个epoch结束=============================')

结果: 如上图所示,通过循环遍历这种方法仍然不能调用on_epoch_end,即无法打乱顺序

1.2.2.2 改进版

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np print('tensorflow version: ', tf.__version__) class Date(tf.keras.utils.Sequence): def __init__(self): print('初始化相关参数') self.lines = [1,2,3,4,5] self.shuffle = True def __len__(self): """ 此方法要实现,否则会报错 正常程序中返回1个epoch迭代的次数 :return: """ return 2 def on_epoch_end(self): print('=======================') if self.shuffle == True: print('------------一个epoch结束,打乱了顺序---') np.random.shuffle(self.lines) def __getitem__(self, index): """生成一个batch的数据""" print('index:', index) x_batch = ['x1', 'x2', 'x3', 'x4'] y_batch = ['y1', 'y2', 'y3', 'y4'] print('-' * 20) return x_batch, y_batch # 实例化数据 date = Date() for epoch in range(2): print(date.lines) for batch_number, (x, y) in enumerate(date): print('正在进行第{} batch'.format(batch_number)) print('x_batch:', x) print('y_batcxh:', y) np.random.shuffle(date.lines) print('一个epoch结束=============================')

如下图所示,我们发现已经打乱了“样本”顺序,

参考文献

[1] https://blog.csdn.net/weixin_39190382/article/details/105808830 [2] https://blog.csdn.net/weixin_43198141/article/details/89926262 [3] https://blog.csdn.net/u011311291/article/details/80991330 [4] https://github.com/tensorflow/tensorflow/issues/35911 [5] https://colab.research.google.com/gist/bfs15/fd18263f788a071225c60cedaf126748/35911.ipynb

最新回复(0)