站义 拨开荷叶行,寻梦已然成。仙女莲花里,翩翩白鹭情。
IMG-LOGO
主页 文章列表 Tensorflow:如何提取attention_scores进行绘图?

Tensorflow:如何提取attention_scores进行绘图?

by 白鹭 - 2022-03-03 1826 0 0

如果你在 Keras 中有一个 MultiHeadAttention 层,那么它可以像这样回传注意力分数:

    x, attention_scores = MultiHeadAttention(1, 10, 10)(x, return_attention_scores=True)

你如何从网络图中提取注意力分数?我想绘制它们。

uj5u.com热心网友回复:

选项 1: 如果您想在训练期间绘制注意力分数,您可以创建一个Callback并将资料传递给它。例如,它可以在每个纪元之后触发。这是我使用 2 个注意力头并在每个 epoch 之后绘制它们的示例:

import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, data):
      self.data = data
   def on_epoch_end(self, epoch, logs=None):
      test_targets, test_sources = self.data 
      _, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample

      fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
      sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
      sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
      fig.colorbar(axs[1].collections[0], cax=axs[2])
      plt.show()

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

attention_layer = model.layers[2]
samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2, callbacks=[CustomCallback([test_targets, test_sources])])
Epoch 1/2
1/3 [=========>....................] - ETA: 2s - loss: 0.7142

Tensorflow:如何提取 attention_scores 进行绘图?

3/3 [==============================] - 3s 649ms/step - loss: 0.6992
Epoch 2/2
1/3 [=========>....................] - ETA: 0s - loss: 0.7265

Tensorflow:如何提取 attention_scores 进行绘图?

3/3 [==============================] - 1s 650ms/step - loss: 0.6863
<keras.callbacks.History at 0x7fcc839dc590>

选项 2: 如果您只想在训练后绘制注意力分数,您可以将一些资料传递给模型的注意力层并绘制分数:

import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2)

attention_layer = model.layers[2]

_, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample
fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
fig.colorbar(axs[1].collections[0], cax=axs[2])
plt.show()
Epoch 1/2
3/3 [==============================] - 1s 7ms/step - loss: 0.6727
Epoch 2/2
3/3 [==============================] - 0s 6ms/step - loss: 0.6503

Tensorflow:如何提取 attention_scores 进行绘图?

一起学习
标签:

0 评论

发表评论

您的电子邮件地址不会被公开。 必填的字段已做标记 *