diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 86818bb721a..ed661c75517 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -123,6 +123,14 @@ class ModelArgs: use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. rope_scale_factor: int = 8 high_freq_factor: int = 4 + # LongRoPE (https://arxiv.org/abs/2402.13753) used by Phi-3 / Phi-4 family. + # Mirrors HF's rope_scaling.{short_factor,long_factor,attention_factor} + # plus original_max_position_embeddings / max_position_embeddings. + rope_scaling_short_factor: Optional[list] = None + rope_scaling_long_factor: Optional[list] = None + original_max_position_embeddings: Optional[int] = None + max_position_embeddings: Optional[int] = None + rope_scaling_attention_factor: Optional[float] = None # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index ea4e6b37243..7fc111e2c34 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -136,32 +136,80 @@ def forward( # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 # and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242. -# Current only support non-long rope. +# Supports both vanilla HF RoPE and LongRoPE +# (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py +# `_compute_longrope_parameters`), used by Phi-3 / Phi-4 family. def hf_precompute_freqs_cis( dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0, device: Union[str, torch.device] = "cpu", + short_factor: Optional[list] = None, + long_factor: Optional[list] = None, + original_max_pos: Optional[int] = None, + max_position_embeddings: Optional[int] = None, + attention_factor: Optional[float] = None, ): # Partial rotary embeddings. dim = int(dim * partial_rotary_factor) - # Short factor scaling. - freqs = 1.0 / ( + # Compute the RoPE table in fp64 to minimize ULP-level drift; cast to fp32 + # once at the end. Phi-4 Mini's narrow decode-time logit margins make the + # exported model sensitive to 1-ULP differences in freqs_cos / freqs_sin + # under sampling, especially on the Vulkan delegate. + inv_freq = 1.0 / ( theta - ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim) + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(torch.float64) + / dim + ) ) - # TODO: support long factor scaling. + + # LongRoPE: divide inv_freq element-wise by short_factor or long_factor. + # Selection mirrors HF: long_factor when seq_len > original_max_position_embeddings. + longrope_active = (short_factor is not None) or (long_factor is not None) + if longrope_active: + chosen = ( + long_factor + if (original_max_pos is not None and end > original_max_pos) + else short_factor + ) + if chosen is None: + # Fall back to whichever factor was provided. + chosen = short_factor if long_factor is None else long_factor + ext_factors = torch.tensor(chosen, dtype=torch.float64, device=device) + assert ext_factors.numel() == inv_freq.numel(), ( + f"LongRoPE factor length {ext_factors.numel()} must equal dim/2 " + f"({inv_freq.numel()})" + ) + inv_freq = inv_freq / ext_factors + + # Derive attention_factor if not provided (matches HF's + # _compute_longrope_parameters default). + if attention_factor is None and original_max_pos is not None: + ref_max_pos = ( + max_position_embeddings if max_position_embeddings is not None else end + ) + scaling_factor = ref_max_pos / original_max_pos + if scaling_factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt( + 1 + math.log(scaling_factor) / math.log(original_max_pos) + ) # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. - t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( - freqs # pyre-ignore - ) - freqs = torch.outer(t, freqs).float() # pyre-ignore + t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(torch.float64) + freqs = torch.outer(t, inv_freq).to(torch.float64) # pyre-ignore emb = torch.cat((freqs, freqs), dim=-1) - freqs_cos = torch.cos(emb) - freqs_sin = torch.sin(emb) + cos_tab = torch.cos(emb) + sin_tab = torch.sin(emb) + if attention_factor is not None and attention_factor != 1.0: + cos_tab = cos_tab * attention_factor + sin_tab = sin_tab * attention_factor + freqs_cos = cos_tab.to(torch.float32) + freqs_sin = sin_tab.to(torch.float32) return freqs_cos, freqs_sin @@ -241,9 +289,25 @@ def __init__(self, params: ModelArgs): hf_precompute_freqs_cis, partial_rotary_factor=self.params.partial_rotary_factor, device=getattr(self.params, "device", "cpu"), + short_factor=getattr(self.params, "rope_scaling_short_factor", None), + long_factor=getattr(self.params, "rope_scaling_long_factor", None), + original_max_pos=getattr( + self.params, "original_max_position_embeddings", None + ), + max_position_embeddings=getattr( + self.params, "max_position_embeddings", None + ), + attention_factor=getattr( + self.params, "rope_scaling_attention_factor", None + ), ) self.apply_rotary_emb = hf_apply_rotary_emb else: + # NOTE: precompute_freqs_cis (the non-HF path) does not implement + # LongRoPE today. Models using rope_scaling.type == "longrope" must + # set use_hf_rope=True. If a future model needs LongRoPE on the + # vanilla path, mirror the short_factor/long_factor/attention_factor + # plumbing from hf_precompute_freqs_cis. self.precompute_freqs_cis = partial( precompute_freqs_cis, use_scaled=self.params.use_scaled_rope, diff --git a/examples/models/phi_4_mini/config/config.json b/examples/models/phi_4_mini/config/config.json index edce93e59fa..59f4440747c 100644 --- a/examples/models/phi_4_mini/config/config.json +++ b/examples/models/phi_4_mini/config/config.json @@ -11,5 +11,9 @@ "vocab_size": 200064, "use_hf_rope": true, "partial_rotary_factor": 0.75, - "attention_qkv_bias": false + "attention_qkv_bias": false, + "original_max_position_embeddings": 4096, + "max_position_embeddings": 131072, + "rope_scaling_short_factor": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "rope_scaling_long_factor": [1.0, 1.118320672, 1.250641126, 1.398617824, 1.564103225, 1.74916897, 1.956131817, 2.187582649, 2.446418898, 2.735880826, 3.059592084, 3.421605075, 3.826451687, 4.279200023, 4.785517845, 5.351743533, 5.984965424, 6.693110555, 7.485043894, 8.370679318, 9.36110372, 10.4687158, 11.70738129, 13.09260651, 14.64173252, 16.37415215, 18.31155283, 20.47818807, 22.90118105, 25.61086418, 28.64115884, 32.03, 32.1, 32.13, 32.23, 32.6, 32.61, 32.64, 32.66, 32.7, 32.71, 32.93, 32.97, 33.28, 33.49, 33.5, 44.16, 47.77] }