8302673: [SuperWord] MaxReduction and MinReduction should vectorize for int

Co-authored-by: Jatin Bhateja <jbhateja@openjdk.org>
Reviewed-by: epeter, kvn
This commit is contained in:
Roberto Castañeda Lozano 2023-06-05 07:08:33 +00:00
parent 22a9a86be0
commit 3fa776d66a
5 changed files with 414 additions and 182 deletions

View File

@ -1030,6 +1030,14 @@ const Type* XorLNode::Value(PhaseGVN* phase) const {
return AddNode::Value(phase);
}
Node* build_min_max_int(Node* a, Node* b, bool is_max) {
if (is_max) {
return new MaxINode(a, b);
} else {
return new MinINode(a, b);
}
}
Node* MaxNode::build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn) {
bool is_int = gvn.type(a)->isa_int();
assert(is_int || gvn.type(a)->isa_long(), "int or long inputs");
@ -1044,13 +1052,7 @@ Node* MaxNode::build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, co
}
Node* res = nullptr;
if (is_int && !is_unsigned) {
Node* res_new = nullptr;
if (is_max) {
res_new = new MaxINode(a, b);
} else {
res_new = new MinINode(a, b);
}
res = gvn.transform(res_new);
res = gvn.transform(build_min_max_int(a, b, is_max));
assert(gvn.type(res)->is_int()->_lo >= t->is_int()->_lo && gvn.type(res)->is_int()->_hi <= t->is_int()->_hi, "type doesn't match");
} else {
Node* cmp = nullptr;
@ -1096,6 +1098,113 @@ Node* MaxNode::build_min_max_diff_with_zero(Node* a, Node* b, bool is_max, const
return res;
}
// Check if addition of an integer with type 't' and a constant 'c' can overflow.
static bool can_overflow(const TypeInt* t, jint c) {
jint t_lo = t->_lo;
jint t_hi = t->_hi;
return ((c < 0 && (java_add(t_lo, c) > t_lo)) ||
(c > 0 && (java_add(t_hi, c) < t_hi)));
}
// Let <x, x_off> = x_operands and <y, y_off> = y_operands.
// If x == y and neither add(x, x_off) nor add(y, y_off) overflow, return
// add(x, op(x_off, y_off)). Otherwise, return nullptr.
Node* MaxNode::extract_add(PhaseGVN* phase, ConstAddOperands x_operands, ConstAddOperands y_operands) {
Node* x = x_operands.first;
Node* y = y_operands.first;
int opcode = Opcode();
assert(opcode == Op_MaxI || opcode == Op_MinI, "Unexpected opcode");
const TypeInt* tx = phase->type(x)->isa_int();
jint x_off = x_operands.second;
jint y_off = y_operands.second;
if (x == y && tx != nullptr &&
!can_overflow(tx, x_off) &&
!can_overflow(tx, y_off)) {
jint c = opcode == Op_MinI ? MIN2(x_off, y_off) : MAX2(x_off, y_off);
return new AddINode(x, phase->intcon(c));
}
return nullptr;
}
// Try to cast n as an integer addition with a constant. Return:
// <x, C>, if n == add(x, C), where 'C' is a non-TOP constant;
// <nullptr, 0>, if n == add(x, C), where 'C' is a TOP constant; or
// <n, 0>, otherwise.
static ConstAddOperands as_add_with_constant(Node* n) {
if (n->Opcode() != Op_AddI) {
return ConstAddOperands(n, 0);
}
Node* x = n->in(1);
Node* c = n->in(2);
if (!c->is_Con()) {
return ConstAddOperands(n, 0);
}
const Type* c_type = c->bottom_type();
if (c_type == Type::TOP) {
return ConstAddOperands(nullptr, 0);
}
return ConstAddOperands(x, c_type->is_int()->get_con());
}
Node* MaxNode::IdealI(PhaseGVN* phase, bool can_reshape) {
int opcode = Opcode();
assert(opcode == Op_MinI || opcode == Op_MaxI, "Unexpected opcode");
// Try to transform the following pattern, in any of its four possible
// permutations induced by op's commutativity:
// op(op(add(inner, inner_off), inner_other), add(outer, outer_off))
// into
// op(add(inner, op(inner_off, outer_off)), inner_other),
// where:
// op is either MinI or MaxI, and
// inner == outer, and
// the additions cannot overflow.
for (uint inner_op_index = 1; inner_op_index <= 2; inner_op_index++) {
if (in(inner_op_index)->Opcode() != opcode) {
continue;
}
Node* outer_add = in(inner_op_index == 1 ? 2 : 1);
ConstAddOperands outer_add_operands = as_add_with_constant(outer_add);
if (outer_add_operands.first == nullptr) {
return nullptr; // outer_add has a TOP input, no need to continue.
}
// One operand is a MinI/MaxI and the other is an integer addition with
// constant. Test the operands of the inner MinI/MaxI.
for (uint inner_add_index = 1; inner_add_index <= 2; inner_add_index++) {
Node* inner_op = in(inner_op_index);
Node* inner_add = inner_op->in(inner_add_index);
ConstAddOperands inner_add_operands = as_add_with_constant(inner_add);
if (inner_add_operands.first == nullptr) {
return nullptr; // inner_add has a TOP input, no need to continue.
}
// Try to extract the inner add.
Node* add_extracted = extract_add(phase, inner_add_operands, outer_add_operands);
if (add_extracted == nullptr) {
continue;
}
Node* add_transformed = phase->transform(add_extracted);
Node* inner_other = inner_op->in(inner_add_index == 1 ? 2 : 1);
return build_min_max_int(add_transformed, inner_other, opcode == Op_MaxI);
}
}
// Try to transform
// op(add(x, x_off), add(y, y_off))
// into
// add(x, op(x_off, y_off)),
// where:
// op is either MinI or MaxI, and
// inner == outer, and
// the additions cannot overflow.
ConstAddOperands xC = as_add_with_constant(in(1));
ConstAddOperands yC = as_add_with_constant(in(2));
if (xC.first == nullptr || yC.first == nullptr) return nullptr;
return extract_add(phase, xC, yC);
}
// Ideal transformations for MaxINode
Node* MaxINode::Ideal(PhaseGVN* phase, bool can_reshape) {
return IdealI(phase, can_reshape);
}
//=============================================================================
//------------------------------add_ring---------------------------------------
// Supplied function returns the sum of the inputs.
@ -1107,174 +1216,12 @@ const Type *MaxINode::add_ring( const Type *t0, const Type *t1 ) const {
return TypeInt::make( MAX2(r0->_lo,r1->_lo), MAX2(r0->_hi,r1->_hi), MAX2(r0->_widen,r1->_widen) );
}
// Check if addition of an integer with type 't' and a constant 'c' can overflow
static bool can_overflow(const TypeInt* t, jint c) {
jint t_lo = t->_lo;
jint t_hi = t->_hi;
return ((c < 0 && (java_add(t_lo, c) > t_lo)) ||
(c > 0 && (java_add(t_hi, c) < t_hi)));
}
// Ideal transformations for MaxINode
Node* MaxINode::Ideal(PhaseGVN* phase, bool can_reshape) {
// Force a right-spline graph
Node* l = in(1);
Node* r = in(2);
// Transform MaxI1(MaxI2(a, b), c) into MaxI1(a, MaxI2(b, c))
// to force a right-spline graph for the rest of MaxINode::Ideal().
if (l->Opcode() == Op_MaxI) {
assert(l != l->in(1), "dead loop in MaxINode::Ideal");
r = phase->transform(new MaxINode(l->in(2), r));
l = l->in(1);
set_req_X(1, l, phase);
set_req_X(2, r, phase);
return this;
}
// Get left input & constant
Node* x = l;
jint x_off = 0;
if (x->Opcode() == Op_AddI && // Check for "x+c0" and collect constant
x->in(2)->is_Con()) {
const Type* t = x->in(2)->bottom_type();
if (t == Type::TOP) return nullptr; // No progress
x_off = t->is_int()->get_con();
x = x->in(1);
}
// Scan a right-spline-tree for MAXs
Node* y = r;
jint y_off = 0;
// Check final part of MAX tree
if (y->Opcode() == Op_AddI && // Check for "y+c1" and collect constant
y->in(2)->is_Con()) {
const Type* t = y->in(2)->bottom_type();
if (t == Type::TOP) return nullptr; // No progress
y_off = t->is_int()->get_con();
y = y->in(1);
}
if (x->_idx > y->_idx && r->Opcode() != Op_MaxI) {
swap_edges(1, 2);
return this;
}
const TypeInt* tx = phase->type(x)->isa_int();
if (r->Opcode() == Op_MaxI) {
assert(r != r->in(2), "dead loop in MaxINode::Ideal");
y = r->in(1);
// Check final part of MAX tree
if (y->Opcode() == Op_AddI &&// Check for "y+c1" and collect constant
y->in(2)->is_Con()) {
const Type* t = y->in(2)->bottom_type();
if (t == Type::TOP) return nullptr; // No progress
y_off = t->is_int()->get_con();
y = y->in(1);
}
if (x->_idx > y->_idx)
return new MaxINode(r->in(1), phase->transform(new MaxINode(l, r->in(2))));
// Transform MAX2(x + c0, MAX2(x + c1, z)) into MAX2(x + MAX2(c0, c1), z)
// if x == y and the additions can't overflow.
if (x == y && tx != nullptr &&
!can_overflow(tx, x_off) &&
!can_overflow(tx, y_off)) {
return new MaxINode(phase->transform(new AddINode(x, phase->intcon(MAX2(x_off, y_off)))), r->in(2));
}
} else {
// Transform MAX2(x + c0, y + c1) into x + MAX2(c0, c1)
// if x == y and the additions can't overflow.
if (x == y && tx != nullptr &&
!can_overflow(tx, x_off) &&
!can_overflow(tx, y_off)) {
return new AddINode(x, phase->intcon(MAX2(x_off, y_off)));
}
}
return nullptr;
}
//=============================================================================
//------------------------------Idealize---------------------------------------
// MINs show up in range-check loop limit calculations. Look for
// "MIN2(x+c0,MIN2(y,x+c1))". Pick the smaller constant: "MIN2(x+c0,y)"
Node *MinINode::Ideal(PhaseGVN *phase, bool can_reshape) {
Node *progress = nullptr;
// Force a right-spline graph
Node *l = in(1);
Node *r = in(2);
// Transform MinI1( MinI2(a,b), c) into MinI1( a, MinI2(b,c) )
// to force a right-spline graph for the rest of MinINode::Ideal().
if( l->Opcode() == Op_MinI ) {
assert( l != l->in(1), "dead loop in MinINode::Ideal" );
r = phase->transform(new MinINode(l->in(2),r));
l = l->in(1);
set_req_X(1, l, phase);
set_req_X(2, r, phase);
return this;
}
// Get left input & constant
Node *x = l;
jint x_off = 0;
if( x->Opcode() == Op_AddI && // Check for "x+c0" and collect constant
x->in(2)->is_Con() ) {
const Type *t = x->in(2)->bottom_type();
if( t == Type::TOP ) return nullptr; // No progress
x_off = t->is_int()->get_con();
x = x->in(1);
}
// Scan a right-spline-tree for MINs
Node *y = r;
jint y_off = 0;
// Check final part of MIN tree
if( y->Opcode() == Op_AddI && // Check for "y+c1" and collect constant
y->in(2)->is_Con() ) {
const Type *t = y->in(2)->bottom_type();
if( t == Type::TOP ) return nullptr; // No progress
y_off = t->is_int()->get_con();
y = y->in(1);
}
if( x->_idx > y->_idx && r->Opcode() != Op_MinI ) {
swap_edges(1, 2);
return this;
}
const TypeInt* tx = phase->type(x)->isa_int();
if( r->Opcode() == Op_MinI ) {
assert( r != r->in(2), "dead loop in MinINode::Ideal" );
y = r->in(1);
// Check final part of MIN tree
if( y->Opcode() == Op_AddI &&// Check for "y+c1" and collect constant
y->in(2)->is_Con() ) {
const Type *t = y->in(2)->bottom_type();
if( t == Type::TOP ) return nullptr; // No progress
y_off = t->is_int()->get_con();
y = y->in(1);
}
if( x->_idx > y->_idx )
return new MinINode(r->in(1),phase->transform(new MinINode(l,r->in(2))));
// Transform MIN2(x + c0, MIN2(x + c1, z)) into MIN2(x + MIN2(c0, c1), z)
// if x == y and the additions can't overflow.
if (x == y && tx != nullptr &&
!can_overflow(tx, x_off) &&
!can_overflow(tx, y_off)) {
return new MinINode(phase->transform(new AddINode(x, phase->intcon(MIN2(x_off, y_off)))), r->in(2));
}
} else {
// Transform MIN2(x + c0, y + c1) into x + MIN2(c0, c1)
// if x == y and the additions can't overflow.
if (x == y && tx != nullptr &&
!can_overflow(tx, x_off) &&
!can_overflow(tx, y_off)) {
return new AddINode(x,phase->intcon(MIN2(x_off,y_off)));
}
}
return nullptr;
Node* MinINode::Ideal(PhaseGVN* phase, bool can_reshape) {
return IdealI(phase, can_reshape);
}
//------------------------------add_ring---------------------------------------

View File

@ -28,10 +28,12 @@
#include "opto/node.hpp"
#include "opto/opcodes.hpp"
#include "opto/type.hpp"
#include "utilities/pair.hpp"
// Portions of code courtesy of Clifford Click
class PhaseTransform;
typedef const Pair<Node*, jint> ConstAddOperands;
//------------------------------AddNode----------------------------------------
// Classic Add functionality. This covers all the usual 'add' behaviors for
@ -252,12 +254,14 @@ class MaxNode : public AddNode {
private:
static Node* build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn);
static Node* build_min_max_diff_with_zero(Node* a, Node* b, bool is_max, const Type* t, PhaseGVN& gvn);
Node* extract_add(PhaseGVN* phase, ConstAddOperands x_operands, ConstAddOperands y_operands);
public:
MaxNode( Node *in1, Node *in2 ) : AddNode(in1,in2) {}
virtual int Opcode() const = 0;
virtual int max_opcode() const = 0;
virtual int min_opcode() const = 0;
Node* IdealI(PhaseGVN* phase, bool can_reshape);
static Node* unsigned_max(Node* a, Node* b, const Type* t, PhaseGVN& gvn) {
return build_min_max(a, b, true, true, t, gvn);

View File

@ -1,4 +1,5 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, Arm Limited. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
@ -39,22 +40,38 @@ public class MaxMinINodeIdealizationTests {
TestFramework.run();
}
@Run(test = {"testMax1", "testMax2", "testMax3", "testMin1", "testMin2", "testMin3"})
public void runMethod() {
@Run(test = {"testMax1LL", "testMax1LR", "testMax1RL", "testMax1RR",
"testMax1LLNoInnerAdd", "testMax1LLNoInnerAdd2", "testMax1LLNoOuterAdd", "testMax1LLNoAdd",
"testMax2L", "testMax2R",
"testMax2LNoLeftAdd",
"testMax3",
"testMin1",
"testMin2",
"testMin3"})
public void runPositiveTests() {
int a = RunInfo.getRandom().nextInt();
int min = Integer.MIN_VALUE;
int max = Integer.MAX_VALUE;
assertResult(a);
assertResult(0);
assertResult(min);
assertResult(max);
assertPositiveResult(a);
assertPositiveResult(0);
assertPositiveResult(min);
assertPositiveResult(max);
}
@DontCompile
public void assertResult(int a) {
Asserts.assertEQ(Math.max(((a >> 1) + 100), Math.max(((a >> 1) + 150), 200)), testMax1(a));
Asserts.assertEQ(Math.max(((a >> 1) + 10), ((a >> 1) + 11)) , testMax2(a));
public void assertPositiveResult(int a) {
Asserts.assertEQ(Math.max(Math.max(((a >> 1) + 150), 200), ((a >> 1) + 100)), testMax1LL(a));
Asserts.assertEQ(testMax1LL(a) , testMax1LR(a));
Asserts.assertEQ(testMax1LL(a) , testMax1RL(a));
Asserts.assertEQ(testMax1LL(a) , testMax1RR(a));
Asserts.assertEQ(Math.max(Math.max((a >> 1), 200), (a >> 1) + 100) , testMax1LLNoInnerAdd(a));
Asserts.assertEQ(Math.max(Math.max((a >> 1), (a << 1)), (a >> 1) + 100) , testMax1LLNoInnerAdd2(a));
Asserts.assertEQ(Math.max(Math.max(((a >> 1) + 150), 200), a >> 1) , testMax1LLNoOuterAdd(a));
Asserts.assertEQ(Math.max(Math.max((a >> 1), 200), a >> 1) , testMax1LLNoAdd(a));
Asserts.assertEQ(Math.max(((a >> 1) + 10), ((a >> 1) + 11)) , testMax2L(a));
Asserts.assertEQ(testMax2L(a) , testMax2R(a));
Asserts.assertEQ(Math.max(a >> 1, ((a >> 1) + 11)) , testMax2LNoLeftAdd(a));
Asserts.assertEQ(Math.max(a, a) , testMax3(a));
Asserts.assertEQ(Math.min(((a >> 1) + 100), Math.min(((a >> 1) + 150), 200)), testMin1(a));
@ -72,10 +89,65 @@ public class MaxMinINodeIdealizationTests {
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1(int i) {
public int testMax1LL(int i) {
return Math.max(Math.max(((i >> 1) + 150), 200), ((i >> 1) + 100));
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1LR(int i) {
return Math.max(Math.max(200, ((i >> 1) + 150)), ((i >> 1) + 100));
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1RL(int i) {
return Math.max(((i >> 1) + 100), Math.max(((i >> 1) + 150), 200));
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1RR(int i) {
return Math.max(((i >> 1) + 100), Math.max(200, ((i >> 1) + 150)));
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1LLNoInnerAdd(int i) {
return Math.max(Math.max((i >> 1), 200), (i >> 1) + 100);
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1LLNoInnerAdd2(int i) {
return Math.max(Math.max((i >> 1), (i << 1)), (i >> 1) + 100);
}
@Test
@IR(counts = {IRNode.MAX_I, "1",
IRNode.ADD , "1",
})
public int testMax1LLNoOuterAdd(int i) {
return Math.max(Math.max(((i >> 1) + 150), 200), i >> 1);
}
@Test
@IR(failOn = {IRNode.ADD})
@IR(counts = {IRNode.MAX_I, "1"})
public int testMax1LLNoAdd(int i) {
return Math.max(Math.max((i >> 1), 200), i >> 1);
}
// Similarly, transform min(x + c0, min(y + c1, z)) to min(add(x, c2), z) if x == y, where c2 = MIN2(c0, c1).
@Test
@IR(counts = {IRNode.MIN_I, "1",
@ -91,10 +163,24 @@ public class MaxMinINodeIdealizationTests {
@Test
@IR(failOn = {IRNode.MAX_I})
@IR(counts = {IRNode.ADD, "1"})
public int testMax2(int i) {
public int testMax2L(int i) {
return Math.max((i >> 1) + 10, (i >> 1) + 11);
}
@Test
@IR(failOn = {IRNode.MAX_I})
@IR(counts = {IRNode.ADD, "1"})
public int testMax2R(int i) {
return Math.max((i >> 1) + 11, (i >> 1) + 10);
}
@Test
@IR(failOn = {IRNode.MAX_I})
@IR(counts = {IRNode.ADD, "1"})
public int testMax2LNoLeftAdd(int i) {
return Math.max(i >> 1, (i >> 1) + 11);
}
// Similarly, transform min(x + c0, y + c1) to add(x, c2) if x == y, where c2 = MIN2(c0, c1).
@Test
@IR(failOn = {IRNode.MIN_I})
@ -116,4 +202,76 @@ public class MaxMinINodeIdealizationTests {
public int testMin3(int i) {
return Math.min(i, i);
}
@Run(test = {"testTwoLevelsDifferentXY",
"testTwoLevelsNoLeftConstant",
"testTwoLevelsNoRightConstant",
"testDifferentXY",
"testNoLeftConstant",
"testNoRightConstant"})
public void runNegativeTests() {
int a = RunInfo.getRandom().nextInt();
int min = Integer.MIN_VALUE;
int max = Integer.MAX_VALUE;
assertNegativeResult(a);
assertNegativeResult(0);
assertNegativeResult(min);
assertNegativeResult(max);
testTwoLevelsDifferentXY(10);
testTwoLevelsNoLeftConstant(10, 42);
testTwoLevelsNoRightConstant(10, 42);
testDifferentXY(10);
testNoLeftConstant(10, 42);
testNoRightConstant(10, 42);
}
@DontCompile
public void assertNegativeResult(int a) {
Asserts.assertEQ(Math.max(Math.max(((a >> 1) + 150), 200), ((a >> 2) + 100)), testTwoLevelsDifferentXY(a));
Asserts.assertEQ(Math.max(Math.max(((a >> 1) + a*2), 200), ((a >> 1) + 100)), testTwoLevelsNoLeftConstant(a, a*2));
Asserts.assertEQ(Math.max(Math.max(((a >> 1) + 150), 200), ((a >> 1) + a*2)), testTwoLevelsNoRightConstant(a, a*2));
Asserts.assertEQ(Math.max((a >> 1) + 10, (a >> 2) + 11), testDifferentXY(a));
Asserts.assertEQ(Math.max((a >> 1) + a*2, (a >> 1) + 11), testNoLeftConstant(a, a*2));
Asserts.assertEQ(Math.max((a >> 1) + 10, (a >> 1) + a*2), testNoRightConstant(a, a*2));
}
@Test
@IR(counts = {IRNode.MAX_I, "2"})
public int testTwoLevelsDifferentXY(int i) {
return Math.max(Math.max(((i >> 1) + 150), 200), ((i >> 2) + 100));
}
@Test
@IR(counts = {IRNode.MAX_I, "2"})
public int testTwoLevelsNoLeftConstant(int i, int c0) {
return Math.max(Math.max(((i >> 1) + c0), 200), ((i >> 1) + 100));
}
@Test
@IR(counts = {IRNode.MAX_I, "2"})
public int testTwoLevelsNoRightConstant(int i, int c1) {
return Math.max(Math.max(((i >> 1) + 150), 200), ((i >> 1) + c1));
}
@Test
@IR(counts = {IRNode.MAX_I, "1"})
public int testDifferentXY(int i) {
return Math.max((i >> 1) + 10, (i >> 2) + 11);
}
@Test
@IR(counts = {IRNode.MAX_I, "1"})
public int testNoLeftConstant(int i, int c0) {
return Math.max((i >> 1) + c0, (i >> 1) + 11);
}
@Test
@IR(counts = {IRNode.MAX_I, "1"})
public int testNoRightConstant(int i, int c1) {
return Math.max((i >> 1) + 10, (i >> 1) + c1);
}
}

View File

@ -794,6 +794,16 @@ public class IRNode {
superWordNodes(MUL_REDUCTION_VL, "MulReductionVL");
}
public static final String MIN_REDUCTION_V = PREFIX + "MIN_REDUCTION_V" + POSTFIX;
static {
superWordNodes(MIN_REDUCTION_V, "MinReductionV");
}
public static final String MAX_REDUCTION_V = PREFIX + "MAX_REDUCTION_V" + POSTFIX;
static {
superWordNodes(MAX_REDUCTION_V, "MaxReductionV");
}
public static final String NEG_V = PREFIX + "NEG_V" + POSTFIX;
static {
beforeMatchingNameRegex(NEG_V, "NegV(F|D)");

View File

@ -0,0 +1,113 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/**
* @test
* @bug 8302673
* @summary [SuperWord] MaxReduction and MinReduction should vectorize for int
* @library /test/lib /
* @run driver compiler.loopopts.superword.MinMaxRed_Int
*/
package compiler.loopopts.superword;
import compiler.lib.ir_framework.*;
import java.util.Arrays;
import java.util.Random;
import jdk.test.lib.Utils;
public class MinMaxRed_Int {
private static final Random random = Utils.getRandomInstance();
public static void main(String[] args) throws Exception {
TestFramework framework = new TestFramework();
framework.addFlags("-XX:+IgnoreUnrecognizedVMOptions",
"-XX:LoopUnrollLimit=250",
"-XX:CompileThresholdScaling=0.1");
framework.start();
}
@Run(test = {"maxReductionImplement"},
mode = RunMode.STANDALONE)
public void runMaxTest() {
int[] a = new int[1024];
int[] b = new int[1024];
ReductionInit(a, b);
int res = 0;
for (int j = 0; j < 2000; j++) {
res = maxReductionImplement(a, b, res);
}
if (res == Arrays.stream(a).max().getAsInt()) {
System.out.println("Success");
} else {
throw new AssertionError("Failed");
}
}
@Run(test = {"minReductionImplement"},
mode = RunMode.STANDALONE)
public void runMinTest() {
int[] a = new int[1024];
int[] b = new int[1024];
ReductionInit(a, b);
int res = 1;
for (int j = 0; j < 2000; j++) {
res = minReductionImplement(a, b, res);
}
if (res == Arrays.stream(a).min().getAsInt()) {
System.out.println("Success");
} else {
throw new AssertionError("Failed");
}
}
public static void ReductionInit(int[] a, int[] b) {
for (int i = 0; i < a.length; i++) {
a[i] = random.nextInt();
b[i] = 1;
}
}
@Test
@IR(applyIf = {"SuperWordReductions", "true"},
applyIfCPUFeatureOr = { "sse4.1", "true" , "asimd" , "true"},
counts = {IRNode.MIN_REDUCTION_V, " > 0"})
public static int minReductionImplement(int[] a, int[] b, int res) {
for (int i = 0; i < a.length; i++) {
res = Math.min(res, a[i] * b[i]);
}
return res;
}
@Test
@IR(applyIf = {"SuperWordReductions", "true"},
applyIfCPUFeatureOr = { "sse4.1", "true" , "asimd" , "true"},
counts = {IRNode.MAX_REDUCTION_V, " > 0"})
public static int maxReductionImplement(int[] a, int[] b, int res) {
for (int i = 0; i < a.length; i++) {
res = Math.max(res, a[i] * b[i]);
}
return res;
}
}