This commit is contained in:
Daniel Holle 2024-04-09 14:58:43 +02:00
parent 0b7f07108f
commit cb7d0e22cc
9 changed files with 102 additions and 38 deletions

View File

@ -0,0 +1,13 @@
import java.lang.Integer;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Stream;
import java.util.function.Predicate;
import java.util.function.Function;
import java.util.stream.Collectors;
public class Bug314 {
public List<Integer> convert(List<Integer> in) {
return in.stream().filter(x -> x > 5).collect(Collectors.toList());
}
}

View File

@ -6,6 +6,8 @@ import de.dhbwstuttgart.parser.NullToken;
import de.dhbwstuttgart.parser.scope.JavaClassName; import de.dhbwstuttgart.parser.scope.JavaClassName;
import de.dhbwstuttgart.syntaxtree.ClassOrInterface; import de.dhbwstuttgart.syntaxtree.ClassOrInterface;
import de.dhbwstuttgart.syntaxtree.type.RefType; import de.dhbwstuttgart.syntaxtree.type.RefType;
import de.dhbwstuttgart.target.generate.ASTToTargetAST;
import de.dhbwstuttgart.target.generate.StatementToTargetExpression;
import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*; import de.dhbwstuttgart.target.tree.type.*;
@ -14,6 +16,7 @@ import org.objectweb.asm.*;
import java.lang.invoke.*; import java.lang.invoke.*;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.util.*; import java.util.*;
import java.util.stream.IntStream;
import static org.objectweb.asm.Opcodes.*; import static org.objectweb.asm.Opcodes.*;
import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*; import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*;
@ -26,8 +29,9 @@ public class Codegen {
private int lambdaCounter = 0; private int lambdaCounter = 0;
private final HashMap<TargetLambdaExpression, TargetMethod> lambdas = new HashMap<>(); private final HashMap<TargetLambdaExpression, TargetMethod> lambdas = new HashMap<>();
private final JavaTXCompiler compiler; private final JavaTXCompiler compiler;
private final ASTToTargetAST converter;
public Codegen(TargetStructure clazz, JavaTXCompiler compiler) { public Codegen(TargetStructure clazz, JavaTXCompiler compiler, ASTToTargetAST converter) {
this.clazz = clazz; this.clazz = clazz;
this.className = clazz.qualifiedName().getClassName(); this.className = clazz.qualifiedName().getClassName();
this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) { this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) {
@ -37,6 +41,7 @@ public class Codegen {
} }
}; };
this.compiler = compiler; this.compiler = compiler;
this.converter = converter;
} }
private record LocalVar(int index, String name, TargetType type) { private record LocalVar(int index, String name, TargetType type) {
@ -721,15 +726,49 @@ public class Codegen {
private void generateLambdaExpression(State state, TargetLambdaExpression lambda) { private void generateLambdaExpression(State state, TargetLambdaExpression lambda) {
var mv = state.mv; var mv = state.mv;
String methodName = "apply";
TargetMethod.Signature signature = null;
if (!(state.contextType instanceof TargetFunNType ctx)) {
var intf = compiler.getClass(new JavaClassName(state.contextType.name()));
if (intf != null) {
var method = intf.getMethods().stream().filter(m -> Modifier.isAbstract(m.modifier)).findFirst().orElseThrow();
methodName = method.getName();
var methodParams = new ArrayList<MethodParameter>();
for (var i = 0; i < lambda.signature().parameters().size(); i++) {
var param = lambda.signature().parameters().get(i);
var tpe = converter.convert(method.getParameterList().getParameterAt(i).getType());
methodParams.add(param.withType(tpe));
}
var retType = converter.convert(method.getReturnType());
signature = new TargetMethod.Signature(Set.of(), methodParams, retType);
}
}
if (signature == null) {
signature = new TargetMethod.Signature(Set.of(), lambda.signature().parameters().stream().map(par -> par.withType(TargetType.Object)).toList(), TargetType.Object);
}
signature = new TargetMethod.Signature(
signature.generics(),
signature.parameters().stream().map(par ->
par.withType(par.pattern().type() instanceof TargetGenericType ? TargetType.Object : par.pattern().type())
).toList(),
signature.returnType() instanceof TargetGenericType ? TargetType.Object : signature.returnType()
);
var parameters = new ArrayList<>(lambda.captures());
parameters.addAll(signature.parameters());
var implSignature = new TargetMethod.Signature(Set.of(), parameters, lambda.signature().returnType());
// Normalize
TargetMethod impl; TargetMethod impl;
if (lambdas.containsKey(lambda)) { if (lambdas.containsKey(lambda)) {
impl = lambdas.get(lambda); impl = lambdas.get(lambda);
} else { } else {
var name = "lambda$" + lambdaCounter++; var name = "lambda$" + lambdaCounter++;
var parameters = new ArrayList<>(lambda.captures());
parameters.addAll(lambda.params().stream().map(param -> param.pattern().type() instanceof TargetGenericType ? param.withType(TargetType.Object) : param).toList());
impl = new TargetMethod(0, name, lambda.block(), new TargetMethod.Signature(Set.of(), parameters, lambda.returnType() instanceof TargetGenericType ? TargetType.Object : lambda.returnType()), null); impl = new TargetMethod(0, name, lambda.block(), implSignature, null);
generateMethod(impl); generateMethod(impl);
lambdas.put(lambda, impl); lambdas.put(lambda, impl);
} }
@ -737,17 +776,7 @@ public class Codegen {
var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, MethodType.class, MethodHandle.class, MethodType.class); var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, MethodType.class, MethodHandle.class, MethodType.class);
var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/invoke/LambdaMetafactory", "metafactory", mt.toMethodDescriptorString(), false); var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/invoke/LambdaMetafactory", "metafactory", mt.toMethodDescriptorString(), false);
var handle = new Handle(H_INVOKEVIRTUAL, clazz.getName(), impl.name(), impl.getDescriptor(), false); var handle = new Handle(H_INVOKEVIRTUAL, clazz.getName(), impl.name(), implSignature.getDescriptor(), false);
// TODO maybe make this a function?
var desugared = "(";
for (var param : lambda.params())
desugared += "Ljava/lang/Object;";
desugared += ")";
if (lambda.returnType() != null)
desugared += "Ljava/lang/Object;";
else
desugared += "V";
var params = new ArrayList<TargetType>(); var params = new ArrayList<TargetType>();
params.add(new TargetRefType(clazz.qualifiedName().getClassName())); params.add(new TargetRefType(clazz.qualifiedName().getClassName()));
@ -760,12 +789,7 @@ public class Codegen {
mv.visitVarInsn(ALOAD, state.scope.get(pattern.name()).index); mv.visitVarInsn(ALOAD, state.scope.get(pattern.name()).index);
} }
String methodName; mv.visitInvokeDynamicInsn(methodName, descriptor, bootstrap, Type.getType(signature.getSignature()), handle, Type.getType(signature.getDescriptor()));
var intf = compiler.getClass(new JavaClassName(state.contextType.name()));
if (intf == null) methodName = "apply"; // TODO Weird fallback logic here
else methodName = intf.getMethods().stream().filter(m -> Modifier.isAbstract(m.modifier)).findFirst().orElseThrow().getName();
mv.visitInvokeDynamicInsn(methodName, descriptor, bootstrap, Type.getType(desugared), handle, Type.getType(TargetMethod.getDescriptor(impl.signature().returnType(), lambda.params().stream().map(mp -> mp.pattern().type()).toArray(TargetType[]::new))));
} }
private void generate(State state, TargetExpression expr) { private void generate(State state, TargetExpression expr) {
@ -876,6 +900,8 @@ public class Codegen {
case TargetLocalVar localVar: { case TargetLocalVar localVar: {
LocalVar local = state.scope.get(localVar.name()); LocalVar local = state.scope.get(localVar.name());
mv.visitVarInsn(ALOAD, local.index()); mv.visitVarInsn(ALOAD, local.index());
// This is a bit weird but sometimes the types don't match (see lambda expressions)
convertTo(state, local.type(), localVar.type());
unboxPrimitive(state, local.type()); unboxPrimitive(state, local.type());
break; break;
} }

View File

@ -754,7 +754,7 @@ public class JavaTXCompiler {
var converter = new ASTToTargetAST(this, typeInferenceResult, sf, classLoader); var converter = new ASTToTargetAST(this, typeInferenceResult, sf, classLoader);
var generatedClasses = new HashMap<JavaClassName, byte[]>(); var generatedClasses = new HashMap<JavaClassName, byte[]>();
for (var clazz : sf.getClasses()) { for (var clazz : sf.getClasses()) {
var codegen = new Codegen(converter.convert(clazz), this); var codegen = new Codegen(converter.convert(clazz), this, converter);
var code = codegen.generate(); var code = codegen.generate();
generatedClasses.put(clazz.getClassName(), code); generatedClasses.put(clazz.getClassName(), code);
converter.auxiliaries.forEach((name, source) -> { converter.auxiliaries.forEach((name, source) -> {

View File

@ -32,7 +32,7 @@ public class ASTToTargetAST {
public static RefType OBJECT = ASTFactory.createObjectType(); // TODO It would be better if I could call this directly but the hashcode seems to change public static RefType OBJECT = ASTFactory.createObjectType(); // TODO It would be better if I could call this directly but the hashcode seems to change
protected List<Generics> all; protected List<Generics> all;
protected Generics generics; public Generics generics;
final Map<ClassOrInterface, Set<GenericTypeVar>> userDefinedGenerics = new HashMap<>(); final Map<ClassOrInterface, Set<GenericTypeVar>> userDefinedGenerics = new HashMap<>();
public final JavaTXCompiler compiler; public final JavaTXCompiler compiler;
@ -173,7 +173,7 @@ public class ASTToTargetAST {
else return new TargetClass(input.getModifiers(), input.getClassName(), convert(input.getSuperClass(), generics.javaGenerics), javaGenerics, txGenerics, superInterfaces, constructors, staticConstructor, fields, methods); else return new TargetClass(input.getModifiers(), input.getClassName(), convert(input.getSuperClass(), generics.javaGenerics), javaGenerics, txGenerics, superInterfaces, constructors, staticConstructor, fields, methods);
} }
private List<MethodParameter> convert(ParameterList input, GenerateGenerics generics) { public List<MethodParameter> convert(ParameterList input, GenerateGenerics generics) {
return input.getFormalparalist().stream().map(param -> return input.getFormalparalist().stream().map(param ->
new MethodParameter((TargetPattern) convert(param)) new MethodParameter((TargetPattern) convert(param))
).toList(); ).toList();
@ -447,7 +447,7 @@ public class ASTToTargetAST {
public Map<String, byte[]> auxiliaries = new HashMap<>(); public Map<String, byte[]> auxiliaries = new HashMap<>();
protected TargetType convert(RefTypeOrTPHOrWildcardOrGeneric input) { public TargetType convert(RefTypeOrTPHOrWildcardOrGeneric input) {
return convert(input, generics.javaGenerics); return convert(input, generics.javaGenerics);
} }

View File

@ -9,6 +9,8 @@ import de.dhbwstuttgart.syntaxtree.factory.PrimitiveMethodsGenerator;
import de.dhbwstuttgart.syntaxtree.statement.*; import de.dhbwstuttgart.syntaxtree.statement.*;
import de.dhbwstuttgart.syntaxtree.type.*; import de.dhbwstuttgart.syntaxtree.type.*;
import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetGeneric;
import de.dhbwstuttgart.target.tree.TargetMethod;
import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*; import de.dhbwstuttgart.target.tree.type.*;
@ -81,7 +83,9 @@ public class StatementToTargetExpression implements ASTVisitor {
} // Don't look at lambda expressions } // Don't look at lambda expressions
}); });
result = new TargetLambdaExpression(converter.convert(lambdaExpression.getType()), captures, parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody)); TargetMethod.Signature signature = new TargetMethod.Signature(Set.of(), parameters, converter.convert(lambdaExpression.getReturnType()));;
var tpe = converter.convert(lambdaExpression.getType());
result = new TargetLambdaExpression(tpe, captures, signature, converter.convert(lambdaExpression.methodBody));
} }
@Override @Override

View File

@ -1,6 +1,7 @@
package de.dhbwstuttgart.target.tree; package de.dhbwstuttgart.target.tree;
import de.dhbwstuttgart.target.tree.expression.TargetBlock; import de.dhbwstuttgart.target.tree.expression.TargetBlock;
import de.dhbwstuttgart.target.tree.expression.TargetPattern;
import de.dhbwstuttgart.target.tree.type.TargetType; import de.dhbwstuttgart.target.tree.type.TargetType;
import org.objectweb.asm.Opcodes; import org.objectweb.asm.Opcodes;
@ -8,7 +9,15 @@ import java.util.List;
import java.util.Set; import java.util.Set;
public record TargetMethod(int access, String name, TargetBlock block, Signature signature, Signature txSignature) { public record TargetMethod(int access, String name, TargetBlock block, Signature signature, Signature txSignature) {
public record Signature(Set<TargetGeneric> generics, List<MethodParameter> parameters, TargetType returnType) { } public record Signature(Set<TargetGeneric> generics, List<MethodParameter> parameters, TargetType returnType) {
public String getSignature() {
return TargetMethod.getSignature(generics, parameters, returnType);
}
public String getDescriptor() {
return TargetMethod.getDescriptor(returnType, parameters.stream().map(MethodParameter::pattern).map(TargetPattern::type).toArray(TargetType[]::new));
}
}
public static String getDescriptor(TargetType returnType, TargetType... parameters) { public static String getDescriptor(TargetType returnType, TargetType... parameters) {
String ret = "("; String ret = "(";

View File

@ -2,9 +2,10 @@ package de.dhbwstuttgart.target.tree.expression;
import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetField; import de.dhbwstuttgart.target.tree.TargetField;
import de.dhbwstuttgart.target.tree.TargetMethod;
import de.dhbwstuttgart.target.tree.type.TargetType; import de.dhbwstuttgart.target.tree.type.TargetType;
import java.util.List; import java.util.List;
public record TargetLambdaExpression(TargetType type, List<MethodParameter> captures, List<MethodParameter> params, TargetType returnType, TargetBlock block) implements TargetExpression { public record TargetLambdaExpression(TargetType type, List<MethodParameter> captures, TargetMethod.Signature signature, TargetBlock block) implements TargetExpression {
} }

View File

@ -5,6 +5,7 @@ import org.junit.Test;
import java.lang.reflect.*; import java.lang.reflect.*;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Vector; import java.util.Vector;
import targetast.TestCodegen; import targetast.TestCodegen;
@ -1064,4 +1065,15 @@ public class TestComplete {
var instance = clazz.getDeclaredConstructor().newInstance(); var instance = clazz.getDeclaredConstructor().newInstance();
clazz.getDeclaredMethod("main").invoke(instance); clazz.getDeclaredMethod("main").invoke(instance);
} }
@Test
public void testBug314() throws Exception {
var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Bug314.jav");
var clazz = classFiles.get("Bug314");
var instance = clazz.getDeclaredConstructor().newInstance();
var list = List.of(3, 4, 6, 7, 8);
var res = clazz.getDeclaredMethod("convert", List.class).invoke(instance, list);
assertEquals(res, List.of(6, 7, 8));
}
} }

View File

@ -8,6 +8,7 @@ import de.dhbwstuttgart.parser.scope.JavaClassName;
import de.dhbwstuttgart.target.generate.ASTToTargetAST; import de.dhbwstuttgart.target.generate.ASTToTargetAST;
import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetClass; import de.dhbwstuttgart.target.tree.TargetClass;
import de.dhbwstuttgart.target.tree.TargetMethod;
import de.dhbwstuttgart.target.tree.TargetStructure; import de.dhbwstuttgart.target.tree.TargetStructure;
import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.TargetFunNType; import de.dhbwstuttgart.target.tree.type.TargetFunNType;
@ -23,10 +24,7 @@ import org.objectweb.asm.Opcodes;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Arrays; import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -54,7 +52,7 @@ public class TestCodegen {
result.putAll(classes.stream().map(cli -> { result.putAll(classes.stream().map(cli -> {
try { try {
return generateClass(converter.convert(cli), classLoader, compiler); return generateClass(converter.convert(cli), classLoader, converter);
} catch (IOException exception) { } catch (IOException exception) {
throw new RuntimeException(exception); throw new RuntimeException(exception);
} }
@ -69,14 +67,14 @@ public class TestCodegen {
} }
public static Class<?> generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader) throws IOException, ClassNotFoundException { public static Class<?> generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader) throws IOException, ClassNotFoundException {
Codegen codegen = new Codegen(clazz, new JavaTXCompiler(List.of())); Codegen codegen = new Codegen(clazz, new JavaTXCompiler(List.of()), null);
var code = codegen.generate(); var code = codegen.generate();
writeClassFile(clazz.qualifiedName().getClassName(), code); writeClassFile(clazz.qualifiedName().getClassName(), code);
return classLoader.loadClass(code); return classLoader.loadClass(code);
} }
public static Class<?> generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader, JavaTXCompiler compiler) throws IOException { public static Class<?> generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader, ASTToTargetAST converter) throws IOException {
Codegen codegen = new Codegen(clazz, compiler); Codegen codegen = new Codegen(clazz, converter.compiler, converter);
var code = codegen.generate(); var code = codegen.generate();
writeClassFile(clazz.qualifiedName().getClassName(), code); writeClassFile(clazz.qualifiedName().getClassName(), code);
return classLoader.loadClass(code); return classLoader.loadClass(code);
@ -93,7 +91,7 @@ public class TestCodegen {
var result = classes.stream().map(cli -> { var result = classes.stream().map(cli -> {
try { try {
return generateClass(converter.convert(cli), classLoader, compiler); return generateClass(converter.convert(cli), classLoader, converter);
} catch (IOException exception) { } catch (IOException exception) {
throw new RuntimeException(exception); throw new RuntimeException(exception);
} }
@ -272,7 +270,8 @@ public class TestCodegen {
var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, new JavaClassName("CGLambda")); var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, new JavaClassName("CGLambda"));
targetClass.addConstructor(Opcodes.ACC_PUBLIC, List.of(), new TargetBlock(List.of(new TargetMethodCall(null, new TargetSuper(TargetType.Object), List.of(), TargetType.Object, "<init>", false, false, false)))); targetClass.addConstructor(Opcodes.ACC_PUBLIC, List.of(), new TargetBlock(List.of(new TargetMethodCall(null, new TargetSuper(TargetType.Object), List.of(), TargetType.Object, "<init>", false, false, false))));
targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, new TargetBlock(List.of(new TargetVarDecl(interfaceType, "by2", new TargetLambdaExpression(interfaceType, List.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer, new TargetBlock(List.of(new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "num"), new TargetLiteral.IntLiteral(2))))))), new TargetReturn(new TargetCast(TargetType.Integer, new TargetMethodCall(TargetType.Object, TargetType.Object, List.of(TargetType.Object), new TargetLocalVar(interfaceType, "by2"), List.of(new TargetLiteral.IntLiteral(10)), interfaceType, "apply", false, true, false)))))); var signature = new TargetMethod.Signature(Set.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer);
targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, new TargetBlock(List.of(new TargetVarDecl(interfaceType, "by2", new TargetLambdaExpression(interfaceType, List.of(), signature, new TargetBlock(List.of(new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "num"), new TargetLiteral.IntLiteral(2))))))), new TargetReturn(new TargetCast(TargetType.Integer, new TargetMethodCall(TargetType.Object, TargetType.Object, List.of(TargetType.Object), new TargetLocalVar(interfaceType, "by2"), List.of(new TargetLiteral.IntLiteral(10)), interfaceType, "apply", false, true, false))))));
var clazz = generateClass(targetClass, classLoader); var clazz = generateClass(targetClass, classLoader);
var instance = clazz.getConstructor().newInstance(); var instance = clazz.getConstructor().newInstance();
assertEquals(clazz.getDeclaredMethod("lambda").invoke(instance), 20); assertEquals(clazz.getDeclaredMethod("lambda").invoke(instance), 20);