Compare commits

..

10 Commits

3 changed files with 253 additions and 54 deletions

View File

@ -22,7 +22,7 @@ data Statement
data StatementExpression data StatementExpression
= Assignment Identifier Expression = Assignment Identifier Expression
| ConstructorCall DataType [Expression] | ConstructorCall DataType [Expression]
| MethodCall Identifier [Expression] | MethodCall Expression Identifier [Expression]
| TypedStatementExpression DataType StatementExpression | TypedStatementExpression DataType StatementExpression
deriving (Show, Eq) deriving (Show, Eq)
@ -56,6 +56,8 @@ data Expression
| BooleanLiteral Bool | BooleanLiteral Bool
| NullLiteral | NullLiteral
| Reference Identifier | Reference Identifier
| LocalVariable Identifier
| FieldVariable Identifier
| BinaryOperation BinaryOperator Expression Expression | BinaryOperation BinaryOperator Expression Expression
| UnaryOperation UnaryOperator Expression | UnaryOperation UnaryOperator Expression
| StatementExpressionExpression StatementExpression | StatementExpressionExpression StatementExpression

View File

@ -1,50 +1,203 @@
module Example where module Example where
import Ast import Ast
import Typecheck
import Control.Exception (catch, evaluate, SomeException, displayException) import Control.Exception (catch, evaluate, SomeException, displayException)
import Control.Exception.Base import Control.Exception.Base
import Typecheck import System.IO (stderr, hPutStrLn)
import Data.Maybe
import Data.List
green, red, yellow, blue, magenta, cyan, white :: String -> String
green str = "\x1b[32m" ++ str ++ "\x1b[0m"
red str = "\x1b[31m" ++ str ++ "\x1b[0m"
yellow str = "\x1b[33m" ++ str ++ "\x1b[0m"
blue str = "\x1b[34m" ++ str ++ "\x1b[0m"
magenta str = "\x1b[35m" ++ str ++ "\x1b[0m"
cyan str = "\x1b[36m" ++ str ++ "\x1b[0m"
white str = "\x1b[37m" ++ str ++ "\x1b[0m"
printSuccess :: String -> IO ()
printSuccess msg = putStrLn $ green "Success:" ++ white msg
handleError :: SomeException -> IO ()
handleError e = hPutStrLn stderr $ red ("Error: " ++ displayException e)
printResult :: Show a => String -> a -> IO ()
printResult title result = do
putStrLn $ green title
print result
-- Example classes and their methods and fields
sampleClasses :: [Class] sampleClasses :: [Class]
sampleClasses = [ sampleClasses = [
Class "Person" [ Class "Person" [
MethodDeclaration "void" "setAge" [ParameterDeclaration "Int" "newAge"] MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
(Block [ (Block [
LocalVariableDeclaration (VariableDeclaration "Int" "age" (Just (Reference "newAge"))) LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
]), ]),
MethodDeclaration "Int" "getAge" [] (Return (Just (Reference "age"))) MethodDeclaration "int" "getAge" [] (Return (Just (Reference "age"))),
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] (Block [])
] [ ] [
VariableDeclaration "Int" "age" (Just (IntegerLiteral 25)), VariableDeclaration "int" "age" (Just (IntegerLiteral 25))
VariableDeclaration "String" "name" (Just (CharacterLiteral 'A'))
] ]
] ]
-- Symbol table, mapping identifiers to their data types
initialSymtab :: [(DataType, Identifier)] initialSymtab :: [(DataType, Identifier)]
initialSymtab = [] initialSymtab = []
-- An example block of statements to type check
exampleBlock :: Statement
exampleBlock = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [])))),
StatementExpressionStatement (MethodCall "setAge" [IntegerLiteral 30]),
Return (Just (StatementExpressionExpression (MethodCall "getAge" [])))
]
exampleExpression :: Expression exampleExpression :: Expression
exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age") exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age")
-- Function to perform type checking and handle errors exampleAssignment :: Expression
exampleAssignment = StatementExpressionExpression (Assignment "a" (IntegerLiteral 30))
exampleMethodCall :: Statement
exampleMethodCall = StatementExpressionStatement (MethodCall (Reference "this") "setAge" [IntegerLiteral 30])
exampleConstructorCall :: Statement
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
exampleNameResolution :: Expression
exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age")
exampleBlockResolution :: Statement
exampleBlockResolution = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
]
exampleBlockResolutionFail :: Statement
exampleBlockResolutionFail = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "bool" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
]
exampleMethodCallAndAssignment :: Statement
exampleMethodCallAndAssignment = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
LocalVariableDeclaration (VariableDeclaration "int" "a" Nothing),
StatementExpressionStatement (Assignment "a" (Reference "age"))
]
exampleMethodCallAndAssignmentFail :: Statement
exampleMethodCallAndAssignmentFail = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
StatementExpressionStatement (Assignment "a" (Reference "age"))
]
testClasses :: [Class]
testClasses = [
Class "Person" [
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"]
(Block [
Return (Just (Reference "this"))
]),
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
(Block [
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
]),
MethodDeclaration "int" "getAge" []
(Return (Just (Reference "age")))
] [
VariableDeclaration "int" "age" Nothing -- initially unassigned
],
Class "Main" [
MethodDeclaration "int" "main" []
(Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
Return (Just (Reference "bobAge"))
])
] []
]
runTypeCheck :: IO () runTypeCheck :: IO ()
runTypeCheck = do runTypeCheck = do
-- Evaluate the block of statements catch (do
--evaluatedBlock <- evaluate (typeCheckStatement exampleBlock initialSymtab sampleClasses) print "====================================================================================="
--putStrLn "Type checking of block completed successfully:" evaluatedExpression <- evaluate (typeCheckExpression exampleExpression [("bob", "Person")] sampleClasses)
--print evaluatedBlock printSuccess "Type checking of expression completed successfully"
printResult "Result Expression:" evaluatedExpression
) handleError
-- Evaluate the expression catch (do
evaluatedExpression <- evaluate (typeCheckExpression exampleExpression [("bob", "Person"), ("age", "int")] sampleClasses) print "====================================================================================="
putStrLn "Type checking of expression completed successfully:" evaluatedAssignment <- evaluate (typeCheckExpression exampleAssignment [("a", "int")] sampleClasses)
print evaluatedExpression printSuccess "Type checking of assignment completed successfully"
printResult "Result Assignment:" evaluatedAssignment
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCall <- evaluate (typeCheckStatement exampleMethodCall [("this", "Person"), ("setAge", "Person"), ("getAge", "Person")] sampleClasses)
printSuccess "Type checking of method call this completed successfully"
printResult "Result MethodCall:" evaluatedMethodCall
) handleError
catch (do
print "====================================================================================="
evaluatedConstructorCall <- evaluate (typeCheckStatement exampleConstructorCall [] sampleClasses)
printSuccess "Type checking of constructor call completed successfully"
printResult "Result Constructor Call:" evaluatedConstructorCall
) handleError
catch (do
print "====================================================================================="
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses)
printSuccess "Type checking of name resolution completed successfully"
printResult "Result Name Resolution:" evaluatedNameResolution
) handleError
catch (do
print "====================================================================================="
evaluatedBlockResolution <- evaluate (typeCheckStatement exampleBlockResolution [] sampleClasses)
printSuccess "Type checking of block resolution completed successfully"
printResult "Result Block Resolution:" evaluatedBlockResolution
) handleError
catch (do
print "====================================================================================="
evaluatedBlockResolutionFail <- evaluate (typeCheckStatement exampleBlockResolutionFail [] sampleClasses)
printSuccess "Type checking of block resolution failed"
printResult "Result Block Resolution:" evaluatedBlockResolutionFail
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCallAndAssignment <- evaluate (typeCheckStatement exampleMethodCallAndAssignment [] sampleClasses)
printSuccess "Type checking of method call and assignment completed successfully"
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignment
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCallAndAssignmentFail <- evaluate (typeCheckStatement exampleMethodCallAndAssignmentFail [] sampleClasses)
printSuccess "Type checking of method call and assignment failed"
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignmentFail
) handleError
catch (do
print "====================================================================================="
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
case mainClass of
Class _ [mainMethod] _ -> do
let result = typeCheckMethodDeclaration mainMethod [] testClasses
printSuccess "Full program type checking completed successfully."
printResult "Main method result:" result
) handleError
catch (do
print "====================================================================================="
let typedProgram = typeCheckCompilationUnit testClasses
printSuccess "Type checking of Program completed successfully"
printResult "Typed Program:" typedProgram
) handleError

View File

@ -1,5 +1,6 @@
module Typecheck where module Typecheck where
import Data.List (find) import Data.List (find)
import Data.Maybe
import Ast import Ast
typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit
@ -8,9 +9,11 @@ typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes
typeCheckClass :: Class -> [Class] -> Class typeCheckClass :: Class -> [Class] -> Class
typeCheckClass (Class className methods fields) classes = typeCheckClass (Class className methods fields) classes =
let let
-- Create a symbol table from class fields -- Create a symbol table from class fields and method entries
classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] classFields = [(id, dt) | VariableDeclaration dt id _ <- fields]
checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
initalSymTab = ("this", className) : classFields ++ methodEntries
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
in Class className checkedMethods fields in Class className checkedMethods fields
typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration
@ -18,9 +21,7 @@ typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFie
let let
-- Combine class fields with method parameters to form the initial symbol table for the method -- Combine class fields with method parameters to form the initial symbol table for the method
methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params]
-- Ensure method parameters shadow class fields if names collide
initialSymtab = classFields ++ methodParams initialSymtab = classFields ++ methodParams
-- Type check the body of the method using the combined symbol table
checkedBody = typeCheckStatement body initialSymtab classes checkedBody = typeCheckStatement body initialSymtab classes
bodyType = getTypeFromStmt checkedBody bodyType = getTypeFromStmt checkedBody
-- Check if the type of the body matches the declared return type -- Check if the type of the body matches the declared return type
@ -136,21 +137,18 @@ typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
error "Logical OR operation requires two operands of type boolean" error "Logical OR operation requires two operands of type boolean"
NameResolution -> NameResolution ->
case (expr1', expr2) of case (expr1', expr2) of
(TypedExpression t1 (Reference obj), Reference member) -> (TypedExpression t1 (Reference obj), Reference member) ->
-- Lookup the class type of obj from the symbol table let objectType = lookupType obj symtab
let objectType = lookupType obj symtab classDetails = find (\(Class className _ _) -> className == objectType) classes
classDetails = find (\(Class className _ _) -> className == objectType) classes in case classDetails of
in case classDetails of Just (Class _ _ fields) ->
Just (Class _ methods fields) -> let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member]
-- Check both fields and methods to find a match for member in case fieldTypes of
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] [resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2))
methodTypes = [dt | MethodDeclaration dt id _ _ <- methods, id == member] [] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'"
in case fieldTypes ++ methodTypes of _ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'"
[resolvedType] -> TypedExpression resolvedType (BinaryOperation op expr1' (TypedExpression resolvedType (Reference member))) Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
[] -> error $ "Member '" ++ member ++ "' not found in class '" ++ objectType ++ "'" _ -> error "Name resolution requires object reference and field name"
_ -> error $ "Ambiguous reference to '" ++ member ++ "' in class '" ++ objectType ++ "'"
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
_ -> error "Name resolution requires object reference and member name"
typeCheckExpression (UnaryOperation op expr) symtab classes = typeCheckExpression (UnaryOperation op expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes let expr' = typeCheckExpression expr symtab classes
@ -174,7 +172,7 @@ typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes =
in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr') in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr')
-- ********************************** Type Checking: StatementExpressions ********************************** -- ********************************** Type Checking: StatementExpressions **********************************
-- TODO: Implement type checking for StatementExpressions
typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression
typeCheckStatementExpression (Assignment id expr) symtab classes = typeCheckStatementExpression (Assignment id expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes let expr' = typeCheckExpression expr symtab classes
@ -187,12 +185,54 @@ typeCheckStatementExpression (Assignment id expr) symtab classes =
error "Assignment type mismatch" error "Assignment type mismatch"
typeCheckStatementExpression (ConstructorCall className args) symtab classes = typeCheckStatementExpression (ConstructorCall className args) symtab classes =
let args' = map (\arg -> typeCheckExpression arg symtab classes) args case find (\(Class name _ _) -> name == className) classes of
in TypedStatementExpression className (ConstructorCall className args') Nothing -> error $ "Class '" ++ className ++ "' not found."
Just (Class _ methods fields) ->
-- Constructor needs the same name as the class
case find (\(MethodDeclaration retType name params _) -> name == className && retType == className) methods of
Nothing -> error $ "No valid constructor found for class '" ++ className ++ "'."
Just (MethodDeclaration _ _ params _) ->
let
args' = map (\arg -> typeCheckExpression arg symtab classes) args
-- Extract expected parameter types from the constructor's parameters
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args'
-- Check if the types of the provided arguments match the expected types
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches)
in
if length args /= length params then
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
else if not (null mismatchErrors) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
else
TypedStatementExpression className (ConstructorCall className args')
typeCheckStatementExpression (MethodCall methodName args) symtab classes = typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
let args' = map (\arg -> typeCheckExpression arg symtab classes) args let objExprTyped = typeCheckExpression expr symtab classes
in TypedStatementExpression "Object" (MethodCall methodName args') in case objExprTyped of
TypedExpression objType _ ->
case find (\(Class className _ _) -> className == objType) classes of
Just (Class _ methods _) ->
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
Just (MethodDeclaration retType _ params _) ->
let args' = map (\arg -> typeCheckExpression arg symtab classes) args
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args'
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches
where fst3 (a, _, _) = a
in
if null mismatches && length args == length params then
TypedStatementExpression retType (MethodCall objExprTyped methodName args')
else if not (null mismatches) then
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
_ -> error "Invalid object type for method call. Object must have a class type."
-- ********************************** Type Checking: Statements ********************************** -- ********************************** Type Checking: Statements **********************************
@ -264,6 +304,10 @@ typeCheckStatement (Return expr) symtab classes =
Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e'))
Nothing -> TypedStatement "Void" (Return Nothing) Nothing -> TypedStatement "Void" (Return Nothing)
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')
-- ********************************** Type Checking: Helpers ********************************** -- ********************************** Type Checking: Helpers **********************************
getTypeFromExpr :: Expression -> DataType getTypeFromExpr :: Expression -> DataType