Lambda captures

This commit is contained in:
Victorious3 2022-08-08 14:50:43 +02:00
parent d87ea005b1
commit f36f981ca8
6 changed files with 66 additions and 17 deletions

View File

@ -1,5 +1,6 @@
package de.dhbwstuttgart.target.bytecode;
import de.dhbwstuttgart.syntaxtree.statement.Block;
import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*;
@ -9,10 +10,7 @@ import java.lang.invoke.CallSite;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.*;
import static org.objectweb.asm.Opcodes.*;
import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*;
@ -705,9 +703,12 @@ public class Codegen {
impl = lambdas.get(lambda);
} else {
var name = "lambda$" + lambdaCounter++;
var parameters = new ArrayList<>(lambda.captures());
parameters.addAll(lambda.params());
impl = new TargetMethod(
0, name, Set.of(),
lambda.params(), lambda.returnType(), lambda.block()
parameters, lambda.returnType(), lambda.block()
);
generateMethod(impl);
lambdas.put(lambda, impl);
@ -732,9 +733,19 @@ public class Codegen {
desugared += "Ljava/lang/Object;";
else desugared += "V";
var params = new ArrayList<TargetType>();
params.add(new TargetRefType(clazz.qualifiedName()));
params.addAll(lambda.captures().stream().map(MethodParameter::type).toList());
var descriptor = TargetMethod.getDescriptor(lambda.type(), params.toArray(TargetType[]::new));
mv.visitVarInsn(ALOAD, 0);
mv.visitInvokeDynamicInsn("apply", TargetMethod.getDescriptor(lambda.type(), new TargetRefType(clazz.qualifiedName())),
bootstrap, Type.getType(desugared), handle, Type.getType(impl.getDescriptor()));
for (var capture : lambda.captures())
mv.visitVarInsn(ALOAD, state.scope.get(capture.name()).index);
mv.visitInvokeDynamicInsn("apply", descriptor,
bootstrap, Type.getType(desugared), handle,
Type.getType(TargetMethod.getDescriptor(impl.returnType(), lambda.params().stream().map(MethodParameter::type).toArray(TargetType[]::new)))
);
}
private void generate(State state, TargetExpression expr) {
@ -933,7 +944,7 @@ public class Codegen {
if (call.returnType() != null && !(call.returnType() instanceof TargetPrimitiveType)) {
if (!call.returnType().equals(call.type()) && !(call.type() instanceof TargetGenericType))
mv.visitTypeInsn(CHECKCAST, call.type().getInternalName());
else unboxPrimitive(state, call.type());
unboxPrimitive(state, call.type());
}
break;
}

View File

@ -15,9 +15,7 @@ import de.dhbwstuttgart.target.tree.type.TargetSpecializedType;
import de.dhbwstuttgart.target.tree.type.TargetType;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.*;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
@ -41,9 +39,48 @@ public class StatementToTargetExpression implements StatementVisitor {
.stream(lambdaExpression.params.spliterator(), false)
.map(p -> new MethodParameter(converter.convert(p.getType()), p.getName()))
.toList();
List<MethodParameter> captures = new ArrayList<>();
lambdaExpression.methodBody.accept(new TracingStatementVisitor() {
// TODO The same mechanism is implemented in Codegen, maybe use it from there?
final Stack<Set<String>> localVariables = new Stack<>();
{ localVariables.push(new HashSet<>()); }
boolean hasLocalVar(String name) {
for (var localVariables : this.localVariables) {
if (localVariables.contains(name)) return true;
}
return false;
}
@Override
public void visit(Block block) {
localVariables.push(new HashSet<>());
super.visit(block);
localVariables.pop();
}
@Override
public void visit(LocalVar localVar) {
super.visit(localVar);
var capture = new MethodParameter(converter.convert(localVar.getType()), localVar.name);
if (!hasLocalVar(localVar.name) && !parameters.contains(capture) && !captures.contains(capture))
captures.add(capture);
}
@Override
public void visit(LocalVarDecl varDecl) {
var localVariables = this.localVariables.peek();
localVariables.add(varDecl.getName());
}
@Override
public void visit(LambdaExpression lambda) {} // Don't look at lambda expressions
});
result = new TargetLambdaExpression(
new TargetFunNType(parameters.size(), parameters.stream().map(MethodParameter::type).toList()),
parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody)
captures, parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody)
);
}

View File

@ -144,7 +144,7 @@ public abstract class TracingStatementVisitor implements StatementVisitor {
@Override
public void visit(ExpressionReceiver expressionReceiver) {
expressionReceiver.expr.accept(this);
}
@Override

View File

@ -1,9 +1,10 @@
package de.dhbwstuttgart.target.tree.expression;
import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetField;
import de.dhbwstuttgart.target.tree.type.TargetType;
import java.util.List;
public record TargetLambdaExpression(TargetType type, List<MethodParameter> params, TargetType returnType, TargetBlock block) implements TargetExpression {
public record TargetLambdaExpression(TargetType type, List<MethodParameter> captures, List<MethodParameter> params, TargetType returnType, TargetBlock block) implements TargetExpression {
}

View File

@ -330,7 +330,7 @@ public class TestCodegen {
targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer,
new TargetBlock(List.of(
new TargetVarDecl(interfaceType, "by2",
new TargetLambdaExpression(interfaceType, List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer,
new TargetLambdaExpression(interfaceType, List.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer,
new TargetBlock(List.of(
new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer,
new TargetLocalVar(TargetType.Integer, "num"),

View File

@ -25,8 +25,8 @@ public class TphTest {
@Test
public void test1() throws Exception {
var classFiles = TestCodegen.generateClassFiles("Tph7.jav", new ByteArrayClassLoader());
classToTest = classFiles.get("Tph7");
instanceOfClass = classToTest.getDeclaredConstructor().newInstance();
var classToTest = classFiles.get("Tph7");
var instanceOfClass = classToTest.getDeclaredConstructor().newInstance();
//public <DZN, DZL, DZU extends DZN, DZM extends DZU> DZU m(DZL, DZM);
Method m = classToTest.getDeclaredMethod("m", Object.class, Object.class);