欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

计算机视觉(python)tensorflow2.0框架下CNN做图像分类

程序员文章站 2023-11-01 14:24:28
本文将一个完整的tf2.0框架下使用CNN模型解决图像分类问题import globimport os import cv2import numpy as npimport randomimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import losses,layers,optimizersfrom tensorflow.keras.callbacks import EarlyStop...

本文将一个完整的tf2.0框架下使用CNN模型解决图像分类问题 喜欢记得关注我 点收藏不迷路 辛苦整理免费分享的

import glob
import os 
import cv2
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import losses,layers,optimizers
from tensorflow.keras.callbacks import EarlyStopping


tf.random.set_seed(2222)
np.random.seed(2222)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

def Data_Generation():
    X_data=[];Y_data=[]
    path_data=[];path_label=[]

    #path_file=os.getcwd()
    files=os.listdir('pokemon')
    
    for file in files:
        print(file)
        for path in glob.glob('pokemon/'+file+'/*.*'):
            if 'jpg' or 'png' or 'jpeg' in path:
                path_data.append(path)  
            
    

    random.shuffle (path_data)  #打乱数据
   
    for paths in path_data:
        if 'bulbasaur' in paths:
            path_label.append(0)
        elif 'charmander' in paths:
            path_label.append(1)
        elif 'mewtwo' in paths:
            path_label.append(2)
        elif 'pikachu' in paths:
            path_label.append(3)
        elif 'squirtle' in paths:
            path_label.append(4)
            
        img=cv2.imread(paths)
        img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img=cv2.resize(img,(224,224))
        X_data.append(img)
    
    L=len(path_data)
    Y_data=path_label
    X_data=np.array(X_data,dtype=float)
    Y_data=np.array(Y_data,dtype='uint8')
    X_train=X_data[0:int(L*0.6)]
    Y_train=Y_data[0:int(L*0.6)]
    X_valid=X_data[int(L*0.6):int(L*0.8)]
    Y_valid=Y_data[int(L*0.6):int(L*0.8)]
    X_test=X_data[int(L*0.8):]
    Y_test=Y_data[int(L*0.8):]
    return X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L




def normalize(x):
    img_mean = tf.constant([0.485, 0.456, 0.406])
    img_std = tf.constant([0.229, 0.224, 0.225])
    x = (x - img_mean)/img_std
    return x

def preprocess(x,y):
    x=tf.image.resize(x,[244,244])
    x=tf.image.random_flip_left_right(x)
    x=tf.image.random_crop(x,[224,224,3])
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x,y
 


X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L=Data_Generation()

batchsz=32
#print(shape(X_data), shape(Y_data))
train_db = tf.data.Dataset.from_tensor_slices((X_train,Y_train))
train_db = train_db.shuffle(10000).map(preprocess).batch(batchsz)

valid_db = tf.data.Dataset.from_tensor_slices((X_valid,Y_valid))
valid_db = valid_db.map(preprocess).batch(batchsz)

test_db = tf.data.Dataset.from_tensor_slices((X_test,Y_test))
test_db = test_db.map(preprocess).batch(batchsz)

net=keras.applications.DenseNet121(weights='imagenet',include_top=False,pooling='max')#这里使用了自带的DenseNet121网络 你也可以用keras.Sequential DIY模型
net.trainable=False
mynet=keras.Sequential([
    net,
    layers.Dense(1024,activation='relu'),
    layers.BatchNormalization(), #BN层 标准化数据
    layers.Dropout(rate=0.2),
    layers.Dense(5)])

mynet.build(input_shape=(4,224,224,3))
mynet.summary()

early_stopping=EarlyStopping(              #防止过拟合
    monitor='val_accuracy',
    min_delta=0.01,
    patience=3)


mynet.compile(optimizer=optimizers.Adam(lr=1e-3),
              loss=losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
history  = mynet.fit(train_db, validation_data=valid_db, validation_freq=1, epochs=50,
           callbacks=[early_stopping])

history = history.history

mynet.evaluate(test_db)
#训练结束以后保存mmodel文件到本地方便做图片分类的时候直接调用
#way1 保存model成 .pb 格式 方便各个平台(移动端等)的调用
tf.saved_model.save(mynet,'densenet')  
#way2 保存model成 .h5格式 里面包含了模型结构和训练好的模型参数
mynet.save('densenet.h5')

因为设置了随机种子seed 所以每次的train validation test 集都一样,如果结果不满意的话 代码里有很多超参数可以调 

程序运行完成后就可以在代码所在的文件夹里发现多了一个densenet文件夹 打开之后如下图,pb后缀的文件就是我们保存的模型文件

计算机视觉(python)tensorflow2.0框架下CNN做图像分类

也可以生成 .h5 文件 

                                                      计算机视觉(python)tensorflow2.0框架下CNN做图像分类

生成模型后就可以用在网上下载图片来测试了,这里我用了电影大侦探皮卡丘的一张图做测试  pika pika~

                                                                                   计算机视觉(python)tensorflow2.0框架下CNN做图像分类

import tensorflow as tf
from tensorflow import keras
import cv2

label=['bulbasaur','charmander','mewtwo','pikachu','squirtle' ]

network = keras.models.load_model('densenet.h5')
network.summary()

image=cv2.imread('test.jpeg')
img=image.copy()
img=cv2.resize(img,(224,224))
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

def normalize(x):
    img_mean = tf.constant([0.485, 0.456, 0.406])
    img_std = tf.constant([0.229, 0.224, 0.225])
    x = (x - img_mean)/img_std
    return x

def preprocess(x):
    x = tf.expand_dims(x,axis=0)
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    return x

img=preprocess(img)

#img= tf.cast(img, dtype=tf.uint8)

result=network(img)
result=tf.nn.softmax(result)

index=tf.argmax(result,axis=-1)
print(label[int(index)])


cv2.putText(image,label[int(index)],(166,54),cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 1.2, (255,0,0),2)
cv2.imshow('img',image)
cv2.waitKey()
cv2.destroyAllWindows()

运行代码后,就可以看到如下的效果了,图片的上方多了所属类的标签

 

                                                                                             计算机视觉(python)tensorflow2.0框架下CNN做图像分类

本文地址:https://blog.csdn.net/Oscarouyangyafei/article/details/107289684