keras多输入模型
程序员文章站
2022-05-26 19:30:33
...
双输入模型的构建
from keras.models import Model
from keras import layers
from keras import Input
text_vocabulary_size=10000
question_vocabulary_size=10000
answer_vocabulary_size=5000
#参考文本的输入
text_input=Input(shape=(None,),dtype='int32',name='text')
embedded_text=layers.Embedding(text_vocabulary_size,64)(text_input)
encoded_text=layers.LSTM(32)(embedded_text)
#问题的输入
question_input=Input(shape=(None,),dtype='int32',name='question')
embedded_question=layers.Embedding(question_vocabulary_size,32)(question_input)
encoded_question=layers.LSTM(16)(embedded_question)
concatenated=layers.concatenate([encoded_text,encoded_question],axis=-1)
answer=layers.Dense(answer_vocabulary_size,activation='softmax')(concatenated)
#模型的构建
model=Model([text_input,question_input],answer)
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])
生成测试数据
import numpy as np
num_samples=1000
max_length=100
text=np.random.randint(1,text_vocabulary_size,size=(num_samples,max_length))
question=np.random.randint(1,question_vocabulary_size,size=(num_samples))
answers=np.random.randint(answer_vocabulary_size,size=(num_samples))
answers=keras.utils.to_categorical(answers,answer_vocabulary_size)
两种训练数据方式(列表,字典)
model.fit([text,question],answers,epochs=10, batch_size=128)
model.fit({'text': text, 'question': question}, answers, epochs=10, batch_size=128)
上一篇: windows 下查看端口占用