Skip to content

Commit a79521b

Browse files
SS-JIAclaude
andcommitted
Add LongRoPE support and fp64 RoPE precompute for Phi-3 / Phi-4 family
Summary: Adds LongRoPE plumbing and an fp64 cos/sin precompute pass to hf_precompute_freqs_cis. Together these eliminate Phi-4 Mini decode-time n-gram repetition under both XNNPACK and Vulkan delegates. Phi-3 and Phi-4 family models use HF's "longrope" RoPE scaling, which multiplies cos/sin by an attention_factor (~1.19 for Phi-4 Mini) and divides inv_freq element-wise by a per-dimension short_factor (when seq_len <= original_max_position_embeddings) or long_factor. ET's hf_precompute_freqs_cis was vanilla RoPE -- missing both terms. At typical export configurations the dominant effect is the missing attention_factor, which leaves attention scores ~1.42x softer than the model was trained for. Compounded across 32 layers this pushes Phi-4 Mini's narrow top-2 logit margins past their tipping point and triggers greedy-decode n-gram repetition; the same error explains prior on-device looping observed under both XNNPACK and Vulkan. Adds LongRoPE plumbing through ModelArgs (short_factor, long_factor, original_max_position_embeddings, max_position_embeddings, rope_scaling_attention_factor) and into hf_precompute_freqs_cis, with attention_factor derived as sqrt(1 + log(scaling)/log(original_max)) when not explicitly set. The non-HF precompute_freqs_cis path is left vanilla; longrope models must set use_hf_rope=True (noted in Rope.__init__). Also moves the cos/sin precompute to fp64, casting to fp32 once at the end. After LongRoPE corrects the 19% scale error, fp32 ULP-level rounding in the cos/sin tables becomes the next-largest contributor to logit drift -- load-bearing on Vulkan under sampling: with fp32 precompute, 1/2 T=0.5 trajectories collapsed into a 4-gram loop ("avoiding data and data biases") even with LongRoPE applied. fp64 precompute is one-time at construction (microseconds on a few-KB table); runtime tables remain fp32, so no inference-time cost. Wires the LongRoPE fields into examples/models/phi_4_mini/config/config.json sourced from HF's Phi-4 Mini config. Test Plan: Validated end-to-end on Samsung Galaxy S25: - Eager bf16 (host, 12 threads): 3/3 loop-free at T=0 greedy and T=0.5 sampling x 2 seeds. - XNNPACK 8da4w-g32 on device: 3/3 loop-free, ~20.7 tok/s decode. - Vulkan 8da4w-g32 on device: 3/3 loop-free, ~17.1 tok/s decode. Reproduced across two distinct S25 units to confirm result is not device-specific. Verified that omitting either fix regresses Vulkan sampling: LongRoPE alone leaves residual sampling loops; fp64 alone was previously known insufficient. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 42b25bb commit a79521b

3 files changed

Lines changed: 88 additions & 12 deletions

File tree

examples/models/llama/model_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ class ModelArgs:
123123
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
124124
rope_scale_factor: int = 8
125125
high_freq_factor: int = 4
126+
# LongRoPE (https://arxiv.org/abs/2402.13753) used by Phi-3 / Phi-4 family.
127+
# Mirrors HF's rope_scaling.{short_factor,long_factor,attention_factor}
128+
# plus original_max_position_embeddings / max_position_embeddings.
129+
rope_scaling_short_factor: Optional[list] = None
130+
rope_scaling_long_factor: Optional[list] = None
131+
original_max_position_embeddings: Optional[int] = None
132+
max_position_embeddings: Optional[int] = None
133+
rope_scaling_attention_factor: Optional[float] = None
126134
# Additional Model Metadata needed at runtime
127135
bos_idx: int = 1
128136
eos_idx: int = 3

examples/models/llama/rope.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,32 +136,80 @@ def forward(
136136

137137
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
138138
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139-
# Current only support non-long rope.
139+
# Supports both vanilla HF RoPE and LongRoPE
140+
# (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
141+
# `_compute_longrope_parameters`), used by Phi-3 / Phi-4 family.
140142
def hf_precompute_freqs_cis(
141143
dim: int,
142144
end: int,
143145
theta: float,
144146
partial_rotary_factor: float = 1.0,
145147
device: Union[str, torch.device] = "cpu",
148+
short_factor: Optional[list] = None,
149+
long_factor: Optional[list] = None,
150+
original_max_pos: Optional[int] = None,
151+
max_position_embeddings: Optional[int] = None,
152+
attention_factor: Optional[float] = None,
146153
):
147154
# Partial rotary embeddings.
148155
dim = int(dim * partial_rotary_factor)
149156

150-
# Short factor scaling.
151-
freqs = 1.0 / (
157+
# Compute the RoPE table in fp64 to minimize ULP-level drift; cast to fp32
158+
# once at the end. Phi-4 Mini's narrow decode-time logit margins make the
159+
# exported model sensitive to 1-ULP differences in freqs_cos / freqs_sin
160+
# under sampling, especially on the Vulkan delegate.
161+
inv_freq = 1.0 / (
152162
theta
153-
** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)
163+
** (
164+
torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(torch.float64)
165+
/ dim
166+
)
154167
)
155-
# TODO: support long factor scaling.
168+
169+
# LongRoPE: divide inv_freq element-wise by short_factor or long_factor.
170+
# Selection mirrors HF: long_factor when seq_len > original_max_position_embeddings.
171+
longrope_active = (short_factor is not None) or (long_factor is not None)
172+
if longrope_active:
173+
chosen = (
174+
long_factor
175+
if (original_max_pos is not None and end > original_max_pos)
176+
else short_factor
177+
)
178+
if chosen is None:
179+
# Fall back to whichever factor was provided.
180+
chosen = short_factor if long_factor is None else long_factor
181+
ext_factors = torch.tensor(chosen, dtype=torch.float64, device=device)
182+
assert ext_factors.numel() == inv_freq.numel(), (
183+
f"LongRoPE factor length {ext_factors.numel()} must equal dim/2 "
184+
f"({inv_freq.numel()})"
185+
)
186+
inv_freq = inv_freq / ext_factors
187+
188+
# Derive attention_factor if not provided (matches HF's
189+
# _compute_longrope_parameters default).
190+
if attention_factor is None and original_max_pos is not None:
191+
ref_max_pos = (
192+
max_position_embeddings if max_position_embeddings is not None else end
193+
)
194+
scaling_factor = ref_max_pos / original_max_pos
195+
if scaling_factor <= 1.0:
196+
attention_factor = 1.0
197+
else:
198+
attention_factor = math.sqrt(
199+
1 + math.log(scaling_factor) / math.log(original_max_pos)
200+
)
156201

157202
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
158-
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
159-
freqs # pyre-ignore
160-
)
161-
freqs = torch.outer(t, freqs).float() # pyre-ignore
203+
t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(torch.float64)
204+
freqs = torch.outer(t, inv_freq).to(torch.float64) # pyre-ignore
162205
emb = torch.cat((freqs, freqs), dim=-1)
163-
freqs_cos = torch.cos(emb)
164-
freqs_sin = torch.sin(emb)
206+
cos_tab = torch.cos(emb)
207+
sin_tab = torch.sin(emb)
208+
if attention_factor is not None and attention_factor != 1.0:
209+
cos_tab = cos_tab * attention_factor
210+
sin_tab = sin_tab * attention_factor
211+
freqs_cos = cos_tab.to(torch.float32)
212+
freqs_sin = sin_tab.to(torch.float32)
165213
return freqs_cos, freqs_sin
166214

167215

@@ -241,9 +289,25 @@ def __init__(self, params: ModelArgs):
241289
hf_precompute_freqs_cis,
242290
partial_rotary_factor=self.params.partial_rotary_factor,
243291
device=getattr(self.params, "device", "cpu"),
292+
short_factor=getattr(self.params, "rope_scaling_short_factor", None),
293+
long_factor=getattr(self.params, "rope_scaling_long_factor", None),
294+
original_max_pos=getattr(
295+
self.params, "original_max_position_embeddings", None
296+
),
297+
max_position_embeddings=getattr(
298+
self.params, "max_position_embeddings", None
299+
),
300+
attention_factor=getattr(
301+
self.params, "rope_scaling_attention_factor", None
302+
),
244303
)
245304
self.apply_rotary_emb = hf_apply_rotary_emb
246305
else:
306+
# NOTE: precompute_freqs_cis (the non-HF path) does not implement
307+
# LongRoPE today. Models using rope_scaling.type == "longrope" must
308+
# set use_hf_rope=True. If a future model needs LongRoPE on the
309+
# vanilla path, mirror the short_factor/long_factor/attention_factor
310+
# plumbing from hf_precompute_freqs_cis.
247311
self.precompute_freqs_cis = partial(
248312
precompute_freqs_cis,
249313
use_scaled=self.params.use_scaled_rope,

examples/models/phi_4_mini/config/config.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@
1111
"vocab_size": 200064,
1212
"use_hf_rope": true,
1313
"partial_rotary_factor": 0.75,
14-
"attention_qkv_bias": false
14+
"attention_qkv_bias": false,
15+
"original_max_position_embeddings": 4096,
16+
"max_position_embeddings": 131072,
17+
"rope_scaling_short_factor": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
18+
"rope_scaling_long_factor": [1.0, 1.118320672, 1.250641126, 1.398617824, 1.564103225, 1.74916897, 1.956131817, 2.187582649, 2.446418898, 2.735880826, 3.059592084, 3.421605075, 3.826451687, 4.279200023, 4.785517845, 5.351743533, 5.984965424, 6.693110555, 7.485043894, 8.370679318, 9.36110372, 10.4687158, 11.70738129, 13.09260651, 14.64173252, 16.37415215, 18.31155283, 20.47818807, 22.90118105, 25.61086418, 28.64115884, 32.03, 32.1, 32.13, 32.23, 32.6, 32.61, 32.64, 32.66, 32.7, 32.71, 32.93, 32.97, 33.28, 33.49, 33.5, 44.16, 47.77]
1519
}

0 commit comments

Comments
 (0)