diff --git a/src/hotspot/cpu/aarch64/aarch64.ad b/src/hotspot/cpu/aarch64/aarch64.ad index f1dcc0443c6..81417c9509a 100644 --- a/src/hotspot/cpu/aarch64/aarch64.ad +++ b/src/hotspot/cpu/aarch64/aarch64.ad @@ -2713,11 +2713,17 @@ bool size_fits_all_mem_uses(AddPNode* addp, int shift) { return true; } -bool can_combine_with_imm(Node* binary_node, Node* replicate_node) { - if (UseSVE == 0 || !VectorNode::is_invariant_vector(replicate_node)){ +// Binary src (Replicate con) +bool is_valid_sve_arith_imm_pattern(Node* n, Node* m) { + if (n == NULL || m == NULL) { return false; } - Node* imm_node = replicate_node->in(1); + + if (UseSVE == 0 || !VectorNode::is_invariant_vector(m)) { + return false; + } + + Node* imm_node = m->in(1); if (!imm_node->is_Con()) { return false; } @@ -2727,11 +2733,11 @@ bool can_combine_with_imm(Node* binary_node, Node* replicate_node) { return false; } - switch (binary_node->Opcode()) { + switch (n->Opcode()) { case Op_AndV: case Op_OrV: case Op_XorV: { - Assembler::SIMD_RegVariant T = Assembler::elemType_to_regVariant(Matcher::vector_element_basic_type(binary_node)); + Assembler::SIMD_RegVariant T = Assembler::elemType_to_regVariant(Matcher::vector_element_basic_type(n)); uint64_t value = t->isa_long() ? (uint64_t)imm_node->get_long() : (uint64_t)imm_node->get_int(); return Assembler::operand_valid_for_sve_logical_immediate(Assembler::regVariant_to_elemBits(T), value); } @@ -2747,22 +2753,24 @@ bool can_combine_with_imm(Node* binary_node, Node* replicate_node) { } } -bool is_vector_arith_imm_pattern(Node* n, Node* m) { +// (XorV src (Replicate m1)) +// (XorVMask src (MaskAll m1)) +bool is_vector_bitwise_not_pattern(Node* n, Node* m) { if (n != NULL && m != NULL) { - return can_combine_with_imm(n, m); + return (n->Opcode() == Op_XorV || n->Opcode() == Op_XorVMask) && + VectorNode::is_all_ones_vector(m); } return false; } // Should the matcher clone input 'm' of node 'n'? bool Matcher::pd_clone_node(Node* n, Node* m, Matcher::MStack& mstack) { - // ShiftV src (ShiftCntV con) - // Binary src (Replicate con) - if (is_vshift_con_pattern(n, m) || is_vector_arith_imm_pattern(n, m)) { + if (is_vshift_con_pattern(n, m) || + is_vector_bitwise_not_pattern(n, m) || + is_valid_sve_arith_imm_pattern(n, m)) { mstack.push(m, Visit); return true; } - return false; } diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index e5ac7379f3b..e39849a5d24 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -847,6 +847,7 @@ bool VectorNode::is_all_ones_vector(Node* n) { case Op_ReplicateS: case Op_ReplicateI: case Op_ReplicateL: + case Op_MaskAll: return is_con_M1(n->in(1)); default: return false; diff --git a/test/hotspot/jtreg/compiler/vectorapi/AllBitsSetVectorMatchRuleTest.java b/test/hotspot/jtreg/compiler/vectorapi/AllBitsSetVectorMatchRuleTest.java new file mode 100644 index 00000000000..1edb09f0194 --- /dev/null +++ b/test/hotspot/jtreg/compiler/vectorapi/AllBitsSetVectorMatchRuleTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2022, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package compiler.vectorapi; + +import compiler.lib.ir_framework.*; + +import java.util.Random; + +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.VectorMask; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +import jdk.test.lib.Asserts; +import jdk.test.lib.Utils; + +/** + * @test + * @bug 8287984 + * @key randomness + * @library /test/lib / + * @requires vm.compiler2.enabled + * @requires vm.cpu.features ~= ".*simd.*" | vm.cpu.features ~= ".*sve.*" + * @summary AArch64: [vector] Make all bits set vector sharable for match rules + * @modules jdk.incubator.vector + * + * @run driver compiler.vectorapi.AllBitsSetVectorMatchRuleTest + */ + +public class AllBitsSetVectorMatchRuleTest { + private static final VectorSpecies I_SPECIES = IntVector.SPECIES_MAX; + private static final VectorSpecies L_SPECIES = LongVector.SPECIES_MAX; + + private static int LENGTH = 128; + private static final Random RD = Utils.getRandomInstance(); + + private static int[] ia; + private static int[] ib; + private static int[] ir; + private static boolean[] ma; + private static boolean[] mb; + private static boolean[] mc; + private static boolean[] mr; + + static { + ia = new int[LENGTH]; + ib = new int[LENGTH]; + ir = new int[LENGTH]; + ma = new boolean[LENGTH]; + mb = new boolean[LENGTH]; + mc = new boolean[LENGTH]; + mr = new boolean[LENGTH]; + + for (int i = 0; i < LENGTH; i++) { + ia[i] = RD.nextInt(25); + ib[i] = RD.nextInt(25); + ma[i] = RD.nextBoolean(); + mb[i] = RD.nextBoolean(); + mc[i] = RD.nextBoolean(); + } + } + + @Test + @Warmup(10000) + @IR(counts = { "bic", " >= 1" }) + public static void testAllBitsSetVector() { + IntVector av = IntVector.fromArray(I_SPECIES, ia, 0); + IntVector bv = IntVector.fromArray(I_SPECIES, ib, 0); + av.not().lanewise(VectorOperators.AND_NOT, bv).intoArray(ir, 0); + + // Verify results + for (int i = 0; i < I_SPECIES.length(); i++) { + Asserts.assertEquals((~ia[i]) & (~ib[i]), ir[i]); + } + } + + @Test + @Warmup(10000) + @IR(counts = { "bic", " >= 1" }) + public static void testAllBitsSetMask() { + VectorMask avm = VectorMask.fromArray(L_SPECIES, ma, 0); + VectorMask bvm = VectorMask.fromArray(L_SPECIES, mb, 0); + VectorMask cvm = VectorMask.fromArray(L_SPECIES, mc, 0); + avm.andNot(bvm).andNot(cvm).intoArray(mr, 0); + + // Verify results + for (int i = 0; i < L_SPECIES.length(); i++) { + Asserts.assertEquals((ma[i] & (!mb[i])) & (!mc[i]), mr[i]); + } + } + + public static void main(String[] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector"); + } +}