8328404: RISC-V: Fix potential crash in C2_MacroAssembler::arrays_equals

Reviewed-by: fyang
This commit is contained in:
Gui Cao 2024-03-25 01:18:50 +00:00 committed by Fei Yang
parent bc73963974
commit c7b9dc463a
3 changed files with 99 additions and 116 deletions

View File

@ -1536,21 +1536,20 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
BLOCK_COMMENT("} string_compare"); BLOCK_COMMENT("} string_compare");
} }
void C2_MacroAssembler::arrays_equals(Register a1, Register a2, Register tmp3, void C2_MacroAssembler::arrays_equals(Register a1, Register a2,
Register tmp4, Register tmp5, Register tmp6, Register result, Register tmp1, Register tmp2, Register tmp3,
Register cnt1, int elem_size) { Register result, int elem_size) {
Label DONE, SAME, NEXT_DWORD, SHORT, TAIL, TAIL2, IS_TMP5_ZR; assert(elem_size == 1 || elem_size == 2, "must be char or byte");
Register tmp1 = t0; assert_different_registers(a1, a2, result, tmp1, tmp2, tmp3, t0);
Register tmp2 = t1;
Register cnt2 = tmp2; // cnt2 only used in array length compare int elem_per_word = wordSize/elem_size;
Register elem_per_word = tmp6;
int log_elem_size = exact_log2(elem_size); int log_elem_size = exact_log2(elem_size);
int length_offset = arrayOopDesc::length_offset_in_bytes(); int length_offset = arrayOopDesc::length_offset_in_bytes();
int base_offset = arrayOopDesc::base_offset_in_bytes(elem_size == 2 ? T_CHAR : T_BYTE); int base_offset = arrayOopDesc::base_offset_in_bytes(elem_size == 2 ? T_CHAR : T_BYTE);
assert(elem_size == 1 || elem_size == 2, "must be char or byte"); Register cnt1 = tmp3;
assert_different_registers(a1, a2, result, cnt1, t0, t1, tmp3, tmp4, tmp5, tmp6); Register cnt2 = tmp1; // cnt2 only used in array length compare
mv(elem_per_word, wordSize / elem_size); Label DONE, SAME, NEXT_WORD, SHORT, TAIL03, TAIL01;
BLOCK_COMMENT("arrays_equals {"); BLOCK_COMMENT("arrays_equals {");
@ -1558,71 +1557,84 @@ void C2_MacroAssembler::arrays_equals(Register a1, Register a2, Register tmp3,
beq(a1, a2, SAME); beq(a1, a2, SAME);
mv(result, false); mv(result, false);
// if (a1 == nullptr || a2 == nullptr)
// return false;
beqz(a1, DONE); beqz(a1, DONE);
beqz(a2, DONE); beqz(a2, DONE);
// if (a1.length != a2.length)
// return false;
lwu(cnt1, Address(a1, length_offset)); lwu(cnt1, Address(a1, length_offset));
lwu(cnt2, Address(a2, length_offset)); lwu(cnt2, Address(a2, length_offset));
bne(cnt2, cnt1, DONE); bne(cnt1, cnt2, DONE);
beqz(cnt1, SAME);
slli(tmp5, cnt1, 3 + log_elem_size); la(a1, Address(a1, base_offset));
sub(tmp5, zr, tmp5); la(a2, Address(a2, base_offset));
add(a1, a1, base_offset); // Check for short strings, i.e. smaller than wordSize.
add(a2, a2, base_offset); addi(cnt1, cnt1, -elem_per_word);
ld(tmp3, Address(a1, 0)); bltz(cnt1, SHORT);
ld(tmp4, Address(a2, 0));
ble(cnt1, elem_per_word, SHORT); // short or same
// Main 16 byte comparison loop with 2 exits // Main 8 byte comparison loop.
bind(NEXT_DWORD); { bind(NEXT_WORD); {
ld(tmp1, Address(a1, wordSize)); ld(tmp1, Address(a1));
ld(tmp2, Address(a2, wordSize)); ld(tmp2, Address(a2));
sub(cnt1, cnt1, 2 * wordSize / elem_size); addi(cnt1, cnt1, -elem_per_word);
blez(cnt1, TAIL); addi(a1, a1, wordSize);
bne(tmp3, tmp4, DONE); addi(a2, a2, wordSize);
ld(tmp3, Address(a1, 2 * wordSize));
ld(tmp4, Address(a2, 2 * wordSize));
add(a1, a1, 2 * wordSize);
add(a2, a2, 2 * wordSize);
ble(cnt1, elem_per_word, TAIL2);
} beq(tmp1, tmp2, NEXT_DWORD);
j(DONE);
bind(TAIL);
xorr(tmp4, tmp3, tmp4);
xorr(tmp2, tmp1, tmp2);
sll(tmp2, tmp2, tmp5);
orr(tmp5, tmp4, tmp2);
j(IS_TMP5_ZR);
bind(TAIL2);
bne(tmp1, tmp2, DONE); bne(tmp1, tmp2, DONE);
} bgez(cnt1, NEXT_WORD);
addi(tmp1, cnt1, elem_per_word);
beqz(tmp1, SAME);
bind(SHORT); bind(SHORT);
xorr(tmp4, tmp3, tmp4); test_bit(tmp1, cnt1, 2 - log_elem_size);
sll(tmp5, tmp4, tmp5); beqz(tmp1, TAIL03); // 0-7 bytes left.
{
lwu(tmp1, Address(a1));
lwu(tmp2, Address(a2));
addi(a1, a1, 4);
addi(a2, a2, 4);
bne(tmp1, tmp2, DONE);
}
bind(IS_TMP5_ZR); bind(TAIL03);
bnez(tmp5, DONE); test_bit(tmp1, cnt1, 1 - log_elem_size);
beqz(tmp1, TAIL01); // 0-3 bytes left.
{
lhu(tmp1, Address(a1));
lhu(tmp2, Address(a2));
addi(a1, a1, 2);
addi(a2, a2, 2);
bne(tmp1, tmp2, DONE);
}
bind(TAIL01);
if (elem_size == 1) { // Only needed when comparing byte arrays.
test_bit(tmp1, cnt1, 0);
beqz(tmp1, SAME); // 0-1 bytes left.
{
lbu(tmp1, Address(a1));
lbu(tmp2, Address(a2));
bne(tmp1, tmp2, DONE);
}
}
bind(SAME); bind(SAME);
mv(result, true); mv(result, true);
// That's it. // That's it.
bind(DONE); bind(DONE);
BLOCK_COMMENT("} array_equals"); BLOCK_COMMENT("} arrays_equals");
} }
// Compare Strings // Compare Strings
// For Strings we're passed the address of the first characters in a1 // For Strings we're passed the address of the first characters in a1 and a2
// and a2 and the length in cnt1. // and the length in cnt1. There are two implementations.
// There are two implementations. For arrays >= 8 bytes, all // For arrays >= 8 bytes, all comparisons (except for the tail) are performed
// comparisons (for hw supporting unaligned access: including the final one, // 8 bytes at a time. For the tail, we compare a halfword, then a short, and then a byte.
// which may overlap) are performed 8 bytes at a time. // For strings < 8 bytes, we compare a halfword, then a short, and then a byte.
// For strings < 8 bytes (and for tails of long strings when
// AvoidUnalignedAccesses is true), we compare a
// halfword, then a short, and then a byte.
void C2_MacroAssembler::string_equals(Register a1, Register a2, void C2_MacroAssembler::string_equals(Register a1, Register a2,
Register result, Register cnt1) Register result, Register cnt1)
@ -1635,39 +1647,24 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
BLOCK_COMMENT("string_equals {"); BLOCK_COMMENT("string_equals {");
beqz(cnt1, SAME);
mv(result, false); mv(result, false);
// Check for short strings, i.e. smaller than wordSize. // Check for short strings, i.e. smaller than wordSize.
sub(cnt1, cnt1, wordSize); addi(cnt1, cnt1, -wordSize);
bltz(cnt1, SHORT); bltz(cnt1, SHORT);
// Main 8 byte comparison loop. // Main 8 byte comparison loop.
bind(NEXT_WORD); { bind(NEXT_WORD); {
ld(tmp1, Address(a1, 0)); ld(tmp1, Address(a1));
add(a1, a1, wordSize); ld(tmp2, Address(a2));
ld(tmp2, Address(a2, 0)); addi(cnt1, cnt1, -wordSize);
add(a2, a2, wordSize); addi(a1, a1, wordSize);
sub(cnt1, cnt1, wordSize); addi(a2, a2, wordSize);
bne(tmp1, tmp2, DONE); bne(tmp1, tmp2, DONE);
} bgez(cnt1, NEXT_WORD); } bgez(cnt1, NEXT_WORD);
if (!AvoidUnalignedAccesses) { addi(tmp1, cnt1, wordSize);
// Last longword. In the case where length == 4 we compare the
// same longword twice, but that's still faster than another
// conditional branch.
// cnt1 could be 0, -1, -2, -3, -4 for chars; -4 only happens when
// length == 4.
add(tmp1, a1, cnt1);
ld(tmp1, Address(tmp1, 0));
add(tmp2, a2, cnt1);
ld(tmp2, Address(tmp2, 0));
bne(tmp1, tmp2, DONE);
j(SAME);
} else {
add(tmp1, cnt1, wordSize);
beqz(tmp1, SAME); beqz(tmp1, SAME);
}
bind(SHORT); bind(SHORT);
Label TAIL03, TAIL01; Label TAIL03, TAIL01;
@ -1676,10 +1673,10 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 2); test_bit(tmp1, cnt1, 2);
beqz(tmp1, TAIL03); beqz(tmp1, TAIL03);
{ {
lwu(tmp1, Address(a1, 0)); lwu(tmp1, Address(a1));
add(a1, a1, 4); lwu(tmp2, Address(a2));
lwu(tmp2, Address(a2, 0)); addi(a1, a1, 4);
add(a2, a2, 4); addi(a2, a2, 4);
bne(tmp1, tmp2, DONE); bne(tmp1, tmp2, DONE);
} }
@ -1688,10 +1685,10 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 1); test_bit(tmp1, cnt1, 1);
beqz(tmp1, TAIL01); beqz(tmp1, TAIL01);
{ {
lhu(tmp1, Address(a1, 0)); lhu(tmp1, Address(a1));
add(a1, a1, 2); lhu(tmp2, Address(a2));
lhu(tmp2, Address(a2, 0)); addi(a1, a1, 2);
add(a2, a2, 2); addi(a2, a2, 2);
bne(tmp1, tmp2, DONE); bne(tmp1, tmp2, DONE);
} }
@ -1700,8 +1697,8 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 0); test_bit(tmp1, cnt1, 0);
beqz(tmp1, SAME); beqz(tmp1, SAME);
{ {
lbu(tmp1, Address(a1, 0)); lbu(tmp1, Address(a1));
lbu(tmp2, Address(a2, 0)); lbu(tmp2, Address(a2));
bne(tmp1, tmp2, DONE); bne(tmp1, tmp2, DONE);
} }

View File

@ -79,16 +79,15 @@
int needle_con_cnt, Register result, int ae); int needle_con_cnt, Register result, int ae);
void arrays_equals(Register r1, Register r2, void arrays_equals(Register r1, Register r2,
Register tmp3, Register tmp4, Register tmp1, Register tmp2, Register tmp3,
Register tmp5, Register tmp6, Register result, int elem_size);
Register result, Register cnt1,
int elem_size);
void arrays_hashcode(Register ary, Register cnt, Register result, void arrays_hashcode(Register ary, Register cnt, Register result,
Register tmp1, Register tmp2, Register tmp1, Register tmp2,
Register tmp3, Register tmp4, Register tmp3, Register tmp4,
Register tmp5, Register tmp6, Register tmp5, Register tmp6,
BasicType eltype); BasicType eltype);
// helper function for arrays_hashcode // helper function for arrays_hashcode
int arrays_hashcode_elsize(BasicType eltype); int arrays_hashcode_elsize(BasicType eltype);
void arrays_hashcode_elload(Register dst, Address src, BasicType eltype); void arrays_hashcode_elload(Register dst, Address src, BasicType eltype);

View File

@ -3286,17 +3286,6 @@ operand iRegP_R15()
interface(REG_INTER); interface(REG_INTER);
%} %}
operand iRegP_R16()
%{
constraint(ALLOC_IN_RC(r16_reg));
match(RegP);
// match(iRegP);
match(iRegPNoSp);
op_cost(0);
format %{ %}
interface(REG_INTER);
%}
// Pointer 64 bit Register R28 only // Pointer 64 bit Register R28 only
operand iRegP_R28() operand iRegP_R28()
%{ %{
@ -10336,35 +10325,33 @@ instruct string_equalsL(iRegP_R11 str1, iRegP_R13 str2, iRegI_R14 cnt,
%} %}
instruct array_equalsB(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result, instruct array_equalsB(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result,
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3, iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3)
iRegP_R16 tmp4, iRegP_R28 tmp5, rFlagsReg cr)
%{ %{
predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::LL); predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::LL);
match(Set result (AryEq ary1 ary2)); match(Set result (AryEq ary1 ary2));
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3, TEMP tmp4, KILL tmp5, KILL cr); effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3);
format %{ "Array Equals $ary1, ary2 -> $result\t#@array_equalsB // KILL $tmp5" %} format %{ "Array Equals $ary1, $ary2 -> $result\t#@array_equalsB // KILL all" %}
ins_encode %{ ins_encode %{
__ arrays_equals($ary1$$Register, $ary2$$Register, __ arrays_equals($ary1$$Register, $ary2$$Register,
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register, $tmp4$$Register, $tmp1$$Register, $tmp2$$Register, $tmp3$$Register,
$result$$Register, $tmp5$$Register, 1); $result$$Register, 1);
%} %}
ins_pipe(pipe_class_memory); ins_pipe(pipe_class_memory);
%} %}
instruct array_equalsC(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result, instruct array_equalsC(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result,
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3, iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3)
iRegP_R16 tmp4, iRegP_R28 tmp5, rFlagsReg cr)
%{ %{
predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::UU); predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::UU);
match(Set result (AryEq ary1 ary2)); match(Set result (AryEq ary1 ary2));
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3, TEMP tmp4, KILL tmp5, KILL cr); effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3);
format %{ "Array Equals $ary1, ary2 -> $result\t#@array_equalsC // KILL $tmp5" %} format %{ "Array Equals $ary1, $ary2 -> $result\t#@array_equalsC // KILL all" %}
ins_encode %{ ins_encode %{
__ arrays_equals($ary1$$Register, $ary2$$Register, __ arrays_equals($ary1$$Register, $ary2$$Register,
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register, $tmp4$$Register, $tmp1$$Register, $tmp2$$Register, $tmp3$$Register,
$result$$Register, $tmp5$$Register, 2); $result$$Register, 2);
%} %}
ins_pipe(pipe_class_memory); ins_pipe(pipe_class_memory);
%} %}