/* * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ /** * @test * @bug 8218554 * @summary test that the handler can request a connection close. * @library /test/lib * @build jdk.test.lib.net.SimpleSSLContext * @run main/othervm HandlerConnectionClose */ import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsConfigurator; import com.sun.net.httpserver.HttpsServer; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; import java.net.URI; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Locale; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.Logger; import java.util.logging.SimpleFormatter; import java.util.logging.StreamHandler; import jdk.test.lib.net.SimpleSSLContext; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSession; public class HandlerConnectionClose { static final int ONEK = 1024; static final long POST_SIZE = ONEK * 1L; SSLContext sslContext; Logger logger; void test(String[] args) throws Exception { HttpServer httpServer = startHttpServer("http"); try { testHttpURLConnection(httpServer, "http","/close/legacy/http/chunked"); testHttpURLConnection(httpServer, "http","/close/legacy/http/fixed"); testPlainSocket(httpServer, "http","/close/plain/http/chunked"); testPlainSocket(httpServer, "http","/close/plain/http/fixed"); } finally { httpServer.stop(0); } sslContext = new SimpleSSLContext().get(); HttpServer httpsServer = startHttpServer("https"); try { testHttpURLConnection(httpsServer, "https","/close/legacy/https/chunked"); testHttpURLConnection(httpsServer, "https","/close/legacy/https/fixed"); testPlainSocket(httpsServer, "https","/close/plain/https/chunked"); testPlainSocket(httpsServer, "https","/close/plain/https/fixed"); } finally{ httpsServer.stop(0); } } void testHttpURLConnection(HttpServer httpServer, String protocol, String path) throws Exception { int port = httpServer.getAddress().getPort(); String host = httpServer.getAddress().getHostString(); URL url = new URI(protocol, null, host, port, path, null, null).toURL(); HttpURLConnection uc = (HttpURLConnection) url.openConnection(); if ("https".equalsIgnoreCase(protocol)) { ((HttpsURLConnection)uc).setSSLSocketFactory(sslContext.getSocketFactory()); ((HttpsURLConnection)uc).setHostnameVerifier((String hostname, SSLSession session) -> true); } uc.setDoOutput(true); uc.setRequestMethod("POST"); uc.setFixedLengthStreamingMode(POST_SIZE); OutputStream os = uc.getOutputStream(); /* create a 1K byte array with data to POST */ byte[] ba = new byte[ONEK]; for (int i = 0; i < ONEK; i++) ba[i] = (byte) i; System.out.println("\n" + uc.getClass().getSimpleName() +": POST " + url + " HTTP/1.1"); long times = POST_SIZE / ONEK; for (int i = 0; i < times; i++) { os.write(ba); } os.close(); InputStream is = uc.getInputStream(); int read; long count = 0; while ((read = is.read(ba)) != -1) { for (int i = 0; i < read; i++) { byte expected = (byte) count++; if (ba[i] != expected) { throw new IOException("byte mismatch at " + (count - 1) + ": expected " + expected + " got " + ba[i]); } } } if (count != POST_SIZE) { throw new IOException("Unexpected length: " + count + " expected " + POST_SIZE); } is.close(); pass(); } void testPlainSocket(HttpServer httpServer, String protocol, String path) throws Exception { int port = httpServer.getAddress().getPort(); String host = httpServer.getAddress().getHostString(); URL url = new URI(protocol, null, host, port, path, null, null).toURL(); Socket socket; if ("https".equalsIgnoreCase(protocol)) { socket = sslContext.getSocketFactory().createSocket(host, port); } else { socket = new Socket(host, port); } try (Socket sock = socket) { OutputStream os = socket.getOutputStream(); // send request headers String request = new StringBuilder() .append("POST ").append(path).append(" HTTP/1.1").append("\r\n") .append("host: ").append(host).append(':').append(port).append("\r\n") .append("Content-Length: ").append(POST_SIZE).append("\r\n") .append("\r\n") .toString(); os.write(request.getBytes(StandardCharsets.US_ASCII)); /* create a 1K byte array with data to POST */ byte[] ba = new byte[ONEK]; for (int i = 0; i < ONEK; i++) ba[i] = (byte) i; // send request data long times = POST_SIZE / ONEK; for (int i = 0; i < times; i++) { os.write(ba); } os.flush(); InputStream is = socket.getInputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream(); // read all response headers int c; int crlf = 0; while ((c = is.read()) != -1) { if (c == '\r') continue; if (c == '\n') crlf++; else crlf = 0; bos.write(c); if (crlf == 2) break; } String responseHeadersStr = bos.toString(StandardCharsets.US_ASCII); List responseHeaders = List.of(responseHeadersStr.split("\n")); System.out.println("\nPOST " + url + " HTTP/1.1"); responseHeaders.stream().forEach(s -> System.out.println("[reply]\t" + s)); String statusLine = responseHeaders.get(0); if (!statusLine.startsWith("HTTP/1.1 200 ")) throw new IOException("Unexpected status: " + statusLine); String cl = responseHeaders.stream() .map(s -> s.toLowerCase(Locale.ROOT)) .filter(s -> s.startsWith("content-length: ")) .findFirst() .orElse(null); String te = responseHeaders.stream() .map(s -> s.toLowerCase(Locale.ROOT)) .filter(s -> s.startsWith("transfer-encoding: ")) .findFirst() .orElse(null); // check content-length and transfer-encoding are as expected int read = 0; long count = 0; if (path.endsWith("/fixed")) { if (!("content-length: " + POST_SIZE).equalsIgnoreCase(cl)) { throw new IOException("Unexpected Content-Length: [" + cl + "]"); } if (te != null) { throw new IOException("Unexpected Transfer-Encoding: [" + te + "]"); } // Got expected Content-Length: 1024 - read response data while ((read = is.read()) != -1) { int expected = (int) (count & 0xFF); if ((read & 0xFF) != expected) { throw new IOException("byte mismatch at " + (count - 1) + ": expected " + expected + " got " + read); } if (++count == POST_SIZE) break; } } else if (cl != null) { throw new IOException("Unexpected Content-Length: [" + cl + "]"); } else { if (!("transfer-encoding: chunked").equalsIgnoreCase(te)) { throw new IOException("Unexpected Transfer-Encoding: [" + te + "]"); } // This is a quick & dirty implementation of // chunk decoding - no trailers - no extensions StringBuilder chunks = new StringBuilder(); int cs = -1; while (cs != 0) { cs = 0; chunks.setLength(0); // read next chunk length while ((read = is.read()) != -1) { if (read == '\r') continue; if (read == '\n') break; chunks.append((char) read); } cs = Integer.parseInt(chunks.toString().trim(), 16); System.out.println("Got chunk length: " + cs); // If chunk size is 0, then we have read the last chunk. if (cs == 0) break; // Read the chunk data while (--cs >= 0) { read = is.read(); if (read == -1) break; // EOF int expected = (int) (count & 0xFF); if ((read & 0xFF) != expected) { throw new IOException("byte mismatch at " + (count - 1) + ": expected " + expected + " got " + read); } // This is cheating: we know the size :-) if (++count == POST_SIZE) break; } if (read == -1) { throw new IOException("Unexpected EOF after " + count + " data bytes"); } // read CRLF if ((read = is.read()) != '\r') { throw new IOException("Expected CR at " + count + "after chunk data - got " + read); } if ((read = is.read()) != '\n') { throw new IOException("Expected LF at " + count + "after chunk data - got " + read); } if (cs == 0 && count == POST_SIZE) { cs = -1; } if (cs != -1) { // count == POST_SIZE, but some chunk data still to be read? throw new IOException("Unexpected chunk size, " + cs + " bytes still to read after " + count + " data bytes received."); } } // Last CRLF? for (int i = 0; i < 2; i++) { if ((read = is.read()) == -1) break; } } if (count != POST_SIZE) { throw new IOException("Unexpected length: " + count + " expected " + POST_SIZE); } if (!sock.isClosed()) { try { // We send an end request to the server to verify that the // connection is closed. If the server has not closed the // connection, it will reply. If we receive a response, // we should fail... String endrequest = new StringBuilder() .append("GET ").append("/close/end").append(" HTTP/1.1").append("\r\n") .append("host: ").append(host).append(':').append(port).append("\r\n") .append("Content-Length: ").append(0).append("\r\n") .append("\r\n") .toString(); os.write(endrequest.getBytes(StandardCharsets.US_ASCII)); os.flush(); StringBuilder resp = new StringBuilder(); crlf = 0; // read all headers. // If the server closed the connection as expected // we should immediately read EOF while ((read = is.read()) != -1) { if (read == '\r') continue; if (read == '\n') crlf++; else crlf = 0; if (crlf == 2) break; resp.append((char) read); } List lines = List.of(resp.toString().split("\n")); if (read != -1 || resp.length() != 0) { System.err.println("Connection not closed!"); System.err.println("Got: "); lines.stream().forEach(s -> System.err.println("[end]\t" + s)); throw new AssertionError("EOF not received after " + count + " data bytes"); } if (read != -1) { throw new AssertionError("EOF was expected after " + count + " bytes, but got: " + read); } else { System.out.println("Got expected EOF (" + read + ")"); } } catch (IOException x) { // expected! all is well System.out.println("Socket closed as expected, got exception writing to it."); } } else { System.out.println("Socket closed as expected"); } pass(); } } /** * Http Server */ HttpServer startHttpServer(String protocol) throws IOException { if (debug) { logger = Logger.getLogger("com.sun.net.httpserver"); Handler outHandler = new StreamHandler(System.out, new SimpleFormatter()); outHandler.setLevel(Level.FINEST); logger.setLevel(Level.FINEST); logger.addHandler(outHandler); } InetSocketAddress serverAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); HttpServer httpServer = null; if ("http".equalsIgnoreCase(protocol)) { httpServer = HttpServer.create(serverAddress, 0); } if ("https".equalsIgnoreCase(protocol)) { HttpsServer httpsServer = HttpsServer.create(serverAddress, 0); httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext)); httpServer = httpsServer; } httpServer.createContext("/close/", new MyHandler(POST_SIZE)); System.out.println("Server created at: " + httpServer.getAddress()); httpServer.start(); return httpServer; } class MyHandler implements HttpHandler { static final int BUFFER_SIZE = 512; final long expected; MyHandler(long expected){ this.expected = expected; } @Override public void handle(HttpExchange t) throws IOException { System.out.println("Server: serving " + t.getRequestURI()); boolean chunked = t.getRequestURI().getPath().endsWith("/chunked"); boolean fixed = t.getRequestURI().getPath().endsWith("/fixed"); boolean end = t.getRequestURI().getPath().endsWith("/end"); long responseLength = fixed ? POST_SIZE : 0; responseLength = end ? -1 : responseLength; responseLength = chunked ? 0 : responseLength; if (!end) t.getResponseHeaders().add("connection", "CLose"); t.sendResponseHeaders(200, responseLength); if (!end) { OutputStream os = t.getResponseBody(); InputStream is = t.getRequestBody(); byte[] ba = new byte[BUFFER_SIZE]; int read; long count = 0L; while ((read = is.read(ba)) != -1) { count += read; os.write(ba, 0, read); } is.close(); check(count == expected, "Expected: " + expected + ", received " + count); debug("Received " + count + " bytes"); os.close(); } t.close(); } } //--------------------- Infrastructure --------------------------- boolean debug = true; volatile int passed = 0, failed = 0; void pass() {passed++;} void fail() {failed++; Thread.dumpStack();} void fail(String msg) {System.err.println(msg); fail();} void unexpected(Throwable t) {failed++; t.printStackTrace();} void check(boolean cond) {if (cond) pass(); else fail();} void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);} void debug(String message) {if(debug) { System.out.println(message); } } public static void main(String[] args) throws Throwable { Class k = new Object(){}.getClass().getEnclosingClass(); try {k.getMethod("instanceMain",String[].class) .invoke( k.newInstance(), (Object) args);} catch (Throwable e) {throw e.getCause();}} public void instanceMain(String[] args) throws Throwable { try {test(args);} catch (Throwable t) {unexpected(t);} System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed); if (failed > 0) throw new AssertionError("Some tests failed");} }