diff --git a/Test/TestByteCodeGenerator.hs b/Test/TestByteCodeGenerator.hs index 3e44349..2c6f4d2 100644 --- a/Test/TestByteCodeGenerator.hs +++ b/Test/TestByteCodeGenerator.hs @@ -107,10 +107,12 @@ testBasicConstantPool = TestCase $ assertEqual "basic constant pool" expectedCla testFields = TestCase $ assertEqual "fields in constant pool" expectedClassWithFields $ classBuilder classWithFields emptyClassFile testMethodDescriptor = TestCase $ assertEqual "method descriptor" "(II)I" (methodDescriptor method) testMethodAssembly = TestCase $ assertEqual "method assembly" expectedClassWithMethod (classBuilder classWithMethod emptyClassFile) +testFindMethodIndex = TestCase $ assertEqual "find method" (Just 0) (findMethodIndex expectedClassWithMethod "add_two_numbers") tests = TestList [ TestLabel "Basic constant pool" testBasicConstantPool, TestLabel "Fields constant pool" testFields, TestLabel "Method descriptor building" testMethodDescriptor, - TestLabel "Method assembly" testMethodAssembly + TestLabel "Method assembly" testMethodAssembly, + TestLabel "Find method by name" testFindMethodIndex ] \ No newline at end of file diff --git a/Test/TestParser.hs b/Test/TestParser.hs index ce3f2cc..5041e3e 100644 --- a/Test/TestParser.hs +++ b/Test/TestParser.hs @@ -18,11 +18,286 @@ testBooleanField = TestCase $ testIntField = TestCase $ assertEqual "expect class with int field" [Class "WithInt" [] [VariableDeclaration "int" "value" Nothing]] $ parse [CLASS,IDENTIFIER "WithInt",LBRACKET,INT,IDENTIFIER "value",SEMICOLON,RBRACKET] +testCustomTypeField = TestCase $ + assertEqual "expect class with foo field" [Class "WithFoo" [] [VariableDeclaration "Foo" "value" Nothing]] $ + parse [CLASS,IDENTIFIER "WithFoo",LBRACKET,IDENTIFIER "Foo",IDENTIFIER "value",SEMICOLON,RBRACKET] +testMultipleDeclarationSameLine = TestCase $ + assertEqual "expect class with two int fields" [Class "TwoInts" [] [VariableDeclaration "int" "num1" Nothing, VariableDeclaration "int" "num2" Nothing]] $ + parse [CLASS,IDENTIFIER "TwoInts",LBRACKET,INT,IDENTIFIER "num1",COMMA,IDENTIFIER "num2",SEMICOLON,RBRACKET] +testMultipleDeclarations = TestCase $ + assertEqual "expect class with int and char field" [Class "Multiple" [] [VariableDeclaration "int" "value" Nothing, VariableDeclaration "char" "letter" Nothing]] $ + parse [CLASS,IDENTIFIER "Multiple",LBRACKET,INT,IDENTIFIER "value",SEMICOLON,CHAR,IDENTIFIER "letter",SEMICOLON,RBRACKET] +testWithModifier = TestCase $ + assertEqual "expect class with int field" [Class "WithInt" [] [VariableDeclaration "int" "value" Nothing]] $ + parse [ABSTRACT,CLASS,IDENTIFIER "WithInt",LBRACKET,PUBLIC,INT,IDENTIFIER "value",SEMICOLON,RBRACKET] + +testEmptyMethod = TestCase $ + assertEqual "expect class with method" [Class "WithMethod" [MethodDeclaration "int" "foo" [] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithMethod",LBRACKET,INT,IDENTIFIER "foo",LBRACE,RBRACE,SEMICOLON,RBRACKET] +testEmptyPrivateMethod = TestCase $ + assertEqual "expect class with method" [Class "WithMethod" [MethodDeclaration "int" "foo" [] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithMethod",LBRACKET,PRIVATE,INT,IDENTIFIER "foo",LBRACE,RBRACE,LBRACKET,RBRACKET,RBRACKET] +testEmptyVoidMethod = TestCase $ + assertEqual "expect class with method" [Class "WithMethod" [MethodDeclaration "void" "foo" [] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithMethod",LBRACKET,VOID,IDENTIFIER "foo",LBRACE,RBRACE,LBRACKET,RBRACKET,RBRACKET] +testEmptyMethodWithParam = TestCase $ + assertEqual "expect class with method with param" [Class "WithParam" [MethodDeclaration "void" "foo" [ParameterDeclaration "int" "param"] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithParam",LBRACKET,VOID,IDENTIFIER "foo",LBRACE,INT,IDENTIFIER "param",RBRACE,SEMICOLON,RBRACKET] +testEmptyMethodWithParams = TestCase $ + assertEqual "expect class with multiple params" [Class "WithParams" [MethodDeclaration "void" "foo" [ParameterDeclaration "int" "p1",ParameterDeclaration "Custom" "p2"] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithParams",LBRACKET,VOID,IDENTIFIER "foo",LBRACE,INT,IDENTIFIER "p1",COMMA,IDENTIFIER "Custom",IDENTIFIER "p2",RBRACE,SEMICOLON,RBRACKET] +testClassWithMethodAndField = TestCase $ + assertEqual "expect class with method and field" [Class "WithMethodAndField" [MethodDeclaration "void" "foo" [] (Block []), MethodDeclaration "int" "bar" [] (Block [])] [VariableDeclaration "int" "value" Nothing]] $ + parse [CLASS,IDENTIFIER "WithMethodAndField",LBRACKET,VOID,IDENTIFIER "foo",LBRACE,RBRACE,LBRACKET,RBRACKET,INT,IDENTIFIER "value",SEMICOLON,INT,IDENTIFIER "bar",LBRACE,RBRACE,SEMICOLON,RBRACKET] +testClassWithConstructor = TestCase $ + assertEqual "expect class with constructor" [Class "WithConstructor" [MethodDeclaration "void" "" [] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithConstructor",LBRACKET,IDENTIFIER "WithConstructor",LBRACE,RBRACE,LBRACKET,RBRACKET,RBRACKET] +testConstructorWithParams = TestCase $ + assertEqual "expect constructor with params" [Class "WithParams" [MethodDeclaration "void" "" [ParameterDeclaration "int" "p1"] (Block [])] []] $ + parse [CLASS,IDENTIFIER "WithParams",LBRACKET,IDENTIFIER "WithParams",LBRACE,INT,IDENTIFIER "p1",RBRACE,LBRACKET,RBRACKET,RBRACKET] +testConstructorWithStatements = TestCase $ + assertEqual "expect constructor with statement" [Class "WithConstructor" [MethodDeclaration "void" "" [] (Block [Return Nothing])] []] $ + parse [CLASS,IDENTIFIER "WithConstructor",LBRACKET,IDENTIFIER "WithConstructor",LBRACE,RBRACE,LBRACKET,RETURN,SEMICOLON,RBRACKET,RBRACKET] + + +testEmptyBlock = TestCase $ assertEqual "expect empty block" [Block []] $ parseStatement [LBRACKET,RBRACKET] +testBlockWithLocalVarDecl = TestCase $ + assertEqual "expect block with local var delcaration" [Block [LocalVariableDeclaration $ VariableDeclaration "int" "localvar" Nothing]] $ + parseStatement [LBRACKET,INT,IDENTIFIER "localvar",SEMICOLON,RBRACKET] +testBlockWithMultipleLocalVarDecls = TestCase $ + assertEqual "expect block with multiple local var declarations" [Block [LocalVariableDeclaration $ VariableDeclaration "int" "var1" Nothing, LocalVariableDeclaration $ VariableDeclaration "boolean" "var2" Nothing]] $ + parseStatement [LBRACKET,INT,IDENTIFIER "var1",SEMICOLON,BOOLEAN,IDENTIFIER "var2",SEMICOLON,RBRACKET] +testNestedBlocks = TestCase $ + assertEqual "expect block with block inside" [Block [Block []]] $ + parseStatement [LBRACKET,LBRACKET,RBRACKET,RBRACKET] +testBlockWithEmptyStatement = TestCase $ + assertEqual "expect empty block" [Block []] $ + parseStatement [LBRACKET,SEMICOLON,SEMICOLON,RBRACKET] + +testExpressionIntLiteral = TestCase $ + assertEqual "expect IntLiteral" (IntegerLiteral 3) $ + parseExpression [INTEGERLITERAL 3] +testFieldWithInitialization = TestCase $ + assertEqual "expect Class with initialized field" [Class "WithInitField" [] [VariableDeclaration "int" "number" $ Just $ IntegerLiteral 3]] $ + parse [CLASS,IDENTIFIER "WithInitField",LBRACKET,INT,IDENTIFIER "number",ASSIGN,INTEGERLITERAL 3,SEMICOLON,RBRACKET] +testLocalBoolWithInitialization = TestCase $ + assertEqual "expect block with with initialized local var" [Block [LocalVariableDeclaration $ VariableDeclaration "boolean" "b" $ Just $ BooleanLiteral False]] $ + parseStatement [LBRACKET,BOOLEAN,IDENTIFIER "b",ASSIGN,BOOLLITERAL False,SEMICOLON,RBRACKET] +testFieldNullWithInitialization = TestCase $ + assertEqual "expect Class with initialized field" [Class "WithInitField" [] [VariableDeclaration "Object" "bar" $ Just NullLiteral]] $ + parse [CLASS,IDENTIFIER "WithInitField",LBRACKET,IDENTIFIER "Object",IDENTIFIER "bar",ASSIGN,NULLLITERAL,SEMICOLON,RBRACKET] +testReturnVoid = TestCase $ + assertEqual "expect block with return nothing" [Block [Return Nothing]] $ + parseStatement [LBRACKET,RETURN,SEMICOLON,RBRACKET] + +testExpressionNot = TestCase $ + assertEqual "expect expression not" (UnaryOperation Not (Reference "boar")) $ + parseExpression [NOT,IDENTIFIER "boar"] +testExpressionMinus = TestCase $ + assertEqual "expect expression minus" (UnaryOperation Minus (Reference "boo")) $ + parseExpression [MINUS,IDENTIFIER "boo"] +testExpressionMultiplication = TestCase $ + assertEqual "expect multiplication" (BinaryOperation Multiplication (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",TIMES,INTEGERLITERAL 3] +testExpressionDivision = TestCase $ + assertEqual "expect division" (BinaryOperation Division (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",DIV,INTEGERLITERAL 3] +testExpressionModulo = TestCase $ + assertEqual "expect modulo operation" (BinaryOperation Modulo (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",MODULO,INTEGERLITERAL 3] +testExpressionAddition = TestCase $ + assertEqual "expect addition" (BinaryOperation Addition (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",PLUS,INTEGERLITERAL 3] +testExpressionSubtraction = TestCase $ + assertEqual "expect subtraction" (BinaryOperation Subtraction (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",MINUS,INTEGERLITERAL 3] +testExpressionLessThan = TestCase $ + assertEqual "expect comparision less than" (BinaryOperation CompareLessThan (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",LESS,INTEGERLITERAL 3] +testExpressionGreaterThan = TestCase $ + assertEqual "expect comparision greater than" (BinaryOperation CompareGreaterThan (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",GREATER,INTEGERLITERAL 3] +testExpressionLessThanEqual = TestCase $ + assertEqual "expect comparision less than or equal" (BinaryOperation CompareLessOrEqual (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",LESSEQUAL,INTEGERLITERAL 3] +testExpressionGreaterThanOrEqual = TestCase $ + assertEqual "expect comparision greater than or equal" (BinaryOperation CompareGreaterOrEqual (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",GREATEREQUAL,INTEGERLITERAL 3] +testExpressionEqual = TestCase $ + assertEqual "expect comparison equal" (BinaryOperation CompareEqual (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",EQUAL,INTEGERLITERAL 3] +testExpressionNotEqual = TestCase $ + assertEqual "expect comparison equal" (BinaryOperation CompareNotEqual (Reference "bar") (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "bar",NOTEQUAL,INTEGERLITERAL 3] +testExpressionAnd = TestCase $ + assertEqual "expect and expression" (BinaryOperation And (Reference "bar") (Reference "baz")) $ + parseExpression [IDENTIFIER "bar",AND,IDENTIFIER "baz"] +testExpressionXor = TestCase $ + assertEqual "expect xor expression" (BinaryOperation BitwiseXor (Reference "bar") (Reference "baz")) $ + parseExpression [IDENTIFIER "bar",XOR,IDENTIFIER "baz"] +testExpressionOr = TestCase $ + assertEqual "expect or expression" (BinaryOperation Or (Reference "bar") (Reference "baz")) $ + parseExpression [IDENTIFIER "bar",OR,IDENTIFIER "baz"] +testExpressionPostIncrement = TestCase $ + assertEqual "expect PostIncrement" (StatementExpressionExpression $ PostIncrement (Reference "a")) $ + parseExpression [IDENTIFIER "a",INCREMENT] +testExpressionPostDecrement = TestCase $ + assertEqual "expect PostDecrement" (StatementExpressionExpression $ PostDecrement (Reference "a")) $ + parseExpression [IDENTIFIER "a",DECREMENT] +testExpressionPreIncrement = TestCase $ + assertEqual "expect PreIncrement" (StatementExpressionExpression $ PreIncrement (Reference "a")) $ + parseExpression [INCREMENT,IDENTIFIER "a"] +testExpressionPreDecrement = TestCase $ + assertEqual "expect PreIncrement" (StatementExpressionExpression $ PreDecrement (Reference "a")) $ + parseExpression [DECREMENT,IDENTIFIER "a"] +testExpressionAssign = TestCase $ + assertEqual "expect assign 5 to a" (StatementExpressionExpression (Assignment (Reference "a") (IntegerLiteral 5))) $ + parseExpression [IDENTIFIER "a",ASSIGN,INTEGERLITERAL 5] +testExpressionTimesEqual = TestCase $ + assertEqual "expect assign and multiplication" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Multiplication (Reference "a") (IntegerLiteral 5)))) $ + parseExpression [IDENTIFIER "a",TIMESEQUAL,INTEGERLITERAL 5] +testExpressionDivideEqual = TestCase $ + assertEqual "expect assign and division" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Division (Reference "a") (IntegerLiteral 5)))) $ + parseExpression [IDENTIFIER "a",DIVEQUAL,INTEGERLITERAL 5] +testExpressionPlusEqual = TestCase $ + assertEqual "expect assign and addition" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Addition (Reference "a") (IntegerLiteral 5)))) $ + parseExpression [IDENTIFIER "a",PLUSEQUAL,INTEGERLITERAL 5] +testExpressionMinusEqual = TestCase $ + assertEqual "expect assign and subtraction" (StatementExpressionExpression (Assignment (Reference "a") (BinaryOperation Subtraction (Reference "a") (IntegerLiteral 5)))) $ + parseExpression [IDENTIFIER "a",MINUSEQUAL,INTEGERLITERAL 5] +testExpressionThis = TestCase $ + assertEqual "expect this" (Reference "this") $ + parseExpression [THIS] +testExpressionBraced = TestCase $ + assertEqual "expect braced expresssion" (BinaryOperation Multiplication (Reference "b") (BinaryOperation Addition (Reference "a") (IntegerLiteral 3))) $ + parseExpression [IDENTIFIER "b",TIMES,LBRACE,IDENTIFIER "a",PLUS,INTEGERLITERAL 3,RBRACE] + +testExpressionPrecedence = TestCase $ + assertEqual "expect times to be inner expression" (BinaryOperation Addition (BinaryOperation Multiplication (Reference "b") (Reference "a")) (IntegerLiteral 3)) $ + parseExpression [IDENTIFIER "b",TIMES,IDENTIFIER "a",PLUS,INTEGERLITERAL 3] + +testExpressionMethodCallNoParams = TestCase $ + assertEqual "expect methodcall no params" (StatementExpressionExpression (MethodCall (Reference "this") "foo" [])) $ + parseExpression [IDENTIFIER "foo",LBRACE,RBRACE] +testExpressionMethodCallOneParam = TestCase $ + assertEqual "expect methodcall one param" (StatementExpressionExpression (MethodCall (Reference "this") "foo" [Reference "a"])) $ + parseExpression [IDENTIFIER "foo",LBRACE,IDENTIFIER "a",RBRACE] +testExpressionMethodCallTwoParams = TestCase $ + assertEqual "expect methocall two params" (StatementExpressionExpression (MethodCall (Reference "this") "foo" [Reference "a", IntegerLiteral 5])) $ + parseExpression [IDENTIFIER "foo",LBRACE,IDENTIFIER "a",COMMA,INTEGERLITERAL 5,RBRACE] +testExpressionThisMethodCall = TestCase $ + assertEqual "expect this methocall" (StatementExpressionExpression (MethodCall (Reference "this") "foo" [])) $ + parseExpression [THIS,DOT,IDENTIFIER "foo",LBRACE,RBRACE] +testExpressionThisMethodCallParam = TestCase $ + assertEqual "expect this methocall" (StatementExpressionExpression (MethodCall (Reference "this") "foo" [Reference "x"])) $ + parseExpression [THIS,DOT,IDENTIFIER "foo",LBRACE,IDENTIFIER "x",RBRACE] +testExpressionFieldAccess = TestCase $ + assertEqual "expect NameResolution" (BinaryOperation NameResolution (Reference "this") (Reference "b")) $ + parseExpression [THIS,DOT,IDENTIFIER "b"] +testExpressionSimpleFieldAccess = TestCase $ + assertEqual "expect Reference" (Reference "a") $ + parseExpression [IDENTIFIER "a"] +testExpressionFieldSubAccess = TestCase $ + assertEqual "expect NameResolution without this" (BinaryOperation NameResolution (Reference "a") (Reference "b")) $ + parseExpression [IDENTIFIER "a",DOT,IDENTIFIER "b"] +testExpressionConstructorCall = TestCase $ + assertEqual "expect constructor call" (StatementExpressionExpression (ConstructorCall "Foo" [])) $ + parseExpression [NEW,IDENTIFIER "Foo",LBRACE,RBRACE] + +testStatementIfThen = TestCase $ + assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) Nothing] $ + parseStatement [IF,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET] +testStatementIfThenElse = TestCase $ + assertEqual "expect empty ifthen" [If (Reference "a") (Block [Block []]) (Just (Block [Block []]))] $ + parseStatement [IF,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET,ELSE,LBRACKET,RBRACKET] +testStatementWhile = TestCase $ + assertEqual "expect while" [While (Reference "a") (Block [Block []])] $ + parseStatement [WHILE,LBRACE,IDENTIFIER "a",RBRACE,LBRACKET,RBRACKET] +testStatementAssign = TestCase $ + assertEqual "expect assign 5" [StatementExpressionStatement (Assignment (Reference "a") (IntegerLiteral 5))] $ + parseStatement [IDENTIFIER "a",ASSIGN,INTEGERLITERAL 5,SEMICOLON] + +testStatementMethodCallNoParams = TestCase $ + assertEqual "expect methodcall statement no params" [StatementExpressionStatement (MethodCall (Reference "this") "foo" [])] $ + parseStatement [IDENTIFIER "foo",LBRACE,RBRACE,SEMICOLON] +testStatementConstructorCall = TestCase $ + assertEqual "expect constructor call" [StatementExpressionStatement (ConstructorCall "Foo" [])] $ + parseStatement [NEW,IDENTIFIER "Foo",LBRACE,RBRACE,SEMICOLON] +testStatementConstructorCallWithArgs = TestCase $ + assertEqual "expect constructor call" [StatementExpressionStatement (ConstructorCall "Foo" [Reference "b"])] $ + parseStatement [NEW,IDENTIFIER "Foo",LBRACE,IDENTIFIER "b",RBRACE,SEMICOLON] + +testStatementPreIncrement = TestCase $ + assertEqual "expect increment" [StatementExpressionStatement $ PostIncrement $ Reference "a"] $ + parseStatement [IDENTIFIER "a",INCREMENT,SEMICOLON] tests = TestList [ testSingleEmptyClass, testTwoEmptyClasses, testBooleanField, - testIntField + testIntField, + testCustomTypeField, + testMultipleDeclarations, + testWithModifier, + testEmptyMethod, + testEmptyPrivateMethod, + testEmptyVoidMethod, + testEmptyMethodWithParam, + testEmptyMethodWithParams, + testClassWithMethodAndField, + testClassWithConstructor, + testConstructorWithParams, + testConstructorWithStatements, + testEmptyBlock, + testBlockWithLocalVarDecl, + testBlockWithMultipleLocalVarDecls, + testNestedBlocks, + testBlockWithEmptyStatement, + testExpressionIntLiteral, + testFieldWithInitialization, + testLocalBoolWithInitialization, + testFieldNullWithInitialization, + testReturnVoid, + testExpressionNot, + testExpressionMinus, + testExpressionLessThan, + testExpressionGreaterThan, + testExpressionLessThanEqual, + testExpressionGreaterThanOrEqual, + testExpressionEqual, + testExpressionNotEqual, + testExpressionAnd, + testExpressionXor, + testExpressionOr, + testExpressionPostIncrement, + testExpressionPostDecrement, + testExpressionPreIncrement, + testExpressionPreDecrement, + testExpressionAssign, + testExpressionTimesEqual, + testExpressionTimesEqual, + testExpressionDivideEqual, + testExpressionPlusEqual, + testExpressionMinusEqual, + testExpressionBraced, + testExpressionThis, + testExpressionPrecedence, + testExpressionMethodCallNoParams, + testExpressionMethodCallOneParam, + testExpressionMethodCallTwoParams, + testExpressionThisMethodCall, + testExpressionThisMethodCallParam, + testExpressionFieldAccess, + testExpressionSimpleFieldAccess, + testExpressionFieldSubAccess, + testExpressionConstructorCall, + testStatementIfThen, + testStatementIfThenElse, + testStatementWhile, + testStatementAssign, + testStatementMethodCallNoParams, + testStatementConstructorCall, + testStatementConstructorCallWithArgs, + testStatementPreIncrement ] \ No newline at end of file diff --git a/project.cabal b/project.cabal index cf8bce0..7c0128b 100644 --- a/project.cabal +++ b/project.cabal @@ -12,18 +12,21 @@ executable compiler utf8-string, bytestring default-language: Haskell2010 - hs-source-dirs: src, - src/ByteCode, - src/ByteCode/ClassFile + hs-source-dirs: src build-tool-depends: alex:alex, happy:happy other-modules: Parser.Lexer, - Parser.JavaParser + Parser.JavaParser, Ast, Example, Typecheck, ByteCode.ByteUtil, ByteCode.ClassFile, - ByteCode.ClassFile.Generator, + ByteCode.Generation.Generator, + ByteCode.Generation.Assembler.ExpressionAndStatement, + ByteCode.Generation.Assembler.Method, + ByteCode.Generation.Builder.Class, + ByteCode.Generation.Builder.Field, + ByteCode.Generation.Builder.Method, ByteCode.Constants test-suite tests diff --git a/src/Ast.hs b/src/Ast.hs index e04cd3d..a20b8e8 100644 --- a/src/Ast.hs +++ b/src/Ast.hs @@ -20,10 +20,14 @@ data Statement deriving (Show, Eq) data StatementExpression - = Assignment Identifier Expression + = Assignment Expression Expression | ConstructorCall DataType [Expression] | MethodCall Expression Identifier [Expression] | TypedStatementExpression DataType StatementExpression + | PostIncrement Expression + | PostDecrement Expression + | PreIncrement Expression + | PreDecrement Expression deriving (Show, Eq) data BinaryOperator @@ -31,6 +35,7 @@ data BinaryOperator | Subtraction | Multiplication | Division + | Modulo | BitwiseAnd | BitwiseOr | BitwiseXor diff --git a/src/ByteCode/ClassFile.hs b/src/ByteCode/ClassFile.hs index a7b0779..358b91a 100644 --- a/src/ByteCode/ClassFile.hs +++ b/src/ByteCode/ClassFile.hs @@ -5,7 +5,8 @@ module ByteCode.ClassFile( ClassFile(..), Operation(..), serialize, - emptyClassFile + emptyClassFile, + opcodeEncodingLength ) where import Data.Word @@ -31,6 +32,7 @@ data Operation = Opiadd | Opior | Opixor | Opineg + | Opdup | Opif_icmplt Word16 | Opif_icmple Word16 | Opif_icmpgt Word16 @@ -41,6 +43,8 @@ data Operation = Opiadd | Opreturn | Opireturn | Opareturn + | Opinvokespecial Word16 + | Opgoto Word16 | Opsipush Word16 | Opldc_w Word16 | Opaload Word16 @@ -48,7 +52,7 @@ data Operation = Opiadd | Opastore Word16 | Opistore Word16 | Opputfield Word16 - | OpgetField Word16 + | Opgetfield Word16 deriving (Show, Eq) @@ -87,6 +91,37 @@ emptyClassFile = ClassFile { attributes = [] } +opcodeEncodingLength :: Operation -> Word16 +opcodeEncodingLength Opiadd = 1 +opcodeEncodingLength Opisub = 1 +opcodeEncodingLength Opimul = 1 +opcodeEncodingLength Opidiv = 1 +opcodeEncodingLength Opiand = 1 +opcodeEncodingLength Opior = 1 +opcodeEncodingLength Opixor = 1 +opcodeEncodingLength Opineg = 1 +opcodeEncodingLength Opdup = 1 +opcodeEncodingLength (Opif_icmplt _) = 3 +opcodeEncodingLength (Opif_icmple _) = 3 +opcodeEncodingLength (Opif_icmpgt _) = 3 +opcodeEncodingLength (Opif_icmpge _) = 3 +opcodeEncodingLength (Opif_icmpeq _) = 3 +opcodeEncodingLength (Opif_icmpne _) = 3 +opcodeEncodingLength Opaconst_null = 1 +opcodeEncodingLength Opreturn = 1 +opcodeEncodingLength Opireturn = 1 +opcodeEncodingLength Opareturn = 1 +opcodeEncodingLength (Opinvokespecial _) = 3 +opcodeEncodingLength (Opgoto _) = 3 +opcodeEncodingLength (Opsipush _) = 3 +opcodeEncodingLength (Opldc_w _) = 3 +opcodeEncodingLength (Opaload _) = 4 +opcodeEncodingLength (Opiload _) = 4 +opcodeEncodingLength (Opastore _) = 4 +opcodeEncodingLength (Opistore _) = 4 +opcodeEncodingLength (Opputfield _) = 3 +opcodeEncodingLength (Opgetfield _) = 3 + class Serializable a where serialize :: a -> [Word8] @@ -108,32 +143,35 @@ instance Serializable MemberInfo where ++ concatMap serialize (memberAttributes member) instance Serializable Operation where - serialize Opiadd = [0x60] - serialize Opisub = [0x64] - serialize Opimul = [0x68] - serialize Opidiv = [0x6C] - serialize Opiand = [0x7E] - serialize Opior = [0x80] - serialize Opixor = [0x82] - serialize Opineg = [0x74] - serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch - serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch - serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch - serialize (Opif_icmpge branch) = 0xA2 : unpackWord16 branch - serialize (Opif_icmpeq branch) = 0x9F : unpackWord16 branch - serialize (Opif_icmpne branch) = 0xA0 : unpackWord16 branch - serialize Opaconst_null = [0x01] - serialize Opreturn = [0xB1] - serialize Opireturn = [0xAC] - serialize Opareturn = [0xB0] - serialize (Opsipush index) = 0x11 : unpackWord16 index - serialize (Opldc_w index) = 0x13 : unpackWord16 index - serialize (Opaload index) = [0xC4, 0x19] ++ unpackWord16 index - serialize (Opiload index) = [0xC4, 0x15] ++ unpackWord16 index - serialize (Opastore index) = [0xC4, 0x3A] ++ unpackWord16 index - serialize (Opistore index) = [0xC4, 0x36] ++ unpackWord16 index - serialize (Opputfield index) = 0xB5 : unpackWord16 index - serialize (OpgetField index) = 0xB4 : unpackWord16 index + serialize Opiadd = [0x60] + serialize Opisub = [0x64] + serialize Opimul = [0x68] + serialize Opidiv = [0x6C] + serialize Opiand = [0x7E] + serialize Opior = [0x80] + serialize Opixor = [0x82] + serialize Opineg = [0x74] + serialize Opdup = [0x59] + serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch + serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch + serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch + serialize (Opif_icmpge branch) = 0xA2 : unpackWord16 branch + serialize (Opif_icmpeq branch) = 0x9F : unpackWord16 branch + serialize (Opif_icmpne branch) = 0xA0 : unpackWord16 branch + serialize Opaconst_null = [0x01] + serialize Opreturn = [0xB1] + serialize Opireturn = [0xAC] + serialize Opareturn = [0xB0] + serialize (Opinvokespecial index) = 0xB7 : unpackWord16 index + serialize (Opgoto index) = 0xA7 : unpackWord16 index + serialize (Opsipush index) = 0x11 : unpackWord16 index + serialize (Opldc_w index) = 0x13 : unpackWord16 index + serialize (Opaload index) = [0xC4, 0x19] ++ unpackWord16 index + serialize (Opiload index) = [0xC4, 0x15] ++ unpackWord16 index + serialize (Opastore index) = [0xC4, 0x3A] ++ unpackWord16 index + serialize (Opistore index) = [0xC4, 0x36] ++ unpackWord16 index + serialize (Opputfield index) = 0xB5 : unpackWord16 index + serialize (Opgetfield index) = 0xB4 : unpackWord16 index instance Serializable Attribute where serialize (CodeAttribute { attributeMaxStack = maxStack, @@ -151,7 +189,7 @@ instance Serializable Attribute where ++ unpackWord16 0 -- attributes_count instance Serializable ClassFile where - serialize classfile = unpackWord32 0xC0FEBABE -- magic + serialize classfile = unpackWord32 0xCAFEBABE -- magic ++ unpackWord16 0 -- minor version ++ unpackWord16 49 -- major version ++ unpackWord16 (fromIntegral (1 + length (constantPool classfile))) -- constant pool count diff --git a/src/ByteCode/ClassFile/Generator.hs b/src/ByteCode/ClassFile/Generator.hs deleted file mode 100644 index 3bdd0ec..0000000 --- a/src/ByteCode/ClassFile/Generator.hs +++ /dev/null @@ -1,169 +0,0 @@ -module ByteCode.ClassFile.Generator( - classBuilder, - datatypeDescriptor, - methodParameterDescriptor, - methodDescriptor, -) where - -import ByteCode.Constants -import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..)) -import Ast -import Data.Char - - -type ClassFileBuilder a = a -> ClassFile -> ClassFile - - -datatypeDescriptor :: String -> String -datatypeDescriptor "void" = "V" -datatypeDescriptor "int" = "I" -datatypeDescriptor "char" = "C" -datatypeDescriptor "boolean" = "B" -datatypeDescriptor x = "L" ++ x - -methodParameterDescriptor :: String -> String -methodParameterDescriptor "void" = "V" -methodParameterDescriptor "int" = "I" -methodParameterDescriptor "char" = "C" -methodParameterDescriptor "boolean" = "B" -methodParameterDescriptor x = "L" ++ x ++ ";" - -methodDescriptor :: MethodDeclaration -> String -methodDescriptor (MethodDeclaration returntype _ parameters _) = let - parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] - in - "(" - ++ (concat (map methodParameterDescriptor parameter_types)) - ++ ")" - ++ datatypeDescriptor returntype - -classBuilder :: ClassFileBuilder Class -classBuilder (Class name methods fields) _ = let - baseConstants = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code" - ] - nameConstants = [ClassInfo 9, Utf8Info name] - nakedClassFile = ClassFile { - constantPool = baseConstants ++ nameConstants, - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [], - methods = [], - attributes = [] - } - in - foldr methodBuilder (foldr fieldBuilder nakedClassFile fields) methods - - - -fieldBuilder :: ClassFileBuilder VariableDeclaration -fieldBuilder (VariableDeclaration datatype name _) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info name, - Utf8Info (datatypeDescriptor datatype) - ] - field = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral (baseIndex + 2)), - memberDescriptorIndex = (fromIntegral (baseIndex + 3)), - memberAttributes = [] - } - in - input { - constantPool = (constantPool input) ++ constants, - fields = (fields input) ++ [field] - } - -methodBuilder :: ClassFileBuilder MethodDeclaration -methodBuilder (MethodDeclaration returntype name parameters statement) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info name, - Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) - ] - --code = assembleByteCode statement - method = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral (baseIndex + 2)), - memberDescriptorIndex = (fromIntegral (baseIndex + 3)), - memberAttributes = [ - CodeAttribute { - attributeMaxStack = 420, - attributeMaxLocals = 420, - attributeCode = [Opiadd] - } - ] - } - in - input { - constantPool = (constantPool input) ++ constants, - methods = (fields input) ++ [method] - } - -type Assembler a = a -> ([ConstantInfo], [Operation]) -> ([ConstantInfo], [Operation]) - -returnOperation :: DataType -> Operation -returnOperation dtype - | elem dtype ["int", "char", "boolean"] = Opireturn - | otherwise = Opareturn - -binaryOperation :: BinaryOperator -> Operation -binaryOperation Addition = Opiadd -binaryOperation Subtraction = Opisub -binaryOperation Multiplication = Opimul -binaryOperation Division = Opidiv -binaryOperation BitwiseAnd = Opiand -binaryOperation BitwiseOr = Opior -binaryOperation BitwiseXor = Opixor - -assembleMethod :: Assembler MethodDeclaration -assembleMethod (MethodDeclaration _ _ _ (Block statements)) (constants, ops) = - foldr assembleStatement (constants, ops) statements - -assembleStatement :: Assembler Statement -assembleStatement (TypedStatement stype (Return expr)) (constants, ops) = case expr of - Nothing -> (constants, ops ++ [Opreturn]) - Just expr -> let - (expr_constants, expr_ops) = assembleExpression expr (constants, ops) - in - (expr_constants, expr_ops ++ [returnOperation stype]) - -assembleExpression :: Assembler Expression -assembleExpression (TypedExpression _ (BinaryOperation op a b)) (constants, ops) - | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let - (aConstants, aOps) = assembleExpression a (constants, ops) - (bConstants, bOps) = assembleExpression b (aConstants, aOps) - in - (bConstants, bOps ++ [binaryOperation op]) -assembleExpression (TypedExpression _ (CharacterLiteral literal)) (constants, ops) = - (constants, ops ++ [Opsipush (fromIntegral (ord literal))]) -assembleExpression (TypedExpression _ (BooleanLiteral literal)) (constants, ops) = - (constants, ops ++ [Opsipush (if literal then 1 else 0)]) -assembleExpression (TypedExpression _ (IntegerLiteral literal)) (constants, ops) - | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)]) - | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))]) -assembleExpression (TypedExpression _ NullLiteral) (constants, ops) = - (constants, ops ++ [Opaconst_null]) -assembleExpression (TypedExpression etype (UnaryOperation Not expr)) (constants, ops) = let - (exprConstants, exprOps) = assembleExpression expr (constants, ops) - newConstant = fromIntegral (1 + length exprConstants) - in case etype of - "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor]) - "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor]) - "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor]) -assembleExpression (TypedExpression _ (UnaryOperation Minus expr)) (constants, ops) = let - (exprConstants, exprOps) = assembleExpression expr (constants, ops) - in - (exprConstants, exprOps ++ [Opineg]) diff --git a/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs b/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs new file mode 100644 index 0000000..4ace628 --- /dev/null +++ b/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs @@ -0,0 +1,228 @@ +module ByteCode.Generation.Assembler.ExpressionAndStatement where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import Data.List +import Data.Char +import ByteCode.Generation.Builder.Field + +assembleExpression :: Assembler Expression +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) + | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + in + (bConstants, bOps ++ [binaryOperation op], lvars) + | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + cmp_op = comparisonOperation op 9 + cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] + in + (bConstants, bOps ++ cmp_ops, lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = + (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = + (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) + | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) + | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = + (constants, ops ++ [Opaconst_null], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + newConstant = fromIntegral (1 + length exprConstants) + in case etype of + "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) + "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) + "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in + (exprConstants, exprOps ++ [Opineg], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (FieldVariable name)) = let + fieldIndex = findFieldIndex constants name + in case fieldIndex of + Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)], lvars) + Nothing -> error ("No such field found in constant pool: " ++ name) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) = let + localIndex = findIndex ((==) name) lvars + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = + assembleStatementExpression (constants, ops, lvars) stmtexp + +assembleExpression _ expr = error ("unimplemented: " ++ show expr) + + + + +-- TODO untested +assembleStatementExpression :: Assembler StatementExpression +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype (LocalVariable name)) expr)) = let + localIndex = findIndex ((==) name) lvars + (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants_a, ops_a ++ if isPrimitive then [Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype (FieldVariable name)) expr)) = let + fieldIndex = findFieldIndex constants name + (constants_a, ops_a, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + in case fieldIndex of + Just index -> (constants_a, ops_a ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error ("No such field variable found in constant pool: " ++ name) + + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreIncrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostIncrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreDecrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostDecrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreIncrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostIncrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreDecrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostDecrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + + + + + +assembleStatement :: Assembler Statement +assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of + Nothing -> (constants, ops ++ [Opreturn], lvars) + Just expr -> let + (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr + in + (expr_constants, expr_ops ++ [returnOperation stype], lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = + foldl assembleStatement (constants, ops, lvars) statements +assembleStatement (constants, ops, lvars) (TypedStatement _ (If expr if_stmt else_stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt + (constants_elsea, ops_elsea, _) = case else_stmt of + Nothing -> (constants_ifa, [], lvars) + Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt + -- +6 because we insert 2 gotos, one for if, one for else + if_length = sum (map opcodeEncodingLength ops_ifa) + 6 + -- +3 because we need to account for the goto in the if statement. + else_length = sum (map opcodeEncodingLength ops_elsea) + 3 + in + (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea, lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt + -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison + stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 + entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) + in + (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let + isPrimitive = elem dtype ["char", "boolean", "int"] + (constants_init, ops_init, _) = case expr of + Just exp -> assembleExpression (constants, ops, lvars) exp + Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) + localIndex = fromIntegral (length lvars) + storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] + in + (constants_init, ops_init ++ storeLocal, lvars ++ [name]) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = + assembleStatementExpression (constants, ops, lvars) expr + +assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) diff --git a/src/ByteCode/Generation/Assembler/Method.hs b/src/ByteCode/Generation/Assembler/Method.hs new file mode 100644 index 0000000..a1b896e --- /dev/null +++ b/src/ByteCode/Generation/Assembler/Method.hs @@ -0,0 +1,20 @@ +module ByteCode.Generation.Assembler.Method where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Generation.Assembler.ExpressionAndStatement + +assembleMethod :: Assembler MethodDeclaration +assembleMethod (constants, ops, lvars) (MethodDeclaration _ name _ (TypedStatement _ (Block statements))) + | name == "" = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0, Opinvokespecial 2] + in + (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) + | otherwise = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0] + in + (constants_a, init_ops ++ ops_a, lvars_a) +assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt) diff --git a/src/ByteCode/Generation/Builder/Class.hs b/src/ByteCode/Generation/Builder/Class.hs new file mode 100644 index 0000000..16fef21 --- /dev/null +++ b/src/ByteCode/Generation/Builder/Class.hs @@ -0,0 +1,44 @@ +module ByteCode.Generation.Builder.Class where + +import ByteCode.Generation.Builder.Field +import ByteCode.Generation.Builder.Method +import ByteCode.Generation.Generator +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Constants + +injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] +injectDefaultConstructor pre + | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre + | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] + + +classBuilder :: ClassFileBuilder Class +classBuilder (Class name methods fields) _ = let + baseConstants = [ + ClassInfo 4, + MethodRefInfo 1 3, + NameAndTypeInfo 5 6, + Utf8Info "java/lang/Object", + Utf8Info "", + Utf8Info "()V", + Utf8Info "Code" + ] + nameConstants = [ClassInfo 9, Utf8Info name] + nakedClassFile = ClassFile { + constantPool = baseConstants ++ nameConstants, + accessFlags = accessPublic, + thisClass = 8, + superClass = 1, + fields = [], + methods = [], + attributes = [] + } + + methodsWithInjectedConstructor = injectDefaultConstructor methods + + classFileWithFields = foldr fieldBuilder nakedClassFile fields + classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedConstructor + classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedConstructor + in + classFileWithAssembledMethods \ No newline at end of file diff --git a/src/ByteCode/Generation/Builder/Field.hs b/src/ByteCode/Generation/Builder/Field.hs new file mode 100644 index 0000000..ec1f711 --- /dev/null +++ b/src/ByteCode/Generation/Builder/Field.hs @@ -0,0 +1,46 @@ +module ByteCode.Generation.Builder.Field where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Constants +import Data.List + +findFieldIndex :: [ConstantInfo] -> String -> Maybe Int +findFieldIndex constants name = let + fieldRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, FieldRefInfo _ _) <- (zip [1..] constants) + ] + fieldRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info fieldName -> (index, fieldName) + something_else -> error ("Expected UTF8Info but got" ++ show something_else)) + fieldRefNameInfos + fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames + in case fieldIndex of + Just (index, _) -> Just index + Nothing -> Nothing + + +fieldBuilder :: ClassFileBuilder VariableDeclaration +fieldBuilder (VariableDeclaration datatype name _) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info name, + Utf8Info (datatypeDescriptor datatype) + ] + field = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral (baseIndex + 2)), + memberDescriptorIndex = (fromIntegral (baseIndex + 3)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + fields = (fields input) ++ [field] + } diff --git a/src/ByteCode/Generation/Builder/Method.hs b/src/ByteCode/Generation/Builder/Method.hs new file mode 100644 index 0000000..5475d4d --- /dev/null +++ b/src/ByteCode/Generation/Builder/Method.hs @@ -0,0 +1,80 @@ +module ByteCode.Generation.Builder.Method where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Generation.Assembler.Method +import ByteCode.Constants +import Data.List + +methodDescriptor :: MethodDeclaration -> String +methodDescriptor (MethodDeclaration returntype _ parameters _) = let + parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] + in + "(" + ++ (concat (map methodParameterDescriptor parameter_types)) + ++ ")" + ++ methodParameterDescriptor returntype + +methodParameterDescriptor :: String -> String +methodParameterDescriptor "void" = "V" +methodParameterDescriptor "int" = "I" +methodParameterDescriptor "char" = "C" +methodParameterDescriptor "boolean" = "B" +methodParameterDescriptor x = "L" ++ x ++ ";" + +memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool +memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) + +findMethodIndex :: ClassFile -> String -> Maybe Int +findMethodIndex classFile name = let + constants = constantPool classFile + in + findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) + + +methodBuilder :: ClassFileBuilder MethodDeclaration +methodBuilder (MethodDeclaration returntype name parameters statement) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + Utf8Info name, + Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) + ] + + method = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral baseIndex), + memberDescriptorIndex = (fromIntegral (baseIndex + 1)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + methods = (methods input) ++ [method] + } + + + +methodAssembler :: ClassFileBuilder MethodDeclaration +methodAssembler (MethodDeclaration returntype name parameters statement) input = let + methodConstantIndex = findMethodIndex input name + in case methodConstantIndex of + Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) + Just index -> let + declaration = MethodDeclaration returntype name parameters statement + paramNames = "this" : [name | ParameterDeclaration _ name <- parameters] + (pre, method : post) = splitAt index (methods input) + (_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration + assembledMethod = method { + memberAttributes = [ + CodeAttribute { + attributeMaxStack = 420, + attributeMaxLocals = 420, + attributeCode = bytecode + } + ] + } + in + input { + methods = pre ++ (assembledMethod : post) + } diff --git a/src/ByteCode/Generation/Generator.hs b/src/ByteCode/Generation/Generator.hs new file mode 100644 index 0000000..6d42ba0 --- /dev/null +++ b/src/ByteCode/Generation/Generator.hs @@ -0,0 +1,73 @@ +module ByteCode.Generation.Generator( + datatypeDescriptor, + memberInfoName, + memberInfoDescriptor, + returnOperation, + binaryOperation, + comparisonOperation, + ClassFileBuilder, + Assembler +) where + +import ByteCode.Constants +import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import Ast +import Data.Char +import Data.List +import Data.Word + +type ClassFileBuilder a = a -> ClassFile -> ClassFile +type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) + +datatypeDescriptor :: String -> String +datatypeDescriptor "void" = "V" +datatypeDescriptor "int" = "I" +datatypeDescriptor "char" = "C" +datatypeDescriptor "boolean" = "B" +datatypeDescriptor x = "L" ++ x + +memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String +memberInfoDescriptor constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = _, + memberDescriptorIndex = descriptorIndex, + memberAttributes = _ } = let + descriptor = constants!!((fromIntegral descriptorIndex) - 1) + in case descriptor of + Utf8Info descriptorText -> descriptorText + _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) + + +memberInfoName :: [ConstantInfo] -> MemberInfo -> String +memberInfoName constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = nameIndex, + memberDescriptorIndex = _, + memberAttributes = _ } = let + name = constants!!((fromIntegral nameIndex) - 1) + in case name of + Utf8Info nameText -> nameText + _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) + + +returnOperation :: DataType -> Operation +returnOperation dtype + | elem dtype ["int", "char", "boolean"] = Opireturn + | otherwise = Opareturn + +binaryOperation :: BinaryOperator -> Operation +binaryOperation Addition = Opiadd +binaryOperation Subtraction = Opisub +binaryOperation Multiplication = Opimul +binaryOperation Division = Opidiv +binaryOperation BitwiseAnd = Opiand +binaryOperation BitwiseOr = Opior +binaryOperation BitwiseXor = Opixor + +comparisonOperation :: BinaryOperator -> Word16 -> Operation +comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation +comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation +comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation +comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation +comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation +comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation \ No newline at end of file diff --git a/src/Example.hs b/src/Example.hs index 7f345b3..03ff209 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -49,7 +49,7 @@ exampleExpression :: Expression exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age") exampleAssignment :: Expression -exampleAssignment = StatementExpressionExpression (Assignment "a" (IntegerLiteral 30)) +exampleAssignment = StatementExpressionExpression (Assignment (Reference "a") (IntegerLiteral 30)) exampleMethodCall :: Statement exampleMethodCall = StatementExpressionStatement (MethodCall (Reference "this") "setAge" [IntegerLiteral 30]) @@ -58,7 +58,7 @@ exampleConstructorCall :: Statement exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))) exampleNameResolution :: Expression -exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age") +exampleNameResolution = BinaryOperation NameResolution (Reference "bob2") (Reference "age") exampleBlockResolution :: Statement exampleBlockResolution = Block [ @@ -80,7 +80,7 @@ exampleMethodCallAndAssignment = Block [ LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))), StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]), LocalVariableDeclaration (VariableDeclaration "int" "a" Nothing), - StatementExpressionStatement (Assignment "a" (Reference "age")) + StatementExpressionStatement (Assignment (Reference "a") (Reference "age")) ] @@ -89,34 +89,57 @@ exampleMethodCallAndAssignmentFail = Block [ LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))), LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))), StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]), - StatementExpressionStatement (Assignment "a" (Reference "age")) + StatementExpressionStatement (Assignment (Reference "a") (Reference "age")) ] +exampleNameResolutionAssignment :: Statement +exampleNameResolutionAssignment = Block [ + LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))), + StatementExpressionStatement (Assignment (BinaryOperation NameResolution (Reference "bob") (Reference "age")) (IntegerLiteral 30)) + ] + +exampleCharIntOperation :: Expression +exampleCharIntOperation = BinaryOperation Addition (CharacterLiteral 'a') (IntegerLiteral 1) + +exampleNullDeclaration :: Statement +exampleNullDeclaration = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just NullLiteral)) + +exampleNullDeclarationFail :: Statement +exampleNullDeclarationFail = LocalVariableDeclaration (VariableDeclaration "int" "a" (Just NullLiteral)) + +exampleNullAssignment :: Statement +exampleNullAssignment = StatementExpressionStatement (Assignment (Reference "a") NullLiteral) + +exampleIncrement :: Statement +exampleIncrement = StatementExpressionStatement (PostIncrement (Reference "a")) + testClasses :: [Class] testClasses = [ Class "Person" [ - MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] + MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] (Block [ Return (Just (Reference "this")) ]), - MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"] + MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"] (Block [ LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge"))) ]), - MethodDeclaration "int" "getAge" [] + MethodDeclaration "int" "getAge" [] (Return (Just (Reference "age"))) ] [ VariableDeclaration "int" "age" Nothing -- initially unassigned ], Class "Main" [ - MethodDeclaration "int" "main" [] + MethodDeclaration "int" "main" [] (Block [ LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))), StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]), - LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))), + LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob2") "getAge" [])))), Return (Just (Reference "bobAge")) ]) - ] [] + ] [ + VariableDeclaration "Person" "bob2" Nothing + ] ] runTypeCheck :: IO () @@ -151,7 +174,7 @@ runTypeCheck = do catch (do print "=====================================================================================" - evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses) + evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("this", "Main")] testClasses) printSuccess "Type checking of name resolution completed successfully" printResult "Result Name Resolution:" evaluatedNameResolution ) handleError @@ -189,7 +212,7 @@ runTypeCheck = do let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses case mainClass of Class _ [mainMethod] _ -> do - let result = typeCheckMethodDeclaration mainMethod [] testClasses + let result = typeCheckMethodDeclaration mainMethod [("this", "Main")] testClasses printSuccess "Full program type checking completed successfully." printResult "Main method result:" result ) handleError @@ -201,3 +224,44 @@ runTypeCheck = do printResult "Typed Program:" typedProgram ) handleError + catch (do + print "=====================================================================================" + typedAssignment <- evaluate (typeCheckStatement exampleNameResolutionAssignment [] sampleClasses) + printSuccess "Type checking of name resolution assignment completed successfully" + printResult "Result Name Resolution Assignment:" typedAssignment + ) handleError + + catch (do + print "=====================================================================================" + evaluatedCharIntOperation <- evaluate (typeCheckExpression exampleCharIntOperation [] sampleClasses) + printSuccess "Type checking of char int operation completed successfully" + printResult "Result Char Int Operation:" evaluatedCharIntOperation + ) handleError + + catch (do + print "=====================================================================================" + evaluatedNullDeclaration <- evaluate (typeCheckStatement exampleNullDeclaration [] sampleClasses) + printSuccess "Type checking of null declaration completed successfully" + printResult "Result Null Declaration:" evaluatedNullDeclaration + ) handleError + + catch (do + print "=====================================================================================" + evaluatedNullDeclarationFail <- evaluate (typeCheckStatement exampleNullDeclarationFail [] sampleClasses) + printSuccess "Type checking of null declaration failed" + printResult "Result Null Declaration:" evaluatedNullDeclarationFail + ) handleError + + catch (do + print "=====================================================================================" + evaluatedNullAssignment <- evaluate (typeCheckStatement exampleNullAssignment [("a", "Person")] sampleClasses) + printSuccess "Type checking of null assignment completed successfully" + printResult "Result Null Assignment:" evaluatedNullAssignment + ) handleError + + catch (do + print "=====================================================================================" + evaluatedIncrement <- evaluate (typeCheckStatement exampleIncrement [("a", "int")] sampleClasses) + printSuccess "Type checking of increment completed successfully" + printResult "Result Increment:" evaluatedIncrement + ) handleError diff --git a/src/Main.hs b/src/Main.hs index 5ee22ff..588efc2 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,7 +2,19 @@ module Main where import Example import Typecheck +import Parser.Lexer (alexScanTokens) +import Parser.JavaParser +import ByteCode.Generation.Generator +import ByteCode.Generation.Builder.Class +import ByteCode.ClassFile +import Data.ByteString (pack, writeFile) main = do - Example.runTypeCheck - + file <- readFile "Testklasse.java" + + let untypedAST = parse $ alexScanTokens file + let typedAST = head (typeCheckCompilationUnit untypedAST) + let abstractClassFile = classBuilder typedAST emptyClassFile + let assembledClassFile = pack (serialize abstractClassFile) + + Data.ByteString.writeFile "Testklasse.class" assembledClassFile diff --git a/src/Parser/JavaParser.y b/src/Parser/JavaParser.y index bfe74b2..aa22bb8 100644 --- a/src/Parser/JavaParser.y +++ b/src/Parser/JavaParser.y @@ -1,12 +1,15 @@ { -module Parser.JavaParser (parse) where +module Parser.JavaParser (parse, parseStatement, parseExpression) where import Ast import Parser.Lexer } %name parse +%name parseStatement statement +%name parseExpression expression %tokentype { Token } %error { parseError } +%errorhandlertype explist %token BOOLEAN { BOOLEAN } @@ -41,7 +44,6 @@ import Parser.Lexer JNULL { NULLLITERAL } BOOLLITERAL { BOOLLITERAL $$ } DIV { DIV } - LOGICALOR { OR } NOTEQUAL { NOTEQUAL } INSTANCEOF { INSTANCEOF } ANDEQUAL { ANDEQUAL } @@ -80,17 +82,17 @@ compilationunit : typedeclarations { $1 } typedeclarations : typedeclaration { [$1] } | typedeclarations typedeclaration { $1 ++ [$2] } -name : qualifiedname { } - | simplename { } +name : simplename { Reference $1 } + | qualifiedname { $1 } typedeclaration : classdeclaration { $1 } -qualifiedname : name DOT IDENTIFIER { } +qualifiedname : name DOT IDENTIFIER { BinaryOperation NameResolution $1 (Reference $3) } -simplename : IDENTIFIER { } +simplename : IDENTIFIER { $1 } classdeclaration : CLASS IDENTIFIER classbody { case $3 of (methods, fields) -> Class $2 methods fields } - -- | modifiers CLASS IDENTIFIER classbody { case $4 of (methods, fields) -> Class $3 methods fields } + | modifiers CLASS IDENTIFIER classbody { case $4 of (methods, fields) -> Class $3 methods fields } classbody : LBRACKET RBRACKET { ([], []) } | LBRACKET classbodydeclarations RBRACKET { $2 } @@ -103,11 +105,11 @@ classbodydeclarations : classbodydeclaration { MethodDecl method -> ([method], []) FieldDecls fields -> ([], fields) } - -- | classbodydeclarations classbodydeclaration { - -- case ($1, $2) of - -- ((methods, fields), MethodDecl method) -> ((methods ++ [method]), fields) - -- ((methods, fields), FieldDecl field) -> (methods, (fields ++ [field])) - -- } + | classbodydeclarations classbodydeclaration { + case ($1, $2) of + ((methods, fields), MethodDecl method) -> ((methods ++ [method]), fields) + ((methods, fields), FieldDecls newFields) -> (methods, (fields ++ newFields)) + } modifier : PUBLIC { } | PROTECTED { } @@ -115,54 +117,54 @@ modifier : PUBLIC { } | STATIC { } | ABSTRACT { } -classtype : classorinterfacetype{ } +classtype : classorinterfacetype { $1 } classbodydeclaration : classmemberdeclaration { $1 } - -- | constructordeclaration { FieldDecl $ VariableDeclaration "int" "a" Nothing } -- TODO + | constructordeclaration { $1 } -classorinterfacetype : name{ } +classorinterfacetype : simplename { $1 } classmemberdeclaration : fielddeclaration { $1 } - -- | methoddeclaration { } + | methoddeclaration { $1 } -constructordeclaration : constructordeclarator constructorbody { } - | modifiers constructordeclarator constructorbody { } +constructordeclaration : constructordeclarator constructorbody { MethodDecl $ MethodDeclaration "void" "" $1 $2 } + | modifiers constructordeclarator constructorbody { MethodDecl $ MethodDeclaration "void" "" $2 $3 } fielddeclaration : type variabledeclarators SEMICOLON { FieldDecls $ map (convertDeclarator $1) $2 } - -- | modifiers type variabledeclarators SEMICOLON {} + | modifiers type variabledeclarators SEMICOLON { FieldDecls $ map (convertDeclarator $2) $3 } -methoddeclaration : methodheader methodbody { } +methoddeclaration : methodheader methodbody { case $1 of (returnType, (name, parameters)) -> MethodDecl (MethodDeclaration returnType name parameters $2) } -block : LBRACKET RBRACKET { } - | LBRACKET blockstatements RBRACKET { } +block : LBRACKET RBRACKET { Block [] } + | LBRACKET blockstatements RBRACKET { Block $2 } -constructordeclarator : simplename LBRACE RBRACE { } - | simplename LBRACE formalparameterlist RBRACE { } +constructordeclarator : simplename LBRACE RBRACE { [] } + | simplename LBRACE formalparameterlist RBRACE { $3 } -constructorbody : LBRACKET RBRACKET { } - | LBRACKET explicitconstructorinvocation RBRACKET { } - | LBRACKET blockstatements RBRACKET { } - | LBRACKET explicitconstructorinvocation blockstatements RBRACKET { } +constructorbody : LBRACKET RBRACKET { Block [] } + -- | LBRACKET explicitconstructorinvocation RBRACKET { } + | LBRACKET blockstatements RBRACKET { Block $2 } + -- | LBRACKET explicitconstructorinvocation blockstatements RBRACKET { } -methodheader : type methoddeclarator { } - | modifiers type methoddeclarator { } - | VOID methoddeclarator { } - | modifiers VOID methoddeclarator { } +methodheader : type methoddeclarator { ($1, $2) } + | modifiers type methoddeclarator { ($2, $3) } + | VOID methoddeclarator { ("void", $2) } + | modifiers VOID methoddeclarator { ("void", $3)} type : primitivetype { $1 } - -- | referencetype { } + | referencetype { $1 } variabledeclarators : variabledeclarator { [$1] } - -- | variabledeclarators COMMA variabledeclarator { $1 ++ [$3] } + | variabledeclarators COMMA variabledeclarator { $1 ++ [$3] } -methodbody : block { } - | SEMICOLON { } +methodbody : block { $1 } + | SEMICOLON { Block [] } -blockstatements : blockstatement { } - | blockstatements blockstatement { } +blockstatements : blockstatement { $1 } + | blockstatements blockstatement { $1 ++ $2} -formalparameterlist : formalparameter { } - | formalparameterlist COMMA formalparameter{ } +formalparameterlist : formalparameter { [$1] } + | formalparameterlist COMMA formalparameter { $1 ++ [$3] } explicitconstructorinvocation : THIS LBRACE RBRACE SEMICOLON { } | THIS LBRACE argumentlist RBRACE SEMICOLON { } @@ -170,192 +172,196 @@ explicitconstructorinvocation : THIS LBRACE RBRACE SEMICOLON { } classtypelist : classtype { } | classtypelist COMMA classtype { } -methoddeclarator : IDENTIFIER LBRACE RBRACE { } - | IDENTIFIER LBRACE formalparameterlist RBRACE { } +methoddeclarator : IDENTIFIER LBRACE RBRACE { ($1, []) } + | IDENTIFIER LBRACE formalparameterlist RBRACE { ($1, $3) } primitivetype : BOOLEAN { "boolean" } | numerictype { $1 } -referencetype : classorinterfacetype { } +referencetype : classorinterfacetype { $1 } variabledeclarator : variabledeclaratorid { Declarator $1 Nothing } - -- | variabledeclaratorid ASSIGN variableinitializer { Declarator $1 Nothing } -- TODO + | variabledeclaratorid ASSIGN variableinitializer { Declarator $1 (Just $3) } -blockstatement : localvariabledeclarationstatement { } - | statement { } +blockstatement : localvariabledeclarationstatement { $1 } -- expected type statement + | statement { $1 } -formalparameter : type variabledeclaratorid { } +formalparameter : type variabledeclaratorid { ParameterDeclaration $1 $2 } -argumentlist : expression { } - | argumentlist COMMA expression { } +argumentlist : expression { [$1] } + | argumentlist COMMA expression { $1 ++ [$3] } numerictype : integraltype { $1 } variabledeclaratorid : IDENTIFIER { $1 } -variableinitializer : expression { } +variableinitializer : expression { $1 } -localvariabledeclarationstatement : localvariabledeclaration SEMICOLON { } +localvariabledeclarationstatement : localvariabledeclaration SEMICOLON { $1 } -statement : statementwithouttrailingsubstatement{ } - | ifthenstatement { } - | ifthenelsestatement { } - | whilestatement { } +statement : statementwithouttrailingsubstatement{ $1 } -- statement returns a list of statements + | ifthenstatement { [$1] } + | ifthenelsestatement { [$1] } + | whilestatement { [$1] } -expression : assignmentexpression { } +expression : assignmentexpression { $1 } integraltype : INT { "int" } | CHAR { "char" } -localvariabledeclaration : type variabledeclarators { } +localvariabledeclaration : type variabledeclarators { map LocalVariableDeclaration $ map (convertDeclarator $1) $2 } -statementwithouttrailingsubstatement : block { } - | emptystatement { } - | expressionstatement { } - | returnstatement { } +statementwithouttrailingsubstatement : block { [$1] } + | emptystatement { [] } + | expressionstatement { [$1] } + | returnstatement { [$1] } -ifthenstatement : IF LBRACE expression RBRACE statement { } +ifthenstatement : IF LBRACE expression RBRACE statement { If $3 (Block $5) Nothing } -ifthenelsestatement : IF LBRACE expression RBRACE statementnoshortif ELSE statement { } +ifthenelsestatement : IF LBRACE expression RBRACE statementnoshortif ELSE statement { If $3 (Block $5) (Just (Block $7)) } -whilestatement : WHILE LBRACE expression RBRACE statement { } +whilestatement : WHILE LBRACE expression RBRACE statement { While $3 (Block $5) } -assignmentexpression : conditionalexpression { } - | assignment{ } +assignmentexpression : conditionalexpression { $1 } + | assignment { StatementExpressionExpression $1 } -emptystatement : SEMICOLON { } +emptystatement : SEMICOLON { Block [] } -expressionstatement : statementexpression SEMICOLON { } +expressionstatement : statementexpression SEMICOLON { StatementExpressionStatement $1 } -returnstatement : RETURN SEMICOLON { } - | RETURN expression SEMICOLON { } +returnstatement : RETURN SEMICOLON { Return Nothing } + | RETURN expression SEMICOLON { Return $ Just $2 } -statementnoshortif : statementwithouttrailingsubstatement { } - | ifthenelsestatementnoshortif { } - | whilestatementnoshortif { } +statementnoshortif : statementwithouttrailingsubstatement { $1 } + -- | ifthenelsestatementnoshortif { } + -- | whilestatementnoshortif { } -conditionalexpression : conditionalorexpression { } - | conditionalorexpression QUESMARK expression COLON conditionalexpression { } +conditionalexpression : conditionalorexpression { $1 } + -- | conditionalorexpression QUESMARK expression COLON conditionalexpression { } -assignment :lefthandside assignmentoperator assignmentexpression { } +assignment : lefthandside assignmentoperator assignmentexpression { + case $2 of + Nothing -> Assignment $1 $3 + Just operator -> Assignment $1 (BinaryOperation operator $1 $3) + } -statementexpression : assignment { } - | preincrementexpression { } - | predecrementexpression { } - | postincrementexpression { } - | postdecrementexpression { } - | methodinvocation { } - | classinstancecreationexpression { } +statementexpression : assignment { $1 } + | preincrementexpression { $1 } + | predecrementexpression { $1 } + | postincrementexpression { $1 } + | postdecrementexpression { $1 } + | methodinvocation { $1 } + | classinstancecreationexpression { $1 } ifthenelsestatementnoshortif :IF LBRACE expression RBRACE statementnoshortif ELSE statementnoshortif { } whilestatementnoshortif : WHILE LBRACE expression RBRACE statementnoshortif { } -conditionalorexpression : conditionalandexpression { } - | conditionalorexpression LOGICALOR conditionalandexpression{ } +conditionalorexpression : conditionalandexpression { $1 } + -- | conditionalorexpression LOGICALOR conditionalandexpression{ } -lefthandside : name { } +lefthandside : name { $1 } -assignmentoperator : ASSIGN{ } - | TIMESEQUAL { } - | DIVIDEEQUAL { } - | MODULOEQUAL { } - | PLUSEQUAL { } - | MINUSEQUAL { } - | SHIFTLEFTEQUAL { } - | SIGNEDSHIFTRIGHTEQUAL { } - | UNSIGNEDSHIFTRIGHTEQUAL { } - | ANDEQUAL { } - | XOREQUAL { } - | OREQUAL{ } +assignmentoperator : ASSIGN { Nothing } + | TIMESEQUAL { Just Multiplication } + | DIVIDEEQUAL { Just Division } + | MODULOEQUAL { Just Modulo } + | PLUSEQUAL { Just Addition } + | MINUSEQUAL { Just Subtraction } + -- | SHIFTLEFTEQUAL { } + -- | SIGNEDSHIFTRIGHTEQUAL { } + -- | UNSIGNEDSHIFTRIGHTEQUAL { } + | ANDEQUAL { Just BitwiseAnd } + | XOREQUAL { Just BitwiseXor } + | OREQUAL{ Just BitwiseOr } -preincrementexpression : INCREMENT unaryexpression { } +preincrementexpression : INCREMENT unaryexpression { PreIncrement $2 } -predecrementexpression : DECREMENT unaryexpression { } +predecrementexpression : DECREMENT unaryexpression { PreDecrement $2 } -postincrementexpression : postfixexpression INCREMENT { } +postincrementexpression : postfixexpression INCREMENT { PostIncrement $1 } -postdecrementexpression : postfixexpression DECREMENT { } +postdecrementexpression : postfixexpression DECREMENT { PostDecrement $1 } -methodinvocation : name LBRACE RBRACE { } - | name LBRACE argumentlist RBRACE { } - | primary DOT IDENTIFIER LBRACE RBRACE { } - | primary DOT IDENTIFIER LBRACE argumentlist RBRACE { } +methodinvocation : simplename LBRACE RBRACE { MethodCall (Reference "this") $1 [] } + | simplename LBRACE argumentlist RBRACE { MethodCall (Reference "this") $1 $3 } + | primary DOT IDENTIFIER LBRACE RBRACE { MethodCall $1 $3 [] } + | primary DOT IDENTIFIER LBRACE argumentlist RBRACE { MethodCall $1 $3 $5 } -classinstancecreationexpression : NEW classtype LBRACE RBRACE { } - | NEW classtype LBRACE argumentlist RBRACE { } +classinstancecreationexpression : NEW classtype LBRACE RBRACE { ConstructorCall $2 [] } + | NEW classtype LBRACE argumentlist RBRACE { ConstructorCall $2 $4 } -conditionalandexpression : inclusiveorexpression { } +conditionalandexpression : inclusiveorexpression { $1 } -fieldaccess : primary DOT IDENTIFIER { } +fieldaccess : primary DOT IDENTIFIER { BinaryOperation NameResolution $1 (Reference $3) } -unaryexpression : preincrementexpression { } - | predecrementexpression { } - | PLUS unaryexpression { } - | MINUS unaryexpression { } - | unaryexpressionnotplusminus { } +unaryexpression : unaryexpressionnotplusminus { $1 } + | predecrementexpression { StatementExpressionExpression $1 } + | PLUS unaryexpression { $2 } + | MINUS unaryexpression { UnaryOperation Minus $2 } + | preincrementexpression { StatementExpressionExpression $1 } -postfixexpression : primary { } - | name { } - | postincrementexpression { } - | postdecrementexpression{ } +postfixexpression : primary { $1 } + | name { $1 } + | postincrementexpression { StatementExpressionExpression $1 } + | postdecrementexpression { StatementExpressionExpression $1 } -primary : primarynonewarray { } +primary : primarynonewarray { $1 } -inclusiveorexpression : exclusiveorexpression { } - | inclusiveorexpression OR exclusiveorexpression { } +inclusiveorexpression : exclusiveorexpression { $1 } + | inclusiveorexpression OR exclusiveorexpression { BinaryOperation Or $1 $3 } -primarynonewarray : literal { } - | THIS { } - | LBRACE expression RBRACE { } - | classinstancecreationexpression { } - | fieldaccess { } - | methodinvocation { } +primarynonewarray : literal { $1 } + | THIS { Reference "this" } + | LBRACE expression RBRACE { $2 } + | classinstancecreationexpression { StatementExpressionExpression $1 } + | fieldaccess { $1 } + | methodinvocation { StatementExpressionExpression $1 } -unaryexpressionnotplusminus : postfixexpression { } - | TILDE unaryexpression { } - | EXCLMARK unaryexpression { } - | castexpression{ } +unaryexpressionnotplusminus : postfixexpression { $1 } + -- | TILDE unaryexpression { } + | EXCLMARK unaryexpression { UnaryOperation Not $2 } + -- | castexpression{ } -exclusiveorexpression : andexpression { } - | exclusiveorexpression XOR andexpression { } +exclusiveorexpression : andexpression { $1 } + | exclusiveorexpression XOR andexpression { BinaryOperation BitwiseXor $1 $3 } -literal : INTLITERAL { } - | BOOLLITERAL { } - | CHARLITERAL { } - | JNULL { } +literal : INTLITERAL { IntegerLiteral $1 } + | BOOLLITERAL { BooleanLiteral $1 } + | CHARLITERAL { CharacterLiteral $1 } + | JNULL { NullLiteral } castexpression : LBRACE primitivetype RBRACE unaryexpression { } | LBRACE expression RBRACE unaryexpressionnotplusminus{ } -andexpression : equalityexpression { } - | andexpression AND equalityexpression { } +andexpression : equalityexpression { $1 } + | andexpression AND equalityexpression { BinaryOperation And $1 $3 } -equalityexpression : relationalexpression { } - | equalityexpression EQUAL relationalexpression { } - | equalityexpression NOTEQUAL relationalexpression { } +equalityexpression : relationalexpression { $1 } + | equalityexpression EQUAL relationalexpression { BinaryOperation CompareEqual $1 $3 } + | equalityexpression NOTEQUAL relationalexpression { BinaryOperation CompareNotEqual $1 $3 } -relationalexpression : shiftexpression { } - | relationalexpression LESS shiftexpression { } - | relationalexpression GREATER shiftexpression { } - | relationalexpression LESSEQUAL shiftexpression { } - | relationalexpression GREATEREQUAL shiftexpression { } - | relationalexpression INSTANCEOF referencetype { } +relationalexpression : shiftexpression { $1 } + | relationalexpression LESS shiftexpression { BinaryOperation CompareLessThan $1 $3 } + | relationalexpression GREATER shiftexpression { BinaryOperation CompareGreaterThan $1 $3 } + | relationalexpression LESSEQUAL shiftexpression { BinaryOperation CompareLessOrEqual $1 $3 } + | relationalexpression GREATEREQUAL shiftexpression { BinaryOperation CompareGreaterOrEqual $1 $3 } + -- | relationalexpression INSTANCEOF referencetype { } -shiftexpression : additiveexpression { } +shiftexpression : additiveexpression { $1 } -additiveexpression : multiplicativeexpression { } - | additiveexpression PLUS multiplicativeexpression { } - | additiveexpression MINUS multiplicativeexpression { } +additiveexpression : multiplicativeexpression { $1 } + | additiveexpression PLUS multiplicativeexpression { BinaryOperation Addition $1 $3 } + | additiveexpression MINUS multiplicativeexpression { BinaryOperation Subtraction $1 $3 } -multiplicativeexpression : unaryexpression { } - | multiplicativeexpression MUL unaryexpression { } - | multiplicativeexpression DIV unaryexpression { } - | multiplicativeexpression MOD unaryexpression { } +multiplicativeexpression : unaryexpression { $1 } + | multiplicativeexpression MUL unaryexpression { BinaryOperation Multiplication $1 $3 } + | multiplicativeexpression DIV unaryexpression { BinaryOperation Division $1 $3 } + | multiplicativeexpression MOD unaryexpression { BinaryOperation Modulo $1 $3 } { @@ -365,13 +371,13 @@ data MethodOrFieldDeclaration = MethodDecl MethodDeclaration data Declarator = Declarator Identifier (Maybe Expression) --- convertDeclaratorList :: [DataType] -> MethodOrFieldDeclaration --- convertDeclaratorList = FieldDecls $ map - convertDeclarator :: DataType -> Declarator -> VariableDeclaration convertDeclarator dataType (Declarator id assigment) = VariableDeclaration dataType id assigment -parseError :: [Token] -> a -parseError msg = error ("Parse error: " ++ show msg) +data StatementWithoutSub = Statement + + +parseError :: ([Token], [String]) -> a +parseError (errortoken, expected) = error ("parse error on token: " ++ show errortoken ++ "\nexpected one of: " ++ show expected) } diff --git a/src/Parser/Lexer.x b/src/Parser/Lexer.x index ef76773..cb0d075 100644 --- a/src/Parser/Lexer.x +++ b/src/Parser/Lexer.x @@ -72,7 +72,7 @@ tokens :- -- end keywords $JavaLetter$JavaLetterOrDigit* { \s -> IDENTIFIER s } -- Literals - [1-9]([0-9\_]*[0-9])* { \s -> case readMaybe $ filter ((/=) '_') s of Just a -> INTEGERLITERAL a; Nothing -> error ("failed to parse INTLITERAL " ++ s) } + [0-9]([0-9\_]*[0-9])* { \s -> case readMaybe $ filter ((/=) '_') s of Just a -> INTEGERLITERAL a; Nothing -> error ("failed to parse INTLITERAL " ++ s) } "'"."'" { \s -> case (s) of _ : c : _ -> CHARLITERAL c; _ -> error ("failed to parse CHARLITERAL " ++ s) } -- separators "(" { \_ -> LBRACE } diff --git a/src/Typecheck.hs b/src/Typecheck.hs index 7cdc3a7..0be409f 100644 --- a/src/Typecheck.hs +++ b/src/Typecheck.hs @@ -9,25 +9,42 @@ typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes typeCheckClass :: Class -> [Class] -> Class typeCheckClass (Class className methods fields) classes = let - -- Create a symbol table from class fields and method entries - classFields = [(id, dt) | VariableDeclaration dt id _ <- fields] - methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods] - initalSymTab = ("this", className) : classFields ++ methodEntries + -- Fields and methods dont need to be added to the symtab because they are looked upon automatically under "this" + -- if its not a declared local variable. Also shadowing wouldnt be possible then. + initalSymTab = [("this", className)] checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods - in Class className checkedMethods fields + checkedFields = map (\field -> typeCheckVariableDeclaration field initalSymTab classes) fields + in Class className checkedMethods checkedFields typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration -typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes = +typeCheckMethodDeclaration (MethodDeclaration retType name params body) symtab classes = let - -- Combine class fields with method parameters to form the initial symbol table for the method methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params] - initialSymtab = classFields ++ methodParams + initialSymtab = ("thisMeth", retType) : symtab ++ methodParams checkedBody = typeCheckStatement body initialSymtab classes bodyType = getTypeFromStmt checkedBody - -- Check if the type of the body matches the declared return type - in if bodyType == retType || (bodyType == "void" && retType == "void") + in if bodyType == retType || (bodyType == "void" && retType == "void") || (bodyType == "null" && isObjectType retType) || isSubtype bodyType retType classes then MethodDeclaration retType name params checkedBody - else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + else error $ "Method Declaration: Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType + +typeCheckVariableDeclaration :: VariableDeclaration -> [(Identifier, DataType)] -> [Class] -> VariableDeclaration +typeCheckVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) symtab classes = + let + -- Ensure the type is valid (either a primitive type or a valid class name) + validType = dataType `elem` ["int", "boolean", "char"] || isUserDefinedClass dataType classes + -- Ensure no redefinition in the same scope + redefined = any ((== identifier) . snd) symtab + -- Type check the initializer expression if it exists + checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr + exprType = fmap getTypeFromExpr checkedExpr + in case (validType, redefined, exprType) of + (False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'" + (_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" + (_, _, Just t) + | t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExpr + | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t + | otherwise -> VariableDeclaration dataType identifier checkedExpr + (_, _, Nothing) -> VariableDeclaration dataType identifier checkedExpr -- ********************************** Type Checking: Expressions ********************************** @@ -37,118 +54,47 @@ typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (Character typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b) typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral typeCheckExpression (Reference id) symtab classes = - let type' = lookupType id symtab - in TypedExpression type' (Reference id) + case lookup id symtab of + Just t -> TypedExpression t (LocalVariable id) + Nothing -> + case lookup "this" symtab of + Just className -> + let classDetails = find (\(Class name _ _) -> name == className) classes + in case classDetails of + Just (Class _ _ fields) -> + let fieldTypes = [dt | VariableDeclaration dt fieldId _ <- fields, fieldId == id] + -- this case only happens when its a field of its own class so the implicit this will be converted to explicit this + in case fieldTypes of + [fieldType] -> TypedExpression fieldType (BinaryOperation NameResolution (TypedExpression className (LocalVariable "this")) (TypedExpression fieldType (FieldVariable id))) + [] -> error $ "Field '" ++ id ++ "' not found in class '" ++ className ++ "'" + _ -> error $ "Ambiguous reference to field '" ++ id ++ "' in class '" ++ className ++ "'" + Nothing -> error $ "Class '" ++ className ++ "' not found for 'this'" + Nothing -> error $ "Context for 'this' not found in symbol table, unable to resolve '" ++ id ++ "'" + typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes = let expr1' = typeCheckExpression expr1 symtab classes expr2' = typeCheckExpression expr2 symtab classes type1 = getTypeFromExpr expr1' type2 = getTypeFromExpr expr2' + resultType = resolveResultType type1 type2 in case op of - Addition -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Addition operation requires two operands of type int" - Subtraction -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Subtraction operation requires two operands of type int" - Multiplication -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Multiplication operation requires two operands of type int" - Division -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Division operation requires two operands of type int" - BitwiseAnd -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise AND operation requires two operands of type int" - BitwiseOr -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise OR operation requires two operands of type int" - BitwiseXor -> - if type1 == "int" && type2 == "int" - then - TypedExpression "int" (BinaryOperation op expr1' expr2') - else - error "Bitwise XOR operation requires two operands of type int" - CompareLessThan -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Less than operation requires two operands of type int" - CompareLessOrEqual -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Less than or equal operation requires two operands of type int" - CompareGreaterThan -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Greater than operation requires two operands of type int" - CompareGreaterOrEqual -> - if type1 == "int" && type2 == "int" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Greater than or equal operation requires two operands of type int" - CompareEqual -> - if type1 == type2 - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Equality operation requires two operands of the same type" - CompareNotEqual -> - if type1 == type2 - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Inequality operation requires two operands of the same type" - And -> - if type1 == "boolean" && type2 == "boolean" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Logical AND operation requires two operands of type boolean" - Or -> - if type1 == "boolean" && type2 == "boolean" - then - TypedExpression "boolean" (BinaryOperation op expr1' expr2') - else - error "Logical OR operation requires two operands of type boolean" - NameResolution -> - case (expr1', expr2) of - (TypedExpression t1 (Reference obj), Reference member) -> - let objectType = lookupType obj symtab - classDetails = find (\(Class className _ _) -> className == objectType) classes - in case classDetails of - Just (Class _ _ fields) -> - let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member] - in case fieldTypes of - [resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2)) - [] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'" - _ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'" - Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class" - _ -> error "Name resolution requires object reference and field name" + Addition -> checkArithmeticOperation op expr1' expr2' type1 type2 resultType + Subtraction -> checkArithmeticOperation op expr1' expr2' type1 type2 resultType + Multiplication -> checkArithmeticOperation op expr1' expr2' type1 type2 resultType + Division -> checkArithmeticOperation op expr1' expr2' type1 type2 resultType + Modulo -> checkArithmeticOperation op expr1' expr2' type1 type2 resultType + BitwiseAnd -> checkBitwiseOperation op expr1' expr2' type1 type2 + BitwiseOr -> checkBitwiseOperation op expr1' expr2' type1 type2 + BitwiseXor -> checkBitwiseOperation op expr1' expr2' type1 type2 + CompareLessThan -> checkComparisonOperation op expr1' expr2' type1 type2 + CompareLessOrEqual -> checkComparisonOperation op expr1' expr2' type1 type2 + CompareGreaterThan -> checkComparisonOperation op expr1' expr2' type1 type2 + CompareGreaterOrEqual -> checkComparisonOperation op expr1' expr2' type1 type2 + CompareEqual -> checkEqualityOperation op expr1' expr2' type1 type2 + CompareNotEqual -> checkEqualityOperation op expr1' expr2' type1 type2 + And -> checkLogicalOperation op expr1' expr2' type1 type2 + Or -> checkLogicalOperation op expr1' expr2' type1 type2 + NameResolution -> resolveNameResolution expr1' expr2 symtab classes typeCheckExpression (UnaryOperation op expr) symtab classes = let expr' = typeCheckExpression expr symtab classes @@ -161,11 +107,11 @@ typeCheckExpression (UnaryOperation op expr) symtab classes = else error "Logical NOT operation requires an operand of type boolean" Minus -> - if type' == "int" + if type' == "int" || type' == "char" then - TypedExpression "int" (UnaryOperation op expr') + TypedExpression type' (UnaryOperation op expr') else - error "Unary minus operation requires an operand of type int" + error "Unary minus operation requires an operand of type int or char" typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes @@ -174,23 +120,29 @@ typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes = -- ********************************** Type Checking: StatementExpressions ********************************** typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression -typeCheckStatementExpression (Assignment id expr) symtab classes = +typeCheckStatementExpression (Assignment ref expr) symtab classes = let expr' = typeCheckExpression expr symtab classes + ref' = typeCheckExpression ref symtab classes type' = getTypeFromExpr expr' - type'' = lookupType id symtab - in if type' == type'' - then - TypedStatementExpression type' (Assignment id expr') - else - error "Assignment type mismatch" + type'' = getTypeFromExpr ref' + in + if type'' == type' || (type' == "null" && isObjectType type'') then + TypedStatementExpression type'' (Assignment ref' expr') + else + error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type' typeCheckStatementExpression (ConstructorCall className args) symtab classes = case find (\(Class name _ _) -> name == className) classes of Nothing -> error $ "Class '" ++ className ++ "' not found." Just (Class _ methods fields) -> - -- Constructor needs the same name as the class - case find (\(MethodDeclaration retType name params _) -> name == className && retType == className) methods of - Nothing -> error $ "No valid constructor found for class '" ++ className ++ "'." + -- Find constructor matching the class name with void return type + case find (\(MethodDeclaration retType name params _) -> name == "" && retType == "void") methods of + -- If no constructor is found, assume standard constructor with no parameters + Nothing -> + if null args then + TypedStatementExpression className (ConstructorCall className args) + else + error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided." Just (MethodDeclaration _ _ params _) -> let args' = map (\arg -> typeCheckExpression arg symtab classes) args @@ -234,20 +186,62 @@ typeCheckStatementExpression (MethodCall expr methodName args) symtab classes = Nothing -> error $ "Class for object type '" ++ objType ++ "' not found." _ -> error "Invalid object type for method call. Object must have a class type." +typeCheckStatementExpression (PostIncrement expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in if type' == "int" || type' == "char" + then + TypedStatementExpression type' (PostIncrement expr') + else + error "Post-increment operation requires an operand of type int or char" + +typeCheckStatementExpression (PostDecrement expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in if type' == "int" || type' == "char" + then + TypedStatementExpression type' (PostDecrement expr') + else + error "Post-decrement operation requires an operand of type int or char" + +typeCheckStatementExpression (PreIncrement expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in if type' == "int" || type' == "char" + then + TypedStatementExpression type' (PreIncrement expr') + else + error "Pre-increment operation requires an operand of type int or char" + +typeCheckStatementExpression (PreDecrement expr) symtab classes = + let expr' = typeCheckExpression expr symtab classes + type' = getTypeFromExpr expr' + in if type' == "int" || type' == "char" + then + TypedStatementExpression type' (PreDecrement expr') + else + error "Pre-decrement operation requires an operand of type int or char" + -- ********************************** Type Checking: Statements ********************************** typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement typeCheckStatement (If cond thenStmt elseStmt) symtab classes = - let cond' = typeCheckExpression cond symtab classes - thenStmt' = typeCheckStatement thenStmt symtab classes - elseStmt' = case elseStmt of - Just stmt -> Just (typeCheckStatement stmt symtab classes) - Nothing -> Nothing - in if getTypeFromExpr cond' == "boolean" - then - TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt') - else - error "If condition must be of type boolean" + let + cond' = typeCheckExpression cond symtab classes + thenStmt' = typeCheckStatement thenStmt symtab classes + elseStmt' = fmap (\stmt -> typeCheckStatement stmt symtab classes) elseStmt + + thenType = getTypeFromStmt thenStmt' + elseType = maybe "void" getTypeFromStmt elseStmt' + + ifType = if thenType == "void" || elseType == "void" + then "void" + else unifyReturnTypes thenType elseType + + in if getTypeFromExpr cond' == "boolean" + then TypedStatement ifType (If cond' thenStmt' elseStmt') + else error "If condition must be of type boolean" + typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes = -- Check for redefinition in the current scope @@ -258,8 +252,12 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr exprType = fmap getTypeFromExpr checkedExpr in case exprType of - Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t - _ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) + Just t + | t == "null" && isObjectType dataType -> + TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) + | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t + | otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) + Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) typeCheckStatement (While cond stmt) symtab classes = let cond' = typeCheckExpression cond symtab classes @@ -273,36 +271,46 @@ typeCheckStatement (While cond stmt) symtab classes = typeCheckStatement (Block statements) symtab classes = let processStatements (accSts, currentSymtab, types) stmt = - let - checkedStmt = typeCheckStatement stmt currentSymtab classes - stmtType = getTypeFromStmt checkedStmt - in case stmt of + case stmt of LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) -> let + alreadyDefined = any (\(id, _) -> id == identifier) currentSymtab + newSymtab = if alreadyDefined + then error ("Variable " ++ identifier ++ " already defined in this scope.") + else (identifier, dataType) : currentSymtab checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr - newSymtab = (identifier, dataType) : currentSymtab + checkedStmt = typeCheckStatement stmt newSymtab classes in (accSts ++ [checkedStmt], newSymtab, types) - If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) - _ -> (accSts ++ [checkedStmt], currentSymtab, types) + _ -> + let + checkedStmt = typeCheckStatement stmt currentSymtab classes + stmtType = getTypeFromStmt checkedStmt + in case stmt of + If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + Block _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types) + _ -> (accSts ++ [checkedStmt], currentSymtab, types) -- Initial accumulator: empty statements list, initial symbol table, empty types list (checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements - -- Determine the block's type: unify all collected types, default to "Void" if none + -- Determine the block's type: unify all collected types, default to "void" if none (UpperBound) blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes in TypedStatement blockType (Block checkedStatements) + typeCheckStatement (Return expr) symtab classes = - let expr' = case expr of + let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab) + expr' = case expr of Just e -> Just (typeCheckExpression e symtab classes) Nothing -> Nothing - in case expr' of - Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e')) - Nothing -> TypedStatement "Void" (Return Nothing) + returnType = maybe "void" getTypeFromExpr expr' + in if returnType == methodReturnType || isSubtype returnType methodReturnType classes + then TypedStatement returnType (Return expr') + else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes @@ -310,6 +318,20 @@ typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = -- ********************************** Type Checking: Helpers ********************************** +isSubtype :: DataType -> DataType -> [Class] -> Bool +isSubtype subType superType classes + | subType == superType = True + | subType == "null" && isObjectType superType = True + | superType == "Object" && isObjectType subType = True + | superType == "Object" && isUserDefinedClass subType classes = True + | otherwise = False + +isUserDefinedClass :: DataType -> [Class] -> Bool +isUserDefinedClass dt classes = dt `elem` map (\(Class name _ _) -> name) classes + +isObjectType :: DataType -> Bool +isObjectType dt = dt /= "int" && dt /= "boolean" && dt /= "char" + getTypeFromExpr :: Expression -> DataType getTypeFromExpr (TypedExpression t _) = t getTypeFromExpr _ = error "Untyped expression found where typed was expected" @@ -324,11 +346,60 @@ getTypeFromStmtExpr _ = error "Untyped statement expression found where typed wa unifyReturnTypes :: DataType -> DataType -> DataType unifyReturnTypes dt1 dt2 - | dt1 == dt2 = dt1 - | otherwise = "Object" + | dt1 == dt2 = dt1 + | dt1 == "null" = dt2 + | dt2 == "null" = dt1 + | otherwise = "Object" -lookupType :: Identifier -> [(Identifier, DataType)] -> DataType -lookupType id symtab = - case lookup id symtab of - Just t -> t - Nothing -> error ("Identifier " ++ id ++ " not found in symbol table") +resolveResultType :: DataType -> DataType -> DataType +resolveResultType "char" "char" = "char" +resolveResultType "int" "int" = "int" +resolveResultType "char" "int" = "int" +resolveResultType "int" "char" = "int" +resolveResultType t1 t2 + | t1 == t2 = t1 + | otherwise = error $ "Incompatible types: " ++ t1 ++ " and " ++ t2 + +checkArithmeticOperation :: BinaryOperator -> Expression -> Expression -> DataType -> DataType -> DataType -> Expression +checkArithmeticOperation op expr1' expr2' type1 type2 resultType + | (type1 == "int" || type1 == "char") && (type2 == "int" || type2 == "char") = + TypedExpression resultType (BinaryOperation op expr1' expr2') + | otherwise = error $ "Arithmetic operation " ++ show op ++ " requires operands of type int or char" + +checkBitwiseOperation :: BinaryOperator -> Expression -> Expression -> DataType -> DataType -> Expression +checkBitwiseOperation op expr1' expr2' type1 type2 + | type1 == "int" && type2 == "int" = + TypedExpression "int" (BinaryOperation op expr1' expr2') + | otherwise = error $ "Bitwise operation " ++ show op ++ " requires operands of type int or char" + +checkComparisonOperation :: BinaryOperator -> Expression -> Expression -> DataType -> DataType -> Expression +checkComparisonOperation op expr1' expr2' type1 type2 + | (type1 == "int" || type1 == "char") && (type2 == "int" || type2 == "char") = + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + | otherwise = error $ "Comparison operation " ++ show op ++ " requires operands of type int or char" + +checkEqualityOperation :: BinaryOperator -> Expression -> Expression -> DataType -> DataType -> Expression +checkEqualityOperation op expr1' expr2' type1 type2 + | type1 == type2 = + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + | otherwise = error $ "Equality operation " ++ show op ++ " requires operands of the same type" + +checkLogicalOperation :: BinaryOperator -> Expression -> Expression -> DataType -> DataType -> Expression +checkLogicalOperation op expr1' expr2' type1 type2 + | type1 == "boolean" && type2 == "boolean" = + TypedExpression "boolean" (BinaryOperation op expr1' expr2') + | otherwise = error $ "Logical operation " ++ show op ++ " requires operands of type boolean" + +resolveNameResolution :: Expression -> Expression -> [(Identifier, DataType)] -> [Class] -> Expression +resolveNameResolution expr1' (Reference ident2) symtab classes = + case getTypeFromExpr expr1' of + objType -> + case find (\(Class className _ _) -> className == objType) classes of + Just (Class _ _ fields) -> + let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == ident2] + in case fieldTypes of + [resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType (FieldVariable ident2))) + [] -> error $ "Field '" ++ ident2 ++ "' not found in class '" ++ objType ++ "'" + _ -> error $ "Ambiguous reference to field '" ++ ident2 ++ "' in class '" ++ objType ++ "'" + Nothing -> error $ "Class '" ++ objType ++ "' not found" +resolveNameResolution _ _ _ _ = error "Name resolution requires object reference and field name"