From 4f3164a48a209064311a69e7d2e7fcca4fddcbfd Mon Sep 17 00:00:00 2001 From: Daniel Holle Date: Fri, 28 Jul 2023 12:04:14 +0200 Subject: [PATCH] Allow the first patterns --- .../parser/antlr/Java17Parser.g4 | 5 +- .../de/dhbwstuttgart/bytecode/Codegen.java | 471 ++++++++++++------ .../StatementGenerator.java | 20 +- .../syntaxtree/statement/GuardedPattern.java | 25 +- .../syntaxtree/statement/SwitchBlock.java | 7 +- .../syntaxtree/visual/OutputGenerator.java | 9 +- .../target/generate/ASTToTargetAST.java | 2 +- .../generate/StatementToTargetExpression.java | 14 +- .../target/tree/expression/TargetSwitch.java | 19 +- src/test/java/targetast/TestCodegen.java | 32 +- 10 files changed, 390 insertions(+), 214 deletions(-) diff --git a/src/main/antlr4/de/dhbwstuttgart/parser/antlr/Java17Parser.g4 b/src/main/antlr4/de/dhbwstuttgart/parser/antlr/Java17Parser.g4 index cd041fd4..a23f17f8 100644 --- a/src/main/antlr4/de/dhbwstuttgart/parser/antlr/Java17Parser.g4 +++ b/src/main/antlr4/de/dhbwstuttgart/parser/antlr/Java17Parser.g4 @@ -717,10 +717,9 @@ switchLabelCase | DEFAULT (ARROW | COLON) #labeledRuleDefault ; -// Java17 +// Java20 guardedPattern - : variableModifier* typeType annotation* identifier ('&&' expression)* - | guardedPattern '&&' expression + : primaryPattern WITH expression ; // Java17 diff --git a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java index b43f0481..bcbe0573 100644 --- a/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java +++ b/src/main/java/de/dhbwstuttgart/bytecode/Codegen.java @@ -5,6 +5,7 @@ import de.dhbwstuttgart.syntaxtree.statement.Break; import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.type.*; +import org.antlr.v4.codegen.Target; import org.objectweb.asm.*; import java.lang.invoke.CallSite; @@ -71,6 +72,7 @@ public class Codegen { TargetType returnType; Stack breakStack = new Stack<>(); + Stack switchResultValue = new Stack<>(); State(TargetType returnType, MethodVisitor mv, int localCounter) { this.returnType = returnType; @@ -92,6 +94,13 @@ public class Codegen { localCounter += 1; return local; } + + void pushSwitch() { + switchResultValue.push(this.localCounter++); + } + void popSwitch() { + switchResultValue.pop(); + } } private void popValue(State state, TargetType type) { @@ -582,119 +591,118 @@ public class Codegen { private void generateUnaryOp(State state, TargetUnaryOp op) { var mv = state.mv; switch (op) { - case TargetUnaryOp.Add add: - // This literally does nothing - generate(state, add.expr()); - break; - case TargetUnaryOp.Negate negate: - generate(state, negate.expr()); - if (negate.type().equals(TargetType.Double)) - mv.visitInsn(DNEG); - else if (negate.type().equals(TargetType.Float)) - mv.visitInsn(FNEG); - else if (negate.type().equals(TargetType.Long)) - mv.visitInsn(LNEG); - else - mv.visitInsn(INEG); - break; - case TargetUnaryOp.Not not: - generate(state, not.expr()); - if (not.type().equals(TargetType.Long)) { - mv.visitLdcInsn(-1L); - mv.visitInsn(LXOR); - } else { - mv.visitInsn(ICONST_M1); - mv.visitInsn(IXOR); + case TargetUnaryOp.Add add -> + // This literally does nothing + generate(state, add.expr()); + case TargetUnaryOp.Negate negate -> { + generate(state, negate.expr()); + if (negate.type().equals(TargetType.Double)) + mv.visitInsn(DNEG); + else if (negate.type().equals(TargetType.Float)) + mv.visitInsn(FNEG); + else if (negate.type().equals(TargetType.Long)) + mv.visitInsn(LNEG); + else + mv.visitInsn(INEG); } - break; - case TargetUnaryOp.PreIncrement preIncrement: - generate(state, preIncrement.expr()); - if (preIncrement.type().equals(TargetType.Float)) { - mv.visitLdcInsn(1F); - mv.visitInsn(FADD); - mv.visitInsn(DUP); - } else if (preIncrement.type().equals(TargetType.Double)) { - mv.visitLdcInsn(1D); - mv.visitInsn(DADD); - mv.visitInsn(DUP2); - } else if (preIncrement.type().equals(TargetType.Long)) { - mv.visitLdcInsn(1L); - mv.visitInsn(LADD); - mv.visitInsn(DUP2); - } else { - mv.visitLdcInsn(1); - mv.visitInsn(IADD); - mv.visitInsn(DUP); + case TargetUnaryOp.Not not -> { + generate(state, not.expr()); + if (not.type().equals(TargetType.Long)) { + mv.visitLdcInsn(-1L); + mv.visitInsn(LXOR); + } else { + mv.visitInsn(ICONST_M1); + mv.visitInsn(IXOR); + } } - boxPrimitive(state, preIncrement.type()); - afterIncDec(state, preIncrement); - break; - case TargetUnaryOp.PreDecrement preDecrement: - generate(state, preDecrement.expr()); - if (preDecrement.type().equals(TargetType.Float)) { - mv.visitLdcInsn(1F); - mv.visitInsn(FSUB); - mv.visitInsn(DUP); - } else if (preDecrement.type().equals(TargetType.Double)) { - mv.visitLdcInsn(1D); - mv.visitInsn(DSUB); - mv.visitInsn(DUP2); - } else if (preDecrement.type().equals(TargetType.Long)) { - mv.visitLdcInsn(1L); - mv.visitInsn(LSUB); - mv.visitInsn(DUP2); - } else { - mv.visitLdcInsn(1); - mv.visitInsn(ISUB); - mv.visitInsn(DUP); + case TargetUnaryOp.PreIncrement preIncrement -> { + generate(state, preIncrement.expr()); + if (preIncrement.type().equals(TargetType.Float)) { + mv.visitLdcInsn(1F); + mv.visitInsn(FADD); + mv.visitInsn(DUP); + } else if (preIncrement.type().equals(TargetType.Double)) { + mv.visitLdcInsn(1D); + mv.visitInsn(DADD); + mv.visitInsn(DUP2); + } else if (preIncrement.type().equals(TargetType.Long)) { + mv.visitLdcInsn(1L); + mv.visitInsn(LADD); + mv.visitInsn(DUP2); + } else { + mv.visitLdcInsn(1); + mv.visitInsn(IADD); + mv.visitInsn(DUP); + } + boxPrimitive(state, preIncrement.type()); + afterIncDec(state, preIncrement); } - boxPrimitive(state, preDecrement.type()); - afterIncDec(state, preDecrement); - break; - case TargetUnaryOp.PostIncrement postIncrement: - generate(state, postIncrement.expr()); - if (postIncrement.type().equals(TargetType.Float)) { - mv.visitInsn(DUP); - mv.visitLdcInsn(1F); - mv.visitInsn(FADD); - } else if (postIncrement.type().equals(TargetType.Double)) { - mv.visitInsn(DUP2); - mv.visitLdcInsn(1D); - mv.visitInsn(DADD); - } else if (postIncrement.type().equals(TargetType.Long)) { - mv.visitInsn(DUP2); - mv.visitLdcInsn(1L); - mv.visitInsn(LADD); - } else { - mv.visitInsn(DUP); - mv.visitLdcInsn(1); - mv.visitInsn(IADD); + case TargetUnaryOp.PreDecrement preDecrement -> { + generate(state, preDecrement.expr()); + if (preDecrement.type().equals(TargetType.Float)) { + mv.visitLdcInsn(1F); + mv.visitInsn(FSUB); + mv.visitInsn(DUP); + } else if (preDecrement.type().equals(TargetType.Double)) { + mv.visitLdcInsn(1D); + mv.visitInsn(DSUB); + mv.visitInsn(DUP2); + } else if (preDecrement.type().equals(TargetType.Long)) { + mv.visitLdcInsn(1L); + mv.visitInsn(LSUB); + mv.visitInsn(DUP2); + } else { + mv.visitLdcInsn(1); + mv.visitInsn(ISUB); + mv.visitInsn(DUP); + } + boxPrimitive(state, preDecrement.type()); + afterIncDec(state, preDecrement); } - boxPrimitive(state, postIncrement.type()); - afterIncDec(state, postIncrement); - break; - case TargetUnaryOp.PostDecrement postDecrement: - generate(state, postDecrement.expr()); - if (postDecrement.type().equals(TargetType.Float)) { - mv.visitInsn(DUP); - mv.visitLdcInsn(1F); - mv.visitInsn(FSUB); - } else if (postDecrement.type().equals(TargetType.Double)) { - mv.visitInsn(DUP2); - mv.visitLdcInsn(1D); - mv.visitInsn(DSUB); - } else if (postDecrement.type().equals(TargetType.Long)) { - mv.visitInsn(DUP2); - mv.visitLdcInsn(1L); - mv.visitInsn(LSUB); - } else { - mv.visitInsn(DUP); - mv.visitLdcInsn(1); - mv.visitInsn(ISUB); + case TargetUnaryOp.PostIncrement postIncrement -> { + generate(state, postIncrement.expr()); + if (postIncrement.type().equals(TargetType.Float)) { + mv.visitInsn(DUP); + mv.visitLdcInsn(1F); + mv.visitInsn(FADD); + } else if (postIncrement.type().equals(TargetType.Double)) { + mv.visitInsn(DUP2); + mv.visitLdcInsn(1D); + mv.visitInsn(DADD); + } else if (postIncrement.type().equals(TargetType.Long)) { + mv.visitInsn(DUP2); + mv.visitLdcInsn(1L); + mv.visitInsn(LADD); + } else { + mv.visitInsn(DUP); + mv.visitLdcInsn(1); + mv.visitInsn(IADD); + } + boxPrimitive(state, postIncrement.type()); + afterIncDec(state, postIncrement); + } + case TargetUnaryOp.PostDecrement postDecrement -> { + generate(state, postDecrement.expr()); + if (postDecrement.type().equals(TargetType.Float)) { + mv.visitInsn(DUP); + mv.visitLdcInsn(1F); + mv.visitInsn(FSUB); + } else if (postDecrement.type().equals(TargetType.Double)) { + mv.visitInsn(DUP2); + mv.visitLdcInsn(1D); + mv.visitInsn(DSUB); + } else if (postDecrement.type().equals(TargetType.Long)) { + mv.visitInsn(DUP2); + mv.visitLdcInsn(1L); + mv.visitInsn(LSUB); + } else { + mv.visitInsn(DUP); + mv.visitLdcInsn(1); + mv.visitInsn(ISUB); + } + boxPrimitive(state, postDecrement.type()); + afterIncDec(state, postDecrement); } - boxPrimitive(state, postDecrement.type()); - afterIncDec(state, postDecrement); - break; } } @@ -773,31 +781,19 @@ public class Codegen { break; case TargetLiteral literal: switch (literal) { - case IntLiteral intLiteral: - mv.visitLdcInsn(intLiteral.value()); - break; - case FloatLiteral floatLiteral: - mv.visitLdcInsn(floatLiteral.value()); - break; - case LongLiteral longLiteral: - mv.visitLdcInsn(longLiteral.value()); - break; - case StringLiteral stringLiteral: - mv.visitLdcInsn(stringLiteral.value()); - break; - case CharLiteral charLiteral: - mv.visitIntInsn(BIPUSH, charLiteral.value()); - break; - case DoubleLiteral doubleLiteral: - mv.visitLdcInsn(doubleLiteral.value()); - break; - case BooleanLiteral booleanLiteral: - if (booleanLiteral.value()) { - mv.visitInsn(ICONST_1); - } else { - mv.visitInsn(ICONST_0); + case IntLiteral intLiteral -> mv.visitLdcInsn(intLiteral.value()); + case FloatLiteral floatLiteral -> mv.visitLdcInsn(floatLiteral.value()); + case LongLiteral longLiteral -> mv.visitLdcInsn(longLiteral.value()); + case StringLiteral stringLiteral -> mv.visitLdcInsn(stringLiteral.value()); + case CharLiteral charLiteral -> mv.visitIntInsn(BIPUSH, charLiteral.value()); + case DoubleLiteral doubleLiteral -> mv.visitLdcInsn(doubleLiteral.value()); + case BooleanLiteral booleanLiteral -> { + if (booleanLiteral.value()) { + mv.visitInsn(ICONST_1); + } else { + mv.visitInsn(ICONST_0); + } } - break; } break; case TargetVarDecl varDecl: { @@ -820,30 +816,27 @@ public class Codegen { break; case TargetAssign assign: { switch (assign.left()) { - case TargetLocalVar localVar: { - generate(state, assign.right()); - convertTo(state, assign.right().type(), localVar.type()); - boxPrimitive(state, localVar.type()); - var local = state.scope.get(localVar.name()); - mv.visitInsn(DUP); - mv.visitVarInsn(ASTORE, local.index()); - break; - } - case TargetFieldVar dot: { - var fieldType = dot.type(); - generate(state, dot.left()); - generate(state, assign.right()); - convertTo(state, assign.right().type(), fieldType); - boxPrimitive(state, fieldType); - if (dot.isStatic()) + case TargetLocalVar localVar -> { + generate(state, assign.right()); + convertTo(state, assign.right().type(), localVar.type()); + boxPrimitive(state, localVar.type()); + var local = state.scope.get(localVar.name()); mv.visitInsn(DUP); - else - mv.visitInsn(DUP_X1); - mv.visitFieldInsn(dot.isStatic() ? PUTSTATIC : PUTFIELD, dot.owner().getInternalName(), dot.right(), fieldType.toSignature()); - break; - } - default: - throw new CodeGenException("Invalid assignment"); + mv.visitVarInsn(ASTORE, local.index()); + } + case TargetFieldVar dot -> { + var fieldType = dot.type(); + generate(state, dot.left()); + generate(state, assign.right()); + convertTo(state, assign.right().type(), fieldType); + boxPrimitive(state, fieldType); + if (dot.isStatic()) + mv.visitInsn(DUP); + else + mv.visitInsn(DUP_X1); + mv.visitFieldInsn(dot.isStatic() ? PUTSTATIC : PUTFIELD, dot.owner().getInternalName(), dot.right(), fieldType.toSignature()); + } + default -> throw new CodeGenException("Invalid assignment"); } break; } @@ -890,7 +883,6 @@ public class Codegen { } mv.visitJumpInsn(GOTO, start); mv.visitLabel(end); - mv.visitInsn(NOP); state.exitScope(); state.localCounter = localCounter; break; @@ -912,7 +904,6 @@ public class Codegen { mv.visitJumpInsn(GOTO, start); mv.visitLabel(end); - mv.visitInsn(NOP); break; } case TargetIf _if: { @@ -927,7 +918,6 @@ public class Codegen { generate(state, _if.else_body()); } mv.visitLabel(end); - mv.visitInsn(NOP); break; } case TargetReturn ret: { @@ -940,6 +930,16 @@ public class Codegen { mv.visitInsn(RETURN); break; } + case TargetYield yield: { + generate(state, yield.expression()); + try { + yieldValue(state, yield.expression().type()); + mv.visitJumpInsn(GOTO, state.breakStack.peek().endLabel); + } catch (EmptyStackException e) { + throw new CodeGenException("Yield outside of switch expression"); + } + break; + } case TargetSwitch _switch: { generateSwitch(state, _switch); break; @@ -1004,13 +1004,25 @@ public class Codegen { } } + private void yieldValue(State state, TargetType type) { + boxPrimitive(state, type); + state.mv.visitVarInsn(ASTORE, state.switchResultValue.peek()); + } + private void generateClassicSwitch(State state, TargetSwitch aSwitch) { // TODO Constant expressions are allowed, we need to evaluate them somehow... - // For now we just assume we get literals - // TODO This always uses a lookupswitch, a tableswitch may be faster in some cases but we can't generate that in all cases + // For now we just assume we get literals... + // TODO This always uses a lookupswitch, a tableswitch may be faster in some cases but we can't generate that every time + // TODO We can't switch on Strings yet, the idea for this (like javac does it) would be to implement the hash code at compile time + // and switch based on that, adding an equals check for every case and going to yet another tableswitch which finally decides which branch to take var mv = state.mv; + if (aSwitch.isExpression()) + state.pushSwitch(); generate(state, aSwitch.expr()); + + state.enterScope(); + var keys = new int[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()]; var labels = new Label[keys.length]; var bodyLabels = new Label[aSwitch.cases().size()]; @@ -1043,28 +1055,159 @@ public class Codegen { for (var k = 0; k < aSwitch.cases().size(); k++) { mv.visitLabel(bodyLabels[k]); - generate(state, aSwitch.cases().get(k).body()); + var cse = aSwitch.cases().get(k); + generate(state, cse.body()); + if (cse.isSingleExpression() && aSwitch.isExpression()) + yieldValue(state, cse.body().statements().get(0).type()); + if (aSwitch.isExpression()) mv.visitJumpInsn(GOTO, end); } if (aSwitch.default_() != null) { mv.visitLabel(defaultLabel); - generate(state, aSwitch.default_()); + generate(state, aSwitch.default_().body()); + if (aSwitch.default_().isSingleExpression() && aSwitch.isExpression()) + yieldValue(state, aSwitch.default_().body().statements().get(0).type()); } mv.visitLabel(end); - mv.visitInsn(NOP); state.breakStack.pop(); + + if (aSwitch.isExpression()) { + mv.visitVarInsn(ALOAD, state.switchResultValue.peek()); + unboxPrimitive(state, aSwitch.type()); + state.popSwitch(); + } + + state.exitScope(); } private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) { + var mv = state.mv; + generate(state, aSwitch.expr()); + var tmp = state.localCounter++; + mv.visitInsn(DUP); + mv.visitVarInsn(ASTORE, tmp); + state.enterScope(); + // This is the index to start the switch from + mv.visitInsn(ICONST_0); + if (aSwitch.isExpression()) + state.pushSwitch(); + + // To be able to skip ahead to the next case + var start = new Label(); + mv.visitLabel(start); + + var end = new Label(); + var env = new BreakEnv(); + env.endLabel = end; + state.breakStack.push(env); + + var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, Object[].class); + var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/runtime/SwitchBootstraps", "typeSwitch", mt.toMethodDescriptorString(), false); + + var types = new Object[aSwitch.cases().size()]; + for (var i = 0; i < types.length; i++) { + var cse = aSwitch.cases().get(i); + var label = cse.labels().get(0); + if (label instanceof TargetSwitch.SimplePattern || label instanceof TargetSwitch.ComplexPattern) + types[i] = Type.getObjectType(label.type().getInternalName()); + else if (label instanceof TargetLiteral lit) + types[i] = lit.value(); + else if (label instanceof TargetSwitch.Guard guard) + types[i] = Type.getObjectType(guard.inner().type().getInternalName()); + // TODO Same here we need to evaluate constants + else throw new NotImplementedException(); + } + + mv.visitInvokeDynamicInsn("typeSwitch", "(Ljava/lang/Object;I)I", bootstrap, types); + + var caseLabels = new Label[aSwitch.cases().size()]; + var labels = new Label[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()]; + var j = 0; + for (var i = 0; i < caseLabels.length; i++) { + var cse = aSwitch.cases().get(i); + var label = new Label(); + caseLabels[i] = label; + for (var k = 0; k < cse.labels().size(); k++) { + labels[j] = label; + j += 1; + } + } + + var defaultLabel = end; + if (aSwitch.default_() != null) { + defaultLabel = new Label(); + } + + mv.visitTableSwitchInsn(0, labels.length - 1, defaultLabel, labels); + + for (var i = 0; i < aSwitch.cases().size(); i++) { + mv.visitLabel(caseLabels[i]); + var cse = aSwitch.cases().get(i); + + if (cse.labels().size() == 1) { + var label = cse.labels().get(0); + if (label instanceof TargetSwitch.Guard gd) + bindLabel(state, tmp, aSwitch.expr().type(), gd.inner()); + else if (label instanceof TargetSwitch.Pattern pat) + bindLabel(state, tmp, aSwitch.expr().type(), pat); + + if (label instanceof TargetSwitch.Guard gd) { + generate(state, gd.expression()); + var next = new Label(); + mv.visitJumpInsn(IFNE, next); + mv.visitVarInsn(ALOAD, tmp); + // Push the offset onto the stack (this is used by the invokedynamic call) + mv.visitLdcInsn(i + 1); + mv.visitJumpInsn(GOTO, start); + mv.visitLabel(next); + } + } + + generate(state, cse.body()); + if (cse.isSingleExpression() && aSwitch.isExpression()) + yieldValue(state, cse.body().statements().get(0).type()); + if (aSwitch.isExpression()) mv.visitJumpInsn(GOTO, end); + } + + if (aSwitch.default_() != null) { + mv.visitLabel(defaultLabel); + generate(state, aSwitch.default_().body()); + if (aSwitch.default_().isSingleExpression() && aSwitch.isExpression()) + yieldValue(state, aSwitch.default_().body().statements().get(0).type()); + } + + mv.visitLabel(end); + //mv.visitInsn(POP); + + state.breakStack.pop(); + if (aSwitch.isExpression()) { + mv.visitVarInsn(ALOAD, state.switchResultValue.peek()); + unboxPrimitive(state, aSwitch.type()); + state.popSwitch(); + } + + state.exitScope(); + } + + private void bindLabel(State state, int tmp, TargetType type, TargetSwitch.Pattern pat) { + if (pat instanceof TargetSwitch.SimplePattern sp) { + state.mv.visitVarInsn(ALOAD, tmp); + var local = state.createVariable(sp.name(), sp.type()); + convertTo(state, type, local.type); + boxPrimitive(state, local.type); + state.mv.visitVarInsn(ASTORE, local.index); + } } final Set wrapperTypes = Set.of(TargetType.Long, TargetType.Integer, TargetType.Byte, TargetType.Char, TargetType.Boolean, TargetType.Double, TargetType.Float); private void generateSwitch(State state, TargetSwitch aSwitch) { - if (!wrapperTypes.contains(aSwitch.expr().type())) + if (!wrapperTypes.contains(aSwitch.expr().type())) { generateEnhancedSwitch(state, aSwitch); + return; + } else for (var case_ : aSwitch.cases()) { if (case_.labels().stream().anyMatch(c -> c instanceof TargetSwitch.Pattern)) { generateEnhancedSwitch(state, aSwitch); diff --git a/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java b/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java index 8fb199cf..862ecc17 100644 --- a/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java +++ b/src/main/java/de/dhbwstuttgart/parser/SyntaxTreeGenerator/StatementGenerator.java @@ -20,7 +20,6 @@ import de.dhbwstuttgart.parser.antlr.Java17Parser.AssignexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseandexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseorexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwisexorexpressionContext; -import de.dhbwstuttgart.parser.antlr.Java17Parser.BlockStatementContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BlockstmtContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BoolLiteralContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BreakstmtContext; @@ -381,7 +380,7 @@ public class StatementGenerator { return new Switch(switched, switchBlocks, TypePlaceholder.fresh(offset), false, offset); } - private SwitchBlock convert(Java17Parser.SwitchLabeledRuleContext labeledRule) { + private SwitchBlock convert(SwitchLabeledRuleContext labeledRule) { Boolean isDefault = false; List labels = switch (labeledRule.switchLabelCase()) { case LabeledRuleExprListContext exprList -> { @@ -401,11 +400,14 @@ public class StatementGenerator { } default -> throw new NotImplementedException(); }; + + var isSingleExpression = false; Token offset = labeledRule.getStart(); SwitchRuleOutcomeContext outcome = labeledRule.switchRuleOutcome(); Block block; if (Objects.isNull(outcome.block())) { List stmts = new ArrayList<>(); + if (outcome.blockStatement().size() == 1) isSingleExpression = true; outcome.blockStatement().stream().forEach((stmt) -> { stmts.addAll(convert(stmt)); }); @@ -414,7 +416,7 @@ public class StatementGenerator { } else { block = convert(outcome.block(), false); } - return new SwitchBlock(labels, block, isDefault, offset); + return new SwitchBlock(labels, block, isDefault, isSingleExpression, offset); } private Statement convert(Java17Parser.YieldstmtContext yieldstmt) { @@ -430,7 +432,7 @@ public class StatementGenerator { stmt.blockStatement().stream().forEach((blockStmt) -> { block.addAll(convert(blockStmt)); }); - return new SwitchBlock(labels, new Block(block, stmt.blockStatement(0).getStart()), stmt.getStart()); + return new SwitchBlock(labels, new Block(block, stmt.blockStatement(0).getStart()), false, stmt.getStart()); } private SwitchLabel convert(SwitchLabelContext switchLabel) { @@ -454,17 +456,15 @@ public class StatementGenerator { } } - private Pattern convert(PatternContext pattern) { + private Expression convert(PatternContext pattern) { return switch (pattern) { case PPatternContext pPattern -> { yield convert(pPattern.primaryPattern()); } case GPatternContext gPattern -> { GuardedPatternContext guarded = gPattern.guardedPattern(); - List conditions = guarded.expression().stream().map((expr) -> { - return convert(expr); - }).toList(); - yield new GuardedPattern(conditions, guarded.identifier().getText(), TypeGenerator.convert(guarded.typeType(), reg, generics), guarded.getStart()); + Expression condition = convert(guarded.expression()); + yield new GuardedPattern(condition, convert(guarded.primaryPattern()), guarded.getStart()); } default -> throw new NotImplementedException(); }; @@ -487,7 +487,7 @@ public class StatementGenerator { private RecordPattern convert(RecordPatternContext recordPatternCtx) { List subPatternCtx = recordPatternCtx.recordStructurePattern().recordComponentPatternList().pattern(); List subPattern = subPatternCtx.stream().map((patternCtx) -> { - return convert(patternCtx); + return (Pattern) convert(patternCtx); }).collect(Collectors.toList()); IdentifierContext identifierCtx = recordPatternCtx.identifier(); return new RecordPattern(subPattern, (identifierCtx != null) ? identifierCtx.getText() : null, TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics), recordPatternCtx.getStart()); diff --git a/src/main/java/de/dhbwstuttgart/syntaxtree/statement/GuardedPattern.java b/src/main/java/de/dhbwstuttgart/syntaxtree/statement/GuardedPattern.java index dcfd7665..9e7bb485 100644 --- a/src/main/java/de/dhbwstuttgart/syntaxtree/statement/GuardedPattern.java +++ b/src/main/java/de/dhbwstuttgart/syntaxtree/statement/GuardedPattern.java @@ -2,21 +2,32 @@ package de.dhbwstuttgart.syntaxtree.statement; import java.util.List; +import de.dhbwstuttgart.syntaxtree.StatementVisitor; import org.antlr.v4.runtime.Token; import de.dhbwstuttgart.syntaxtree.type.RefTypeOrTPHOrWildcardOrGeneric; -public class GuardedPattern extends Pattern { +public class GuardedPattern extends Expression { - private List conditions; + private final Expression condition; + private final Pattern nested; - public GuardedPattern(List conditions, String name, RefTypeOrTPHOrWildcardOrGeneric type, Token offset) { - super(name, type, offset); - this.conditions = conditions; + public GuardedPattern(Expression condition, Pattern nested, Token offset) { + super(nested.getType(), offset); + this.condition = condition; + this.nested = nested; } - public List getConditions() { - return conditions; + public Expression getCondition() { + return condition; } + public Pattern getNestedPattern() { + return nested; + } + + @Override + public void accept(StatementVisitor visitor) { + visitor.visit(this); + } } diff --git a/src/main/java/de/dhbwstuttgart/syntaxtree/statement/SwitchBlock.java b/src/main/java/de/dhbwstuttgart/syntaxtree/statement/SwitchBlock.java index afe87e89..56a18eb2 100644 --- a/src/main/java/de/dhbwstuttgart/syntaxtree/statement/SwitchBlock.java +++ b/src/main/java/de/dhbwstuttgart/syntaxtree/statement/SwitchBlock.java @@ -12,16 +12,19 @@ public class SwitchBlock extends Block { private List labels = new ArrayList<>(); private boolean defaultBlock = false; + public final boolean isExpression; // This is for single expressions that yield a value - public SwitchBlock(List labels, Block statements, Token offset) { + public SwitchBlock(List labels, Block statements, boolean isExpression, Token offset) { super(statements.getStatements(), offset); this.labels = labels; + this.isExpression = isExpression; } - public SwitchBlock(List labels, Block statements, boolean isDefault, Token offset) { + public SwitchBlock(List labels, Block statements, boolean isDefault, boolean isExpression, Token offset) { super(statements.getStatements(), offset); this.labels = labels; this.defaultBlock = isDefault; + this.isExpression = isExpression; } public boolean isDefault() { diff --git a/src/main/java/de/dhbwstuttgart/syntaxtree/visual/OutputGenerator.java b/src/main/java/de/dhbwstuttgart/syntaxtree/visual/OutputGenerator.java index 219c8172..b914ecab 100644 --- a/src/main/java/de/dhbwstuttgart/syntaxtree/visual/OutputGenerator.java +++ b/src/main/java/de/dhbwstuttgart/syntaxtree/visual/OutputGenerator.java @@ -474,11 +474,8 @@ public class OutputGenerator implements ASTVisitor { @Override public void visit(GuardedPattern aGuardedPattern) { - aGuardedPattern.getType().accept(this); - out.append(aGuardedPattern.getName()); - for (Expression cond : aGuardedPattern.getConditions()) { - out.append("&&"); - cond.accept(this); - } + aGuardedPattern.getNestedPattern().accept(this); + out.append(" with "); + aGuardedPattern.getCondition().accept(this); } } \ No newline at end of file diff --git a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java index d479c186..d65b5f10 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/ASTToTargetAST.java @@ -208,7 +208,7 @@ public class ASTToTargetAST { } protected TargetSwitch.Case convert(SwitchBlock block) { - return new TargetSwitch.Case(block.getLabels().stream().map(this::convert).toList(), convert((Block) block)); + return new TargetSwitch.Case(block.getLabels().stream().map(this::convert).toList(), convert((Block) block), block.isExpression); } protected TargetBlock convert(Block block) { diff --git a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java index cdbbb06f..5cad36bd 100644 --- a/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java +++ b/src/main/java/de/dhbwstuttgart/target/generate/StatementToTargetExpression.java @@ -344,8 +344,14 @@ public class StatementToTargetExpression implements StatementVisitor { @Override public void visit(Switch switchStmt) { var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(converter::convert).toList(); - var default_ = switchStmt.getBlocks().stream().filter(SwitchBlock::isDefault).map(s -> converter.convert((Block) s)).findFirst().orElse(null); - result = new TargetSwitch(converter.convert(switchStmt.getSwitch()), cases, default_, converter.convert(switchStmt.getType()), !switchStmt.getStatement()); + + TargetSwitch.Case default_ = null; + for (var block : switchStmt.getBlocks()) { + if (block.isDefault()) { + default_ = new TargetSwitch.Case(converter.convert((Block) block), block.isExpression); + } + } + result = new TargetSwitch(converter.convert(switchStmt.getSwitch()), cases, default_ , converter.convert(switchStmt.getType()), !switchStmt.getStatement()); } @Override @@ -374,8 +380,6 @@ public class StatementToTargetExpression implements StatementVisitor { @Override public void visit(GuardedPattern aGuardedPattern) { - //FIXME This isn't done properly inside the parser, really you should only have one guard (Chaining them together with && just yields another expression) - //And then it also needs to be able to accept complex patterns. Because of this we only accept one condition for now. - result = new TargetSwitch.Guard(new TargetSwitch.SimplePattern(converter.convert(aGuardedPattern.getType()), aGuardedPattern.getName()), converter.convert(aGuardedPattern.getConditions().get(0))); + result = new TargetSwitch.Guard((TargetSwitch.Pattern) converter.convert(aGuardedPattern.getNestedPattern()), converter.convert(aGuardedPattern.getCondition())); } } diff --git a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetSwitch.java b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetSwitch.java index 3a2315e1..44a3c394 100644 --- a/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetSwitch.java +++ b/src/main/java/de/dhbwstuttgart/target/tree/expression/TargetSwitch.java @@ -4,26 +4,33 @@ import de.dhbwstuttgart.target.tree.type.TargetType; import java.util.List; -public record TargetSwitch(TargetExpression expr, List cases, TargetBlock default_, TargetType type, boolean isExpression) implements TargetExpression { +public record TargetSwitch(TargetExpression expr, List cases, Case default_, TargetType type, boolean isExpression) implements TargetExpression { - public TargetSwitch(TargetExpression expr, List cases, TargetBlock default_) { + public TargetSwitch(TargetExpression expr, List cases, Case default_) { this(expr, cases, default_, null, false); } - public TargetSwitch(TargetExpression expr, List cases, TargetBlock default_, TargetType type) { + public TargetSwitch(TargetExpression expr, List cases, Case default_, TargetType type) { this(expr, cases, default_, type, true); } - public TargetSwitch(TargetExpression expr, List cases, TargetBlock default_, boolean isExpression) { + public TargetSwitch(TargetExpression expr, List cases, Case default_, boolean isExpression) { this(expr, cases, default_, null, isExpression); } - public record Case(List labels, TargetBlock body) {} + public record Case(List labels, TargetBlock body, boolean isSingleExpression) { + public Case(List labels, TargetBlock body) { + this(labels, body, false); + } + public Case(TargetBlock body, boolean isSingleExpression) { + this(List.of(), body, isSingleExpression); + } + } public sealed interface Pattern extends TargetExpression {} public record SimplePattern(TargetType type, String name) implements Pattern {} public record ComplexPattern(TargetType type, List subPatterns) implements Pattern {} - public record Guard(TargetExpression inner, TargetExpression expression) implements Pattern {} + public record Guard(Pattern inner, TargetExpression expression) implements Pattern {} } diff --git a/src/test/java/targetast/TestCodegen.java b/src/test/java/targetast/TestCodegen.java index 668d5e39..532335db 100644 --- a/src/test/java/targetast/TestCodegen.java +++ b/src/test/java/targetast/TestCodegen.java @@ -200,46 +200,58 @@ public class TestCodegen { @Test public void testClassicSwitch() throws Exception { - var targetClass = new TargetClass(Opcodes.ACC_PUBLIC , "Switch"); + var targetClass = new TargetClass(Opcodes.ACC_PUBLIC , "SwitchClassic"); targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchClassic", List.of(new MethodParameter(TargetType.Integer, "i")), TargetType.Integer, new TargetBlock(List.of( new TargetVarDecl(TargetType.Integer, "res", null), new TargetSwitch(new TargetLocalVar(TargetType.Integer, "i"), List.of( new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(10)), new TargetBlock( List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(0)), new TargetBreak()) )), - new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(20)), new TargetBlock(List.of())), + new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(15), new TargetLiteral.IntLiteral(20)), new TargetBlock(List.of())), new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(30)), new TargetBlock( List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(1)), new TargetBreak()) )) - ), new TargetBlock( + ), new TargetSwitch.Case(new TargetBlock( List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(2)), new TargetBreak()) - )), + ), false)), new TargetReturn(new TargetLocalVar(TargetType.Integer, "res")) ))); var clazz = generateClass(targetClass, new ByteArrayClassLoader()); + var m = clazz.getDeclaredMethod("switchClassic", Integer.class); + assertEquals(m.invoke(null, 10), 0); + assertEquals(m.invoke(null, 15), 1); + assertEquals(m.invoke(null, 20), 1); + assertEquals(m.invoke(null, 30), 1); + assertEquals(m.invoke(null, 99), 2); } @Test public void testTypeSwitch() throws Exception { - var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, "Switch"); + var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, "SwitchEnhanced"); targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchType", List.of(new MethodParameter(TargetType.Object, "obj")), TargetType.Integer, new TargetBlock(List.of( new TargetReturn(new TargetSwitch(new TargetLocalVar(TargetType.Object, "obj"), List.of( new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.String, "aString")), new TargetBlock( - List.of(new TargetLiteral.IntLiteral(0)) - )), + List.of(new TargetYield(new TargetLiteral.IntLiteral(0))) + ), false), + new TargetSwitch.Case(List.of( + new TargetSwitch.Guard(new TargetSwitch.SimplePattern(TargetType.Integer, "i"), new TargetBinaryOp.Less(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "i"), new TargetLiteral.IntLiteral(10))) + ), new TargetBlock( + List.of(new TargetLiteral.IntLiteral(3)) + ), true), new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.Integer, "anInteger")), new TargetBlock( List.of(new TargetLiteral.IntLiteral(1)) - )) - ), new TargetBlock( + ), true) + ), new TargetSwitch.Case(new TargetBlock( List.of(new TargetLiteral.IntLiteral(2)) - ), TargetType.Integer) + ), true), TargetType.Integer) )))); var clazz = generateClass(targetClass, new ByteArrayClassLoader()); var m = clazz.getDeclaredMethod("switchType", Object.class); assertEquals(m.invoke(null, "String"), 0); assertEquals(m.invoke(null, 10), 1); assertEquals(m.invoke(null, 'A'), 2); + assertEquals(m.invoke(null, 5), 3); } @Test