self.params的使用注意点
程序员文章站
2022-03-11 21:14:01
from mxnet.gluon import nnfrom mxnet import ndclass MyDense(nn.HybridBlock): def __init__(self, units, in_units, **kwargs): super().__init__(**kwargs) self.embedding = nn.Embedding(3, 5) self.weight = self.params.get('weight....
from mxnet.gluon import nn from mxnet import nd class MyDense(nn.HybridBlock): def __init__(self, units, in_units, **kwargs): super().__init__(**kwargs) self.embedding = nn.Embedding(3, 5) self.weight = self.params.get('weight', shape=(in_units, units)) self.bias = self.params.get('bias', shape=(units,)) def hybrid_forward(self, F, x, weight,bias): #linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(ctx=x.ctx) # print(self.embedding) linear = F.dot(x, weight) + bias return F.relu(linear) dense = MyDense(units=3,in_units=5) dense.initialize() print(dense(nd.random.uniform(shape=(2, 5)))) print('hybrid_forward success!') from mxnet.gluon import nn from mxnet import nd class MyDense(nn.Block): def __init__(self, units, in_units, **kwargs): super().__init__(**kwargs) self.embedding = nn.Embedding(3, 5) self.weight = self.params.get('weight', shape=(in_units, units)) self.bias = self.params.get('bias', shape=(units,)) def forward(self, x, weight,bias): # 这种注释方式可以 #linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(ctx=x.ctx) # print(self.embedding) linear = nd.dot(x, weight) + bias return nd.relu(linear) dense = MyDense(units=3,in_units=5) dense.initialize() try: print(dense(nd.random.uniform(shape=(2, 5)))) except Exception as e: print('some error:',e) print('forward fail!')
结果:
[[0. 0.14012612 0.0058622 ]
[0. 0.12333627 0.063691 ]]
<NDArray 2x3 @cpu(0)>
hybrid_forward success!
some error: forward() missing 2 required positional arguments: 'weight' and 'bias'
forward fail!
本文地址:https://blog.csdn.net/sinat_24395003/article/details/109644137
下一篇: java实现扫雷小游戏