-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathprotocol.py
More file actions
118 lines (94 loc) · 4.31 KB
/
Copy pathprotocol.py
File metadata and controls
118 lines (94 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors
import ctypes
from itertools import chain
from types import SimpleNamespace
from typing import List, Protocol, runtime_checkable
from .._mlir import ir
@runtime_checkable
class DslType(Protocol):
@classmethod
def __construct_from_ir_values__(cls, values: List[ir.Value]) -> "DslType": ...
def __extract_to_ir_values__(self) -> List[ir.Value]: ...
@runtime_checkable
class JitArgument(Protocol):
def __get_ir_types__(self) -> List[ir.Type]: ...
def __get_c_pointers__(self) -> List[ctypes.c_void_p]: ...
@runtime_checkable
class Storable(Protocol):
@classmethod
def __dsl_size_of__(cls) -> int: ...
@classmethod
def __dsl_align_of__(cls) -> int: ...
@classmethod
def __peek_from_ptr__(cls, ptr: ir.Value): ...
@classmethod
def __poke_into_ptr__(cls, ptr: ir.Value, value): ...
def get_ir_types(obj) -> List[ir.Type]:
if isinstance(obj, ir.Value):
return [obj.type]
if hasattr(obj, "__get_ir_types__"):
return obj.__get_ir_types__()
if hasattr(obj, "__extract_to_ir_values__"):
return [v.type for v in obj.__extract_to_ir_values__()]
if isinstance(obj, SimpleNamespace):
return list(chain.from_iterable(get_ir_types(v) for v in vars(obj).values()))
if isinstance(obj, (tuple, list)):
return list(chain.from_iterable(get_ir_types(x) for x in obj))
raise TypeError(f"Cannot derive IR types from {obj}")
def get_c_pointers(obj) -> List[ctypes.c_void_p]:
if hasattr(obj, "__get_c_pointers__"):
return obj.__get_c_pointers__()
if isinstance(obj, SimpleNamespace):
return list(chain.from_iterable(get_c_pointers(v) for v in vars(obj).values()))
if isinstance(obj, (tuple, list)):
return list(chain.from_iterable(get_c_pointers(x) for x in obj))
raise TypeError(f"Cannot derive C pointers from {obj}")
def extract_to_ir_values(obj) -> List[ir.Value]:
if isinstance(obj, ir.Value):
return [obj]
if hasattr(obj, "__extract_to_ir_values__"):
return obj.__extract_to_ir_values__()
if isinstance(obj, SimpleNamespace):
return list(chain.from_iterable(extract_to_ir_values(v) for v in vars(obj).values()))
if isinstance(obj, (tuple, list)):
return list(chain.from_iterable(extract_to_ir_values(x) for x in obj))
raise TypeError(f"Cannot extract IR values from {obj}")
def construct_from_ir_values(dsl_type, args, values: List[ir.Value]) -> DslType:
if isinstance(args, SimpleNamespace):
rebuilt = {}
cursor = 0
for name, value in vars(args).items():
n = len(get_ir_types(value))
sub_type = type(value)
rebuilt[name] = construct_from_ir_values(sub_type, value, values[cursor : cursor + n])
cursor += n
if cursor != len(values):
raise ValueError(f"SimpleNamespace expected {cursor} ir.Values, got {len(values)}")
return SimpleNamespace(**rebuilt)
if hasattr(dsl_type, "__construct_from_ir_values__"):
return dsl_type.__construct_from_ir_values__(values)
if isinstance(dsl_type, (tuple, list)):
elems = []
for ty, arg in zip(dsl_type, args, strict=True):
count = len(get_ir_types(arg))
elems.append(construct_from_ir_values(ty, arg, values[:count]))
values = values[count:]
return type(dsl_type)(elems)
raise TypeError(f"Cannot construct DSL value for {dsl_type}")
def dsl_size_of(dsl_type) -> int:
if hasattr(dsl_type, "__dsl_size_of__"):
return dsl_type.__dsl_size_of__()
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")
def dsl_align_of(dsl_type) -> int:
if hasattr(dsl_type, "__dsl_align_of__"):
return dsl_type.__dsl_align_of__()
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")
def peek_from_ptr(dsl_type, ptr: ir.Value):
if hasattr(dsl_type, "__peek_from_ptr__"):
return dsl_type.__peek_from_ptr__(ptr)
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")
def poke_into_ptr(dsl_type, ptr: ir.Value, value):
if hasattr(dsl_type, "__poke_into_ptr__"):
return dsl_type.__poke_into_ptr__(ptr, value)
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")