@@ -44,11 +44,14 @@ namespace infini::ops {
4444// gathered `[T, D]` tensors to ATB Rope. The `seqlen` input is a single
4545// int32 element equal to T (all tokens treated as one batch).
4646//
47- // Restrictions (implementation choices, not ATB API limits) :
47+ // Restrictions:
4848// - `rotary_dim` must equal `head_size` (full rotation only). ATB
4949// RopeParam supports `rotaryCoeff=2/4/head_size/head_size_2` per the
50- // CANN 8.5 ATB docs; this wrapper plumbs only `rotaryCoeff=2`.
51- // - `is_neox_style` must be true.
50+ // CANN 8.5 ATB docs. This wrapper plumbs:
51+ // * `rotaryCoeff=2` when `is_neox_style=true` (half split + cat)
52+ // * `rotaryCoeff=head_size` when `is_neox_style=false` (interleave)
53+ // Partial rotary (`rotary_dim < head_size`) is not supported by either
54+ // the aclnn or ATB fused APIs; callers must pad to `head_size` upstream.
5255template <>
5356class Operator <RotaryEmbedding, Device::Type::kAscend , 1 >
5457 : public RotaryEmbedding {
@@ -57,11 +60,10 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
5760 const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
5861 bool is_neox_style, Tensor query_out, Tensor key_out)
5962 : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
60- rotary_dim, is_neox_style, query_out, key_out) {
63+ rotary_dim, is_neox_style, query_out, key_out),
64+ is_neox_style_{is_neox_style} {
6165 assert (rotary_dim == head_size &&
6266 " ATB `RotaryEmbedding` requires rotary_dim == head_size" );
63- assert (is_neox_style &&
64- " ATB `RotaryEmbedding` requires neox style (rotaryCoeff=2)" );
6567
6668 const int64_t D = head_size_;
6769 const size_t elem_sz = cos_sin_cache.element_size ();
@@ -110,10 +112,14 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
110112 cos_out_cache_ = ascend::AclTensorCache ({T, D}, acl_dt_, cos_dev_);
111113 sin_out_cache_ = ascend::AclTensorCache ({T, D}, acl_dt_, sin_dev_);
112114
113- // Create the ATB Rope operation.
115+ // Create the ATB Rope operation. `rotaryCoeff` selects the rotation
116+ // pattern: 2 for neox (split-then-rotate halves), `head_size` for
117+ // interleave (pair-wise rotate adjacent elements).
114118 atb::infer::RopeParam param;
115- param.rotaryCoeff = 2 ; // Neox half-rotation.
116- param.cosFormat = 0 ; // Inference mode.
119+ param.rotaryCoeff = is_neox_style
120+ ? 2
121+ : static_cast <int32_t >(D);
122+ param.cosFormat = 0 ; // Inference mode.
117123 atb::Status s = atb::CreateOperation (param, &op_);
118124
119125 assert (s == atb::NO_ERROR && " atb::CreateOperation(Rope) failed" );
@@ -254,8 +260,16 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
254260 }
255261
256262 private:
257- // D2H copy cos_sin_cache, split into cos/sin, neox-expand, and upload to
258- // device. Called once at construction.
263+ // D2H copy cos_sin_cache, split into cos/sin, expand to `[max_seq_len, D]`
264+ // in the layout that ATB Rope expects for the chosen `rotaryCoeff`, and
265+ // upload to device. Called once at construction.
266+ //
267+ // For `rotaryCoeff=2` (neox): cos tensor holds the same `half_D` values
268+ // duplicated front/back — `[c0 .. c_{half-1}, c0 .. c_{half-1}]`.
269+ //
270+ // For `rotaryCoeff=head_size` (interleave): cos tensor holds each of the
271+ // `half_D` values repeated pair-wise —
272+ // `[c0, c0, c1, c1, .., c_{half-1}, c_{half-1}]`.
259273 void uploadCosSinCache (const Tensor cos_sin_cache) const {
260274 const int64_t D = head_size_;
261275 const int64_t half_D = D / 2 ;
@@ -277,16 +291,35 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
277291 const auto * s_src = cache_host.data () +
278292 static_cast <size_t >(p * D + half_D + j) * elem_sz;
279293
280- std::memcpy (cos_host.data () + static_cast <size_t >(p * D + j) * elem_sz,
281- c_src, elem_sz);
282- std::memcpy (
283- cos_host.data () + static_cast <size_t >(p * D + half_D + j) * elem_sz,
284- c_src, elem_sz);
285- std::memcpy (sin_host.data () + static_cast <size_t >(p * D + j) * elem_sz,
286- s_src, elem_sz);
287- std::memcpy (
288- sin_host.data () + static_cast <size_t >(p * D + half_D + j) * elem_sz,
289- s_src, elem_sz);
294+ if (is_neox_style_) {
295+ // Neox layout: [c_j ... , c_j ...] front/back duplication.
296+ std::memcpy (
297+ cos_host.data () + static_cast <size_t >(p * D + j) * elem_sz,
298+ c_src, elem_sz);
299+ std::memcpy (
300+ cos_host.data () + static_cast <size_t >(p * D + half_D + j) * elem_sz,
301+ c_src, elem_sz);
302+ std::memcpy (
303+ sin_host.data () + static_cast <size_t >(p * D + j) * elem_sz,
304+ s_src, elem_sz);
305+ std::memcpy (
306+ sin_host.data () + static_cast <size_t >(p * D + half_D + j) * elem_sz,
307+ s_src, elem_sz);
308+ } else {
309+ // Interleave layout: each value repeated pair-wise.
310+ std::memcpy (
311+ cos_host.data () + static_cast <size_t >(p * D + 2 * j) * elem_sz,
312+ c_src, elem_sz);
313+ std::memcpy (
314+ cos_host.data () + static_cast <size_t >(p * D + 2 * j + 1 ) * elem_sz,
315+ c_src, elem_sz);
316+ std::memcpy (
317+ sin_host.data () + static_cast <size_t >(p * D + 2 * j) * elem_sz,
318+ s_src, elem_sz);
319+ std::memcpy (
320+ sin_host.data () + static_cast <size_t >(p * D + 2 * j + 1 ) * elem_sz,
321+ s_src, elem_sz);
322+ }
290323 }
291324 }
292325
@@ -296,6 +329,8 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
296329 ACL_MEMCPY_HOST_TO_DEVICE );
297330 }
298331
332+ bool is_neox_style_;
333+
299334 atb::Operation* op_ = nullptr ;
300335
301336 // Neox-expanded cos/sin tables on device: [max_seq_len, D].
0 commit comments