diff --git a/project.cabal b/project.cabal index 5a98c99..82617b1 100644 --- a/project.cabal +++ b/project.cabal @@ -16,7 +16,7 @@ executable compiler src/ByteCode, src/ByteCode/ClassFile build-tool-depends: alex:alex, happy:happy - other-modules: Parser.Lexer, Ast, Parser.JavaParser, ByteCode.ByteUtil, ByteCode.ClassFile, ByteCode.ClassFile.Generator, ByteCode.Constants + other-modules: Ast, Example, ByteCode.ByteUtil, ByteCode.ClassFile, ByteCode.ClassFile.Generator, ByteCode.Constants test-suite tests type: exitcode-stdio-1.0 @@ -28,4 +28,4 @@ test-suite tests utf8-string, bytestring build-tool-depends: alex:alex, happy:happy - other-modules: Parser.Lexer, TestLexer + other-modules: TestLexer diff --git a/src/Ast.hs b/src/Ast.hs index 2c131c6..83ccaa1 100644 --- a/src/Ast.hs +++ b/src/Ast.hs @@ -1,54 +1,357 @@ module Ast where -type CompilationUnit = [Class] +import Data.List (find) + +type CompilationUnit = [Class] + type DataType = String -type Identifier = String -data ParameterDeclaration = ParameterDeclaration DataType Identifier -data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) -data Class = Class DataType [MethodDeclaration] [VariableDeclaration] -data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement +type Identifier = String -data Statement = If Expression Statement (Maybe Statement) - | LocalVariableDeclaration VariableDeclaration - | While Expression Statement - | Block [Statement] - | Return (Maybe Expression) - | StatementExpressionStatement StatementExpression - | TypedStatement DataType Statement +data ParameterDeclaration = ParameterDeclaration DataType Identifier deriving (Show) -data StatementExpression = Assignment Identifier Expression - | ConstructorCall DataType [Expression] - | MethodCall Identifier [Expression] - | TypedStatementExpression DataType StatementExpression +data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) deriving (Show) -data BinaryOperator = Addition - | Subtraction - | Multiplication - | Division - | BitwiseAnd - | BitwiseOr - | BitwiseXor - | CompareLessThan - | CompareLessOrEqual - | CompareGreaterThan - | CompareGreaterOrEqual - | CompareEqual - | CompareNotEqual - | And - | Or - | NameResolution +data Class = Class DataType [MethodDeclaration] [VariableDeclaration] deriving (Show) -data UnaryOperator = Not - | Minus +data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement deriving (Show) -data Expression = IntegerLiteral Int - | CharacterLiteral Char - | BooleanLiteral Bool - | NullLiteral - | Reference Identifier - | BinaryOperation BinaryOperator Expression Expression - | UnaryOperation UnaryOperator Expression - | StatementExpressionExpression StatementExpression - | TypedExpression DataType Expression - \ No newline at end of file +data Statement + = If Expression Statement (Maybe Statement) + | LocalVariableDeclaration VariableDeclaration + | While Expression Statement + | Block [Statement] + | Return (Maybe Expression) + | StatementExpressionStatement StatementExpression + | TypedStatement DataType Statement + deriving (Show) + +data StatementExpression + = Assignment Identifier Expression + | ConstructorCall DataType [Expression] + | MethodCall Identifier [Expression] + | TypedStatementExpression DataType StatementExpression + deriving (Show) + +data BinaryOperator + = Addition + | Subtraction + | Multiplication + | Division + | BitwiseAnd + | BitwiseOr + | BitwiseXor + | CompareLessThan + | CompareLessOrEqual + | CompareGreaterThan + | CompareGreaterOrEqual + | CompareEqual + | CompareNotEqual + | And + | Or + | NameResolution + deriving (Show) + +data UnaryOperator + = Not + | Minus + deriving (Show) + +data Expression + = IntegerLiteral Int + | CharacterLiteral Char + | BooleanLiteral Bool + | NullLiteral + | Reference Identifier + | BinaryOperation BinaryOperator Expression Expression + | UnaryOperation UnaryOperator Expression + | StatementExpressionExpression StatementExpression + | TypedExpression DataType Expression + deriving (Show) + +typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit +typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes + +typeCheckClass :: Class -> [Class] -> Class +typeCheckClass (Class className methods fields) classes = + let + -- Create a symbol table from class fields + classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] + checkedMethods = map (\method -> typeCheckMethodDeclaration method classFields classes) methods + in Class className checkedMethods fields + +typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration +typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = + let + -- Combine class fields with method parameters to form the initial symbol table for the method + methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] + -- Ensure method parameters shadow class fields if names collide + initialSymtab = classFields ++ methodParams + -- Type check the body of the method using the combined symbol table + checkedBody = typeCheckStatement body initialSymtab classes + bodyType = getTypeFromStmt checkedBody + -- Check if the type of the body matches the declared return type + in if bodyType == retType || (bodyType == "void" && retType == "void") + then MethodDeclaration retType name params checkedBody + else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + +-- ********************************** Type Checking: Expressions ********************************** + +typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression +typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i) +typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c) +typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) +typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral +typeCheckExpression (Reference id) symtab classes = + let type' = lookupType id symtab + in TypedExpression type' (Reference id) +typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = + let expr1' = typeCheckExpression expr1 symtab classes + expr2' = typeCheckExpression expr2 symtab classes + type1 = getTypeFromExpr expr1' + type2 = getTypeFromExpr expr2' + in case op of + Addition -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Addition operation requires two operands of type int" + Subtraction -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Subtraction operation requires two operands of type int" + Multiplication -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Multiplication operation requires two operands of type int" + Division -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Division operation requires two operands of type int" + BitwiseAnd -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise AND operation requires two operands of type int" + BitwiseOr -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise OR operation requires two operands of type int" + BitwiseXor -> + if type1 == "int" && type2 == "int" + then + TypedExpression "int" (BinaryOperation op expr1' expr2') + else + error "Bitwise XOR operation requires two operands of type int" + CompareLessThan -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Less than operation requires two operands of type int" + CompareLessOrEqual -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Less than or equal operation requires two operands of type int" + CompareGreaterThan -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Greater than operation requires two operands of type int" + CompareGreaterOrEqual -> + if type1 == "int" && type2 == "int" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Greater than or equal operation requires two operands of type int" + CompareEqual -> + if type1 == type2 + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Equality operation requires two operands of the same type" + CompareNotEqual -> + if type1 == type2 + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Inequality operation requires two operands of the same type" + And -> + if type1 == "boolean" && type2 == "boolean" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Logical AND operation requires two operands of type boolean" + Or -> + if type1 == "boolean" && type2 == "boolean" + then + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + else + error "Logical OR operation requires two operands of type boolean" + NameResolution -> + case (expr1', expr2) of + (TypedExpression t1 (Reference obj), Reference member) -> + -- Lookup the class type of obj from the symbol table + let objectType = lookupType obj symtab + classDetails = find (\(Class className _ _) -> className == objectType) classes + in case classDetails of + Just (Class _ methods fields) -> + -- Check both fields and methods to find a match for member + let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] + methodTypes = [dt | MethodDeclaration dt id _ _ <- methods, id == member] + in case fieldTypes ++ methodTypes of + [resolvedType] -> TypedExpression resolvedType (BinaryOperation op expr1' (TypedExpression resolvedType (Reference member))) + [] -> error $ "Member '" ++ member ++ "' not found in class '" ++ objectType ++ "'" + _ -> error $ "Ambiguous reference to '" ++ member ++ "' in class '" ++ objectType ++ "'" + Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class" + _ -> error "Name resolution requires object reference and member name" + +typeCheckExpression (UnaryOperation op expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in case op of + Not -> + if type' == "boolean" + then + TypedExpression "boolean" (UnaryOperation op expr') + else + error "Logical NOT operation requires an operand of type boolean" + Minus -> + if type' == "int" + then + TypedExpression "int" (UnaryOperation op expr') + else + error "Unary minus operation requires an operand of type int" + +typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = + let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes + in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr') + +-- ********************************** Type Checking: StatementExpressions ********************************** +-- TODO: Implement type checking for StatementExpressions +typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression +typeCheckStatementExpression (Assignment id expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + type'' = lookupType id symtab + in if type' == type'' + then + TypedStatementExpression type' (Assignment id expr') + else + error "Assignment type mismatch" + +typeCheckStatementExpression (ConstructorCall className args) symtab classes = + let args' = map (\arg -> typeCheckExpression arg symtab classes) args + in TypedStatementExpression className (ConstructorCall className args') + +typeCheckStatementExpression (MethodCall methodName args) symtab classes = + let args' = map (\arg -> typeCheckExpression arg symtab classes) args + in TypedStatementExpression "Object" (MethodCall methodName args') + +-- ********************************** Type Checking: Statements ********************************** + +typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement +typeCheckStatement (If cond thenStmt elseStmt) symtab classes = + let cond' = typeCheckExpression cond symtab classes + thenStmt' = typeCheckStatement thenStmt symtab classes + elseStmt' = case elseStmt of + Just stmt -> Just (typeCheckStatement stmt symtab classes) + Nothing -> Nothing + in if getTypeFromExpr cond' == "boolean" + then + TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt') + else + error "If condition must be of type boolean" + +typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = + -- Check for redefinition in the current scope + if any ((== identifier) . snd) symtab + then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" + else + -- If there's an initializer expression, type check it + let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr + exprType = fmap getTypeFromExpr checkedExpr + in case exprType of + Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t + _ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) + +typeCheckStatement (While cond stmt) symtab classes = + let cond' = typeCheckExpression cond symtab classes + stmt' = typeCheckStatement stmt symtab classes + in if getTypeFromExpr cond' == "boolean" + then + TypedStatement (getTypeFromStmt stmt') (While cond' stmt') + else + error "While condition must be of type boolean" + +typeCheckStatement (Block statements) symtab classes = + let + processStatements (accSts, currentSymtab, types) stmt = + let + checkedStmt = typeCheckStatement stmt currentSymtab classes + stmtType = getTypeFromStmt checkedStmt + in case stmt of + LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> + let + checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr + newSymtab = (identifier, dataType) : currentSymtab + in (accSts ++ [checkedStmt], newSymtab, types) + + If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + _ -> (accSts ++ [checkedStmt], currentSymtab, types) + + -- Initial accumulator: empty statements list, initial symbol table, empty types list + (checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements + + -- Determine the block's type: unify all collected types, default to "Void" if none + blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes + + in TypedStatement blockType (Block checkedStatements) + +typeCheckStatement (Return expr) symtab classes = + let expr' = case expr of + Just e -> Just (typeCheckExpression e symtab classes) + Nothing -> Nothing + in case expr' of + Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) + Nothing -> TypedStatement "Void" (Return Nothing) + +-- ********************************** Type Checking: Helpers ********************************** + +getTypeFromExpr :: Expression -> DataType +getTypeFromExpr (TypedExpression t _) = t +getTypeFromExpr _ = error "Untyped expression found where typed was expected" + +getTypeFromStmt :: Statement -> DataType +getTypeFromStmt (TypedStatement t _) = t +getTypeFromStmt _ = error "Untyped statement found where typed was expected" + +getTypeFromStmtExpr :: StatementExpression -> DataType +getTypeFromStmtExpr (TypedStatementExpression t _) = t +getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected" + +unifyReturnTypes :: DataType -> DataType -> DataType +unifyReturnTypes dt1 dt2 + | dt1 == dt2 = dt1 + | otherwise = "Object" + +lookupType :: Identifier -> [(Identifier, DataType)] -> DataType +lookupType id symtab = + case lookup id symtab of + Just t -> t + Nothing -> error ("Identifier " ++ id ++ " not found in symbol table") \ No newline at end of file diff --git a/src/Example.hs b/src/Example.hs new file mode 100644 index 0000000..b45130c --- /dev/null +++ b/src/Example.hs @@ -0,0 +1,49 @@ +module Example where + +import Ast +import Control.Exception (catch, evaluate, SomeException, displayException) +import Control.Exception.Base + +-- Example classes and their methods and fields +sampleClasses :: [Class] +sampleClasses = [ + Class "Person" [ + MethodDeclaration "void" "setAge" [ParameterDeclaration "Int" "newAge"] + (Block [ + LocalVariableDeclaration (VariableDeclaration "Int" "age" (Just (Reference "newAge"))) + ]), + MethodDeclaration "Int" "getAge" [] (Return (Just (Reference "age"))) + ] [ + VariableDeclaration "Int" "age" (Just (IntegerLiteral 25)), + VariableDeclaration "String" "name" (Just (CharacterLiteral 'A')) + ] + ] + +-- Symbol table, mapping identifiers to their data types +initialSymtab :: [(DataType, Identifier)] +initialSymtab = [] + +-- An example block of statements to type check +exampleBlock :: Statement +exampleBlock = Block [ + LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [])))), + StatementExpressionStatement (MethodCall "setAge" [IntegerLiteral 30]), + Return (Just (StatementExpressionExpression (MethodCall "getAge" []))) + ] + +exampleExpression :: Expression +exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age") + +-- Function to perform type checking and handle errors +runTypeCheck :: IO () +runTypeCheck = do + -- Evaluate the block of statements + --evaluatedBlock <- evaluate (typeCheckStatement exampleBlock initialSymtab sampleClasses) + --putStrLn "Type checking of block completed successfully:" + --print evaluatedBlock + + -- Evaluate the expression + evaluatedExpression <- evaluate (typeCheckExpression exampleExpression [("bob", "Person"), ("age", "int")] sampleClasses) + putStrLn "Type checking of expression completed successfully:" + print evaluatedExpression + diff --git a/src/Main.hs b/src/Main.hs index 3b01e3c..4062f85 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,6 +1,9 @@ module Main where -import Parser.Lexer ( alexScanTokens ) +import Parser.Lexer +import Example main = do - print $ alexScanTokens "/**/" + --print $ alexScanTokens "/**/" + Example.runTypeCheck +