Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 141 additions & 35 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability.
Default is 1e-6.
Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
Expand All @@ -66,6 +66,8 @@ def _norm(self, x):
torch.Tensor: The normalized tensor.

"""
# Divide each element of the tensor by its root mean square (RMS),
# adding eps to ensure the square root is positive.
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
Expand All @@ -78,15 +80,18 @@ def forward(self, x):
torch.Tensor: The output tensor after applying RMSNorm.

"""
# Perform RMSNorm normalization.
output = self._norm(x.float()).type_as(x)
# Multiply the result by the learned scaling factor to complete the RMSNorm.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe apply the RMSNorm?

return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

This function calculates a frequency tensor with complex exponentials using the given dimension
'dim' and the end index 'end'. The 'theta' parameter scales the frequencies.
'dim' and the end index 'end'.
The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.

Args:
Expand All @@ -97,11 +102,30 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
# Calculate the basis for rotation angles used in RoPE for different dimensions based on theta
# and dim parameters.
# Specifically compute θ_i = theta^{-2i / dim} for i ∈ [0, dim / 2).
# [Shape] freqs: (dim / 2, )
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore

# Generate t based on the end index to sample rotation angles for all positions.
# [Shape] t: (end, )
t = torch.arange(end, device=freqs.device)

# Sample rotation angles for all positions and corresponding dimensions using outer product operation.
# For each position pos_i in t and each set of dimension-specific angle bases θ_j in freqs,sample
# the rotation angle corresponding to position and dimension pos_i * θ_j.
# [Shape] freqs: (end, dim / 2)
freqs = torch.outer(t, freqs).float()

# Convert all rotation angles corresponding to dimensions for all positions into complex exponential
# frequencies.
# That is, for each rotation angle θ, calculate the corresponding frequency e^{iθ}.
# Refer to docs/CN/RoPE for description based on Euler's formula.
# [Shape] freqs_cis: (end, dim / 2)
# [Type] freqs_cis: complex64
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

Expand All @@ -123,9 +147,14 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
# Number of dimensions in the target tensor x
ndim = x.ndim
# Meaningless exception check?
assert 0 <= 1 < ndim
# Check if the precomputed frequency tensor matches the length and feature dimensions of the
# target tensor x.
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# Calculate the target shape for reshaping the frequency tensor (1, seq_len, 1, dim / 2).
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

Expand All @@ -148,19 +177,50 @@ def apply_rotary_emb(
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with
rotary embeddings.
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# Reshape the query tensor xq into (batch_size, seq_len, n_heads, dim / 2, 2) and then represent
# it in complex form.
# This groups the elements of the tensor into pairs and represents them as complex numbers, each
# pair as a 2D vector.
# [Shape] xq_: (batch_size, seq_len, n_heads, dim / 2)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))

# Perform the same operation on the key tensor xk.
# [Shape] xk_: (batch_size, seq_len, n_heads, dim / 2)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

# Reshape the frequency tensor freqs_cis for broadcasting.
# [Shape] freqs_cis: (seq_len, dim / 2) -> (1, seq_len, 1, dim / 2)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

# Perform rotation on the tensor using complex multiplication and convert the corresponding
# results back to real numbers.
# Flatten the last two dimensions to restore the tensor to its original shape.
# This converts each grouped 2D tensor d = [d1, d2]^T into a complex number d1 + i d2.
# Use complex multiplication with the corresponding complex frequency e^{iθ}, then convert back
# to real numbers,
# achieving R(θ)d, which rotates each grouped 2D tensor by the corresponding angle, thereby implementing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the usage of "achieving".

# RoPE (Rotation Position Encoding).
# [Shape] xq_out: (batch_size, seq_len, n_heads, dim)
# [Shape] xk_out: (batch_size, seq_len, n_heads, dim)
# flatten operation instance (Shape changes):
# (batch_size, seq_len, n_heads, dim // 2, 2) -flatten(3)-> (batch_size, seq_len, n_heads, dim)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)."""
"""在 n_kv_heads 维度上重复扩展 key 或 query 张量.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBD.


Args:
x (torch.Tensor): 输入张量, Shape: (batch_size, sequence_length, n_kv_heads, head_dim).
n_rep (int): 重复次数.

Returns:
torch.Tensor: 扩展后的张量.
"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
Expand Down Expand Up @@ -232,7 +292,7 @@ def __init__(self, args: ModelArgs):
input_is_parallel=True,
init_method=lambda x: x,
)

# key_cache_in_Attention
self.cache_k = torch.zeros(
(
args.max_batch_size,
Expand All @@ -241,6 +301,7 @@ def __init__(self, args: ModelArgs):
self.head_dim,
)
).cuda()
# value_cache_in_Attention
self.cache_v = torch.zeros(
(
args.max_batch_size,
Expand Down Expand Up @@ -270,50 +331,61 @@ def forward(

"""
bsz, seqlen, _ = x.shape
# Perform query, key, value transformations on input x.
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

# Adjust the sizes of query, key, and value to distribute features across multiple heads.
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Apply RoPE operation on query and key.
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

# Store the newly computed key and query into the corresponding cache based on the starting
# position and length of the new sequence.
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

# Retrieve keys and values needed for attention computation.
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
# If n_kv_heads < n_heads, expand in the n_kv_heads dimension to match the sizes of query, key,
# and value.
# Q: Why would n_kv_heads be less than n_heads?
# A: This occurs when the transformer layer uses Grouped-Query Attention.
# Multiple query heads share a pair of key and value heads to reduce the KV-Cache, hence
# n_heads must be divisible by n_kv_heads.
# [Shape] keys: (batch_size, cache_len + seq_len, n_local_heads, head_dim)
keys = repeat_kv(keys, self.n_rep)
# [Shape] values: (batch_size, cache_len + seq_len, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep)

# [Shape] xq: (batch_size, n_local_heads, seq_len, head_dim)
xq = xq.transpose(1, 2)
# [Shape] keys: (batch_size, n_local_heads, cache_len + seq_len, head_dim)
keys = keys.transpose(1, 2)
# [Shape] values: (batch_size, n_local_heads, cache_len + seq_len, head_dim)
values = values.transpose(1, 2)
# Calculate attention scores.
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
# If there is a mask, add it to attention scores to mask out parts of the scores.
if mask is not None:
scores = (
scores + mask
) # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = scores + mask
# Normalize attention scores using softmax.
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(
scores, values
) # (bs, n_local_heads, seqlen, head_dim)
# Aggregate the values vector using attention scores.
# [Shape] output: (batch_size, n_local_heads, seq_len, head_dim)
output = torch.matmul(scores, values)
# [Shape] output: (batch_size, seq_len, dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# Apply a linear transformation to the output.
return self.wo(output)


Expand Down Expand Up @@ -344,9 +416,10 @@ def __init__(
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
# Custom scaling parameter for hidden layer dimensions.
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
# Ensure hidden_dim is a multiple of multiple_of.
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
Expand Down Expand Up @@ -432,9 +505,14 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
# Residual calculation for attention.
# [Shape] h: (batch_size, seq_len, dim)
h = x + self.attention(
self.attention_norm(x), start_pos, freqs_cis, mask
)

# Residual calculation for FeedForward.
# [Shape] out: (batch_size, seq_len, dim)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -476,7 +554,7 @@ def __init__(self, params: ModelArgs):
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)

# Compute complex exponential frequency tensor for RoPE.
self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the
# Llama 2 generation of models is 4096. Adding this multiplier instead of using 4096
Expand All @@ -498,28 +576,56 @@ def forward(self, tokens: torch.Tensor, start_pos: int):

"""
_bsz, seqlen = tokens.shape
# Compute embeddings corresponding to the input token indices.
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
# Slice the complex exponential frequency tensor based on the starting position
# and sequence length.
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
# If the sequence length is greater than 1, calculate the mask for Attention.
if seqlen > 1:
# First, create a [seq_len, seq_len] matrix where each element is set to -inf.
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)
# Retain the upper triangular part of the matrix, excluding the main diagonal values,
# which are set to -inf, while all other elements are set to 0.
# Indicates each token can only attend to itself and previously processed tokens.

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
# Due to key-value caching, we only need to compute attention scores for the new sequence.
# Therefore, the desired size of attention scores is (seq_len, cache_len + seq_len).
# For the mask, elements (i, j) where j > cache_len + i need to be masked out.
# In practice, we prepend the previously computed [seq_len, seq_len] mask matrix with a
# zero-filled matrix of size (seq_len, start_pos). This part does not require any masking,
# so we concatenate a zero matrix of size (seq_len, start_pos) to the beginning of the mask.
# ┌─────────────────────────────────────────────────────────┐
# │ > mask maxtirx visualization │
# │ │
# │ hstack │
# │ ↓ │
# │ ↙ [0][0][0][0][0][0][0][0] | [0][x][x][x][x] │
# │ ↙ [0][0][0][0][0][0][0][0] | [0][0][x][x][x] │
# │ seq_len(5) ← [0][0][0][0][0][0][0][0] | [0][0][0][x][x] │
# │ ↖ [0][0][0][0][0][0][0][0] | [0][0][0][0][x] │
# │ ↖ [0][0][0][0][0][0][0][0] | [0][0][0][0][0] │
# │ ↑ ↘ seq_len ↙ │
# │ cache_len(8) │
# │ │
# │ x: denote the -inf value (masked value) │
# └─────────────────────────────────────────────────────────┘
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)

# Iterate through each layer in the network.
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
# Perform normalization before entering the output layer.
h = self.norm(h)
# Obtain logits through the final output layer.
output = self.output(h).float()
return output