diff --git a/src/Typecheck.hs b/src/Typecheck.hs index fd58f17..e4aab3b 100644 --- a/src/Typecheck.hs +++ b/src/Typecheck.hs @@ -19,15 +19,14 @@ typeCheckClass (Class className methods fields) classes = typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = let - -- Combine class fields with method parameters to form the initial symbol table for the method methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] - initialSymtab = classFields ++ methodParams + initialSymtab = ("thisMeth", retType) : classFields ++ methodParams checkedBody = typeCheckStatement body initialSymtab classes bodyType = getTypeFromStmt checkedBody - -- Check if the type of the body matches the declared return type - in if bodyType == retType || (bodyType == "void" && retType == "void") || (bodyType == "null" && isObjectType retType) + in if bodyType == retType || (bodyType == "void" && retType == "void") || (bodyType == "null" && isObjectType retType) || isSubtype bodyType retType classes then MethodDeclaration retType name params checkedBody - else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + else error $ "Method Declaration: Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + -- ********************************** Type Checking: Expressions ********************************** @@ -119,7 +118,7 @@ typeCheckStatementExpression (ConstructorCall className args) symtab classes = Nothing -> error $ "Class '" ++ className ++ "' not found." Just (Class _ methods fields) -> -- Constructor needs the same name as the class - case find (\(MethodDeclaration retType name params _) -> name == className && retType == className) methods of + case find (\(MethodDeclaration retType name params _) -> name == "" && retType == "void") methods of Nothing -> error $ "No valid constructor found for class '" ++ className ++ "'." Just (MethodDeclaration _ _ params _) -> let @@ -204,19 +203,21 @@ typeCheckStatementExpression (PreDecrement expr) symtab classes = typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement typeCheckStatement (If cond thenStmt elseStmt) symtab classes = - let cond' = typeCheckExpression cond symtab classes - thenStmt' = typeCheckStatement thenStmt symtab classes - elseStmt' = case elseStmt of - Just stmt -> Just (typeCheckStatement stmt symtab classes) - Nothing -> Nothing - thenType = getTypeFromStmt thenStmt' - elseType = maybe "void" getTypeFromStmt elseStmt' - ifType = if thenType /= "void" && elseType /= "void" && thenType == elseType then thenType else "void" - in if getTypeFromExpr cond' == "boolean" - then - TypedStatement ifType (If cond' thenStmt' elseStmt') - else - error "If condition must be of type boolean" + let + cond' = typeCheckExpression cond symtab classes + thenStmt' = typeCheckStatement thenStmt symtab classes + elseStmt' = fmap (\stmt -> typeCheckStatement stmt symtab classes) elseStmt + + thenType = getTypeFromStmt thenStmt' + elseType = maybe "void" getTypeFromStmt elseStmt' + + ifType = if thenType == "void" || elseType == "void" + then "void" + else unifyReturnTypes thenType elseType + + in if getTypeFromExpr cond' == "boolean" + then TypedStatement ifType (If cond' thenStmt' elseStmt') + else error "If condition must be of type boolean" typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = @@ -229,7 +230,7 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident exprType = fmap getTypeFromExpr checkedExpr in case exprType of Just t - | t == "null" && isObjectType dataType -> + | t == "null" && isObjectType dataType -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t | otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) @@ -272,12 +273,14 @@ typeCheckStatement (Block statements) symtab classes = in TypedStatement blockType (Block checkedStatements) typeCheckStatement (Return expr) symtab classes = - let expr' = case expr of + let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab) + expr' = case expr of Just e -> Just (typeCheckExpression e symtab classes) Nothing -> Nothing - in case expr' of - Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) - Nothing -> TypedStatement "void" (Return Nothing) + returnType = maybe "void" getTypeFromExpr expr' + in if returnType == methodReturnType || isSubtype returnType methodReturnType classes + then TypedStatement returnType (Return expr') + else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes @@ -285,6 +288,17 @@ typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = -- ********************************** Type Checking: Helpers ********************************** +isSubtype :: DataType -> DataType -> [Class] -> Bool +isSubtype subType superType classes + | subType == superType = True + | subType == "null" && isObjectType superType = True + | superType == "Object" && isObjectType subType = True + | superType == "Object" && isUserDefinedClass subType classes = True + | otherwise = False + +isUserDefinedClass :: DataType -> [Class] -> Bool +isUserDefinedClass dt classes = dt `elem` map (\(Class name _ _) -> name) classes + isObjectType :: DataType -> Bool isObjectType dt = dt /= "int" && dt /= "boolean" && dt /= "char" @@ -302,8 +316,10 @@ getTypeFromStmtExpr _ = error "Untyped statement expression found where typed wa unifyReturnTypes :: DataType -> DataType -> DataType unifyReturnTypes dt1 dt2 - | dt1 == dt2 = dt1 - | otherwise = "Object" + | dt1 == dt2 = dt1 + | dt1 == "null" = dt2 + | dt2 == "null" = dt1 + | otherwise = "Object" resolveResultType :: DataType -> DataType -> DataType resolveResultType "char" "char" = "char"