add name resolution, fix symbol table key, value switchup

This commit is contained in:

View File

@ -1,18 +1,20 @@
module Ast where module Ast where
import Data.List (find)
type CompilationUnit = [Class] type CompilationUnit = [Class]
type DataType = String type DataType = String
type Identifier = String type Identifier = String
data ParameterDeclaration = ParameterDeclaration DataType Identifier data ParameterDeclaration = ParameterDeclaration DataType Identifier deriving (Show)
data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) deriving (Show)
data Class = Class DataType [MethodDeclaration] [VariableDeclaration] data Class = Class DataType [MethodDeclaration] [VariableDeclaration] deriving (Show)
data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement deriving (Show)
data Statement data Statement
= If Expression Statement (Maybe Statement) = If Expression Statement (Maybe Statement)
@ -22,12 +24,14 @@ data Statement
| Return (Maybe Expression) | Return (Maybe Expression)
| StatementExpressionStatement StatementExpression | StatementExpressionStatement StatementExpression
| TypedStatement DataType Statement | TypedStatement DataType Statement
deriving (Show)
data StatementExpression data StatementExpression
= Assignment Identifier Expression = Assignment Identifier Expression
| ConstructorCall DataType [Expression] | ConstructorCall DataType [Expression]
| MethodCall Identifier [Expression] | MethodCall Identifier [Expression]
| TypedStatementExpression DataType StatementExpression | TypedStatementExpression DataType StatementExpression
deriving (Show)
data BinaryOperator data BinaryOperator
= Addition = Addition
@ -46,10 +50,12 @@ data BinaryOperator
| And | And
| Or | Or
| NameResolution | NameResolution
deriving (Show)
data UnaryOperator data UnaryOperator
= Not = Not
| Minus | Minus
deriving (Show)
data Expression data Expression
= IntegerLiteral Int = IntegerLiteral Int
@ -61,6 +67,7 @@ data Expression
| UnaryOperation UnaryOperator Expression | UnaryOperation UnaryOperator Expression
| StatementExpressionExpression StatementExpression | StatementExpressionExpression StatementExpression
| TypedExpression DataType Expression | TypedExpression DataType Expression
deriving (Show)
typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit
typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes
@ -69,28 +76,28 @@ typeCheckClass :: Class -> [Class] -> Class
typeCheckClass (Class className methods fields) classes = typeCheckClass (Class className methods fields) classes =
let let
-- Create a symbol table from class fields -- Create a symbol table from class fields
classFields = [(dt, id) | VariableDeclaration dt id _ <- fields] classFields = [(id, dt) | VariableDeclaration dt id _ <- fields]
checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods
in Class className checkedMethods fields in Class className checkedMethods fields
typeCheckMethodDeclaration :: MethodDeclaration -> [(DataType, Identifier)] -> [Class] -> MethodDeclaration typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration
typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes =
let let
-- Combine class fields with method parameters to form the initial symbol table for the method -- Combine class fields with method parameters to form the initial symbol table for the method
methodParams = [(dataType, identifier) | ParameterDeclaration dataType identifier <- params] methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params]
-- Ensure method parameters shadow class fields if names collide -- Ensure method parameters shadow class fields if names collide
initialSymtab = classFields ++ methodParams initialSymtab = classFields ++ methodParams
-- Type check the body of the method using the combined symbol table -- Type check the body of the method using the combined symbol table
checkedBody = typeCheckStatement body initialSymtab classes checkedBody = typeCheckStatement body initialSymtab classes
bodyType = getTypeFromStmt checkedBody bodyType = getTypeFromStmt checkedBody
-- Check if the type of the body matches the declared return type -- Check if the type of the body matches the declared return type
in if bodyType == retType || (bodyType == "Void" && retType == "void") in if bodyType == retType || (bodyType == "void" && retType == "void")
then MethodDeclaration retType name params checkedBody then MethodDeclaration retType name params checkedBody
else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType
-- ********************************** Type Checking: Expressions ********************************** -- ********************************** Type Checking: Expressions **********************************
typeCheckExpression :: Expression -> [(DataType, Identifier)] -> [Class] -> Expression typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression
typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i) typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i)
typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c) typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c)
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
@ -194,8 +201,23 @@ typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
TypedExpression "boolean" (BinaryOperation op expr1' expr2') TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else else
error "Logical OR operation requires two operands of type boolean" error "Logical OR operation requires two operands of type boolean"
-- dont i have to lookup in classes if expr1 is in the list of classes? and if it is, then i have to check if expr2 is a method / variable in that class NameResolution ->
NameResolution -> TypedExpression type1 (BinaryOperation op expr1' expr2') case (expr1', expr2') of
(TypedExpression t1 (Reference obj), TypedExpression t2 (Reference member)) ->
-- Lookup the class type of obj from the symbol table
let objectType = lookupType obj symtab
classDetails = find (\(Class className _ _) -> className == objectType) classes
in case classDetails of
Just (Class _ methods fields) ->
-- Check both fields and methods to find a match for member
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member]
methodTypes = [dt | MethodDeclaration dt id _ _ <- methods, id == member]
in case fieldTypes ++ methodTypes of
[resolvedType] -> TypedExpression resolvedType (BinaryOperation op expr1' expr2')
[] -> error $ "Member '" ++ member ++ "' not found in class '" ++ objectType ++ "'"
_ -> error $ "Ambiguous reference to '" ++ member ++ "' in class '" ++ objectType ++ "'"
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
_ -> error "Name resolution requires object reference and member name"
typeCheckExpression (UnaryOperation op expr) symtab classes = typeCheckExpression (UnaryOperation op expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes let expr' = typeCheckExpression expr symtab classes
@ -220,7 +242,7 @@ typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes =
-- ********************************** Type Checking: StatementExpressions ********************************** -- ********************************** Type Checking: StatementExpressions **********************************
typeCheckStatementExpression :: StatementExpression -> [(DataType, Identifier)] -> [Class] -> StatementExpression typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression
typeCheckStatementExpression (Assignment id expr) symtab classes = typeCheckStatementExpression (Assignment id expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr' type' = getTypeFromExpr expr'
@ -241,7 +263,7 @@ typeCheckStatementExpression (MethodCall methodName args) symtab classes =
-- ********************************** Type Checking: Statements ********************************** -- ********************************** Type Checking: Statements **********************************
typeCheckStatement :: Statement -> [(DataType, Identifier)] -> [Class] -> Statement typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement
typeCheckStatement (If cond thenStmt elseStmt) symtab classes = typeCheckStatement (If cond thenStmt elseStmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes let cond' = typeCheckExpression cond symtab classes
thenStmt' = typeCheckStatement thenStmt symtab classes thenStmt' = typeCheckStatement thenStmt symtab classes
@ -254,6 +276,18 @@ typeCheckStatement (If cond thenStmt elseStmt) symtab classes =
else else
error "If condition must be of type boolean" error "If condition must be of type boolean"
typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes =
-- Check for redefinition in the current scope
if any ((== identifier) . snd) symtab
then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
else
-- If there's an initializer expression, type check it
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr
in case exprType of
Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
_ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (While cond stmt) symtab classes = typeCheckStatement (While cond stmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes let cond' = typeCheckExpression cond symtab classes
stmt' = typeCheckStatement stmt symtab classes stmt' = typeCheckStatement stmt symtab classes
@ -273,34 +307,22 @@ typeCheckStatement (Block statements) symtab classes =
LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) ->
let let
checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr
newSymtab = (dataType, identifier) : currentSymtab newSymtab = (identifier, dataType) : currentSymtab
in (accSts ++ [checkedStmt], newSymtab, types) in (accSts ++ [checkedStmt], newSymtab, types)
If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "Void" then types ++ [stmtType] else types) If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "Void" then types ++ [stmtType] else types) While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "Void" then types ++ [stmtType] else types) Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
_ -> (accSts ++ [checkedStmt], currentSymtab, types) _ -> (accSts ++ [checkedStmt], currentSymtab, types)
-- Initial accumulator: empty statements list, initial symbol table, empty types list -- Initial accumulator: empty statements list, initial symbol table, empty types list
(checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements (checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements
-- Determine the block's type: unify all collected types, default to "Void" if none -- Determine the block's type: unify all collected types, default to "Void" if none
blockType = if null collectedTypes then "Void" else foldl1 unifyReturnTypes collectedTypes blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes
in TypedStatement blockType (Block checkedStatements) in TypedStatement blockType (Block checkedStatements)
typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes =
-- Check for redefinition in the current scope
if any ((== identifier) . snd) symtab
then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
else
-- If there's an initializer expression, type check it
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr
in case exprType of
Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
_ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (Return expr) symtab classes = typeCheckStatement (Return expr) symtab classes =
let expr' = case expr of let expr' = case expr of
Just e -> Just (typeCheckExpression e symtab classes) Just e -> Just (typeCheckExpression e symtab classes)
@ -328,7 +350,7 @@ unifyReturnTypes dt1 dt2
| dt1 == dt2 = dt1 | dt1 == dt2 = dt1
| otherwise = "Object" | otherwise = "Object"
lookupType :: Identifier -> [(DataType, Identifier)] -> DataType lookupType :: Identifier -> [(Identifier, DataType)] -> DataType
lookupType id symtab = lookupType id symtab =
case lookup id symtab of case lookup id symtab of
Just t -> t Just t -> t