8290714: Make com.sun.jndi.dns.DnsClient virtual threads friendly

Reviewed-by: dfuchs, jpai
This commit is contained in:
Aleksei Efimov 2022-11-10 19:20:33 +00:00
parent d6468be81f
commit 9ef7852be3
3 changed files with 163 additions and 112 deletions
src/jdk.naming.dns/share/classes/com/sun/jndi/dns
test/jdk/com/sun/jndi/dns/ConfigTests

@ -27,7 +27,6 @@ package com.sun.jndi.dns;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.ProtocolFamily;
import java.net.SocketException;
import java.net.InetSocketAddress;
import java.nio.channels.DatagramChannel;
import java.security.AccessController;
@ -35,7 +34,7 @@ import java.security.PrivilegedExceptionAction;
import java.util.Objects;
import java.util.Random;
class DNSDatagramSocketFactory {
class DNSDatagramChannelFactory {
static final int DEVIATION = 3;
static final int THRESHOLD = 6;
static final int BIT_DEVIATION = 2;
@ -120,17 +119,17 @@ class DNSDatagramSocketFactory {
final Random random;
final PortHistory history;
DNSDatagramSocketFactory() {
DNSDatagramChannelFactory() {
this(new Random());
}
DNSDatagramSocketFactory(Random random) {
DNSDatagramChannelFactory(Random random) {
this(Objects.requireNonNull(random), null, DEVIATION, THRESHOLD);
}
DNSDatagramSocketFactory(Random random,
ProtocolFamily family,
int deviation,
int threshold) {
DNSDatagramChannelFactory(Random random,
ProtocolFamily family,
int deviation,
int threshold) {
this.random = Objects.requireNonNull(random);
this.history = new PortHistory(HISTORY, random);
this.family = family;
@ -145,12 +144,13 @@ class DNSDatagramSocketFactory {
* port) then the underlying OS implementation is used. Otherwise, this
* method will allocate and bind a socket on a randomly selected ephemeral
* port in the dynamic range.
* @return A new DatagramSocket bound to a random port.
* @throws SocketException if the socket cannot be created.
*
* @return A new DatagramChannel bound to a random port.
* @throws IOException if the socket cannot be created.
*/
public synchronized DatagramSocket open() throws SocketException {
public synchronized DatagramChannel open() throws IOException {
int lastseen = lastport;
DatagramSocket s;
DatagramChannel s;
boolean thresholdCrossed = unsuitablePortCount > thresholdCount;
if (thresholdCrossed) {
@ -166,7 +166,7 @@ class DNSDatagramSocketFactory {
// Allocate an ephemeral port (port 0)
s = openDefault();
lastport = s.getLocalPort();
lastport = getLocalPort(s);
if (lastseen == 0) {
lastSystemAllocated = lastport;
history.offer(lastport);
@ -199,36 +199,27 @@ class DNSDatagramSocketFactory {
// Undecided... the new port was too close. Let's allocate a random
// port using our own algorithm
assert !thresholdCrossed;
DatagramSocket ss = openRandom();
DatagramChannel ss = openRandom();
if (ss == null) return s;
unsuitablePortCount++;
s.close();
return ss;
}
private DatagramSocket openDefault() throws SocketException {
if (family != null) {
try {
DatagramChannel c = DatagramChannel.open(family);
try {
DatagramSocket s = c.socket();
s.bind(null);
return s;
} catch (Throwable x) {
c.close();
throw x;
}
} catch (SocketException x) {
throw x;
} catch (IOException x) {
throw new SocketException(x.getMessage(), x);
}
private DatagramChannel openDefault() throws IOException {
DatagramChannel c = family != null ? DatagramChannel.open(family)
: DatagramChannel.open();
try {
c.bind(null);
return c;
} catch (Throwable x) {
c.close();
throw x;
}
return new DatagramSocket();
}
synchronized boolean isUsingNativePortRandomization() {
return unsuitablePortCount <= thresholdCount
return unsuitablePortCount <= thresholdCount
&& suitablePortCount > thresholdCount;
}
@ -246,7 +237,7 @@ class DNSDatagramSocketFactory {
&& Math.abs(port - lastport) > deviation;
}
private DatagramSocket openRandom() {
private DatagramChannel openRandom() throws IOException {
int maxtries = MAX_RANDOM_TRIES;
while (maxtries-- > 0) {
int port;
@ -265,29 +256,24 @@ class DNSDatagramSocketFactory {
// times - but that should be OK with MAX_RANDOM_TRIES = 5.
if (!suitable) continue;
DatagramChannel dc = (family != null)
? DatagramChannel.open(family)
: DatagramChannel.open();
try {
if (family != null) {
DatagramChannel c = DatagramChannel.open(family);
try {
DatagramSocket s = c.socket();
s.bind(new InetSocketAddress(port));
lastport = s.getLocalPort();
if (!recycled) history.add(port);
return s;
} catch (Throwable x) {
c.close();
throw x;
}
}
DatagramSocket s = new DatagramSocket(port);
lastport = s.getLocalPort();
dc.bind(new InetSocketAddress(port));
lastport = getLocalPort(dc);
if (!recycled) history.add(port);
return s;
return dc;
} catch (IOException x) {
dc.close();
// try again until maxtries == 0;
}
}
return null;
}
private static int getLocalPort(DatagramChannel dc) throws IOException {
return ((InetSocketAddress) dc.getLocalAddress()).getPort();
}
}

@ -27,19 +27,29 @@ package com.sun.jndi.dns;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.DatagramSocket;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.PortUnreachableException;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.security.SecureRandom;
import javax.naming.*;
import javax.naming.CommunicationException;
import javax.naming.ConfigurationException;
import javax.naming.NameNotFoundException;
import javax.naming.NamingException;
import javax.naming.OperationNotSupportedException;
import javax.naming.ServiceUnavailableException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.locks.ReentrantLock;
import sun.security.jca.JCAUtil;
@ -83,15 +93,19 @@ public class DnsClient {
private static final int DEFAULT_PORT = 53;
private static final int TRANSACTION_ID_BOUND = 0x10000;
private static final int MIN_TIMEOUT = 50; // msec after which there are no retries.
private static final SecureRandom random = JCAUtil.getSecureRandom();
private InetAddress[] servers;
private int[] serverPorts;
private int timeout; // initial timeout on UDP and TCP queries in ms
private int retries; // number of UDP retries
private final Object udpSocketLock = new Object();
private static final DNSDatagramSocketFactory factory =
new DNSDatagramSocketFactory(random);
private final ReentrantLock udpChannelLock = new ReentrantLock();
private final Selector udpChannelSelector;
private static final DNSDatagramChannelFactory factory =
new DNSDatagramChannelFactory(random);
// Requests sent
private Map<Integer, ResourceRecord> reqs;
@ -135,15 +149,24 @@ public class DnsClient {
throw ne;
}
}
try {
udpChannelSelector = Selector.open();
} catch (IOException e) {
NamingException ne = new ConfigurationException(
"Channel selector configuration error");
ne.setRootCause(e);
throw ne;
}
reqs = Collections.synchronizedMap(
new HashMap<Integer, ResourceRecord>());
resps = Collections.synchronizedMap(new HashMap<Integer, byte[]>());
}
DatagramSocket getDatagramSocket() throws NamingException {
DatagramChannel getDatagramChannel() throws NamingException {
try {
return factory.open();
} catch (java.net.SocketException e) {
} catch (IOException e) {
NamingException ne = new ConfigurationException();
ne.setRootCause(e);
throw ne;
@ -159,6 +182,10 @@ public class DnsClient {
private Object queuesLock = new Object();
public void close() {
try {
udpChannelSelector.close();
} catch (IOException ioException) {
}
synchronized (queuesLock) {
reqs.clear();
resps.clear();
@ -212,23 +239,9 @@ public class DnsClient {
dprint("SEND ID (" + (retry + 1) + "): " + xid);
}
byte[] msg = null;
msg = doUdpQuery(pkt, servers[i], serverPorts[i],
retry, xid);
//
// If the matching response is not got within the
// given timeout, check if the response was enqueued
// by some other thread, if not proceed with the next
// server or retry.
//
if (msg == null) {
if (resps.size() > 0) {
msg = lookupResponse(xid);
}
if (msg == null) { // try next server or retry
continue;
}
}
byte[] msg = doUdpQuery(pkt, servers[i], serverPorts[i],
retry, xid);
assert msg != null;
Header hdr = new Header(msg, msg.length);
if (auth && !hdr.authoritative) {
@ -294,6 +307,13 @@ public class DnsClient {
if (caughtException == null) {
caughtException = e;
}
} catch (ClosedSelectorException e) {
// ClosedSelectorException is thrown by blockingReceive if
// the datagram channel selector associated with DNS client
// is unexpectedly closed
var ce = new CommunicationException("DNS client closed");
ce.setRootCause(e);
throw ce;
} catch (NameNotFoundException e) {
// This is authoritative, so return immediately
throw e;
@ -402,22 +422,28 @@ public class DnsClient {
int port, int retry, int xid)
throws IOException, NamingException {
int minTimeout = 50; // msec after which there are no retries.
synchronized (udpSocketLock) {
try (DatagramSocket udpSocket = getDatagramSocket()) {
DatagramPacket opkt = new DatagramPacket(
pkt.getData(), pkt.length(), server, port);
DatagramPacket ipkt = new DatagramPacket(new byte[8000], 8000);
udpChannelLock.lock();
try {
try (DatagramChannel udpChannel = getDatagramChannel()) {
ByteBuffer opkt = ByteBuffer.wrap(pkt.getData(), 0, pkt.length());
byte[] data = new byte[8000];
ByteBuffer ipkt = ByteBuffer.wrap(data);
// Packets may only be sent to or received from this server address
udpSocket.connect(server, port);
InetSocketAddress target = new InetSocketAddress(server, port);
udpChannel.connect(target);
int pktTimeout = (timeout * (1 << retry));
udpSocket.send(opkt);
udpChannel.write(opkt);
// timeout remaining after successive 'receive()'
// timeout remaining after successive 'blockingReceive()'
int timeoutLeft = pktTimeout;
int cnt = 0;
boolean gotData = false;
do {
// prepare for retry
if (gotData) {
Arrays.fill(data, 0, ipkt.position(), (byte) 0);
ipkt.clear();
}
if (debug) {
cnt++;
dprint("Trying RECEIVE(" +
@ -425,22 +451,53 @@ public class DnsClient {
") for:" + xid + " sock-timeout:" +
timeoutLeft + " ms.");
}
udpSocket.setSoTimeout(timeoutLeft);
long start = System.currentTimeMillis();
udpSocket.receive(ipkt);
gotData = blockingReceive(udpChannel, ipkt, timeoutLeft);
long end = System.currentTimeMillis();
byte[] data = ipkt.getData();
if (isMatchResponse(data, xid)) {
assert gotData || ipkt.position() == 0;
if (gotData && isMatchResponse(data, xid)) {
return data;
} else if (resps.size() > 0) {
// If the matching response is not found, check if
// the response was enqueued by some other thread,
// if not continue
byte[] cachedMsg = lookupResponse(xid);
if (cachedMsg != null) { // found in cache
return cachedMsg;
}
}
timeoutLeft = pktTimeout - ((int) (end - start));
} while (timeoutLeft > minTimeout);
return null; // no matching packet received within the timeout
} while (timeoutLeft > MIN_TIMEOUT);
// no matching packets received within the timeout
throw new SocketTimeoutException();
}
} finally {
udpChannelLock.unlock();
}
}
boolean blockingReceive(DatagramChannel dc, ByteBuffer buffer, long timeout) throws IOException {
boolean dataReceived = false;
// The provided datagram channel will be used by the caller only to receive data after
// it is put to non-blocking mode
dc.configureBlocking(false);
var selectionKey = dc.register(udpChannelSelector, SelectionKey.OP_READ);
try {
udpChannelSelector.select(timeout);
var keys = udpChannelSelector.selectedKeys();
if (keys.contains(selectionKey) && selectionKey.isReadable()) {
dc.receive(buffer);
dataReceived = true;
}
keys.clear();
} finally {
selectionKey.cancel();
// Flush the canceled key out of the selected key set
udpChannelSelector.selectNow();
}
return dataReceived;
}
/*
* Sends a TCP query, and returns the first DNS message in the response.
*/
@ -629,7 +686,7 @@ public class DnsClient {
//
synchronized (queuesLock) {
if (reqs.containsKey(hdr.xid)) { // enqueue only the first response
resps.put(hdr.xid, pkt);
resps.put(hdr.xid, pkt.clone());
}
}

@ -1,5 +1,5 @@
/*
* Copyright (c) 2002, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2002, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -24,28 +24,26 @@
import javax.naming.CommunicationException;
import javax.naming.Context;
import javax.naming.directory.InitialDirContext;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.time.Instant;
import jdk.test.lib.net.URIBuilder;
/*
* @test
* @bug 8200151 8265309
* @summary Tests that we can set the initial UDP timeout interval and the
* number of retries.
* @library ../lib/
* @library ../lib/ /test/lib
* @modules java.base/sun.security.util
* @run main Timeout
*/
public class Timeout extends DNSTestBase {
// Host 10.0.0.0 is a bit bucket, used here to simulate a DNS server that
// doesn't respond. 10.0.0.0 server shouldn't be reachable.
// Ping to this address should not give any reply
private static final String HOST = "10.0.0.0";
// Port 9 is a bit bucket, used here to simulate a DNS server that
// doesn't respond.
private static final int PORT = 9;
// initial timeout = 1/4 sec
private static final int TIMEOUT = 250;
// try 5 times per server
@ -67,18 +65,28 @@ public class Timeout extends DNSTestBase {
*/
@Override
public void runTest() throws Exception {
String allQuietUrl = "dns://" + HOST + ":" + PORT;
env().put(Context.PROVIDER_URL, allQuietUrl);
env().put("com.sun.jndi.dns.timeout.initial", String.valueOf(TIMEOUT));
env().put("com.sun.jndi.dns.timeout.retries", String.valueOf(RETRIES));
setContext(new InitialDirContext(env()));
// Create a DatagramSocket and bind it to the loopback address to simulate
// UDP DNS server that doesn't respond
try (DatagramSocket ds = new DatagramSocket(
new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))) {
String allQuietUrl = URIBuilder.newBuilder()
.scheme("dns")
.loopback()
.port(ds.getLocalPort())
.build()
.toString();
env().put(Context.PROVIDER_URL, allQuietUrl);
env().put("com.sun.jndi.dns.timeout.initial", String.valueOf(TIMEOUT));
env().put("com.sun.jndi.dns.timeout.retries", String.valueOf(RETRIES));
setContext(new InitialDirContext(env()));
// Any request should fail after timeouts have expired.
startTime = Instant.now();
context().getAttributes("");
// Any request should fail after timeouts have expired.
startTime = Instant.now();
context().getAttributes("");
throw new RuntimeException(
"Failed: getAttributes succeeded unexpectedly");
throw new RuntimeException(
"Failed: getAttributes succeeded unexpectedly");
}
}
@Override