Make switches work with set types

This commit is contained in:
Daniel Holle 2023-08-16 17:13:28 +02:00
parent be5591f7dc
commit 762d344e42
6 changed files with 80 additions and 24 deletions

View File

@ -5,10 +5,10 @@ import java.lang.Float;
record Rec(Integer a, Object b) {}
public class Switch {
main(Object o) {
Integer main(Object o) {
return switch (o) {
case Rec(Integer a, Integer b) -> { yield a + b; }
case Rec(Integer a, Float b) -> { yield a * b; }
case Rec(Integer a, Float b) -> { yield a + 10; }
case Rec(Integer a, Rec(Integer b, Integer c)) -> { yield a + b + c; }
case Integer i -> { yield i; }
default -> { yield 0; }

View File

@ -28,7 +28,12 @@ public class Codegen {
public Codegen(TargetStructure clazz, JavaTXCompiler compiler) {
this.clazz = clazz;
this.className = clazz.qualifiedName();
this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) {
@Override
protected ClassLoader getClassLoader() {
return compiler.getClassLoader();
}
};
this.compiler = compiler;
}
@ -1084,7 +1089,7 @@ public class Codegen {
private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) {
var mv = state.mv;
generate(state, aSwitch.expr());
var tmp = state.localCounter++;
var tmp = state.localCounter;
mv.visitInsn(DUP);
mv.visitVarInsn(ASTORE, tmp);
@ -1148,12 +1153,12 @@ public class Codegen {
if (cse.labels().size() == 1) {
var label = cse.labels().get(0);
if (label instanceof Guard gd){
if (label instanceof Guard gd) {
state.mv.visitVarInsn(ALOAD, tmp);
bindPattern(state, aSwitch.expr().type(), gd.inner(), start);
bindPattern(state, aSwitch.expr().type(), gd.inner(), start, i, 1);
} else if (label instanceof TargetPattern pat) {
state.mv.visitVarInsn(ALOAD, tmp);
bindPattern(state, aSwitch.expr().type(), pat, start);
bindPattern(state, aSwitch.expr().type(), pat, start, i, 1);
}
if (label instanceof Guard gd) {
@ -1194,30 +1199,51 @@ public class Codegen {
state.exitScope();
}
private void bindPattern(State state, TargetType type, TargetPattern pat, Label start) {
private void bindPattern(State state, TargetType type, TargetPattern pat, Label start, int index, int depth) {
if (pat.type() instanceof TargetPrimitiveType)
boxPrimitive(state, pat.type());
state.mv.visitInsn(DUP);
state.mv.visitTypeInsn(INSTANCEOF, pat.type().getInternalName());
var cont = new Label();
state.mv.visitJumpInsn(IFNE, cont);
for (var i = 0; i < depth; i++) {
state.mv.visitInsn(POP);
}
state.mv.visitVarInsn(ALOAD, state.switchResultValue.peek());
state.mv.visitLdcInsn(index + 1);
state.mv.visitJumpInsn(GOTO, start);
state.mv.visitLabel(cont);
state.mv.visitTypeInsn(CHECKCAST, pat.type().getInternalName());
if (pat instanceof SimplePattern sp) {
var local = state.createVariable(sp.name(), sp.type());
convertTo(state, type, sp.type());
boxPrimitive(state, sp.type());
state.mv.visitVarInsn(ASTORE, local.index);
} else if (pat instanceof ComplexPattern cp) {
convertTo(state, type, cp.type());
boxPrimitive(state, cp.type());
if (cp.name() != null) {
state.mv.visitInsn(DUP);
var local = state.createVariable(cp.name(), cp.type());
state.mv.visitVarInsn(ASTORE, local.index);
}
var clazz = findClass(new JavaClassName(cp.type().name()));
if (clazz == null) throw new CodeGenException("Class definition for '" + cp.type().name() + "' not found");
// TODO Check if class is a Record
for (var i = 0; i < cp.subPatterns().size(); i++) {
state.mv.visitInsn(DUP);
var subPattern = cp.subPatterns().get(i);
if (i >= clazz.getFieldDecl().size())
throw new CodeGenException("Couldn't find suitable field accessor for '" + cp.type().name() + "'");
var field = clazz.getFieldDecl().get(i);
var fieldType = new TargetRefType(((RefType) field.getType()).getName().toString());
state.mv.visitMethodInsn(INVOKEDYNAMIC, cp.type().getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false);
convertTo(state, fieldType, subPattern.type());
bindPattern(state, subPattern.type(), subPattern, start);
state.mv.visitMethodInsn(INVOKEVIRTUAL, cp.type().getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false);
bindPattern(state, subPattern.type(), subPattern, start, index, depth + 1);
}
state.mv.visitInsn(POP);
}
}

View File

@ -74,6 +74,10 @@ public class JavaTXCompiler {
Boolean log = true; //gibt an ob ein Log-File nach System.getProperty("user.dir")+""/logFiles/"" geschrieben werden soll?
public volatile UnifyTaskModel usedTasks = new UnifyTaskModel();
private final DirectoryClassLoader classLoader;
public DirectoryClassLoader getClassLoader() {
return classLoader;
}
public JavaTXCompiler(File sourceFile) throws IOException, ClassNotFoundException {
this(Arrays.asList(sourceFile), null);

View File

@ -355,7 +355,7 @@ public class StatementGenerator {
for (SwitchBlockStatementGroupContext blockstmt : stmt.switchBlockStatementGroup()) {
switchBlocks.add(convert(blockstmt));
}
return new Switch(switched, switchBlocks, switched.getType(), true, stmt.getStart());
return new Switch(switched, switchBlocks, TypePlaceholder.fresh(switched.getOffset()), true, stmt.getStart());
}
// Um switchExpressions als Statement zu behandeln
@ -474,7 +474,10 @@ public class StatementGenerator {
switch (pPattern) {
case TPatternContext tPattern:
TypePatternContext typePattern = tPattern.typePattern();
return new Pattern(typePattern.identifier().getText(), TypeGenerator.convert(typePattern.typeType(), reg, generics), typePattern.getStart());
var text = typePattern.identifier().getText();
var type = TypeGenerator.convert(typePattern.typeType(), reg, generics);
localVars.put(text, type);
return new Pattern(text, type, typePattern.getStart());
case RPatternContext rPattern:
RecordPatternContext recordPattern = rPattern.recordPattern();
return convert(recordPattern);
@ -490,7 +493,10 @@ public class StatementGenerator {
return (Pattern) convert(patternCtx);
}).collect(Collectors.toList());
IdentifierContext identifierCtx = recordPatternCtx.identifier();
return new RecordPattern(subPattern, (identifierCtx != null) ? identifierCtx.getText() : null, TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics), recordPatternCtx.getStart());
var text = (identifierCtx != null) ? identifierCtx.getText() : null;
var type = TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics);
if (text != null) localVars.put(text, type);
return new RecordPattern(subPattern, text, type, recordPatternCtx.getStart());
}
private Statement convert(Java17Parser.WhileloopContext stmt) {

View File

@ -1,10 +1,7 @@
//PL 2018-12-19: Merge chekcen
package de.dhbwstuttgart.typeinference.typeAlgo;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import de.dhbwstuttgart.exceptions.NotImplementedException;
@ -78,6 +75,7 @@ public class TYPEStmt implements StatementVisitor {
private final TypeInferenceBlockInformation info;
private final ConstraintSet constraintsSet = new ConstraintSet();
private final Stack<Switch> switchStack = new Stack<>();
public TYPEStmt(TypeInferenceBlockInformation info) {
this.info = info;
@ -735,12 +733,19 @@ public class TYPEStmt implements StatementVisitor {
@Override
public void visit(Switch switchStmt) {
// TODO Auto-generated method stub
switchStack.push(switchStmt);
for (var child : switchStmt.getBlocks()) {
child.accept(this);
constraintsSet.addUndConstraint(new Pair(child.getType(), switchStmt.getType(), PairOperator.SMALLERDOT));
}
switchStack.pop();
}
@Override
public void visit(SwitchBlock switchBlock) {
// TODO Auto-generated method stub
for (var stmt : switchBlock.statements) {
stmt.accept(this);
}
}
@Override
@ -750,6 +755,8 @@ public class TYPEStmt implements StatementVisitor {
@Override
public void visit(Yield aYield) {
aYield.retexpr.accept(this);
constraintsSet.addUndConstraint(new Pair(aYield.getType(), switchStack.peek().getType(), PairOperator.SMALLERDOT));
// TODO Auto-generated method stub
}

View File

@ -656,5 +656,18 @@ public class TestComplete {
var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Switch.jav");
var clazz = classFiles.get("Switch");
var instance = clazz.getDeclaredConstructor().newInstance();
var swtch = clazz.getDeclaredMethod("main", Object.class);
var record = classFiles.get("Rec");
var ctor = record.getDeclaredConstructor(Integer.class, Object.class);
var r1 = ctor.newInstance(10, 20);
var r2 = ctor.newInstance(10, 20f);
var r3 = ctor.newInstance(10, r1);
assertEquals(swtch.invoke(instance, r1), 30);
assertEquals(swtch.invoke(instance, r2), 20);
assertEquals(swtch.invoke(instance, r3), 40);
assertEquals(swtch.invoke(instance, 50), 50);
assertEquals(swtch.invoke(instance, "Some string"), 0);
}
}