module Ast where type CompilationUnit = [Class] type DataType = String type Identifier = String data ParameterDeclaration = ParameterDeclaration DataType Identifier data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) data Class = Class DataType [MethodDeclaration] [VariableDeclaration] data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement data Statement = If Expression Statement (Maybe Statement) | LocalVariableDeclaration VariableDeclaration | While Expression Statement | Block [Statement] | Return (Maybe Expression) | StatementExpressionStatement StatementExpression | TypedStatement DataType Statement data StatementExpression = Assignment Identifier Expression | ConstructorCall DataType [Expression] | MethodCall Identifier [Expression] | TypedStatementExpression DataType StatementExpression data BinaryOperator = Addition | Subtraction | Multiplication | Division | BitwiseAnd | BitwiseOr | BitwiseXor | CompareLessThan | CompareLessOrEqual | CompareGreaterThan | CompareGreaterOrEqual | CompareEqual | CompareNotEqual | And | Or | NameResolution data UnaryOperator = Not | Minus data Expression = IntegerLiteral Int | CharacterLiteral Char | BooleanLiteral Bool | NullLiteral | Reference Identifier | BinaryOperation BinaryOperator Expression Expression | UnaryOperation UnaryOperator Expression | StatementExpressionExpression StatementExpression | TypedExpression DataType Expression typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes typeCheckClass :: Class -> [Class] -> Class typeCheckClass (Class className methods fields) classes = let -- Create a symbol table from class fields classFields = [(dt, id) | VariableDeclaration dt id _ <- fields] checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods in Class className checkedMethods fields typeCheckMethodDeclaration :: MethodDeclaration -> [(DataType, Identifier)] -> [Class] -> MethodDeclaration typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = let -- Combine class fields with method parameters to form the initial symbol table for the method methodParams = [(dataType, identifier) | ParameterDeclaration dataType identifier <- params] -- Ensure method parameters shadow class fields if names collide initialSymtab = classFields ++ methodParams -- Type check the body of the method using the combined symbol table checkedBody = typeCheckStatement body initialSymtab classes bodyType = getTypeFromStmt checkedBody -- Check if the type of the body matches the declared return type in if bodyType == retType || (bodyType == "Void" && retType == "void") then MethodDeclaration retType name params checkedBody else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType -- ********************************** Type Checking: Expressions ********************************** typeCheckExpression :: Expression -> [(DataType, Identifier)] -> [Class] -> Expression typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i) typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c) typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral typeCheckExpression (Reference id) symtab classes = let type' = lookupType id symtab in TypedExpression type' (Reference id) typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = let expr1' = typeCheckExpression expr1 symtab classes expr2' = typeCheckExpression expr2 symtab classes type1 = getTypeFromExpr expr1' type2 = getTypeFromExpr expr2' in case op of Addition -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Addition operation requires two operands of type int" Subtraction -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Subtraction operation requires two operands of type int" Multiplication -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Multiplication operation requires two operands of type int" Division -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Division operation requires two operands of type int" BitwiseAnd -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Bitwise AND operation requires two operands of type int" BitwiseOr -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Bitwise OR operation requires two operands of type int" BitwiseXor -> if type1 == "int" && type2 == "int" then TypedExpression "int" (BinaryOperation op expr1' expr2') else error "Bitwise XOR operation requires two operands of type int" CompareLessThan -> if type1 == "int" && type2 == "int" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Less than operation requires two operands of type int" CompareLessOrEqual -> if type1 == "int" && type2 == "int" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Less than or equal operation requires two operands of type int" CompareGreaterThan -> if type1 == "int" && type2 == "int" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Greater than operation requires two operands of type int" CompareGreaterOrEqual -> if type1 == "int" && type2 == "int" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Greater than or equal operation requires two operands of type int" CompareEqual -> if type1 == type2 then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Equality operation requires two operands of the same type" CompareNotEqual -> if type1 == type2 then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Inequality operation requires two operands of the same type" And -> if type1 == "boolean" && type2 == "boolean" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Logical AND operation requires two operands of type boolean" Or -> if type1 == "boolean" && type2 == "boolean" then TypedExpression "boolean" (BinaryOperation op expr1' expr2') else error "Logical OR operation requires two operands of type boolean" -- dont i have to lookup in classes if expr1 is in the list of classes? and if it is, then i have to check if expr2 is a method / variable in that class NameResolution -> TypedExpression type1 (BinaryOperation op expr1' expr2') typeCheckExpression (UnaryOperation op expr) symtab classes = let expr' = typeCheckExpression expr symtab classes type' = getTypeFromExpr expr' in case op of Not -> if type' == "boolean" then TypedExpression "boolean" (UnaryOperation op expr') else error "Logical NOT operation requires an operand of type boolean" Minus -> if type' == "int" then TypedExpression "int" (UnaryOperation op expr') else error "Unary minus operation requires an operand of type int" typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr') -- ********************************** Type Checking: StatementExpressions ********************************** typeCheckStatementExpression :: StatementExpression -> [(DataType, Identifier)] -> [Class] -> StatementExpression typeCheckStatementExpression (Assignment id expr) symtab classes = let expr' = typeCheckExpression expr symtab classes type' = getTypeFromExpr expr' type'' = lookupType id symtab in if type' == type'' then TypedStatementExpression type' (Assignment id expr') else error "Assignment type mismatch" typeCheckStatementExpression (ConstructorCall className args) symtab classes = let args' = map (\arg -> typeCheckExpression arg symtab classes) args in TypedStatementExpression className (ConstructorCall className args') typeCheckStatementExpression (MethodCall methodName args) symtab classes = let args' = map (\arg -> typeCheckExpression arg symtab classes) args in TypedStatementExpression "Object" (MethodCall methodName args') -- ********************************** Type Checking: Statements ********************************** typeCheckStatement :: Statement -> [(DataType, Identifier)] -> [Class] -> Statement typeCheckStatement (If cond thenStmt elseStmt) symtab classes = let cond' = typeCheckExpression cond symtab classes thenStmt' = typeCheckStatement thenStmt symtab classes elseStmt' = case elseStmt of Just stmt -> Just (typeCheckStatement stmt symtab classes) Nothing -> Nothing in if getTypeFromExpr cond' == "boolean" then TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt') else error "If condition must be of type boolean" typeCheckStatement (While cond stmt) symtab classes = let cond' = typeCheckExpression cond symtab classes stmt' = typeCheckStatement stmt symtab classes in if getTypeFromExpr cond' == "boolean" then TypedStatement (getTypeFromStmt stmt') (While cond' stmt') else error "While condition must be of type boolean" typeCheckStatement (Block statements) symtab classes = let -- Helper function to process each statement and manage the symbol table processStatements (accSts, currentSymtab) stmt = case stmt of LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> let -- Type check the expression if it exists checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr -- Update the symbol table with the new variable newSymtab = (dataType, identifier) : currentSymtab newStmt = typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) newSymtab classes in (accSts ++ [newStmt], newSymtab) _ -> -- For other statements, just type check using the current symbol table let checkedStmt = typeCheckStatement stmt currentSymtab classes in (accSts ++ [checkedStmt], currentSymtab) -- Fold over the list of statements starting with the initial symbol table (checkedStatements, finalSymtab) = foldl processStatements ([], symtab) statements -- Determine the type of the block by examining the types of return statements blockType = if any isReturnStatement checkedStatements then foldl1 unifyReturnTypes [getTypeFromStmt s | s <- checkedStatements, isReturnStatement s] else "Void" -- Function to check if a statement is a return statement isReturnStatement (Return _) = True isReturnStatement _ = False in TypedStatement blockType (Block checkedStatements) typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr in TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) typeCheckStatement (Return expr) symtab classes = let expr' = case expr of Just e -> Just (typeCheckExpression e symtab classes) Nothing -> Nothing in case expr' of Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) Nothing -> TypedStatement "void" (Return Nothing) -- ********************************** Type Checking: Helpers ********************************** getTypeFromExpr :: Expression -> DataType getTypeFromExpr (TypedExpression t _) = t getTypeFromExpr _ = error "Untyped expression found where typed was expected" getTypeFromStmt :: Statement -> DataType getTypeFromStmt (TypedStatement t _) = t getTypeFromStmt _ = error "Untyped statement found where typed was expected" getTypeFromStmtExpr :: StatementExpression -> DataType getTypeFromStmtExpr (TypedStatementExpression t _) = t getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected" unifyReturnTypes :: DataType -> DataType -> DataType unifyReturnTypes dt1 dt2 | dt1 == dt2 = dt1 | otherwise = "Object" lookupType :: Identifier -> [(DataType, Identifier)] -> DataType lookupType id symtab = case lookup id symtab of Just t -> t Nothing -> error ("Identifier " ++ id ++ " not found in symbol table")