8308118: Avoid multiarray allocations in AESCrypt.makeSessionKey

Reviewed-by: xuelei
This commit is contained in:
Aleksey Shipilev 2023-05-19 06:53:50 +00:00
parent 97ade57fb2
commit 6765761075
2 changed files with 96 additions and 56 deletions

View File

@ -1061,37 +1061,6 @@ final class AESCrypt extends SymmetricCipher implements AESConstants {
this.K = sessionK[(decrypting? 1:0)]; this.K = sessionK[(decrypting? 1:0)];
} }
/**
* Expand an int[(ROUNDS+1)][4] into int[(ROUNDS+1)*4].
* For decryption round keys, need to rotate right by 4 ints.
* @param kr The round keys for encryption or decryption.
* @param decrypting True if 'kr' is for decryption and false otherwise.
*/
private static final int[] expandToSubKey(int[][] kr, boolean decrypting) {
int total = kr.length;
int[] expK = new int[total*4];
if (decrypting) {
// decrypting, rotate right by 4 ints
// i.e. i==0
for(int j=0; j<4; j++) {
expK[j] = kr[total-1][j];
}
for(int i=1; i<total; i++) {
for(int j=0; j<4; j++) {
expK[i*4 + j] = kr[i-1][j];
}
}
} else {
// encrypting, straight expansion
for(int i=0; i<total; i++) {
for(int j=0; j<4; j++) {
expK[i*4 + j] = kr[i][j];
}
}
}
return expK;
}
// check if the specified length (in bytes) is a valid keysize for AES // check if the specified length (in bytes) is a valid keysize for AES
static boolean isKeySizeValid(int len) { static boolean isKeySizeValid(int len) {
for (int aesKeysize : AES_KEYSIZES) { for (int aesKeysize : AES_KEYSIZES) {
@ -1361,12 +1330,13 @@ final class AESCrypt extends SymmetricCipher implements AESConstants {
k.length + " bytes"); k.length + " bytes");
} }
int ROUNDS = getRounds(k.length); final int BC = 4;
int ROUND_KEY_COUNT = (ROUNDS + 1) * 4;
int BC = 4; int ROUNDS = getRounds(k.length);
int[][] Ke = new int[ROUNDS + 1][4]; // encryption round keys int ROUND_KEY_COUNT = (ROUNDS + 1) * BC;
int[][] Kd = new int[ROUNDS + 1][4]; // decryption round keys
int[] Ke = new int[ROUND_KEY_COUNT]; // encryption round keys
int[] Kd = new int[ROUND_KEY_COUNT]; // decryption round keys
int KC = k.length/4; // keylen in 32-bit elements int KC = k.length/4; // keylen in 32-bit elements
@ -1384,8 +1354,8 @@ final class AESCrypt extends SymmetricCipher implements AESConstants {
// copy values into round key arrays // copy values into round key arrays
int t = 0; int t = 0;
for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) { for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) {
Ke[t / 4][t % 4] = tk[j]; Ke[t] = tk[j];
Kd[ROUNDS - (t / 4)][t % 4] = tk[j]; Kd[(ROUNDS - (t / BC))*BC + (t % BC)] = tk[j];
} }
int tt, rconpointer = 0; int tt, rconpointer = 0;
while (t < ROUND_KEY_COUNT) { while (t < ROUND_KEY_COUNT) {
@ -1409,32 +1379,35 @@ final class AESCrypt extends SymmetricCipher implements AESConstants {
} }
// copy values into round key arrays // copy values into round key arrays
for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) { for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) {
Ke[t / 4][t % 4] = tk[j]; Ke[t] = tk[j];
Kd[ROUNDS - (t / 4)][t % 4] = tk[j]; Kd[(ROUNDS - (t / BC))*BC + (t % BC)] = tk[j];
} }
} }
for (int r = 1; r < ROUNDS; r++) { for (int r = 1; r < ROUNDS; r++) {
// inverse MixColumn where needed // inverse MixColumn where needed
for (j = 0; j < BC; j++) { for (j = 0; j < BC; j++) {
tt = Kd[r][j]; int idx = r*BC + j;
Kd[r][j] = U1[(tt >>> 24) & 0xFF] ^ tt = Kd[idx];
U2[(tt >>> 16) & 0xFF] ^ Kd[idx] = U1[(tt >>> 24) & 0xFF] ^
U3[(tt >>> 8) & 0xFF] ^ U2[(tt >>> 16) & 0xFF] ^
U4[ tt & 0xFF]; U3[(tt >>> 8) & 0xFF] ^
U4[ tt & 0xFF];
} }
} }
// assemble the encryption (Ke) and decryption (Kd) round keys // For decryption round keys, need to rotate right by 4 ints.
// and expand them into arrays of ints. // Do that without allocating and zeroing the small buffer.
int[] expandedKe = expandToSubKey(Ke, false); // decrypting==false int KdTail_0 = Kd[Kd.length - 4];
int[] expandedKd = expandToSubKey(Kd, true); // decrypting==true int KdTail_1 = Kd[Kd.length - 3];
int KdTail_2 = Kd[Kd.length - 2];
int KdTail_3 = Kd[Kd.length - 1];
System.arraycopy(Kd, 0, Kd, 4, Kd.length - 4);
Kd[0] = KdTail_0;
Kd[1] = KdTail_1;
Kd[2] = KdTail_2;
Kd[3] = KdTail_3;
Arrays.fill(tk, 0); Arrays.fill(tk, 0);
for (int[] ia: Ke) {
Arrays.fill(ia, 0);
}
for (int[] ia: Kd) {
Arrays.fill(ia, 0);
}
ROUNDS_12 = (ROUNDS>=12); ROUNDS_12 = (ROUNDS>=12);
ROUNDS_14 = (ROUNDS==14); ROUNDS_14 = (ROUNDS==14);
limit = ROUNDS*4; limit = ROUNDS*4;
@ -1444,8 +1417,11 @@ final class AESCrypt extends SymmetricCipher implements AESConstants {
// erase the previous values in sessionK // erase the previous values in sessionK
Arrays.fill(sessionK[0], 0); Arrays.fill(sessionK[0], 0);
Arrays.fill(sessionK[1], 0); Arrays.fill(sessionK[1], 0);
} else {
sessionK = new int[2][];
} }
sessionK = new int[][] { expandedKe, expandedKd }; sessionK[0] = Ke;
sessionK[1] = Kd;
} }
/** /**

View File

@ -0,0 +1,64 @@
/*
* Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.javax.crypto;
import org.openjdk.jmh.annotations.*;
import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 3, jvmArgsAppend = {"-Xms1g", "-Xmx1g"})
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
public class AESReinit {
private Cipher cipher;
private Random random;
byte[] key = new byte[16];
byte[] iv = new byte[16];
@Setup
public void prepare() throws Exception {
random = new Random();
cipher = Cipher.getInstance("AES/GCM/NoPadding");
key = new byte[16];
iv = new byte[16];
}
@Benchmark
public void test() throws Exception {
random.nextBytes(key);
random.nextBytes(iv);
SecretKeySpec secretKey = new SecretKeySpec(key, "AES");
GCMParameterSpec param = new GCMParameterSpec(128, iv);
cipher.init(Cipher.ENCRYPT_MODE, secretKey, param);
}
}