|
6 | 6 | from dataclasses import dataclass, field |
7 | 7 | import itertools |
8 | 8 | from collections import defaultdict |
9 | | -from typing import Sequence |
10 | | - |
11 | | -from typing_extensions import override |
12 | | - |
13 | | -from cuda.tile._datatype import DType |
14 | 9 | from cuda.tile._ir.ir import ( |
15 | 10 | Block as TileBlock, |
16 | 11 | Builder as TileBuilder, |
|
22 | 17 | attribute, |
23 | 18 | add_operation, |
24 | 19 | format_var, |
25 | | - AggregateValue, TypingHooks, |
| 20 | + AggregateValue, |
26 | 21 | ) |
27 | | -from cuda.tile._ir.type import TensorLikeTy, TileTy |
28 | 22 |
|
29 | 23 |
|
30 | 24 | class Builder: |
@@ -91,18 +85,13 @@ def to_string( |
91 | 85 | return f"{' ' * indent}^{self._name}({params}):\n{ops}" |
92 | 86 |
|
93 | 87 |
|
94 | | -class _LangTypingHooks(TypingHooks): |
95 | | - @override |
96 | | - def get_tensor_like_type(self, dtype: DType, shape: Sequence[int]) -> TensorLikeTy: |
97 | | - return TileTy(dtype, shape) |
98 | | - |
99 | | - |
100 | 88 | class IRContext(TileIRContext): |
101 | 89 | def __init__(self, log_ir_on_error: bool = True): |
| 90 | + from cuda.lang._ir.type import LangTypingHooks |
102 | 91 | self._block_names: dict[int, str] = {} |
103 | 92 | self._block_counter: dict[str, itertools.count] = defaultdict(itertools.count) |
104 | 93 | super().__init__(log_ir_on_error, tileiras_version=None, |
105 | | - typing_hooks=_LangTypingHooks()) |
| 94 | + typing_hooks=LangTypingHooks()) |
106 | 95 |
|
107 | 96 | def make_block(self, name: str, loc: Loc, params: tuple[Var, ...] = ()) -> Block: |
108 | 97 | block = Block(self, loc) |
|
0 commit comments