|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | import ast |
| 6 | +import builtins as _builtins |
6 | 7 | import inspect |
7 | 8 | import itertools |
8 | 9 | import operator |
|
18 | 19 | from cuda.tile._ir.type import ClosureDefaultPlaceholder |
19 | 20 | from cuda.tile._passes.ast_util import ast_get_all_local_names |
20 | 21 | from cuda.tile._stub import static_eval, static_assert, static_iter |
| 22 | +import cuda.tile._stub as _ct_stub |
21 | 23 |
|
22 | 24 |
|
23 | 25 | @lru_cache |
@@ -248,8 +250,8 @@ def decorate(f): |
248 | 250 | _expr_handlers: Dict[Type[ast.AST], Callable] = {} |
249 | 251 |
|
250 | 252 |
|
251 | | -_KEYWORD_LIKE_FUNCS = (static_eval, static_assert, static_iter) |
252 | | -_KEYWORD_LIKE_FUNC_NAMES = ("static_eval", "static_assert", "static_iter") |
| 253 | +_KEYWORD_LIKE_FUNCS = (static_eval, static_assert, static_iter, _ct_stub.print, _builtins.print) |
| 254 | +_KEYWORD_LIKE_FUNC_NAMES = ("static_eval", "static_assert", "static_iter", "print", "print") |
253 | 255 |
|
254 | 256 |
|
255 | 257 | @_register(_expr_handlers, ast.Call) |
@@ -277,6 +279,8 @@ def _call_expr(call: ast.Call, ctx: _Context) -> hir.Value: |
277 | 279 | elif kwd_func == "static_iter": |
278 | 280 | raise TileSyntaxError("static_iter() is only allowed as iterable in a `for` loop," |
279 | 281 | " i.e. `for i in ct.static_iter(...)`") |
| 282 | + elif kwd_func == "print": |
| 283 | + return _handle_ct_print(call, ctx) |
280 | 284 | else: |
281 | 285 | raise TileSyntaxError(f"{kwd_func} is not expected here") |
282 | 286 | else: |
@@ -378,6 +382,131 @@ def _is_cuda_module(value: ast.expr, ctx: _Context) -> bool: |
378 | 382 | return ctx.frozen_globals.get(value.id) is cuda |
379 | 383 |
|
380 | 384 |
|
| 385 | +# ================================ |
| 386 | +# ct.print() helper functions |
| 387 | +# ================================ |
| 388 | + |
| 389 | +def _escape_format_str(s: str) -> str: |
| 390 | + """Escape a literal string for use in a C printf format (replace % with %%).""" |
| 391 | + return s.replace('%', '%%') |
| 392 | + |
| 393 | + |
| 394 | +def _python_spec_to_c_printf(py_spec: str, ctx: _Context) -> str: |
| 395 | + """Convert a Python format spec string to a C printf format specifier.""" |
| 396 | + import re |
| 397 | + m = re.fullmatch( |
| 398 | + r'(?P<align>[<>^])?' |
| 399 | + r'(?P<sign>[+ -])?' |
| 400 | + r'(?P<alt>\#)?' |
| 401 | + r'(?P<zero>0)?' |
| 402 | + r'(?P<width>[0-9]+)?' |
| 403 | + r'(?:\.(?P<precision>[0-9]+))?' |
| 404 | + r'(?P<type>[diouxXeEfFgGaA])?', |
| 405 | + py_spec) |
| 406 | + if m is None or m.group(0) != py_spec: |
| 407 | + raise ctx.syntax_error(f"ct.print(): unsupported format spec '{py_spec}'") |
| 408 | + align = m.group('align') |
| 409 | + sign = m.group('sign') |
| 410 | + alt = m.group('alt') |
| 411 | + zero = m.group('zero') |
| 412 | + width = m.group('width') or '' |
| 413 | + precision = ('.' + m.group('precision')) if m.group('precision') is not None else '' |
| 414 | + typ = m.group('type') or '' |
| 415 | + |
| 416 | + flags = '' |
| 417 | + if align == '<': |
| 418 | + flags += '-' |
| 419 | + if sign in ('+', ' '): |
| 420 | + flags += sign |
| 421 | + if alt: |
| 422 | + flags += '#' |
| 423 | + if zero and align != '<': |
| 424 | + flags += '0' |
| 425 | + |
| 426 | + return f'%{flags}{width}{precision}{typ}' |
| 427 | + |
| 428 | + |
| 429 | +def _extract_format_spec(spec_node, ctx: _Context): |
| 430 | + """Extract explicit format spec from a FormattedValue's format_spec. |
| 431 | + Returns a C printf format specifier (e.g. '%.2f') or None for type-inferred.""" |
| 432 | + if spec_node is None: |
| 433 | + return None |
| 434 | + if not isinstance(spec_node, ast.JoinedStr): |
| 435 | + raise ctx.syntax_error("ct.print(): internal error: unexpected format_spec node") |
| 436 | + if len(spec_node.values) == 0: |
| 437 | + return None |
| 438 | + if len(spec_node.values) != 1 or not isinstance(spec_node.values[0], ast.Constant): |
| 439 | + raise ctx.syntax_error( |
| 440 | + "ct.print() f-string: dynamic format specs (e.g. {x:{width}}) are not supported") |
| 441 | + py_spec = str(spec_node.values[0].value) |
| 442 | + return _python_spec_to_c_printf(py_spec, ctx) |
| 443 | + |
| 444 | + |
| 445 | +def _process_fstring(node: ast.JoinedStr, format_parts: list, tile_var_hirs: list, |
| 446 | + ctx: _Context) -> None: |
| 447 | + """Decompose a JoinedStr (f-string) into format template parts and HIR vars.""" |
| 448 | + for part in node.values: |
| 449 | + if isinstance(part, ast.Constant): |
| 450 | + format_parts.append(_escape_format_str(str(part.value))) |
| 451 | + elif isinstance(part, ast.FormattedValue): |
| 452 | + if part.conversion != -1: |
| 453 | + raise ctx.syntax_error( |
| 454 | + "ct.print() f-string: !r, !s, !a conversions are not supported") |
| 455 | + c_spec = _extract_format_spec(part.format_spec, ctx) |
| 456 | + if c_spec is not None: |
| 457 | + format_parts.append(c_spec) |
| 458 | + else: |
| 459 | + format_parts.append('\x01') |
| 460 | + tile_var_hirs.append(_expr(part.value, ctx)) |
| 461 | + else: |
| 462 | + raise ctx.syntax_error("ct.print(): unsupported f-string component") |
| 463 | + |
| 464 | + |
| 465 | +def _require_str_constant(node: ast.expr, ctx: _Context, kwarg_name: str) -> str: |
| 466 | + """Require a keyword argument to be a string constant at AST level.""" |
| 467 | + if not isinstance(node, ast.Constant) or not isinstance(node.value, str): |
| 468 | + raise ctx.syntax_error( |
| 469 | + f"ct.print(): keyword argument '{kwarg_name}' must be a string constant") |
| 470 | + return node.value |
| 471 | + |
| 472 | + |
| 473 | +def _handle_ct_print(call: ast.Call, ctx: _Context) -> hir.Value: |
| 474 | + """Handle ct.print() calls by decomposing f-strings and building HIR.""" |
| 475 | + sep = ' ' |
| 476 | + end = '\n' |
| 477 | + for kw in call.keywords: |
| 478 | + if kw.arg == 'sep': |
| 479 | + sep = _require_str_constant(kw.value, ctx, 'sep') |
| 480 | + elif kw.arg == 'end': |
| 481 | + end = _require_str_constant(kw.value, ctx, 'end') |
| 482 | + else: |
| 483 | + raise ctx.syntax_error( |
| 484 | + f"ct.print() got unexpected keyword argument '{kw.arg}'") |
| 485 | + |
| 486 | + format_parts = [] |
| 487 | + tile_var_hirs = [] |
| 488 | + first = True |
| 489 | + |
| 490 | + for arg_node in call.args: |
| 491 | + if not first: |
| 492 | + format_parts.append(_escape_format_str(sep)) |
| 493 | + first = False |
| 494 | + |
| 495 | + if isinstance(arg_node, ast.JoinedStr): |
| 496 | + _process_fstring(arg_node, format_parts, tile_var_hirs, ctx) |
| 497 | + elif isinstance(arg_node, ast.Constant) and isinstance(arg_node.value, str): |
| 498 | + format_parts.append(_escape_format_str(arg_node.value)) |
| 499 | + else: |
| 500 | + format_parts.append('\x01') |
| 501 | + tile_var_hirs.append(_expr(arg_node, ctx)) |
| 502 | + |
| 503 | + format_parts.append(_escape_format_str(end)) |
| 504 | + format_template = ''.join(format_parts) |
| 505 | + |
| 506 | + template_hir = ctx.call(hir.identity, (format_template,)) |
| 507 | + return ctx.call(_ct_stub.printf, (template_hir, *tile_var_hirs)) |
| 508 | + |
| 509 | + |
381 | 510 | @_register(_expr_handlers, ast.Name) |
382 | 511 | def _name_expr(name: ast.Name, ctx: Any) -> hir.Value: |
383 | 512 | if not isinstance(name.ctx, ast.Load): |
|
0 commit comments