欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras多输入模型

程序员文章站 2022-05-26 19:30:33
...

keras多输入模型
双输入模型的构建

from keras.models import Model
from keras import layers
from keras import Input

text_vocabulary_size=10000
question_vocabulary_size=10000
answer_vocabulary_size=5000

#参考文本的输入
text_input=Input(shape=(None,),dtype='int32',name='text')
embedded_text=layers.Embedding(text_vocabulary_size,64)(text_input)
encoded_text=layers.LSTM(32)(embedded_text)
#问题的输入
question_input=Input(shape=(None,),dtype='int32',name='question')
embedded_question=layers.Embedding(question_vocabulary_size,32)(question_input)
encoded_question=layers.LSTM(16)(embedded_question)

concatenated=layers.concatenate([encoded_text,encoded_question],axis=-1)
answer=layers.Dense(answer_vocabulary_size,activation='softmax')(concatenated)

#模型的构建
model=Model([text_input,question_input],answer)
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])

生成测试数据

import numpy as np
num_samples=1000
max_length=100
text=np.random.randint(1,text_vocabulary_size,size=(num_samples,max_length))
question=np.random.randint(1,question_vocabulary_size,size=(num_samples))
answers=np.random.randint(answer_vocabulary_size,size=(num_samples))
answers=keras.utils.to_categorical(answers,answer_vocabulary_size)

两种训练数据方式(列表,字典)

model.fit([text,question],answers,epochs=10, batch_size=128)
model.fit({'text': text, 'question': question}, answers, epochs=10, batch_size=128)