diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 7f64559423..7b7ebc0f3e 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -745,12 +745,7 @@ def _apply_rotary_pos_emb_thd( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ - if cp_group is not None: - cp_size = cp_group.size() - else: - cp_size = mpu.get_context_parallel_world_size() - cu_seqlens_for_batched = cu_seqlens // cp_size - use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() + use_batched_rope = freqs.dim() >= 1 and freqs.shape[0] == t.shape[0] if not use_batched_rope: logger.warning_once('Using non-batched RoPE, which may affect performance.') kwargs = {'cp_group': cp_group} if mcore_013 else {}