17 min read1 hour ago
–
Modern AI models like GPT can understand context, remember long sentences, and generate coherent answers. But how does it actually know which words matter the most? The secret behind this intelligence is a mechanism called attention.
Before Transformers, models like RNNs and LSTMs processed text one token at a time, passing information forward through hidden states. This approach worked for short sentences, but it came with major limitations especially when dealing with long-range dependencies.
RNNs tend to forget important information that appears early in a sentence because they can’t directly revisit earlier hidden states during decoding. Everything is compressed into a single evolving vector, and as sequences grow longer, this vector simp…
17 min read1 hour ago
–
Modern AI models like GPT can understand context, remember long sentences, and generate coherent answers. But how does it actually know which words matter the most? The secret behind this intelligence is a mechanism called attention.
Before Transformers, models like RNNs and LSTMs processed text one token at a time, passing information forward through hidden states. This approach worked for short sentences, but it came with major limitations especially when dealing with long-range dependencies.
RNNs tend to forget important information that appears early in a sentence because they can’t directly revisit earlier hidden states during decoding. Everything is compressed into a single evolving vector, and as sequences grow longer, this vector simply can’t capture all the necessary context.
This leads to two core issues:
- Long-range context loss The model’s “memory” gradually fades as data flows through time. During backpropagation through time (BPTT), gradients get repeatedly multiplied, often causing them to shrink. This vanishing gradient problem makes it extremely hard for RNNs to learn relationships that occur many steps apart.
- Slow, sequential processing RNNs must compute each step in order because every new hidden state depends on the previous one. This sequential nature prevents parallelization, making training slow and inefficient — especially for long sequences.
Because of these limitations, RNNs work reasonably well for short sentences, but they break down on long texts where words far apart still impact each other. A few years later, Google researchers introduced a groundbreaking architecture in the paper “Attention Is All You Need.” Instead of processing tokens one by one, like RNNs do, Transformers can look at the entire sequence simultaneously, decide which words are important, and compute everything in parallel. This made training dramatically faster and enabled models like GPT to understand long, complex relationships across text.
With the attention mechanism, the decoder can selectively focus on any token in the input. Not every token contributes equally to generating the next word some tokens matter more than others. This importance is captured through attention weights.
The goal of self-attention is to compute a context vector for each token. This vector represents how much attention that token gives to every other token in the sentence.
For example, consider the sentence:
“The cat sat on the mat.”
(Here we treat each word as one token for simplicity.)
Self-attention asks: How much should the token “cat” pay attention to each token in the sequence “The”, “cat”, “sat”, “on”, “the”, “mat”?
Later, we’ll go through the exact calculations, but the intuition is simple: If “cat” finds “mat” highly relevant, the dot product between their vectors will be high, leading to a stronger attention weight. Because the dot product shows how strongly two vectors are aligned.
To compute context vectors which are weighted sums of the input embeddings we introduce several trainable weight matrices. These matrices are learned during training and allow the model to generate meaningful context representations for each token.
In the self-attention mechanism, we use three such matrices: **Wq **, Wk, and Wv. These are the Query, Key, and Value projection matrices. They start with small random values and are gradually updated through backpropagation as the model learns.
Step 1: Linear Transformation
For every token, the input embedding X is multiplied by these three weight matrices:
Each multiplication produces a distinct Query, Key, and Value vector for every token in the sequence. In multi-head attention, each head has its own separate set of Wquery, Wkey, and Wvalue matrices.
Step 2: Why Q, K, and V? (The Intuition)
The Q, K, V terminology comes from search engines and databases:
- Query (Q): Represents what the current token is looking for. It’s like a search query you type into Google.
- Key (K): Represents what each token offers or what information it carries. It’s like the tags or metadata of items in a database. Each Query is compared with all Keys to measure relevance.
- Value (V): Contains the actual information/content of each token. If a Key matches a Query well, its Value contributes strongly to the final context vector.
For demonstration purposes, let’s use the sentence:
“The cat sat on the mat.”
If we treat each word as one token, we have six tokens in total.
To keep it simple, assume each token is represented by a 4-dimensional embedding vector. We can create a 6×4 matrix where each row corresponds to a token and assign random values as embeddings for all six words.
Now, let’s focus on the token “cat” and see how much attention it gives to all the other tokens. To do this, we create Query, Key, and Value vectors for every token using their respective weight matrices (Wq, Wk, Wv).
Using the input data, we will compute the Query, Key, and Value vectors by multiplying the input with the corresponding weight matrices: W_query, W_key, and W_value.
Once we have these vectors, we can compute the attention scores to understand how “cat” interacts with all other words in the sentence.
For now, let’s focus on the Key vector of the second token, ‘cat.’ We will take this Key and compute its dot product with the Query vectors of all other tokens. This will give us the raw attention scores that show how much ‘cat’ attends to each token in the sentence.
Press enter or click to view image in full size
We calculate the attention scores for the second query by taking the dot product of
*Queries[1]*with the transpose of the*keys*matrix.
Press enter or click to view image in full size
Press enter or click to view image in full size
Press enter or click to view image in full size
Press enter or click to view image in full size
This process is repeated for all query vectors in the sequence. Each query computes compatibility scores with all keys, building a complete attention matrix that shows how every token attends to all other tokens. Using the This equation:
After computing the attention scores using the dot product between each token’s Query and Key vectors, we need to normalize these scores. The raw values (often called logits) can be any real number positive or negative so we can’t use them directly as weights.
To fix this, we apply the softmax function. Softmax converts the raw scores into a probability-like distribution where:
- all values are between 0 and 1
- all weights add up to 1
- higher scores get higher weights, lower scores get lower weights
These normalized values become the attention weights, which determine how much each token contributes to the final context vector.
Step-by-Step Intuition 1. Exponentiation (exp) makes all values positive The exp() function turns every score into a positive number. Even negative scores become small positive numbers. This is important because attention weights must be ≥ 0. 2. Bigger scores become MUCH bigger Exponentials grow fast: • exp(4) is far larger than exp(1) • exp(1) is far larger than exp(−1) This amplifies differences, highlighting the most important tokens.
3. Divide by the sum to normalize All exponentiated scores are summed:
This ensures: • All values are between 0 and 1 • All values sum to 1 So they act like probabilities or attention weights.
For example, when we apply the Softmax function to the attention scores of the token ‘cat’ (its scores with every other token), the outputs are transformed into values between 0 and 1. All of these values add up to 1, which means they now represent a valid probability distribution. In other words, we’ve converted raw attention scores into meaningful probabilities.
Press enter or click to view image in full size
Next, we apply an important adjustment known as scaled dot-product attention. Here, we divide the raw attention scores by the square root of the key embedding dimension. (Taking a square root is equivalent to raising the value to the power of 0.5.)
Why do we do this? The Softmax function is very sensitive to the magnitude of its inputs. If the attention scores are too large, the exponentials inside Softmax grow rapidly, causing the output distribution to become extremely peaked — meaning one token gets almost all the attention while the others get nearly none. This makes training unstable.
The dot product of queries and keys naturally grows with the embedding dimension: multiplying high-dimensional vectors increases the variance of the result. Dividing by √(dimension) stabilizes this variance, keeping the attention scores in a manageable range and ensuring smoother, more balanced Softmax outputs.
Once we obtain the attention weights, we multiply them by the value vectors. This operation produces the final output of the attention mechanism, known as the context vector.
Now let’s implement everything in code. For simplicity, we’ll use a small 4-dimensional embedding vector for each token. This keeps the math easy to follow while still demonstrating how attention works under the hood.
import torchinputs = torch.tensor([ [0.43, 0.15, 0.89, 0.34], # The (x¹) [0.55, 0.87, 0.66, 0.45], # Cat (x²) [0.57, 0.85, 0.64, 0.76], # Sat (x³) [0.22, 0.58, 0.33, 0.12], # on (x⁴) [0.77, 0.25, 0.10, 0.87], # the (x⁵) [0.05, 0.80, 0.55, 0.54], # mat (x⁶)])
Based on the input embeddings above, let’s now define the weight matrices for the query, key, and value projections.
import torch# Set seed for reproducibilitytorch.manual_seed(123)# Define dimensions# Input: 6 tokens, each with 4 dimensions (6x4)# We want Query, Key, Value to be 4x4 matricesd_in = 4 # Input dimension (each token has 4 features)d_out = 4 # Output dimension (we want 4x4 matrices for Q, K, V)# Initialize weight matrices for Query, Key, Value projections# Each weight matrix transforms from 4D input to 4D outputW_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)print(f"Query weight matrix shape: {W_query}") # Should be (4, 4)print(f"Key weight matrix shape: {W_key}") # Should be (4, 4)print(f"Value weight matrix shape: {W_value}") # Should be (4, 4)
so you we’ll get the output like this
Press enter or click to view image in full size
Now we can compute the Query, Key, and Value vectors by multiplying the input embeddings with the W_query, W_key, and W_value weight matrices.
queries = inputs @ W_query # Shape: (6, 4)keys = inputs @ W_key # Shape: (6, 4) values = inputs @ W_value # Shape: (6, 4)print(f"Queries shape: {queries}") # Should be (6, 4)print(f"Keys shape: {keys}") # Should be (6, 4)print(f"Values shape: {values}")
The output look like this
Press enter or click to view image in full size
Now that we have the Query, Key, and Value matrices, we can compute the attention scores by taking the dot product of the Query matrix with the transpose of the Key matrix.
attn_scores = queries @ keys.T print(attn_scores)
Now we have computed the attention scores output The attention_scores values are as follows:
Press enter or click to view image in full size
Next, we convert the raw attention scores into attention weights. Just like before, we apply the softmax function but this time we first scale the scores by dividing them by the square root of the key’s embedding dimension. This scaling step stabilizes the values going into softmax. (Taking the square root is the same as raising the dimension to the power of 0.5.)
dim_k = keys.shape[-1]attn_weights = torch.softmax(attn_scores / (dim_k ** 0.5), dim=-1)print(attn_weights)
Now we have the attention weights. Each weight is between 0 and 1, and the sum of all weights equals 1 exactly what we want, just like a probability distribution.
Press enter or click to view image in full size
Finally, we can compute the context vectors. To do this, we multiply the attention weights by the corresponding Value vectors and sum the results. This weighted sum ensures that tokens receiving higher attention contribute more to the final representation. In other words, each context vector is a combination of all Value vectors, weighted according to how relevant each token is to the current token. This is the core idea behind self-attention capturing relationships between all tokens in a sequence to produce context-aware representations.
context_vec = attn_weights @ valuesprint(context_vec_2)
we get the output like this
Press enter or click to view image in full size
Finally, we have the context vector! The main goal of the attention mechanism is to compute these context vectors, which capture how each token relates to others in the sequence. What we’ve done so far is a single, basic attention computation without causal masking, dropout, or multi-head attention. Next, we’ll explore how causal attention and dropout improve the model’s performance and stability.
Causal attention
In this section, we’ll extend the Self-Attention mechanism we developed earlier by incorporating causal attention and dropout. This updated class will also serve as the foundation for implementing multi-head attention in the next section.
Why Causal Attention?
Causal attention, also known as masked attention, is a special form of self-attention. Unlike standard self-attention, which allows a token to consider the entire input sequence when computing attention scores, causal attention restricts the model to only look at the current and previous tokens in a sequence.
This restriction is crucial for autoregressive models like GPT, where the goal is to predict the next token without peeking into the future. To implement this, we mask out all future tokens for each position in the sequence. In practice, this means setting the attention weights for tokens that come after the current token to zero.
After masking, the remaining attention weights are normalized so that they sum to 1 in each row. This ensures that each token’s context vector is computed only from the current and previous tokens, preserving the natural order of the sequence while preventing information leakage from the future.
This is essential for autoregressive tasks like text generation, where the model predicts the next token based only on the tokens it has already seen.
Specifically, causal attention provides the following benefits:
- Autoregressive generation
- Allows models to generate sequences one token at a time.
- For each new token, only previously generated tokens are used as context.
- Preventing information leakage
- Masks out future tokens so the model cannot ‘peek’ ahead.
- Critical for tasks like language modeling, where the next word must be predicted solely from preceding words.
- Maintaining temporal structure
- Preserves the natural sequential order of data.
- Important for time-series forecasting or any task where the past influences the future, but not vice versa.
- Improving model generalization
- Prevents the model from focusing on spurious correlations in some applications, like vision-language tasks.
- Forces the model to learn meaningful, contextually relevant relationships.
By combining causal attention with dropout, our updated Self-Attention class will be robust, generalizable, and ready for multi-head attention.
After applying causal attention and then the softmax function, the attention output for our example sentence becomes easier to interpret. For simplicity, we treat each word as a single token: "The", "cat", "sat", "on", "the", "mat".
With causal attention:
- The token
"The"attends only to itself. - The token
"cat"attends to"The"and"cat". - The token
"sat"attends to"The","cat", and"sat". - And this pattern continues for the remaining tokens.
This ensures that each token only considers previous and current tokens in the sequence.
Once we have these attention weights We’ll add the dropout next it is a technique to randomly zero out certain values in the attention weight matrix during training to prevent the model from overfitting and becoming overly reliant on specific parts of the input, the next step is to multiply them by the Value vectors to compute the final context vectors for each token.
Dropout
Dropout is a deep learning technique where randomly selected units in a hidden layer are temporarily ignored during training. This helps prevent overfitting and improves the model’s ability to generalize to new, unseen data.
Why We Use Dropout in Attention
After computing the attention weights with softmax, we apply dropout to the weights. Dropout is a regularization technique that randomly zeroes out a portion of the attention probabilities during training.
Here’s why this is important:
- Prevents overfitting
- Without dropout, the model might rely too heavily on a few tokens in the sequence.
- Dropout forces the model to distribute attention more evenly and learn more robust patterns.
2.Improves generalization
- By randomly masking some attention weights, the model is less likely to memorize specific sequences.
- This helps the model perform better on unseen data.
3. Applied after softmax
- Dropout operates on the attention probabilities, not the raw scores.
- This ensures that even after masking future tokens and normalizing the weights, some connections are randomly dropped to regularize learning.
4. Maintains expected value
- PyTorch’s
nn.Dropoutautomatically scales the remaining weights during training so that the overall sum of attention weights stays consistent in expectation.
As you can see in the table, after computing the attention weights with causal attention (masking out future tokens), we apply dropout. This randomly zeroes out some attention values, which helps the model generalize better and reduces the risk of overfitting.
Now that we’ve covered the theory, let’s implement the full causal attention mechanism in code, including masking future tokens and applying dropout.
import torchimport torch.nn as nnclass CausalAttention(nn.Module): """ Implements causal (masked) self-attention with optional dropout. This ensures that each token only attends to itself and previous tokens. """ def __init__(self, d_in, d_out, context_length, dropout=0.0, qkv_bias=False): super().__init__() self.d_out = d_out # Linear layers to compute Query, Key, and Value vectors self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # Dropout applied to attention weights to prevent overfitting self.dropout = nn.Dropout(dropout) # Causal mask to prevent attention to future tokens # Upper triangular matrix with 1 above diagonal, 0 elsewhere self.register_buffer( 'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): """ x: input tensor of shape (batch_size, num_tokens, d_in) returns: context vectors of shape (batch_size, num_tokens, d_out) """ batch_size, num_tokens, d_in = x.shape # Compute Query, Key, and Value matrices queries = self.W_query(x) # (batch_size, num_tokens, d_out) keys = self.W_key(x) # (batch_size, num_tokens, d_out) values = self.W_value(x) # (batch_size, num_tokens, d_out) # Compute raw attention scores # Shape: (batch_size, num_tokens, num_tokens) attn_scores = queries @ keys.transpose(1, 2) # Apply causal mask to prevent attending to future tokens # Only consider tokens at or before the current position attn_scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], float('-inf') ) # Scale by sqrt of key dimension and apply softmax to get attention weights attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1) # Apply dropout to attention weights for regularization attn_weights = self.dropout(attn_weights) # Compute context vectors as weighted sum of Value vectors context_vec = attn_weights @ values return context_vec
Next, let’s dive into **Multi-Head Attention **the exact attention mechanism used in GPT. In fact, this is the core engine behind most GPT models, enabling them to capture complex relationships across tokens in a sequence.
Multi-Head Attention
Multi-Head Attention allows a model to focus on different types of information simultaneously, improving its ability to understand complex relationships in the data. Instead of a single attention mechanism, multiple attention ‘heads’ run in parallel, with each head learning to attend to different aspects of the input.
The outputs of these heads are then combined, providing a richer, more comprehensive representation than a single attention head could achieve. Each head captures unique relationships, allowing the model to recognize intricate patterns in the sequence.
Implementing Multi-Head Attention involves creating multiple instances of the self-attention mechanism, each with its own set of learned weights, and then merging their outputs. While this is computationally more intensive, it is a key factor that makes large language models so powerful for tasks requiring deep understanding of complex data.
import torchimport torch.nn as nnclass MultiHeadAttentionWrapper(nn.Module): """ Implements multi-head causal attention by combining multiple independent causal attention heads in parallel. """ def __init__(self, d_in, d_out, context_length, dropout=0.0, num_heads=8, qkv_bias=False): super().__init__() # Create multiple causal attention heads # Each head has its own set of weights (queries, keys, values) self.heads = nn.ModuleList([ CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads) ]) def forward(self, x): """ x: input tensor of shape (batch_size, num_tokens, d_in) returns: concatenated outputs from all attention heads """ # Compute attention for each head and concatenate along the last dimension head_outputs = [head(x) for head in self.heads] return torch.cat(head_outputs, dim=-1)
Main idea: Run the attention mechanism multiple times in parallel, each with its own learned linear projections for queries, keys, and values, and then combine the results to form a unified, context-rich representation.
Instead of keeping two separate classes — CausalAttention and MultiHeadAttentionWrapper — we can combine their functionality into a single, more efficient MultiHeadAttention class.
In the previous MultiHeadAttentionWrapper, multiple attention heads were implemented by creating a list of independent CausalAttention objects, one for each head. Each head computed its own attention separately, and the outputs were then concatenated. While this works, it can be inefficient in terms of computation and memory.
The new MultiHeadAttention class integrates multi-head functionality directly. Rather than maintaining separate objects for each head, it splits the input embeddings into multiple heads by reshaping the projected Query, Key, and Value tensors. Each head computes attention in parallel, and the results are combined at the end to produce a single, unified output.
This approach is more memory- and compute-efficient, while preserving all the benefits of multi-head attention: capturing diverse relationships across different representation subspaces.
Let’s take a look at the implementation of the MultiHeadAttention class.
import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module): """ Implements multi-head causal (masked) attention in a single, efficient class. Splits the input into multiple heads, computes attention in parallel, and combines the results. """ def __init__(self, d_in, d_out, context_length, dropout=0.0, num_heads=8, qkv_bias=False): super().__init__() assert d_out % num_heads == 0, "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads # Dimension per head # Linear projections for Query, Key, Value self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # Linear layer to combine outputs from all heads self.out_proj = nn.Linear(d_out, d_out) # Dropout applied to attention weights self.dropout = nn.Dropout(dropout) # Causal mask to prevent attending to future tokens self.register_buffer( "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): """ x: input tensor of shape (batch_size, num_tokens, d_in) returns: context vectors of shape (batch_size, num_tokens, d_out) """ batch_size, num_tokens, _ = x.shape # Project inputs to Query, Key, and Value queries = self.W_query(x) # (b, num_tokens, d_out) keys = self.W_key(x) values = self.W_value(x) # Split into multiple heads: (b, num_tokens, num_heads, head_dim) queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim) keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim) values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim) # Transpose to (b, num_heads, num_tokens, head_dim) for attention computation queries = queries.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) # Compute scaled dot-product attention attn_scores = queries @ keys.transpose(2, 3) # (b, num_heads, num_tokens, num_tokens) # Apply causal mask (prevent attending to future tokens) mask_bool = self.mask.bool()[:num_tokens, :num_tokens] attn_scores.masked_fill_(mask_bool, float('-inf')) # Softmax + scale by sqrt of head dimension attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1) # Apply dropout to attention weights attn_weights = self.dropout(attn_weights) # Compute context vectors for each head context = attn_weights @ values # (b, num_heads, num_tokens, head_dim) # Transpose and merge heads: (b, num_tokens, num_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out) # Optional final linear projection context = self.out_proj(context) return context
In this article, we’ve explored the core of GPT’s transformer architecture: self-attention and multi-head attention. We saw how attention allows the model to weigh the importance of each token in a sequence, how causal masking ensures predictions only depend on past tokens, and how multiple attention heads capture diverse patterns in the data.
Understanding these mechanisms is crucial because they are the building blocks behind the incredible capabilities of large language models. By breaking down these concepts step by step, from query, key, and value vectors to causal multi-head attention, we can better appreciate how models like GPT generate coherent, context-aware text.
In the next article, we’ll dive deeper into other fascinating aspects of AI, including positional encoding, transformer decoders, and how these models are trained at scale. Stay tuned as we continue to unravel the mechanics behind cutting-edge AI systems.