From cb7d0e22cc248caf16c59389378a5818d8f4feb9 Mon Sep 17 00:00:00 2001 From: Daniel Holle Date: Tue, 9 Apr 2024 14:58:43 +0200 Subject: [PATCH] Fix #314 --- resources/bytecode/javFiles/Bug314.jav | 13 ++++ .../de/dhbwstuttgart/bytecode/Codegen.java | 68 +++++++++++++------ .../de/dhbwstuttgart/core/JavaTXCompiler.java | 2 +- .../target/generate/ASTToTargetAST.java | 6 +- .../generate/StatementToTargetExpression.java | 6 +- .../target/tree/TargetMethod.java | 11 ++- .../expression/TargetLambdaExpression.java | 3 +- src/test/java/TestComplete.java | 12 ++++ src/test/java/targetast/TestCodegen.java | 19 +++--- 9 files changed, 102 insertions(+), 38 deletions(-) create mode 100644 resources/bytecode/javFiles/Bug314.jav diff --git a/resources/bytecode/javFiles/Bug314.jav b/resources/bytecode/javFiles/Bug314.jav new file mode 100644 index 00000000..7f20d6d2 --- /dev/null +++ b/resources/bytecode/javFiles/Bug314.jav @@ -0,0 +1,13 @@ +import java.lang.Integer; +import java.util.List; +import java.util.ArrayList; +import java.util.stream.Stream; +import java.util.function.Predicate; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class Bug314 { + public List convert(List in) { + return in.stream().filter(x -> x > 5).collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index fd5e737f..81e9eacb 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -6,6 +6,8 @@ import de.dhbwstuttgart.parser.NullToken; import de.dhbwstuttgart.parser.scope.JavaClassName; import de.dhbwstuttgart.syntaxtree.ClassOrInterface; import de.dhbwstuttgart.syntaxtree.type.RefType; +import de.dhbwstuttgart.target.generate.ASTToTargetAST; +import de.dhbwstuttgart.target.generate.StatementToTargetExpression; import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; @@ -14,6 +16,7 @@ import org.objectweb.asm.*; import java.lang.invoke.*; import java.lang.reflect.Modifier; import java.util.*; +import java.util.stream.IntStream; import static org.objectweb.asm.Opcodes.*; import static de.dhbwstuttgart.target.tree.expression.TargetBinaryOp.*; @@ -26,8 +29,9 @@ public class Codegen { private int lambdaCounter = 0; private final HashMap lambdas = new HashMap<>(); private final JavaTXCompiler compiler; + private final ASTToTargetAST converter; - public Codegen(TargetStructure clazz, JavaTXCompiler compiler) { + public Codegen(TargetStructure clazz, JavaTXCompiler compiler, ASTToTargetAST converter) { this.clazz = clazz; this.className = clazz.qualifiedName().getClassName(); this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) { @@ -37,6 +41,7 @@ public class Codegen { } }; this.compiler = compiler; + this.converter = converter; } private record LocalVar(int index, String name, TargetType type) { @@ -721,15 +726,49 @@ public class Codegen { private void generateLambdaExpression(State state, TargetLambdaExpression lambda) { var mv = state.mv; + String methodName = "apply"; + TargetMethod.Signature signature = null; + + if (!(state.contextType instanceof TargetFunNType ctx)) { + var intf = compiler.getClass(new JavaClassName(state.contextType.name())); + if (intf != null) { + var method = intf.getMethods().stream().filter(m -> Modifier.isAbstract(m.modifier)).findFirst().orElseThrow(); + methodName = method.getName(); + var methodParams = new ArrayList(); + for (var i = 0; i < lambda.signature().parameters().size(); i++) { + var param = lambda.signature().parameters().get(i); + var tpe = converter.convert(method.getParameterList().getParameterAt(i).getType()); + methodParams.add(param.withType(tpe)); + } + var retType = converter.convert(method.getReturnType()); + signature = new TargetMethod.Signature(Set.of(), methodParams, retType); + } + } + if (signature == null) { + signature = new TargetMethod.Signature(Set.of(), lambda.signature().parameters().stream().map(par -> par.withType(TargetType.Object)).toList(), TargetType.Object); + } + + signature = new TargetMethod.Signature( + signature.generics(), + signature.parameters().stream().map(par -> + par.withType(par.pattern().type() instanceof TargetGenericType ? TargetType.Object : par.pattern().type()) + ).toList(), + signature.returnType() instanceof TargetGenericType ? TargetType.Object : signature.returnType() + ); + + var parameters = new ArrayList<>(lambda.captures()); + parameters.addAll(signature.parameters()); + var implSignature = new TargetMethod.Signature(Set.of(), parameters, lambda.signature().returnType()); + + // Normalize + TargetMethod impl; if (lambdas.containsKey(lambda)) { impl = lambdas.get(lambda); } else { var name = "lambda$" + lambdaCounter++; - var parameters = new ArrayList<>(lambda.captures()); - parameters.addAll(lambda.params().stream().map(param -> param.pattern().type() instanceof TargetGenericType ? param.withType(TargetType.Object) : param).toList()); - impl = new TargetMethod(0, name, lambda.block(), new TargetMethod.Signature(Set.of(), parameters, lambda.returnType() instanceof TargetGenericType ? TargetType.Object : lambda.returnType()), null); + impl = new TargetMethod(0, name, lambda.block(), implSignature, null); generateMethod(impl); lambdas.put(lambda, impl); } @@ -737,17 +776,7 @@ public class Codegen { var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, MethodType.class, MethodHandle.class, MethodType.class); var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/invoke/LambdaMetafactory", "metafactory", mt.toMethodDescriptorString(), false); - var handle = new Handle(H_INVOKEVIRTUAL, clazz.getName(), impl.name(), impl.getDescriptor(), false); - - // TODO maybe make this a function? - var desugared = "("; - for (var param : lambda.params()) - desugared += "Ljava/lang/Object;"; - desugared += ")"; - if (lambda.returnType() != null) - desugared += "Ljava/lang/Object;"; - else - desugared += "V"; + var handle = new Handle(H_INVOKEVIRTUAL, clazz.getName(), impl.name(), implSignature.getDescriptor(), false); var params = new ArrayList(); params.add(new TargetRefType(clazz.qualifiedName().getClassName())); @@ -760,12 +789,7 @@ public class Codegen { mv.visitVarInsn(ALOAD, state.scope.get(pattern.name()).index); } - String methodName; - var intf = compiler.getClass(new JavaClassName(state.contextType.name())); - if (intf == null) methodName = "apply"; // TODO Weird fallback logic here - else methodName = intf.getMethods().stream().filter(m -> Modifier.isAbstract(m.modifier)).findFirst().orElseThrow().getName(); - - mv.visitInvokeDynamicInsn(methodName, descriptor, bootstrap, Type.getType(desugared), handle, Type.getType(TargetMethod.getDescriptor(impl.signature().returnType(), lambda.params().stream().map(mp -> mp.pattern().type()).toArray(TargetType[]::new)))); + mv.visitInvokeDynamicInsn(methodName, descriptor, bootstrap, Type.getType(signature.getSignature()), handle, Type.getType(signature.getDescriptor())); } private void generate(State state, TargetExpression expr) { @@ -876,6 +900,8 @@ public class Codegen { case TargetLocalVar localVar: { LocalVar local = state.scope.get(localVar.name()); mv.visitVarInsn(ALOAD, local.index()); + // This is a bit weird but sometimes the types don't match (see lambda expressions) + convertTo(state, local.type(), localVar.type()); unboxPrimitive(state, local.type()); break; } diff --git a/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java b/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java index b5300198..46522989 100644 --- a/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java +++ b/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java @@ -754,7 +754,7 @@ public class JavaTXCompiler { var converter = new ASTToTargetAST(this, typeInferenceResult, sf, classLoader); var generatedClasses = new HashMap(); for (var clazz : sf.getClasses()) { - var codegen = new Codegen(converter.convert(clazz), this); + var codegen = new Codegen(converter.convert(clazz), this, converter); var code = codegen.generate(); generatedClasses.put(clazz.getClassName(), code); converter.auxiliaries.forEach((name, source) -> { diff --git a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java index a6f42d80..1c485d7e 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java @@ -32,7 +32,7 @@ public class ASTToTargetAST { public static RefType OBJECT = ASTFactory.createObjectType(); // TODO It would be better if I could call this directly but the hashcode seems to change protected List all; - protected Generics generics; + public Generics generics; final Map> userDefinedGenerics = new HashMap<>(); public final JavaTXCompiler compiler; @@ -173,7 +173,7 @@ public class ASTToTargetAST { else return new TargetClass(input.getModifiers(), input.getClassName(), convert(input.getSuperClass(), generics.javaGenerics), javaGenerics, txGenerics, superInterfaces, constructors, staticConstructor, fields, methods); } - private List convert(ParameterList input, GenerateGenerics generics) { + public List convert(ParameterList input, GenerateGenerics generics) { return input.getFormalparalist().stream().map(param -> new MethodParameter((TargetPattern) convert(param)) ).toList(); @@ -447,7 +447,7 @@ public class ASTToTargetAST { public Map auxiliaries = new HashMap<>(); - protected TargetType convert(RefTypeOrTPHOrWildcardOrGeneric input) { + public TargetType convert(RefTypeOrTPHOrWildcardOrGeneric input) { return convert(input, generics.javaGenerics); } diff --git a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java index 7196b78c..c6ee9939 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java @@ -9,6 +9,8 @@ import de.dhbwstuttgart.syntaxtree.factory.PrimitiveMethodsGenerator; import de.dhbwstuttgart.syntaxtree.statement.*; import de.dhbwstuttgart.syntaxtree.type.*; import de.dhbwstuttgart.target.tree.MethodParameter; +import de.dhbwstuttgart.target.tree.TargetGeneric; +import de.dhbwstuttgart.target.tree.TargetMethod; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; @@ -81,7 +83,9 @@ public class StatementToTargetExpression implements ASTVisitor { } // Don't look at lambda expressions }); - result = new TargetLambdaExpression(converter.convert(lambdaExpression.getType()), captures, parameters, converter.convert(lambdaExpression.getReturnType()), converter.convert(lambdaExpression.methodBody)); + TargetMethod.Signature signature = new TargetMethod.Signature(Set.of(), parameters, converter.convert(lambdaExpression.getReturnType()));; + var tpe = converter.convert(lambdaExpression.getType()); + result = new TargetLambdaExpression(tpe, captures, signature, converter.convert(lambdaExpression.methodBody)); } @Override diff --git a/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java b/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java index bbbe7290..5a8b9368 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java @@ -1,6 +1,7 @@ package de.dhbwstuttgart.target.tree; import de.dhbwstuttgart.target.tree.expression.TargetBlock; +import de.dhbwstuttgart.target.tree.expression.TargetPattern; import de.dhbwstuttgart.target.tree.type.TargetType; import org.objectweb.asm.Opcodes; @@ -8,7 +9,15 @@ import java.util.List; import java.util.Set; public record TargetMethod(int access, String name, TargetBlock block, Signature signature, Signature txSignature) { - public record Signature(Set generics, List parameters, TargetType returnType) { } + public record Signature(Set generics, List parameters, TargetType returnType) { + public String getSignature() { + return TargetMethod.getSignature(generics, parameters, returnType); + } + + public String getDescriptor() { + return TargetMethod.getDescriptor(returnType, parameters.stream().map(MethodParameter::pattern).map(TargetPattern::type).toArray(TargetType[]::new)); + } + } public static String getDescriptor(TargetType returnType, TargetType... parameters) { String ret = "("; diff --git a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java index 718fb663..882bae48 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetLambdaExpression.java @@ -2,9 +2,10 @@ package de.dhbwstuttgart.target.tree.expression; import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.TargetField; +import de.dhbwstuttgart.target.tree.TargetMethod; import de.dhbwstuttgart.target.tree.type.TargetType; import java.util.List; -public record TargetLambdaExpression(TargetType type, List captures, List params, TargetType returnType, TargetBlock block) implements TargetExpression { +public record TargetLambdaExpression(TargetType type, List captures, TargetMethod.Signature signature, TargetBlock block) implements TargetExpression { } diff --git a/src/test/java/TestComplete.java b/src/test/java/TestComplete.java index d98627b5..59bf694c 100644 --- a/src/test/java/TestComplete.java +++ b/src/test/java/TestComplete.java @@ -5,6 +5,7 @@ import org.junit.Test; import java.lang.reflect.*; import java.util.Arrays; +import java.util.List; import java.util.Vector; import targetast.TestCodegen; @@ -1064,4 +1065,15 @@ public class TestComplete { var instance = clazz.getDeclaredConstructor().newInstance(); clazz.getDeclaredMethod("main").invoke(instance); } + + @Test + public void testBug314() throws Exception { + var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Bug314.jav"); + var clazz = classFiles.get("Bug314"); + var instance = clazz.getDeclaredConstructor().newInstance(); + + var list = List.of(3, 4, 6, 7, 8); + var res = clazz.getDeclaredMethod("convert", List.class).invoke(instance, list); + assertEquals(res, List.of(6, 7, 8)); + } } diff --git a/src/test/java/targetast/TestCodegen.java b/src/test/java/targetast/TestCodegen.java index cb9a8e02..7db04a8b 100644 --- a/src/test/java/targetast/TestCodegen.java +++ b/src/test/java/targetast/TestCodegen.java @@ -8,6 +8,7 @@ import de.dhbwstuttgart.parser.scope.JavaClassName; import de.dhbwstuttgart.target.generate.ASTToTargetAST; import de.dhbwstuttgart.target.tree.MethodParameter; import de.dhbwstuttgart.target.tree.TargetClass; +import de.dhbwstuttgart.target.tree.TargetMethod; import de.dhbwstuttgart.target.tree.TargetStructure; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.TargetFunNType; @@ -23,10 +24,7 @@ import org.objectweb.asm.Opcodes; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -54,7 +52,7 @@ public class TestCodegen { result.putAll(classes.stream().map(cli -> { try { - return generateClass(converter.convert(cli), classLoader, compiler); + return generateClass(converter.convert(cli), classLoader, converter); } catch (IOException exception) { throw new RuntimeException(exception); } @@ -69,14 +67,14 @@ public class TestCodegen { } public static Class generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader) throws IOException, ClassNotFoundException { - Codegen codegen = new Codegen(clazz, new JavaTXCompiler(List.of())); + Codegen codegen = new Codegen(clazz, new JavaTXCompiler(List.of()), null); var code = codegen.generate(); writeClassFile(clazz.qualifiedName().getClassName(), code); return classLoader.loadClass(code); } - public static Class generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader, JavaTXCompiler compiler) throws IOException { - Codegen codegen = new Codegen(clazz, compiler); + public static Class generateClass(TargetStructure clazz, IByteArrayClassLoader classLoader, ASTToTargetAST converter) throws IOException { + Codegen codegen = new Codegen(clazz, converter.compiler, converter); var code = codegen.generate(); writeClassFile(clazz.qualifiedName().getClassName(), code); return classLoader.loadClass(code); @@ -93,7 +91,7 @@ public class TestCodegen { var result = classes.stream().map(cli -> { try { - return generateClass(converter.convert(cli), classLoader, compiler); + return generateClass(converter.convert(cli), classLoader, converter); } catch (IOException exception) { throw new RuntimeException(exception); } @@ -272,7 +270,8 @@ public class TestCodegen { var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, new JavaClassName("CGLambda")); targetClass.addConstructor(Opcodes.ACC_PUBLIC, List.of(), new TargetBlock(List.of(new TargetMethodCall(null, new TargetSuper(TargetType.Object), List.of(), TargetType.Object, "", false, false, false)))); - targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, new TargetBlock(List.of(new TargetVarDecl(interfaceType, "by2", new TargetLambdaExpression(interfaceType, List.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer, new TargetBlock(List.of(new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "num"), new TargetLiteral.IntLiteral(2))))))), new TargetReturn(new TargetCast(TargetType.Integer, new TargetMethodCall(TargetType.Object, TargetType.Object, List.of(TargetType.Object), new TargetLocalVar(interfaceType, "by2"), List.of(new TargetLiteral.IntLiteral(10)), interfaceType, "apply", false, true, false)))))); + var signature = new TargetMethod.Signature(Set.of(), List.of(new MethodParameter(TargetType.Integer, "num")), TargetType.Integer); + targetClass.addMethod(Opcodes.ACC_PUBLIC, "lambda", List.of(), TargetType.Integer, new TargetBlock(List.of(new TargetVarDecl(interfaceType, "by2", new TargetLambdaExpression(interfaceType, List.of(), signature, new TargetBlock(List.of(new TargetReturn(new TargetBinaryOp.Mul(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "num"), new TargetLiteral.IntLiteral(2))))))), new TargetReturn(new TargetCast(TargetType.Integer, new TargetMethodCall(TargetType.Object, TargetType.Object, List.of(TargetType.Object), new TargetLocalVar(interfaceType, "by2"), List.of(new TargetLiteral.IntLiteral(10)), interfaceType, "apply", false, true, false)))))); var clazz = generateClass(targetClass, classLoader); var instance = clazz.getConstructor().newInstance(); assertEquals(clazz.getDeclaredMethod("lambda").invoke(instance), 20);