Source code for pytorch_extra_mhirano.nn.attention

import math
import warnings
from typing import Optional
from typing import Tuple
from typing import Union

import torch
import torch.nn as nn

__all__ = ["DotProductAttention", "SelfAttention", "SelfMultiheadAttention"]


[docs]class DotProductAttention(nn.Module): """DotProductAttention. .. math:: \mathrm{DotProductAttention}(Q, K, V) &=& \mathrm{softmax}(qk^T) v q &=& QW_1 + b_1 k &=& KW_2 + b_2 v &=& VW_3 + b_3 Args: qdim: dimension of the model, i.e., dimension of Q output_dim: dimension of output layer, i.e., dimension of output. Default: None dropout: a Dropout layer on attn_output_weights. Default: 0.0. transform: q = Q, k = K, v = V if it is False. Default: True bias: add bias as module parameter. Default: True. same_embd: W1 = W2 = W3, b1 = b2 = b3 if it is True. Default: True add_bias_kv: add bias to the key and value sequences at dim=0. kdim: total number of features in key. Default: None. vdim: total number of features in key. Default: None. Note: if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature) scaled: If ``True``, this performs as scaled dot product attention Examples:: >>> attn = DotProductAttention(query_dim) >>> attn_output, attn_output_weights = attn(query, key, value) """
[docs] def __init__( self, qdim: int, output_dim: Optional[int] = None, dropout: float = 0.0, transform: bool = True, bias: bool = True, same_embd: bool = True, add_bias_kv: Optional[bool] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = True, scaled: bool = False, ) -> None: super(DotProductAttention, self).__init__() self.qdim: int = qdim self.transform: bool = transform self.bias: bool = bias self.same_embd: bool = same_embd self.kdim: int = kdim if kdim is not None else self.qdim self.vdim: int = vdim if vdim is not None else self.qdim self.output_dim: int = output_dim if output_dim is not None else self.vdim if self.same_embd and (self.qdim != self.kdim or self.qdim != self.vdim): raise AssertionError( "qdim, kdim, vdim should be the same dimensions if same_embd is True" ) self.add_bias_kv: bool = add_bias_kv if add_bias_kv is not None else self.bias if self.same_embd and (self.bias != self.add_bias_kv): raise AssertionError( "bias and add_bias_kv should be the same if same_embd is True" ) self.batch_first: bool = batch_first self.scaled: bool = scaled self.fc_q: nn.Module = nn.Linear(self.qdim, self.output_dim, bias=bias) self.fc_k: nn.Module self.fc_v: nn.Module if self.same_embd: self.fc_k = self.fc_q self.fc_v = self.fc_k else: self.fc_k = nn.Linear(self.kdim, self.output_dim, bias=self.add_bias_kv) self.fc_v = nn.Linear(self.vdim, self.output_dim, bias=self.add_bias_kv) self.dropout: nn.Module = nn.Dropout(p=dropout) self.softmax: nn.Module = nn.Softmax(dim=2)
def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` value will be ignored. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. .. note:: `batch_first` argument is ignored for unbatched inputs. """ if key_padding_mask is not None: warnings.warn( "'key_padding_mask' in 'DotProductAttention' is currently an experimental version." "When you use this, please check if this is working correctly or not very carefully." ) if not self.batch_first: query = torch.transpose(query, 0, 1) key = torch.transpose(key, 0, 1) value = torch.transpose(value, 0, 1) bsz, tgt_len, _ = query.size() q = self.fc_q(query) q = self.dropout(q) k = self.fc_k(key) k = self.dropout(k) v = self.fc_v(value) v = self.dropout(v) if k.size() != v.size(): raise AssertionError("The sizes of key and value should be the same.") src_len = k.size(1) if key_padding_mask is not None: if key_padding_mask.size(0) != bsz: raise AssertionError( "The first dimension of kay padding mask size must be the same as batch size" ) if key_padding_mask.size(1) != src_len: raise AssertionError( "The second dimension of key padding mask size must be the same as source length" ) a = torch.bmm(q, torch.transpose(k, 1, 2)) if self.scaled: a /= math.sqrt(self.output_dim) if attn_mask is not None: a += attn_mask if key_padding_mask is not None: a = a.masked_fill(key_padding_mask.unsqueeze(1), float("-inf")) attn = self.softmax(a) output = torch.bmm(attn, v) if not self.batch_first: output = torch.transpose(output, 0, 1) if need_weights: return output, attn else: return output, None def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor: r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = ( mask.float() .masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 1, float(0.0)) ) return mask
[docs]class SelfAttention(nn.Module): """Self Attention module using DotProductAttention Args: qdim: dimension of the model, i.e., dimension of Q dropout: a Dropout layer on attn_output_weights. Default: 0.0. transform: q = Q, k = K, v = V if it is False. Default: True bias: add bias as module parameter. Default: True. same_embd: W1 = W2 = W3, b1 = b2 = b3 if it is True. Default: True add_bias_kv: add bias to the key and value sequences at dim=0. kdim: total number of features in key. Default: None. vdim: total number of features in key. Default: None. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature) scaled: If ``True``, this performs as scaled dot product attention """
[docs] def __init__( self, qdim: int, dropout: float = 0.0, transform: bool = True, bias: bool = True, same_embd: bool = True, add_bias_kv: Optional[bool] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = True, scaled: bool = False, ) -> None: super(SelfAttention, self).__init__() self.attn = DotProductAttention( qdim=qdim, output_dim=qdim, dropout=dropout, transform=transform, bias=bias, same_embd=same_embd, add_bias_kv=add_bias_kv, kdim=kdim, vdim=vdim, batch_first=batch_first, scaled=scaled, )
def forward( self, inputs: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, need_weights: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: return self.attn.forward( inputs, inputs, inputs, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=need_weights, )
[docs]class SelfMultiheadAttention(nn.Module): """Self Attention module using torch.nn.MultiheadAttention Args: embed_dim: dimension of the model, i.e., dimension of Q num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: total number of features in key. Default: None. vdim: total number of features in key. Default: None. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature) """
[docs] def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: Optional[bool] = None, add_zero_attn: bool = False, kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = True, device: Optional[Union[torch.device, str]] = None, dtype: Optional[torch.dtype] = None, ) -> None: super(SelfMultiheadAttention, self).__init__() self.attn = nn.MultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, kdim=kdim, vdim=vdim, batch_first=batch_first, device=device, dtype=dtype, )
def forward( self, inputs: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: return self.attn.forward( query=inputs, key=inputs, value=inputs, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, )