您现在的位置是: 首页

180223 "Merge" versus "merge", what is the difference?

程序员文章站 2022-06-02 22:49:35

Would someone explain the usage of concat_axis, dot_axis and output_shape in merge layer? #2626

180223 "Merge" versus "merge", what is the difference?

# from keras.layers import dot
from keras.layers import Input
from keras.models import Model
import numpy as np

input_a = np.reshape([1, 2, 3], (1, 1, 3))
input_b = np.reshape([4, 5, 6], (1, 1, 3))

a = Input(shape=(1, 3))
b = Input(shape=(1, 3))

# keras 1.2.0
concat = merge([a, b], mode='concat', concat_axis=2)
dot = merge([a, b], mode='dot', dot_axes=(1,1))
cos = merge([a, b], mode='cos', dot_axes=2)

# keras 2.0.x
# concat = keras.layers,concatenate([a,b])
# dot = keras.layers.dot([a, b],axes=(1,1))
# cos = keras.layers.cos([a,b],axes=2)

model_concat = Model(input=[a, b], output=concat)
model_dot = Model(input=[a, b], output=dot)
model_cos = Model(input=[a, b], output=cos)

print(model_concat.predict([input_a, input_b]))
print(model_dot.predict([input_a, input_b]))
print(model_cos.predict([input_a, input_b]))

“Merge” versus “merge”, what is the difference?
180223 "Merge" versus "merge", what is the difference?

An example

# Code
from keras.layers import dot
from keras.layers import Input
from keras.models import Model
import numpy as np

input_a = np.reshape(np.arange(12), (-1, 4, 3))
input_b = np.reshape(np.arange(9), (-1, 3, 3))

print('data_a and data_b')
for i in [input_a,input_b]:

a = Input(shape=(4, 3))
b = Input(shape=(3, 3))

# keras 1.2.0
print('concat result')
concat = merge([a, b], mode='concat', concat_axis=1)
model_concat = Model(input=[a, b], output=concat)
print(model_concat.predict([input_a, input_b]))

print('dot result')
dot = merge([a, b], mode='dot', dot_axes=(2,2))
model_dot = Model(input=[a, b], output=dot)
print(model_dot.predict([input_a, input_b]))
# result
data_a and data_b
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]
  [ 9 10 11]]]
[[[0 1 2]
  [3 4 5]
  [6 7 8]]]
concat result
[[[  0.   1.   2.]
  [  3.   4.   5.]
  [  6.   7.   8.]
  [  9.  10.  11.]
  [  0.   1.   2.]
  [  3.   4.   5.]
  [  6.   7.   8.]]]
dot result
[[[   5.   14.   23.]
  [  14.   50.   86.]
  [  23.   86.  149.]
  [  32.  122.  212.]]]

180223 "Merge" versus "merge", what is the difference?