欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TFLearn代码示例

程序员文章站 2024-03-14 14:29:04
...
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
import tflearn.datasets.mnist as mnist
x , y, x_test, y_test = mnist.load_data(one_hot=True)
x = x.reshape([-1,28,28,1])
x_test = x_test.reshape([-1,28,28,1])
CNN = input_data(shape=[None, 28, 28, 1],name='input')
CNN = conv_2d(CNNN, 32, 5, activation='relu',regularizer='L2')
CNN = conv_2d(CNN, 32, 5, activation='relu',regularizer='L2')
CNN = max_pool_2d(CNN,2)
CNN = local_response_normalization(CNN)
CNN = conv_2d(CNN,64,5,activation='relu',regularizer='L2')
CNN =max_pool_2d(CNN,2)
CNN = local_response_normalization(CNN)
CNN = fully_connected(CNN,1024,activation=None)
CNN = dropout(CNN,0.5)
CNN = fully_connected(CNN,10,activation='softmax')
CNN = regression(CNN,optimizer='adam',learning_rate=0.0001,loss='categorical_crossentropy',name='target')
model = tflearn.DNN(CNN,tensorboard_verbose=0,tensorboard_dir='MNIST_tflearn_board/',checkpoint_path='MNIST_tflearn_checkpoints/checkpoint')
model.fit({'input':x},{'target':y},n_epoch=3,validation_set=({'input': x_test},{'target': y_test}),snapshot_step=1000,show_metric=True,run_id='convnet_mnist')

#训练集结束后评估
evaluation  = model.evaluate({'input':x_test},{'target': y_test})
print(evaluation) #输出准确率[0.9852]

#查看预测结果
pred = model.predict({'input':x_test})
print((np.argmax(test_y,1)==np.argmax(pred,1)).mean())

#########################################################
#从网址读取图片
import urllib.request
url='https://timgsa.baidu.com/timg?image&quality=80&size=b9999_10000&sec=1567673429047&di=a81560189960acb86bd45ab6b3ff0821&imgtype=0&src=http%3A%2F%2Fa.hiphotos.baidu.com%2Fimage%2Fpic%2Fitem%2F0ff41bd5ad6eddc40189fc4133dbb6fd52663319.jpg'
im_as_string = urllib.request.urlopen(url).read()

im = tf.image.decode_jpeg(im_as_string, channels=3)
#对于png,使用以下代码
im = tf.image.decode_png(im_as_string, channels=3)

#要从计算机加载图像,可以在目标目录中创建一个文件名队列
filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once('./images/*.jpg'))
img_reader = tf.WholeFileReader()
_,imgae_file = img_reader.read(fielname_queue)
image = tf.image.decode_jpeg(image_file)

sess=tf.Session()
sess.run(im)


#########################队列操作
import tensorflow as tf
sess = tf.InteractiveSession
sess = tf.InteractiveSession()
queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.string])
enque_op = queue1.enqueue(["F"])
sess.run(queue1.size())
sess.run(enque_op)
sess.run(queue1.size())
enque_op = queue1.enqueue(['I'])
enque_op.run()
enque_op = queue1.enqueue(['F'])
enque_op.run()
enque_op = queue1.enqueue(['O'])
enque_op.run()
sess.run(queue1.size())
x = queue1.dequeue()
x.eval
x.eval()
x = queue1.dequeue()
x.eval()
x = queue1.dequeue()
x.eval()
x = queue1.dequeue()
x.eval()




##########################多线程
import threading
import time
gen_random_normal = tf.random_normal(shape=())
queue = tf.FIFOQueue(capacity=100,dtypes=[tf.float32],shapes=())
enque = queue.enqueue(gen_random_normal)
def add():
    for i in range(10):
def add():
    for i in range(10):
        sess.run(enque)

#创建多线程列表       
threads = [threading.Thread(target=add, args=()) for i in range(10)]
threads
for t in threads:
    t.start()
    
print(sess.run(queue.size()))
time.sleep(0.01)
print(sess.run(queue.size()))
time.sleep(0.01)
print(sess.run(queue.size()))
x = queue.dequeue_many(10)
print(x.eval())
sess.run(queue.size())
#############################

########################tf.train.Coordinator 多线程协调器
gen_random_normal = tf.random_normal(shape=())
queue = tf.FIFOQueue(capacity=100,dtypes=[tf.float32],shapes=())
enque = queue.enqueue(gen_random_normal)
def add(coord, i):
    while not coord.should_stop():
        sess.run(enque)
        if i == 11:
            coord.request_stop()
            
coord = tf.train.Coordinator()
threads = [threading.Thread(target=add,args=(coord,i)) for i in range(10)]
coord.join(threads)
for t in threads:
    t.start()
    
print(sess.run(queue.size()))

###################使用内置的tf.train.QueueRunner
gen_random_normal = tf.random_normal(shape=())
queue = tf.RandomShuffleQueue(capacity=100,dtypes=[tf.float32],min_after_dequeue=1)
enque_op = queue.enqueue(gen_random_normal)
qr = tf.train.QueueRunner(queue, [enque_op] * 4)
coord = tf.train.Coordinator()
enque_threads = qr.create_threads(sess, coord=coord, start=True)
coord.request_stop()
coord.join(enque_threads)

#############################################完整示例
相关标签: 代码示例