欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  移动技术

Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用

程序员文章站 2022-04-30 19:51:37
对上一篇博客中代码略做修改,在训练完成之后进行模型导出操作 # y = x^2 + 1 import tensorflow as tf import numpy as...

对上一篇博客中代码略做修改,在训练完成之后进行模型导出操作

# y = x^2 + 1

import tensorflow as tf
import numpy as np
import random

def get_batch(size=128):
    xs = []
    ys = []
    for i in range(size):
        x = random.random() * 2
        y = x * x + 1
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)




X = tf.placeholder(tf.float32, [None,1], name='input')
Y = tf.placeholder(tf.float32, [None,1])
def my_dnn():
    x = tf.reshape(X, shape=[-1, 1])
    w1 = tf.Variable(tf.random_normal(shape=[1,256], mean=0.0,
                                      stddev=1))
    b1 = tf.Variable(tf.random_normal([256]))
    out1 = tf.nn.bias_add(tf.matmul(x,w1),b1)
    out1 = tf.nn.relu(out1)
    w2= tf.Variable(tf.random_normal(shape=[256,256]))
    b2 = tf.Variable(tf.random_normal([256]))
    out2= tf.nn.bias_add(tf.matmul(out1, w2),b2)
    out2 = tf.nn.relu(out2)
    w3 = tf.Variable(tf.random_normal(shape=[256, 256]))
    b3 = tf.Variable(tf.random_normal([256]))
    out3 = tf.nn.bias_add(tf.matmul(out2, w3),b3)
    out3 = tf.nn.relu(out3)
    w4 = tf.Variable(tf.random_normal(shape=[256, 1]))
    b4 = tf.Variable(tf.random_normal([1]))
    out4 = tf.nn.bias_add(tf.matmul(out3, w4), b4, name='output')


    return out4
def train():
    out = my_dnn()
    loss = tf.reduce_mean(tf.square(Y - out))
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        step = 0
        while True:
            batch_x, batch_y = get_batch(64)
            batch_x = batch_x.reshape([-1, 1])
            batch_y = batch_y.reshape([-1, 1])
            _, loss_ = sess.run([optimizer, loss], feed_dict={X:batch_x, Y:batch_y})
            print(loss_)
            if loss_ < 0.0001:
                saver.save(sess, "./1.model", global_step=step)
                break
            step += 1


# train()

def eval():
    out = my_dnn()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('.'))
        for i in range(100):
            x = random.random() * 2
            x = np.array([x]).reshape([-1,1])
            y = sess.run(out, feed_dict={X:x})
            print("x=%.5f 正确的y=%.5f 预测的 y=%.5f" % (x, x*x + 1, y))
def exportModel():
    out = my_dnn()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 恢复模型参数
        saver.restore(sess, tf.train.latest_checkpoint('.'))
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
        with tf.gfile.FastGFile('1.pb', mode='wb') as f:
            f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
    # 训练
    # train()

    # 评估
    # eval()

    # 导出模型
    exportModel()

新建一个Android项目
Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用
导入tensorflow-mobile的库
可以选择导在线的库,在这里导入离线的库
Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用
我添加1.6.0版本的,修改了gradle文件,完成了添加

Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用

添加模型文件
Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用

编写tensorflow mobile API的封装
Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用

最后在Activity调就可以了
Tensorflow:Android调用Tensorflow Mobile版本API:基于Android的调用