465 lines
17 KiB
Java
465 lines
17 KiB
Java
|
/*
|
||
|
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
|
||
|
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||
|
*
|
||
|
* This code is free software; you can redistribute it and/or modify it
|
||
|
* under the terms of the GNU General Public License version 2 only, as
|
||
|
* published by the Free Software Foundation.
|
||
|
*
|
||
|
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||
|
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||
|
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||
|
* version 2 for more details (a copy is included in the LICENSE file that
|
||
|
* accompanied this code).
|
||
|
*
|
||
|
* You should have received a copy of the GNU General Public License version
|
||
|
* 2 along with this work; if not, write to the Free Software Foundation,
|
||
|
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||
|
*
|
||
|
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||
|
* or visit www.oracle.com if you need additional information or have any
|
||
|
* questions.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
import jdk.incubator.foreign.GroupLayout;
|
||
|
import jdk.incubator.foreign.MemoryAddress;
|
||
|
import jdk.incubator.foreign.MemoryLayout;
|
||
|
import jdk.incubator.foreign.MemorySegment;
|
||
|
import jdk.incubator.foreign.ValueLayout;
|
||
|
|
||
|
import java.lang.invoke.VarHandle;
|
||
|
import java.util.ArrayList;
|
||
|
import java.util.List;
|
||
|
import java.util.Stack;
|
||
|
import java.util.function.Consumer;
|
||
|
import java.util.stream.Collectors;
|
||
|
import java.util.stream.IntStream;
|
||
|
|
||
|
import org.testng.annotations.*;
|
||
|
|
||
|
import static jdk.incubator.foreign.CLinker.*;
|
||
|
import static org.testng.Assert.*;
|
||
|
|
||
|
public class CallGeneratorHelper extends NativeTestHelper {
|
||
|
|
||
|
static final int MAX_FIELDS = 3;
|
||
|
static final int MAX_PARAMS = 3;
|
||
|
static final int CHUNK_SIZE = 600;
|
||
|
|
||
|
static int functions = 0;
|
||
|
|
||
|
public static void assertStructEquals(MemorySegment actual, MemorySegment expected, MemoryLayout layout) {
|
||
|
assertEquals(actual.byteSize(), expected.byteSize());
|
||
|
GroupLayout g = (GroupLayout) layout;
|
||
|
for (MemoryLayout field : g.memberLayouts()) {
|
||
|
if (field instanceof ValueLayout) {
|
||
|
VarHandle vh = g.varHandle(vhCarrier(field), MemoryLayout.PathElement.groupElement(field.name().orElseThrow()));
|
||
|
assertEquals(vh.get(actual), vh.get(expected));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private static Class<?> vhCarrier(MemoryLayout layout) {
|
||
|
if (layout instanceof ValueLayout) {
|
||
|
if (isIntegral(layout)) {
|
||
|
if (layout.bitSize() == 64) {
|
||
|
return long.class;
|
||
|
}
|
||
|
return int.class;
|
||
|
} else if (layout.bitSize() == 32) {
|
||
|
return float.class;
|
||
|
}
|
||
|
return double.class;
|
||
|
} else {
|
||
|
throw new IllegalStateException("Unexpected layout: " + layout);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
enum Ret {
|
||
|
VOID,
|
||
|
NON_VOID
|
||
|
}
|
||
|
|
||
|
enum StructFieldType {
|
||
|
INT("int", C_INT),
|
||
|
FLOAT("float", C_FLOAT),
|
||
|
DOUBLE("double", C_DOUBLE),
|
||
|
POINTER("void*", C_POINTER);
|
||
|
|
||
|
final String typeStr;
|
||
|
final MemoryLayout layout;
|
||
|
|
||
|
StructFieldType(String typeStr, MemoryLayout layout) {
|
||
|
this.typeStr = typeStr;
|
||
|
this.layout = layout;
|
||
|
}
|
||
|
|
||
|
MemoryLayout layout() {
|
||
|
return layout;
|
||
|
}
|
||
|
|
||
|
@SuppressWarnings("unchecked")
|
||
|
static List<List<StructFieldType>>[] perms = new List[10];
|
||
|
|
||
|
static List<List<StructFieldType>> perms(int i) {
|
||
|
if (perms[i] == null) {
|
||
|
perms[i] = generateTest(i, values());
|
||
|
}
|
||
|
return perms[i];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
enum ParamType {
|
||
|
INT("int", C_INT),
|
||
|
FLOAT("float", C_FLOAT),
|
||
|
DOUBLE("double", C_DOUBLE),
|
||
|
POINTER("void*", C_POINTER),
|
||
|
STRUCT("struct S", null);
|
||
|
|
||
|
private final String typeStr;
|
||
|
private final MemoryLayout layout;
|
||
|
|
||
|
ParamType(String typeStr, MemoryLayout layout) {
|
||
|
this.typeStr = typeStr;
|
||
|
this.layout = layout;
|
||
|
}
|
||
|
|
||
|
String type(List<StructFieldType> fields) {
|
||
|
return this == STRUCT ?
|
||
|
typeStr + "_" + sigCode(fields) :
|
||
|
typeStr;
|
||
|
}
|
||
|
|
||
|
MemoryLayout layout(List<StructFieldType> fields) {
|
||
|
if (this == STRUCT) {
|
||
|
long offset = 0L;
|
||
|
List<MemoryLayout> layouts = new ArrayList<>();
|
||
|
for (StructFieldType field : fields) {
|
||
|
MemoryLayout l = field.layout();
|
||
|
long padding = offset % l.bitSize();
|
||
|
if (padding != 0) {
|
||
|
layouts.add(MemoryLayout.ofPaddingBits(padding));
|
||
|
offset += padding;
|
||
|
}
|
||
|
layouts.add(l.withName("field" + offset));
|
||
|
offset += l.bitSize();
|
||
|
}
|
||
|
return MemoryLayout.ofStruct(layouts.toArray(new MemoryLayout[0]));
|
||
|
} else {
|
||
|
return layout;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@SuppressWarnings("unchecked")
|
||
|
static List<List<ParamType>>[] perms = new List[10];
|
||
|
|
||
|
static List<List<ParamType>> perms(int i) {
|
||
|
if (perms[i] == null) {
|
||
|
perms[i] = generateTest(i, values());
|
||
|
}
|
||
|
return perms[i];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static <Z> List<List<Z>> generateTest(int i, Z[] elems) {
|
||
|
List<List<Z>> res = new ArrayList<>();
|
||
|
generateTest(i, new Stack<>(), elems, res);
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
static <Z> void generateTest(int i, Stack<Z> combo, Z[] elems, List<List<Z>> results) {
|
||
|
if (i == 0) {
|
||
|
results.add(new ArrayList<>(combo));
|
||
|
} else {
|
||
|
for (Z z : elems) {
|
||
|
combo.push(z);
|
||
|
generateTest(i - 1, combo, elems, results);
|
||
|
combo.pop();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@DataProvider(name = "functions")
|
||
|
public static Object[][] functions() {
|
||
|
List<Object[]> downcalls = new ArrayList<>();
|
||
|
for (Ret r : Ret.values()) {
|
||
|
for (int i = 0; i <= MAX_PARAMS; i++) {
|
||
|
if (r != Ret.VOID && i == 0) continue;
|
||
|
for (List<ParamType> ptypes : ParamType.perms(i)) {
|
||
|
String retCode = r == Ret.VOID ? "V" : ptypes.get(0).name().charAt(0) + "";
|
||
|
String sigCode = sigCode(ptypes);
|
||
|
if (ptypes.contains(ParamType.STRUCT)) {
|
||
|
for (int j = 1; j <= MAX_FIELDS; j++) {
|
||
|
for (List<StructFieldType> fields : StructFieldType.perms(j)) {
|
||
|
String structCode = sigCode(fields);
|
||
|
int fCode = functions++ / CHUNK_SIZE;
|
||
|
String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
|
||
|
downcalls.add(new Object[] { fName, r, ptypes, fields });
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
String structCode = sigCode(List.<StructFieldType>of());
|
||
|
int fCode = functions++ / CHUNK_SIZE;
|
||
|
String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
|
||
|
downcalls.add(new Object[] { fName, r, ptypes, List.of() });
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return downcalls.toArray(new Object[0][]);
|
||
|
}
|
||
|
|
||
|
static <Z extends Enum<Z>> String sigCode(List<Z> elems) {
|
||
|
return elems.stream().map(p -> p.name().charAt(0) + "").collect(Collectors.joining());
|
||
|
}
|
||
|
|
||
|
static void generateStructDecl(List<StructFieldType> fields) {
|
||
|
String structCode = sigCode(fields);
|
||
|
List<String> fieldDecls = new ArrayList<>();
|
||
|
for (int i = 0 ; i < fields.size() ; i++) {
|
||
|
fieldDecls.add(String.format("%s p%d;", fields.get(i).typeStr, i));
|
||
|
}
|
||
|
String res = String.format("struct S_%s { %s };", structCode,
|
||
|
fieldDecls.stream().collect(Collectors.joining(" ")));
|
||
|
System.out.println(res);
|
||
|
}
|
||
|
|
||
|
/* this can be used to generate the test header/implementation */
|
||
|
public static void main(String[] args) {
|
||
|
boolean header = args.length > 0 && args[0].equals("header");
|
||
|
boolean upcall = args.length > 1 && args[1].equals("upcall");
|
||
|
if (upcall) {
|
||
|
generateUpcalls(header);
|
||
|
} else {
|
||
|
generateDowncalls(header);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static void generateDowncalls(boolean header) {
|
||
|
if (header) {
|
||
|
System.out.println(
|
||
|
"#ifdef _WIN64\n" +
|
||
|
"#define EXPORT __declspec(dllexport)\n" +
|
||
|
"#else\n" +
|
||
|
"#define EXPORT\n" +
|
||
|
"#endif\n"
|
||
|
);
|
||
|
|
||
|
for (int j = 1; j <= MAX_FIELDS; j++) {
|
||
|
for (List<StructFieldType> fields : StructFieldType.perms(j)) {
|
||
|
generateStructDecl(fields);
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
System.out.println(
|
||
|
"#include \"libh\"\n" +
|
||
|
"#ifdef __clang__\n" +
|
||
|
"#pragma clang optimize off\n" +
|
||
|
"#elif defined __GNUC__\n" +
|
||
|
"#pragma GCC optimize (\"O0\")\n" +
|
||
|
"#elif defined _MSC_BUILD\n" +
|
||
|
"#pragma optimize( \"\", off )\n" +
|
||
|
"#endif\n"
|
||
|
);
|
||
|
}
|
||
|
|
||
|
for (Object[] downcall : functions()) {
|
||
|
String fName = (String)downcall[0];
|
||
|
Ret r = (Ret)downcall[1];
|
||
|
@SuppressWarnings("unchecked")
|
||
|
List<ParamType> ptypes = (List<ParamType>)downcall[2];
|
||
|
@SuppressWarnings("unchecked")
|
||
|
List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
|
||
|
generateDowncallFunction(fName, r, ptypes, fields, header);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static void generateDowncallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
|
||
|
String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
|
||
|
List<String> paramDecls = new ArrayList<>();
|
||
|
for (int i = 0 ; i < params.size() ; i++) {
|
||
|
paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
|
||
|
}
|
||
|
String sig = paramDecls.isEmpty() ?
|
||
|
"void" :
|
||
|
paramDecls.stream().collect(Collectors.joining(", "));
|
||
|
String body = ret == Ret.VOID ? "{ }" : "{ return p0; }";
|
||
|
String res = String.format("EXPORT %s f%s(%s) %s", retType, fName,
|
||
|
sig, declOnly ? ";" : body);
|
||
|
System.out.println(res);
|
||
|
}
|
||
|
|
||
|
static void generateUpcalls(boolean header) {
|
||
|
if (header) {
|
||
|
System.out.println(
|
||
|
"#ifdef _WIN64\n" +
|
||
|
"#define EXPORT __declspec(dllexport)\n" +
|
||
|
"#else\n" +
|
||
|
"#define EXPORT\n" +
|
||
|
"#endif\n"
|
||
|
);
|
||
|
|
||
|
for (int j = 1; j <= MAX_FIELDS; j++) {
|
||
|
for (List<StructFieldType> fields : StructFieldType.perms(j)) {
|
||
|
generateStructDecl(fields);
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
System.out.println(
|
||
|
"#include \"libh\"\n" +
|
||
|
"#ifdef __clang__\n" +
|
||
|
"#pragma clang optimize off\n" +
|
||
|
"#elif defined __GNUC__\n" +
|
||
|
"#pragma GCC optimize (\"O0\")\n" +
|
||
|
"#elif defined _MSC_BUILD\n" +
|
||
|
"#pragma optimize( \"\", off )\n" +
|
||
|
"#endif\n"
|
||
|
);
|
||
|
}
|
||
|
|
||
|
for (Object[] downcall : functions()) {
|
||
|
String fName = (String)downcall[0];
|
||
|
Ret r = (Ret)downcall[1];
|
||
|
@SuppressWarnings("unchecked")
|
||
|
List<ParamType> ptypes = (List<ParamType>)downcall[2];
|
||
|
@SuppressWarnings("unchecked")
|
||
|
List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
|
||
|
generateUpcallFunction(fName, r, ptypes, fields, header);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static void generateUpcallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
|
||
|
String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
|
||
|
List<String> paramDecls = new ArrayList<>();
|
||
|
for (int i = 0 ; i < params.size() ; i++) {
|
||
|
paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
|
||
|
}
|
||
|
String paramNames = IntStream.range(0, params.size())
|
||
|
.mapToObj(i -> "p" + i)
|
||
|
.collect(Collectors.joining(","));
|
||
|
String sig = paramDecls.isEmpty() ?
|
||
|
"" :
|
||
|
paramDecls.stream().collect(Collectors.joining(", ")) + ", ";
|
||
|
String body = String.format(ret == Ret.VOID ? "{ cb(%s); }" : "{ return cb(%s); }", paramNames);
|
||
|
List<String> paramTypes = params.stream().map(p -> p.type(fields)).collect(Collectors.toList());
|
||
|
String cbSig = paramTypes.isEmpty() ?
|
||
|
"void" :
|
||
|
paramTypes.stream().collect(Collectors.joining(", "));
|
||
|
String cbParam = String.format("%s (*cb)(%s)",
|
||
|
retType, cbSig);
|
||
|
|
||
|
String res = String.format("EXPORT %s %s(%s %s) %s", retType, fName,
|
||
|
sig, cbParam, declOnly ? ";" : body);
|
||
|
System.out.println(res);
|
||
|
}
|
||
|
|
||
|
//helper methods
|
||
|
|
||
|
@SuppressWarnings("unchecked")
|
||
|
static Object makeArg(MemoryLayout layout, List<Consumer<Object>> checks, boolean check, List<MemorySegment> segments) throws ReflectiveOperationException {
|
||
|
if (layout instanceof GroupLayout) {
|
||
|
MemorySegment segment = MemorySegment.allocateNative(layout);
|
||
|
initStruct(segment, (GroupLayout)layout, checks, check, segments);
|
||
|
segments.add(segment);
|
||
|
return segment;
|
||
|
} else if (isPointer(layout)) {
|
||
|
MemorySegment segment = MemorySegment.allocateNative(1);
|
||
|
segments.add(segment);
|
||
|
if (check) {
|
||
|
checks.add(o -> {
|
||
|
try {
|
||
|
assertEquals((MemoryAddress)o, segment.address());
|
||
|
} catch (Throwable ex) {
|
||
|
throw new IllegalStateException(ex);
|
||
|
}
|
||
|
});
|
||
|
}
|
||
|
return segment.address();
|
||
|
} else if (layout instanceof ValueLayout) {
|
||
|
if (isIntegral(layout)) {
|
||
|
if (check) {
|
||
|
checks.add(o -> assertEquals(o, 42));
|
||
|
}
|
||
|
return 42;
|
||
|
} else if (layout.bitSize() == 32) {
|
||
|
if (check) {
|
||
|
checks.add(o -> assertEquals(o, 12f));
|
||
|
}
|
||
|
return 12f;
|
||
|
} else {
|
||
|
if (check) {
|
||
|
checks.add(o -> assertEquals(o, 24d));
|
||
|
}
|
||
|
return 24d;
|
||
|
}
|
||
|
} else {
|
||
|
throw new IllegalStateException("Unexpected layout: " + layout);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static void initStruct(MemorySegment str, GroupLayout g, List<Consumer<Object>> checks, boolean check, List<MemorySegment> segments) throws ReflectiveOperationException {
|
||
|
for (MemoryLayout l : g.memberLayouts()) {
|
||
|
if (l.isPadding()) continue;
|
||
|
VarHandle accessor = g.varHandle(structFieldCarrier(l), MemoryLayout.PathElement.groupElement(l.name().get()));
|
||
|
List<Consumer<Object>> fieldsCheck = new ArrayList<>();
|
||
|
Object value = makeArg(l, fieldsCheck, check, segments);
|
||
|
if (isPointer(l)) {
|
||
|
value = ((MemoryAddress)value).toRawLongValue();
|
||
|
}
|
||
|
//set value
|
||
|
accessor.set(str, value);
|
||
|
//add check
|
||
|
if (check) {
|
||
|
assertTrue(fieldsCheck.size() == 1);
|
||
|
checks.add(o -> {
|
||
|
MemorySegment actual = (MemorySegment)o;
|
||
|
try {
|
||
|
if (isPointer(l)) {
|
||
|
fieldsCheck.get(0).accept(MemoryAddress.ofLong((long)accessor.get(actual)));
|
||
|
} else {
|
||
|
fieldsCheck.get(0).accept(accessor.get(actual));
|
||
|
}
|
||
|
} catch (Throwable ex) {
|
||
|
throw new IllegalStateException(ex);
|
||
|
}
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static Class<?> structFieldCarrier(MemoryLayout layout) {
|
||
|
if (isPointer(layout)) {
|
||
|
return long.class;
|
||
|
} else if (layout instanceof ValueLayout) {
|
||
|
if (isIntegral(layout)) {
|
||
|
return int.class;
|
||
|
} else if (layout.bitSize() == 32) {
|
||
|
return float.class;
|
||
|
} else {
|
||
|
return double.class;
|
||
|
}
|
||
|
} else {
|
||
|
throw new IllegalStateException("Unexpected layout: " + layout);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static Class<?> paramCarrier(MemoryLayout layout) {
|
||
|
if (layout instanceof GroupLayout) {
|
||
|
return MemorySegment.class;
|
||
|
} if (isPointer(layout)) {
|
||
|
return MemoryAddress.class;
|
||
|
} else if (layout instanceof ValueLayout) {
|
||
|
if (isIntegral(layout)) {
|
||
|
return int.class;
|
||
|
} else if (layout.bitSize() == 32) {
|
||
|
return float.class;
|
||
|
} else {
|
||
|
return double.class;
|
||
|
}
|
||
|
} else {
|
||
|
throw new IllegalStateException("Unexpected layout: " + layout);
|
||
|
}
|
||
|
}
|
||
|
}
|