/*
 * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

import com.sun.net.httpserver.HttpServer;
import com.sun.net.httpserver.HttpsConfigurator;
import com.sun.net.httpserver.HttpsServer;
import jdk.test.lib.net.IPSupport;
import jdk.test.lib.net.SimpleSSLContext;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import javax.net.ssl.SSLContext;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Predicate;

/*
 * @test
 * @summary Tests HttpClient usage when configured with a local address to bind
 *          to, when sending requests
 * @bug 8209137
 * @modules jdk.httpserver
 *          java.net.http/jdk.internal.net.http.common
 *          java.net.http/jdk.internal.net.http.frame
 *          java.base/sun.net.www.http
 *          java.net.http/jdk.internal.net.http.hpack
 *
 * @library /test/lib http2/server
 *
 * @build jdk.test.lib.net.SimpleSSLContext jdk.test.lib.net.IPSupport HttpServerAdapters
 *
 * @run testng/othervm
 *      -Djdk.httpclient.HttpClient.log=frames,ssl,requests,responses,errors
 *      -Djdk.internal.httpclient.debug=true
 *      HttpClientLocalAddrTest
 *
 * @run testng/othervm/java.security.policy=httpclient-localaddr-security.policy
 *      -Djdk.httpclient.HttpClient.log=frames,ssl,requests,responses,errors
 *      -Djdk.internal.httpclient.debug=true
 *      HttpClientLocalAddrTest
 *
 */
public class HttpClientLocalAddrTest implements HttpServerAdapters {

    private static SSLContext sslContext;
    private static HttpServerAdapters.HttpTestServer http1_1_Server;
    private static URI httpURI;
    private static HttpServerAdapters.HttpTestServer https_1_1_Server;
    private static URI httpsURI;
    private static HttpServerAdapters.HttpTestServer http2_Server;
    private static URI http2URI;
    private static HttpServerAdapters.HttpTestServer https2_Server;
    private static URI https2URI;

    // start various HTTP/HTTPS servers that will be invoked against in the tests
    @BeforeClass
    public static void beforeClass() throws Exception {
        sslContext = new SimpleSSLContext().get();
        Assert.assertNotNull(sslContext, "Unexpected null sslContext");

        HttpServerAdapters.HttpTestHandler handler = (exchange) -> {
            // the handler receives a request and sends back a 200 response with the
            // response body containing the raw IP address (in byte[] form) of the client from whom
            // the request was received
            var clientAddr = exchange.getRemoteAddress();
            System.out.println("Received a request from client address " + clientAddr);
            var responseContent = clientAddr.getAddress().getAddress();
            exchange.sendResponseHeaders(200, responseContent.length);
            try (var os = exchange.getResponseBody()) {
                // write out the client address as a response
                os.write(responseContent);
            }
            exchange.close();
        };

        // HTTP/1.1 - create servers with http and https
        final var sa = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0);
        final int backlog = 0;
        http1_1_Server = HttpServerAdapters.HttpTestServer.of(HttpServer.create(sa, backlog));
        http1_1_Server.addHandler(handler, "/");
        http1_1_Server.start();
        System.out.println("Started HTTP v1.1 server at " + http1_1_Server.serverAuthority());
        httpURI = new URI("http://" + http1_1_Server.serverAuthority() + "/");

        final HttpsServer httpsServer = HttpsServer.create(sa, 0);
        httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
        https_1_1_Server = HttpServerAdapters.HttpTestServer.of(httpsServer);
        https_1_1_Server.addHandler(handler, "/");
        https_1_1_Server.start();
        System.out.println("Started HTTPS v1.1 server at " + https_1_1_Server.serverAuthority());
        httpsURI = new URI("https://" + https_1_1_Server.serverAuthority() + "/");

        // HTTP/2 - create servers with http and https
        http2_Server = HttpServerAdapters.HttpTestServer.of(new Http2TestServer(sa.getHostString(), false, null));
        http2_Server.addHandler(handler, "/");
        http2_Server.start();
        System.out.println("Started HTTP v2 server at " + http2_Server.serverAuthority());
        http2URI = new URI("http://" + http2_Server.serverAuthority() + "/");

        https2_Server = HttpServerAdapters.HttpTestServer.of(new Http2TestServer(sa.getHostString(), true, sslContext));
        https2_Server.addHandler(handler, "/");
        https2_Server.start();
        System.out.println("Started HTTPS v2 server at " + https2_Server.serverAuthority());
        https2URI = new URI("https://" + https2_Server.serverAuthority() + "/");
    }

    // stop each of the started servers
    @AfterClass
    public static void afterClass() throws Exception {
        // stop each of the server and accumulate any exception
        // that might happen during stop and finally throw
        // the accumulated exception(s)
        var e = safeStop(http1_1_Server, null);
        e = safeStop(https_1_1_Server, e);
        e = safeStop(http2_Server, e);
        e = safeStop(https2_Server, e);
        // throw any exception that happened during stop
        if (e != null) {
            throw e;
        }
    }

    /**
     * Stops the server and returns (instead of throwing) any exception that might
     * have occurred during stop. If {@code prevException} is not null then any
     * exception during stop of the {@code server} will be added as a suppressed
     * exception to the {@code prevException} and the {@code prevException} will be
     * returned.
     */
    private static Exception safeStop(HttpServerAdapters.HttpTestServer server, Exception prevException) {
        if (server == null) {
            return null;
        }
        var serverAuthority = server.serverAuthority();
        try {
            server.stop();
        } catch (Exception e) {
            System.err.println("Failed to stop server " + serverAuthority);
            if (prevException == null) {
                return e;
            }
            prevException.addSuppressed(e);
            return prevException;
        }
        return prevException;
    }

    @DataProvider(name = "params")
    private Object[][] paramsProvider() throws Exception {
        final List<Object[]> testMethodParams = new ArrayList();
        final URI[] requestURIs = new URI[]{httpURI, httpsURI, http2URI, https2URI};
        final Predicate<URI> requiresSSLContext = (uri) -> uri.getScheme().equals("https");
        for (var requestURI : requestURIs) {
            final var configureClientSSL = requiresSSLContext.test(requestURI);
            // no localAddr set
            testMethodParams.add(new Object[]{
                    newBuilder(configureClientSSL).build(),
                    requestURI,
                    null
            });
            // null localAddr set
            testMethodParams.add(new Object[]{
                    newBuilder(configureClientSSL).localAddress(null).build(),
                    requestURI,
                    null
            });
            // localAddr set to loopback address
            final var loopbackAddr = InetAddress.getLoopbackAddress();
            testMethodParams.add(new Object[]{
                    newBuilder(configureClientSSL)
                            .localAddress(loopbackAddr)
                            .build(),
                    requestURI,
                    loopbackAddr
            });
            // anyAddress
            if (IPSupport.hasIPv6()) {
                // ipv6 wildcard
                final var localAddr = InetAddress.getByName("::");
                testMethodParams.add(new Object[]{
                        newBuilder(configureClientSSL)
                                .localAddress(localAddr)
                                .build(),
                        requestURI,
                        localAddr
                });
            }
            if (IPSupport.hasIPv4()) {
                // ipv4 wildcard
                final var localAddr = InetAddress.getByName("0.0.0.0");
                testMethodParams.add(new Object[]{
                        newBuilder(configureClientSSL)
                                .localAddress(localAddr)
                                .build(),
                        requestURI,
                        localAddr
                });
            }
        }
        return testMethodParams.stream().toArray(Object[][]::new);
    }

    private static HttpClient.Builder newBuilder(boolean configureClientSSL) {
        var builder = HttpClient.newBuilder();
        // don't let proxies interfere with the client addresses received on the
        // HTTP request, by the server side handler used in this test.
        builder.proxy(HttpClient.Builder.NO_PROXY);
        if (configureClientSSL) {
            builder.sslContext(sslContext);
        }
        return builder;
    }

    /**
     * Sends a GET request using the {@code client} and expects a 200 response.
     * The returned response body is then tested to see if the client address
     * seen by the server side handler is the same one as that is set on the
     * {@code client}
     */
    @Test(dataProvider = "params")
    public void testSend(HttpClient client, URI requestURI, InetAddress localAddress) throws Exception {
        System.out.println("Testing using a HTTP client " + client.version() + " with local address " + localAddress
                + " against request URI " + requestURI);
        // GET request
        var req = HttpRequest.newBuilder(requestURI).build();
        var resp = client.send(req, HttpResponse.BodyHandlers.ofByteArray());
        Assert.assertEquals(resp.statusCode(), 200, "Unexpected status code");
        // verify the address only if a specific one was set on the client
        if (localAddress != null && !localAddress.isAnyLocalAddress()) {
            Assert.assertEquals(resp.body(), localAddress.getAddress(),
                    "Unexpected client address seen by the server handler");
        }
    }

    /**
     * Sends a GET request using the {@code sendAsync} method on the {@code client} and
     * expects a 200 response. The returned response body is then tested to see if the client address
     * seen by the server side handler is the same one as that is set on the
     * {@code client}
     */
    @Test(dataProvider = "params")
    public void testSendAsync(HttpClient client, URI requestURI, InetAddress localAddress) throws Exception {
        System.out.println("Testing using a HTTP client " + client.version() + " with local address " + localAddress
                + " against request URI " + requestURI);
        // GET request
        var req = HttpRequest.newBuilder(requestURI).build();
        var cf = client.sendAsync(req,
                HttpResponse.BodyHandlers.ofByteArray());
        var resp = cf.get();
        Assert.assertEquals(resp.statusCode(), 200, "Unexpected status code");
        // verify the address only if a specific one was set on the client
        if (localAddress != null && !localAddress.isAnyLocalAddress()) {
            Assert.assertEquals(resp.body(), localAddress.getAddress(),
                    "Unexpected client address seen by the server handler");
        }
    }

    /**
     * Invokes the {@link #testSend(HttpClient)} and {@link #testSendAsync(HttpClient)}
     * tests, concurrently in multiple threads to verify that the correct local address
     * is used when multiple concurrent threads are involved in sending requests from
     * the {@code client}
     */
    @Test(dataProvider = "params")
    public void testMultiSendRequests(HttpClient client, URI requestURI, InetAddress localAddress) throws Exception {
        int numThreads = 4;
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        List<Future<Void>> taskResults = new ArrayList<>();
        try {
            for (int i = 0; i < numThreads; i++) {
                final var currentIdx = i;
                var f = executor.submit(new Callable<Void>() {
                    @Override
                    public Void call() throws Exception {
                        // test some for send and some for sendAsync
                        if (currentIdx % 2 == 0) {
                            testSend(client, requestURI, localAddress);
                        } else {
                            testSendAsync(client, requestURI, localAddress);
                        }
                        return null;
                    }
                });
                taskResults.add(f);
            }
            // wait for results
            for (var r : taskResults) {
                r.get();
            }
        } finally {
            executor.shutdownNow();
        }
    }
}