8328528: C2 should optimize long-typed parallel iv in an int counted loop

Reviewed-by: roland, chagedorn, thartmann
This commit is contained in:
Kangcheng Xu 2024-10-21 14:57:31 +00:00
parent 330f2b5a9c
commit 80ec552248
3 changed files with 493 additions and 38 deletions

View File

@ -3951,6 +3951,41 @@ bool PhaseIdealLoop::is_deleteable_safept(Node* sfpt) {
//---------------------------replace_parallel_iv-------------------------------
// Replace parallel induction variable (parallel to trip counter)
// This optimization looks for patterns similar to:
//
// int a = init2;
// for (int iv = init; iv < limit; iv += stride_con) {
// a += stride_con2;
// }
//
// and transforms it to:
//
// int iv2 = init2
// int iv = init
// loop:
// if (iv >= limit) goto exit
// iv += stride_con
// iv2 = init2 + (iv - init) * (stride_con2 / stride_con)
// goto loop
// exit:
// ...
//
// Such transformation introduces more optimization opportunities. In this
// particular example, the loop can be eliminated entirely given that
// `stride_con2 / stride_con` is exact (i.e., no remainder). Checks are in
// place to only perform this optimization if such a division is exact. This
// example will be transformed into its semantic equivalence:
//
// int iv2 = (iv * stride_con2 / stride_con) + (init2 - (init * stride_con2 / stride_con))
//
// which corresponds to the structure of transformed subgraph.
//
// However, if there is a mismatch between types of the loop and the parallel
// induction variable (e.g., a long-typed IV in an int-typed loop), type
// conversions are required:
//
// long iv2 = ((long) iv * stride_con2 / stride_con) + (init2 - ((long) init * stride_con2 / stride_con))
//
void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) {
assert(loop->_head->is_CountedLoop(), "");
CountedLoopNode *cl = loop->_head->as_CountedLoop();
@ -3963,7 +3998,7 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) {
}
Node *init = cl->init_trip();
Node *phi = cl->phi();
int stride_con = cl->stride_con();
jlong stride_con = cl->stride_con();
// Visit all children, looking for Phis
for (DUIterator i = cl->outs(); cl->has_out(i); i++) {
@ -3980,7 +4015,7 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) {
incr2->req() != 3 ||
incr2->in(1)->uncast() != phi2 ||
incr2 == incr ||
incr2->Opcode() != Op_AddI ||
(incr2->Opcode() != Op_AddI && incr2->Opcode() != Op_AddL) ||
!incr2->in(2)->is_Con()) {
continue;
}
@ -3996,11 +4031,15 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) {
// the trip-counter, so we need to convert all these to trip-counter
// expressions.
Node* init2 = phi2->in(LoopNode::EntryControl);
int stride_con2 = incr2->in(2)->get_int();
// Determine the basic type of the stride constant (and the iv being incremented).
BasicType stride_con2_bt = incr2->Opcode() == Op_AddI ? T_INT : T_LONG;
jlong stride_con2 = incr2->in(2)->get_integer_as_long(stride_con2_bt);
// The ratio of the two strides cannot be represented as an int
// if stride_con2 is min_int and stride_con is -1.
if (stride_con2 == min_jint && stride_con == -1) {
// if stride_con2 is min_jint (or min_jlong, respectively) and
// stride_con is -1.
if (stride_con2 == min_signed_integer(stride_con2_bt) && stride_con == -1) {
continue;
}
@ -4011,44 +4050,67 @@ void PhaseIdealLoop::replace_parallel_iv(IdealLoopTree *loop) {
// Instead we require 'stride_con2' to be a multiple of 'stride_con',
// where +/-1 is the common case, but other integer multiples are
// also easy to handle.
int ratio_con = stride_con2/stride_con;
jlong ratio_con = stride_con2 / stride_con;
if ((ratio_con * stride_con) == stride_con2) { // Check for exact
#ifndef PRODUCT
if (TraceLoopOpts) {
tty->print("Parallel IV: %d ", phi2->_idx);
loop->dump_head();
}
#endif
// Convert to using the trip counter. The parallel induction
// variable differs from the trip counter by a loop-invariant
// amount, the difference between their respective initial values.
// It is scaled by the 'ratio_con'.
Node* ratio = _igvn.intcon(ratio_con);
set_ctrl(ratio, C->root());
Node* ratio_init = new MulINode(init, ratio);
_igvn.register_new_node_with_optimizer(ratio_init, init);
set_early_ctrl(ratio_init, false);
Node* diff = new SubINode(init2, ratio_init);
_igvn.register_new_node_with_optimizer(diff, init2);
set_early_ctrl(diff, false);
Node* ratio_idx = new MulINode(phi, ratio);
_igvn.register_new_node_with_optimizer(ratio_idx, phi);
set_ctrl(ratio_idx, cl);
Node* add = new AddINode(ratio_idx, diff);
_igvn.register_new_node_with_optimizer(add);
set_ctrl(add, cl);
_igvn.replace_node( phi2, add );
// Sometimes an induction variable is unused
if (add->outcnt() == 0) {
_igvn.remove_dead_node(add);
}
--i; // deleted this phi; rescan starting with next position
continue;
if ((ratio_con * stride_con) != stride_con2) { // Check for exact (no remainder)
continue;
}
#ifndef PRODUCT
if (TraceLoopOpts) {
tty->print("Parallel IV: %d ", phi2->_idx);
loop->dump_head();
}
#endif
// Convert to using the trip counter. The parallel induction
// variable differs from the trip counter by a loop-invariant
// amount, the difference between their respective initial values.
// It is scaled by the 'ratio_con'.
Node* ratio = _igvn.integercon(ratio_con, stride_con2_bt);
set_ctrl(ratio, C->root());
Node* init_converted = insert_convert_node_if_needed(stride_con2_bt, init);
Node* phi_converted = insert_convert_node_if_needed(stride_con2_bt, phi);
Node* ratio_init = MulNode::make(init_converted, ratio, stride_con2_bt);
_igvn.register_new_node_with_optimizer(ratio_init, init_converted);
set_early_ctrl(ratio_init, false);
Node* diff = SubNode::make(init2, ratio_init, stride_con2_bt);
_igvn.register_new_node_with_optimizer(diff, init2);
set_early_ctrl(diff, false);
Node* ratio_idx = MulNode::make(phi_converted, ratio, stride_con2_bt);
_igvn.register_new_node_with_optimizer(ratio_idx, phi_converted);
set_ctrl(ratio_idx, cl);
Node* add = AddNode::make(ratio_idx, diff, stride_con2_bt);
_igvn.register_new_node_with_optimizer(add);
set_ctrl(add, cl);
_igvn.replace_node( phi2, add );
// Sometimes an induction variable is unused
if (add->outcnt() == 0) {
_igvn.remove_dead_node(add);
}
--i; // deleted this phi; rescan starting with next position
}
}
Node* PhaseIdealLoop::insert_convert_node_if_needed(BasicType target, Node* input) {
BasicType source = _igvn.type(input)->basic_type();
if (source == target) {
return input;
}
Node* converted = ConvertNode::create_convert(source, target, input);
_igvn.register_new_node_with_optimizer(converted, input);
set_early_ctrl(converted, false);
return converted;
}
void IdealLoopTree::remove_safepoints(PhaseIdealLoop* phase, bool keep_one) {
Node* keep = nullptr;
if (keep_one) {

View File

@ -1134,6 +1134,8 @@ private:
}
#endif
Node* insert_convert_node_if_needed(BasicType target, Node* input);
public:
Node* idom_no_update(Node* d) const {
return idom_no_update(d->_idx);

View File

@ -0,0 +1,391 @@
/*
* Copyright (c) 2024 Red Hat and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.loopopts.parallel_iv;
import compiler.lib.ir_framework.*;
import jdk.test.lib.Asserts;
import jdk.test.lib.Utils;
import java.util.Random;
/**
* @test
* @bug 8328528
* @summary test the long typed parallel iv replacing transformation for int counted loop
* @library /test/lib /
* @requires vm.compiler2.enabled
* @run driver compiler.loopopts.parallel_iv.TestParallelIvInIntCountedLoop
*/
public class TestParallelIvInIntCountedLoop {
private static final Random RNG = Utils.getRandomInstance();
// stride2 must be a multiple of stride and must not overflow for the optimization to work
private static final int STRIDE = RNG.nextInt(1, Integer.MAX_VALUE / 16);
private static final int STRIDE_2 = STRIDE * RNG.nextInt(1, 16);
public static void main(String[] args) {
TestFramework.runWithFlags(
"-XX:+IgnoreUnrecognizedVMOptions", // StressLongCountedLoop is only available in debug builds
"-XX:StressLongCountedLoop=0", // Don't convert int counted loops to long ones
"-XX:PerMethodTrapLimit=100" // allow slow-path loop limit checks
);
}
/*
* The IR framework can only test against static code, and the transformation relies on strides being constants to
* perform constant propagation. Therefore, we have no choice but repeating the same test case multiple times with
* different numbers.
*
* For good measures, randomly initialized static final stride and stride2 is also tested.
*/
// A controlled test making sure a simple non-counted loop can be found by the test framework.
@Test
@Arguments(values = { Argument.NUMBER_42 }) // otherwise a large number may take too long
@IR(counts = { IRNode.COUNTED_LOOP, ">=1" })
private static int testControlledSimpleLoop(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += i; // cannot be extracted to multiplications
}
return a;
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIV(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIV")
private static void runTestIntCountedLoopWithIntIv() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(s, testIntCountedLoopWithIntIV(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVZero(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += 0;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVZero")
private static void runTestIntCountedLoopWithIntIVZero() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(0, testIntCountedLoopWithIntIVZero(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVMax(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += Integer.MAX_VALUE;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVMax")
private static void runTestIntCountedLoopWithIntIVMax() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(s * Integer.MAX_VALUE, testIntCountedLoopWithIntIVMax(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVMaxMinusOne(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += Integer.MAX_VALUE - 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVMaxMinusOne")
private static void runTestIntCountedLoopWithIntIVMaxMinusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(s * (Integer.MAX_VALUE - 1), testIntCountedLoopWithIntIVMaxMinusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVMaxPlusOne(int stop) {
int a = 0;
for (int i = 0; i < stop; i++) {
a += Integer.MAX_VALUE + 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVMaxPlusOne")
private static void runTestIntCountedLoopWithIntIVMaxPlusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(s * (Integer.MAX_VALUE + 1), testIntCountedLoopWithIntIVMaxPlusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVWithStrideTwo(int stop) {
int a = 0;
for (int i = 0; i < stop; i += 2) {
a += 2; // this stride2 constant must be a multiple of the first stride (i += ...) for optimization
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVWithStrideTwo")
private static void runTestIntCountedLoopWithIntIVWithStrideTwo() {
// Since we can't easily determine expected values if loop variables overflow when incrementing, we make sure
// `stop` is less than (MAX_VALUE - stride).
int s = RNG.nextInt(0, Integer.MAX_VALUE - 2);
Asserts.assertEQ(Math.ceilDiv(s, 2) * 2, testIntCountedLoopWithIntIVWithStrideTwo(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVWithStrideMinusOne(int stop) {
int a = 0;
for (int i = stop; i > 0; i += -1) {
a += 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVWithStrideMinusOne")
private static void runTestIntCountedLoopWithIntIVWithStrideMinusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ(s, testIntCountedLoopWithIntIVWithStrideMinusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVWithRandomStrides(int stop) {
int a = 0;
for (int i = 0; i < stop; i += STRIDE) {
a += STRIDE_2;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVWithRandomStrides")
private static void runTestIntCountedLoopWithIntIVWithRandomStrides() {
// Make sure `stop` is less than (MAX_VALUE - stride) to avoid overflows.
int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE);
Asserts.assertEQ(Math.ceilDiv(s, STRIDE) * STRIDE_2, testIntCountedLoopWithIntIVWithRandomStrides(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static int testIntCountedLoopWithIntIVWithRandomStridesAndInits(int init, int init2, int stop) {
int a = init;
for (int i = init2; i < stop; i += STRIDE) {
a += STRIDE_2;
}
return a;
}
@Run(test = "testIntCountedLoopWithIntIVWithRandomStridesAndInits")
private static void runTestIntCountedLoopWithIntIVWithRandomStridesAndInits() {
int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE);
int init1 = RNG.nextInt();
int init2 = RNG.nextInt(Integer.MIN_VALUE + s + 1, s); // Limit bounds to avoid loop variables from overflowing.
Asserts.assertEQ(Math.ceilDiv((s - init2), STRIDE) * STRIDE_2 + init1,
testIntCountedLoopWithIntIVWithRandomStridesAndInits(init1, init2, s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIV(int stop) {
long a = 0;
for (int i = 0; i < stop; i++) {
a += 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIV")
private static void runTestIntCountedLoopWithLongIV() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) s, testIntCountedLoopWithLongIV(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVZero(int stop) {
long a = 0;
for (int i = 0; i < stop; i++) {
a += 0;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVZero")
private static void runTestIntCountedLoopWithLongIVZero() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) 0, testIntCountedLoopWithLongIVZero(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVMax(int stop) {
long a = 0;
for (int i = 0; i < stop; i++) {
a += Long.MAX_VALUE;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVMax")
private static void runTestIntCountedLoopWithLongIVMax() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) s * Long.MAX_VALUE, testIntCountedLoopWithLongIVMax(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVMaxMinusOne(int stop) {
long a = 0;
for (int i = 0; i < stop; i++) {
a += Long.MAX_VALUE - 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVMaxMinusOne")
private static void runTestIntCountedLoopWithLongIVMaxMinusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) s * (Long.MAX_VALUE - 1L), testIntCountedLoopWithLongIVMaxMinusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVMaxPlusOne(int stop) {
long a = 0;
for (int i = 0; i < stop; i++) {
a += Long.MAX_VALUE + 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVMaxPlusOne")
private static void runTestIntCountedLoopWithLongIVMaxPlusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) s * (Long.MAX_VALUE + 1L), testIntCountedLoopWithLongIVMaxPlusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVWithStrideTwo(int stop) {
long a = 0;
for (int i = 0; i < stop; i += 2) {
a += 2;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVWithStrideTwo")
private static void runTestIntCountedLoopWithLongIVWithStrideTwo() {
int s = RNG.nextInt(0, Integer.MAX_VALUE - 2);
Asserts.assertEQ(Math.ceilDiv(s, 2L) * 2L, testIntCountedLoopWithLongIVWithStrideTwo(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVWithStrideMinusOne(int stop) {
long a = 0;
for (int i = stop; i > 0; i += -1) {
a += 1;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVWithStrideMinusOne")
private static void runTestIntCountedLoopWithLongIVWithStrideMinusOne() {
int s = RNG.nextInt(0, Integer.MAX_VALUE);
Asserts.assertEQ((long) s, testIntCountedLoopWithLongIVWithStrideMinusOne(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVWithRandomStrides(int stop) {
long a = 0;
for (int i = 0; i < stop; i += STRIDE) {
a += STRIDE_2;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVWithRandomStrides")
private static void runTestIntCountedLoopWithLongIVWithRandomStrides() {
int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE);
Asserts.assertEQ(Math.ceilDiv(s, (long) STRIDE) * (long) STRIDE_2,
testIntCountedLoopWithLongIVWithRandomStrides(s));
}
@Test
@IR(failOn = { IRNode.COUNTED_LOOP })
private static long testIntCountedLoopWithLongIVWithRandomStridesAndInits(long init, int init2, int stop) {
long a = init;
for (int i = init2; i < stop; i += STRIDE) {
a += STRIDE_2;
}
return a;
}
@Run(test = "testIntCountedLoopWithLongIVWithRandomStridesAndInits")
private static void runTestIntCountedLoopWithLongIVWithRandomStridesAndInits() {
int s = RNG.nextInt(0, Integer.MAX_VALUE - STRIDE);
long init1 = RNG.nextLong();
int init2 = RNG.nextInt(Integer.MIN_VALUE + s + 1, s); // Limit bounds to avoid loop variables from overflowing.
Asserts.assertEQ(Math.ceilDiv(((long) s - init2), (long) STRIDE) * (long) STRIDE_2 + init1,
testIntCountedLoopWithLongIVWithRandomStridesAndInits(init1, init2, s));
}
}