jdk-24/test/hotspot/jtreg/compiler/loopopts/superword/ReductionPerf.java
Emanuel Peter 06b0a5e038 8302652: [SuperWord] Reduction should happen after loop, when possible
Reviewed-by: kvn, pli, jbhateja, sviswanathan
2023-05-23 08:05:13 +00:00

591 lines
22 KiB
Java

/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8074981 8302652
* @summary Test SuperWord Reduction Perf.
* @requires vm.compiler2.enabled
* @requires vm.simpleArch == "x86" | vm.simpleArch == "x64" | vm.simpleArch == "aarch64" | vm.simpleArch == "riscv64"
* @library /test/lib /
* @run main/othervm -Xbatch -XX:LoopUnrollLimit=250
* -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main
* compiler.loopopts.superword.ReductionPerf
*/
package compiler.loopopts.superword;
import java.util.Random;
import jdk.test.lib.Utils;
public class ReductionPerf {
static final int RANGE = 8192;
static Random rand = Utils.getRandomInstance();
public static void main(String args[]) {
// Please increase iterations for measurement to 2_000 and 100_000.
int iter_warmup = 100;
int iter_perf = 1_000;
double[] aDouble = new double[RANGE];
double[] bDouble = new double[RANGE];
double[] cDouble = new double[RANGE];
float[] aFloat = new float[RANGE];
float[] bFloat = new float[RANGE];
float[] cFloat = new float[RANGE];
int[] aInt = new int[RANGE];
int[] bInt = new int[RANGE];
int[] cInt = new int[RANGE];
long[] aLong = new long[RANGE];
long[] bLong = new long[RANGE];
long[] cLong = new long[RANGE];
long start, stop;
int startIntAdd = init(aInt, bInt, cInt);
int goldIntAdd = testIntAdd(aInt, bInt, cInt, startIntAdd);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntAdd(aInt, bInt, cInt, startIntAdd);
verify("int add", total, goldIntAdd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntAdd(aInt, bInt, cInt, startIntAdd);
}
stop = System.currentTimeMillis();
System.out.println("int add " + (stop - start));
int startIntMul = init(aInt, bInt, cInt);
int goldIntMul = testIntMul(aInt, bInt, cInt, startIntMul);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntMul(aInt, bInt, cInt, startIntMul);
verify("int mul", total, goldIntMul);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntMul(aInt, bInt, cInt, startIntMul);
}
stop = System.currentTimeMillis();
System.out.println("int mul " + (stop - start));
int startIntMin = init(aInt, bInt, cInt);
int goldIntMin = testIntMin(aInt, bInt, cInt, startIntMin);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntMin(aInt, bInt, cInt, startIntMin);
verify("int min", total, goldIntMin);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntMin(aInt, bInt, cInt, startIntMin);
}
stop = System.currentTimeMillis();
System.out.println("int min " + (stop - start));
int startIntMax = init(aInt, bInt, cInt);
int goldIntMax = testIntMax(aInt, bInt, cInt, startIntMax);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntMax(aInt, bInt, cInt, startIntMax);
verify("int max", total, goldIntMax);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntMax(aInt, bInt, cInt, startIntMax);
}
stop = System.currentTimeMillis();
System.out.println("int max " + (stop - start));
int startIntAnd = init(aInt, bInt, cInt);
int goldIntAnd = testIntAnd(aInt, bInt, cInt, startIntAnd);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntAnd(aInt, bInt, cInt, startIntAnd);
verify("int and", total, goldIntAnd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntAnd(aInt, bInt, cInt, startIntAnd);
}
stop = System.currentTimeMillis();
System.out.println("int and " + (stop - start));
int startIntOr = init(aInt, bInt, cInt);
int goldIntOr = testIntOr(aInt, bInt, cInt, startIntOr);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntOr(aInt, bInt, cInt, startIntOr);
verify("int or", total, goldIntOr);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntOr(aInt, bInt, cInt, startIntOr);
}
stop = System.currentTimeMillis();
System.out.println("int or " + (stop - start));
int startIntXor = init(aInt, bInt, cInt);
int goldIntXor = testIntXor(aInt, bInt, cInt, startIntXor);
for (int j = 0; j < iter_warmup; j++) {
int total = testIntXor(aInt, bInt, cInt, startIntXor);
verify("int xor", total, goldIntXor);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testIntXor(aInt, bInt, cInt, startIntXor);
}
stop = System.currentTimeMillis();
System.out.println("int xor " + (stop - start));
long startLongAdd = init(aLong, bLong, cLong);
long goldLongAdd = testLongAdd(aLong, bLong, cLong, startLongAdd);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongAdd(aLong, bLong, cLong, startLongAdd);
verify("long add", total, goldLongAdd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongAdd(aLong, bLong, cLong, startLongAdd);
}
stop = System.currentTimeMillis();
System.out.println("long add " + (stop - start));
long startLongMul = init(aLong, bLong, cLong);
long goldLongMul = testLongMul(aLong, bLong, cLong, startLongMul);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongMul(aLong, bLong, cLong, startLongMul);
verify("long mul", total, goldLongMul);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongMul(aLong, bLong, cLong, startLongMul);
}
stop = System.currentTimeMillis();
System.out.println("long mul " + (stop - start));
long startLongMin = init(aLong, bLong, cLong);
long goldLongMin = testLongMin(aLong, bLong, cLong, startLongMin);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongMin(aLong, bLong, cLong, startLongMin);
verify("long min", total, goldLongMin);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongMin(aLong, bLong, cLong, startLongMin);
}
stop = System.currentTimeMillis();
System.out.println("long min " + (stop - start));
long startLongMax = init(aLong, bLong, cLong);
long goldLongMax = testLongMax(aLong, bLong, cLong, startLongMax);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongMax(aLong, bLong, cLong, startLongMax);
verify("long max", total, goldLongMax);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongMax(aLong, bLong, cLong, startLongMax);
}
stop = System.currentTimeMillis();
System.out.println("long max " + (stop - start));
long startLongAnd = init(aLong, bLong, cLong);
long goldLongAnd = testLongAnd(aLong, bLong, cLong, startLongAnd);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongAnd(aLong, bLong, cLong, startLongAnd);
verify("long and", total, goldLongAnd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongAnd(aLong, bLong, cLong, startLongAnd);
}
stop = System.currentTimeMillis();
System.out.println("long and " + (stop - start));
long startLongOr = init(aLong, bLong, cLong);
long goldLongOr = testLongOr(aLong, bLong, cLong, startLongOr);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongOr(aLong, bLong, cLong, startLongOr);
verify("long or", total, goldLongOr);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongOr(aLong, bLong, cLong, startLongOr);
}
stop = System.currentTimeMillis();
System.out.println("long or " + (stop - start));
long startLongXor = init(aLong, bLong, cLong);
long goldLongXor = testLongXor(aLong, bLong, cLong, startLongXor);
for (int j = 0; j < iter_warmup; j++) {
long total = testLongXor(aLong, bLong, cLong, startLongXor);
verify("long xor", total, goldLongXor);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testLongXor(aLong, bLong, cLong, startLongXor);
}
stop = System.currentTimeMillis();
System.out.println("long xor " + (stop - start));
float startFloatAdd = init(aFloat, bFloat, cFloat);
float goldFloatAdd = testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd);
for (int j = 0; j < iter_warmup; j++) {
float total = testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd);
verify("float add", total, goldFloatAdd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd);
}
stop = System.currentTimeMillis();
System.out.println("float add " + (stop - start));
float startFloatMul = init(aFloat, bFloat, cFloat);
float goldFloatMul = testFloatMul(aFloat, bFloat, cFloat, startFloatMul);
for (int j = 0; j < iter_warmup; j++) {
float total = testFloatMul(aFloat, bFloat, cFloat, startFloatMul);
verify("float mul", total, goldFloatMul);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testFloatMul(aFloat, bFloat, cFloat, startFloatMul);
}
stop = System.currentTimeMillis();
System.out.println("float mul " + (stop - start));
float startFloatMin = init(aFloat, bFloat, cFloat);
float goldFloatMin = testFloatMin(aFloat, bFloat, cFloat, startFloatMin);
for (int j = 0; j < iter_warmup; j++) {
float total = testFloatMin(aFloat, bFloat, cFloat, startFloatMin);
verify("float min", total, goldFloatMin);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testFloatMin(aFloat, bFloat, cFloat, startFloatMin);
}
stop = System.currentTimeMillis();
System.out.println("float min " + (stop - start));
float startFloatMax = init(aFloat, bFloat, cFloat);
float goldFloatMax = testFloatMax(aFloat, bFloat, cFloat, startFloatMax);
for (int j = 0; j < iter_warmup; j++) {
float total = testFloatMax(aFloat, bFloat, cFloat, startFloatMax);
verify("float max", total, goldFloatMax);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testFloatMax(aFloat, bFloat, cFloat, startFloatMax);
}
stop = System.currentTimeMillis();
System.out.println("float max " + (stop - start));
double startDoubleAdd = init(aDouble, bDouble, cDouble);
double goldDoubleAdd = testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd);
for (int j = 0; j < iter_warmup; j++) {
double total = testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd);
verify("double add", total, goldDoubleAdd);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd);
}
stop = System.currentTimeMillis();
System.out.println("double add " + (stop - start));
double startDoubleMul = init(aDouble, bDouble, cDouble);
double goldDoubleMul = testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul);
for (int j = 0; j < iter_warmup; j++) {
double total = testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul);
verify("double mul", total, goldDoubleMul);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul);
}
stop = System.currentTimeMillis();
System.out.println("double mul " + (stop - start));
double startDoubleMin = init(aDouble, bDouble, cDouble);
double goldDoubleMin = testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin);
for (int j = 0; j < iter_warmup; j++) {
double total = testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin);
verify("double min", total, goldDoubleMin);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin);
}
stop = System.currentTimeMillis();
System.out.println("double min " + (stop - start));
double startDoubleMax = init(aDouble, bDouble, cDouble);
double goldDoubleMax = testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax);
for (int j = 0; j < iter_warmup; j++) {
double total = testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax);
verify("double max", total, goldDoubleMax);
}
start = System.currentTimeMillis();
for (int j = 0; j < iter_perf; j++) {
testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax);
}
stop = System.currentTimeMillis();
System.out.println("double max " + (stop - start));
}
// ------------------- Tests -------------------
static int testIntAdd(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total += v;
}
return total;
}
static int testIntMul(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total *= v;
}
return total;
}
static int testIntMin(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.min(total, v);
}
return total;
}
static int testIntMax(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.max(total, v);
}
return total;
}
static int testIntAnd(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total &= v;
}
return total;
}
static int testIntOr(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total |= v;
}
return total;
}
static int testIntXor(int[] a, int[] b, int[] c, int total) {
for (int i = 0; i < RANGE; i++) {
int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total ^= v;
}
return total;
}
static long testLongAdd(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total += v;
}
return total;
}
static long testLongMul(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total *= v;
}
return total;
}
static long testLongMin(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.min(total, v);
}
return total;
}
static long testLongMax(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.max(total, v);
}
return total;
}
static long testLongAnd(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total &= v;
}
return total;
}
static long testLongOr(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total |= v;
}
return total;
}
static long testLongXor(long[] a, long[] b, long[] c, long total) {
for (int i = 0; i < RANGE; i++) {
long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total ^= v;
}
return total;
}
static float testFloatAdd(float[] a, float[] b, float[] c, float total) {
for (int i = 0; i < RANGE; i++) {
float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total += v;
}
return total;
}
static float testFloatMul(float[] a, float[] b, float[] c, float total) {
for (int i = 0; i < RANGE; i++) {
float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total *= v;
}
return total;
}
static float testFloatMin(float[] a, float[] b, float[] c, float total) {
for (int i = 0; i < RANGE; i++) {
float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.min(total, v);
}
return total;
}
static float testFloatMax(float[] a, float[] b, float[] c, float total) {
for (int i = 0; i < RANGE; i++) {
float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.max(total, v);
}
return total;
}
static double testDoubleAdd(double[] a, double[] b, double[] c, double total) {
for (int i = 0; i < RANGE; i++) {
double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total += v;
}
return total;
}
static double testDoubleMul(double[] a, double[] b, double[] c, double total) {
for (int i = 0; i < RANGE; i++) {
double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total *= v;
}
return total;
}
static double testDoubleMin(double[] a, double[] b, double[] c, double total) {
for (int i = 0; i < RANGE; i++) {
double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.min(total, v);
}
return total;
}
static double testDoubleMax(double[] a, double[] b, double[] c, double total) {
for (int i = 0; i < RANGE; i++) {
double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]);
total = Math.max(total, v);
}
return total;
}
// ------------------- Initialization -------------------
static int init(int[] a, int[] b, int[] c) {
for (int j = 0; j < RANGE; j++) {
a[j] = rand.nextInt();
b[j] = rand.nextInt();
c[j] = rand.nextInt();
}
return rand.nextInt();
}
static long init(long[] a, long[] b, long[] c) {
for (int j = 0; j < RANGE; j++) {
a[j] = rand.nextLong();
b[j] = rand.nextLong();
c[j] = rand.nextLong();
}
return rand.nextLong();
}
static float init(float[] a, float[] b, float[] c) {
for (int j = 0; j < RANGE; j++) {
a[j] = rand.nextFloat();
b[j] = rand.nextFloat();
c[j] = rand.nextFloat();
}
return rand.nextFloat();
}
static double init(double[] a, double[] b, double[] c) {
for (int j = 0; j < RANGE; j++) {
a[j] = rand.nextDouble();
b[j] = rand.nextDouble();
c[j] = rand.nextDouble();
}
return rand.nextDouble();
}
// ------------------- Verification -------------------
static void verify(String context, double total, double gold) {
if (total != gold) {
throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold);
}
}
static void verify(String context, float total, float gold) {
if (total != gold) {
throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold);
}
}
static void verify(String context, int total, int gold) {
if (total != gold) {
throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold);
}
}
static void verify(String context, long total, long gold) {
if (total != gold) {
throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold);
}
}
}