From f2d8786bcdaa85b7ac80cb714ae38923531ad0dd Mon Sep 17 00:00:00 2001 From: rwestrel Date: Thu, 18 Jun 2026 10:22:11 +0200 Subject: [PATCH] Backport 5803dd3e80f4fa6d4795960f88425d4b570ac050 --- src/hotspot/share/opto/callnode.hpp | 15 +- src/hotspot/share/opto/cfgnode.hpp | 1 + src/hotspot/share/opto/ifnode.cpp | 55 ++++ src/hotspot/share/opto/split_if.cpp | 1 + .../rangechecks/TestFoldedIfsWrongReexec.java | 299 ++++++++++++++++++ 5 files changed, 369 insertions(+), 2 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java diff --git a/src/hotspot/share/opto/callnode.hpp b/src/hotspot/share/opto/callnode.hpp index cde237e6e31..a44c31cf6b3 100644 --- a/src/hotspot/share/opto/callnode.hpp +++ b/src/hotspot/share/opto/callnode.hpp @@ -728,11 +728,14 @@ class CallJavaNode : public CallNode { // calls and optimized virtual calls, plus calls to wrappers for run-time // routines); generates static stub. class CallStaticJavaNode : public CallJavaNode { + // If this is an uncommon trap guarded by some condition, is it safe to change the condition to a narrower condition? + // See comment in PhaseIdealLoop::do_split_if() + bool _safe_for_fold_compare; virtual bool cmp( const Node &n ) const; virtual uint size_of() const; // Size is bigger public: CallStaticJavaNode(Compile* C, const TypeFunc* tf, address addr, ciMethod* method) - : CallJavaNode(tf, addr, method) { + : CallJavaNode(tf, addr, method), _safe_for_fold_compare(true) { init_class_id(Class_CallStaticJava); if (C->eliminate_boxing() && (method != nullptr) && method->is_boxing_method()) { init_flags(Flag_is_macro); @@ -740,7 +743,7 @@ class CallStaticJavaNode : public CallJavaNode { } } CallStaticJavaNode(const TypeFunc* tf, address addr, const char* name, const TypePtr* adr_type) - : CallJavaNode(tf, addr, nullptr) { + : CallJavaNode(tf, addr, nullptr), _safe_for_fold_compare(true) { init_class_id(Class_CallStaticJava); // This node calls a runtime stub, which often has narrow memory effects. _adr_type = adr_type; @@ -763,6 +766,14 @@ class CallStaticJavaNode : public CallJavaNode { virtual int Opcode() const; virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + void clear_safe_for_fold_compare() { + _safe_for_fold_compare = false; + } + + bool safe_for_fold_compare() const { + return _safe_for_fold_compare; + } + #ifndef PRODUCT virtual void dump_spec(outputStream *st) const; virtual void dump_compact_spec(outputStream *st) const; diff --git a/src/hotspot/share/opto/cfgnode.hpp b/src/hotspot/share/opto/cfgnode.hpp index 2fc09b94288..dd249f2e14e 100644 --- a/src/hotspot/share/opto/cfgnode.hpp +++ b/src/hotspot/share/opto/cfgnode.hpp @@ -438,6 +438,7 @@ class IfNode : public MultiBranchNode { #ifndef PRODUCT virtual void dump_spec(outputStream *st) const; #endif + void mark_projections_unsafe_for_fold_compare() const; }; class RangeCheckNode : public IfNode { diff --git a/src/hotspot/share/opto/ifnode.cpp b/src/hotspot/share/opto/ifnode.cpp index f63f4ae8002..021c7264352 100644 --- a/src/hotspot/share/opto/ifnode.cpp +++ b/src/hotspot/share/opto/ifnode.cpp @@ -843,6 +843,10 @@ bool IfNode::has_only_uncommon_traps(ProjNode* proj, ProjNode*& success, ProjNod return false; } + if (!dom_unc->safe_for_fold_compare()) { + return false; + } + // See merge_uncommon_traps: the reason of the uncommon trap // will be changed and the state of the dominating If will be // used. Checked that we didn't apply this transformation in a @@ -1601,6 +1605,57 @@ Node* IfNode::search_identical(int dist) { return prev_dom; } +void IfNode::mark_projections_unsafe_for_fold_compare() const { + // With the following code pattern + // + // if (some_condition) { + // v = 0; + // } else { + // v = 1; + // } // v is Phi(0, 1) + // if (v == 0) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" above, captures v as stack argument to ifeq bytecode + // } + // if (some_other_condition) { + // uncommon_trap(); // reexecutes the "if (some_other_condition) {" + // } + // + // if the second if is split thru Phi, the result is: + // + // if (some_condition) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" that was removed above, captures v = 0 as stack argument to ifeq bytecode + // } + // if (some_other_condition) { + // uncommon_trap(); // reexecutes the "if (some_other_condition) {" + // } + // + // some_condition and some_other_condition could be folded into + // a single new condition that is narrower than some_condition + // (done by IfNode::fold_compares(), for instance): + // + // if (combined_narrower_condition) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" that was removed, captures v = 0 as stack argument to ifeq bytecode + // } + // + // Then combined_narrower_condition is true for some input value for + // which some_condition is false. When such an input value is used + // at runtime, the trap is taken which causes "if (v == 0) {" to be + // reexecuted with v = 0 even though some_condition is wrong, causing + // the wrong branch to be executed. + // + // Mark the uncommon trap nodes to prevent such a transformation + // from happening. + IfProjNode* true_projection = proj_out(1)->as_IfProj(); + IfProjNode* false_projection = proj_out(0)->as_IfProj(); + CallStaticJavaNode* unc = true_projection->is_uncommon_trap_proj(Deoptimization::Reason_none); + if (unc != nullptr) { + unc->clear_safe_for_fold_compare(); + } + unc = false_projection->is_uncommon_trap_proj(Deoptimization::Reason_none); + if (unc != nullptr) { + unc->clear_safe_for_fold_compare(); + } +} static int subsuming_bool_test_encode(Node*); diff --git a/src/hotspot/share/opto/split_if.cpp b/src/hotspot/share/opto/split_if.cpp index e90dfa1cc8e..034ac4ac574 100644 --- a/src/hotspot/share/opto/split_if.cpp +++ b/src/hotspot/share/opto/split_if.cpp @@ -598,6 +598,7 @@ void PhaseIdealLoop::do_split_if(Node* iff, RegionNode** new_false_region, Regio tty->print_cr("SplitIf"); } + iff->as_If()->mark_projections_unsafe_for_fold_compare(); C->set_major_progress(); RegionNode *region = iff->in(0)->as_Region(); Node *region_dom = idom(region); diff --git a/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java b/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java new file mode 100644 index 00000000000..9e77ceca0dd --- /dev/null +++ b/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2026 IBM Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** + * @test + * @bug 8376400 + * @summary C2: folding ifs may cause incorrect execution when trap is taken + * + * @run main/othervm -XX:-TieredCompilation -XX:-UseOnStackReplacement -XX:-BackgroundCompilation + * -XX:+UnlockDiagnosticVMOptions -XX:-OptimizeUnstableIf ${test.main.class} + * @run main ${test.main.class} + * + */ + +package compiler.rangechecks; + +public class TestFoldedIfsWrongReexec { + private static int taken1; + private static int taken2; + private static int taken3; + private static int taken4; + private static int taken5; + private static int taken6; + private static int taken7; + private static int MIN_VALUE = Integer.MIN_VALUE; + + public static void main(String[] args) { + for (int i = 0; i < 20_000; i++) { + test1(12); + if (taken1 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test1Helper1(16, 0); + test2(12); + if (taken2 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test2Helper1(16, 0); + test3(12); + if (taken3 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test3Helper1(16, 0); + test4(12, 1, 2); + if (taken4 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test4Helper1(16, 0, 1, 2); + test5(12); + if (taken5 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test5Helper1(16, 0); + test6(12, 1, 2); + if (taken6 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test6Helper1(16, 0, 1, 2); + test7(12); + if (taken7 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test7Helper1(16, 0); + test7Helper2(o1); + test7Helper2(a); + test7Helper2(b); + } + test1(0); + if (taken1 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test2(0); + if (taken2 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test3(0); + if (taken3 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test4(0, 1, 2); + if (taken4 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test5(0); + if (taken5 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test6(0, 1, 2); + if (taken6 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test7(0); + if (taken7 == 0) { + throw new RuntimeException("branch should have been taken"); + } + } + + private static void test1(int i) { + if (test1Helper1(i, 16) == 0) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken1++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test1Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 0; + } + return 1; + } + + private static void test2(int i) { + if (test2Helper1(i, 16) == 42) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken2++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test2Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 42; + } + return 0x42; + } + + private static void test3(int i) { + if (test3Helper1(i, 16) == 42L) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken3++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static long test3Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 42L; + } + return 0x42L; + } + + private static void test4(int i, int x, int y) { + if (x == y) { + throw new RuntimeException("never taken"); + } + if (test4Helper1(i, 16, x, y) == y) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken4++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test4Helper1(int i, int j, int x, int y) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return y; + } + return x; + } + + static final Object o1 = new Object(); + static final Object o2 = new Object(); + + private static void test5(int i) { + if (test5Helper1(i, 16) == o1) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken5++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static Object test5Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return o1; + } + return o2; + } + + private static void test6(int i, int x, int y) { + if (x < y) { + if (test6Helper1(i, 16, x, y) < y) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken6++; + } + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test6Helper1(int i, int j, int x, int y) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return x; + } + return y; + } + + static final Object a = new A(); + static final Object b = new B(); + + private static void test7(int i) { + if (test7Helper2(test7Helper1(i, 16))) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken7++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static Object test7Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return a; + } + return b; + } + + private static boolean test7Helper2(Object o) { + return o instanceof A; + } + + private static class A { + } + + private static class B { + } +}