Add overloading for switches, see #348
All checks were successful
Build and Test with Maven / Build-and-test-with-Maven (push) Successful in 5m1s

This commit is contained in:
Daniel Holle 2024-10-01 17:28:20 +02:00
parent b7979ac7e7
commit 6ccf2a3df6
8 changed files with 161 additions and 14 deletions

View File

@ -0,0 +1,17 @@
import java.lang.Integer;
import java.lang.Double;
import java.lang.Number;
public record R(Number n) {}
public class SwitchOverload {
Number f(Double d) { return d * 2; }
Number f(Integer i) { return i * 5; }
public m(r) {
return switch(r) {
case R(o) -> f(o);
};
}
}

View File

@ -1294,7 +1294,6 @@ public class Codegen {
state.enterScope();
// This is the index to start the switch from
mv.visitInsn(ICONST_0);
if (aSwitch.isExpression())
state.pushSwitch();
// To be able to skip ahead to the next case

View File

@ -8,6 +8,7 @@ import org.antlr.v4.runtime.Token;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
public class RefType extends RefTypeOrTPHOrWildcardOrGeneric
@ -49,10 +50,7 @@ public class RefType extends RefTypeOrTPHOrWildcardOrGeneric
@Override
public int hashCode() {
int hash = 0;
hash += super.hashCode();
hash += this.name.hashCode();//Nur den Name hashen. Sorgt ¼r langsame, aber funktionierende HashMaps
return hash;
return this.name.hashCode();//Nur den Name hashen. Sorgt ¼r langsame, aber funktionierende HashMaps
}
public RefType(JavaClassName fullyQualifiedName, List<RefTypeOrTPHOrWildcardOrGeneric> parameter, Token offset) {
@ -83,6 +81,7 @@ public class RefType extends RefTypeOrTPHOrWildcardOrGeneric
public boolean equals(Object obj)
{
if(obj instanceof RefType){
if (!Objects.equals(this.name, ((RefType) obj).name)) return false;
boolean ret = true;
//if(!(super.equals(obj))) PL 2020-03-12 muss vll. einkommentiert werden

View File

@ -39,6 +39,10 @@ public class ASTToTargetAST {
public final JavaTXCompiler compiler;
public List<RefTypeOrTPHOrWildcardOrGeneric> findAllVariants(RefTypeOrTPHOrWildcardOrGeneric type) {
return javaGenerics().stream().map(generics -> generics.resolve(type)).distinct().toList();
}
public List<GenericsResult> txGenerics() {
return all.stream().map(generics -> new GenericsResult(generics.txGenerics)).toList();
}

View File

@ -134,8 +134,8 @@ public abstract class GenerateGenerics {
final Map<Method, Set<Pair>> familyOfMethods = new HashMap<>();
final Set<PairLT> simplifiedConstraints = new HashSet<>();
final Map<TPH, RefTypeOrTPHOrWildcardOrGeneric> concreteTypes = new HashMap<>();
final Map<TypePlaceholder, TypePlaceholder> equality = new HashMap<>();
Map<TPH, RefTypeOrTPHOrWildcardOrGeneric> concreteTypes = new HashMap<>();
Map<TypePlaceholder, TypePlaceholder> equality = new HashMap<>();
GenerateGenerics(ASTToTargetAST astToTargetAST, ResultSet constraints) {
this.astToTargetAST = astToTargetAST;
@ -154,6 +154,22 @@ public abstract class GenerateGenerics {
System.out.println("Simplified constraints: " + simplifiedConstraints);
}
/*public record GenericsState(Map<TPH, RefTypeOrTPHOrWildcardOrGeneric> concreteTypes, Map<TypePlaceholder, TypePlaceholder> equality) {}
public GenericsState store() {
return new GenericsState(new HashMap<>(concreteTypes), new HashMap<>(equality));
}
public void restore(GenericsState state) {
this.concreteTypes = state.concreteTypes;
this.equality = state.equality;
}
public void addOverlay(TypePlaceholder from, RefTypeOrTPHOrWildcardOrGeneric to) {
if (to instanceof TypePlaceholder t) equality.put(from, t);
else if (to instanceof RefType t) concreteTypes.put(new TPH(from), t);
}*/
Set<TPH> findTypeVariables(RefTypeOrTPHOrWildcardOrGeneric type) {
var result = new HashSet<TPH>();
if (type instanceof TypePlaceholder tph) {

View File

@ -2,23 +2,19 @@ package de.dhbwstuttgart.target.generate;
import de.dhbwstuttgart.exceptions.DebugException;
import de.dhbwstuttgart.exceptions.NotImplementedException;
import de.dhbwstuttgart.parser.NullToken;
import de.dhbwstuttgart.parser.SyntaxTreeGenerator.AssignToLocal;
import de.dhbwstuttgart.parser.scope.JavaClassName;
import de.dhbwstuttgart.syntaxtree.*;
import de.dhbwstuttgart.syntaxtree.factory.PrimitiveMethodsGenerator;
import de.dhbwstuttgart.syntaxtree.statement.*;
import de.dhbwstuttgart.syntaxtree.type.*;
import de.dhbwstuttgart.target.tree.MethodParameter;
import de.dhbwstuttgart.target.tree.TargetGeneric;
import de.dhbwstuttgart.target.tree.TargetMethod;
import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*;
import javax.swing.text.html.Option;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.stream.Stream;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
public class StatementToTargetExpression implements ASTVisitor {
@ -386,9 +382,94 @@ public class StatementToTargetExpression implements ASTVisitor {
result = new TargetTernary(converter.convert(ternary.getType()), converter.convert(ternary.cond), converter.convert(ternary.iftrue), converter.convert(ternary.iffalse));
}
record TypeVariants(RefTypeOrTPHOrWildcardOrGeneric in, List<RefTypeOrTPHOrWildcardOrGeneric> types) {}
private List<TypeVariants> extractAllPatterns(Pattern pattern) {
return switch (pattern) {
case GuardedPattern guarded -> extractAllPatterns(guarded.getNestedPattern());
case RecordPattern recordPattern -> recordPattern.getSubPattern().stream()
.map(this::extractAllPatterns)
.flatMap(List::stream).toList();
case FormalParameter param -> List.of(new TypeVariants(param.getType(), converter.findAllVariants(param.getType())));
default -> List.of();
};
}
record TypePair(RefTypeOrTPHOrWildcardOrGeneric in, RefTypeOrTPHOrWildcardOrGeneric out) {}
private void cartesianProduct(
List<TypeVariants> variants, int index,
List<RefTypeOrTPHOrWildcardOrGeneric> current,
List<List<RefTypeOrTPHOrWildcardOrGeneric>> result) {
if (index == variants.size()) {
result.add(new ArrayList<>(current));
return;
}
var currentSet = variants.get(index).types;
for (var element: currentSet) {
current.add(element);
cartesianProduct(variants, index + 1, current, result);
current.removeLast();
}
}
private List<List<TypePair>> cartesianProduct(List<TypeVariants> variants) {
var prod = new ArrayList<List<RefTypeOrTPHOrWildcardOrGeneric>>();
cartesianProduct(variants, 0, new ArrayList<>(), prod);
var res = new ArrayList<List<TypePair>>();
for (var list : prod) {
var l = new ArrayList<TypePair>();
for (var i = 0; i < list.size(); i++) {
l.add(new TypePair(variants.get(i).in, list.get(i)));
}
res.add(l);
}
return res;
}
@Override
public void visit(Switch switchStmt) {
var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(converter::convert).toList();
var variants = converter.findAllVariants(switchStmt.getSwitch().getType());
var returns = converter.findAllVariants(switchStmt.getType());
var canBeOverloaded = variants.size() == 1 && returns.size() == 1;
var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(case_ -> {
var overloads = new ArrayList<TargetSwitch.Case>();
if (canBeOverloaded) {
for (var label: case_.getLabels()) {
var product = cartesianProduct(extractAllPatterns(label.getPattern()));
for (var l : product) {
var oldGenerics = converter.generics;
// Set the generics to matching result set
for (var generics : converter.all) {
var java = generics.javaGenerics();
var equals = true;
for (var pair : l) {
if (!java.getType(pair.in).equals(pair.out)) {
equals = false; break;
}
}
if (equals) {
converter.generics = generics;
break;
}
}
overloads.add(converter.convert(case_));
converter.generics = oldGenerics;
}
}
} else {
overloads.add(converter.convert(case_));
}
return overloads;
}).flatMap(List::stream).toList();
TargetSwitch.Case default_ = null;
for (var block : switchStmt.getBlocks()) {

View File

@ -18,6 +18,19 @@ public record TargetMethod(int access, String name, TargetBlock block, Signature
public String getDescriptor() {
return TargetMethod.getDescriptor(returnType, parameters.stream().map(MethodParameter::pattern).map(TargetPattern::type).toArray(TargetType[]::new));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Signature signature = (Signature) o;
return Objects.equals(parameters, signature.parameters);
}
@Override
public int hashCode() {
return Objects.hash(parameters);
}
}
public static String getDescriptor(TargetType returnType, TargetType... parameters) {

View File

@ -849,6 +849,24 @@ public class TestComplete {
assertEquals(m2.invoke(instance, 10), 10);
}
@Test
public void testOverloadSwitch() throws Exception {
var classFiles = generateClassFiles(new ByteArrayClassLoader(), "SwitchOverload.jav");
var clazz = classFiles.get("SwitchOverload");
var R = classFiles.get("R");
var rctor = R.getDeclaredConstructor(Number.class);
var instance = clazz.getDeclaredConstructor().newInstance();
var m = clazz.getDeclaredMethod("m", R);
var x = rctor.newInstance(10);
var d = rctor.newInstance(20.0);
assertEquals(m.invoke(instance, x), 50);
assertEquals(m.invoke(instance, d), 40.0);
}
@Test
public void testInterfaces() throws Exception {
var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Interfaces.jav");