Allow the first patterns

This commit is contained in:
Daniel Holle 2023-07-28 12:04:14 +02:00
parent be55d661cb
commit 4f3164a48a
10 changed files with 390 additions and 214 deletions

View File

@ -717,10 +717,9 @@ switchLabelCase
| DEFAULT (ARROW | COLON) #labeledRuleDefault | DEFAULT (ARROW | COLON) #labeledRuleDefault
; ;
// Java17 // Java20
guardedPattern guardedPattern
: variableModifier* typeType annotation* identifier ('&&' expression)* : primaryPattern WITH expression
| guardedPattern '&&' expression
; ;
// Java17 // Java17

View File

@ -5,6 +5,7 @@ import de.dhbwstuttgart.syntaxtree.statement.Break;
import de.dhbwstuttgart.target.tree.*; import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.*; import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*; import de.dhbwstuttgart.target.tree.type.*;
import org.antlr.v4.codegen.Target;
import org.objectweb.asm.*; import org.objectweb.asm.*;
import java.lang.invoke.CallSite; import java.lang.invoke.CallSite;
@ -71,6 +72,7 @@ public class Codegen {
TargetType returnType; TargetType returnType;
Stack<BreakEnv> breakStack = new Stack<>(); Stack<BreakEnv> breakStack = new Stack<>();
Stack<Integer> switchResultValue = new Stack<>();
State(TargetType returnType, MethodVisitor mv, int localCounter) { State(TargetType returnType, MethodVisitor mv, int localCounter) {
this.returnType = returnType; this.returnType = returnType;
@ -92,6 +94,13 @@ public class Codegen {
localCounter += 1; localCounter += 1;
return local; return local;
} }
void pushSwitch() {
switchResultValue.push(this.localCounter++);
}
void popSwitch() {
switchResultValue.pop();
}
} }
private void popValue(State state, TargetType type) { private void popValue(State state, TargetType type) {
@ -582,11 +591,10 @@ public class Codegen {
private void generateUnaryOp(State state, TargetUnaryOp op) { private void generateUnaryOp(State state, TargetUnaryOp op) {
var mv = state.mv; var mv = state.mv;
switch (op) { switch (op) {
case TargetUnaryOp.Add add: case TargetUnaryOp.Add add ->
// This literally does nothing // This literally does nothing
generate(state, add.expr()); generate(state, add.expr());
break; case TargetUnaryOp.Negate negate -> {
case TargetUnaryOp.Negate negate:
generate(state, negate.expr()); generate(state, negate.expr());
if (negate.type().equals(TargetType.Double)) if (negate.type().equals(TargetType.Double))
mv.visitInsn(DNEG); mv.visitInsn(DNEG);
@ -596,8 +604,8 @@ public class Codegen {
mv.visitInsn(LNEG); mv.visitInsn(LNEG);
else else
mv.visitInsn(INEG); mv.visitInsn(INEG);
break; }
case TargetUnaryOp.Not not: case TargetUnaryOp.Not not -> {
generate(state, not.expr()); generate(state, not.expr());
if (not.type().equals(TargetType.Long)) { if (not.type().equals(TargetType.Long)) {
mv.visitLdcInsn(-1L); mv.visitLdcInsn(-1L);
@ -606,8 +614,8 @@ public class Codegen {
mv.visitInsn(ICONST_M1); mv.visitInsn(ICONST_M1);
mv.visitInsn(IXOR); mv.visitInsn(IXOR);
} }
break; }
case TargetUnaryOp.PreIncrement preIncrement: case TargetUnaryOp.PreIncrement preIncrement -> {
generate(state, preIncrement.expr()); generate(state, preIncrement.expr());
if (preIncrement.type().equals(TargetType.Float)) { if (preIncrement.type().equals(TargetType.Float)) {
mv.visitLdcInsn(1F); mv.visitLdcInsn(1F);
@ -628,8 +636,8 @@ public class Codegen {
} }
boxPrimitive(state, preIncrement.type()); boxPrimitive(state, preIncrement.type());
afterIncDec(state, preIncrement); afterIncDec(state, preIncrement);
break; }
case TargetUnaryOp.PreDecrement preDecrement: case TargetUnaryOp.PreDecrement preDecrement -> {
generate(state, preDecrement.expr()); generate(state, preDecrement.expr());
if (preDecrement.type().equals(TargetType.Float)) { if (preDecrement.type().equals(TargetType.Float)) {
mv.visitLdcInsn(1F); mv.visitLdcInsn(1F);
@ -650,8 +658,8 @@ public class Codegen {
} }
boxPrimitive(state, preDecrement.type()); boxPrimitive(state, preDecrement.type());
afterIncDec(state, preDecrement); afterIncDec(state, preDecrement);
break; }
case TargetUnaryOp.PostIncrement postIncrement: case TargetUnaryOp.PostIncrement postIncrement -> {
generate(state, postIncrement.expr()); generate(state, postIncrement.expr());
if (postIncrement.type().equals(TargetType.Float)) { if (postIncrement.type().equals(TargetType.Float)) {
mv.visitInsn(DUP); mv.visitInsn(DUP);
@ -672,8 +680,8 @@ public class Codegen {
} }
boxPrimitive(state, postIncrement.type()); boxPrimitive(state, postIncrement.type());
afterIncDec(state, postIncrement); afterIncDec(state, postIncrement);
break; }
case TargetUnaryOp.PostDecrement postDecrement: case TargetUnaryOp.PostDecrement postDecrement -> {
generate(state, postDecrement.expr()); generate(state, postDecrement.expr());
if (postDecrement.type().equals(TargetType.Float)) { if (postDecrement.type().equals(TargetType.Float)) {
mv.visitInsn(DUP); mv.visitInsn(DUP);
@ -694,7 +702,7 @@ public class Codegen {
} }
boxPrimitive(state, postDecrement.type()); boxPrimitive(state, postDecrement.type());
afterIncDec(state, postDecrement); afterIncDec(state, postDecrement);
break; }
} }
} }
@ -773,31 +781,19 @@ public class Codegen {
break; break;
case TargetLiteral literal: case TargetLiteral literal:
switch (literal) { switch (literal) {
case IntLiteral intLiteral: case IntLiteral intLiteral -> mv.visitLdcInsn(intLiteral.value());
mv.visitLdcInsn(intLiteral.value()); case FloatLiteral floatLiteral -> mv.visitLdcInsn(floatLiteral.value());
break; case LongLiteral longLiteral -> mv.visitLdcInsn(longLiteral.value());
case FloatLiteral floatLiteral: case StringLiteral stringLiteral -> mv.visitLdcInsn(stringLiteral.value());
mv.visitLdcInsn(floatLiteral.value()); case CharLiteral charLiteral -> mv.visitIntInsn(BIPUSH, charLiteral.value());
break; case DoubleLiteral doubleLiteral -> mv.visitLdcInsn(doubleLiteral.value());
case LongLiteral longLiteral: case BooleanLiteral booleanLiteral -> {
mv.visitLdcInsn(longLiteral.value());
break;
case StringLiteral stringLiteral:
mv.visitLdcInsn(stringLiteral.value());
break;
case CharLiteral charLiteral:
mv.visitIntInsn(BIPUSH, charLiteral.value());
break;
case DoubleLiteral doubleLiteral:
mv.visitLdcInsn(doubleLiteral.value());
break;
case BooleanLiteral booleanLiteral:
if (booleanLiteral.value()) { if (booleanLiteral.value()) {
mv.visitInsn(ICONST_1); mv.visitInsn(ICONST_1);
} else { } else {
mv.visitInsn(ICONST_0); mv.visitInsn(ICONST_0);
} }
break; }
} }
break; break;
case TargetVarDecl varDecl: { case TargetVarDecl varDecl: {
@ -820,16 +816,15 @@ public class Codegen {
break; break;
case TargetAssign assign: { case TargetAssign assign: {
switch (assign.left()) { switch (assign.left()) {
case TargetLocalVar localVar: { case TargetLocalVar localVar -> {
generate(state, assign.right()); generate(state, assign.right());
convertTo(state, assign.right().type(), localVar.type()); convertTo(state, assign.right().type(), localVar.type());
boxPrimitive(state, localVar.type()); boxPrimitive(state, localVar.type());
var local = state.scope.get(localVar.name()); var local = state.scope.get(localVar.name());
mv.visitInsn(DUP); mv.visitInsn(DUP);
mv.visitVarInsn(ASTORE, local.index()); mv.visitVarInsn(ASTORE, local.index());
break;
} }
case TargetFieldVar dot: { case TargetFieldVar dot -> {
var fieldType = dot.type(); var fieldType = dot.type();
generate(state, dot.left()); generate(state, dot.left());
generate(state, assign.right()); generate(state, assign.right());
@ -840,10 +835,8 @@ public class Codegen {
else else
mv.visitInsn(DUP_X1); mv.visitInsn(DUP_X1);
mv.visitFieldInsn(dot.isStatic() ? PUTSTATIC : PUTFIELD, dot.owner().getInternalName(), dot.right(), fieldType.toSignature()); mv.visitFieldInsn(dot.isStatic() ? PUTSTATIC : PUTFIELD, dot.owner().getInternalName(), dot.right(), fieldType.toSignature());
break;
} }
default: default -> throw new CodeGenException("Invalid assignment");
throw new CodeGenException("Invalid assignment");
} }
break; break;
} }
@ -890,7 +883,6 @@ public class Codegen {
} }
mv.visitJumpInsn(GOTO, start); mv.visitJumpInsn(GOTO, start);
mv.visitLabel(end); mv.visitLabel(end);
mv.visitInsn(NOP);
state.exitScope(); state.exitScope();
state.localCounter = localCounter; state.localCounter = localCounter;
break; break;
@ -912,7 +904,6 @@ public class Codegen {
mv.visitJumpInsn(GOTO, start); mv.visitJumpInsn(GOTO, start);
mv.visitLabel(end); mv.visitLabel(end);
mv.visitInsn(NOP);
break; break;
} }
case TargetIf _if: { case TargetIf _if: {
@ -927,7 +918,6 @@ public class Codegen {
generate(state, _if.else_body()); generate(state, _if.else_body());
} }
mv.visitLabel(end); mv.visitLabel(end);
mv.visitInsn(NOP);
break; break;
} }
case TargetReturn ret: { case TargetReturn ret: {
@ -940,6 +930,16 @@ public class Codegen {
mv.visitInsn(RETURN); mv.visitInsn(RETURN);
break; break;
} }
case TargetYield yield: {
generate(state, yield.expression());
try {
yieldValue(state, yield.expression().type());
mv.visitJumpInsn(GOTO, state.breakStack.peek().endLabel);
} catch (EmptyStackException e) {
throw new CodeGenException("Yield outside of switch expression");
}
break;
}
case TargetSwitch _switch: { case TargetSwitch _switch: {
generateSwitch(state, _switch); generateSwitch(state, _switch);
break; break;
@ -1004,13 +1004,25 @@ public class Codegen {
} }
} }
private void yieldValue(State state, TargetType type) {
boxPrimitive(state, type);
state.mv.visitVarInsn(ASTORE, state.switchResultValue.peek());
}
private void generateClassicSwitch(State state, TargetSwitch aSwitch) { private void generateClassicSwitch(State state, TargetSwitch aSwitch) {
// TODO Constant expressions are allowed, we need to evaluate them somehow... // TODO Constant expressions are allowed, we need to evaluate them somehow...
// For now we just assume we get literals // For now we just assume we get literals...
// TODO This always uses a lookupswitch, a tableswitch may be faster in some cases but we can't generate that in all cases // TODO This always uses a lookupswitch, a tableswitch may be faster in some cases but we can't generate that every time
// TODO We can't switch on Strings yet, the idea for this (like javac does it) would be to implement the hash code at compile time
// and switch based on that, adding an equals check for every case and going to yet another tableswitch which finally decides which branch to take
var mv = state.mv; var mv = state.mv;
if (aSwitch.isExpression())
state.pushSwitch();
generate(state, aSwitch.expr()); generate(state, aSwitch.expr());
state.enterScope();
var keys = new int[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()]; var keys = new int[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()];
var labels = new Label[keys.length]; var labels = new Label[keys.length];
var bodyLabels = new Label[aSwitch.cases().size()]; var bodyLabels = new Label[aSwitch.cases().size()];
@ -1043,28 +1055,159 @@ public class Codegen {
for (var k = 0; k < aSwitch.cases().size(); k++) { for (var k = 0; k < aSwitch.cases().size(); k++) {
mv.visitLabel(bodyLabels[k]); mv.visitLabel(bodyLabels[k]);
generate(state, aSwitch.cases().get(k).body()); var cse = aSwitch.cases().get(k);
generate(state, cse.body());
if (cse.isSingleExpression() && aSwitch.isExpression())
yieldValue(state, cse.body().statements().get(0).type());
if (aSwitch.isExpression()) mv.visitJumpInsn(GOTO, end);
} }
if (aSwitch.default_() != null) { if (aSwitch.default_() != null) {
mv.visitLabel(defaultLabel); mv.visitLabel(defaultLabel);
generate(state, aSwitch.default_()); generate(state, aSwitch.default_().body());
if (aSwitch.default_().isSingleExpression() && aSwitch.isExpression())
yieldValue(state, aSwitch.default_().body().statements().get(0).type());
} }
mv.visitLabel(end); mv.visitLabel(end);
mv.visitInsn(NOP);
state.breakStack.pop(); state.breakStack.pop();
if (aSwitch.isExpression()) {
mv.visitVarInsn(ALOAD, state.switchResultValue.peek());
unboxPrimitive(state, aSwitch.type());
state.popSwitch();
}
state.exitScope();
} }
private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) { private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) {
var mv = state.mv;
generate(state, aSwitch.expr());
var tmp = state.localCounter++;
mv.visitInsn(DUP);
mv.visitVarInsn(ASTORE, tmp);
state.enterScope();
// This is the index to start the switch from
mv.visitInsn(ICONST_0);
if (aSwitch.isExpression())
state.pushSwitch();
// To be able to skip ahead to the next case
var start = new Label();
mv.visitLabel(start);
var end = new Label();
var env = new BreakEnv();
env.endLabel = end;
state.breakStack.push(env);
var mt = MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, Object[].class);
var bootstrap = new Handle(H_INVOKESTATIC, "java/lang/runtime/SwitchBootstraps", "typeSwitch", mt.toMethodDescriptorString(), false);
var types = new Object[aSwitch.cases().size()];
for (var i = 0; i < types.length; i++) {
var cse = aSwitch.cases().get(i);
var label = cse.labels().get(0);
if (label instanceof TargetSwitch.SimplePattern || label instanceof TargetSwitch.ComplexPattern)
types[i] = Type.getObjectType(label.type().getInternalName());
else if (label instanceof TargetLiteral lit)
types[i] = lit.value();
else if (label instanceof TargetSwitch.Guard guard)
types[i] = Type.getObjectType(guard.inner().type().getInternalName());
// TODO Same here we need to evaluate constants
else throw new NotImplementedException();
}
mv.visitInvokeDynamicInsn("typeSwitch", "(Ljava/lang/Object;I)I", bootstrap, types);
var caseLabels = new Label[aSwitch.cases().size()];
var labels = new Label[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()];
var j = 0;
for (var i = 0; i < caseLabels.length; i++) {
var cse = aSwitch.cases().get(i);
var label = new Label();
caseLabels[i] = label;
for (var k = 0; k < cse.labels().size(); k++) {
labels[j] = label;
j += 1;
}
}
var defaultLabel = end;
if (aSwitch.default_() != null) {
defaultLabel = new Label();
}
mv.visitTableSwitchInsn(0, labels.length - 1, defaultLabel, labels);
for (var i = 0; i < aSwitch.cases().size(); i++) {
mv.visitLabel(caseLabels[i]);
var cse = aSwitch.cases().get(i);
if (cse.labels().size() == 1) {
var label = cse.labels().get(0);
if (label instanceof TargetSwitch.Guard gd)
bindLabel(state, tmp, aSwitch.expr().type(), gd.inner());
else if (label instanceof TargetSwitch.Pattern pat)
bindLabel(state, tmp, aSwitch.expr().type(), pat);
if (label instanceof TargetSwitch.Guard gd) {
generate(state, gd.expression());
var next = new Label();
mv.visitJumpInsn(IFNE, next);
mv.visitVarInsn(ALOAD, tmp);
// Push the offset onto the stack (this is used by the invokedynamic call)
mv.visitLdcInsn(i + 1);
mv.visitJumpInsn(GOTO, start);
mv.visitLabel(next);
}
}
generate(state, cse.body());
if (cse.isSingleExpression() && aSwitch.isExpression())
yieldValue(state, cse.body().statements().get(0).type());
if (aSwitch.isExpression()) mv.visitJumpInsn(GOTO, end);
}
if (aSwitch.default_() != null) {
mv.visitLabel(defaultLabel);
generate(state, aSwitch.default_().body());
if (aSwitch.default_().isSingleExpression() && aSwitch.isExpression())
yieldValue(state, aSwitch.default_().body().statements().get(0).type());
}
mv.visitLabel(end);
//mv.visitInsn(POP);
state.breakStack.pop();
if (aSwitch.isExpression()) {
mv.visitVarInsn(ALOAD, state.switchResultValue.peek());
unboxPrimitive(state, aSwitch.type());
state.popSwitch();
}
state.exitScope();
}
private void bindLabel(State state, int tmp, TargetType type, TargetSwitch.Pattern pat) {
if (pat instanceof TargetSwitch.SimplePattern sp) {
state.mv.visitVarInsn(ALOAD, tmp);
var local = state.createVariable(sp.name(), sp.type());
convertTo(state, type, local.type);
boxPrimitive(state, local.type);
state.mv.visitVarInsn(ASTORE, local.index);
}
} }
final Set<TargetType> wrapperTypes = Set.of(TargetType.Long, TargetType.Integer, TargetType.Byte, TargetType.Char, TargetType.Boolean, TargetType.Double, TargetType.Float); final Set<TargetType> wrapperTypes = Set.of(TargetType.Long, TargetType.Integer, TargetType.Byte, TargetType.Char, TargetType.Boolean, TargetType.Double, TargetType.Float);
private void generateSwitch(State state, TargetSwitch aSwitch) { private void generateSwitch(State state, TargetSwitch aSwitch) {
if (!wrapperTypes.contains(aSwitch.expr().type())) if (!wrapperTypes.contains(aSwitch.expr().type())) {
generateEnhancedSwitch(state, aSwitch); generateEnhancedSwitch(state, aSwitch);
return;
}
else for (var case_ : aSwitch.cases()) { else for (var case_ : aSwitch.cases()) {
if (case_.labels().stream().anyMatch(c -> c instanceof TargetSwitch.Pattern)) { if (case_.labels().stream().anyMatch(c -> c instanceof TargetSwitch.Pattern)) {
generateEnhancedSwitch(state, aSwitch); generateEnhancedSwitch(state, aSwitch);

View File

@ -20,7 +20,6 @@ import de.dhbwstuttgart.parser.antlr.Java17Parser.AssignexpressionContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseandexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseandexpressionContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseorexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwiseorexpressionContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwisexorexpressionContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BitwisexorexpressionContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BlockStatementContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BlockstmtContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BlockstmtContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BoolLiteralContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BoolLiteralContext;
import de.dhbwstuttgart.parser.antlr.Java17Parser.BreakstmtContext; import de.dhbwstuttgart.parser.antlr.Java17Parser.BreakstmtContext;
@ -381,7 +380,7 @@ public class StatementGenerator {
return new Switch(switched, switchBlocks, TypePlaceholder.fresh(offset), false, offset); return new Switch(switched, switchBlocks, TypePlaceholder.fresh(offset), false, offset);
} }
private SwitchBlock convert(Java17Parser.SwitchLabeledRuleContext labeledRule) { private SwitchBlock convert(SwitchLabeledRuleContext labeledRule) {
Boolean isDefault = false; Boolean isDefault = false;
List<SwitchLabel> labels = switch (labeledRule.switchLabelCase()) { List<SwitchLabel> labels = switch (labeledRule.switchLabelCase()) {
case LabeledRuleExprListContext exprList -> { case LabeledRuleExprListContext exprList -> {
@ -401,11 +400,14 @@ public class StatementGenerator {
} }
default -> throw new NotImplementedException(); default -> throw new NotImplementedException();
}; };
var isSingleExpression = false;
Token offset = labeledRule.getStart(); Token offset = labeledRule.getStart();
SwitchRuleOutcomeContext outcome = labeledRule.switchRuleOutcome(); SwitchRuleOutcomeContext outcome = labeledRule.switchRuleOutcome();
Block block; Block block;
if (Objects.isNull(outcome.block())) { if (Objects.isNull(outcome.block())) {
List<Statement> stmts = new ArrayList<>(); List<Statement> stmts = new ArrayList<>();
if (outcome.blockStatement().size() == 1) isSingleExpression = true;
outcome.blockStatement().stream().forEach((stmt) -> { outcome.blockStatement().stream().forEach((stmt) -> {
stmts.addAll(convert(stmt)); stmts.addAll(convert(stmt));
}); });
@ -414,7 +416,7 @@ public class StatementGenerator {
} else { } else {
block = convert(outcome.block(), false); block = convert(outcome.block(), false);
} }
return new SwitchBlock(labels, block, isDefault, offset); return new SwitchBlock(labels, block, isDefault, isSingleExpression, offset);
} }
private Statement convert(Java17Parser.YieldstmtContext yieldstmt) { private Statement convert(Java17Parser.YieldstmtContext yieldstmt) {
@ -430,7 +432,7 @@ public class StatementGenerator {
stmt.blockStatement().stream().forEach((blockStmt) -> { stmt.blockStatement().stream().forEach((blockStmt) -> {
block.addAll(convert(blockStmt)); block.addAll(convert(blockStmt));
}); });
return new SwitchBlock(labels, new Block(block, stmt.blockStatement(0).getStart()), stmt.getStart()); return new SwitchBlock(labels, new Block(block, stmt.blockStatement(0).getStart()), false, stmt.getStart());
} }
private SwitchLabel convert(SwitchLabelContext switchLabel) { private SwitchLabel convert(SwitchLabelContext switchLabel) {
@ -454,17 +456,15 @@ public class StatementGenerator {
} }
} }
private Pattern convert(PatternContext pattern) { private Expression convert(PatternContext pattern) {
return switch (pattern) { return switch (pattern) {
case PPatternContext pPattern -> { case PPatternContext pPattern -> {
yield convert(pPattern.primaryPattern()); yield convert(pPattern.primaryPattern());
} }
case GPatternContext gPattern -> { case GPatternContext gPattern -> {
GuardedPatternContext guarded = gPattern.guardedPattern(); GuardedPatternContext guarded = gPattern.guardedPattern();
List<Expression> conditions = guarded.expression().stream().map((expr) -> { Expression condition = convert(guarded.expression());
return convert(expr); yield new GuardedPattern(condition, convert(guarded.primaryPattern()), guarded.getStart());
}).toList();
yield new GuardedPattern(conditions, guarded.identifier().getText(), TypeGenerator.convert(guarded.typeType(), reg, generics), guarded.getStart());
} }
default -> throw new NotImplementedException(); default -> throw new NotImplementedException();
}; };
@ -487,7 +487,7 @@ public class StatementGenerator {
private RecordPattern convert(RecordPatternContext recordPatternCtx) { private RecordPattern convert(RecordPatternContext recordPatternCtx) {
List<PatternContext> subPatternCtx = recordPatternCtx.recordStructurePattern().recordComponentPatternList().pattern(); List<PatternContext> subPatternCtx = recordPatternCtx.recordStructurePattern().recordComponentPatternList().pattern();
List<Pattern> subPattern = subPatternCtx.stream().map((patternCtx) -> { List<Pattern> subPattern = subPatternCtx.stream().map((patternCtx) -> {
return convert(patternCtx); return (Pattern) convert(patternCtx);
}).collect(Collectors.toList()); }).collect(Collectors.toList());
IdentifierContext identifierCtx = recordPatternCtx.identifier(); IdentifierContext identifierCtx = recordPatternCtx.identifier();
return new RecordPattern(subPattern, (identifierCtx != null) ? identifierCtx.getText() : null, TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics), recordPatternCtx.getStart()); return new RecordPattern(subPattern, (identifierCtx != null) ? identifierCtx.getText() : null, TypeGenerator.convert(recordPatternCtx.typeType(), reg, generics), recordPatternCtx.getStart());

View File

@ -2,21 +2,32 @@ package de.dhbwstuttgart.syntaxtree.statement;
import java.util.List; import java.util.List;
import de.dhbwstuttgart.syntaxtree.StatementVisitor;
import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.Token;
import de.dhbwstuttgart.syntaxtree.type.RefTypeOrTPHOrWildcardOrGeneric; import de.dhbwstuttgart.syntaxtree.type.RefTypeOrTPHOrWildcardOrGeneric;
public class GuardedPattern extends Pattern { public class GuardedPattern extends Expression {
private List<Expression> conditions; private final Expression condition;
private final Pattern nested;
public GuardedPattern(List<Expression> conditions, String name, RefTypeOrTPHOrWildcardOrGeneric type, Token offset) { public GuardedPattern(Expression condition, Pattern nested, Token offset) {
super(name, type, offset); super(nested.getType(), offset);
this.conditions = conditions; this.condition = condition;
this.nested = nested;
} }
public List<Expression> getConditions() { public Expression getCondition() {
return conditions; return condition;
} }
public Pattern getNestedPattern() {
return nested;
}
@Override
public void accept(StatementVisitor visitor) {
visitor.visit(this);
}
} }

View File

@ -12,16 +12,19 @@ public class SwitchBlock extends Block {
private List<SwitchLabel> labels = new ArrayList<>(); private List<SwitchLabel> labels = new ArrayList<>();
private boolean defaultBlock = false; private boolean defaultBlock = false;
public final boolean isExpression; // This is for single expressions that yield a value
public SwitchBlock(List<SwitchLabel> labels, Block statements, Token offset) { public SwitchBlock(List<SwitchLabel> labels, Block statements, boolean isExpression, Token offset) {
super(statements.getStatements(), offset); super(statements.getStatements(), offset);
this.labels = labels; this.labels = labels;
this.isExpression = isExpression;
} }
public SwitchBlock(List<SwitchLabel> labels, Block statements, boolean isDefault, Token offset) { public SwitchBlock(List<SwitchLabel> labels, Block statements, boolean isDefault, boolean isExpression, Token offset) {
super(statements.getStatements(), offset); super(statements.getStatements(), offset);
this.labels = labels; this.labels = labels;
this.defaultBlock = isDefault; this.defaultBlock = isDefault;
this.isExpression = isExpression;
} }
public boolean isDefault() { public boolean isDefault() {

View File

@ -474,11 +474,8 @@ public class OutputGenerator implements ASTVisitor {
@Override @Override
public void visit(GuardedPattern aGuardedPattern) { public void visit(GuardedPattern aGuardedPattern) {
aGuardedPattern.getType().accept(this); aGuardedPattern.getNestedPattern().accept(this);
out.append(aGuardedPattern.getName()); out.append(" with ");
for (Expression cond : aGuardedPattern.getConditions()) { aGuardedPattern.getCondition().accept(this);
out.append("&&");
cond.accept(this);
}
} }
} }

View File

@ -208,7 +208,7 @@ public class ASTToTargetAST {
} }
protected TargetSwitch.Case convert(SwitchBlock block) { protected TargetSwitch.Case convert(SwitchBlock block) {
return new TargetSwitch.Case(block.getLabels().stream().map(this::convert).toList(), convert((Block) block)); return new TargetSwitch.Case(block.getLabels().stream().map(this::convert).toList(), convert((Block) block), block.isExpression);
} }
protected TargetBlock convert(Block block) { protected TargetBlock convert(Block block) {

View File

@ -344,7 +344,13 @@ public class StatementToTargetExpression implements StatementVisitor {
@Override @Override
public void visit(Switch switchStmt) { public void visit(Switch switchStmt) {
var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(converter::convert).toList(); var cases = switchStmt.getBlocks().stream().filter(s -> !s.isDefault()).map(converter::convert).toList();
var default_ = switchStmt.getBlocks().stream().filter(SwitchBlock::isDefault).map(s -> converter.convert((Block) s)).findFirst().orElse(null);
TargetSwitch.Case default_ = null;
for (var block : switchStmt.getBlocks()) {
if (block.isDefault()) {
default_ = new TargetSwitch.Case(converter.convert((Block) block), block.isExpression);
}
}
result = new TargetSwitch(converter.convert(switchStmt.getSwitch()), cases, default_ , converter.convert(switchStmt.getType()), !switchStmt.getStatement()); result = new TargetSwitch(converter.convert(switchStmt.getSwitch()), cases, default_ , converter.convert(switchStmt.getType()), !switchStmt.getStatement());
} }
@ -374,8 +380,6 @@ public class StatementToTargetExpression implements StatementVisitor {
@Override @Override
public void visit(GuardedPattern aGuardedPattern) { public void visit(GuardedPattern aGuardedPattern) {
//FIXME This isn't done properly inside the parser, really you should only have one guard (Chaining them together with && just yields another expression) result = new TargetSwitch.Guard((TargetSwitch.Pattern) converter.convert(aGuardedPattern.getNestedPattern()), converter.convert(aGuardedPattern.getCondition()));
//And then it also needs to be able to accept complex patterns. Because of this we only accept one condition for now.
result = new TargetSwitch.Guard(new TargetSwitch.SimplePattern(converter.convert(aGuardedPattern.getType()), aGuardedPattern.getName()), converter.convert(aGuardedPattern.getConditions().get(0)));
} }
} }

View File

@ -4,26 +4,33 @@ import de.dhbwstuttgart.target.tree.type.TargetType;
import java.util.List; import java.util.List;
public record TargetSwitch(TargetExpression expr, List<Case> cases, TargetBlock default_, TargetType type, boolean isExpression) implements TargetExpression { public record TargetSwitch(TargetExpression expr, List<Case> cases, Case default_, TargetType type, boolean isExpression) implements TargetExpression {
public TargetSwitch(TargetExpression expr, List<Case> cases, TargetBlock default_) { public TargetSwitch(TargetExpression expr, List<Case> cases, Case default_) {
this(expr, cases, default_, null, false); this(expr, cases, default_, null, false);
} }
public TargetSwitch(TargetExpression expr, List<Case> cases, TargetBlock default_, TargetType type) { public TargetSwitch(TargetExpression expr, List<Case> cases, Case default_, TargetType type) {
this(expr, cases, default_, type, true); this(expr, cases, default_, type, true);
} }
public TargetSwitch(TargetExpression expr, List<Case> cases, TargetBlock default_, boolean isExpression) { public TargetSwitch(TargetExpression expr, List<Case> cases, Case default_, boolean isExpression) {
this(expr, cases, default_, null, isExpression); this(expr, cases, default_, null, isExpression);
} }
public record Case(List<TargetExpression> labels, TargetBlock body) {} public record Case(List<TargetExpression> labels, TargetBlock body, boolean isSingleExpression) {
public Case(List<TargetExpression> labels, TargetBlock body) {
this(labels, body, false);
}
public Case(TargetBlock body, boolean isSingleExpression) {
this(List.of(), body, isSingleExpression);
}
}
public sealed interface Pattern extends TargetExpression {} public sealed interface Pattern extends TargetExpression {}
public record SimplePattern(TargetType type, String name) implements Pattern {} public record SimplePattern(TargetType type, String name) implements Pattern {}
public record ComplexPattern(TargetType type, List<Pattern> subPatterns) implements Pattern {} public record ComplexPattern(TargetType type, List<Pattern> subPatterns) implements Pattern {}
public record Guard(TargetExpression inner, TargetExpression expression) implements Pattern {} public record Guard(Pattern inner, TargetExpression expression) implements Pattern {}
} }

View File

@ -200,46 +200,58 @@ public class TestCodegen {
@Test @Test
public void testClassicSwitch() throws Exception { public void testClassicSwitch() throws Exception {
var targetClass = new TargetClass(Opcodes.ACC_PUBLIC , "Switch"); var targetClass = new TargetClass(Opcodes.ACC_PUBLIC , "SwitchClassic");
targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchClassic", List.of(new MethodParameter(TargetType.Integer, "i")), TargetType.Integer, new TargetBlock(List.of( targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchClassic", List.of(new MethodParameter(TargetType.Integer, "i")), TargetType.Integer, new TargetBlock(List.of(
new TargetVarDecl(TargetType.Integer, "res", null), new TargetVarDecl(TargetType.Integer, "res", null),
new TargetSwitch(new TargetLocalVar(TargetType.Integer, "i"), List.of( new TargetSwitch(new TargetLocalVar(TargetType.Integer, "i"), List.of(
new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(10)), new TargetBlock( new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(10)), new TargetBlock(
List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(0)), new TargetBreak()) List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(0)), new TargetBreak())
)), )),
new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(20)), new TargetBlock(List.of())), new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(15), new TargetLiteral.IntLiteral(20)), new TargetBlock(List.of())),
new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(30)), new TargetBlock( new TargetSwitch.Case(List.of(new TargetLiteral.IntLiteral(30)), new TargetBlock(
List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(1)), new TargetBreak()) List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(1)), new TargetBreak())
)) ))
), new TargetBlock( ), new TargetSwitch.Case(new TargetBlock(
List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(2)), new TargetBreak()) List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(2)), new TargetBreak())
)), ), false)),
new TargetReturn(new TargetLocalVar(TargetType.Integer, "res")) new TargetReturn(new TargetLocalVar(TargetType.Integer, "res"))
))); )));
var clazz = generateClass(targetClass, new ByteArrayClassLoader()); var clazz = generateClass(targetClass, new ByteArrayClassLoader());
var m = clazz.getDeclaredMethod("switchClassic", Integer.class);
assertEquals(m.invoke(null, 10), 0);
assertEquals(m.invoke(null, 15), 1);
assertEquals(m.invoke(null, 20), 1);
assertEquals(m.invoke(null, 30), 1);
assertEquals(m.invoke(null, 99), 2);
} }
@Test @Test
public void testTypeSwitch() throws Exception { public void testTypeSwitch() throws Exception {
var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, "Switch"); var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, "SwitchEnhanced");
targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchType", List.of(new MethodParameter(TargetType.Object, "obj")), TargetType.Integer, new TargetBlock(List.of( targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchType", List.of(new MethodParameter(TargetType.Object, "obj")), TargetType.Integer, new TargetBlock(List.of(
new TargetReturn(new TargetSwitch(new TargetLocalVar(TargetType.Object, "obj"), List.of( new TargetReturn(new TargetSwitch(new TargetLocalVar(TargetType.Object, "obj"), List.of(
new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.String, "aString")), new TargetBlock( new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.String, "aString")), new TargetBlock(
List.of(new TargetLiteral.IntLiteral(0)) List.of(new TargetYield(new TargetLiteral.IntLiteral(0)))
)), ), false),
new TargetSwitch.Case(List.of(
new TargetSwitch.Guard(new TargetSwitch.SimplePattern(TargetType.Integer, "i"), new TargetBinaryOp.Less(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "i"), new TargetLiteral.IntLiteral(10)))
), new TargetBlock(
List.of(new TargetLiteral.IntLiteral(3))
), true),
new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.Integer, "anInteger")), new TargetBlock( new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.Integer, "anInteger")), new TargetBlock(
List.of(new TargetLiteral.IntLiteral(1)) List.of(new TargetLiteral.IntLiteral(1))
)) ), true)
), new TargetBlock( ), new TargetSwitch.Case(new TargetBlock(
List.of(new TargetLiteral.IntLiteral(2)) List.of(new TargetLiteral.IntLiteral(2))
), TargetType.Integer) ), true), TargetType.Integer)
)))); ))));
var clazz = generateClass(targetClass, new ByteArrayClassLoader()); var clazz = generateClass(targetClass, new ByteArrayClassLoader());
var m = clazz.getDeclaredMethod("switchType", Object.class); var m = clazz.getDeclaredMethod("switchType", Object.class);
assertEquals(m.invoke(null, "String"), 0); assertEquals(m.invoke(null, "String"), 0);
assertEquals(m.invoke(null, 10), 1); assertEquals(m.invoke(null, 10), 1);
assertEquals(m.invoke(null, 'A'), 2); assertEquals(m.invoke(null, 'A'), 2);
assertEquals(m.invoke(null, 5), 3);
} }
@Test @Test