|
4 | 4 | #include "cinderx/Jit/hir/resolve_kwargs.h" |
5 | 5 |
|
6 | 6 | #include "cinderx/Jit/hir/hir.h" |
7 | | -#include "cinderx/Jit/hir/hir_c_api.h" |
8 | | -#include "cinderx/Jit/hir/hir_type_c.h" |
9 | | -#include "cinderx/Common/log.h" |
10 | | -#include "cinderx/Common/code.h" |
11 | | - |
12 | | -#include <Python.h> |
13 | | - |
14 | | -#include <vector> |
15 | 7 |
|
16 | 8 | namespace jit::hir { |
17 | 9 |
|
18 | | -static inline HirType to_hir(Type t) { return Type::toHirType(t); } |
19 | | - |
20 | | -namespace { |
21 | | - |
22 | | -// Try to resolve keyword arguments in a VectorCall to positional order. |
23 | | -// Returns true if the instruction was replaced. |
24 | | -bool resolveVectorCallKwargs(VectorCall* call) { |
25 | | - Register* target = call->func(); |
26 | | - |
27 | | - // Need a known function to get parameter names. |
28 | | - auto target_ty = target->type(); |
29 | | - HirType target_hir = to_hir(target_ty); |
30 | | - HirType tfunc_hir = to_hir(TFunc); |
31 | | - if (!hir_type_has_value_spec(&target_hir, tfunc_hir)) { |
32 | | - return false; |
33 | | - } |
34 | | - |
35 | | - PyFunctionObject* callee = (PyFunctionObject*)hir_type_object_spec(&target_hir); |
36 | | - PyCodeObject* code = (PyCodeObject*)callee->func_code; |
37 | | - |
38 | | - // Skip callees with **kwargs or *args — cannot resolve statically. |
39 | | - if (code->co_flags & (CO_VARKEYWORDS | CO_VARARGS)) { |
40 | | - return false; |
41 | | - } |
42 | | - |
43 | | - // Get the kwnames tuple from the last operand. |
44 | | - std::size_t total_operands = call->NumOperands(); |
45 | | - if (total_operands < 2) { |
46 | | - return false; // Need at least func + 1 arg |
47 | | - } |
48 | | - |
49 | | - Register* kwnames_reg = call->GetOperand(total_operands - 1); |
50 | | - auto kwnames_ty = kwnames_reg->type(); |
51 | | - HirType kwnames_hir = to_hir(kwnames_ty); |
52 | | - if (!hir_type_has_object_spec(&kwnames_hir)) { |
53 | | - return false; // kwnames not a known constant |
54 | | - } |
55 | | - |
56 | | - PyObject* kwnames_obj = hir_type_object_spec(&kwnames_hir); |
57 | | - if (!PyTuple_Check(kwnames_obj)) { |
58 | | - return false; |
59 | | - } |
60 | | - |
61 | | - Py_ssize_t n_kw = PyTuple_GET_SIZE(kwnames_obj); |
62 | | - if (n_kw == 0) { |
63 | | - return false; |
64 | | - } |
65 | | - |
66 | | - // Operand layout: [func, arg0, arg1, ..., argN-1, kwnames] |
67 | | - // Total args = total_operands - 2 (subtract func and kwnames) |
68 | | - std::size_t total_args = total_operands - 2; |
69 | | - std::size_t n_pos = total_args - static_cast<std::size_t>(n_kw); |
70 | | - |
71 | | - // Get parameter names from callee code object. |
72 | | - int co_argcount = code->co_argcount; |
73 | | - int co_kwonly = code->co_kwonlyargcount; |
74 | | - int total_params = co_argcount + co_kwonly; |
75 | | - |
76 | | - // The total args provided must not exceed the callee parameter count. |
77 | | - if (static_cast<int>(total_args) > total_params) { |
78 | | - return false; |
79 | | - } |
80 | | - |
81 | | - // Build mapping: for each callee param position, which call-site operand |
82 | | - // provides the value? |
83 | | - // First n_pos call-site args map to first n_pos callee params (positional). |
84 | | - // Remaining args are keyword — look up their position in co_varnames. |
85 | | - std::vector<Register*> reordered(total_args, nullptr); |
86 | | - |
87 | | - // Copy positional args (they are already in the right position). |
88 | | - for (std::size_t i = 0; i < n_pos; i++) { |
89 | | - reordered[i] = call->GetOperand(i + 1); // +1 to skip func |
90 | | - } |
91 | | - |
92 | | - // Map keyword args to callee parameter positions. |
93 | | - PyObject* varnames = PyCode_GetVarnames(code); |
94 | | - if (!PyTuple_Check(varnames)) { |
95 | | - return false; |
96 | | - } |
97 | | - |
98 | | - for (Py_ssize_t kw_idx = 0; kw_idx < n_kw; kw_idx++) { |
99 | | - PyObject* kwname = PyTuple_GET_ITEM(kwnames_obj, kw_idx); |
100 | | - Register* kw_arg = call->GetOperand(n_pos + kw_idx + 1); // +1 for func |
101 | | - |
102 | | - // Find this keyword name in the callee parameters. |
103 | | - bool found = false; |
104 | | - for (int param_idx = 0; param_idx < total_params; param_idx++) { |
105 | | - PyObject* param_name = PyTuple_GET_ITEM(varnames, param_idx); |
106 | | - int cmp = PyUnicode_Compare(kwname, param_name); |
107 | | - if (cmp == 0 && !PyErr_Occurred()) { |
108 | | - if (reordered[param_idx] != nullptr) { |
109 | | - // Duplicate argument — bail out, let Python handle the error. |
110 | | - return false; |
111 | | - } |
112 | | - reordered[param_idx] = kw_arg; |
113 | | - found = true; |
114 | | - break; |
115 | | - } |
116 | | - if (PyErr_Occurred()) { |
117 | | - PyErr_Clear(); |
118 | | - return false; |
119 | | - } |
120 | | - } |
121 | | - |
122 | | - if (!found) { |
123 | | - // Keyword arg does not match any callee parameter — bail out. |
124 | | - return false; |
125 | | - } |
126 | | - } |
127 | | - |
128 | | - // Verify all provided positions are filled (no gaps). |
129 | | - for (std::size_t i = 0; i < total_args; i++) { |
130 | | - if (reordered[i] == nullptr) { |
131 | | - return false; // Missing argument — let runtime handle it. |
132 | | - } |
133 | | - } |
134 | | - |
135 | | - // Create replacement VectorCall without kwargs. |
136 | | - // New operand count: func + total_args (no kwnames). |
137 | | - std::size_t new_num_operands = total_args + 1; |
138 | | - CallFlags new_flags = static_cast<CallFlags>(static_cast<uint32_t>(call->flags()) & ~static_cast<uint32_t>(CallFlags::KwArgs)); |
139 | | - |
140 | | - auto* new_call = static_cast<VectorCall*>(hir_c_create_vectorcall_reg(new_num_operands, call->output(), static_cast<uint32_t>(new_flags))); |
141 | | - new_call->SetOperand(0, target); // func |
142 | | - for (std::size_t i = 0; i < total_args; i++) { |
143 | | - new_call->SetOperand(i + 1, reordered[i]); |
144 | | - } |
145 | | - |
146 | | - // Copy frame state and bytecode offset. |
147 | | - if (auto* fs = call->frameState()) { |
148 | | - new_call->setFrameState(*fs); |
149 | | - } |
150 | | - new_call->setBytecodeOffset(call->bytecodeOffset()); |
151 | | - |
152 | | - call->ReplaceWith(*new_call); |
153 | | - return true; |
154 | | -} |
155 | | - |
156 | | -// Try to resolve keyword arguments in a CallMethod to positional order. |
157 | | -bool resolveCallMethodKwargs(CallMethod* call) { |
158 | | - Register* target = call->func(); |
159 | | - |
160 | | - auto target_ty2 = target->type(); |
161 | | - HirType target_hir2 = to_hir(target_ty2); |
162 | | - HirType tfunc_hir2 = to_hir(TFunc); |
163 | | - if (!hir_type_has_value_spec(&target_hir2, tfunc_hir2)) { |
164 | | - return false; |
165 | | - } |
166 | | - |
167 | | - PyFunctionObject* callee = (PyFunctionObject*)hir_type_object_spec(&target_hir2); |
168 | | - PyCodeObject* code = (PyCodeObject*)callee->func_code; |
169 | | - |
170 | | - if (code->co_flags & (CO_VARKEYWORDS | CO_VARARGS)) { |
171 | | - return false; |
172 | | - } |
173 | | - |
174 | | - // CallMethod operand layout: [func, self, arg0, ..., argN-1, kwnames] |
175 | | - std::size_t total_operands = call->NumOperands(); |
176 | | - if (total_operands < 3) { |
177 | | - return false; |
178 | | - } |
179 | | - |
180 | | - Register* kwnames_reg = call->GetOperand(total_operands - 1); |
181 | | - auto kwnames_ty2 = kwnames_reg->type(); |
182 | | - HirType kwnames_hir2 = to_hir(kwnames_ty2); |
183 | | - if (!hir_type_has_object_spec(&kwnames_hir2)) { |
184 | | - return false; |
185 | | - } |
186 | | - |
187 | | - PyObject* kwnames_obj = hir_type_object_spec(&kwnames_hir2); |
188 | | - if (!PyTuple_Check(kwnames_obj)) { |
189 | | - return false; |
190 | | - } |
191 | | - |
192 | | - Py_ssize_t n_kw = PyTuple_GET_SIZE(kwnames_obj); |
193 | | - if (n_kw == 0) { |
194 | | - return false; |
195 | | - } |
196 | | - |
197 | | - // Total args = total_operands - 3 (subtract func, self, kwnames) |
198 | | - std::size_t total_args = total_operands - 3; |
199 | | - std::size_t n_pos = total_args - static_cast<std::size_t>(n_kw); |
200 | | - |
201 | | - int co_argcount = code->co_argcount; |
202 | | - int co_kwonly = code->co_kwonlyargcount; |
203 | | - // For methods, first param is self which is separate in CallMethod. |
204 | | - // co_argcount includes self for bound methods, but CallMethod has self |
205 | | - // as a separate operand. Adjust: effective params = co_argcount - 1 + co_kwonly. |
206 | | - // Actually, for module-level functions called via CALL, self is NULL/unused. |
207 | | - // This needs careful handling — for now, use the args as-is. |
208 | | - int total_params = co_argcount + co_kwonly; |
209 | | - |
210 | | - if (static_cast<int>(total_args) > total_params) { |
211 | | - return false; |
212 | | - } |
213 | | - |
214 | | - std::vector<Register*> reordered(total_args, nullptr); |
215 | | - |
216 | | - for (std::size_t i = 0; i < n_pos; i++) { |
217 | | - reordered[i] = call->GetOperand(i + 2); // +2 to skip func and self |
218 | | - } |
219 | | - |
220 | | - PyObject* varnames = PyCode_GetVarnames(code); |
221 | | - if (!PyTuple_Check(varnames)) { |
222 | | - return false; |
223 | | - } |
224 | | - |
225 | | - for (Py_ssize_t kw_idx = 0; kw_idx < n_kw; kw_idx++) { |
226 | | - PyObject* kwname = PyTuple_GET_ITEM(kwnames_obj, kw_idx); |
227 | | - Register* kw_arg = call->GetOperand(n_pos + kw_idx + 2); |
228 | | - |
229 | | - bool found = false; |
230 | | - for (int param_idx = 0; param_idx < total_params; param_idx++) { |
231 | | - PyObject* param_name = PyTuple_GET_ITEM(varnames, param_idx); |
232 | | - int cmp = PyUnicode_Compare(kwname, param_name); |
233 | | - if (cmp == 0 && !PyErr_Occurred()) { |
234 | | - if (reordered[param_idx] != nullptr) { |
235 | | - return false; |
236 | | - } |
237 | | - reordered[param_idx] = kw_arg; |
238 | | - found = true; |
239 | | - break; |
240 | | - } |
241 | | - if (PyErr_Occurred()) { |
242 | | - PyErr_Clear(); |
243 | | - return false; |
244 | | - } |
245 | | - } |
246 | | - |
247 | | - if (!found) { |
248 | | - return false; |
249 | | - } |
250 | | - } |
251 | | - |
252 | | - for (std::size_t i = 0; i < total_args; i++) { |
253 | | - if (reordered[i] == nullptr) { |
254 | | - return false; |
255 | | - } |
256 | | - } |
257 | | - |
258 | | - // Create replacement CallMethod without kwargs. |
259 | | - std::size_t new_num_operands = total_args + 2; // func + self + args |
260 | | - CallFlags new_flags = static_cast<CallFlags>(static_cast<uint32_t>(call->flags()) & ~static_cast<uint32_t>(CallFlags::KwArgs)); |
261 | | - |
262 | | - auto* new_call = static_cast<CallMethod*>(hir_c_create_call_method_reg(new_num_operands, call->output(), static_cast<uint32_t>(new_flags))); |
263 | | - new_call->SetOperand(0, target); |
264 | | - new_call->SetOperand(1, call->self()); |
265 | | - for (std::size_t i = 0; i < total_args; i++) { |
266 | | - new_call->SetOperand(i + 2, reordered[i]); |
267 | | - } |
268 | | - |
269 | | - if (auto* fs = call->frameState()) { |
270 | | - new_call->setFrameState(*fs); |
271 | | - } |
272 | | - new_call->setBytecodeOffset(call->bytecodeOffset()); |
273 | | - |
274 | | - call->ReplaceWith(*new_call); |
275 | | - return true; |
276 | | -} |
277 | | - |
278 | | -} // namespace |
279 | | - |
280 | 10 | void ResolveKwargs::Run(Function& irfunc) { |
281 | | - int resolved = 0; |
282 | | - |
283 | | - for (auto& block : irfunc.cfg.blocks) { |
284 | | - // Collect instructions to process (can not modify while iterating). |
285 | | - std::vector<Instr*> kwargs_instrs; |
286 | | - for (auto& instr : block) { |
287 | | - if (instr.IsVectorCall()) { |
288 | | - auto* call = static_cast<VectorCall*>(&instr); |
289 | | - if (call->flags() & CallFlags::KwArgs) { |
290 | | - kwargs_instrs.push_back(&instr); |
291 | | - } |
292 | | - } else if (instr.IsCallMethod()) { |
293 | | - auto* call = static_cast<CallMethod*>(&instr); |
294 | | - if (call->flags() & CallFlags::KwArgs) { |
295 | | - kwargs_instrs.push_back(&instr); |
296 | | - } |
297 | | - } |
298 | | - } |
299 | | - |
300 | | - for (auto* instr : kwargs_instrs) { |
301 | | - if (instr->IsVectorCall()) { |
302 | | - if (resolveVectorCallKwargs(static_cast<VectorCall*>(instr))) { |
303 | | - resolved++; |
304 | | - } |
305 | | - } else if (instr->IsCallMethod()) { |
306 | | - if (resolveCallMethodKwargs(static_cast<CallMethod*>(instr))) { |
307 | | - resolved++; |
308 | | - } |
309 | | - } |
310 | | - } |
311 | | - } |
312 | | - |
313 | | - if (resolved > 0) { |
314 | | - JIT_DLOG("ResolveKwargs: resolved {} keyword calls to positional in {}", |
315 | | - resolved, irfunc.fullname); |
316 | | - } |
| 11 | + hir_resolve_kwargs_run(static_cast<void*>(&irfunc)); |
317 | 12 | } |
318 | 13 |
|
319 | 14 | } // namespace jit::hir |
320 | | - |
321 | | -extern "C" void hir_resolve_kwargs_run(HirFunction func) { |
322 | | - jit::hir::ResolveKwargs{}.Run( |
323 | | - *static_cast<jit::hir::Function*>(func)); |
324 | | -} |
0 commit comments