8263087: Add a MethodHandle combinator that switches over a set of MethodHandles

Reviewed-by: redestad
This commit is contained in:
Jorn Vernee 2021-05-27 12:28:10 +00:00
parent 85f616522b
commit 3623abb7f6
10 changed files with 974 additions and 26 deletions

View File

@ -864,6 +864,12 @@ class InvokerBytecodeGenerator {
onStack = emitTryFinally(i);
i += 2; // jump to the end of the TF idiom
continue;
case TABLE_SWITCH:
assert lambdaForm.isTableSwitch(i);
int numCases = (Integer) name.function.intrinsicData();
onStack = emitTableSwitch(i, numCases);
i += 2; // jump to the end of the TS idiom
continue;
case LOOP:
assert lambdaForm.isLoop(i);
onStack = emitLoop(i);
@ -1389,6 +1395,58 @@ class InvokerBytecodeGenerator {
}
}
private Name emitTableSwitch(int pos, int numCases) {
Name args = lambdaForm.names[pos];
Name invoker = lambdaForm.names[pos + 1];
Name result = lambdaForm.names[pos + 2];
Class<?> returnType = result.function.resolvedHandle().type().returnType();
MethodType caseType = args.function.resolvedHandle().type()
.dropParameterTypes(0, 1) // drop collector
.changeReturnType(returnType);
String caseDescriptor = caseType.basicType().toMethodDescriptorString();
emitPushArgument(invoker, 2); // push cases
mv.visitFieldInsn(Opcodes.GETFIELD, "java/lang/invoke/MethodHandleImpl$CasesHolder", "cases",
"[Ljava/lang/invoke/MethodHandle;");
int casesLocal = extendLocalsMap(new Class<?>[] { MethodHandle[].class });
emitStoreInsn(L_TYPE, casesLocal);
Label endLabel = new Label();
Label defaultLabel = new Label();
Label[] caseLabels = new Label[numCases];
for (int i = 0; i < caseLabels.length; i++) {
caseLabels[i] = new Label();
}
emitPushArgument(invoker, 0); // push switch input
mv.visitTableSwitchInsn(0, numCases - 1, defaultLabel, caseLabels);
mv.visitLabel(defaultLabel);
emitPushArgument(invoker, 1); // push default handle
emitPushArguments(args, 1); // again, skip collector
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, MH, "invokeBasic", caseDescriptor, false);
mv.visitJumpInsn(Opcodes.GOTO, endLabel);
for (int i = 0; i < numCases; i++) {
mv.visitLabel(caseLabels[i]);
// Load the particular case:
emitLoadInsn(L_TYPE, casesLocal);
emitIconstInsn(i);
mv.visitInsn(Opcodes.AALOAD);
// invoke it:
emitPushArguments(args, 1); // again, skip collector
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, MH, "invokeBasic", caseDescriptor, false);
mv.visitJumpInsn(Opcodes.GOTO, endLabel);
}
mv.visitLabel(endLabel);
return result;
}
/**
* Emit bytecode for the loop idiom.
* <p>

View File

@ -314,6 +314,7 @@ class LambdaForm {
GET_DOUBLE_VOLATILE("getDoubleVolatile"),
PUT_DOUBLE_VOLATILE("putDoubleVolatile"),
TRY_FINALLY("tryFinally"),
TABLE_SWITCH("tableSwitch"),
COLLECT("collect"),
COLLECTOR("collector"),
CONVERT("convert"),
@ -707,6 +708,32 @@ class LambdaForm {
return isMatchingIdiom(pos, "tryFinally", 2);
}
/**
* Check if i-th name is a start of the tableSwitch idiom.
*/
boolean isTableSwitch(int pos) {
// tableSwitch idiom:
// t_{n}:L=MethodHandle.invokeBasic(...) // args
// t_{n+1}:L=MethodHandleImpl.tableSwitch(*, *, *, t_{n})
// t_{n+2}:?=MethodHandle.invokeBasic(*, t_{n+1})
if (pos + 2 >= names.length) return false;
final int POS_COLLECT_ARGS = pos;
final int POS_TABLE_SWITCH = pos + 1;
final int POS_UNBOX_RESULT = pos + 2;
Name collectArgs = names[POS_COLLECT_ARGS];
Name tableSwitch = names[POS_TABLE_SWITCH];
Name unboxResult = names[POS_UNBOX_RESULT];
return tableSwitch.refersTo(MethodHandleImpl.class, "tableSwitch") &&
collectArgs.isInvokeBasic() &&
unboxResult.isInvokeBasic() &&
tableSwitch.lastUseIndex(collectArgs) == 3 && // t_{n+1}:L=MethodHandleImpl.<invoker>(*, *, *, t_{n});
lastUseIndex(collectArgs) == POS_TABLE_SWITCH && // t_{n} is local: used only in t_{n+1}
unboxResult.lastUseIndex(tableSwitch) == 1 && // t_{n+2}:?=MethodHandle.invokeBasic(*, t_{n+1})
lastUseIndex(tableSwitch) == POS_UNBOX_RESULT; // t_{n+1} is local: used only in t_{n+2}
}
/**
* Check if i-th name is a start of the loop idiom.
*/
@ -1067,24 +1094,13 @@ class LambdaForm {
final MemberName member;
private @Stable MethodHandle resolvedHandle;
@Stable MethodHandle invoker;
private final MethodHandleImpl.Intrinsic intrinsicName;
NamedFunction(MethodHandle resolvedHandle) {
this(resolvedHandle.internalMemberName(), resolvedHandle, MethodHandleImpl.Intrinsic.NONE);
}
NamedFunction(MethodHandle resolvedHandle, MethodHandleImpl.Intrinsic intrinsic) {
this(resolvedHandle.internalMemberName(), resolvedHandle, intrinsic);
this(resolvedHandle.internalMemberName(), resolvedHandle);
}
NamedFunction(MemberName member, MethodHandle resolvedHandle) {
this(member, resolvedHandle, MethodHandleImpl.Intrinsic.NONE);
}
NamedFunction(MemberName member, MethodHandle resolvedHandle, MethodHandleImpl.Intrinsic intrinsic) {
this.member = member;
this.resolvedHandle = resolvedHandle;
this.intrinsicName = intrinsic;
assert(resolvedHandle == null ||
resolvedHandle.intrinsicName() == MethodHandleImpl.Intrinsic.NONE ||
resolvedHandle.intrinsicName() == intrinsic) : resolvedHandle.intrinsicName() + " != " + intrinsic;
// The following assert is almost always correct, but will fail for corner cases, such as PrivateInvokeTest.
//assert(!isInvokeBasic(member));
}
@ -1097,7 +1113,6 @@ class LambdaForm {
// necessary to pass BigArityTest
this.member = Invokers.invokeBasicMethod(basicInvokerType);
}
this.intrinsicName = MethodHandleImpl.Intrinsic.NONE;
assert(isInvokeBasic(member));
}
@ -1250,7 +1265,15 @@ class LambdaForm {
}
public MethodHandleImpl.Intrinsic intrinsicName() {
return intrinsicName;
return resolvedHandle != null
? resolvedHandle.intrinsicName()
: MethodHandleImpl.Intrinsic.NONE;
}
public Object intrinsicData() {
return resolvedHandle != null
? resolvedHandle.intrinsicData()
: null;
}
}
@ -1732,15 +1755,15 @@ class LambdaForm {
Name[] idNames = new Name[] { argument(0, L_TYPE), argument(1, type) };
idForm = new LambdaForm(2, idNames, 1, Kind.IDENTITY);
idForm.compileToBytecode();
idFun = new NamedFunction(idMem, SimpleMethodHandle.make(idMem.getInvocationType(), idForm),
MethodHandleImpl.Intrinsic.IDENTITY);
idFun = new NamedFunction(idMem, MethodHandleImpl.makeIntrinsic(SimpleMethodHandle.make(idMem.getInvocationType(), idForm),
MethodHandleImpl.Intrinsic.IDENTITY));
Object zeValue = Wrapper.forBasicType(btChar).zero();
Name[] zeNames = new Name[] { argument(0, L_TYPE), new Name(idFun, zeValue) };
zeForm = new LambdaForm(1, zeNames, 1, Kind.ZERO);
zeForm.compileToBytecode();
zeFun = new NamedFunction(zeMem, SimpleMethodHandle.make(zeMem.getInvocationType(), zeForm),
MethodHandleImpl.Intrinsic.ZERO);
zeFun = new NamedFunction(zeMem, MethodHandleImpl.makeIntrinsic(SimpleMethodHandle.make(zeMem.getInvocationType(), zeForm),
MethodHandleImpl.Intrinsic.ZERO));
}
LF_zero[ord] = zeForm;

View File

@ -38,6 +38,7 @@ import static java.lang.invoke.LambdaForm.*;
import static java.lang.invoke.LambdaForm.BasicType.*;
import static java.lang.invoke.MethodHandleImpl.Intrinsic;
import static java.lang.invoke.MethodHandleImpl.NF_loop;
import static java.lang.invoke.MethodHandleImpl.makeIntrinsic;
/** Transforms on LFs.
* A lambda-form editor can derive new LFs from its base LF.
@ -619,7 +620,7 @@ class LambdaFormEditor {
// adjust the arguments
MethodHandle aload = MethodHandles.arrayElementGetter(erasedArrayType);
for (int i = 0; i < arrayLength; i++) {
Name loadArgument = new Name(new NamedFunction(aload, Intrinsic.ARRAY_LOAD), spreadParam, i);
Name loadArgument = new Name(new NamedFunction(makeIntrinsic(aload, Intrinsic.ARRAY_LOAD)), spreadParam, i);
buf.insertExpression(exprPos + i, loadArgument);
buf.replaceParameterByCopy(pos + i, exprPos + i);
}

View File

@ -1679,6 +1679,11 @@ assertEquals("[three, thee, tee]", asListFix.invoke((Object)argv).toString());
return MethodHandleImpl.Intrinsic.NONE;
}
/*non-public*/
Object intrinsicData() {
return null;
}
/*non-public*/
MethodHandle withInternalMemberName(MemberName member, boolean isInvokeSpecial) {
if (member != null) {

View File

@ -49,6 +49,8 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Stream;
@ -823,7 +825,9 @@ abstract class MethodHandleImpl {
names[PROFILE] = new Name(getFunction(NF_profileBoolean), names[CALL_TEST], names[GET_COUNTERS]);
}
// call selectAlternative
names[SELECT_ALT] = new Name(new NamedFunction(getConstantHandle(MH_selectAlternative), Intrinsic.SELECT_ALTERNATIVE), names[TEST], names[GET_TARGET], names[GET_FALLBACK]);
names[SELECT_ALT] = new Name(new NamedFunction(
makeIntrinsic(getConstantHandle(MH_selectAlternative), Intrinsic.SELECT_ALTERNATIVE)),
names[TEST], names[GET_TARGET], names[GET_FALLBACK]);
// call target or fallback
invokeArgs[0] = names[SELECT_ALT];
@ -894,7 +898,7 @@ abstract class MethodHandleImpl {
Object[] args = new Object[invokeBasic.type().parameterCount()];
args[0] = names[GET_COLLECT_ARGS];
System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT-ARG_BASE);
names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.GUARD_WITH_CATCH), args);
names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.GUARD_WITH_CATCH)), args);
// t_{i+1}:L=MethodHandleImpl.guardWithCatch(target:L,exType:L,catcher:L,t_{i}:L);
Object[] gwcArgs = new Object[] {names[GET_TARGET], names[GET_CLASS], names[GET_CATCHER], names[BOXED_ARGS]};
@ -1226,6 +1230,7 @@ abstract class MethodHandleImpl {
SELECT_ALTERNATIVE,
GUARD_WITH_CATCH,
TRY_FINALLY,
TABLE_SWITCH,
LOOP,
ARRAY_LOAD,
ARRAY_STORE,
@ -1240,11 +1245,17 @@ abstract class MethodHandleImpl {
static final class IntrinsicMethodHandle extends DelegatingMethodHandle {
private final MethodHandle target;
private final Intrinsic intrinsicName;
private final Object intrinsicData;
IntrinsicMethodHandle(MethodHandle target, Intrinsic intrinsicName) {
this(target, intrinsicName, null);
}
IntrinsicMethodHandle(MethodHandle target, Intrinsic intrinsicName, Object intrinsicData) {
super(target.type(), target);
this.target = target;
this.intrinsicName = intrinsicName;
this.intrinsicData = intrinsicData;
}
@Override
@ -1257,6 +1268,11 @@ abstract class MethodHandleImpl {
return intrinsicName;
}
@Override
Object intrinsicData() {
return intrinsicData;
}
@Override
public MethodHandle asTypeUncached(MethodType newType) {
// This MH is an alias for target, except for the intrinsic name
@ -1282,9 +1298,13 @@ abstract class MethodHandleImpl {
}
static MethodHandle makeIntrinsic(MethodHandle target, Intrinsic intrinsicName) {
return makeIntrinsic(target, intrinsicName, null);
}
static MethodHandle makeIntrinsic(MethodHandle target, Intrinsic intrinsicName, Object intrinsicData) {
if (intrinsicName == target.intrinsicName())
return target;
return new IntrinsicMethodHandle(target, intrinsicName);
return new IntrinsicMethodHandle(target, intrinsicName, intrinsicData);
}
static MethodHandle makeIntrinsic(MethodType type, LambdaForm form, Intrinsic intrinsicName) {
@ -1360,7 +1380,8 @@ abstract class MethodHandleImpl {
NF_tryFinally = 3,
NF_loop = 4,
NF_profileBoolean = 5,
NF_LIMIT = 6;
NF_tableSwitch = 6,
NF_LIMIT = 7;
private static final @Stable NamedFunction[] NFS = new NamedFunction[NF_LIMIT];
@ -1394,6 +1415,9 @@ abstract class MethodHandleImpl {
case NF_profileBoolean:
return new NamedFunction(MethodHandleImpl.class
.getDeclaredMethod("profileBoolean", boolean.class, int[].class));
case NF_tableSwitch:
return new NamedFunction(MethodHandleImpl.class
.getDeclaredMethod("tableSwitch", int.class, MethodHandle.class, CasesHolder.class, Object[].class));
default:
throw new InternalError("Undefined function: " + func);
}
@ -1602,7 +1626,7 @@ abstract class MethodHandleImpl {
Object[] args = new Object[invokeBasic.type().parameterCount()];
args[0] = names[GET_COLLECT_ARGS];
System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT - ARG_BASE);
names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.LOOP), args);
names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.LOOP)), args);
// t_{i+1}:L=MethodHandleImpl.loop(localTypes:L,clauses:L,t_{i}:L);
Object[] lArgs =
@ -1839,7 +1863,7 @@ abstract class MethodHandleImpl {
Object[] args = new Object[invokeBasic.type().parameterCount()];
args[0] = names[GET_COLLECT_ARGS];
System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT-ARG_BASE);
names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.TRY_FINALLY), args);
names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.TRY_FINALLY)), args);
// t_{i+1}:L=MethodHandleImpl.tryFinally(target:L,exType:L,catcher:L,t_{i}:L);
Object[] tfArgs = new Object[] {names[GET_TARGET], names[GET_CLEANUP], names[BOXED_ARGS]};
@ -1941,7 +1965,7 @@ abstract class MethodHandleImpl {
storeNameCursor < STORE_ELEMENT_LIMIT;
storeIndex++, storeNameCursor++, argCursor++){
names[storeNameCursor] = new Name(new NamedFunction(storeFunc, Intrinsic.ARRAY_STORE),
names[storeNameCursor] = new Name(new NamedFunction(makeIntrinsic(storeFunc, Intrinsic.ARRAY_STORE)),
names[CALL_NEW_ARRAY], storeIndex, names[argCursor]);
}
@ -1952,6 +1976,141 @@ abstract class MethodHandleImpl {
return lform;
}
// use a wrapper because we need this array to be @Stable
static class CasesHolder {
@Stable
final MethodHandle[] cases;
public CasesHolder(MethodHandle[] cases) {
this.cases = cases;
}
}
static MethodHandle makeTableSwitch(MethodType type, MethodHandle defaultCase, MethodHandle[] caseActions) {
MethodType varargsType = type.changeReturnType(Object[].class);
MethodHandle collectArgs = varargsArray(type.parameterCount()).asType(varargsType);
MethodHandle unboxResult = unboxResultHandle(type.returnType());
BoundMethodHandle.SpeciesData data = BoundMethodHandle.speciesData_LLLL();
LambdaForm form = makeTableSwitchForm(type.basicType(), data, caseActions.length);
BoundMethodHandle mh;
CasesHolder caseHolder = new CasesHolder(caseActions);
try {
mh = (BoundMethodHandle) data.factory().invokeBasic(type, form, (Object) defaultCase, (Object) collectArgs,
(Object) unboxResult, (Object) caseHolder);
} catch (Throwable ex) {
throw uncaughtException(ex);
}
assert(mh.type() == type);
return mh;
}
private static class TableSwitchCacheKey {
private static final Map<TableSwitchCacheKey, LambdaForm> CACHE = new ConcurrentHashMap<>();
private final MethodType basicType;
private final int numberOfCases;
public TableSwitchCacheKey(MethodType basicType, int numberOfCases) {
this.basicType = basicType;
this.numberOfCases = numberOfCases;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TableSwitchCacheKey that = (TableSwitchCacheKey) o;
return numberOfCases == that.numberOfCases && Objects.equals(basicType, that.basicType);
}
@Override
public int hashCode() {
return Objects.hash(basicType, numberOfCases);
}
}
private static LambdaForm makeTableSwitchForm(MethodType basicType, BoundMethodHandle.SpeciesData data,
int numCases) {
MethodType lambdaType = basicType.invokerType();
// We need to cache based on the basic type X number of cases,
// since the number of cases is used when generating bytecode.
// This also means that we can't use the cache in MethodTypeForm,
// which only uses the basic type as a key.
TableSwitchCacheKey key = new TableSwitchCacheKey(basicType, numCases);
LambdaForm lform = TableSwitchCacheKey.CACHE.get(key);
if (lform != null) {
return lform;
}
final int THIS_MH = 0;
final int ARG_BASE = 1; // start of incoming arguments
final int ARG_LIMIT = ARG_BASE + basicType.parameterCount();
final int ARG_SWITCH_ON = ARG_BASE;
assert ARG_SWITCH_ON < ARG_LIMIT;
int nameCursor = ARG_LIMIT;
final int GET_COLLECT_ARGS = nameCursor++;
final int GET_DEFAULT_CASE = nameCursor++;
final int GET_UNBOX_RESULT = nameCursor++;
final int GET_CASES = nameCursor++;
final int BOXED_ARGS = nameCursor++;
final int TABLE_SWITCH = nameCursor++;
final int UNBOXED_RESULT = nameCursor++;
int fieldCursor = 0;
final int FIELD_DEFAULT_CASE = fieldCursor++;
final int FIELD_COLLECT_ARGS = fieldCursor++;
final int FIELD_UNBOX_RESULT = fieldCursor++;
final int FIELD_CASES = fieldCursor++;
Name[] names = arguments(nameCursor - ARG_LIMIT, lambdaType);
names[THIS_MH] = names[THIS_MH].withConstraint(data);
names[GET_DEFAULT_CASE] = new Name(data.getterFunction(FIELD_DEFAULT_CASE), names[THIS_MH]);
names[GET_COLLECT_ARGS] = new Name(data.getterFunction(FIELD_COLLECT_ARGS), names[THIS_MH]);
names[GET_UNBOX_RESULT] = new Name(data.getterFunction(FIELD_UNBOX_RESULT), names[THIS_MH]);
names[GET_CASES] = new Name(data.getterFunction(FIELD_CASES), names[THIS_MH]);
{
MethodType collectArgsType = basicType.changeReturnType(Object.class);
MethodHandle invokeBasic = MethodHandles.basicInvoker(collectArgsType);
Object[] args = new Object[invokeBasic.type().parameterCount()];
args[0] = names[GET_COLLECT_ARGS];
System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT - ARG_BASE);
names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.TABLE_SWITCH, numCases)), args);
}
{
Object[] tfArgs = new Object[]{
names[ARG_SWITCH_ON], names[GET_DEFAULT_CASE], names[GET_CASES], names[BOXED_ARGS]};
names[TABLE_SWITCH] = new Name(getFunction(NF_tableSwitch), tfArgs);
}
{
MethodHandle invokeBasic = MethodHandles.basicInvoker(MethodType.methodType(basicType.rtype(), Object.class));
Object[] unboxArgs = new Object[]{names[GET_UNBOX_RESULT], names[TABLE_SWITCH]};
names[UNBOXED_RESULT] = new Name(invokeBasic, unboxArgs);
}
lform = new LambdaForm(lambdaType.parameterCount(), names, Kind.TABLE_SWITCH);
LambdaForm prev = TableSwitchCacheKey.CACHE.putIfAbsent(key, lform);
return prev != null ? prev : lform;
}
@Hidden
static Object tableSwitch(int input, MethodHandle defaultCase, CasesHolder holder, Object[] args) throws Throwable {
MethodHandle[] caseActions = holder.cases;
MethodHandle selectedCase;
if (input < 0 || input >= caseActions.length) {
selectedCase = defaultCase;
} else {
selectedCase = caseActions[input];
}
return selectedCase.invokeWithArguments(args);
}
// Indexes into constant method handles:
static final int
MH_cast = 0,

View File

@ -7751,4 +7751,90 @@ assertEquals("boojum", (String) catTrace.invokeExact("boo", "jum"));
}
}
/**
* Creates a table switch method handle, which can be used to switch over a set of target
* method handles, based on a given target index, called selector.
* <p>
* For a selector value of {@code n}, where {@code n} falls in the range {@code [0, N)},
* and where {@code N} is the number of target method handles, the table switch method
* handle will invoke the n-th target method handle from the list of target method handles.
* <p>
* For a selector value that does not fall in the range {@code [0, N)}, the table switch
* method handle will invoke the given fallback method handle.
* <p>
* All method handles passed to this method must have the same type, with the additional
* requirement that the leading parameter be of type {@code int}. The leading parameter
* represents the selector.
* <p>
* Any trailing parameters present in the type will appear on the returned table switch
* method handle as well. Any arguments assigned to these parameters will be forwarded,
* together with the selector value, to the selected method handle when invoking it.
*
* @apiNote Example:
* The cases each drop the {@code selector} value they are given, and take an additional
* {@code String} argument, which is concatenated (using {@link String#concat(String)})
* to a specific constant label string for each case:
* <blockquote><pre>{@code
* MethodHandles.Lookup lookup = MethodHandles.lookup();
* MethodHandle caseMh = lookup.findVirtual(String.class, "concat",
* MethodType.methodType(String.class, String.class));
* caseMh = MethodHandles.dropArguments(caseMh, 0, int.class);
*
* MethodHandle caseDefault = MethodHandles.insertArguments(caseMh, 1, "default: ");
* MethodHandle case0 = MethodHandles.insertArguments(caseMh, 1, "case 0: ");
* MethodHandle case1 = MethodHandles.insertArguments(caseMh, 1, "case 1: ");
*
* MethodHandle mhSwitch = MethodHandles.tableSwitch(
* caseDefault,
* case0,
* case1
* );
*
* assertEquals("default: data", (String) mhSwitch.invokeExact(-1, "data"));
* assertEquals("case 0: data", (String) mhSwitch.invokeExact(0, "data"));
* assertEquals("case 1: data", (String) mhSwitch.invokeExact(1, "data"));
* assertEquals("default: data", (String) mhSwitch.invokeExact(2, "data"));
* }</pre></blockquote>
*
* @param fallback the fallback method handle that is called when the selector is not
* within the range {@code [0, N)}.
* @param targets array of target method handles.
* @return the table switch method handle.
* @throws NullPointerException if {@code fallback}, the {@code targets} array, or any
* any of the elements of the {@code targets} array are
* {@code null}.
* @throws IllegalArgumentException if the {@code targets} array is empty, if the leading
* parameter of the fallback handle or any of the target
* handles is not {@code int}, or if the types of
* the fallback handle and all of target handles are
* not the same.
*/
public static MethodHandle tableSwitch(MethodHandle fallback, MethodHandle... targets) {
Objects.requireNonNull(fallback);
Objects.requireNonNull(targets);
targets = targets.clone();
MethodType type = tableSwitchChecks(fallback, targets);
return MethodHandleImpl.makeTableSwitch(type, fallback, targets);
}
private static MethodType tableSwitchChecks(MethodHandle defaultCase, MethodHandle[] caseActions) {
if (caseActions.length == 0)
throw new IllegalArgumentException("Not enough cases: " + Arrays.toString(caseActions));
MethodType expectedType = defaultCase.type();
if (!(expectedType.parameterCount() >= 1) || expectedType.parameterType(0) != int.class)
throw new IllegalArgumentException(
"Case actions must have int as leading parameter: " + Arrays.toString(caseActions));
for (MethodHandle mh : caseActions) {
Objects.requireNonNull(mh);
if (mh.type() != expectedType)
throw new IllegalArgumentException(
"Case actions must have the same type: " + Arrays.toString(caseActions));
}
return expectedType;
}
}

View File

@ -0,0 +1,234 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @run testng/othervm -Xverify:all TestTableSwitch
*/
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import javax.management.ObjectName;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import static org.testng.Assert.assertEquals;
public class TestTableSwitch {
static final MethodHandle MH_IntConsumer_accept;
static final MethodHandle MH_check;
static {
try {
MethodHandles.Lookup lookup = MethodHandles.lookup();
MH_IntConsumer_accept = lookup.findVirtual(IntConsumer.class, "accept",
MethodType.methodType(void.class, int.class));
MH_check = lookup.findStatic(TestTableSwitch.class, "check",
MethodType.methodType(void.class, List.class, Object[].class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
public static MethodHandle simpleTestCase(String value) {
return simpleTestCase(String.class, value);
}
public static MethodHandle simpleTestCase(Class<?> type, Object value) {
return MethodHandles.dropArguments(MethodHandles.constant(type, value), 0, int.class);
}
public static Object testValue(Class<?> type) {
if (type == String.class) {
return "X";
} else if (type == byte.class) {
return (byte) 42;
} else if (type == short.class) {
return (short) 84;
} else if (type == char.class) {
return 'Y';
} else if (type == int.class) {
return 168;
} else if (type == long.class) {
return 336L;
} else if (type == float.class) {
return 42F;
} else if (type == double.class) {
return 84D;
} else if (type == boolean.class) {
return true;
}
return null;
}
static final Class<?>[] TEST_TYPES = {
Object.class,
String.class,
byte.class,
short.class,
char.class,
int.class,
long.class,
float.class,
double.class,
boolean.class
};
public static Object[] testArguments(int caseNum, List<Object> testValues) {
Object[] args = new Object[testValues.size() + 1];
args[0] = caseNum;
int insertPos = 1;
for (Object testValue : testValues) {
args[insertPos++] = testValue;
}
return args;
}
@DataProvider
public static Object[][] nonVoidCases() {
List<Object[]> tests = new ArrayList<>();
for (Class<?> returnType : TEST_TYPES) {
for (int numCases = 1; numCases < 5; numCases++) {
tests.add(new Object[] { returnType, numCases, List.of() });
tests.add(new Object[] { returnType, numCases, List.of(TEST_TYPES) });
}
}
return tests.toArray(Object[][]::new);
}
private static void check(List<Object> testValues, Object[] collectedValues) {
assertEquals(collectedValues, testValues.toArray());
}
@Test(dataProvider = "nonVoidCases")
public void testNonVoidHandles(Class<?> type, int numCases, List<Class<?>> additionalTypes) throws Throwable {
MethodHandle collector = MH_check;
List<Object> testArguments = new ArrayList<>();
collector = MethodHandles.insertArguments(collector, 0, testArguments);
collector = collector.asCollector(Object[].class, additionalTypes.size());
Object defaultReturnValue = testValue(type);
MethodHandle defaultCase = simpleTestCase(type, defaultReturnValue);
defaultCase = MethodHandles.collectArguments(defaultCase, 1, collector);
Object[] returnValues = new Object[numCases];
MethodHandle[] cases = new MethodHandle[numCases];
for (int i = 0; i < cases.length; i++) {
Object returnValue = testValue(type);
returnValues[i] = returnValue;
MethodHandle theCase = simpleTestCase(type, returnValue);
theCase = MethodHandles.collectArguments(theCase, 1, collector);
cases[i] = theCase;
}
MethodHandle mhSwitch = MethodHandles.tableSwitch(
defaultCase,
cases
);
for (Class<?> additionalType : additionalTypes) {
testArguments.add(testValue(additionalType));
}
assertEquals(mhSwitch.invokeWithArguments(testArguments(-1, testArguments)), defaultReturnValue);
for (int i = 0; i < numCases; i++) {
assertEquals(mhSwitch.invokeWithArguments(testArguments(i, testArguments)), returnValues[i]);
}
assertEquals(mhSwitch.invokeWithArguments(testArguments(numCases, testArguments)), defaultReturnValue);
}
@Test
public void testVoidHandles() throws Throwable {
IntFunction<MethodHandle> makeTestCase = expectedIndex -> {
IntConsumer test = actualIndex -> assertEquals(actualIndex, expectedIndex);
return MH_IntConsumer_accept.bindTo(test);
};
MethodHandle mhSwitch = MethodHandles.tableSwitch(
/* default: */ makeTestCase.apply(-1),
/* case 0: */ makeTestCase.apply(0),
/* case 1: */ makeTestCase.apply(1),
/* case 2: */ makeTestCase.apply(2)
);
mhSwitch.invokeExact((int) -1);
mhSwitch.invokeExact((int) 0);
mhSwitch.invokeExact((int) 1);
mhSwitch.invokeExact((int) 2);
}
@Test(expectedExceptions = NullPointerException.class)
public void testNullDefaultHandle() {
MethodHandles.tableSwitch(null, simpleTestCase("test"));
}
@Test(expectedExceptions = NullPointerException.class)
public void testNullCases() {
MethodHandle[] cases = null;
MethodHandles.tableSwitch(simpleTestCase("default"), cases);
}
@Test(expectedExceptions = NullPointerException.class)
public void testNullCase() {
MethodHandles.tableSwitch(simpleTestCase("default"), simpleTestCase("case"), null);
}
@Test(expectedExceptions = IllegalArgumentException.class,
expectedExceptionsMessageRegExp = ".*Not enough cases.*")
public void testNotEnoughCases() {
MethodHandles.tableSwitch(simpleTestCase("default"));
}
@Test(expectedExceptions = IllegalArgumentException.class,
expectedExceptionsMessageRegExp = ".*Case actions must have int as leading parameter.*")
public void testNotEnoughParameters() {
MethodHandle empty = MethodHandles.empty(MethodType.methodType(void.class));
MethodHandles.tableSwitch(empty, empty, empty);
}
@Test(expectedExceptions = IllegalArgumentException.class,
expectedExceptionsMessageRegExp = ".*Case actions must have int as leading parameter.*")
public void testNoLeadingIntParameter() {
MethodHandle empty = MethodHandles.empty(MethodType.methodType(void.class, double.class));
MethodHandles.tableSwitch(empty, empty, empty);
}
@Test(expectedExceptions = IllegalArgumentException.class,
expectedExceptionsMessageRegExp = ".*Case actions must have the same type.*")
public void testWrongCaseType() {
// doesn't return a String
MethodHandle wrongType = MethodHandles.empty(MethodType.methodType(void.class, int.class));
MethodHandles.tableSwitch(simpleTestCase("default"), simpleTestCase("case"), wrongType);
}
}

View File

@ -0,0 +1,135 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.lang.invoke;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.MutableCallSite;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(org.openjdk.jmh.annotations.Scope.Thread)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(3)
public class MethodHandlesTableSwitchConstant {
// Switch combinator test for a single constant input index
private static final MethodType callType = MethodType.methodType(int.class, int.class);
private static final MutableCallSite cs = new MutableCallSite(callType);
private static final MethodHandle target = cs.dynamicInvoker();
private static final MutableCallSite csInput = new MutableCallSite(MethodType.methodType(int.class));
private static final MethodHandle targetInput = csInput.dynamicInvoker();
private static final MethodHandle MH_SUBTRACT;
private static final MethodHandle MH_DEFAULT;
private static final MethodHandle MH_PAYLOAD;
static {
try {
MH_SUBTRACT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "subtract",
MethodType.methodType(int.class, int.class, int.class));
MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "defaultCase",
MethodType.methodType(int.class, int.class));
MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "payload",
MethodType.methodType(int.class, int.class, int.class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
// Using batch size since we really need a per-invocation setup
// but the measured code is too fast. Using JMH batch size doesn't work
// since there is no way to do a batch-level setup as well.
private static final int BATCH_SIZE = 1_000_000;
@Param({
"5",
"10",
"25"
})
public int numCases;
@Param({
"0",
"150"
})
public int offset;
@Setup(Level.Trial)
public void setupTrial() throws Throwable {
MethodHandle[] cases = IntStream.range(0, numCases)
.mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i))
.toArray(MethodHandle[]::new);
MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases);
if (offset != 0) {
switcher = MethodHandles.filterArguments(switcher, 0, MethodHandles.insertArguments(MH_SUBTRACT, 1, offset));
}
cs.setTarget(switcher);
int input = ThreadLocalRandom.current().nextInt(numCases) + offset;
csInput.setTarget(MethodHandles.constant(int.class, input));
}
private static int payload(int dropped, int constant) {
return constant;
}
private static int subtract(int a, int b) {
return a - b;
}
private static int defaultCase(int x) {
throw new IllegalStateException();
}
@Benchmark
public void testSwitch(Blackhole bh) throws Throwable {
for (int i = 0; i < BATCH_SIZE; i++) {
bh.consume((int) target.invokeExact((int) targetInput.invokeExact()));
}
}
}

View File

@ -0,0 +1,116 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.lang.invoke;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.MutableCallSite;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(org.openjdk.jmh.annotations.Scope.Thread)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(3)
public class MethodHandlesTableSwitchOpaqueSingle {
// Switch combinator test for a single input index, but opaquely fed in, so the JIT
// does not see it as a constant.
private static final MethodType callType = MethodType.methodType(int.class, int.class);
private static final MutableCallSite cs = new MutableCallSite(callType);
private static final MethodHandle target = cs.dynamicInvoker();
private static final MethodHandle MH_DEFAULT;
private static final MethodHandle MH_PAYLOAD;
static {
try {
MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchOpaqueSingle.class, "defaultCase",
MethodType.methodType(int.class, int.class));
MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchOpaqueSingle.class, "payload",
MethodType.methodType(int.class, int.class, int.class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
// Using batch size since we really need a per-invocation setup
// but the measured code is too fast. Using JMH batch size doesn't work
// since there is no way to do a batch-level setup as well.
private static final int BATCH_SIZE = 1_000_000;
@Param({
"5",
"10",
"25"
})
public int numCases;
public int input;
@Setup(Level.Trial)
public void setupTrial() throws Throwable {
MethodHandle[] cases = IntStream.range(0, numCases)
.mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i))
.toArray(MethodHandle[]::new);
MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases);
cs.setTarget(switcher);
input = ThreadLocalRandom.current().nextInt(numCases);
}
private static int payload(int dropped, int constant) {
return constant;
}
private static int defaultCase(int x) {
throw new IllegalStateException();
}
@Benchmark
public void testSwitch(Blackhole bh) throws Throwable {
for (int i = 0; i < BATCH_SIZE; i++) {
bh.consume((int) target.invokeExact(input));
}
}
}

View File

@ -0,0 +1,131 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.lang.invoke;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.MutableCallSite;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(org.openjdk.jmh.annotations.Scope.Thread)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(3)
public class MethodHandlesTableSwitchRandom {
// Switch combinator test for a random input index, testing several switch sizes
private static final MethodType callType = MethodType.methodType(int.class, int.class);
private static final MutableCallSite cs = new MutableCallSite(callType);
private static final MethodHandle target = cs.dynamicInvoker();
private static final MethodHandle MH_DEFAULT;
private static final MethodHandle MH_PAYLOAD;
static {
try {
MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchRandom.class, "defaultCase",
MethodType.methodType(int.class, int.class));
MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchRandom.class, "payload",
MethodType.methodType(int.class, int.class, int.class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
// Using batch size since we really need a per-invocation setup
// but the measured code is too fast. Using JMH batch size doesn't work
// since there is no way to do a batch-level setup as well.
private static final int BATCH_SIZE = 1_000_000;
@Param({
"5",
"10",
"25"
})
public int numCases;
@Param({
"true",
"false"
})
public boolean sorted;
public int[] inputs;
@Setup(Level.Trial)
public void setupTrial() throws Throwable {
MethodHandle[] cases = IntStream.range(0, numCases)
.mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i))
.toArray(MethodHandle[]::new);
MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases);
cs.setTarget(switcher);
inputs = new int[BATCH_SIZE];
Random rand = new Random(0);
for (int i = 0; i < BATCH_SIZE; i++) {
inputs[i] = rand.nextInt(numCases);
}
if (sorted) {
Arrays.sort(inputs);
}
}
private static int payload(int dropped, int constant) {
return constant;
}
private static int defaultCase(int x) {
throw new IllegalStateException();
}
@Benchmark
public void testSwitch(Blackhole bh) throws Throwable {
for (int i = 0; i < inputs.length; i++) {
bh.consume((int) target.invokeExact(inputs[i]));
}
}
}