From 63493ed0f7ba8f836b1c4a210040fe1d69be9416 Mon Sep 17 00:00:00 2001 From: Daniel Holle Date: Fri, 19 Jul 2024 17:26:39 +0200 Subject: [PATCH] Make lambdas castable --- resources/bytecode/javFiles/LamRunnable.jav | 16 -- .../de/dhbwstuttgart/bytecode/Codegen.java | 230 ++++++++++++------ .../target/generate/ASTToTargetAST.java | 4 +- .../target/tree/type/TargetFunNType.java | 24 +- .../target/tree/type/TargetType.java | 24 +- src/test/java/TestComplete.java | 7 - src/test/java/targetast/TestCodegen.java | 2 +- 7 files changed, 193 insertions(+), 114 deletions(-) delete mode 100644 resources/bytecode/javFiles/LamRunnable.jav diff --git a/resources/bytecode/javFiles/LamRunnable.jav b/resources/bytecode/javFiles/LamRunnable.jav deleted file mode 100644 index d0da84cc..00000000 --- a/resources/bytecode/javFiles/LamRunnable.jav +++ /dev/null @@ -1,16 +0,0 @@ -import java.lang.Runnable; -import java.lang.System; -import java.lang.String; -import java.io.PrintStream; - -public class LamRunnable { - - public LamRunnable() { - Runnable lam = () -> { - System.out.println("lambda"); - }; - - lam.run(); - } -} - \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index 8df566b9..77d1460c 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -5,6 +5,8 @@ import de.dhbwstuttgart.exceptions.NotImplementedException; import de.dhbwstuttgart.parser.NullToken; import de.dhbwstuttgart.parser.scope.JavaClassName; import de.dhbwstuttgart.syntaxtree.ClassOrInterface; +import de.dhbwstuttgart.syntaxtree.Method; +import de.dhbwstuttgart.syntaxtree.Pattern; import de.dhbwstuttgart.syntaxtree.type.RefType; import de.dhbwstuttgart.target.generate.ASTToTargetAST; import de.dhbwstuttgart.target.generate.StatementToTargetExpression; @@ -31,15 +33,21 @@ public class Codegen { private final JavaTXCompiler compiler; private final ASTToTargetAST converter; + private class CustomClassWriter extends ClassWriter { + public CustomClassWriter() { + super(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + } + + @Override + protected ClassLoader getClassLoader() { + return compiler.getClassLoader(); + } + } + public Codegen(TargetStructure clazz, JavaTXCompiler compiler, ASTToTargetAST converter) { this.clazz = clazz; this.className = clazz.qualifiedName().getClassName(); - this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) { - @Override - protected ClassLoader getClassLoader() { - return compiler.getClassLoader(); - } - }; + this.cw = new CustomClassWriter(); this.compiler = compiler; this.converter = converter; } @@ -82,8 +90,6 @@ public class Codegen { int localCounter; MethodVisitor mv; TargetType returnType; - // This is used to remember the type from lambda expressions - TargetType contextType; Stack breakStack = new Stack<>(); Stack switchResultValue = new Stack<>(); @@ -270,13 +276,40 @@ public class Codegen { mv.visitInsn(I2F); else if (dest.equals(TargetType.Double)) mv.visitInsn(I2D); + } else if (isFunctionalInterface(source) && isFunctionalInterface(dest) && + !(source instanceof TargetFunNType && dest instanceof TargetFunNType)) { + boxFunctionalInterface(state, source, dest); } else if (!(dest instanceof TargetGenericType)) { - boxPrimitive(state, source); + //boxPrimitive(state, source); mv.visitTypeInsn(CHECKCAST, dest.getInternalName()); unboxPrimitive(state, dest); } } + record TypePair(TargetType from, TargetType to) {} + private Map funWrapperClasses = new HashMap<>(); + + private void boxFunctionalInterface(State state, TargetType source, TargetType dest) { + var mv = state.mv; + var className = "FunWrapper$$" + + source.name().replaceAll("\\.", "\\$") + + "$_$" + + dest.name().replaceAll("\\.", "\\$"); + + funWrapperClasses.put(new TypePair(source, dest), className); + mv.visitTypeInsn(NEW, className); + mv.visitInsn(DUP_X1); + mv.visitInsn(SWAP); + mv.visitMethodInsn(INVOKESPECIAL, className, "", "(" + source.toDescriptor() + ")V", false); + } + + private boolean isFunctionalInterface(TargetType type) { + if (type instanceof TargetFunNType) return true; + if (type instanceof TargetRefType) + return compiler.getClass(new JavaClassName(type.name())).isFunctionalInterface(); + return false; + } + private TargetType largerType(TargetType left, TargetType right) { if (left.equals(TargetType.String) || right.equals(TargetType.String)) { return TargetType.String; @@ -727,41 +760,15 @@ public class Codegen { var mv = state.mv; String methodName = "apply"; - TargetMethod.Signature signature = null; - - if (!(state.contextType instanceof TargetFunNType ctx)) { - var intf = compiler.getClass(new JavaClassName(state.contextType.name())); - if (intf != null) { - var method = intf.getMethods().stream().filter(m -> Modifier.isAbstract(m.modifier)).findFirst().orElseThrow(); - methodName = method.getName(); - var methodParams = new ArrayList(); - for (var i = 0; i < lambda.signature().parameters().size(); i++) { - var param = lambda.signature().parameters().get(i); - var tpe = converter.convert(method.getParameterList().getParameterAt(i).getType()); - methodParams.add(param.withType(tpe)); - } - var retType = converter.convert(method.getReturnType()); - signature = new TargetMethod.Signature(Set.of(), methodParams, retType); - } - } - if (signature == null) { - signature = new TargetMethod.Signature(Set.of(), lambda.signature().parameters().stream().map(par -> par.withType(TargetType.Object)).toList(), TargetType.Object); - } - - signature = new TargetMethod.Signature( - signature.generics(), - signature.parameters().stream().map(par -> - par.withType(par.pattern().type() instanceof TargetGenericType ? TargetType.Object : par.pattern().type()) - ).toList(), - signature.returnType() instanceof TargetGenericType ? TargetType.Object : signature.returnType() - ); + TargetMethod.Signature signature = new TargetMethod.Signature(Set.of(), + lambda.signature().parameters().stream().map( + par -> par.withType(TargetType.Object)).toList(), + lambda.signature().returnType() != null ? TargetType.Object : null); var parameters = new ArrayList<>(lambda.captures()); parameters.addAll(signature.parameters()); var implSignature = new TargetMethod.Signature(Set.of(), parameters, lambda.signature().returnType()); - // Normalize - TargetMethod impl; if (lambdas.containsKey(lambda)) { impl = lambdas.get(lambda); @@ -782,7 +789,6 @@ public class Codegen { params.add(new TargetRefType(clazz.qualifiedName().getClassName())); params.addAll(lambda.captures().stream().map(mp -> mp.pattern().type()).toList()); - var descriptor = TargetMethod.getDescriptor(state.contextType, params.toArray(TargetType[]::new)); mv.visitVarInsn(ALOAD, 0); for (var index = 0; index < lambda.captures().size(); index++) { var capture = lambda.captures().get(index); @@ -792,9 +798,42 @@ public class Codegen { mv.visitTypeInsn(CHECKCAST, capture.pattern().type().getInternalName()); } + var descriptor = TargetMethod.getDescriptor(lambda.type(), params.toArray(TargetType[]::new)); mv.visitInvokeDynamicInsn(methodName, descriptor, bootstrap, Type.getType(signature.getSignature()), handle, Type.getType(signature.getDescriptor())); } + private int findReturnCode(TargetType returnType) { + if (returnType.equals(TargetType.boolean_) + || returnType.equals(TargetType.char_) + || returnType.equals(TargetType.int_) + || returnType.equals(TargetType.short_) + || returnType.equals(TargetType.byte_)) + return IRETURN; + else if (returnType.equals(TargetType.long_)) + return LRETURN; + else if (returnType.equals(TargetType.float_)) + return FRETURN; + else if (returnType.equals(TargetType.double_)) + return DRETURN; + return ARETURN; + } + + private int findLoadCode(TargetType loadType) { + if (loadType.equals(TargetType.boolean_) + || loadType.equals(TargetType.char_) + || loadType.equals(TargetType.int_) + || loadType.equals(TargetType.short_) + || loadType.equals(TargetType.byte_)) + return ILOAD; + else if (loadType.equals(TargetType.long_)) + return LLOAD; + else if (loadType.equals(TargetType.float_)) + return FLOAD; + else if (loadType.equals(TargetType.double_)) + return DLOAD; + return ALOAD; + } + private void generate(State state, TargetExpression expr) { var mv = state.mv; switch (expr) { @@ -819,10 +858,7 @@ public class Codegen { break; } case TargetCast cast: - var ctx = state.contextType; - state.contextType = cast.type(); generate(state, cast.expr()); - state.contextType = ctx; convertTo(state, cast.expr().type(), cast.type()); break; case TargetInstanceOf instanceOf: @@ -867,10 +903,7 @@ public class Codegen { case TargetAssign assign: { switch (assign.left()) { case TargetLocalVar localVar -> { - var ctype = state.contextType; - state.contextType = localVar.type(); generate(state, assign.right()); - state.contextType = ctype; convertTo(state, assign.right().type(), localVar.type()); boxPrimitive(state, localVar.type()); @@ -883,10 +916,7 @@ public class Codegen { if (!(dot.left() instanceof TargetThis && dot.isStatic())) generate(state, dot.left()); - var ctype = state.contextType; - state.contextType = fieldType; generate(state, assign.right()); - state.contextType = ctype; convertTo(state, assign.right().type(), fieldType); boxPrimitive(state, fieldType); @@ -1016,29 +1046,12 @@ public class Codegen { case TargetReturn ret: { if (ret.expression() != null && state.returnType != null) { if (state.returnType instanceof TargetPrimitiveType) { - var ctype = state.contextType; - state.contextType = state.returnType; generate(state, ret.expression()); - state.contextType = ctype; unboxPrimitive(state, state.returnType); - if (state.returnType.equals(TargetType.boolean_) - || state.returnType.equals(TargetType.char_) - || state.returnType.equals(TargetType.int_) - || state.returnType.equals(TargetType.short_) - || state.returnType.equals(TargetType.byte_)) - mv.visitInsn(IRETURN); - else if (state.returnType.equals(TargetType.long_)) - mv.visitInsn(LRETURN); - else if (state.returnType.equals(TargetType.float_)) - mv.visitInsn(FRETURN); - else if (state.returnType.equals(TargetType.double_)) - mv.visitInsn(DRETURN); + mv.visitInsn(findReturnCode(state.returnType)); } else { - var ctype = state.contextType; - state.contextType = state.returnType; generate(state, ret.expression()); - state.contextType = ctype; boxPrimitive(state, ret.expression().type()); convertTo(state, ret.expression().type(), state.returnType); mv.visitInsn(ARETURN); @@ -1089,12 +1102,10 @@ public class Codegen { for (var i = 0; i < call.args().size(); i++) { var e = call.args().get(i); var arg = call.parameterTypes().get(i); - var ctype = state.contextType; - state.contextType = arg; generate(state, e); + convertTo(state, e.type(), arg); if (!(arg instanceof TargetPrimitiveType)) boxPrimitive(state, e.type()); - state.contextType = ctype; } var descriptor = call.getDescriptor(); if (call.owner() instanceof TargetFunNType) // Decay FunN @@ -1597,6 +1608,85 @@ public class Codegen { if (clazz instanceof TargetRecord) generateRecordMethods(); + // Generate wrapper classes for function types + for (var pair : funWrapperClasses.keySet()) { + var className = funWrapperClasses.get(pair); + ClassWriter cw2 = new CustomClassWriter(); + cw2.visit(V1_8, ACC_PUBLIC, className, null, "java/lang/Object", new String[] { pair.to.getInternalName() }); + cw2.visitField(ACC_PRIVATE, "wrapped", pair.from.toDescriptor(), null, null).visitEnd(); + + // Generate constructor + var ctor = cw2.visitMethod(ACC_PUBLIC, "", "(" + pair.from.toDescriptor() + ")V", null, null); + ctor.visitVarInsn(ALOAD, 0); + ctor.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false); + ctor.visitVarInsn(ALOAD, 0); + ctor.visitVarInsn(ALOAD, 1); + ctor.visitFieldInsn(PUTFIELD, className, "wrapped", pair.from.toDescriptor()); + ctor.visitInsn(RETURN); + ctor.visitMaxs(0, 0); + ctor.visitEnd(); + + String methodName = "apply"; + String fromDescriptor = null; + TargetType fromReturn = null; + if (!(pair.from instanceof TargetFunNType funNType)) { + var fromClass = compiler.getClass(new JavaClassName(pair.from.name())); + var fromMethod = fromClass.getMethods().stream().filter(m -> (m.modifier & ACC_ABSTRACT) != 0).findFirst().orElseThrow(); + methodName = fromMethod.name; + + fromReturn = converter.convert(fromMethod.getReturnType()); + var fromParams = converter.convert(fromMethod.getParameterList(), converter.generics.javaGenerics()).stream().map(m -> m.pattern().type()).toArray(TargetType[]::new); + fromDescriptor = TargetMethod.getDescriptor(fromReturn, fromParams); + } else { + fromReturn = funNType.arity() > 1 ? TargetType.Object : null; + fromDescriptor = funNType.toMethodDescriptor(); + } + + var toClass = compiler.getClass(new JavaClassName(pair.to.name())); + var toMethod = toClass.getMethods().stream().filter(m -> (m.modifier & ACC_ABSTRACT) != 0).findFirst().orElseThrow(); + var toReturn = converter.convert(toMethod.getReturnType()); + var toParams = converter.convert(toMethod.getParameterList(), converter.generics.javaGenerics()).stream().map(m -> m.pattern().type()).toArray(TargetType[]::new); + var toDescriptor = TargetMethod.getDescriptor(toReturn, toParams); + + // Generate wrapper method + var mv = cw2.visitMethod(ACC_PUBLIC, toMethod.name, toDescriptor, null, null); + var state = new State(null, mv, 0); + + mv.visitVarInsn(ALOAD, 0); + mv.visitFieldInsn(GETFIELD, className, "wrapped", pair.from.toDescriptor()); + for (var i = 0; i < toParams.length; i++) { + var arg = toParams[i]; + mv.visitVarInsn(findLoadCode(arg), i + 1); + } + mv.visitMethodInsn(INVOKEINTERFACE, pair.from.getInternalName(), methodName, fromDescriptor, true); + if (fromReturn != null) { + if (toReturn instanceof TargetPrimitiveType) { + convertTo(state, fromReturn, TargetType.toWrapper(toReturn)); + } else convertTo(state, fromReturn, toReturn); + } + + if (toReturn != null) + mv.visitInsn(findReturnCode(toReturn)); + + else mv.visitInsn(RETURN); + mv.visitMaxs(0, 0); + mv.visitEnd(); + + cw2.visitEnd(); + var bytes = cw2.toByteArray(); + converter.auxiliaries.put(className, bytes); + + // TODO These class loading shenanigans happen in a few places, the tests load the classes individually. + // Instead we should just look at the folder. + try { + converter.classLoader.findClass(className); + } catch (ClassNotFoundException e) { + try { + converter.classLoader.loadClass(bytes); + } catch (LinkageError ignored) {} + } + } + cw.visitEnd(); return cw.toByteArray(); } diff --git a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java index d6e73ba7..f80ec109 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java @@ -57,7 +57,7 @@ public class ASTToTargetAST { } - protected IByteArrayClassLoader classLoader; + public IByteArrayClassLoader classLoader; protected SourceFile sourceFile; public ASTToTargetAST(List resultSets) { @@ -483,7 +483,7 @@ public class ASTToTargetAST { if (gep.parameters.get(i) != null) filteredParams.add(newParams.get(i)); } - return TargetFunNType.fromParams(params, filteredParams); + return TargetFunNType.fromParams(params, filteredParams, params.size()); } private boolean isSubtype(TargetType test, TargetType other) { diff --git a/src/main/java/de/dhbwstuttgart/target/tree/type/TargetFunNType.java b/src/main/java/de/dhbwstuttgart/target/tree/type/TargetFunNType.java index f35f1043..1c998662 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/type/TargetFunNType.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/type/TargetFunNType.java @@ -4,15 +4,29 @@ import de.dhbwstuttgart.bytecode.FunNGenerator; import java.util.List; -public record TargetFunNType(String name, List params) implements TargetSpecializedType { +public record TargetFunNType(String name, List params, int arity) implements TargetSpecializedType { - public static TargetFunNType fromParams(List params) { - return fromParams(params, params); + public static TargetFunNType fromParams(List params, int arity) { + return fromParams(params, params, arity); } - public static TargetFunNType fromParams(List params, List realParams) { + public static TargetFunNType fromParams(List params, List realParams, int arity) { var name = FunNGenerator.getSpecializedClassName(FunNGenerator.getArguments(params), FunNGenerator.getReturnType(params)); - return new TargetFunNType(name, realParams); + return new TargetFunNType(name, realParams, arity); + } + + public String toMethodDescriptor() { + var res = "("; + for (var i = 0; i < arity - 1; i++) { + res += "Ljava/lang/Object;"; + } + res += ")"; + if (arity > 0) { + res += "Ljava/lang/Object;"; + } else { + res += "V"; + } + return res; } @Override diff --git a/src/main/java/de/dhbwstuttgart/target/tree/type/TargetType.java b/src/main/java/de/dhbwstuttgart/target/tree/type/TargetType.java index 1cb57f90..f34e84ef 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/type/TargetType.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/type/TargetType.java @@ -53,19 +53,17 @@ public sealed interface TargetType }; } - static TargetType toTargetType(Class clazz) { - if (clazz.isPrimitive()) { - if (clazz.equals(boolean.class)) return boolean_; - if (clazz.equals(char.class)) return char_; - if (clazz.equals(byte.class)) return byte_; - if (clazz.equals(short.class)) return short_; - if (clazz.equals(int.class)) return int_; - if (clazz.equals(long.class)) return long_; - if (clazz.equals(float.class)) return float_; - if (clazz.equals(double.class)) return double_; - } - if (clazz.equals(void.class)) return null; - return new TargetRefType(clazz.getName()); + static TargetType toWrapper(TargetType f) { + if (f.equals(boolean_)) return Boolean; + if (f.equals(char_)) return Char; + if (f.equals(byte_)) return Byte; + if (f.equals(short_)) return Short; + if (f.equals(int_)) return Integer; + if (f.equals(long_)) return Long; + if (f.equals(float_)) return Float; + if (f.equals(double_)) return Double; + + return f; } String toSignature(); diff --git a/src/test/java/TestComplete.java b/src/test/java/TestComplete.java index 5638e64a..c4599e02 100644 --- a/src/test/java/TestComplete.java +++ b/src/test/java/TestComplete.java @@ -860,13 +860,6 @@ public class TestComplete { var instance = clazz.getDeclaredConstructor().newInstance(); } - @Test - public void testLamRunnable() throws Exception { - var classFiles = generateClassFiles(new ByteArrayClassLoader(), "LamRunnable.jav"); - var clazz = classFiles.get("LamRunnable"); - var instance = clazz.getDeclaredConstructor().newInstance(); - } - @Test public void testAccess() throws Exception { var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Access.jav"); diff --git a/src/test/java/targetast/TestCodegen.java b/src/test/java/targetast/TestCodegen.java index fb901776..e83c2f95 100644 --- a/src/test/java/targetast/TestCodegen.java +++ b/src/test/java/targetast/TestCodegen.java @@ -277,7 +277,7 @@ public class TestCodegen { public void testLambda() throws Exception { var classLoader = new ByteArrayClassLoader(); // var fun = classLoader.loadClass(Path.of(System.getProperty("user.dir"), "src/test/java/targetast/Fun1$$.class")); - var interfaceType = TargetFunNType.fromParams(List.of(TargetType.Integer)); + var interfaceType = TargetFunNType.fromParams(List.of(TargetType.Integer), 1); var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, new JavaClassName("CGLambda")); targetClass.addConstructor(Opcodes.ACC_PUBLIC, List.of(), new TargetBlock(List.of(new TargetMethodCall(null, new TargetSuper(TargetType.Object), List.of(), TargetType.Object, "", false, false, false))));