@@ -183,7 +183,9 @@ class TensorSpec:
183183 static fields are preserved from the host-side template."""
184184
185185 dtype : Type [cutlass .Numeric ]
186- shape : Tuple [int , int ]
186+ shape : Tuple [
187+ int , ...
188+ ] # 2D for matmul operands; 1D supported for vector-with-stage aux operands
187189 stage : Optional [int ] = None
188190 layout : LayoutEnum = LayoutEnum .ROW_MAJOR
189191 transposed : bool = False
@@ -333,6 +335,14 @@ def with_tma_atom(self, tma_atom, gmem) -> "TensorSpec":
333335 rather than TensorSpec args)."""
334336 return replace (self , tma_atom = tma_atom , gmem = gmem )
335337
338+ def with_gmem (self , gmem ) -> "TensorSpec" :
339+ """Return a new spec with only a GMEM tensor attached.
340+
341+ Use for operands that should cross the kernel boundary as TensorSpecs but
342+ are copied with non-TMA helpers.
343+ """
344+ return replace (self , gmem = gmem )
345+
336346 def with_smem (self , storage_or_tensor ) -> "TensorSpec" :
337347 """Return a new spec with the SMEM tensor attached (call inside the kernel).
338348 Accepts either a SmemAllocator storage field (derives the tensor with this
@@ -355,13 +365,18 @@ def in_rmem(self) -> bool:
355365 return self .stage is None
356366
357367 @property
358- def storage_shape (self ) -> Tuple [int , int ]:
368+ def storage_shape (self ) -> Tuple [int , ...]:
369+ # 1D has nothing to transpose.
370+ if len (self .shape ) == 1 :
371+ return self .shape
359372 return (self .shape [1 ], self .shape [0 ]) if self .transposed else self .shape
360373
361374 def smem_layout (self , tiled_mma = None ):
362375 """Derive the SMEM layout for this operand.
363376
364377 Resolution order:
378+ 0. 1D shape (vector-with-stage aux operand) → trivial `cute.make_layout`,
379+ no swizzling. Bypasses all matmul-side layout logic.
365380 1. `smem_layout_override` set → return it as-is.
366381 2. `smem_axis_pattern` set → axis-pattern helper (Mojo-style, role-free).
367382 3. `is_epi` → SM100 epi helper.
@@ -383,6 +398,9 @@ def smem_layout(self, tiled_mma=None):
383398 assert not self .in_rmem , "register tensor has no SMEM layout"
384399 if self .smem_layout_override is not None :
385400 return self .smem_layout_override
401+ if len (self .shape ) == 1 :
402+ # 1D vector-with-stage: no swizzling needed (small + naturally aligned).
403+ return cute .make_layout ((self .shape [0 ], self .stage ))
386404 if self .smem_axis_pattern is not None :
387405 if self .smem_axis_pattern == "K" :
388406 return make_smem_layout_kmajor (
@@ -427,13 +445,16 @@ def with_axis_pattern(self, pattern: Literal["K", "MN"]) -> "TensorSpec":
427445 return replace (self , smem_axis_pattern = pattern )
428446
429447 def tma_copy_bytes (self ) -> int :
430- return cute .size_in_bytes (self .dtype , cute .select (self .smem_layout (), mode = [0 , 1 ]))
448+ # 1D specs have a single non-stage mode; 2D specs have two.
449+ modes = [0 ] if len (self .shape ) == 1 else [0 , 1 ]
450+ return cute .size_in_bytes (self .dtype , cute .select (self .smem_layout (), mode = modes ))
431451
432452 def make_tma_atom (self , op , gmem_tensor , num_multicast : int = 1 ):
453+ modes = [0 ] if len (self .shape ) == 1 else [0 , 1 ]
433454 return cpasync .make_tiled_tma_atom (
434455 op ,
435456 gmem_tensor ,
436- cute .select (self .smem_layout (), mode = [ 0 , 1 ] ),
457+ cute .select (self .smem_layout (), mode = modes ),
437458 self .storage_shape ,
438459 num_multicast = num_multicast ,
439460 )
@@ -458,7 +479,9 @@ def make_tma_pipeline(self, barrier_storage, producer, consumer, **kwargs):
458479 def get_smem_tensor (self , storage_field ):
459480 """Materialize the SMEM tensor backed by `storage_field` with this spec's layout."""
460481 layout = self .smem_layout ()
461- return storage_field .get_tensor (layout .outer , swizzle = layout .inner )
482+ if hasattr (layout , "outer" ):
483+ return storage_field .get_tensor (layout .outer , swizzle = layout .inner )
484+ return storage_field .get_tensor (layout )
462485
463486 @property
464487 def smem_T (self ) -> cute .Tensor :
0 commit comments