Skip to content

Commit e9374ae

Browse files
committed
ZJIT: Handle caller_kwarg in direct send when all keyword params are required
1 parent 821b650 commit e9374ae

3 files changed

Lines changed: 185 additions & 14 deletions

File tree

zjit/src/hir.rs

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ pub enum SendFallbackReason {
615615
SendWithoutBlockNotOptimizedMethodTypeOptimized(OptimizedMethodType),
616616
SendWithoutBlockBopRedefined,
617617
SendWithoutBlockOperandsNotFixnum,
618+
SendWithoutBlockDirectKeywordMismatch,
619+
SendWithoutBlockDirectOptionalKeywords,
620+
SendWithoutBlockDirectKeywordCountMismatch,
621+
SendWithoutBlockDirectMissingKeyword,
618622
SendPolymorphic,
619623
SendMegamorphic,
620624
SendNoProfiles,
@@ -631,6 +635,8 @@ pub enum SendFallbackReason {
631635
/// The call has at least one feature on the caller or callee side that the optimizer does not
632636
/// support.
633637
ComplexArgPass,
638+
/// Caller has keyword arguments but callee doesn't expect them; need to convert to hash.
639+
UnexpectedKeywordArgs,
634640
/// Initial fallback reason for every instruction, which should be mutated to
635641
/// a more actionable reason when an attempt to specialize the instruction fails.
636642
Uncategorized(ruby_vminsn_type),
@@ -1520,7 +1526,20 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
15201526
use Counter::*;
15211527
if 0 != params.flags.has_rest() { count_failure(complex_arg_pass_param_rest) }
15221528
if 0 != params.flags.has_post() { count_failure(complex_arg_pass_param_post) }
1523-
if 0 != params.flags.has_kw() { count_failure(complex_arg_pass_param_kw) }
1529+
1530+
// We support required-only keywords, but not optional keywords yet
1531+
if 0 != params.flags.has_kw() {
1532+
let keyword = params.keyword;
1533+
if !keyword.is_null() {
1534+
let num = unsafe { (*keyword).num };
1535+
let required_num = unsafe { (*keyword).required_num };
1536+
// Only support required keywords for now (no optional keywords)
1537+
if num != required_num {
1538+
count_failure(complex_arg_pass_param_kw_opt)
1539+
}
1540+
}
1541+
}
1542+
15241543
if 0 != params.flags.has_kwrest() { count_failure(complex_arg_pass_param_kwrest) }
15251544
if 0 != params.flags.has_block() { count_failure(complex_arg_pass_param_block) }
15261545
if 0 != params.flags.forwardable() { count_failure(complex_arg_pass_param_forwardable) }
@@ -1530,12 +1549,16 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
15301549
return false;
15311550
}
15321551

1533-
// Because we exclude e.g. post parameters above, they are also excluded from the sum below.
1552+
// Check argument count against callee's parameters. Note that correctness for this calculation
1553+
// relies on rejecting features above.
15341554
let lead_num = params.lead_num;
15351555
let opt_num = params.opt_num;
1556+
let keyword = params.keyword;
1557+
let kw_req_num = if keyword.is_null() { 0 } else { unsafe { (*keyword).required_num } };
1558+
let req_num = lead_num + kw_req_num;
15361559
can_send = c_int::try_from(args.len())
15371560
.as_ref()
1538-
.map(|argc| (lead_num..=lead_num + opt_num).contains(argc))
1561+
.map(|argc| (req_num..=req_num + opt_num).contains(argc))
15391562
.unwrap_or(false);
15401563
if !can_send {
15411564
function.set_dynamic_send_reason(send_insn, ArgcParamMismatch);
@@ -2230,6 +2253,72 @@ impl Function {
22302253
}
22312254
}
22322255

2256+
/// Reorder keyword arguments to match the callee's expectation.
2257+
///
2258+
/// Returns Ok with reordered arguments if successful, or Err with the fallback reason if not.
2259+
fn reorder_keyword_arguments(
2260+
&self,
2261+
args: &[InsnId],
2262+
kwarg: *const rb_callinfo_kwarg,
2263+
iseq: IseqPtr,
2264+
) -> Result<Vec<InsnId>, SendFallbackReason> {
2265+
let callee_keyword = unsafe { get_iseq_body_param_keyword(iseq) };
2266+
if callee_keyword.is_null() {
2267+
// Caller is passing kwargs but callee doesn't expect them.
2268+
return Err(SendWithoutBlockDirectKeywordMismatch);
2269+
}
2270+
2271+
let caller_kw_count = unsafe { get_cikw_keyword_len(kwarg) } as usize;
2272+
let callee_kw_count = unsafe { (*callee_keyword).num } as usize;
2273+
let callee_kw_required = unsafe { (*callee_keyword).required_num } as usize;
2274+
let callee_kw_table = unsafe { (*callee_keyword).table };
2275+
2276+
// For now, only handle the case where all keywords are required.
2277+
if callee_kw_count != callee_kw_required {
2278+
return Err(SendWithoutBlockDirectOptionalKeywords);
2279+
}
2280+
if caller_kw_count != callee_kw_count {
2281+
return Err(SendWithoutBlockDirectKeywordCountMismatch);
2282+
}
2283+
2284+
// The keyword arguments are the last arguments in the args vector.
2285+
let kw_args_start = args.len() - caller_kw_count;
2286+
2287+
// Build a mapping from caller keywords to their positions.
2288+
let mut caller_kw_order: Vec<ID> = Vec::with_capacity(caller_kw_count);
2289+
for i in 0..caller_kw_count {
2290+
let sym = unsafe { get_cikw_keywords_idx(kwarg, i as i32) };
2291+
let id = unsafe { rb_sym2id(sym) };
2292+
caller_kw_order.push(id);
2293+
}
2294+
2295+
// Reorder keyword arguments to match callee expectation.
2296+
let mut reordered_kw_args: Vec<InsnId> = Vec::with_capacity(callee_kw_count);
2297+
for i in 0..callee_kw_count {
2298+
let expected_id = unsafe { *callee_kw_table.add(i) };
2299+
2300+
// Find where this keyword is in the caller's order
2301+
let mut found = false;
2302+
for (j, &caller_id) in caller_kw_order.iter().enumerate() {
2303+
if caller_id == expected_id {
2304+
reordered_kw_args.push(args[kw_args_start + j]);
2305+
found = true;
2306+
break;
2307+
}
2308+
}
2309+
2310+
if !found {
2311+
// Required keyword not provided by caller which will raise an ArgumentError.
2312+
return Err(SendWithoutBlockDirectMissingKeyword);
2313+
}
2314+
}
2315+
2316+
// Replace the keyword arguments with the reordered ones.
2317+
let mut processed_args = args[..kw_args_start].to_vec();
2318+
processed_args.extend(reordered_kw_args);
2319+
Ok(processed_args)
2320+
}
2321+
22332322
/// Resolve the receiver type for method dispatch optimization.
22342323
///
22352324
/// Takes the receiver's Type, receiver HIR instruction, and ISEQ instruction index.
@@ -2449,7 +2538,29 @@ impl Function {
24492538
if let Some(profiled_type) = profiled_type {
24502539
recv = self.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state });
24512540
}
2452-
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args, state });
2541+
2542+
// Check if caller is passing keywords but callee doesn't expect them.
2543+
let kwarg = unsafe { rb_vm_ci_kwarg(ci) };
2544+
if !kwarg.is_null() && !unsafe { rb_get_iseq_flags_has_kw(iseq) } {
2545+
// Caller has keywords but callee doesn't; Need to convert to hash.
2546+
self.set_dynamic_send_reason(insn_id, UnexpectedKeywordArgs);
2547+
self.push_insn_id(block, insn_id); continue;
2548+
}
2549+
2550+
// Handle keyword argument reordering if present.
2551+
let processed_args = if !kwarg.is_null() {
2552+
match self.reorder_keyword_arguments(&args, kwarg, iseq) {
2553+
Ok(reordered) => reordered,
2554+
Err(reason) => {
2555+
self.set_dynamic_send_reason(insn_id, reason);
2556+
self.push_insn_id(block, insn_id); continue;
2557+
}
2558+
}
2559+
} else {
2560+
args.clone()
2561+
};
2562+
2563+
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args: processed_args, state });
24532564
self.make_equal_to(insn_id, send_direct);
24542565
} else if def_type == VM_METHOD_TYPE_BMETHOD {
24552566
let procv = unsafe { rb_get_def_bmethod_proc((*cme).def) };
@@ -2484,7 +2595,29 @@ impl Function {
24842595
if let Some(profiled_type) = profiled_type {
24852596
recv = self.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state });
24862597
}
2487-
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args, state });
2598+
2599+
// Check if caller is passing keywords but callee doesn't expect them.
2600+
let kwarg = unsafe { rb_vm_ci_kwarg(ci) };
2601+
if !kwarg.is_null() && !unsafe { rb_get_iseq_flags_has_kw(iseq) } {
2602+
// Caller has keywords but callee doesn't; Need to convert to hash.
2603+
self.set_dynamic_send_reason(insn_id, UnexpectedKeywordArgs);
2604+
self.push_insn_id(block, insn_id); continue;
2605+
}
2606+
2607+
// Handle keyword argument reordering if present.
2608+
let processed_args = if !kwarg.is_null() {
2609+
match self.reorder_keyword_arguments(&args, kwarg, iseq) {
2610+
Ok(reordered) => reordered,
2611+
Err(reason) => {
2612+
self.set_dynamic_send_reason(insn_id, reason);
2613+
self.push_insn_id(block, insn_id); continue;
2614+
}
2615+
}
2616+
} else {
2617+
args.clone()
2618+
};
2619+
2620+
let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { recv, cd, cme, iseq, args: processed_args, state });
24882621
self.make_equal_to(insn_id, send_direct);
24892622
} else if def_type == VM_METHOD_TYPE_IVAR && args.is_empty() {
24902623
self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid, cme }, state });
@@ -4851,7 +4984,6 @@ fn unhandled_call_type(flags: u32) -> Result<(), CallType> {
48514984

48524985
/// If a given call uses overly complex arguments, then we won't specialize.
48534986
fn unspecializable_call_type(flags: u32) -> bool {
4854-
((flags & VM_CALL_KWARG) != 0) ||
48554987
((flags & VM_CALL_ARGS_SPLAT) != 0) ||
48564988
((flags & VM_CALL_ARGS_BLOCKARG) != 0)
48574989
}

zjit/src/hir/opt_tests.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,9 +2760,9 @@ mod hir_opt_tests {
27602760
}
27612761

27622762
#[test]
2763-
fn dont_specialize_call_to_iseq_with_kw() {
2763+
fn specialize_call_to_iseq_with_required_kw() {
27642764
eval("
2765-
def foo(a:) = 1
2765+
def foo(a:) = a * 2
27662766
def test = foo(a: 1)
27672767
test
27682768
test
@@ -2778,7 +2778,35 @@ mod hir_opt_tests {
27782778
Jump bb2(v4)
27792779
bb2(v6:BasicObject):
27802780
v11:Fixnum[1] = Const Value(1)
2781-
IncrCounter complex_arg_pass_caller_kwarg
2781+
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
2782+
PatchPoint NoSingletonClass(Object@0x1000)
2783+
v20:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)]
2784+
v21:BasicObject = SendWithoutBlockDirect v20, :foo (0x1038), v11
2785+
CheckInterrupts
2786+
Return v21
2787+
");
2788+
}
2789+
2790+
#[test]
2791+
fn test_send_call_to_iseq_with_optional_kw() {
2792+
eval("
2793+
def foo(a: 1) = a
2794+
def test = foo(a: 2)
2795+
test
2796+
test
2797+
");
2798+
assert_snapshot!(hir_string("test"), @r"
2799+
fn test@<compiled>:3:
2800+
bb0():
2801+
EntryPoint interpreter
2802+
v1:BasicObject = LoadSelf
2803+
Jump bb2(v1)
2804+
bb1(v4:BasicObject):
2805+
EntryPoint JIT(0)
2806+
Jump bb2(v4)
2807+
bb2(v6:BasicObject):
2808+
v11:Fixnum[2] = Const Value(2)
2809+
IncrCounter complex_arg_pass_param_kw_opt
27822810
v13:BasicObject = SendWithoutBlock v6, :foo, v11
27832811
CheckInterrupts
27842812
Return v13
@@ -2804,7 +2832,7 @@ mod hir_opt_tests {
28042832
Jump bb2(v4)
28052833
bb2(v6:BasicObject):
28062834
v11:Fixnum[1] = Const Value(1)
2807-
IncrCounter complex_arg_pass_caller_kwarg
2835+
IncrCounter complex_arg_pass_param_kwrest
28082836
v13:BasicObject = SendWithoutBlock v6, :foo, v11
28092837
CheckInterrupts
28102838
Return v13
@@ -2829,7 +2857,7 @@ mod hir_opt_tests {
28292857
EntryPoint JIT(0)
28302858
Jump bb2(v4)
28312859
bb2(v6:BasicObject):
2832-
IncrCounter complex_arg_pass_param_kw
2860+
IncrCounter complex_arg_pass_param_kw_opt
28332861
v11:BasicObject = SendWithoutBlock v6, :foo
28342862
CheckInterrupts
28352863
Return v11
@@ -3115,7 +3143,7 @@ mod hir_opt_tests {
31153143
v13:NilClass = Const Value(nil)
31163144
PatchPoint MethodRedefined(Hash@0x1008, new@0x1009, cme:0x1010)
31173145
v46:HashExact = ObjectAllocClass Hash:VALUE(0x1008)
3118-
IncrCounter complex_arg_pass_param_kw
3146+
IncrCounter complex_arg_pass_param_kw_opt
31193147
IncrCounter complex_arg_pass_param_block
31203148
v20:BasicObject = SendWithoutBlock v46, :initialize
31213149
CheckInterrupts
@@ -8453,7 +8481,7 @@ mod hir_opt_tests {
84538481
bb2(v6:BasicObject):
84548482
v11:Fixnum[1] = Const Value(1)
84558483
IncrCounter complex_arg_pass_param_rest
8456-
IncrCounter complex_arg_pass_param_kw
8484+
IncrCounter complex_arg_pass_param_kw_opt
84578485
IncrCounter complex_arg_pass_param_kwrest
84588486
IncrCounter complex_arg_pass_param_block
84598487
v13:BasicObject = SendWithoutBlock v6, :fancy, v11

zjit/src/stats.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ make_counters! {
221221
send_fallback_too_many_args_for_lir,
222222
send_fallback_send_without_block_bop_redefined,
223223
send_fallback_send_without_block_operands_not_fixnum,
224+
send_fallback_send_without_block_direct_keyword_mismatch,
225+
send_fallback_send_without_block_direct_optional_keywords,
226+
send_fallback_send_without_block_direct_keyword_count_mismatch,
227+
send_fallback_send_without_block_direct_missing_keyword,
224228
send_fallback_send_polymorphic,
225229
send_fallback_send_megamorphic,
226230
send_fallback_send_no_profiles,
@@ -230,6 +234,8 @@ make_counters! {
230234
// The call has at least one feature on the caller or callee side
231235
// that the optimizer does not support.
232236
send_fallback_one_or_more_complex_arg_pass,
237+
// Caller has keyword arguments but callee doesn't expect them.
238+
send_fallback_unexpected_keyword_args,
233239
send_fallback_bmethod_non_iseq_proc,
234240
send_fallback_obj_to_string_not_string,
235241
send_fallback_send_cfunc_variadic,
@@ -344,7 +350,7 @@ make_counters! {
344350
// Unsupported parameter features
345351
complex_arg_pass_param_rest,
346352
complex_arg_pass_param_post,
347-
complex_arg_pass_param_kw,
353+
complex_arg_pass_param_kw_opt,
348354
complex_arg_pass_param_kwrest,
349355
complex_arg_pass_param_block,
350356
complex_arg_pass_param_forwardable,
@@ -542,12 +548,17 @@ pub fn send_fallback_counter(reason: crate::hir::SendFallbackReason) -> Counter
542548
TooManyArgsForLir => send_fallback_too_many_args_for_lir,
543549
SendWithoutBlockBopRedefined => send_fallback_send_without_block_bop_redefined,
544550
SendWithoutBlockOperandsNotFixnum => send_fallback_send_without_block_operands_not_fixnum,
551+
SendWithoutBlockDirectKeywordMismatch => send_fallback_send_without_block_direct_keyword_mismatch,
552+
SendWithoutBlockDirectOptionalKeywords => send_fallback_send_without_block_direct_optional_keywords,
553+
SendWithoutBlockDirectKeywordCountMismatch=> send_fallback_send_without_block_direct_keyword_count_mismatch,
554+
SendWithoutBlockDirectMissingKeyword => send_fallback_send_without_block_direct_missing_keyword,
545555
SendPolymorphic => send_fallback_send_polymorphic,
546556
SendMegamorphic => send_fallback_send_megamorphic,
547557
SendNoProfiles => send_fallback_send_no_profiles,
548558
SendCfuncVariadic => send_fallback_send_cfunc_variadic,
549559
SendCfuncArrayVariadic => send_fallback_send_cfunc_array_variadic,
550560
ComplexArgPass => send_fallback_one_or_more_complex_arg_pass,
561+
UnexpectedKeywordArgs => send_fallback_unexpected_keyword_args,
551562
ArgcParamMismatch => send_fallback_argc_param_mismatch,
552563
BmethodNonIseqProc => send_fallback_bmethod_non_iseq_proc,
553564
SendNotOptimizedMethodType(_) => send_fallback_send_not_optimized_method_type,

0 commit comments

Comments
 (0)