欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

神经网络----代码实现鸢尾花分类

程序员文章站 2024-03-19 16:51:40
...

运行环境:python3.7+tensorflow2.0

1.首先加载鸢尾花数据集,读入输入特征以及标签

直接采用load函数加载:

x_data = datasets.load_iris().data  # .data返回iris数据集所有输入特征
y_data = datasets.load_iris().target  # .target返回iris数据集所有标签

2.为保证准确性,对数据集进行打乱

 使用相同的seed,保证输入特征和标签一一对应

np.random.seed(116) 
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)

3. 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行

x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

4.对x的数据类型进行,否则后面矩阵相乘时会因数据类型不一致报错,利用from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据,这里分为32组)

x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

5.定义神经网络可训练参数

w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

6.训练函数

for epoch in range(epoch): 
    for step, (x_train, y_train) in enumerate(train_db): 
        with tf.GradientTape() as tape:  
            y = tf.matmul(x_train, w1) + b1  
            y = tf.nn.softmax(y) 
            y_ = tf.one_hot(y_train, depth=3)  
            loss = tf.reduce_mean(tf.square(y_ - y))  
            loss_all += loss.numpy()  
        grads = tape.gradient(loss, [w1, b1])

        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1]) 

    print("Epoch {}, loss: {}".format(epoch, loss_all/4))
    train_loss_results.append(loss_all / 4) 
    loss_all = 0  

7.测试函数

 total_correct, total_number = 0, 0
    for x_test, y_test in test_db:
        y = tf.matmul(x_test, w1) + b1
        y = tf.nn.softmax(y)
        pred = tf.argmax(y, axis=1)  
        pred = tf.cast(pred, dtype=y_test.dtype)
        correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
        correct = tf.reduce_sum(correct)
        total_correct += int(correct)
        total_number += x_test.shape[0]
    acc = total_correct / total_number
    test_acc.append(acc)
    print("Test_acc:", acc)
    print("--------------------------")

8.运行结果:

神经网络----代码实现鸢尾花分类

准确度曲线: 

神经网络----代码实现鸢尾花分类

损失函数曲线:

神经网络----代码实现鸢尾花分类

 

源码: