diff --git a/resources/bytecode/javFiles/SwitchString.jav b/resources/bytecode/javFiles/SwitchString.jav new file mode 100644 index 00000000..e8758c25 --- /dev/null +++ b/resources/bytecode/javFiles/SwitchString.jav @@ -0,0 +1,14 @@ +import java.lang.Integer; +import java.lang.String; +import java.lang.Object; + +public class SwitchString { + main(o) { + return switch (o) { + case "AaAaAa" -> 1; // These two have the same hash code! + case "AaAaBB" -> 2; + case "test", "TEST" -> 3; + default -> 4; + }; + } +} \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index 73b4ba78..5a758899 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -1111,21 +1111,19 @@ public class Codegen { var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, Object[].class); var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/runtime/SwitchBootstraps", "typeSwitch", mt.toMethodDescriptorString(), false); - var types = new Object[aSwitch.cases().size()]; - for (var i = 0; i < types.length; i++) { - var cse = aSwitch.cases().get(i); - var label = cse.labels().get(0); + var types = new ArrayList(aSwitch.cases().size()); + for (var cse : aSwitch.cases()) for (var label : cse.labels()) { if (label instanceof SimplePattern || label instanceof ComplexPattern) - types[i] = Type.getObjectType(label.type().getInternalName()); + types.add(Type.getObjectType(label.type().getInternalName())); else if (label instanceof TargetLiteral lit) - types[i] = lit.value(); + types.add(lit.value()); else if (label instanceof Guard guard) - types[i] = Type.getObjectType(guard.inner().type().getInternalName()); + types.add(Type.getObjectType(guard.inner().type().getInternalName())); // TODO Same here we need to evaluate constant; else throw new NotImplementedException(); } - mv.visitInvokeDynamicInsn("typeSwitch", "(Ljava/lang/Object;I)I", bootstrap, types); + mv.visitInvokeDynamicInsn("typeSwitch", "(Ljava/lang/Object;I)I", bootstrap, types.toArray()); var caseLabels = new Label[aSwitch.cases().size()]; var labels = new Label[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()]; diff --git a/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java b/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java index 7f1a9159..88912144 100644 --- a/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java +++ b/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java @@ -735,8 +735,13 @@ public class TYPEStmt implements StatementVisitor { public void visit(Switch switchStmt) { switchStack.push(switchStmt); for (var child : switchStmt.getBlocks()) { - for (var label : child.getLabels()) if (label.getExpression() instanceof Pattern) - constraintsSet.addUndConstraint(new Pair(label.getExpression().getType(), switchStmt.getSwitch().getType(), PairOperator.SMALLERDOT)); + for (var label : child.getLabels()) { + if (label.getExpression() instanceof Pattern) { + constraintsSet.addUndConstraint(new Pair(label.getExpression().getType(), switchStmt.getSwitch().getType(), PairOperator.SMALLERDOT)); + } else { + constraintsSet.addUndConstraint(new Pair(label.getType(), switchStmt.getSwitch().getType(), PairOperator.SMALLERDOT)); + } + } child.accept(this); constraintsSet.addUndConstraint(new Pair(child.getType(), switchStmt.getType(), PairOperator.SMALLERDOT)); diff --git a/src/test/java/TestComplete.java b/src/test/java/TestComplete.java index 1d76bc68..5f3d592e 100644 --- a/src/test/java/TestComplete.java +++ b/src/test/java/TestComplete.java @@ -670,4 +670,18 @@ public class TestComplete { assertEquals(swtch.invoke(instance, 50), 50); assertEquals(swtch.invoke(instance, "Some string"), 0); } + + @Test + public void testStringSwitch() throws Exception { + var classFiles = generateClassFiles(new ByteArrayClassLoader(), "SwitchString.jav"); + var clazz = classFiles.get("SwitchString"); + var instance = clazz.getDeclaredConstructor().newInstance(); + var main = clazz.getDeclaredMethod("main", String.class); + + assertEquals(main.invoke(instance, "AaAaAa"), 1); + assertEquals(main.invoke(instance, "AaAaBB"), 2); + assertEquals(main.invoke(instance, "test"), 3); + assertEquals(main.invoke(instance, "TEST"), 3); + assertEquals(main.invoke(instance, "awawa"), 4); + } }