8302908: RISC-V: Support masked vector arithmetic instructions for Vector API

Co-authored-by: zifeihan <caogui@iscas.ac.cn>
Reviewed-by: fyang, fjiang, yzhu
This commit is contained in:
Dingli Zhang 2023-04-26 02:24:49 +00:00 committed by Fei Yang
parent adf62febe6
commit 1c1a73f715
7 changed files with 993 additions and 138 deletions

View File

@ -1488,6 +1488,16 @@ enum VectorMask {
#undef INSN
#define INSN(NAME, op, funct3, vm, funct6) \
void NAME(VectorRegister Vd, VectorRegister Vs2, Register Rs1) { \
patch_VArith(op, Vd, funct3, Rs1->raw_encoding(), Vs2, vm, funct6); \
}
// Vector Integer Merge Instructions
INSN(vmerge_vxm, 0b1010111, 0b100, 0b0, 0b010111);
#undef INSN
#define INSN(NAME, op, funct3, funct6) \
void NAME(VectorRegister Vd, VectorRegister Vs2, FloatRegister Rs1, VectorMask vm = unmasked) { \
patch_VArith(op, Vd, funct3, Rs1->raw_encoding(), Vs2, vm, funct6); \
@ -1542,6 +1552,17 @@ enum VectorMask {
#undef INSN
#define INSN(NAME, op, funct3, vm, funct6) \
void NAME(VectorRegister Vd, VectorRegister Vs2, int32_t imm) { \
guarantee(is_simm5(imm), "imm is invalid"); \
patch_VArith(op, Vd, funct3, (uint32_t)(imm & 0x1f), Vs2, vm, funct6); \
}
// Vector Integer Merge Instructions
INSN(vmerge_vim, 0b1010111, 0b011, 0b0, 0b010111);
#undef INSN
#define INSN(NAME, op, funct3, vm, funct6) \
void NAME(VectorRegister Vd, VectorRegister Vs2, VectorRegister Vs1) { \
patch_VArith(op, Vd, funct3, Vs1->raw_encoding(), Vs2, vm, funct6); \
@ -1560,6 +1581,9 @@ enum VectorMask {
INSN(vmnand_mm, 0b1010111, 0b010, 0b1, 0b011101);
INSN(vmand_mm, 0b1010111, 0b010, 0b1, 0b011001);
// Vector Integer Merge Instructions
INSN(vmerge_vvm, 0b1010111, 0b000, 0b0, 0b010111);
#undef INSN
#define INSN(NAME, op, funct3, Vs2, vm, funct6) \

View File

@ -1304,7 +1304,7 @@ void C2_MacroAssembler::enc_cmove(int cmpFlag, Register op1, Register op2, Regis
}
// Set dst to NaN if any NaN input.
void C2_MacroAssembler::minmax_FD(FloatRegister dst, FloatRegister src1, FloatRegister src2,
void C2_MacroAssembler::minmax_fp(FloatRegister dst, FloatRegister src1, FloatRegister src2,
bool is_double, bool is_min) {
assert_different_registers(dst, src1, src2);
@ -1616,7 +1616,7 @@ void C2_MacroAssembler::string_indexof_char_v(Register str1, Register cnt1,
}
// Set dst to NaN if any NaN input.
void C2_MacroAssembler::minmax_FD_v(VectorRegister dst, VectorRegister src1, VectorRegister src2,
void C2_MacroAssembler::minmax_fp_v(VectorRegister dst, VectorRegister src1, VectorRegister src2,
bool is_double, bool is_min, int length_in_bytes) {
assert_different_registers(dst, src1, src2);
@ -1632,7 +1632,7 @@ void C2_MacroAssembler::minmax_FD_v(VectorRegister dst, VectorRegister src1, Vec
}
// Set dst to NaN if any NaN input.
void C2_MacroAssembler::reduce_minmax_FD_v(FloatRegister dst,
void C2_MacroAssembler::reduce_minmax_fp_v(FloatRegister dst,
FloatRegister src1, VectorRegister src2,
VectorRegister tmp1, VectorRegister tmp2,
bool is_double, bool is_min, int length_in_bytes) {
@ -1722,3 +1722,64 @@ void C2_MacroAssembler::rvv_vsetvli(BasicType bt, int length_in_bytes, Register
}
}
}
void C2_MacroAssembler::compare_integral_v(VectorRegister vd, BasicType bt, int length_in_bytes,
VectorRegister src1, VectorRegister src2, int cond, VectorMask vm) {
assert(is_integral_type(bt), "unsupported element type");
assert(vm == Assembler::v0_t ? vd != v0 : true, "should be different registers");
rvv_vsetvli(bt, length_in_bytes);
vmclr_m(vd);
switch (cond) {
case BoolTest::eq: vmseq_vv(vd, src1, src2, vm); break;
case BoolTest::ne: vmsne_vv(vd, src1, src2, vm); break;
case BoolTest::le: vmsle_vv(vd, src1, src2, vm); break;
case BoolTest::ge: vmsge_vv(vd, src1, src2, vm); break;
case BoolTest::lt: vmslt_vv(vd, src1, src2, vm); break;
case BoolTest::gt: vmsgt_vv(vd, src1, src2, vm); break;
default:
assert(false, "unsupported compare condition");
ShouldNotReachHere();
}
}
void C2_MacroAssembler::compare_floating_point_v(VectorRegister vd, BasicType bt, int length_in_bytes,
VectorRegister src1, VectorRegister src2,
VectorRegister tmp1, VectorRegister tmp2,
VectorRegister vmask, int cond, VectorMask vm) {
assert(is_floating_point_type(bt), "unsupported element type");
assert(vd != v0, "should be different registers");
assert(vm == Assembler::v0_t ? vmask != v0 : true, "vmask should not be v0");
rvv_vsetvli(bt, length_in_bytes);
// Check vector elements of src1 and src2 for quiet or signaling NaN.
vfclass_v(tmp1, src1);
vfclass_v(tmp2, src2);
vsrl_vi(tmp1, tmp1, 8);
vsrl_vi(tmp2, tmp2, 8);
vmseq_vx(tmp1, tmp1, zr);
vmseq_vx(tmp2, tmp2, zr);
if (vm == Assembler::v0_t) {
vmand_mm(tmp2, tmp1, tmp2);
if (cond == BoolTest::ne) {
vmandn_mm(tmp1, vmask, tmp2);
}
vmand_mm(v0, vmask, tmp2);
} else {
vmand_mm(v0, tmp1, tmp2);
if (cond == BoolTest::ne) {
vmnot_m(tmp1, v0);
}
}
vmclr_m(vd);
switch (cond) {
case BoolTest::eq: vmfeq_vv(vd, src1, src2, Assembler::v0_t); break;
case BoolTest::ne: vmfne_vv(vd, src1, src2, Assembler::v0_t);
vmor_mm(vd, vd, tmp1); break;
case BoolTest::le: vmfle_vv(vd, src1, src2, Assembler::v0_t); break;
case BoolTest::ge: vmfge_vv(vd, src1, src2, Assembler::v0_t); break;
case BoolTest::lt: vmflt_vv(vd, src1, src2, Assembler::v0_t); break;
case BoolTest::gt: vmfgt_vv(vd, src1, src2, Assembler::v0_t); break;
default:
assert(false, "unsupported compare condition");
ShouldNotReachHere();
}
}

View File

@ -137,13 +137,15 @@
vl1re8_v(v, t0);
}
void spill_copy_vector_stack_to_stack(int src_offset, int dst_offset, int vec_reg_size_in_bytes) {
assert(vec_reg_size_in_bytes % 16 == 0, "unexpected vector reg size");
unspill(v0, src_offset);
spill(v0, dst_offset);
void spill_copy_vector_stack_to_stack(int src_offset, int dst_offset, int vector_length_in_bytes) {
assert(vector_length_in_bytes % 16 == 0, "unexpected vector reg size");
for (int i = 0; i < vector_length_in_bytes / 8; i++) {
unspill(t0, true, src_offset + (i * 8));
spill(t0, true, dst_offset + (i * 8));
}
}
void minmax_FD(FloatRegister dst,
void minmax_fp(FloatRegister dst,
FloatRegister src1, FloatRegister src2,
bool is_double, bool is_min);
@ -183,11 +185,11 @@
Register tmp1, Register tmp2,
bool isL);
void minmax_FD_v(VectorRegister dst,
void minmax_fp_v(VectorRegister dst,
VectorRegister src1, VectorRegister src2,
bool is_double, bool is_min, int length_in_bytes);
void reduce_minmax_FD_v(FloatRegister dst,
void reduce_minmax_fp_v(FloatRegister dst,
FloatRegister src1, VectorRegister src2,
VectorRegister tmp1, VectorRegister tmp2,
bool is_double, bool is_min, int length_in_bytes);
@ -198,4 +200,34 @@
void rvv_vsetvli(BasicType bt, int length_in_bytes, Register tmp = t0);
void compare_integral_v(VectorRegister dst, BasicType bt, int length_in_bytes,
VectorRegister src1, VectorRegister src2, int cond, VectorMask vm = Assembler::unmasked);
void compare_floating_point_v(VectorRegister dst, BasicType bt, int length_in_bytes,
VectorRegister src1, VectorRegister src2, VectorRegister tmp1, VectorRegister tmp2,
VectorRegister vmask, int cond, VectorMask vm = Assembler::unmasked);
// In Matcher::scalable_predicate_reg_slots,
// we assume each predicate register is one-eighth of the size of
// scalable vector register, one mask bit per vector byte.
void spill_vmask(VectorRegister v, int offset){
rvv_vsetvli(T_BYTE, MaxVectorSize >> 3);
add(t0, sp, offset);
vse8_v(v, t0);
}
void unspill_vmask(VectorRegister v, int offset){
rvv_vsetvli(T_BYTE, MaxVectorSize >> 3);
add(t0, sp, offset);
vle8_v(v, t0);
}
void spill_copy_vmask_stack_to_stack(int src_offset, int dst_offset, int vector_length_in_bytes) {
assert(vector_length_in_bytes % 4 == 0, "unexpected vector mask reg size");
for (int i = 0; i < vector_length_in_bytes / 4; i++) {
unspill(t0, false, src_offset + (i * 4));
spill(t0, false, dst_offset + (i * 4));
}
}
#endif // CPU_RISCV_C2_MACROASSEMBLER_RISCV_HPP

View File

@ -1264,7 +1264,7 @@ public:
vmnand_mm(vd, vs, vs);
}
inline void vncvt_x_x_w(VectorRegister vd, VectorRegister vs, VectorMask vm) {
inline void vncvt_x_x_w(VectorRegister vd, VectorRegister vs, VectorMask vm = unmasked) {
vnsrl_wx(vd, vs, x0, vm);
}
@ -1276,6 +1276,45 @@ public:
vfsgnjn_vv(vd, vs, vs);
}
inline void vmsgt_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmslt_vv(vd, vs1, vs2, vm);
}
inline void vmsgtu_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmsltu_vv(vd, vs1, vs2, vm);
}
inline void vmsge_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmsle_vv(vd, vs1, vs2, vm);
}
inline void vmsgeu_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmsleu_vv(vd, vs1, vs2, vm);
}
inline void vmfgt_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmflt_vv(vd, vs1, vs2, vm);
}
inline void vmfge_vv(VectorRegister vd, VectorRegister vs2, VectorRegister vs1, VectorMask vm = unmasked) {
vmfle_vv(vd, vs1, vs2, vm);
}
// Copy mask register
inline void vmmv_m(VectorRegister vd, VectorRegister vs) {
vmand_mm(vd, vs, vs);
}
// Clear mask register
inline void vmclr_m(VectorRegister vd) {
vmxor_mm(vd, vd, vd);
}
// Set mask register
inline void vmset_m(VectorRegister vd) {
vmxnor_mm(vd, vd, vd);
}
static const int zero_words_block_size;
void cast_primitive_type(BasicType type, Register Rt) {

View File

@ -149,7 +149,7 @@
// Some microarchitectures have mask registers used on vectors
static const bool has_predicated_vectors(void) {
return false;
return UseRVV;
}
// true means we have fast l2f conversion

View File

@ -830,7 +830,8 @@ reg_class double_reg(
F31, F31_H
);
// Class for all RVV vector registers
// Class for RVV vector registers
// Note: v0, v30 and v31 are used as mask registers.
reg_class vectora_reg(
V1, V1_H, V1_J, V1_K,
V2, V2_H, V2_J, V2_K,
@ -860,9 +861,7 @@ reg_class vectora_reg(
V26, V26_H, V26_J, V26_K,
V27, V27_H, V27_J, V27_K,
V28, V28_H, V28_J, V28_K,
V29, V29_H, V29_J, V29_K,
V30, V30_H, V30_J, V30_K,
V31, V31_H, V31_J, V31_K
V29, V29_H, V29_J, V29_K
);
// Class for 64 bit register f0
@ -912,6 +911,23 @@ reg_class v5_reg(
// class for condition codes
reg_class reg_flags(RFLAGS);
// Class for RVV v0 mask register
// https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#53-vector-masking
// The mask value used to control execution of a masked vector
// instruction is always supplied by vector register v0.
reg_class vmask_reg_v0 (
V0
);
// Class for RVV mask registers
// We need two more vmask registers to do the vector mask logical ops,
// so define v30, v31 as mask register too.
reg_class vmask_reg (
V0,
V30,
V31
);
%}
//----------DEFINITION BLOCK---------------------------------------------------
@ -1522,7 +1538,7 @@ uint MachSpillCopyNode::implementation(CodeBuffer *cbuf, PhaseRegAlloc *ra_, boo
assert(src_lo != OptoReg::Bad && dst_lo != OptoReg::Bad, "must move at least 1 register");
if (src_hi != OptoReg::Bad) {
if (src_hi != OptoReg::Bad && !bottom_type()->isa_vectmask()) {
assert((src_lo & 1) == 0 && src_lo + 1 == src_hi &&
(dst_lo & 1) == 0 && dst_lo + 1 == dst_hi,
"expected aligned-adjacent pairs");
@ -1558,6 +1574,25 @@ uint MachSpillCopyNode::implementation(CodeBuffer *cbuf, PhaseRegAlloc *ra_, boo
} else {
ShouldNotReachHere();
}
} else if (bottom_type()->isa_vectmask() && cbuf) {
C2_MacroAssembler _masm(cbuf);
int vmask_size_in_bytes = Matcher::scalable_predicate_reg_slots() * 32 / 8;
if (src_lo_rc == rc_stack && dst_lo_rc == rc_stack) {
// stack to stack
__ spill_copy_vmask_stack_to_stack(src_offset, dst_offset,
vmask_size_in_bytes);
} else if (src_lo_rc == rc_vector && dst_lo_rc == rc_stack) {
// vmask to stack
__ spill_vmask(as_VectorRegister(Matcher::_regEncode[src_lo]), ra_->reg2offset(dst_lo));
} else if (src_lo_rc == rc_stack && dst_lo_rc == rc_vector) {
// stack to vmask
__ unspill_vmask(as_VectorRegister(Matcher::_regEncode[dst_lo]), ra_->reg2offset(src_lo));
} else if (src_lo_rc == rc_vector && dst_lo_rc == rc_vector) {
// vmask to vmask
__ vmv1r_v(as_VectorRegister(Matcher::_regEncode[dst_lo]), as_VectorRegister(Matcher::_regEncode[src_lo]));
} else {
ShouldNotReachHere();
}
}
} else if (cbuf != NULL) {
C2_MacroAssembler _masm(cbuf);
@ -1642,7 +1677,7 @@ uint MachSpillCopyNode::implementation(CodeBuffer *cbuf, PhaseRegAlloc *ra_, boo
} else {
st->print("%s", Matcher::regName[dst_lo]);
}
if (bottom_type()->isa_vect() != NULL) {
if (bottom_type()->isa_vect() && !bottom_type()->isa_vectmask()) {
int vsize = 0;
if (ideal_reg() == Op_VecA) {
vsize = Matcher::scalable_vector_reg_size(T_BYTE) * 8;
@ -1650,6 +1685,10 @@ uint MachSpillCopyNode::implementation(CodeBuffer *cbuf, PhaseRegAlloc *ra_, boo
ShouldNotReachHere();
}
st->print("\t# vector spill size = %d", vsize);
} else if (ideal_reg() == Op_RegVectMask) {
assert(Matcher::supports_scalable_vector(), "bad register type for spill");
int vsize = Matcher::scalable_predicate_reg_slots() * 32;
st->print("\t# vmask spill size = %d", vsize);
} else {
st->print("\t# spill size = %d", is64 ? 64 : 32);
}
@ -1863,7 +1902,59 @@ const bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType
}
const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, BasicType bt) {
return false;
if (!UseRVV) {
return false;
}
switch (opcode) {
case Op_AddVB:
case Op_AddVS:
case Op_AddVI:
case Op_AddVL:
case Op_AddVF:
case Op_AddVD:
case Op_SubVB:
case Op_SubVS:
case Op_SubVI:
case Op_SubVL:
case Op_SubVF:
case Op_SubVD:
case Op_MulVB:
case Op_MulVS:
case Op_MulVI:
case Op_MulVL:
case Op_MulVF:
case Op_MulVD:
case Op_DivVF:
case Op_DivVD:
case Op_VectorLoadMask:
case Op_VectorMaskCmp:
case Op_AndVMask:
case Op_XorVMask:
case Op_OrVMask:
case Op_RShiftVB:
case Op_RShiftVS:
case Op_RShiftVI:
case Op_RShiftVL:
case Op_LShiftVB:
case Op_LShiftVS:
case Op_LShiftVI:
case Op_LShiftVL:
case Op_URShiftVB:
case Op_URShiftVS:
case Op_URShiftVI:
case Op_URShiftVL:
case Op_VectorBlend:
break;
case Op_LoadVector:
opcode = Op_LoadVectorMasked;
break;
case Op_StoreVector:
opcode = Op_StoreVectorMasked;
break;
default:
return false;
}
return match_rule_supported_vector(opcode, vlen, bt);
}
const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
@ -1875,11 +1966,11 @@ const bool Matcher::vector_needs_load_shuffle(BasicType elem_bt, int vlen) {
}
const RegMask* Matcher::predicate_reg_mask(void) {
return NULL;
return &_VMASK_REG_mask;
}
const TypeVectMask* Matcher::predicate_reg_type(const Type* elemTy, int length) {
return NULL;
return new TypeVectMask(elemTy, length);
}
// Vector calling convention not yet implemented.
@ -3556,6 +3647,28 @@ operand vReg_V5()
interface(REG_INTER);
%}
operand vRegMask()
%{
constraint(ALLOC_IN_RC(vmask_reg));
match(RegVectMask);
match(vRegMask_V0);
op_cost(0);
format %{ %}
interface(REG_INTER);
%}
// The mask value used to control execution of a masked
// vector instruction is always supplied by vector register v0.
operand vRegMask_V0()
%{
constraint(ALLOC_IN_RC(vmask_reg_v0));
match(RegVectMask);
match(vRegMask);
op_cost(0);
format %{ %}
interface(REG_INTER);
%}
// Java Thread Register
operand javaThread_RegP(iRegP reg)
%{
@ -7271,7 +7384,7 @@ instruct maxF_reg_reg(fRegF dst, fRegF src1, fRegF src2, rFlagsReg cr) %{
format %{ "maxF $dst, $src1, $src2" %}
ins_encode %{
__ minmax_FD(as_FloatRegister($dst$$reg),
__ minmax_fp(as_FloatRegister($dst$$reg),
as_FloatRegister($src1$$reg), as_FloatRegister($src2$$reg),
false /* is_double */, false /* is_min */);
%}
@ -7287,7 +7400,7 @@ instruct minF_reg_reg(fRegF dst, fRegF src1, fRegF src2, rFlagsReg cr) %{
format %{ "minF $dst, $src1, $src2" %}
ins_encode %{
__ minmax_FD(as_FloatRegister($dst$$reg),
__ minmax_fp(as_FloatRegister($dst$$reg),
as_FloatRegister($src1$$reg), as_FloatRegister($src2$$reg),
false /* is_double */, true /* is_min */);
%}
@ -7303,7 +7416,7 @@ instruct maxD_reg_reg(fRegD dst, fRegD src1, fRegD src2, rFlagsReg cr) %{
format %{ "maxD $dst, $src1, $src2" %}
ins_encode %{
__ minmax_FD(as_FloatRegister($dst$$reg),
__ minmax_fp(as_FloatRegister($dst$$reg),
as_FloatRegister($src1$$reg), as_FloatRegister($src2$$reg),
true /* is_double */, false /* is_min */);
%}
@ -7319,7 +7432,7 @@ instruct minD_reg_reg(fRegD dst, fRegD src1, fRegD src2, rFlagsReg cr) %{
format %{ "minD $dst, $src1, $src2" %}
ins_encode %{
__ minmax_FD(as_FloatRegister($dst$$reg),
__ minmax_fp(as_FloatRegister($dst$$reg),
as_FloatRegister($src1$$reg), as_FloatRegister($src2$$reg),
true /* is_double */, true /* is_min */);
%}

File diff suppressed because it is too large Load Diff