From 89bbbdacd8f13bbcf081b9280a93a6a68089e706 Mon Sep 17 00:00:00 2001 From: Daniel Holle Date: Wed, 2 Oct 2024 15:09:19 +0200 Subject: [PATCH] Work on pattern matching in function headers --- .../bytecode/javFiles/OverloadPattern.jav | 10 +-- .../bytecode/javFiles/SwitchOverload.jav | 8 +- .../de/dhbwstuttgart/bytecode/Codegen.java | 42 +++++----- .../target/generate/ASTToTargetAST.java | 84 +++++++------------ 4 files changed, 63 insertions(+), 81 deletions(-) diff --git a/resources/bytecode/javFiles/OverloadPattern.jav b/resources/bytecode/javFiles/OverloadPattern.jav index e7f777de..16973a0a 100644 --- a/resources/bytecode/javFiles/OverloadPattern.jav +++ b/resources/bytecode/javFiles/OverloadPattern.jav @@ -2,16 +2,16 @@ import java.lang.Integer; import java.lang.Number; import java.lang.Float; -record Point(Number x, Number y) {} +public record Point(Number x, Number y) {} public class OverloadPattern { - public m(Point(Integer x, Integer y)) { - return x + y; + public m(Point(x, y), Point(z, a)) { + return x + y + z + a; } - public m(Point(Float x, Float y)) { + /*public m(Point(Float x, Float y)) { return x * y; - } + }*/ public m(Integer x) { return x; diff --git a/resources/bytecode/javFiles/SwitchOverload.jav b/resources/bytecode/javFiles/SwitchOverload.jav index 0c8f0179..10208867 100644 --- a/resources/bytecode/javFiles/SwitchOverload.jav +++ b/resources/bytecode/javFiles/SwitchOverload.jav @@ -9,9 +9,13 @@ public class SwitchOverload { Number f(Double d) { return d * 2; } Number f(Integer i) { return i * 5; } - public m(r) { + public m(r, x) { + x = x + x; return switch(r) { - case R(o) -> f(o); + case R(o) -> { + x = x + x; + yield f(o); + } }; } } \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index 12ea9e3f..2e884fcd 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -13,6 +13,7 @@ import de.dhbwstuttgart.target.generate.StatementToTargetExpression; import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; +import org.antlr.v4.codegen.Target; import org.objectweb.asm.*; import java.lang.invoke.*; @@ -1521,29 +1522,26 @@ public class Codegen { mv.visitEnd(); } - private int bindLocalVariables(State state, TargetPattern pattern, int offset, int field) { - if (pattern instanceof TargetComplexPattern cp) { - state.mv.visitVarInsn(ALOAD, offset); + private void bindLocalVariables(State state, TargetComplexPattern cp, int offset) { + state.mv.visitVarInsn(ALOAD, offset); - var clazz = findClass(new JavaClassName(cp.type().name())); - if (clazz == null) throw new CodeGenException("Class definition for '" + cp.type().name() + "' not found"); + var clazz = findClass(new JavaClassName(cp.type().name())); + if (clazz == null) throw new CodeGenException("Class definition for '" + cp.type().name() + "' not found"); - for (var i = 0; i < cp.subPatterns().size(); i++) { - var subPattern = cp.subPatterns().get(i); + for (var i = 0; i < cp.subPatterns().size(); i++) { + var subPattern = cp.subPatterns().get(i); - if (i < cp.subPatterns().size() - 1) - state.mv.visitInsn(DUP); + if (i < cp.subPatterns().size() - 1) + state.mv.visitInsn(DUP); - extractField(state, cp.type(), i, clazz); - state.mv.visitTypeInsn(CHECKCAST, subPattern.type().getInternalName()); - state.mv.visitVarInsn(ASTORE, offset); - offset = bindLocalVariables(state, subPattern, offset, i); + extractField(state, cp.type(), i, clazz); + state.mv.visitTypeInsn(CHECKCAST, subPattern.type().getInternalName()); + offset = state.createVariable(subPattern.name(), subPattern.type()).index; + state.mv.visitVarInsn(ASTORE, offset); + if (subPattern instanceof TargetComplexPattern cp2) { + bindLocalVariables(state, cp2, offset); } - } else if (pattern instanceof TargetTypePattern tp) { - offset++; - state.createVariable(tp.name(), tp.type()); - } else throw new NotImplementedException(); - return offset; + } } private void generateMethod(TargetMethod method) { @@ -1562,8 +1560,14 @@ public class Codegen { if (method.block() != null) { mv.visitCode(); var state = new State(method.signature().returnType(), mv, method.isStatic() ? 0 : 1); + var offset = 1; for (var param : method.signature().parameters()) { - bindLocalVariables(state, param.pattern(), 1, 0); + state.createVariable(param.pattern().name(), param.pattern().type()); + } + for (var param : method.signature().parameters()) { + if (param.pattern() instanceof TargetComplexPattern cp) + bindLocalVariables(state, cp, offset); + offset++; } generate(state, method.block()); if (method.signature().returnType() == null) diff --git a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java index 148b6b19..c30551cf 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java @@ -17,6 +17,7 @@ import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; import de.dhbwstuttgart.typeinference.result.*; +import java.sql.Array; import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -206,7 +207,7 @@ public class ASTToTargetAST { var superInterfaces = input.getSuperInterfaces().stream().map(clazz -> convert(clazz, generics.javaGenerics)).toList(); var constructors = input.getConstructors().stream().map(constructor -> this.convert(input, constructor, finalFieldInitializer)).flatMap(List::stream).toList(); var fields = input.getFieldDecl().stream().map(this::convert).toList(); - var methods = groupOverloads(input, input.getMethods()).stream().flatMap(List::stream).toList(); + var methods = groupOverloads(input, input.getMethods()).stream().map(m -> generatePatternOverloads(input, m)).flatMap(List::stream).toList(); TargetMethod staticConstructor = null; if (input.getStaticInitializer().isPresent()) @@ -271,69 +272,42 @@ public class ASTToTargetAST { return result; } - private String encodeName(String name, ParameterList params) { + private String encodeName(String name, TargetMethod.Signature params) { var res = new StringBuilder(); res.append(name); res.append('$'); - for (var param : params.getFormalparalist()) { - if (param instanceof RecordPattern rp) { - res.append(FunNGenerator.encodeType(convert(param.getType()))); - for (var pattern : rp.getSubPattern()) { - res.append(FunNGenerator.encodeType(convert(pattern.getType()))); - } - } + for (var param : params.parameters()) { + encodeName(param.pattern(), res); } return res.toString(); } - private List convert(ClassOrInterface clazz, List overloadedMethods) { - if (overloadedMethods.size() == 1) { - return convert(clazz, overloadedMethods.getFirst()).stream().map(m -> m.method()).toList(); - } - var methods = new ArrayList(); - for (var method : overloadedMethods) { - var newMethod = new Method( - method.modifier, - method.name, - //encodeName(method.name, method.getParameterList()), - method.getReturnType(), - method.getParameterList(), - method.block, - method.getGenerics(), - method.getOffset() - ); - methods.add(newMethod); - } - - // TODO Record overloading - /*var template = overloadedMethods.get(0); - - var pParams = new ArrayList(); - var i = 0; - for (var par : template.getParameterList()) { - pParams.add(switch (par) { - case RecordPattern rp -> new RecordPattern(rp.getSubPattern(), "par" + i, rp.getType(), new NullToken()); - default -> par; - }); - i++; - } - var params = new ParameterList(pParams, new NullToken()); - - var statements = new ArrayList(); - statements.add(new Return(makeRecordSwitch(template.getReturnType(), params, res), new NullToken())); - var block = new Block(statements, new NullToken()); - var entryPoint = new Method(template.modifier, template.name, template.getReturnType(), params, block, template.getGenerics(), new NullToken()); - - res.add(entryPoint); // TODO*/ - var res = new ArrayList(); - for (var method : methods) { - var overloads = convert(clazz, method); - for (var m : overloads) { - var overload = m.method; - if (res.contains(overload)) throw new CodeGenException("Duplicate method found: " + overload.name() + " with signature " + overload.signature().getSignature()); - res.add(overload); + private void encodeName(TargetPattern pattern, StringBuilder res) { + if (pattern instanceof TargetComplexPattern cp) { + res.append(FunNGenerator.encodeType(cp.type())); + for (var pat : cp.subPatterns()) { + encodeName(pat, res); } + } else { + res.append(FunNGenerator.encodeType(pattern.type())); } + } + + private List generatePatternOverloads(ClassOrInterface clazz, List overloadedMethods) { + if (overloadedMethods.size() <= 1) return overloadedMethods; + // Check if we have a pattern as a parameter + var firstMethod = overloadedMethods.getFirst(); + if (firstMethod.signature().parameters().stream().noneMatch(mp -> mp.pattern() instanceof TargetComplexPattern)) return overloadedMethods; + // Rename existing methods + + var res = new ArrayList(); + for (var method : overloadedMethods) { + var name = encodeName(method.name(), method.signature()); + res.add(new TargetMethod(method.access(), name, method.block(), method.signature(), method.txSignature())); + } + + // Generate dispatch method + return res; }