diff --git a/resources/bytecode/javFiles/SwitchOverload.jav b/resources/bytecode/javFiles/SwitchOverload.jav new file mode 100644 index 00000000..0c8f0179 --- /dev/null +++ b/resources/bytecode/javFiles/SwitchOverload.jav @@ -0,0 +1,17 @@ +import java.lang.Integer; +import java.lang.Double; +import java.lang.Number; + +public record R(Number n) {} + +public class SwitchOverload { + + Number f(Double d) { return d * 2; } + Number f(Integer i) { return i * 5; } + + public m(r) { + return switch(r) { + case R(o) -> f(o); + }; + } +} \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index 513944dd..12ea9e3f 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -1294,8 +1294,7 @@ public class Codegen { state.enterScope(); // This is the index to start the switch from mv.visitInsn(ICONST_0); - if (aSwitch.isExpression()) - state.pushSwitch(); + state.pushSwitch(); // To be able to skip ahead to the next case var start = new Label(); diff --git a/src/main/java/de/dhbwstuttgart/syntaxtree/type/RefType.java b/src/main/java/de/dhbwstuttgart/syntaxtree/type/RefType.java index 73ca1751..f93dba21 100644 --- a/src/main/java/de/dhbwstuttgart/syntaxtree/type/RefType.java +++ b/src/main/java/de/dhbwstuttgart/syntaxtree/type/RefType.java @@ -8,6 +8,7 @@ import org.antlr.v4.runtime.Token; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Objects; public class RefType extends RefTypeOrTPHOrWildcardOrGeneric @@ -49,10 +50,7 @@ public class RefType extends RefTypeOrTPHOrWildcardOrGeneric @Override public int hashCode() { - int hash = 0; - hash += super.hashCode(); - hash += this.name.hashCode();//Nur den Name hashen. Sorgt für langsame, aber funktionierende HashMaps - return hash; + return this.name.hashCode();//Nur den Name hashen. Sorgt für langsame, aber funktionierende HashMaps } public RefType(JavaClassName fullyQualifiedName, List parameter, Token offset) { @@ -83,6 +81,7 @@ public class RefType extends RefTypeOrTPHOrWildcardOrGeneric public boolean equals(Object obj) { if(obj instanceof RefType){ + if (!Objects.equals(this.name, ((RefType) obj).name)) return false; boolean ret = true; //if(!(super.equals(obj))) PL 2020-03-12 muss vll. einkommentiert werden diff --git a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java index 36fa146b..148b6b19 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java @@ -39,6 +39,10 @@ public class ASTToTargetAST { public final JavaTXCompiler compiler; + public List findAllVariants(RefTypeOrTPHOrWildcardOrGeneric type) { + return javaGenerics().stream().map(generics -> generics.resolve(type)).distinct().toList(); + } + public List txGenerics() { return all.stream().map(generics -> new GenericsResult(generics.txGenerics)).toList(); } diff --git a/src/main/java/de/dhbwstuttgart/target/generate/GenerateGenerics.java b/src/main/java/de/dhbwstuttgart/target/generate/GenerateGenerics.java index a0d7c5c7..49cdc775 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/GenerateGenerics.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/GenerateGenerics.java @@ -134,8 +134,8 @@ public abstract class GenerateGenerics { final Map> familyOfMethods = new HashMap<>(); final Set simplifiedConstraints = new HashSet<>(); - final Map concreteTypes = new HashMap<>(); - final Map equality = new HashMap<>(); + Map concreteTypes = new HashMap<>(); + Map equality = new HashMap<>(); GenerateGenerics(ASTToTargetAST astToTargetAST, ResultSet constraints) { this.astToTargetAST = astToTargetAST; @@ -154,6 +154,22 @@ public abstract class GenerateGenerics { System.out.println("Simplified constraints: " + simplifiedConstraints); } + /*public record GenericsState(Map concreteTypes, Map equality) {} + + public GenericsState store() { + return new GenericsState(new HashMap<>(concreteTypes), new HashMap<>(equality)); + } + + public void restore(GenericsState state) { + this.concreteTypes = state.concreteTypes; + this.equality = state.equality; + } + + public void addOverlay(TypePlaceholder from, RefTypeOrTPHOrWildcardOrGeneric to) { + if (to instanceof TypePlaceholder t) equality.put(from, t); + else if (to instanceof RefType t) concreteTypes.put(new TPH(from), t); + }*/ + Set findTypeVariables(RefTypeOrTPHOrWildcardOrGeneric type) { var result = new HashSet(); if (type instanceof TypePlaceholder tph) { diff --git a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java index b94c37cd..cedf3a85 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java @@ -2,23 +2,19 @@ package de.dhbwstuttgart.target.generate; import de.dhbwstuttgart.exceptions.DebugException; import de.dhbwstuttgart.exceptions.NotImplementedException; -import de.dhbwstuttgart.parser.NullToken; import de.dhbwstuttgart.parser.SyntaxTreeGenerator.AssignToLocal; import de.dhbwstuttgart.parser.scope.JavaClassName; import de.dhbwstuttgart.syntaxtree.*; -import de.dhbwstuttgart.syntaxtree.factory.PrimitiveMethodsGenerator; import de.dhbwstuttgart.syntaxtree.statement.*; import de.dhbwstuttgart.syntaxtree.type.*; import de.dhbwstuttgart.target.tree.MethodParameter; -import de.dhbwstuttgart.target.tree.TargetGeneric; import de.dhbwstuttgart.target.tree.TargetMethod; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; -import javax.swing.text.html.Option; import java.lang.reflect.Modifier; import java.util.*; -import java.util.stream.Stream; +import java.util.stream.Collectors; import java.util.stream.StreamSupport; public class StatementToTargetExpression implements ASTVisitor { @@ -386,9 +382,94 @@ public class StatementToTargetExpression implements ASTVisitor { result = new TargetTernary(converter.convert(ternary.getType()), converter.convert(ternary.cond), converter.convert(ternary.iftrue), converter.convert(ternary.iffalse)); } + record TypeVariants(RefTypeOrTPHOrWildcardOrGeneric in, List types) {} + + private List extractAllPatterns(Pattern pattern) { + return switch (pattern) { + case GuardedPattern guarded -> extractAllPatterns(guarded.getNestedPattern()); + case RecordPattern recordPattern -> recordPattern.getSubPattern().stream() + .map(this::extractAllPatterns) + .flatMap(List::stream).toList(); + case FormalParameter param -> List.of(new TypeVariants(param.getType(), converter.findAllVariants(param.getType()))); + default -> List.of(); + }; + } + + record TypePair(RefTypeOrTPHOrWildcardOrGeneric in, RefTypeOrTPHOrWildcardOrGeneric out) {} + + private void cartesianProduct( + List variants, int index, + List current, + List> result) { + + if (index == variants.size()) { + result.add(new ArrayList<>(current)); + return; + } + var currentSet = variants.get(index).types; + for (var element: currentSet) { + current.add(element); + cartesianProduct(variants, index + 1, current, result); + current.removeLast(); + } + } + + private List> cartesianProduct(List variants) { + var prod = new ArrayList>(); + cartesianProduct(variants, 0, new ArrayList<>(), prod); + + var res = new ArrayList>(); + for (var list : prod) { + var l = new ArrayList(); + for (var i = 0; i < list.size(); i++) { + l.add(new TypePair(variants.get(i).in, list.get(i))); + } + res.add(l); + } + return res; + } + @Override public void visit(Switch switchStmt) { - var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(converter::convert).toList(); + var variants = converter.findAllVariants(switchStmt.getSwitch().getType()); + var returns = converter.findAllVariants(switchStmt.getType()); + var canBeOverloaded = variants.size() == 1 && returns.size() == 1; + + var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(case_ -> { + var overloads = new ArrayList(); + + if (canBeOverloaded) { + for (var label: case_.getLabels()) { + var product = cartesianProduct(extractAllPatterns(label.getPattern())); + + for (var l : product) { + var oldGenerics = converter.generics; + + // Set the generics to matching result set + for (var generics : converter.all) { + var java = generics.javaGenerics(); + var equals = true; + for (var pair : l) { + if (!java.getType(pair.in).equals(pair.out)) { + equals = false; break; + } + } + if (equals) { + converter.generics = generics; + break; + } + } + + overloads.add(converter.convert(case_)); + converter.generics = oldGenerics; + } + } + } else { + overloads.add(converter.convert(case_)); + } + + return overloads; + }).flatMap(List::stream).toList(); TargetSwitch.Case default_ = null; for (var block : switchStmt.getBlocks()) { diff --git a/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java b/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java index bdabc81e..2f3a689b 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/TargetMethod.java @@ -18,6 +18,19 @@ public record TargetMethod(int access, String name, TargetBlock block, Signature public String getDescriptor() { return TargetMethod.getDescriptor(returnType, parameters.stream().map(MethodParameter::pattern).map(TargetPattern::type).toArray(TargetType[]::new)); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Signature signature = (Signature) o; + return Objects.equals(parameters, signature.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(parameters); + } } public static String getDescriptor(TargetType returnType, TargetType... parameters) { diff --git a/src/test/java/TestComplete.java b/src/test/java/TestComplete.java index b6b26597..8f58d93d 100644 --- a/src/test/java/TestComplete.java +++ b/src/test/java/TestComplete.java @@ -849,6 +849,24 @@ public class TestComplete { assertEquals(m2.invoke(instance, 10), 10); } + @Test + public void testOverloadSwitch() throws Exception { + var classFiles = generateClassFiles(new ByteArrayClassLoader(), "SwitchOverload.jav"); + var clazz = classFiles.get("SwitchOverload"); + + var R = classFiles.get("R"); + var rctor = R.getDeclaredConstructor(Number.class); + + var instance = clazz.getDeclaredConstructor().newInstance(); + var m = clazz.getDeclaredMethod("m", R); + + var x = rctor.newInstance(10); + var d = rctor.newInstance(20.0); + + assertEquals(m.invoke(instance, x), 50); + assertEquals(m.invoke(instance, d), 40.0); + } + @Test public void testInterfaces() throws Exception { var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Interfaces.jav");