Add initial typechecker for AST #2

Merged
mrab merged 121 commits from typedAST into master 2024-06-14 07:53:30 +00:00
Showing only changes of commit 207fb5c5f3 - Show all commits

View File

@ -92,7 +92,7 @@ methodDescriptor (MethodDeclaration returntype _ parameters _) = let
"(" "("
++ (concat (map methodParameterDescriptor parameter_types)) ++ (concat (map methodParameterDescriptor parameter_types))
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ methodParameterDescriptor returntype
classBuilder :: ClassFileBuilder Class classBuilder :: ClassFileBuilder Class
@ -160,11 +160,12 @@ methodBuilder (MethodDeclaration returntype name parameters statement) input = l
memberDescriptorIndex = (fromIntegral (baseIndex + 1)), memberDescriptorIndex = (fromIntegral (baseIndex + 1)),
memberAttributes = [] memberAttributes = []
} }
in in
input { input {
constantPool = (constantPool input) ++ constants, constantPool = (constantPool input) ++ constants,
methods = (methods input) ++ [method] methods = (methods input) ++ [method]
} }
methodAssembler :: ClassFileBuilder MethodDeclaration methodAssembler :: ClassFileBuilder MethodDeclaration
@ -174,8 +175,9 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input =
Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) Nothing -> error ("Cannot find method entry in method pool for method: " ++ name)
Just index -> let Just index -> let
declaration = MethodDeclaration returntype name parameters statement declaration = MethodDeclaration returntype name parameters statement
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
(pre, method : post) = splitAt index (methods input) (pre, method : post) = splitAt index (methods input)
(_, bytecode) = assembleMethod (constantPool input, []) declaration (_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration
assembledMethod = method { assembledMethod = method {
memberAttributes = [ memberAttributes = [
CodeAttribute { CodeAttribute {
@ -193,7 +195,7 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input =
type Assembler a = ([ConstantInfo], [Operation]) -> a -> ([ConstantInfo], [Operation]) type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String])
returnOperation :: DataType -> Operation returnOperation :: DataType -> Operation
returnOperation dtype returnOperation dtype
@ -219,78 +221,93 @@ comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLoc
assembleMethod :: Assembler MethodDeclaration assembleMethod :: Assembler MethodDeclaration
assembleMethod (constants, ops) (MethodDeclaration _ name _ (TypedStatement _ (Block statements))) assembleMethod (constants, ops, lvars) (MethodDeclaration _ name _ (TypedStatement _ (Block statements)))
| name == "<init>" = let | name == "<init>" = let
(constants_a, ops_a) = foldl assembleStatement (constants, ops) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0, Opinvokespecial 2] init_ops = [Opaload 0, Opinvokespecial 2]
in in
(constants_a, init_ops ++ ops_a ++ [Opreturn]) (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a)
| otherwise = let | otherwise = let
(constants_a, ops_a) = foldl assembleStatement (constants, ops) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0] init_ops = [Opaload 0]
in in
(constants_a, init_ops ++ ops_a) (constants_a, init_ops ++ ops_a, lvars_a)
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt) assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt)
assembleStatement :: Assembler Statement assembleStatement :: Assembler Statement
assembleStatement (constants, ops) (TypedStatement stype (Return expr)) = case expr of assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of
Nothing -> (constants, ops ++ [Opreturn]) Nothing -> (constants, ops ++ [Opreturn], lvars)
Just expr -> let Just expr -> let
(expr_constants, expr_ops) = assembleExpression (constants, ops) expr (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr
in in
(expr_constants, expr_ops ++ [returnOperation stype]) (expr_constants, expr_ops ++ [returnOperation stype], lvars)
assembleStatement (constants, ops) (TypedStatement _ (Block statements)) = assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) =
foldl assembleStatement (constants, ops) statements foldl assembleStatement (constants, ops, lvars) statements
assembleStatement (constants, ops) (TypedStatement _ (If expr if_stmt else_stmt)) = let assembleStatement (constants, ops, lvars) (TypedStatement _ (If expr if_stmt else_stmt)) = let
(constants_cmp, ops_cmp) = assembleExpression (constants, []) expr (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
(constants_ifa, ops_ifa) = assembleStatement (constants_cmp, []) if_stmt (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt
(constants_elsea, ops_elsea) = case else_stmt of (constants_elsea, ops_elsea, _) = case else_stmt of
Nothing -> (constants_ifa, []) Nothing -> (constants_ifa, [], lvars)
Just stmt -> assembleStatement (constants_ifa, []) stmt Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt
-- +6 because we insert 2 gotos, one for if, one for else -- +6 because we insert 2 gotos, one for if, one for else
if_length = sum (map opcodeEncodingLength ops_ifa) + 6 if_length = sum (map opcodeEncodingLength ops_ifa) + 6
-- +3 because we need to account for the goto in the if statement. -- +3 because we need to account for the goto in the if statement.
else_length = sum (map opcodeEncodingLength ops_elsea) + 3 else_length = sum (map opcodeEncodingLength ops_elsea) + 3
in in
(constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea) (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea, lvars)
assembleStatement stmt _ = error ("Not yet implemented: " ++ show stmt) assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let
isPrimitive = elem dtype ["char", "boolean", "int"]
(constants_init, ops_init, _) = case expr of
Just exp -> assembleExpression (constants, ops, lvars) exp
Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars)
localIndex = fromIntegral (length lvars)
loadLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex]
in
(constants_init, ops_init ++ loadLocal, lvars ++ [name])
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt)
assembleExpression :: Assembler Expression assembleExpression :: Assembler Expression
assembleExpression (constants, ops) (TypedExpression _ (BinaryOperation op a b)) assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
| elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let
(aConstants, aOps) = assembleExpression (constants, ops) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps) = assembleExpression (aConstants, aOps) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
in in
(bConstants, bOps ++ [binaryOperation op]) (bConstants, bOps ++ [binaryOperation op], lvars)
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
(aConstants, aOps) = assembleExpression (constants, ops) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps) = assembleExpression (aConstants, aOps) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
cmp_op = comparisonOperation op 9 cmp_op = comparisonOperation op 9
cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1]
in in
(bConstants, bOps ++ cmp_ops) (bConstants, bOps ++ cmp_ops, lvars)
assembleExpression (constants, ops) (TypedExpression _ (CharacterLiteral literal)) = assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) =
(constants, ops ++ [Opsipush (fromIntegral (ord literal))]) (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars)
assembleExpression (constants, ops) (TypedExpression _ (BooleanLiteral literal)) = assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) =
(constants, ops ++ [Opsipush (if literal then 1 else 0)]) (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars)
assembleExpression (constants, ops) (TypedExpression _ (IntegerLiteral literal)) assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal))
| literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)]) | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars)
| otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))]) | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars)
assembleExpression (constants, ops) (TypedExpression _ NullLiteral) = assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) =
(constants, ops ++ [Opaconst_null]) (constants, ops ++ [Opaconst_null], lvars)
assembleExpression (constants, ops) (TypedExpression etype (UnaryOperation Not expr)) = let assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let
(exprConstants, exprOps) = assembleExpression (constants, ops) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
newConstant = fromIntegral (1 + length exprConstants) newConstant = fromIntegral (1 + length exprConstants)
in case etype of in case etype of
"int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor]) "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars)
"char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor]) "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars)
"boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor]) "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars)
assembleExpression (constants, ops) (TypedExpression _ (UnaryOperation Minus expr)) = let assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let
(exprConstants, exprOps) = assembleExpression (constants, ops) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in in
(exprConstants, exprOps ++ [Opineg]) (exprConstants, exprOps ++ [Opineg], lvars)
assembleExpression (constants, ops) (TypedExpression _ (FieldVariable name)) = let assembleExpression (constants, ops, lvars) (TypedExpression _ (FieldVariable name)) = let
fieldIndex = findFieldIndex constants name fieldIndex = findFieldIndex constants name
in case fieldIndex of in case fieldIndex of
Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)]) Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)], lvars)
Nothing -> error ("No such field found in constant pool: " ++ name) Nothing -> error ("No such field found in constant pool: " ++ name)
assembleExpression (constants, ops, lvars) (TypedExpression _ (LocalVariable name)) = let
localIndex = findIndex ((==) name) lvars
in case localIndex of
Just index -> (constants, ops ++ [Opiload (fromIntegral index)], lvars)
Nothing -> error ("No such local variable found in local variable pool: " ++ name)
assembleExpression _ expr = error ("unimplemented: " ++ show expr)