jdk-24/test/jdk/javax/crypto/KEM/RSA_KEM.java
Weijun Wang 6b90b0519e 8297878: KEM: Implementation
Reviewed-by: ascarpino, mullan
2023-05-30 16:29:19 +00:00

502 lines
21 KiB
Java

/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8297878
* @summary RSA_KEM example
* @modules java.base/sun.security.jca
* java.base/sun.security.rsa
* java.base/sun.security.util
*/
import sun.security.jca.JCAUtil;
import sun.security.rsa.RSACore;
import sun.security.util.*;
import javax.crypto.*;
import javax.crypto.spec.*;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
// This test implements RSA-KEM as described in RFC 5990. In this KEM, the
// sender configures the encapsulator with an RSAKEMParameterSpec object.
// This object is encoded as a byte array and included in the Encapsulated
// output. The receiver is then able to recover the same RSAKEMParameterSpec
// object from the encoding using an AlgorithmParameters implementation
// and use the object to configure the decapsulator.
public class RSA_KEM {
public static void main(String[] args) throws Exception {
Provider p = new ProviderImpl();
RSAKEMParameterSpec[] kspecs = new RSAKEMParameterSpec[] {
RSAKEMParameterSpec.kdf1("SHA-256", "AES_128/KW/NoPadding"),
RSAKEMParameterSpec.kdf1("SHA-512", "AES_256/KW/NoPadding"),
RSAKEMParameterSpec.kdf2("SHA-256", "AES_128/KW/NoPadding"),
RSAKEMParameterSpec.kdf2("SHA-512", "AES_256/KW/NoPadding"),
RSAKEMParameterSpec.kdf3("SHA-256", new byte[10], "AES_128/KW/NoPadding"),
RSAKEMParameterSpec.kdf3("SHA-256", new byte[0], "AES_128/KW/NoPadding"),
RSAKEMParameterSpec.kdf3("SHA-512", new byte[0], "AES_128/KW/NoPadding"),
};
for (RSAKEMParameterSpec kspec : kspecs) {
System.err.println("---------");
System.err.println(kspec);
AlgorithmParameters d = AlgorithmParameters.getInstance("RSA-KEM", p);
d.init(kspec);
AlgorithmParameters s = AlgorithmParameters.getInstance("RSA-KEM", p);
s.init(d.getEncoded());
AlgorithmParameterSpec spec = s.getParameterSpec(AlgorithmParameterSpec.class);
if (!spec.toString().equals(kspec.toString())) {
throw new RuntimeException(spec.toString());
}
}
byte[] msg = "hello".getBytes(StandardCharsets.UTF_8);
byte[] iv = new byte[16];
for (int size : List.of(1024, 2048)) {
KeyPairGenerator g = KeyPairGenerator.getInstance("RSA");
g.initialize(size);
KeyPair kp = g.generateKeyPair();
for (RSAKEMParameterSpec kspec : kspecs) {
SecretKey cek = KeyGenerator.getInstance("AES").generateKey();
KEM kem1 = KEM.getInstance("RSA-KEM", p);
Cipher c = Cipher.getInstance("AES/CBC/PKCS5Padding");
c.init(Cipher.ENCRYPT_MODE, cek, new IvParameterSpec(iv));
byte[] ciphertext = c.doFinal(msg);
KEM.Encapsulator e = kem1.newEncapsulator(kp.getPublic(), kspec, null);
KEM.Encapsulated enc = e.encapsulate(0, e.secretSize(), "AES");
Cipher c2 = Cipher.getInstance(kspec.encAlg);
c2.init(Cipher.WRAP_MODE, enc.key());
byte[] ek = c2.wrap(cek);
AlgorithmParameters a = AlgorithmParameters.getInstance("RSA-KEM", p);
a.init(enc.params());
KEM kem2 = KEM.getInstance("RSA-KEM", p);
KEM.Decapsulator d = kem2.newDecapsulator(kp.getPrivate(), a.getParameterSpec(AlgorithmParameterSpec.class));
SecretKey k = d.decapsulate(enc.encapsulation(), 0, d.secretSize(), "AES");
Cipher c3 = Cipher.getInstance(kspec.encAlg);
c3.init(Cipher.UNWRAP_MODE, k);
cek = (SecretKey) c3.unwrap(ek, "AES", Cipher.SECRET_KEY);
Cipher c4 = Cipher.getInstance("AES/CBC/PKCS5Padding");
c4.init(Cipher.DECRYPT_MODE, cek, new IvParameterSpec(iv));
byte[] cleartext = c4.doFinal(ciphertext);
if (!Arrays.equals(cleartext, msg)) {
throw new RuntimeException();
}
System.out.printf("%4d %20s - %11d %11d %11d %11d %s\n",
size, kspec,
e.secretSize(), e.encapsulationSize(),
d.secretSize(), d.encapsulationSize(), k.getAlgorithm());
}
}
}
static final String RSA_KEM = "1.2.840.113549.1.9.16.3.14";
static final String KEM_RSA = "1.0.18033.2.2.4";
public static class ProviderImpl extends Provider {
public ProviderImpl() {
super("MYKEM", "1", "RSA-KEM");
List<String> alias = List.of(RSA_KEM, "OID." + RSA_KEM);
Map<String, String> attrs = Map.of(
"SupportedKeyClasses", "java.security.interfaces.RSAKey");
putService(new Service(this, "KEM", "RSA-KEM",
"RSA_KEM$KEMImpl", alias, attrs));
putService(new Service(this, "AlgorithmParameters", "RSA-KEM",
"RSA_KEM$AlgorithmParametersImpl", alias, attrs));
}
}
public static class AlgorithmParametersImpl extends AlgorithmParametersSpi {
RSAKEMParameterSpec spec;
@Override
protected void engineInit(AlgorithmParameterSpec paramSpec)
throws InvalidParameterSpecException {
if (paramSpec instanceof RSAKEMParameterSpec rspec) {
spec = rspec;
} else {
throw new InvalidParameterSpecException();
}
}
@Override
protected void engineInit(byte[] params) throws IOException {
spec = decode(params);
}
@Override
protected void engineInit(byte[] params, String format) throws IOException {
spec = decode(params);
}
@Override
protected <T extends AlgorithmParameterSpec> T engineGetParameterSpec(
Class<T> paramSpec) throws InvalidParameterSpecException {
if (paramSpec.isAssignableFrom(RSAKEMParameterSpec.class)) {
return paramSpec.cast(spec);
} else {
throw new InvalidParameterSpecException();
}
}
@Override
protected byte[] engineGetEncoded() {
return encode(spec);
}
@Override
protected byte[] engineGetEncoded(String format) {
return encode(spec);
}
@Override
protected String engineToString() {
return spec == null ? "<null>" : spec.toString();
}
static final ObjectIdentifier id_rsa_kem;
static final ObjectIdentifier id_kem_rsa;
static final ObjectIdentifier id_kdf1;
static final ObjectIdentifier id_kdf2;
static final ObjectIdentifier id_kdf3;
static {
try {
id_rsa_kem = ObjectIdentifier.of("1.2.840.113549.1.9.16.3.14");
id_kem_rsa = ObjectIdentifier.of("1.0.18033.2.2.4");
id_kdf1 = ObjectIdentifier.of("1.3.133.16.840.9.44.1.0"); // fake
id_kdf2 = ObjectIdentifier.of("1.3.133.16.840.9.44.1.1");
id_kdf3 = ObjectIdentifier.of("1.3.133.16.840.9.44.1.2");
} catch (IOException e) {
throw new AssertionError(e);
}
}
static byte[] encode(RSAKEMParameterSpec spec) {
DerOutputStream kdf = new DerOutputStream()
.write(DerValue.tag_Sequence, new DerOutputStream()
.putOID(oid4(spec.kdfAlg))
.write(DerValue.tag_Sequence, new DerOutputStream()
.putOID(oid4(spec.hashAlg))))
.putInteger(spec.kdfLen());
// The next line is not in RFC 5990
if (spec.fixedInfo != null) {
kdf.putOctetString(spec.fixedInfo);
}
return new DerOutputStream()
.write(DerValue.tag_Sequence, new DerOutputStream()
.write(DerValue.tag_Sequence, new DerOutputStream()
.putOID(id_kem_rsa)
.write(DerValue.tag_Sequence, kdf))
.write(DerValue.tag_Sequence, new DerOutputStream()
.putOID(oid4(spec.encAlg)))).toByteArray();
}
static RSAKEMParameterSpec decode(byte[] der) throws IOException {
String kdfAlg, encAlg, hashAlg;
int kdfLen;
byte[] fixedInfo;
DerInputStream d2 = new DerValue(der).toDerInputStream();
DerInputStream d3 = d2.getDerValue().toDerInputStream();
if (!d3.getOID().equals(id_kem_rsa)) {
throw new IOException("not id_kem_rsa");
}
DerInputStream d4 = d3.getDerValue().toDerInputStream();
DerInputStream d5 = d4.getDerValue().toDerInputStream();
kdfLen = d4.getInteger();
fixedInfo = d4.available() > 0 ? d4.getOctetString() : null;
d4.atEnd();
ObjectIdentifier kdfOid = d5.getOID();
if (kdfOid.equals(id_kdf1)) {
kdfAlg = "kdf1";
} else if (kdfOid.equals(id_kdf2)) {
kdfAlg = "kdf2";
} else if (kdfOid.equals(id_kdf3)) {
kdfAlg = "kdf3";
} else {
throw new IOException("unknown kdf");
}
DerInputStream d6 = d5.getDerValue().toDerInputStream();
String hashOID = d6.getOID().toString();
KnownOIDs k = KnownOIDs.findMatch(hashOID);
hashAlg = k == null ? hashOID : k.stdName();
d6.atEnd();
d5.atEnd();
d3.atEnd();
DerInputStream d7 = d2.getDerValue().toDerInputStream();
String encOID = d7.getOID().toString();
KnownOIDs e = KnownOIDs.findMatch(encOID);
encAlg = e == null ? encOID : e.stdName();
d7.atEnd();
d2.atEnd();
if (kdfLen != RSAKEMParameterSpec.kdfLen(encAlg)) {
throw new IOException("kdfLen does not match encAlg");
}
return new RSAKEMParameterSpec(kdfAlg, hashAlg, fixedInfo, encAlg);
}
static ObjectIdentifier oid4(String s) {
return switch (s) {
case "kdf1" -> id_kdf1;
case "kdf2" -> id_kdf2;
case "kdf3" -> id_kdf3;
default -> {
KnownOIDs k = KnownOIDs.findMatch(s);
if (k == null) throw new UnsupportedOperationException();
yield ObjectIdentifier.of(k);
}
};
}
}
public static class RSAKEMParameterSpec implements AlgorithmParameterSpec {
private final String kdfAlg;
private final String hashAlg;
private final byte[] fixedInfo;
private final String encAlg;
private RSAKEMParameterSpec(String kdfAlg, String hashAlg, byte[] fixedInfo, String encAlg) {
this.hashAlg = hashAlg;
this.kdfAlg = kdfAlg;
this.fixedInfo = fixedInfo == null ? null : fixedInfo.clone();
this.encAlg = encAlg;
}
public static RSAKEMParameterSpec kdf1(String hashAlg, String encAlg) {
return new RSAKEMParameterSpec("kdf1", hashAlg, null, encAlg);
}
public static RSAKEMParameterSpec kdf2(String hashAlg, String encAlg) {
return new RSAKEMParameterSpec("kdf2", hashAlg, null, encAlg);
}
public static RSAKEMParameterSpec kdf3(String hashAlg, byte[] fixedInfo, String encAlg) {
return new RSAKEMParameterSpec("kdf3", hashAlg, fixedInfo, encAlg);
}
public int kdfLen() {
return RSAKEMParameterSpec.kdfLen(encAlg);
}
public static int kdfLen(String encAlg) {
return Integer.parseInt(encAlg, 4, 7, 10) / 8;
}
public String hashAlgorithm() {
return hashAlg;
}
public String kdfAlgorithm() {
return kdfAlg;
}
public byte[] fixedInfo() {
return fixedInfo == null ? null : fixedInfo.clone();
}
public String getEncAlg() {
return encAlg;
}
@Override
public String toString() {
return String.format("[%s,%s,%s]", kdfAlg, hashAlg, encAlg);
}
}
public static class KEMImpl implements KEMSpi {
@Override
public KEMSpi.EncapsulatorSpi engineNewEncapsulator(
PublicKey pk, AlgorithmParameterSpec spec, SecureRandom secureRandom)
throws InvalidAlgorithmParameterException, InvalidKeyException {
if (!(pk instanceof RSAPublicKey rpk)) {
throw new InvalidKeyException("Not an RSA key");
}
return Handler.newEncapsulator(spec, rpk, secureRandom);
}
@Override
public KEMSpi.DecapsulatorSpi engineNewDecapsulator(
PrivateKey sk, AlgorithmParameterSpec spec)
throws InvalidAlgorithmParameterException, InvalidKeyException {
if (!(sk instanceof RSAPrivateCrtKey rsk)) {
throw new InvalidKeyException("Not an RSA key");
}
return Handler.newDecapsulator(spec, rsk);
}
static class Handler implements KEMSpi.EncapsulatorSpi, KEMSpi.DecapsulatorSpi {
private final RSAPublicKey rpk; // not null for encapsulator
private final RSAPrivateKey rsk; // not null for decapsulator
private final RSAKEMParameterSpec kspec; // not null
private final SecureRandom sr; // not null for encapsulator
Handler(AlgorithmParameterSpec spec, RSAPublicKey rpk, RSAPrivateCrtKey rsk, SecureRandom sr)
throws InvalidAlgorithmParameterException {
this.rpk = rpk;
this.rsk = rsk;
this.sr = sr;
if (spec != null) {
if (spec instanceof RSAKEMParameterSpec rs) {
this.kspec = rs;
} else {
throw new InvalidAlgorithmParameterException();
}
} else {
this.kspec = RSAKEMParameterSpec
.kdf2("SHA-256", "AES_256/KW/NoPadding");
}
}
static Handler newEncapsulator(AlgorithmParameterSpec spec, RSAPublicKey rpk, SecureRandom sr)
throws InvalidAlgorithmParameterException {
if (sr == null) {
sr = JCAUtil.getDefSecureRandom();
}
return new Handler(spec, rpk, null, sr);
}
static Handler newDecapsulator(AlgorithmParameterSpec spec, RSAPrivateCrtKey rsk)
throws InvalidAlgorithmParameterException {
return new Handler(spec, null, rsk, null);
}
@Override
public SecretKey engineDecapsulate(byte[] encapsulation,
int from, int to, String algorithm)
throws DecapsulateException {
Objects.checkFromToIndex(from, to, kspec.kdfLen());
Objects.requireNonNull(algorithm, "null algorithm");
Objects.requireNonNull(encapsulation, "null encapsulation");
if (encapsulation.length != KeyUtil.getKeySize(rsk) / 8) {
throw new DecapsulateException("incorrect encapsulation size");
}
try {
byte[] Z = RSACore.rsa(encapsulation, rsk, false);
return new SecretKeySpec(kdf(Z), from, to - from, algorithm);
} catch (BadPaddingException e) {
throw new DecapsulateException("cannot decrypt", e);
}
}
@Override
public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {
Objects.checkFromToIndex(from, to, kspec.kdfLen());
Objects.requireNonNull(algorithm, "null algorithm");
int nLen = rpk.getModulus().bitLength();
int nSize = (nLen + 7) / 8;
BigInteger z;
int tried = 0;
while (true) {
z = new BigInteger(nLen, sr);
if (z.compareTo(rpk.getModulus()) < 0) {
break;
}
if (tried++ > 20) {
throw new ProviderException("Cannot get good random number");
}
}
byte[] Z = z.toByteArray();
if (Z.length > nSize) {
Z = Arrays.copyOfRange(Z, Z.length - nSize, Z.length);
} else if (Z.length < nSize) {
byte[] tmp = new byte[nSize];
System.arraycopy(Z, 0, tmp, nSize - Z.length, Z.length);
Z = tmp;
}
byte[] c;
try {
c = RSACore.rsa(Z, rpk);
} catch (BadPaddingException e) {
throw new AssertionError(e);
}
return new KEM.Encapsulated(
new SecretKeySpec(kdf(Z), from, to - from, algorithm),
c, AlgorithmParametersImpl.encode(kspec));
}
byte[] kdf(byte[] input) {
String hashAlg = kspec.hashAlgorithm();
MessageDigest md;
try {
md = MessageDigest.getInstance(hashAlg);
} catch (NoSuchAlgorithmException e) {
throw new ProviderException(e);
}
String kdfAlg = kspec.kdfAlgorithm();
byte[] fixedInput = kspec.fixedInfo();
int length = kspec.kdfLen();
ByteArrayOutputStream bout = new ByteArrayOutputStream();
int n = kdfAlg.equals("kdf1") ? 0 : 1;
while (true) {
switch (kdfAlg) {
case "kdf1", "kdf2" -> {
md.update(input);
md.update(u32str(n));
}
case "kdf3" -> {
md.update(u32str(n));
md.update(input);
md.update(fixedInput);
}
default -> throw new ProviderException();
}
bout.writeBytes(md.digest());
if (bout.size() > length) break;
n++;
}
byte[] result = bout.toByteArray();
return result.length == length
? result
: Arrays.copyOf(result, length);
}
@Override
public int engineSecretSize() {
return kspec.kdfLen();
}
@Override
public int engineEncapsulationSize() {
return KeyUtil.getKeySize(rsk == null ? rpk : rsk) / 8;
}
}
}
static byte[] u32str(int i) {
return new byte[] {
(byte)(i >> 24), (byte)(i >> 16), (byte)(i >> 8), (byte)i };
}
}