-
Notifications
You must be signed in to change notification settings - Fork 3
[Model] EN Comment Translation #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pkuyihangji
wants to merge
1
commit into
Mixture-AI:main
Choose a base branch
from
pkuyihangji:yihang_model1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6): | |
| Args: | ||
| dim (int): The dimension of the input tensor. | ||
| eps (float, optional): A small value added to the denominator for numerical stability. | ||
| Default is 1e-6. | ||
| Default is 1e-6. | ||
|
|
||
| Attributes: | ||
| eps (float): A small value added to the denominator for numerical stability. | ||
|
|
@@ -66,6 +66,8 @@ def _norm(self, x): | |
| torch.Tensor: The normalized tensor. | ||
|
|
||
| """ | ||
| # Divide each element of the tensor by its root mean square (RMS), | ||
| # adding eps to ensure the square root is positive. | ||
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
|
||
| def forward(self, x): | ||
|
|
@@ -78,15 +80,18 @@ def forward(self, x): | |
| torch.Tensor: The output tensor after applying RMSNorm. | ||
|
|
||
| """ | ||
| # Perform RMSNorm normalization. | ||
| output = self._norm(x.float()).type_as(x) | ||
| # Multiply the result by the learned scaling factor to complete the RMSNorm. | ||
| return output * self.weight | ||
|
|
||
|
|
||
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
| """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
|
|
||
| This function calculates a frequency tensor with complex exponentials using the given dimension | ||
| 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. | ||
| 'dim' and the end index 'end'. | ||
| The 'theta' parameter scales the frequencies. | ||
| The returned tensor contains complex values in complex64 data type. | ||
|
|
||
| Args: | ||
|
|
@@ -97,11 +102,30 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | |
| Returns: | ||
| torch.Tensor: Precomputed frequency tensor with complex exponentials. | ||
| """ | ||
| # Calculate the basis for rotation angles used in RoPE for different dimensions based on theta | ||
| # and dim parameters. | ||
| # Specifically compute θ_i = theta^{-2i / dim} for i ∈ [0, dim / 2). | ||
| # [Shape] freqs: (dim / 2, ) | ||
| freqs = 1.0 / ( | ||
| theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) | ||
| ) | ||
| t = torch.arange(end, device=freqs.device) # type: ignore | ||
| freqs = torch.outer(t, freqs).float() # type: ignore | ||
|
|
||
| # Generate t based on the end index to sample rotation angles for all positions. | ||
| # [Shape] t: (end, ) | ||
| t = torch.arange(end, device=freqs.device) | ||
|
|
||
| # Sample rotation angles for all positions and corresponding dimensions using outer product operation. | ||
| # For each position pos_i in t and each set of dimension-specific angle bases θ_j in freqs,sample | ||
| # the rotation angle corresponding to position and dimension pos_i * θ_j. | ||
| # [Shape] freqs: (end, dim / 2) | ||
| freqs = torch.outer(t, freqs).float() | ||
|
|
||
| # Convert all rotation angles corresponding to dimensions for all positions into complex exponential | ||
| # frequencies. | ||
| # That is, for each rotation angle θ, calculate the corresponding frequency e^{iθ}. | ||
| # Refer to docs/CN/RoPE for description based on Euler's formula. | ||
| # [Shape] freqs_cis: (end, dim / 2) | ||
| # [Type] freqs_cis: complex64 | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
| return freqs_cis | ||
|
|
||
|
|
@@ -123,9 +147,14 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
| AssertionError: If the frequency tensor doesn't match the expected shape. | ||
| AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. | ||
| """ | ||
| # Number of dimensions in the target tensor x | ||
| ndim = x.ndim | ||
| # Meaningless exception check? | ||
| assert 0 <= 1 < ndim | ||
| # Check if the precomputed frequency tensor matches the length and feature dimensions of the | ||
| # target tensor x. | ||
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
| # Calculate the target shape for reshaping the frequency tensor (1, seq_len, 1, dim / 2). | ||
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
| return freqs_cis.view(*shape) | ||
|
|
||
|
|
@@ -148,19 +177,50 @@ def apply_rotary_emb( | |
| freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with | ||
| rotary embeddings. | ||
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | ||
| """ | ||
| # Reshape the query tensor xq into (batch_size, seq_len, n_heads, dim / 2, 2) and then represent | ||
| # it in complex form. | ||
| # This groups the elements of the tensor into pairs and represents them as complex numbers, each | ||
| # pair as a 2D vector. | ||
| # [Shape] xq_: (batch_size, seq_len, n_heads, dim / 2) | ||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
|
|
||
| # Perform the same operation on the key tensor xk. | ||
| # [Shape] xk_: (batch_size, seq_len, n_heads, dim / 2) | ||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
|
|
||
| # Reshape the frequency tensor freqs_cis for broadcasting. | ||
| # [Shape] freqs_cis: (seq_len, dim / 2) -> (1, seq_len, 1, dim / 2) | ||
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
|
|
||
| # Perform rotation on the tensor using complex multiplication and convert the corresponding | ||
| # results back to real numbers. | ||
| # Flatten the last two dimensions to restore the tensor to its original shape. | ||
| # This converts each grouped 2D tensor d = [d1, d2]^T into a complex number d1 + i d2. | ||
| # Use complex multiplication with the corresponding complex frequency e^{iθ}, then convert back | ||
| # to real numbers, | ||
| # achieving R(θ)d, which rotates each grouped 2D tensor by the corresponding angle, thereby implementing | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check the usage of "achieving". |
||
| # RoPE (Rotation Position Encoding). | ||
| # [Shape] xq_out: (batch_size, seq_len, n_heads, dim) | ||
| # [Shape] xk_out: (batch_size, seq_len, n_heads, dim) | ||
| # flatten operation instance (Shape changes): | ||
| # (batch_size, seq_len, n_heads, dim // 2, 2) -flatten(3)-> (batch_size, seq_len, n_heads, dim) | ||
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
| return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
|
||
|
|
||
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
| """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" | ||
| """在 n_kv_heads 维度上重复扩展 key 或 query 张量. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBD. |
||
|
|
||
| Args: | ||
| x (torch.Tensor): 输入张量, Shape: (batch_size, sequence_length, n_kv_heads, head_dim). | ||
| n_rep (int): 重复次数. | ||
|
|
||
| Returns: | ||
| torch.Tensor: 扩展后的张量. | ||
| """ | ||
| bs, slen, n_kv_heads, head_dim = x.shape | ||
| if n_rep == 1: | ||
| return x | ||
|
|
@@ -232,7 +292,7 @@ def __init__(self, args: ModelArgs): | |
| input_is_parallel=True, | ||
| init_method=lambda x: x, | ||
| ) | ||
|
|
||
| # key_cache_in_Attention | ||
| self.cache_k = torch.zeros( | ||
| ( | ||
| args.max_batch_size, | ||
|
|
@@ -241,6 +301,7 @@ def __init__(self, args: ModelArgs): | |
| self.head_dim, | ||
| ) | ||
| ).cuda() | ||
| # value_cache_in_Attention | ||
| self.cache_v = torch.zeros( | ||
| ( | ||
| args.max_batch_size, | ||
|
|
@@ -270,50 +331,61 @@ def forward( | |
|
|
||
| """ | ||
| bsz, seqlen, _ = x.shape | ||
| # Perform query, key, value transformations on input x. | ||
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
|
||
| # Adjust the sizes of query, key, and value to distribute features across multiple heads. | ||
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | ||
| xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | ||
|
|
||
| # Apply RoPE operation on query and key. | ||
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
|
||
| self.cache_k = self.cache_k.to(xq) | ||
| self.cache_v = self.cache_v.to(xq) | ||
|
|
||
| # Store the newly computed key and query into the corresponding cache based on the starting | ||
| # position and length of the new sequence. | ||
| self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | ||
| self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | ||
|
|
||
| # Retrieve keys and values needed for attention computation. | ||
| keys = self.cache_k[:bsz, : start_pos + seqlen] | ||
| values = self.cache_v[:bsz, : start_pos + seqlen] | ||
|
|
||
| # repeat k/v heads if n_kv_heads < n_heads | ||
| keys = repeat_kv( | ||
| keys, self.n_rep | ||
| ) # (bs, cache_len + seqlen, n_local_heads, head_dim) | ||
| values = repeat_kv( | ||
| values, self.n_rep | ||
| ) # (bs, cache_len + seqlen, n_local_heads, head_dim) | ||
|
|
||
| xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | ||
| keys = keys.transpose( | ||
| 1, 2 | ||
| ) # (bs, n_local_heads, cache_len + seqlen, head_dim) | ||
| values = values.transpose( | ||
| 1, 2 | ||
| ) # (bs, n_local_heads, cache_len + seqlen, head_dim) | ||
| # If n_kv_heads < n_heads, expand in the n_kv_heads dimension to match the sizes of query, key, | ||
| # and value. | ||
| # Q: Why would n_kv_heads be less than n_heads? | ||
| # A: This occurs when the transformer layer uses Grouped-Query Attention. | ||
| # Multiple query heads share a pair of key and value heads to reduce the KV-Cache, hence | ||
| # n_heads must be divisible by n_kv_heads. | ||
| # [Shape] keys: (batch_size, cache_len + seq_len, n_local_heads, head_dim) | ||
| keys = repeat_kv(keys, self.n_rep) | ||
| # [Shape] values: (batch_size, cache_len + seq_len, n_local_heads, head_dim) | ||
| values = repeat_kv(values, self.n_rep) | ||
|
|
||
| # [Shape] xq: (batch_size, n_local_heads, seq_len, head_dim) | ||
| xq = xq.transpose(1, 2) | ||
| # [Shape] keys: (batch_size, n_local_heads, cache_len + seq_len, head_dim) | ||
| keys = keys.transpose(1, 2) | ||
| # [Shape] values: (batch_size, n_local_heads, cache_len + seq_len, head_dim) | ||
| values = values.transpose(1, 2) | ||
| # Calculate attention scores. | ||
| scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( | ||
| self.head_dim | ||
| ) | ||
| # If there is a mask, add it to attention scores to mask out parts of the scores. | ||
| if mask is not None: | ||
| scores = ( | ||
| scores + mask | ||
| ) # (bs, n_local_heads, seqlen, cache_len + seqlen) | ||
| scores = scores + mask | ||
| # Normalize attention scores using softmax. | ||
| scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
| output = torch.matmul( | ||
| scores, values | ||
| ) # (bs, n_local_heads, seqlen, head_dim) | ||
| # Aggregate the values vector using attention scores. | ||
| # [Shape] output: (batch_size, n_local_heads, seq_len, head_dim) | ||
| output = torch.matmul(scores, values) | ||
| # [Shape] output: (batch_size, seq_len, dim) | ||
| output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
| # Apply a linear transformation to the output. | ||
| return self.wo(output) | ||
|
|
||
|
|
||
|
|
@@ -344,9 +416,10 @@ def __init__( | |
| """ | ||
| super().__init__() | ||
| hidden_dim = int(2 * hidden_dim / 3) | ||
| # custom dim factor multiplier | ||
| # Custom scaling parameter for hidden layer dimensions. | ||
| if ffn_dim_multiplier is not None: | ||
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | ||
| # Ensure hidden_dim is a multiple of multiple_of. | ||
| hidden_dim = multiple_of * ( | ||
| (hidden_dim + multiple_of - 1) // multiple_of | ||
| ) | ||
|
|
@@ -432,9 +505,14 @@ def forward( | |
| torch.Tensor: Output tensor after applying attention and feedforward layers. | ||
|
|
||
| """ | ||
| # Residual calculation for attention. | ||
| # [Shape] h: (batch_size, seq_len, dim) | ||
| h = x + self.attention( | ||
| self.attention_norm(x), start_pos, freqs_cis, mask | ||
| ) | ||
|
|
||
| # Residual calculation for FeedForward. | ||
| # [Shape] out: (batch_size, seq_len, dim) | ||
| out = h + self.feed_forward(self.ffn_norm(h)) | ||
| return out | ||
|
|
||
|
|
@@ -476,7 +554,7 @@ def __init__(self, params: ModelArgs): | |
| self.output = ColumnParallelLinear( | ||
| params.dim, params.vocab_size, bias=False, init_method=lambda x: x | ||
| ) | ||
|
|
||
| # Compute complex exponential frequency tensor for RoPE. | ||
| self.freqs_cis = precompute_freqs_cis( | ||
| # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the | ||
| # Llama 2 generation of models is 4096. Adding this multiplier instead of using 4096 | ||
|
|
@@ -498,28 +576,56 @@ def forward(self, tokens: torch.Tensor, start_pos: int): | |
|
|
||
| """ | ||
| _bsz, seqlen = tokens.shape | ||
| # Compute embeddings corresponding to the input token indices. | ||
| h = self.tok_embeddings(tokens) | ||
| self.freqs_cis = self.freqs_cis.to(h.device) | ||
| # Slice the complex exponential frequency tensor based on the starting position | ||
| # and sequence length. | ||
| freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
|
|
||
| mask = None | ||
| # If the sequence length is greater than 1, calculate the mask for Attention. | ||
| if seqlen > 1: | ||
| # First, create a [seq_len, seq_len] matrix where each element is set to -inf. | ||
| mask = torch.full( | ||
| (seqlen, seqlen), float("-inf"), device=tokens.device | ||
| ) | ||
| # Retain the upper triangular part of the matrix, excluding the main diagonal values, | ||
| # which are set to -inf, while all other elements are set to 0. | ||
| # Indicates each token can only attend to itself and previously processed tokens. | ||
|
|
||
| mask = torch.triu(mask, diagonal=1) | ||
|
|
||
| # When performing key-value caching, we compute the attention scores | ||
| # only for the new sequence. Thus, the matrix of scores is of size | ||
| # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for | ||
| # j > cache_len + i, since row i corresponds to token cache_len + i. | ||
| # Due to key-value caching, we only need to compute attention scores for the new sequence. | ||
| # Therefore, the desired size of attention scores is (seq_len, cache_len + seq_len). | ||
| # For the mask, elements (i, j) where j > cache_len + i need to be masked out. | ||
| # In practice, we prepend the previously computed [seq_len, seq_len] mask matrix with a | ||
| # zero-filled matrix of size (seq_len, start_pos). This part does not require any masking, | ||
| # so we concatenate a zero matrix of size (seq_len, start_pos) to the beginning of the mask. | ||
| # ┌─────────────────────────────────────────────────────────┐ | ||
| # │ > mask maxtirx visualization │ | ||
| # │ │ | ||
| # │ hstack │ | ||
| # │ ↓ │ | ||
| # │ ↙ [0][0][0][0][0][0][0][0] | [0][x][x][x][x] │ | ||
| # │ ↙ [0][0][0][0][0][0][0][0] | [0][0][x][x][x] │ | ||
| # │ seq_len(5) ← [0][0][0][0][0][0][0][0] | [0][0][0][x][x] │ | ||
| # │ ↖ [0][0][0][0][0][0][0][0] | [0][0][0][0][x] │ | ||
| # │ ↖ [0][0][0][0][0][0][0][0] | [0][0][0][0][0] │ | ||
| # │ ↑ ↘ seq_len ↙ │ | ||
| # │ cache_len(8) │ | ||
| # │ │ | ||
| # │ x: denote the -inf value (masked value) │ | ||
| # └─────────────────────────────────────────────────────────┘ | ||
| mask = torch.hstack( | ||
| [torch.zeros((seqlen, start_pos), device=tokens.device), mask] | ||
| ).type_as(h) | ||
|
|
||
| # Iterate through each layer in the network. | ||
| for layer in self.layers: | ||
| h = layer(h, start_pos, freqs_cis, mask) | ||
| # Perform normalization before entering the output layer. | ||
| h = self.norm(h) | ||
| # Obtain logits through the final output layer. | ||
| output = self.output(h).float() | ||
| return output | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe apply the RMSNorm?