From f36f981ca8cea11c3d61262c8cbe2782208d6ddf Mon Sep 17 00:00:00 2001 From: Victorious3 Date: Mon, 8 Aug 2022 14:50:43 +0200 Subject: [PATCH] Lambda captures --- .../target/bytecode/Codegen.java | 27 +++++++---- .../generate/StatementToTargetExpression.java | 45 +++++++++++++++++-- .../generate/TracingStatementVisitor.java | 2 +- .../expression/TargetLambdaExpression.java | 3 +- src/test/java/targetast/TestCodegen.java | 2 +- src/test/java/targetast/TphTest.java | 4 +- 6 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/main/java/de/dhbwstuttgart/target/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/target/bytecode/Codegen.java index 774a4f01..04bee492 100755 --- a/src/main/java/de/dhbwstuttgart/target/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/target/bytecode/Codegen.java @@ -1,5 +1,6 @@ package de.dhbwstuttgart.target.bytecode; +import de.dhbwstuttgart.syntaxtree.statement.Block; import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; @@ -9,10 +10,7 @@ import java.lang.invoke.CallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import java.util.*; import static org.objectweb.asm.Opcodes.*; import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*; @@ -705,9 +703,12 @@ public class Codegen { impl = lambdas.get(lambda); } else { var name = "lambda$" + lambdaCounter++; + var parameters = new ArrayList<>(lambda.captures()); + parameters.addAll(lambda.params()); + impl = new TargetMethod( 0, name, Set.of(), - lambda.params(), lambda.returnType(), lambda.block() + parameters, lambda.returnType(), lambda.block() ); generateMethod(impl); lambdas.put(lambda, impl); @@ -732,9 +733,19 @@ public class Codegen { desugared += "Ljava/lang/Object;"; else desugared += "V"; + var params = new ArrayList(); + params.add(new TargetRefType(clazz.qualifiedName())); + params.addAll(lambda.captures().stream().map(MethodParameter::type).toList()); + + var descriptor = TargetMethod.getDescriptor(lambda.type(), params.toArray(TargetType[]::new)); mv.visitVarInsn(ALOAD, 0); - mv.visitInvokeDynamicInsn("apply", TargetMethod.getDescriptor(lambda.type(), new TargetRefType(clazz.qualifiedName())), - bootstrap, Type.getType(desugared), handle, Type.getType(impl.getDescriptor())); + for (var capture : lambda.captures()) + mv.visitVarInsn(ALOAD, state.scope.get(capture.name()).index); + + mv.visitInvokeDynamicInsn("apply", descriptor, + bootstrap, Type.getType(desugared), handle, + Type.getType(TargetMethod.getDescriptor(impl.returnType(), lambda.params().stream().map(MethodParameter::type).toArray(TargetType[]::new))) + ); } private void generate(State state, TargetExpression expr) { @@ -933,7 +944,7 @@ public class Codegen { if (call.returnType() != null && !(call.returnType() instanceof TargetPrimitiveType)) { if (!call.returnType().equals(call.type()) && !(call.type() instanceof TargetGenericType)) mv.visitTypeInsn(CHECKCAST, call.type().getInternalName()); - else unboxPrimitive(state, call.type()); + unboxPrimitive(state, call.type()); } break; } diff --git a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java index 5a8a2235..8ca6e052 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java @@ -15,9 +15,7 @@ import de.dhbwstuttgart.target.tree.type.TargetSpecializedType; import de.dhbwstuttgart.target.tree.type.TargetType; import java.lang.reflect.Method; -import java.util.List; -import java.util.Objects; -import java.util.Optional; +import java.util.*; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -41,9 +39,48 @@ public class StatementToTargetExpression implements StatementVisitor { .stream(lambdaExpression.params.spliterator(), false) .map(p -> new MethodParameter(converter.convert(p.getType()), p.getName())) .toList(); + + List captures = new ArrayList<>(); + lambdaExpression.methodBody.accept(new TracingStatementVisitor() { + // TODO The same mechanism is implemented in Codegen, maybe use it from there? + final Stack> localVariables = new Stack<>(); + { localVariables.push(new HashSet<>()); } + + boolean hasLocalVar(String name) { + for (var localVariables : this.localVariables) { + if (localVariables.contains(name)) return true; + } + return false; + } + + @Override + public void visit(Block block) { + localVariables.push(new HashSet<>()); + super.visit(block); + localVariables.pop(); + } + + @Override + public void visit(LocalVar localVar) { + super.visit(localVar); + var capture = new MethodParameter(converter.convert(localVar.getType()), localVar.name); + if (!hasLocalVar(localVar.name) && !parameters.contains(capture) && !captures.contains(capture)) + captures.add(capture); + } + + @Override + public void visit(LocalVarDecl varDecl) { + var localVariables = this.localVariables.peek(); + localVariables.add(varDecl.getName()); + } + + @Override + public void visit(LambdaExpression lambda) {} // Don't look at lambda expressions + }); + result = new TargetLambdaExpression( new TargetFunNType(parameters.size(), parameters.stream().map(MethodParameter::type).toList()), - parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody) + captures, parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody) ); } diff --git a/src/main/java/de/dhbwstuttgart/target/generate/TracingStatementVisitor.java b/src/main/java/de/dhbwstuttgart/target/generate/TracingStatementVisitor.java index 7a7352ad..6fc9c248 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/TracingStatementVisitor.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/TracingStatementVisitor.java @@ -144,7 +144,7 @@ public abstract class TracingStatementVisitor implements StatementVisitor { @Override public void visit(ExpressionReceiver expressionReceiver) { - + expressionReceiver.expr.accept(this); } @Override diff --git a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java index c0a5a811..718fb663 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java @@ -1,9 +1,10 @@ package de.dhbwstuttgart.target.tree.expression; import de.dhbwstuttgart.target.tree.MethodParameter; +import de.dhbwstuttgart.target.tree.TargetField; import de.dhbwstuttgart.target.tree.type.TargetType; import java.util.List; -public record TargetLambdaExpression(TargetType type, List params, TargetType returnType, TargetBlock block) implements TargetExpression { +public record TargetLambdaExpression(TargetType type, List captures, List params, TargetType returnType, TargetBlock block) implements TargetExpression { } diff --git a/src/test/java/targetast/TestCodegen.java b/src/test/java/targetast/TestCodegen.java index 165dff72..46c4c834 100644 --- a/src/test/java/targetast/TestCodegen.java +++ b/src/test/java/targetast/TestCodegen.java @@ -330,7 +330,7 @@ public class TestCodegen { targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, new TargetBlock(List.of( new TargetVarDecl(interfaceType, "by2", - new TargetLambdaExpression(interfaceType, List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer, + new TargetLambdaExpression(interfaceType, List.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer, new TargetBlock(List.of( new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "num"), diff --git a/src/test/java/targetast/TphTest.java b/src/test/java/targetast/TphTest.java index a1a1a214..c3f29c7f 100644 --- a/src/test/java/targetast/TphTest.java +++ b/src/test/java/targetast/TphTest.java @@ -25,8 +25,8 @@ public class TphTest { @Test public void test1() throws Exception { var classFiles = TestCodegen.generateClassFiles("Tph7.jav", new ByteArrayClassLoader()); - classToTest = classFiles.get("Tph7"); - instanceOfClass = classToTest.getDeclaredConstructor().newInstance(); + var classToTest = classFiles.get("Tph7"); + var instanceOfClass = classToTest.getDeclaredConstructor().newInstance(); //public DZU m(DZL, DZM); Method m = classToTest.getDeclaredMethod("m", Object.class, Object.class);