/* * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ // // SunJSSE does not support dynamic system properties, no way to re-use // system properties in samevm/agentvm mode. // /* * @test * @bug 8274736 8277970 * @summary Concurrent read/close of SSLSockets causes SSLSessions to be * invalidated unnecessarily * @library /javax/net/ssl/templates * @run main/othervm NoInvalidateSocketException TLSv1.3 * @run main/othervm NoInvalidateSocketException TLSv1.2 * @run main/othervm -Djdk.tls.client.enableSessionTicketExtension=false * NoInvalidateSocketException TLSv1.2 */ import java.io.*; import javax.net.ssl.*; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; public class NoInvalidateSocketException extends SSLSocketTemplate { private static final int ITERATIONS = 10; // This controls how long the main thread waits before closing the socket. // This may need tweaking for different environments to get the timing // right. private static final int CLOSE_DELAY = 10; private static SSLContext clientSSLCtx; private static SSLSocket theSSLSocket; private static SSLSession theSSLSession; private static InputStream theInputStream; private static String theSSLSocketHashCode; private static SSLSession lastSSLSession; private static final List serverCleanupList = new ArrayList<>(); private static String tlsVersion = null; private static int invalidSessCount = 0; private static volatile boolean readFromSocket = false; private static volatile boolean finished = false; public static void main(String[] args) throws Exception { if (System.getProperty("javax.net.debug") == null) { System.setProperty("javax.net.debug", "session"); } if (args != null && args.length >= 1) { tlsVersion = args[0]; } new NoInvalidateSocketException(true).run(); if (invalidSessCount > 0) { throw new RuntimeException("One or more sessions were improperly " + "invalidated."); } } public NoInvalidateSocketException(boolean sepSrvThread) { super(sepSrvThread); } @Override public boolean isCustomizedClientConnection() { return true; } @Override public void runClientApplication(int serverPort) { Thread.currentThread().setName("Main Client Thread"); // Create the SSLContext we'll use for client sockets for the // duration of the test. try { clientSSLCtx = createClientSSLContext(); } catch (Exception e) { throw new RuntimeException("Failed to create client ctx", e); } // Create the reader thread ReaderThread readerThread = new ReaderThread(); readerThread.setName("Client Reader Thread"); readerThread.start(); try { for (int i = 0; i < ITERATIONS; i++) { openSSLSocket(); doHandshake(); getInputStream(); getAndCompareSession(); // Perform the Close/Read MT collision readCloseMultiThreaded(); // Check to make sure that the initially negotiated session // remains intact. isSessionValid(); lastSSLSession = theSSLSession; // Insert a short gap between iterations Thread.sleep(1000); System.out.println(); } } catch (Exception e) { logToConsole("Unexpected Exception: " + e); } finally { // Tell the reader thread to finish finished = true; } } private void readCloseMultiThreaded() throws IOException, InterruptedException { // Tell the reader thread to start trying to read from this // socket readFromSocket = true; // Short pause to give the reader thread time to start // reading. if (CLOSE_DELAY > 0) { Thread.sleep(CLOSE_DELAY); } // The problem happens when the reader thread tries to read // from the socket while this thread is in the close() call closeSSLSocket(); // Pause to give the reader thread time to discover that the // socket is closed and throw a SocketException Thread.sleep(500); } private class ReaderThread extends Thread { public void run() { // This thread runs in a tight loop until // readFromSocket == true while (!finished) { if (readFromSocket) { int result = 0; try { // If the timing is just // right, this will throw a SocketException // and the SSLSession will be // invalidated. result = readFromSSLSocket(); } catch (Exception e) { logToConsole("Exception reading from SSLSocket@" + theSSLSocketHashCode + ": " + e); e.printStackTrace(System.out); // Stop trying to read from // the socket now readFromSocket = false; } if (result == -1) { logToConsole("Reached end of stream reading from " + "SSLSocket@" + theSSLSocketHashCode); // Stop trying to read from // the socket now readFromSocket = false; } } } } } private void openSSLSocket() throws IOException { theSSLSocket = (SSLSocket)clientSSLCtx.getSocketFactory(). createSocket(serverAddress, serverPort); if (tlsVersion != null) { theSSLSocket.setEnabledProtocols(new String[] { tlsVersion }); } theSSLSocketHashCode = String.format("%08x", theSSLSocket.hashCode()); logToConsole("Opened SSLSocket@" + theSSLSocketHashCode); } private void doHandshake() throws IOException { logToConsole("Started handshake on SSLSocket@" + theSSLSocketHashCode); theSSLSocket.startHandshake(); logToConsole("Finished handshake on SSLSocket@" + theSSLSocketHashCode); } private void getInputStream() throws IOException { theInputStream = theSSLSocket.getInputStream(); } private void getAndCompareSession() { theSSLSession = theSSLSocket.getSession(); // Have we opened a new session or re-used the last one? if (lastSSLSession == null || !theSSLSession.equals(lastSSLSession)) { logToConsole("*** OPENED NEW SESSION ***: " + theSSLSession); } else { logToConsole("*** RE-USING PREVIOUS SESSION ***: " + theSSLSession + ")"); } } private void closeSSLSocket() throws IOException { logToConsole("Closing SSLSocket@" + theSSLSocketHashCode); theSSLSocket.close(); logToConsole("Closed SSLSocket@" + theSSLSocketHashCode); } private int readFromSSLSocket() throws Exception { logToConsole("Started reading from SSLSocket@" + theSSLSocketHashCode); int result = theInputStream.read(); logToConsole("Finished reading from SSLSocket@" + theSSLSocketHashCode + ": result = " + result); return result; } private void isSessionValid() { // Is the session still valid? if (theSSLSession.isValid()) { logToConsole("*** " + theSSLSession + " IS VALID ***"); } else { logToConsole("*** " + theSSLSession + " IS INVALID ***"); invalidSessCount++; } } private static void logToConsole(String s) { System.out.println(System.nanoTime() + ": " + Thread.currentThread().getName() + ": " + s); } @Override public void doServerSide() throws Exception { Thread.currentThread().setName("Server Listener Thread"); SSLContext context = createServerSSLContext(); SSLServerSocketFactory sslssf = context.getServerSocketFactory(); InetAddress serverAddress = this.serverAddress; SSLServerSocket sslServerSocket = serverAddress == null ? (SSLServerSocket)sslssf.createServerSocket(serverPort) : (SSLServerSocket)sslssf.createServerSocket(); if (serverAddress != null) { sslServerSocket.bind(new InetSocketAddress(serverAddress, serverPort)); } configureServerSocket(sslServerSocket); serverPort = sslServerSocket.getLocalPort(); logToConsole("Listening on " + sslServerSocket.getLocalSocketAddress()); // Signal the client, the server is ready to accept connection. serverCondition.countDown(); // Try to accept a connection in 5 seconds. // We will do this in a loop until the client flips the // finished variable to true SSLSocket sslSocket; int timeoutCount = 0; try { do { try { sslSocket = (SSLSocket) sslServerSocket.accept(); timeoutCount = 0; // Reset the timeout counter; logToConsole("Accepted connection from " + sslSocket.getRemoteSocketAddress()); // Add the socket to the cleanup list so it can get // closed at the end of the test serverCleanupList.add(sslSocket); boolean clientIsReady = clientCondition.await(30L, TimeUnit.SECONDS); if (clientIsReady) { // Handle the connection in a new thread ServerHandlerThread sht = null; try { sht = new ServerHandlerThread(sslSocket); sht.start(); } finally { if (sht != null) { sht.join(); } } } } catch (SocketTimeoutException ste) { timeoutCount++; // If we are finished then we can return, otherwise // check if we've timed out too many times (an exception // case). One way or the other we will exit eventually. if (finished) { return; } else if (timeoutCount >= 3) { logToConsole("Server accept timeout exceeded"); throw ste; } } } while (!finished); } finally { sslServerSocket.close(); // run through the server cleanup list and close those sockets // as well. for (SSLSocket sock : serverCleanupList) { try { if (sock != null) { sock.close(); } } catch (IOException ioe) { // Swallow these close failures as the server itself // is shutting down anyway. } } } } @Override public void configureServerSocket(SSLServerSocket socket) { try { socket.setReuseAddress(true); socket.setSoTimeout(5000); } catch (SocketException se) { // Rethrow as unchecked to satisfy the override signature throw new RuntimeException(se); } } @Override public void runServerApplication(SSLSocket sslSocket) { Thread.currentThread().setName("Server Reader Thread"); SSLSocket sock = null; sock = sslSocket; try { BufferedReader is = new BufferedReader( new InputStreamReader(sock.getInputStream())); PrintWriter os = new PrintWriter(new BufferedWriter( new OutputStreamWriter(sock.getOutputStream()))); // Only handle a single burst of data char[] buf = new char[1024]; int dataRead = is.read(buf); logToConsole(String.format("Received: %d bytes of data\n", dataRead)); os.println("Received connection from client"); os.flush(); } catch (IOException ioe) { // Swallow these exceptions for this test } } private class ServerHandlerThread extends Thread { SSLSocket sock; ServerHandlerThread(SSLSocket socket) { this.sock = Objects.requireNonNull(socket, "Illegal null socket"); } @Override public void run() { try { runServerApplication(sock); } catch (Exception exc) { // Wrap inside an unchecked exception to satisfy Runnable throw new RuntimeException(exc); } } void close() { try { sock.close(); } catch (IOException e) { // swallow this exception } } } }