Add initial typechecker for AST #2

Merged
mrab merged 121 commits from typedAST into master 2024-06-14 07:53:30 +00:00
2 changed files with 54 additions and 30 deletions
Showing only changes of commit 1d5463582f - Show all commits

View File

@ -58,7 +58,7 @@ exampleConstructorCall :: Statement
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))) exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
exampleNameResolution :: Expression exampleNameResolution :: Expression
exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age") exampleNameResolution = BinaryOperation NameResolution (Reference "bob2") (Reference "age")
exampleBlockResolution :: Statement exampleBlockResolution :: Statement
exampleBlockResolution = Block [ exampleBlockResolution = Block [
@ -113,10 +113,12 @@ testClasses = [
(Block [ (Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))), LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]), StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))), LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob2") "getAge" [])))),
Return (Just (Reference "bobAge")) Return (Just (Reference "bobAge"))
]) ])
] [] ] [
VariableDeclaration "Person" "bob2" Nothing
]
] ]
runTypeCheck :: IO () runTypeCheck :: IO ()
@ -151,7 +153,7 @@ runTypeCheck = do
catch (do catch (do
print "=====================================================================================" print "====================================================================================="
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses) evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("this", "Main")] testClasses)
printSuccess "Type checking of name resolution completed successfully" printSuccess "Type checking of name resolution completed successfully"
printResult "Result Name Resolution:" evaluatedNameResolution printResult "Result Name Resolution:" evaluatedNameResolution
) handleError ) handleError
@ -189,7 +191,7 @@ runTypeCheck = do
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
case mainClass of case mainClass of
Class _ [mainMethod] _ -> do Class _ [mainMethod] _ -> do
let result = typeCheckMethodDeclaration mainMethod [] testClasses let result = typeCheckMethodDeclaration mainMethod [("this", "Main")] testClasses
printSuccess "Full program type checking completed successfully." printSuccess "Full program type checking completed successfully."
printResult "Main method result:" result printResult "Main method result:" result
) handleError ) handleError

View File

@ -10,9 +10,9 @@ typeCheckClass :: Class -> [Class] -> Class
typeCheckClass (Class className methods fields) classes = typeCheckClass (Class className methods fields) classes =
let let
-- Create a symbol table from class fields and method entries -- Create a symbol table from class fields and method entries
classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] -- TODO: Maybe remove method entries from the symbol table?
methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods] methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
initalSymTab = ("this", className) : classFields ++ methodEntries initalSymTab = ("this", className) : methodEntries
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
in Class className checkedMethods fields in Class className checkedMethods fields
@ -37,8 +37,21 @@ typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (Character
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral
typeCheckExpression (Reference id) symtab classes = typeCheckExpression (Reference id) symtab classes =
let type' = lookupType id symtab case lookup id symtab of
in TypedExpression type' (Reference id) Just t -> TypedExpression t (LocalVariable id)
Nothing ->
case lookup "this" symtab of
Just className ->
let classDetails = find (\(Class name _ _) -> name == className) classes
in case classDetails of
Just (Class _ _ fields) ->
let fieldTypes = [dt | VariableDeclaration dt fieldId _ <- fields, fieldId == id]
in case fieldTypes of
[fieldType] -> TypedExpression fieldType (FieldVariable id)
[] -> error $ "Field '" ++ id ++ "' not found in class '" ++ className ++ "'"
_ -> error $ "Ambiguous reference to field '" ++ id ++ "' in class '" ++ className ++ "'"
Nothing -> error $ "Class '" ++ className ++ "' not found for 'this'"
Nothing -> error $ "Context for 'this' not found in symbol table, unable to resolve '" ++ id ++ "'"
typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
let expr1' = typeCheckExpression expr1 symtab classes let expr1' = typeCheckExpression expr1 symtab classes
expr2' = typeCheckExpression expr2 symtab classes expr2' = typeCheckExpression expr2 symtab classes
@ -137,17 +150,24 @@ typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
error "Logical OR operation requires two operands of type boolean" error "Logical OR operation requires two operands of type boolean"
NameResolution -> NameResolution ->
case (expr1', expr2) of case (expr1', expr2) of
(TypedExpression t1 (Reference obj), Reference member) -> (TypedExpression objType (LocalVariable ident), Reference ident2) ->
let objectType = lookupType obj symtab case find (\(Class className _ _) -> className == objType) classes of
classDetails = find (\(Class className _ _) -> className == objectType) classes
in case classDetails of
Just (Class _ _ fields) -> Just (Class _ _ fields) ->
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == ident2]
in case fieldTypes of in case fieldTypes of
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2)) [resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType (FieldVariable ident2)))
[] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'" [] -> error $ "Field '" ++ ident2 ++ "' not found in class '" ++ objType ++ "'"
_ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'" _ -> error $ "Ambiguous reference to field '" ++ ident ++ "' in class '" ++ objType ++ "'"
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class" Nothing -> error $ "Class '" ++ objType ++ "' not found"
(TypedExpression objType (FieldVariable ident), Reference ident2) ->
case find (\(Class className _ _) -> className == objType) classes of
Just (Class _ _ fields) ->
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == ident2]
in case fieldTypes of
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType (FieldVariable ident2)))
[] -> error $ "Field '" ++ ident2 ++ "' not found in class '" ++ objType ++ "'"
_ -> error $ "Ambiguous reference to field '" ++ ident ++ "' in class '" ++ objType ++ "'"
Nothing -> error $ "Class '" ++ objType ++ "' not found"
_ -> error "Name resolution requires object reference and field name" _ -> error "Name resolution requires object reference and field name"
typeCheckExpression (UnaryOperation op expr) symtab classes = typeCheckExpression (UnaryOperation op expr) symtab classes =
@ -177,12 +197,14 @@ typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)]
typeCheckStatementExpression (Assignment id expr) symtab classes = typeCheckStatementExpression (Assignment id expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr' type' = getTypeFromExpr expr'
type'' = lookupType id symtab maybeType'' = lookupType id symtab
in if type' == type'' in case maybeType'' of
then Just type'' ->
if type' == type'' then
TypedStatementExpression type' (Assignment id expr') TypedStatementExpression type' (Assignment id expr')
else else
error "Assignment type mismatch" error $ "Assignment type mismatch: expected " ++ type'' ++ ", found " ++ type'
Nothing -> error $ "Identifier '" ++ id ++ "' not found in symbol table"
typeCheckStatementExpression (ConstructorCall className args) symtab classes = typeCheckStatementExpression (ConstructorCall className args) symtab classes =
case find (\(Class name _ _) -> name == className) classes of case find (\(Class name _ _) -> name == className) classes of
@ -327,8 +349,8 @@ unifyReturnTypes dt1 dt2
| dt1 == dt2 = dt1 | dt1 == dt2 = dt1
| otherwise = "Object" | otherwise = "Object"
lookupType :: Identifier -> [(Identifier, DataType)] -> DataType lookupType :: Identifier -> [(Identifier, DataType)] -> Maybe DataType
lookupType id symtab = lookupType id symtab =
case lookup id symtab of case lookup id symtab of
Just t -> t Just t -> Just t
Nothing -> error ("Identifier " ++ id ++ " not found in symbol table") Nothing -> Nothing