8300808: Accelerate Base64 on x86 for AVX2

Reviewed-by: jbhateja, redestad, sviswanathan
This commit is contained in:
Scott Gibbons 2023-02-15 09:26:10 +00:00 committed by Claes Redestad
parent 46bcc4901e
commit 33bec20710
7 changed files with 229 additions and 26 deletions

View File

@ -1642,7 +1642,6 @@ address StubGenerator::generate_base64_encodeBlock()
// calculate length from offsets
__ movl(length, end_offset);
__ subl(length, start_offset);
__ cmpl(length, 0);
__ jcc(Assembler::lessEqual, L_exit);
// Code for 512-bit VBMI encoding. Encodes 48 input bytes into 64
@ -1685,8 +1684,7 @@ address StubGenerator::generate_base64_encodeBlock()
}
__ BIND(L_not512);
if (VM_Version::supports_avx2()
&& VM_Version::supports_avx512vlbw()) {
if (VM_Version::supports_avx2()) {
/*
** This AVX2 encoder is based off the paper at:
** https://dl.acm.org/doi/10.1145/3132709
@ -1703,15 +1701,17 @@ address StubGenerator::generate_base64_encodeBlock()
__ vmovdqu(xmm9, ExternalAddress(StubRoutines::x86::base64_avx2_shuffle_addr()), rax);
// 6-bit mask for 2nd and 4th (and multiples) 6-bit values
__ movl(rax, 0x0fc0fc00);
__ movdl(xmm8, rax);
__ vmovdqu(xmm1, ExternalAddress(StubRoutines::x86::base64_avx2_input_mask_addr()), rax);
__ evpbroadcastd(xmm8, rax, Assembler::AVX_256bit);
__ vpbroadcastd(xmm8, xmm8, Assembler::AVX_256bit);
// Multiplication constant for "shifting" right by 6 and 10
// bits
__ movl(rax, 0x04000040);
__ subl(length, 24);
__ evpbroadcastd(xmm7, rax, Assembler::AVX_256bit);
__ movdl(xmm7, rax);
__ vpbroadcastd(xmm7, xmm7, Assembler::AVX_256bit);
// For the first load, we mask off reading of the first 4
// bytes into the register. This is so we can get 4 3-byte
@ -1813,19 +1813,23 @@ address StubGenerator::generate_base64_encodeBlock()
// Load masking register for first and third (and multiples)
// 6-bit values.
__ movl(rax, 0x003f03f0);
__ evpbroadcastd(xmm6, rax, Assembler::AVX_256bit);
__ movdl(xmm6, rax);
__ vpbroadcastd(xmm6, xmm6, Assembler::AVX_256bit);
// Multiplication constant for "shifting" left by 4 and 8 bits
__ movl(rax, 0x01000010);
__ evpbroadcastd(xmm5, rax, Assembler::AVX_256bit);
__ movdl(xmm5, rax);
__ vpbroadcastd(xmm5, xmm5, Assembler::AVX_256bit);
// Isolate 6-bit chunks of interest
__ vpand(xmm0, xmm8, xmm1, Assembler::AVX_256bit);
// Load constants for encoding
__ movl(rax, 0x19191919);
__ evpbroadcastd(xmm3, rax, Assembler::AVX_256bit);
__ movdl(xmm3, rax);
__ vpbroadcastd(xmm3, xmm3, Assembler::AVX_256bit);
__ movl(rax, 0x33333333);
__ evpbroadcastd(xmm4, rax, Assembler::AVX_256bit);
__ movdl(xmm4, rax);
__ vpbroadcastd(xmm4, xmm4, Assembler::AVX_256bit);
// Shift output bytes 0 and 2 into proper lanes
__ vpmulhuw(xmm2, xmm0, xmm7, Assembler::AVX_256bit);
@ -2133,6 +2137,80 @@ address StubGenerator::base64_vbmi_join_2_3_addr() {
return start;
}
address StubGenerator::base64_AVX2_decode_tables_addr() {
__ align64();
StubCodeMark mark(this, "StubRoutines", "AVX2_tables_base64");
address start = __ pc();
assert(((unsigned long long)start & 0x3f) == 0,
"Alignment problem (0x%08llx)", (unsigned long long)start);
__ emit_data(0x2f2f2f2f, relocInfo::none, 0);
__ emit_data(0x5f5f5f5f, relocInfo::none, 0); // for URL
__ emit_data(0xffffffff, relocInfo::none, 0);
__ emit_data(0xfcfcfcfc, relocInfo::none, 0); // for URL
// Permute table
__ emit_data64(0x0000000100000000, relocInfo::none);
__ emit_data64(0x0000000400000002, relocInfo::none);
__ emit_data64(0x0000000600000005, relocInfo::none);
__ emit_data64(0xffffffffffffffff, relocInfo::none);
// Shuffle table
__ emit_data64(0x090a040506000102, relocInfo::none);
__ emit_data64(0xffffffff0c0d0e08, relocInfo::none);
__ emit_data64(0x090a040506000102, relocInfo::none);
__ emit_data64(0xffffffff0c0d0e08, relocInfo::none);
// merge table
__ emit_data(0x01400140, relocInfo::none, 0);
// merge multiplier
__ emit_data(0x00011000, relocInfo::none, 0);
return start;
}
address StubGenerator::base64_AVX2_decode_LUT_tables_addr() {
__ align64();
StubCodeMark mark(this, "StubRoutines", "AVX2_tables_URL_base64");
address start = __ pc();
assert(((unsigned long long)start & 0x3f) == 0,
"Alignment problem (0x%08llx)", (unsigned long long)start);
// lut_lo
__ emit_data64(0x1111111111111115, relocInfo::none);
__ emit_data64(0x1a1b1b1b1a131111, relocInfo::none);
__ emit_data64(0x1111111111111115, relocInfo::none);
__ emit_data64(0x1a1b1b1b1a131111, relocInfo::none);
// lut_roll
__ emit_data64(0xb9b9bfbf04131000, relocInfo::none);
__ emit_data64(0x0000000000000000, relocInfo::none);
__ emit_data64(0xb9b9bfbf04131000, relocInfo::none);
__ emit_data64(0x0000000000000000, relocInfo::none);
// lut_lo URL
__ emit_data64(0x1111111111111115, relocInfo::none);
__ emit_data64(0x1b1b1a1b1b131111, relocInfo::none);
__ emit_data64(0x1111111111111115, relocInfo::none);
__ emit_data64(0x1b1b1a1b1b131111, relocInfo::none);
// lut_roll URL
__ emit_data64(0xb9b9bfbf0411e000, relocInfo::none);
__ emit_data64(0x0000000000000000, relocInfo::none);
__ emit_data64(0xb9b9bfbf0411e000, relocInfo::none);
__ emit_data64(0x0000000000000000, relocInfo::none);
// lut_hi
__ emit_data64(0x0804080402011010, relocInfo::none);
__ emit_data64(0x1010101010101010, relocInfo::none);
__ emit_data64(0x0804080402011010, relocInfo::none);
__ emit_data64(0x1010101010101010, relocInfo::none);
return start;
}
address StubGenerator::base64_decoding_table_addr() {
StubCodeMark mark(this, "StubRoutines", "decoding_table_base64");
address start = __ pc();
@ -2289,7 +2367,7 @@ address StubGenerator::generate_base64_decodeBlock() {
Label L_process256, L_process64, L_process64Loop, L_exit, L_processdata, L_loadURL;
Label L_continue, L_finalBit, L_padding, L_donePadding, L_bruteForce;
Label L_forceLoop, L_bottomLoop, L_checkMIME, L_exit_no_vzero;
Label L_forceLoop, L_bottomLoop, L_checkMIME, L_exit_no_vzero, L_lastChunk;
// calculate length from offsets
__ movl(length, end_offset);
@ -2299,11 +2377,11 @@ address StubGenerator::generate_base64_decodeBlock() {
// If AVX512 VBMI not supported, just compile non-AVX code
if(VM_Version::supports_avx512_vbmi() &&
VM_Version::supports_avx512bw()) {
__ cmpl(length, 128); // 128-bytes is break-even for AVX-512
__ jcc(Assembler::lessEqual, L_bruteForce);
__ cmpl(length, 31); // 32-bytes is break-even for AVX-512
__ jcc(Assembler::lessEqual, L_lastChunk);
__ cmpl(isMIME, 0);
__ jcc(Assembler::notEqual, L_bruteForce);
__ jcc(Assembler::notEqual, L_lastChunk);
// Load lookup tables based on isURL
__ cmpl(isURL, 0);
@ -2554,6 +2632,89 @@ address StubGenerator::generate_base64_decodeBlock() {
__ BIND(L_bruteForce);
} // End of if(avx512_vbmi)
if (VM_Version::supports_avx2()) {
Label L_tailProc, L_topLoop, L_enterLoop;
__ cmpl(isMIME, 0);
__ jcc(Assembler::notEqual, L_lastChunk);
// Check for buffer too small (for algorithm)
__ subl(length, 0x2c);
__ jcc(Assembler::less, L_tailProc);
__ shll(isURL, 2);
// Algorithm adapted from https://arxiv.org/abs/1704.00605, "Faster Base64
// Encoding and Decoding using AVX2 Instructions". URL modifications added.
// Set up constants
__ lea(r13, ExternalAddress(StubRoutines::x86::base64_AVX2_decode_tables_addr()));
__ vpbroadcastd(xmm4, Address(r13, isURL, Address::times_1), Assembler::AVX_256bit); // 2F or 5F
__ vpbroadcastd(xmm10, Address(r13, isURL, Address::times_1, 0x08), Assembler::AVX_256bit); // -1 or -4
__ vmovdqu(xmm12, Address(r13, 0x10)); // permute
__ vmovdqu(xmm13, Address(r13, 0x30)); // shuffle
__ vpbroadcastd(xmm7, Address(r13, 0x50), Assembler::AVX_256bit); // merge
__ vpbroadcastd(xmm6, Address(r13, 0x54), Assembler::AVX_256bit); // merge mult
__ lea(r13, ExternalAddress(StubRoutines::x86::base64_AVX2_decode_LUT_tables_addr()));
__ shll(isURL, 4);
__ vmovdqu(xmm11, Address(r13, isURL, Address::times_1, 0x00)); // lut_lo
__ vmovdqu(xmm8, Address(r13, isURL, Address::times_1, 0x20)); // lut_roll
__ shrl(isURL, 6); // restore isURL
__ vmovdqu(xmm9, Address(r13, 0x80)); // lut_hi
__ jmp(L_enterLoop);
__ align32();
__ bind(L_topLoop);
// Add in the offset value (roll) to get 6-bit out values
__ vpaddb(xmm0, xmm0, xmm2, Assembler::AVX_256bit);
// Merge and permute the output bits into appropriate output byte lanes
__ vpmaddubsw(xmm0, xmm0, xmm7, Assembler::AVX_256bit);
__ vpmaddwd(xmm0, xmm0, xmm6, Assembler::AVX_256bit);
__ vpshufb(xmm0, xmm0, xmm13, Assembler::AVX_256bit);
__ vpermd(xmm0, xmm12, xmm0, Assembler::AVX_256bit);
// Store the output bytes
__ vmovdqu(Address(dest, dp, Address::times_1, 0), xmm0);
__ addptr(source, 0x20);
__ addptr(dest, 0x18);
__ subl(length, 0x20);
__ jcc(Assembler::less, L_tailProc);
__ bind(L_enterLoop);
// Load in encoded string (32 bytes)
__ vmovdqu(xmm2, Address(source, start_offset, Address::times_1, 0x0));
// Extract the high nibble for indexing into the lut tables. High 4 bits are don't care.
__ vpsrld(xmm1, xmm2, 0x4, Assembler::AVX_256bit);
__ vpand(xmm1, xmm4, xmm1, Assembler::AVX_256bit);
// Extract the low nibble. 5F/2F will isolate the low-order 4 bits. High 4 bits are don't care.
__ vpand(xmm3, xmm2, xmm4, Assembler::AVX_256bit);
// Check for special-case (0x2F or 0x5F (URL))
__ vpcmpeqb(xmm0, xmm4, xmm2, Assembler::AVX_256bit);
// Get the bitset based on the low nibble. vpshufb uses low-order 4 bits only.
__ vpshufb(xmm3, xmm11, xmm3, Assembler::AVX_256bit);
// Get the bit value of the high nibble
__ vpshufb(xmm5, xmm9, xmm1, Assembler::AVX_256bit);
// Make sure 2F / 5F shows as valid
__ vpandn(xmm3, xmm0, xmm3, Assembler::AVX_256bit);
// Make adjustment for roll index. For non-URL, this is a no-op,
// for URL, this adjusts by -4. This is to properly index the
// roll value for 2F / 5F.
__ vpand(xmm0, xmm0, xmm10, Assembler::AVX_256bit);
// If the and of the two is non-zero, we have an invalid input character
__ vptest(xmm3, xmm5);
// Extract the "roll" value - value to add to the input to get 6-bit out value
__ vpaddb(xmm0, xmm0, xmm1, Assembler::AVX_256bit); // Handle 2F / 5F
__ vpshufb(xmm0, xmm8, xmm0, Assembler::AVX_256bit);
__ jcc(Assembler::equal, L_topLoop); // Fall through on error
__ bind(L_tailProc);
__ addl(length, 0x2c);
__ vzeroupper();
}
// Use non-AVX code to decode 4-byte chunks into 3 bytes of output
// Register state (Linux):
@ -2584,6 +2745,8 @@ address StubGenerator::generate_base64_decodeBlock() {
const Register byte3 = WIN64_ONLY(r8) NOT_WIN64(rdx);
const Register byte4 = WIN64_ONLY(r10) NOT_WIN64(r9);
__ bind(L_lastChunk);
__ shrl(length, 2); // Multiple of 4 bytes only - length is # 4-byte chunks
__ cmpl(length, 0);
__ jcc(Assembler::lessEqual, L_exit_no_vzero);
@ -3829,12 +3992,12 @@ void StubGenerator::generate_all() {
}
if (UseBASE64Intrinsics) {
if(VM_Version::supports_avx2() &&
VM_Version::supports_avx512bw() &&
VM_Version::supports_avx512vl()) {
if(VM_Version::supports_avx2()) {
StubRoutines::x86::_avx2_shuffle_base64 = base64_avx2_shuffle_addr();
StubRoutines::x86::_avx2_input_mask_base64 = base64_avx2_input_mask_addr();
StubRoutines::x86::_avx2_lut_base64 = base64_avx2_lut_addr();
StubRoutines::x86::_avx2_decode_tables_base64 = base64_AVX2_decode_tables_addr();
StubRoutines::x86::_avx2_decode_lut_tables_base64 = base64_AVX2_decode_LUT_tables_addr();
}
StubRoutines::x86::_encoding_table_base64 = base64_encoding_table_addr();
if (VM_Version::supports_avx512_vbmi()) {

View File

@ -441,6 +441,8 @@ class StubGenerator: public StubCodeGenerator {
address base64_vbmi_join_1_2_addr();
address base64_vbmi_join_2_3_addr();
address base64_decoding_table_addr();
address base64_AVX2_decode_tables_addr();
address base64_AVX2_decode_LUT_tables_addr();
// Code for generating Base64 decoding.
//

View File

@ -71,6 +71,8 @@ address StubRoutines::x86::_shuffle_base64 = NULL;
address StubRoutines::x86::_avx2_shuffle_base64 = NULL;
address StubRoutines::x86::_avx2_input_mask_base64 = NULL;
address StubRoutines::x86::_avx2_lut_base64 = NULL;
address StubRoutines::x86::_avx2_decode_tables_base64 = NULL;
address StubRoutines::x86::_avx2_decode_lut_tables_base64 = NULL;
address StubRoutines::x86::_lookup_lo_base64 = NULL;
address StubRoutines::x86::_lookup_hi_base64 = NULL;
address StubRoutines::x86::_lookup_lo_base64url = NULL;

View File

@ -185,6 +185,8 @@ class x86 {
static address _avx2_shuffle_base64;
static address _avx2_input_mask_base64;
static address _avx2_lut_base64;
static address _avx2_decode_tables_base64;
static address _avx2_decode_lut_tables_base64;
static address _lookup_lo_base64;
static address _lookup_hi_base64;
static address _lookup_lo_base64url;
@ -325,6 +327,8 @@ class x86 {
static address base64_vbmi_join_1_2_addr() { return _join_1_2_base64; }
static address base64_vbmi_join_2_3_addr() { return _join_2_3_base64; }
static address base64_decoding_table_addr() { return _decoding_table_base64; }
static address base64_AVX2_decode_tables_addr() { return _avx2_decode_tables_base64; }
static address base64_AVX2_decode_LUT_tables_addr() { return _avx2_decode_lut_tables_base64; }
#endif
static address pshuffle_byte_flip_mask_addr() { return _pshuffle_byte_flip_mask_addr; }
static address arrays_hashcode_powers_of_31() { return (address)_arrays_hashcode_powers_of_31; }

View File

@ -1140,7 +1140,7 @@ void VM_Version::get_processor_features() {
}
// Base64 Intrinsics (Check the condition for which the intrinsic will be active)
if ((UseAVX > 2) && supports_avx512vl() && supports_avx512bw()) {
if (UseAVX >= 2) {
if (FLAG_IS_DEFAULT(UseBASE64Intrinsics)) {
UseBASE64Intrinsics = true;
}

View File

@ -31,18 +31,18 @@ import java.util.Random;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Warmup(iterations = 4, time = 2)
@Measurement(iterations = 4, time = 2)
@Fork(value = 3)
public class Base64Decode {
private Base64.Encoder encoder, mimeEncoder;
private Base64.Decoder decoder, mimeDecoder;
private ArrayList<byte[]> encoded, mimeEncoded, errorEncoded;
private byte[] decoded, mimeDecoded, errorDecoded;
private Base64.Encoder encoder, mimeEncoder, urlEncoder;
private Base64.Decoder decoder, mimeDecoder, urlDecoder;
private ArrayList<byte[]> encoded, mimeEncoded, urlEncoded, errorEncoded;
private byte[] decoded, mimeDecoded, urlDecoded, errorDecoded;
private static final int TESTSIZE = 1000;
@ -60,6 +60,9 @@ public class Base64Decode {
@Param({"144"})
private int errorIndex;
@Param({"0"})
private int addSpecial;
@Setup
public void setup() {
Random r = new Random(1123);
@ -74,6 +77,11 @@ public class Base64Decode {
mimeDecoder = Base64.getMimeDecoder();
mimeEncoded = new ArrayList<byte[]> ();
urlDecoded = new byte[maxNumBytes + 1];
urlEncoder = Base64.getUrlEncoder();
urlDecoder = Base64.getUrlDecoder();
urlEncoded = new ArrayList<byte[]> ();
errorDecoded = new byte[errorIndex + 100];
errorEncoded = new ArrayList<byte[]> ();
@ -83,6 +91,10 @@ public class Base64Decode {
byte[] dst = new byte[((srcLen + 2) / 3) * 4];
r.nextBytes(src);
encoder.encode(src, dst);
if(addSpecial != 0){
dst[0] = '/';
dst[1] = '+';
}
encoded.add(dst);
int mimeSrcLen = 1 + r.nextInt(maxNumBytes);
@ -92,13 +104,24 @@ public class Base64Decode {
mimeEncoder.encode(mimeSrc, mimeDst);
mimeEncoded.add(mimeDst);
int urlSrcLen = 1 + r.nextInt(maxNumBytes);
byte[] urlSrc = new byte[urlSrcLen];
byte[] urlDst = new byte[((urlSrcLen + 2) / 3) * 4];
r.nextBytes(urlSrc);
urlEncoder.encode(urlSrc, urlDst);
if(addSpecial != 0){
urlDst[0] = '_';
urlDst[1] = '-';
}
urlEncoded.add(urlDst);
int errorSrcLen = errorIndex + r.nextInt(100);
byte[] errorSrc = new byte[errorSrcLen];
byte[] errorDst = new byte[(errorSrcLen + 2) / 3 * 4];
r.nextBytes(errorSrc);
encoder.encode(errorSrc, errorDst);
errorEncoded.add(errorDst);
errorDst[errorIndex] = (byte) '?';
errorEncoded.add(errorDst);
}
}
@ -120,6 +143,15 @@ public class Base64Decode {
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testBase64URLDecode(Blackhole bh) {
for (byte[] s : urlEncoded) {
urlDecoder.decode(s, urlDecoded);
bh.consume(urlDecoded);
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testBase64WithErrorInputsDecode (Blackhole bh) {

View File

@ -31,8 +31,8 @@ import java.util.Random;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Warmup(iterations = 4, time = 2)
@Measurement(iterations = 4, time = 2)