欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python运用深度学习生活垃圾分类,acc达到94以上

程序员文章站 2024-03-19 16:24:46
...

导入库

%matplotlib inline
import yaml
import sys,time
import string
import json
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.keras.models import model_from_yaml
import pylab
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import PIL
import tensorflow as tf
import numpy as np
import os 
import scipy.io as io
from multiprocessing import Pool
from tensorflow.python.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping
from tensorflow.python.keras.preprocessing import image
from PIL import Image
import matplotlib.image as mpimg
from tensorflow.python.keras.callbacks import TensorBoard
from tensorflow.python.keras import layers
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.keras.utils import to_categorical
from tensorflow.python.keras.initializers  import glorot_normal
from tensorflow.python.keras.applications import  resnet50
from tensorflow.python.keras.applications.xception  import Xception 
from tensorflow.python.keras.models import Model, Sequential,load_model
from tensorflow.python.keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D,Flatten,AveragePooling2D,add,BatchNormalization,Convolution2D,ZeroPadding2D,Reshape,Activation, Dense,Lambda,Conv2D,MaxPool2D, Flatten, Dropout,MaxPooling2D,Dense,concatenate,GlobalAveragePooling2D
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.optimizers import Adam, RMSprop,Adagrad,Nadam,SGD
from tensorflow.python.keras import backend, initializers,regularizers,callbacks,Input,optimizers

训练集和测试集目录

train_dir = '/media/design/0D7907FA0D7907FA/数据集3/garbage_classify_v2/data/train/'
test_dir = '/media/design/0D7907FA0D7907FA/数据集3/garbage_classify_v2/data/test/'

数据预处理

Keras 用一个称为数据生成器(data-generator)输入数据到神经网络中,它将在整个数据上循环。
我们有一个小型的训练集所以它通过对图像进行各种变换来人为地增加它的数量。我们使用内置的数据生成器,可以进行这些随机转换。这也被称为数据增强。

datagen_train = ImageDataGenerator(
      rescale=1./255  ,
      #rotation_range=180,
      #width_shift_range=0.1,
      #height_shift_range=0.1,
      #shear_range=0.1,
      #zoom_range=[0.9, 1.5], 
      #horizontal_flip=True,
      #vertical_flip=True,
      #fill_mode='nearest'
datagen_test = ImageDataGenerator(rescale=1./255)

现在我们创建数据生成器的实例,它将从硬盘中读取文件,缩放图像并返回随机的批。

generator_train = datagen_train.flow_from_directory(directory=train_dir,
                                                    target_size=(299,299),
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    class_mode='categorical',
                                                    save_to_dir=save_to_dir)
generator_test = datagen_test.flow_from_directory(directory=test_dir,
                                                  target_size=(299,299),
                                                  batch_size=batch_size,
                                                  class_mode='categorical',
                                                  shuffle=False)
cls_train = generator_train.classes
cls_test = generator_test.class_indices

generator_train.class_indices

画一些图像检测数据是否正确

python运用深度学习生活垃圾分类,acc达到94以上

建立模型: Xception-v3

python运用深度学习生活垃圾分类,acc达到94以上

训练后导入训练后的模型预测

python运用深度学习生活垃圾分类,acc达到94以上
python运用深度学习生活垃圾分类,acc达到94以上
python运用深度学习生活垃圾分类,acc达到94以上
python运用深度学习生活垃圾分类,acc达到94以上

测试集分类acc达到94.83%(q568897492)