8087113: Websocket API and implementation

Reviewed-by: chegar
This commit is contained in:
Pavel Rappo 2016-05-09 23:33:09 +01:00
parent dd927b90d5
commit b962e07463
26 changed files with 5402 additions and 165 deletions

View File

@ -1,159 +0,0 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
*/
package java.net.http;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CoderResult;
import static java.nio.charset.StandardCharsets.UTF_8;
// The purpose of this class is to separate charset-related tasks from the main
// WebSocket logic, simplifying where possible.
//
// * Coders hide the differences between coding and flushing stages on the
// API level
// * Verifier abstracts the way the verification is performed
// (spoiler: it's a decoding into a throw-away buffer)
//
// Coding methods throw exceptions instead of returning coding result denoting
// errors, since any kind of handling and recovery is not expected.
final class CharsetToolkit {
private CharsetToolkit() { }
static final class Verifier {
private final CharsetDecoder decoder = UTF_8.newDecoder();
// A buffer used to check validity of UTF-8 byte stream by decoding it.
// The contents of this buffer are never used.
// The size is arbitrary, though it should probably be chosen from the
// performance perspective since it affects the total number of calls to
// decoder.decode() and amount of work in each of these calls
private final CharBuffer blackHole = CharBuffer.allocate(1024);
void verify(ByteBuffer in, boolean endOfInput)
throws CharacterCodingException {
while (true) {
// Since decoder.flush() cannot produce an error, it's not
// helpful for verification. Therefore this step is skipped.
CoderResult r = decoder.decode(in, blackHole, endOfInput);
if (r.isOverflow()) {
blackHole.clear();
} else if (r.isUnderflow()) {
break;
} else if (r.isError()) {
r.throwException();
} else {
// Should not happen
throw new InternalError();
}
}
}
Verifier reset() {
decoder.reset();
return this;
}
}
static final class Encoder {
private final CharsetEncoder encoder = UTF_8.newEncoder();
private boolean coding = true;
CoderResult encode(CharBuffer in, ByteBuffer out, boolean endOfInput)
throws CharacterCodingException {
if (coding) {
CoderResult r = encoder.encode(in, out, endOfInput);
if (r.isOverflow()) {
return r;
} else if (r.isUnderflow()) {
if (endOfInput) {
coding = false;
} else {
return r;
}
} else if (r.isError()) {
r.throwException();
} else {
// Should not happen
throw new InternalError();
}
}
assert !coding;
return encoder.flush(out);
}
Encoder reset() {
coding = true;
encoder.reset();
return this;
}
}
static CharBuffer decode(ByteBuffer in) throws CharacterCodingException {
return UTF_8.newDecoder().decode(in);
}
static final class Decoder {
private final CharsetDecoder decoder = UTF_8.newDecoder();
private boolean coding = true; // Either coding or flushing
CoderResult decode(ByteBuffer in, CharBuffer out, boolean endOfInput)
throws CharacterCodingException {
if (coding) {
CoderResult r = decoder.decode(in, out, endOfInput);
if (r.isOverflow()) {
return r;
} else if (r.isUnderflow()) {
if (endOfInput) {
coding = false;
} else {
return r;
}
} else if (r.isError()) {
r.throwException();
} else {
// Should not happen
throw new InternalError();
}
}
assert !coding;
return decoder.flush(out);
}
Decoder reset() {
coding = true;
decoder.reset();
return this;
}
}
}

View File

@ -39,18 +39,22 @@ final class RawChannel implements ByteChannel, GatheringByteChannel {
private final HttpClientImpl client;
private final HttpConnection connection;
private volatile boolean closed;
private interface RawEvent {
/** must return the selector interest op flags OR'd. */
/**
* must return the selector interest op flags OR'd.
*/
int interestOps();
/** called when event occurs. */
/**
* called when event occurs.
*/
void handle();
}
interface NonBlockingEvent extends RawEvent { }
interface NonBlockingEvent extends RawEvent {
}
RawChannel(HttpClientImpl client, HttpConnection connection) {
this.client = client;
@ -127,12 +131,11 @@ final class RawChannel implements ByteChannel, GatheringByteChannel {
@Override
public boolean isOpen() {
return !closed;
return connection.isOpen();
}
@Override
public void close() throws IOException {
closed = true;
connection.close();
}

View File

@ -0,0 +1,390 @@
/*
* Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.io.IOException;
import java.net.ProtocolException;
import java.net.http.WSOpeningHandshake.Result;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;
import static java.lang.System.Logger.Level.ERROR;
import static java.lang.System.Logger.Level.WARNING;
import static java.net.http.WSUtils.logger;
import static java.util.Objects.requireNonNull;
/*
* A WebSocket client.
*
* Consists of two independent parts; a transmitter responsible for sending
* messages, and a receiver which notifies the listener of incoming messages.
*/
final class WS implements WebSocket {
private final String subprotocol;
private final RawChannel channel;
private final WSTransmitter transmitter;
private final WSReceiver receiver;
private final Listener listener;
private final Object stateLock = new Object();
private volatile State state = State.CONNECTED;
private final CompletableFuture<Void> whenClosed = new CompletableFuture<>();
static CompletableFuture<WebSocket> newInstanceAsync(WSBuilder b) {
CompletableFuture<Result> result = new WSOpeningHandshake(b).performAsync();
Listener listener = b.getListener();
Executor executor = b.getClient().executorService();
return result.thenApply(r -> {
WS ws = new WS(listener, r.subprotocol, r.channel, executor);
ws.start();
return ws;
});
}
private WS(Listener listener, String subprotocol, RawChannel channel,
Executor executor) {
this.listener = wrapListener(listener);
this.channel = channel;
this.subprotocol = subprotocol;
Consumer<Throwable> errorHandler = error -> {
if (error == null) {
throw new InternalError();
}
// If the channel is closed, we need to update the state, to denote
// there's no point in trying to continue using WebSocket
if (!channel.isOpen()) {
synchronized (stateLock) {
tryChangeState(State.ERROR);
}
}
};
transmitter = new WSTransmitter(executor, channel, errorHandler);
receiver = new WSReceiver(this.listener, this, executor, channel);
}
private void start() {
receiver.start();
}
@Override
public CompletableFuture<Void> sendText(ByteBuffer message, boolean isLast) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public CompletableFuture<Void> sendText(CharSequence message, boolean isLast) {
requireNonNull(message, "message");
synchronized (stateLock) {
checkState();
return transmitter.sendText(message, isLast);
}
}
@Override
public CompletableFuture<Void> sendText(Stream<? extends CharSequence> message) {
requireNonNull(message, "message");
synchronized (stateLock) {
checkState();
return transmitter.sendText(message);
}
}
@Override
public CompletableFuture<Void> sendBinary(ByteBuffer message, boolean isLast) {
requireNonNull(message, "message");
synchronized (stateLock) {
checkState();
return transmitter.sendBinary(message, isLast);
}
}
@Override
public CompletableFuture<Void> sendPing(ByteBuffer message) {
requireNonNull(message, "message");
synchronized (stateLock) {
checkState();
return transmitter.sendPing(message);
}
}
@Override
public CompletableFuture<Void> sendPong(ByteBuffer message) {
requireNonNull(message, "message");
synchronized (stateLock) {
checkState();
return transmitter.sendPong(message);
}
}
@Override
public CompletableFuture<Void> sendClose(CloseCode code, CharSequence reason) {
requireNonNull(code, "code");
requireNonNull(reason, "reason");
synchronized (stateLock) {
return doSendClose(() -> transmitter.sendClose(code, reason));
}
}
@Override
public CompletableFuture<Void> sendClose() {
synchronized (stateLock) {
return doSendClose(() -> transmitter.sendClose());
}
}
private CompletableFuture<Void> doSendClose(Supplier<CompletableFuture<Void>> s) {
checkState();
boolean closeChannel = false;
synchronized (stateLock) {
if (state == State.CLOSED_REMOTELY) {
closeChannel = tryChangeState(State.CLOSED);
} else {
tryChangeState(State.CLOSED_LOCALLY);
}
}
CompletableFuture<Void> sent = s.get();
if (closeChannel) {
sent.whenComplete((v, t) -> {
try {
channel.close();
} catch (IOException e) {
logger.log(ERROR, "Error transitioning to state " + State.CLOSED, e);
}
});
}
return sent;
}
@Override
public long request(long n) {
if (n < 0L) {
throw new IllegalArgumentException("The number must not be negative: " + n);
}
return receiver.request(n);
}
@Override
public String getSubprotocol() {
return subprotocol;
}
@Override
public boolean isClosed() {
return state.isTerminal();
}
@Override
public void abort() throws IOException {
synchronized (stateLock) {
tryChangeState(State.ABORTED);
}
channel.close();
}
@Override
public String toString() {
return super.toString() + "[" + state + "]";
}
private void checkState() {
if (state.isTerminal() || state == State.CLOSED_LOCALLY) {
throw new IllegalStateException("WebSocket is closed [" + state + "]");
}
}
/*
* Wraps the user's listener passed to the constructor into own listener to
* intercept transitions to terminal states (onClose and onError) and to act
* upon exceptions and values from the user's listener.
*/
private Listener wrapListener(Listener listener) {
return new Listener() {
// Listener's method MUST be invoked in a happen-before order
private final Object visibilityLock = new Object();
@Override
public void onOpen(WebSocket webSocket) {
synchronized (visibilityLock) {
listener.onOpen(webSocket);
}
}
@Override
public CompletionStage<?> onText(WebSocket webSocket, Text message,
MessagePart part) {
synchronized (visibilityLock) {
return listener.onText(webSocket, message, part);
}
}
@Override
public CompletionStage<?> onBinary(WebSocket webSocket, ByteBuffer message,
MessagePart part) {
synchronized (visibilityLock) {
return listener.onBinary(webSocket, message, part);
}
}
@Override
public CompletionStage<?> onPing(WebSocket webSocket, ByteBuffer message) {
synchronized (visibilityLock) {
return listener.onPing(webSocket, message);
}
}
@Override
public CompletionStage<?> onPong(WebSocket webSocket, ByteBuffer message) {
synchronized (visibilityLock) {
return listener.onPong(webSocket, message);
}
}
@Override
public void onClose(WebSocket webSocket, Optional<CloseCode> code, String reason) {
synchronized (stateLock) {
if (state == State.CLOSED_REMOTELY || state.isTerminal()) {
throw new InternalError("Unexpected onClose in state " + state);
} else if (state == State.CLOSED_LOCALLY) {
try {
channel.close();
} catch (IOException e) {
logger.log(ERROR, "Error transitioning to state " + State.CLOSED, e);
}
tryChangeState(State.CLOSED);
} else if (state == State.CONNECTED) {
tryChangeState(State.CLOSED_REMOTELY);
}
}
synchronized (visibilityLock) {
listener.onClose(webSocket, code, reason);
}
}
@Override
public void onError(WebSocket webSocket, Throwable error) {
// An error doesn't necessarily mean the connection must be
// closed automatically
if (!channel.isOpen()) {
synchronized (stateLock) {
tryChangeState(State.ERROR);
}
} else if (error instanceof ProtocolException
&& error.getCause() instanceof WSProtocolException) {
WSProtocolException cause = (WSProtocolException) error.getCause();
logger.log(WARNING, "Failing connection {0}, reason: ''{1}''",
webSocket, cause.getMessage());
CloseCode cc = cause.getCloseCode();
transmitter.sendClose(cc, "").whenComplete((v, t) -> {
synchronized (stateLock) {
tryChangeState(State.ERROR);
}
try {
channel.close();
} catch (IOException e) {
logger.log(ERROR, e);
}
});
}
synchronized (visibilityLock) {
listener.onError(webSocket, error);
}
}
};
}
private boolean tryChangeState(State newState) {
assert Thread.holdsLock(stateLock);
if (state.isTerminal()) {
return false;
}
state = newState;
if (newState.isTerminal()) {
whenClosed.complete(null);
}
return true;
}
CompletionStage<Void> whenClosed() {
return whenClosed;
}
/*
* WebSocket connection internal state.
*/
private enum State {
/*
* Initial WebSocket state. The WebSocket is connected (i.e. remains in
* this state) unless proven otherwise. For example, by reading or
* writing operations on the channel.
*/
CONNECTED,
/*
* A Close message has been received by the client. No more messages
* will be received.
*/
CLOSED_REMOTELY,
/*
* A Close message has been sent by the client. No more messages can be
* sent.
*/
CLOSED_LOCALLY,
/*
* Close messages has been both sent and received (closing handshake)
* and TCP connection closed. Closed _cleanly_ in terms of RFC 6455.
*/
CLOSED,
/*
* The connection has been aborted by the client. Closed not _cleanly_
* in terms of RFC 6455.
*/
ABORTED,
/*
* The connection has been terminated due to a protocol or I/O error.
* Might happen during sending or receiving.
*/
ERROR;
/*
* Returns `true` if this state is terminal. If WebSocket has transited
* to such a state, if remains in it forever.
*/
boolean isTerminal() {
return this == CLOSED || this == ABORTED || this == ERROR;
}
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
final class WSBuilder implements WebSocket.Builder {
private static final Set<String> FORBIDDEN_HEADERS =
new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
static {
List<String> headers = List.of("Connection", "Upgrade",
"Sec-WebSocket-Accept", "Sec-WebSocket-Extensions",
"Sec-WebSocket-Key", "Sec-WebSocket-Protocol",
"Sec-WebSocket-Version");
FORBIDDEN_HEADERS.addAll(headers);
}
private final URI uri;
private final HttpClient client;
private final LinkedHashMap<String, List<String>> headers = new LinkedHashMap<>();
private final WebSocket.Listener listener;
private Collection<String> subprotocols = Collections.emptyList();
private long timeout;
private TimeUnit timeUnit;
WSBuilder(URI uri, HttpClient client, WebSocket.Listener listener) {
checkURI(requireNonNull(uri, "uri"));
requireNonNull(client, "client");
requireNonNull(listener, "listener");
this.uri = uri;
this.listener = listener;
this.client = client;
}
@Override
public WebSocket.Builder header(String name, String value) {
requireNonNull(name, "name");
requireNonNull(value, "value");
if (FORBIDDEN_HEADERS.contains(name)) {
throw new IllegalArgumentException(
format("Header '%s' is used in the WebSocket Protocol", name));
}
List<String> values = headers.computeIfAbsent(name, n -> new LinkedList<>());
values.add(value);
return this;
}
@Override
public WebSocket.Builder subprotocols(String mostPreferred, String... lesserPreferred) {
requireNonNull(mostPreferred, "mostPreferred");
requireNonNull(lesserPreferred, "lesserPreferred");
this.subprotocols = checkSubprotocols(mostPreferred, lesserPreferred);
return this;
}
@Override
public WebSocket.Builder connectTimeout(long timeout, TimeUnit unit) {
if (timeout < 0) {
throw new IllegalArgumentException("Negative timeout: " + timeout);
}
requireNonNull(unit, "unit");
this.timeout = timeout;
this.timeUnit = unit;
return this;
}
@Override
public CompletableFuture<WebSocket> buildAsync() {
return WS.newInstanceAsync(this);
}
private static URI checkURI(URI uri) {
String s = uri.getScheme();
if (!("ws".equalsIgnoreCase(s) || "wss".equalsIgnoreCase(s))) {
throw new IllegalArgumentException
("URI scheme not ws or wss (RFC 6455 3.): " + s);
}
String fragment = uri.getFragment();
if (fragment != null) {
throw new IllegalArgumentException(format
("Fragment not allowed in a WebSocket URI (RFC 6455 3.): '%s'",
fragment));
}
return uri;
}
URI getUri() { return uri; }
HttpClient getClient() { return client; }
Map<String, List<String>> getHeaders() {
LinkedHashMap<String, List<String>> copy = new LinkedHashMap<>(headers.size());
headers.forEach((name, values) -> copy.put(name, new LinkedList<>(values)));
return copy;
}
WebSocket.Listener getListener() { return listener; }
Collection<String> getSubprotocols() {
return new ArrayList<>(subprotocols);
}
long getTimeout() { return timeout; }
TimeUnit getTimeUnit() { return timeUnit; }
private static Collection<String> checkSubprotocols(String mostPreferred,
String... lesserPreferred) {
checkSubprotocolSyntax(mostPreferred, "mostPreferred");
LinkedHashSet<String> sp = new LinkedHashSet<>(1 + lesserPreferred.length);
sp.add(mostPreferred);
for (int i = 0; i < lesserPreferred.length; i++) {
String p = lesserPreferred[i];
String location = format("lesserPreferred[%s]", i);
requireNonNull(p, location);
checkSubprotocolSyntax(p, location);
if (!sp.add(p)) {
throw new IllegalArgumentException(format(
"Duplicate subprotocols (RFC 6455 4.1.): '%s'", p));
}
}
return sp;
}
private static void checkSubprotocolSyntax(String subprotocol, String location) {
if (subprotocol.isEmpty()) {
throw new IllegalArgumentException
("Subprotocol name is empty (RFC 6455 4.1.): " + location);
}
if (!subprotocol.chars().allMatch(c -> 0x21 <= c && c <= 0x7e)) {
throw new IllegalArgumentException
("Subprotocol name contains illegal characters (RFC 6455 4.1.): "
+ location);
}
}
}

View File

@ -0,0 +1,126 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
*/
package java.net.http;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CoderResult;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import static java.lang.System.Logger.Level.WARNING;
import static java.net.http.WSUtils.EMPTY_BYTE_BUFFER;
import static java.net.http.WSUtils.logger;
import static java.nio.charset.StandardCharsets.UTF_8;
/*
* A collection of tools for UTF-8 coding.
*/
final class WSCharsetToolkit {
private WSCharsetToolkit() { }
static final class Encoder {
private final CharsetEncoder encoder = UTF_8.newEncoder();
ByteBuffer encode(CharBuffer in) throws CharacterCodingException {
return encoder.encode(in);
}
// TODO:
// ByteBuffer[] encode(CharBuffer in) throws CharacterCodingException {
// return encoder.encode(in);
// }
}
static CharBuffer decode(ByteBuffer in) throws CharacterCodingException {
return UTF_8.newDecoder().decode(in);
}
static final class Decoder {
private final CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder();
{
decoder.onMalformedInput(CodingErrorAction.REPORT);
decoder.onUnmappableCharacter(CodingErrorAction.REPORT);
}
private ByteBuffer leftovers = EMPTY_BYTE_BUFFER;
WSShared<CharBuffer> decode(WSShared<ByteBuffer> in, boolean endOfInput)
throws CharacterCodingException {
ByteBuffer b;
int rem = leftovers.remaining();
if (rem != 0) {
// TODO: We won't need this wasteful allocation & copying when
// JDK-8155222 has been resolved
b = ByteBuffer.allocate(rem + in.remaining());
b.put(leftovers).put(in.buffer()).flip();
} else {
b = in.buffer();
}
CharBuffer out = CharBuffer.allocate(b.remaining());
CoderResult r = decoder.decode(b, out, endOfInput);
if (r.isError()) {
r.throwException();
}
if (b.hasRemaining()) {
leftovers = ByteBuffer.allocate(b.remaining()).put(b).flip();
} else {
leftovers = EMPTY_BYTE_BUFFER;
}
// Since it's UTF-8, the assumption is leftovers.remaining() < 4
// (i.e. small). Otherwise a shared buffer should be used
if (!(leftovers.remaining() < 4)) {
logger.log(WARNING,
"The size of decoding leftovers is greater than expected: {0}",
leftovers.remaining());
}
b.position(b.limit()); // As if we always read to the end
in.dispose();
// Decoder promises that in the case of endOfInput == true:
// "...any remaining undecoded input will be treated as being
// malformed"
assert !(endOfInput && leftovers.hasRemaining()) : endOfInput + ", " + leftovers;
if (endOfInput) {
r = decoder.flush(out);
decoder.reset();
if (r.isOverflow()) {
// FIXME: for now I know flush() does nothing. But the
// implementation of UTF8 decoder might change. And if now
// flush() is a no-op, it is not guaranteed to remain so in
// the future
throw new InternalError("Not yet implemented");
}
}
out.flip();
return WSShared.wrap(out);
}
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
interface WSDisposable {
void dispose();
}

View File

@ -0,0 +1,67 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
final class WSDisposableText implements WebSocket.Text, WSDisposable {
private final WSShared<CharBuffer> text;
WSDisposableText(WSShared<CharBuffer> text) {
this.text = text;
}
@Override
public int length() {
return text.buffer().length();
}
@Override
public char charAt(int index) {
return text.buffer().charAt(index);
}
@Override
public CharSequence subSequence(int start, int end) {
return text.buffer().subSequence(start, end);
}
@Override
public ByteBuffer asByteBuffer() {
throw new UnsupportedOperationException("To be removed from the API");
}
@Override
public String toString() {
return text.buffer().toString();
}
@Override
public void dispose() {
text.dispose();
}
}

View File

@ -0,0 +1,486 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.nio.ByteBuffer;
import static java.lang.String.format;
import static java.net.http.WSFrame.Opcode.ofCode;
import static java.net.http.WSUtils.dump;
/*
* A collection of utilities for reading, writing, and masking frames.
*/
final class WSFrame {
private WSFrame() { }
static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4;
enum Opcode {
CONTINUATION (0x0),
TEXT (0x1),
BINARY (0x2),
NON_CONTROL_0x3(0x3),
NON_CONTROL_0x4(0x4),
NON_CONTROL_0x5(0x5),
NON_CONTROL_0x6(0x6),
NON_CONTROL_0x7(0x7),
CLOSE (0x8),
PING (0x9),
PONG (0xA),
CONTROL_0xB (0xB),
CONTROL_0xC (0xC),
CONTROL_0xD (0xD),
CONTROL_0xE (0xE),
CONTROL_0xF (0xF);
private static final Opcode[] opcodes;
static {
Opcode[] values = values();
opcodes = new Opcode[values.length];
for (Opcode c : values) {
assert opcodes[c.code] == null
: WSUtils.dump(c, c.code, opcodes[c.code]);
opcodes[c.code] = c;
}
}
private final byte code;
private final char shiftedCode;
private final String description;
Opcode(int code) {
this.code = (byte) code;
this.shiftedCode = (char) (code << 8);
this.description = format("%x (%s)", code, name());
}
boolean isControl() {
return (code & 0x8) != 0;
}
static Opcode ofCode(int code) {
return opcodes[code & 0xF];
}
@Override
public String toString() {
return description;
}
}
/*
* A utility to mask payload data.
*/
static final class Masker {
private final ByteBuffer acc = ByteBuffer.allocate(8);
private final int[] maskBytes = new int[4];
private int offset;
private long maskLong;
/*
* Sets up the mask.
*/
Masker mask(int value) {
acc.clear().putInt(value).putInt(value).flip();
for (int i = 0; i < maskBytes.length; i++) {
maskBytes[i] = acc.get(i);
}
offset = 0;
maskLong = acc.getLong(0);
return this;
}
/*
* Reads as many bytes as possible from the given input buffer, writing
* the resulting masked bytes to the given output buffer.
*
* src.remaining() <= dst.remaining() // TODO: do we need this restriction?
* 'src' and 'dst' can be the same ByteBuffer
*/
Masker applyMask(ByteBuffer src, ByteBuffer dst) {
if (src.remaining() > dst.remaining()) {
throw new IllegalArgumentException(dump(src, dst));
}
begin(src, dst);
loop(src, dst);
end(src, dst);
return this;
}
// Applying the remaining of the mask (strictly not more than 3 bytes)
// byte-wise
private void begin(ByteBuffer src, ByteBuffer dst) {
if (offset > 0) {
for (int i = src.position(), j = dst.position();
offset < 4 && i <= src.limit() - 1 && j <= dst.limit() - 1;
i++, j++, offset++) {
dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
dst.position(j + 1);
src.position(i + 1);
}
offset &= 3;
}
}
private void loop(ByteBuffer src, ByteBuffer dst) {
int i = src.position();
int j = dst.position();
final int srcLim = src.limit() - 8;
final int dstLim = dst.limit() - 8;
for (; i <= srcLim && j <= dstLim; i += 8, j += 8) {
dst.putLong(j, (src.getLong(i) ^ maskLong));
}
if (i > src.limit()) {
src.position(i - 8);
} else {
src.position(i);
}
if (j > dst.limit()) {
dst.position(j - 8);
} else {
dst.position(j);
}
}
// Applying the mask to the remaining bytes byte-wise (don't make any
// assumptions on how many, hopefully not more than 7 for 64bit arch)
private void end(ByteBuffer src, ByteBuffer dst) {
for (int i = src.position(), j = dst.position();
i <= src.limit() - 1 && j <= dst.limit() - 1;
i++, j++, offset = (offset + 1) & 3) { // offset cycle through 0..3
dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
src.position(i + 1);
dst.position(j + 1);
}
}
}
/*
* A builder of frame headers, capable of writing to a given buffer.
*
* The builder does not enforce any protocol-level rules, it simply writes
* a header structure to the buffer. The order of calls to intermediate
* methods is not significant.
*/
static final class HeaderBuilder {
private char firstChar;
private long payloadLen;
private int maskingKey;
private boolean mask;
HeaderBuilder fin(boolean value) {
if (value) {
firstChar |= 0b10000000_00000000;
} else {
firstChar &= ~0b10000000_00000000;
}
return this;
}
HeaderBuilder rsv1(boolean value) {
if (value) {
firstChar |= 0b01000000_00000000;
} else {
firstChar &= ~0b01000000_00000000;
}
return this;
}
HeaderBuilder rsv2(boolean value) {
if (value) {
firstChar |= 0b00100000_00000000;
} else {
firstChar &= ~0b00100000_00000000;
}
return this;
}
HeaderBuilder rsv3(boolean value) {
if (value) {
firstChar |= 0b00010000_00000000;
} else {
firstChar &= ~0b00010000_00000000;
}
return this;
}
HeaderBuilder opcode(Opcode value) {
firstChar = (char) ((firstChar & 0xF0FF) | value.shiftedCode);
return this;
}
HeaderBuilder payloadLen(long value) {
payloadLen = value;
firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers
if (payloadLen < 126) {
firstChar |= payloadLen;
} else if (payloadLen < 65535) {
firstChar |= 126;
} else {
firstChar |= 127;
}
return this;
}
HeaderBuilder mask(int value) {
firstChar |= 0b00000000_10000000;
maskingKey = value;
mask = true;
return this;
}
HeaderBuilder noMask() {
firstChar &= ~0b00000000_10000000;
mask = false;
return this;
}
/*
* Writes the header to the given buffer.
*
* The buffer must have at least MAX_HEADER_SIZE_BYTES remaining. The
* buffer's position is incremented by the number of bytes written.
*/
void build(ByteBuffer buffer) {
buffer.putChar(firstChar);
if (payloadLen >= 126) {
if (payloadLen < 65535) {
buffer.putChar((char) payloadLen);
} else {
buffer.putLong(payloadLen);
}
}
if (mask) {
buffer.putInt(maskingKey);
}
}
}
/*
* A consumer of frame parts.
*
* Guaranteed to be called in the following order by the Frame.Reader:
*
* fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame
*/
interface Consumer {
void fin(boolean value);
void rsv1(boolean value);
void rsv2(boolean value);
void rsv3(boolean value);
void opcode(Opcode value);
void mask(boolean value);
void payloadLen(long value);
void maskingKey(int value);
/*
* Called when a part of the payload is ready to be consumed.
*
* Though may not yield a complete payload in a single invocation, i.e.
*
* data.remaining() < payloadLen
*
* the sum of `data.remaining()` passed to all invocations of this
* method will be equal to 'payloadLen', reported in
* `void payloadLen(long value)`
*
* No unmasking is done.
*/
void payloadData(WSShared<ByteBuffer> data, boolean isLast);
void endFrame(); // TODO: remove (payloadData(isLast=true)) should be enough
}
/*
* A Reader of Frames.
*
* No protocol-level rules are enforced, only frame structure.
*/
static final class Reader {
private static final int AWAITING_FIRST_BYTE = 1;
private static final int AWAITING_SECOND_BYTE = 2;
private static final int READING_16_LENGTH = 4;
private static final int READING_64_LENGTH = 8;
private static final int READING_MASK = 16;
private static final int READING_PAYLOAD = 32;
// A private buffer used to simplify multi-byte integers reading
private final ByteBuffer accumulator = ByteBuffer.allocate(8);
private int state = AWAITING_FIRST_BYTE;
private boolean mask;
private long payloadLength;
/*
* Reads at most one frame from the given buffer invoking the consumer's
* methods corresponding to the frame elements found.
*
* As much of the frame's payload, if any, is read. The buffers position
* is updated to reflect the number of bytes read.
*
* Throws WSProtocolException if the frame is malformed.
*/
void readFrame(WSShared<ByteBuffer> shared, Consumer consumer) {
ByteBuffer input = shared.buffer();
loop:
while (true) {
byte b;
switch (state) {
case AWAITING_FIRST_BYTE:
if (!input.hasRemaining()) {
break loop;
}
b = input.get();
consumer.fin( (b & 0b10000000) != 0);
consumer.rsv1((b & 0b01000000) != 0);
consumer.rsv2((b & 0b00100000) != 0);
consumer.rsv3((b & 0b00010000) != 0);
consumer.opcode(ofCode(b));
state = AWAITING_SECOND_BYTE;
continue loop;
case AWAITING_SECOND_BYTE:
if (!input.hasRemaining()) {
break loop;
}
b = input.get();
consumer.mask(mask = (b & 0b10000000) != 0);
byte p1 = (byte) (b & 0b01111111);
if (p1 < 126) {
assert p1 >= 0 : p1;
consumer.payloadLen(payloadLength = p1);
state = mask ? READING_MASK : READING_PAYLOAD;
} else if (p1 < 127) {
state = READING_16_LENGTH;
} else {
state = READING_64_LENGTH;
}
continue loop;
case READING_16_LENGTH:
if (!input.hasRemaining()) {
break loop;
}
b = input.get();
if (accumulator.put(b).position() < 2) {
continue loop;
}
payloadLength = accumulator.flip().getChar();
if (payloadLength < 126) {
throw notMinimalEncoding(payloadLength, 2);
}
consumer.payloadLen(payloadLength);
accumulator.clear();
state = mask ? READING_MASK : READING_PAYLOAD;
continue loop;
case READING_64_LENGTH:
if (!input.hasRemaining()) {
break loop;
}
b = input.get();
if (accumulator.put(b).position() < 8) {
continue loop;
}
payloadLength = accumulator.flip().getLong();
if (payloadLength < 0) {
throw negativePayload(payloadLength);
} else if (payloadLength < 65535) {
throw notMinimalEncoding(payloadLength, 8);
}
consumer.payloadLen(payloadLength);
accumulator.clear();
state = mask ? READING_MASK : READING_PAYLOAD;
continue loop;
case READING_MASK:
if (!input.hasRemaining()) {
break loop;
}
b = input.get();
if (accumulator.put(b).position() != 4) {
continue loop;
}
consumer.maskingKey(accumulator.flip().getInt());
accumulator.clear();
state = READING_PAYLOAD;
continue loop;
case READING_PAYLOAD:
// This state does not require any bytes to be available
// in the input buffer in order to proceed
boolean fullyRead;
int limit;
if (payloadLength <= input.remaining()) {
limit = input.position() + (int) payloadLength;
payloadLength = 0;
fullyRead = true;
} else {
limit = input.limit();
payloadLength -= input.remaining();
fullyRead = false;
}
// FIXME: consider a case where payloadLen != 0,
// but input.remaining() == 0
//
// There shouldn't be an invocation of payloadData with
// an empty buffer, as it would be an artifact of
// reading
consumer.payloadData(shared.share(input.position(), limit), fullyRead);
// Update the position manually, since reading the
// payload doesn't advance buffer's position
input.position(limit);
if (fullyRead) {
consumer.endFrame();
state = AWAITING_FIRST_BYTE;
}
break loop;
default:
throw new InternalError(String.valueOf(state));
}
}
}
private static WSProtocolException negativePayload(long payloadLength) {
return new WSProtocolException
("5.2.", format("Negative 64-bit payload length %s", payloadLength));
}
private static WSProtocolException notMinimalEncoding(long payloadLength, int numBytes) {
return new WSProtocolException
("5.2.", format("Payload length (%s) is not encoded with minimal number (%s) of bytes",
payloadLength, numBytes));
}
}
}

View File

@ -0,0 +1,289 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.http.WSFrame.Opcode;
import java.net.http.WebSocket.MessagePart;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.util.concurrent.atomic.AtomicInteger;
import static java.lang.String.format;
import static java.lang.System.Logger.Level.TRACE;
import static java.net.http.WSUtils.dump;
import static java.net.http.WSUtils.logger;
import static java.net.http.WebSocket.CloseCode.NOT_CONSISTENT;
import static java.net.http.WebSocket.CloseCode.of;
import static java.util.Objects.requireNonNull;
/*
* Consumes frame parts and notifies a message consumer, when there is
* sufficient data to produce a message, or part thereof.
*
* Data consumed but not yet translated is accumulated until it's sufficient to
* form a message.
*/
final class WSFrameConsumer implements WSFrame.Consumer {
private final AtomicInteger invocationOrder = new AtomicInteger();
private final WSMessageConsumer output;
private final WSCharsetToolkit.Decoder decoder = new WSCharsetToolkit.Decoder();
private boolean fin;
private Opcode opcode, originatingOpcode;
private MessagePart part = MessagePart.WHOLE;
private long payloadLen;
private WSShared<ByteBuffer> binaryData;
WSFrameConsumer(WSMessageConsumer output) {
this.output = requireNonNull(output);
}
@Override
public void fin(boolean value) {
assert invocationOrder.compareAndSet(0, 1) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
// Checked for being loggable because of autoboxing of 'value'
logger.log(TRACE, "Reading fin: {0}", value);
}
fin = value;
}
@Override
public void rsv1(boolean value) {
assert invocationOrder.compareAndSet(1, 2) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading rsv1: {0}", value);
}
if (value) {
throw new WSProtocolException("5.2.", "rsv1 bit is set unexpectedly");
}
}
@Override
public void rsv2(boolean value) {
assert invocationOrder.compareAndSet(2, 3) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading rsv2: {0}", value);
}
if (value) {
throw new WSProtocolException("5.2.", "rsv2 bit is set unexpectedly");
}
}
@Override
public void rsv3(boolean value) {
assert invocationOrder.compareAndSet(3, 4) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading rsv3: {0}", value);
}
if (value) {
throw new WSProtocolException("5.2.", "rsv3 bit is set unexpectedly");
}
}
@Override
public void opcode(Opcode v) {
assert invocationOrder.compareAndSet(4, 5) : dump(invocationOrder, v);
logger.log(TRACE, "Reading opcode: {0}", v);
if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
if (!fin) {
throw new WSProtocolException("5.5.", "A fragmented control frame " + v);
}
opcode = v;
} else if (v == Opcode.TEXT || v == Opcode.BINARY) {
if (originatingOpcode != null) {
throw new WSProtocolException
("5.4.", format("An unexpected frame %s (fin=%s)", v, fin));
}
opcode = v;
if (!fin) {
originatingOpcode = v;
}
} else if (v == Opcode.CONTINUATION) {
if (originatingOpcode == null) {
throw new WSProtocolException
("5.4.", format("An unexpected frame %s (fin=%s)", v, fin));
}
opcode = v;
} else {
throw new WSProtocolException("5.2.", "An unknown opcode " + v);
}
}
@Override
public void mask(boolean value) {
assert invocationOrder.compareAndSet(5, 6) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading mask: {0}", value);
}
if (value) {
throw new WSProtocolException
("5.1.", "Received a masked frame from the server");
}
}
@Override
public void payloadLen(long value) {
assert invocationOrder.compareAndSet(6, 7) : dump(invocationOrder, value);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading payloadLen: {0}", value);
}
if (opcode.isControl()) {
if (value > 125) {
throw new WSProtocolException
("5.5.", format("A control frame %s has a payload length of %s",
opcode, value));
}
assert Opcode.CLOSE.isControl();
if (opcode == Opcode.CLOSE && value == 1) {
throw new WSProtocolException
("5.5.1.", "A Close frame's status code is only 1 byte long");
}
}
payloadLen = value;
}
@Override
public void maskingKey(int value) {
assert false : dump(invocationOrder, value);
}
@Override
public void payloadData(WSShared<ByteBuffer> data, boolean isLast) {
assert invocationOrder.compareAndSet(7, isLast ? 8 : 7)
: dump(invocationOrder, data, isLast);
if (logger.isLoggable(TRACE)) {
logger.log(TRACE, "Reading payloadData: data={0}, isLast={1}", data, isLast);
}
if (opcode.isControl()) {
if (binaryData != null) {
binaryData.put(data);
data.dispose();
} else if (!isLast) {
// The first chunk of the message
int remaining = data.remaining();
// It shouldn't be 125, otherwise the next chunk will be of size
// 0, which is not what Reader promises to deliver (eager
// reading)
assert remaining < 125 : dump(remaining);
WSShared<ByteBuffer> b = WSShared.wrap(ByteBuffer.allocate(125)).put(data);
data.dispose();
binaryData = b; // Will be disposed by the user
} else {
// The only chunk; will be disposed by the user
binaryData = data.position(data.limit()); // FIXME: remove this hack
}
} else {
part = determinePart(isLast);
boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
if (!text) {
output.onBinary(part, data);
} else {
boolean binaryNonEmpty = data.hasRemaining();
WSShared<CharBuffer> textData;
try {
textData = decoder.decode(data, part.isLast());
} catch (CharacterCodingException e) {
throw new WSProtocolException
("5.6.", "Invalid UTF-8 sequence in frame " + opcode, NOT_CONSISTENT, e);
}
if (!(binaryNonEmpty && !textData.hasRemaining())) {
// If there's a binary data, that result in no text, then we
// don't deliver anything
output.onText(part, new WSDisposableText(textData));
}
}
}
}
@Override
public void endFrame() {
assert invocationOrder.compareAndSet(8, 0) : dump(invocationOrder);
if (opcode.isControl()) {
binaryData.flip();
}
switch (opcode) {
case CLOSE:
WebSocket.CloseCode cc;
String reason;
if (payloadLen == 0) {
cc = null;
reason = "";
} else {
ByteBuffer b = binaryData.buffer();
int len = b.remaining();
assert 2 <= len && len <= 125 : dump(len, payloadLen);
try {
cc = of(b.getChar());
reason = WSCharsetToolkit.decode(b).toString();
} catch (IllegalArgumentException e) {
throw new WSProtocolException
("5.5.1", "Incorrect status code", e);
} catch (CharacterCodingException e) {
throw new WSProtocolException
("5.5.1", "Close reason is a malformed UTF-8 sequence", e);
}
}
binaryData.dispose(); // Manual dispose
output.onClose(cc, reason);
break;
case PING:
output.onPing(binaryData);
binaryData = null;
break;
case PONG:
output.onPong(binaryData);
binaryData = null;
break;
default:
assert opcode == Opcode.TEXT || opcode == Opcode.BINARY
|| opcode == Opcode.CONTINUATION : dump(opcode);
if (fin) {
// It is always the last chunk:
// either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE)
originatingOpcode = null;
}
break;
}
payloadLen = 0;
opcode = null;
}
private MessagePart determinePart(boolean isLast) {
boolean lastChunk = fin && isLast;
switch (part) {
case LAST:
case WHOLE:
return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST;
case FIRST:
case PART:
return lastChunk ? MessagePart.LAST : MessagePart.PART;
default:
throw new InternalError(String.valueOf(part));
}
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.http.WebSocket.CloseCode;
import java.net.http.WebSocket.MessagePart;
import java.nio.ByteBuffer;
interface WSMessageConsumer {
void onText(MessagePart part, WSDisposableText data);
void onBinary(MessagePart part, WSShared<ByteBuffer> data);
void onPing(WSShared<ByteBuffer> data);
void onPong(WSShared<ByteBuffer> data);
void onClose(CloseCode code, CharSequence reason);
}

View File

@ -0,0 +1,189 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.http.WSFrame.HeaderBuilder;
import java.net.http.WSFrame.Masker;
import java.net.http.WSOutgoingMessage.Binary;
import java.net.http.WSOutgoingMessage.Close;
import java.net.http.WSOutgoingMessage.Ping;
import java.net.http.WSOutgoingMessage.Pong;
import java.net.http.WSOutgoingMessage.StreamedText;
import java.net.http.WSOutgoingMessage.Text;
import java.net.http.WSOutgoingMessage.Visitor;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.security.SecureRandom;
import java.util.function.Consumer;
import static java.net.http.WSFrame.MAX_HEADER_SIZE_BYTES;
import static java.net.http.WSFrame.Opcode.BINARY;
import static java.net.http.WSFrame.Opcode.CLOSE;
import static java.net.http.WSFrame.Opcode.CONTINUATION;
import static java.net.http.WSFrame.Opcode.PING;
import static java.net.http.WSFrame.Opcode.PONG;
import static java.net.http.WSFrame.Opcode.TEXT;
import static java.util.Objects.requireNonNull;
/*
* A Sender of outgoing messages. Given a message,
*
* 1) constructs the frame
* 2) initiates the channel write
* 3) notifies when the message has been sent
*/
final class WSMessageSender {
private final Visitor frameBuilderVisitor;
private final Consumer<Throwable> completionEventConsumer;
private final WSWriter writer;
private final ByteBuffer[] buffers = new ByteBuffer[2];
WSMessageSender(RawChannel channel, Consumer<Throwable> completionEventConsumer) {
// Single reusable buffer that holds a header
this.buffers[0] = ByteBuffer.allocateDirect(MAX_HEADER_SIZE_BYTES);
this.frameBuilderVisitor = new FrameBuilderVisitor();
this.completionEventConsumer = completionEventConsumer;
this.writer = new WSWriter(channel, this.completionEventConsumer);
}
/*
* Tries to send the given message fully. Invoked once per message.
*/
boolean trySendFully(WSOutgoingMessage m) {
requireNonNull(m);
synchronized (this) {
try {
return sendNow(m);
} catch (Exception e) {
completionEventConsumer.accept(e);
return false;
}
}
}
private boolean sendNow(WSOutgoingMessage m) {
buffers[0].clear();
m.accept(frameBuilderVisitor);
buffers[0].flip();
return writer.tryWriteFully(buffers);
}
/*
* Builds and initiates a write of a frame, from a given message.
*/
class FrameBuilderVisitor implements Visitor {
private final SecureRandom random = new SecureRandom();
private final WSCharsetToolkit.Encoder encoder = new WSCharsetToolkit.Encoder();
private final Masker masker = new Masker();
private final HeaderBuilder headerBuilder = new HeaderBuilder();
private boolean previousIsLast = true;
@Override
public void visit(Text message) {
try {
buffers[1] = encoder.encode(CharBuffer.wrap(message.characters));
} catch (CharacterCodingException e) {
completionEventConsumer.accept(e);
return;
}
int mask = random.nextInt();
maskAndRewind(buffers[1], mask);
headerBuilder
.fin(message.isLast)
.opcode(previousIsLast ? TEXT : CONTINUATION)
.payloadLen(buffers[1].remaining())
.mask(mask)
.build(buffers[0]);
previousIsLast = message.isLast;
}
@Override
public void visit(StreamedText streamedText) {
throw new IllegalArgumentException("Not yet implemented");
}
@Override
public void visit(Binary message) {
buffers[1] = message.bytes;
int mask = random.nextInt();
maskAndRewind(buffers[1], mask);
headerBuilder
.fin(message.isLast)
.opcode(previousIsLast ? BINARY : CONTINUATION)
.payloadLen(message.bytes.remaining())
.mask(mask)
.build(buffers[0]);
previousIsLast = message.isLast;
}
@Override
public void visit(Ping message) {
buffers[1] = message.bytes;
int mask = random.nextInt();
maskAndRewind(buffers[1], mask);
headerBuilder
.fin(true)
.opcode(PING)
.payloadLen(message.bytes.remaining())
.mask(mask)
.build(buffers[0]);
}
@Override
public void visit(Pong message) {
buffers[1] = message.bytes;
int mask = random.nextInt();
maskAndRewind(buffers[1], mask);
headerBuilder
.fin(true)
.opcode(PONG)
.payloadLen(message.bytes.remaining())
.mask(mask)
.build(buffers[0]);
}
@Override
public void visit(Close message) {
buffers[1] = message.bytes;
int mask = random.nextInt();
maskAndRewind(buffers[1], mask);
headerBuilder
.fin(true)
.opcode(CLOSE)
.payloadLen(buffers[1].remaining())
.mask(mask)
.build(buffers[0]);
}
private void maskAndRewind(ByteBuffer b, int mask) {
int oldPos = b.position();
masker.mask(mask).applyMask(b, b);
b.position(oldPos);
}
}
}

View File

@ -0,0 +1,268 @@
/*
* Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static java.lang.System.Logger.Level.TRACE;
import static java.net.http.WSUtils.logger;
import static java.net.http.WSUtils.webSocketSpecViolation;
final class WSOpeningHandshake {
private static final String HEADER_CONNECTION = "Connection";
private static final String HEADER_UPGRADE = "Upgrade";
private static final String HEADER_ACCEPT = "Sec-WebSocket-Accept";
private static final String HEADER_EXTENSIONS = "Sec-WebSocket-Extensions";
private static final String HEADER_KEY = "Sec-WebSocket-Key";
private static final String HEADER_PROTOCOL = "Sec-WebSocket-Protocol";
private static final String HEADER_VERSION = "Sec-WebSocket-Version";
private static final String VALUE_VERSION = "13"; // WebSocket's lucky number
private static final SecureRandom srandom = new SecureRandom();
private final MessageDigest sha1;
{
try {
sha1 = MessageDigest.getInstance("SHA-1");
} catch (NoSuchAlgorithmException e) {
// Shouldn't happen:
// SHA-1 must be available in every Java platform implementation
throw new InternalError("Minimum platform requirements are not met", e);
}
}
private final HttpRequest request;
private final Collection<String> subprotocols;
private final String nonce;
WSOpeningHandshake(WSBuilder b) {
URI httpURI = createHttpUri(b.getUri());
HttpRequest.Builder requestBuilder = b.getClient().request(httpURI);
if (b.getTimeUnit() != null) {
requestBuilder.timeout(b.getTimeUnit(), b.getTimeout());
}
Collection<String> s = b.getSubprotocols();
if (!s.isEmpty()) {
String p = s.stream().collect(Collectors.joining(", "));
requestBuilder.header(HEADER_PROTOCOL, p);
}
requestBuilder.header(HEADER_VERSION, VALUE_VERSION);
this.nonce = createNonce();
requestBuilder.header(HEADER_KEY, this.nonce);
this.request = requestBuilder.GET();
HttpRequestImpl r = (HttpRequestImpl) this.request;
r.isWebSocket(true);
r.setSystemHeader(HEADER_UPGRADE, "websocket");
r.setSystemHeader(HEADER_CONNECTION, "Upgrade");
this.subprotocols = s;
}
private URI createHttpUri(URI webSocketUri) {
// FIXME: check permission for WebSocket URI and translate it into http/https permission
logger.log(TRACE, "->createHttpUri(''{0}'')", webSocketUri);
String httpScheme = webSocketUri.getScheme().equalsIgnoreCase("ws")
? "http"
: "https";
try {
URI uri = new URI(httpScheme,
webSocketUri.getUserInfo(),
webSocketUri.getHost(),
webSocketUri.getPort(),
webSocketUri.getPath(),
webSocketUri.getQuery(),
null);
logger.log(TRACE, "<-createHttpUri: ''{0}''", uri);
return uri;
} catch (URISyntaxException e) {
// Shouldn't happen: URI invariant
throw new InternalError("Error translating WebSocket URI to HTTP URI", e);
}
}
CompletableFuture<Result> performAsync() {
// The whole dancing with thenCompose instead of thenApply is because
// WebSocketHandshakeException is a checked exception
return request.responseAsync()
.thenCompose(response -> {
try {
Result result = handleResponse(response);
return CompletableFuture.completedFuture(result);
} catch (WebSocketHandshakeException e) {
return CompletableFuture.failedFuture(e);
}
});
}
private Result handleResponse(HttpResponse response) throws WebSocketHandshakeException {
// By this point all redirects, authentications, etc. (if any) must have
// been done by the httpClient used by the WebSocket; so only 101 is
// expected
int statusCode = response.statusCode();
if (statusCode != 101) {
String m = webSocketSpecViolation("1.3.",
"Unable to complete handshake; HTTP response status code "
+ statusCode
);
throw new WebSocketHandshakeException(m, response);
}
HttpHeaders h = response.headers();
checkHeader(h, response, HEADER_UPGRADE, v -> v.equalsIgnoreCase("websocket"));
checkHeader(h, response, HEADER_CONNECTION, v -> v.equalsIgnoreCase("Upgrade"));
checkVersion(response, h);
checkAccept(response, h);
checkExtensions(response, h);
String subprotocol = checkAndReturnSubprotocol(response, h);
RawChannel channel = ((HttpResponseImpl) response).rawChannel();
return new Result(subprotocol, channel);
}
private void checkExtensions(HttpResponse response, HttpHeaders headers)
throws WebSocketHandshakeException {
List<String> ext = headers.allValues(HEADER_EXTENSIONS);
if (!ext.isEmpty()) {
String m = webSocketSpecViolation("4.1.",
"Server responded with extension(s) though none were requested "
+ Arrays.toString(ext.toArray())
);
throw new WebSocketHandshakeException(m, response);
}
}
private String checkAndReturnSubprotocol(HttpResponse response, HttpHeaders headers)
throws WebSocketHandshakeException {
assert response.statusCode() == 101 : response.statusCode();
List<String> sp = headers.allValues(HEADER_PROTOCOL);
int size = sp.size();
if (size == 0) {
// In this case the subprotocol requested (if any) by the client
// doesn't matter. If there is no such header in the response, then
// the server doesn't want to use any subprotocol
return null;
} else if (size > 1) {
// We don't know anything about toString implementation of this
// list, so let's create an array
String m = webSocketSpecViolation("4.1.",
"Server responded with multiple subprotocols: "
+ Arrays.toString(sp.toArray())
);
throw new WebSocketHandshakeException(m, response);
} else {
String selectedSubprotocol = sp.get(0);
if (this.subprotocols.contains(selectedSubprotocol)) {
return selectedSubprotocol;
} else {
String m = webSocketSpecViolation("4.1.",
format("Server responded with a subprotocol " +
"not among those requested: '%s'",
selectedSubprotocol));
throw new WebSocketHandshakeException(m, response);
}
}
}
private void checkAccept(HttpResponse response, HttpHeaders headers)
throws WebSocketHandshakeException {
assert response.statusCode() == 101 : response.statusCode();
String x = nonce + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1.update(x.getBytes(StandardCharsets.ISO_8859_1));
String expected = Base64.getEncoder().encodeToString(sha1.digest());
checkHeader(headers, response, HEADER_ACCEPT, actual -> actual.trim().equals(expected));
}
private void checkVersion(HttpResponse response, HttpHeaders headers)
throws WebSocketHandshakeException {
assert response.statusCode() == 101 : response.statusCode();
List<String> versions = headers.allValues(HEADER_VERSION);
if (versions.isEmpty()) { // That's normal and expected
return;
}
String m = webSocketSpecViolation("4.4.",
"Server responded with version(s) "
+ Arrays.toString(versions.toArray()));
throw new WebSocketHandshakeException(m, response);
}
//
// Checks whether there's only one value for the header with the given name
// and the value satisfies the predicate.
//
private static void checkHeader(HttpHeaders headers,
HttpResponse response,
String headerName,
Predicate<? super String> valuePredicate)
throws WebSocketHandshakeException {
assert response.statusCode() == 101 : response.statusCode();
List<String> values = headers.allValues(headerName);
if (values.isEmpty()) {
String m = webSocketSpecViolation("4.1.",
format("Server response field '%s' is missing", headerName)
);
throw new WebSocketHandshakeException(m, response);
} else if (values.size() > 1) {
String m = webSocketSpecViolation("4.1.",
format("Server response field '%s' has multiple values", headerName)
);
throw new WebSocketHandshakeException(m, response);
}
if (!valuePredicate.test(values.get(0))) {
String m = webSocketSpecViolation("4.1.",
format("Server response field '%s' is incorrect", headerName)
);
throw new WebSocketHandshakeException(m, response);
}
}
private static String createNonce() {
byte[] bytes = new byte[16];
srandom.nextBytes(bytes);
return Base64.getEncoder().encodeToString(bytes);
}
static final class Result {
final String subprotocol;
final RawChannel channel;
private Result(String subprotocol, RawChannel channel) {
this.subprotocol = subprotocol;
this.channel = channel;
}
}
}

View File

@ -0,0 +1,164 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.nio.ByteBuffer;
import java.util.stream.Stream;
abstract class WSOutgoingMessage {
interface Visitor {
void visit(Text message);
void visit(StreamedText message);
void visit(Binary message);
void visit(Ping message);
void visit(Pong message);
void visit(Close message);
}
abstract void accept(Visitor visitor);
private WSOutgoingMessage() { }
static final class Text extends WSOutgoingMessage {
public final boolean isLast;
public final CharSequence characters;
Text(boolean isLast, CharSequence characters) {
this.isLast = isLast;
this.characters = characters;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[isLast=" + isLast
+ ", characters=" + WSUtils.toString(characters) + "]";
}
}
static final class StreamedText extends WSOutgoingMessage {
public final Stream<? extends CharSequence> characters;
StreamedText(Stream<? extends CharSequence> characters) {
this.characters = characters;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[characters=" + characters + "]";
}
}
static final class Binary extends WSOutgoingMessage {
public final boolean isLast;
public final ByteBuffer bytes;
Binary(boolean isLast, ByteBuffer bytes) {
this.isLast = isLast;
this.bytes = bytes;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[isLast=" + isLast
+ ", bytes=" + WSUtils.toString(bytes) + "]";
}
}
static final class Ping extends WSOutgoingMessage {
public final ByteBuffer bytes;
Ping(ByteBuffer bytes) {
this.bytes = bytes;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[" + WSUtils.toString(bytes) + "]";
}
}
static final class Pong extends WSOutgoingMessage {
public final ByteBuffer bytes;
Pong(ByteBuffer bytes) {
this.bytes = bytes;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[" + WSUtils.toString(bytes) + "]";
}
}
static final class Close extends WSOutgoingMessage {
public final ByteBuffer bytes;
Close(ByteBuffer bytes) {
this.bytes = bytes;
}
@Override
void accept(Visitor visitor) {
visitor.visit(this);
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[" + WSUtils.toString(bytes) + "]";
}
}
}

View File

@ -0,0 +1,68 @@
package java.net.http;
import java.net.http.WebSocket.CloseCode;
import static java.net.http.WebSocket.CloseCode.PROTOCOL_ERROR;
import static java.util.Objects.requireNonNull;
//
// Special kind of exception closed from the outside world.
//
// Used as a "marker exception" for protocol issues in the incoming data, so the
// implementation could close the connection and specify an appropriate status
// code.
//
// A separate 'section' argument makes it more uncomfortable to be lazy and to
// leave a relevant spec reference empty :-) As a bonus all messages have the
// same style.
//
final class WSProtocolException extends RuntimeException {
private static final long serialVersionUID = 1L;
private final CloseCode closeCode;
private final String section;
WSProtocolException(String section, String detail) {
this(section, detail, PROTOCOL_ERROR);
}
WSProtocolException(String section, String detail, Throwable cause) {
this(section, detail, PROTOCOL_ERROR, cause);
}
private WSProtocolException(String section, String detail, CloseCode code) {
super(formatMessage(section, detail));
this.closeCode = requireNonNull(code);
this.section = section;
}
WSProtocolException(String section, String detail, CloseCode code,
Throwable cause) {
super(formatMessage(section, detail), cause);
this.closeCode = requireNonNull(code);
this.section = section;
}
private static String formatMessage(String section, String detail) {
if (requireNonNull(section).isEmpty()) {
throw new IllegalArgumentException();
}
if (requireNonNull(detail).isEmpty()) {
throw new IllegalArgumentException();
}
return WSUtils.webSocketSpecViolation(section, detail);
}
CloseCode getCloseCode() {
return closeCode;
}
public String getSection() {
return section;
}
@Override
public String toString() {
return super.toString() + "[" + closeCode + "]";
}
}

View File

@ -0,0 +1,275 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.ProtocolException;
import java.net.http.WebSocket.Listener;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.util.Optional;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import static java.lang.System.Logger.Level.ERROR;
import static java.net.http.WSUtils.EMPTY_BYTE_BUFFER;
import static java.net.http.WSUtils.logger;
/*
* Receives incoming data from the channel and converts it into a sequence of
* messages, which are then passed to the listener.
*/
final class WSReceiver {
private final Listener listener;
private final WebSocket webSocket;
private final Supplier<WSShared<ByteBuffer>> buffersSupplier =
new WSSharedPool<>(() -> ByteBuffer.allocateDirect(32768), 2);
private final RawChannel channel;
private final RawChannel.NonBlockingEvent channelEvent;
private final WSSignalHandler handler;
private final AtomicLong demand = new AtomicLong();
private final AtomicBoolean readable = new AtomicBoolean();
private boolean started;
private volatile boolean closed;
private final WSFrame.Reader reader = new WSFrame.Reader();
private final WSFrameConsumer frameConsumer;
private WSShared<ByteBuffer> buf = WSShared.wrap(EMPTY_BYTE_BUFFER);
private WSShared<ByteBuffer> data; // TODO: initialize with leftovers from the RawChannel
WSReceiver(Listener listener, WebSocket webSocket, Executor executor,
RawChannel channel) {
this.listener = listener;
this.webSocket = webSocket;
this.channel = channel;
handler = new WSSignalHandler(executor, this::react);
channelEvent = createChannelEvent();
this.frameConsumer = new WSFrameConsumer(new MessageConsumer());
}
private void react() {
synchronized (this) {
while (demand.get() > 0 && !closed) {
try {
if (data == null) {
if (!getData()) {
break;
}
}
reader.readFrame(data, frameConsumer);
if (!data.hasRemaining()) {
data.dispose();
data = null;
}
// In case of exception we don't need to clean any state,
// since it's the terminal condition anyway. Nothing will be
// retried.
} catch (WSProtocolException e) {
// Translate into ProtocolException
closeExceptionally(new ProtocolException().initCause(e));
} catch (Exception e) {
closeExceptionally(e);
}
}
}
}
long request(long n) {
long newDemand = demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i);
handler.signal();
assert newDemand >= 0 : newDemand;
return newDemand;
}
private boolean getData() throws IOException {
if (!readable.get()) {
return false;
}
if (!buf.hasRemaining()) {
buf.dispose();
buf = buffersSupplier.get();
assert buf.hasRemaining() : buf;
}
int oldPosition = buf.position();
int oldLimit = buf.limit();
int numRead = channel.read(buf.buffer());
if (numRead > 0) {
data = buf.share(oldPosition, oldPosition + numRead);
buf.select(buf.limit(), oldLimit); // Move window to the free region
return true;
} else if (numRead == 0) {
readable.set(false);
channel.registerEvent(channelEvent);
return false;
} else {
assert numRead < 0 : numRead;
throw new WSProtocolException
("7.2.1.", "Stream ended before a Close frame has been received");
}
}
void start() {
synchronized (this) {
if (started) {
throw new IllegalStateException("Already started");
}
started = true;
try {
channel.registerEvent(channelEvent);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
try {
listener.onOpen(webSocket);
} catch (Exception e) {
closeExceptionally(new RuntimeException("onOpen threw an exception", e));
}
}
}
private void close() { // TODO: move to WS.java
closed = true;
}
private void closeExceptionally(Throwable error) { // TODO: move to WS.java
close();
try {
listener.onError(webSocket, error);
} catch (Exception e) {
logger.log(ERROR, "onError threw an exception", e);
}
}
private final class MessageConsumer implements WSMessageConsumer {
@Override
public void onText(WebSocket.MessagePart part, WSDisposableText data) {
decrementDemand();
CompletionStage<?> cs;
try {
cs = listener.onText(webSocket, data, part);
} catch (Exception e) {
closeExceptionally(new RuntimeException("onText threw an exception", e));
return;
}
follow(cs, data, "onText");
}
@Override
public void onBinary(WebSocket.MessagePart part, WSShared<ByteBuffer> data) {
decrementDemand();
CompletionStage<?> cs;
try {
cs = listener.onBinary(webSocket, data.buffer(), part);
} catch (Exception e) {
closeExceptionally(new RuntimeException("onBinary threw an exception", e));
return;
}
follow(cs, data, "onBinary");
}
@Override
public void onPing(WSShared<ByteBuffer> data) {
decrementDemand();
CompletionStage<?> cs;
try {
cs = listener.onPing(webSocket, data.buffer());
} catch (Exception e) {
closeExceptionally(new RuntimeException("onPing threw an exception", e));
return;
}
follow(cs, data, "onPing");
}
@Override
public void onPong(WSShared<ByteBuffer> data) {
decrementDemand();
CompletionStage<?> cs;
try {
cs = listener.onPong(webSocket, data.buffer());
} catch (Exception e) {
closeExceptionally(new RuntimeException("onPong threw an exception", e));
return;
}
follow(cs, data, "onPong");
}
@Override
public void onClose(WebSocket.CloseCode code, CharSequence reason) {
decrementDemand();
try {
close();
listener.onClose(webSocket, Optional.ofNullable(code), reason.toString());
} catch (Exception e) {
logger.log(ERROR, "onClose threw an exception", e);
}
}
}
private void follow(CompletionStage<?> cs, WSDisposable d, String source) {
if (cs == null) {
d.dispose();
} else {
cs.whenComplete((whatever, error) -> {
if (error != null) {
String m = "CompletionStage returned by " + source + " completed exceptionally";
closeExceptionally(new RuntimeException(m, error));
}
d.dispose();
});
}
}
private void decrementDemand() {
long newDemand = demand.decrementAndGet();
assert newDemand >= 0 : newDemand;
}
private RawChannel.NonBlockingEvent createChannelEvent() {
return new RawChannel.NonBlockingEvent() {
@Override
public int interestOps() {
return SelectionKey.OP_READ;
}
@Override
public void handle() {
boolean wasNotReadable = readable.compareAndSet(false, true);
assert wasNotReadable;
handler.signal();
}
@Override
public String toString() {
return "Read readiness event [" + channel + "]";
}
};
}
}

View File

@ -0,0 +1,202 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
//
// +-----------+---------------+------------ ~ ------+
// | shared#1 | shared#2 | non-shared |
// +-----------+---------------+------------ ~ ------+
// | |
// |<------------------ shared0 ---------- ~ ----->|
//
//
// Objects of the type are not thread-safe. It's the responsibility of the
// client to access shared buffers safely between threads.
//
// It would be perfect if we could extend java.nio.Buffer, but it's not an
// option since Buffer and all its descendants have package-private
// constructors.
//
abstract class WSShared<T extends Buffer> implements WSDisposable {
protected final AtomicBoolean disposed = new AtomicBoolean();
protected final T buffer;
protected WSShared(T buffer) {
this.buffer = Objects.requireNonNull(buffer);
}
static <T extends Buffer> WSShared<T> wrap(T buffer) {
return new WSShared<>(buffer) {
@Override
WSShared<T> share(int pos, int limit) {
throw new UnsupportedOperationException();
}
};
}
// TODO: should be a terminal operation as after it returns the buffer might
// have escaped (we can't protect it any more)
public T buffer() {
checkDisposed();
return buffer;
}
abstract WSShared<T> share(final int pos, final int limit);
WSShared<T> select(final int pos, final int limit) {
checkRegion(pos, limit, buffer());
select(pos, limit, buffer());
return this;
}
@Override
public void dispose() {
if (!disposed.compareAndSet(false, true)) {
throw new IllegalStateException("Has been disposed previously");
}
}
int limit() {
return buffer().limit();
}
WSShared<T> limit(int newLimit) {
buffer().limit(newLimit);
return this;
}
int position() {
return buffer().position();
}
WSShared<T> position(int newPosition) {
buffer().position(newPosition);
return this;
}
int remaining() {
return buffer().remaining();
}
boolean hasRemaining() {
return buffer().hasRemaining();
}
WSShared<T> flip() {
buffer().flip();
return this;
}
WSShared<T> rewind() {
buffer().rewind();
return this;
}
WSShared<T> put(WSShared<? extends T> src) {
put(this.buffer(), src.buffer());
return this;
}
static void checkRegion(int position, int limit, Buffer buffer) {
if (position < 0 || position > buffer.capacity()) {
throw new IllegalArgumentException("position: " + position);
}
if (limit < 0 || limit > buffer.capacity()) {
throw new IllegalArgumentException("limit: " + limit);
}
if (limit < position) {
throw new IllegalArgumentException
("limit < position: limit=" + limit + ", position=" + position);
}
}
void select(int newPos, int newLim, Buffer buffer) {
int oldPos = buffer.position();
int oldLim = buffer.limit();
assert 0 <= oldPos && oldPos <= oldLim && oldLim <= buffer.capacity();
if (oldLim <= newPos) {
buffer().limit(newLim).position(newPos);
} else {
buffer.position(newPos).limit(newLim);
}
}
// The same as dst.put(src)
static <T extends Buffer> T put(T dst, T src) {
if (dst instanceof ByteBuffer) {
((ByteBuffer) dst).put((ByteBuffer) src);
} else if (dst instanceof CharBuffer) {
((CharBuffer) dst).put((CharBuffer) src);
} else {
// We don't work with buffers of other types
throw new IllegalArgumentException();
}
return dst;
}
// TODO: Remove when JDK-8150785 has been done
@SuppressWarnings("unchecked")
static <T extends Buffer> T slice(T buffer) {
if (buffer instanceof ByteBuffer) {
return (T) ((ByteBuffer) buffer).slice();
} else if (buffer instanceof CharBuffer) {
return (T) ((CharBuffer) buffer).slice();
} else {
// We don't work with buffers of other types
throw new IllegalArgumentException();
}
}
// TODO: Remove when JDK-8150785 has been done
@SuppressWarnings("unchecked")
static <T extends Buffer> T duplicate(T buffer) {
if (buffer instanceof ByteBuffer) {
return (T) ((ByteBuffer) buffer).duplicate();
} else if (buffer instanceof CharBuffer) {
return (T) ((CharBuffer) buffer).duplicate();
} else {
// We don't work with buffers of other types
throw new IllegalArgumentException();
}
}
@Override
public String toString() {
return super.toString() + "[" + WSUtils.toString(buffer()) + "]";
}
private void checkDisposed() {
if (disposed.get()) {
throw new IllegalStateException("Has been disposed previously");
}
}
}

View File

@ -0,0 +1,148 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.nio.Buffer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import static java.lang.System.Logger.Level.TRACE;
import static java.net.http.WSShared.duplicate;
import static java.net.http.WSUtils.logger;
import static java.util.Objects.requireNonNull;
final class WSSharedPool<T extends Buffer> implements Supplier<WSShared<T>> {
private final Supplier<T> factory;
private final BlockingQueue<T> queue;
WSSharedPool(Supplier<T> factory, int maxPoolSize) {
this.factory = requireNonNull(factory);
this.queue = new LinkedBlockingQueue<>(maxPoolSize);
}
@Override
public Pooled get() {
T b = queue.poll();
if (b == null) {
logger.log(TRACE, "Pool {0} contains no free buffers", this);
b = requireNonNull(factory.get());
}
Pooled buf = new Pooled(new AtomicInteger(1), b, duplicate(b));
logger.log(TRACE, "Pool {0} created new buffer {1}", this, buf);
return buf;
}
private void put(Pooled b) {
assert b.disposed.get() && b.refCount.get() == 0
: WSUtils.dump(b.disposed, b.refCount, b);
b.shared.clear();
boolean accepted = queue.offer(b.getShared());
if (logger.isLoggable(TRACE)) {
if (accepted) {
logger.log(TRACE, "Pool {0} accepted {1}", this, b);
} else {
logger.log(TRACE, "Pool {0} discarded {1}", this, b);
}
}
}
@Override
public String toString() {
return super.toString() + "[queue.size=" + queue.size() + "]";
}
private final class Pooled extends WSShared<T> {
private final AtomicInteger refCount;
private final T shared;
private Pooled(AtomicInteger refCount, T shared, T region) {
super(region);
this.refCount = refCount;
this.shared = shared;
}
private T getShared() {
return shared;
}
@Override
@SuppressWarnings("unchecked")
public Pooled share(final int pos, final int limit) {
synchronized (this) {
T buffer = buffer();
checkRegion(pos, limit, buffer);
final int oldPos = buffer.position();
final int oldLimit = buffer.limit();
select(pos, limit, buffer);
T slice = WSShared.slice(buffer);
select(oldPos, oldLimit, buffer);
referenceAndGetCount();
Pooled buf = new Pooled(refCount, shared, slice);
logger.log(TRACE, "Shared {0} from {1}", buf, this);
return buf;
}
}
@Override
public void dispose() {
logger.log(TRACE, "Disposed {0}", this);
super.dispose();
if (dereferenceAndGetCount() == 0) {
WSSharedPool.this.put(this);
}
}
private int referenceAndGetCount() {
return refCount.updateAndGet(n -> {
if (n != Integer.MAX_VALUE) {
return n + 1;
} else {
throw new IllegalArgumentException
("Too many references: " + this);
}
});
}
private int dereferenceAndGetCount() {
return refCount.updateAndGet(n -> {
if (n > 0) {
return n - 1;
} else {
throw new InternalError();
}
});
}
@Override
public String toString() {
return WSUtils.toStringSimple(this) + "[" + WSUtils.toString(buffer)
+ "[refCount=" + refCount + ", disposed=" + disposed + "]]";
}
}
}

View File

@ -0,0 +1,137 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import static java.util.Objects.requireNonNull;
//
// The problem:
// ------------
// 1. For every invocation of 'signal()' there must be at least
// 1 invocation of 'handler.run()' that goes after
// 2. There must be no more than 1 thread running the 'handler.run()'
// at any given time
//
// For example, imagine each signal increments (+1) some number. Then the
// handler responds (eventually) the way that makes the number 0.
//
// For each signal there's a response. Several signals may be handled by a
// single response.
//
final class WSSignalHandler {
// In this state the task is neither submitted nor running.
// No one is handling signals. If a new signal has been received, the task
// has to be submitted to the executor in order to handle this signal.
private static final int DONE = 0;
// In this state the task is running.
// * If the signaller has found the task in this state it will try to change
// the state to RERUN in order to make the already running task to handle
// the new signal before exiting.
// * If the task has found itself in this state it will exit.
private static final int RUNNING = 1;
// A signal to the task, that it must rerun on the spot (without being
// resubmitted to the executor).
// If the task has found itself in this state it resets the state to
// RUNNING and repeats the pass.
private static final int RERUN = 2;
private final AtomicInteger state = new AtomicInteger(DONE);
private final Executor executor;
private final Runnable task;
WSSignalHandler(Executor executor, Runnable handler) {
this.executor = requireNonNull(executor);
requireNonNull(handler);
task = () -> {
while (!Thread.currentThread().isInterrupted()) {
try {
handler.run();
} catch (Exception e) {
// Sorry, the task won't be automatically retried;
// hope next signals (if any) will kick off the handling
state.set(DONE);
throw e;
}
int prev = state.getAndUpdate(s -> {
if (s == RUNNING) {
return DONE;
} else {
return RUNNING;
}
});
// Can't be DONE, since only the task itself may transit state
// into DONE (with one exception: RejectedExecution in signal();
// but in that case we couldn't be here at all)
assert prev == RUNNING || prev == RERUN;
if (prev == RUNNING) {
break;
}
}
};
}
// Invoked by outer code to signal
void signal() {
int prev = state.getAndUpdate(s -> {
switch (s) {
case RUNNING:
return RERUN;
case DONE:
return RUNNING;
case RERUN:
return RERUN;
default:
throw new InternalError(String.valueOf(s));
}
});
if (prev != DONE) {
// Nothing to do! piggybacking on previous signal
return;
}
try {
executor.execute(task);
} catch (RejectedExecutionException e) {
// Sorry some signal() invocations may have been accepted, but won't
// be done, since the 'task' couldn't be submitted
state.set(DONE);
throw e;
}
}
}

View File

@ -0,0 +1,176 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.http.WSOutgoingMessage.Binary;
import java.net.http.WSOutgoingMessage.Close;
import java.net.http.WSOutgoingMessage.Ping;
import java.net.http.WSOutgoingMessage.Pong;
import java.net.http.WSOutgoingMessage.StreamedText;
import java.net.http.WSOutgoingMessage.Text;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CoderResult;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Consumer;
import java.util.stream.Stream;
import static java.lang.String.format;
import static java.net.http.Pair.pair;
/*
* Prepares outgoing messages for transmission. Verifies the WebSocket state,
* places the message on the outbound queue, and notifies the signal handler.
*/
final class WSTransmitter {
private final BlockingQueue<Pair<WSOutgoingMessage, CompletableFuture<Void>>>
backlog = new LinkedBlockingQueue<>();
private final WSMessageSender sender;
private final WSSignalHandler handler;
private boolean previousMessageSent = true;
private boolean canSendBinary = true;
private boolean canSendText = true;
WSTransmitter(Executor executor, RawChannel channel, Consumer<Throwable> errorHandler) {
this.handler = new WSSignalHandler(executor, this::handleSignal);
Consumer<Throwable> sendCompletion = (error) -> {
synchronized (this) {
if (error == null) {
previousMessageSent = true;
handler.signal();
} else {
errorHandler.accept(error);
backlog.forEach(p -> p.second.completeExceptionally(error));
backlog.clear();
}
}
};
this.sender = new WSMessageSender(channel, sendCompletion);
}
CompletableFuture<Void> sendText(CharSequence message, boolean isLast) {
checkAndUpdateText(isLast);
return acceptMessage(new Text(isLast, message));
}
CompletableFuture<Void> sendText(Stream<? extends CharSequence> message) {
checkAndUpdateText(true);
return acceptMessage(new StreamedText(message));
}
CompletableFuture<Void> sendBinary(ByteBuffer message, boolean isLast) {
checkAndUpdateBinary(isLast);
return acceptMessage(new Binary(isLast, message));
}
CompletableFuture<Void> sendPing(ByteBuffer message) {
checkSize(message.remaining(), 125);
return acceptMessage(new Ping(message));
}
CompletableFuture<Void> sendPong(ByteBuffer message) {
checkSize(message.remaining(), 125);
return acceptMessage(new Pong(message));
}
CompletableFuture<Void> sendClose(WebSocket.CloseCode code, CharSequence reason) {
return acceptMessage(createCloseMessage(code, reason));
}
CompletableFuture<Void> sendClose() {
return acceptMessage(new Close(ByteBuffer.allocate(0)));
}
private CompletableFuture<Void> acceptMessage(WSOutgoingMessage m) {
CompletableFuture<Void> cf = new CompletableFuture<>();
synchronized (this) {
backlog.offer(pair(m, cf));
}
handler.signal();
return cf;
}
/* Callback for pulling messages from the queue, and initiating the send. */
private void handleSignal() {
synchronized (this) {
while (!backlog.isEmpty() && previousMessageSent) {
previousMessageSent = false;
Pair<WSOutgoingMessage, CompletableFuture<Void>> p = backlog.peek();
boolean sent = sender.trySendFully(p.first);
if (sent) {
backlog.remove();
p.second.complete(null);
previousMessageSent = true;
}
}
}
}
private Close createCloseMessage(WebSocket.CloseCode code, CharSequence reason) {
// TODO: move to construction of CloseDetail (JDK-8155621)
ByteBuffer b = ByteBuffer.allocateDirect(125).putChar((char) code.getCode());
CoderResult result = StandardCharsets.UTF_8.newEncoder()
.encode(CharBuffer.wrap(reason), b, true);
if (result.isError()) {
try {
result.throwException();
} catch (CharacterCodingException e) {
throw new IllegalArgumentException("Reason is a malformed UTF-16 sequence", e);
}
} else if (result.isOverflow()) {
throw new IllegalArgumentException("Reason is too long");
}
return new Close(b.flip());
}
private void checkSize(int size, int maxSize) {
if (size > maxSize) {
throw new IllegalArgumentException(
format("The message is too long: %s;" +
" expected not longer than %s", size, maxSize)
);
}
}
private void checkAndUpdateText(boolean isLast) {
if (!canSendText) {
throw new IllegalStateException("Unexpected text message");
}
canSendBinary = isLast;
}
private void checkAndUpdateBinary(boolean isLast) {
if (!canSendBinary) {
throw new IllegalStateException("Unexpected binary message");
}
canSendText = isLast;
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.Arrays;
final class WSUtils {
private WSUtils() { }
static final System.Logger logger = System.getLogger("java.net.http.WebSocket");
static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0);
//
// Helps to trim long names (packages, nested/inner types) in logs/toString
//
static String toStringSimple(Object o) {
return o.getClass().getSimpleName() + "@" +
Integer.toHexString(System.identityHashCode(o));
}
//
// 1. It adds a number of remaining bytes;
// 2. Standard Buffer-type toString for CharBuffer (since it adheres to the
// contract of java.lang.CharSequence.toString() which is both not too
// useful and not too private)
//
static String toString(Buffer b) {
return toStringSimple(b)
+ "[pos=" + b.position()
+ " lim=" + b.limit()
+ " cap=" + b.capacity()
+ " rem=" + b.remaining() + "]";
}
static String toString(CharSequence s) {
return s == null
? "null"
: toStringSimple(s) + "[len=" + s.length() + "]";
}
static String dump(Object... objects) {
return Arrays.toString(objects);
}
static String webSocketSpecViolation(String section, String detail) {
return "RFC 6455 " + section + " " + detail;
}
}

View File

@ -0,0 +1,134 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.util.function.Consumer;
import static java.util.Objects.requireNonNull;
/*
* Writes ByteBuffer[] to the channel in a non-blocking, asynchronous fashion.
*
* A client attempts to write data by calling
*
* boolean tryWriteFully(ByteBuffer[] buffers)
*
* If the attempt was successful and all the data has been written, then the
* method returns `true`.
*
* If the data has been written partially, then the method returns `false`, and
* the writer (this object) attempts to complete the write asynchronously by
* calling, possibly more than once
*
* boolean tryCompleteWrite()
*
* in its own threads.
*
* When the write has been completed asynchronously, the callback is signalled
* with `null`.
*
* If an error occurs in any of these stages it will NOT be thrown from the
* method. Instead `false` will be returned and the exception will be signalled
* to the callback. This is done in order to handle all exceptions in a single
* place.
*/
final class WSWriter {
private final RawChannel channel;
private final RawChannel.NonBlockingEvent writeReadinessHandler;
private final Consumer<Throwable> completionCallback;
private ByteBuffer[] buffers;
private int offset;
WSWriter(RawChannel channel, Consumer<Throwable> completionCallback) {
this.channel = channel;
this.completionCallback = completionCallback;
this.writeReadinessHandler = createHandler();
}
boolean tryWriteFully(ByteBuffer[] buffers) {
synchronized (this) {
this.buffers = requireNonNull(buffers);
this.offset = 0;
}
return tryCompleteWrite();
}
private final boolean tryCompleteWrite() {
try {
return writeNow();
} catch (IOException e) {
completionCallback.accept(e);
return false;
}
}
private boolean writeNow() throws IOException {
synchronized (this) {
for (; offset != -1; offset = nextUnwrittenIndex(buffers, offset)) {
long bytesWritten = channel.write(buffers, offset, buffers.length - offset);
if (bytesWritten == 0) {
channel.registerEvent(writeReadinessHandler);
return false;
}
}
return true;
}
}
private static int nextUnwrittenIndex(ByteBuffer[] buffers, int offset) {
for (int i = offset; i < buffers.length; i++) {
if (buffers[i].hasRemaining()) {
return i;
}
}
return -1;
}
private RawChannel.NonBlockingEvent createHandler() {
return new RawChannel.NonBlockingEvent() {
@Override
public int interestOps() {
return SelectionKey.OP_WRITE;
}
@Override
public void handle() {
if (tryCompleteWrite()) {
completionCallback.accept(null);
}
}
@Override
public String toString() {
return "Write readiness event [" + channel + "]";
}
};
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.net.http;
/**
* An exception used to signal the opening handshake failed.
*
* @since 9
*/
public final class WebSocketHandshakeException extends Exception {
private static final long serialVersionUID = 1L;
private final transient HttpResponse response;
WebSocketHandshakeException(HttpResponse response) {
this(null, response);
}
WebSocketHandshakeException(String message, HttpResponse response) {
super(statusCodeOrFullMessage(message, response));
this.response = response;
}
/**
* // FIXME: terrible toString (+ not always status should be displayed I guess)
*/
private static String statusCodeOrFullMessage(String m, HttpResponse response) {
return (m == null || m.isEmpty())
? String.valueOf(response.statusCode())
: response.statusCode() + ": " + m;
}
/**
* Returns a HTTP response from the server.
*
* <p> The value may be unavailable ({@code null}) if this exception has
* been serialized and then read back in.
*
* @return server response
*/
public HttpResponse getResponse() {
return response;
}
}

View File

@ -33,6 +33,7 @@
* <li>{@link java.net.http.HttpClient}</li>
* <li>{@link java.net.http.HttpRequest}</li>
* <li>{@link java.net.http.HttpResponse}</li>
* <li>{@link java.net.http.WebSocket}</li>
* </ul>
*
* @since 9

View File

@ -0,0 +1,332 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
import org.testng.annotations.Test;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.WebSocket;
import java.net.http.WebSocket.CloseCode;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
/*
* @test
* @bug 8087113
* @build TestKit
* @run testng/othervm BasicWebSocketAPITest
*/
public class BasicWebSocketAPITest {
@Test
public void webSocket() throws Exception {
checkAndClose(
(ws) ->
TestKit.assertThrows(IllegalArgumentException.class,
"(?i).*\\bnegative\\b.*",
() -> ws.request(-1))
);
checkAndClose((ws) ->
TestKit.assertNotThrows(() -> ws.request(0))
);
checkAndClose((ws) ->
TestKit.assertNotThrows(() -> ws.request(1))
);
checkAndClose((ws) ->
TestKit.assertNotThrows(() -> ws.request(Long.MAX_VALUE))
);
checkAndClose((ws) ->
TestKit.assertNotThrows(ws::isClosed)
);
checkAndClose((ws) ->
TestKit.assertNotThrows(ws::getSubprotocol)
);
checkAndClose(
(ws) -> {
try {
ws.abort();
} catch (IOException ignored) { }
// No matter what happens during the first abort invocation,
// other invocations must return normally
TestKit.assertNotThrows(ws::abort);
TestKit.assertNotThrows(ws::abort);
}
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendBinary((byte[]) null, true))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendBinary((ByteBuffer) null, true))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendPing(null))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendPong(null))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendText((CharSequence) null, true))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendText((CharSequence) null))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"message",
() -> ws.sendText((Stream<? extends CharSequence>) null))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"code",
() -> ws.sendClose(null, ""))
);
checkAndClose(
(ws) ->
TestKit.assertNotThrows(
() -> ws.sendClose(CloseCode.NORMAL_CLOSURE, ""))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"reason",
() -> ws.sendClose(CloseCode.NORMAL_CLOSURE, null))
);
checkAndClose(
(ws) ->
TestKit.assertThrows(NullPointerException.class,
"code|reason",
() -> ws.sendClose(null, null))
);
}
@Test
public void builder() {
URI ws = URI.create("ws://localhost:9001");
// FIXME: check all 24 cases:
// {null, ws, wss, incorrect} x {null, HttpClient.getDefault(), custom} x {null, listener}
//
// if (any null) or (any incorrect)
// NPE or IAE is thrown
// else
// builder is created
TestKit.assertThrows(NullPointerException.class,
"uri",
() -> WebSocket.newBuilder(null, defaultListener())
);
TestKit.assertThrows(NullPointerException.class,
"listener",
() -> WebSocket.newBuilder(ws, null)
);
URI uri = URI.create("ftp://localhost:9001");
TestKit.assertThrows(IllegalArgumentException.class,
"(?i).*\\buri\\b\\s+\\bscheme\\b.*",
() -> WebSocket.newBuilder(uri, defaultListener())
);
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener())
);
URI uri1 = URI.create("wss://localhost:9001");
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(uri1, defaultListener())
);
URI uri2 = URI.create("wss://localhost:9001#a");
TestKit.assertThrows(IllegalArgumentException.class,
"(?i).*\\bfragment\\b.*",
() -> WebSocket.newBuilder(uri2, HttpClient.getDefault(), defaultListener())
);
TestKit.assertThrows(NullPointerException.class,
"uri",
() -> WebSocket.newBuilder(null, HttpClient.getDefault(), defaultListener())
);
TestKit.assertThrows(NullPointerException.class,
"client",
() -> WebSocket.newBuilder(ws, null, defaultListener())
);
TestKit.assertThrows(NullPointerException.class,
"listener",
() -> WebSocket.newBuilder(ws, HttpClient.getDefault(), null)
);
// FIXME: check timeout works
// (i.e. it directly influences the time WebSocket waits for connection + opening handshake)
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener()).connectTimeout(1, TimeUnit.SECONDS)
);
WebSocket.Builder builder = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
"(?i).*\\bnegative\\b.*",
() -> builder.connectTimeout(-1, TimeUnit.SECONDS)
);
WebSocket.Builder builder1 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(NullPointerException.class,
"unit",
() -> builder1.connectTimeout(1, null)
);
// FIXME: check these headers are actually received by the server
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener()).header("a", "b")
);
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener()).header("a", "b").header("a", "b")
);
// FIXME: check all 18 cases:
// {null, websocket(7), custom} x {null, custom}
WebSocket.Builder builder2 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(NullPointerException.class,
"name",
() -> builder2.header(null, "b")
);
WebSocket.Builder builder3 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(NullPointerException.class,
"value",
() -> builder3.header("a", null)
);
WebSocket.Builder builder4 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder4.header("Sec-WebSocket-Accept", "")
);
WebSocket.Builder builder5 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder5.header("Sec-WebSocket-Extensions", "")
);
WebSocket.Builder builder6 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder6.header("Sec-WebSocket-Key", "")
);
WebSocket.Builder builder7 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder7.header("Sec-WebSocket-Protocol", "")
);
WebSocket.Builder builder8 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder8.header("Sec-WebSocket-Version", "")
);
WebSocket.Builder builder9 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder9.header("Connection", "")
);
WebSocket.Builder builder10 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder10.header("Upgrade", "")
);
// FIXME: check 3 cases (1 arg):
// {null, incorrect, custom}
// FIXME: check 12 cases (2 args):
// {null, incorrect, custom} x {(String) null, (String[]) null, incorrect, custom}
// FIXME: check 27 cases (3 args) (the interesting part in null inside var-arg):
// {null, incorrect, custom}^3
// FIXME: check the server receives them in the order listed
TestKit.assertThrows(NullPointerException.class,
"mostPreferred",
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols(null)
);
TestKit.assertThrows(NullPointerException.class,
"lesserPreferred",
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols("a", null)
);
TestKit.assertThrows(NullPointerException.class,
"lesserPreferred\\[0\\]",
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols("a", null, "b")
);
TestKit.assertThrows(NullPointerException.class,
"lesserPreferred\\[1\\]",
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols("a", "b", null)
);
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols("a")
);
TestKit.assertNotThrows(
() -> WebSocket.newBuilder(ws, defaultListener()).subprotocols("a", "b", "c")
);
WebSocket.Builder builder11 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder11.subprotocols("")
);
WebSocket.Builder builder12 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder12.subprotocols("a", "a")
);
WebSocket.Builder builder13 = WebSocket.newBuilder(ws, defaultListener());
TestKit.assertThrows(IllegalArgumentException.class,
() -> builder13.subprotocols("a" + ((char) 0x7f))
);
}
private static WebSocket.Listener defaultListener() {
return new WebSocket.Listener() { };
}
//
// Automatically closes everything after the check has been performed
//
private static void checkAndClose(Consumer<? super WebSocket> c) {
HandshakePhase HandshakePhase
= new HandshakePhase(new InetSocketAddress("127.0.0.1", 0));
URI serverURI = HandshakePhase.getURI();
CompletableFuture<SocketChannel> cfc = HandshakePhase.afterHandshake();
WebSocket.Builder b = WebSocket.newBuilder(serverURI, defaultListener());
CompletableFuture<WebSocket> cfw = b.buildAsync();
try {
WebSocket ws;
try {
ws = cfw.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
c.accept(ws);
} finally {
try {
SocketChannel now = cfc.getNow(null);
if (now != null) {
now.close();
}
} catch (Throwable ignored) { }
}
}
}

View File

@ -0,0 +1,265 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
//
// Performs a simple opening handshake and yields the channel.
//
// Client Request:
//
// GET /chat HTTP/1.1
// Host: server.example.com
// Upgrade: websocket
// Connection: Upgrade
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Origin: http://example.com
// Sec-WebSocket-Protocol: chat, superchat
// Sec-WebSocket-Version: 13
//
//
// Server Response:
//
// HTTP/1.1 101 Switching Protocols
// Upgrade: websocket
// Connection: Upgrade
// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
// Sec-WebSocket-Protocol: chat
//
final class HandshakePhase {
private final ServerSocketChannel ssc;
HandshakePhase(InetSocketAddress address) {
requireNonNull(address);
try {
ssc = ServerSocketChannel.open();
ssc.bind(address);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
//
// Returned CF completes normally after the handshake has been performed
//
CompletableFuture<SocketChannel> afterHandshake(
Function<List<String>, List<String>> mapping) {
return CompletableFuture.supplyAsync(
() -> {
SocketChannel socketChannel = accept();
try {
StringBuilder request = new StringBuilder();
if (!readRequest(socketChannel, request)) {
throw new IllegalStateException();
}
List<String> strings = Arrays.asList(
request.toString().split("\r\n")
);
List<String> response = mapping.apply(strings);
writeResponse(socketChannel, response);
return socketChannel;
} catch (Throwable t) {
try {
socketChannel.close();
} catch (IOException ignored) { }
throw t;
}
});
}
CompletableFuture<SocketChannel> afterHandshake() {
return afterHandshake((request) -> {
List<String> response = new LinkedList<>();
Iterator<String> iterator = request.iterator();
if (!iterator.hasNext()) {
throw new IllegalStateException("The request is empty");
}
if (!"GET / HTTP/1.1".equals(iterator.next())) {
throw new IllegalStateException
("Unexpected status line: " + request.get(0));
}
response.add("HTTP/1.1 101 Switching Protocols");
Map<String, String> requestHeaders = new HashMap<>();
while (iterator.hasNext()) {
String header = iterator.next();
String[] split = header.split(": ");
if (split.length != 2) {
throw new IllegalStateException
("Unexpected header: " + header
+ ", split=" + Arrays.toString(split));
}
if (requestHeaders.put(split[0], split[1]) != null) {
throw new IllegalStateException
("Duplicating headers: " + Arrays.toString(split));
}
}
if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
throw new IllegalStateException("Subprotocols are not expected");
}
if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
throw new IllegalStateException("Extensions are not expected");
}
expectHeader(requestHeaders, "Connection", "Upgrade");
response.add("Connection: Upgrade");
expectHeader(requestHeaders, "Upgrade", "websocket");
response.add("Upgrade: websocket");
expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
String key = requestHeaders.get("Sec-WebSocket-Key");
if (key == null) {
throw new IllegalStateException("Sec-WebSocket-Key is missing");
}
MessageDigest sha1 = null;
try {
sha1 = MessageDigest.getInstance("SHA-1");
} catch (NoSuchAlgorithmException e) {
throw new InternalError(e);
}
String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1.update(x.getBytes(StandardCharsets.ISO_8859_1));
String v = Base64.getEncoder().encodeToString(sha1.digest());
response.add("Sec-WebSocket-Accept: " + v);
return response;
});
}
private String expectHeader(Map<String, String> headers,
String name,
String value) {
String v = headers.get(name);
if (!value.equals(v)) {
throw new IllegalStateException(
format("Expected '%s: %s', actual: '%s: %s'",
name, value, name, v)
);
}
return v;
}
URI getURI() {
InetSocketAddress a;
try {
a = (InetSocketAddress) ssc.getLocalAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return URI.create("ws://" + a.getHostName() + ":" + a.getPort());
}
private int read(SocketChannel socketChannel, ByteBuffer buffer) {
try {
int num = socketChannel.read(buffer);
if (num == -1) {
throw new IllegalStateException("Unexpected EOF");
}
assert socketChannel.isBlocking() && num > 0;
return num;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
private SocketChannel accept() {
SocketChannel socketChannel = null;
try {
socketChannel = ssc.accept();
socketChannel.configureBlocking(true);
} catch (IOException e) {
if (socketChannel != null) {
try {
socketChannel.close();
} catch (IOException ignored) { }
}
throw new UncheckedIOException(e);
}
return socketChannel;
}
private boolean readRequest(SocketChannel socketChannel,
StringBuilder request) {
ByteBuffer buffer = ByteBuffer.allocateDirect(512);
read(socketChannel, buffer);
CharBuffer decoded;
buffer.flip();
try {
decoded =
StandardCharsets.ISO_8859_1.newDecoder().decode(buffer);
} catch (CharacterCodingException e) {
throw new UncheckedIOException(e);
}
request.append(decoded);
return Pattern.compile("\r\n\r\n").matcher(request).find();
}
private void writeResponse(SocketChannel socketChannel,
List<String> response) {
String s = response.stream().collect(Collectors.joining("\r\n"))
+ "\r\n\r\n";
ByteBuffer encoded;
try {
encoded =
StandardCharsets.ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
} catch (CharacterCodingException e) {
throw new UncheckedIOException(e);
}
write(socketChannel, encoded);
}
private void write(SocketChannel socketChannel, ByteBuffer buffer) {
try {
while (buffer.hasRemaining()) {
socketChannel.write(buffer);
}
} catch (IOException e) {
try {
socketChannel.close();
} catch (IOException ignored) { }
throw new UncheckedIOException(e);
}
}
}