Add initial typechecker for AST #2

Merged
mrab merged 121 commits from typedAST into master 2024-06-14 07:53:30 +00:00
5 changed files with 149 additions and 75 deletions
Showing only changes of commit 2ac3b60e79 - Show all commits

View File

@ -144,6 +144,21 @@ testExpressionPreIncrement = TestCase $
testExpressionPreDecrement = TestCase $ testExpressionPreDecrement = TestCase $
assertEqual "expect PreIncrement" (UnaryOperation PreDecrement (Reference "a")) $ assertEqual "expect PreIncrement" (UnaryOperation PreDecrement (Reference "a")) $
parseExpression [DECREMENT,IDENTIFIER "a"] parseExpression [DECREMENT,IDENTIFIER "a"]
testExpressionAssign = TestCase $
assertEqual "expect assign 5 to a" (StatementExpressionExpression (Assignment (Reference "a") (IntegerLiteral 5))) $
parseExpression [IDENTIFIER "a",ASSIGN,INTEGERLITERAL 5]
testExpressionTimesEqual = TestCase $
assertEqual "expect assign and multiplication" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Multiplication (Reference "a") (IntegerLiteral 5)))) $
parseExpression [IDENTIFIER "a",TIMESEQUAL,INTEGERLITERAL 5]
testExpressionDivideEqual = TestCase $
assertEqual "expect assign and division" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Division (Reference "a") (IntegerLiteral 5)))) $
parseExpression [IDENTIFIER "a",DIVEQUAL,INTEGERLITERAL 5]
testExpressionPlusEqual = TestCase $
assertEqual "expect assign and addition" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Addition (Reference "a") (IntegerLiteral 5)))) $
parseExpression [IDENTIFIER "a",PLUSEQUAL,INTEGERLITERAL 5]
testExpressionMinusEqual = TestCase $
assertEqual "expect assign and subtraction" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Subtraction (Reference "a") (IntegerLiteral 5)))) $
parseExpression [IDENTIFIER "a",MINUSEQUAL,INTEGERLITERAL 5]
testStatementIfThen = TestCase $ testStatementIfThen = TestCase $
assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) Nothing] $ assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) Nothing] $
@ -151,6 +166,10 @@ testStatementIfThen = TestCase $
testStatementIfThenElse = TestCase $ testStatementIfThenElse = TestCase $
assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) (Just (Block [Block []]))] $ assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) (Just (Block [Block []]))] $
parseStatement [IF,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET,ELSE,LBRACKET,RBRACKET] parseStatement [IF,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET,ELSE,LBRACKET,RBRACKET]
testStatementWhile = TestCase $
assertEqual "expect while" [While (Reference "a") (Block [Block []])] $
parseStatement [WHILE,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET]
tests = TestList [ tests = TestList [
@ -193,6 +212,13 @@ tests = TestList [
testExpressionPostDecrement, testExpressionPostDecrement,
testExpressionPreIncrement, testExpressionPreIncrement,
testExpressionPreDecrement, testExpressionPreDecrement,
testExpressionAssign,
testExpressionTimesEqual,
testExpressionTimesEqual,
testExpressionDivideEqual,
testExpressionPlusEqual,
testExpressionMinusEqual,
testStatementIfThen, testStatementIfThen,
testStatementIfThenElse testStatementIfThenElse,
testStatementWhile
] ]

View File

@ -24,6 +24,10 @@ data StatementExpression
| ConstructorCall DataType [Expression] | ConstructorCall DataType [Expression]
| MethodCall Expression Identifier [Expression] | MethodCall Expression Identifier [Expression]
| TypedStatementExpression DataType StatementExpression | TypedStatementExpression DataType StatementExpression
| PostIncrement Expression
| PostDecrement Expression
| PreIncrement Expression
| PreDecrement Expression
deriving (Show, Eq) deriving (Show, Eq)
data BinaryOperator data BinaryOperator
@ -49,10 +53,6 @@ data BinaryOperator
data UnaryOperator data UnaryOperator
= Not = Not
| Minus | Minus
| PostIncrement
| PostDecrement
| PreIncrement
| PreDecrement
deriving (Show, Eq) deriving (Show, Eq)
data Expression data Expression

View File

@ -101,24 +101,36 @@ exampleNameResolutionAssignment = Block [
exampleCharIntOperation :: Expression exampleCharIntOperation :: Expression
exampleCharIntOperation = BinaryOperation Addition (CharacterLiteral 'a') (IntegerLiteral 1) exampleCharIntOperation = BinaryOperation Addition (CharacterLiteral 'a') (IntegerLiteral 1)
exampleNullDeclaration :: Statement
exampleNullDeclaration = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just NullLiteral))
exampleNullDeclarationFail :: Statement
exampleNullDeclarationFail = LocalVariableDeclaration (VariableDeclaration "int" "a" (Just NullLiteral))
exampleNullAssignment :: Statement
exampleNullAssignment = StatementExpressionStatement (Assignment (Reference "a") NullLiteral)
exampleIncrement :: Statement
exampleIncrement = StatementExpressionStatement (PostIncrement (Reference "a"))
testClasses :: [Class] testClasses :: [Class]
testClasses = [ testClasses = [
Class "Person" [ Class "Person" [
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"]
(Block [ (Block [
Return (Just (Reference "this")) Return (Just (Reference "this"))
]), ]),
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"] MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
(Block [ (Block [
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge"))) LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
]), ]),
MethodDeclaration "int" "getAge" [] MethodDeclaration "int" "getAge" []
(Return (Just (Reference "age"))) (Return (Just (Reference "age")))
] [ ] [
VariableDeclaration "int" "age" Nothing -- initially unassigned VariableDeclaration "int" "age" Nothing -- initially unassigned
], ],
Class "Main" [ Class "Main" [
MethodDeclaration "int" "main" [] MethodDeclaration "int" "main" []
(Block [ (Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))), LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]), StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
@ -211,7 +223,7 @@ runTypeCheck = do
printSuccess "Type checking of Program completed successfully" printSuccess "Type checking of Program completed successfully"
printResult "Typed Program:" typedProgram printResult "Typed Program:" typedProgram
) handleError ) handleError
catch (do catch (do
print "=====================================================================================" print "====================================================================================="
typedAssignment <- evaluate (typeCheckStatement exampleNameResolutionAssignment [] sampleClasses) typedAssignment <- evaluate (typeCheckStatement exampleNameResolutionAssignment [] sampleClasses)
@ -226,3 +238,30 @@ runTypeCheck = do
printResult "Result Char Int Operation:" evaluatedCharIntOperation printResult "Result Char Int Operation:" evaluatedCharIntOperation
) handleError ) handleError
catch (do
print "====================================================================================="
evaluatedNullDeclaration <- evaluate (typeCheckStatement exampleNullDeclaration [] sampleClasses)
printSuccess "Type checking of null declaration completed successfully"
printResult "Result Null Declaration:" evaluatedNullDeclaration
) handleError
catch (do
print "====================================================================================="
evaluatedNullDeclarationFail <- evaluate (typeCheckStatement exampleNullDeclarationFail [] sampleClasses)
printSuccess "Type checking of null declaration failed"
printResult "Result Null Declaration:" evaluatedNullDeclarationFail
) handleError
catch (do
print "====================================================================================="
evaluatedNullAssignment <- evaluate (typeCheckStatement exampleNullAssignment [("a", "Person")] sampleClasses)
printSuccess "Type checking of null assignment completed successfully"
printResult "Result Null Assignment:" evaluatedNullAssignment
) handleError
catch (do
print "====================================================================================="
evaluatedIncrement <- evaluate (typeCheckStatement exampleIncrement [("a", "int")] sampleClasses)
printSuccess "Type checking of increment completed successfully"
printResult "Result Increment:" evaluatedIncrement
) handleError

View File

@ -82,7 +82,7 @@ compilationunit : typedeclarations { $1 }
typedeclarations : typedeclaration { [$1] } typedeclarations : typedeclaration { [$1] }
| typedeclarations typedeclaration { $1 ++ [$2] } | typedeclarations typedeclaration { $1 ++ [$2] }
name : simplename { $1 } name : simplename { Reference $1 }
-- | qualifiedname { } -- | qualifiedname { }
typedeclaration : classdeclaration { $1 } typedeclaration : classdeclaration { $1 }
@ -122,7 +122,7 @@ classtype : classorinterfacetype{ }
classbodydeclaration : classmemberdeclaration { $1 } classbodydeclaration : classmemberdeclaration { $1 }
| constructordeclaration { $1 } | constructordeclaration { $1 }
classorinterfacetype : name { $1 } classorinterfacetype : simplename { $1 }
classmemberdeclaration : fielddeclaration { $1 } classmemberdeclaration : fielddeclaration { $1 }
| methoddeclaration { $1 } | methoddeclaration { $1 }
@ -203,7 +203,7 @@ localvariabledeclarationstatement : localvariabledeclaration SEMICOLON { $1 }
statement : statementwithouttrailingsubstatement{ $1 } -- statement returns a list of statements statement : statementwithouttrailingsubstatement{ $1 } -- statement returns a list of statements
| ifthenstatement { [$1] } | ifthenstatement { [$1] }
| ifthenelsestatement { [$1] } | ifthenelsestatement { [$1] }
-- | whilestatement { } | whilestatement { [$1] }
expression : assignmentexpression { $1 } expression : assignmentexpression { $1 }
@ -222,10 +222,10 @@ ifthenstatement : IF LBRACE expression RBRACE statement { If $3 (Block $5) No
ifthenelsestatement : IF LBRACE expression RBRACE statementnoshortif ELSE statement { If $3 (Block $5) (Just (Block $7)) } ifthenelsestatement : IF LBRACE expression RBRACE statementnoshortif ELSE statement { If $3 (Block $5) (Just (Block $7)) }
whilestatement : WHILE LBRACE expression RBRACE statement { } whilestatement : WHILE LBRACE expression RBRACE statement { While $3 (Block $5) }
assignmentexpression : conditionalexpression { $1 } assignmentexpression : conditionalexpression { $1 }
-- | assignment { } | assignment { StatementExpressionExpression $1 }
emptystatement : SEMICOLON { Block [] } emptystatement : SEMICOLON { Block [] }
@ -241,7 +241,11 @@ statementnoshortif : statementwithouttrailingsubstatement { $1 }
conditionalexpression : conditionalorexpression { $1 } conditionalexpression : conditionalorexpression { $1 }
-- | conditionalorexpression QUESMARK expression COLON conditionalexpression { } -- | conditionalorexpression QUESMARK expression COLON conditionalexpression { }
assignment : lefthandside assignmentoperator assignmentexpression { } assignment : lefthandside assignmentoperator assignmentexpression {
case $2 of
Nothing -> Assignment $1 $3
Just operator -> Assignment $1 (BinaryOperation operator $1 $3)
}
statementexpression : assignment { } statementexpression : assignment { }
@ -262,18 +266,18 @@ conditionalorexpression : conditionalandexpression { $1 }
lefthandside : name { $1 } lefthandside : name { $1 }
assignmentoperator : ASSIGN{ } assignmentoperator : ASSIGN { Nothing }
-- | TIMESEQUAL { } | TIMESEQUAL { Just Multiplication }
-- | DIVIDEEQUAL { } | DIVIDEEQUAL { Just Division }
-- | MODULOEQUAL { } | MODULOEQUAL { Just Modulo }
-- | PLUSEQUAL { } | PLUSEQUAL { Just Addition }
-- | MINUSEQUAL { } | MINUSEQUAL { Just Subtraction }
-- | SHIFTLEFTEQUAL { } -- | SHIFTLEFTEQUAL { }
-- | SIGNEDSHIFTRIGHTEQUAL { } -- | SIGNEDSHIFTRIGHTEQUAL { }
-- | UNSIGNEDSHIFTRIGHTEQUAL { } -- | UNSIGNEDSHIFTRIGHTEQUAL { }
-- | ANDEQUAL { } | ANDEQUAL { Just BitwiseAnd }
-- | XOREQUAL { } | XOREQUAL { Just BitwiseXor }
-- | OREQUAL{ } | OREQUAL{ Just BitwiseOr }
preincrementexpression : INCREMENT unaryexpression { UnaryOperation PreIncrement $2 } preincrementexpression : INCREMENT unaryexpression { UnaryOperation PreIncrement $2 }
@ -302,7 +306,7 @@ unaryexpression : unaryexpressionnotplusminus { $1 }
| preincrementexpression { $1 } | preincrementexpression { $1 }
postfixexpression : primary { $1 } postfixexpression : primary { $1 }
| name { Reference $1 } | name { $1 }
| postincrementexpression { $1 } | postincrementexpression { $1 }
| postdecrementexpression{ $1 } | postdecrementexpression{ $1 }

View File

@ -25,7 +25,7 @@ typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFie
checkedBody = typeCheckStatement body initialSymtab classes checkedBody = typeCheckStatement body initialSymtab classes
bodyType = getTypeFromStmt checkedBody bodyType = getTypeFromStmt checkedBody
-- Check if the type of the body matches the declared return type -- Check if the type of the body matches the declared return type
in if bodyType == retType || (bodyType == "void" && retType == "void") in if bodyType == retType || (bodyType == "void" && retType == "void") || (bodyType == "null" && isObjectType retType)
then MethodDeclaration retType name params checkedBody then MethodDeclaration retType name params checkedBody
else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType
@ -89,50 +89,11 @@ typeCheckExpression (UnaryOperation op expr) symtab classes =
else else
error "Logical NOT operation requires an operand of type boolean" error "Logical NOT operation requires an operand of type boolean"
Minus -> Minus ->
if type' == "int" if type' == "int" || type' == "char"
then then
TypedExpression "int" (UnaryOperation op expr') TypedExpression type' (UnaryOperation op expr')
else if type' == "char"
then
TypedExpression "char" (UnaryOperation op expr')
else else
error "Unary minus operation requires an operand of type int or char" error "Unary minus operation requires an operand of type int or char"
PostIncrement ->
if type' == "int"
then
TypedExpression "int" (UnaryOperation op expr')
else if type' == "char"
then
TypedExpression "char" (UnaryOperation op expr')
else
error "Post-increment operation requires an operand of type int or char"
PostDecrement ->
if type' == "int"
then
TypedExpression "int" (UnaryOperation op expr')
else if type' == "char"
then
TypedExpression "char" (UnaryOperation op expr')
else
error "Post-decrement operation requires an operand of type int or char"
PreIncrement ->
if type' == "int"
then
TypedExpression "int" (UnaryOperation op expr')
else if type' == "char"
then
TypedExpression "char" (UnaryOperation op expr')
else
error "Pre-increment operation requires an operand of type int or char"
PreDecrement ->
if type' == "int"
then
TypedExpression "int" (UnaryOperation op expr')
else if type' == "char"
then
TypedExpression "char" (UnaryOperation op expr')
else
error "Pre-decrement operation requires an operand of type int or char"
typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
@ -146,11 +107,11 @@ typeCheckStatementExpression (Assignment ref expr) symtab classes =
ref' = typeCheckExpression ref symtab classes ref' = typeCheckExpression ref symtab classes
type' = getTypeFromExpr expr' type' = getTypeFromExpr expr'
type'' = getTypeFromExpr ref' type'' = getTypeFromExpr ref'
in in
if type'' == type' then if type'' == type' || (type' == "null" && isObjectType type'') then
TypedStatementExpression type' (Assignment ref' expr') TypedStatementExpression type'' (Assignment ref' expr')
else else
error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type' error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type'
typeCheckStatementExpression (ConstructorCall className args) symtab classes = typeCheckStatementExpression (ConstructorCall className args) symtab classes =
case find (\(Class name _ _) -> name == className) classes of case find (\(Class name _ _) -> name == className) classes of
@ -202,6 +163,42 @@ typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found." Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
_ -> error "Invalid object type for method call. Object must have a class type." _ -> error "Invalid object type for method call. Object must have a class type."
typeCheckStatementExpression (PostIncrement expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
in if type' == "int" || type' == "char"
then
TypedStatementExpression type' (PostIncrement expr')
else
error "Post-increment operation requires an operand of type int or char"
typeCheckStatementExpression (PostDecrement expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
in if type' == "int" || type' == "char"
then
TypedStatementExpression type' (PostDecrement expr')
else
error "Post-decrement operation requires an operand of type int or char"
typeCheckStatementExpression (PreIncrement expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
in if type' == "int" || type' == "char"
then
TypedStatementExpression type' (PreIncrement expr')
else
error "Pre-increment operation requires an operand of type int or char"
typeCheckStatementExpression (PreDecrement expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
in if type' == "int" || type' == "char"
then
TypedStatementExpression type' (PreDecrement expr')
else
error "Pre-decrement operation requires an operand of type int or char"
-- ********************************** Type Checking: Statements ********************************** -- ********************************** Type Checking: Statements **********************************
typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement
@ -226,8 +223,12 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr exprType = fmap getTypeFromExpr checkedExpr
in case exprType of in case exprType of
Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t Just t
_ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) | t == "null" && isObjectType dataType ->
TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
| otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (While cond stmt) symtab classes = typeCheckStatement (While cond stmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes let cond' = typeCheckExpression cond symtab classes
@ -254,6 +255,7 @@ typeCheckStatement (Block statements) symtab classes =
If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
Block _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
_ -> (accSts ++ [checkedStmt], currentSymtab, types) _ -> (accSts ++ [checkedStmt], currentSymtab, types)
-- Initial accumulator: empty statements list, initial symbol table, empty types list -- Initial accumulator: empty statements list, initial symbol table, empty types list
@ -278,6 +280,9 @@ typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
-- ********************************** Type Checking: Helpers ********************************** -- ********************************** Type Checking: Helpers **********************************
isObjectType :: DataType -> Bool
isObjectType dt = dt /= "int" && dt /= "boolean" && dt /= "char"
getTypeFromExpr :: Expression -> DataType getTypeFromExpr :: Expression -> DataType
getTypeFromExpr (TypedExpression t _) = t getTypeFromExpr (TypedExpression t _) = t
getTypeFromExpr _ = error "Untyped expression found where typed was expected" getTypeFromExpr _ = error "Untyped expression found where typed was expected"