jdk-24/test/jdk/java/foreign/NativeTestHelper.java

334 lines
15 KiB
Java
Raw Normal View History

/*
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
import java.lang.foreign.AddressLayout;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.GroupLayout;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.PaddingLayout;
import java.lang.foreign.SegmentAllocator;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.StructLayout;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.UnionLayout;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.UnaryOperator;
import java.util.random.RandomGenerator;
import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement;
public class NativeTestHelper {
public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows");
private static final MethodHandle MH_SAVER;
private static final RandomGenerator DEFAULT_RANDOM;
static {
int seed = Integer.getInteger("NativeTestHelper.DEFAULT_RANDOM.seed", ThreadLocalRandom.current().nextInt());
System.out.println("NativeTestHelper::DEFAULT_RANDOM.seed = " + seed);
System.out.println("Re-run with '-DNativeTestHelper.DEFAULT_RANDOM.seed=" + seed + "' to reproduce");
DEFAULT_RANDOM = new Random(seed);
try {
MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver",
MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
public static boolean isIntegral(MemoryLayout layout) {
return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier());
}
static boolean isIntegral(Class<?> clazz) {
return clazz == byte.class || clazz == char.class || clazz == short.class
|| clazz == int.class || clazz == long.class;
}
public static boolean isPointer(MemoryLayout layout) {
return layout instanceof ValueLayout valueLayout && valueLayout.carrier() == MemorySegment.class;
}
public static final Linker LINKER = Linker.nativeLinker();
// the constants below are useful aliases for C types. The type/carrier association is only valid for 64-bit platforms.
/**
* The layout for the {@code bool} C type
*/
public static final ValueLayout.OfBoolean C_BOOL = (ValueLayout.OfBoolean) LINKER.canonicalLayouts().get("bool");
/**
* The layout for the {@code char} C type
*/
public static final ValueLayout.OfByte C_CHAR = (ValueLayout.OfByte) LINKER.canonicalLayouts().get("char");
/**
* The layout for the {@code short} C type
*/
public static final ValueLayout.OfShort C_SHORT = (ValueLayout.OfShort) LINKER.canonicalLayouts().get("short");
/**
* The layout for the {@code int} C type
*/
public static final ValueLayout.OfInt C_INT = (ValueLayout.OfInt) LINKER.canonicalLayouts().get("int");
/**
* The layout for the {@code long long} C type.
*/
public static final ValueLayout.OfLong C_LONG_LONG = (ValueLayout.OfLong) LINKER.canonicalLayouts().get("long long");
/**
* The layout for the {@code float} C type
*/
public static final ValueLayout.OfFloat C_FLOAT = (ValueLayout.OfFloat) LINKER.canonicalLayouts().get("float");
/**
* The layout for the {@code double} C type
*/
public static final ValueLayout.OfDouble C_DOUBLE = (ValueLayout.OfDouble) LINKER.canonicalLayouts().get("double");
/**
* The {@code T*} native type.
*/
public static final AddressLayout C_POINTER = ((AddressLayout) LINKER.canonicalLayouts().get("void*"))
.withTargetLayout(MemoryLayout.sequenceLayout(Long.MAX_VALUE, C_CHAR));
/**
* The layout for the {@code size_t} C type
*/
public static final ValueLayout C_SIZE_T = (ValueLayout) LINKER.canonicalLayouts().get("size_t");
// Common layout shared by some tests
// struct S_PDI { void* p0; double p1; int p2; };
public static final MemoryLayout S_PDI_LAYOUT = switch ((int) ValueLayout.ADDRESS.byteSize()) {
case 8 -> MemoryLayout.structLayout(
C_POINTER.withName("p0"),
C_DOUBLE.withName("p1"),
C_INT.withName("p2"),
MemoryLayout.paddingLayout(4));
case 4 -> MemoryLayout.structLayout(
C_POINTER.withName("p0"),
C_DOUBLE.withName("p1"),
C_INT.withName("p2"));
default -> throw new UnsupportedOperationException("Unsupported address size");
};
private static final MethodHandle FREE = LINKER.downcallHandle(
LINKER.defaultLookup().find("free").get(), FunctionDescriptor.ofVoid(C_POINTER));
private static final MethodHandle MALLOC = LINKER.downcallHandle(
LINKER.defaultLookup().find("malloc").get(), FunctionDescriptor.of(C_POINTER, C_LONG_LONG));
public static void freeMemory(MemorySegment address) {
try {
FREE.invokeExact(address);
} catch (Throwable ex) {
throw new IllegalStateException(ex);
}
}
public static MemorySegment allocateMemory(long size) {
try {
return (MemorySegment) MALLOC.invokeExact(size);
} catch (Throwable ex) {
throw new IllegalStateException(ex);
}
}
public static MemorySegment findNativeOrThrow(String name) {
return SymbolLookup.loaderLookup().find(name).orElseThrow();
}
public static MethodHandle downcallHandle(String symbol, FunctionDescriptor desc, Linker.Option... options) {
return LINKER.downcallHandle(findNativeOrThrow(symbol), desc, options);
}
public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor) {
return upcallStub(holder, name, descriptor, Arena.ofAuto());
}
public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor, Arena arena) {
try {
MethodHandle target = MethodHandles.lookup().findStatic(holder, name, descriptor.toMethodType());
return LINKER.upcallStub(target, descriptor, arena);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
public static TestValue[] genTestArgs(FunctionDescriptor descriptor, SegmentAllocator allocator) {
return genTestArgs(DEFAULT_RANDOM, descriptor, allocator);
}
public static TestValue[] genTestArgs(RandomGenerator random, FunctionDescriptor descriptor, SegmentAllocator allocator) {
TestValue[] result = new TestValue[descriptor.argumentLayouts().size()];
for (int i = 0; i < result.length; i++) {
result[i] = genTestValue(random, descriptor.argumentLayouts().get(i), allocator);
}
return result;
}
public record TestValue (Object value, Consumer<Object> check) {
public void check(Object actual) { check.accept(actual); }
}
public static TestValue genTestValue(MemoryLayout layout, SegmentAllocator allocator) {
return genTestValue(DEFAULT_RANDOM, layout, allocator);
}
public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) {
if (layout instanceof StructLayout struct) {
MemorySegment segment = allocator.allocate(struct);
List<Consumer<Object>> fieldChecks = new ArrayList<>();
for (MemoryLayout fieldLayout : struct.memberLayouts()) {
if (fieldLayout instanceof PaddingLayout) continue;
MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator));
}
return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual)));
} else if (layout instanceof UnionLayout union) {
MemorySegment segment = allocator.allocate(union);
List<MemoryLayout> filteredFields = union.memberLayouts().stream()
.filter(l -> !(l instanceof PaddingLayout))
.toList();
int fieldIdx = random.nextInt(filteredFields.size());
MemoryLayout fieldLayout = filteredFields.get(fieldIdx);
MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator);
return new TestValue(segment, check);
} else if (layout instanceof SequenceLayout array) {
MemorySegment segment = allocator.allocate(array);
List<Consumer<Object>> elementChecks = new ArrayList<>();
for (int i = 0; i < array.elementCount(); i++) {
elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator));
}
return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual)));
} else if (layout instanceof AddressLayout) {
MemorySegment value = MemorySegment.ofAddress(random.nextLong());
return new TestValue(value, actual -> assertEquals(actual, value));
}else if (layout instanceof ValueLayout.OfByte) {
byte value = (byte) random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfShort) {
short value = (short) random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfChar) {
char value = (char) random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfInt) {
int value = random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfLong) {
long value = random.nextLong();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfFloat) {
float value = random.nextFloat();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfDouble) {
double value = random.nextDouble();
return new TestValue(value, actual -> assertEquals(actual, value));
}
throw new IllegalStateException("Unexpected layout: " + layout);
}
private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout,
MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath,
SegmentAllocator allocator) {
TestValue fieldValue = genTestValue(random, fieldLayout, allocator);
Consumer<Object> fieldCheck = fieldValue.check();
if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) {
UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath);
MemorySegment slice = slicer.apply(container);
slice.copyFrom((MemorySegment) fieldValue.value());
return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual));
} else {
VarHandle accessor = containerLayout.varHandle(fieldPath);
//set value
accessor.set(container, 0L, fieldValue.value());
return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual, 0L));
}
}
private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) {
MethodHandle slicer = containerLayout.sliceHandle(fieldPath);
return container -> {
try {
return (MemorySegment) slicer.invokeExact(container, 0L);
} catch (Throwable e) {
throw new IllegalStateException(e);
}
};
}
private static void assertEquals(Object actual, Object expected) {
if (actual.getClass() != expected.getClass()) {
throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass());
}
if (!actual.equals(expected)) {
throw new AssertionError("Not equal: " + actual + " != " + expected);
}
}
/**
* Make an upcall stub that saves its arguments into the given 'ref' array
*
* @param fd function descriptor for the upcall stub
* @param capturedArgs box to save arguments in
* @param arena allocator for making copies of by-value structs
* @param retIdx the index of the argument to return
* @return return the upcall stub
*/
public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena,
AtomicReference<Object[]> capturedArgs, int retIdx) {
MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx);
target = target.asCollector(Object[].class, fd.argumentLayouts().size());
target = target.asType(fd.toMethodType());
return LINKER.upcallStub(target, fd, arena);
}
private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) {
for (int i = 0; i < o.length; i++) {
if (argLayouts.get(i) instanceof GroupLayout gl) {
MemorySegment ms = (MemorySegment) o[i];
MemorySegment copy = allocator.allocate(gl);
copy.copyFrom(ms);
o[i] = copy;
}
}
ref.set(o);
return retArg != -1 ? o[retArg] : null;
}
}