神经网络----代码实现鸢尾花分类
程序员文章站
2024-03-19 16:51:40
...
运行环境:python3.7+tensorflow2.0
1.首先加载鸢尾花数据集,读入输入特征以及标签
直接采用load函数加载:
x_data = datasets.load_iris().data # .data返回iris数据集所有输入特征
y_data = datasets.load_iris().target # .target返回iris数据集所有标签
2.为保证准确性,对数据集进行打乱
使用相同的seed,保证输入特征和标签一一对应
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
3. 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
4.对x的数据类型进行,否则后面矩阵相乘时会因数据类型不一致报错,利用from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据,这里分为32组)
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
5.定义神经网络可训练参数
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
6.训练函数
for epoch in range(epoch):
for step, (x_train, y_train) in enumerate(train_db):
with tf.GradientTape() as tape:
y = tf.matmul(x_train, w1) + b1
y = tf.nn.softmax(y)
y_ = tf.one_hot(y_train, depth=3)
loss = tf.reduce_mean(tf.square(y_ - y))
loss_all += loss.numpy()
grads = tape.gradient(loss, [w1, b1])
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
print("Epoch {}, loss: {}".format(epoch, loss_all/4))
train_loss_results.append(loss_all / 4)
loss_all = 0
7.测试函数
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1)
pred = tf.cast(pred, dtype=y_test.dtype)
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
correct = tf.reduce_sum(correct)
total_correct += int(correct)
total_number += x_test.shape[0]
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
8.运行结果:
准确度曲线:
损失函数曲线:
源码:
上一篇: C语言 贪吃蛇
推荐阅读
-
神经网络----代码实现鸢尾花分类
-
PHP 实现301转向代码 博客分类: PHP
-
PHP 实现301转向代码 博客分类: PHP
-
实现一个sizeof获取Java对象大小 博客分类: 代码琐记 Java对象大小内存sizeofhotspot
-
实现Maven自动下载源代码包并关联 博客分类: Maven Maven关联源代码查看源代码maven插件源代码
-
BP神经网络在肺癌分类中应用_附matlab代码 博客分类: 模式识别,人工智能 人工神经网络BPBP神经网络算法matlab
-
深度学习-BP神经网络(python3代码实现)
-
自定义MD5加盐加密方式代码实现 博客分类: java随笔 Securityjava
-
CI框架无限级分类+递归的实现代码
-
PHP从二维数组得到N层分类树的实现代码