Compare commits

..

7 Commits

4 changed files with 169 additions and 50 deletions

View File

@ -57,7 +57,6 @@ methodBuilder (MethodDeclaration returntype name parameters statement) input = l
} }
methodAssembler :: ClassFileBuilder MethodDeclaration methodAssembler :: ClassFileBuilder MethodDeclaration
methodAssembler (MethodDeclaration returntype name parameters statement) input = let methodAssembler (MethodDeclaration returntype name parameters statement) input = let
methodConstantIndex = findMethodIndex input name methodConstantIndex = findMethodIndex input name
@ -73,7 +72,7 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input =
assembledMethod = method { assembledMethod = method {
memberAttributes = [ memberAttributes = [
CodeAttribute { CodeAttribute {
attributeMaxStack = 420, attributeMaxStack = fromIntegral $ maxStackDepth constants bytecode,
attributeMaxLocals = fromIntegral $ length aParamNames, attributeMaxLocals = fromIntegral $ length aParamNames,
attributeCode = bytecode attributeCode = bytecode
} }

View File

@ -4,7 +4,7 @@ import Data.Int
import Ast import Ast
import ByteCode.ClassFile import ByteCode.ClassFile
import Data.List import Data.List
import Data.Maybe (mapMaybe) import Data.Maybe (mapMaybe, isJust)
import Data.Word (Word8, Word16, Word32) import Data.Word (Word8, Word16, Word32)
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. -- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
@ -12,13 +12,13 @@ resolveNameChain :: Expression -> Expression
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name))
resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name))
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show (invalidExpression))
-- walks the name resolution chain. returns the second-to-last item of the namechain. -- walks the name resolution chain. returns the second-to-last item of the namechain.
resolveNameChainOwner :: Expression -> Expression resolveNameChainOwner :: Expression -> Expression
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show (invalidExpression))
methodDescriptor :: MethodDeclaration -> String methodDescriptor :: MethodDeclaration -> String
@ -39,10 +39,39 @@ methodDescriptorFromParamlist parameters returntype = let
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ datatypeDescriptor returntype
parseMethodType :: ([String], String) -> String -> ([String], String)
parseMethodType (params, returnType) ('(' : descriptor) = parseMethodType (params, returnType) descriptor
parseMethodType (params, returnType) ('I' : descriptor) = parseMethodType (params ++ ["I"], returnType) descriptor
parseMethodType (params, returnType) ('C' : descriptor) = parseMethodType (params ++ ["C"], returnType) descriptor
parseMethodType (params, returnType) ('B' : descriptor) = parseMethodType (params ++ ["B"], returnType) descriptor
parseMethodType (params, returnType) ('L' : descriptor) = let
typeLength = elemIndex ';' descriptor
in case typeLength of
Just length -> let
(typeName, semicolon : restOfDescriptor) = splitAt length descriptor
in
parseMethodType (params ++ [typeName], returnType) restOfDescriptor
Nothing -> error $ "unterminated class type in function signature: " ++ (show descriptor)
parseMethodType (params, _) (')' : descriptor) = (params, descriptor)
parseMethodType _ descriptor = error $ "expected start of type name (L, I, C, B) but got: " ++ descriptor
-- given a method index (constant pool index),
-- returns the full type of the method. (i.e (LSomething;II)V)
methodTypeFromIndex :: [ConstantInfo] -> Int -> String
methodTypeFromIndex constants index = case constants!!(fromIntegral (index - 1)) of
MethodRefInfo _ nameAndTypeIndex -> case constants!!(fromIntegral (nameAndTypeIndex - 1)) of
NameAndTypeInfo _ typeIndex -> case constants!!(fromIntegral (typeIndex - 1)) of
Utf8Info typeLiteral -> typeLiteral
unexpectedElement -> error "Expected Utf8Info but got: " ++ show unexpectedElement
unexpectedElement -> error "Expected NameAndTypeInfo but got: " ++ show unexpectedElement
unexpectedElement -> error "Expected MethodRefInfo but got: " ++ show unexpectedElement
methodParametersFromIndex :: [ConstantInfo] -> Int -> ([String], String)
methodParametersFromIndex constants index = parseMethodType ([], "V") (methodTypeFromIndex constants index)
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info)
datatypeDescriptor :: String -> String datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V" datatypeDescriptor "void" = "V"
datatypeDescriptor "int" = "I" datatypeDescriptor "int" = "I"
@ -100,6 +129,18 @@ comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLoc
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation
comparisonOffset :: Operation -> Maybe Int
comparisonOffset (Opif_icmpeq offset) = Just $ fromIntegral offset
comparisonOffset (Opif_icmpne offset) = Just $ fromIntegral offset
comparisonOffset (Opif_icmplt offset) = Just $ fromIntegral offset
comparisonOffset (Opif_icmple offset) = Just $ fromIntegral offset
comparisonOffset (Opif_icmpgt offset) = Just $ fromIntegral offset
comparisonOffset (Opif_icmpge offset) = Just $ fromIntegral offset
comparisonOffset anything_else = Nothing
isComparisonOperation :: Operation -> Bool
isComparisonOperation op = isJust (comparisonOffset op)
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int findFieldIndex :: [ConstantInfo] -> String -> Maybe Int
findFieldIndex constants name = let findFieldIndex constants name = let
fieldRefNameInfos = [ fieldRefNameInfos = [
@ -143,10 +184,10 @@ findMethodIndex classFile name = let
findClassIndex :: [ConstantInfo] -> String -> Maybe Int findClassIndex :: [ConstantInfo] -> String -> Maybe Int
findClassIndex constants name = let findClassIndex constants name = let
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)] classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip [1..] constants)]
classNames = map (\(index, nameInfo) -> case nameInfo of classNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info className -> (index, className) Utf8Info className -> (index, className)
something_else -> error("Expected UTF8Info but got " ++ show something_else)) something_else -> error ("Expected UTF8Info but got " ++ show something_else))
classNameIndices classNameIndices
desiredClassIndex = find (\(index, className) -> className == name) classNames desiredClassIndex = find (\(index, className) -> className == name) classNames
in case desiredClassIndex of in case desiredClassIndex of
@ -170,7 +211,7 @@ getKnownMembers constants = let
fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of
(Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype)) (Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype))
something_else -> error("Expected UTF8Infos but got " ++ show something_else)) something_else -> error ("Expected UTF8Infos but got " ++ show something_else))
fieldsClassNameType fieldsClassNameType
in in
fieldsResolved fieldsResolved
@ -243,3 +284,54 @@ injectFieldInitializers classname vars pre = let
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements))) MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
otherwise -> method otherwise -> method
) pre ) pre
-- effect of one instruction/operation on the stack
operationStackCost :: [ConstantInfo] -> Operation -> Int
operationStackCost constants Opiadd = -1
operationStackCost constants Opisub = -1
operationStackCost constants Opimul = -1
operationStackCost constants Opidiv = -1
operationStackCost constants Opirem = -1
operationStackCost constants Opiand = -1
operationStackCost constants Opior = -1
operationStackCost constants Opixor = -1
operationStackCost constants Opineg = 0
operationStackCost constants Opdup = 1
operationStackCost constants (Opnew _) = 1
operationStackCost constants (Opif_icmplt _) = -2
operationStackCost constants (Opif_icmple _) = -2
operationStackCost constants (Opif_icmpgt _) = -2
operationStackCost constants (Opif_icmpge _) = -2
operationStackCost constants (Opif_icmpeq _) = -2
operationStackCost constants (Opif_icmpne _) = -2
operationStackCost constants Opaconst_null = 1
operationStackCost constants Opreturn = 0
operationStackCost constants Opireturn = -1
operationStackCost constants Opareturn = -1
operationStackCost constants Opdup_x1 = 1
operationStackCost constants Oppop = -1
operationStackCost constants (Opinvokespecial idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - (fromEnum (returnType /= "V"))
operationStackCost constants (Opinvokevirtual idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - (fromEnum (returnType /= "V"))
operationStackCost constants (Opgoto _) = 0
operationStackCost constants (Opsipush _) = 1
operationStackCost constants (Opldc_w _) = 1
operationStackCost constants (Opaload _) = 1
operationStackCost constants (Opiload _) = 1
operationStackCost constants (Opastore _) = -1
operationStackCost constants (Opistore _) = -1
operationStackCost constants (Opputfield _) = -2
operationStackCost constants (Opgetfield _) = -1
simulateStackOperation :: [ConstantInfo] -> Operation -> (Int, Int) -> (Int, Int)
simulateStackOperation constants op (cd, md) = let
depth = cd + operationStackCost constants op
in if depth < 0
then error ("Consuming value off of empty stack: " ++ show op)
else (depth, max depth md)
maxStackDepth :: [ConstantInfo] -> [Operation] -> Int
maxStackDepth constants ops = snd $ foldr (simulateStackOperation constants) (0, 0) (reverse ops)

View File

@ -37,14 +37,18 @@ typeCheckVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)
-- Type check the initializer expression if it exists -- Type check the initializer expression if it exists
checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr exprType = fmap getTypeFromExpr checkedExpr
checkedExprWithType = case exprType of
Just "null" | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
_ -> checkedExpr
in case (validType, redefined, exprType) of in case (validType, redefined, exprType) of
(False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'" (False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'"
(_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" (_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
(_, _, Just t) (_, _, Just t)
| t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExpr | t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExprWithType
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
| otherwise -> VariableDeclaration dataType identifier checkedExpr | otherwise -> VariableDeclaration dataType identifier checkedExprWithType
(_, _, Nothing) -> VariableDeclaration dataType identifier checkedExpr (_, _, Nothing) -> VariableDeclaration dataType identifier checkedExprWithType
-- ********************************** Type Checking: Expressions ********************************** -- ********************************** Type Checking: Expressions **********************************
@ -125,18 +129,21 @@ typeCheckStatementExpression (Assignment ref expr) symtab classes =
ref' = typeCheckExpression ref symtab classes ref' = typeCheckExpression ref symtab classes
type' = getTypeFromExpr expr' type' = getTypeFromExpr expr'
type'' = getTypeFromExpr ref' type'' = getTypeFromExpr ref'
typeToAssign = if type' == "null" && isObjectType type'' then type'' else type'
exprWithType = if type' == "null" && isObjectType type'' then TypedExpression type'' NullLiteral else expr'
in in
if type'' == type' || (type' == "null" && isObjectType type'') then if type'' == typeToAssign then
TypedStatementExpression type'' (Assignment ref' expr') TypedStatementExpression type'' (Assignment ref' exprWithType)
else else
error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type' error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ typeToAssign
typeCheckStatementExpression (ConstructorCall className args) symtab classes = typeCheckStatementExpression (ConstructorCall className args) symtab classes =
case find (\(Class name _ _) -> name == className) classes of case find (\(Class name _ _) -> name == className) classes of
Nothing -> error $ "Class '" ++ className ++ "' not found." Nothing -> error $ "Class '" ++ className ++ "' not found."
Just (Class _ methods fields) -> Just (Class _ methods _) ->
-- Find constructor matching the class name with void return type -- Find constructor matching the class name with void return type
case find (\(MethodDeclaration retType name params _) -> name == "<init>" && retType == "void") methods of case find (\(MethodDeclaration _ name params _) -> name == "<init>") methods of
-- If no constructor is found, assume standard constructor with no parameters -- If no constructor is found, assume standard constructor with no parameters
Nothing -> Nothing ->
if null args then if null args then
@ -144,21 +151,28 @@ typeCheckStatementExpression (ConstructorCall className args) symtab classes =
else else
error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided." error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided."
Just (MethodDeclaration _ _ params _) -> Just (MethodDeclaration _ _ params _) ->
let let args' = zipWith
args' = map (\arg -> typeCheckExpression arg symtab classes) args (\arg (ParameterDeclaration paramType _) ->
-- Extract expected parameter types from the constructor's parameters let argTyped = typeCheckExpression arg symtab classes
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
then TypedExpression paramType NullLiteral
else argTyped
) args params
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params] expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args' argTypes = map getTypeFromExpr args'
-- Check if the types of the provided arguments match the expected types typeMatches = zipWith
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes (\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches) expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches
fst3 (a, _, _) = a
in in
if length args /= length params then if null mismatches && length args == length params then
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
else if not (null mismatchErrors) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
else
TypedStatementExpression className (ConstructorCall className args') TypedStatementExpression className (ConstructorCall className args')
else if not (null mismatches) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else
error $ "Incorrect number of arguments for constructor of class '" ++ className ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes = typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
let objExprTyped = typeCheckExpression expr symtab classes let objExprTyped = typeCheckExpression expr symtab classes
@ -168,20 +182,26 @@ typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
Just (Class _ methods _) -> Just (Class _ methods _) ->
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
Just (MethodDeclaration retType _ params _) -> Just (MethodDeclaration retType _ params _) ->
let args' = map (\arg -> typeCheckExpression arg symtab classes) args let args' = zipWith
(\arg (ParameterDeclaration paramType _) ->
let argTyped = typeCheckExpression arg symtab classes
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
then TypedExpression paramType NullLiteral
else argTyped
) args params
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params] expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args' argTypes = map getTypeFromExpr args'
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes typeMatches = zipWith
(\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches mismatches = filter (not . fst3) typeMatches
where fst3 (a, _, _) = a fst3 (a, _, _) = a
in in if null mismatches && length args == length params
if null mismatches && length args == length params then then TypedStatementExpression retType (MethodCall objExprTyped methodName args')
TypedStatementExpression retType (MethodCall objExprTyped methodName args') else if not (null mismatches)
else if not (null mismatches) then then error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ] : [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else else error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'." Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found." Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
_ -> error "Invalid object type for method call. Object must have a class type." _ -> error "Invalid object type for method call. Object must have a class type."
@ -251,14 +271,18 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident
-- If there's an initializer expression, type check it -- If there's an initializer expression, type check it
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr exprType = fmap getTypeFromExpr checkedExpr
checkedExprWithType = case (exprType, dataType) of
(Just "null", _) | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
_ -> checkedExpr
in case exprType of in case exprType of
Just t Just t
| t == "null" && isObjectType dataType -> | t == "null" && isObjectType dataType ->
TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
| otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) | otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (While cond stmt) symtab classes = typeCheckStatement (While cond stmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes let cond' = typeCheckExpression cond symtab classes
stmt' = typeCheckStatement stmt symtab classes stmt' = typeCheckStatement stmt symtab classes
@ -305,13 +329,17 @@ typeCheckStatement (Block statements) symtab classes =
typeCheckStatement (Return expr) symtab classes = typeCheckStatement (Return expr) symtab classes =
let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab) let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab)
expr' = case expr of expr' = case expr of
Just e -> Just (typeCheckExpression e symtab classes) Just e -> let eTyped = typeCheckExpression e symtab classes
in if getTypeFromExpr eTyped == "null" && isObjectType methodReturnType
then Just (TypedExpression methodReturnType NullLiteral)
else Just eTyped
Nothing -> Nothing Nothing -> Nothing
returnType = maybe "void" getTypeFromExpr expr' returnType = maybe "void" getTypeFromExpr expr'
in if returnType == methodReturnType || isSubtype returnType methodReturnType classes in if returnType == methodReturnType || isSubtype returnType methodReturnType classes
then TypedStatement returnType (Return expr') then TypedStatement returnType (Return expr')
else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr') in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')