Skip to content

Commit f4fe2ff

Browse files
committed
ZJIT: Optimize send with a nil block to SendWithoutBlockDirect
A call like `foo(&block)` sets VM_CALL_ARGS_BLOCKARG, which previously forced a dynamic dispatch even when the block argument is nil and the callee is an ISEQ method that could otherwise be reduced to a SendDirect. Passing a nil block argument is equivalent to calling without a block, so when we can prove the block argument is nil we strip it and treat the call as blockless. This happens when the block argument is either: * statically known to be nil, or * profiled as monomorphically nil, in which case we emit a GuardBitEquals against nil (side exit reason BlockArgNotNil) before stripping it so a non-nil block at runtime takes a side exit. When the block argument can't be proven nil, we fall back to a dynamic Send and record the SendBlockArgNotNil fallback reason. The stripped-block frame state is used only for the callee frame for the direct send. The pre-call guards (the nil-block GuardBitEquals, the receiver GuardType, and the method-redefinition PatchPoint) keep using the original frame state, which still has the block argument on the stack so that it is there if we side exit.
1 parent a5f6050 commit f4fe2ff

4 files changed

Lines changed: 145 additions & 29 deletions

File tree

test/ruby/test_zjit.rb

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,56 @@ def test_float_arithmetic
452452
assert_compiles '-2', 'def test = (-2.9).to_i; test'
453453
end
454454

455+
def test_send_forwarded_block_arg_nil_then_non_nil
456+
# Regression test: when a forwarded &block arg is profiled as nil, the nil
457+
# block optimization must update the frame state to match the stripped args.
458+
# Otherwise the saved SP is off by one, causing a stack consistency error
459+
# when the guard side-exits for a non-nil block.
460+
assert_runs ':ok', <<~RUBY, call_threshold: 2
461+
def inner(callable = nil, &block)
462+
callable || block
463+
end
464+
465+
def outer(&block)
466+
inner(&block)
467+
end
468+
469+
100.times { outer }
470+
result = outer { |x| x }
471+
result.is_a?(Proc) ? :ok : :fail
472+
RUBY
473+
end
474+
475+
def test_send_forwarded_nil_block_arg_with_polymorphic_receiver
476+
# Regression test: the nil block optimization strips the block arg from the
477+
# frame state used to set up the callee frame, but the pre-call guards
478+
# (receiver GuardType, method-redefinition PatchPoint) must keep using the
479+
# original frame state that still has the block arg on the stack. Otherwise a
480+
# guard side-exit re-executes the send with a stack that is missing the block
481+
# arg slot, corrupting the pushed frame's EP (VM_ENV_FLAGS assertion failure).
482+
# A polymorphic receiver forces the receiver GuardType to side-exit.
483+
assert_runs ':ok', <<~RUBY, call_threshold: 2
484+
class Base
485+
def self.inner(model, name, &block)
486+
block ? block.call : model
487+
end
488+
def self.outer(model, name, &block)
489+
inner(model, name, &block)
490+
end
491+
end
492+
class A < Base; end
493+
class B < Base; end
494+
class C < Base; end
495+
class D < Base; end
496+
497+
1000.times do |i|
498+
klass = [A, B, C, D][i % 4]
499+
klass.outer(i, :n)
500+
end
501+
:ok
502+
RUBY
503+
end
504+
455505
private
456506

457507
# Assert that every method call in `test_script` can be compiled by ZJIT

zjit/src/hir.rs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ pub enum SideExitReason {
531531
UnhandledYARVInsn(u32),
532532
UnhandledCallType(CallType),
533533
UnhandledBlockArg,
534+
BlockArgNotNil,
534535
TooManyKeywordParameters,
535536
TooManyArgsForLir,
536537
FixnumAddOverflow,
@@ -694,6 +695,8 @@ pub enum SendFallbackReason {
694695
SendCfuncArrayVariadic,
695696
SendNotOptimizedMethodType(MethodType),
696697
SendNotOptimizedNeedPermission,
698+
/// The block argument is not nil, so we can't optimize to SendWithoutBlockDirect
699+
SendBlockArgNotNil,
697700
CCallWithFrameTooManyArgs,
698701
ObjToStringNotString,
699702
TooManyArgsForLir,
@@ -768,6 +771,7 @@ impl Display for SendFallbackReason {
768771
SendCfuncVariadic => write!(f, "Send: C function is variadic"),
769772
SendCfuncArrayVariadic => write!(f, "Send: C function expects array variadic"),
770773
SendNotOptimizedMethodType(method_type) => write!(f, "Send: unsupported method type {:?}", method_type),
774+
SendBlockArgNotNil => write!(f, "Send: block argument is not nil"),
771775
CCallWithFrameTooManyArgs => write!(f, "CCallWithFrame: too many arguments"),
772776
ObjToStringNotString => write!(f, "ObjToString: result is not a string"),
773777
TooManyArgsForLir => write!(f, "Too many arguments for LIR"),
@@ -2488,7 +2492,7 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
24882492
let params = unsafe { iseq.params() };
24892493

24902494
let callee_has_block_param = 0 != params.flags.has_block();
2491-
let caller_passes_block_arg = (unsafe { rb_vm_ci_flag(ci) } & VM_CALL_ARGS_BLOCKARG) != 0;
2495+
let caller_passes_block_arg = has_block && (unsafe { rb_vm_ci_flag(ci) } & VM_CALL_ARGS_BLOCKARG) != 0;
24922496

24932497
use Counter::*;
24942498
if 0 != params.flags.has_rest() { count_failure(complex_arg_pass_param_rest) }
@@ -3803,7 +3807,7 @@ impl Function {
38033807
Insn::Send { recv, block: None, args, state, cd, .. } if ruby_call_method_id(cd) == ID!(minusat) && args.is_empty() =>
38043808
self.try_rewrite_uminus(block, insn_id, recv, state),
38053809
ref send@Insn::Send { mut recv, cd, state, block: send_block, ref args, .. } => {
3806-
let has_block = send_block.is_some();
3810+
let mut has_block = send_block.is_some();
38073811
let (klass, profiled_type) = match self.resolve_receiver_type(recv, self.type_of(recv), state) {
38083812
ReceiverTypeResolution::StaticallyKnown { class } => (class, None),
38093813
ReceiverTypeResolution::Monomorphic { profiled_type }
@@ -3864,9 +3868,57 @@ impl Function {
38643868
def_type = unsafe { get_cme_def_type(cme) };
38653869
}
38663870

3871+
// Check if we can optimize `foo(&block)` where block is nil to a send without block.
3872+
// `state` keeps referring to the pre-send frame state (block arg still on the
3873+
// stack). Any guard that side-exits before the call re-executes the `send` in
3874+
// the interpreter, so it must reconstruct the stack with the block arg present.
3875+
// Only the direct-send frame setup uses `send_frame_state`, which has the nil
3876+
// block arg stripped from the stack.
3877+
let mut send_block = send_block;
3878+
let mut send_frame_state = state;
3879+
let mut args = args.to_vec();
3880+
let mut stripped_nil_block = false;
3881+
if send_block == Some(BlockHandler::BlockArg) && def_type == VM_METHOD_TYPE_ISEQ {
3882+
// The block arg is the last element in args
3883+
if let Some(&block_arg) = args.last() {
3884+
let statically_nil = self.is_a(block_arg, types::NilClass);
3885+
let profiled_nil = self.profiled_type_of_at(block_arg, state)
3886+
.map_or(false, |pt| pt.is_nil());
3887+
if statically_nil || profiled_nil {
3888+
if !statically_nil {
3889+
// Guard needed when relying on profiled type. Uses the original
3890+
// `state` so a side exit re-executes the send with the block
3891+
// arg still on the VM stack.
3892+
self.push_insn(block, Insn::GuardBitEquals {
3893+
val: block_arg,
3894+
expected: Const::Value(Qnil),
3895+
reason: SideExitReason::BlockArgNotNil,
3896+
state,
3897+
recompile: None,
3898+
});
3899+
}
3900+
// Strip nil block arg and treat as no block
3901+
args = args[..args.len() - 1].to_vec();
3902+
send_block = None;
3903+
has_block = false;
3904+
stripped_nil_block = true;
3905+
// Frame state for the direct send only: the block arg is removed
3906+
// from the stack so the callee frame is laid out correctly.
3907+
let new_state = self.frame_state(state).with_replaced_args(&args, args.len() + 1);
3908+
send_frame_state = self.push_insn(block, Insn::Snapshot { state: Box::new(new_state) });
3909+
} else {
3910+
// Can't prove block arg is nil
3911+
self.set_dynamic_send_reason(insn_id, SendBlockArgNotNil);
3912+
self.push_insn_id(block, insn_id); continue;
3913+
}
3914+
}
3915+
}
3916+
38673917
// If the call site info indicates that the `Function` has overly complex arguments, then do not optimize into a `SendDirect`.
38683918
// Optimized methods(`VM_METHOD_TYPE_OPTIMIZED`) and C methods handle their own argument constraints (e.g., kw_splat for Proc call).
3869-
if def_type != VM_METHOD_TYPE_OPTIMIZED && def_type != VM_METHOD_TYPE_CFUNC && unspecializable_call_type(flags) {
3919+
// Mask out ARGS_BLOCKARG only if we've already handled the nil block arg case above.
3920+
let flags_for_check = if stripped_nil_block { flags & !VM_CALL_ARGS_BLOCKARG } else { flags };
3921+
if def_type != VM_METHOD_TYPE_OPTIMIZED && def_type != VM_METHOD_TYPE_CFUNC && unspecializable_call_type(flags_for_check) {
38703922
self.count_complex_call_features(block, flags);
38713923
self.set_dynamic_send_reason(insn_id, ComplexArgPass);
38723924
self.push_insn_id(block, insn_id); continue;
@@ -3882,7 +3934,7 @@ impl Function {
38823934
}
38833935

38843936
// Check if the args are compatible before emitting any assmptions
3885-
let Ok((send_state, processed_args, kw_bits)) = self.prepare_direct_send_args(block, &args, ci, iseq, state)
3937+
let Ok((send_state, processed_args, kw_bits)) = self.prepare_direct_send_args(block, &args, ci, iseq, send_frame_state)
38863938
.inspect_err(|&reason| self.set_dynamic_send_reason(insn_id, reason)) else {
38873939
self.push_insn_id(block, insn_id); continue;
38883940
};
@@ -3922,7 +3974,7 @@ impl Function {
39223974
}
39233975

39243976
// Check if the args are compatible before emitting any assmptions
3925-
let Ok((send_state, processed_args, kw_bits)) = self.prepare_direct_send_args(block, &args, ci, iseq, state)
3977+
let Ok((send_state, processed_args, kw_bits)) = self.prepare_direct_send_args(block, &args, ci, iseq, send_frame_state)
39263978
.inspect_err(|&reason| self.set_dynamic_send_reason(insn_id, reason)) else {
39273979
self.push_insn_id(block, insn_id); continue;
39283980
};

zjit/src/hir/opt_tests.rs

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5207,7 +5207,7 @@ mod hir_opt_tests {
52075207
v26:TrueClass = GuardBitEquals v25, Value(true) recompile
52085208
Jump bb6(v24, v10)
52095209
bb6(v16:BasicObject, v17:BasicObject):
5210-
v29:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Complex argument passing
5210+
v29:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Send: block argument is not nil
52115211
CheckInterrupts
52125212
Return v29
52135213
");
@@ -5263,7 +5263,7 @@ mod hir_opt_tests {
52635263
v37:NilClass = Const Value(nil)
52645264
Jump bb8(v37, v13)
52655265
bb8(v27:BasicObject, v28:BasicObject):
5266-
v40:BasicObject = Send v25, &block, :then, v27 # SendFallbackReason: Complex argument passing
5266+
v40:BasicObject = Send v25, &block, :then, v27 # SendFallbackReason: Send: block argument is not nil
52675267
CheckInterrupts
52685268
Return v40
52695269
bb4(v45:BasicObject, v46:Falsy, v47:BasicObject):
@@ -5428,7 +5428,7 @@ mod hir_opt_tests {
54285428
v34:NilClass = Const Value(nil)
54295429
Jump bb6(v34, v10)
54305430
bb6(v16:BasicObject, v17:BasicObject):
5431-
v38:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Complex argument passing
5431+
v38:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Send: block argument is not nil
54325432
CheckInterrupts
54335433
Return v38
54345434
bb10():
@@ -5496,7 +5496,7 @@ mod hir_opt_tests {
54965496
v41:ObjectSubclass[BlockParamProxy] = Const Value(VALUE(0x1010))
54975497
Jump bb6(v41, v10)
54985498
bb6(v16:BasicObject, v17:BasicObject):
5499-
v45:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Complex argument passing
5499+
v45:BasicObject = Send v14, &block, :then, v16 # SendFallbackReason: Send: block argument is not nil
55005500
CheckInterrupts
55015501
Return v45
55025502
bb13():
@@ -9519,7 +9519,7 @@ mod hir_opt_tests {
95199519
v26:ObjectSubclass[BlockParamProxy] = Const Value(VALUE(0x1008))
95209520
Jump bb6(v26, v10)
95219521
bb6(v16:BasicObject, v17:BasicObject):
9522-
v29:BasicObject = Send v14, &block, :map, v16 # SendFallbackReason: Complex argument passing
9522+
v29:BasicObject = Send v14, &block, :map, v16 # SendFallbackReason: Send: block argument is not nil
95239523
CheckInterrupts
95249524
Return v29
95259525
");
@@ -9559,9 +9559,12 @@ mod hir_opt_tests {
95599559
v26:NilClass = Const Value(nil)
95609560
Jump bb6(v26, v10)
95619561
bb6(v16:BasicObject, v17:BasicObject):
9562-
v29:BasicObject = Send v14, &block, :map, v16 # SendFallbackReason: Complex argument passing
9562+
v35:NilClass = GuardBitEquals v16, Value(nil)
9563+
PatchPoint NoSingletonClass(Array@0x1008)
9564+
PatchPoint MethodRedefined(Array@0x1008, map@0x1010, cme:0x1018)
9565+
v40:BasicObject = SendDirect v14, 0x1040, :map (0x1050)
95639566
CheckInterrupts
9564-
Return v29
9567+
Return v40
95659568
");
95669569
}
95679570

@@ -9600,7 +9603,7 @@ mod hir_opt_tests {
96009603
v21:ObjectSubclass[BlockParamProxy] = Const Value(VALUE(0x1008))
96019604
Jump bb6(v21)
96029605
bb6(v12:BasicObject):
9603-
v24:BasicObject = Send v10, &block, :map, v12 # SendFallbackReason: Complex argument passing
9606+
v24:BasicObject = Send v10, &block, :map, v12 # SendFallbackReason: Send: block argument is not nil
96049607
CheckInterrupts
96059608
Return v24
96069609
");
@@ -9768,7 +9771,7 @@ mod hir_opt_tests {
97689771
Jump bb3(v5, v6)
97699772
bb3(v8:BasicObject, v9:NilClass):
97709773
v13:StaticSymbol[:to_s] = Const Value(VALUE(0x1000))
9771-
v19:BasicObject = Send v8, &block, :foo, v13 # SendFallbackReason: Complex argument passing
9774+
v19:BasicObject = Send v8, &block, :foo, v13 # SendFallbackReason: Send: block argument is not nil
97729775
CheckInterrupts
97739776
Return v19
97749777
");
@@ -9799,9 +9802,11 @@ mod hir_opt_tests {
97999802
Jump bb3(v5, v6)
98009803
bb3(v8:BasicObject, v9:NilClass):
98019804
v13:NilClass = Const Value(nil)
9802-
v19:BasicObject = Send v8, &block, :foo, v13 # SendFallbackReason: Complex argument passing
9805+
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
9806+
v27:ObjectSubclass[class_exact*:Object@VALUE(0x1000)] = GuardType v8, ObjectSubclass[class_exact*:Object@VALUE(0x1000)] recompile
9807+
v28:Fixnum[42] = Const Value(42)
98039808
CheckInterrupts
9804-
Return v19
9809+
Return v28
98059810
");
98069811
}
98079812

@@ -9830,19 +9835,24 @@ mod hir_opt_tests {
98309835
Jump bb3(v6, v7)
98319836
bb3(v9:BasicObject, v10:BasicObject):
98329837
v17:CPtr = GetEP 0
9833-
v18:CBool = IsBlockParamModified v17
9834-
IfTrue v18, bb4()
9835-
v23:CInt64 = LoadField v17, :_env_data_index_specval@0x1001
9836-
v24:CInt64[0] = GuardBitEquals v23, CInt64(0)
9837-
v25:NilClass = Const Value(nil)
9838-
Jump bb6(v25, v10)
9838+
v18:CUInt64 = LoadField v17, :VM_ENV_DATA_INDEX_FLAGS@0x1001
9839+
v19:CBool = IsBlockParamModified v18
9840+
CondBranch v19, bb4(), bb5()
98399841
bb4():
98409842
v21:BasicObject = LoadField v17, :block@0x1002
98419843
Jump bb6(v21, v21)
9844+
bb5():
9845+
v23:CInt64 = LoadField v17, :VM_ENV_DATA_INDEX_SPECVAL@0x1003
9846+
v24:CInt64[0] = GuardBitEquals v23, CInt64(0) recompile
9847+
v25:NilClass = Const Value(nil)
9848+
Jump bb6(v25, v10)
98429849
bb6(v15:BasicObject, v16:BasicObject):
9843-
v28:BasicObject = Send v9, &block, :foo, v15 # SendFallbackReason: Complex argument passing
9850+
v34:NilClass = GuardBitEquals v15, Value(nil)
9851+
PatchPoint MethodRedefined(Object@0x1008, foo@0x1010, cme:0x1018)
9852+
v37:ObjectSubclass[class_exact*:Object@VALUE(0x1008)] = GuardType v9, ObjectSubclass[class_exact*:Object@VALUE(0x1008)] recompile
9853+
v38:Fixnum[42] = Const Value(42)
98449854
CheckInterrupts
9845-
Return v28
9855+
Return v38
98469856
");
98479857
}
98489858

@@ -13203,7 +13213,7 @@ mod hir_opt_tests {
1320313213
Jump bb3(v4)
1320413214
bb3(v6:BasicObject):
1320513215
v11:StaticSymbol[:the_block] = Const Value(VALUE(0x1000))
13206-
v13:BasicObject = Send v6, &block, :callee, v11 # SendFallbackReason: Complex argument passing
13216+
v13:BasicObject = Send v6, &block, :callee, v11 # SendFallbackReason: Send: block argument is not nil
1320713217
CheckInterrupts
1320813218
Return v13
1320913219
");
@@ -13248,7 +13258,7 @@ mod hir_opt_tests {
1324813258
v26:ObjectSubclass[BlockParamProxy] = Const Value(VALUE(0x1008))
1324913259
Jump bb6(v26, v10)
1325013260
bb6(v16:BasicObject, v17:BasicObject):
13251-
v29:BasicObject = Send v14, &block, :map, v16 # SendFallbackReason: Complex argument passing
13261+
v29:BasicObject = Send v14, &block, :map, v16 # SendFallbackReason: Send: block argument is not nil
1325213262
CheckInterrupts
1325313263
Return v29
1325413264
");
@@ -17095,7 +17105,7 @@ mod hir_opt_tests {
1709517105
test(true, block)
1709617106
");
1709717107

17098-
assert_snapshot!(hir_string("test"), @r"
17108+
assert_snapshot!(hir_string("test"), @"
1709917109
fn test@<compiled>:7:
1710017110
bb1():
1710117111
EntryPoint interpreter
@@ -17118,7 +17128,7 @@ mod hir_opt_tests {
1711817128
bb5():
1711917129
v22:Truthy = RefineType v12, Truthy
1712017130
v26:Fixnum[42] = Const Value(42)
17121-
v29:BasicObject = Send v11, &block, :passthrough_recompile_blockarg, v26, v13 # SendFallbackReason: Complex argument passing
17131+
v29:BasicObject = Send v11, &block, :passthrough_recompile_blockarg, v26, v13 # SendFallbackReason: Send: block argument is not nil
1712217132
CheckInterrupts
1712317133
Return v29
1712417134
bb4(v34:BasicObject, v35:Falsy, v36:BasicObject):
@@ -19326,7 +19336,7 @@ mod hir_opt_tests {
1932619336
v47:ObjectSubclass[BlockParamProxy] = Const Value(VALUE(0x1050))
1932719337
Jump bb8(v47, v55)
1932819338
bb8(v37:BasicObject, v38:BasicObject):
19329-
v50:BasicObject = Send v27, &block, :inner, v10, v37 # SendFallbackReason: Complex argument passing
19339+
v50:BasicObject = Send v27, &block, :inner, v10, v37 # SendFallbackReason: Send: block argument is not nil
1933019340
CheckInterrupts
1933119341
PopInlineFrame
1933219342
PatchPoint NoEPEscape(test)

0 commit comments

Comments
 (0)