8328608: Multiple NewSessionTicket support for TLS

Reviewed-by: djelinski
This commit is contained in:
Anthony Scarpino 2024-08-28 17:24:33 +00:00
parent 379f3db001
commit 0c2b175898
16 changed files with 1161 additions and 300 deletions

View File

@ -1139,9 +1139,11 @@ final class Finished {
//
// produce
if (SSLConfiguration.serverNewSessionTicketCount > 0) {
NewSessionTicket.t13PosthandshakeProducer.produce(shc);
}
}
}
private static void recordEvent(SSLSessionImpl session) {
TLSHandshakeEvent event = new TLSHandshakeEvent();

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,11 +25,11 @@
package sun.security.ssl;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Locale;
import javax.crypto.SecretKey;
import javax.net.ssl.SSLHandshakeException;
@ -118,11 +118,6 @@ final class NewSessionTicket {
this.ticket = Record.getBytes16(m);
}
@Override
public SSLHandshake handshakeType() {
return NEW_SESSION_TICKET;
}
@Override
public int messageLength() {
return 4 + // ticketLifetime
@ -221,11 +216,6 @@ final class NewSessionTicket {
this.extensions = new SSLExtensions(this, m, supportedExtensions);
}
@Override
public SSLHandshake handshakeType() {
return NEW_SESSION_TICKET;
}
int getTicketAgeAdd() {
return ticketAgeAdd;
}
@ -332,8 +322,7 @@ final class NewSessionTicket {
// Is this session resumable?
if (!hc.handshakeSession.isRejoinable()) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"session is not resumable");
}
@ -351,8 +340,7 @@ final class NewSessionTicket {
if (pkemSpec == null ||
!pkemSpec.contains(PskKeyExchangeMode.PSK_DHE_KE)) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"client does not support psk_dhe_ke");
}
@ -363,8 +351,7 @@ final class NewSessionTicket {
// using an allowable PSK exchange key mode.
if (!hc.handshakeSession.isPSKable()) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"No session ticket allowed in this session");
}
@ -375,91 +362,48 @@ final class NewSessionTicket {
// get a new session ID
SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
hc.sslContext.engineGetServerSessionContext();
SessionId newId = new SessionId(true,
hc.sslContext.getSecureRandom());
SecretKey resumptionMasterSecret =
hc.handshakeSession.getResumptionMasterSecret();
if (resumptionMasterSecret == null) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
"no resumption secret");
}
return null;
}
// construct the PSK and handshake message
BigInteger nonce = hc.handshakeSession.incrTicketNonceCounter();
byte[] nonceArr = nonce.toByteArray();
SecretKey psk = derivePreSharedKey(
hc.negotiatedCipherSuite.hashAlg,
resumptionMasterSecret, nonceArr);
int sessionTimeoutSeconds = sessionCache.getSessionTimeout();
if (sessionTimeoutSeconds > MAX_TICKET_LIFETIME) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
"session timeout");
SSLLogger.fine("No session ticket produced: " +
"session timeout is too long");
}
return null;
}
NewSessionTicketMessage nstm = null;
SSLSessionImpl sessionCopy =
new SSLSessionImpl(hc.handshakeSession, newId);
sessionCopy.setPreSharedKey(psk);
sessionCopy.setPskIdentity(newId.getId());
// If a stateless ticket is allowed, attempt to make one
if (hc.statelessResumption &&
hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc,
sessionTimeoutSeconds,
hc.sslContext.getSecureRandom(),
nonceArr,
new SessionTicketSpec().encrypt(hc, sessionCopy));
// If ticket construction failed, switch to session cache
if (!nstm.isValid()) {
hc.statelessResumption = false;
} else {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Produced NewSessionTicket stateless " +
"post-handshake message", nstm);
// Send NewSessionTickets to the client based
if (SSLConfiguration.serverNewSessionTicketCount > 0) {
int i = 0;
NewSessionTicketMessage nstm;
while (i < SSLConfiguration.serverNewSessionTicketCount) {
nstm = generateNST(hc, sessionCache);
if (nstm == null) {
break;
}
}
}
// If a session cache ticket is being used, make one
if (!hc.statelessResumption ||
!hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc, sessionTimeoutSeconds,
hc.sslContext.getSecureRandom(), nonceArr,
newId.getId());
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Produced NewSessionTicket post-handshake message",
nstm);
}
// create and cache the new session
// The new session must be a child of the existing session so
// they will be invalidated together, etc.
hc.handshakeSession.addChild(sessionCopy);
sessionCopy.setTicketAgeAdd(nstm.getTicketAgeAdd());
sessionCache.put(sessionCopy);
}
// Output the handshake message.
if (nstm != null) {
// should never be null
nstm.write(hc.handshakeOutput);
i++;
}
hc.handshakeOutput.flush();
}
/*
* With large NST counts, a client that quickly closes after
* TLS Finished completes can cause SocketExceptions such as:
* Windows servers read-side throwing SocketException:
* "An established connection was aborted by the software in
* your host machine", which relates to error WSAECONNABORTED.
* A SocketException caused by a "broken pipe" has been observed on
* other systems.
* These are very unlikely situations when client and server are on
* different machines.
*
* RFC 8446 does not put requirements when an NST needs to be
* sent, but it should be sent very soon after TLS Finished for
* clients that will quickly resume to create more sessions.
* TLS 1.3 is different from TLS 1.2, there is more data the client
* should be aware of
*/
// See note on TransportContext.needHandshakeFinishedStatus.
//
@ -470,7 +414,6 @@ final class NewSessionTicket {
if (hc.conContext.needHandshakeFinishedStatus) {
hc.conContext.needHandshakeFinishedStatus = false;
}
}
// clean the post handshake context
hc.conContext.finishPostHandshake();
@ -478,6 +421,71 @@ final class NewSessionTicket {
// The message has been delivered.
return null;
}
private NewSessionTicketMessage generateNST(HandshakeContext hc,
SSLSessionContextImpl sessionCache) throws IOException {
NewSessionTicketMessage nstm;
SessionId newId = new SessionId(true,
hc.sslContext.getSecureRandom());
// construct the PSK and handshake message
byte[] nonce = hc.handshakeSession.incrTicketNonceCounter();
SSLSessionImpl sessionCopy =
new SSLSessionImpl(hc.handshakeSession, newId);
sessionCopy.setPreSharedKey(derivePreSharedKey(
hc.negotiatedCipherSuite.hashAlg,
hc.handshakeSession.getResumptionMasterSecret(), nonce));
sessionCopy.setPskIdentity(newId.getId());
// If a stateless ticket is allowed, attempt to make one
if (hc.statelessResumption &&
hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc,
sessionCache.getSessionTimeout(),
hc.sslContext.getSecureRandom(),
nonce,
new SessionTicketSpec().encrypt(hc, sessionCopy));
// If ticket construction failed, switch to session cache
if (!nstm.isValid()) {
hc.statelessResumption = false;
} else {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Produced NewSessionTicket stateless " +
"post-handshake message", nstm);
}
}
return nstm;
}
// If a session cache ticket is being used, make one
if (!hc.statelessResumption ||
!hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc,
sessionCache.getSessionTimeout(),
hc.sslContext.getSecureRandom(), nonce,
newId.getId());
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Produced NewSessionTicket " +
"post-handshake message", nstm);
}
// create and cache the new session
// The new session must be a child of the existing session so
// they will be invalidated together, etc.
hc.handshakeSession.addChild(sessionCopy);
sessionCopy.setTicketAgeAdd(nstm.getTicketAgeAdd());
sessionCache.put(sessionCopy);
return nstm;
}
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("No NewSessionTicket created");
}
return null;
}
}
/**
@ -497,8 +505,9 @@ final class NewSessionTicket {
ServerHandshakeContext shc = (ServerHandshakeContext)context;
// Is this session resumable?
if (!shc.handshakeSession.isRejoinable()) {
// Are new tickets allowed? If so, is this session resumable?
if (SSLConfiguration.serverNewSessionTicketCount == 0 ||
!shc.handshakeSession.isRejoinable()) {
return null;
}
@ -578,7 +587,6 @@ final class NewSessionTicket {
"Discarding NewSessionTicket with lifetime " +
nstm.ticketLifetime, nstm);
}
sessionCache.remove(hc.handshakeSession.getSessionId());
return;
}
@ -619,13 +627,19 @@ final class NewSessionTicket {
sessionCopy.setPreSharedKey(psk);
sessionCopy.setTicketAgeAdd(nstm.getTicketAgeAdd());
sessionCopy.setPskIdentity(nstm.ticket);
sessionCache.put(sessionCopy);
sessionCache.put(sessionCopy, sessionCopy.isPSK());
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("MultiNST PSK (Server): " +
Utilities.toHexString(Arrays.copyOf(nstm.ticket, 16)));
}
// clean the post handshake context
hc.conContext.finishPostHandshake();
}
}
/* TLS 1.2 spec does not specify multiple NST behavior.*/
private static final
class T12NewSessionTicketConsumer implements SSLConsumer {
// Prevent instantiation of this class.
@ -674,8 +688,7 @@ final class NewSessionTicket {
hc.handshakeSession.setPskIdentity(nstm.ticket);
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Consuming NewSessionTicket\n" +
nstm.toString());
SSLLogger.fine("Consuming NewSessionTicket\n" + nstm);
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -699,11 +699,13 @@ final class PreSharedKeyExtension {
//The session cannot be used again. Remove it from the cache.
SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
chc.sslContext.engineGetClientSessionContext();
sessionCache.remove(chc.resumingSession.getSessionId());
sessionCache.remove(chc.resumingSession.getSessionId(), true);
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Found resumable session. Preparing PSK message.");
SSLLogger.fine(
"MultiNST PSK (Client): " + Utilities.toHexString(Arrays.copyOf(chc.pskIdentity, 16)));
}
List<PskIdentity> identities = new ArrayList<>();

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -121,6 +121,11 @@ final class SSLConfiguration implements Cloneable {
static final boolean enableDtlsResumeCookie = Utilities.getBooleanProperty(
"jdk.tls.enableDtlsResumeCookie", true);
// Number of NewSessionTickets that will be sent by the server.
static final int serverNewSessionTicketCount;
// Default for NewSessionTickets
static final int SERVER_NST_DEFAULT = 1;
// Is the extended_master_secret extension supported?
static {
boolean supportExtendedMasterSecret = Utilities.getBooleanProperty(
@ -191,6 +196,33 @@ final class SSLConfiguration implements Cloneable {
} else {
maxInboundServerCertChainLen = inboundServerLen;
}
/*
* jdk.tls.server.newSessionTicketCount system property
* Sets the number of NewSessionTickets sent to a TLS 1.3 resumption
* client. The value must be between 0 and 10. Default is defined by
* SERVER_NST_DEFAULT.
*/
Integer nstServerCount = GetIntegerAction.privilegedGetProperty(
"jdk.tls.server.newSessionTicketCount");
if (nstServerCount == null || nstServerCount < 0 ||
nstServerCount > 10) {
serverNewSessionTicketCount = SERVER_NST_DEFAULT;
if (nstServerCount != null && SSLLogger.isOn &&
SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"jdk.tls.server.newSessionTicketCount defaults to " +
SERVER_NST_DEFAULT + " as the property was not " +
"between 0 and 10");
}
} else {
serverNewSessionTicketCount = nstServerCount;
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"jdk.tls.server.newSessionTicketCount set to " +
serverNewSessionTicketCount);
}
}
}
SSLConfiguration(SSLContextImpl sslContext, boolean isClientMode) {

View File

@ -416,7 +416,8 @@ final class SSLEngineImpl extends SSLEngine implements SSLTransport {
HandshakeStatus currentHandshakeStatus) throws IOException {
// Don't bother to kickstart if handshaking is in progress, or if the
// connection is not duplex-open.
if ((conContext.handshakeContext == null) &&
if (SSLConfiguration.serverNewSessionTicketCount > 0 &&
conContext.handshakeContext == null &&
conContext.protocolVersion.useTLS13PlusSpec() &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 1999, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1999, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -63,6 +63,7 @@ import sun.security.util.Cache;
final class SSLSessionContextImpl implements SSLSessionContext {
private static final int DEFAULT_MAX_CACHE_SIZE = 20480;
private static final int DEFAULT_MAX_QUEUE_SIZE = 10;
// Default lifetime of a session. 24 hours
static final int DEFAULT_SESSION_TIMEOUT = 86400;
@ -87,14 +88,17 @@ final class SSLSessionContextImpl implements SSLSessionContext {
cacheLimit = getDefaults(server); // default cache size
// use soft reference
if (server) {
sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
sessionHostPortCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
if (server) {
keyHashMap = new ConcurrentHashMap<>();
// Should be "randomly generated" according to RFC 5077,
// but doesn't necessarily has to be a true random number.
// but doesn't necessarily have to be a true random number.
currentKeyID = new Random(System.nanoTime()).nextInt();
} else {
sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
sessionHostPortCache = Cache.newSoftMemoryQueue(cacheLimit, timeout,
DEFAULT_MAX_QUEUE_SIZE);
keyHashMap = Map.of();
}
}
@ -277,12 +281,22 @@ final class SSLSessionContextImpl implements SSLSessionContext {
// time it created, which is a little longer than the expected. So
// please do check isTimedout() while getting entry from the cache.
void put(SSLSessionImpl s) {
put(s, false);
}
/**
* Put an entry in the cache
* @param s SSLSessionImpl entry to be stored
* @param canQueue True if multiple entries may exist under one
* session entry.
*/
void put(SSLSessionImpl s, boolean canQueue) {
sessionCache.put(s.getSessionId(), s);
// If no hostname/port info is available, don't add this one.
if ((s.getPeerHost() != null) && (s.getPeerPort() != -1)) {
sessionHostPortCache.put(
getKey(s.getPeerHost(), s.getPeerPort()), s);
getKey(s.getPeerHost(), s.getPeerPort()), s, canQueue);
}
s.setContext(this);
@ -290,13 +304,19 @@ final class SSLSessionContextImpl implements SSLSessionContext {
// package-private method, remove a cached SSLSession
void remove(SessionId key) {
remove(key, false);
}
void remove(SessionId key, boolean isClient) {
SSLSessionImpl s = sessionCache.get(key);
if (s != null) {
sessionCache.remove(key);
// A client keeps the cache entry for queued NST resumption.
if (!isClient) {
sessionHostPortCache.remove(
getKey(s.getPeerHost(), s.getPeerPort()));
}
}
}
private int getDefaults(boolean server) {
try {

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 1996, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -38,7 +38,6 @@ import java.util.Arrays;
import java.util.Queue;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
@ -132,7 +131,11 @@ final class SSLSessionImpl extends ExtendedSSLSession {
private final List<SNIServerName> requestedServerNames;
// Counter used to create unique nonces in NewSessionTicket
private BigInteger ticketNonceCounter = BigInteger.ONE;
private byte ticketNonceCounter = 1;
// This boolean is true when a new set of NewSessionTickets are needed after
// the initial ones sent after the handshake.
boolean updateNST = false;
// The endpoint identification algorithm used to check certificates
// in this session.
@ -501,8 +504,13 @@ final class SSLSessionImpl extends ExtendedSSLSession {
buf.get(b);
this.preSharedKey = new SecretKeySpec(b, alg);
// Get identity len
i = buf.get();
if (i > 0) {
this.pskIdentity = new byte[buf.get()];
buf.get(pskIdentity);
} else {
this.pskIdentity = null;
}
break;
default:
throw new SSLException("Failed local certs of session.");
@ -715,14 +723,12 @@ final class SSLSessionImpl extends ExtendedSSLSession {
this.pskIdentity = pskIdentity;
}
BigInteger incrTicketNonceCounter() {
BigInteger result = ticketNonceCounter;
ticketNonceCounter = ticketNonceCounter.add(BigInteger.ONE);
return result;
byte[] incrTicketNonceCounter() {
return new byte[] {ticketNonceCounter++};
}
boolean isPSKable() {
return (ticketNonceCounter.compareTo(BigInteger.ZERO) > 0);
return (ticketNonceCounter > 0);
}
/**
@ -781,6 +787,10 @@ final class SSLSessionImpl extends ExtendedSSLSession {
return pskIdentity;
}
public boolean isPSK() {
return (pskIdentity != null && pskIdentity.length > 0);
}
void setPeerCertificates(X509Certificate[] peer) {
if (peerCerts == null) {
peerCerts = peer;
@ -1230,7 +1240,6 @@ final class SSLSessionImpl extends ExtendedSSLSession {
* sessions can be shared across different protection domains.
*/
private final ConcurrentHashMap<SecureKey, Object> boundValues;
boolean updateNST;
/**
* Assigns a session value. Session change events are given if

View File

@ -1321,7 +1321,6 @@ public final class SSLSocketImpl
}
// Check if NewSessionTicket PostHandshake message needs to be sent
if (conContext.conSession.updateNST) {
conContext.conSession.updateNST = false;
tryNewSessionTicket();
}
}
@ -1556,7 +1555,8 @@ public final class SSLSocketImpl
private void tryNewSessionTicket() throws IOException {
// Don't bother to kickstart if handshaking is in progress, or if the
// connection is not duplex-open.
if (!conContext.sslConfig.isClientMode &&
if (SSLConfiguration.serverNewSessionTicketCount > 0 &&
!conContext.sslConfig.isClientMode &&
conContext.protocolVersion.useTLS13PlusSpec() &&
conContext.handshakeContext == null &&
!conContext.isOutboundClosed() &&
@ -1565,6 +1565,7 @@ public final class SSLSocketImpl
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("trigger new session ticket");
}
conContext.conSession.updateNST = false;
NewSessionTicket.t13PosthandshakeProducer.produce(
new PostHandshakeContext(conContext));
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2002, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2002, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,8 +25,12 @@
package sun.security.util;
import javax.net.ssl.SSLSession;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.util.*;
import java.lang.ref.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Abstract base class and factory for caches. A cache is a key-value mapping.
@ -90,6 +94,15 @@ public abstract class Cache<K,V> {
*/
public abstract void put(K key, V value);
/**
* Add V to the cache with the option to use a QueueCacheEntry if the
* cache is configured for it. If the cache is not configured for a queue,
* V will silently add the entry directly.
*/
public void put(K key, V value, boolean canQueue) {
put(key, value);
}
/**
* Get a value from the cache.
*/
@ -137,6 +150,11 @@ public abstract class Cache<K,V> {
return new MemoryCache<>(true, size, timeout);
}
public static <K,V> Cache<K,V> newSoftMemoryQueue(int size, int timeout,
int maxQueueSize) {
return new MemoryCache<>(true, size, timeout, maxQueueSize);
}
/**
* Return a new memory cache with the specified maximum size, unlimited
* lifetime for entries, with the values held by standard references.
@ -248,13 +266,12 @@ class NullCache<K,V> extends Cache<K,V> {
class MemoryCache<K,V> extends Cache<K,V> {
private static final float LOAD_FACTOR = 0.75f;
// XXXX
// Debugging
private static final boolean DEBUG = false;
private final Map<K, CacheEntry<K,V>> cacheMap;
private int maxSize;
final private int maxQueueSize;
private long lifetime;
private long nextExpirationTime = Long.MAX_VALUE;
@ -263,18 +280,25 @@ class MemoryCache<K,V> extends Cache<K,V> {
private final ReferenceQueue<V> queue;
public MemoryCache(boolean soft, int maxSize) {
this(soft, maxSize, 0);
this(soft, maxSize, 0, 0);
}
public MemoryCache(boolean soft, int maxSize, int lifetime) {
this.maxSize = maxSize;
this.lifetime = lifetime * 1000L;
if (soft)
this.queue = new ReferenceQueue<>();
else
this.queue = null;
this(soft, maxSize, lifetime, 0);
}
cacheMap = new LinkedHashMap<>(1, LOAD_FACTOR, true);
public MemoryCache(boolean soft, int maxSize, int lifetime, int qSize) {
this.maxSize = maxSize;
this.maxQueueSize = qSize;
this.lifetime = lifetime * 1000L;
if (soft) {
this.queue = new ReferenceQueue<>();
} else {
this.queue = null;
}
// LinkedHashMap is needed for its access order. 0.75f load factor is
// default.
cacheMap = new LinkedHashMap<>(1, 0.75f, true);
}
/**
@ -338,6 +362,10 @@ class MemoryCache<K,V> extends Cache<K,V> {
cnt++;
} else if (nextExpirationTime > entry.getExpirationTime()) {
nextExpirationTime = entry.getExpirationTime();
// If this is a queue, check for some expired entries
if (entry instanceof QueueCacheEntry<K,V> qe) {
qe.getQueue().removeIf(e -> !e.isValid(time));
}
}
}
if (DEBUG) {
@ -367,18 +395,60 @@ class MemoryCache<K,V> extends Cache<K,V> {
cacheMap.clear();
}
public synchronized void put(K key, V value) {
public void put(K key, V value) {
put(key, value, false);
}
/**
* This puts an entry into the cacheMap.
*
* If canQueue is true, V will be added using a QueueCacheEntry which
* is added to cacheMap. If false, V is added to the cacheMap directly.
* The caller must keep a consistent canQueue value, mixing them can
* result in a queue being replaced with a single entry.
*
* This method is synchronized to avoid multiple QueueCacheEntry
* overwriting the same key.
*
* @param key key to the cacheMap
* @param value value to be stored
* @param canQueue can the value be put into a QueueCacheEntry
*/
public synchronized void put(K key, V value, boolean canQueue) {
emptyQueue();
long expirationTime = (lifetime == 0) ? 0 :
System.currentTimeMillis() + lifetime;
long expirationTime =
(lifetime == 0) ? 0 : System.currentTimeMillis() + lifetime;
if (expirationTime < nextExpirationTime) {
nextExpirationTime = expirationTime;
}
CacheEntry<K,V> newEntry = newEntry(key, value, expirationTime, queue);
if (maxQueueSize == 0 || !canQueue) {
CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
if (oldEntry != null) {
oldEntry.invalidate();
return;
}
} else {
CacheEntry<K, V> entry = cacheMap.get(key);
switch (entry) {
case QueueCacheEntry<K, V> qe -> {
qe.putValue(newEntry);
if (DEBUG) {
System.out.println("QueueCacheEntry= " + qe);
final AtomicInteger i = new AtomicInteger(1);
qe.queue.stream().forEach(e ->
System.out.println(i.getAndIncrement() + "= " + e));
}
}
case null, default ->
cacheMap.put(key, new QueueCacheEntry<>(key, newEntry,
expirationTime, maxQueueSize));
}
if (DEBUG) {
System.out.println("Cache entry added: key=" +
key.toString() + ", class=" +
(entry != null ? entry.getClass().getName() : null));
}
}
if (maxSize > 0 && cacheMap.size() > maxSize) {
expungeExpiredEntries();
@ -401,25 +471,37 @@ class MemoryCache<K,V> extends Cache<K,V> {
if (entry == null) {
return null;
}
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
if (!entry.isValid(time)) {
if (lifetime > 0 && !entry.isValid(System.currentTimeMillis())) {
cacheMap.remove(key);
if (DEBUG) {
System.out.println("Ignoring expired entry");
}
cacheMap.remove(key);
return null;
}
// If the value is a queue, return a queue entry.
if (entry instanceof QueueCacheEntry<K, V> qe) {
V result = qe.getValue(lifetime);
if (qe.isEmpty()) {
removeImpl(key);
}
return result;
}
return entry.getValue();
}
public synchronized void remove(Object key) {
emptyQueue();
removeImpl(key);
}
private void removeImpl(Object key) {
CacheEntry<K,V> entry = cacheMap.remove(key);
if (entry != null) {
entry.invalidate();
}
}
public synchronized V pull(Object key) {
emptyQueue();
CacheEntry<K,V> entry = cacheMap.remove(key);
@ -550,8 +632,7 @@ class MemoryCache<K,V> extends Cache<K,V> {
}
}
private static class SoftCacheEntry<K,V>
extends SoftReference<V>
private static class SoftCacheEntry<K,V> extends SoftReference<V>
implements CacheEntry<K,V> {
private K key;
@ -589,6 +670,116 @@ class MemoryCache<K,V> extends Cache<K,V> {
key = null;
expirationTime = -1;
}
@Override
public String toString() {
if (get() instanceof SSLSession se)
return HexFormat.of().formatHex(se.getId());
return super.toString();
}
}
/**
* This CacheEntry<K,V> type allows multiple V entries to be stored in
* one key in the cacheMap.
*
* This implementation is need for TLS clients that receive multiple
* PSKs or NewSessionTickets for server resumption.
*/
private static class QueueCacheEntry<K,V> implements CacheEntry<K,V> {
// Limit the number of queue entries.
private final int MAXQUEUESIZE;
final boolean DEBUG = false;
private K key;
private long expirationTime;
final Queue<CacheEntry<K,V>> queue = new ConcurrentLinkedQueue<>();
QueueCacheEntry(K key, CacheEntry<K,V> entry, long expirationTime,
int maxSize) {
this.key = key;
this.expirationTime = expirationTime;
this.MAXQUEUESIZE = maxSize;
queue.add(entry);
}
public K getKey() {
return key;
}
public V getValue() {
return getValue(0);
}
public V getValue(long lifetime) {
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
do {
var entry = queue.poll();
if (entry == null) {
return null;
}
if (entry.isValid(time)) {
return entry.getValue();
}
entry.invalidate();
} while (!queue.isEmpty());
return null;
}
public long getExpirationTime() {
return expirationTime;
}
public void setExpirationTime(long time) {
expirationTime = time;
}
public void putValue(CacheEntry<K,V> entry) {
if (DEBUG) {
System.out.println("Added to queue (size=" + queue.size() +
"): " + entry.getKey().toString() + ", " + entry);
}
// Update the cache entry's expiration time to the latest entry.
// The getValue() calls will remove expired tickets.
expirationTime = entry.getExpirationTime();
// Limit the number of queue entries, removing the oldest.
if (queue.size() >= MAXQUEUESIZE) {
queue.remove();
}
queue.add(entry);
}
public boolean isValid(long currentTime) {
boolean valid = (currentTime <= expirationTime) && !queue.isEmpty();
if (!valid) {
invalidate();
}
return valid;
}
public boolean isValid() {
return isValid(System.currentTimeMillis());
}
public void invalidate() {
clear();
key = null;
expirationTime = -1;
}
public void clear() {
queue.forEach(CacheEntry::invalidate);
queue.clear();
}
public boolean isEmpty() {
return queue.isEmpty();
}
public Queue<CacheEntry<K,V>> getQueue() {
return queue;
}
}
}

View File

@ -39,10 +39,12 @@ public class CertMsgCheck {
build();
// Initial client session
TLSBase.Client client1 = new TLSBase.Client(true, false);
TLSBase.Client client = new TLSBase.Client(true, false);
client.connect();
server.getSession(client1).getSessionContext();
server.done();
// Close must be called to gather all the exceptions thrown
client.close();
server.close();
var eList = server.getExceptionList();
System.out.println("Exception list size is " + eList.size());

View File

@ -47,6 +47,7 @@ public class CheckSessionContext {
// Initial client session
TLSBase.Client client1 = new TLSBase.Client();
client1.connect();
if (server.getSession(client1).getSessionContext() == null) {
throw new Exception("Context was null. Handshake failure.");
} else {
@ -66,6 +67,7 @@ public class CheckSessionContext {
// Resume the client session
TLSBase.Client client2 = new TLSBase.Client();
client2.connect();
if (server.getSession(client2).getSessionContext() == null) {
throw new Exception("Context was null on resumption");
} else {
@ -73,6 +75,5 @@ public class CheckSessionContext {
}
server.close(client2);
client2.close();
server.done();
}
}

View File

@ -25,13 +25,16 @@ import javax.net.ssl.*;
import java.io.*;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.cert.PKIXBuilderParameters;
import java.security.cert.X509CertSelector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
* This is a base setup for creating a server and clients. All clients will
@ -39,7 +42,7 @@ import java.util.concurrent.ConcurrentHashMap;
* first. The idea is for the test code to be minimal as possible without
* this library class being complicated.
*
* Server.done() must be called or the server will never exit and hang the test.
* Server.close() must be called so the server will exit and end threading.
*
* After construction, reading and writing are allowed from either side,
* or a combination write/read from both sides for verifying text.
@ -52,24 +55,25 @@ import java.util.concurrent.ConcurrentHashMap;
*/
abstract public class TLSBase {
static String pathToStores = "../etc";
static String pathToStores = "javax/net/ssl/etc";
static String keyStoreFile = "keystore";
static String trustStoreFile = "truststore";
static String passwd = "passphrase";
static final String TESTROOT =
System.getProperty("test.root", "../../../..");
SSLContext sslContext;
// Server's port
static int serverPort;
// Name shown during read and write ops
String name;
public String name;
TLSBase() {
String keyFilename =
System.getProperty("test.src", "./") + "/" + pathToStores +
"/" + keyStoreFile;
String trustFilename =
System.getProperty("test.src", "./") + "/" + pathToStores +
"/" + trustStoreFile;
String keyFilename = TESTROOT + "/" + pathToStores + "/" + keyStoreFile;
String trustFilename = TESTROOT + "/" + pathToStores + "/" +
trustStoreFile;
System.setProperty("javax.net.ssl.keyStore", keyFilename);
System.setProperty("javax.net.ssl.keyStorePassword", passwd);
System.setProperty("javax.net.ssl.trustStore", trustFilename);
@ -78,26 +82,22 @@ abstract public class TLSBase {
// Base read operation
byte[] read(SSLSocket sock) throws Exception {
BufferedReader reader = new BufferedReader(
new InputStreamReader(sock.getInputStream()));
String s = reader.readLine();
System.err.println("(read) " + name + ": " + s);
return s.getBytes();
BufferedInputStream is = new BufferedInputStream(sock.getInputStream());
byte[] b = is.readNBytes(5);
System.err.println("(read) " + Thread.currentThread().getName() + ": " + new String(b));
return b;
}
// Base write operation
public void write(SSLSocket sock, byte[] data) throws Exception {
PrintWriter out = new PrintWriter(
new OutputStreamWriter(sock.getOutputStream()));
out.println(new String(data));
out.flush();
System.err.println("(write)" + name + ": " + new String(data));
sock.getOutputStream().write(data);
System.err.println("(write)" + Thread.currentThread().getName() + ": " + new String(data));
}
private static KeyManager[] getKeyManager(boolean empty) throws Exception {
FileInputStream fis = null;
if (!empty) {
fis = new FileInputStream(System.getProperty("test.src", "./") +
fis = new FileInputStream(System.getProperty("test.root", "./") +
"/" + pathToStores + "/" + keyStoreFile);
}
// Load the keystore
@ -113,7 +113,7 @@ abstract public class TLSBase {
private static TrustManager[] getTrustManager(boolean empty) throws Exception {
FileInputStream fis = null;
if (!empty) {
fis = new FileInputStream(System.getProperty("test.src", "./") +
fis = new FileInputStream(System.getProperty("test.root", "./") +
"/" + pathToStores + "/" + trustStoreFile);
}
// Load the keystore
@ -150,6 +150,11 @@ abstract public class TLSBase {
new ConcurrentHashMap<>();
Thread t;
List<Exception> exceptionList = new ArrayList<>();
ExecutorService threadPool = Executors.newFixedThreadPool(1,
r -> {
Thread t = Executors.defaultThreadFactory().newThread(r);
return t;
});
Server(ServerBuilder builder) {
super();
@ -160,8 +165,10 @@ abstract public class TLSBase {
TLSBase.getTrustManager(builder.tm), null);
fac = sslContext.getServerSocketFactory();
ssock = (SSLServerSocket) fac.createServerSocket(0);
ssock.setReuseAddress(true);
ssock.setNeedClientAuth(builder.clientauth);
serverPort = ssock.getLocalPort();
System.out.println("Server Port: " + serverPort);
} catch (Exception e) {
System.err.println("Failure during server initialization");
e.printStackTrace();
@ -171,117 +178,67 @@ abstract public class TLSBase {
t = new Thread(() -> {
try {
while (true) {
System.err.println("Server ready on port " +
serverPort);
SSLSocket c = (SSLSocket)ssock.accept();
clientMap.put(c.getPort(), c);
SSLSocket sock = (SSLSocket)ssock.accept();
threadPool.submit(new ServerThread(sock));
}
} catch (Exception ex) {
System.err.println("Server Down");
ex.printStackTrace();
} finally {
threadPool.close();
}
});
t.start();
}
class ServerThread extends Thread {
SSLSocket sock;
ServerThread(SSLSocket s) {
this.sock = s;
System.err.println("ServerThread("+sock.getPort()+")");
clientMap.put(sock.getPort(), sock);
}
public void run() {
try {
write(c, read(c));
write(sock, read(sock));
} catch (Exception e) {
System.out.println("Caught " + e.getMessage());
e.printStackTrace();
exceptionList.add(e);
}
}
} catch (Exception ex) {
System.err.println("Server Down");
ex.printStackTrace();
}
});
t.start();
}
Server() {
this(new ServerBuilder());
}
/**
* @param km - true for an empty key manager
* @param tm - true for an empty trust manager
*/
Server(boolean km, boolean tm) {
super();
name = "server";
public SSLSession getSession(Client client) throws Exception {
System.err.println("getSession("+client.getPort()+")");
SSLSocket clientSocket = clientMap.get(client.getPort());
if (clientSocket == null) {
throw new Exception("Server can't find client socket");
}
return clientSocket.getSession();
}
void close(Client client) {
try {
sslContext = SSLContext.getInstance("TLS");
sslContext.init(TLSBase.getKeyManager(km),
TLSBase.getTrustManager(tm), null);
fac = sslContext.getServerSocketFactory();
ssock = (SSLServerSocket) fac.createServerSocket(0);
ssock.setNeedClientAuth(true);
serverPort = ssock.getLocalPort();
System.err.println("close("+client.getPort()+")");
clientMap.remove(client.getPort()).close();
} catch (Exception e) {
System.err.println("Failure during server initialization");
e.printStackTrace();
;
}
// Thread to allow multiple clients to connect
t = new Thread(() -> {
}
void close() throws InterruptedException {
clientMap.values().stream().forEach(s -> {
try {
while (true) {
System.err.println("Server ready on port " +
serverPort);
SSLSocket c = (SSLSocket)ssock.accept();
clientMap.put(c.getPort(), c);
try {
write(c, read(c));
} catch (Exception e) {
System.out.println("Caught " + e.getMessage());
e.printStackTrace();
exceptionList.add(e);
}
}
} catch (Exception ex) {
System.err.println("Server Down");
ex.printStackTrace();
}
});
t.start();
}
// Exit test to quit the test. This must be called at the end of the
// test or the test will never end.
void done() {
try {
t.join(5000);
ssock.close();
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
}
}
// Read from the client
byte[] read(Client client) throws Exception {
SSLSocket s = clientMap.get(Integer.valueOf(client.getPort()));
if (s == null) {
System.err.println("No socket found, port " + client.getPort());
}
return read(s);
}
// Write to the client
void write(Client client, byte[] data) throws Exception {
write(clientMap.get(client.getPort()), data);
}
// Server writes to the client, then reads from the client.
// Return true if the read & write data match, false if not.
boolean writeRead(Client client, String s) throws Exception{
write(client, s.getBytes());
return (Arrays.compare(s.getBytes(), client.read()) == 0);
}
// Get the SSLSession from the server side socket
SSLSession getSession(Client c) {
SSLSocket s = clientMap.get(Integer.valueOf(c.getPort()));
return s.getSession();
}
// Close client socket
void close(Client c) throws IOException {
SSLSocket s = clientMap.get(Integer.valueOf(c.getPort()));
s.close();
} catch (IOException e) {}
});
threadPool.awaitTermination(500, TimeUnit.MILLISECONDS);
}
List<Exception> getExceptionList() {
@ -312,11 +269,11 @@ abstract public class TLSBase {
}
}
/**
* Client side will establish a connection from the constructor and wait.
* Client side will establish a SSLContext instance.
* It must be run after the Server constructor is called.
*/
static class Client extends TLSBase {
SSLSocket sock;
public SSLSocket socket;
boolean km, tm;
Client() {
this(false, false);
@ -330,55 +287,66 @@ abstract public class TLSBase {
super();
this.km = km;
this.tm = tm;
connect();
}
// Connect to server. Maybe runnable in the future
public SSLSocket connect() {
try {
sslContext = SSLContext.getInstance("TLS");
sslContext.init(TLSBase.getKeyManager(km), TLSBase.getTrustManager(tm), null);
sock = (SSLSocket)sslContext.getSocketFactory().createSocket();
sock.connect(new InetSocketAddress(InetAddress.getLoopbackAddress(), serverPort));
System.err.println("Client connected using port " +
sock.getLocalPort());
name = "client(" + sock.toString() + ")";
write("Hello");
read();
socket = createSocket();
} catch (Exception ex) {
ex.printStackTrace();
}
return sock;
}
// Read from the client socket
byte[] read() throws Exception {
return read(sock);
Client(Client cl) {
sslContext = cl.sslContext;
socket = createSocket();
}
// Write to the client socket
void write(byte[] data) throws Exception {
write(sock, data);
public SSLSocket createSocket() {
try {
return (SSLSocket) sslContext.getSocketFactory().createSocket();
} catch (Exception ex) {
ex.printStackTrace();
}
void write(String s) throws Exception {
write(sock, s.getBytes());
return null;
}
// Client writes to the server, then reads from the server.
// Return true if the read & write data match, false if not.
boolean writeRead(Server server, String s) throws Exception {
write(s.getBytes());
return (Arrays.compare(s.getBytes(), server.read(this)) == 0);
public SSLSocket connect() {
try {
socket.connect(new InetSocketAddress(InetAddress.getLoopbackAddress(), serverPort));
System.err.println("Client (" + Thread.currentThread().getName() + ") connected using port " +
socket.getLocalPort() + " to " + socket.getPort());
writeRead();
} catch (Exception ex) {
ex.printStackTrace();
return null;
}
return socket;
}
// Get port from the socket
int getPort() {
return sock.getLocalPort();
public SSLSession getSession() {
return socket.getSession();
}
public void close() {
try {
socket.close();
} catch (Exception ex) {
ex.printStackTrace();
}
}
public int getPort() {
return socket.getLocalPort();
}
private SSLSocket writeRead() {
try {
write(socket, "Hello".getBytes(StandardCharsets.ISO_8859_1));
read(socket);
} catch (Exception ex) {
ex.printStackTrace();
}
return socket;
}
// Close socket
void close() throws IOException {
sock.close();
}
}
}

View File

@ -0,0 +1,173 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @library /test/lib
* @library /javax/net/ssl/templates
* @bug 8242008
* @summary Verifies multiple PSKs are used by JSSE
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.newSessionTicketCount=1
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.newSessionTicketCount=3
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.newSessionTicketCount=10
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=false
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=false
* @run main/othervm MultiNSTClient -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true
*/
import jdk.test.lib.Utils;
import jdk.test.lib.process.OutputAnalyzer;
import jdk.test.lib.process.ProcessTools;
import javax.net.ssl.SSLSession;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
/**
* This test verifies that multiple NSTs and PSKs are sent by a JSSE server.
* Then JSSE client is able to store them all and resume the connection. It
* requires specific text in the TLS debugging to verify the success.
*/
public class MultiNSTClient {
static HexFormat hex = HexFormat.of();
public static void main(String[] args) throws Exception {
if (!args[0].equalsIgnoreCase("p")) {
StringBuilder sb = new StringBuilder();
Arrays.stream(args).forEach(a -> {
sb.append(a);
sb.append(" ");
});
String params = sb.toString();
System.setProperty("test.java.opts",
"-Dtest.src=" + System.getProperty("test.src") +
" -Dtest.jdk=" + System.getProperty("test.jdk") +
" -Dtest.root=" + System.getProperty("test.root") +
" -Djavax.net.debug=ssl,handshake " + params
);
boolean TLS13 = args[0].contains("1.3");
System.out.println("test.java.opts: " +
System.getProperty("test.java.opts"));
ProcessBuilder pb = ProcessTools.createTestJavaProcessBuilder(
Utils.addTestJavaOpts("MultiNSTClient", "p"));
OutputAnalyzer output = ProcessTools.executeProcess(pb);
System.out.println("I'm here");
boolean pass = true;
try {
List<String> list = output.stderrShouldContain("MultiNST PSK").
asLines().stream().filter(s ->
s.contains("MultiNST PSK")).toList();
List<String> serverPSK = list.stream().filter(s ->
s.contains("MultiNST PSK (Server)")).toList();
List<String> clientPSK = list.stream().filter(s ->
s.contains("MultiNST PSK (Client)")).toList();
System.out.println("found list: " + list.size());
System.out.println("found server: " + serverPSK.size());
serverPSK.stream().forEach(s -> System.out.println("\t" + s));
System.out.println("found client: " + clientPSK.size());
clientPSK.stream().forEach(s -> System.out.println("\t" + s));
for (int i = 0; i < 2; i++) {
String svr = serverPSK.getFirst();
String cli = clientPSK.getFirst();
if (svr.regionMatches(svr.length() - 16, cli, cli.length() - 16, 16)) {
System.out.println("entry " + (i + 1) + " match.");
} else {
System.out.println("entry " + (i + 1) + " server and client PSK didn't match:");
System.out.println(" server: " + svr);
System.out.println(" client: " + cli);
pass = false;
}
}
} catch (RuntimeException e) {
System.out.println("No MultiNST PSK found.");
pass = false;
}
if (TLS13) {
if (!pass) {
throw new Exception("Test failed: " + params);
}
} else {
if (pass) {
throw new Exception("Test failed: " + params);
}
}
System.out.println("Test Passed");
return;
}
TLSBase.Server server = new TLSBase.Server();
System.out.println("------ Start connection");
TLSBase.Client initial = new TLSBase.Client();
SSLSession initialSession = initial.connect().getSession();
System.out.println("id = " + hex.formatHex(initialSession.getId()));
System.out.println("session = " + initialSession);
System.out.println("------ getNewSession from original client");
TLSBase.Client resumClient = new TLSBase.Client(initial);
SSLSession resumption = resumClient.connect().getSession();
System.out.println("id = " + hex.formatHex(resumption.getId()));
System.out.println("session = " + resumption);
if (!initialSession.toString().equalsIgnoreCase(resumption.toString())) {
throw new Exception("Resumed session did not match");
}
System.out.println("------ Second getNewSession from original client");
TLSBase.Client resumClient2 = new TLSBase.Client(initial);
resumption = resumClient2.connect().getSession();
System.out.println("id = " + hex.formatHex(resumption.getId()));
System.out.println("session = " + resumption);
if (!initialSession.toString().equalsIgnoreCase(resumption.toString())) {
throw new Exception("Resumed session did not match");
}
System.out.println("------ New client connection");
TLSBase.Client newConnection = new TLSBase.Client();
SSLSession newSession = newConnection.connect().getSession();
System.out.println("id = " + hex.formatHex(newSession.getId()));
System.out.println("session = " + newSession);
if (initialSession.toString().equalsIgnoreCase(newSession.toString())) {
throw new Exception("new session is the same as the initial.");
}
System.out.println("------ Closing connections");
initial.close();
resumClient.close();
resumClient2.close();
newConnection.close();
server.close();
System.out.println("------ End");
System.exit(0);
}
}

View File

@ -0,0 +1,96 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @library /test/lib
* @library /javax/net/ssl/templates
* @bug 8242008
* @summary Verifies resumption fails with 0 NSTs and session creation off
* @run main/othervm MultiNSTNoSessionCreation -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.newSessionTicketCount=0
* @run main/othervm MultiNSTNoSessionCreation -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.newSessionTicketCount=0
*/
import jdk.test.lib.Utils;
import jdk.test.lib.process.OutputAnalyzer;
import jdk.test.lib.process.ProcessTools;
import java.util.Arrays;
/**
* With no NSTs sent by the server, try to resume the session with
* setEnabledSessionCreation(false). The test should get an exception and
* fail to connect.
*/
public class MultiNSTNoSessionCreation {
public static void main(String[] args) throws Exception {
if (!args[0].equalsIgnoreCase("p")) {
StringBuilder sb = new StringBuilder();
Arrays.stream(args).forEach(a -> sb.append(a).append(" "));
String params = sb.toString();
System.setProperty("test.java.opts",
"-Dtest.src=" + System.getProperty("test.src") +
" -Dtest.jdk=" + System.getProperty("test.jdk") +
" -Dtest.root=" + System.getProperty("test.root") +
" -Djavax.net.debug=ssl,handshake " + params);
System.out.println("test.java.opts: " +
System.getProperty("test.java.opts"));
ProcessBuilder pb = ProcessTools.createTestJavaProcessBuilder(
Utils.addTestJavaOpts("MultiNSTNoSessionCreation", "p"));
OutputAnalyzer output = ProcessTools.executeProcess(pb);
try {
if (output.stderrContains(
"(PROTOCOL_VERSION): New session creation is disabled")) {
return;
}
} catch (RuntimeException e) {
throw new Exception("Error collecting data", e);
}
throw new Exception("Disabled creation msg not found");
}
TLSBase.Server server = new TLSBase.Server();
System.out.println("------ Initial connection");
TLSBase.Client initial = new TLSBase.Client();
initial.connect();
System.out.println(
"------ Resume client w/ setEnableSessionCreation set to false");
TLSBase.Client resumClient = new TLSBase.Client(initial);
resumClient.socket.setEnableSessionCreation(false);
resumClient.connect();
System.out.println("------ Closing connections");
initial.close();
resumClient.close();
server.close();
System.out.println("------ End");
System.exit(0);
}
}

View File

@ -0,0 +1,205 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @library /test/lib
* @library /javax/net/ssl/templates
* @bug 8242008
* @summary Verifies multiple PSKs are used by TLSv1.3
* @run main/othervm MultiNSTParallel 10 -Djdk.tls.client.protocols=TLSv1.3
*/
import jdk.test.lib.Utils;
import jdk.test.lib.process.OutputAnalyzer;
import jdk.test.lib.process.ProcessTools;
import javax.net.ssl.SSLSession;
import java.util.ArrayList;
import java.util.HexFormat;
import java.util.List;
import java.util.concurrent.CountDownLatch;
/**
* This test verifies that parallel resumption connections successfully get
* a PSK entry and not initiate a full handshake.
*
* Note: THe first argument after 'MultiNSTParallel' is the ticket count
* The test will set 'jdk.tls.server.NewSessionTicketCount` to that number and
* will start the same number of resumption client attempts. The ticket count
* must be the same or larger than resumption attempts otherwise the queue runs
* empty and the test will fail.
*
* Because this test runs parallel connections, the thread order finish is not
* guaranteed. Each client NST id is checked with all server NSTs ids until
* a match is found. When a match is found, it is removed from the list to
* verify no NST was used more than once.
*
* TLS 1.2 spec does not specify multiple NST behavior.
*/
public class MultiNSTParallel {
static HexFormat hex = HexFormat.of();
final static CountDownLatch wait = new CountDownLatch(1);
static class ClientThread extends Thread {
TLSBase.Client client;
ClientThread(TLSBase.Client c) {
client = c;
}
public void run() {
String name = Thread.currentThread().getName();
SSLSession r;
System.err.println("waiting " + Thread.currentThread().getName());
try {
wait.await();
r = new TLSBase.Client(client).connect().getSession();
} catch (Exception e) {
throw new RuntimeException(name + ": " +e);
}
StringBuffer sb = new StringBuffer(100);
sb.append("(").append(name).append(") id = ");
sb.append(hex.formatHex(r.getId()));
sb.append("\n(").append(name).append(") session = ").append(r);
if (!client.getSession().toString().equalsIgnoreCase(r.toString())) {
throw new RuntimeException("(" + name +
") Resumed session did not match");
}
}
}
static boolean pass = true;
public static void main(String[] args) throws Exception {
if (!args[0].equalsIgnoreCase("p")) {
int ticketCount = Integer.parseInt(args[0]);
StringBuilder sb = new StringBuilder();
for (int i = 1; i < args.length; i++) {
sb.append(" ").append(args[i]);
}
String params = sb.toString();
System.setProperty("test.java.opts",
"-Dtest.src=" + System.getProperty("test.src") +
" -Dtest.jdk=" + System.getProperty("test.jdk") +
" -Dtest.root=" + System.getProperty("test.root") +
" -Djavax.net.debug=ssl,handshake " +
" -Djdk.tls.server.newSessionTicketCount=" + ticketCount +
params);
boolean TLS13 = args[1].contains("1.3");
System.out.println("test.java.opts: " +
System.getProperty("test.java.opts"));
ProcessBuilder pb = ProcessTools.createTestJavaProcessBuilder(
Utils.addTestJavaOpts("MultiNSTParallel", "p"));
OutputAnalyzer output = ProcessTools.executeProcess(pb);
try {
List<String> list = output.stderrShouldContain("MultiNST PSK").
asLines().stream().filter(s ->
s.contains("MultiNST PSK")).toList();
List<String> sp = list.stream().filter(s ->
s.contains("MultiNST PSK (Server)")).toList();
List<String> serverPSK = new ArrayList<>(sp.stream().toList());
List<String> clientPSK = list.stream().filter(s ->
s.contains("MultiNST PSK (Client)")).toList();
System.out.println("found list: " + list.size());
System.out.println("found server: " + serverPSK.size());
serverPSK.stream().forEach(s -> System.out.println("\t" + s));
System.out.println("found client: " + clientPSK.size());
clientPSK.stream().forEach(s -> System.out.println("\t" + s));
// Must search all results as order is not guaranteed.
clientPSK.stream().forEach(cli -> {
for (int i = 0; i < serverPSK.size(); i++) {
String svr = serverPSK.get(i);
if (svr.regionMatches(svr.length() - 16, cli,
cli.length() - 16, 16)) {
System.out.println("entry " + (i + 1) + " match.");
serverPSK.remove(i);
return;
}
}
System.out.println("client entry (" + cli.substring(0, 16) +
") not found in server list");
pass = false;
});
} catch (RuntimeException e) {
System.out.println("Error looking at PSK results.");
throw new Exception(e);
}
if (TLS13) {
if (!pass) {
throw new Exception("Test failed: " + params);
}
} else {
if (pass) {
throw new Exception("Test failed: " + params);
}
}
System.out.println("Test Passed");
return;
}
int ticketCount = Integer.parseInt(
System.getProperty("jdk.tls.server.newSessionTicketCount"));
TLSBase.Server server = new TLSBase.Server();
System.out.println("------ Start connection");
TLSBase.Client initial = new TLSBase.Client();
SSLSession initialSession = initial.getSession();
System.out.println("id = " + hex.formatHex(initialSession.getId()));
System.out.println("session = " + initialSession);
System.out.println("------ getNewSession from original client");
ArrayList<Thread> slist = new ArrayList<>(ticketCount);
System.out.println("tx " + ticketCount);
for (int i = 0; ticketCount > i; i++) {
Thread t = new ClientThread(initial);
t.setName("Iteration " + i);
slist.add(t);
t.start();
}
wait.countDown();
for (Thread t : slist) {
t.join(1000);
System.err.println("released: " + t.getName());
}
System.out.println("------ Closing connections");
initial.close();
server.close();
System.out.println("------ End");
System.exit(0);
}
}

View File

@ -0,0 +1,145 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @library /test/lib
* @library /javax/net/ssl/templates
* @bug 8242008
* @summary Verifies sequence of used NST entries from the cache queue.
* @run main/othervm MultiNSTSequence -Djdk.tls.server.newSessionTicketCount=2
*/
import jdk.test.lib.Utils;
import jdk.test.lib.process.OutputAnalyzer;
import jdk.test.lib.process.ProcessTools;
import javax.net.ssl.SSLSession;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
/**
* This test verifies that multiple NSTs take the oldest PSK from the
* QueueCacheEntry stored in the TLS Session Cache.
*
* Note: Beyond 9 iterations the PSK id verification code becomes complicated
* with a QueueCacheEntry limit set to retain only the 10 newest entries.
*
* TLS 1.2 spec does not specify multiple NST behavior.
*/
public class MultiNSTSequence {
static HexFormat hex = HexFormat.of();
static final int ITERATIONS = 9;
public static void main(String[] args) throws Exception {
if (!args[0].equalsIgnoreCase("p")) {
StringBuilder sb = new StringBuilder();
Arrays.stream(args).forEach(a -> sb.append(a).append(" "));
String params = sb.toString();
System.setProperty("test.java.opts",
"-Dtest.src=" + System.getProperty("test.src") +
" -Dtest.jdk=" + System.getProperty("test.jdk") +
" -Dtest.root=" + System.getProperty("test.root") +
" -Djavax.net.debug=ssl,handshake " + params
);
System.out.println("test.java.opts: " +
System.getProperty("test.java.opts"));
ProcessBuilder pb = ProcessTools.createTestJavaProcessBuilder(
Utils.addTestJavaOpts("MultiNSTSequence", "p"));
OutputAnalyzer output = ProcessTools.executeProcess(pb);
boolean pass = true;
try {
List<String> list = output.stderrShouldContain("MultiNST PSK").
asLines().stream().filter(s ->
s.contains("MultiNST PSK")).toList();
List<String> serverPSK = list.stream().filter(s ->
s.contains("MultiNST PSK (Server)")).toList();
List<String> clientPSK = list.stream().filter(s ->
s.contains("MultiNST PSK (Client)")).toList();
System.out.println("found list: " + list.size());
System.out.println("found server: " + serverPSK.size());
serverPSK.stream().forEach(s -> System.out.println("\t" + s));
System.out.println("found client: " + clientPSK.size());
clientPSK.stream().forEach(s -> System.out.println("\t" + s));
int i;
for (i = 0; i < ITERATIONS; i++) {
String svr = serverPSK.get(i);
String cli = clientPSK.get(i);
if (svr.regionMatches(svr.length() - 16, cli, cli.length() - 16, 16)) {
System.out.println("entry " + (i + 1) + " match.");
} else {
System.out.println("entry " + (i + 1) + " server and client PSK didn't match:");
System.out.println(" server: " + svr);
System.out.println(" client: " + cli);
pass = false;
}
}
} catch (RuntimeException e) {
System.out.println("Server and Client PSK usage order is not" +
" the same.");
pass = false;
}
if (!pass) {
throw new Exception("Test failed: " + params);
}
System.out.println("Test Passed");
return;
}
TLSBase.Server server = new TLSBase.Server();
System.out.println("------ Initial connection");
TLSBase.Client initial = new TLSBase.Client();
SSLSession initialSession = initial.connect().getSession();
System.out.println("id = " + hex.formatHex(initialSession.getId()));
System.out.println("session = " + initialSession);
System.out.println("------ Resume client");
for (int i = 0; i < ITERATIONS; i++) {
SSLSession r = new TLSBase.Client(initial).connect().getSession();
StringBuilder sb = new StringBuilder(100);
sb.append("Iteration: ").append(i);
sb.append("\tid = ").append(hex.formatHex(r.getId()));
sb.append("\tsession = ").append(r);
System.out.println(sb);
if (!initialSession.toString().equalsIgnoreCase(r.toString())) {
throw new Exception("Resumed session did not match");
}
}
System.out.println("------ Closing connections");
initial.close();
server.close();
System.out.println("------ End");
System.exit(0);
}
}