欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras如何创建自己的generator

程序员文章站 2022-06-12 16:57:44
...

keras.utils.Sequence 是一个基类

class Sequence(object):
    """Base object for fitting to a sequence of data, such as a dataset.
    Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
    If you want to modify your dataset between epochs you may implement
    `on_epoch_end`. The method `__getitem__` should return a complete batch.
    # Notes
    `Sequence` are a safer way to do multiprocessing. This structure guarantees
    that the network will only train once on each sample per epoch which is not
    the case with generators.
    # Examples
    ```python
        from skimage.io import imread
        from skimage.transform import resize
        import numpy as np
        # Here, `x_set` is list of path to the images
        # and `y_set` are the associated classes.
        class CIFAR10Sequence(Sequence):
            def __init__(self, x_set, y_set, batch_size):
                self.x, self.y = x_set, y_set
                self.batch_size = batch_size
            def __len__(self):
                return int(np.ceil(len(self.x) / float(self.batch_size)))
            def __getitem__(self, idx):
                batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
                batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
                return np.array([
                    resize(imread(file_name), (200, 200))
                       for file_name in batch_x]), np.array(batch_y)
    ```
    """

    use_sequence_api = True

    @abstractmethod
    def __getitem__(self, index):
        """Gets batch at position `index`.
        # Arguments
            index: position of the batch in the Sequence.
        # Returns
            A batch
        """
        raise NotImplementedError

    @abstractmethod
    def __len__(self):
        """Number of batch in the Sequence.
        # Returns
            The number of batches in the Sequence.
        """
        raise NotImplementedError

    def on_epoch_end(self):
        """Method called at the end of every epoch.
        """
        pass

    def __iter__(self):
        """Create a generator that iterate over the Sequence."""
        for item in (self[i] for i in range(len(self))):
            yield item

 所以我们需要去实现他的方法__len__,__getitem__

import keras
import numpy as np
import cv2


class DataGenerator(keras.utils.Sequence):
    def __init__(self,img_files=None, labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
        self.dim = dim
        self.batch_size = batch_size
        self.img_files = img_files
        self.labels = labels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.cls2id={item:i for i,item  in enumerate(sorted(list(set(self.labels))))}
        self.on_epoch_end()
        

从keras.utils.Sequence类继承,__init__函数主要负责初始化img_files,labels等

其中self.on_epoch_end()主要用来初始化index,用于generator后续索引

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.labels))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

__getitem__用来得到每一个batch的数据 

    def __getitem__(self,index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X, y =self.__data_generation(indexes)
        return X,y

__data_generation主要是迭代读入图片,对图片预处理

    def __data_generation(self,list_IDs_temp):
        X = np.empty((self.batch_size,*self.dim))
        y = np.empty((self.batch_size),dtype=int)
        for i, ID in enumerate(list_IDs_temp):
            img = cv2.imread(self.img_files[ID]).astype('float')
            label = self.labels[ID]
            img = cv2.resize(img,(224,224))
            img = img/255
            X[i,] = img
            y[i] = self.cls2id[label]
        return X,keras.utils.to_categorical(y,num_classes=self.n_classes)

完整训练代码如下

import keras
import numpy as np
import cv2


class DataGenerator(keras.utils.Sequence):
    def __init__(self,img_files=None, labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
        self.dim = dim
        self.batch_size = batch_size
        self.img_files = img_files
        self.labels = labels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.cls2id={item:i for i,item  in enumerate(sorted(list(set(self.labels))))}
        self.on_epoch_end()
        
    def __len__(self):
        return int(len(self.img_files)/self.batch_size)
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.labels))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __getitem__(self,index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X, y =self.__data_generation(indexes)
        return X,y
        
    def __data_generation(self,list_IDs_temp):
        X = np.empty((self.batch_size,*self.dim))
        y = np.empty((self.batch_size),dtype=int)
        for i, ID in enumerate(list_IDs_temp):
            img = cv2.imread(self.img_files[ID]).astype('float')
            label = self.labels[ID]
            img = cv2.resize(img,(224,224))
            img = img/255
            X[i,] = img
            y[i] = self.cls2id[label]
        return X,keras.utils.to_categorical(y,num_classes=self.n_classes)


import osimport os
files = []
ids = []
for item in os.listdir('dog_classes/train/'):
    for dog in os.listdir('dog_classes/train/'+item):
        files.append('dog_classes/train/'+item+'/'+dog)
        ids.append(item)
ids = []
for item in os.listdir('dog_classes/train/'):
    for dog in os.listdir('dog_classes/train/'+item):
        files.append('dog_classes/train/'+item+'/'+dog)
        ids.append(item)

from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D, Activation, Flatten, Dropout, BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import h5py
##构建模型,用imagenet初始化
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
my_new_model.add(Dense(512))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.add(Dense(3, activation='softmax'))


my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

from tensorflow.python.keras.applications.resnet50 import preprocess_input
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator

image_size = 224
bs = 32
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
my_new_model.add(Dense(512))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.add(Dense(11, activation='softmax'))


my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

my_new_model.fit_generator(generator=cls_gen,
                           steps_per_epoch=40,
                           epochs=50                      
                            )

my_new_model.fit_generator(generator=cls_gen,
                           steps_per_epoch=40,
                           epochs=50                      
                            )

 

相关标签: AI