深度学习 neural machine translation with attention 错误解析
程序员文章站
2022-06-12 07:51:13
...
在这次的 练习中,在 load 过模型参数后,进行 example预测时,报错。
以下是代码部分
EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001']
for example in EXAMPLES:
source = string_to_int(example, Tx, human_vocab)
source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
prediction = model.predict([source, s0, c0])
prediction = np.argmax(prediction, axis = -1)
output = [inv_machine_vocab[int(i)] for i in prediction]
print("source:", example)
print("output:", ''.join(output))
以下是输出错误:
ValueError Traceback (most recent call last)
<ipython-input-31-5f0a9dfb7249> in <module>()
4 source = string_to_int(example, Tx, human_vocab)
5 source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
----> 6 prediction = model.predict([source, s0, c0])
7 prediction = np.argmax(prediction, axis = -1)
8 output = [inv_machine_vocab[int(i)] for i in prediction]
E:\Python\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)
1815 x = _standardize_input_data(x, self._feed_input_names,
1816 self._feed_input_shapes,
-> 1817 check_batch_axis=False)
1818 if self.stateful:
1819 if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:
E:\Python\lib\site-packages\keras\engine\training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
111 ': expected ' + names[i] + ' to have ' +
112 str(len(shape)) + ' dimensions, but got array '
--> 113 'with shape ' + str(data_shape))
114 if not check_batch_axis:
115 data_shape = data_shape[1:]
ValueError: Error when checking : expected input_1 to have 3 dimensions, but got array with shape (37, 30)
从错误来看 是 数据维度不对,还有就是 shape (37, 30) 这也不对,应该是 shape(30,37),所以修改后如下(看红色部分)
EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001']
for example in EXAMPLES:
source = string_to_int(example, Tx, human_vocab)
# source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
# prediction = model.predict([source, s0, c0])
source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))) #不能变换 数据维度 ,
ttt=np.expand_dims(source,axis=0) # 在 axis=0的位置 ,增加一个 维度,以适应 输入维度要求
prediction = model.predict([ttt, s0, c0])
prediction = np.argmax(prediction, axis = -1)
output = [inv_machine_vocab[int(i)] for i in prediction]
print("source:", example)
print("output:", ''.join(output))