diff --git a/src/java.base/share/classes/java/lang/invoke/InvokerBytecodeGenerator.java b/src/java.base/share/classes/java/lang/invoke/InvokerBytecodeGenerator.java index 42d8ad0b663..0e1cac19720 100644 --- a/src/java.base/share/classes/java/lang/invoke/InvokerBytecodeGenerator.java +++ b/src/java.base/share/classes/java/lang/invoke/InvokerBytecodeGenerator.java @@ -864,6 +864,12 @@ class InvokerBytecodeGenerator { onStack = emitTryFinally(i); i += 2; // jump to the end of the TF idiom continue; + case TABLE_SWITCH: + assert lambdaForm.isTableSwitch(i); + int numCases = (Integer) name.function.intrinsicData(); + onStack = emitTableSwitch(i, numCases); + i += 2; // jump to the end of the TS idiom + continue; case LOOP: assert lambdaForm.isLoop(i); onStack = emitLoop(i); @@ -1389,6 +1395,58 @@ class InvokerBytecodeGenerator { } } + private Name emitTableSwitch(int pos, int numCases) { + Name args = lambdaForm.names[pos]; + Name invoker = lambdaForm.names[pos + 1]; + Name result = lambdaForm.names[pos + 2]; + + Class<?> returnType = result.function.resolvedHandle().type().returnType(); + MethodType caseType = args.function.resolvedHandle().type() + .dropParameterTypes(0, 1) // drop collector + .changeReturnType(returnType); + String caseDescriptor = caseType.basicType().toMethodDescriptorString(); + + emitPushArgument(invoker, 2); // push cases + mv.visitFieldInsn(Opcodes.GETFIELD, "java/lang/invoke/MethodHandleImpl$CasesHolder", "cases", + "[Ljava/lang/invoke/MethodHandle;"); + int casesLocal = extendLocalsMap(new Class<?>[] { MethodHandle[].class }); + emitStoreInsn(L_TYPE, casesLocal); + + Label endLabel = new Label(); + Label defaultLabel = new Label(); + Label[] caseLabels = new Label[numCases]; + for (int i = 0; i < caseLabels.length; i++) { + caseLabels[i] = new Label(); + } + + emitPushArgument(invoker, 0); // push switch input + mv.visitTableSwitchInsn(0, numCases - 1, defaultLabel, caseLabels); + + mv.visitLabel(defaultLabel); + emitPushArgument(invoker, 1); // push default handle + emitPushArguments(args, 1); // again, skip collector + mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, MH, "invokeBasic", caseDescriptor, false); + mv.visitJumpInsn(Opcodes.GOTO, endLabel); + + for (int i = 0; i < numCases; i++) { + mv.visitLabel(caseLabels[i]); + // Load the particular case: + emitLoadInsn(L_TYPE, casesLocal); + emitIconstInsn(i); + mv.visitInsn(Opcodes.AALOAD); + + // invoke it: + emitPushArguments(args, 1); // again, skip collector + mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, MH, "invokeBasic", caseDescriptor, false); + + mv.visitJumpInsn(Opcodes.GOTO, endLabel); + } + + mv.visitLabel(endLabel); + + return result; + } + /** * Emit bytecode for the loop idiom. * <p> diff --git a/src/java.base/share/classes/java/lang/invoke/LambdaForm.java b/src/java.base/share/classes/java/lang/invoke/LambdaForm.java index 639edec90dd..d0480cd7e0f 100644 --- a/src/java.base/share/classes/java/lang/invoke/LambdaForm.java +++ b/src/java.base/share/classes/java/lang/invoke/LambdaForm.java @@ -314,6 +314,7 @@ class LambdaForm { GET_DOUBLE_VOLATILE("getDoubleVolatile"), PUT_DOUBLE_VOLATILE("putDoubleVolatile"), TRY_FINALLY("tryFinally"), + TABLE_SWITCH("tableSwitch"), COLLECT("collect"), COLLECTOR("collector"), CONVERT("convert"), @@ -707,6 +708,32 @@ class LambdaForm { return isMatchingIdiom(pos, "tryFinally", 2); } + /** + * Check if i-th name is a start of the tableSwitch idiom. + */ + boolean isTableSwitch(int pos) { + // tableSwitch idiom: + // t_{n}:L=MethodHandle.invokeBasic(...) // args + // t_{n+1}:L=MethodHandleImpl.tableSwitch(*, *, *, t_{n}) + // t_{n+2}:?=MethodHandle.invokeBasic(*, t_{n+1}) + if (pos + 2 >= names.length) return false; + + final int POS_COLLECT_ARGS = pos; + final int POS_TABLE_SWITCH = pos + 1; + final int POS_UNBOX_RESULT = pos + 2; + + Name collectArgs = names[POS_COLLECT_ARGS]; + Name tableSwitch = names[POS_TABLE_SWITCH]; + Name unboxResult = names[POS_UNBOX_RESULT]; + return tableSwitch.refersTo(MethodHandleImpl.class, "tableSwitch") && + collectArgs.isInvokeBasic() && + unboxResult.isInvokeBasic() && + tableSwitch.lastUseIndex(collectArgs) == 3 && // t_{n+1}:L=MethodHandleImpl.<invoker>(*, *, *, t_{n}); + lastUseIndex(collectArgs) == POS_TABLE_SWITCH && // t_{n} is local: used only in t_{n+1} + unboxResult.lastUseIndex(tableSwitch) == 1 && // t_{n+2}:?=MethodHandle.invokeBasic(*, t_{n+1}) + lastUseIndex(tableSwitch) == POS_UNBOX_RESULT; // t_{n+1} is local: used only in t_{n+2} + } + /** * Check if i-th name is a start of the loop idiom. */ @@ -1067,24 +1094,13 @@ class LambdaForm { final MemberName member; private @Stable MethodHandle resolvedHandle; @Stable MethodHandle invoker; - private final MethodHandleImpl.Intrinsic intrinsicName; NamedFunction(MethodHandle resolvedHandle) { - this(resolvedHandle.internalMemberName(), resolvedHandle, MethodHandleImpl.Intrinsic.NONE); - } - NamedFunction(MethodHandle resolvedHandle, MethodHandleImpl.Intrinsic intrinsic) { - this(resolvedHandle.internalMemberName(), resolvedHandle, intrinsic); + this(resolvedHandle.internalMemberName(), resolvedHandle); } NamedFunction(MemberName member, MethodHandle resolvedHandle) { - this(member, resolvedHandle, MethodHandleImpl.Intrinsic.NONE); - } - NamedFunction(MemberName member, MethodHandle resolvedHandle, MethodHandleImpl.Intrinsic intrinsic) { this.member = member; this.resolvedHandle = resolvedHandle; - this.intrinsicName = intrinsic; - assert(resolvedHandle == null || - resolvedHandle.intrinsicName() == MethodHandleImpl.Intrinsic.NONE || - resolvedHandle.intrinsicName() == intrinsic) : resolvedHandle.intrinsicName() + " != " + intrinsic; // The following assert is almost always correct, but will fail for corner cases, such as PrivateInvokeTest. //assert(!isInvokeBasic(member)); } @@ -1097,7 +1113,6 @@ class LambdaForm { // necessary to pass BigArityTest this.member = Invokers.invokeBasicMethod(basicInvokerType); } - this.intrinsicName = MethodHandleImpl.Intrinsic.NONE; assert(isInvokeBasic(member)); } @@ -1250,7 +1265,15 @@ class LambdaForm { } public MethodHandleImpl.Intrinsic intrinsicName() { - return intrinsicName; + return resolvedHandle != null + ? resolvedHandle.intrinsicName() + : MethodHandleImpl.Intrinsic.NONE; + } + + public Object intrinsicData() { + return resolvedHandle != null + ? resolvedHandle.intrinsicData() + : null; } } @@ -1732,15 +1755,15 @@ class LambdaForm { Name[] idNames = new Name[] { argument(0, L_TYPE), argument(1, type) }; idForm = new LambdaForm(2, idNames, 1, Kind.IDENTITY); idForm.compileToBytecode(); - idFun = new NamedFunction(idMem, SimpleMethodHandle.make(idMem.getInvocationType(), idForm), - MethodHandleImpl.Intrinsic.IDENTITY); + idFun = new NamedFunction(idMem, MethodHandleImpl.makeIntrinsic(SimpleMethodHandle.make(idMem.getInvocationType(), idForm), + MethodHandleImpl.Intrinsic.IDENTITY)); Object zeValue = Wrapper.forBasicType(btChar).zero(); Name[] zeNames = new Name[] { argument(0, L_TYPE), new Name(idFun, zeValue) }; zeForm = new LambdaForm(1, zeNames, 1, Kind.ZERO); zeForm.compileToBytecode(); - zeFun = new NamedFunction(zeMem, SimpleMethodHandle.make(zeMem.getInvocationType(), zeForm), - MethodHandleImpl.Intrinsic.ZERO); + zeFun = new NamedFunction(zeMem, MethodHandleImpl.makeIntrinsic(SimpleMethodHandle.make(zeMem.getInvocationType(), zeForm), + MethodHandleImpl.Intrinsic.ZERO)); } LF_zero[ord] = zeForm; diff --git a/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java b/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java index 88ccfab9de8..ae241759b20 100644 --- a/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java +++ b/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java @@ -38,6 +38,7 @@ import static java.lang.invoke.LambdaForm.*; import static java.lang.invoke.LambdaForm.BasicType.*; import static java.lang.invoke.MethodHandleImpl.Intrinsic; import static java.lang.invoke.MethodHandleImpl.NF_loop; +import static java.lang.invoke.MethodHandleImpl.makeIntrinsic; /** Transforms on LFs. * A lambda-form editor can derive new LFs from its base LF. @@ -619,7 +620,7 @@ class LambdaFormEditor { // adjust the arguments MethodHandle aload = MethodHandles.arrayElementGetter(erasedArrayType); for (int i = 0; i < arrayLength; i++) { - Name loadArgument = new Name(new NamedFunction(aload, Intrinsic.ARRAY_LOAD), spreadParam, i); + Name loadArgument = new Name(new NamedFunction(makeIntrinsic(aload, Intrinsic.ARRAY_LOAD)), spreadParam, i); buf.insertExpression(exprPos + i, loadArgument); buf.replaceParameterByCopy(pos + i, exprPos + i); } diff --git a/src/java.base/share/classes/java/lang/invoke/MethodHandle.java b/src/java.base/share/classes/java/lang/invoke/MethodHandle.java index 36864c23843..69453e24abc 100644 --- a/src/java.base/share/classes/java/lang/invoke/MethodHandle.java +++ b/src/java.base/share/classes/java/lang/invoke/MethodHandle.java @@ -1679,6 +1679,11 @@ assertEquals("[three, thee, tee]", asListFix.invoke((Object)argv).toString()); return MethodHandleImpl.Intrinsic.NONE; } + /*non-public*/ + Object intrinsicData() { + return null; + } + /*non-public*/ MethodHandle withInternalMemberName(MemberName member, boolean isInvokeSpecial) { if (member != null) { diff --git a/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java b/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java index 96fb8370cbd..846a55423d3 100644 --- a/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java +++ b/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java @@ -49,6 +49,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.stream.Stream; @@ -823,7 +825,9 @@ abstract class MethodHandleImpl { names[PROFILE] = new Name(getFunction(NF_profileBoolean), names[CALL_TEST], names[GET_COUNTERS]); } // call selectAlternative - names[SELECT_ALT] = new Name(new NamedFunction(getConstantHandle(MH_selectAlternative), Intrinsic.SELECT_ALTERNATIVE), names[TEST], names[GET_TARGET], names[GET_FALLBACK]); + names[SELECT_ALT] = new Name(new NamedFunction( + makeIntrinsic(getConstantHandle(MH_selectAlternative), Intrinsic.SELECT_ALTERNATIVE)), + names[TEST], names[GET_TARGET], names[GET_FALLBACK]); // call target or fallback invokeArgs[0] = names[SELECT_ALT]; @@ -894,7 +898,7 @@ abstract class MethodHandleImpl { Object[] args = new Object[invokeBasic.type().parameterCount()]; args[0] = names[GET_COLLECT_ARGS]; System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT-ARG_BASE); - names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.GUARD_WITH_CATCH), args); + names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.GUARD_WITH_CATCH)), args); // t_{i+1}:L=MethodHandleImpl.guardWithCatch(target:L,exType:L,catcher:L,t_{i}:L); Object[] gwcArgs = new Object[] {names[GET_TARGET], names[GET_CLASS], names[GET_CATCHER], names[BOXED_ARGS]}; @@ -1226,6 +1230,7 @@ abstract class MethodHandleImpl { SELECT_ALTERNATIVE, GUARD_WITH_CATCH, TRY_FINALLY, + TABLE_SWITCH, LOOP, ARRAY_LOAD, ARRAY_STORE, @@ -1240,11 +1245,17 @@ abstract class MethodHandleImpl { static final class IntrinsicMethodHandle extends DelegatingMethodHandle { private final MethodHandle target; private final Intrinsic intrinsicName; + private final Object intrinsicData; IntrinsicMethodHandle(MethodHandle target, Intrinsic intrinsicName) { + this(target, intrinsicName, null); + } + + IntrinsicMethodHandle(MethodHandle target, Intrinsic intrinsicName, Object intrinsicData) { super(target.type(), target); this.target = target; this.intrinsicName = intrinsicName; + this.intrinsicData = intrinsicData; } @Override @@ -1257,6 +1268,11 @@ abstract class MethodHandleImpl { return intrinsicName; } + @Override + Object intrinsicData() { + return intrinsicData; + } + @Override public MethodHandle asTypeUncached(MethodType newType) { // This MH is an alias for target, except for the intrinsic name @@ -1282,9 +1298,13 @@ abstract class MethodHandleImpl { } static MethodHandle makeIntrinsic(MethodHandle target, Intrinsic intrinsicName) { + return makeIntrinsic(target, intrinsicName, null); + } + + static MethodHandle makeIntrinsic(MethodHandle target, Intrinsic intrinsicName, Object intrinsicData) { if (intrinsicName == target.intrinsicName()) return target; - return new IntrinsicMethodHandle(target, intrinsicName); + return new IntrinsicMethodHandle(target, intrinsicName, intrinsicData); } static MethodHandle makeIntrinsic(MethodType type, LambdaForm form, Intrinsic intrinsicName) { @@ -1360,7 +1380,8 @@ abstract class MethodHandleImpl { NF_tryFinally = 3, NF_loop = 4, NF_profileBoolean = 5, - NF_LIMIT = 6; + NF_tableSwitch = 6, + NF_LIMIT = 7; private static final @Stable NamedFunction[] NFS = new NamedFunction[NF_LIMIT]; @@ -1394,6 +1415,9 @@ abstract class MethodHandleImpl { case NF_profileBoolean: return new NamedFunction(MethodHandleImpl.class .getDeclaredMethod("profileBoolean", boolean.class, int[].class)); + case NF_tableSwitch: + return new NamedFunction(MethodHandleImpl.class + .getDeclaredMethod("tableSwitch", int.class, MethodHandle.class, CasesHolder.class, Object[].class)); default: throw new InternalError("Undefined function: " + func); } @@ -1602,7 +1626,7 @@ abstract class MethodHandleImpl { Object[] args = new Object[invokeBasic.type().parameterCount()]; args[0] = names[GET_COLLECT_ARGS]; System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT - ARG_BASE); - names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.LOOP), args); + names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.LOOP)), args); // t_{i+1}:L=MethodHandleImpl.loop(localTypes:L,clauses:L,t_{i}:L); Object[] lArgs = @@ -1839,7 +1863,7 @@ abstract class MethodHandleImpl { Object[] args = new Object[invokeBasic.type().parameterCount()]; args[0] = names[GET_COLLECT_ARGS]; System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT-ARG_BASE); - names[BOXED_ARGS] = new Name(new NamedFunction(invokeBasic, Intrinsic.TRY_FINALLY), args); + names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.TRY_FINALLY)), args); // t_{i+1}:L=MethodHandleImpl.tryFinally(target:L,exType:L,catcher:L,t_{i}:L); Object[] tfArgs = new Object[] {names[GET_TARGET], names[GET_CLEANUP], names[BOXED_ARGS]}; @@ -1941,7 +1965,7 @@ abstract class MethodHandleImpl { storeNameCursor < STORE_ELEMENT_LIMIT; storeIndex++, storeNameCursor++, argCursor++){ - names[storeNameCursor] = new Name(new NamedFunction(storeFunc, Intrinsic.ARRAY_STORE), + names[storeNameCursor] = new Name(new NamedFunction(makeIntrinsic(storeFunc, Intrinsic.ARRAY_STORE)), names[CALL_NEW_ARRAY], storeIndex, names[argCursor]); } @@ -1952,6 +1976,141 @@ abstract class MethodHandleImpl { return lform; } + // use a wrapper because we need this array to be @Stable + static class CasesHolder { + @Stable + final MethodHandle[] cases; + + public CasesHolder(MethodHandle[] cases) { + this.cases = cases; + } + } + + static MethodHandle makeTableSwitch(MethodType type, MethodHandle defaultCase, MethodHandle[] caseActions) { + MethodType varargsType = type.changeReturnType(Object[].class); + MethodHandle collectArgs = varargsArray(type.parameterCount()).asType(varargsType); + + MethodHandle unboxResult = unboxResultHandle(type.returnType()); + + BoundMethodHandle.SpeciesData data = BoundMethodHandle.speciesData_LLLL(); + LambdaForm form = makeTableSwitchForm(type.basicType(), data, caseActions.length); + BoundMethodHandle mh; + CasesHolder caseHolder = new CasesHolder(caseActions); + try { + mh = (BoundMethodHandle) data.factory().invokeBasic(type, form, (Object) defaultCase, (Object) collectArgs, + (Object) unboxResult, (Object) caseHolder); + } catch (Throwable ex) { + throw uncaughtException(ex); + } + assert(mh.type() == type); + return mh; + } + + private static class TableSwitchCacheKey { + private static final Map<TableSwitchCacheKey, LambdaForm> CACHE = new ConcurrentHashMap<>(); + + private final MethodType basicType; + private final int numberOfCases; + + public TableSwitchCacheKey(MethodType basicType, int numberOfCases) { + this.basicType = basicType; + this.numberOfCases = numberOfCases; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TableSwitchCacheKey that = (TableSwitchCacheKey) o; + return numberOfCases == that.numberOfCases && Objects.equals(basicType, that.basicType); + } + @Override + public int hashCode() { + return Objects.hash(basicType, numberOfCases); + } + } + + private static LambdaForm makeTableSwitchForm(MethodType basicType, BoundMethodHandle.SpeciesData data, + int numCases) { + MethodType lambdaType = basicType.invokerType(); + + // We need to cache based on the basic type X number of cases, + // since the number of cases is used when generating bytecode. + // This also means that we can't use the cache in MethodTypeForm, + // which only uses the basic type as a key. + TableSwitchCacheKey key = new TableSwitchCacheKey(basicType, numCases); + LambdaForm lform = TableSwitchCacheKey.CACHE.get(key); + if (lform != null) { + return lform; + } + + final int THIS_MH = 0; + final int ARG_BASE = 1; // start of incoming arguments + final int ARG_LIMIT = ARG_BASE + basicType.parameterCount(); + final int ARG_SWITCH_ON = ARG_BASE; + assert ARG_SWITCH_ON < ARG_LIMIT; + + int nameCursor = ARG_LIMIT; + final int GET_COLLECT_ARGS = nameCursor++; + final int GET_DEFAULT_CASE = nameCursor++; + final int GET_UNBOX_RESULT = nameCursor++; + final int GET_CASES = nameCursor++; + final int BOXED_ARGS = nameCursor++; + final int TABLE_SWITCH = nameCursor++; + final int UNBOXED_RESULT = nameCursor++; + + int fieldCursor = 0; + final int FIELD_DEFAULT_CASE = fieldCursor++; + final int FIELD_COLLECT_ARGS = fieldCursor++; + final int FIELD_UNBOX_RESULT = fieldCursor++; + final int FIELD_CASES = fieldCursor++; + + Name[] names = arguments(nameCursor - ARG_LIMIT, lambdaType); + + names[THIS_MH] = names[THIS_MH].withConstraint(data); + names[GET_DEFAULT_CASE] = new Name(data.getterFunction(FIELD_DEFAULT_CASE), names[THIS_MH]); + names[GET_COLLECT_ARGS] = new Name(data.getterFunction(FIELD_COLLECT_ARGS), names[THIS_MH]); + names[GET_UNBOX_RESULT] = new Name(data.getterFunction(FIELD_UNBOX_RESULT), names[THIS_MH]); + names[GET_CASES] = new Name(data.getterFunction(FIELD_CASES), names[THIS_MH]); + + { + MethodType collectArgsType = basicType.changeReturnType(Object.class); + MethodHandle invokeBasic = MethodHandles.basicInvoker(collectArgsType); + Object[] args = new Object[invokeBasic.type().parameterCount()]; + args[0] = names[GET_COLLECT_ARGS]; + System.arraycopy(names, ARG_BASE, args, 1, ARG_LIMIT - ARG_BASE); + names[BOXED_ARGS] = new Name(new NamedFunction(makeIntrinsic(invokeBasic, Intrinsic.TABLE_SWITCH, numCases)), args); + } + + { + Object[] tfArgs = new Object[]{ + names[ARG_SWITCH_ON], names[GET_DEFAULT_CASE], names[GET_CASES], names[BOXED_ARGS]}; + names[TABLE_SWITCH] = new Name(getFunction(NF_tableSwitch), tfArgs); + } + + { + MethodHandle invokeBasic = MethodHandles.basicInvoker(MethodType.methodType(basicType.rtype(), Object.class)); + Object[] unboxArgs = new Object[]{names[GET_UNBOX_RESULT], names[TABLE_SWITCH]}; + names[UNBOXED_RESULT] = new Name(invokeBasic, unboxArgs); + } + + lform = new LambdaForm(lambdaType.parameterCount(), names, Kind.TABLE_SWITCH); + LambdaForm prev = TableSwitchCacheKey.CACHE.putIfAbsent(key, lform); + return prev != null ? prev : lform; + } + + @Hidden + static Object tableSwitch(int input, MethodHandle defaultCase, CasesHolder holder, Object[] args) throws Throwable { + MethodHandle[] caseActions = holder.cases; + MethodHandle selectedCase; + if (input < 0 || input >= caseActions.length) { + selectedCase = defaultCase; + } else { + selectedCase = caseActions[input]; + } + return selectedCase.invokeWithArguments(args); + } + // Indexes into constant method handles: static final int MH_cast = 0, diff --git a/src/java.base/share/classes/java/lang/invoke/MethodHandles.java b/src/java.base/share/classes/java/lang/invoke/MethodHandles.java index a460ae7f262..6e4f33affda 100644 --- a/src/java.base/share/classes/java/lang/invoke/MethodHandles.java +++ b/src/java.base/share/classes/java/lang/invoke/MethodHandles.java @@ -7751,4 +7751,90 @@ assertEquals("boojum", (String) catTrace.invokeExact("boo", "jum")); } } + /** + * Creates a table switch method handle, which can be used to switch over a set of target + * method handles, based on a given target index, called selector. + * <p> + * For a selector value of {@code n}, where {@code n} falls in the range {@code [0, N)}, + * and where {@code N} is the number of target method handles, the table switch method + * handle will invoke the n-th target method handle from the list of target method handles. + * <p> + * For a selector value that does not fall in the range {@code [0, N)}, the table switch + * method handle will invoke the given fallback method handle. + * <p> + * All method handles passed to this method must have the same type, with the additional + * requirement that the leading parameter be of type {@code int}. The leading parameter + * represents the selector. + * <p> + * Any trailing parameters present in the type will appear on the returned table switch + * method handle as well. Any arguments assigned to these parameters will be forwarded, + * together with the selector value, to the selected method handle when invoking it. + * + * @apiNote Example: + * The cases each drop the {@code selector} value they are given, and take an additional + * {@code String} argument, which is concatenated (using {@link String#concat(String)}) + * to a specific constant label string for each case: + * <blockquote><pre>{@code + * MethodHandles.Lookup lookup = MethodHandles.lookup(); + * MethodHandle caseMh = lookup.findVirtual(String.class, "concat", + * MethodType.methodType(String.class, String.class)); + * caseMh = MethodHandles.dropArguments(caseMh, 0, int.class); + * + * MethodHandle caseDefault = MethodHandles.insertArguments(caseMh, 1, "default: "); + * MethodHandle case0 = MethodHandles.insertArguments(caseMh, 1, "case 0: "); + * MethodHandle case1 = MethodHandles.insertArguments(caseMh, 1, "case 1: "); + * + * MethodHandle mhSwitch = MethodHandles.tableSwitch( + * caseDefault, + * case0, + * case1 + * ); + * + * assertEquals("default: data", (String) mhSwitch.invokeExact(-1, "data")); + * assertEquals("case 0: data", (String) mhSwitch.invokeExact(0, "data")); + * assertEquals("case 1: data", (String) mhSwitch.invokeExact(1, "data")); + * assertEquals("default: data", (String) mhSwitch.invokeExact(2, "data")); + * }</pre></blockquote> + * + * @param fallback the fallback method handle that is called when the selector is not + * within the range {@code [0, N)}. + * @param targets array of target method handles. + * @return the table switch method handle. + * @throws NullPointerException if {@code fallback}, the {@code targets} array, or any + * any of the elements of the {@code targets} array are + * {@code null}. + * @throws IllegalArgumentException if the {@code targets} array is empty, if the leading + * parameter of the fallback handle or any of the target + * handles is not {@code int}, or if the types of + * the fallback handle and all of target handles are + * not the same. + */ + public static MethodHandle tableSwitch(MethodHandle fallback, MethodHandle... targets) { + Objects.requireNonNull(fallback); + Objects.requireNonNull(targets); + targets = targets.clone(); + MethodType type = tableSwitchChecks(fallback, targets); + return MethodHandleImpl.makeTableSwitch(type, fallback, targets); + } + + private static MethodType tableSwitchChecks(MethodHandle defaultCase, MethodHandle[] caseActions) { + if (caseActions.length == 0) + throw new IllegalArgumentException("Not enough cases: " + Arrays.toString(caseActions)); + + MethodType expectedType = defaultCase.type(); + + if (!(expectedType.parameterCount() >= 1) || expectedType.parameterType(0) != int.class) + throw new IllegalArgumentException( + "Case actions must have int as leading parameter: " + Arrays.toString(caseActions)); + + for (MethodHandle mh : caseActions) { + Objects.requireNonNull(mh); + if (mh.type() != expectedType) + throw new IllegalArgumentException( + "Case actions must have the same type: " + Arrays.toString(caseActions)); + } + + return expectedType; + } + } diff --git a/test/jdk/java/lang/invoke/MethodHandles/TestTableSwitch.java b/test/jdk/java/lang/invoke/MethodHandles/TestTableSwitch.java new file mode 100644 index 00000000000..4ff1a99cc28 --- /dev/null +++ b/test/jdk/java/lang/invoke/MethodHandles/TestTableSwitch.java @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @run testng/othervm -Xverify:all TestTableSwitch + */ + +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.management.ObjectName; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.ArrayList; +import java.util.List; +import java.util.function.IntConsumer; +import java.util.function.IntFunction; + +import static org.testng.Assert.assertEquals; + +public class TestTableSwitch { + + static final MethodHandle MH_IntConsumer_accept; + static final MethodHandle MH_check; + + static { + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + MH_IntConsumer_accept = lookup.findVirtual(IntConsumer.class, "accept", + MethodType.methodType(void.class, int.class)); + MH_check = lookup.findStatic(TestTableSwitch.class, "check", + MethodType.methodType(void.class, List.class, Object[].class)); + } catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + public static MethodHandle simpleTestCase(String value) { + return simpleTestCase(String.class, value); + } + + public static MethodHandle simpleTestCase(Class<?> type, Object value) { + return MethodHandles.dropArguments(MethodHandles.constant(type, value), 0, int.class); + } + + public static Object testValue(Class<?> type) { + if (type == String.class) { + return "X"; + } else if (type == byte.class) { + return (byte) 42; + } else if (type == short.class) { + return (short) 84; + } else if (type == char.class) { + return 'Y'; + } else if (type == int.class) { + return 168; + } else if (type == long.class) { + return 336L; + } else if (type == float.class) { + return 42F; + } else if (type == double.class) { + return 84D; + } else if (type == boolean.class) { + return true; + } + return null; + } + + static final Class<?>[] TEST_TYPES = { + Object.class, + String.class, + byte.class, + short.class, + char.class, + int.class, + long.class, + float.class, + double.class, + boolean.class + }; + + public static Object[] testArguments(int caseNum, List<Object> testValues) { + Object[] args = new Object[testValues.size() + 1]; + args[0] = caseNum; + int insertPos = 1; + for (Object testValue : testValues) { + args[insertPos++] = testValue; + } + return args; + } + + @DataProvider + public static Object[][] nonVoidCases() { + List<Object[]> tests = new ArrayList<>(); + + for (Class<?> returnType : TEST_TYPES) { + for (int numCases = 1; numCases < 5; numCases++) { + tests.add(new Object[] { returnType, numCases, List.of() }); + tests.add(new Object[] { returnType, numCases, List.of(TEST_TYPES) }); + } + } + + return tests.toArray(Object[][]::new); + } + + private static void check(List<Object> testValues, Object[] collectedValues) { + assertEquals(collectedValues, testValues.toArray()); + } + + @Test(dataProvider = "nonVoidCases") + public void testNonVoidHandles(Class<?> type, int numCases, List<Class<?>> additionalTypes) throws Throwable { + MethodHandle collector = MH_check; + List<Object> testArguments = new ArrayList<>(); + collector = MethodHandles.insertArguments(collector, 0, testArguments); + collector = collector.asCollector(Object[].class, additionalTypes.size()); + + Object defaultReturnValue = testValue(type); + MethodHandle defaultCase = simpleTestCase(type, defaultReturnValue); + defaultCase = MethodHandles.collectArguments(defaultCase, 1, collector); + Object[] returnValues = new Object[numCases]; + MethodHandle[] cases = new MethodHandle[numCases]; + for (int i = 0; i < cases.length; i++) { + Object returnValue = testValue(type); + returnValues[i] = returnValue; + MethodHandle theCase = simpleTestCase(type, returnValue); + theCase = MethodHandles.collectArguments(theCase, 1, collector); + cases[i] = theCase; + } + + MethodHandle mhSwitch = MethodHandles.tableSwitch( + defaultCase, + cases + ); + + for (Class<?> additionalType : additionalTypes) { + testArguments.add(testValue(additionalType)); + } + + assertEquals(mhSwitch.invokeWithArguments(testArguments(-1, testArguments)), defaultReturnValue); + + for (int i = 0; i < numCases; i++) { + assertEquals(mhSwitch.invokeWithArguments(testArguments(i, testArguments)), returnValues[i]); + } + + assertEquals(mhSwitch.invokeWithArguments(testArguments(numCases, testArguments)), defaultReturnValue); + } + + @Test + public void testVoidHandles() throws Throwable { + IntFunction<MethodHandle> makeTestCase = expectedIndex -> { + IntConsumer test = actualIndex -> assertEquals(actualIndex, expectedIndex); + return MH_IntConsumer_accept.bindTo(test); + }; + + MethodHandle mhSwitch = MethodHandles.tableSwitch( + /* default: */ makeTestCase.apply(-1), + /* case 0: */ makeTestCase.apply(0), + /* case 1: */ makeTestCase.apply(1), + /* case 2: */ makeTestCase.apply(2) + ); + + mhSwitch.invokeExact((int) -1); + mhSwitch.invokeExact((int) 0); + mhSwitch.invokeExact((int) 1); + mhSwitch.invokeExact((int) 2); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testNullDefaultHandle() { + MethodHandles.tableSwitch(null, simpleTestCase("test")); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testNullCases() { + MethodHandle[] cases = null; + MethodHandles.tableSwitch(simpleTestCase("default"), cases); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testNullCase() { + MethodHandles.tableSwitch(simpleTestCase("default"), simpleTestCase("case"), null); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = ".*Not enough cases.*") + public void testNotEnoughCases() { + MethodHandles.tableSwitch(simpleTestCase("default")); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = ".*Case actions must have int as leading parameter.*") + public void testNotEnoughParameters() { + MethodHandle empty = MethodHandles.empty(MethodType.methodType(void.class)); + MethodHandles.tableSwitch(empty, empty, empty); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = ".*Case actions must have int as leading parameter.*") + public void testNoLeadingIntParameter() { + MethodHandle empty = MethodHandles.empty(MethodType.methodType(void.class, double.class)); + MethodHandles.tableSwitch(empty, empty, empty); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = ".*Case actions must have the same type.*") + public void testWrongCaseType() { + // doesn't return a String + MethodHandle wrongType = MethodHandles.empty(MethodType.methodType(void.class, int.class)); + MethodHandles.tableSwitch(simpleTestCase("default"), simpleTestCase("case"), wrongType); + } + +} diff --git a/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchConstant.java b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchConstant.java new file mode 100644 index 00000000000..dca928ef61b --- /dev/null +++ b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchConstant.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang.invoke; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.invoke.MutableCallSite; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@State(org.openjdk.jmh.annotations.Scope.Thread) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(3) +public class MethodHandlesTableSwitchConstant { + + // Switch combinator test for a single constant input index + + private static final MethodType callType = MethodType.methodType(int.class, int.class); + + private static final MutableCallSite cs = new MutableCallSite(callType); + private static final MethodHandle target = cs.dynamicInvoker(); + + private static final MutableCallSite csInput = new MutableCallSite(MethodType.methodType(int.class)); + private static final MethodHandle targetInput = csInput.dynamicInvoker(); + + private static final MethodHandle MH_SUBTRACT; + private static final MethodHandle MH_DEFAULT; + private static final MethodHandle MH_PAYLOAD; + + static { + try { + MH_SUBTRACT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "subtract", + MethodType.methodType(int.class, int.class, int.class)); + MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "defaultCase", + MethodType.methodType(int.class, int.class)); + MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchConstant.class, "payload", + MethodType.methodType(int.class, int.class, int.class)); + } catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + // Using batch size since we really need a per-invocation setup + // but the measured code is too fast. Using JMH batch size doesn't work + // since there is no way to do a batch-level setup as well. + private static final int BATCH_SIZE = 1_000_000; + + @Param({ + "5", + "10", + "25" + }) + public int numCases; + + + @Param({ + "0", + "150" + }) + public int offset; + + @Setup(Level.Trial) + public void setupTrial() throws Throwable { + MethodHandle[] cases = IntStream.range(0, numCases) + .mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i)) + .toArray(MethodHandle[]::new); + MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases); + if (offset != 0) { + switcher = MethodHandles.filterArguments(switcher, 0, MethodHandles.insertArguments(MH_SUBTRACT, 1, offset)); + } + cs.setTarget(switcher); + + int input = ThreadLocalRandom.current().nextInt(numCases) + offset; + csInput.setTarget(MethodHandles.constant(int.class, input)); + } + + private static int payload(int dropped, int constant) { + return constant; + } + + private static int subtract(int a, int b) { + return a - b; + } + + private static int defaultCase(int x) { + throw new IllegalStateException(); + } + + @Benchmark + public void testSwitch(Blackhole bh) throws Throwable { + for (int i = 0; i < BATCH_SIZE; i++) { + bh.consume((int) target.invokeExact((int) targetInput.invokeExact())); + } + } + +} diff --git a/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchOpaqueSingle.java b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchOpaqueSingle.java new file mode 100644 index 00000000000..b0d45e471a9 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchOpaqueSingle.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang.invoke; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.invoke.MutableCallSite; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@State(org.openjdk.jmh.annotations.Scope.Thread) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(3) +public class MethodHandlesTableSwitchOpaqueSingle { + + // Switch combinator test for a single input index, but opaquely fed in, so the JIT + // does not see it as a constant. + + private static final MethodType callType = MethodType.methodType(int.class, int.class); + + private static final MutableCallSite cs = new MutableCallSite(callType); + private static final MethodHandle target = cs.dynamicInvoker(); + + private static final MethodHandle MH_DEFAULT; + private static final MethodHandle MH_PAYLOAD; + + static { + try { + MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchOpaqueSingle.class, "defaultCase", + MethodType.methodType(int.class, int.class)); + MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchOpaqueSingle.class, "payload", + MethodType.methodType(int.class, int.class, int.class)); + } catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + // Using batch size since we really need a per-invocation setup + // but the measured code is too fast. Using JMH batch size doesn't work + // since there is no way to do a batch-level setup as well. + private static final int BATCH_SIZE = 1_000_000; + + @Param({ + "5", + "10", + "25" + }) + public int numCases; + + public int input; + + @Setup(Level.Trial) + public void setupTrial() throws Throwable { + MethodHandle[] cases = IntStream.range(0, numCases) + .mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i)) + .toArray(MethodHandle[]::new); + MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases); + cs.setTarget(switcher); + + input = ThreadLocalRandom.current().nextInt(numCases); + } + + private static int payload(int dropped, int constant) { + return constant; + } + + private static int defaultCase(int x) { + throw new IllegalStateException(); + } + + @Benchmark + public void testSwitch(Blackhole bh) throws Throwable { + for (int i = 0; i < BATCH_SIZE; i++) { + bh.consume((int) target.invokeExact(input)); + } + } + +} diff --git a/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchRandom.java b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchRandom.java new file mode 100644 index 00000000000..94c21c1cc92 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/lang/invoke/MethodHandlesTableSwitchRandom.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang.invoke; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.invoke.MutableCallSite; +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@State(org.openjdk.jmh.annotations.Scope.Thread) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(3) +public class MethodHandlesTableSwitchRandom { + + // Switch combinator test for a random input index, testing several switch sizes + + private static final MethodType callType = MethodType.methodType(int.class, int.class); + + private static final MutableCallSite cs = new MutableCallSite(callType); + private static final MethodHandle target = cs.dynamicInvoker(); + + private static final MethodHandle MH_DEFAULT; + private static final MethodHandle MH_PAYLOAD; + + static { + try { + MH_DEFAULT = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchRandom.class, "defaultCase", + MethodType.methodType(int.class, int.class)); + MH_PAYLOAD = MethodHandles.lookup().findStatic(MethodHandlesTableSwitchRandom.class, "payload", + MethodType.methodType(int.class, int.class, int.class)); + } catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + // Using batch size since we really need a per-invocation setup + // but the measured code is too fast. Using JMH batch size doesn't work + // since there is no way to do a batch-level setup as well. + private static final int BATCH_SIZE = 1_000_000; + + @Param({ + "5", + "10", + "25" + }) + public int numCases; + + @Param({ + "true", + "false" + }) + public boolean sorted; + + public int[] inputs; + + @Setup(Level.Trial) + public void setupTrial() throws Throwable { + MethodHandle[] cases = IntStream.range(0, numCases) + .mapToObj(i -> MethodHandles.insertArguments(MH_PAYLOAD, 1, i)) + .toArray(MethodHandle[]::new); + MethodHandle switcher = MethodHandles.tableSwitch(MH_DEFAULT, cases); + + cs.setTarget(switcher); + + inputs = new int[BATCH_SIZE]; + Random rand = new Random(0); + for (int i = 0; i < BATCH_SIZE; i++) { + inputs[i] = rand.nextInt(numCases); + } + + if (sorted) { + Arrays.sort(inputs); + } + } + + private static int payload(int dropped, int constant) { + return constant; + } + + private static int defaultCase(int x) { + throw new IllegalStateException(); + } + + @Benchmark + public void testSwitch(Blackhole bh) throws Throwable { + for (int i = 0; i < inputs.length; i++) { + bh.consume((int) target.invokeExact(inputs[i])); + } + } + +}