Lambda captures

This commit is contained in:
Victorious3 2022-08-08 14:50:43 +02:00
parent d87ea005b1
commit f36f981ca8
6 changed files with 66 additions and 17 deletions

View File

@ -1,5 +1,6 @@
package de.dhbwstuttgart.target.bytecode; package de.dhbwstuttgart.target.bytecode;
import de.dhbwstuttgart.syntaxtree.statement.Block;
import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*; import de.dhbwstuttgart.target.tree.type.*;
@ -9,10 +10,7 @@ import java.lang.invoke.CallSite;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.util.HashMap; import java.util.*;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.objectweb.asm.Opcodes.*; import static org.objectweb.asm.Opcodes.*;
import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*; import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*;
@ -705,9 +703,12 @@ public class Codegen {
impl = lambdas.get(lambda); impl = lambdas.get(lambda);
} else { } else {
var name = "lambda$" + lambdaCounter++; var name = "lambda$" + lambdaCounter++;
var parameters = new ArrayList<>(lambda.captures());
parameters.addAll(lambda.params());
impl = new TargetMethod( impl = new TargetMethod(
0, name, Set.of(), 0, name, Set.of(),
lambda.params(), lambda.returnType(), lambda.block() parameters, lambda.returnType(), lambda.block()
); );
generateMethod(impl); generateMethod(impl);
lambdas.put(lambda, impl); lambdas.put(lambda, impl);
@ -732,9 +733,19 @@ public class Codegen {
desugared += "Ljava/lang/Object;"; desugared += "Ljava/lang/Object;";
else desugared += "V"; else desugared += "V";
var params = new ArrayList<TargetType>();
params.add(new TargetRefType(clazz.qualifiedName()));
params.addAll(lambda.captures().stream().map(MethodParameter::type).toList());
var descriptor = TargetMethod.getDescriptor(lambda.type(), params.toArray(TargetType[]::new));
mv.visitVarInsn(ALOAD, 0); mv.visitVarInsn(ALOAD, 0);
mv.visitInvokeDynamicInsn("apply", TargetMethod.getDescriptor(lambda.type(), new TargetRefType(clazz.qualifiedName())), for (var capture : lambda.captures())
bootstrap, Type.getType(desugared), handle, Type.getType(impl.getDescriptor())); mv.visitVarInsn(ALOAD, state.scope.get(capture.name()).index);
mv.visitInvokeDynamicInsn("apply", descriptor,
bootstrap, Type.getType(desugared), handle,
Type.getType(TargetMethod.getDescriptor(impl.returnType(), lambda.params().stream().map(MethodParameter::type).toArray(TargetType[]::new)))
);
} }
private void generate(State state, TargetExpression expr) { private void generate(State state, TargetExpression expr) {
@ -933,7 +944,7 @@ public class Codegen {
if (call.returnType() != null && !(call.returnType() instanceof TargetPrimitiveType)) { if (call.returnType() != null && !(call.returnType() instanceof TargetPrimitiveType)) {
if (!call.returnType().equals(call.type()) && !(call.type() instanceof TargetGenericType)) if (!call.returnType().equals(call.type()) && !(call.type() instanceof TargetGenericType))
mv.visitTypeInsn(CHECKCAST, call.type().getInternalName()); mv.visitTypeInsn(CHECKCAST, call.type().getInternalName());
else unboxPrimitive(state, call.type()); unboxPrimitive(state, call.type());
} }
break; break;
} }

View File

@ -15,9 +15,7 @@ import de.dhbwstuttgart.target.tree.type.TargetSpecializedType;
import de.dhbwstuttgart.target.tree.type.TargetType; import de.dhbwstuttgart.target.tree.type.TargetType;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.List; import java.util.*;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream; import java.util.stream.Stream;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;
@ -41,9 +39,48 @@ public class StatementToTargetExpression implements StatementVisitor {
.stream(lambdaExpression.params.spliterator(), false) .stream(lambdaExpression.params.spliterator(), false)
.map(p -> new MethodParameter(converter.convert(p.getType()), p.getName())) .map(p -> new MethodParameter(converter.convert(p.getType()), p.getName()))
.toList(); .toList();
List<MethodParameter> captures = new ArrayList<>();
lambdaExpression.methodBody.accept(new TracingStatementVisitor() {
// TODO The same mechanism is implemented in Codegen, maybe use it from there?
final Stack<Set<String>> localVariables = new Stack<>();
{ localVariables.push(new HashSet<>()); }
boolean hasLocalVar(String name) {
for (var localVariables : this.localVariables) {
if (localVariables.contains(name)) return true;
}
return false;
}
@Override
public void visit(Block block) {
localVariables.push(new HashSet<>());
super.visit(block);
localVariables.pop();
}
@Override
public void visit(LocalVar localVar) {
super.visit(localVar);
var capture = new MethodParameter(converter.convert(localVar.getType()), localVar.name);
if (!hasLocalVar(localVar.name) && !parameters.contains(capture) && !captures.contains(capture))
captures.add(capture);
}
@Override
public void visit(LocalVarDecl varDecl) {
var localVariables = this.localVariables.peek();
localVariables.add(varDecl.getName());
}
@Override
public void visit(LambdaExpression lambda) {} // Don't look at lambda expressions
});
result = new TargetLambdaExpression( result = new TargetLambdaExpression(
new TargetFunNType(parameters.size(), parameters.stream().map(MethodParameter::type).toList()), new TargetFunNType(parameters.size(), parameters.stream().map(MethodParameter::type).toList()),
parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody) captures, parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody)
); );
} }

View File

@ -144,7 +144,7 @@ public abstract class TracingStatementVisitor implements StatementVisitor {
@Override @Override
public void visit(ExpressionReceiver expressionReceiver) { public void visit(ExpressionReceiver expressionReceiver) {
expressionReceiver.expr.accept(this);
} }
@Override @Override

View File

@ -1,9 +1,10 @@
package de.dhbwstuttgart.target.tree.expression; package de.dhbwstuttgart.target.tree.expression;
import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetField;
import de.dhbwstuttgart.target.tree.type.TargetType; import de.dhbwstuttgart.target.tree.type.TargetType;
import java.util.List; import java.util.List;
public record TargetLambdaExpression(TargetType type, List<MethodParameter> params, TargetType returnType, TargetBlock block) implements TargetExpression { public record TargetLambdaExpression(TargetType type, List<MethodParameter> captures, List<MethodParameter> params, TargetType returnType, TargetBlock block) implements TargetExpression {
} }

View File

@ -330,7 +330,7 @@ public class TestCodegen {
targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer,
new TargetBlock(List.of( new TargetBlock(List.of(
new TargetVarDecl(interfaceType, "by2", new TargetVarDecl(interfaceType, "by2",
new TargetLambdaExpression(interfaceType, List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer, new TargetLambdaExpression(interfaceType, List.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer,
new TargetBlock(List.of( new TargetBlock(List.of(
new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer,
new TargetLocalVar(TargetType.Integer, "num"), new TargetLocalVar(TargetType.Integer, "num"),

View File

@ -25,8 +25,8 @@ public class TphTest {
@Test @Test
public void test1() throws Exception { public void test1() throws Exception {
var classFiles = TestCodegen.generateClassFiles("Tph7.jav", new ByteArrayClassLoader()); var classFiles = TestCodegen.generateClassFiles("Tph7.jav", new ByteArrayClassLoader());
classToTest = classFiles.get("Tph7"); var classToTest = classFiles.get("Tph7");
instanceOfClass = classToTest.getDeclaredConstructor().newInstance(); var instanceOfClass = classToTest.getDeclaredConstructor().newInstance();
//public <DZN, DZL, DZU extends DZN, DZM extends DZU> DZU m(DZL, DZM); //public <DZN, DZL, DZU extends DZN, DZM extends DZU> DZU m(DZL, DZM);
Method m = classToTest.getDeclaredMethod("m", Object.class, Object.class); Method m = classToTest.getDeclaredMethod("m", Object.class, Object.class);