欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

深度学习 neural machine translation with attention 错误解析

程序员文章站 2022-06-12 07:51:13
...

在这次的 练习中,在 load 过模型参数后,进行 example预测时,报错。


以下是代码部分

EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001']
for example in EXAMPLES:


    source = string_to_int(example, Tx, human_vocab)
    source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
    prediction = model.predict([source, s0, c0])
    prediction = np.argmax(prediction, axis = -1)
    output = [inv_machine_vocab[int(i)] for i in prediction]


    print("source:", example)
    print("output:", ''.join(output))

以下是输出错误:

ValueError                                Traceback (most recent call last)
<ipython-input-31-5f0a9dfb7249> in <module>()
      4     source = string_to_int(example, Tx, human_vocab)
      5     source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
----> 6     prediction = model.predict([source, s0, c0])
      7     prediction = np.argmax(prediction, axis = -1)
      8     output = [inv_machine_vocab[int(i)] for i in prediction]

E:\Python\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)
   1815         x = _standardize_input_data(x, self._feed_input_names,
   1816                                     self._feed_input_shapes,
-> 1817                                     check_batch_axis=False)
   1818         if self.stateful:
   1819             if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:

E:\Python\lib\site-packages\keras\engine\training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    111                         ': expected ' + names[i] + ' to have ' +
    112                         str(len(shape)) + ' dimensions, but got array '
--> 113                         'with shape ' + str(data_shape))
    114                 if not check_batch_axis:
    115                     data_shape = data_shape[1:]

ValueError: Error when checking : expected input_1 to have 3 dimensions, but got array with shape (37, 30)
从错误来看 是 数据维度不对,还有就是 
shape (37, 30) 这也不对,应该是 shape(30,37),所以修改后如下(看红色部分)

EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001']
for example in EXAMPLES:
    
    source = string_to_int(example, Tx, human_vocab)
    # source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)
    # prediction = model.predict([source, s0, c0])
    source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))) #不能变换  数据维度 ,  
    ttt=np.expand_dims(source,axis=0) # 在 axis=0的位置 ,增加一个 维度,以适应 输入维度要求
    prediction = model.predict([ttt, s0, c0])
    prediction = np.argmax(prediction, axis = -1)
    output = [inv_machine_vocab[int(i)] for i in prediction]
    
    print("source:", example)
    print("output:", ''.join(output))