8333840: C2 SuperWord: wrong result for MulAddS2I when inputs permuted

Reviewed-by: kvn, chagedorn
This commit is contained in:
Emanuel Peter 2024-06-13 18:11:36 +00:00
parent cff048c735
commit b09a45163c
3 changed files with 141 additions and 32 deletions

View File

@ -2794,11 +2794,7 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const {
} }
if (VectorNode::is_muladds2i(use)) { if (VectorNode::is_muladds2i(use)) {
// MulAddS2I takes shorts and produces ints. return _packset.is_muladds2i_pack_with_pack_inputs(u_pk);
if (u_pk->size() * 2 != d_pk->size()) {
return false;
}
return true;
} }
if (u_pk->size() != d_pk->size()) { if (u_pk->size() != d_pk->size()) {
@ -2815,6 +2811,59 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const {
return true; return true;
} }
// MulAddS2I takes 4 shorts and produces an int. We can reinterpret
// the 4 shorts as two ints: a = (a0, a1) and b = (b0, b1).
//
// Inputs: 1 2 3 4
// Offsets: 0 0 1 1
// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1
//
// But permutations are possible, because add and mul are commutative. For
// simplicity, the first input is always either a0 or a1. These are all
// the possible permutations:
//
// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1 (case 1)
// v = MulAddS2I(a, b) = a0 * b0 + b1 * a1 (case 2)
// v = MulAddS2I(a, b) = a1 * b1 + a0 * b0 (case 3)
// v = MulAddS2I(a, b) = a1 * b1 + b0 * a0 (case 4)
//
// To vectorize, we expect (a0, a1) to be consecutive in one input pack,
// and (b0, b1) in the other input pack. Thus, both a and b are strided,
// with stride = 2. Further, a0 and b0 have offset 0, whereas a1 and b1
// have offset 1.
bool PackSet::is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const {
assert(VectorNode::is_muladds2i(pack->at(0)), "must be MulAddS2I");
bool pack1_has_offset_0 = (strided_pack_input_at_index_or_null(pack, 1, 2, 0) != nullptr);
Node_List* pack1 = strided_pack_input_at_index_or_null(pack, 1, 2, pack1_has_offset_0 ? 0 : 1);
Node_List* pack2 = strided_pack_input_at_index_or_null(pack, 2, 2, pack1_has_offset_0 ? 0 : 1);
Node_List* pack3 = strided_pack_input_at_index_or_null(pack, 3, 2, pack1_has_offset_0 ? 1 : 0);
Node_List* pack4 = strided_pack_input_at_index_or_null(pack, 4, 2, pack1_has_offset_0 ? 1 : 0);
return pack1 != nullptr &&
pack2 != nullptr &&
pack3 != nullptr &&
pack4 != nullptr &&
((pack1 == pack3 && pack2 == pack4) || // case 1 or 3
(pack1 == pack4 && pack2 == pack3)); // case 2 or 4
}
Node_List* PackSet::strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const {
Node* def0 = pack->at(0)->in(index);
Node_List* pack_in = get_pack(def0);
if (pack_in == nullptr || pack->size() * stride != pack_in->size()) {
return nullptr; // size mismatch
}
for (uint i = 1; i < pack->size(); i++) {
if (pack->at(i)->in(index) != pack_in->at(i * stride + offset)) {
return nullptr; // use-def mismatch
}
}
return pack_in;
}
// Check if the output type of def is compatible with the input type of use, i.e. if the // Check if the output type of def is compatible with the input type of use, i.e. if the
// types have the same size. // types have the same size.
bool SuperWord::is_velt_basic_type_compatible_use_def(Node* use, Node* def) const { bool SuperWord::is_velt_basic_type_compatible_use_def(Node* use, Node* def) const {

View File

@ -362,8 +362,9 @@ public:
} }
} }
Node_List* strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const;
bool is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const;
Node* same_inputs_at_index_or_null(const Node_List* pack, const int index) const; Node* same_inputs_at_index_or_null(const Node_List* pack, const int index) const;
VTransformBoolTest get_bool_test(const Node_List* bool_pack) const; VTransformBoolTest get_bool_test(const Node_List* bool_pack) const;
private: private:

View File

@ -41,7 +41,6 @@ public class TestMulAddS2I {
static short[] sArr1 = new short[RANGE]; static short[] sArr1 = new short[RANGE];
static short[] sArr2 = new short[RANGE]; static short[] sArr2 = new short[RANGE];
static int[] ioutArr = new int[RANGE];
static final int[] GOLDEN_A; static final int[] GOLDEN_A;
static final int[] GOLDEN_B; static final int[] GOLDEN_B;
static final int[] GOLDEN_C; static final int[] GOLDEN_C;
@ -50,6 +49,10 @@ public class TestMulAddS2I {
static final int[] GOLDEN_F; static final int[] GOLDEN_F;
static final int[] GOLDEN_G; static final int[] GOLDEN_G;
static final int[] GOLDEN_H; static final int[] GOLDEN_H;
static final int[] GOLDEN_I;
static final int[] GOLDEN_J;
static final int[] GOLDEN_K;
static final int[] GOLDEN_L;
static { static {
for (int i = 0; i < RANGE; i++) { for (int i = 0; i < RANGE; i++) {
@ -58,12 +61,16 @@ public class TestMulAddS2I {
} }
GOLDEN_A = testa(); GOLDEN_A = testa();
GOLDEN_B = testb(); GOLDEN_B = testb();
GOLDEN_C = testc(); GOLDEN_C = testc(new int[ITER]);
GOLDEN_D = testd(); GOLDEN_D = testd(new int[ITER]);
GOLDEN_E = teste(); GOLDEN_E = teste(new int[ITER]);
GOLDEN_F = testf(); GOLDEN_F = testf(new int[ITER]);
GOLDEN_G = testg(); GOLDEN_G = testg(new int[ITER]);
GOLDEN_H = testh(); GOLDEN_H = testh(new int[ITER]);
GOLDEN_I = testi(new int[ITER]);
GOLDEN_J = testj(new int[ITER]);
GOLDEN_K = testk(new int[ITER]);
GOLDEN_L = testl(new int[ITER]);
} }
@ -72,17 +79,22 @@ public class TestMulAddS2I {
TestFramework.runWithFlags("-XX:-AlignVector"); TestFramework.runWithFlags("-XX:-AlignVector");
} }
@Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh"}) @Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh",
"testi", "testj", "testk", "testl"})
@Warmup(0) @Warmup(0)
public static void run() { public static void run() {
compare(testa(), GOLDEN_A, "testa"); compare(testa(), GOLDEN_A, "testa");
compare(testb(), GOLDEN_B, "testb"); compare(testb(), GOLDEN_B, "testb");
compare(testc(), GOLDEN_C, "testc"); compare(testc(new int[ITER]), GOLDEN_C, "testc");
compare(testd(), GOLDEN_D, "testd"); compare(testd(new int[ITER]), GOLDEN_D, "testd");
compare(teste(), GOLDEN_E, "teste"); compare(teste(new int[ITER]), GOLDEN_E, "teste");
compare(testf(), GOLDEN_F, "testf"); compare(testf(new int[ITER]), GOLDEN_F, "testf");
compare(testg(), GOLDEN_G, "testg"); compare(testg(new int[ITER]), GOLDEN_G, "testg");
compare(testh(), GOLDEN_H, "testh"); compare(testh(new int[ITER]), GOLDEN_H, "testh");
compare(testi(new int[ITER]), GOLDEN_I, "testi");
compare(testj(new int[ITER]), GOLDEN_J, "testj");
compare(testk(new int[ITER]), GOLDEN_K, "testk");
compare(testl(new int[ITER]), GOLDEN_L, "testl");
} }
public static void compare(int[] out, int[] golden, String name) { public static void compare(int[] out, int[] golden, String name) {
@ -138,8 +150,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] testc() { public static int[] testc(int[] out) {
int[] out = new int[ITER];
for (int i = 0; i < ITER; i++) { for (int i = 0; i < ITER; i++) {
out[i] += ((sArr1[2*i] * sArr2[2*i]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i] += ((sArr1[2*i] * sArr2[2*i]) + (sArr1[2*i+1] * sArr2[2*i+1]));
} }
@ -155,8 +166,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] testd() { public static int[] testd(int[] out) {
int[] out = ioutArr;
for (int i = 0; i < ITER-2; i+=2) { for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with the same structure. // Unrolled, with the same structure.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@ -174,8 +184,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] teste() { public static int[] teste(int[] out) {
int[] out = ioutArr;
for (int i = 0; i < ITER-2; i+=2) { for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps. // Unrolled, with some swaps.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@ -193,8 +202,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] testf() { public static int[] testf(int[] out) {
int[] out = ioutArr;
for (int i = 0; i < ITER-2; i+=2) { for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps. // Unrolled, with some swaps.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@ -212,8 +220,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] testg() { public static int[] testg(int[] out) {
int[] out = ioutArr;
for (int i = 0; i < ITER-2; i+=2) { for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps. // Unrolled, with some swaps.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@ -231,8 +238,7 @@ public class TestMulAddS2I {
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
@IR(applyIfCPUFeature = {"avx512_vnni", "true"}, @IR(applyIfCPUFeature = {"avx512_vnni", "true"},
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"}) counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
public static int[] testh() { public static int[] testh(int[] out) {
int[] out = ioutArr;
for (int i = 0; i < ITER-2; i+=2) { for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps. // Unrolled, with some swaps.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@ -240,4 +246,57 @@ public class TestMulAddS2I {
} }
return out; return out;
} }
@Test
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
public static int[] testi(int[] out) {
for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps that prevent vectorization.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); // ok
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
}
return out;
}
@Test
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
public static int[] testj(int[] out) {
for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps that prevent vectorization.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
}
return out;
}
@Test
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
public static int[] testk(int[] out) {
for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps that prevent vectorization.
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+2]) + (sArr1[2*i+3] * sArr2[2*i+3])); // ok
}
return out;
}
@Test
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
public static int[] testl(int[] out) {
for (int i = 0; i < ITER-2; i+=2) {
// Unrolled, with some swaps that prevent vectorization.
out[i+0] += ((sArr1[2*i+1] * sArr2[2*i+1]) + (sArr1[2*i+0] * sArr2[2*i+0])); // ok
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
}
return out;
}
} }