Skip to content

Commit 4a27fb7

Browse files
committed
:Add ct.print to suport fstring and python-tyle print
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent fb2bdd0 commit 4a27fb7

File tree

7 files changed

+368
-21
lines changed

7 files changed

+368
-21
lines changed

changelog.d/ct-print.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added `print()` and `ct.print()` to support python-style f-strings and positional arguments

docs/source/operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ Utility
189189
:nosignatures:
190190

191191
printf
192+
print
192193
assert_
193194

194195

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
permute,
127127
pow,
128128
printf,
129+
print,
129130
prod,
130131
reduce,
131132
reshape,
@@ -267,6 +268,7 @@
267268
"permute",
268269
"pow",
269270
"printf",
271+
"print",
270272
"prod",
271273
"reduce",
272274
"reshape",

src/cuda/tile/_ir/op_impl.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,19 +428,36 @@ def validate_dtype(cls, dtype: DType, specifier: str) -> bool:
428428
else:
429429
return False
430430

431+
# Placeholder emitted by ast2hir for type-inferred format specifiers
432+
_TYPE_INFER = '\x01'
433+
431434
@classmethod
432435
def parse_format(cls, format: str, arg_types: Tuple[Union[TileTy, DType], ...]) -> str:
433436
last_pos = pos = 0
434437
arg_idx = 0
435438
tokens = []
436439
while pos < len(format):
437-
if format[pos] == "%":
440+
ch = format[pos]
441+
if ch == cls._TYPE_INFER:
442+
tokens.append(format[last_pos:pos])
443+
if arg_idx >= len(arg_types):
444+
raise TileTypeError("Not enough arguments for format string")
445+
dtype = get_dtype(arg_types[arg_idx])
446+
if is_boolean(dtype) or is_integral(dtype):
447+
tokens.append('%d')
448+
elif is_float(dtype) or is_restricted_float(dtype):
449+
tokens.append('%f')
450+
else:
451+
raise TileTypeError(f"Cannot infer format for dtype {dtype}")
452+
arg_idx += 1
453+
pos += 1
454+
last_pos = pos
455+
elif ch == "%":
438456
tokens.append(format[last_pos:pos])
439457
last_pos = pos
440458
# escape "%%"
441459
if (pos + 1 < len(format) and format[pos + 1] == "%"):
442460
pos += 2
443-
continue
444461
elif (m := cls.pattern.match(format, pos)):
445462
# get a format match
446463
_, _, _, _, sp = m.groups()
@@ -458,10 +475,10 @@ def parse_format(cls, format: str, arg_types: Tuple[Union[TileTy, DType], ...])
458475
pos = m.end()
459476
tokens.append(format[last_pos:pos])
460477
last_pos = pos
461-
continue
462478
else:
463479
raise TileTypeError("Invalid format string")
464-
pos += 1
480+
else:
481+
pos += 1
465482
tokens.append(format[last_pos:pos])
466483
if arg_idx < len(arg_types):
467484
raise TileTypeError("Too many arguments for format string")

src/cuda/tile/_passes/ast2hir.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import ast
6+
import builtins as _builtins
67
import inspect
78
import itertools
89
import operator
@@ -18,6 +19,7 @@
1819
from cuda.tile._ir.type import ClosureDefaultPlaceholder
1920
from cuda.tile._passes.ast_util import ast_get_all_local_names
2021
from cuda.tile._stub import static_eval, static_assert, static_iter
22+
import cuda.tile._stub as _ct_stub
2123

2224

2325
@lru_cache
@@ -248,8 +250,8 @@ def decorate(f):
248250
_expr_handlers: Dict[Type[ast.AST], Callable] = {}
249251

250252

251-
_KEYWORD_LIKE_FUNCS = (static_eval, static_assert, static_iter)
252-
_KEYWORD_LIKE_FUNC_NAMES = ("static_eval", "static_assert", "static_iter")
253+
_KEYWORD_LIKE_FUNCS = (static_eval, static_assert, static_iter, _ct_stub.print, _builtins.print)
254+
_KEYWORD_LIKE_FUNC_NAMES = ("static_eval", "static_assert", "static_iter", "print", "print")
253255

254256

255257
@_register(_expr_handlers, ast.Call)
@@ -277,6 +279,8 @@ def _call_expr(call: ast.Call, ctx: _Context) -> hir.Value:
277279
elif kwd_func == "static_iter":
278280
raise TileSyntaxError("static_iter() is only allowed as iterable in a `for` loop,"
279281
" i.e. `for i in ct.static_iter(...)`")
282+
elif kwd_func == "print":
283+
return _handle_ct_print(call, ctx)
280284
else:
281285
raise TileSyntaxError(f"{kwd_func} is not expected here")
282286
else:
@@ -378,6 +382,131 @@ def _is_cuda_module(value: ast.expr, ctx: _Context) -> bool:
378382
return ctx.frozen_globals.get(value.id) is cuda
379383

380384

385+
# ================================
386+
# ct.print() helper functions
387+
# ================================
388+
389+
def _escape_format_str(s: str) -> str:
390+
"""Escape a literal string for use in a C printf format (replace % with %%)."""
391+
return s.replace('%', '%%')
392+
393+
394+
def _python_spec_to_c_printf(py_spec: str, ctx: _Context) -> str:
395+
"""Convert a Python format spec string to a C printf format specifier."""
396+
import re
397+
m = re.fullmatch(
398+
r'(?P<align>[<>^])?'
399+
r'(?P<sign>[+ -])?'
400+
r'(?P<alt>\#)?'
401+
r'(?P<zero>0)?'
402+
r'(?P<width>[0-9]+)?'
403+
r'(?:\.(?P<precision>[0-9]+))?'
404+
r'(?P<type>[diouxXeEfFgGaA])?',
405+
py_spec)
406+
if m is None or m.group(0) != py_spec:
407+
raise ctx.syntax_error(f"ct.print(): unsupported format spec '{py_spec}'")
408+
align = m.group('align')
409+
sign = m.group('sign')
410+
alt = m.group('alt')
411+
zero = m.group('zero')
412+
width = m.group('width') or ''
413+
precision = ('.' + m.group('precision')) if m.group('precision') is not None else ''
414+
typ = m.group('type') or ''
415+
416+
flags = ''
417+
if align == '<':
418+
flags += '-'
419+
if sign in ('+', ' '):
420+
flags += sign
421+
if alt:
422+
flags += '#'
423+
if zero and align != '<':
424+
flags += '0'
425+
426+
return f'%{flags}{width}{precision}{typ}'
427+
428+
429+
def _extract_format_spec(spec_node, ctx: _Context):
430+
"""Extract explicit format spec from a FormattedValue's format_spec.
431+
Returns a C printf format specifier (e.g. '%.2f') or None for type-inferred."""
432+
if spec_node is None:
433+
return None
434+
if not isinstance(spec_node, ast.JoinedStr):
435+
raise ctx.syntax_error("ct.print(): internal error: unexpected format_spec node")
436+
if len(spec_node.values) == 0:
437+
return None
438+
if len(spec_node.values) != 1 or not isinstance(spec_node.values[0], ast.Constant):
439+
raise ctx.syntax_error(
440+
"ct.print() f-string: dynamic format specs (e.g. {x:{width}}) are not supported")
441+
py_spec = str(spec_node.values[0].value)
442+
return _python_spec_to_c_printf(py_spec, ctx)
443+
444+
445+
def _process_fstring(node: ast.JoinedStr, format_parts: list, tile_var_hirs: list,
446+
ctx: _Context) -> None:
447+
"""Decompose a JoinedStr (f-string) into format template parts and HIR vars."""
448+
for part in node.values:
449+
if isinstance(part, ast.Constant):
450+
format_parts.append(_escape_format_str(str(part.value)))
451+
elif isinstance(part, ast.FormattedValue):
452+
if part.conversion != -1:
453+
raise ctx.syntax_error(
454+
"ct.print() f-string: !r, !s, !a conversions are not supported")
455+
c_spec = _extract_format_spec(part.format_spec, ctx)
456+
if c_spec is not None:
457+
format_parts.append(c_spec)
458+
else:
459+
format_parts.append('\x01')
460+
tile_var_hirs.append(_expr(part.value, ctx))
461+
else:
462+
raise ctx.syntax_error("ct.print(): unsupported f-string component")
463+
464+
465+
def _require_str_constant(node: ast.expr, ctx: _Context, kwarg_name: str) -> str:
466+
"""Require a keyword argument to be a string constant at AST level."""
467+
if not isinstance(node, ast.Constant) or not isinstance(node.value, str):
468+
raise ctx.syntax_error(
469+
f"ct.print(): keyword argument '{kwarg_name}' must be a string constant")
470+
return node.value
471+
472+
473+
def _handle_ct_print(call: ast.Call, ctx: _Context) -> hir.Value:
474+
"""Handle ct.print() calls by decomposing f-strings and building HIR."""
475+
sep = ' '
476+
end = '\n'
477+
for kw in call.keywords:
478+
if kw.arg == 'sep':
479+
sep = _require_str_constant(kw.value, ctx, 'sep')
480+
elif kw.arg == 'end':
481+
end = _require_str_constant(kw.value, ctx, 'end')
482+
else:
483+
raise ctx.syntax_error(
484+
f"ct.print() got unexpected keyword argument '{kw.arg}'")
485+
486+
format_parts = []
487+
tile_var_hirs = []
488+
first = True
489+
490+
for arg_node in call.args:
491+
if not first:
492+
format_parts.append(_escape_format_str(sep))
493+
first = False
494+
495+
if isinstance(arg_node, ast.JoinedStr):
496+
_process_fstring(arg_node, format_parts, tile_var_hirs, ctx)
497+
elif isinstance(arg_node, ast.Constant) and isinstance(arg_node.value, str):
498+
format_parts.append(_escape_format_str(arg_node.value))
499+
else:
500+
format_parts.append('\x01')
501+
tile_var_hirs.append(_expr(arg_node, ctx))
502+
503+
format_parts.append(_escape_format_str(end))
504+
format_template = ''.join(format_parts)
505+
506+
template_hir = ctx.call(hir.identity, (format_template,))
507+
return ctx.call(_ct_stub.printf, (template_hir, *tile_var_hirs))
508+
509+
381510
@_register(_expr_handlers, ast.Name)
382511
def _name_expr(name: ast.Name, ctx: Any) -> hir.Value:
383512
if not isinstance(name.ctx, ast.Load):

src/cuda/tile/_stub.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,6 +2035,38 @@ def my_print_kernel():
20352035
"""
20362036

20372037

2038+
@function
2039+
def print(*args, sep: str = ' ', end: str = '\n') -> None:
2040+
"""Print values at runtime from the device using Python-style syntax.
2041+
2042+
Supports Python f-strings and positional arguments similar to Python's
2043+
built-in ``print()`` function.
2044+
2045+
Args:
2046+
*args: Values to print. Each argument can be:
2047+
- A string literal or f-string
2048+
- A tile value (format inferred from dtype: int→``%d``, float→``%f``)
2049+
sep (str): Separator inserted between arguments (default: ``' '``)
2050+
end (str): String appended after the last argument (default: ``'\\n'``)
2051+
2052+
Examples:
2053+
2054+
>>> tile = ct.arange(4, dtype=ct.int32)
2055+
>>> ct.print(f"tile={tile}")
2056+
>>> ct.print(f"x={tile:.5f}", end='')
2057+
>>> ct.print("tile:", tile, sep='=')
2058+
2059+
Notes:
2060+
This operation has significant overhead, and should only be used
2061+
for debugging purposes.
2062+
2063+
F-string expressions must evaluate to tile values. Constant compile-time
2064+
values are supported as string-formatted segments.
2065+
2066+
Use ``opt_level=0`` to prevent block-level output interleaving.
2067+
"""
2068+
2069+
20382070
@function
20392071
def assert_(cond, /, message=None) -> None:
20402072
"""Assert that all elements of the given tile are True.

0 commit comments

Comments
 (0)