8335181: Incorrect handling of HTTP/2 GOAWAY frames in HttpClient

Reviewed-by: dfuchs
This commit is contained in:
Jaikiran Pai 2024-08-14 05:42:14 +00:00
parent f132b347e1
commit 720b44648b
12 changed files with 625 additions and 56 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -58,6 +58,10 @@ abstract class ExchangeImpl<T> {
final Exchange<T> exchange;
// this will be set to true only when the peer explicitly states (through a GOAWAY frame or
// a relevant error code in reset frame) that the corresponding stream (id) wasn't processed
private volatile boolean unprocessedByPeer;
ExchangeImpl(Exchange<T> e) {
// e == null means a http/2 pushed stream
this.exchange = e;
@ -265,4 +269,13 @@ abstract class ExchangeImpl<T> {
// Called when server returns non 100 response to
// an Expect-Continue
void expectContinueFailed(int rcode) { }
final boolean isUnprocessedByPeer() {
return this.unprocessedByPeer;
}
// Marks the exchange as unprocessed by the peer
final void markUnprocessedByPeer() {
this.unprocessedByPeer = true;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -47,6 +47,8 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
@ -358,6 +360,7 @@ class Http2Connection {
private final String key; // for HttpClientImpl.connections map
private final FramesDecoder framesDecoder;
private final FramesEncoder framesEncoder = new FramesEncoder();
private final AtomicLong lastProcessedStreamInGoAway = new AtomicLong(-1);
/**
* Send Window controller for both connection and stream windows.
@ -725,7 +728,9 @@ class Http2Connection {
void close() {
if (markHalfClosedLocal()) {
if (connection.channel().isOpen()) {
// we send a GOAWAY frame only if the remote side hasn't already indicated
// the intention to close the connection by previously sending a GOAWAY of its own
if (connection.channel().isOpen() && !isMarked(closedState, HALF_CLOSED_REMOTE)) {
Log.logTrace("Closing HTTP/2 connection: to {0}", connection.address());
GoAwayFrame f = new GoAwayFrame(0,
ErrorFrame.NO_ERROR,
@ -1205,13 +1210,46 @@ class Http2Connection {
sendUnorderedFrame(frame);
}
private void handleGoAway(GoAwayFrame frame)
throws IOException
{
if (markHalfClosedLRemote()) {
shutdown(new IOException(
connection.channel().getLocalAddress()
+ ": GOAWAY received"));
private void handleGoAway(final GoAwayFrame frame) {
final long lastProcessedStream = frame.getLastStream();
assert lastProcessedStream >= 0 : "unexpected last stream id: "
+ lastProcessedStream + " in GOAWAY frame";
markHalfClosedRemote();
setFinalStream(); // don't allow any new streams on this connection
if (debug.on()) {
debug.log("processing incoming GOAWAY with last processed stream id:%s in frame %s",
lastProcessedStream, frame);
}
// see if this connection has previously received a GOAWAY from the peer and if yes
// then check if this new last processed stream id is lesser than the previous
// known last processed stream id. Only update the last processed stream id if the new
// one is lesser than the previous one.
long prevLastProcessed = lastProcessedStreamInGoAway.get();
while (prevLastProcessed == -1 || lastProcessedStream < prevLastProcessed) {
if (lastProcessedStreamInGoAway.compareAndSet(prevLastProcessed,
lastProcessedStream)) {
break;
}
prevLastProcessed = lastProcessedStreamInGoAway.get();
}
handlePeerUnprocessedStreams(lastProcessedStreamInGoAway.get());
}
private void handlePeerUnprocessedStreams(final long lastProcessedStream) {
final AtomicInteger numClosed = new AtomicInteger(); // atomic merely to allow usage within lambda
streams.forEach((id, exchange) -> {
if (id > lastProcessedStream) {
// any streams with an stream id higher than the last processed stream
// can be retried (on a new connection). we close the exchange as unprocessed
// to facilitate the retrying.
client2.client().theExecutor().ensureExecutedAsync(exchange::closeAsUnprocessed);
numClosed.incrementAndGet();
}
});
if (debug.on()) {
debug.log(numClosed.get() + " stream(s), with id greater than " + lastProcessedStream
+ ", will be closed as unprocessed");
}
}
@ -1745,7 +1783,7 @@ class Http2Connection {
return markClosedState(HALF_CLOSED_LOCAL);
}
private boolean markHalfClosedLRemote() {
private boolean markHalfClosedRemote() {
return markClosedState(HALF_CLOSED_REMOTE);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -90,7 +90,7 @@ class MultiExchange<T> implements Cancelable {
Exchange<T> exchange; // the current exchange
Exchange<T> previous;
volatile Throwable retryCause;
volatile boolean expiredOnce;
volatile boolean retriedOnce;
volatile HttpResponse<T> response;
// Maximum number of times a request will be retried/redirected
@ -469,7 +469,7 @@ class MultiExchange<T> implements Cancelable {
return exch.ignoreBody().handle((r,t) -> {
previousreq = currentreq;
currentreq = newrequest;
expiredOnce = false;
retriedOnce = false;
setExchange(new Exchange<>(currentreq, this, acc));
return responseAsyncImpl();
}).thenCompose(Function.identity());
@ -482,7 +482,7 @@ class MultiExchange<T> implements Cancelable {
return completedFuture(response);
}
// all exceptions thrown are handled here
CompletableFuture<Response> errorCF = getExceptionalCF(ex);
CompletableFuture<Response> errorCF = getExceptionalCF(ex, exch.exchImpl);
if (errorCF == null) {
return responseAsyncImpl();
} else {
@ -554,36 +554,39 @@ class MultiExchange<T> implements Cancelable {
* Takes a Throwable and returns a suitable CompletableFuture that is
* completed exceptionally, or null.
*/
private CompletableFuture<Response> getExceptionalCF(Throwable t) {
private CompletableFuture<Response> getExceptionalCF(Throwable t, ExchangeImpl<?> exchImpl) {
if ((t instanceof CompletionException) || (t instanceof ExecutionException)) {
if (t.getCause() != null) {
t = t.getCause();
}
}
final boolean retryAsUnprocessed = exchImpl != null && exchImpl.isUnprocessedByPeer();
if (cancelled && !requestCancelled() && t instanceof IOException) {
if (!(t instanceof HttpTimeoutException)) {
t = toTimeoutException((IOException)t);
}
} else if (retryOnFailure(t)) {
} else if (retryAsUnprocessed || retryOnFailure(t)) {
Throwable cause = retryCause(t);
if (!(t instanceof ConnectException)) {
// we may need to start a new connection, and if so
// we want to start with a fresh connect timeout again.
if (connectTimeout != null) connectTimeout.reset();
if (!canRetryRequest(currentreq)) {
return failedFuture(cause); // fails with original cause
if (!retryAsUnprocessed && !canRetryRequest(currentreq)) {
// a (peer) processed request which cannot be retried, fail with
// the original cause
return failedFuture(cause);
}
} // ConnectException: retry, but don't reset the connectTimeout.
// allow the retry mechanism to do its work
retryCause = cause;
if (!expiredOnce) {
if (!retriedOnce) {
if (debug.on()) {
debug.log(t.getClass().getSimpleName()
+ " (async): retrying due to: ", t);
+ " (async): retrying " + currentreq + " due to: ", t);
}
expiredOnce = true;
retriedOnce = true;
// The connection was abruptly closed.
// We return null to retry the same request a second time.
// The request filters have already been applied to the
@ -594,7 +597,7 @@ class MultiExchange<T> implements Cancelable {
} else {
if (debug.on()) {
debug.log(t.getClass().getSimpleName()
+ " (async): already retried once.", t);
+ " (async): already retried once " + currentreq, t);
}
t = cause;
}

View File

@ -641,20 +641,39 @@ class Stream<T> extends ExchangeImpl<T> {
stateLock.unlock();
}
try {
int error = frame.getErrorCode();
IOException e = new IOException("Received RST_STREAM: "
+ ErrorFrame.stringForCode(error));
if (errorRef.compareAndSet(null, e)) {
final int error = frame.getErrorCode();
// A REFUSED_STREAM error code implies that the stream wasn't processed by the
// peer and the client is free to retry the request afresh.
if (error == ErrorFrame.REFUSED_STREAM) {
// Here we arrange for the request to be retried. Note that we don't call
// closeAsUnprocessed() method here because the "closed" state is already set
// to true a few lines above and calling close() from within
// closeAsUnprocessed() will end up being a no-op. We instead do the additional
// bookkeeping here.
markUnprocessedByPeer();
errorRef.compareAndSet(null, new IOException("request not processed by peer"));
if (debug.on()) {
debug.log("request unprocessed by peer (REFUSED_STREAM) " + this.request);
}
} else {
final String reason = ErrorFrame.stringForCode(error);
final IOException failureCause = new IOException("Received RST_STREAM: " + reason);
if (debug.on()) {
debug.log(streamid + " received RST_STREAM with code: " + reason);
}
if (errorRef.compareAndSet(null, failureCause)) {
if (subscriber != null) {
subscriber.onError(e);
subscriber.onError(failureCause);
}
}
completeResponseExceptionally(e);
}
final Throwable failureCause = errorRef.get();
completeResponseExceptionally(failureCause);
if (!requestBodyCF.isDone()) {
requestBodyCF.completeExceptionally(errorRef.get()); // we may be sending the body..
requestBodyCF.completeExceptionally(failureCause); // we may be sending the body..
}
if (responseBodyCF != null) {
responseBodyCF.completeExceptionally(errorRef.get());
responseBodyCF.completeExceptionally(failureCause);
}
} finally {
connection.decrementStreamsCount(streamid);
@ -1663,7 +1682,35 @@ class Stream<T> extends ExchangeImpl<T> {
}
final String dbgString() {
return connection.dbgString() + "/Stream("+streamid+")";
final int id = streamid;
final String sid = id == 0 ? "?" : String.valueOf(id);
return connection.dbgString() + "/Stream(" + sid + ")";
}
/**
* An unprocessed exchange is one that hasn't been processed by a peer. The local end of the
* connection would be notified about such exchanges when it receives a GOAWAY frame with
* a stream id that tells which exchanges have been unprocessed.
* This method is called on such unprocessed exchanges and the implementation of this method
* will arrange for the request, corresponding to this exchange, to be retried afresh on a
* new connection.
*/
void closeAsUnprocessed() {
try {
// We arrange for the request to be retried on a new connection as allowed by the RFC-9113
markUnprocessedByPeer();
this.errorRef.compareAndSet(null, new IOException("request not processed by peer"));
if (debug.on()) {
debug.log("closing " + this.request + " as unprocessed by peer");
}
// close the exchange and complete the response CF exceptionally
close();
completeResponseExceptionally(this.errorRef.get());
} finally {
// decrementStreamsCount isn't really needed but we do it to make sure
// the log messages, where these counts/states get reported, show the accurate state.
connection.decrementStreamsCount(streamid);
}
}
private class HeadersConsumer extends ValidatingHeadersConsumer {

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -100,13 +100,16 @@ final class WindowController {
controllerLock.lock();
try {
Integer old = streams.remove(streamid);
// Odd stream numbers (client streams) should have been registered.
// A client initiated stream might be closed (as unprocessed, due to a
// GOAWAY received on the connection) even before the stream is
// registered with this WindowController instance (when sending out request headers).
// Thus, for client initiated streams, we don't enforce the presence of the
// stream in the registered "streams" map.
// Even stream numbers (server streams - aka Push Streams) should
// not be registered
final boolean isClientStream = (streamid & 0x1) == 1;
if (old == null && isClientStream) {
throw new InternalError("Expected entry for streamid: " + streamid);
} else if (old != null && !isClientStream) {
if (old != null && !isClientStream) {
throw new InternalError("Unexpected entry for streamid: " + streamid);
}
} finally {

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -57,7 +57,9 @@ public class GoAwayFrame extends ErrorFrame {
@Override
public String toString() {
return super.toString() + " Debugdata: " + new String(debugData, UTF_8);
return super.toString()
+ " lastStreamId=" + lastStream
+ ", Debugdata: " + new String(debugData, UTF_8);
}
public int getLastStream() {

View File

@ -0,0 +1,336 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandlers;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLContext;
import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange;
import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler;
import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer;
import jdk.test.lib.net.SimpleSSLContext;
import jdk.test.lib.net.URIBuilder;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import static java.net.http.HttpClient.Version.HTTP_2;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;
/*
* @test
* @bug 8335181
* @summary verify that the HttpClient correctly handles incoming GOAWAY frames and
* retries any unprocessed requests on a new connection
* @library /test/lib /test/jdk/java/net/httpclient/lib
* @build jdk.httpclient.test.lib.common.HttpServerAdapters
* jdk.test.lib.net.SimpleSSLContext
* @run junit H2GoAwayTest
*/
public class H2GoAwayTest {
private static final String REQ_PATH = "/test";
private static HttpTestServer server;
private static String REQ_URI_BASE;
private static SSLContext sslCtx;
@BeforeAll
static void beforeAll() throws Exception {
sslCtx = new SimpleSSLContext().get();
assertNotNull(sslCtx, "SSLContext couldn't be created");
server = HttpTestServer.create(HTTP_2, sslCtx);
server.addHandler(new Handler(), REQ_PATH);
server.start();
System.out.println("Server started at " + server.getAddress());
REQ_URI_BASE = URIBuilder.newBuilder().scheme("https")
.loopback()
.port(server.getAddress().getPort())
.path(REQ_PATH)
.build().toString();
}
@AfterAll
static void afterAll() {
if (server != null) {
System.out.println("Stopping server at " + server.getAddress());
server.stop();
}
}
/**
* Verifies that when several requests are sent using send() and the server
* connection is configured to send a GOAWAY after processing only a few requests, then
* the remaining requests are retried on a different connection
*/
@Test
public void testSequential() throws Exception {
final LimitedPerConnRequestApprover reqApprover = new LimitedPerConnRequestApprover();
server.setRequestApprover(reqApprover::allowNewRequest);
try (final HttpClient client = HttpClient.newBuilder().version(HTTP_2)
.sslContext(sslCtx).build()) {
final String[] reqMethods = {"HEAD", "GET", "POST"};
for (final String reqMethod : reqMethods) {
final int numReqs = LimitedPerConnRequestApprover.MAX_REQS_PER_CONN + 3;
final Set<String> connectionKeys = new LinkedHashSet<>();
for (int i = 1; i <= numReqs; i++) {
final URI reqURI = new URI(REQ_URI_BASE + "?seq&" + reqMethod + "=" + i);
final HttpRequest req = HttpRequest.newBuilder()
.uri(reqURI)
.method(reqMethod, HttpRequest.BodyPublishers.noBody())
.build();
System.out.println("initiating request " + req);
final HttpResponse<String> resp = client.send(req, BodyHandlers.ofString());
final String respBody = resp.body();
System.out.println("received response: " + respBody);
assertEquals(200, resp.statusCode(),
"unexpected status code for request " + resp.request());
// response body is the logical key of the connection on which the
// request was handled
connectionKeys.add(respBody);
}
System.out.println("connections involved in handling the requests: "
+ connectionKeys);
// all requests have finished, we now just do a basic check that
// more than one connection was involved in processing these requests
assertEquals(2, connectionKeys.size(),
"unexpected number of connections " + connectionKeys);
}
} finally {
server.setRequestApprover(null); // reset
}
}
/**
* Verifies that when a server responds with a GOAWAY and then never processes the new retried
* requests on a new connection too, then the application code receives the request failure.
* This tests the send() API of the HttpClient.
*/
@Test
public void testUnprocessedRaisesException() throws Exception {
try (final HttpClient client = HttpClient.newBuilder().version(HTTP_2)
.sslContext(sslCtx).build()) {
final Random random = new Random();
final String[] reqMethods = {"HEAD", "GET", "POST"};
for (final String reqMethod : reqMethods) {
final int maxAllowedReqs = 2;
final int numReqs = maxAllowedReqs + 3; // 3 more requests than max allowed
// configure the approver
final LimitedRequestApprover reqApprover = new LimitedRequestApprover(maxAllowedReqs);
server.setRequestApprover(reqApprover::allowNewRequest);
try {
int numSuccess = 0;
int numFailed = 0;
for (int i = 1; i <= numReqs; i++) {
final String reqQueryPart = "?sync&" + reqMethod + "=" + i;
final URI reqURI = new URI(REQ_URI_BASE + reqQueryPart);
final HttpRequest req = HttpRequest.newBuilder()
.uri(reqURI)
.method(reqMethod, HttpRequest.BodyPublishers.noBody())
.build();
System.out.println("initiating request " + req);
if (i <= maxAllowedReqs) {
// expected to successfully complete
numSuccess++;
final HttpResponse<String> resp = client.send(req, BodyHandlers.ofString());
final String respBody = resp.body();
System.out.println("received response: " + respBody);
assertEquals(200, resp.statusCode(),
"unexpected status code for request " + resp.request());
} else {
// expected to fail as unprocessed
try {
final HttpResponse<String> resp = client.send(req, BodyHandlers.ofString());
fail("Request was expected to fail as unprocessed,"
+ " but got response: " + resp.body() + ", status code: "
+ resp.statusCode());
} catch (IOException ioe) {
// verify it failed for the right reason
if (ioe.getMessage() == null
|| !ioe.getMessage().contains("request not processed by peer")) {
// propagate the original failure
throw ioe;
}
numFailed++; // failed due to right reason
System.out.println("received expected failure: " + ioe
+ ", for request " + reqURI);
}
}
}
// verify the correct number of requests succeeded/failed
assertEquals(maxAllowedReqs, numSuccess, "unexpected number of requests succeeded");
assertEquals((numReqs - maxAllowedReqs), numFailed, "unexpected number of requests failed");
} finally {
server.setRequestApprover(null); // reset
}
}
}
}
/**
* Verifies that when a server responds with a GOAWAY and then never processes the new retried
* requests on a new connection too, then the application code receives the request failure.
* This tests the sendAsync() API of the HttpClient.
*/
@Test
public void testUnprocessedRaisesExceptionAsync() throws Throwable {
try (final HttpClient client = HttpClient.newBuilder().version(HTTP_2)
.sslContext(sslCtx).build()) {
final Random random = new Random();
final String[] reqMethods = {"HEAD", "GET", "POST"};
for (final String reqMethod : reqMethods) {
final int maxAllowedReqs = 2;
final int numReqs = maxAllowedReqs + 3; // 3 more requests than max allowed
// configure the approver
final LimitedRequestApprover reqApprover = new LimitedRequestApprover(maxAllowedReqs);
server.setRequestApprover(reqApprover::allowNewRequest);
try {
final List<Future<HttpResponse<String>>> futures = new ArrayList<>();
for (int i = 1; i <= numReqs; i++) {
final URI reqURI = new URI(REQ_URI_BASE + "?async&" + reqMethod + "=" + i);
final HttpRequest req = HttpRequest.newBuilder()
.uri(reqURI)
.method(reqMethod, HttpRequest.BodyPublishers.noBody())
.build();
System.out.println("initiating request " + req);
final Future<HttpResponse<String>> f = client.sendAsync(req, BodyHandlers.ofString());
futures.add(f);
}
// wait for responses
int numFailed = 0;
int numSuccess = 0;
for (int i = 1; i <= numReqs; i++) {
final String reqQueryPart = "?async&" + reqMethod + "=" + i;
try {
System.out.println("waiting response of request "
+ REQ_URI_BASE + reqQueryPart);
final HttpResponse<String> resp = futures.get(i - 1).get();
numSuccess++;
final String respBody = resp.body();
System.out.println("request: " + resp.request()
+ ", received response: " + respBody);
assertEquals(200, resp.statusCode(),
"unexpected status code for request " + resp.request());
} catch (ExecutionException ee) {
final Throwable cause = ee.getCause();
if (!(cause instanceof IOException ioe)) {
throw cause;
}
// verify it failed for the right reason
if (ioe.getMessage() == null
|| !ioe.getMessage().contains("request not processed by peer")) {
// propagate the original failure
throw ioe;
}
numFailed++; // failed due to the right reason
System.out.println("received expected failure: " + ioe
+ ", for request " + REQ_URI_BASE + reqQueryPart);
}
}
// verify the correct number of requests succeeded/failed
assertEquals(maxAllowedReqs, numSuccess, "unexpected number of requests succeeded");
assertEquals((numReqs - maxAllowedReqs), numFailed, "unexpected number of requests failed");
} finally {
server.setRequestApprover(null); // reset
}
}
}
}
// only allows fixed number of requests, irrespective of which server connection handles
// it. requests that are rejected will either be sent a GOAWAY on the connection
// or a RST_FRAME with a REFUSED_STREAM on the stream
private static final class LimitedRequestApprover {
private final int maxAllowedReqs;
private final AtomicInteger numApproved = new AtomicInteger();
private LimitedRequestApprover(final int maxAllowedReqs) {
this.maxAllowedReqs = maxAllowedReqs;
}
public boolean allowNewRequest(final String serverConnKey) {
final int approved = numApproved.incrementAndGet();
return approved <= maxAllowedReqs;
}
}
// allows a certain number of requests per server connection.
// requests that are rejected will either be sent a GOAWAY on the connection
// or a RST_FRAME with a REFUSED_STREAM on the stream
private static final class LimitedPerConnRequestApprover {
private static final int MAX_REQS_PER_CONN = 6;
private final Map<String, AtomicInteger> numApproved =
new ConcurrentHashMap<>();
private final Map<String, AtomicInteger> numDisapproved =
new ConcurrentHashMap<>();
public boolean allowNewRequest(final String serverConnKey) {
final AtomicInteger approved = numApproved.computeIfAbsent(serverConnKey,
(k) -> new AtomicInteger());
int curr = approved.get();
while (curr < MAX_REQS_PER_CONN) {
if (approved.compareAndSet(curr, curr + 1)) {
return true; // new request allowed
}
curr = approved.get();
}
final AtomicInteger disapproved = numDisapproved.computeIfAbsent(serverConnKey,
(k) -> new AtomicInteger());
final int numUnprocessed = disapproved.incrementAndGet();
System.out.println(approved.get() + " processed, "
+ numUnprocessed + " unprocessed requests on connection " + serverConnKey);
return false;
}
}
private static final class Handler implements HttpTestHandler {
@Override
public void handle(final HttpTestExchange exchange) throws IOException {
final String connectionKey = exchange.getConnectionKey();
System.out.println("responding to request: " + exchange.getRequestURI()
+ " on connection " + connectionKey);
final byte[] response = connectionKey.getBytes(UTF_8);
exchange.sendResponseHeaders(200, response.length);
try (final OutputStream os = exchange.getResponseBody()) {
os.write(response);
}
}
}
}

View File

@ -58,6 +58,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@ -239,6 +240,7 @@ public interface HttpServerAdapters {
public abstract String getRequestMethod();
public abstract void close();
public abstract InetSocketAddress getRemoteAddress();
public abstract String getConnectionKey();
public void serverPush(URI uri, HttpHeaders headers, byte[] body) {
ByteArrayInputStream bais = new ByteArrayInputStream(body);
serverPush(uri, headers, bais);
@ -253,7 +255,7 @@ public interface HttpServerAdapters {
return new Http1TestExchange(exchange);
}
public static HttpTestExchange of(Http2TestExchange exchange) {
return new Http2TestExchangeImpl(exchange);
return new H2ExchangeImpl(exchange);
}
abstract void doFilter(Filter.Chain chain) throws IOException;
@ -306,15 +308,21 @@ public interface HttpServerAdapters {
public URI getRequestURI() { return exchange.getRequestURI(); }
@Override
public String getRequestMethod() { return exchange.getRequestMethod(); }
@Override
public String getConnectionKey() {
return exchange.getLocalAddress() + "->" + exchange.getRemoteAddress();
}
@Override
public String toString() {
return this.getClass().getSimpleName() + ": " + exchange.toString();
}
}
private static final class Http2TestExchangeImpl extends HttpTestExchange {
private static final class H2ExchangeImpl extends HttpTestExchange {
private final Http2TestExchange exchange;
Http2TestExchangeImpl(Http2TestExchange exch) {
H2ExchangeImpl(Http2TestExchange exch) {
this.exchange = exch;
}
@Override
@ -363,6 +371,11 @@ public interface HttpServerAdapters {
return exchange.getRemoteAddress();
}
@Override
public String getConnectionKey() {
return exchange.getConnectionKey();
}
@Override
public URI getRequestURI() { return exchange.getRequestURI(); }
@Override
@ -708,6 +721,7 @@ public interface HttpServerAdapters {
public abstract HttpTestContext addHandler(HttpTestHandler handler, String root);
public abstract InetSocketAddress getAddress();
public abstract Version getVersion();
public abstract void setRequestApprover(final Predicate<String> approver);
public String serverAuthority() {
InetSocketAddress address = getAddress();
@ -856,6 +870,11 @@ public interface HttpServerAdapters {
impl.getAddress().getPort());
}
public Version getVersion() { return Version.HTTP_1_1; }
@Override
public void setRequestApprover(final Predicate<String> approver) {
throw new UnsupportedOperationException("not supported");
}
}
private static class Http1TestContext extends HttpTestContext {
@ -907,6 +926,11 @@ public interface HttpServerAdapters {
impl.getAddress().getPort());
}
public Version getVersion() { return Version.HTTP_2; }
@Override
public void setRequestApprover(final Predicate<String> approver) {
this.impl.setRequestApprover(approver);
}
}
private static class Http2TestContext

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2017, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -71,4 +71,10 @@ public interface Http2TestExchange {
* It may also complete exceptionally
*/
CompletableFuture<Long> sendPing();
/**
* {@return the identification of the connection on which this exchange is being
* processed}
*/
String getConnectionKey();
}

View File

@ -220,6 +220,11 @@ public class Http2TestExchangeImpl implements Http2TestExchange {
}
}
@Override
public String getConnectionKey() {
return conn.connectionKey();
}
private boolean isHeadRequest() {
return HEAD.equalsIgnoreCase(getRequestMethod());
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -32,6 +32,8 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import javax.net.ServerSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
@ -59,6 +61,8 @@ public class Http2TestServer implements AutoCloseable {
final Set<Http2TestServerConnection> connections;
final Properties properties;
final String name;
// request approver which takes the server connection key as the input
private volatile Predicate<String> newRequestApprover;
private static ThreadFactory defaultThreadFac =
(Runnable r) -> {
@ -285,6 +289,14 @@ public class Http2TestServer implements AutoCloseable {
return serverName;
}
public void setRequestApprover(final Predicate<String> approver) {
this.newRequestApprover = approver;
}
Predicate<String> getRequestApprover() {
return this.newRequestApprover;
}
private synchronized void putConnection(InetSocketAddress addr, Http2TestServerConnection c) {
if (!stopping)
connections.add(c);

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -75,13 +75,18 @@ import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.function.Predicate;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.nio.charset.StandardCharsets.UTF_8;
import static jdk.internal.net.http.frame.ErrorFrame.REFUSED_STREAM;
import static jdk.internal.net.http.frame.SettingsFrame.HEADER_TABLE_SIZE;
/**
@ -110,6 +115,10 @@ public class Http2TestServerConnection {
volatile boolean stopping;
volatile int nextPushStreamId = 2;
ConcurrentLinkedQueue<PingRequest> pings = new ConcurrentLinkedQueue<>();
// the max stream id of a processed H2 request. -1 implies none were processed.
private final AtomicInteger maxProcessedRequestStreamId = new AtomicInteger(-1);
// the stream id that was sent in a GOAWAY frame. -1 implies no GOAWAY frame was sent.
private final AtomicInteger goAwayRequestStreamId = new AtomicInteger(-1);
final static ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
final static byte[] EMPTY_BARRAY = new byte[0];
@ -234,11 +243,29 @@ public class Http2TestServerConnection {
return ping.response();
}
void goAway(int error) throws IOException {
int laststream = nextstream >= 3 ? nextstream - 2 : 1;
GoAwayFrame go = new GoAwayFrame(laststream, error);
outputQ.put(go);
private void sendGoAway(final int error) throws IOException {
int maxProcessedStreamId = maxProcessedRequestStreamId.get();
if (maxProcessedStreamId == -1) {
maxProcessedStreamId = 0;
}
boolean send = false;
int currentGoAwayReqStrmId = goAwayRequestStreamId.get();
// update the last processed stream id and send a goaway frame if the new last processed
// stream id is lesser than the last processed stream id sent in
// a previous goaway frame (if any)
while (currentGoAwayReqStrmId == -1 || maxProcessedStreamId < currentGoAwayReqStrmId) {
if (goAwayRequestStreamId.compareAndSet(currentGoAwayReqStrmId, maxProcessedStreamId)) {
send = true;
break;
}
currentGoAwayReqStrmId = goAwayRequestStreamId.get();
}
if (!send) {
return;
}
final GoAwayFrame frame = new GoAwayFrame(maxProcessedStreamId, error);
outputQ.put(frame);
System.err.println("Sending GOAWAY frame " + frame + " from server connection " + this);
}
/**
@ -331,8 +358,9 @@ public class Http2TestServerConnection {
q.orderlyClose();
});
try {
if (error != -1)
goAway(error);
if (error != -1) {
sendGoAway(error);
}
outputQ.orderlyClose();
socket.close();
} catch (Exception e) {
@ -612,6 +640,14 @@ public class Http2TestServerConnection {
path = path + "?" + uri.getRawQuery();
headersBuilder.setHeader(":path", path);
// skip processing the request if configured to do so
final String connKey = connectionKey();
if (!shouldProcessNewHTTPRequest(connKey)) {
System.err.println("Rejecting primordial stream 1 and sending GOAWAY" +
" on server connection " + connKey + ", for request: " + path);
sendGoAway(ErrorFrame.NO_ERROR);
return;
}
Queue q = new Queue(sentinel);
byte[] body = getRequestBody(request);
addHeaders(getHeaders(request.headers), headersBuilder);
@ -620,11 +656,24 @@ public class Http2TestServerConnection {
addRequestBodyToQueue(body, q);
streams.put(1, q);
maxProcessedRequestStreamId.set(1);
exec.submit(() -> {
handleRequest(headers, q, 1, true /*complete request has been read*/);
});
}
private boolean shouldProcessNewHTTPRequest(final String serverConnKey) {
final Predicate<String> approver = this.server.getRequestApprover();
if (approver == null) {
return true; // process the request
}
return approver.test(serverConnKey);
}
final String connectionKey() {
return this.server.getAddress() + "->" + this.socket.getRemoteSocketAddress();
}
// all other streams created here
@SuppressWarnings({"rawtypes","unchecked"})
void createStream(HeaderFrame frame) throws IOException {
@ -632,7 +681,7 @@ public class Http2TestServerConnection {
frames.add(frame);
int streamid = frame.streamid();
if (streamid != nextstream) {
throw new IOException("unexpected stream id");
throw new IOException("unexpected stream id: " + streamid);
}
nextstream += 2;
@ -663,12 +712,30 @@ public class Http2TestServerConnection {
throw new IOException("Unexpected Upgrade in headers:" + headers);
}
disallowedHeader = headers.firstValue("HTTP2-Settings");
if (disallowedHeader.isPresent())
if (disallowedHeader.isPresent()) {
throw new IOException("Unexpected HTTP2-Settings in headers:" + headers);
}
// skip processing the request if the server is configured to do so
final String connKey = connectionKey();
final String path = headers.firstValue(":path").orElse("");
if (!shouldProcessNewHTTPRequest(connKey)) {
System.err.println("Rejecting stream " + streamid
+ " and sending GOAWAY on server connection "
+ connKey + ", for request: " + path);
sendGoAway(ErrorFrame.NO_ERROR);
return;
}
Queue q = new Queue(sentinel);
streams.put(streamid, q);
// keep track of the largest request id that we have processed
int currentLargest = maxProcessedRequestStreamId.get();
while (streamid > currentLargest) {
if (maxProcessedRequestStreamId.compareAndSet(currentLargest, streamid)) {
break;
}
currentLargest = maxProcessedRequestStreamId.get();
}
exec.submit(() -> {
handleRequest(headers, q, streamid, endStreamReceived);
});
@ -763,6 +830,8 @@ public class Http2TestServerConnection {
while (!stopping) {
Http2Frame frame = readFrameImpl();
if (frame == null) {
System.err.println("EOF reached on connection " + connectionKey()
+ ", will no longer accept incoming frames");
closeIncoming();
return;
}
@ -786,6 +855,17 @@ public class Http2TestServerConnection {
// TODO: close connection
continue;
} else {
final int streamId = frame.streamid();
final int finalProcessedStreamId = goAwayRequestStreamId.get();
// if we already sent a goaway, then don't create new streams with
// higher stream ids.
if (finalProcessedStreamId != -1 && streamId > finalProcessedStreamId) {
System.err.println(connectionKey() + " resetting stream " + streamId
+ " as REFUSED_STREAM");
final ResetFrame rst = new ResetFrame(streamId, REFUSED_STREAM);
outputQ.put(rst);
continue;
}
createStream((HeadersFrame) frame);
}
} else {