8217429: WebSocket over authenticating proxy fails to send Upgrade headers

Reviewed-by: dfuchs, prappo
This commit is contained in:
Chris Hegarty 2019-01-28 13:51:16 +00:00
parent ef07b1b314
commit 46f4ab603b
8 changed files with 724 additions and 50 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -43,6 +43,7 @@ import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.common.Utils;
import jdk.internal.net.http.websocket.OpeningHandshake;
import jdk.internal.net.http.websocket.WebSocketRequest;
import static jdk.internal.net.http.common.Utils.ALLOWED_HEADERS;
@ -157,7 +158,11 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest {
/** Returns a new instance suitable for authentication. */
public static HttpRequestImpl newInstanceForAuthentication(HttpRequestImpl other) {
return new HttpRequestImpl(other.uri(), other.method(), other);
HttpRequestImpl request = new HttpRequestImpl(other.uri(), other.method(), other);
if (request.isWebSocket()) {
Utils.setWebSocketUpgradeHeaders(request);
}
return request;
}
/**

View File

@ -263,6 +263,15 @@ public final class Utils {
: ! PROXY_AUTH_DISABLED_SCHEMES.isEmpty();
}
// WebSocket connection Upgrade headers
private static final String HEADER_CONNECTION = "Connection";
private static final String HEADER_UPGRADE = "Upgrade";
public static final void setWebSocketUpgradeHeaders(HttpRequestImpl request) {
request.setSystemHeader(HEADER_UPGRADE, "websocket");
request.setSystemHeader(HEADER_CONNECTION, "Upgrade");
}
public static IllegalArgumentException newIAE(String message, Object... args) {
return new IllegalArgumentException(format(message, args));
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -143,8 +143,7 @@ public class OpeningHandshake {
requestBuilder.version(Version.HTTP_1_1).GET();
request = requestBuilder.buildForWebSocket();
request.isWebSocket(true);
request.setSystemHeader(HEADER_UPGRADE, "websocket");
request.setSystemHeader(HEADER_CONNECTION, "Upgrade");
Utils.setWebSocketUpgradeHeaders(request);
request.setProxy(proxy);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,6 +25,9 @@ import java.net.*;
import java.io.*;
import java.util.*;
import java.security.*;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
/**
* A minimal proxy server that supports CONNECT tunneling. It does not do
@ -37,6 +40,18 @@ public class ProxyServer extends Thread implements Closeable {
ServerSocket listener;
int port;
volatile boolean debug;
private final Credentials credentials; // may be null
private static class Credentials {
private final String name;
private final String password;
private Credentials(String name, String password) {
this.name = name;
this.password = password;
}
public String name() { return name; }
public String password() { return password; }
}
/**
* Create proxy on port (zero means don't care). Call getPort()
@ -46,19 +61,42 @@ public class ProxyServer extends Thread implements Closeable {
this(port, false);
}
public ProxyServer(Integer port, Boolean debug) throws IOException {
public ProxyServer(Integer port,
Boolean debug,
String username,
String password)
throws IOException
{
this(port, debug, new Credentials(username, password));
}
public ProxyServer(Integer port,
Boolean debug)
throws IOException
{
this(port, debug, null);
}
public ProxyServer(Integer port,
Boolean debug,
Credentials credentials)
throws IOException
{
this.debug = debug;
listener = new ServerSocket();
listener.setReuseAddress(false);
listener.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), port));
this.port = listener.getLocalPort();
this.credentials = credentials;
setName("ProxyListener");
setDaemon(true);
connections = new LinkedList<>();
start();
}
public ProxyServer(String s) { }
public ProxyServer(String s) {
credentials = null;
}
/**
* Returns the port number this proxy is listening on
@ -194,16 +232,69 @@ public class ProxyServer extends Thread implements Closeable {
return -1;
}
// Checks credentials in the request against those allowable by the proxy.
private boolean authorized(Credentials credentials,
List<String> requestHeaders) {
List<String> authorization = requestHeaders.stream()
.filter(n -> n.toLowerCase(Locale.US).startsWith("proxy-authorization"))
.collect(toList());
if (authorization.isEmpty())
return false;
if (authorization.size() != 1) {
throw new IllegalStateException("Authorization unexpected count:" + authorization);
}
String value = authorization.get(0).substring("proxy-authorization".length()).trim();
if (!value.startsWith(":"))
throw new IllegalStateException("Authorization malformed: " + value);
value = value.substring(1).trim();
if (!value.startsWith("Basic "))
throw new IllegalStateException("Authorization not Basic: " + value);
value = value.substring("Basic ".length());
String values = new String(Base64.getDecoder().decode(value), UTF_8);
int sep = values.indexOf(':');
if (sep < 1) {
throw new IllegalStateException("Authorization no colon: " + values);
}
String name = values.substring(0, sep);
String password = values.substring(sep + 1);
if (name.equals(credentials.name()) && password.equals(credentials.password()))
return true;
return false;
}
public void init() {
try {
byte[] buf = readHeaders(clientIn);
int p = findCRLF(buf);
if (p == -1) {
close();
return;
byte[] buf;
while (true) {
buf = readHeaders(clientIn);
if (findCRLF(buf) == -1) {
close();
return;
}
List<String> headers = asList(new String(buf, UTF_8).split("\r\n"));
// check authorization credentials, if required by the server
if (credentials != null && !authorized(credentials, headers)) {
String resp = "HTTP/1.1 407 Proxy Authentication Required\r\n" +
"Content-Length: 0\r\n" +
"Proxy-Authenticate: Basic realm=\"proxy realm\"\r\n\r\n";
clientOut.write(resp.getBytes(UTF_8));
} else {
break;
}
}
int p = findCRLF(buf);
String cmd = new String(buf, 0, p, "US-ASCII");
String[] params = cmd.split(" ");
if (params[0].equals("CONNECT")) {
doTunnel(params[1]);
} else {

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -46,13 +46,14 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static java.lang.System.err;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
@ -92,12 +93,32 @@ public class DummyWebSocketServer implements Closeable {
private ByteBuffer read = ByteBuffer.allocate(16384);
private final CountDownLatch readReady = new CountDownLatch(1);
public DummyWebSocketServer() {
this(defaultMapping());
private static class Credentials {
private final String name;
private final String password;
private Credentials(String name, String password) {
this.name = name;
this.password = password;
}
public String name() { return name; }
public String password() { return password; }
}
public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
public DummyWebSocketServer() {
this(defaultMapping(), null, null);
}
public DummyWebSocketServer(String username, String password) {
this(defaultMapping(), username, password);
}
public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
String username,
String password) {
requireNonNull(mapping);
Credentials credentials = username != null ?
new Credentials(username, password) : null;
thread = new Thread(() -> {
try {
while (!Thread.currentThread().isInterrupted()) {
@ -107,14 +128,23 @@ public class DummyWebSocketServer implements Closeable {
try {
channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
channel.configureBlocking(true);
StringBuilder request = new StringBuilder();
if (!readRequest(channel, request)) {
throw new IOException("Bad request:" + request);
while (true) {
StringBuilder request = new StringBuilder();
if (!readRequest(channel, request)) {
throw new IOException("Bad request:[" + request + "]");
}
List<String> strings = asList(request.toString().split("\r\n"));
List<String> response = mapping.apply(strings, credentials);
writeResponse(channel, response);
if (response.get(0).startsWith("HTTP/1.1 401")) {
err.println("Sent 401 Authentication response " + channel);
continue;
} else {
serve(channel);
break;
}
}
List<String> strings = asList(request.toString().split("\r\n"));
List<String> response = mapping.apply(strings);
writeResponse(channel, response);
serve(channel);
} catch (IOException e) {
err.println("Error in connection: " + channel + ", " + e);
} finally {
@ -125,7 +155,7 @@ public class DummyWebSocketServer implements Closeable {
}
} catch (ClosedByInterruptException ignored) {
} catch (Exception e) {
err.println(e);
e.printStackTrace(err);
} finally {
close(ssc);
err.println("Stopped at: " + getURI());
@ -256,8 +286,8 @@ public class DummyWebSocketServer implements Closeable {
}
}
private static Function<List<String>, List<String>> defaultMapping() {
return request -> {
private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
return (request, credentials) -> {
List<String> response = new LinkedList<>();
Iterator<String> iterator = request.iterator();
if (!iterator.hasNext()) {
@ -309,14 +339,57 @@ public class DummyWebSocketServer implements Closeable {
sha1.update(x.getBytes(ISO_8859_1));
String v = Base64.getEncoder().encodeToString(sha1.digest());
response.add("Sec-WebSocket-Accept: " + v);
// check authorization credentials, if required by the server
if (credentials != null && !authorized(credentials, requestHeaders)) {
response.clear();
response.add("HTTP/1.1 401 Unauthorized");
response.add("Content-Length: 0");
response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
}
return response;
};
}
// Checks credentials in the request against those allowable by the server.
private static boolean authorized(Credentials credentials,
Map<String,List<String>> requestHeaders) {
List<String> authorization = requestHeaders.get("Authorization");
if (authorization == null)
return false;
if (authorization.size() != 1) {
throw new IllegalStateException("Authorization unexpected count:" + authorization);
}
String header = authorization.get(0);
if (!header.startsWith("Basic "))
throw new IllegalStateException("Authorization not Basic: " + header);
header = header.substring("Basic ".length());
String values = new String(Base64.getDecoder().decode(header), UTF_8);
int sep = values.indexOf(':');
if (sep < 1) {
throw new IllegalStateException("Authorization not colon: " + values);
}
String name = values.substring(0, sep);
String password = values.substring(sep + 1);
if (name.equals(credentials.name()) && password.equals(credentials.password()))
return true;
return false;
}
protected static String expectHeader(Map<String, List<String>> headers,
String name,
String value) {
List<String> v = headers.get(name);
if (v == null) {
throw new IllegalStateException(
format("Expected '%s' header, not present in %s",
name, headers));
}
if (!v.contains(value)) {
throw new IllegalStateException(
format("Expected '%s: %s', actual: '%s: %s'",

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -79,16 +79,32 @@ public class Support {
}
public static DummyWebSocketServer serverWithCannedData(int... data) {
return serverWithCannedDataAndAuthentication(null, null, data);
}
public static DummyWebSocketServer serverWithCannedDataAndAuthentication(
String username,
String password,
int... data)
{
byte[] copy = new byte[data.length];
for (int i = 0; i < data.length; i++) {
copy[i] = (byte) data[i];
}
return serverWithCannedData(copy);
return serverWithCannedDataAndAuthentication(username, password, copy);
}
public static DummyWebSocketServer serverWithCannedData(byte... data) {
return serverWithCannedDataAndAuthentication(null, null, data);
}
public static DummyWebSocketServer serverWithCannedDataAndAuthentication(
String username,
String password,
byte... data)
{
byte[] copy = Arrays.copyOf(data, data.length);
return new DummyWebSocketServer() {
return new DummyWebSocketServer(username, password) {
@Override
protected void write(SocketChannel ch) throws IOException {
int off = 0; int n = 1; // 1 byte at a time

View File

@ -0,0 +1,309 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8217429
* @summary WebSocket proxy tunneling tests
* @compile DummyWebSocketServer.java ../ProxyServer.java
* @run testng/othervm
* -Djdk.http.auth.tunneling.disabledSchemes=
* WebSocketProxyTest
*/
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.Authenticator;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.PasswordAuthentication;
import java.net.ProxySelector;
import java.net.http.HttpResponse;
import java.net.http.WebSocket;
import java.net.http.WebSocketHandshakeException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import static java.net.http.HttpClient.newBuilder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
import static org.testng.FileAssert.fail;
public class WebSocketProxyTest {
// Used to verify a proxy/websocket server requiring Authentication
private static final String USERNAME = "wally";
private static final String PASSWORD = "xyz987";
static class WSAuthenticator extends Authenticator {
@Override
protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
}
}
static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
new Function<>() {
@Override public DummyWebSocketServer apply(int[] data) {
return Support.serverWithCannedData(data); }
@Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
};
static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
new Function<>() {
@Override public DummyWebSocketServer apply(int[] data) {
return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
@Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
};
static final Supplier<ProxyServer> TUNNELING_PROXY_SERVER =
new Supplier<>() {
@Override public ProxyServer get() {
try { return new ProxyServer(0, true);}
catch(IOException e) { throw new UncheckedIOException(e); } }
@Override public String toString() { return "TUNNELING_PROXY_SERVER"; }
};
static final Supplier<ProxyServer> AUTH_TUNNELING_PROXY_SERVER =
new Supplier<>() {
@Override public ProxyServer get() {
try { return new ProxyServer(0, true, USERNAME, PASSWORD);}
catch(IOException e) { throw new UncheckedIOException(e); } }
@Override public String toString() { return "AUTH_TUNNELING_PROXY_SERVER"; }
};
@DataProvider(name = "servers")
public Object[][] servers() {
return new Object[][] {
{ SERVER_WITH_CANNED_DATA, TUNNELING_PROXY_SERVER },
{ SERVER_WITH_CANNED_DATA, AUTH_TUNNELING_PROXY_SERVER },
{ AUTH_SERVER_WITH_CANNED_DATA, TUNNELING_PROXY_SERVER },
};
}
@Test(dataProvider = "servers")
public void simpleAggregatingBinaryMessages
(Function<int[],DummyWebSocketServer> serverSupplier,
Supplier<ProxyServer> proxyServerSupplier)
throws IOException
{
List<byte[]> expected = List.of("hello", "chegar")
.stream()
.map(s -> s.getBytes(StandardCharsets.US_ASCII))
.collect(Collectors.toList());
int[] binary = new int[]{
0x82, 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, // hello
0x82, 0x06, 0x63, 0x68, 0x65, 0x67, 0x61, 0x72, // chegar
0x88, 0x00 // <CLOSE>
};
CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
try (var proxyServer = proxyServerSupplier.get();
var server = serverSupplier.apply(binary)) {
InetSocketAddress proxyAddress = new InetSocketAddress(
InetAddress.getLoopbackAddress(), proxyServer.getPort());
server.open();
WebSocket.Listener listener = new WebSocket.Listener() {
List<byte[]> collectedBytes = new ArrayList<>();
ByteBuffer buffer = ByteBuffer.allocate(1024);
@Override
public CompletionStage<?> onBinary(WebSocket webSocket,
ByteBuffer message,
boolean last) {
System.out.printf("onBinary(%s, %s)%n", message, last);
webSocket.request(1);
append(message);
if (last) {
buffer.flip();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
buffer.clear();
processWholeBinary(bytes);
}
return null;
}
private void append(ByteBuffer message) {
if (buffer.remaining() < message.remaining()) {
assert message.remaining() > 0;
int cap = (buffer.capacity() + message.remaining()) * 2;
ByteBuffer b = ByteBuffer.allocate(cap);
b.put(buffer.flip());
buffer = b;
}
buffer.put(message);
}
private void processWholeBinary(byte[] bytes) {
String stringBytes = new String(bytes, UTF_8);
System.out.println("processWholeBinary: " + stringBytes);
collectedBytes.add(bytes);
}
@Override
public CompletionStage<?> onClose(WebSocket webSocket,
int statusCode,
String reason) {
actual.complete(collectedBytes);
return null;
}
@Override
public void onError(WebSocket webSocket, Throwable error) {
actual.completeExceptionally(error);
}
};
var webSocket = newBuilder()
.proxy(ProxySelector.of(proxyAddress))
.authenticator(new WSAuthenticator())
.build().newWebSocketBuilder()
.buildAsync(server.getURI(), listener)
.join();
List<byte[]> a = actual.join();
assertEquals(a, expected);
}
}
// -- authentication specific tests
/*
* Ensures authentication succeeds when an Authenticator set on client builder.
*/
@Test
public void clientAuthenticate() throws IOException {
try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
var server = new DummyWebSocketServer()){
server.open();
InetSocketAddress proxyAddress = new InetSocketAddress(
InetAddress.getLoopbackAddress(), proxyServer.getPort());
var webSocket = newBuilder()
.proxy(ProxySelector.of(proxyAddress))
.authenticator(new WSAuthenticator())
.build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { })
.join();
}
}
/*
* Ensures authentication succeeds when an `Authorization` header is explicitly set.
*/
@Test
public void explicitAuthenticate() throws IOException {
try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
var server = new DummyWebSocketServer()) {
server.open();
InetSocketAddress proxyAddress = new InetSocketAddress(
InetAddress.getLoopbackAddress(), proxyServer.getPort());
String hv = "Basic " + Base64.getEncoder().encodeToString(
(USERNAME + ":" + PASSWORD).getBytes(UTF_8));
var webSocket = newBuilder()
.proxy(ProxySelector.of(proxyAddress)).build()
.newWebSocketBuilder()
.header("Proxy-Authorization", hv)
.buildAsync(server.getURI(), new WebSocket.Listener() { })
.join();
}
}
/*
* Ensures authentication does not succeed when no authenticator is present.
*/
@Test
public void failNoAuthenticator() throws IOException {
try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
server.open();
InetSocketAddress proxyAddress = new InetSocketAddress(
InetAddress.getLoopbackAddress(), proxyServer.getPort());
CompletableFuture<WebSocket> cf = newBuilder()
.proxy(ProxySelector.of(proxyAddress)).build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { });
try {
var webSocket = cf.join();
fail("Expected exception not thrown");
} catch (CompletionException expected) {
WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
HttpResponse<?> response = e.getResponse();
assertEquals(response.statusCode(), 407);
}
}
}
/*
* Ensures authentication does not succeed when the authenticator presents
* unauthorized credentials.
*/
@Test
public void failBadCredentials() throws IOException {
try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
server.open();
InetSocketAddress proxyAddress = new InetSocketAddress(
InetAddress.getLoopbackAddress(), proxyServer.getPort());
Authenticator authenticator = new Authenticator() {
@Override protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
}
};
CompletableFuture<WebSocket> cf = newBuilder()
.proxy(ProxySelector.of(proxyAddress))
.authenticator(authenticator)
.build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { });
try {
var webSocket = cf.join();
fail("Expected exception not thrown");
} catch (CompletionException expected) {
System.out.println("caught expected exception:" + expected);
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -23,6 +23,7 @@
/*
* @test
* @bug 8217429
* @build DummyWebSocketServer
* @run testng/othervm
* WebSocketTest
@ -33,23 +34,32 @@ import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.IOException;
import java.net.Authenticator;
import java.net.PasswordAuthentication;
import java.net.http.HttpResponse;
import java.net.http.WebSocket;
import java.net.http.WebSocketHandshakeException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static java.net.http.HttpClient.Builder.NO_PROXY;
import static java.net.http.HttpClient.newBuilder;
import static java.net.http.WebSocket.NORMAL_CLOSURE;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.fail;
public class WebSocketTest {
@ -68,8 +78,11 @@ public class WebSocketTest {
@AfterTest
public void cleanup() {
server.close();
webSocket.abort();
System.out.println("AFTER TEST");
if (server != null)
server.close();
if (webSocket != null)
webSocket.abort();
}
@Test
@ -134,6 +147,8 @@ public class WebSocketTest {
assertThrows(IAE, () -> webSocket.request(Long.MIN_VALUE));
assertThrows(IAE, () -> webSocket.request(-1));
assertThrows(IAE, () -> webSocket.request(0));
server.close();
}
@Test
@ -149,6 +164,7 @@ public class WebSocketTest {
// Pings & Pongs are fine
webSocket.sendPing(ByteBuffer.allocate(125)).join();
webSocket.sendPong(ByteBuffer.allocate(125)).join();
server.close();
}
@Test
@ -165,6 +181,7 @@ public class WebSocketTest {
// Pings & Pongs are fine
webSocket.sendPing(ByteBuffer.allocate(125)).join();
webSocket.sendPong(ByteBuffer.allocate(125)).join();
server.close();
}
@Test
@ -198,6 +215,8 @@ public class WebSocketTest {
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
server.close();
}
@DataProvider(name = "sequence")
@ -318,6 +337,8 @@ public class WebSocketTest {
listener.invocations();
violation.complete(null); // won't affect if completed exceptionally
violation.join();
server.close();
}
@Test
@ -372,10 +393,48 @@ public class WebSocketTest {
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
server.close();
}
@Test
public void simpleAggregatingBinaryMessages() throws IOException {
// Used to verify a server requiring Authentication
private static final String USERNAME = "chegar";
private static final String PASSWORD = "a1b2c3";
static class WSAuthenticator extends Authenticator {
@Override
protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
}
}
static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
new Function<>() {
@Override public DummyWebSocketServer apply(int[] data) {
return Support.serverWithCannedData(data); }
@Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
};
static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
new Function<>() {
@Override public DummyWebSocketServer apply(int[] data) {
return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
@Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
};
@DataProvider(name = "servers")
public Object[][] servers() {
return new Object[][] {
{ SERVER_WITH_CANNED_DATA },
{ AUTH_SERVER_WITH_CANNED_DATA },
};
}
@Test(dataProvider = "servers")
public void simpleAggregatingBinaryMessages
(Function<int[],DummyWebSocketServer> serverSupplier)
throws IOException
{
List<byte[]> expected = List.of("alpha", "beta", "gamma", "delta")
.stream()
.map(s -> s.getBytes(StandardCharsets.US_ASCII))
@ -399,7 +458,7 @@ public class WebSocketTest {
};
CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
server = Support.serverWithCannedData(binary);
server = serverSupplier.apply(binary);
server.open();
WebSocket.Listener listener = new WebSocket.Listener() {
@ -437,7 +496,7 @@ public class WebSocketTest {
}
private void processWholeBinary(byte[] bytes) {
String stringBytes = new String(bytes, StandardCharsets.UTF_8);
String stringBytes = new String(bytes, UTF_8);
System.out.println("processWholeBinary: " + stringBytes);
collectedBytes.add(bytes);
}
@ -456,17 +515,24 @@ public class WebSocketTest {
}
};
webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
webSocket = newBuilder()
.proxy(NO_PROXY)
.authenticator(new WSAuthenticator())
.build().newWebSocketBuilder()
.buildAsync(server.getURI(), listener)
.join();
List<byte[]> a = actual.join();
assertEquals(a, expected);
server.close();
}
@Test
public void simpleAggregatingTextMessages() throws IOException {
@Test(dataProvider = "servers")
public void simpleAggregatingTextMessages
(Function<int[],DummyWebSocketServer> serverSupplier)
throws IOException
{
List<String> expected = List.of("alpha", "beta", "gamma", "delta");
int[] binary = new int[]{
@ -488,7 +554,7 @@ public class WebSocketTest {
};
CompletableFuture<List<String>> actual = new CompletableFuture<>();
server = Support.serverWithCannedData(binary);
server = serverSupplier.apply(binary);
server.open();
WebSocket.Listener listener = new WebSocket.Listener() {
@ -530,21 +596,28 @@ public class WebSocketTest {
}
};
webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
webSocket = newBuilder()
.proxy(NO_PROXY)
.authenticator(new WSAuthenticator())
.build().newWebSocketBuilder()
.buildAsync(server.getURI(), listener)
.join();
List<String> a = actual.join();
assertEquals(a, expected);
server.close();
}
/*
* Exercises the scenario where requests for more messages are made prior to
* completing the returned CompletionStage instances.
*/
@Test
public void aggregatingTextMessages() throws IOException {
@Test(dataProvider = "servers")
public void aggregatingTextMessages
(Function<int[],DummyWebSocketServer> serverSupplier)
throws IOException
{
List<String> expected = List.of("alpha", "beta", "gamma", "delta");
int[] binary = new int[]{
@ -566,8 +639,7 @@ public class WebSocketTest {
};
CompletableFuture<List<String>> actual = new CompletableFuture<>();
server = Support.serverWithCannedData(binary);
server = serverSupplier.apply(binary);
server.open();
WebSocket.Listener listener = new WebSocket.Listener() {
@ -623,11 +695,111 @@ public class WebSocketTest {
}
};
webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
webSocket = newBuilder()
.proxy(NO_PROXY)
.authenticator(new WSAuthenticator())
.build().newWebSocketBuilder()
.buildAsync(server.getURI(), listener)
.join();
List<String> a = actual.join();
assertEquals(a, expected);
server.close();
}
// -- authentication specific tests
/*
* Ensures authentication succeeds when an Authenticator set on client builder.
*/
@Test
public void clientAuthenticate() throws IOException {
try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)){
server.open();
var webSocket = newBuilder()
.proxy(NO_PROXY)
.authenticator(new WSAuthenticator())
.build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { })
.join();
}
}
/*
* Ensures authentication succeeds when an `Authorization` header is explicitly set.
*/
@Test
public void explicitAuthenticate() throws IOException {
try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
server.open();
String hv = "Basic " + Base64.getEncoder().encodeToString(
(USERNAME + ":" + PASSWORD).getBytes(UTF_8));
var webSocket = newBuilder()
.proxy(NO_PROXY).build()
.newWebSocketBuilder()
.header("Authorization", hv)
.buildAsync(server.getURI(), new WebSocket.Listener() { })
.join();
}
}
/*
* Ensures authentication does not succeed when no authenticator is present.
*/
@Test
public void failNoAuthenticator() throws IOException {
try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
server.open();
CompletableFuture<WebSocket> cf = newBuilder()
.proxy(NO_PROXY).build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { });
try {
var webSocket = cf.join();
fail("Expected exception not thrown");
} catch (CompletionException expected) {
WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
HttpResponse<?> response = e.getResponse();
assertEquals(response.statusCode(), 401);
}
}
}
/*
* Ensures authentication does not succeed when the authenticator presents
* unauthorized credentials.
*/
@Test
public void failBadCredentials() throws IOException {
try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
server.open();
Authenticator authenticator = new Authenticator() {
@Override protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
}
};
CompletableFuture<WebSocket> cf = newBuilder()
.proxy(NO_PROXY)
.authenticator(authenticator)
.build()
.newWebSocketBuilder()
.buildAsync(server.getURI(), new WebSocket.Listener() { });
try {
var webSocket = cf.join();
fail("Expected exception not thrown");
} catch (CompletionException expected) {
System.out.println("caught expected exception:" + expected);
}
}
}
}