diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index e71a6026209..5cd4341c42d 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2794,11 +2794,7 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const { } if (VectorNode::is_muladds2i(use)) { - // MulAddS2I takes shorts and produces ints. - if (u_pk->size() * 2 != d_pk->size()) { - return false; - } - return true; + return _packset.is_muladds2i_pack_with_pack_inputs(u_pk); } if (u_pk->size() != d_pk->size()) { @@ -2815,6 +2811,59 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const { return true; } +// MulAddS2I takes 4 shorts and produces an int. We can reinterpret +// the 4 shorts as two ints: a = (a0, a1) and b = (b0, b1). +// +// Inputs: 1 2 3 4 +// Offsets: 0 0 1 1 +// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1 +// +// But permutations are possible, because add and mul are commutative. For +// simplicity, the first input is always either a0 or a1. These are all +// the possible permutations: +// +// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1 (case 1) +// v = MulAddS2I(a, b) = a0 * b0 + b1 * a1 (case 2) +// v = MulAddS2I(a, b) = a1 * b1 + a0 * b0 (case 3) +// v = MulAddS2I(a, b) = a1 * b1 + b0 * a0 (case 4) +// +// To vectorize, we expect (a0, a1) to be consecutive in one input pack, +// and (b0, b1) in the other input pack. Thus, both a and b are strided, +// with stride = 2. Further, a0 and b0 have offset 0, whereas a1 and b1 +// have offset 1. +bool PackSet::is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const { + assert(VectorNode::is_muladds2i(pack->at(0)), "must be MulAddS2I"); + + bool pack1_has_offset_0 = (strided_pack_input_at_index_or_null(pack, 1, 2, 0) != nullptr); + Node_List* pack1 = strided_pack_input_at_index_or_null(pack, 1, 2, pack1_has_offset_0 ? 0 : 1); + Node_List* pack2 = strided_pack_input_at_index_or_null(pack, 2, 2, pack1_has_offset_0 ? 0 : 1); + Node_List* pack3 = strided_pack_input_at_index_or_null(pack, 3, 2, pack1_has_offset_0 ? 1 : 0); + Node_List* pack4 = strided_pack_input_at_index_or_null(pack, 4, 2, pack1_has_offset_0 ? 1 : 0); + + return pack1 != nullptr && + pack2 != nullptr && + pack3 != nullptr && + pack4 != nullptr && + ((pack1 == pack3 && pack2 == pack4) || // case 1 or 3 + (pack1 == pack4 && pack2 == pack3)); // case 2 or 4 +} + +Node_List* PackSet::strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const { + Node* def0 = pack->at(0)->in(index); + + Node_List* pack_in = get_pack(def0); + if (pack_in == nullptr || pack->size() * stride != pack_in->size()) { + return nullptr; // size mismatch + } + + for (uint i = 1; i < pack->size(); i++) { + if (pack->at(i)->in(index) != pack_in->at(i * stride + offset)) { + return nullptr; // use-def mismatch + } + } + return pack_in; +} + // Check if the output type of def is compatible with the input type of use, i.e. if the // types have the same size. bool SuperWord::is_velt_basic_type_compatible_use_def(Node* use, Node* def) const { diff --git a/src/hotspot/share/opto/superword.hpp b/src/hotspot/share/opto/superword.hpp index c118b420117..fb91d014fae 100644 --- a/src/hotspot/share/opto/superword.hpp +++ b/src/hotspot/share/opto/superword.hpp @@ -362,8 +362,9 @@ public: } } + Node_List* strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const; + bool is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const; Node* same_inputs_at_index_or_null(const Node_List* pack, const int index) const; - VTransformBoolTest get_bool_test(const Node_List* bool_pack) const; private: diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/TestMulAddS2I.java b/test/hotspot/jtreg/compiler/loopopts/superword/TestMulAddS2I.java index 4521d43804b..578d4ee8bdb 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/TestMulAddS2I.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/TestMulAddS2I.java @@ -41,7 +41,6 @@ public class TestMulAddS2I { static short[] sArr1 = new short[RANGE]; static short[] sArr2 = new short[RANGE]; - static int[] ioutArr = new int[RANGE]; static final int[] GOLDEN_A; static final int[] GOLDEN_B; static final int[] GOLDEN_C; @@ -50,6 +49,10 @@ public class TestMulAddS2I { static final int[] GOLDEN_F; static final int[] GOLDEN_G; static final int[] GOLDEN_H; + static final int[] GOLDEN_I; + static final int[] GOLDEN_J; + static final int[] GOLDEN_K; + static final int[] GOLDEN_L; static { for (int i = 0; i < RANGE; i++) { @@ -58,12 +61,16 @@ public class TestMulAddS2I { } GOLDEN_A = testa(); GOLDEN_B = testb(); - GOLDEN_C = testc(); - GOLDEN_D = testd(); - GOLDEN_E = teste(); - GOLDEN_F = testf(); - GOLDEN_G = testg(); - GOLDEN_H = testh(); + GOLDEN_C = testc(new int[ITER]); + GOLDEN_D = testd(new int[ITER]); + GOLDEN_E = teste(new int[ITER]); + GOLDEN_F = testf(new int[ITER]); + GOLDEN_G = testg(new int[ITER]); + GOLDEN_H = testh(new int[ITER]); + GOLDEN_I = testi(new int[ITER]); + GOLDEN_J = testj(new int[ITER]); + GOLDEN_K = testk(new int[ITER]); + GOLDEN_L = testl(new int[ITER]); } @@ -72,17 +79,22 @@ public class TestMulAddS2I { TestFramework.runWithFlags("-XX:-AlignVector"); } - @Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh"}) + @Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh", + "testi", "testj", "testk", "testl"}) @Warmup(0) public static void run() { compare(testa(), GOLDEN_A, "testa"); compare(testb(), GOLDEN_B, "testb"); - compare(testc(), GOLDEN_C, "testc"); - compare(testd(), GOLDEN_D, "testd"); - compare(teste(), GOLDEN_E, "teste"); - compare(testf(), GOLDEN_F, "testf"); - compare(testg(), GOLDEN_G, "testg"); - compare(testh(), GOLDEN_H, "testh"); + compare(testc(new int[ITER]), GOLDEN_C, "testc"); + compare(testd(new int[ITER]), GOLDEN_D, "testd"); + compare(teste(new int[ITER]), GOLDEN_E, "teste"); + compare(testf(new int[ITER]), GOLDEN_F, "testf"); + compare(testg(new int[ITER]), GOLDEN_G, "testg"); + compare(testh(new int[ITER]), GOLDEN_H, "testh"); + compare(testi(new int[ITER]), GOLDEN_I, "testi"); + compare(testj(new int[ITER]), GOLDEN_J, "testj"); + compare(testk(new int[ITER]), GOLDEN_K, "testk"); + compare(testl(new int[ITER]), GOLDEN_L, "testl"); } public static void compare(int[] out, int[] golden, String name) { @@ -138,8 +150,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] testc() { - int[] out = new int[ITER]; + public static int[] testc(int[] out) { for (int i = 0; i < ITER; i++) { out[i] += ((sArr1[2*i] * sArr2[2*i]) + (sArr1[2*i+1] * sArr2[2*i+1])); } @@ -155,8 +166,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] testd() { - int[] out = ioutArr; + public static int[] testd(int[] out) { for (int i = 0; i < ITER-2; i+=2) { // Unrolled, with the same structure. out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); @@ -174,8 +184,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] teste() { - int[] out = ioutArr; + public static int[] teste(int[] out) { for (int i = 0; i < ITER-2; i+=2) { // Unrolled, with some swaps. out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); @@ -193,8 +202,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] testf() { - int[] out = ioutArr; + public static int[] testf(int[] out) { for (int i = 0; i < ITER-2; i+=2) { // Unrolled, with some swaps. out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); @@ -212,8 +220,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] testg() { - int[] out = ioutArr; + public static int[] testg(int[] out) { for (int i = 0; i < ITER-2; i+=2) { // Unrolled, with some swaps. out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); @@ -231,8 +238,7 @@ public class TestMulAddS2I { counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) @IR(applyIfCPUFeature = {"avx512_vnni", "true"}, counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) - public static int[] testh() { - int[] out = ioutArr; + public static int[] testh(int[] out) { for (int i = 0; i < ITER-2; i+=2) { // Unrolled, with some swaps. out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); @@ -240,4 +246,57 @@ public class TestMulAddS2I { } return out; } + + @Test + @IR(counts = {IRNode.MUL_ADD_S2I, "> 0"}, + applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"}) + @IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"}) + public static int[] testi(int[] out) { + for (int i = 0; i < ITER-2; i+=2) { + // Unrolled, with some swaps that prevent vectorization. + out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); // ok + out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad + } + return out; + } + + @Test + @IR(counts = {IRNode.MUL_ADD_S2I, "> 0"}, + applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"}) + @IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"}) + public static int[] testj(int[] out) { + for (int i = 0; i < ITER-2; i+=2) { + // Unrolled, with some swaps that prevent vectorization. + out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad + out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad + } + return out; + } + + @Test + @IR(counts = {IRNode.MUL_ADD_S2I, "> 0"}, + applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"}) + @IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"}) + public static int[] testk(int[] out) { + for (int i = 0; i < ITER-2; i+=2) { + // Unrolled, with some swaps that prevent vectorization. + out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad + out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+2]) + (sArr1[2*i+3] * sArr2[2*i+3])); // ok + } + return out; + } + + @Test + @IR(counts = {IRNode.MUL_ADD_S2I, "> 0"}, + applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"}) + @IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"}) + public static int[] testl(int[] out) { + for (int i = 0; i < ITER-2; i+=2) { + // Unrolled, with some swaps that prevent vectorization. + out[i+0] += ((sArr1[2*i+1] * sArr2[2*i+1]) + (sArr1[2*i+0] * sArr2[2*i+0])); // ok + out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad + } + return out; + } + }