Skip to content

Commit 7fd0e83

Browse files
committed
[DSL] Add TensorSpec for 1D
1 parent 2f14834 commit 7fd0e83

1 file changed

Lines changed: 28 additions & 5 deletions

File tree

quack/tensor_spec.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ class TensorSpec:
183183
static fields are preserved from the host-side template."""
184184

185185
dtype: Type[cutlass.Numeric]
186-
shape: Tuple[int, int]
186+
shape: Tuple[
187+
int, ...
188+
] # 2D for matmul operands; 1D supported for vector-with-stage aux operands
187189
stage: Optional[int] = None
188190
layout: LayoutEnum = LayoutEnum.ROW_MAJOR
189191
transposed: bool = False
@@ -333,6 +335,14 @@ def with_tma_atom(self, tma_atom, gmem) -> "TensorSpec":
333335
rather than TensorSpec args)."""
334336
return replace(self, tma_atom=tma_atom, gmem=gmem)
335337

338+
def with_gmem(self, gmem) -> "TensorSpec":
339+
"""Return a new spec with only a GMEM tensor attached.
340+
341+
Use for operands that should cross the kernel boundary as TensorSpecs but
342+
are copied with non-TMA helpers.
343+
"""
344+
return replace(self, gmem=gmem)
345+
336346
def with_smem(self, storage_or_tensor) -> "TensorSpec":
337347
"""Return a new spec with the SMEM tensor attached (call inside the kernel).
338348
Accepts either a SmemAllocator storage field (derives the tensor with this
@@ -355,13 +365,18 @@ def in_rmem(self) -> bool:
355365
return self.stage is None
356366

357367
@property
358-
def storage_shape(self) -> Tuple[int, int]:
368+
def storage_shape(self) -> Tuple[int, ...]:
369+
# 1D has nothing to transpose.
370+
if len(self.shape) == 1:
371+
return self.shape
359372
return (self.shape[1], self.shape[0]) if self.transposed else self.shape
360373

361374
def smem_layout(self, tiled_mma=None):
362375
"""Derive the SMEM layout for this operand.
363376
364377
Resolution order:
378+
0. 1D shape (vector-with-stage aux operand) → trivial `cute.make_layout`,
379+
no swizzling. Bypasses all matmul-side layout logic.
365380
1. `smem_layout_override` set → return it as-is.
366381
2. `smem_axis_pattern` set → axis-pattern helper (Mojo-style, role-free).
367382
3. `is_epi` → SM100 epi helper.
@@ -383,6 +398,9 @@ def smem_layout(self, tiled_mma=None):
383398
assert not self.in_rmem, "register tensor has no SMEM layout"
384399
if self.smem_layout_override is not None:
385400
return self.smem_layout_override
401+
if len(self.shape) == 1:
402+
# 1D vector-with-stage: no swizzling needed (small + naturally aligned).
403+
return cute.make_layout((self.shape[0], self.stage))
386404
if self.smem_axis_pattern is not None:
387405
if self.smem_axis_pattern == "K":
388406
return make_smem_layout_kmajor(
@@ -427,13 +445,16 @@ def with_axis_pattern(self, pattern: Literal["K", "MN"]) -> "TensorSpec":
427445
return replace(self, smem_axis_pattern=pattern)
428446

429447
def tma_copy_bytes(self) -> int:
430-
return cute.size_in_bytes(self.dtype, cute.select(self.smem_layout(), mode=[0, 1]))
448+
# 1D specs have a single non-stage mode; 2D specs have two.
449+
modes = [0] if len(self.shape) == 1 else [0, 1]
450+
return cute.size_in_bytes(self.dtype, cute.select(self.smem_layout(), mode=modes))
431451

432452
def make_tma_atom(self, op, gmem_tensor, num_multicast: int = 1):
453+
modes = [0] if len(self.shape) == 1 else [0, 1]
433454
return cpasync.make_tiled_tma_atom(
434455
op,
435456
gmem_tensor,
436-
cute.select(self.smem_layout(), mode=[0, 1]),
457+
cute.select(self.smem_layout(), mode=modes),
437458
self.storage_shape,
438459
num_multicast=num_multicast,
439460
)
@@ -458,7 +479,9 @@ def make_tma_pipeline(self, barrier_storage, producer, consumer, **kwargs):
458479
def get_smem_tensor(self, storage_field):
459480
"""Materialize the SMEM tensor backed by `storage_field` with this spec's layout."""
460481
layout = self.smem_layout()
461-
return storage_field.get_tensor(layout.outer, swizzle=layout.inner)
482+
if hasattr(layout, "outer"):
483+
return storage_field.get_tensor(layout.outer, swizzle=layout.inner)
484+
return storage_field.get_tensor(layout)
462485

463486
@property
464487
def smem_T(self) -> cute.Tensor:

0 commit comments

Comments
 (0)