mxnet的gluon Trainer接受多个网络参数
程序员文章站
2024-03-14 20:20:29
...
参考链接:https://discuss.gluon.ai/t/topic/3087/4
from mxnet import ndarray as nd
from mxnet import autograd
from mxnet import gluon
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
X = nd.random_normal(shape=(num_examples, num_inputs))
y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b
y += .01 * nd.random_normal(shape=y.shape)
batch_size = 10
dataset = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
#定义第一个网络
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1))
net.initialize()
#定义第二个网络
net1 = gluon.nn.Sequential()
net1.add(gluon.nn.Dense(1))
net1.initialize()
square_loss = gluon.loss.L2Loss()
#将两个网络的参数进行合并
lst = list(net.collect_params().values()) + list(net1.collect_params().values())
trainer = gluon.Trainer(lst, 'sgd', {'learning_rate': 0.1})
epochs = 5
batch_size = 10
for e in range(epochs):
total_loss = 0
total_loss1 = 0
for data, label in data_iter:
with autograd.record():
output = net(data)
loss = square_loss(output, label)
output1 = net1(data)
loss1 = square_loss(output1, label)
autograd.backward([loss, loss1])
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
total_loss1 += nd.sum(loss1).asscalar()
print("Epoch %d, average loss: %f" % (e, total_loss/num_examples))
print("Epoch %d, average loss1: %f" % (e, total_loss1/num_examples))
下一篇: 动态代理-实现原理(jdk方式)