keras如何创建自己的generator
程序员文章站
2022-06-12 16:57:44
...
keras.utils.Sequence 是一个基类
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
`on_epoch_end`. The method `__getitem__` should return a complete batch.
# Notes
`Sequence` are a safer way to do multiprocessing. This structure guarantees
that the network will only train once on each sample per epoch which is not
the case with generators.
# Examples
```python
from skimage.io import imread
from skimage.transform import resize
import numpy as np
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
```
"""
use_sequence_api = True
@abstractmethod
def __getitem__(self, index):
"""Gets batch at position `index`.
# Arguments
index: position of the batch in the Sequence.
# Returns
A batch
"""
raise NotImplementedError
@abstractmethod
def __len__(self):
"""Number of batch in the Sequence.
# Returns
The number of batches in the Sequence.
"""
raise NotImplementedError
def on_epoch_end(self):
"""Method called at the end of every epoch.
"""
pass
def __iter__(self):
"""Create a generator that iterate over the Sequence."""
for item in (self[i] for i in range(len(self))):
yield item
所以我们需要去实现他的方法__len__,__getitem__
import keras
import numpy as np
import cv2
class DataGenerator(keras.utils.Sequence):
def __init__(self,img_files=None, labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
self.dim = dim
self.batch_size = batch_size
self.img_files = img_files
self.labels = labels
self.n_classes = n_classes
self.shuffle = shuffle
self.cls2id={item:i for i,item in enumerate(sorted(list(set(self.labels))))}
self.on_epoch_end()
从keras.utils.Sequence类继承,__init__函数主要负责初始化img_files,labels等
其中self.on_epoch_end()主要用来初始化index,用于generator后续索引
def on_epoch_end(self):
self.indexes = np.arange(len(self.labels))
if self.shuffle == True:
np.random.shuffle(self.indexes)
__getitem__用来得到每一个batch的数据
def __getitem__(self,index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
X, y =self.__data_generation(indexes)
return X,y
__data_generation主要是迭代读入图片,对图片预处理
def __data_generation(self,list_IDs_temp):
X = np.empty((self.batch_size,*self.dim))
y = np.empty((self.batch_size),dtype=int)
for i, ID in enumerate(list_IDs_temp):
img = cv2.imread(self.img_files[ID]).astype('float')
label = self.labels[ID]
img = cv2.resize(img,(224,224))
img = img/255
X[i,] = img
y[i] = self.cls2id[label]
return X,keras.utils.to_categorical(y,num_classes=self.n_classes)
完整训练代码如下
import keras
import numpy as np
import cv2
class DataGenerator(keras.utils.Sequence):
def __init__(self,img_files=None, labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
self.dim = dim
self.batch_size = batch_size
self.img_files = img_files
self.labels = labels
self.n_classes = n_classes
self.shuffle = shuffle
self.cls2id={item:i for i,item in enumerate(sorted(list(set(self.labels))))}
self.on_epoch_end()
def __len__(self):
return int(len(self.img_files)/self.batch_size)
def on_epoch_end(self):
self.indexes = np.arange(len(self.labels))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __getitem__(self,index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
X, y =self.__data_generation(indexes)
return X,y
def __data_generation(self,list_IDs_temp):
X = np.empty((self.batch_size,*self.dim))
y = np.empty((self.batch_size),dtype=int)
for i, ID in enumerate(list_IDs_temp):
img = cv2.imread(self.img_files[ID]).astype('float')
label = self.labels[ID]
img = cv2.resize(img,(224,224))
img = img/255
X[i,] = img
y[i] = self.cls2id[label]
return X,keras.utils.to_categorical(y,num_classes=self.n_classes)
import osimport os
files = []
ids = []
for item in os.listdir('dog_classes/train/'):
for dog in os.listdir('dog_classes/train/'+item):
files.append('dog_classes/train/'+item+'/'+dog)
ids.append(item)
ids = []
for item in os.listdir('dog_classes/train/'):
for dog in os.listdir('dog_classes/train/'+item):
files.append('dog_classes/train/'+item+'/'+dog)
ids.append(item)
from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D, Activation, Flatten, Dropout, BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import h5py
##构建模型,用imagenet初始化
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
my_new_model.add(Dense(512))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.add(Dense(3, activation='softmax'))
my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
from tensorflow.python.keras.applications.resnet50 import preprocess_input
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
image_size = 224
bs = 32
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
my_new_model.add(Dense(512))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.add(Dense(11, activation='softmax'))
my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
my_new_model.fit_generator(generator=cls_gen,
steps_per_epoch=40,
epochs=50
)
my_new_model.fit_generator(generator=cls_gen,
steps_per_epoch=40,
epochs=50
)
下一篇: 中文情感分析实例---WordNet