From 762d344e42285ef088adaf67f485b7bd2dab741d Mon Sep 17 00:00:00 2001 From: Daniel Holle Date: Wed, 16 Aug 2023 17:13:28 +0200 Subject: [PATCH] Make switches work with set types --- resources/bytecode/javFiles/Switch.jav | 4 +- .../de/dhbwstuttgart/bytecode/Codegen.java | 52 ++++++++++++++----- .../de/dhbwstuttgart/core/JavaTXCompiler.java | 4 ++ .../StatementGenerator.java | 12 +++-- .../typeinference/typeAlgo/TYPEStmt.java | 19 ++++--- src/test/java/TestComplete.java | 13 +++++ 6 files changed, 80 insertions(+), 24 deletions(-) diff --git a/resources/bytecode/javFiles/Switch.jav b/resources/bytecode/javFiles/Switch.jav index 4166eb7e..95de49e1 100644 --- a/resources/bytecode/javFiles/Switch.jav +++ b/resources/bytecode/javFiles/Switch.jav @@ -5,10 +5,10 @@ import java.lang.Float; record Rec(Integer a, Object b) {} public class Switch { - main(Object o) { + Integer main(Object o) { return switch (o) { case Rec(Integer a, Integer b) -> { yield a + b; } - case Rec(Integer a, Float b) -> { yield a * b; } + case Rec(Integer a, Float b) -> { yield a + 10; } case Rec(Integer a, Rec(Integer b, Integer c)) -> { yield a + b + c; } case Integer i -> { yield i; } default -> { yield 0; } diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index 25c1eca2..73b4ba78 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -28,7 +28,12 @@ public class Codegen { public Codegen(TargetStructure clazz, JavaTXCompiler compiler) { this.clazz = clazz; this.className = clazz.qualifiedName(); - this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + this.cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) { + @Override + protected ClassLoader getClassLoader() { + return compiler.getClassLoader(); + } + }; this.compiler = compiler; } @@ -1084,7 +1089,7 @@ public class Codegen { private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) { var mv = state.mv; generate(state, aSwitch.expr()); - var tmp = state.localCounter++; + var tmp = state.localCounter; mv.visitInsn(DUP); mv.visitVarInsn(ASTORE, tmp); @@ -1148,12 +1153,12 @@ public class Codegen { if (cse.labels().size() == 1) { var label = cse.labels().get(0); - if (label instanceof Guard gd){ + if (label instanceof Guard gd) { state.mv.visitVarInsn(ALOAD, tmp); - bindPattern(state, aSwitch.expr().type(), gd.inner(), start); + bindPattern(state, aSwitch.expr().type(), gd.inner(), start, i, 1); } else if (label instanceof TargetPattern pat) { state.mv.visitVarInsn(ALOAD, tmp); - bindPattern(state, aSwitch.expr().type(), pat, start); + bindPattern(state, aSwitch.expr().type(), pat, start, i, 1); } if (label instanceof Guard gd) { @@ -1194,30 +1199,51 @@ public class Codegen { state.exitScope(); } - private void bindPattern(State state, TargetType type, TargetPattern pat, Label start) { + private void bindPattern(State state, TargetType type, TargetPattern pat, Label start, int index, int depth) { + if (pat.type() instanceof TargetPrimitiveType) + boxPrimitive(state, pat.type()); + + state.mv.visitInsn(DUP); + state.mv.visitTypeInsn(INSTANCEOF, pat.type().getInternalName()); + + var cont = new Label(); + state.mv.visitJumpInsn(IFNE, cont); + for (var i = 0; i < depth; i++) { + state.mv.visitInsn(POP); + } + state.mv.visitVarInsn(ALOAD, state.switchResultValue.peek()); + state.mv.visitLdcInsn(index + 1); + state.mv.visitJumpInsn(GOTO, start); + state.mv.visitLabel(cont); + + state.mv.visitTypeInsn(CHECKCAST, pat.type().getInternalName()); + if (pat instanceof SimplePattern sp) { var local = state.createVariable(sp.name(), sp.type()); - convertTo(state, type, sp.type()); - boxPrimitive(state, sp.type()); state.mv.visitVarInsn(ASTORE, local.index); } else if (pat instanceof ComplexPattern cp) { - convertTo(state, type, cp.type()); - boxPrimitive(state, cp.type()); + if (cp.name() != null) { + state.mv.visitInsn(DUP); + var local = state.createVariable(cp.name(), cp.type()); + state.mv.visitVarInsn(ASTORE, local.index); + } var clazz = findClass(new JavaClassName(cp.type().name())); if (clazz == null) throw new CodeGenException("Class definition for '" + cp.type().name() + "' not found"); // TODO Check if class is a Record for (var i = 0; i < cp.subPatterns().size(); i++) { + state.mv.visitInsn(DUP); + var subPattern = cp.subPatterns().get(i); if (i >= clazz.getFieldDecl().size()) throw new CodeGenException("Couldn't find suitable field accessor for '" + cp.type().name() + "'"); var field = clazz.getFieldDecl().get(i); var fieldType = new TargetRefType(((RefType) field.getType()).getName().toString()); - state.mv.visitMethodInsn(INVOKEDYNAMIC, cp.type().getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false); - convertTo(state, fieldType, subPattern.type()); - bindPattern(state, subPattern.type(), subPattern, start); + state.mv.visitMethodInsn(INVOKEVIRTUAL, cp.type().getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false); + bindPattern(state, subPattern.type(), subPattern, start, index, depth + 1); } + state.mv.visitInsn(POP); } } diff --git a/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java b/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java index 192b8b66..a6b809c1 100644 --- a/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java +++ b/src/main/java/de/dhbwstuttgart/core/JavaTXCompiler.java @@ -74,6 +74,10 @@ public class JavaTXCompiler { Boolean log = true; //gibt an ob ein Log-File nach System.getProperty("user.dir")+""/logFiles/"" geschrieben werden soll? public volatile UnifyTaskModel usedTasks = new UnifyTaskModel(); private final DirectoryClassLoader classLoader; + + public DirectoryClassLoader getClassLoader() { + return classLoader; + } public JavaTXCompiler(File sourceFile) throws IOException, ClassNotFoundException { this(Arrays.asList(sourceFile), null); diff --git a/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java b/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java index 4ee5121d..1da80097 100644 --- a/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java +++ b/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java @@ -355,7 +355,7 @@ public class StatementGenerator { for (SwitchBlockStatementGroupContext blockstmt : stmt.switchBlockStatementGroup()) { switchBlocks.add(convert(blockstmt)); } - return new Switch(switched, switchBlocks, switched.getType(), true, stmt.getStart()); + return new Switch(switched, switchBlocks, TypePlaceholder.fresh(switched.getOffset()), true, stmt.getStart()); } // Um switchExpressions als Statement zu behandeln @@ -474,7 +474,10 @@ public class StatementGenerator { switch (pPattern) { case TPatternContext tPattern: TypePatternContext typePattern = tPattern.typePattern(); - return new Pattern(typePattern.identifier().getText(), TypeGenerator.convert(typePattern.typeType(), reg, generics), typePattern.getStart()); + var text = typePattern.identifier().getText(); + var type = TypeGenerator.convert(typePattern.typeType(), reg, generics); + localVars.put(text, type); + return new Pattern(text, type, typePattern.getStart()); case RPatternContext rPattern: RecordPatternContext recordPattern = rPattern.recordPattern(); return convert(recordPattern); @@ -490,7 +493,10 @@ public class StatementGenerator { return (Pattern) convert(patternCtx); }).collect(Collectors.toList()); IdentifierContext identifierCtx = recordPatternCtx.identifier(); - return new RecordPattern(subPattern, (identifierCtx != null) ? identifierCtx.getText() : null, TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics), recordPatternCtx.getStart()); + var text = (identifierCtx != null) ? identifierCtx.getText() : null; + var type = TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics); + if (text != null) localVars.put(text, type); + return new RecordPattern(subPattern, text, type, recordPatternCtx.getStart()); } private Statement convert(Java17Parser.WhileloopContext stmt) { diff --git a/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java b/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java index e54a7f77..5c68974d 100644 --- a/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java +++ b/src/main/java/de/dhbwstuttgart/typeinference/typeAlgo/TYPEStmt.java @@ -1,10 +1,7 @@ //PL 2018-12-19: Merge chekcen package de.dhbwstuttgart.typeinference.typeAlgo; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import de.dhbwstuttgart.exceptions.NotImplementedException; @@ -78,6 +75,7 @@ public class TYPEStmt implements StatementVisitor { private final TypeInferenceBlockInformation info; private final ConstraintSet constraintsSet = new ConstraintSet(); + private final Stack switchStack = new Stack<>(); public TYPEStmt(TypeInferenceBlockInformation info) { this.info = info; @@ -735,12 +733,19 @@ public class TYPEStmt implements StatementVisitor { @Override public void visit(Switch switchStmt) { - // TODO Auto-generated method stub + switchStack.push(switchStmt); + for (var child : switchStmt.getBlocks()) { + child.accept(this); + constraintsSet.addUndConstraint(new Pair(child.getType(), switchStmt.getType(), PairOperator.SMALLERDOT)); + } + switchStack.pop(); } @Override public void visit(SwitchBlock switchBlock) { - // TODO Auto-generated method stub + for (var stmt : switchBlock.statements) { + stmt.accept(this); + } } @Override @@ -750,6 +755,8 @@ public class TYPEStmt implements StatementVisitor { @Override public void visit(Yield aYield) { + aYield.retexpr.accept(this); + constraintsSet.addUndConstraint(new Pair(aYield.getType(), switchStack.peek().getType(), PairOperator.SMALLERDOT)); // TODO Auto-generated method stub } diff --git a/src/test/java/TestComplete.java b/src/test/java/TestComplete.java index 9679efab..1d76bc68 100644 --- a/src/test/java/TestComplete.java +++ b/src/test/java/TestComplete.java @@ -656,5 +656,18 @@ public class TestComplete { var classFiles = generateClassFiles(new ByteArrayClassLoader(), "Switch.jav"); var clazz = classFiles.get("Switch"); var instance = clazz.getDeclaredConstructor().newInstance(); + var swtch = clazz.getDeclaredMethod("main", Object.class); + + var record = classFiles.get("Rec"); + var ctor = record.getDeclaredConstructor(Integer.class, Object.class); + var r1 = ctor.newInstance(10, 20); + var r2 = ctor.newInstance(10, 20f); + var r3 = ctor.newInstance(10, r1); + + assertEquals(swtch.invoke(instance, r1), 30); + assertEquals(swtch.invoke(instance, r2), 20); + assertEquals(swtch.invoke(instance, r3), 40); + assertEquals(swtch.invoke(instance, 50), 50); + assertEquals(swtch.invoke(instance, "Some string"), 0); } }