python运用深度学习生活垃圾分类,acc达到94以上
程序员文章站
2024-03-19 16:24:46
...
python运用深度学习生活垃圾分类,acc达到94以上
导入库
%matplotlib inline
import yaml
import sys,time
import string
import json
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.keras.models import model_from_yaml
import pylab
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import PIL
import tensorflow as tf
import numpy as np
import os
import scipy.io as io
from multiprocessing import Pool
from tensorflow.python.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping
from tensorflow.python.keras.preprocessing import image
from PIL import Image
import matplotlib.image as mpimg
from tensorflow.python.keras.callbacks import TensorBoard
from tensorflow.python.keras import layers
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.keras.utils import to_categorical
from tensorflow.python.keras.initializers import glorot_normal
from tensorflow.python.keras.applications import resnet50
from tensorflow.python.keras.applications.xception import Xception
from tensorflow.python.keras.models import Model, Sequential,load_model
from tensorflow.python.keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D,Flatten,AveragePooling2D,add,BatchNormalization,Convolution2D,ZeroPadding2D,Reshape,Activation, Dense,Lambda,Conv2D,MaxPool2D, Flatten, Dropout,MaxPooling2D,Dense,concatenate,GlobalAveragePooling2D
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.optimizers import Adam, RMSprop,Adagrad,Nadam,SGD
from tensorflow.python.keras import backend, initializers,regularizers,callbacks,Input,optimizers
训练集和测试集目录
train_dir = '/media/design/0D7907FA0D7907FA/数据集3/garbage_classify_v2/data/train/'
test_dir = '/media/design/0D7907FA0D7907FA/数据集3/garbage_classify_v2/data/test/'
数据预处理
Keras 用一个称为数据生成器(data-generator)输入数据到神经网络中,它将在整个数据上循环。
我们有一个小型的训练集所以它通过对图像进行各种变换来人为地增加它的数量。我们使用内置的数据生成器,可以进行这些随机转换。这也被称为数据增强。
datagen_train = ImageDataGenerator(
rescale=1./255 ,
#rotation_range=180,
#width_shift_range=0.1,
#height_shift_range=0.1,
#shear_range=0.1,
#zoom_range=[0.9, 1.5],
#horizontal_flip=True,
#vertical_flip=True,
#fill_mode='nearest'
datagen_test = ImageDataGenerator(rescale=1./255)
现在我们创建数据生成器的实例,它将从硬盘中读取文件,缩放图像并返回随机的批。
generator_train = datagen_train.flow_from_directory(directory=train_dir,
target_size=(299,299),
batch_size=batch_size,
shuffle=True,
class_mode='categorical',
save_to_dir=save_to_dir)
generator_test = datagen_test.flow_from_directory(directory=test_dir,
target_size=(299,299),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
cls_train = generator_train.classes
cls_test = generator_test.class_indices
generator_train.class_indices
画一些图像检测数据是否正确
建立模型: Xception-v3
训练后导入训练后的模型预测
测试集分类acc达到94.83%(q568897492)
上一篇: 生气!面试官你过来,我给你手写一个SpringAop!
下一篇: 程序员彩虹屁指南