diff --git a/project.cabal b/project.cabal index cf8bce0..ad9b335 100644 --- a/project.cabal +++ b/project.cabal @@ -12,18 +12,23 @@ executable compiler utf8-string, bytestring default-language: Haskell2010 - hs-source-dirs: src, - src/ByteCode, - src/ByteCode/ClassFile + hs-source-dirs: src build-tool-depends: alex:alex, happy:happy other-modules: Parser.Lexer, - Parser.JavaParser + Parser.JavaParser, Ast, Example, Typecheck, ByteCode.ByteUtil, ByteCode.ClassFile, - ByteCode.ClassFile.Generator, + ByteCode.Generation.Generator, + ByteCode.Generation.Assembler.Expression, + ByteCode.Generation.Assembler.Method, + ByteCode.Generation.Assembler.Statement, + ByteCode.Generation.Assembler.StatementExpression, + ByteCode.Generation.Builder.Class, + ByteCode.Generation.Builder.Field, + ByteCode.Generation.Builder.Method, ByteCode.Constants test-suite tests diff --git a/src/ByteCode/ClassFile.hs b/src/ByteCode/ClassFile.hs index 7798cf8..358b91a 100644 --- a/src/ByteCode/ClassFile.hs +++ b/src/ByteCode/ClassFile.hs @@ -32,6 +32,7 @@ data Operation = Opiadd | Opior | Opixor | Opineg + | Opdup | Opif_icmplt Word16 | Opif_icmple Word16 | Opif_icmpgt Word16 @@ -51,7 +52,7 @@ data Operation = Opiadd | Opastore Word16 | Opistore Word16 | Opputfield Word16 - | OpgetField Word16 + | Opgetfield Word16 deriving (Show, Eq) @@ -99,6 +100,7 @@ opcodeEncodingLength Opiand = 1 opcodeEncodingLength Opior = 1 opcodeEncodingLength Opixor = 1 opcodeEncodingLength Opineg = 1 +opcodeEncodingLength Opdup = 1 opcodeEncodingLength (Opif_icmplt _) = 3 opcodeEncodingLength (Opif_icmple _) = 3 opcodeEncodingLength (Opif_icmpgt _) = 3 @@ -113,12 +115,12 @@ opcodeEncodingLength (Opinvokespecial _) = 3 opcodeEncodingLength (Opgoto _) = 3 opcodeEncodingLength (Opsipush _) = 3 opcodeEncodingLength (Opldc_w _) = 3 -opcodeEncodingLength (Opaload _) = 3 -opcodeEncodingLength (Opiload _) = 3 -opcodeEncodingLength (Opastore _) = 3 -opcodeEncodingLength (Opistore _) = 3 +opcodeEncodingLength (Opaload _) = 4 +opcodeEncodingLength (Opiload _) = 4 +opcodeEncodingLength (Opastore _) = 4 +opcodeEncodingLength (Opistore _) = 4 opcodeEncodingLength (Opputfield _) = 3 -opcodeEncodingLength (OpgetField _) = 3 +opcodeEncodingLength (Opgetfield _) = 3 class Serializable a where serialize :: a -> [Word8] @@ -149,6 +151,7 @@ instance Serializable Operation where serialize Opior = [0x80] serialize Opixor = [0x82] serialize Opineg = [0x74] + serialize Opdup = [0x59] serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch @@ -168,7 +171,7 @@ instance Serializable Operation where serialize (Opastore index) = [0xC4, 0x3A] ++ unpackWord16 index serialize (Opistore index) = [0xC4, 0x36] ++ unpackWord16 index serialize (Opputfield index) = 0xB5 : unpackWord16 index - serialize (OpgetField index) = 0xB4 : unpackWord16 index + serialize (Opgetfield index) = 0xB4 : unpackWord16 index instance Serializable Attribute where serialize (CodeAttribute { attributeMaxStack = maxStack, diff --git a/src/ByteCode/ClassFile/Generator.hs b/src/ByteCode/ClassFile/Generator.hs deleted file mode 100644 index 37919e0..0000000 --- a/src/ByteCode/ClassFile/Generator.hs +++ /dev/null @@ -1,273 +0,0 @@ -module ByteCode.ClassFile.Generator( - classBuilder, - datatypeDescriptor, - methodParameterDescriptor, - methodDescriptor, - memberInfoIsMethod, - memberInfoName, - memberInfoDescriptor, - findMethodIndex -) where - -import ByteCode.Constants -import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) -import Ast -import Data.Char -import Data.List -import Data.Word - - -type ClassFileBuilder a = a -> ClassFile -> ClassFile - - -datatypeDescriptor :: String -> String -datatypeDescriptor "void" = "V" -datatypeDescriptor "int" = "I" -datatypeDescriptor "char" = "C" -datatypeDescriptor "boolean" = "B" -datatypeDescriptor x = "L" ++ x - -methodParameterDescriptor :: String -> String -methodParameterDescriptor "void" = "V" -methodParameterDescriptor "int" = "I" -methodParameterDescriptor "char" = "C" -methodParameterDescriptor "boolean" = "B" -methodParameterDescriptor x = "L" ++ x ++ ";" - -memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool -memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) - -findMethodIndex :: ClassFile -> String -> Maybe Int -findMethodIndex classFile name = let - constants = constantPool classFile - in - findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) - -memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String -memberInfoDescriptor constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = _, - memberDescriptorIndex = descriptorIndex, - memberAttributes = _ } = let - descriptor = constants!!((fromIntegral descriptorIndex) - 1) - in case descriptor of - Utf8Info descriptorText -> descriptorText - _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) - - -memberInfoName :: [ConstantInfo] -> MemberInfo -> String -memberInfoName constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = nameIndex, - memberDescriptorIndex = _, - memberAttributes = _ } = let - name = constants!!((fromIntegral nameIndex) - 1) - in case name of - Utf8Info nameText -> nameText - _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) - - -methodDescriptor :: MethodDeclaration -> String -methodDescriptor (MethodDeclaration returntype _ parameters _) = let - parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] - in - "(" - ++ (concat (map methodParameterDescriptor parameter_types)) - ++ ")" - ++ datatypeDescriptor returntype - - -classBuilder :: ClassFileBuilder Class -classBuilder (Class name methods fields) _ = let - baseConstants = [ - ClassInfo 4, - MethodRefInfo 1 3, - NameAndTypeInfo 5 6, - Utf8Info "java/lang/Object", - Utf8Info "", - Utf8Info "()V", - Utf8Info "Code" - ] - nameConstants = [ClassInfo 9, Utf8Info name] - nakedClassFile = ClassFile { - constantPool = baseConstants ++ nameConstants, - accessFlags = accessPublic, - thisClass = 8, - superClass = 1, - fields = [], - methods = [], - attributes = [] - } - - classFileWithFields = foldr fieldBuilder nakedClassFile fields - classFileWithMethods = foldr methodBuilder classFileWithFields methods - classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methods - in - classFileWithAssembledMethods - - - -fieldBuilder :: ClassFileBuilder VariableDeclaration -fieldBuilder (VariableDeclaration datatype name _) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info name, - Utf8Info (datatypeDescriptor datatype) - ] - field = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral (baseIndex + 2)), - memberDescriptorIndex = (fromIntegral (baseIndex + 3)), - memberAttributes = [] - } - in - input { - constantPool = (constantPool input) ++ constants, - fields = (fields input) ++ [field] - } - -methodBuilder :: ClassFileBuilder MethodDeclaration -methodBuilder (MethodDeclaration returntype name parameters statement) input = let - baseIndex = 1 + length (constantPool input) - constants = [ - Utf8Info name, - Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) - ] - - method = MemberInfo { - memberAccessFlags = accessPublic, - memberNameIndex = (fromIntegral baseIndex), - memberDescriptorIndex = (fromIntegral (baseIndex + 1)), - memberAttributes = [] - } - in - input { - constantPool = (constantPool input) ++ constants, - methods = (methods input) ++ [method] - } - - -methodAssembler :: ClassFileBuilder MethodDeclaration -methodAssembler (MethodDeclaration returntype name parameters statement) input = let - methodConstantIndex = findMethodIndex input name - in case methodConstantIndex of - Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) - Just index -> let - declaration = MethodDeclaration returntype name parameters statement - (pre, method : post) = splitAt index (methods input) - (_, bytecode) = assembleMethod (constantPool input, []) declaration - assembledMethod = method { - memberAttributes = [ - CodeAttribute { - attributeMaxStack = 420, - attributeMaxLocals = 420, - attributeCode = bytecode - } - ] - } - in - input { - methods = pre ++ (assembledMethod : post) - } - - - - -type Assembler a = ([ConstantInfo], [Operation]) -> a -> ([ConstantInfo], [Operation]) - -returnOperation :: DataType -> Operation -returnOperation dtype - | elem dtype ["int", "char", "boolean"] = Opireturn - | otherwise = Opareturn - -binaryOperation :: BinaryOperator -> Operation -binaryOperation Addition = Opiadd -binaryOperation Subtraction = Opisub -binaryOperation Multiplication = Opimul -binaryOperation Division = Opidiv -binaryOperation BitwiseAnd = Opiand -binaryOperation BitwiseOr = Opior -binaryOperation BitwiseXor = Opixor - -comparisonOperation :: BinaryOperator -> Word16 -> Operation -comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation -comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation -comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation -comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation -comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation -comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation - - -assembleMethod :: Assembler MethodDeclaration -assembleMethod (constants, ops) (MethodDeclaration _ name _ (TypedStatement _ (Block statements))) - | name == "" = let - (constants_a, ops_a) = foldl assembleStatement (constants, ops) statements - init_ops = [Opaload 0, Opinvokespecial 2] - in - (constants_a, init_ops ++ ops_a ++ [Opreturn]) - | otherwise = let - (constants_a, ops_a) = foldl assembleStatement (constants, ops) statements - init_ops = [Opaload 0] - in - (constants_a, init_ops ++ ops_a) -assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt) - -assembleStatement :: Assembler Statement -assembleStatement (constants, ops) (TypedStatement stype (Return expr)) = case expr of - Nothing -> (constants, ops ++ [Opreturn]) - Just expr -> let - (expr_constants, expr_ops) = assembleExpression (constants, ops) expr - in - (expr_constants, expr_ops ++ [returnOperation stype]) -assembleStatement (constants, ops) (TypedStatement _ (Block statements)) = - foldl assembleStatement (constants, ops) statements -assembleStatement (constants, ops) (TypedStatement _ (If expr if_stmt else_stmt)) = let - (constants_cmp, ops_cmp) = assembleExpression (constants, []) expr - (constants_ifa, ops_ifa) = assembleStatement (constants_cmp, []) if_stmt - (constants_elsea, ops_elsea) = case else_stmt of - Nothing -> (constants_ifa, []) - Just stmt -> assembleStatement (constants_ifa, []) stmt - -- +6 because we insert 2 gotos, one for if, one for else - if_length = sum (map opcodeEncodingLength ops_ifa) + 6 - -- +3 because we need to account for the goto in the if statement. - else_length = sum (map opcodeEncodingLength ops_elsea) + 3 - in - (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea) -assembleStatement stmt _ = error ("Not yet implemented: " ++ show stmt) - -assembleExpression :: Assembler Expression -assembleExpression (constants, ops) (TypedExpression _ (BinaryOperation op a b)) - | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let - (aConstants, aOps) = assembleExpression (constants, ops) a - (bConstants, bOps) = assembleExpression (aConstants, aOps) b - in - (bConstants, bOps ++ [binaryOperation op]) - | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let - (aConstants, aOps) = assembleExpression (constants, ops) a - (bConstants, bOps) = assembleExpression (aConstants, aOps) b - cmp_op = comparisonOperation op 9 - cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] - in - (bConstants, bOps ++ cmp_ops) -assembleExpression (constants, ops) (TypedExpression _ (CharacterLiteral literal)) = - (constants, ops ++ [Opsipush (fromIntegral (ord literal))]) -assembleExpression (constants, ops) (TypedExpression _ (BooleanLiteral literal)) = - (constants, ops ++ [Opsipush (if literal then 1 else 0)]) -assembleExpression (constants, ops) (TypedExpression _ (IntegerLiteral literal)) - | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)]) - | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))]) -assembleExpression (constants, ops) (TypedExpression _ NullLiteral) = - (constants, ops ++ [Opaconst_null]) -assembleExpression (constants, ops) (TypedExpression etype (UnaryOperation Not expr)) = let - (exprConstants, exprOps) = assembleExpression (constants, ops) expr - newConstant = fromIntegral (1 + length exprConstants) - in case etype of - "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor]) - "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor]) - "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor]) -assembleExpression (constants, ops) (TypedExpression _ (UnaryOperation Minus expr)) = let - (exprConstants, exprOps) = assembleExpression (constants, ops) expr - in - (exprConstants, exprOps ++ [Opineg]) diff --git a/src/ByteCode/Generation/Assembler/Expression.hs b/src/ByteCode/Generation/Assembler/Expression.hs new file mode 100644 index 0000000..6921956 --- /dev/null +++ b/src/ByteCode/Generation/Assembler/Expression.hs @@ -0,0 +1,153 @@ +module ByteCode.Generation.Assembler.Expression where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import Data.List +import Data.Char +import ByteCode.Generation.Builder.Field + +assembleExpression :: Assembler Expression +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) + | elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + in + (bConstants, bOps ++ [binaryOperation op], lvars) + | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + cmp_op = comparisonOperation op 9 + cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] + in + (bConstants, bOps ++ cmp_ops, lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = + (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = + (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) + | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) + | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = + (constants, ops ++ [Opaconst_null], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + newConstant = fromIntegral (1 + length exprConstants) + in case etype of + "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) + "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) + "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in + (exprConstants, exprOps ++ [Opineg], lvars) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PreIncrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PostIncrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PreDecrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PostDecrement (TypedExpression dtype (LocalVariable name)))) = let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] + in case localIndex of + Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PreIncrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PostIncrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PreDecrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + +assembleExpression + (constants, ops, lvars) + (TypedExpression _ (UnaryOperation PostDecrement (TypedExpression dtype (FieldVariable name)))) = let + fieldIndex = findFieldIndex constants name + expr = (TypedExpression dtype (FieldVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub] + in case fieldIndex of + Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error("No such field variable found in field variable pool: " ++ name) + + +assembleExpression (constants, ops, lvars) (TypedExpression _ (FieldVariable name)) = let + fieldIndex = findFieldIndex constants name + in case fieldIndex of + Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)], lvars) + Nothing -> error ("No such field found in constant pool: " ++ name) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) = let + localIndex = findIndex ((==) name) lvars + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleExpression _ expr = error ("unimplemented: " ++ show expr) diff --git a/src/ByteCode/Generation/Assembler/Method.hs b/src/ByteCode/Generation/Assembler/Method.hs new file mode 100644 index 0000000..c4826c1 --- /dev/null +++ b/src/ByteCode/Generation/Assembler/Method.hs @@ -0,0 +1,20 @@ +module ByteCode.Generation.Assembler.Method where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Generation.Assembler.Statement + +assembleMethod :: Assembler MethodDeclaration +assembleMethod (constants, ops, lvars) (MethodDeclaration _ name _ (TypedStatement _ (Block statements))) + | name == "" = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0, Opinvokespecial 2] + in + (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) + | otherwise = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0] + in + (constants_a, init_ops ++ ops_a, lvars_a) +assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt) diff --git a/src/ByteCode/Generation/Assembler/Statement.hs b/src/ByteCode/Generation/Assembler/Statement.hs new file mode 100644 index 0000000..ed4dcc9 --- /dev/null +++ b/src/ByteCode/Generation/Assembler/Statement.hs @@ -0,0 +1,51 @@ +module ByteCode.Generation.Assembler.Statement where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Generation.Assembler.Expression +import ByteCode.Generation.Assembler.StatementExpression + +assembleStatement :: Assembler Statement +assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of + Nothing -> (constants, ops ++ [Opreturn], lvars) + Just expr -> let + (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr + in + (expr_constants, expr_ops ++ [returnOperation stype], lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = + foldl assembleStatement (constants, ops, lvars) statements +assembleStatement (constants, ops, lvars) (TypedStatement _ (If expr if_stmt else_stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt + (constants_elsea, ops_elsea, _) = case else_stmt of + Nothing -> (constants_ifa, [], lvars) + Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt + -- +6 because we insert 2 gotos, one for if, one for else + if_length = sum (map opcodeEncodingLength ops_ifa) + 6 + -- +3 because we need to account for the goto in the if statement. + else_length = sum (map opcodeEncodingLength ops_elsea) + 3 + in + (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea, lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt + -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison + stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 + entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) + in + (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) +assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let + isPrimitive = elem dtype ["char", "boolean", "int"] + (constants_init, ops_init, _) = case expr of + Just exp -> assembleExpression (constants, ops, lvars) exp + Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) + localIndex = fromIntegral (length lvars) + storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] + in + (constants_init, ops_init ++ storeLocal, lvars ++ [name]) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = + assembleStatementExpression (constants, ops, lvars) expr + +assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) diff --git a/src/ByteCode/Generation/Assembler/StatementExpression.hs b/src/ByteCode/Generation/Assembler/StatementExpression.hs new file mode 100644 index 0000000..e9fcb07 --- /dev/null +++ b/src/ByteCode/Generation/Assembler/StatementExpression.hs @@ -0,0 +1,29 @@ +module ByteCode.Generation.Assembler.StatementExpression where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import Data.List +import ByteCode.Generation.Assembler.Expression +import ByteCode.Generation.Builder.Field + +-- TODO untested +assembleStatementExpression :: Assembler StatementExpression +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype (LocalVariable name)) expr)) = let + localIndex = findIndex ((==) name) lvars + (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants_a, ops_a ++ if isPrimitive then [Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype (FieldVariable name)) expr)) = let + fieldIndex = findFieldIndex constants name + (constants_a, ops_a, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr + in case fieldIndex of + Just index -> (constants_a, ops_a ++ [Opputfield (fromIntegral index)], lvars) + Nothing -> error ("No such field variable found in constant pool: " ++ name) \ No newline at end of file diff --git a/src/ByteCode/Generation/Builder/Class.hs b/src/ByteCode/Generation/Builder/Class.hs new file mode 100644 index 0000000..16fef21 --- /dev/null +++ b/src/ByteCode/Generation/Builder/Class.hs @@ -0,0 +1,44 @@ +module ByteCode.Generation.Builder.Class where + +import ByteCode.Generation.Builder.Field +import ByteCode.Generation.Builder.Method +import ByteCode.Generation.Generator +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Constants + +injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] +injectDefaultConstructor pre + | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre + | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] + + +classBuilder :: ClassFileBuilder Class +classBuilder (Class name methods fields) _ = let + baseConstants = [ + ClassInfo 4, + MethodRefInfo 1 3, + NameAndTypeInfo 5 6, + Utf8Info "java/lang/Object", + Utf8Info "", + Utf8Info "()V", + Utf8Info "Code" + ] + nameConstants = [ClassInfo 9, Utf8Info name] + nakedClassFile = ClassFile { + constantPool = baseConstants ++ nameConstants, + accessFlags = accessPublic, + thisClass = 8, + superClass = 1, + fields = [], + methods = [], + attributes = [] + } + + methodsWithInjectedConstructor = injectDefaultConstructor methods + + classFileWithFields = foldr fieldBuilder nakedClassFile fields + classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedConstructor + classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedConstructor + in + classFileWithAssembledMethods \ No newline at end of file diff --git a/src/ByteCode/Generation/Builder/Field.hs b/src/ByteCode/Generation/Builder/Field.hs new file mode 100644 index 0000000..ec1f711 --- /dev/null +++ b/src/ByteCode/Generation/Builder/Field.hs @@ -0,0 +1,46 @@ +module ByteCode.Generation.Builder.Field where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Constants +import Data.List + +findFieldIndex :: [ConstantInfo] -> String -> Maybe Int +findFieldIndex constants name = let + fieldRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, FieldRefInfo _ _) <- (zip [1..] constants) + ] + fieldRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info fieldName -> (index, fieldName) + something_else -> error ("Expected UTF8Info but got" ++ show something_else)) + fieldRefNameInfos + fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames + in case fieldIndex of + Just (index, _) -> Just index + Nothing -> Nothing + + +fieldBuilder :: ClassFileBuilder VariableDeclaration +fieldBuilder (VariableDeclaration datatype name _) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info name, + Utf8Info (datatypeDescriptor datatype) + ] + field = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral (baseIndex + 2)), + memberDescriptorIndex = (fromIntegral (baseIndex + 3)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + fields = (fields input) ++ [field] + } diff --git a/src/ByteCode/Generation/Builder/Method.hs b/src/ByteCode/Generation/Builder/Method.hs new file mode 100644 index 0000000..5475d4d --- /dev/null +++ b/src/ByteCode/Generation/Builder/Method.hs @@ -0,0 +1,80 @@ +module ByteCode.Generation.Builder.Method where + +import Ast +import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Generation.Generator +import ByteCode.Generation.Assembler.Method +import ByteCode.Constants +import Data.List + +methodDescriptor :: MethodDeclaration -> String +methodDescriptor (MethodDeclaration returntype _ parameters _) = let + parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] + in + "(" + ++ (concat (map methodParameterDescriptor parameter_types)) + ++ ")" + ++ methodParameterDescriptor returntype + +methodParameterDescriptor :: String -> String +methodParameterDescriptor "void" = "V" +methodParameterDescriptor "int" = "I" +methodParameterDescriptor "char" = "C" +methodParameterDescriptor "boolean" = "B" +methodParameterDescriptor x = "L" ++ x ++ ";" + +memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool +memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) + +findMethodIndex :: ClassFile -> String -> Maybe Int +findMethodIndex classFile name = let + constants = constantPool classFile + in + findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) + + +methodBuilder :: ClassFileBuilder MethodDeclaration +methodBuilder (MethodDeclaration returntype name parameters statement) input = let + baseIndex = 1 + length (constantPool input) + constants = [ + Utf8Info name, + Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block []))) + ] + + method = MemberInfo { + memberAccessFlags = accessPublic, + memberNameIndex = (fromIntegral baseIndex), + memberDescriptorIndex = (fromIntegral (baseIndex + 1)), + memberAttributes = [] + } + in + input { + constantPool = (constantPool input) ++ constants, + methods = (methods input) ++ [method] + } + + + +methodAssembler :: ClassFileBuilder MethodDeclaration +methodAssembler (MethodDeclaration returntype name parameters statement) input = let + methodConstantIndex = findMethodIndex input name + in case methodConstantIndex of + Nothing -> error ("Cannot find method entry in method pool for method: " ++ name) + Just index -> let + declaration = MethodDeclaration returntype name parameters statement + paramNames = "this" : [name | ParameterDeclaration _ name <- parameters] + (pre, method : post) = splitAt index (methods input) + (_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration + assembledMethod = method { + memberAttributes = [ + CodeAttribute { + attributeMaxStack = 420, + attributeMaxLocals = 420, + attributeCode = bytecode + } + ] + } + in + input { + methods = pre ++ (assembledMethod : post) + } diff --git a/src/ByteCode/Generation/Generator.hs b/src/ByteCode/Generation/Generator.hs new file mode 100644 index 0000000..6d42ba0 --- /dev/null +++ b/src/ByteCode/Generation/Generator.hs @@ -0,0 +1,73 @@ +module ByteCode.Generation.Generator( + datatypeDescriptor, + memberInfoName, + memberInfoDescriptor, + returnOperation, + binaryOperation, + comparisonOperation, + ClassFileBuilder, + Assembler +) where + +import ByteCode.Constants +import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import Ast +import Data.Char +import Data.List +import Data.Word + +type ClassFileBuilder a = a -> ClassFile -> ClassFile +type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) + +datatypeDescriptor :: String -> String +datatypeDescriptor "void" = "V" +datatypeDescriptor "int" = "I" +datatypeDescriptor "char" = "C" +datatypeDescriptor "boolean" = "B" +datatypeDescriptor x = "L" ++ x + +memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String +memberInfoDescriptor constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = _, + memberDescriptorIndex = descriptorIndex, + memberAttributes = _ } = let + descriptor = constants!!((fromIntegral descriptorIndex) - 1) + in case descriptor of + Utf8Info descriptorText -> descriptorText + _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) + + +memberInfoName :: [ConstantInfo] -> MemberInfo -> String +memberInfoName constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = nameIndex, + memberDescriptorIndex = _, + memberAttributes = _ } = let + name = constants!!((fromIntegral nameIndex) - 1) + in case name of + Utf8Info nameText -> nameText + _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) + + +returnOperation :: DataType -> Operation +returnOperation dtype + | elem dtype ["int", "char", "boolean"] = Opireturn + | otherwise = Opareturn + +binaryOperation :: BinaryOperator -> Operation +binaryOperation Addition = Opiadd +binaryOperation Subtraction = Opisub +binaryOperation Multiplication = Opimul +binaryOperation Division = Opidiv +binaryOperation BitwiseAnd = Opiand +binaryOperation BitwiseOr = Opior +binaryOperation BitwiseXor = Opixor + +comparisonOperation :: BinaryOperator -> Word16 -> Operation +comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation +comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation +comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation +comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation +comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation +comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation \ No newline at end of file diff --git a/src/Main.hs b/src/Main.hs index 858d5bb..588efc2 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -4,7 +4,8 @@ import Example import Typecheck import Parser.Lexer (alexScanTokens) import Parser.JavaParser -import ByteCode.ClassFile.Generator +import ByteCode.Generation.Generator +import ByteCode.Generation.Builder.Class import ByteCode.ClassFile import Data.ByteString (pack, writeFile)