Add initial typechecker for AST #2
@ -58,7 +58,7 @@ exampleConstructorCall :: Statement
|
|||||||
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
|
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
|
||||||
|
|
||||||
exampleNameResolution :: Expression
|
exampleNameResolution :: Expression
|
||||||
exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age")
|
exampleNameResolution = BinaryOperation NameResolution (Reference "bob2") (Reference "age")
|
||||||
|
|
||||||
exampleBlockResolution :: Statement
|
exampleBlockResolution :: Statement
|
||||||
exampleBlockResolution = Block [
|
exampleBlockResolution = Block [
|
||||||
@ -113,10 +113,12 @@ testClasses = [
|
|||||||
(Block [
|
(Block [
|
||||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
|
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
|
||||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
|
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
|
||||||
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob2") "getAge" [])))),
|
||||||
Return (Just (Reference "bobAge"))
|
Return (Just (Reference "bobAge"))
|
||||||
])
|
])
|
||||||
] []
|
] [
|
||||||
|
VariableDeclaration "Person" "bob2" Nothing
|
||||||
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
runTypeCheck :: IO ()
|
runTypeCheck :: IO ()
|
||||||
@ -151,7 +153,7 @@ runTypeCheck = do
|
|||||||
|
|
||||||
catch (do
|
catch (do
|
||||||
print "====================================================================================="
|
print "====================================================================================="
|
||||||
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses)
|
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("this", "Main")] testClasses)
|
||||||
printSuccess "Type checking of name resolution completed successfully"
|
printSuccess "Type checking of name resolution completed successfully"
|
||||||
printResult "Result Name Resolution:" evaluatedNameResolution
|
printResult "Result Name Resolution:" evaluatedNameResolution
|
||||||
) handleError
|
) handleError
|
||||||
@ -189,7 +191,7 @@ runTypeCheck = do
|
|||||||
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
|
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
|
||||||
case mainClass of
|
case mainClass of
|
||||||
Class _ [mainMethod] _ -> do
|
Class _ [mainMethod] _ -> do
|
||||||
let result = typeCheckMethodDeclaration mainMethod [] testClasses
|
let result = typeCheckMethodDeclaration mainMethod [("this", "Main")] testClasses
|
||||||
printSuccess "Full program type checking completed successfully."
|
printSuccess "Full program type checking completed successfully."
|
||||||
printResult "Main method result:" result
|
printResult "Main method result:" result
|
||||||
) handleError
|
) handleError
|
||||||
|
@ -10,9 +10,9 @@ typeCheckClass :: Class -> [Class] -> Class
|
|||||||
typeCheckClass (Class className methods fields) classes =
|
typeCheckClass (Class className methods fields) classes =
|
||||||
let
|
let
|
||||||
-- Create a symbol table from class fields and method entries
|
-- Create a symbol table from class fields and method entries
|
||||||
classFields = [(id, dt) | VariableDeclaration dt id _ <- fields]
|
-- TODO: Maybe remove method entries from the symbol table?
|
||||||
methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
|
methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
|
||||||
initalSymTab = ("this", className) : classFields ++ methodEntries
|
initalSymTab = ("this", className) : methodEntries
|
||||||
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
|
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
|
||||||
in Class className checkedMethods fields
|
in Class className checkedMethods fields
|
||||||
|
|
||||||
@ -37,8 +37,21 @@ typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (Character
|
|||||||
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
|
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
|
||||||
typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral
|
typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral
|
||||||
typeCheckExpression (Reference id) symtab classes =
|
typeCheckExpression (Reference id) symtab classes =
|
||||||
let type' = lookupType id symtab
|
case lookup id symtab of
|
||||||
in TypedExpression type' (Reference id)
|
Just t -> TypedExpression t (LocalVariable id)
|
||||||
|
Nothing ->
|
||||||
|
case lookup "this" symtab of
|
||||||
|
Just className ->
|
||||||
|
let classDetails = find (\(Class name _ _) -> name == className) classes
|
||||||
|
in case classDetails of
|
||||||
|
Just (Class _ _ fields) ->
|
||||||
|
let fieldTypes = [dt | VariableDeclaration dt fieldId _ <- fields, fieldId == id]
|
||||||
|
in case fieldTypes of
|
||||||
|
[fieldType] -> TypedExpression fieldType (FieldVariable id)
|
||||||
|
[] -> error $ "Field '" ++ id ++ "' not found in class '" ++ className ++ "'"
|
||||||
|
_ -> error $ "Ambiguous reference to field '" ++ id ++ "' in class '" ++ className ++ "'"
|
||||||
|
Nothing -> error $ "Class '" ++ className ++ "' not found for 'this'"
|
||||||
|
Nothing -> error $ "Context for 'this' not found in symbol table, unable to resolve '" ++ id ++ "'"
|
||||||
typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
|
typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
|
||||||
let expr1' = typeCheckExpression expr1 symtab classes
|
let expr1' = typeCheckExpression expr1 symtab classes
|
||||||
expr2' = typeCheckExpression expr2 symtab classes
|
expr2' = typeCheckExpression expr2 symtab classes
|
||||||
@ -137,17 +150,24 @@ typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
|
|||||||
error "Logical OR operation requires two operands of type boolean"
|
error "Logical OR operation requires two operands of type boolean"
|
||||||
NameResolution ->
|
NameResolution ->
|
||||||
case (expr1', expr2) of
|
case (expr1', expr2) of
|
||||||
(TypedExpression t1 (Reference obj), Reference member) ->
|
(TypedExpression objType (LocalVariable ident), Reference ident2) ->
|
||||||
let objectType = lookupType obj symtab
|
case find (\(Class className _ _) -> className == objType) classes of
|
||||||
classDetails = find (\(Class className _ _) -> className == objectType) classes
|
|
||||||
in case classDetails of
|
|
||||||
Just (Class _ _ fields) ->
|
Just (Class _ _ fields) ->
|
||||||
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member]
|
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == ident2]
|
||||||
in case fieldTypes of
|
in case fieldTypes of
|
||||||
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2))
|
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType (FieldVariable ident2)))
|
||||||
[] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'"
|
[] -> error $ "Field '" ++ ident2 ++ "' not found in class '" ++ objType ++ "'"
|
||||||
_ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'"
|
_ -> error $ "Ambiguous reference to field '" ++ ident ++ "' in class '" ++ objType ++ "'"
|
||||||
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
|
Nothing -> error $ "Class '" ++ objType ++ "' not found"
|
||||||
|
(TypedExpression objType (FieldVariable ident), Reference ident2) ->
|
||||||
|
case find (\(Class className _ _) -> className == objType) classes of
|
||||||
|
Just (Class _ _ fields) ->
|
||||||
|
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == ident2]
|
||||||
|
in case fieldTypes of
|
||||||
|
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType (FieldVariable ident2)))
|
||||||
|
[] -> error $ "Field '" ++ ident2 ++ "' not found in class '" ++ objType ++ "'"
|
||||||
|
_ -> error $ "Ambiguous reference to field '" ++ ident ++ "' in class '" ++ objType ++ "'"
|
||||||
|
Nothing -> error $ "Class '" ++ objType ++ "' not found"
|
||||||
_ -> error "Name resolution requires object reference and field name"
|
_ -> error "Name resolution requires object reference and field name"
|
||||||
|
|
||||||
typeCheckExpression (UnaryOperation op expr) symtab classes =
|
typeCheckExpression (UnaryOperation op expr) symtab classes =
|
||||||
@ -177,12 +197,14 @@ typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)]
|
|||||||
typeCheckStatementExpression (Assignment id expr) symtab classes =
|
typeCheckStatementExpression (Assignment id expr) symtab classes =
|
||||||
let expr' = typeCheckExpression expr symtab classes
|
let expr' = typeCheckExpression expr symtab classes
|
||||||
type' = getTypeFromExpr expr'
|
type' = getTypeFromExpr expr'
|
||||||
type'' = lookupType id symtab
|
maybeType'' = lookupType id symtab
|
||||||
in if type' == type''
|
in case maybeType'' of
|
||||||
then
|
Just type'' ->
|
||||||
|
if type' == type'' then
|
||||||
TypedStatementExpression type' (Assignment id expr')
|
TypedStatementExpression type' (Assignment id expr')
|
||||||
else
|
else
|
||||||
error "Assignment type mismatch"
|
error $ "Assignment type mismatch: expected " ++ type'' ++ ", found " ++ type'
|
||||||
|
Nothing -> error $ "Identifier '" ++ id ++ "' not found in symbol table"
|
||||||
|
|
||||||
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
||||||
case find (\(Class name _ _) -> name == className) classes of
|
case find (\(Class name _ _) -> name == className) classes of
|
||||||
@ -327,8 +349,8 @@ unifyReturnTypes dt1 dt2
|
|||||||
| dt1 == dt2 = dt1
|
| dt1 == dt2 = dt1
|
||||||
| otherwise = "Object"
|
| otherwise = "Object"
|
||||||
|
|
||||||
lookupType :: Identifier -> [(Identifier, DataType)] -> DataType
|
lookupType :: Identifier -> [(Identifier, DataType)] -> Maybe DataType
|
||||||
lookupType id symtab =
|
lookupType id symtab =
|
||||||
case lookup id symtab of
|
case lookup id symtab of
|
||||||
Just t -> t
|
Just t -> Just t
|
||||||
Nothing -> error ("Identifier " ++ id ++ " not found in symbol table")
|
Nothing -> Nothing
|
||||||
|
Loading…
Reference in New Issue
Block a user