Many Flavors of Attention

by David Oniani

Posted on

Attention is the core computational block of any transformer. Due to how popular Large Language Models (LLMs) that generate text have become, these days attention frequently refers to some self-attention implementation. That being said, there are many such implementations and it may be difficult to follow the very active research in this direction. This purpose of the article is to cover many different implementations of attention. It may be helpful to see the [Autoregressive Transformer][autoregressive-transformer] article first as I assume familiarity with Multi-Head Attention (MHA). As usual, we will implement different flavors of attention from scratch in Python with PyTorch as the only dependency.

Before we start, I will provide a nice table that breaks down attention pros/cons:

Attention FlavorQuality (Model Accuracy)Inference SpeedKV Cache Memory
MHAHighestSlowestHighest
MQALowestFastestLowest
GQANear-MHAMedium–FastMedium

Multi-Head Attention (MHA)

MHA is a classic attention introduced in the GPT-2 paper. This is the flavor of attention implemented in the [Autoregressive Transformer][autoregressive-transformer] article. It can be implemented as follows:

class AttentionHead(nn.Module):
    """A single causal self-attention head."""

    def __init__(self, cfg: Config) -> None:
        """Initializes QKV projection and dropout, and cache causal mask to avoid recomputing it."""

        super().__init__()
        self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.head_size, bias=cfg.bias)
        self.dropout = nn.Dropout(cfg.dropout)
        self.register_buffer("mask", torch.tril(torch.ones(cfg.block_size, cfg.block_size)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes masked single-head self-attention for the input tensor."""

        _, T, C = x.shape
        q, k, v = self.qkv(x).split(self.qkv.out_features // 3, dim=-1)

        attn_scores = q @ k.transpose(-2, -1)
        attn_scores = attn_scores * C**-0.5  # Prevent softmax from blowing up
        attn_scores = attn_scores.masked_fill(self.mask[:T, :T] == 0, float("-inf"))  # Mask future tokens

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        return attn_weights @ v

class MultiHeadAttention(nn.Module):
    """A Multi-Head Attention (MHA)."""

    def __init__(self, cfg: Config) -> None:
        """Initializes multi-head self-attention with output projection and dropout."""

        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(cfg) for _ in range(cfg.n_head)])
        self.proj = nn.Linear(cfg.n_embd, cfg.n_embd)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes masked multi-head self-attention for the input tensor."""

        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out