diff --git a/.gitignore b/.gitignore index d20ba06..5dcf875 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ cabal-dev *.chs.h *.dyn_o *.dyn_hi -*.java *.class *.local~* src/Parser/JavaParser.hs diff --git a/Test/JavaSources/Main.java b/Test/JavaSources/Main.java new file mode 100644 index 0000000..ca1eb22 --- /dev/null +++ b/Test/JavaSources/Main.java @@ -0,0 +1,38 @@ +// compile all test files using: +// ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {} +// compile (in project root) using: +// javac -g:none -sourcepath Test/JavaSources/ Test/JavaSources/Main.java +// afterwards, run using +// java -ea -cp Test/JavaSources/ Main + +public class Main { + public static void main(String[] args) + { + TestEmpty empty = new TestEmpty(); + TestFields fields = new TestFields(); + TestConstructor constructor = new TestConstructor(42); + TestMultipleClasses multipleClasses = new TestMultipleClasses(); + TestRecursion recursion = new TestRecursion(10); + TestMalicious malicious = new TestMalicious(); + + // constructing a basic class works + assert empty != null; + // initializers (and default initializers to 0/null) work + assert fields.a == 0 && fields.b == 42; + // constructor parameters override initializers + assert constructor.a == 42; + // multiple classes within one file work. Referencing another classes fields/methods works. + assert multipleClasses.a.a == 42; + // self-referencing classes work. + assert recursion.child.child.child.child.child.value == 5; + // self-referencing methods work. + assert recursion.fibonacci(15) == 610; + // intentionally dodgy expressions work + assert malicious.assignNegativeIncrement(42) == -42; + assert malicious.tripleAddition(1, 2, 3) == 6; + for(int i = 0; i < 3; i++) + { + assert malicious.cursedFormatting(i) == i; + } + } +} diff --git a/Test/JavaSources/TestConstructor.java b/Test/JavaSources/TestConstructor.java new file mode 100644 index 0000000..f676e1f --- /dev/null +++ b/Test/JavaSources/TestConstructor.java @@ -0,0 +1,9 @@ +public class TestConstructor +{ + public int a = -1; + + public TestConstructor(int initial_value) + { + a = initial_value; + } +} diff --git a/Test/JavaSources/TestEmpty.java b/Test/JavaSources/TestEmpty.java new file mode 100644 index 0000000..184895d --- /dev/null +++ b/Test/JavaSources/TestEmpty.java @@ -0,0 +1,4 @@ +public class TestEmpty +{ + +} diff --git a/Test/JavaSources/TestFields.java b/Test/JavaSources/TestFields.java new file mode 100644 index 0000000..2baaf80 --- /dev/null +++ b/Test/JavaSources/TestFields.java @@ -0,0 +1,5 @@ +public class TestFields +{ + public int a; + public int b = 42; +} diff --git a/Test/JavaSources/TestMalicious.java b/Test/JavaSources/TestMalicious.java new file mode 100644 index 0000000..f71785b --- /dev/null +++ b/Test/JavaSources/TestMalicious.java @@ -0,0 +1,41 @@ +public class TestMalicious { + public int assignNegativeIncrement(int n) + { + return n=-++n+1; + } + + public int tripleAddition(int a, int b, int c) + { + return a+++b+++c++; + } + + public int cursedFormatting(int n) + { + if + + + (n == 0) + + { + + return ((((0)))); + } + + else + + + if(n == + + 1) + { + return + + + 1; + }else { + return + 2 + ; + } + } +} diff --git a/Test/JavaSources/TestMultipleClasses.java b/Test/JavaSources/TestMultipleClasses.java new file mode 100644 index 0000000..c5e418e --- /dev/null +++ b/Test/JavaSources/TestMultipleClasses.java @@ -0,0 +1,9 @@ +public class TestMultipleClasses +{ + public AnotherTestClass a = new AnotherTestClass(); +} + +class AnotherTestClass +{ + public int a = 42; +} diff --git a/Test/JavaSources/TestRecursion.java b/Test/JavaSources/TestRecursion.java new file mode 100644 index 0000000..baa6a31 --- /dev/null +++ b/Test/JavaSources/TestRecursion.java @@ -0,0 +1,27 @@ +public class TestRecursion { + + public int value = 0; + public TestRecursion child = null; + + public TestRecursion(int n) + { + value = n; + + if(n > 0) + { + child = new TestRecursion(n - 1); + } + } + + public int fibonacci(int n) + { + if(n < 2) + { + return n; + } + else + { + return fibonacci(n - 1) + this.fibonacci(n - 2); + } + } +} diff --git a/Test/TestByteCodeGenerator.hs b/Test/TestByteCodeGenerator.hs deleted file mode 100644 index 2c6f4d2..0000000 --- a/Test/TestByteCodeGenerator.hs +++ /dev/null @@ -1,118 +0,0 @@ -module TestByteCodeGenerator where - -import Test.HUnit -import ByteCode.ClassFile.Generator -import ByteCode.ClassFile -import ByteCode.Constants -import Ast - -nakedClass = Class "Testklasse" [] [] -expectedClass = ClassFile { - constantPool = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code", - ClassInfo 9, - Utf8Info "Testklasse" - ], - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [], - methods = [], - attributes = [] - } - -classWithFields = Class "Testklasse" [] [VariableDeclaration "int" "testvariable" Nothing] -expectedClassWithFields = ClassFile { - constantPool = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code", - ClassInfo 9, - Utf8Info "Testklasse", - FieldRefInfo 8 11, - NameAndTypeInfo 12 13, - Utf8Info "testvariable", - Utf8Info "I" - ], - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [ - MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = 12, - memberDescriptorIndex = 13, - memberAttributes = [] - } - ], - methods = [], - attributes = [] - } - -method = MethodDeclaration "int" "add_two_numbers" [ - ParameterDeclaration "int" "a", - ParameterDeclaration "int" "b" ] - (Block [Return (Just (BinaryOperation Addition (Reference "a") (Reference "b")))]) - - -classWithMethod = Class "Testklasse" [method] [] -expectedClassWithMethod = ClassFile { - constantPool = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code", - ClassInfo 9, - Utf8Info "Testklasse", - FieldRefInfo 8 11, - NameAndTypeInfo 12 13, - Utf8Info "add_two_numbers", - Utf8Info "(II)I" - ], - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [], - methods = [ - MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = 12, - memberDescriptorIndex = 13, - memberAttributes = [ - CodeAttribute { - attributeMaxStack = 420, - attributeMaxLocals = 420, - attributeCode = [Opiadd] - } - ] - } - ], - attributes = [] - } - -testBasicConstantPool = TestCase $ assertEqual "basic constant pool" expectedClass $ classBuilder nakedClass emptyClassFile -testFields = TestCase $ assertEqual "fields in constant pool" expectedClassWithFields $ classBuilder classWithFields emptyClassFile -testMethodDescriptor = TestCase $ assertEqual "method descriptor" "(II)I" (methodDescriptor method) -testMethodAssembly = TestCase $ assertEqual "method assembly" expectedClassWithMethod (classBuilder classWithMethod emptyClassFile) -testFindMethodIndex = TestCase $ assertEqual "find method" (Just 0) (findMethodIndex expectedClassWithMethod "add_two_numbers") - -tests = TestList [ - TestLabel "Basic constant pool" testBasicConstantPool, - TestLabel "Fields constant pool" testFields, - TestLabel "Method descriptor building" testMethodDescriptor, - TestLabel "Method assembly" testMethodAssembly, - TestLabel "Find method by name" testFindMethodIndex - ] \ No newline at end of file diff --git a/Test/TestSuite.hs b/Test/TestSuite.hs index bf2c67e..cf8c1e7 100644 --- a/Test/TestSuite.hs +++ b/Test/TestSuite.hs @@ -2,13 +2,11 @@ module Main where import Test.HUnit import TestLexer -import TestByteCodeGenerator import TestParser - tests = TestList [ - TestLabel "TestLexer" TestLexer.tests, - TestLabel "TestParser" TestParser.tests, - TestLabel "TestByteCodeGenerator" TestByteCodeGenerator.tests] + TestLabel "TestLexer" TestLexer.tests, + TestLabel "TestParser" TestParser.tests + ] main = do runTestTTAndExit Main.tests \ No newline at end of file diff --git a/project.cabal b/project.cabal index 7c0128b..60ff6fb 100644 --- a/project.cabal +++ b/project.cabal @@ -10,7 +10,8 @@ executable compiler array, HUnit, utf8-string, - bytestring + bytestring, + filepath default-language: Haskell2010 hs-source-dirs: src build-tool-depends: alex:alex, happy:happy @@ -19,14 +20,11 @@ executable compiler Ast, Example, Typecheck, + ByteCode.Util, ByteCode.ByteUtil, ByteCode.ClassFile, - ByteCode.Generation.Generator, - ByteCode.Generation.Assembler.ExpressionAndStatement, - ByteCode.Generation.Assembler.Method, - ByteCode.Generation.Builder.Class, - ByteCode.Generation.Builder.Field, - ByteCode.Generation.Builder.Method, + ByteCode.Assembler, + ByteCode.Builder, ByteCode.Constants test-suite tests @@ -37,15 +35,17 @@ test-suite tests array, HUnit, utf8-string, - bytestring + bytestring, + filepath build-tool-depends: alex:alex, happy:happy other-modules: Parser.Lexer, Parser.JavaParser, Ast, TestLexer, TestParser, - TestByteCodeGenerator, + ByteCode.Util, ByteCode.ByteUtil, ByteCode.ClassFile, - ByteCode.ClassFile.Generator, + ByteCode.Assembler, + ByteCode.Builder, ByteCode.Constants diff --git a/src/ByteCode/Assembler.hs b/src/ByteCode/Assembler.hs new file mode 100644 index 0000000..aaa8542 --- /dev/null +++ b/src/ByteCode/Assembler.hs @@ -0,0 +1,276 @@ +module ByteCode.Assembler where + +import ByteCode.Constants +import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Util +import Ast +import Data.Char +import Data.List +import Data.Word + +type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) + +assembleExpression :: Assembler Expression +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) + | elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + in + (bConstants, bOps ++ [binaryOperation op], lvars) + | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + cmp_op = comparisonOperation op 9 + cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] + in + (bConstants, bOps ++ cmp_ops, lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression btype (FieldVariable b)))) = let + (fConstants, fieldIndex) = getFieldIndex constants (atype, b, datatypeDescriptor btype) + (aConstants, aOps, _) = assembleExpression (fConstants, ops, lvars) (TypedExpression atype a) + in + (aConstants, aOps ++ [Opgetfield (fromIntegral fieldIndex)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = + (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = + (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) + | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) + | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = + (constants, ops ++ [Opaconst_null], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + newConstant = fromIntegral (1 + length exprConstants) + in case etype of + "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) + "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) + "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in + (exprConstants, exprOps ++ [Opineg], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) + | name == "this" = (constants, ops ++ [Opaload 0], lvars) + | otherwise = let + localIndex = findIndex ((==) name) lvars + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = + assembleStatementExpression (constants, ops, lvars) stmtexp + +assembleExpression _ expr = error ("unimplemented: " ++ show expr) + +assembleNameChain :: Assembler Expression +assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) = + assembleExpression input (TypedExpression atype a) +assembleNameChain input expr = assembleExpression input expr + + +assembleStatementExpression :: Assembler StatementExpression +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype receiver) expr)) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants_a, ops_a ++ if isPrimitive then [Opdup, Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + (constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr + in + (constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreIncrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreDecrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostIncrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostDecrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression rtype (MethodCall (TypedExpression otype receiver) name params)) = let + (constants_r, ops_r, lvars_r) = assembleExpression (constants, ops, lvars) (TypedExpression otype receiver) + (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_r, ops_r, lvars_r) params + (constants_m, methodIndex) = getMethodIndex constants_p (otype, name, methodDescriptorFromParamlist params rtype) + in + (constants_m, ops_p ++ [Opinvokevirtual (fromIntegral methodIndex)], lvars_p) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression rtype (ConstructorCall name params)) = let + (constants_c, classIndex) = getClassIndex constants name + (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_c, ops ++ [Opnew (fromIntegral classIndex), Opdup], lvars) params + (constants_m, methodIndex) = getMethodIndex constants_p (name, "", methodDescriptorFromParamlist params "void") + in + (constants_m, ops_p ++ [Opinvokespecial (fromIntegral methodIndex)], lvars_p) + + +assembleStatement :: Assembler Statement +assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of + Nothing -> (constants, ops ++ [Opreturn], lvars) + Just expr -> let + (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr + in + (expr_constants, expr_ops ++ [returnOperation stype], lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = + foldl assembleStatement (constants, ops, lvars) statements + +assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt else_stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt + (constants_elsea, ops_elsea, _) = case else_stmt of + Nothing -> (constants_ifa, [], lvars) + Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt + -- +6 because we insert 2 gotos, one for if, one for else + if_length = sum (map opcodeEncodingLength ops_ifa) + -- +3 because we need to account for the goto in the if statement. + else_length = sum (map opcodeEncodingLength ops_elsea) + in case dtype of + "void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars) + otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt + -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison + stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 + entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) + in + (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let + isPrimitive = elem dtype ["char", "boolean", "int"] + (constants_init, ops_init, _) = case expr of + Just exp -> assembleExpression (constants, ops, lvars) exp + Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) + localIndex = fromIntegral (length lvars) + storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] + in + (constants_init, ops_init ++ storeLocal, lvars ++ [name]) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = let + (constants_e, ops_e, lvars_e) = assembleStatementExpression (constants, ops, lvars) expr + in + (constants_e, ops_e ++ [Oppop], lvars_e) + +assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) + + +assembleMethod :: Assembler MethodDeclaration +assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements))) + | name == "" = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0, Opinvokespecial 2] + in + (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) + | otherwise = case returntype of + "void" -> let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + in + (constants_a, ops_a ++ [Opreturn], lvars_a) + otherwise -> foldl assembleStatement (constants, ops, lvars) statements +assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt) diff --git a/src/ByteCode/Builder.hs b/src/ByteCode/Builder.hs new file mode 100644 index 0000000..dfebf89 --- /dev/null +++ b/src/ByteCode/Builder.hs @@ -0,0 +1,117 @@ +module ByteCode.Builder where + +import ByteCode.Constants +import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Assembler +import ByteCode.Util +import Ast +import Data.Char +import Data.List +import Data.Word + +type ClassFileBuilder a = a -> ClassFile -> ClassFile + +fieldBuilder :: ClassFileBuilder VariableDeclaration +fieldBuilder (VariableDeclaration datatype name _) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info name, + Utf8Info (datatypeDescriptor datatype) + ] + field = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral (baseIndex + 2)), + memberDescriptorIndex = (fromIntegral (baseIndex + 3)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + fields = (fields input) ++ [field] + } + + + +methodBuilder :: ClassFileBuilder MethodDeclaration +methodBuilder (MethodDeclaration returntype name parameters statement) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + MethodRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info name, + Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) + ] + + method = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral (baseIndex + 2)), + memberDescriptorIndex = (fromIntegral (baseIndex + 3)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + methods = (methods input) ++ [method] + } + + + +methodAssembler :: ClassFileBuilder MethodDeclaration +methodAssembler (MethodDeclaration returntype name parameters statement) input = let + methodConstantIndex = findMethodIndex input name + in case methodConstantIndex of + Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) + Just index -> let + declaration = MethodDeclaration returntype name parameters statement + paramNames = "this" : [name | ParameterDeclaration _ name <- parameters] + in case (splitAt index (methods input)) of + (pre, []) -> input + (pre, method : post) -> let + (_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration + assembledMethod = method { + memberAttributes = [ + CodeAttribute { + attributeMaxStack = 420, + attributeMaxLocals = 420, + attributeCode = bytecode + } + ] + } + in + input { + methods = pre ++ (assembledMethod : post) + } + + +classBuilder :: ClassFileBuilder Class +classBuilder (Class name methods fields) _ = let + baseConstants = [ + ClassInfo 4, + MethodRefInfo 1 3, + NameAndTypeInfo 5 6, + Utf8Info "java/lang/Object", + Utf8Info "", + Utf8Info "()V", + Utf8Info "Code" + ] + nameConstants = [ClassInfo 9, Utf8Info name] + nakedClassFile = ClassFile { + constantPool = baseConstants ++ nameConstants, + accessFlags = accessPublic, + thisClass = 8, + superClass = 1, + fields = [], + methods = [], + attributes = [] + } + + methodsWithInjectedConstructor = injectDefaultConstructor methods + methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor + + classFileWithFields = foldr fieldBuilder nakedClassFile fields + classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers + classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers + in + classFileWithAssembledMethods \ No newline at end of file diff --git a/src/ByteCode/ByteUtil.hs b/src/ByteCode/ByteUtil.hs index daa419a..a602ac3 100644 --- a/src/ByteCode/ByteUtil.hs +++ b/src/ByteCode/ByteUtil.hs @@ -1,7 +1,6 @@ -module ByteCode.ByteUtil(unpackWord16, unpackWord32) where +module ByteCode.ByteUtil where import Data.Word ( Word8, Word16, Word32 ) -import Data.Int import Data.Bits unpackWord16 :: Word16 -> [Word8] diff --git a/src/ByteCode/ClassFile.hs b/src/ByteCode/ClassFile.hs index 358b91a..1fc15c9 100644 --- a/src/ByteCode/ClassFile.hs +++ b/src/ByteCode/ClassFile.hs @@ -6,15 +6,16 @@ module ByteCode.ClassFile( Operation(..), serialize, emptyClassFile, - opcodeEncodingLength + opcodeEncodingLength, + className ) where import Data.Word import Data.Int import Data.ByteString (unpack) import Data.ByteString.UTF8 (fromString) -import ByteCode.ByteUtil import ByteCode.Constants +import ByteCode.ByteUtil data ConstantInfo = ClassInfo Word16 | FieldRefInfo Word16 Word16 @@ -28,11 +29,13 @@ data Operation = Opiadd | Opisub | Opimul | Opidiv + | Opirem | Opiand | Opior | Opixor | Opineg | Opdup + | Opnew Word16 | Opif_icmplt Word16 | Opif_icmple Word16 | Opif_icmpgt Word16 @@ -43,7 +46,10 @@ data Operation = Opiadd | Opreturn | Opireturn | Opareturn + | Opdup_x1 + | Oppop | Opinvokespecial Word16 + | Opinvokevirtual Word16 | Opgoto Word16 | Opsipush Word16 | Opldc_w Word16 @@ -91,16 +97,26 @@ emptyClassFile = ClassFile { attributes = [] } +className :: ClassFile -> String +className classFile = let + classInfo = (constantPool classFile)!!(fromIntegral (thisClass classFile)) + in case classInfo of + Utf8Info className -> className + otherwise -> error ("expected Utf8Info but got: " ++ show otherwise) + + opcodeEncodingLength :: Operation -> Word16 opcodeEncodingLength Opiadd = 1 opcodeEncodingLength Opisub = 1 opcodeEncodingLength Opimul = 1 opcodeEncodingLength Opidiv = 1 +opcodeEncodingLength Opirem = 1 opcodeEncodingLength Opiand = 1 opcodeEncodingLength Opior = 1 opcodeEncodingLength Opixor = 1 opcodeEncodingLength Opineg = 1 opcodeEncodingLength Opdup = 1 +opcodeEncodingLength (Opnew _) = 3 opcodeEncodingLength (Opif_icmplt _) = 3 opcodeEncodingLength (Opif_icmple _) = 3 opcodeEncodingLength (Opif_icmpgt _) = 3 @@ -111,7 +127,10 @@ opcodeEncodingLength Opaconst_null = 1 opcodeEncodingLength Opreturn = 1 opcodeEncodingLength Opireturn = 1 opcodeEncodingLength Opareturn = 1 +opcodeEncodingLength Opdup_x1 = 1 +opcodeEncodingLength Oppop = 1 opcodeEncodingLength (Opinvokespecial _) = 3 +opcodeEncodingLength (Opinvokevirtual _) = 3 opcodeEncodingLength (Opgoto _) = 3 opcodeEncodingLength (Opsipush _) = 3 opcodeEncodingLength (Opldc_w _) = 3 @@ -147,11 +166,13 @@ instance Serializable Operation where serialize Opisub = [0x64] serialize Opimul = [0x68] serialize Opidiv = [0x6C] + serialize Opirem = [0x70] serialize Opiand = [0x7E] serialize Opior = [0x80] serialize Opixor = [0x82] serialize Opineg = [0x74] serialize Opdup = [0x59] + serialize (Opnew index) = 0xBB : unpackWord16 index serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch @@ -162,7 +183,10 @@ instance Serializable Operation where serialize Opreturn = [0xB1] serialize Opireturn = [0xAC] serialize Opareturn = [0xB0] + serialize Opdup_x1 = [0x5A] + serialize Oppop = [0x57] serialize (Opinvokespecial index) = 0xB7 : unpackWord16 index + serialize (Opinvokevirtual index) = 0xB6 : unpackWord16 index serialize (Opgoto index) = 0xA7 : unpackWord16 index serialize (Opsipush index) = 0x11 : unpackWord16 index serialize (Opldc_w index) = 0x13 : unpackWord16 index diff --git a/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs b/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs deleted file mode 100644 index 4ace628..0000000 --- a/src/ByteCode/Generation/Assembler/ExpressionAndStatement.hs +++ /dev/null @@ -1,228 +0,0 @@ -module ByteCode.Generation.Assembler.ExpressionAndStatement where - -import Ast -import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import ByteCode.Generation.Generator -import Data.List -import Data.Char -import ByteCode.Generation.Builder.Field - -assembleExpression :: Assembler Expression -assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) - | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let - (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a - (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b - in - (bConstants, bOps ++ [binaryOperation op], lvars) - | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let - (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a - (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b - cmp_op = comparisonOperation op 9 - cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] - in - (bConstants, bOps ++ cmp_ops, lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = - (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = - (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) - | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) - | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = - (constants, ops ++ [Opaconst_null], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - newConstant = fromIntegral (1 + length exprConstants) - in case etype of - "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) - "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) - "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in - (exprConstants, exprOps ++ [Opineg], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (FieldVariable name)) = let - fieldIndex = findFieldIndex constants name - in case fieldIndex of - Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)], lvars) - Nothing -> error ("No such field found in constant pool: " ++ name) - -assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) = let - localIndex = findIndex ((==) name) lvars - isPrimitive = elem dtype ["char", "boolean", "int"] - in case localIndex of - Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) - Nothing -> error ("No such local variable found in local variable pool: " ++ name) - -assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = - assembleStatementExpression (constants, ops, lvars) stmtexp - -assembleExpression _ expr = error ("unimplemented: " ++ show expr) - - - - --- TODO untested -assembleStatementExpression :: Assembler StatementExpression -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (Assignment (TypedExpression dtype (LocalVariable name)) expr)) = let - localIndex = findIndex ((==) name) lvars - (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr - isPrimitive = elem dtype ["char", "boolean", "int"] - in case localIndex of - Just index -> (constants_a, ops_a ++ if isPrimitive then [Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) - Nothing -> error ("No such local variable found in local variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (Assignment (TypedExpression dtype (FieldVariable name)) expr)) = let - fieldIndex = findFieldIndex constants name - (constants_a, ops_a, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr - in case fieldIndex of - Just index -> (constants_a, ops_a ++ [Opputfield (fromIntegral index)], lvars) - Nothing -> error ("No such field variable found in constant pool: " ++ name) - - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreIncrement (TypedExpression dtype (LocalVariable name)))) = let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] - in case localIndex of - Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostIncrement (TypedExpression dtype (LocalVariable name)))) = let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] - in case localIndex of - Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreDecrement (TypedExpression dtype (LocalVariable name)))) = let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] - in case localIndex of - Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostDecrement (TypedExpression dtype (LocalVariable name)))) = let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] - in case localIndex of - Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreIncrement (TypedExpression dtype (FieldVariable name)))) = let - fieldIndex = findFieldIndex constants name - expr = (TypedExpression dtype (FieldVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr - incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] - in case fieldIndex of - Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) - Nothing -> error("No such field variable found in field variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostIncrement (TypedExpression dtype (FieldVariable name)))) = let - fieldIndex = findFieldIndex constants name - expr = (TypedExpression dtype (FieldVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr - incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] - in case fieldIndex of - Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) - Nothing -> error("No such field variable found in field variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreDecrement (TypedExpression dtype (FieldVariable name)))) = let - fieldIndex = findFieldIndex constants name - expr = (TypedExpression dtype (FieldVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr - incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] - in case fieldIndex of - Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) - Nothing -> error("No such field variable found in field variable pool: " ++ name) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostDecrement (TypedExpression dtype (FieldVariable name)))) = let - fieldIndex = findFieldIndex constants name - expr = (TypedExpression dtype (FieldVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr - incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] - in case fieldIndex of - Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) - Nothing -> error("No such field variable found in field variable pool: " ++ name) - - - - - -assembleStatement :: Assembler Statement -assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of - Nothing -> (constants, ops ++ [Opreturn], lvars) - Just expr -> let - (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr - in - (expr_constants, expr_ops ++ [returnOperation stype], lvars) -assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = - foldl assembleStatement (constants, ops, lvars) statements -assembleStatement (constants, ops, lvars) (TypedStatement _ (If expr if_stmt else_stmt)) = let - (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr - (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt - (constants_elsea, ops_elsea, _) = case else_stmt of - Nothing -> (constants_ifa, [], lvars) - Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt - -- +6 because we insert 2 gotos, one for if, one for else - if_length = sum (map opcodeEncodingLength ops_ifa) + 6 - -- +3 because we need to account for the goto in the if statement. - else_length = sum (map opcodeEncodingLength ops_elsea) + 3 - in - (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea, lvars) -assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let - (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr - (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt - -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison - stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 - entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) - in - (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) -assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let - isPrimitive = elem dtype ["char", "boolean", "int"] - (constants_init, ops_init, _) = case expr of - Just exp -> assembleExpression (constants, ops, lvars) exp - Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) - localIndex = fromIntegral (length lvars) - storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] - in - (constants_init, ops_init ++ storeLocal, lvars ++ [name]) - -assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = - assembleStatementExpression (constants, ops, lvars) expr - -assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) diff --git a/src/ByteCode/Generation/Assembler/Method.hs b/src/ByteCode/Generation/Assembler/Method.hs deleted file mode 100644 index a1b896e..0000000 --- a/src/ByteCode/Generation/Assembler/Method.hs +++ /dev/null @@ -1,20 +0,0 @@ -module ByteCode.Generation.Assembler.Method where - -import Ast -import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import ByteCode.Generation.Generator -import ByteCode.Generation.Assembler.ExpressionAndStatement - -assembleMethod :: Assembler MethodDeclaration -assembleMethod (constants, ops, lvars) (MethodDeclaration _ name _ (TypedStatement _ (Block statements))) - | name == "" = let - (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements - init_ops = [Opaload 0, Opinvokespecial 2] - in - (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) - | otherwise = let - (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements - init_ops = [Opaload 0] - in - (constants_a, init_ops ++ ops_a, lvars_a) -assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt) diff --git a/src/ByteCode/Generation/Builder/Class.hs b/src/ByteCode/Generation/Builder/Class.hs deleted file mode 100644 index 16fef21..0000000 --- a/src/ByteCode/Generation/Builder/Class.hs +++ /dev/null @@ -1,44 +0,0 @@ -module ByteCode.Generation.Builder.Class where - -import ByteCode.Generation.Builder.Field -import ByteCode.Generation.Builder.Method -import ByteCode.Generation.Generator -import Ast -import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import ByteCode.Constants - -injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] -injectDefaultConstructor pre - | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre - | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] - - -classBuilder :: ClassFileBuilder Class -classBuilder (Class name methods fields) _ = let - baseConstants = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code" - ] - nameConstants = [ClassInfo 9, Utf8Info name] - nakedClassFile = ClassFile { - constantPool = baseConstants ++ nameConstants, - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [], - methods = [], - attributes = [] - } - - methodsWithInjectedConstructor = injectDefaultConstructor methods - - classFileWithFields = foldr fieldBuilder nakedClassFile fields - classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedConstructor - classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedConstructor - in - classFileWithAssembledMethods \ No newline at end of file diff --git a/src/ByteCode/Generation/Builder/Field.hs b/src/ByteCode/Generation/Builder/Field.hs deleted file mode 100644 index ec1f711..0000000 --- a/src/ByteCode/Generation/Builder/Field.hs +++ /dev/null @@ -1,46 +0,0 @@ -module ByteCode.Generation.Builder.Field where - -import Ast -import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import ByteCode.Generation.Generator -import ByteCode.Constants -import Data.List - -findFieldIndex :: [ConstantInfo] -> String -> Maybe Int -findFieldIndex constants name = let - fieldRefNameInfos = [ - -- we only skip one entry to get the name since the Java constant pool - -- is 1-indexed (why) - (index, constants!!(fromIntegral index + 1)) - | (index, FieldRefInfo _ _) <- (zip [1..] constants) - ] - fieldRefNames = map (\(index, nameInfo) -> case nameInfo of - Utf8Info fieldName -> (index, fieldName) - something_else -> error ("Expected UTF8Info but got" ++ show something_else)) - fieldRefNameInfos - fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames - in case fieldIndex of - Just (index, _) -> Just index - Nothing -> Nothing - - -fieldBuilder :: ClassFileBuilder VariableDeclaration -fieldBuilder (VariableDeclaration datatype name _) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info name, - Utf8Info (datatypeDescriptor datatype) - ] - field = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral (baseIndex + 2)), - memberDescriptorIndex = (fromIntegral (baseIndex + 3)), - memberAttributes = [] - } - in - input { - constantPool = (constantPool input) ++ constants, - fields = (fields input) ++ [field] - } diff --git a/src/ByteCode/Generation/Builder/Method.hs b/src/ByteCode/Generation/Builder/Method.hs deleted file mode 100644 index 5475d4d..0000000 --- a/src/ByteCode/Generation/Builder/Method.hs +++ /dev/null @@ -1,80 +0,0 @@ -module ByteCode.Generation.Builder.Method where - -import Ast -import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import ByteCode.Generation.Generator -import ByteCode.Generation.Assembler.Method -import ByteCode.Constants -import Data.List - -methodDescriptor :: MethodDeclaration -> String -methodDescriptor (MethodDeclaration returntype _ parameters _) = let - parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] - in - "(" - ++ (concat (map methodParameterDescriptor parameter_types)) - ++ ")" - ++ methodParameterDescriptor returntype - -methodParameterDescriptor :: String -> String -methodParameterDescriptor "void" = "V" -methodParameterDescriptor "int" = "I" -methodParameterDescriptor "char" = "C" -methodParameterDescriptor "boolean" = "B" -methodParameterDescriptor x = "L" ++ x ++ ";" - -memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool -memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) - -findMethodIndex :: ClassFile -> String -> Maybe Int -findMethodIndex classFile name = let - constants = constantPool classFile - in - findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) - - -methodBuilder :: ClassFileBuilder MethodDeclaration -methodBuilder (MethodDeclaration returntype name parameters statement) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - Utf8Info name, - Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) - ] - - method = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral baseIndex), - memberDescriptorIndex = (fromIntegral (baseIndex + 1)), - memberAttributes = [] - } - in - input { - constantPool = (constantPool input) ++ constants, - methods = (methods input) ++ [method] - } - - - -methodAssembler :: ClassFileBuilder MethodDeclaration -methodAssembler (MethodDeclaration returntype name parameters statement) input = let - methodConstantIndex = findMethodIndex input name - in case methodConstantIndex of - Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) - Just index -> let - declaration = MethodDeclaration returntype name parameters statement - paramNames = "this" : [name | ParameterDeclaration _ name <- parameters] - (pre, method : post) = splitAt index (methods input) - (_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration - assembledMethod = method { - memberAttributes = [ - CodeAttribute { - attributeMaxStack = 420, - attributeMaxLocals = 420, - attributeCode = bytecode - } - ] - } - in - input { - methods = pre ++ (assembledMethod : post) - } diff --git a/src/ByteCode/Generation/Generator.hs b/src/ByteCode/Generation/Generator.hs deleted file mode 100644 index 6d42ba0..0000000 --- a/src/ByteCode/Generation/Generator.hs +++ /dev/null @@ -1,73 +0,0 @@ -module ByteCode.Generation.Generator( - datatypeDescriptor, - memberInfoName, - memberInfoDescriptor, - returnOperation, - binaryOperation, - comparisonOperation, - ClassFileBuilder, - Assembler -) where - -import ByteCode.Constants -import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import Ast -import Data.Char -import Data.List -import Data.Word - -type ClassFileBuilder a = a -> ClassFile -> ClassFile -type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) - -datatypeDescriptor :: String -> String -datatypeDescriptor "void" = "V" -datatypeDescriptor "int" = "I" -datatypeDescriptor "char" = "C" -datatypeDescriptor "boolean" = "B" -datatypeDescriptor x = "L" ++ x - -memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String -memberInfoDescriptor constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = _, - memberDescriptorIndex = descriptorIndex, - memberAttributes = _ } = let - descriptor = constants!!((fromIntegral descriptorIndex) - 1) - in case descriptor of - Utf8Info descriptorText -> descriptorText - _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) - - -memberInfoName :: [ConstantInfo] -> MemberInfo -> String -memberInfoName constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = nameIndex, - memberDescriptorIndex = _, - memberAttributes = _ } = let - name = constants!!((fromIntegral nameIndex) - 1) - in case name of - Utf8Info nameText -> nameText - _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) - - -returnOperation :: DataType -> Operation -returnOperation dtype - | elem dtype ["int", "char", "boolean"] = Opireturn - | otherwise = Opareturn - -binaryOperation :: BinaryOperator -> Operation -binaryOperation Addition = Opiadd -binaryOperation Subtraction = Opisub -binaryOperation Multiplication = Opimul -binaryOperation Division = Opidiv -binaryOperation BitwiseAnd = Opiand -binaryOperation BitwiseOr = Opior -binaryOperation BitwiseXor = Opixor - -comparisonOperation :: BinaryOperator -> Word16 -> Operation -comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation -comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation -comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation -comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation -comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation -comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation \ No newline at end of file diff --git a/src/ByteCode/Util.hs b/src/ByteCode/Util.hs new file mode 100644 index 0000000..45e8816 --- /dev/null +++ b/src/ByteCode/Util.hs @@ -0,0 +1,245 @@ +module ByteCode.Util where + +import Data.Int +import Ast +import ByteCode.ClassFile +import Data.List +import Data.Maybe (mapMaybe) +import Data.Word (Word8, Word16, Word32) + +-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. +resolveNameChain :: Expression -> Expression +resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b +resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) +resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) +resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) + +-- walks the name resolution chain. returns the second-to-last item of the namechain. +resolveNameChainOwner :: Expression -> Expression +resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a +resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b +resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) + + +methodDescriptor :: MethodDeclaration -> String +methodDescriptor (MethodDeclaration returntype _ parameters _) = let + parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] + in + "(" + ++ (concat (map datatypeDescriptor parameter_types)) + ++ ")" + ++ datatypeDescriptor returntype + +methodDescriptorFromParamlist :: [Expression] -> String -> String +methodDescriptorFromParamlist parameters returntype = let + parameter_types = [datatype | TypedExpression datatype _ <- parameters] + in + "(" + ++ (concat (map datatypeDescriptor parameter_types)) + ++ ")" + ++ datatypeDescriptor returntype + +memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool +memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) + + +datatypeDescriptor :: String -> String +datatypeDescriptor "void" = "V" +datatypeDescriptor "int" = "I" +datatypeDescriptor "char" = "C" +datatypeDescriptor "boolean" = "B" +datatypeDescriptor x = "L" ++ x ++ ";" + + +memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String +memberInfoDescriptor constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = _, + memberDescriptorIndex = descriptorIndex, + memberAttributes = _ } = let + descriptor = constants!!((fromIntegral descriptorIndex) - 1) + in case descriptor of + Utf8Info descriptorText -> descriptorText + _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) + + +memberInfoName :: [ConstantInfo] -> MemberInfo -> String +memberInfoName constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = nameIndex, + memberDescriptorIndex = _, + memberAttributes = _ } = let + name = constants!!((fromIntegral nameIndex) - 1) + in case name of + Utf8Info nameText -> nameText + _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) + + +returnOperation :: DataType -> Operation +returnOperation dtype + | elem dtype ["int", "char", "boolean"] = Opireturn + | otherwise = Opareturn + +binaryOperation :: BinaryOperator -> Operation +binaryOperation Addition = Opiadd +binaryOperation Subtraction = Opisub +binaryOperation Multiplication = Opimul +binaryOperation Division = Opidiv +binaryOperation Modulo = Opirem +binaryOperation BitwiseAnd = Opiand +binaryOperation BitwiseOr = Opior +binaryOperation BitwiseXor = Opixor +binaryOperation And = Opiand +binaryOperation Or = Opior + +comparisonOperation :: BinaryOperator -> Word16 -> Operation +comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation +comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation +comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation +comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation +comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation +comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation + +findFieldIndex :: [ConstantInfo] -> String -> Maybe Int +findFieldIndex constants name = let + fieldRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, FieldRefInfo classIndex _) <- (zip [1..] constants) + ] + fieldRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info fieldName -> (index, fieldName) + something_else -> error ("Expected UTF8Info but got" ++ show something_else)) + fieldRefNameInfos + fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames + in case fieldIndex of + Just (index, _) -> Just index + Nothing -> Nothing + +findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int +findMethodRefIndex constants name = let + methodRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, MethodRefInfo _ _) <- (zip [1..] constants) + ] + methodRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info methodName -> (index, methodName) + something_else -> error ("Expected UTF8Info but got " ++ show something_else)) + methodRefNameInfos + methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames + in case methodIndex of + Just (index, _) -> Just index + Nothing -> Nothing + + +findMethodIndex :: ClassFile -> String -> Maybe Int +findMethodIndex classFile name = let + constants = constantPool classFile + in + findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) + +findClassIndex :: [ConstantInfo] -> String -> Maybe Int +findClassIndex constants name = let + classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)] + classNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info className -> (index, className) + something_else -> error("Expected UTF8Info but got " ++ show something_else)) + classNameIndices + desiredClassIndex = find (\(index, className) -> className == name) classNames + in case desiredClassIndex of + Just (index, _) -> Just index + Nothing -> Nothing + +getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))] +getKnownMembers constants = let + fieldsClassAndNT = [ + (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) + | (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) + ] ++ [ + (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) + | (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) + ] + + fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of + (ClassInfo nameIndex, NameAndTypeInfo fnameIndex ftypeIndex) -> (index, (constants!!(fromIntegral nameIndex - 1), constants!!(fromIntegral fnameIndex - 1), constants!!(fromIntegral ftypeIndex - 1))) + something_else -> error ("Expected Class and NameType info, but got: " ++ show nameInfo ++ " and " ++ show nameTypeInfo)) + fieldsClassAndNT + + fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of + (Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype)) + something_else -> error("Expected UTF8Infos but got " ++ show something_else)) + fieldsClassNameType + in + fieldsResolved + +-- same as findClassIndex, but inserts a new entry into constant pool if not existing +getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int) +getClassIndex constants name = case findClassIndex constants name of + Just index -> (constants, index) + Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants)) + +-- get the index for a field within a class, creating it if it does not exist. +getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) +getFieldIndex constants (cname, fname, ftype) = case findMemberIndex constants (cname, fname, ftype) of + Just index -> (constants, index) + Nothing -> let + (constantsWithClass, classIndex) = getClassIndex constants cname + baseIndex = 1 + length constantsWithClass + in + (constantsWithClass ++ [ + FieldRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info fname, + Utf8Info (datatypeDescriptor ftype) + ], baseIndex) + +getMethodIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) +getMethodIndex constants (cname, mname, mtype) = case findMemberIndex constants (cname, mname, mtype) of + Just index -> (constants, index) + Nothing -> let + (constantsWithClass, classIndex) = getClassIndex constants cname + baseIndex = 1 + length constantsWithClass + in + (constantsWithClass ++ [ + MethodRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info mname, + Utf8Info mtype + ], baseIndex) + +findMemberIndex :: [ConstantInfo] -> (String, String, String) -> Maybe Int +findMemberIndex constants (cname, fname, ftype) = let + allMembers = getKnownMembers constants + desiredMember = find (\(index, (c, f, ft)) -> (c, f, ft) == (cname, fname, ftype)) allMembers + in + fmap (\(index, _) -> index) desiredMember + +injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] +injectDefaultConstructor pre + | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre + | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] + +injectFieldInitializers :: String -> [VariableDeclaration] -> [MethodDeclaration] -> [MethodDeclaration] +injectFieldInitializers classname vars pre = let + initializers = mapMaybe (\(variable) -> case variable of + VariableDeclaration dtype name (Just initializer) -> Just ( + TypedStatement dtype ( + StatementExpressionStatement ( + TypedStatementExpression dtype ( + Assignment + (TypedExpression dtype (BinaryOperation NameResolution (TypedExpression classname (LocalVariable "this")) (TypedExpression dtype (FieldVariable name)))) + initializer + ) + ) + ) + ) + otherwise -> Nothing + ) vars + in + map (\(method) -> case method of + MethodDeclaration "void" "" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "" params (TypedStatement "void" (Block (initializers ++ statements))) + otherwise -> method + ) pre \ No newline at end of file diff --git a/src/Main.hs b/src/Main.hs index 588efc2..3f4c329 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -4,17 +4,30 @@ import Example import Typecheck import Parser.Lexer (alexScanTokens) import Parser.JavaParser -import ByteCode.Generation.Generator -import ByteCode.Generation.Builder.Class +import ByteCode.Builder import ByteCode.ClassFile import Data.ByteString (pack, writeFile) +import System.Environment +import System.FilePath.Posix (takeDirectory) main = do - file <- readFile "Testklasse.java" + args <- getArgs + let filename = if null args + then error "Missing filename, I need to know what to compile" + else args!!0 + let outputDirectory = takeDirectory filename + print ("Compiling " ++ filename) + file <- readFile filename let untypedAST = parse $ alexScanTokens file - let typedAST = head (typeCheckCompilationUnit untypedAST) - let abstractClassFile = classBuilder typedAST emptyClassFile - let assembledClassFile = pack (serialize abstractClassFile) + let typedAST = (typeCheckCompilationUnit untypedAST) + let assembledClasses = map (\(typedClass) -> classBuilder typedClass emptyClassFile) typedAST - Data.ByteString.writeFile "Testklasse.class" assembledClassFile + mapM_ (\(classFile) -> let + fileContent = pack (serialize classFile) + fileName = outputDirectory ++ "/" ++ (className classFile) ++ ".class" + in Data.ByteString.writeFile fileName fileContent + ) assembledClasses + + +