Simple record patterns in method headers

This commit is contained in:
Daniel Holle 2023-08-18 17:15:15 +02:00
parent 5f1f698530
commit e414da3369
7 changed files with 80 additions and 34 deletions

View File

@ -0,0 +1,13 @@
import java.lang.Integer;
record Point(Integer x, Integer y) {}
public class OverloadPattern {
m(Point(Integer x, Integer y)) {
return x + y;
}
m(Integer x) {
return x;
}
}

View File

@ -1019,7 +1019,7 @@ public class Codegen {
return;
}
throw new NotImplementedException();
}
private void yieldValue(State state, TargetType type) {
@ -1210,6 +1210,14 @@ public class Codegen {
state.exitScope();
}
private void extractField(State state, TargetType type, int i, ClassOrInterface clazz) {
if (i >= clazz.getFieldDecl().size())
throw new CodeGenException("Couldn't find suitable field accessor for '" + type.name() + "'");
var field = clazz.getFieldDecl().get(i);
var fieldType = new TargetRefType(((RefType) field.getType()).getName().toString());
state.mv.visitMethodInsn(INVOKEVIRTUAL, type.getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false);
}
private void bindPattern(State state, TargetType type, TargetPattern pat, Label start, int index, int depth) {
if (pat.type() instanceof TargetPrimitiveType)
boxPrimitive(state, pat.type());
@ -1247,11 +1255,7 @@ public class Codegen {
state.mv.visitInsn(DUP);
var subPattern = cp.subPatterns().get(i);
if (i >= clazz.getFieldDecl().size())
throw new CodeGenException("Couldn't find suitable field accessor for '" + cp.type().name() + "'");
var field = clazz.getFieldDecl().get(i);
var fieldType = new TargetRefType(((RefType) field.getType()).getName().toString());
state.mv.visitMethodInsn(INVOKEVIRTUAL, cp.type().getInternalName(), field.getName(), "()" + fieldType.toDescriptor(), false);
extractField(state, cp.type(), i, clazz);
bindPattern(state, subPattern.type(), subPattern, start, index, depth + 1);
}
state.mv.visitInsn(POP);
@ -1312,6 +1316,30 @@ public class Codegen {
mv.visitEnd();
}
private int bindLocalVariables(State state, TargetPattern pattern, int offset, int field) {
if (pattern instanceof TargetComplexPattern cp) {
state.mv.visitVarInsn(ALOAD, offset);
var clazz = findClass(new JavaClassName(cp.type().name()));
if (clazz == null) throw new CodeGenException("Class definition for '" + cp.type().name() + "' not found");
for (var i = 0; i < cp.subPatterns().size(); i++) {
var subPattern = cp.subPatterns().get(i);
if (i < cp.subPatterns().size() - 1)
state.mv.visitInsn(DUP);
extractField(state, cp.type(), i, clazz);
state.mv.visitVarInsn(ASTORE, offset);
offset = bindLocalVariables(state, subPattern, offset, i);
}
} else if (pattern instanceof TargetTypePattern tp) {
offset++;
state.createVariable(tp.name(), tp.type());
} else throw new NotImplementedException();
return offset;
}
private void generateMethod(TargetMethod method) {
// TODO The older codegen has set ACC_PUBLIC for all methods, good for testing but bad for everything else
MethodVisitor mv = cw.visitMethod(method.access() | ACC_PUBLIC, method.name(), method.getDescriptor(), method.getSignature(), null);
@ -1322,10 +1350,7 @@ public class Codegen {
mv.visitCode();
var state = new State(method.signature().returnType(), mv, method.isStatic() ? 0 : 1);
for (var param : method.signature().parameters()) {
var pattern = param.pattern();
if (pattern instanceof TargetTypePattern tp)
state.createVariable(tp.name(), tp.type());
else throw new NotImplementedException();
bindLocalVariables(state, param.pattern(), 1, 0);
}
generate(state, method.block());
if (method.signature().returnType() == null)

View File

@ -17,7 +17,6 @@ import de.dhbwstuttgart.syntaxtree.GenericTypeVar;
import de.dhbwstuttgart.syntaxtree.Method;
import de.dhbwstuttgart.syntaxtree.ParameterList;
import de.dhbwstuttgart.syntaxtree.SourceFile;
import de.dhbwstuttgart.syntaxtree.FormalParameter;
import de.dhbwstuttgart.syntaxtree.GenericDeclarationList;
import de.dhbwstuttgart.syntaxtree.factory.ASTFactory;
import de.dhbwstuttgart.syntaxtree.factory.UnifyTypeFactory;
@ -44,7 +43,6 @@ import de.dhbwstuttgart.typeinference.unify.model.PairOperator;
import de.dhbwstuttgart.typeinference.unify.model.PlaceholderType;
import de.dhbwstuttgart.typeinference.unify.model.UnifyPair;
import de.dhbwstuttgart.typeinference.unify.model.UnifyType;
import de.dhbwstuttgart.util.BiRelation;
import de.dhbwstuttgart.typeinference.unify.TypeUnifyTask;
import de.dhbwstuttgart.typeinference.unify.UnifyResultListener;
import de.dhbwstuttgart.typeinference.unify.UnifyResultListenerImpl;
@ -62,7 +60,6 @@ import java.util.Map.Entry;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.antlr.v4.runtime.Token;
import org.apache.commons.io.output.NullOutputStream;
public class JavaTXCompiler {

View File

@ -161,16 +161,19 @@ public class StatementGenerator {
fps = formalParameterListContext.formalParameter();
for (Java17Parser.FormalParameterContext fp : fps) {
if (fp.pattern() != null) throw new NotImplementedException();
String paramName = SyntaxTreeGenerator.convert(fp.variableDeclaratorId());
RefTypeOrTPHOrWildcardOrGeneric type;
if (fp.typeType() != null) {
type = TypeGenerator.convert(fp.typeType(), reg, generics);
if (fp.pattern() != null) {
ret.add(convert(fp.pattern()));
} else {
type = TypePlaceholder.fresh(fp.getStart());
String paramName = SyntaxTreeGenerator.convert(fp.variableDeclaratorId());
RefTypeOrTPHOrWildcardOrGeneric type;
if (fp.typeType() != null) {
type = TypeGenerator.convert(fp.typeType(), reg, generics);
} else {
type = TypePlaceholder.fresh(fp.getStart());
}
ret.add(new FormalParameter(paramName, type, fp.getStart()));
localVars.put(paramName, type);
}
ret.add(new FormalParameter(paramName, type, fp.getStart()));
localVars.put(paramName, type);
}
return new ParameterList(ret, ret.get(0).getOffset());
}

View File

@ -9,10 +9,11 @@ import de.dhbwstuttgart.syntaxtree.type.RefTypeOrTPHOrWildcardOrGeneric;
public class RecordPattern extends FormalParameter {
private List<Pattern> subPattern = new ArrayList<>();
private final List<Pattern> subPattern;
public RecordPattern(String name, RefTypeOrTPHOrWildcardOrGeneric type, Token offset) {
super(name, type, offset);
subPattern = new ArrayList<>();
}
public RecordPattern(List<Pattern> subPattern, String name, RefTypeOrTPHOrWildcardOrGeneric type, Token offset) {
@ -24,10 +25,6 @@ public class RecordPattern extends FormalParameter {
return this.subPattern;
}
public void addSubPattern(Pattern newPattern) {
this.subPattern.add(newPattern);
}
@Override
public void accept(ASTVisitor visitor) {
visitor.visit(this);

View File

@ -10,10 +10,7 @@ import de.dhbwstuttgart.syntaxtree.factory.ASTFactory;
import de.dhbwstuttgart.syntaxtree.statement.*;
import de.dhbwstuttgart.syntaxtree.type.*;
import de.dhbwstuttgart.target.tree.*;
import de.dhbwstuttgart.target.tree.expression.TargetBlock;
import de.dhbwstuttgart.target.tree.expression.TargetExpression;
import de.dhbwstuttgart.target.tree.expression.TargetSwitch;
import de.dhbwstuttgart.target.tree.expression.TargetTypePattern;
import de.dhbwstuttgart.target.tree.expression.*;
import de.dhbwstuttgart.target.tree.type.*;
import de.dhbwstuttgart.typeinference.result.*;
@ -152,10 +149,9 @@ public class ASTToTargetAST {
}
private List<MethodParameter> convert(ParameterList input, GenerateGenerics generics) {
return input.getFormalparalist().stream().map(param -> switch(param) {
case FormalParameter fpm -> new MethodParameter(new TargetTypePattern(convert(param.getType(), generics), fpm.getName()));
default -> throw new NotImplementedException();
}).toList();
return input.getFormalparalist().stream().map(param ->
new MethodParameter((TargetPattern) convert(param))
).toList();
}
private boolean hasGeneric(Set<TargetGeneric> generics, GenericRefType type) {

View File

@ -691,4 +691,19 @@ public class TestComplete {
var clazz = classFiles.get("InstanceOf");
var instance = clazz.getDeclaredConstructor().newInstance();
}
@Test
public void testOverloadPattern() throws Exception {
var classFiles = generateClassFiles(new ByteArrayClassLoader(), "OverloadPattern.jav");
var clazz = classFiles.get("OverloadPattern");
var rec = classFiles.get("Point");
var instance = clazz.getDeclaredConstructor().newInstance();
var m1 = clazz.getDeclaredMethod("m", rec);
var m2 = clazz.getDeclaredMethod("m", Integer.class);
var pt = rec.getDeclaredConstructor(Integer.class, Integer.class).newInstance(10, 20);
assertEquals(m1.invoke(instance, pt), 30);
assertEquals(m2.invoke(instance, 10), 10);
}
}