bytecode #6

Merged
mrab merged 11 commits from bytecode into master 2024-06-21 07:06:07 +00:00
7 changed files with 86 additions and 137 deletions
Showing only changes of commit 8eb9c16c7a - Show all commits

View File

@ -25,7 +25,7 @@ public class Main {
// basic arithmetics // basic arithmetics
assert arithmetic.basic(1, 2, 3) == 2; assert arithmetic.basic(1, 2, 3) == 2;
// we have boolean logic as well // we have boolean logic as well
assert arithmetic.logic(true, false, true) == true; assert arithmetic.logic(false, false, true) == true;
// multiple classes within one file work. Referencing another classes fields/methods works. // multiple classes within one file work. Referencing another classes fields/methods works.
assert multipleClasses.a.a == 42; assert multipleClasses.a.a == 42;
// self-referencing classes work. // self-referencing classes work.

View File

@ -6,6 +6,6 @@ public class TestArithmetic {
public boolean logic(boolean a, boolean b, boolean c) public boolean logic(boolean a, boolean b, boolean c)
{ {
return a && (c || b); return !a && (c || b);
} }
} }

View File

@ -12,12 +12,12 @@ type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInf
assembleExpression :: Assembler Expression assembleExpression :: Assembler Expression
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
| elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let | op `elem` [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
in in
(bConstants, bOps ++ [binaryOperation op], lvars) (bConstants, bOps ++ [binaryOperation op], lvars)
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let | op `elem` [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
cmp_op = comparisonOperation op 9 cmp_op = comparisonOperation op 9
@ -60,7 +60,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Mi
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name))
| name == "this" = (constants, ops ++ [Opaload 0], lvars) | name == "this" = (constants, ops ++ [Opaload 0], lvars)
| otherwise = let | otherwise = let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
isPrimitive = elem dtype ["char", "boolean", "int"] isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of in case localIndex of
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
@ -69,7 +69,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
assembleStatementExpression (constants, ops, lvars) stmtexp assembleStatementExpression (constants, ops, lvars) stmtexp
assembleExpression _ expr = error ("unimplemented: " ++ show expr) assembleExpression _ expr = error ("Unknown expression: " ++ show expr)
assembleNameChain :: Assembler Expression assembleNameChain :: Assembler Expression
assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) = assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) =
@ -84,7 +84,7 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
isPrimitive = elem dtype ["char", "boolean", "int"] isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of in case localIndex of
@ -99,7 +99,7 @@ assembleStatementExpression
(constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr (constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr
in in
(constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -107,12 +107,12 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name) Nothing -> error ("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let (TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver) owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of in case owner of
@ -121,7 +121,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -129,12 +129,12 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name) Nothing -> error ("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let (TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver) owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of in case owner of
@ -143,7 +143,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -151,12 +151,12 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name) Nothing -> error ("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let (TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver) owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of in case owner of
@ -165,7 +165,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -173,12 +173,12 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name) Nothing -> error ("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let (TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver) owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of in case owner of
@ -187,7 +187,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -231,7 +231,7 @@ assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt
else_length = sum (map opcodeEncodingLength ops_elsea) else_length = sum (map opcodeEncodingLength ops_elsea)
in case dtype of in case dtype of
"void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars) "void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars)
otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars) _ -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
@ -257,20 +257,19 @@ assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpression
in in
(constants_e, ops_e ++ [Oppop], lvars_e) (constants_e, ops_e ++ [Oppop], lvars_e)
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) assembleStatement _ stmt = error ("Unknown statement: " ++ show stmt)
assembleMethod :: Assembler MethodDeclaration assembleMethod :: Assembler MethodDeclaration
assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements))) assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements)))
| name == "<init>" = let | name == "<init>" = let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0, Opinvokespecial 2]
in in
(constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) (constants_a, [Opaload 0, Opinvokespecial 2] ++ ops_a ++ [Opreturn], lvars_a)
| otherwise = case returntype of | otherwise = case returntype of
"void" -> let "void" -> let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
in in
(constants_a, ops_a ++ [Opreturn], lvars_a) (constants_a, ops_a ++ [Opreturn], lvars_a)
otherwise -> foldl assembleStatement (constants, ops, lvars) statements _ -> foldl assembleStatement (constants, ops, lvars) statements
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt) assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt)

View File

@ -22,14 +22,14 @@ fieldBuilder (VariableDeclaration datatype name _) input = let
] ]
field = MemberInfo { field = MemberInfo {
memberAccessFlags = accessPublic, memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)), memberNameIndex = fromIntegral (baseIndex + 2),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)), memberDescriptorIndex = fromIntegral (baseIndex + 3),
memberAttributes = [] memberAttributes = []
} }
in in
input { input {
constantPool = (constantPool input) ++ constants, constantPool = constantPool input ++ constants,
fields = (fields input) ++ [field] fields = fields input ++ [field]
} }
@ -46,14 +46,14 @@ methodBuilder (MethodDeclaration returntype name parameters statement) input = l
method = MemberInfo { method = MemberInfo {
memberAccessFlags = accessPublic, memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)), memberNameIndex = fromIntegral (baseIndex + 2),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)), memberDescriptorIndex = fromIntegral (baseIndex + 3),
memberAttributes = [] memberAttributes = []
} }
in in
input { input {
constantPool = (constantPool input) ++ constants, constantPool = constantPool input ++ constants,
methods = (methods input) ++ [method] methods = methods input ++ [method]
} }
@ -94,11 +94,12 @@ classBuilder (Class name methods fields) _ = let
Utf8Info "java/lang/Object", Utf8Info "java/lang/Object",
Utf8Info "<init>", Utf8Info "<init>",
Utf8Info "()V", Utf8Info "()V",
Utf8Info "Code" Utf8Info "Code",
ClassInfo 9,
Utf8Info name
] ]
nameConstants = [ClassInfo 9, Utf8Info name]
nakedClassFile = ClassFile { nakedClassFile = ClassFile {
constantPool = baseConstants ++ nameConstants, constantPool = baseConstants,
accessFlags = accessPublic, accessFlags = accessPublic,
thisClass = 8, thisClass = 8,
superClass = 1, superClass = 1,
@ -107,9 +108,13 @@ classBuilder (Class name methods fields) _ = let
attributes = [] attributes = []
} }
-- if a class has no constructor, inject an empty one.
methodsWithInjectedConstructor = injectDefaultConstructor methods methodsWithInjectedConstructor = injectDefaultConstructor methods
-- for every constructor, prepend all initialization assignments for fields.
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
-- add fields, then method bodies to the classfile. After all referable names are known,
-- assemble the methods into bytecode.
classFileWithFields = foldr fieldBuilder nakedClassFile fields classFileWithFields = foldr fieldBuilder nakedClassFile fields
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers

View File

@ -1,14 +1,4 @@
module ByteCode.ClassFile( module ByteCode.ClassFile where
ConstantInfo(..),
Attribute(..),
MemberInfo(..),
ClassFile(..),
Operation(..),
serialize,
emptyClassFile,
opcodeEncodingLength,
className
) where
import Data.Word import Data.Word
import Data.Int import Data.Int
@ -99,10 +89,10 @@ emptyClassFile = ClassFile {
className :: ClassFile -> String className :: ClassFile -> String
className classFile = let className classFile = let
classInfo = (constantPool classFile)!!(fromIntegral (thisClass classFile)) classInfo = constantPool classFile !! fromIntegral (thisClass classFile)
in case classInfo of in case classInfo of
Utf8Info className -> className Utf8Info className -> className
otherwise -> error ("expected Utf8Info but got: " ++ show otherwise) unexpected_element -> error ("expected Utf8Info but got: " ++ show unexpected_element)
opcodeEncodingLength :: Operation -> Word16 opcodeEncodingLength :: Operation -> Word16
@ -201,10 +191,10 @@ instance Serializable Attribute where
serialize (CodeAttribute { attributeMaxStack = maxStack, serialize (CodeAttribute { attributeMaxStack = maxStack,
attributeMaxLocals = maxLocals, attributeMaxLocals = maxLocals,
attributeCode = code }) = let attributeCode = code }) = let
assembledCode = concat (map serialize code) assembledCode = concatMap serialize code
in in
unpackWord16 7 -- attribute_name_index unpackWord16 7 -- attribute_name_index
++ unpackWord32 (12 + (fromIntegral (length assembledCode))) -- attribute_length ++ unpackWord32 (12 + fromIntegral (length assembledCode)) -- attribute_length
++ unpackWord16 maxStack -- max_stack ++ unpackWord16 maxStack -- max_stack
++ unpackWord16 maxLocals -- max_locals ++ unpackWord16 maxLocals -- max_locals
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length ++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length

View File

@ -10,23 +10,22 @@ import Data.Word (Word8, Word16, Word32)
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. -- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
resolveNameChain :: Expression -> Expression resolveNameChain :: Expression -> Expression
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) resolveNameChain (TypedExpression dtype (LocalVariable name)) = TypedExpression dtype (LocalVariable name)
resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) resolveNameChain (TypedExpression dtype (FieldVariable name)) = TypedExpression dtype (FieldVariable name)
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show (invalidExpression)) resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
-- walks the name resolution chain. returns the second-to-last item of the namechain. -- walks the name resolution chain. returns the second-to-last item of the namechain.
resolveNameChainOwner :: Expression -> Expression resolveNameChainOwner :: Expression -> Expression
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show (invalidExpression)) resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
methodDescriptor :: MethodDeclaration -> String methodDescriptor :: MethodDeclaration -> String
methodDescriptor (MethodDeclaration returntype _ parameters _) = let methodDescriptor (MethodDeclaration returntype _ parameters _) = let
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
in in
"(" "("
++ (concat (map datatypeDescriptor parameter_types)) ++ concatMap datatypeDescriptor parameter_types
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ datatypeDescriptor returntype
@ -35,10 +34,12 @@ methodDescriptorFromParamlist parameters returntype = let
parameter_types = [datatype | TypedExpression datatype _ <- parameters] parameter_types = [datatype | TypedExpression datatype _ <- parameters]
in in
"(" "("
++ (concat (map datatypeDescriptor parameter_types)) ++ concatMap datatypeDescriptor parameter_types
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ datatypeDescriptor returntype
-- recursively parses a given type signature into a list of parameter types and the method return type.
-- As an initial parameter, you can supply ([], "void").
parseMethodType :: ([String], String) -> String -> ([String], String) parseMethodType :: ([String], String) -> String -> ([String], String)
parseMethodType (params, returnType) ('(' : descriptor) = parseMethodType (params, returnType) descriptor parseMethodType (params, returnType) ('(' : descriptor) = parseMethodType (params, returnType) descriptor
parseMethodType (params, returnType) ('I' : descriptor) = parseMethodType (params ++ ["I"], returnType) descriptor parseMethodType (params, returnType) ('I' : descriptor) = parseMethodType (params ++ ["I"], returnType) descriptor
@ -51,16 +52,16 @@ parseMethodType (params, returnType) ('L' : descriptor) = let
(typeName, semicolon : restOfDescriptor) = splitAt length descriptor (typeName, semicolon : restOfDescriptor) = splitAt length descriptor
in in
parseMethodType (params ++ [typeName], returnType) restOfDescriptor parseMethodType (params ++ [typeName], returnType) restOfDescriptor
Nothing -> error $ "unterminated class type in function signature: " ++ (show descriptor) Nothing -> error $ "unterminated class type in function signature: " ++ show descriptor
parseMethodType (params, _) (')' : descriptor) = (params, descriptor) parseMethodType (params, _) (')' : descriptor) = (params, descriptor)
parseMethodType _ descriptor = error $ "expected start of type name (L, I, C, Z) but got: " ++ descriptor parseMethodType _ descriptor = error $ "expected start of type name (L, I, C, Z) but got: " ++ descriptor
-- given a method index (constant pool index), -- given a method index (constant pool index),
-- returns the full type of the method. (i.e (LSomething;II)V) -- returns the full type of the method. (i.e (LSomething;II)V)
methodTypeFromIndex :: [ConstantInfo] -> Int -> String methodTypeFromIndex :: [ConstantInfo] -> Int -> String
methodTypeFromIndex constants index = case constants!!(fromIntegral (index - 1)) of methodTypeFromIndex constants index = case constants !! fromIntegral (index - 1) of
MethodRefInfo _ nameAndTypeIndex -> case constants!!(fromIntegral (nameAndTypeIndex - 1)) of MethodRefInfo _ nameAndTypeIndex -> case constants !! fromIntegral (nameAndTypeIndex - 1) of
NameAndTypeInfo _ typeIndex -> case constants!!(fromIntegral (typeIndex - 1)) of NameAndTypeInfo _ typeIndex -> case constants !! fromIntegral (typeIndex - 1) of
Utf8Info typeLiteral -> typeLiteral Utf8Info typeLiteral -> typeLiteral
unexpectedElement -> error "Expected Utf8Info but got: " ++ show unexpectedElement unexpectedElement -> error "Expected Utf8Info but got: " ++ show unexpectedElement
unexpectedElement -> error "Expected NameAndTypeInfo but got: " ++ show unexpectedElement unexpectedElement -> error "Expected NameAndTypeInfo but got: " ++ show unexpectedElement
@ -70,7 +71,7 @@ methodParametersFromIndex :: [ConstantInfo] -> Int -> ([String], String)
methodParametersFromIndex constants index = parseMethodType ([], "V") (methodTypeFromIndex constants index) methodParametersFromIndex constants index = parseMethodType ([], "V") (methodTypeFromIndex constants index)
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) memberInfoIsMethod constants info = '(' `elem` memberInfoDescriptor constants info
datatypeDescriptor :: String -> String datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V" datatypeDescriptor "void" = "V"
@ -79,35 +80,24 @@ datatypeDescriptor "char" = "C"
datatypeDescriptor "boolean" = "Z" datatypeDescriptor "boolean" = "Z"
datatypeDescriptor x = "L" ++ x ++ ";" datatypeDescriptor x = "L" ++ x ++ ";"
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
memberInfoDescriptor constants MemberInfo { memberInfoDescriptor constants MemberInfo { memberDescriptorIndex = descriptorIndex } = let
memberAccessFlags = _, descriptor = constants !! (fromIntegral descriptorIndex - 1)
memberNameIndex = _,
memberDescriptorIndex = descriptorIndex,
memberAttributes = _ } = let
descriptor = constants!!((fromIntegral descriptorIndex) - 1)
in case descriptor of in case descriptor of
Utf8Info descriptorText -> descriptorText Utf8Info descriptorText -> descriptorText
_ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) _ -> "Invalid Item at Constant pool index " ++ show descriptorIndex
memberInfoName :: [ConstantInfo] -> MemberInfo -> String memberInfoName :: [ConstantInfo] -> MemberInfo -> String
memberInfoName constants MemberInfo { memberInfoName constants MemberInfo { memberNameIndex = nameIndex } = let
memberAccessFlags = _, name = constants !! (fromIntegral nameIndex - 1)
memberNameIndex = nameIndex,
memberDescriptorIndex = _,
memberAttributes = _ } = let
name = constants!!((fromIntegral nameIndex) - 1)
in case name of in case name of
Utf8Info nameText -> nameText Utf8Info nameText -> nameText
_ -> ("Invalid Item at Constant pool index " ++ show nameIndex) _ -> "Invalid Item at Constant pool index " ++ show nameIndex
returnOperation :: DataType -> Operation returnOperation :: DataType -> Operation
returnOperation dtype returnOperation dtype
| elem dtype ["int", "char", "boolean"] = Opireturn | dtype `elem` ["int", "char", "boolean"] = Opireturn
| otherwise = Opareturn | otherwise = Opareturn
binaryOperation :: BinaryOperator -> Operation binaryOperation :: BinaryOperator -> Operation
binaryOperation Addition = Opiadd binaryOperation Addition = Opiadd
@ -141,50 +131,15 @@ comparisonOffset anything_else = Nothing
isComparisonOperation :: Operation -> Bool isComparisonOperation :: Operation -> Bool
isComparisonOperation op = isJust (comparisonOffset op) isComparisonOperation op = isJust (comparisonOffset op)
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int
findFieldIndex constants name = let
fieldRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, FieldRefInfo classIndex _) <- (zip [1..] constants)
]
fieldRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info fieldName -> (index, fieldName)
something_else -> error ("Expected UTF8Info but got" ++ show something_else))
fieldRefNameInfos
fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames
in case fieldIndex of
Just (index, _) -> Just index
Nothing -> Nothing
findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int
findMethodRefIndex constants name = let
methodRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, MethodRefInfo _ _) <- (zip [1..] constants)
]
methodRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info methodName -> (index, methodName)
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
methodRefNameInfos
methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames
in case methodIndex of
Just (index, _) -> Just index
Nothing -> Nothing
findMethodIndex :: ClassFile -> String -> Maybe Int findMethodIndex :: ClassFile -> String -> Maybe Int
findMethodIndex classFile name = let findMethodIndex classFile name = let
constants = constantPool classFile constants = constantPool classFile
in in
findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) findIndex (\method -> memberInfoIsMethod constants method && memberInfoName constants method == name) (methods classFile)
findClassIndex :: [ConstantInfo] -> String -> Maybe Int findClassIndex :: [ConstantInfo] -> String -> Maybe Int
findClassIndex constants name = let findClassIndex constants name = let
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip [1..] constants)] classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- zip [1..] constants]
classNames = map (\(index, nameInfo) -> case nameInfo of classNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info className -> (index, className) Utf8Info className -> (index, className)
something_else -> error ("Expected UTF8Info but got " ++ show something_else)) something_else -> error ("Expected UTF8Info but got " ++ show something_else))
@ -198,10 +153,10 @@ getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))]
getKnownMembers constants = let getKnownMembers constants = let
fieldsClassAndNT = [ fieldsClassAndNT = [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) | (index, FieldRefInfo classIndex nameTypeIndex) <- zip [1..] constants
] ++ [ ] ++ [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) | (index, MethodRefInfo classIndex nameTypeIndex) <- zip [1..] constants
] ]
fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of
@ -280,9 +235,9 @@ injectFieldInitializers classname vars pre = let
otherwise -> Nothing otherwise -> Nothing
) vars ) vars
in in
map (\(method) -> case method of map (\method -> case method of
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements))) MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
otherwise -> method _ -> method
) pre ) pre
-- effect of one instruction/operation on the stack -- effect of one instruction/operation on the stack
@ -312,10 +267,10 @@ operationStackCost constants Opdup_x1 = 1
operationStackCost constants Oppop = -1 operationStackCost constants Oppop = -1
operationStackCost constants (Opinvokespecial idx) = let operationStackCost constants (Opinvokespecial idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx) (params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - (fromEnum (returnType /= "V")) in (length params + 1) - fromEnum (returnType /= "V")
operationStackCost constants (Opinvokevirtual idx) = let operationStackCost constants (Opinvokevirtual idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx) (params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - (fromEnum (returnType /= "V")) in (length params + 1) - fromEnum (returnType /= "V")
operationStackCost constants (Opgoto _) = 0 operationStackCost constants (Opgoto _) = 0
operationStackCost constants (Opsipush _) = 1 operationStackCost constants (Opsipush _) = 1
operationStackCost constants (Opldc_w _) = 1 operationStackCost constants (Opldc_w _) = 1