diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/AbstractLinker.java b/src/java.base/share/classes/jdk/internal/foreign/abi/AbstractLinker.java index b1847791751..dea56564c6b 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/AbstractLinker.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/AbstractLinker.java @@ -96,6 +96,7 @@ public abstract sealed class AbstractLinker implements Linker permits LinuxAArch FunctionDescriptor fd = linkRequest.descriptor(); MethodType type = fd.toMethodType(); MethodHandle handle = arrangeDowncall(type, fd, linkRequest.options()); + handle = SharedUtils.maybeCheckCaptureSegment(handle, linkRequest.options()); handle = SharedUtils.maybeInsertAllocator(fd, handle); return handle; }); diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/SharedUtils.java b/src/java.base/share/classes/jdk/internal/foreign/abi/SharedUtils.java index b9f1de7ed64..1e417245543 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/SharedUtils.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/SharedUtils.java @@ -78,6 +78,7 @@ public final class SharedUtils { private static final MethodHandle MH_BUFFER_COPY; private static final MethodHandle MH_REACHABILITY_FENCE; public static final MethodHandle MH_CHECK_SYMBOL; + private static final MethodHandle MH_CHECK_CAPTURE_SEGMENT; public static final AddressLayout C_POINTER = ADDRESS .withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE)); @@ -110,6 +111,8 @@ public final class SharedUtils { methodType(void.class, Object.class)); MH_CHECK_SYMBOL = lookup.findStatic(SharedUtils.class, "checkSymbol", methodType(void.class, MemorySegment.class)); + MH_CHECK_CAPTURE_SEGMENT = lookup.findStatic(SharedUtils.class, "checkCaptureSegment", + methodType(MemorySegment.class, MemorySegment.class)); } catch (ReflectiveOperationException e) { throw new BootstrapMethodError(e); } @@ -343,6 +346,23 @@ public final class SharedUtils { return handle; } + public static MethodHandle maybeCheckCaptureSegment(MethodHandle handle, LinkerOptions options) { + if (options.hasCapturedCallState()) { + // (, SegmentAllocator, , ...) -> ... + handle = MethodHandles.filterArguments(handle, 2, MH_CHECK_CAPTURE_SEGMENT); + } + return handle; + } + + @ForceInline + public static MemorySegment checkCaptureSegment(MemorySegment captureSegment) { + Objects.requireNonNull(captureSegment); + if (captureSegment.equals(MemorySegment.NULL)) { + throw new IllegalArgumentException("Capture segment is NULL: " + captureSegment); + } + return captureSegment.asSlice(0, CapturableState.LAYOUT); + } + @ForceInline public static void checkSymbol(MemorySegment symbol) { Objects.requireNonNull(symbol); diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java b/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java index 541557b47e7..1adb6a6da9d 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java @@ -49,7 +49,6 @@ import java.util.List; import java.util.function.Consumer; import static java.lang.foreign.ValueLayout.ADDRESS; -import static java.lang.foreign.ValueLayout.JAVA_LONG; import static java.lang.invoke.MethodHandles.foldArguments; public final class FallbackLinker extends AbstractLinker { @@ -161,7 +160,7 @@ public final class FallbackLinker extends AbstractLinker { MemorySegment capturedState = null; if (invData.capturedStateMask() != 0) { - capturedState = (MemorySegment) args[argStart++]; + capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]); MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl(); capturedStateImpl.acquire0(); acquiredSessions.add(capturedStateImpl); diff --git a/test/jdk/java/foreign/capturecallstate/TestCaptureCallState.java b/test/jdk/java/foreign/capturecallstate/TestCaptureCallState.java index 1e3d16be2f9..9e4d8b686f2 100644 --- a/test/jdk/java/foreign/capturecallstate/TestCaptureCallState.java +++ b/test/jdk/java/foreign/capturecallstate/TestCaptureCallState.java @@ -50,6 +50,7 @@ import static java.lang.foreign.ValueLayout.JAVA_DOUBLE; import static java.lang.foreign.ValueLayout.JAVA_INT; import static java.lang.foreign.ValueLayout.JAVA_LONG; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; public class TestCaptureCallState extends NativeTestHelper { @@ -85,6 +86,21 @@ public class TestCaptureCallState extends NativeTestHelper { } } + @Test(dataProvider = "invalidCaptureSegmentCases") + public void testInvalidCaptureSegment(MemorySegment captureSegment, + Class expectedExceptionType, String expectedExceptionMessage) { + Linker.Option stl = Linker.Option.captureCallState("errno"); + MethodHandle handle = downcallHandle("set_errno_V", FunctionDescriptor.ofVoid(C_INT), stl); + + try { + int testValue = 42; + handle.invoke(captureSegment, testValue); // should throw + } catch (Throwable t) { + assertTrue(expectedExceptionType.isInstance(t)); + assertTrue(t.getMessage().matches(expectedExceptionMessage)); + } + } + @DataProvider public static Object[][] cases() { List cases = new ArrayList<>(); @@ -128,4 +144,13 @@ public class TestCaptureCallState extends NativeTestHelper { return new SaveValuesCase("set_errno_" + name, FunctionDescriptor.of(layout, JAVA_INT), "errno", check); } + @DataProvider + public static Object[][] invalidCaptureSegmentCases() { + return new Object[][]{ + {Arena.ofAuto().allocate(1), IndexOutOfBoundsException.class, ".*Out of bound access on segment.*"}, + {MemorySegment.NULL, IllegalArgumentException.class, ".*Capture segment is NULL.*"}, + {Arena.ofAuto().allocate(Linker.Option.captureStateLayout().byteSize() + 3).asSlice(3), // misaligned + IllegalArgumentException.class, ".*Target offset incompatible with alignment constraints.*"}, + }; + } }