Skip to content

Commit 4d53ddb

Browse files
committed
ZJIT: Handle caller_kwarg in direct send when all keyword params are required
1 parent 1c0573c commit 4d53ddb

4 files changed

Lines changed: 200 additions & 16 deletions

File tree

zjit/src/codegen.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,13 @@ fn gen_send_without_block_direct(
13451345
let mut c_args = vec![recv];
13461346
c_args.extend(&args);
13471347

1348+
if unsafe { rb_get_iseq_flags_has_kw(iseq) } {
1349+
// Currently we only get to this point if all the accepted keyword args are required.
1350+
let unspecified_bits = 0;
1351+
// For each optional keyword that isn't passed we would `unspecified_bits |= (0x01 << idx)`.
1352+
c_args.push(VALUE::fixnum_from_usize(unspecified_bits).into());
1353+
}
1354+
13481355
let params = unsafe { iseq.params() };
13491356
let num_optionals_passed = if params.flags.has_opt() != 0 {
13501357
// See vm_call_iseq_setup_normal_opt_start in vm_inshelper.c

zjit/src/hir.rs

Lines changed: 146 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ pub enum SendFallbackReason {
615615
SendWithoutBlockNotOptimizedMethodTypeOptimized(OptimizedMethodType),
616616
SendWithoutBlockBopRedefined,
617617
SendWithoutBlockOperandsNotFixnum,
618+
SendWithoutBlockDirectKeywordMismatch,
619+
SendWithoutBlockDirectOptionalKeywords,
620+
SendWithoutBlockDirectKeywordCountMismatch,
621+
SendWithoutBlockDirectMissingKeyword,
618622
SendPolymorphic,
619623
SendMegamorphic,
620624
SendNoProfiles,
@@ -631,6 +635,8 @@ pub enum SendFallbackReason {
631635
/// The call has at least one feature on the caller or callee side that the optimizer does not
632636
/// support.
633637
ComplexArgPass,
638+
/// Caller has keyword arguments but callee doesn't expect them; need to convert to hash.
639+
UnexpectedKeywordArgs,
634640
/// Initial fallback reason for every instruction, which should be mutated to
635641
/// a more actionable reason when an attempt to specialize the instruction fails.
636642
Uncategorized(ruby_vminsn_type),
@@ -1527,7 +1533,20 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
15271533
use Counter::*;
15281534
if 0 != params.flags.has_rest() { count_failure(complex_arg_pass_param_rest) }
15291535
if 0 != params.flags.has_post() { count_failure(complex_arg_pass_param_post) }
1530-
if 0 != params.flags.has_kw() { count_failure(complex_arg_pass_param_kw) }
1536+
1537+
// We support required-only keywords, but not optional keywords yet
1538+
if 0 != params.flags.has_kw() {
1539+
let keyword = params.keyword;
1540+
if !keyword.is_null() {
1541+
let num = unsafe { (*keyword).num };
1542+
let required_num = unsafe { (*keyword).required_num };
1543+
// Only support required keywords for now (no optional keywords)
1544+
if num != required_num {
1545+
count_failure(complex_arg_pass_param_kw_opt)
1546+
}
1547+
}
1548+
}
1549+
15311550
if 0 != params.flags.has_kwrest() { count_failure(complex_arg_pass_param_kwrest) }
15321551
if 0 != params.flags.has_block() { count_failure(complex_arg_pass_param_block) }
15331552
if 0 != params.flags.forwardable() { count_failure(complex_arg_pass_param_forwardable) }
@@ -1537,12 +1556,16 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
15371556
return false;
15381557
}
15391558

1540-
// Because we exclude e.g. post parameters above, they are also excluded from the sum below.
1559+
// Check argument count against callee's parameters. Note that correctness for this calculation
1560+
// relies on rejecting features above.
15411561
let lead_num = params.lead_num;
15421562
let opt_num = params.opt_num;
1563+
let keyword = params.keyword;
1564+
let kw_req_num = if keyword.is_null() { 0 } else { unsafe { (*keyword).required_num } };
1565+
let req_num = lead_num + kw_req_num;
15431566
can_send = c_int::try_from(args.len())
15441567
.as_ref()
1545-
.map(|argc| (lead_num..=lead_num + opt_num).contains(argc))
1568+
.map(|argc| (req_num..=req_num + opt_num).contains(argc))
15461569
.unwrap_or(false);
15471570
if !can_send {
15481571
function.set_dynamic_send_reason(send_insn, ArgcParamMismatch);
@@ -2250,6 +2273,72 @@ impl Function {
22502273
}
22512274
}
22522275

2276+
/// Reorder keyword arguments to match the callee's expectation.
2277+
///
2278+
/// Returns Ok with reordered arguments if successful, or Err with the fallback reason if not.
2279+
fn reorder_keyword_arguments(
2280+
&self,
2281+
args: &[InsnId],
2282+
kwarg: *const rb_callinfo_kwarg,
2283+
iseq: IseqPtr,
2284+
) -> Result<Vec<InsnId>, SendFallbackReason> {
2285+
let callee_keyword = unsafe { rb_get_iseq_body_param_keyword(iseq) };
2286+
if callee_keyword.is_null() {
2287+
// Caller is passing kwargs but callee doesn't expect them.
2288+
return Err(SendWithoutBlockDirectKeywordMismatch);
2289+
}
2290+
2291+
let caller_kw_count = unsafe { get_cikw_keyword_len(kwarg) } as usize;
2292+
let callee_kw_count = unsafe { (*callee_keyword).num } as usize;
2293+
let callee_kw_required = unsafe { (*callee_keyword).required_num } as usize;
2294+
let callee_kw_table = unsafe { (*callee_keyword).table };
2295+
2296+
// For now, only handle the case where all keywords are required.
2297+
if callee_kw_count != callee_kw_required {
2298+
return Err(SendWithoutBlockDirectOptionalKeywords);
2299+
}
2300+
if caller_kw_count != callee_kw_count {
2301+
return Err(SendWithoutBlockDirectKeywordCountMismatch);
2302+
}
2303+
2304+
// The keyword arguments are the last arguments in the args vector.
2305+
let kw_args_start = args.len() - caller_kw_count;
2306+
2307+
// Build a mapping from caller keywords to their positions.
2308+
let mut caller_kw_order: Vec<ID> = Vec::with_capacity(caller_kw_count);
2309+
for i in 0..caller_kw_count {
2310+
let sym = unsafe { get_cikw_keywords_idx(kwarg, i as i32) };
2311+
let id = unsafe { rb_sym2id(sym) };
2312+
caller_kw_order.push(id);
2313+
}
2314+
2315+
// Reorder keyword arguments to match callee expectation.
2316+
let mut reordered_kw_args: Vec<InsnId> = Vec::with_capacity(callee_kw_count);
2317+
for i in 0..callee_kw_count {
2318+
let expected_id = unsafe { *callee_kw_table.add(i) };
2319+
2320+
// Find where this keyword is in the caller's order
2321+
let mut found = false;
2322+
for (j, &caller_id) in caller_kw_order.iter().enumerate() {
2323+
if caller_id == expected_id {
2324+
reordered_kw_args.push(args[kw_args_start + j]);
2325+
found = true;
2326+
break;
2327+
}
2328+
}
2329+
2330+
if !found {
2331+
// Required keyword not provided by caller which will raise an ArgumentError.
2332+
return Err(SendWithoutBlockDirectMissingKeyword);
2333+
}
2334+
}
2335+
2336+
// Replace the keyword arguments with the reordered ones.
2337+
let mut processed_args = args[..kw_args_start].to_vec();
2338+
processed_args.extend(reordered_kw_args);
2339+
Ok(processed_args)
2340+
}
2341+
22532342
/// Resolve the receiver type for method dispatch optimization.
22542343
///
22552344
/// Takes the receiver's Type, receiver HIR instruction, and ISEQ instruction index.
@@ -2465,6 +2554,7 @@ impl Function {
24652554
cme = unsafe { rb_aliased_callable_method_entry(cme) };
24662555
def_type = unsafe { get_cme_def_type(cme) };
24672556
}
2557+
24682558
if def_type == VM_METHOD_TYPE_ISEQ {
24692559
// TODO(max): Allow non-iseq; cache cme
24702560
// Only specialize positional-positional calls
@@ -2480,7 +2570,29 @@ impl Function {
24802570
if let Some(profiled_type) = profiled_type {
24812571
recv = self.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state });
24822572
}
2483-
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args, state });
2573+
2574+
// Check if caller is passing keywords but callee doesn't expect them.
2575+
let kwarg = unsafe { rb_vm_ci_kwarg(ci) };
2576+
if !kwarg.is_null() && !unsafe { rb_get_iseq_flags_has_kw(iseq) } {
2577+
// Caller has keywords but callee doesn't; Need to convert to hash.
2578+
self.set_dynamic_send_reason(insn_id, UnexpectedKeywordArgs);
2579+
self.push_insn_id(block, insn_id); continue;
2580+
}
2581+
2582+
// Handle keyword argument reordering if present.
2583+
let processed_args = if !kwarg.is_null() {
2584+
match self.reorder_keyword_arguments(&args, kwarg, iseq) {
2585+
Ok(reordered) => reordered,
2586+
Err(reason) => {
2587+
self.set_dynamic_send_reason(insn_id, reason);
2588+
self.push_insn_id(block, insn_id); continue;
2589+
}
2590+
}
2591+
} else {
2592+
args.clone()
2593+
};
2594+
2595+
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args: processed_args, state });
24842596
self.make_equal_to(insn_id, send_direct);
24852597
} else if def_type == VM_METHOD_TYPE_BMETHOD {
24862598
let procv = unsafe { rb_get_def_bmethod_proc((*cme).def) };
@@ -2515,7 +2627,29 @@ impl Function {
25152627
if let Some(profiled_type) = profiled_type {
25162628
recv = self.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state });
25172629
}
2518-
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args, state });
2630+
2631+
// Check if caller is passing keywords but callee doesn't expect them.
2632+
let kwarg = unsafe { rb_vm_ci_kwarg(ci) };
2633+
if !kwarg.is_null() && !unsafe { rb_get_iseq_flags_has_kw(iseq) } {
2634+
// Caller has keywords but callee doesn't; Need to convert to hash.
2635+
self.set_dynamic_send_reason(insn_id, UnexpectedKeywordArgs);
2636+
self.push_insn_id(block, insn_id); continue;
2637+
}
2638+
2639+
// Handle keyword argument reordering if present.
2640+
let processed_args = if !kwarg.is_null() {
2641+
match self.reorder_keyword_arguments(&args, kwarg, iseq) {
2642+
Ok(reordered) => reordered,
2643+
Err(reason) => {
2644+
self.set_dynamic_send_reason(insn_id, reason);
2645+
self.push_insn_id(block, insn_id); continue;
2646+
}
2647+
}
2648+
} else {
2649+
args.clone()
2650+
};
2651+
2652+
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args: processed_args, state });
25192653
self.make_equal_to(insn_id, send_direct);
25202654
} else if def_type == VM_METHOD_TYPE_IVAR && args.is_empty() {
25212655
// Check if we're accessing ivars of a Class or Module object as they require single-ractor mode.
@@ -3026,7 +3160,7 @@ impl Function {
30263160

30273161
// When seeing &block argument, fall back to dynamic dispatch for now
30283162
// TODO: Support block forwarding
3029-
if unspecializable_call_type(ci_flags) {
3163+
if unspecializable_c_call_type(ci_flags) {
30303164
fun.count_complex_call_features(block, ci_flags);
30313165
fun.set_dynamic_send_reason(send_insn_id, ComplexArgPass);
30323166
return Err(());
@@ -4905,9 +5039,14 @@ fn unhandled_call_type(flags: u32) -> Result<(), CallType> {
49055039
Ok(())
49065040
}
49075041

5042+
/// If a given call to a c func uses overly complex arguments, then we won't specialize.
5043+
fn unspecializable_c_call_type(flags: u32) -> bool {
5044+
((flags & VM_CALL_KWARG) != 0) ||
5045+
unspecializable_call_type(flags)
5046+
}
5047+
49085048
/// If a given call uses overly complex arguments, then we won't specialize.
49095049
fn unspecializable_call_type(flags: u32) -> bool {
4910-
((flags & VM_CALL_KWARG) != 0) ||
49115050
((flags & VM_CALL_ARGS_SPLAT) != 0) ||
49125051
((flags & VM_CALL_ARGS_BLOCKARG) != 0)
49135052
}

zjit/src/hir/opt_tests.rs

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,9 +2760,9 @@ mod hir_opt_tests {
27602760
}
27612761

27622762
#[test]
2763-
fn dont_specialize_call_to_iseq_with_kw() {
2763+
fn specialize_call_to_iseq_with_required_kw() {
27642764
eval("
2765-
def foo(a:) = 1
2765+
def foo(a:) = a * 2
27662766
def test = foo(a: 1)
27672767
test
27682768
test
@@ -2778,7 +2778,35 @@ mod hir_opt_tests {
27782778
Jump bb2(v4)
27792779
bb2(v6:BasicObject):
27802780
v11:Fixnum[1] = Const Value(1)
2781-
IncrCounter complex_arg_pass_caller_kwarg
2781+
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
2782+
PatchPoint NoSingletonClass(Object@0x1000)
2783+
v20:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)]
2784+
v21:BasicObject = SendWithoutBlockDirect v20, :foo (0x1038), v11
2785+
CheckInterrupts
2786+
Return v21
2787+
");
2788+
}
2789+
2790+
#[test]
2791+
fn test_send_call_to_iseq_with_optional_kw() {
2792+
eval("
2793+
def foo(a: 1) = a
2794+
def test = foo(a: 2)
2795+
test
2796+
test
2797+
");
2798+
assert_snapshot!(hir_string("test"), @r"
2799+
fn test@<compiled>:3:
2800+
bb0():
2801+
EntryPoint interpreter
2802+
v1:BasicObject = LoadSelf
2803+
Jump bb2(v1)
2804+
bb1(v4:BasicObject):
2805+
EntryPoint JIT(0)
2806+
Jump bb2(v4)
2807+
bb2(v6:BasicObject):
2808+
v11:Fixnum[2] = Const Value(2)
2809+
IncrCounter complex_arg_pass_param_kw_opt
27822810
v13:BasicObject = SendWithoutBlock v6, :foo, v11
27832811
CheckInterrupts
27842812
Return v13
@@ -2804,7 +2832,7 @@ mod hir_opt_tests {
28042832
Jump bb2(v4)
28052833
bb2(v6:BasicObject):
28062834
v11:Fixnum[1] = Const Value(1)
2807-
IncrCounter complex_arg_pass_caller_kwarg
2835+
IncrCounter complex_arg_pass_param_kwrest
28082836
v13:BasicObject = SendWithoutBlock v6, :foo, v11
28092837
CheckInterrupts
28102838
Return v13
@@ -2829,7 +2857,7 @@ mod hir_opt_tests {
28292857
EntryPoint JIT(0)
28302858
Jump bb2(v4)
28312859
bb2(v6:BasicObject):
2832-
IncrCounter complex_arg_pass_param_kw
2860+
IncrCounter complex_arg_pass_param_kw_opt
28332861
v11:BasicObject = SendWithoutBlock v6, :foo
28342862
CheckInterrupts
28352863
Return v11
@@ -2881,7 +2909,6 @@ mod hir_opt_tests {
28812909
v11:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
28822910
v12:StringExact = StringCopy v11
28832911
v14:Fixnum[1] = Const Value(1)
2884-
IncrCounter complex_arg_pass_caller_kwarg
28852912
v16:BasicObject = SendWithoutBlock v6, :sprintf, v12, v14
28862913
CheckInterrupts
28872914
Return v16
@@ -3179,7 +3206,7 @@ mod hir_opt_tests {
31793206
v13:NilClass = Const Value(nil)
31803207
PatchPoint MethodRedefined(Hash@0x1008, new@0x1009, cme:0x1010)
31813208
v46:HashExact = ObjectAllocClass Hash:VALUE(0x1008)
3182-
IncrCounter complex_arg_pass_param_kw
3209+
IncrCounter complex_arg_pass_param_kw_opt
31833210
IncrCounter complex_arg_pass_param_block
31843211
v20:BasicObject = SendWithoutBlock v46, :initialize
31853212
CheckInterrupts
@@ -8904,7 +8931,7 @@ mod hir_opt_tests {
89048931
bb2(v6:BasicObject):
89058932
v11:Fixnum[1] = Const Value(1)
89068933
IncrCounter complex_arg_pass_param_rest
8907-
IncrCounter complex_arg_pass_param_kw
8934+
IncrCounter complex_arg_pass_param_kw_opt
89088935
IncrCounter complex_arg_pass_param_kwrest
89098936
IncrCounter complex_arg_pass_param_block
89108937
v13:BasicObject = SendWithoutBlock v6, :fancy, v11

zjit/src/stats.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ make_counters! {
221221
send_fallback_too_many_args_for_lir,
222222
send_fallback_send_without_block_bop_redefined,
223223
send_fallback_send_without_block_operands_not_fixnum,
224+
send_fallback_send_without_block_direct_keyword_mismatch,
225+
send_fallback_send_without_block_direct_optional_keywords,
226+
send_fallback_send_without_block_direct_keyword_count_mismatch,
227+
send_fallback_send_without_block_direct_missing_keyword,
224228
send_fallback_send_polymorphic,
225229
send_fallback_send_megamorphic,
226230
send_fallback_send_no_profiles,
@@ -230,6 +234,8 @@ make_counters! {
230234
// The call has at least one feature on the caller or callee side
231235
// that the optimizer does not support.
232236
send_fallback_one_or_more_complex_arg_pass,
237+
// Caller has keyword arguments but callee doesn't expect them.
238+
send_fallback_unexpected_keyword_args,
233239
send_fallback_bmethod_non_iseq_proc,
234240
send_fallback_obj_to_string_not_string,
235241
send_fallback_send_cfunc_variadic,
@@ -344,7 +350,7 @@ make_counters! {
344350
// Unsupported parameter features
345351
complex_arg_pass_param_rest,
346352
complex_arg_pass_param_post,
347-
complex_arg_pass_param_kw,
353+
complex_arg_pass_param_kw_opt,
348354
complex_arg_pass_param_kwrest,
349355
complex_arg_pass_param_block,
350356
complex_arg_pass_param_forwardable,
@@ -542,12 +548,17 @@ pub fn send_fallback_counter(reason: crate::hir::SendFallbackReason) -> Counter
542548
TooManyArgsForLir => send_fallback_too_many_args_for_lir,
543549
SendWithoutBlockBopRedefined => send_fallback_send_without_block_bop_redefined,
544550
SendWithoutBlockOperandsNotFixnum => send_fallback_send_without_block_operands_not_fixnum,
551+
SendWithoutBlockDirectKeywordMismatch => send_fallback_send_without_block_direct_keyword_mismatch,
552+
SendWithoutBlockDirectOptionalKeywords => send_fallback_send_without_block_direct_optional_keywords,
553+
SendWithoutBlockDirectKeywordCountMismatch=> send_fallback_send_without_block_direct_keyword_count_mismatch,
554+
SendWithoutBlockDirectMissingKeyword => send_fallback_send_without_block_direct_missing_keyword,
545555
SendPolymorphic => send_fallback_send_polymorphic,
546556
SendMegamorphic => send_fallback_send_megamorphic,
547557
SendNoProfiles => send_fallback_send_no_profiles,
548558
SendCfuncVariadic => send_fallback_send_cfunc_variadic,
549559
SendCfuncArrayVariadic => send_fallback_send_cfunc_array_variadic,
550560
ComplexArgPass => send_fallback_one_or_more_complex_arg_pass,
561+
UnexpectedKeywordArgs => send_fallback_unexpected_keyword_args,
551562
ArgcParamMismatch => send_fallback_argc_param_mismatch,
552563
BmethodNonIseqProc => send_fallback_bmethod_non_iseq_proc,
553564
SendNotOptimizedMethodType(_) => send_fallback_send_not_optimized_method_type,

0 commit comments

Comments
 (0)