欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet的gluon Trainer接受多个网络参数

程序员文章站 2024-03-14 20:20:29
...

 参考链接:https://discuss.gluon.ai/t/topic/3087/4

from mxnet import ndarray as nd
from mxnet import autograd
from mxnet import gluon

num_inputs = 2
num_examples = 1000

true_w = [2, -3.4]
true_b = 4.2

X = nd.random_normal(shape=(num_examples, num_inputs))
y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b
y += .01 * nd.random_normal(shape=y.shape)

batch_size = 10
dataset = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
#定义第一个网络
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1))
net.initialize()
#定义第二个网络
net1 = gluon.nn.Sequential()
net1.add(gluon.nn.Dense(1))
net1.initialize()

square_loss = gluon.loss.L2Loss()
#将两个网络的参数进行合并
lst = list(net.collect_params().values()) + list(net1.collect_params().values())
trainer = gluon.Trainer(lst, 'sgd', {'learning_rate': 0.1})

epochs = 5
batch_size = 10
for e in range(epochs):
    total_loss = 0
    total_loss1 = 0
    for data, label in data_iter:
        with autograd.record():
            output = net(data)
            loss = square_loss(output, label)
            output1 = net1(data)
            loss1 = square_loss(output1, label)
        autograd.backward([loss, loss1])
        trainer.step(batch_size)
        total_loss += nd.sum(loss).asscalar()
        total_loss1 += nd.sum(loss1).asscalar()
    print("Epoch %d, average loss: %f" % (e, total_loss/num_examples))
    print("Epoch %d, average loss1: %f" % (e, total_loss1/num_examples))

mxnet的gluon Trainer接受多个网络参数