jdk-24/test/jdk/sun/security/ssl/SSLSessionImpl/NoInvalidateSocketException.java

410 lines
14 KiB
Java
Raw Normal View History

/*
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
//
// SunJSSE does not support dynamic system properties, no way to re-use
// system properties in samevm/agentvm mode.
//
/*
* @test
* @bug 8274736 8277970
* @summary Concurrent read/close of SSLSockets causes SSLSessions to be
* invalidated unnecessarily
* @library /javax/net/ssl/templates
* @run main/othervm NoInvalidateSocketException TLSv1.3
* @run main/othervm NoInvalidateSocketException TLSv1.2
* @run main/othervm -Djdk.tls.client.enableSessionTicketExtension=false
* NoInvalidateSocketException TLSv1.2
*/
import java.io.*;
import javax.net.ssl.*;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
public class NoInvalidateSocketException extends SSLSocketTemplate {
private static final int ITERATIONS = 10;
// This controls how long the main thread waits before closing the socket.
// This may need tweaking for different environments to get the timing
// right.
private static final int CLOSE_DELAY = 10;
private static SSLContext clientSSLCtx;
private static SSLSocket theSSLSocket;
private static SSLSession theSSLSession;
private static InputStream theInputStream;
private static String theSSLSocketHashCode;
private static SSLSession lastSSLSession;
private static final List<SSLSocket> serverCleanupList = new ArrayList<>();
private static String tlsVersion = null;
private static int invalidSessCount = 0;
private static volatile boolean readFromSocket = false;
private static volatile boolean finished = false;
public static void main(String[] args) throws Exception {
if (System.getProperty("javax.net.debug") == null) {
System.setProperty("javax.net.debug", "session");
}
if (args != null && args.length >= 1) {
tlsVersion = args[0];
}
new NoInvalidateSocketException(true).run();
if (invalidSessCount > 0) {
throw new RuntimeException("One or more sessions were improperly " +
"invalidated.");
}
}
public NoInvalidateSocketException(boolean sepSrvThread) {
super(sepSrvThread);
}
@Override
public boolean isCustomizedClientConnection() {
return true;
}
@Override
public void runClientApplication(int serverPort) {
Thread.currentThread().setName("Main Client Thread");
// Create the SSLContext we'll use for client sockets for the
// duration of the test.
try {
clientSSLCtx = createClientSSLContext();
} catch (Exception e) {
throw new RuntimeException("Failed to create client ctx", e);
}
// Create the reader thread
ReaderThread readerThread = new ReaderThread();
readerThread.setName("Client Reader Thread");
readerThread.start();
try {
for (int i = 0; i < ITERATIONS; i++) {
openSSLSocket();
doHandshake();
getInputStream();
getAndCompareSession();
// Perform the Close/Read MT collision
readCloseMultiThreaded();
// Check to make sure that the initially negotiated session
// remains intact.
isSessionValid();
lastSSLSession = theSSLSession;
// Insert a short gap between iterations
Thread.sleep(1000);
System.out.println();
}
} catch (Exception e) {
logToConsole("Unexpected Exception: " + e);
} finally {
// Tell the reader thread to finish
finished = true;
}
}
private void readCloseMultiThreaded() throws IOException,
InterruptedException {
// Tell the reader thread to start trying to read from this
// socket
readFromSocket = true;
// Short pause to give the reader thread time to start
// reading.
if (CLOSE_DELAY > 0) {
Thread.sleep(CLOSE_DELAY);
}
// The problem happens when the reader thread tries to read
// from the socket while this thread is in the close() call
closeSSLSocket();
// Pause to give the reader thread time to discover that the
// socket is closed and throw a SocketException
Thread.sleep(500);
}
private class ReaderThread extends Thread {
public void run() {
// This thread runs in a tight loop until
// readFromSocket == true
while (!finished) {
if (readFromSocket) {
int result = 0;
try {
// If the timing is just
// right, this will throw a SocketException
// and the SSLSession will be
// invalidated.
result = readFromSSLSocket();
} catch (Exception e) {
logToConsole("Exception reading from SSLSocket@" +
theSSLSocketHashCode + ": " + e);
e.printStackTrace(System.out);
// Stop trying to read from
// the socket now
readFromSocket = false;
}
if (result == -1) {
logToConsole("Reached end of stream reading from " +
"SSLSocket@" + theSSLSocketHashCode);
// Stop trying to read from
// the socket now
readFromSocket = false;
}
}
}
}
}
private void openSSLSocket() throws IOException {
theSSLSocket = (SSLSocket)clientSSLCtx.getSocketFactory().
createSocket(serverAddress, serverPort);
if (tlsVersion != null) {
theSSLSocket.setEnabledProtocols(new String[] { tlsVersion });
}
theSSLSocketHashCode = String.format("%08x", theSSLSocket.hashCode());
logToConsole("Opened SSLSocket@" + theSSLSocketHashCode);
}
private void doHandshake() throws IOException {
logToConsole("Started handshake on SSLSocket@" +
theSSLSocketHashCode);
theSSLSocket.startHandshake();
logToConsole("Finished handshake on SSLSocket@" +
theSSLSocketHashCode);
}
private void getInputStream() throws IOException {
theInputStream = theSSLSocket.getInputStream();
}
private void getAndCompareSession() {
theSSLSession = theSSLSocket.getSession();
// Have we opened a new session or re-used the last one?
if (lastSSLSession == null ||
!theSSLSession.equals(lastSSLSession)) {
logToConsole("*** OPENED NEW SESSION ***: " +
theSSLSession);
} else {
logToConsole("*** RE-USING PREVIOUS SESSION ***: " +
theSSLSession + ")");
}
}
private void closeSSLSocket() throws IOException {
logToConsole("Closing SSLSocket@" + theSSLSocketHashCode);
theSSLSocket.close();
logToConsole("Closed SSLSocket@" + theSSLSocketHashCode);
}
private int readFromSSLSocket() throws Exception {
logToConsole("Started reading from SSLSocket@" +
theSSLSocketHashCode);
int result = theInputStream.read();
logToConsole("Finished reading from SSLSocket@" +
theSSLSocketHashCode + ": result = " + result);
return result;
}
private void isSessionValid() {
// Is the session still valid?
if (theSSLSession.isValid()) {
logToConsole("*** " + theSSLSession + " IS VALID ***");
} else {
logToConsole("*** " + theSSLSession + " IS INVALID ***");
invalidSessCount++;
}
}
private static void logToConsole(String s) {
System.out.println(System.nanoTime() + ": " +
Thread.currentThread().getName() + ": " + s);
}
@Override
public void doServerSide() throws Exception {
Thread.currentThread().setName("Server Listener Thread");
SSLContext context = createServerSSLContext();
SSLServerSocketFactory sslssf = context.getServerSocketFactory();
InetAddress serverAddress = this.serverAddress;
SSLServerSocket sslServerSocket = serverAddress == null ?
(SSLServerSocket)sslssf.createServerSocket(serverPort)
: (SSLServerSocket)sslssf.createServerSocket();
if (serverAddress != null) {
sslServerSocket.bind(new InetSocketAddress(serverAddress,
serverPort));
}
configureServerSocket(sslServerSocket);
serverPort = sslServerSocket.getLocalPort();
logToConsole("Listening on " + sslServerSocket.getLocalSocketAddress());
// Signal the client, the server is ready to accept connection.
serverCondition.countDown();
// Try to accept a connection in 5 seconds.
// We will do this in a loop until the client flips the
// finished variable to true
SSLSocket sslSocket;
int timeoutCount = 0;
try {
do {
try {
sslSocket = (SSLSocket) sslServerSocket.accept();
timeoutCount = 0; // Reset the timeout counter;
logToConsole("Accepted connection from " +
sslSocket.getRemoteSocketAddress());
// Add the socket to the cleanup list so it can get
// closed at the end of the test
serverCleanupList.add(sslSocket);
boolean clientIsReady =
clientCondition.await(30L, TimeUnit.SECONDS);
if (clientIsReady) {
// Handle the connection in a new thread
ServerHandlerThread sht = null;
try {
sht = new ServerHandlerThread(sslSocket);
sht.start();
} finally {
if (sht != null) {
sht.join();
}
}
}
} catch (SocketTimeoutException ste) {
timeoutCount++;
// If we are finished then we can return, otherwise
// check if we've timed out too many times (an exception
// case). One way or the other we will exit eventually.
if (finished) {
return;
} else if (timeoutCount >= 3) {
logToConsole("Server accept timeout exceeded");
throw ste;
}
}
} while (!finished);
} finally {
sslServerSocket.close();
// run through the server cleanup list and close those sockets
// as well.
for (SSLSocket sock : serverCleanupList) {
try {
if (sock != null) {
sock.close();
}
} catch (IOException ioe) {
// Swallow these close failures as the server itself
// is shutting down anyway.
}
}
}
}
@Override
public void configureServerSocket(SSLServerSocket socket) {
try {
socket.setReuseAddress(true);
socket.setSoTimeout(5000);
} catch (SocketException se) {
// Rethrow as unchecked to satisfy the override signature
throw new RuntimeException(se);
}
}
@Override
public void runServerApplication(SSLSocket sslSocket) {
Thread.currentThread().setName("Server Reader Thread");
SSLSocket sock = null;
sock = sslSocket;
try {
BufferedReader is = new BufferedReader(
new InputStreamReader(sock.getInputStream()));
PrintWriter os = new PrintWriter(new BufferedWriter(
new OutputStreamWriter(sock.getOutputStream())));
// Only handle a single burst of data
char[] buf = new char[1024];
int dataRead = is.read(buf);
logToConsole(String.format("Received: %d bytes of data\n",
dataRead));
os.println("Received connection from client");
os.flush();
} catch (IOException ioe) {
// Swallow these exceptions for this test
}
}
private class ServerHandlerThread extends Thread {
SSLSocket sock;
ServerHandlerThread(SSLSocket socket) {
this.sock = Objects.requireNonNull(socket, "Illegal null socket");
}
@Override
public void run() {
try {
runServerApplication(sock);
} catch (Exception exc) {
// Wrap inside an unchecked exception to satisfy Runnable
throw new RuntimeException(exc);
}
}
void close() {
try {
sock.close();
} catch (IOException e) {
// swallow this exception
}
}
}
}