欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

tensorflow 实现arelu激活函数(AReLU: Attention-based-Rectified-Linear-Unit)

程序员文章站 2024-03-31 10:18:04
import tensorflow as tfclass TfModel(object): def __init__(self, max_len=10, vocab_size=500, embedding_size=100, alpha=0.90, beta=2.0): self.max_len = max_len self.vocab_size = vocab_size self.embedding_size = embedding_siz....
import tensorflow as tf


class TfModel(object):

    def __init__(self, max_len=10, vocab_size=500, embedding_size=100, alpha=0.90, beta=2.0):
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.alpha = alpha
        self.beta = beta

        self.build_model()

    def build_model(self):
        self.x = tf.placeholder(tf.float32, [None, self.max_len], name="x")
        self.y = tf.placeholder(tf.float32, [None, self.max_len], name="y")

        hidden_output = tf.layers.dense(self.x, units=10)
        # 进入arelu激活函数
        self.alpha_tensor = tf.Variable([self.alpha], dtype=tf.float32)
        self.beta_tensor = tf.Variable([self.beta], dtype=tf.float32)

        self.alpha_tensor = tf.clip_by_value(self.alpha_tensor, clip_value_min=0.01, clip_value_max=0.99)
        self.beta_tensor = 1 + tf.nn.sigmoid(self.beta_tensor)
        hidden_action = tf.nn.relu(hidden_output) * self.beta_tensor - tf.nn.relu(-hidden_output) * self.alpha_tensor
        # arelu激活函数结束
        logits = tf.layers.dense(hidden_action, units=1)

        self.loss = tf.reduce_mean(tf.losses.mean_squared_error(self.y, logits))

        optimizer = tf.train.AdamOptimizer()
        self.train_op = optimizer.minimize(self.loss)



model = TfModel(max_len=1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    x = [[1], [2], [3]]
    y = [[14], [25], [36]]

    for i in range(1000):

        _, loss = sess.run([model.train_op, model.loss], {model.x: x, model.y: y})
        print("index: {}, loss: {}".format(i, loss))

        print(sess.run([model.alpha_tensor, model.beta_tensor]))

 

本文地址:https://blog.csdn.net/u010626747/article/details/107062960