8239798: SSLSocket closes socket both socket endpoints on a SocketTimeoutException

Reviewed-by: xuelei
This commit is contained in:
Alexey Bakhtin 2020-03-11 19:14:08 +03:00 committed by Andrew Brygin
parent 6275aee690
commit 14e37ba3df
5 changed files with 128 additions and 113 deletions

View File

@ -436,6 +436,8 @@ public final class SSLSocketImpl
if (!conContext.isNegotiated) { if (!conContext.isNegotiated) {
readHandshakeRecord(); readHandshakeRecord();
} }
} catch (InterruptedIOException iioe) {
handleException(iioe);
} catch (IOException ioe) { } catch (IOException ioe) {
throw conContext.fatal(Alert.HANDSHAKE_FAILURE, throw conContext.fatal(Alert.HANDSHAKE_FAILURE,
"Couldn't kickstart handshaking", ioe); "Couldn't kickstart handshaking", ioe);
@ -1374,12 +1376,11 @@ public final class SSLSocketImpl
} }
} catch (SSLException ssle) { } catch (SSLException ssle) {
throw ssle; throw ssle;
} catch (InterruptedIOException iioe) {
// don't change exception in case of timeouts or interrupts
throw iioe;
} catch (IOException ioe) { } catch (IOException ioe) {
if (!(ioe instanceof SSLException)) { throw new SSLException("readHandshakeRecord", ioe);
throw new SSLException("readHandshakeRecord", ioe);
} else {
throw ioe;
}
} }
} }
@ -1440,6 +1441,9 @@ public final class SSLSocketImpl
} }
} catch (SSLException ssle) { } catch (SSLException ssle) {
throw ssle; throw ssle;
} catch (InterruptedIOException iioe) {
// don't change exception in case of timeouts or interrupts
throw iioe;
} catch (IOException ioe) { } catch (IOException ioe) {
if (!(ioe instanceof SSLException)) { if (!(ioe instanceof SSLException)) {
throw new SSLException("readApplicationRecord", ioe); throw new SSLException("readApplicationRecord", ioe);

View File

@ -1,5 +1,6 @@
/* /*
* Copyright (c) 1996, 2019, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 1996, 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, Azul Systems, Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -26,6 +27,7 @@
package sun.security.ssl; package sun.security.ssl;
import java.io.EOFException; import java.io.EOFException;
import java.io.InterruptedIOException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -47,37 +49,31 @@ import sun.security.ssl.SSLCipher.SSLReadCipher;
final class SSLSocketInputRecord extends InputRecord implements SSLRecord { final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
private InputStream is = null; private InputStream is = null;
private OutputStream os = null; private OutputStream os = null;
private final byte[] temporary = new byte[1024]; private final byte[] header = new byte[headerSize];
private int headerOff = 0;
// Cache for incomplete record body.
private ByteBuffer recordBody = ByteBuffer.allocate(1024);
private boolean formatVerified = false; // SSLv2 ruled out? private boolean formatVerified = false; // SSLv2 ruled out?
// Cache for incomplete handshake messages. // Cache for incomplete handshake messages.
private ByteBuffer handshakeBuffer = null; private ByteBuffer handshakeBuffer = null;
private boolean hasHeader = false; // Had read the record header
SSLSocketInputRecord(HandshakeHash handshakeHash) { SSLSocketInputRecord(HandshakeHash handshakeHash) {
super(handshakeHash, SSLReadCipher.nullTlsReadCipher()); super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
} }
@Override @Override
int bytesInCompletePacket() throws IOException { int bytesInCompletePacket() throws IOException {
if (!hasHeader) { // read header
// read exactly one record try {
try { readHeader();
int really = read(is, temporary, 0, headerSize); } catch (EOFException eofe) {
if (really < 0) { // The caller will handle EOF.
// EOF: peer shut down incorrectly return -1;
return -1;
}
} catch (EOFException eofe) {
// The caller will handle EOF.
return -1;
}
hasHeader = true;
} }
byte byteZero = temporary[0]; byte byteZero = header[0];
int len = 0; int len = 0;
/* /*
@ -93,9 +89,9 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
* Last sanity check that it's not a wild record * Last sanity check that it's not a wild record
*/ */
if (!ProtocolVersion.isNegotiable( if (!ProtocolVersion.isNegotiable(
temporary[1], temporary[2], false, false)) { header[1], header[2], false, false)) {
throw new SSLException("Unrecognized record version " + throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[1], temporary[2]) + ProtocolVersion.nameOf(header[1], header[2]) +
" , plaintext connection?"); " , plaintext connection?");
} }
@ -109,8 +105,8 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
/* /*
* One of the SSLv3/TLS message types. * One of the SSLv3/TLS message types.
*/ */
len = ((temporary[3] & 0xFF) << 8) + len = ((header[3] & 0xFF) << 8) +
(temporary[4] & 0xFF) + headerSize; (header[4] & 0xFF) + headerSize;
} else { } else {
/* /*
* Must be SSLv2 or something unknown. * Must be SSLv2 or something unknown.
@ -121,11 +117,11 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
*/ */
boolean isShort = ((byteZero & 0x80) != 0); boolean isShort = ((byteZero & 0x80) != 0);
if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) { if (isShort && ((header[2] == 1) || (header[2] == 4))) {
if (!ProtocolVersion.isNegotiable( if (!ProtocolVersion.isNegotiable(
temporary[3], temporary[4], false, false)) { header[3], header[4], false, false)) {
throw new SSLException("Unrecognized record version " + throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[3], temporary[4]) + ProtocolVersion.nameOf(header[3], header[4]) +
" , plaintext connection?"); " , plaintext connection?");
} }
@ -138,9 +134,9 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
// //
// int mask = (isShort ? 0x7F : 0x3F); // int mask = (isShort ? 0x7F : 0x3F);
// len = ((byteZero & mask) << 8) + // len = ((byteZero & mask) << 8) +
// (temporary[1] & 0xFF) + (isShort ? 2 : 3); // (header[1] & 0xFF) + (isShort ? 2 : 3);
// //
len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2; len = ((byteZero & 0x7F) << 8) + (header[1] & 0xFF) + 2;
} else { } else {
// Gobblygook! // Gobblygook!
throw new SSLException( throw new SSLException(
@ -160,34 +156,41 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
return null; return null;
} }
if (!hasHeader) { // read header
// read exactly one record readHeader();
int really = read(is, temporary, 0, headerSize);
if (really < 0) { Plaintext[] plaintext = null;
throw new EOFException("SSL peer shut down incorrectly"); boolean cleanInBuffer = true;
try {
if (!formatVerified) {
formatVerified = true;
/*
* The first record must either be a handshake record or an
* alert message. If it's not, it is either invalid or an
* SSLv2 message.
*/
if ((header[0] != ContentType.HANDSHAKE.id) &&
(header[0] != ContentType.ALERT.id)) {
plaintext = handleUnknownRecord();
}
} }
hasHeader = true;
}
Plaintext plaintext = null; // The record header should has consumed.
if (!formatVerified) { if (plaintext == null) {
formatVerified = true; plaintext = decodeInputRecord();
}
/* } catch(InterruptedIOException e) {
* The first record must either be a handshake record or an // do not clean header and recordBody in case of Socket Timeout
* alert message. If it's not, it is either invalid or an cleanInBuffer = false;
* SSLv2 message. throw e;
*/ } finally {
if ((temporary[0] != ContentType.HANDSHAKE.id) && if (cleanInBuffer) {
(temporary[0] != ContentType.ALERT.id)) { headerOff = 0;
hasHeader = false; recordBody.clear();
return handleUnknownRecord(temporary);
} }
} }
return plaintext;
// The record header should has consumed.
hasHeader = false;
return decodeInputRecord(temporary);
} }
@Override @Override
@ -200,9 +203,7 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
this.os = outputStream; this.os = outputStream;
} }
// Note that destination may be null private Plaintext[] decodeInputRecord() throws IOException, BadPaddingException {
private Plaintext[] decodeInputRecord(
byte[] header) throws IOException, BadPaddingException {
byte contentType = header[0]; // pos: 0 byte contentType = header[0]; // pos: 0
byte majorVersion = header[1]; // pos: 1 byte majorVersion = header[1]; // pos: 1
byte minorVersion = header[2]; // pos: 2 byte minorVersion = header[2]; // pos: 2
@ -227,30 +228,27 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
} }
// //
// Read a complete record. // Read a complete record and store in the recordBody
// recordBody is used to cache incoming record and restore in case of
// read operation timedout
// //
ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen); if (recordBody.position() == 0) {
int dstPos = destination.position(); if (recordBody.capacity() < contentLen) {
destination.put(temporary, 0, headerSize); recordBody = ByteBuffer.allocate(contentLen);
while (contentLen > 0) {
int howmuch = Math.min(temporary.length, contentLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
} }
recordBody.limit(contentLen);
destination.put(temporary, 0, howmuch); } else {
contentLen -= howmuch; contentLen = recordBody.remaining();
} }
destination.flip(); readFully(contentLen);
destination.position(dstPos + headerSize); recordBody.flip();
if (SSLLogger.isOn && SSLLogger.isOn("record")) { if (SSLLogger.isOn && SSLLogger.isOn("record")) {
SSLLogger.fine( SSLLogger.fine(
"READ: " + "READ: " +
ProtocolVersion.nameOf(majorVersion, minorVersion) + ProtocolVersion.nameOf(majorVersion, minorVersion) +
" " + ContentType.nameOf(contentType) + ", length = " + " " + ContentType.nameOf(contentType) + ", length = " +
destination.remaining()); recordBody.remaining());
} }
// //
@ -259,7 +257,7 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
ByteBuffer fragment; ByteBuffer fragment;
try { try {
Plaintext plaintext = Plaintext plaintext =
readCipher.decrypt(contentType, destination, null); readCipher.decrypt(contentType, recordBody, null);
fragment = plaintext.fragment; fragment = plaintext.fragment;
contentType = plaintext.contentType; contentType = plaintext.contentType;
} catch (BadPaddingException bpe) { } catch (BadPaddingException bpe) {
@ -361,8 +359,7 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
}; };
} }
private Plaintext[] handleUnknownRecord( private Plaintext[] handleUnknownRecord() throws IOException, BadPaddingException {
byte[] header) throws IOException, BadPaddingException {
byte firstByte = header[0]; byte firstByte = header[0];
byte thirdByte = header[2]; byte thirdByte = header[2];
@ -404,32 +401,29 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
} }
int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF); int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
if (recordBody.position() == 0) {
ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen); if (recordBody.capacity() < (headerSize + msgLen)) {
destination.put(temporary, 0, headerSize); recordBody = ByteBuffer.allocate(headerSize + msgLen);
msgLen -= 3; // had read 3 bytes of content as header
while (msgLen > 0) {
int howmuch = Math.min(temporary.length, msgLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
} }
recordBody.limit(headerSize + msgLen);
destination.put(temporary, 0, howmuch); recordBody.put(header, 0, headerSize);
msgLen -= howmuch; } else {
msgLen = recordBody.remaining();
} }
destination.flip(); msgLen -= 3; // had read 3 bytes of content as header
readFully(msgLen);
recordBody.flip();
/* /*
* If we can map this into a V3 ClientHello, read and * If we can map this into a V3 ClientHello, read and
* hash the rest of the V2 handshake, turn it into a * hash the rest of the V2 handshake, turn it into a
* V3 ClientHello message, and pass it up. * V3 ClientHello message, and pass it up.
*/ */
destination.position(2); // exclude the header recordBody.position(2); // exclude the header
handshakeHash.receive(destination); handshakeHash.receive(recordBody);
destination.position(0); recordBody.position(0);
ByteBuffer converted = convertToClientHello(destination); ByteBuffer converted = convertToClientHello(recordBody);
if (SSLLogger.isOn && SSLLogger.isOn("packet")) { if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine( SSLLogger.fine(
@ -449,28 +443,42 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
} }
} }
// Read the exact bytes of data, otherwise, return -1. // Read the exact bytes of data, otherwise, throw IOException.
private static int read(InputStream is, private int readFully(int len) throws IOException {
byte[] buffer, int offset, int len) throws IOException { int end = len + recordBody.position();
int n = 0; int off = recordBody.position();
while (n < len) { try {
int readLen = is.read(buffer, offset + n, len - n); while (off < end) {
if (readLen < 0) { off += read(is, recordBody.array(), off, end - off);
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine("Raw read: EOF");
}
return -1;
} }
} finally {
recordBody.position(off);
}
return len;
}
// Read SSE record header, otherwise, throw IOException.
private int readHeader() throws IOException {
while (headerOff < headerSize) {
headerOff += read(is, header, headerOff, headerSize - headerOff);
}
return headerSize;
}
private static int read(InputStream is, byte[] buf, int off, int len) throws IOException {
int readLen = is.read(buf, off, len);
if (readLen < 0) {
if (SSLLogger.isOn && SSLLogger.isOn("packet")) { if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen); SSLLogger.fine("Raw read: EOF");
SSLLogger.fine("Raw read", bb);
} }
throw new EOFException("SSL peer shut down incorrectly");
n += readLen;
} }
return n; if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
ByteBuffer bb = ByteBuffer.wrap(buf, off, readLen);
SSLLogger.fine("Raw read", bb);
}
return readLen;
} }
// Try to use up the input stream without impact the performance too much. // Try to use up the input stream without impact the performance too much.

View File

@ -27,6 +27,7 @@ package sun.security.ssl;
import java.io.EOFException; import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import javax.crypto.AEADBadTagException; import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException; import javax.crypto.BadPaddingException;
@ -136,6 +137,9 @@ interface SSLTransport {
} catch (EOFException eofe) { } catch (EOFException eofe) {
// rethrow EOFException, the call will handle it if neede. // rethrow EOFException, the call will handle it if neede.
throw eofe; throw eofe;
} catch (InterruptedIOException iioe) {
// don't close the Socket in case of timeouts or interrupts.
throw iioe;
} catch (IOException ioe) { } catch (IOException ioe) {
throw context.fatal(Alert.UNEXPECTED_MESSAGE, ioe); throw context.fatal(Alert.UNEXPECTED_MESSAGE, ioe);
} }

View File

@ -26,8 +26,7 @@
/* /*
* @test * @test
* @bug 4836493 * @bug 4836493 8239798
* @ignore need further evaluation
* @summary Socket timeouts for SSLSockets causes data corruption. * @summary Socket timeouts for SSLSockets causes data corruption.
* @run main/othervm ClientTimeout * @run main/othervm ClientTimeout
*/ */

View File

@ -36,7 +36,7 @@
import javax.net.ssl.*; import javax.net.ssl.*;
import java.io.*; import java.io.*;
import java.net.InetAddress; import java.net.*;
public class SSLExceptionForIOIssue implements SSLContextTemplate { public class SSLExceptionForIOIssue implements SSLContextTemplate {
@ -139,7 +139,7 @@ public class SSLExceptionForIOIssue implements SSLContextTemplate {
} catch (SSLProtocolException | SSLHandshakeException sslhe) { } catch (SSLProtocolException | SSLHandshakeException sslhe) {
clientException = sslhe; clientException = sslhe;
System.err.println("unexpected client exception: " + sslhe); System.err.println("unexpected client exception: " + sslhe);
} catch (SSLException ssle) { } catch (SSLException | SocketTimeoutException ssle) {
// the expected exception, ignore it // the expected exception, ignore it
System.err.println("expected client exception: " + ssle); System.err.println("expected client exception: " + ssle);
} catch (Exception e) { } catch (Exception e) {