深度学习之“Transfer Learning”
程序员文章站
2024-03-15 11:39:47
...
代码:
from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D
num_classes = 2#classes
resnet_weights_path = 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights='imagenet'))
my_new_model.add(Dense(num_classes, activation='softmax'))
# Say not to train first layer (ResNet) model. It is already trained
my_new_model.layers[0].trainable = False
# We are calling the compile command for some python object.
# Which python object is being compiled? Fill in the answer so the compile command works.
my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator
image_size = 224
#data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)会出错
data_generator = ImageDataGenerator()
train_generator = data_generator.flow_from_directory(
directory = '..\\Using Transfer Learning\\images\\train',
target_size=(image_size, image_size),
shuffle=True,
batch_size=22,
class_mode='categorical')
print('classes of train_generator:',train_generator.class_indices)
validation_generator = data_generator.flow_from_directory(
directory ='..\\Using Transfer Learning\\images\\val',
target_size=(image_size, image_size),
class_mode='categorical')
print('classes of train_generator:',validation_generator.class_indices)
my_new_model.fit_generator(
train_generator,
epochs=1,
steps_per_epoch=4,
validation_data=validation_generator,
validation_steps=6)
结果:
、
环境配置文件:
https://pan.baidu.com/s/1fBzSbJekdorXo7ZRGSZzig tzvm
参考文档:
https://www.kaggle.com/dansbecker/exercise-using-transfer-learning/notebook
https://www.kaggle.com/dansbecker/transfer-learning/notebook
https://keras-cn.readthedocs.io/en/latest/preprocessing/image/#imagedatagenerator
https://keras-cn.readthedocs.io/en/latest/models/sequential/#fit_generator
下一篇: DL代码参考
推荐阅读
-
深度学习之“Transfer Learning”
-
spark-BigDl:深度学习之lenet5
-
深度学习tensorflow之softmax(二)手写数字识别底层实现
-
Machine Learning In Action 学习笔记之 KNN算法
-
h5文件之深度学习数据集制作
-
TensorFlow深度学习之卷积神经网络CNN
-
深度学习开源框架基础算法之傅立叶变换的概要介绍
-
深度学习,分割后处理之连通成分分析-Connected-Components
-
从零开始深度学习0611——pytorch入门之Pytorch 与 numpy 区别+variable+activation+regression+classification+快速搭建
-
深度学习之PyTorch学习_3.2 线性回归的从零开始实现