diff --git a/project.cabal b/project.cabal index 9b12543..eccbaa4 100644 --- a/project.cabal +++ b/project.cabal @@ -20,8 +20,10 @@ executable compiler Ast, Example, Typecheck, + ByteCode.Util, ByteCode.ByteUtil, ByteCode.ClassFile, + ByteCode.Assembler, ByteCode.Generator, ByteCode.Constants @@ -41,7 +43,9 @@ test-suite tests Ast, TestLexer, TestParser, + ByteCode.Util, ByteCode.ByteUtil, ByteCode.ClassFile, + ByteCode.Assembler, ByteCode.Generator, ByteCode.Constants diff --git a/src/ByteCode/Assembler.hs b/src/ByteCode/Assembler.hs new file mode 100644 index 0000000..aaa8542 --- /dev/null +++ b/src/ByteCode/Assembler.hs @@ -0,0 +1,276 @@ +module ByteCode.Assembler where + +import ByteCode.Constants +import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Util +import Ast +import Data.Char +import Data.List +import Data.Word + +type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) + +assembleExpression :: Assembler Expression +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) + | elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + in + (bConstants, bOps ++ [binaryOperation op], lvars) + | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let + (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a + (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b + cmp_op = comparisonOperation op 9 + cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] + in + (bConstants, bOps ++ cmp_ops, lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression btype (FieldVariable b)))) = let + (fConstants, fieldIndex) = getFieldIndex constants (atype, b, datatypeDescriptor btype) + (aConstants, aOps, _) = assembleExpression (fConstants, ops, lvars) (TypedExpression atype a) + in + (aConstants, aOps ++ [Opgetfield (fromIntegral fieldIndex)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = + (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = + (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) + | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) + | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = + (constants, ops ++ [Opaconst_null], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + newConstant = fromIntegral (1 + length exprConstants) + in case etype of + "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) + "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) + "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in + (exprConstants, exprOps ++ [Opineg], lvars) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) + | name == "this" = (constants, ops ++ [Opaload 0], lvars) + | otherwise = let + localIndex = findIndex ((==) name) lvars + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + +assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = + assembleStatementExpression (constants, ops, lvars) stmtexp + +assembleExpression _ expr = error ("unimplemented: " ++ show expr) + +assembleNameChain :: Assembler Expression +assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) = + assembleExpression input (TypedExpression atype a) +assembleNameChain input expr = assembleExpression input expr + + +assembleStatementExpression :: Assembler StatementExpression +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (Assignment (TypedExpression dtype receiver) expr)) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr + isPrimitive = elem dtype ["char", "boolean", "int"] + in case localIndex of + Just index -> (constants_a, ops_a ++ if isPrimitive then [Opdup, Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) + Nothing -> error ("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + (constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr + in + (constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreIncrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PreDecrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostIncrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression _ (PostDecrement (TypedExpression dtype receiver))) = let + target = resolveNameChain (TypedExpression dtype receiver) + in case target of + (TypedExpression dtype (LocalVariable name)) -> let + localIndex = findIndex ((==) name) lvars + expr = (TypedExpression dtype (LocalVariable name)) + (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr + in case localIndex of + Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars) + Nothing -> error("No such local variable found in local variable pool: " ++ name) + (TypedExpression dtype (FieldVariable name)) -> let + owner = resolveNameChainOwner (TypedExpression dtype receiver) + in case owner of + (TypedExpression otype _) -> let + (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) + (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) + in + (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars) + something_else -> error ("expected TypedExpression, but got: " ++ show something_else) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression rtype (MethodCall (TypedExpression otype receiver) name params)) = let + (constants_r, ops_r, lvars_r) = assembleExpression (constants, ops, lvars) (TypedExpression otype receiver) + (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_r, ops_r, lvars_r) params + (constants_m, methodIndex) = getMethodIndex constants_p (otype, name, methodDescriptorFromParamlist params rtype) + in + (constants_m, ops_p ++ [Opinvokevirtual (fromIntegral methodIndex)], lvars_p) + +assembleStatementExpression + (constants, ops, lvars) + (TypedStatementExpression rtype (ConstructorCall name params)) = let + (constants_c, classIndex) = getClassIndex constants name + (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_c, ops ++ [Opnew (fromIntegral classIndex), Opdup], lvars) params + (constants_m, methodIndex) = getMethodIndex constants_p (name, "", methodDescriptorFromParamlist params "void") + in + (constants_m, ops_p ++ [Opinvokespecial (fromIntegral methodIndex)], lvars_p) + + +assembleStatement :: Assembler Statement +assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of + Nothing -> (constants, ops ++ [Opreturn], lvars) + Just expr -> let + (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr + in + (expr_constants, expr_ops ++ [returnOperation stype], lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = + foldl assembleStatement (constants, ops, lvars) statements + +assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt else_stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt + (constants_elsea, ops_elsea, _) = case else_stmt of + Nothing -> (constants_ifa, [], lvars) + Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt + -- +6 because we insert 2 gotos, one for if, one for else + if_length = sum (map opcodeEncodingLength ops_ifa) + -- +3 because we need to account for the goto in the if statement. + else_length = sum (map opcodeEncodingLength ops_elsea) + in case dtype of + "void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars) + otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let + (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr + (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt + -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison + stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 + entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) + in + (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let + isPrimitive = elem dtype ["char", "boolean", "int"] + (constants_init, ops_init, _) = case expr of + Just exp -> assembleExpression (constants, ops, lvars) exp + Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) + localIndex = fromIntegral (length lvars) + storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] + in + (constants_init, ops_init ++ storeLocal, lvars ++ [name]) + +assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = let + (constants_e, ops_e, lvars_e) = assembleStatementExpression (constants, ops, lvars) expr + in + (constants_e, ops_e ++ [Oppop], lvars_e) + +assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) + + +assembleMethod :: Assembler MethodDeclaration +assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements))) + | name == "" = let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + init_ops = [Opaload 0, Opinvokespecial 2] + in + (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) + | otherwise = case returntype of + "void" -> let + (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements + in + (constants_a, ops_a ++ [Opreturn], lvars_a) + otherwise -> foldl assembleStatement (constants, ops, lvars) statements +assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt) diff --git a/src/ByteCode/ByteUtil.hs b/src/ByteCode/ByteUtil.hs index daa419a..a602ac3 100644 --- a/src/ByteCode/ByteUtil.hs +++ b/src/ByteCode/ByteUtil.hs @@ -1,7 +1,6 @@ -module ByteCode.ByteUtil(unpackWord16, unpackWord32) where +module ByteCode.ByteUtil where import Data.Word ( Word8, Word16, Word32 ) -import Data.Int import Data.Bits unpackWord16 :: Word16 -> [Word8] diff --git a/src/ByteCode/ClassFile.hs b/src/ByteCode/ClassFile.hs index 660e94d..1fc15c9 100644 --- a/src/ByteCode/ClassFile.hs +++ b/src/ByteCode/ClassFile.hs @@ -14,8 +14,8 @@ import Data.Word import Data.Int import Data.ByteString (unpack) import Data.ByteString.UTF8 (fromString) -import ByteCode.ByteUtil import ByteCode.Constants +import ByteCode.ByteUtil data ConstantInfo = ClassInfo Word16 | FieldRefInfo Word16 Word16 diff --git a/src/ByteCode/Generator.hs b/src/ByteCode/Generator.hs index e448fb3..d6a8980 100644 --- a/src/ByteCode/Generator.hs +++ b/src/ByteCode/Generator.hs @@ -1,224 +1,15 @@ -module ByteCode.Generator( - datatypeDescriptor, - memberInfoName, - memberInfoDescriptor, - methodDescriptor, - returnOperation, - binaryOperation, - comparisonOperation, - findMethodIndex, - ClassFileBuilder, - Assembler, - classBuilder -) where +module ByteCode.Generator where import ByteCode.Constants import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength) +import ByteCode.Assembler +import ByteCode.Util import Ast import Data.Char import Data.List import Data.Word -import Data.Maybe (mapMaybe) type ClassFileBuilder a = a -> ClassFile -> ClassFile -type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String]) - -methodDescriptor :: MethodDeclaration -> String -methodDescriptor (MethodDeclaration returntype _ parameters _) = let - parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] - in - "(" - ++ (concat (map datatypeDescriptor parameter_types)) - ++ ")" - ++ datatypeDescriptor returntype - -methodDescriptorFromParamlist :: [Expression] -> String -> String -methodDescriptorFromParamlist parameters returntype = let - parameter_types = [datatype | TypedExpression datatype _ <- parameters] - in - "(" - ++ (concat (map datatypeDescriptor parameter_types)) - ++ ")" - ++ datatypeDescriptor returntype - -memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool -memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) - - -datatypeDescriptor :: String -> String -datatypeDescriptor "void" = "V" -datatypeDescriptor "int" = "I" -datatypeDescriptor "char" = "C" -datatypeDescriptor "boolean" = "B" -datatypeDescriptor x = "L" ++ x ++ ";" - - -memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String -memberInfoDescriptor constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = _, - memberDescriptorIndex = descriptorIndex, - memberAttributes = _ } = let - descriptor = constants!!((fromIntegral descriptorIndex) - 1) - in case descriptor of - Utf8Info descriptorText -> descriptorText - _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) - - -memberInfoName :: [ConstantInfo] -> MemberInfo -> String -memberInfoName constants MemberInfo { - memberAccessFlags = _, - memberNameIndex = nameIndex, - memberDescriptorIndex = _, - memberAttributes = _ } = let - name = constants!!((fromIntegral nameIndex) - 1) - in case name of - Utf8Info nameText -> nameText - _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) - - -returnOperation :: DataType -> Operation -returnOperation dtype - | elem dtype ["int", "char", "boolean"] = Opireturn - | otherwise = Opareturn - -binaryOperation :: BinaryOperator -> Operation -binaryOperation Addition = Opiadd -binaryOperation Subtraction = Opisub -binaryOperation Multiplication = Opimul -binaryOperation Division = Opidiv -binaryOperation Modulo = Opirem -binaryOperation BitwiseAnd = Opiand -binaryOperation BitwiseOr = Opior -binaryOperation BitwiseXor = Opixor -binaryOperation And = Opiand -binaryOperation Or = Opior - -comparisonOperation :: BinaryOperator -> Word16 -> Operation -comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation -comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation -comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation -comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation -comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation -comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation - -findFieldIndex :: [ConstantInfo] -> String -> Maybe Int -findFieldIndex constants name = let - fieldRefNameInfos = [ - -- we only skip one entry to get the name since the Java constant pool - -- is 1-indexed (why) - (index, constants!!(fromIntegral index + 1)) - | (index, FieldRefInfo classIndex _) <- (zip [1..] constants) - ] - fieldRefNames = map (\(index, nameInfo) -> case nameInfo of - Utf8Info fieldName -> (index, fieldName) - something_else -> error ("Expected UTF8Info but got" ++ show something_else)) - fieldRefNameInfos - fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames - in case fieldIndex of - Just (index, _) -> Just index - Nothing -> Nothing - -findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int -findMethodRefIndex constants name = let - methodRefNameInfos = [ - -- we only skip one entry to get the name since the Java constant pool - -- is 1-indexed (why) - (index, constants!!(fromIntegral index + 1)) - | (index, MethodRefInfo _ _) <- (zip [1..] constants) - ] - methodRefNames = map (\(index, nameInfo) -> case nameInfo of - Utf8Info methodName -> (index, methodName) - something_else -> error ("Expected UTF8Info but got " ++ show something_else)) - methodRefNameInfos - methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames - in case methodIndex of - Just (index, _) -> Just index - Nothing -> Nothing - - -findMethodIndex :: ClassFile -> String -> Maybe Int -findMethodIndex classFile name = let - constants = constantPool classFile - in - findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) - -findClassIndex :: [ConstantInfo] -> String -> Maybe Int -findClassIndex constants name = let - classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)] - classNames = map (\(index, nameInfo) -> case nameInfo of - Utf8Info className -> (index, className) - something_else -> error("Expected UTF8Info but got " ++ show something_else)) - classNameIndices - desiredClassIndex = find (\(index, className) -> className == name) classNames - in case desiredClassIndex of - Just (index, _) -> Just index - Nothing -> Nothing - -getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))] -getKnownMembers constants = let - fieldsClassAndNT = [ - (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) - | (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) - ] ++ [ - (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) - | (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) - ] - - fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of - (ClassInfo nameIndex, NameAndTypeInfo fnameIndex ftypeIndex) -> (index, (constants!!(fromIntegral nameIndex - 1), constants!!(fromIntegral fnameIndex - 1), constants!!(fromIntegral ftypeIndex - 1))) - something_else -> error ("Expected Class and NameType info, but got: " ++ show nameInfo ++ " and " ++ show nameTypeInfo)) - fieldsClassAndNT - - fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of - (Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype)) - something_else -> error("Expected UTF8Infos but got " ++ show something_else)) - fieldsClassNameType - in - fieldsResolved - --- same as findClassIndex, but inserts a new entry into constant pool if not existing -getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int) -getClassIndex constants name = case findClassIndex constants name of - Just index -> (constants, index) - Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants)) - --- get the index for a field within a class, creating it if it does not exist. -getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) -getFieldIndex constants (cname, fname, ftype) = case findMemberIndex constants (cname, fname, ftype) of - Just index -> (constants, index) - Nothing -> let - (constantsWithClass, classIndex) = getClassIndex constants cname - baseIndex = 1 + length constantsWithClass - in - (constantsWithClass ++ [ - FieldRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info fname, - Utf8Info (datatypeDescriptor ftype) - ], baseIndex) - -getMethodIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) -getMethodIndex constants (cname, mname, mtype) = case findMemberIndex constants (cname, mname, mtype) of - Just index -> (constants, index) - Nothing -> let - (constantsWithClass, classIndex) = getClassIndex constants cname - baseIndex = 1 + length constantsWithClass - in - (constantsWithClass ++ [ - MethodRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), - NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), - Utf8Info mname, - Utf8Info mtype - ], baseIndex) - -findMemberIndex :: [ConstantInfo] -> (String, String, String) -> Maybe Int -findMemberIndex constants (cname, fname, ftype) = let - allMembers = getKnownMembers constants - desiredMember = find (\(index, (c, f, ft)) -> (c, f, ft) == (cname, fname, ftype)) allMembers - in - fmap (\(index, _) -> index) desiredMember - fieldBuilder :: ClassFileBuilder VariableDeclaration fieldBuilder (VariableDeclaration datatype name _) input = let @@ -241,18 +32,7 @@ fieldBuilder (VariableDeclaration datatype name _) input = let fields = (fields input) ++ [field] } --- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. -resolveNameChain :: Expression -> Expression -resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b -resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) -resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) -resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) --- walks the name resolution chain. returns the second-to-last item of the namechain. -resolveNameChainOwner :: Expression -> Expression -resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a -resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b -resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) methodBuilder :: ClassFileBuilder MethodDeclaration methodBuilder (MethodDeclaration returntype name parameters statement) input = let @@ -303,35 +83,6 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input = input { methods = pre ++ (assembledMethod : post) } - - - -injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] -injectDefaultConstructor pre - | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre - | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] - -injectFieldInitializers :: String -> [VariableDeclaration] -> [MethodDeclaration] -> [MethodDeclaration] -injectFieldInitializers classname vars pre = let - initializers = mapMaybe (\(variable) -> case variable of - VariableDeclaration dtype name (Just initializer) -> Just ( - TypedStatement dtype ( - StatementExpressionStatement ( - TypedStatementExpression dtype ( - Assignment - (TypedExpression dtype (BinaryOperation NameResolution (TypedExpression classname (LocalVariable "this")) (TypedExpression dtype (FieldVariable name)))) - initializer - ) - ) - ) - ) - otherwise -> Nothing - ) vars - in - map (\(method) -> case method of - MethodDeclaration "void" "" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "" params (TypedStatement "void" (Block (initializers ++ statements))) - otherwise -> method - ) pre classBuilder :: ClassFileBuilder Class @@ -363,271 +114,4 @@ classBuilder (Class name methods fields) _ = let classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers in - classFileWithAssembledMethods - - -assembleExpression :: Assembler Expression -assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) - | elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let - (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a - (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b - in - (bConstants, bOps ++ [binaryOperation op], lvars) - | elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let - (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a - (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b - cmp_op = comparisonOperation op 9 - cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1] - in - (bConstants, bOps ++ cmp_ops, lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression btype (FieldVariable b)))) = let - (fConstants, fieldIndex) = getFieldIndex constants (atype, b, datatypeDescriptor btype) - (aConstants, aOps, _) = assembleExpression (fConstants, ops, lvars) (TypedExpression atype a) - in - (aConstants, aOps ++ [Opgetfield (fromIntegral fieldIndex)], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) = - (constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) = - (constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal)) - | literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars) - | otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) = - (constants, ops ++ [Opaconst_null], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - newConstant = fromIntegral (1 + length exprConstants) - in case etype of - "int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars) - "char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars) - "boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in - (exprConstants, exprOps ++ [Opineg], lvars) - -assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) - | name == "this" = (constants, ops ++ [Opaload 0], lvars) - | otherwise = let - localIndex = findIndex ((==) name) lvars - isPrimitive = elem dtype ["char", "boolean", "int"] - in case localIndex of - Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) - Nothing -> error ("No such local variable found in local variable pool: " ++ name) - -assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = - assembleStatementExpression (constants, ops, lvars) stmtexp - -assembleExpression _ expr = error ("unimplemented: " ++ show expr) - -assembleNameChain :: Assembler Expression -assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) = - assembleExpression input (TypedExpression atype a) -assembleNameChain input expr = assembleExpression input expr - - --- TODO untested -assembleStatementExpression :: Assembler StatementExpression -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (Assignment (TypedExpression dtype receiver) expr)) = let - target = resolveNameChain (TypedExpression dtype receiver) - in case target of - (TypedExpression dtype (LocalVariable name)) -> let - localIndex = findIndex ((==) name) lvars - (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr - isPrimitive = elem dtype ["char", "boolean", "int"] - in case localIndex of - Just index -> (constants_a, ops_a ++ if isPrimitive then [Opdup, Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars) - Nothing -> error ("No such local variable found in local variable pool: " ++ name) - (TypedExpression dtype (FieldVariable name)) -> let - owner = resolveNameChainOwner (TypedExpression dtype receiver) - in case owner of - (TypedExpression otype _) -> let - (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) - (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) - (constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr - in - (constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) - something_else -> error ("expected TypedExpression, but got: " ++ show something_else) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreIncrement (TypedExpression dtype receiver))) = let - target = resolveNameChain (TypedExpression dtype receiver) - in case target of - (TypedExpression dtype (LocalVariable name)) -> let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in case localIndex of - Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - (TypedExpression dtype (FieldVariable name)) -> let - owner = resolveNameChainOwner (TypedExpression dtype receiver) - in case owner of - (TypedExpression otype _) -> let - (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) - (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) - in - (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) - something_else -> error ("expected TypedExpression, but got: " ++ show something_else) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PreDecrement (TypedExpression dtype receiver))) = let - target = resolveNameChain (TypedExpression dtype receiver) - in case target of - (TypedExpression dtype (LocalVariable name)) -> let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in case localIndex of - Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - (TypedExpression dtype (FieldVariable name)) -> let - owner = resolveNameChainOwner (TypedExpression dtype receiver) - in case owner of - (TypedExpression otype _) -> let - (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) - (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) - in - (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) - something_else -> error ("expected TypedExpression, but got: " ++ show something_else) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostIncrement (TypedExpression dtype receiver))) = let - target = resolveNameChain (TypedExpression dtype receiver) - in case target of - (TypedExpression dtype (LocalVariable name)) -> let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in case localIndex of - Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - (TypedExpression dtype (FieldVariable name)) -> let - owner = resolveNameChainOwner (TypedExpression dtype receiver) - in case owner of - (TypedExpression otype _) -> let - (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) - (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) - in - (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars) - something_else -> error ("expected TypedExpression, but got: " ++ show something_else) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression _ (PostDecrement (TypedExpression dtype receiver))) = let - target = resolveNameChain (TypedExpression dtype receiver) - in case target of - (TypedExpression dtype (LocalVariable name)) -> let - localIndex = findIndex ((==) name) lvars - expr = (TypedExpression dtype (LocalVariable name)) - (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr - in case localIndex of - Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars) - Nothing -> error("No such local variable found in local variable pool: " ++ name) - (TypedExpression dtype (FieldVariable name)) -> let - owner = resolveNameChainOwner (TypedExpression dtype receiver) - in case owner of - (TypedExpression otype _) -> let - (constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype) - (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) - in - (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars) - something_else -> error ("expected TypedExpression, but got: " ++ show something_else) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression rtype (MethodCall (TypedExpression otype receiver) name params)) = let - (constants_r, ops_r, lvars_r) = assembleExpression (constants, ops, lvars) (TypedExpression otype receiver) - (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_r, ops_r, lvars_r) params - (constants_m, methodIndex) = getMethodIndex constants_p (otype, name, methodDescriptorFromParamlist params rtype) - in - (constants_m, ops_p ++ [Opinvokevirtual (fromIntegral methodIndex)], lvars_p) - -assembleStatementExpression - (constants, ops, lvars) - (TypedStatementExpression rtype (ConstructorCall name params)) = let - (constants_c, classIndex) = getClassIndex constants name - (constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_c, ops ++ [Opnew (fromIntegral classIndex), Opdup], lvars) params - (constants_m, methodIndex) = getMethodIndex constants_p (name, "", methodDescriptorFromParamlist params "void") - in - (constants_m, ops_p ++ [Opinvokespecial (fromIntegral methodIndex)], lvars_p) - - -assembleStatement :: Assembler Statement -assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of - Nothing -> (constants, ops ++ [Opreturn], lvars) - Just expr -> let - (expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr - in - (expr_constants, expr_ops ++ [returnOperation stype], lvars) - -assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) = - foldl assembleStatement (constants, ops, lvars) statements - -assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt else_stmt)) = let - (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr - (constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt - (constants_elsea, ops_elsea, _) = case else_stmt of - Nothing -> (constants_ifa, [], lvars) - Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt - -- +6 because we insert 2 gotos, one for if, one for else - if_length = sum (map opcodeEncodingLength ops_ifa) - -- +3 because we need to account for the goto in the if statement. - else_length = sum (map opcodeEncodingLength ops_elsea) - in case dtype of - "void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars) - otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars) - -assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let - (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr - (constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt - -- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison - stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6 - entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp) - in - (constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars) - -assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let - isPrimitive = elem dtype ["char", "boolean", "int"] - (constants_init, ops_init, _) = case expr of - Just exp -> assembleExpression (constants, ops, lvars) exp - Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars) - localIndex = fromIntegral (length lvars) - storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex] - in - (constants_init, ops_init ++ storeLocal, lvars ++ [name]) - -assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = let - (constants_e, ops_e, lvars_e) = assembleStatementExpression (constants, ops, lvars) expr - in - (constants_e, ops_e ++ [Oppop], lvars_e) - -assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) - - -assembleMethod :: Assembler MethodDeclaration -assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements))) - | name == "" = let - (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements - init_ops = [Opaload 0, Opinvokespecial 2] - in - (constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) - | otherwise = case returntype of - "void" -> let - (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements - in - (constants_a, ops_a ++ [Opreturn], lvars_a) - otherwise -> foldl assembleStatement (constants, ops, lvars) statements -assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt) + classFileWithAssembledMethods \ No newline at end of file diff --git a/src/ByteCode/Util.hs b/src/ByteCode/Util.hs new file mode 100644 index 0000000..45e8816 --- /dev/null +++ b/src/ByteCode/Util.hs @@ -0,0 +1,245 @@ +module ByteCode.Util where + +import Data.Int +import Ast +import ByteCode.ClassFile +import Data.List +import Data.Maybe (mapMaybe) +import Data.Word (Word8, Word16, Word32) + +-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. +resolveNameChain :: Expression -> Expression +resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b +resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) +resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) +resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) + +-- walks the name resolution chain. returns the second-to-last item of the namechain. +resolveNameChainOwner :: Expression -> Expression +resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a +resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b +resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) + + +methodDescriptor :: MethodDeclaration -> String +methodDescriptor (MethodDeclaration returntype _ parameters _) = let + parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] + in + "(" + ++ (concat (map datatypeDescriptor parameter_types)) + ++ ")" + ++ datatypeDescriptor returntype + +methodDescriptorFromParamlist :: [Expression] -> String -> String +methodDescriptorFromParamlist parameters returntype = let + parameter_types = [datatype | TypedExpression datatype _ <- parameters] + in + "(" + ++ (concat (map datatypeDescriptor parameter_types)) + ++ ")" + ++ datatypeDescriptor returntype + +memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool +memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) + + +datatypeDescriptor :: String -> String +datatypeDescriptor "void" = "V" +datatypeDescriptor "int" = "I" +datatypeDescriptor "char" = "C" +datatypeDescriptor "boolean" = "B" +datatypeDescriptor x = "L" ++ x ++ ";" + + +memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String +memberInfoDescriptor constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = _, + memberDescriptorIndex = descriptorIndex, + memberAttributes = _ } = let + descriptor = constants!!((fromIntegral descriptorIndex) - 1) + in case descriptor of + Utf8Info descriptorText -> descriptorText + _ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) + + +memberInfoName :: [ConstantInfo] -> MemberInfo -> String +memberInfoName constants MemberInfo { + memberAccessFlags = _, + memberNameIndex = nameIndex, + memberDescriptorIndex = _, + memberAttributes = _ } = let + name = constants!!((fromIntegral nameIndex) - 1) + in case name of + Utf8Info nameText -> nameText + _ -> ("Invalid Item at Constant pool index " ++ show nameIndex) + + +returnOperation :: DataType -> Operation +returnOperation dtype + | elem dtype ["int", "char", "boolean"] = Opireturn + | otherwise = Opareturn + +binaryOperation :: BinaryOperator -> Operation +binaryOperation Addition = Opiadd +binaryOperation Subtraction = Opisub +binaryOperation Multiplication = Opimul +binaryOperation Division = Opidiv +binaryOperation Modulo = Opirem +binaryOperation BitwiseAnd = Opiand +binaryOperation BitwiseOr = Opior +binaryOperation BitwiseXor = Opixor +binaryOperation And = Opiand +binaryOperation Or = Opior + +comparisonOperation :: BinaryOperator -> Word16 -> Operation +comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation +comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation +comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation +comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation +comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation +comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation + +findFieldIndex :: [ConstantInfo] -> String -> Maybe Int +findFieldIndex constants name = let + fieldRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, FieldRefInfo classIndex _) <- (zip [1..] constants) + ] + fieldRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info fieldName -> (index, fieldName) + something_else -> error ("Expected UTF8Info but got" ++ show something_else)) + fieldRefNameInfos + fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames + in case fieldIndex of + Just (index, _) -> Just index + Nothing -> Nothing + +findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int +findMethodRefIndex constants name = let + methodRefNameInfos = [ + -- we only skip one entry to get the name since the Java constant pool + -- is 1-indexed (why) + (index, constants!!(fromIntegral index + 1)) + | (index, MethodRefInfo _ _) <- (zip [1..] constants) + ] + methodRefNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info methodName -> (index, methodName) + something_else -> error ("Expected UTF8Info but got " ++ show something_else)) + methodRefNameInfos + methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames + in case methodIndex of + Just (index, _) -> Just index + Nothing -> Nothing + + +findMethodIndex :: ClassFile -> String -> Maybe Int +findMethodIndex classFile name = let + constants = constantPool classFile + in + findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) + +findClassIndex :: [ConstantInfo] -> String -> Maybe Int +findClassIndex constants name = let + classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)] + classNames = map (\(index, nameInfo) -> case nameInfo of + Utf8Info className -> (index, className) + something_else -> error("Expected UTF8Info but got " ++ show something_else)) + classNameIndices + desiredClassIndex = find (\(index, className) -> className == name) classNames + in case desiredClassIndex of + Just (index, _) -> Just index + Nothing -> Nothing + +getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))] +getKnownMembers constants = let + fieldsClassAndNT = [ + (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) + | (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) + ] ++ [ + (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) + | (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) + ] + + fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of + (ClassInfo nameIndex, NameAndTypeInfo fnameIndex ftypeIndex) -> (index, (constants!!(fromIntegral nameIndex - 1), constants!!(fromIntegral fnameIndex - 1), constants!!(fromIntegral ftypeIndex - 1))) + something_else -> error ("Expected Class and NameType info, but got: " ++ show nameInfo ++ " and " ++ show nameTypeInfo)) + fieldsClassAndNT + + fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of + (Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype)) + something_else -> error("Expected UTF8Infos but got " ++ show something_else)) + fieldsClassNameType + in + fieldsResolved + +-- same as findClassIndex, but inserts a new entry into constant pool if not existing +getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int) +getClassIndex constants name = case findClassIndex constants name of + Just index -> (constants, index) + Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants)) + +-- get the index for a field within a class, creating it if it does not exist. +getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) +getFieldIndex constants (cname, fname, ftype) = case findMemberIndex constants (cname, fname, ftype) of + Just index -> (constants, index) + Nothing -> let + (constantsWithClass, classIndex) = getClassIndex constants cname + baseIndex = 1 + length constantsWithClass + in + (constantsWithClass ++ [ + FieldRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info fname, + Utf8Info (datatypeDescriptor ftype) + ], baseIndex) + +getMethodIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) +getMethodIndex constants (cname, mname, mtype) = case findMemberIndex constants (cname, mname, mtype) of + Just index -> (constants, index) + Nothing -> let + (constantsWithClass, classIndex) = getClassIndex constants cname + baseIndex = 1 + length constantsWithClass + in + (constantsWithClass ++ [ + MethodRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)), + NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)), + Utf8Info mname, + Utf8Info mtype + ], baseIndex) + +findMemberIndex :: [ConstantInfo] -> (String, String, String) -> Maybe Int +findMemberIndex constants (cname, fname, ftype) = let + allMembers = getKnownMembers constants + desiredMember = find (\(index, (c, f, ft)) -> (c, f, ft) == (cname, fname, ftype)) allMembers + in + fmap (\(index, _) -> index) desiredMember + +injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration] +injectDefaultConstructor pre + | any (\(MethodDeclaration _ name _ _) -> name == "") pre = pre + | otherwise = pre ++ [MethodDeclaration "void" "" [] (TypedStatement "void" (Block []))] + +injectFieldInitializers :: String -> [VariableDeclaration] -> [MethodDeclaration] -> [MethodDeclaration] +injectFieldInitializers classname vars pre = let + initializers = mapMaybe (\(variable) -> case variable of + VariableDeclaration dtype name (Just initializer) -> Just ( + TypedStatement dtype ( + StatementExpressionStatement ( + TypedStatementExpression dtype ( + Assignment + (TypedExpression dtype (BinaryOperation NameResolution (TypedExpression classname (LocalVariable "this")) (TypedExpression dtype (FieldVariable name)))) + initializer + ) + ) + ) + ) + otherwise -> Nothing + ) vars + in + map (\(method) -> case method of + MethodDeclaration "void" "" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "" params (TypedStatement "void" (Block (initializers ++ statements))) + otherwise -> method + ) pre \ No newline at end of file