From 80ec552248470dda2d0d003be9315e9e39eb5276 Mon Sep 17 00:00:00 2001 From: Kangcheng Xu Date: Mon, 21 Oct 2024 14:57:31 +0000 Subject: [PATCH] 8328528: C2 should optimize long-typed parallel iv in an int counted loop Reviewed-by: roland, chagedorn, thartmann --- src/hotspot/share/opto/loopnode.cpp | 138 +++++-- src/hotspot/share/opto/loopnode.hpp | 2 + .../TestParallelIvInIntCountedLoop.java | 391 ++++++++++++++++++ 3 files changed, 493 insertions(+), 38 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/loopopts/parallel_iv/TestParallelIvInIntCountedLoop.java diff --git a/src/hotspot/share/opto/loopnode.cpp b/src/hotspot/share/opto/loopnode.cpp index 90b15381f82..cadc6047397 100644 --- a/src/hotspot/share/opto/loopnode.cpp +++ b/src/hotspot/share/opto/loopnode.cpp @@ -3951,6 +3951,41 @@ bool PhaseIdealLoop::is_deleteable_safept(Node* sfpt) { //---------------------------replace_parallel_iv------------------------------- // Replace parallel induction variable (parallel to trip counter) +// This optimization looks for patterns similar to: +// +// int a = init2; +// for (int iv = init; iv < limit; iv += stride_con) { +// a += stride_con2; +// } +// +// and transforms it to: +// +// int iv2 = init2 +// int iv = init +// loop: +// if (iv >= limit) goto exit +// iv += stride_con +// iv2 = init2 + (iv - init) * (stride_con2 / stride_con) +// goto loop +// exit: +// ... +// +// Such transformation introduces more optimization opportunities. In this +// particular example, the loop can be eliminated entirely given that +// `stride_con2 / stride_con` is exact (i.e., no remainder). Checks are in +// place to only perform this optimization if such a division is exact. This +// example will be transformed into its semantic equivalence: +// +// int iv2 = (iv * stride_con2 / stride_con) + (init2 - (init * stride_con2 / stride_con)) +// +// which corresponds to the structure of transformed subgraph. +// +// However, if there is a mismatch between types of the loop and the parallel +// induction variable (e.g., a long-typed IV in an int-typed loop), type +// conversions are required: +// +// long iv2 = ((long) iv * stride_con2 / stride_con) + (init2 - ((long) init * stride_con2 / stride_con)) +// void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) { assert(loop->_head->is_CountedLoop(), ""); CountedLoopNode *cl = loop->_head->as_CountedLoop(); @@ -3963,7 +3998,7 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) { } Node *init = cl->init_trip(); Node *phi = cl->phi(); - int stride_con = cl->stride_con(); + jlong stride_con = cl->stride_con(); // Visit all children, looking for Phis for (DUIterator i = cl->outs(); cl->has_out(i); i++) { @@ -3980,7 +4015,7 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) { incr2->req() != 3 || incr2->in(1)->uncast() != phi2 || incr2 == incr || - incr2->Opcode() != Op_AddI || + (incr2->Opcode() != Op_AddI && incr2->Opcode() != Op_AddL) || !incr2->in(2)->is_Con()) { continue; } @@ -3996,11 +4031,15 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) { // the trip-counter, so we need to convert all these to trip-counter // expressions. Node* init2 = phi2->in(LoopNode::EntryControl); - int stride_con2 = incr2->in(2)->get_int(); + + // Determine the basic type of the stride constant (and the iv being incremented). + BasicType stride_con2_bt = incr2->Opcode() == Op_AddI ? T_INT : T_LONG; + jlong stride_con2 = incr2->in(2)->get_integer_as_long(stride_con2_bt); // The ratio of the two strides cannot be represented as an int - // if stride_con2 is min_int and stride_con is -1. - if (stride_con2 == min_jint && stride_con == -1) { + // if stride_con2 is min_jint (or min_jlong, respectively) and + // stride_con is -1. + if (stride_con2 == min_signed_integer(stride_con2_bt) && stride_con == -1) { continue; } @@ -4011,44 +4050,67 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) { // Instead we require 'stride_con2' to be a multiple of 'stride_con', // where +/-1 is the common case, but other integer multiples are // also easy to handle. - int ratio_con = stride_con2/stride_con; + jlong ratio_con = stride_con2 / stride_con; - if ((ratio_con * stride_con) == stride_con2) { // Check for exact -#ifndef PRODUCT - if (TraceLoopOpts) { - tty->print("Parallel IV: %d ", phi2->_idx); - loop->dump_head(); - } -#endif - // Convert to using the trip counter. The parallel induction - // variable differs from the trip counter by a loop-invariant - // amount, the difference between their respective initial values. - // It is scaled by the 'ratio_con'. - Node* ratio = _igvn.intcon(ratio_con); - set_ctrl(ratio, C->root()); - Node* ratio_init = new MulINode(init, ratio); - _igvn.register_new_node_with_optimizer(ratio_init, init); - set_early_ctrl(ratio_init, false); - Node* diff = new SubINode(init2, ratio_init); - _igvn.register_new_node_with_optimizer(diff, init2); - set_early_ctrl(diff, false); - Node* ratio_idx = new MulINode(phi, ratio); - _igvn.register_new_node_with_optimizer(ratio_idx, phi); - set_ctrl(ratio_idx, cl); - Node* add = new AddINode(ratio_idx, diff); - _igvn.register_new_node_with_optimizer(add); - set_ctrl(add, cl); - _igvn.replace_node( phi2, add ); - // Sometimes an induction variable is unused - if (add->outcnt() == 0) { - _igvn.remove_dead_node(add); - } - --i; // deleted this phi; rescan starting with next position - continue; + if ((ratio_con * stride_con) != stride_con2) { // Check for exact (no remainder) + continue; } + +#ifndef PRODUCT + if (TraceLoopOpts) { + tty->print("Parallel IV: %d ", phi2->_idx); + loop->dump_head(); + } +#endif + + // Convert to using the trip counter. The parallel induction + // variable differs from the trip counter by a loop-invariant + // amount, the difference between their respective initial values. + // It is scaled by the 'ratio_con'. + Node* ratio = _igvn.integercon(ratio_con, stride_con2_bt); + set_ctrl(ratio, C->root()); + + Node* init_converted = insert_convert_node_if_needed(stride_con2_bt, init); + Node* phi_converted = insert_convert_node_if_needed(stride_con2_bt, phi); + + Node* ratio_init = MulNode::make(init_converted, ratio, stride_con2_bt); + _igvn.register_new_node_with_optimizer(ratio_init, init_converted); + set_early_ctrl(ratio_init, false); + + Node* diff = SubNode::make(init2, ratio_init, stride_con2_bt); + _igvn.register_new_node_with_optimizer(diff, init2); + set_early_ctrl(diff, false); + + Node* ratio_idx = MulNode::make(phi_converted, ratio, stride_con2_bt); + _igvn.register_new_node_with_optimizer(ratio_idx, phi_converted); + set_ctrl(ratio_idx, cl); + + Node* add = AddNode::make(ratio_idx, diff, stride_con2_bt); + _igvn.register_new_node_with_optimizer(add); + set_ctrl(add, cl); + + _igvn.replace_node( phi2, add ); + // Sometimes an induction variable is unused + if (add->outcnt() == 0) { + _igvn.remove_dead_node(add); + } + --i; // deleted this phi; rescan starting with next position } } +Node* PhaseIdealLoop::insert_convert_node_if_needed(BasicType target, Node* input) { + BasicType source = _igvn.type(input)->basic_type(); + if (source == target) { + return input; + } + + Node* converted = ConvertNode::create_convert(source, target, input); + _igvn.register_new_node_with_optimizer(converted, input); + set_early_ctrl(converted, false); + + return converted; +} + void IdealLoopTree::remove_safepoints(PhaseIdealLoop* phase, bool keep_one) { Node* keep = nullptr; if (keep_one) { diff --git a/src/hotspot/share/opto/loopnode.hpp b/src/hotspot/share/opto/loopnode.hpp index ef27eb652f7..af32e2366e8 100644 --- a/src/hotspot/share/opto/loopnode.hpp +++ b/src/hotspot/share/opto/loopnode.hpp @@ -1134,6 +1134,8 @@ private: } #endif + Node* insert_convert_node_if_needed(BasicType target, Node* input); + public: Node* idom_no_update(Node* d) const { return idom_no_update(d->_idx); diff --git a/test/hotspot/jtreg/compiler/loopopts/parallel_iv/TestParallelIvInIntCountedLoop.java b/test/hotspot/jtreg/compiler/loopopts/parallel_iv/TestParallelIvInIntCountedLoop.java new file mode 100644 index 00000000000..95ba9e6e795 --- /dev/null +++ b/test/hotspot/jtreg/compiler/loopopts/parallel_iv/TestParallelIvInIntCountedLoop.java @@ -0,0 +1,391 @@ +/* + * Copyright (c) 2024 Red Hat and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package compiler.loopopts.parallel_iv; + +import compiler.lib.ir_framework.*; +import jdk.test.lib.Asserts; +import jdk.test.lib.Utils; + +import java.util.Random; + +/** + * @test + * @bug 8328528 + * @summary test the long typed parallel iv replacing transformation for int counted loop + * @library /test/lib / + * @requires vm.compiler2.enabled + * @run driver compiler.loopopts.parallel_iv.TestParallelIvInIntCountedLoop + */ +public class TestParallelIvInIntCountedLoop { + private static final Random RNG = Utils.getRandomInstance(); + + // stride2 must be a multiple of stride and must not overflow for the optimization to work + private static final int STRIDE = RNG.nextInt(1, Integer.MAX_VALUE / 16); + private static final int STRIDE_2 = STRIDE * RNG.nextInt(1, 16); + + public static void main(String[] args) { + TestFramework.runWithFlags( + "-XX:+IgnoreUnrecognizedVMOptions", // StressLongCountedLoop is only available in debug builds + "-XX:StressLongCountedLoop=0", // Don't convert int counted loops to long ones + "-XX:PerMethodTrapLimit=100" // allow slow-path loop limit checks + ); + } + + /* + * The IR framework can only test against static code, and the transformation relies on strides being constants to + * perform constant propagation. Therefore, we have no choice but repeating the same test case multiple times with + * different numbers. + * + * For good measures, randomly initialized static final stride and stride2 is also tested. + */ + + // A controlled test making sure a simple non-counted loop can be found by the test framework. + @Test + @Arguments(values = { Argument.NUMBER_42 }) // otherwise a large number may take too long + @IR(counts = { IRNode.COUNTED_LOOP, ">=1" }) + private static int testControlledSimpleLoop(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += i; // cannot be extracted to multiplications + } + + return a; + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIV(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIV") + private static void runTestIntCountedLoopWithIntIv() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(s, testIntCountedLoopWithIntIV(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVZero(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += 0; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVZero") + private static void runTestIntCountedLoopWithIntIVZero() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(0, testIntCountedLoopWithIntIVZero(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVMax(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += Integer.MAX_VALUE; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVMax") + private static void runTestIntCountedLoopWithIntIVMax() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(s * Integer.MAX_VALUE, testIntCountedLoopWithIntIVMax(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVMaxMinusOne(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += Integer.MAX_VALUE - 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVMaxMinusOne") + private static void runTestIntCountedLoopWithIntIVMaxMinusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(s * (Integer.MAX_VALUE - 1), testIntCountedLoopWithIntIVMaxMinusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVMaxPlusOne(int stop) { + int a = 0; + for (int i = 0; i < stop; i++) { + a += Integer.MAX_VALUE + 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVMaxPlusOne") + private static void runTestIntCountedLoopWithIntIVMaxPlusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(s * (Integer.MAX_VALUE + 1), testIntCountedLoopWithIntIVMaxPlusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVWithStrideTwo(int stop) { + int a = 0; + for (int i = 0; i < stop; i += 2) { + a += 2; // this stride2 constant must be a multiple of the first stride (i += ...) for optimization + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVWithStrideTwo") + private static void runTestIntCountedLoopWithIntIVWithStrideTwo() { + // Since we can't easily determine expected values if loop variables overflow when incrementing, we make sure + // `stop` is less than (MAX_VALUE - stride). + int s = RNG.nextInt(0, Integer.MAX_VALUE - 2); + Asserts.assertEQ(Math.ceilDiv(s, 2) * 2, testIntCountedLoopWithIntIVWithStrideTwo(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVWithStrideMinusOne(int stop) { + int a = 0; + for (int i = stop; i > 0; i += -1) { + a += 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVWithStrideMinusOne") + private static void runTestIntCountedLoopWithIntIVWithStrideMinusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ(s, testIntCountedLoopWithIntIVWithStrideMinusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVWithRandomStrides(int stop) { + int a = 0; + for (int i = 0; i < stop; i += STRIDE) { + a += STRIDE_2; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVWithRandomStrides") + private static void runTestIntCountedLoopWithIntIVWithRandomStrides() { + // Make sure `stop` is less than (MAX_VALUE - stride) to avoid overflows. + int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE); + Asserts.assertEQ(Math.ceilDiv(s, STRIDE) * STRIDE_2, testIntCountedLoopWithIntIVWithRandomStrides(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static int testIntCountedLoopWithIntIVWithRandomStridesAndInits(int init, int init2, int stop) { + int a = init; + for (int i = init2; i < stop; i += STRIDE) { + a += STRIDE_2; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithIntIVWithRandomStridesAndInits") + private static void runTestIntCountedLoopWithIntIVWithRandomStridesAndInits() { + int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE); + int init1 = RNG.nextInt(); + int init2 = RNG.nextInt(Integer.MIN_VALUE + s + 1, s); // Limit bounds to avoid loop variables from overflowing. + Asserts.assertEQ(Math.ceilDiv((s - init2), STRIDE) * STRIDE_2 + init1, + testIntCountedLoopWithIntIVWithRandomStridesAndInits(init1, init2, s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIV(int stop) { + long a = 0; + for (int i = 0; i < stop; i++) { + a += 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIV") + private static void runTestIntCountedLoopWithLongIV() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) s, testIntCountedLoopWithLongIV(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVZero(int stop) { + long a = 0; + for (int i = 0; i < stop; i++) { + a += 0; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVZero") + private static void runTestIntCountedLoopWithLongIVZero() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) 0, testIntCountedLoopWithLongIVZero(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVMax(int stop) { + long a = 0; + for (int i = 0; i < stop; i++) { + a += Long.MAX_VALUE; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVMax") + private static void runTestIntCountedLoopWithLongIVMax() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) s * Long.MAX_VALUE, testIntCountedLoopWithLongIVMax(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVMaxMinusOne(int stop) { + long a = 0; + for (int i = 0; i < stop; i++) { + a += Long.MAX_VALUE - 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVMaxMinusOne") + private static void runTestIntCountedLoopWithLongIVMaxMinusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) s * (Long.MAX_VALUE - 1L), testIntCountedLoopWithLongIVMaxMinusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVMaxPlusOne(int stop) { + long a = 0; + for (int i = 0; i < stop; i++) { + a += Long.MAX_VALUE + 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVMaxPlusOne") + private static void runTestIntCountedLoopWithLongIVMaxPlusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) s * (Long.MAX_VALUE + 1L), testIntCountedLoopWithLongIVMaxPlusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVWithStrideTwo(int stop) { + long a = 0; + for (int i = 0; i < stop; i += 2) { + a += 2; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVWithStrideTwo") + private static void runTestIntCountedLoopWithLongIVWithStrideTwo() { + int s = RNG.nextInt(0, Integer.MAX_VALUE - 2); + Asserts.assertEQ(Math.ceilDiv(s, 2L) * 2L, testIntCountedLoopWithLongIVWithStrideTwo(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVWithStrideMinusOne(int stop) { + long a = 0; + for (int i = stop; i > 0; i += -1) { + a += 1; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVWithStrideMinusOne") + private static void runTestIntCountedLoopWithLongIVWithStrideMinusOne() { + int s = RNG.nextInt(0, Integer.MAX_VALUE); + Asserts.assertEQ((long) s, testIntCountedLoopWithLongIVWithStrideMinusOne(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVWithRandomStrides(int stop) { + long a = 0; + for (int i = 0; i < stop; i += STRIDE) { + a += STRIDE_2; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVWithRandomStrides") + private static void runTestIntCountedLoopWithLongIVWithRandomStrides() { + int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE); + Asserts.assertEQ(Math.ceilDiv(s, (long) STRIDE) * (long) STRIDE_2, + testIntCountedLoopWithLongIVWithRandomStrides(s)); + } + + @Test + @IR(failOn = { IRNode.COUNTED_LOOP }) + private static long testIntCountedLoopWithLongIVWithRandomStridesAndInits(long init, int init2, int stop) { + long a = init; + for (int i = init2; i < stop; i += STRIDE) { + a += STRIDE_2; + } + + return a; + } + + @Run(test = "testIntCountedLoopWithLongIVWithRandomStridesAndInits") + private static void runTestIntCountedLoopWithLongIVWithRandomStridesAndInits() { + int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE); + long init1 = RNG.nextLong(); + int init2 = RNG.nextInt(Integer.MIN_VALUE + s + 1, s); // Limit bounds to avoid loop variables from overflowing. + Asserts.assertEQ(Math.ceilDiv(((long) s - init2), (long) STRIDE) * (long) STRIDE_2 + init1, + testIntCountedLoopWithLongIVWithRandomStridesAndInits(init1, init2, s)); + } +}