478 lines
18 KiB
Java
478 lines
18 KiB
Java
|
/*
|
||
|
* Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
|
||
|
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||
|
*
|
||
|
* This code is free software; you can redistribute it and/or modify it
|
||
|
* under the terms of the GNU General Public License version 2 only, as
|
||
|
* published by the Free Software Foundation.
|
||
|
*
|
||
|
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||
|
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||
|
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||
|
* version 2 for more details (a copy is included in the LICENSE file that
|
||
|
* accompanied this code).
|
||
|
*
|
||
|
* You should have received a copy of the GNU General Public License version
|
||
|
* 2 along with this work; if not, write to the Free Software Foundation,
|
||
|
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||
|
*
|
||
|
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||
|
* or visit www.oracle.com if you need additional information or have any
|
||
|
* questions.
|
||
|
*/
|
||
|
|
||
|
import java.io.Closeable;
|
||
|
import java.io.IOException;
|
||
|
import java.io.ObjectInputStream;
|
||
|
import java.io.ObjectOutputStream;
|
||
|
import java.io.Serializable;
|
||
|
import java.net.ServerSocket;
|
||
|
import java.net.Socket;
|
||
|
import java.net.UnknownHostException;
|
||
|
import java.util.ArrayList;
|
||
|
import java.util.Arrays;
|
||
|
import java.util.HashMap;
|
||
|
import java.util.Map;
|
||
|
import java.util.StringJoiner;
|
||
|
import javax.security.auth.callback.Callback;
|
||
|
import javax.security.auth.callback.CallbackHandler;
|
||
|
import javax.security.auth.callback.NameCallback;
|
||
|
import javax.security.auth.callback.PasswordCallback;
|
||
|
import javax.security.auth.callback.UnsupportedCallbackException;
|
||
|
import javax.security.sasl.AuthorizeCallback;
|
||
|
import javax.security.sasl.RealmCallback;
|
||
|
import javax.security.sasl.RealmChoiceCallback;
|
||
|
import javax.security.sasl.Sasl;
|
||
|
import javax.security.sasl.SaslClient;
|
||
|
import javax.security.sasl.SaslException;
|
||
|
import javax.security.sasl.SaslServer;
|
||
|
|
||
|
/*
|
||
|
* @test
|
||
|
* @bug 8049814
|
||
|
* @summary JAVA SASL server and client tests with CRAM-MD5 and
|
||
|
* DIGEST-MD5 mechanisms. The tests try different QOP values on
|
||
|
* client and server side.
|
||
|
* @modules java.security.sasl/javax.security.sasl
|
||
|
*/
|
||
|
public class ClientServerTest {
|
||
|
|
||
|
private static final int DELAY = 100;
|
||
|
private static final String LOCALHOST = "localhost";
|
||
|
private static final String DIGEST_MD5 = "DIGEST-MD5";
|
||
|
private static final String CRAM_MD5 = "CRAM-MD5";
|
||
|
private static final String PROTOCOL = "saslservice";
|
||
|
private static final String USER_ID = "sasltester";
|
||
|
private static final String PASSWD = "password";
|
||
|
private static final String QOP_AUTH = "auth";
|
||
|
private static final String QOP_AUTH_CONF = "auth-conf";
|
||
|
private static final String QOP_AUTH_INT = "auth-int";
|
||
|
private static final String AUTHID_SASL_TESTER = "sasl_tester";
|
||
|
private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>();
|
||
|
|
||
|
static {
|
||
|
SUPPORT_MECHS.add(DIGEST_MD5);
|
||
|
SUPPORT_MECHS.add(CRAM_MD5);
|
||
|
}
|
||
|
|
||
|
public static void main(String[] args) throws Exception {
|
||
|
String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH };
|
||
|
String[] twoQops = { QOP_AUTH_INT, QOP_AUTH };
|
||
|
String[] authQop = { QOP_AUTH };
|
||
|
String[] authIntQop = { QOP_AUTH_INT };
|
||
|
String[] authConfQop = { QOP_AUTH_CONF };
|
||
|
String[] emptyQop = {};
|
||
|
|
||
|
boolean success = true;
|
||
|
|
||
|
success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH },
|
||
|
new String[] { QOP_AUTH }, false);
|
||
|
success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH },
|
||
|
new String[] { QOP_AUTH }, false);
|
||
|
success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5,
|
||
|
new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false);
|
||
|
success &= runTest("", DIGEST_MD5, allQops, authQop, false);
|
||
|
success &= runTest("", DIGEST_MD5, allQops, authIntQop, false);
|
||
|
success &= runTest("", DIGEST_MD5, allQops, authConfQop, false);
|
||
|
success &= runTest("", DIGEST_MD5, twoQops, authQop, false);
|
||
|
success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false);
|
||
|
success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true);
|
||
|
success &= runTest("", DIGEST_MD5, authIntQop, authQop, true);
|
||
|
success &= runTest("", DIGEST_MD5, authConfQop, authQop, true);
|
||
|
success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true);
|
||
|
success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true);
|
||
|
success &= runTest("", DIGEST_MD5, authQop, emptyQop, true);
|
||
|
|
||
|
if (!success) {
|
||
|
throw new RuntimeException("At least one test case failed");
|
||
|
}
|
||
|
|
||
|
System.out.println("Test passed");
|
||
|
}
|
||
|
|
||
|
private static boolean runTest(String authId, String mech,
|
||
|
String[] clientQops, String[] serverQops, boolean expectException)
|
||
|
throws Exception {
|
||
|
|
||
|
System.out.println("AuthId:" + authId
|
||
|
+ " mechanism:" + mech
|
||
|
+ " clientQops: " + Arrays.toString(clientQops)
|
||
|
+ " serverQops: " + Arrays.toString(serverQops)
|
||
|
+ " expect exception:" + expectException);
|
||
|
|
||
|
try (Server server = Server.start(LOCALHOST, authId, serverQops)) {
|
||
|
new Client(LOCALHOST, server.getPort(), mech, authId, clientQops)
|
||
|
.run();
|
||
|
if (expectException) {
|
||
|
System.out.println("Expected exception not thrown");
|
||
|
return false;
|
||
|
}
|
||
|
} catch (SaslException e) {
|
||
|
if (!expectException) {
|
||
|
System.out.println("Unexpected exception: " + e);
|
||
|
return false;
|
||
|
}
|
||
|
System.out.println("Expected exception: " + e);
|
||
|
}
|
||
|
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
static enum SaslStatus {
|
||
|
SUCCESS, FAILURE, CONTINUE
|
||
|
}
|
||
|
|
||
|
static class Message implements Serializable {
|
||
|
|
||
|
private final SaslStatus status;
|
||
|
private final byte[] data;
|
||
|
|
||
|
public Message(SaslStatus status, byte[] data) {
|
||
|
this.status = status;
|
||
|
this.data = data;
|
||
|
}
|
||
|
|
||
|
public SaslStatus getStatus() {
|
||
|
return status;
|
||
|
}
|
||
|
|
||
|
public byte[] getData() {
|
||
|
return data;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static class SaslPeer {
|
||
|
|
||
|
final String host;
|
||
|
final String mechanism;
|
||
|
final String qop;
|
||
|
final CallbackHandler callback;
|
||
|
|
||
|
SaslPeer(String host, String authId, String... qops) {
|
||
|
this(host, null, authId, qops);
|
||
|
}
|
||
|
|
||
|
SaslPeer(String host, String mechanism, String authId, String... qops) {
|
||
|
this.host = host;
|
||
|
this.mechanism = mechanism;
|
||
|
|
||
|
StringJoiner sj = new StringJoiner(",");
|
||
|
for (String q : qops) {
|
||
|
sj.add(q);
|
||
|
}
|
||
|
qop = sj.toString();
|
||
|
|
||
|
callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId);
|
||
|
}
|
||
|
|
||
|
Message getMessage(Object ob) {
|
||
|
if (!(ob instanceof Message)) {
|
||
|
throw new RuntimeException("Expected an instance of Message");
|
||
|
}
|
||
|
return (Message) ob;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static class Server extends SaslPeer implements Runnable, Closeable {
|
||
|
|
||
|
private volatile boolean ready = false;
|
||
|
private volatile ServerSocket ssocket;
|
||
|
|
||
|
static Server start(String host, String authId, String[] serverQops)
|
||
|
throws UnknownHostException {
|
||
|
Server server = new Server(host, authId, serverQops);
|
||
|
Thread thread = new Thread(server);
|
||
|
thread.setDaemon(true);
|
||
|
thread.start();
|
||
|
|
||
|
while (!server.ready) {
|
||
|
try {
|
||
|
Thread.sleep(DELAY);
|
||
|
} catch (InterruptedException e) {
|
||
|
throw new RuntimeException(e);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return server;
|
||
|
}
|
||
|
|
||
|
Server(String host, String authId, String... qops) {
|
||
|
super(host, authId, qops);
|
||
|
}
|
||
|
|
||
|
int getPort() {
|
||
|
return ssocket.getLocalPort();
|
||
|
}
|
||
|
|
||
|
private void processConnection(SaslEndpoint endpoint)
|
||
|
throws SaslException, IOException, ClassNotFoundException {
|
||
|
System.out.println("process connection");
|
||
|
endpoint.send(SUPPORT_MECHS);
|
||
|
Object o = endpoint.receive();
|
||
|
if (!(o instanceof String)) {
|
||
|
throw new RuntimeException("Received unexpected object: " + o);
|
||
|
}
|
||
|
String mech = (String) o;
|
||
|
SaslServer saslServer = createSaslServer(mech);
|
||
|
Message msg = getMessage(endpoint.receive());
|
||
|
while (!saslServer.isComplete()) {
|
||
|
byte[] data = processData(msg.getData(), endpoint,
|
||
|
saslServer);
|
||
|
if (saslServer.isComplete()) {
|
||
|
System.out.println("server is complete");
|
||
|
endpoint.send(new Message(SaslStatus.SUCCESS, data));
|
||
|
} else {
|
||
|
System.out.println("server continues");
|
||
|
endpoint.send(new Message(SaslStatus.CONTINUE, data));
|
||
|
msg = getMessage(endpoint.receive());
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private byte[] processData(byte[] data, SaslEndpoint endpoint,
|
||
|
SaslServer server) throws SaslException, IOException {
|
||
|
try {
|
||
|
return server.evaluateResponse(data);
|
||
|
} catch (SaslException e) {
|
||
|
endpoint.send(new Message(SaslStatus.FAILURE, null));
|
||
|
System.out.println("Error while processing data");
|
||
|
throw e;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private SaslServer createSaslServer(String mechanism)
|
||
|
throws SaslException {
|
||
|
Map<String, String> props = new HashMap<>();
|
||
|
props.put(Sasl.QOP, qop);
|
||
|
return Sasl.createSaslServer(mechanism, PROTOCOL, host, props,
|
||
|
callback);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void run() {
|
||
|
try (ServerSocket ss = new ServerSocket(0)) {
|
||
|
ssocket = ss;
|
||
|
System.out.println("server started on port " + getPort());
|
||
|
ready = true;
|
||
|
Socket socket = ss.accept();
|
||
|
try (SaslEndpoint endpoint = new SaslEndpoint(socket)) {
|
||
|
System.out.println("server accepted connection");
|
||
|
processConnection(endpoint);
|
||
|
}
|
||
|
} catch (Exception e) {
|
||
|
// ignore it for now, client will throw an exception
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void close() throws IOException {
|
||
|
if (!ssocket.isClosed()) {
|
||
|
ssocket.close();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static class Client extends SaslPeer {
|
||
|
|
||
|
private final int port;
|
||
|
|
||
|
Client(String host, int port, String mech, String authId,
|
||
|
String... qops) {
|
||
|
super(host, mech, authId, qops);
|
||
|
this.port = port;
|
||
|
}
|
||
|
|
||
|
public void run() throws Exception {
|
||
|
System.out.println("Host:" + host + " port: "
|
||
|
+ port);
|
||
|
try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) {
|
||
|
negotiateMechanism(endpoint);
|
||
|
SaslClient client = createSaslClient();
|
||
|
byte[] data = new byte[0];
|
||
|
if (client.hasInitialResponse()) {
|
||
|
data = client.evaluateChallenge(data);
|
||
|
}
|
||
|
endpoint.send(new Message(SaslStatus.CONTINUE, data));
|
||
|
Message msg = getMessage(endpoint.receive());
|
||
|
while (!client.isComplete()
|
||
|
&& msg.getStatus() != SaslStatus.FAILURE) {
|
||
|
switch (msg.getStatus()) {
|
||
|
case CONTINUE:
|
||
|
System.out.println("client continues");
|
||
|
data = client.evaluateChallenge(msg.getData());
|
||
|
endpoint.send(new Message(SaslStatus.CONTINUE,
|
||
|
data));
|
||
|
msg = getMessage(endpoint.receive());
|
||
|
break;
|
||
|
case SUCCESS:
|
||
|
System.out.println("client succeeded");
|
||
|
data = client.evaluateChallenge(msg.getData());
|
||
|
if (data != null) {
|
||
|
throw new SaslException("data should be null");
|
||
|
}
|
||
|
break;
|
||
|
default:
|
||
|
throw new RuntimeException("Wrong status:"
|
||
|
+ msg.getStatus());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (msg.getStatus() == SaslStatus.FAILURE) {
|
||
|
throw new RuntimeException("Status is FAILURE");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
System.out.println("Done");
|
||
|
}
|
||
|
|
||
|
private SaslClient createSaslClient() throws SaslException {
|
||
|
Map<String, String> props = new HashMap<>();
|
||
|
props.put(Sasl.QOP, qop);
|
||
|
return Sasl.createSaslClient(new String[] {mechanism}, USER_ID,
|
||
|
PROTOCOL, host, props, callback);
|
||
|
}
|
||
|
|
||
|
private void negotiateMechanism(SaslEndpoint endpoint)
|
||
|
throws ClassNotFoundException, IOException {
|
||
|
Object o = endpoint.receive();
|
||
|
if (o instanceof ArrayList) {
|
||
|
ArrayList list = (ArrayList) o;
|
||
|
if (!list.contains(mechanism)) {
|
||
|
throw new RuntimeException(
|
||
|
"Server does not support specified mechanism:"
|
||
|
+ mechanism);
|
||
|
}
|
||
|
} else {
|
||
|
throw new RuntimeException(
|
||
|
"Expected an instance of ArrayList, but received " + o);
|
||
|
}
|
||
|
|
||
|
endpoint.send(mechanism);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
static class SaslEndpoint implements AutoCloseable {
|
||
|
|
||
|
private final Socket socket;
|
||
|
private ObjectInputStream input;
|
||
|
private ObjectOutputStream output;
|
||
|
|
||
|
static SaslEndpoint create(String host, int port) throws IOException {
|
||
|
return new SaslEndpoint(new Socket(host, port));
|
||
|
}
|
||
|
|
||
|
SaslEndpoint(Socket socket) throws IOException {
|
||
|
this.socket = socket;
|
||
|
}
|
||
|
|
||
|
private ObjectInputStream getInput() throws IOException {
|
||
|
if (input == null && socket != null) {
|
||
|
input = new ObjectInputStream(socket.getInputStream());
|
||
|
}
|
||
|
return input;
|
||
|
}
|
||
|
|
||
|
private ObjectOutputStream getOutput() throws IOException {
|
||
|
if (output == null && socket != null) {
|
||
|
output = new ObjectOutputStream(socket.getOutputStream());
|
||
|
}
|
||
|
return output;
|
||
|
}
|
||
|
|
||
|
public Object receive() throws IOException, ClassNotFoundException {
|
||
|
return getInput().readObject();
|
||
|
}
|
||
|
|
||
|
public void send(Object obj) throws IOException {
|
||
|
getOutput().writeObject(obj);
|
||
|
getOutput().flush();
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void close() throws IOException {
|
||
|
if (socket != null && !socket.isClosed()) {
|
||
|
socket.close();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
static class TestCallbackHandler implements CallbackHandler {
|
||
|
|
||
|
private final String userId;
|
||
|
private final char[] passwd;
|
||
|
private final String realm;
|
||
|
private String authId;
|
||
|
|
||
|
TestCallbackHandler(String userId, String passwd, String realm,
|
||
|
String authId) {
|
||
|
this.userId = userId;
|
||
|
this.passwd = passwd.toCharArray();
|
||
|
this.realm = realm;
|
||
|
this.authId = authId;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void handle(Callback[] callbacks) throws IOException,
|
||
|
UnsupportedCallbackException {
|
||
|
for (Callback callback : callbacks) {
|
||
|
if (callback instanceof NameCallback) {
|
||
|
System.out.println("NameCallback");
|
||
|
((NameCallback) callback).setName(userId);
|
||
|
} else if (callback instanceof PasswordCallback) {
|
||
|
System.out.println("PasswordCallback");
|
||
|
((PasswordCallback) callback).setPassword(passwd);
|
||
|
} else if (callback instanceof RealmCallback) {
|
||
|
System.out.println("RealmCallback");
|
||
|
((RealmCallback) callback).setText(realm);
|
||
|
} else if (callback instanceof RealmChoiceCallback) {
|
||
|
System.out.println("RealmChoiceCallback");
|
||
|
RealmChoiceCallback choice = (RealmChoiceCallback) callback;
|
||
|
if (realm == null) {
|
||
|
choice.setSelectedIndex(choice.getDefaultChoice());
|
||
|
} else {
|
||
|
String[] choices = choice.getChoices();
|
||
|
for (int j = 0; j < choices.length; j++) {
|
||
|
if (realm.equals(choices[j])) {
|
||
|
choice.setSelectedIndex(j);
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
} else if (callback instanceof AuthorizeCallback) {
|
||
|
System.out.println("AuthorizeCallback");
|
||
|
((AuthorizeCallback) callback).setAuthorized(true);
|
||
|
if (authId == null || authId.trim().length() == 0) {
|
||
|
authId = userId;
|
||
|
}
|
||
|
((AuthorizeCallback) callback).setAuthorizedID(authId);
|
||
|
} else {
|
||
|
throw new UnsupportedCallbackException(callback);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|