diff --git a/project.cabal b/project.cabal index 82617b1..02431af 100644 --- a/project.cabal +++ b/project.cabal @@ -16,7 +16,7 @@ executable compiler src/ByteCode, src/ByteCode/ClassFile build-tool-depends: alex:alex, happy:happy - other-modules: Ast, Example, ByteCode.ByteUtil, ByteCode.ClassFile, ByteCode.ClassFile.Generator, ByteCode.Constants + other-modules: Ast, Example, Typecheck, ByteCode.ByteUtil, ByteCode.ClassFile, ByteCode.ClassFile.Generator, ByteCode.Constants test-suite tests type: exitcode-stdio-1.0 diff --git a/src/Ast.hs b/src/Ast.hs index 83ccaa1..80f4a2b 100644 --- a/src/Ast.hs +++ b/src/Ast.hs @@ -1,19 +1,12 @@ module Ast where -import Data.List (find) - type CompilationUnit = [Class] - type DataType = String - type Identifier = String data ParameterDeclaration = ParameterDeclaration DataType Identifier deriving (Show) - data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) deriving (Show) - data Class = Class DataType [MethodDeclaration] [VariableDeclaration] deriving (Show) - data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement deriving (Show) data Statement @@ -68,290 +61,3 @@ data Expression | StatementExpressionExpression StatementExpression | TypedExpression DataType Expression deriving (Show) - -typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit -typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes - -typeCheckClass :: Class -> [Class] -> Class -typeCheckClass (Class className methods fields) classes = - let - -- Create a symbol table from class fields - classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] - checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods - in Class className checkedMethods fields - -typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration -typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = - let - -- Combine class fields with method parameters to form the initial symbol table for the method - methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] - -- Ensure method parameters shadow class fields if names collide - initialSymtab = classFields ++ methodParams - -- Type check the body of the method using the combined symbol table - checkedBody = typeCheckStatement body initialSymtab classes - bodyType = getTypeFromStmt checkedBody - -- Check if the type of the body matches the declared return type - in if bodyType == retType || (bodyType == "void" && retType == "void") - then MethodDeclaration retType name params checkedBody - else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType - --- ********************************** Type Checking: Expressions ********************************** - -typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression -typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i) -typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c) -typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) -typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral -typeCheckExpression (Reference id) symtab classes = - let type' = lookupType id symtab - in TypedExpression type' (Reference id) -typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = - let expr1' = typeCheckExpression expr1 symtab classes - expr2' = typeCheckExpression expr2 symtab classes - type1 = getTypeFromExpr expr1' - type2 = getTypeFromExpr expr2' - in case op of - Addition -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Addition operation requires two operands of type int" - Subtraction -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Subtraction operation requires two operands of type int" - Multiplication -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Multiplication operation requires two operands of type int" - Division -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Division operation requires two operands of type int" - BitwiseAnd -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise AND operation requires two operands of type int" - BitwiseOr -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise OR operation requires two operands of type int" - BitwiseXor -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise XOR operation requires two operands of type int" - CompareLessThan -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Less than operation requires two operands of type int" - CompareLessOrEqual -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Less than or equal operation requires two operands of type int" - CompareGreaterThan -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Greater than operation requires two operands of type int" - CompareGreaterOrEqual -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Greater than or equal operation requires two operands of type int" - CompareEqual -> - if type1 == type2 - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Equality operation requires two operands of the same type" - CompareNotEqual -> - if type1 == type2 - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Inequality operation requires two operands of the same type" - And -> - if type1 == "boolean" && type2 == "boolean" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Logical AND operation requires two operands of type boolean" - Or -> - if type1 == "boolean" && type2 == "boolean" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Logical OR operation requires two operands of type boolean" - NameResolution -> - case (expr1', expr2) of - (TypedExpression t1 (Reference obj), Reference member) -> - -- Lookup the class type of obj from the symbol table - let objectType = lookupType obj symtab - classDetails = find (\(Class className _ _) -> className == objectType) classes - in case classDetails of - Just (Class _ methods fields) -> - -- Check both fields and methods to find a match for member - let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] - methodTypes = [dt | MethodDeclaration dt id _ _ <- methods, id == member] - in case fieldTypes ++ methodTypes of - [resolvedType] -> TypedExpression resolvedType (BinaryOperation op expr1' (TypedExpression resolvedType (Reference member))) - [] -> error $ "Member '" ++ member ++ "' not found in class '" ++ objectType ++ "'" - _ -> error $ "Ambiguous reference to '" ++ member ++ "' in class '" ++ objectType ++ "'" - Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class" - _ -> error "Name resolution requires object reference and member name" - -typeCheckExpression (UnaryOperation op expr) symtab classes = - let expr' = typeCheckExpression expr symtab classes - type' = getTypeFromExpr expr' - in case op of - Not -> - if type' == "boolean" - then - TypedExpression "boolean" (UnaryOperation op expr') - else - error "Logical NOT operation requires an operand of type boolean" - Minus -> - if type' == "int" - then - TypedExpression "int" (UnaryOperation op expr') - else - error "Unary minus operation requires an operand of type int" - -typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = - let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes - in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr') - --- ********************************** Type Checking: StatementExpressions ********************************** --- TODO: Implement type checking for StatementExpressions -typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression -typeCheckStatementExpression (Assignment id expr) symtab classes = - let expr' = typeCheckExpression expr symtab classes - type' = getTypeFromExpr expr' - type'' = lookupType id symtab - in if type' == type'' - then - TypedStatementExpression type' (Assignment id expr') - else - error "Assignment type mismatch" - -typeCheckStatementExpression (ConstructorCall className args) symtab classes = - let args' = map (\arg -> typeCheckExpression arg symtab classes) args - in TypedStatementExpression className (ConstructorCall className args') - -typeCheckStatementExpression (MethodCall methodName args) symtab classes = - let args' = map (\arg -> typeCheckExpression arg symtab classes) args - in TypedStatementExpression "Object" (MethodCall methodName args') - --- ********************************** Type Checking: Statements ********************************** - -typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement -typeCheckStatement (If cond thenStmt elseStmt) symtab classes = - let cond' = typeCheckExpression cond symtab classes - thenStmt' = typeCheckStatement thenStmt symtab classes - elseStmt' = case elseStmt of - Just stmt -> Just (typeCheckStatement stmt symtab classes) - Nothing -> Nothing - in if getTypeFromExpr cond' == "boolean" - then - TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt') - else - error "If condition must be of type boolean" - -typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = - -- Check for redefinition in the current scope - if any ((== identifier) . snd) symtab - then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" - else - -- If there's an initializer expression, type check it - let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr - exprType = fmap getTypeFromExpr checkedExpr - in case exprType of - Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t - _ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) - -typeCheckStatement (While cond stmt) symtab classes = - let cond' = typeCheckExpression cond symtab classes - stmt' = typeCheckStatement stmt symtab classes - in if getTypeFromExpr cond' == "boolean" - then - TypedStatement (getTypeFromStmt stmt') (While cond' stmt') - else - error "While condition must be of type boolean" - -typeCheckStatement (Block statements) symtab classes = - let - processStatements (accSts, currentSymtab, types) stmt = - let - checkedStmt = typeCheckStatement stmt currentSymtab classes - stmtType = getTypeFromStmt checkedStmt - in case stmt of - LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> - let - checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr - newSymtab = (identifier, dataType) : currentSymtab - in (accSts ++ [checkedStmt], newSymtab, types) - - If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - _ -> (accSts ++ [checkedStmt], currentSymtab, types) - - -- Initial accumulator: empty statements list, initial symbol table, empty types list - (checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements - - -- Determine the block's type: unify all collected types, default to "Void" if none - blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes - - in TypedStatement blockType (Block checkedStatements) - -typeCheckStatement (Return expr) symtab classes = - let expr' = case expr of - Just e -> Just (typeCheckExpression e symtab classes) - Nothing -> Nothing - in case expr' of - Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) - Nothing -> TypedStatement "Void" (Return Nothing) - --- ********************************** Type Checking: Helpers ********************************** - -getTypeFromExpr :: Expression -> DataType -getTypeFromExpr (TypedExpression t _) = t -getTypeFromExpr _ = error "Untyped expression found where typed was expected" - -getTypeFromStmt :: Statement -> DataType -getTypeFromStmt (TypedStatement t _) = t -getTypeFromStmt _ = error "Untyped statement found where typed was expected" - -getTypeFromStmtExpr :: StatementExpression -> DataType -getTypeFromStmtExpr (TypedStatementExpression t _) = t -getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected" - -unifyReturnTypes :: DataType -> DataType -> DataType -unifyReturnTypes dt1 dt2 - | dt1 == dt2 = dt1 - | otherwise = "Object" - -lookupType :: Identifier -> [(Identifier, DataType)] -> DataType -lookupType id symtab = - case lookup id symtab of - Just t -> t - Nothing -> error ("Identifier " ++ id ++ " not found in symbol table") \ No newline at end of file diff --git a/src/Example.hs b/src/Example.hs index b45130c..26f7ab8 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -3,6 +3,7 @@ module Example where import Ast import Control.Exception (catch, evaluate, SomeException, displayException) import Control.Exception.Base +import Typecheck -- Example classes and their methods and fields sampleClasses :: [Class] diff --git a/src/Main.hs b/src/Main.hs index 4062f85..5ee22ff 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,9 +1,8 @@ module Main where -import Parser.Lexer import Example +import Typecheck main = do - --print $ alexScanTokens "/**/" Example.runTypeCheck diff --git a/src/Typecheck.hs b/src/Typecheck.hs new file mode 100644 index 0000000..2794204 --- /dev/null +++ b/src/Typecheck.hs @@ -0,0 +1,290 @@ +module Typecheck where +import Data.List (find) +import Ast + +typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit +typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes + +typeCheckClass :: Class -> [Class] -> Class +typeCheckClass (Class className methods fields) classes = + let + -- Create a symbol table from class fields + classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] + checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods + in Class className checkedMethods fields + +typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration +typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = + let + -- Combine class fields with method parameters to form the initial symbol table for the method + methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] + -- Ensure method parameters shadow class fields if names collide + initialSymtab = classFields ++ methodParams + -- Type check the body of the method using the combined symbol table + checkedBody = typeCheckStatement body initialSymtab classes + bodyType = getTypeFromStmt checkedBody + -- Check if the type of the body matches the declared return type + in if bodyType == retType || (bodyType == "void" && retType == "void") + then MethodDeclaration retType name params checkedBody + else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + +-- ********************************** Type Checking: Expressions ********************************** + +typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression +typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i) +typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c) +typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) +typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral +typeCheckExpression (Reference id) symtab classes = + let type' = lookupType id symtab + in TypedExpression type' (Reference id) +typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = + let expr1' = typeCheckExpression expr1 symtab classes + expr2' = typeCheckExpression expr2 symtab classes + type1 = getTypeFromExpr expr1' + type2 = getTypeFromExpr expr2' + in case op of + Addition -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Addition operation requires two operands of type int" + Subtraction -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Subtraction operation requires two operands of type int" + Multiplication -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Multiplication operation requires two operands of type int" + Division -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Division operation requires two operands of type int" + BitwiseAnd -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise AND operation requires two operands of type int" + BitwiseOr -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise OR operation requires two operands of type int" + BitwiseXor -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise XOR operation requires two operands of type int" + CompareLessThan -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Less than operation requires two operands of type int" + CompareLessOrEqual -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Less than or equal operation requires two operands of type int" + CompareGreaterThan -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Greater than operation requires two operands of type int" + CompareGreaterOrEqual -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Greater than or equal operation requires two operands of type int" + CompareEqual -> + if type1 == type2 + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Equality operation requires two operands of the same type" + CompareNotEqual -> + if type1 == type2 + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Inequality operation requires two operands of the same type" + And -> + if type1 == "boolean" && type2 == "boolean" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Logical AND operation requires two operands of type boolean" + Or -> + if type1 == "boolean" && type2 == "boolean" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Logical OR operation requires two operands of type boolean" + NameResolution -> + case (expr1', expr2) of + (TypedExpression t1 (Reference obj), Reference member) -> + -- Lookup the class type of obj from the symbol table + let objectType = lookupType obj symtab + classDetails = find (\(Class className _ _) -> className == objectType) classes + in case classDetails of + Just (Class _ methods fields) -> + -- Check both fields and methods to find a match for member + let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] + methodTypes = [dt | MethodDeclaration dt id _ _ <- methods, id == member] + in case fieldTypes ++ methodTypes of + [resolvedType] -> TypedExpression resolvedType (BinaryOperation op expr1' (TypedExpression resolvedType (Reference member))) + [] -> error $ "Member '" ++ member ++ "' not found in class '" ++ objectType ++ "'" + _ -> error $ "Ambiguous reference to '" ++ member ++ "' in class '" ++ objectType ++ "'" + Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class" + _ -> error "Name resolution requires object reference and member name" + +typeCheckExpression (UnaryOperation op expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in case op of + Not -> + if type' == "boolean" + then + TypedExpression "boolean" (UnaryOperation op expr') + else + error "Logical NOT operation requires an operand of type boolean" + Minus -> + if type' == "int" + then + TypedExpression "int" (UnaryOperation op expr') + else + error "Unary minus operation requires an operand of type int" + +typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = + let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes + in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr') + +-- ********************************** Type Checking: StatementExpressions ********************************** +-- TODO: Implement type checking for StatementExpressions +typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression +typeCheckStatementExpression (Assignment id expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + type'' = lookupType id symtab + in if type' == type'' + then + TypedStatementExpression type' (Assignment id expr') + else + error "Assignment type mismatch" + +typeCheckStatementExpression (ConstructorCall className args) symtab classes = + let args' = map (\arg -> typeCheckExpression arg symtab classes) args + in TypedStatementExpression className (ConstructorCall className args') + +typeCheckStatementExpression (MethodCall methodName args) symtab classes = + let args' = map (\arg -> typeCheckExpression arg symtab classes) args + in TypedStatementExpression "Object" (MethodCall methodName args') + +-- ********************************** Type Checking: Statements ********************************** + +typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement +typeCheckStatement (If cond thenStmt elseStmt) symtab classes = + let cond' = typeCheckExpression cond symtab classes + thenStmt' = typeCheckStatement thenStmt symtab classes + elseStmt' = case elseStmt of + Just stmt -> Just (typeCheckStatement stmt symtab classes) + Nothing -> Nothing + in if getTypeFromExpr cond' == "boolean" + then + TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt') + else + error "If condition must be of type boolean" + +typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = + -- Check for redefinition in the current scope + if any ((== identifier) . snd) symtab + then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" + else + -- If there's an initializer expression, type check it + let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr + exprType = fmap getTypeFromExpr checkedExpr + in case exprType of + Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t + _ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) + +typeCheckStatement (While cond stmt) symtab classes = + let cond' = typeCheckExpression cond symtab classes + stmt' = typeCheckStatement stmt symtab classes + in if getTypeFromExpr cond' == "boolean" + then + TypedStatement (getTypeFromStmt stmt') (While cond' stmt') + else + error "While condition must be of type boolean" + +typeCheckStatement (Block statements) symtab classes = + let + processStatements (accSts, currentSymtab, types) stmt = + let + checkedStmt = typeCheckStatement stmt currentSymtab classes + stmtType = getTypeFromStmt checkedStmt + in case stmt of + LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> + let + checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr + newSymtab = (identifier, dataType) : currentSymtab + in (accSts ++ [checkedStmt], newSymtab, types) + + If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + _ -> (accSts ++ [checkedStmt], currentSymtab, types) + + -- Initial accumulator: empty statements list, initial symbol table, empty types list + (checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements + + -- Determine the block's type: unify all collected types, default to "Void" if none + blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes + + in TypedStatement blockType (Block checkedStatements) + +typeCheckStatement (Return expr) symtab classes = + let expr' = case expr of + Just e -> Just (typeCheckExpression e symtab classes) + Nothing -> Nothing + in case expr' of + Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) + Nothing -> TypedStatement "Void" (Return Nothing) + +-- ********************************** Type Checking: Helpers ********************************** + +getTypeFromExpr :: Expression -> DataType +getTypeFromExpr (TypedExpression t _) = t +getTypeFromExpr _ = error "Untyped expression found where typed was expected" + +getTypeFromStmt :: Statement -> DataType +getTypeFromStmt (TypedStatement t _) = t +getTypeFromStmt _ = error "Untyped statement found where typed was expected" + +getTypeFromStmtExpr :: StatementExpression -> DataType +getTypeFromStmtExpr (TypedStatementExpression t _) = t +getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected" + +unifyReturnTypes :: DataType -> DataType -> DataType +unifyReturnTypes dt1 dt2 + | dt1 == dt2 = dt1 + | otherwise = "Object" + +lookupType :: Identifier -> [(Identifier, DataType)] -> DataType +lookupType id symtab = + case lookup id symtab of + Just t -> t + Nothing -> error ("Identifier " ++ id ++ " not found in symbol table") \ No newline at end of file