Skip to content

Commit bd0b521

Browse files
committed
Move ct.print logic out of AST
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent 286fb81 commit bd0b521

File tree

9 files changed

+300
-180
lines changed

9 files changed

+300
-180
lines changed

src/cuda/tile/_ir/hir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def if_else(cond, then_block, else_block, /): ...
241241
def loop(body, iterable, /): ... # infinite if `iterable` is None
242242
def static_foreach(body, items, /): ...
243243
def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`)
244+
def build_formatted_string(format, *values): ... # Creates a FormattedStringTy value
244245
def unpack(iterable, expected_len, /): ...
245246
def identity(x): ... # Identity function (i.e. returns `x`)
246247
def store_var(name, value, /): ... # Store into a named variable

src/cuda/tile/_ir/ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,15 @@ def as_tuple(self) -> tuple["Var", ...]:
243243
return self.items
244244

245245

246+
@dataclass
247+
class FormattedStringValue(AggregateValue):
248+
format: "Any" # StringFormat from type.py
249+
values: tuple # tuple of Var
250+
251+
def as_tuple(self) -> tuple:
252+
return self.values
253+
254+
246255
@dataclass
247256
class RangeValue(AggregateValue):
248257
start: Var

src/cuda/tile/_ir/op_impl.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,15 @@ class PrintfValidator:
426426
specifiers = r"([diuoxXeEfFgGaAcspn])"
427427
pattern = re.compile("%" + flags + width + precision + length + specifiers)
428428

429+
@classmethod
430+
def infer_format(cls, dtype: DType) -> str:
431+
if is_boolean(dtype) or is_integral(dtype):
432+
return '%d'
433+
elif is_float(dtype) or is_restricted_float(dtype):
434+
return '%f'
435+
else:
436+
raise TileTypeError(f"print(): cannot infer format for dtype {dtype}")
437+
429438
@classmethod
430439
def validate_dtype(cls, dtype: DType, specifier: str) -> bool:
431440
if is_boolean(dtype) or is_integral(dtype):
@@ -435,8 +444,58 @@ def validate_dtype(cls, dtype: DType, specifier: str) -> bool:
435444
else:
436445
return False
437446

438-
# Placeholder emitted by ast2hir for type-inferred format specifiers
439-
_TYPE_INFER = '\x01'
447+
# Python format spec regex: [align][sign][alt][zero][width][.precision][type]
448+
_py_spec_pattern = re.compile(
449+
r'(?P<align>[<>^])?'
450+
r'(?P<sign>[+ -])?'
451+
r'(?P<alt>\#)?'
452+
r'(?P<zero>0)?'
453+
r'(?P<width>[0-9]+)?'
454+
r'(?:\.(?P<precision>[0-9]+))?'
455+
r'(?P<type>[diouxXeEfFgGaA])?'
456+
)
457+
458+
@staticmethod
459+
def escape_str(s: str) -> str:
460+
"""Escape a literal string for use in a C printf format (replace % with %%)."""
461+
return s.replace('%', '%%')
462+
463+
@classmethod
464+
def apply_python_spec(cls, py_spec: str, dtype: DType) -> str:
465+
"""Convert a Python format spec to a complete C printf specifier for the given dtype.
466+
467+
If py_spec omits the type character, it is inferred from dtype.
468+
If py_spec includes a type character, it is validated against dtype.
469+
Raises TileTypeError on type mismatch; ValueError on unrecognised spec syntax.
470+
"""
471+
m = cls._py_spec_pattern.fullmatch(py_spec)
472+
if m is None or m.group(0) != py_spec:
473+
raise ValueError(f"print(): unsupported format spec '{py_spec}'")
474+
475+
align = m.group('align')
476+
sign = m.group('sign')
477+
alt = m.group('alt')
478+
zero = m.group('zero')
479+
width = m.group('width') or ''
480+
precision = ('.' + m.group('precision')) if m.group('precision') is not None else ''
481+
typ = m.group('type')
482+
483+
flags = ''
484+
if align == '<':
485+
flags += '-'
486+
if sign in ('+', ' '):
487+
flags += sign
488+
if alt:
489+
flags += '#'
490+
if zero and align != '<':
491+
flags += '0'
492+
493+
if typ is None:
494+
typ = cls.infer_format(dtype)[1:] # inferred type char, e.g. 'd' from '%d'
495+
elif not cls.validate_dtype(dtype, typ):
496+
raise TileTypeError(
497+
f"print(): format spec '{py_spec}' is incompatible with dtype {dtype}")
498+
return f'%{flags}{width}{precision}{typ}'
440499

441500
@classmethod
442501
def parse_format(cls, format: str, arg_types: Tuple[Union[TileTy, DType], ...]) -> str:
@@ -445,21 +504,7 @@ def parse_format(cls, format: str, arg_types: Tuple[Union[TileTy, DType], ...])
445504
tokens = []
446505
while pos < len(format):
447506
ch = format[pos]
448-
if ch == cls._TYPE_INFER:
449-
tokens.append(format[last_pos:pos])
450-
if arg_idx >= len(arg_types):
451-
raise TileTypeError("Not enough arguments for format string")
452-
dtype = get_dtype(arg_types[arg_idx])
453-
if is_boolean(dtype) or is_integral(dtype):
454-
tokens.append('%d')
455-
elif is_float(dtype) or is_restricted_float(dtype):
456-
tokens.append('%f')
457-
else:
458-
raise TileTypeError(f"Cannot infer format for dtype {dtype}")
459-
arg_idx += 1
460-
pos += 1
461-
last_pos = pos
462-
elif ch == "%":
507+
if ch == "%":
463508
tokens.append(format[last_pos:pos])
464509
last_pos = pos
465510
# escape "%%"

src/cuda/tile/_ir/ops.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import builtins
45
import enum
56
import math
67
import operator
@@ -23,7 +24,8 @@
2324
add_operation, Builder,
2425
enter_nested_block, nested_block, PhiState, LoopVarState,
2526
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
26-
ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand, BlockRestriction
27+
ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand,
28+
BlockRestriction, FormattedStringValue,
2729
)
2830
from .type import PointerTy
2931
from . import hir
@@ -57,7 +59,8 @@
5759
PartitionViewTy, TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy,
5860
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
5961
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
60-
array_size_type, ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy
62+
array_size_type, ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy, FormattedStringTy,
63+
StringFormat, FormattedPiece,
6164
)
6265
from cuda.tile._datatype import (
6366
DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float,
@@ -1318,6 +1321,39 @@ def build_tuple(items: tuple[Var, ...]) -> Var:
13181321
return res
13191322

13201323

1324+
@impl(hir.build_formatted_string)
1325+
def build_formatted_string_impl(format: StringFormat, values: tuple[Var, ...]) -> Var:
1326+
new_pieces = []
1327+
new_values = []
1328+
for piece in format.pieces:
1329+
if isinstance(piece, str):
1330+
new_pieces.append(piece)
1331+
else:
1332+
val_var = values[piece.value_idx]
1333+
val_ty = val_var.get_type()
1334+
if isinstance(val_ty, FormattedStringTy):
1335+
if piece.format_spec is not None:
1336+
raise TileTypeError(
1337+
"f-string: cannot apply format spec to a formatted string value",
1338+
val_var.loc)
1339+
inner_val = val_var.get_aggregate()
1340+
assert isinstance(inner_val, FormattedStringValue)
1341+
offset = len(new_values)
1342+
for inner_piece in val_ty.format.pieces:
1343+
if isinstance(inner_piece, str):
1344+
new_pieces.append(inner_piece)
1345+
else:
1346+
new_pieces.append(FormattedPiece(
1347+
offset + inner_piece.value_idx, inner_piece.format_spec))
1348+
new_values.extend(inner_val.values)
1349+
else:
1350+
new_pieces.append(FormattedPiece(len(new_values), piece.format_spec))
1351+
new_values.append(val_var)
1352+
new_fmt = StringFormat(tuple(new_pieces))
1353+
ty = FormattedStringTy(new_fmt, tuple(v.get_type() for v in new_values))
1354+
return make_aggregate(FormattedStringValue(new_fmt, tuple(new_values)), ty)
1355+
1356+
13211357
@impl(hir.unpack)
13221358
def unpack_impl(iterable: Var, expected_len: Var) -> Var:
13231359
ty = iterable.get_type()
@@ -3638,6 +3674,50 @@ def printf_impl(format: Var, args: Tuple[Var, ...]) -> None:
36383674
add_operation(TilePrintf, (), format=parsed_format, args=args)
36393675

36403676

3677+
@impl(ct.print)
3678+
@impl(builtins.print)
3679+
def print_impl(args: Tuple[Var, ...], sep: Var, end: Var) -> None:
3680+
sep_str = PrintfValidator.escape_str(require_constant_str(sep))
3681+
end_str = PrintfValidator.escape_str(require_constant_str(end))
3682+
3683+
format_parts = []
3684+
leaf_vars = []
3685+
first = True
3686+
3687+
for arg_var in args:
3688+
if not first:
3689+
format_parts.append(sep_str)
3690+
else:
3691+
first = False
3692+
3693+
arg_ty = arg_var.get_type()
3694+
if isinstance(arg_ty, FormattedStringTy):
3695+
fmt_val = arg_var.get_aggregate()
3696+
assert isinstance(fmt_val, FormattedStringValue)
3697+
for piece in arg_ty.format.pieces:
3698+
if isinstance(piece, str):
3699+
format_parts.append(PrintfValidator.escape_str(piece))
3700+
else:
3701+
value_ty = arg_ty.value_types[piece.value_idx]
3702+
dtype = get_dtype(value_ty)
3703+
if piece.format_spec is not None:
3704+
format_parts.append(PrintfValidator.apply_python_spec(
3705+
piece.format_spec, dtype))
3706+
else:
3707+
format_parts.append(PrintfValidator.infer_format(dtype))
3708+
leaf_vars.append(fmt_val.values[piece.value_idx])
3709+
elif isinstance(arg_ty, StringTy):
3710+
format_parts.append(PrintfValidator.escape_str(arg_ty.value))
3711+
else:
3712+
tile_ty = require_tile_type(arg_var)
3713+
format_parts.append(PrintfValidator.infer_format(get_dtype(tile_ty)))
3714+
leaf_vars.append(arg_var)
3715+
3716+
format_parts.append(end_str)
3717+
final_format = ''.join(format_parts)
3718+
add_operation(TilePrintf, (), format=final_format, args=tuple(leaf_vars))
3719+
3720+
36413721
@dataclass(eq=False)
36423722
class TileAssert(Operation, opcode="assert", memory_effect=MemoryEffect.STORE):
36433723
message: str = attribute()

src/cuda/tile/_ir/type.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,50 @@ def map(self, unwrap: Callable[[Type], Any]) -> Tuple[Any, ...]:
210210
return tuple(unwrap(t) for t in self.value_types)
211211

212212

213+
# ============== Formatted String Type ===============
214+
215+
@dataclass(frozen=True)
216+
class FormattedPiece:
217+
"""A single typed placeholder in a formatted string."""
218+
value_idx: int # index into FormattedStringTy.value_types
219+
format_spec: str | None # None = type-inferred; otherwise Python format spec (e.g. '.2f')
220+
221+
222+
@dataclass(frozen=True)
223+
class StringFormat:
224+
"""Immutable format template for a formatted string value."""
225+
pieces: tuple[str | FormattedPiece, ...]
226+
227+
228+
@dataclass(frozen=True)
229+
class FormattedStringTy(Type):
230+
format: "StringFormat"
231+
value_types: tuple
232+
233+
def is_aggregate(self) -> bool:
234+
return True
235+
236+
def aggregate_item_types(self) -> tuple:
237+
return self.value_types
238+
239+
def make_aggregate_value(self, items: tuple) -> "AggregateValue":
240+
from .ir import FormattedStringValue
241+
return FormattedStringValue(self.format, items)
242+
243+
def __str__(self):
244+
parts = []
245+
for piece in self.format.pieces:
246+
if isinstance(piece, str):
247+
parts.append(piece)
248+
else:
249+
ty = self.value_types[piece.value_idx]
250+
if piece.format_spec is not None:
251+
parts.append(f"{{<{ty}>:{piece.format_spec}}}")
252+
else:
253+
parts.append(f"{{<{ty}>}}")
254+
return 'FormattedString<"' + "".join(parts) + '">'
255+
256+
213257
def size_to_bytecode(s: Optional[int]) -> int:
214258
return bc.DYNAMIC_SHAPE if s is None else s
215259

src/cuda/tile/_ir/typing_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def wrapped(handler: TypeHandler):
116116
float: lambda x=0, /: None,
117117
int: lambda x=0, /: None,
118118
bool: lambda x=False, /: None,
119+
print: lambda *args, sep=' ', end='\n': None,
119120
}
120121

121122

0 commit comments

Comments
 (0)