8313760: [REDO] Enhance AES performance

Co-authored-by: Andrew Haley <aph@openjdk.org>
Reviewed-by: adinn, aph, sviswanathan, rhalade, kvn, dlong
This commit is contained in:
Christian Hagedorn 2023-08-16 07:21:04 +00:00
parent d46f0fb318
commit 49ddb19972
7 changed files with 106 additions and 37 deletions

View File

@ -2944,6 +2944,23 @@ class StubGenerator: public StubCodeGenerator {
return start; return start;
} }
// Big-endian 128-bit + 64-bit -> 128-bit addition.
// Inputs: 128-bits. in is preserved.
// The least-significant 64-bit word is in the upper dword of each vector.
// inc (the 64-bit increment) is preserved. Its lower dword must be zero.
// Output: result
void be_add_128_64(FloatRegister result, FloatRegister in,
FloatRegister inc, FloatRegister tmp) {
assert_different_registers(result, tmp, inc);
__ addv(result, __ T2D, in, inc); // Add inc to the least-significant dword of
// input
__ cm(__ HI, tmp, __ T2D, inc, result);// Check for result overflowing
__ ext(tmp, __ T16B, tmp, tmp, 0x08); // Swap LSD of comparison result to MSD and
// MSD == 0 (must be!) to LSD
__ subv(result, __ T2D, result, tmp); // Subtract -1 from MSD if there was an overflow
}
// CTR AES crypt. // CTR AES crypt.
// Arguments: // Arguments:
// //
@ -3053,13 +3070,16 @@ class StubGenerator: public StubCodeGenerator {
// Setup the counter // Setup the counter
__ movi(v4, __ T4S, 0); __ movi(v4, __ T4S, 0);
__ movi(v5, __ T4S, 1); __ movi(v5, __ T4S, 1);
__ ins(v4, __ S, v5, 3, 3); // v4 contains { 0, 0, 0, 1 } __ ins(v4, __ S, v5, 2, 2); // v4 contains { 0, 1 }
__ ld1(v0, __ T16B, counter); // Load the counter into v0 // 128-bit big-endian increment
__ rev32(v16, __ T16B, v0); __ ld1(v0, __ T16B, counter);
__ addv(v16, __ T4S, v16, v4); __ rev64(v16, __ T16B, v0);
__ rev32(v16, __ T16B, v16); be_add_128_64(v16, v16, v4, /*tmp*/v5);
__ st1(v16, __ T16B, counter); // Save the incremented counter back __ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter);
// Previous counter value is in v0
// v4 contains { 0, 1 }
{ {
// We have fewer than bulk_width blocks of data left. Encrypt // We have fewer than bulk_width blocks of data left. Encrypt
@ -3091,9 +3111,9 @@ class StubGenerator: public StubCodeGenerator {
// Increment the counter, store it back // Increment the counter, store it back
__ orr(v0, __ T16B, v16, v16); __ orr(v0, __ T16B, v16, v16);
__ rev32(v16, __ T16B, v16); __ rev64(v16, __ T16B, v16);
__ addv(v16, __ T4S, v16, v4); be_add_128_64(v16, v16, v4, /*tmp*/v5);
__ rev32(v16, __ T16B, v16); __ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter); // Save the incremented counter back __ st1(v16, __ T16B, counter); // Save the incremented counter back
__ b(inner_loop); __ b(inner_loop);
@ -3141,7 +3161,7 @@ class StubGenerator: public StubCodeGenerator {
// Keys should already be loaded into the correct registers // Keys should already be loaded into the correct registers
__ ld1(v0, __ T16B, counter); // v0 contains the first counter __ ld1(v0, __ T16B, counter); // v0 contains the first counter
__ rev32(v16, __ T16B, v0); // v16 contains byte-reversed counter __ rev64(v16, __ T16B, v0); // v16 contains byte-reversed counter
// AES/CTR loop // AES/CTR loop
{ {
@ -3151,12 +3171,12 @@ class StubGenerator: public StubCodeGenerator {
// Setup the counters // Setup the counters
__ movi(v8, __ T4S, 0); __ movi(v8, __ T4S, 0);
__ movi(v9, __ T4S, 1); __ movi(v9, __ T4S, 1);
__ ins(v8, __ S, v9, 3, 3); // v8 contains { 0, 0, 0, 1 } __ ins(v8, __ S, v9, 2, 2); // v8 contains { 0, 1 }
for (int i = 0; i < bulk_width; i++) { for (int i = 0; i < bulk_width; i++) {
FloatRegister v0_ofs = as_FloatRegister(v0->encoding() + i); FloatRegister v0_ofs = as_FloatRegister(v0->encoding() + i);
__ rev32(v0_ofs, __ T16B, v16); __ rev64(v0_ofs, __ T16B, v16);
__ addv(v16, __ T4S, v16, v8); be_add_128_64(v16, v16, v8, /*tmp*/v9);
} }
__ ld1(v8, v9, v10, v11, __ T16B, __ post(in, 4 * 16)); __ ld1(v8, v9, v10, v11, __ T16B, __ post(in, 4 * 16));
@ -3186,7 +3206,7 @@ class StubGenerator: public StubCodeGenerator {
} }
// Save the counter back where it goes // Save the counter back where it goes
__ rev32(v16, __ T16B, v16); __ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter); __ st1(v16, __ T16B, counter);
__ pop(saved_regs, sp); __ pop(saved_regs, sp);

View File

@ -4431,6 +4431,14 @@ void Assembler::evpcmpuw(KRegister kdst, XMMRegister nds, XMMRegister src, Compa
emit_int24(0x3E, (0xC0 | encode), vcc); emit_int24(0x3E, (0xC0 | encode), vcc);
} }
void Assembler::evpcmpuq(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len) {
assert(VM_Version::supports_avx512vl(), "");
InstructionAttr attributes(vector_len, /* rex_w */ true, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
attributes.set_is_evex_instruction();
int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes);
emit_int24(0x1E, (0xC0 | encode), vcc);
}
void Assembler::evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len) { void Assembler::evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len) {
assert(VM_Version::supports_avx512vlbw(), ""); assert(VM_Version::supports_avx512vlbw(), "");
InstructionMark im(this); InstructionMark im(this);

View File

@ -1806,6 +1806,8 @@ private:
void evpcmpuw(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len); void evpcmpuw(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len);
void evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len); void evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len);
void evpcmpuq(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len);
void pcmpeqw(XMMRegister dst, XMMRegister src); void pcmpeqw(XMMRegister dst, XMMRegister src);
void vpcmpeqw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); void vpcmpeqw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len);
void evpcmpeqw(KRegister kdst, XMMRegister nds, XMMRegister src, int vector_len); void evpcmpeqw(KRegister kdst, XMMRegister nds, XMMRegister src, int vector_len);

View File

@ -9257,6 +9257,17 @@ void MacroAssembler::evpandq(XMMRegister dst, XMMRegister nds, AddressLiteral sr
} }
} }
void MacroAssembler::evpaddq(XMMRegister dst, KRegister mask, XMMRegister nds, AddressLiteral src, bool merge, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");
if (reachable(src)) {
Assembler::evpaddq(dst, mask, nds, as_Address(src), merge, vector_len);
} else {
lea(rscratch, src);
Assembler::evpaddq(dst, mask, nds, Address(rscratch, 0), merge, vector_len);
}
}
void MacroAssembler::evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch) { void MacroAssembler::evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing"); assert(rscratch != noreg || always_reachable(src), "missing");

View File

@ -1788,6 +1788,9 @@ public:
using Assembler::evpandq; using Assembler::evpandq;
void evpandq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg); void evpandq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg);
using Assembler::evpaddq;
void evpaddq(XMMRegister dst, KRegister mask, XMMRegister nds, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);
using Assembler::evporq; using Assembler::evporq;
void evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg); void evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg);

View File

@ -364,7 +364,8 @@ class StubGenerator: public StubCodeGenerator {
// Utility routine for increase 128bit counter (iv in CTR mode) // Utility routine for increase 128bit counter (iv in CTR mode)
void inc_counter(Register reg, XMMRegister xmmdst, int inc_delta, Label& next_block); void inc_counter(Register reg, XMMRegister xmmdst, int inc_delta, Label& next_block);
void ev_add128(XMMRegister xmmdst, XMMRegister xmmsrc1, XMMRegister xmmsrc2,
int vector_len, KRegister ktmp, Register rscratch = noreg);
void generate_aes_stubs(); void generate_aes_stubs();

View File

@ -121,6 +121,16 @@ static address counter_mask_linc32_addr() {
return (address)COUNTER_MASK_LINC32; return (address)COUNTER_MASK_LINC32;
} }
ATTRIBUTE_ALIGNED(64) uint64_t COUNTER_MASK_ONES[] = {
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
};
static address counter_mask_ones_addr() {
return (address)COUNTER_MASK_ONES;
}
ATTRIBUTE_ALIGNED(64) static const uint64_t GHASH_POLYNOMIAL_REDUCTION[] = { ATTRIBUTE_ALIGNED(64) static const uint64_t GHASH_POLYNOMIAL_REDUCTION[] = {
0x00000001C2000000UL, 0xC200000000000000UL, 0x00000001C2000000UL, 0xC200000000000000UL,
0x00000001C2000000UL, 0xC200000000000000UL, 0x00000001C2000000UL, 0xC200000000000000UL,
@ -1623,6 +1633,17 @@ void StubGenerator::ev_load_key(XMMRegister xmmdst, Register key, int offset, Re
__ evshufi64x2(xmmdst, xmmdst, xmmdst, 0x0, Assembler::AVX_512bit); __ evshufi64x2(xmmdst, xmmdst, xmmdst, 0x0, Assembler::AVX_512bit);
} }
// Add 128-bit integers in xmmsrc1 to xmmsrc2, then place the result in xmmdst.
// Clobber ktmp and rscratch.
// Used by aesctr_encrypt.
void StubGenerator::ev_add128(XMMRegister xmmdst, XMMRegister xmmsrc1, XMMRegister xmmsrc2,
int vector_len, KRegister ktmp, Register rscratch) {
__ vpaddq(xmmdst, xmmsrc1, xmmsrc2, vector_len);
__ evpcmpuq(ktmp, xmmdst, xmmsrc2, __ lt, vector_len);
__ kshiftlbl(ktmp, ktmp, 1);
__ evpaddq(xmmdst, ktmp, xmmdst, ExternalAddress(counter_mask_ones_addr()), /*merge*/true,
vector_len, rscratch);
}
// AES-ECB Encrypt Operation // AES-ECB Encrypt Operation
void StubGenerator::aesecb_encrypt(Register src_addr, Register dest_addr, Register key, Register len) { void StubGenerator::aesecb_encrypt(Register src_addr, Register dest_addr, Register key, Register len) {
@ -2046,7 +2067,6 @@ void StubGenerator::aesecb_decrypt(Register src_addr, Register dest_addr, Regist
} }
// AES Counter Mode using VAES instructions // AES Counter Mode using VAES instructions
void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Register key, Register counter, void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Register key, Register counter,
Register len_reg, Register used, Register used_addr, Register saved_encCounter_start) { Register len_reg, Register used, Register used_addr, Register saved_encCounter_start) {
@ -2104,14 +2124,17 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// The counter is incremented after each block i.e. 16 bytes is processed; // The counter is incremented after each block i.e. 16 bytes is processed;
// each zmm register has 4 counter values as its MSB // each zmm register has 4 counter values as its MSB
// the counters are incremented in parallel // the counters are incremented in parallel
__ vpaddd(xmm8, xmm8, ExternalAddress(counter_mask_linc0_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm9, xmm8, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); __ evmovdquq(xmm19, ExternalAddress(counter_mask_linc0_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm10, xmm9, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vpaddd(xmm11, xmm10, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); __ evmovdquq(xmm19, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm12, xmm11, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); ev_add128(xmm9, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vpaddd(xmm13, xmm12, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); ev_add128(xmm10, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vpaddd(xmm14, xmm13, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); ev_add128(xmm11, xmm10, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vpaddd(xmm15, xmm14, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); ev_add128(xmm12, xmm11, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm13, xmm12, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm14, xmm13, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm15, xmm14, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
// load linc32 mask in zmm register.linc32 increments counter by 32 // load linc32 mask in zmm register.linc32 increments counter by 32
__ evmovdquq(xmm19, ExternalAddress(counter_mask_linc32_addr()), Assembler::AVX_512bit, r15 /*rscratch*/); __ evmovdquq(xmm19, ExternalAddress(counter_mask_linc32_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
@ -2159,21 +2182,21 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// This is followed by incrementing counter values in zmm8-zmm15. // This is followed by incrementing counter values in zmm8-zmm15.
// Since we will be processing 32 blocks at a time, the counter is incremented by 32. // Since we will be processing 32 blocks at a time, the counter is incremented by 32.
roundEnc(xmm21, 7); roundEnc(xmm21, 7);
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm22, 7); roundEnc(xmm22, 7);
__ vpaddq(xmm9, xmm9, xmm19, Assembler::AVX_512bit); ev_add128(xmm9, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm23, 7); roundEnc(xmm23, 7);
__ vpaddq(xmm10, xmm10, xmm19, Assembler::AVX_512bit); ev_add128(xmm10, xmm10, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm24, 7); roundEnc(xmm24, 7);
__ vpaddq(xmm11, xmm11, xmm19, Assembler::AVX_512bit); ev_add128(xmm11, xmm11, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm25, 7); roundEnc(xmm25, 7);
__ vpaddq(xmm12, xmm12, xmm19, Assembler::AVX_512bit); ev_add128(xmm12, xmm12, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm26, 7); roundEnc(xmm26, 7);
__ vpaddq(xmm13, xmm13, xmm19, Assembler::AVX_512bit); ev_add128(xmm13, xmm13, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm27, 7); roundEnc(xmm27, 7);
__ vpaddq(xmm14, xmm14, xmm19, Assembler::AVX_512bit); ev_add128(xmm14, xmm14, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm28, 7); roundEnc(xmm28, 7);
__ vpaddq(xmm15, xmm15, xmm19, Assembler::AVX_512bit); ev_add128(xmm15, xmm15, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm29, 7); roundEnc(xmm29, 7);
__ cmpl(rounds, 52); __ cmpl(rounds, 52);
@ -2251,8 +2274,8 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ vpshufb(xmm3, xmm11, xmm16, Assembler::AVX_512bit); __ vpshufb(xmm3, xmm11, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm3, xmm3, xmm20, Assembler::AVX_512bit); __ evpxorq(xmm3, xmm3, xmm20, Assembler::AVX_512bit);
// Increment counter values by 16 // Increment counter values by 16
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vpaddq(xmm9, xmm9, xmm19, Assembler::AVX_512bit); ev_add128(xmm9, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
// AES encode rounds // AES encode rounds
roundEnc(xmm21, 3); roundEnc(xmm21, 3);
roundEnc(xmm22, 3); roundEnc(xmm22, 3);
@ -2319,7 +2342,7 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ vpshufb(xmm1, xmm9, xmm16, Assembler::AVX_512bit); __ vpshufb(xmm1, xmm9, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm1, xmm1, xmm20, Assembler::AVX_512bit); __ evpxorq(xmm1, xmm1, xmm20, Assembler::AVX_512bit);
// increment counter by 8 // increment counter by 8
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
// AES encode // AES encode
roundEnc(xmm21, 1); roundEnc(xmm21, 1);
roundEnc(xmm22, 1); roundEnc(xmm22, 1);
@ -2376,8 +2399,9 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// XOR counter with first roundkey // XOR counter with first roundkey
__ vpshufb(xmm0, xmm8, xmm16, Assembler::AVX_512bit); __ vpshufb(xmm0, xmm8, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_512bit); __ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_512bit);
// Increment counter // Increment counter
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_512bit); __ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_512bit);
__ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_512bit); __ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_512bit);
__ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_512bit); __ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_512bit);
@ -2427,7 +2451,7 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_128bit); __ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_128bit); __ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_128bit);
// Increment counter by 1 // Increment counter by 1
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_128bit); ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_128bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_128bit); __ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_128bit); __ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm24, Assembler::AVX_128bit); __ vaesenc(xmm0, xmm0, xmm24, Assembler::AVX_128bit);