Add classic switch

This commit is contained in:
Daniel Holle 2023-07-27 10:02:28 +02:00
parent 3de9fde672
commit be55d661cb
6 changed files with 132 additions and 11 deletions

View File

@ -1,6 +1,7 @@
package de.dhbwstuttgart.bytecode;
import de.dhbwstuttgart.exceptions.NotImplementedException;
import de.dhbwstuttgart.syntaxtree.statement.Break;
import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*;
@ -10,6 +11,7 @@ import java.lang.invoke.CallSite;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.sql.Array;
import java.util.*;
import static org.objectweb.asm.Opcodes.*;
@ -56,12 +58,20 @@ public class Codegen {
}
}
private static class BreakEnv {
String labelName; // TODO This is for labeled statements (Not implemented)
Label startLabel;
Label endLabel;
}
private static class State {
Scope scope = new Scope(null);
int localCounter;
MethodVisitor mv;
TargetType returnType;
Stack<BreakEnv> breakStack = new Stack<>();
State(TargetType returnType, MethodVisitor mv, int localCounter) {
this.returnType = returnType;
this.mv = mv;
@ -863,7 +873,15 @@ public class Codegen {
else
mv.visitInsn(ICONST_1);
mv.visitJumpInsn(IFEQ, end);
var env = new BreakEnv();
env.startLabel = start;
env.endLabel = end;
state.breakStack.push(env);
generate(state, _for.body());
state.breakStack.pop();
if (_for.increment() != null) {
generate(state, _for.increment());
if (_for.increment().type() != null) {
@ -872,6 +890,7 @@ public class Codegen {
}
mv.visitJumpInsn(GOTO, start);
mv.visitLabel(end);
mv.visitInsn(NOP);
state.exitScope();
state.localCounter = localCounter;
break;
@ -882,9 +901,18 @@ public class Codegen {
mv.visitLabel(start);
generate(state, _while.cond());
mv.visitJumpInsn(IFEQ, end);
var env = new BreakEnv();
env.startLabel = start;
env.endLabel = end;
state.breakStack.push(env);
generate(state, _while.body());
state.breakStack.pop();
mv.visitJumpInsn(GOTO, start);
mv.visitLabel(end);
mv.visitInsn(NOP);
break;
}
case TargetIf _if: {
@ -899,6 +927,7 @@ public class Codegen {
generate(state, _if.else_body());
}
mv.visitLabel(end);
mv.visitInsn(NOP);
break;
}
case TargetReturn ret: {
@ -911,6 +940,22 @@ public class Codegen {
mv.visitInsn(RETURN);
break;
}
case TargetSwitch _switch: {
generateSwitch(state, _switch);
break;
}
case TargetBreak brk: {
if (state.breakStack.isEmpty()) throw new CodeGenException("Break outside of switch or loop");
mv.visitJumpInsn(GOTO, state.breakStack.peek().endLabel);
break;
}
case TargetContinue cnt: {
if (state.breakStack.isEmpty()) throw new CodeGenException("Continue outside of loop");
var env = state.breakStack.peek();
if (env.startLabel == null) throw new CodeGenException("Continue outside of loop");
mv.visitJumpInsn(GOTO, env.startLabel);
break;
}
case TargetThis _this: {
mv.visitVarInsn(ALOAD, 0);
break;
@ -959,6 +1004,76 @@ public class Codegen {
}
}
private void generateClassicSwitch(State state, TargetSwitch aSwitch) {
// TODO Constant expressions are allowed, we need to evaluate them somehow...
// For now we just assume we get literals
// TODO This always uses a lookupswitch, a tableswitch may be faster in some cases but we can't generate that in all cases
var mv = state.mv;
generate(state, aSwitch.expr());
var keys = new int[aSwitch.cases().stream().mapToInt(c -> c.labels().size()).sum()];
var labels = new Label[keys.length];
var bodyLabels = new Label[aSwitch.cases().size()];
var end = new Label();
var env = new BreakEnv();
env.endLabel = end;
state.breakStack.push(env);
var i = 0;
var j = 0;
for (var case_ : aSwitch.cases()) {
bodyLabels[j] = new Label();
for (var label : case_.labels()) {
if (!(label instanceof TargetLiteral literal))
throw new CodeGenException("Labels may only be constants for now");
keys[i] = (int) literal.value();
labels[i] = bodyLabels[j];
i += 1;
}
j += 1;
}
var defaultLabel = end;
if (aSwitch.default_() != null) {
defaultLabel = new Label();
}
mv.visitLookupSwitchInsn(defaultLabel, keys, labels);
for (var k = 0; k < aSwitch.cases().size(); k++) {
mv.visitLabel(bodyLabels[k]);
generate(state, aSwitch.cases().get(k).body());
}
if (aSwitch.default_() != null) {
mv.visitLabel(defaultLabel);
generate(state, aSwitch.default_());
}
mv.visitLabel(end);
mv.visitInsn(NOP);
state.breakStack.pop();
}
private void generateEnhancedSwitch(State state, TargetSwitch aSwitch) {
}
final Set<TargetType> wrapperTypes = Set.of(TargetType.Long, TargetType.Integer, TargetType.Byte, TargetType.Char, TargetType.Boolean, TargetType.Double, TargetType.Float);
private void generateSwitch(State state, TargetSwitch aSwitch) {
if (!wrapperTypes.contains(aSwitch.expr().type()))
generateEnhancedSwitch(state, aSwitch);
else for (var case_ : aSwitch.cases()) {
if (case_.labels().stream().anyMatch(c -> c instanceof TargetSwitch.Pattern)) {
generateEnhancedSwitch(state, aSwitch);
return;
}
}
generateClassicSwitch(state, aSwitch);
}
private void generateField(TargetField field) {
cw.visitField(field.access() | ACC_PUBLIC, field.name(), field.type().toSignature(), field.type().toDescriptor(), null);
}

View File

@ -361,14 +361,14 @@ public class StatementToTargetExpression implements StatementVisitor {
@Override
public void visit(Pattern aPattern) {
result = new TargetSwitch.Pattern(converter.convert(aPattern.getType()), aPattern.getName());
result = new TargetSwitch.SimplePattern(converter.convert(aPattern.getType()), aPattern.getName());
}
@Override
public void visit(RecordPattern aRecordPattern) {
result = new TargetSwitch.ComplexPattern(
converter.convert(aRecordPattern.getType()),
aRecordPattern.getSubPattern().stream().map(x -> (TargetSwitch.Pattern) converter.convert(x)).toList()
converter.convert(aRecordPattern.getType()),
aRecordPattern.getSubPattern().stream().map(x -> (TargetSwitch.Pattern) converter.convert(x)).toList()
);
}
@ -376,6 +376,6 @@ public class StatementToTargetExpression implements StatementVisitor {
public void visit(GuardedPattern aGuardedPattern) {
//FIXME This isn't done properly inside the parser, really you should only have one guard (Chaining them together with && just yields another expression)
//And then it also needs to be able to accept complex patterns. Because of this we only accept one condition for now.
result = new TargetSwitch.Guard(new TargetSwitch.Pattern(converter.convert(aGuardedPattern.getType()), aGuardedPattern.getName()), converter.convert(aGuardedPattern.getConditions().get(0)));
result = new TargetSwitch.Guard(new TargetSwitch.SimplePattern(converter.convert(aGuardedPattern.getType()), aGuardedPattern.getName()), converter.convert(aGuardedPattern.getConditions().get(0)));
}
}

View File

@ -2,5 +2,6 @@ package de.dhbwstuttgart.target.tree.expression;
import de.dhbwstuttgart.target.tree.type.TargetType;
// TODO This needs a label
public record TargetBreak() implements TargetExpression {
}

View File

@ -3,7 +3,7 @@ package de.dhbwstuttgart.target.tree.expression;
import de.dhbwstuttgart.target.tree.type.*;
public sealed interface TargetExpression
permits TargetBinaryOp, TargetBlock, TargetBreak, TargetCast, TargetClassName, TargetContinue, TargetFieldVar, TargetFor, TargetForEach, TargetIf, TargetInstanceOf, TargetLambdaExpression, TargetLiteral, TargetLocalVar, TargetReturn, TargetStatementExpression, TargetSuper, TargetSwitch, TargetSwitch.ComplexPattern, TargetSwitch.Guard, TargetSwitch.Pattern, TargetTernary, TargetThis, TargetUnaryOp, TargetVarDecl, TargetWhile, TargetYield {
permits TargetBinaryOp, TargetBlock, TargetBreak, TargetCast, TargetClassName, TargetContinue, TargetFieldVar, TargetFor, TargetForEach, TargetIf, TargetInstanceOf, TargetLambdaExpression, TargetLiteral, TargetLocalVar, TargetReturn, TargetStatementExpression, TargetSuper, TargetSwitch, TargetSwitch.Pattern, TargetTernary, TargetThis, TargetUnaryOp, TargetVarDecl, TargetWhile, TargetYield {
default TargetType type() {
return null;

View File

@ -20,8 +20,10 @@ public record TargetSwitch(TargetExpression expr, List<Case> cases, TargetBlock
public record Case(List<TargetExpression> labels, TargetBlock body) {}
public record Pattern(TargetType type, String name) implements TargetExpression {}
public record ComplexPattern(TargetType type, List<Pattern> subPatterns) implements TargetExpression {}
public sealed interface Pattern extends TargetExpression {}
public record Guard(TargetExpression inner, TargetExpression expression) implements TargetExpression {}
public record SimplePattern(TargetType type, String name) implements Pattern {}
public record ComplexPattern(TargetType type, List<Pattern> subPatterns) implements Pattern {}
public record Guard(TargetExpression inner, TargetExpression expression) implements Pattern {}
}

View File

@ -213,8 +213,11 @@ public class TestCodegen {
))
), new TargetBlock(
List.of(new TargetAssign(TargetType.Integer, new TargetLocalVar(TargetType.Integer, "res"), new TargetLiteral.IntLiteral(2)), new TargetBreak())
))
)),
new TargetReturn(new TargetLocalVar(TargetType.Integer, "res"))
)));
var clazz = generateClass(targetClass, new ByteArrayClassLoader());
}
@Test
@ -222,10 +225,10 @@ public class TestCodegen {
var targetClass = new TargetClass(Opcodes.ACC_PUBLIC, "Switch");
targetClass.addMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, "switchType", List.of(new MethodParameter(TargetType.Object, "obj")), TargetType.Integer, new TargetBlock(List.of(
new TargetReturn(new TargetSwitch(new TargetLocalVar(TargetType.Object, "obj"), List.of(
new TargetSwitch.Case(List.of(new TargetSwitch.Pattern(TargetType.String, "aString")), new TargetBlock(
new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.String, "aString")), new TargetBlock(
List.of(new TargetLiteral.IntLiteral(0))
)),
new TargetSwitch.Case(List.of(new TargetSwitch.Pattern(TargetType.Integer, "anInteger")), new TargetBlock(
new TargetSwitch.Case(List.of(new TargetSwitch.SimplePattern(TargetType.Integer, "anInteger")), new TargetBlock(
List.of(new TargetLiteral.IntLiteral(1))
))
), new TargetBlock(