bytecode #3

Merged
mrab merged 21 commits from bytecode into master 2024-06-14 07:56:36 +00:00
23 changed files with 831 additions and 636 deletions

1
.gitignore vendored
View File

@ -8,7 +8,6 @@ cabal-dev
*.chs.h
*.dyn_o
*.dyn_hi
*.java
*.class
*.local~*
src/Parser/JavaParser.hs

View File

@ -0,0 +1,38 @@
// compile all test files using:
// ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {}
// compile (in project root) using:
// javac -g:none -sourcepath Test/JavaSources/ Test/JavaSources/Main.java
// afterwards, run using
// java -ea -cp Test/JavaSources/ Main
public class Main {
public static void main(String[] args)
{
TestEmpty empty = new TestEmpty();
TestFields fields = new TestFields();
TestConstructor constructor = new TestConstructor(42);
TestMultipleClasses multipleClasses = new TestMultipleClasses();
TestRecursion recursion = new TestRecursion(10);
TestMalicious malicious = new TestMalicious();
// constructing a basic class works
assert empty != null;
// initializers (and default initializers to 0/null) work
assert fields.a == 0 && fields.b == 42;
// constructor parameters override initializers
assert constructor.a == 42;
// multiple classes within one file work. Referencing another classes fields/methods works.
assert multipleClasses.a.a == 42;
// self-referencing classes work.
assert recursion.child.child.child.child.child.value == 5;
// self-referencing methods work.
assert recursion.fibonacci(15) == 610;
// intentionally dodgy expressions work
assert malicious.assignNegativeIncrement(42) == -42;
assert malicious.tripleAddition(1, 2, 3) == 6;
for(int i = 0; i < 3; i++)
{
assert malicious.cursedFormatting(i) == i;
}
}
}

View File

@ -0,0 +1,9 @@
public class TestConstructor
{
public int a = -1;
public TestConstructor(int initial_value)
{
a = initial_value;
}
}

View File

@ -0,0 +1,4 @@
public class TestEmpty
{
}

View File

@ -0,0 +1,5 @@
public class TestFields
{
public int a;
public int b = 42;
}

View File

@ -0,0 +1,41 @@
public class TestMalicious {
public int assignNegativeIncrement(int n)
{
return n=-++n+1;
}
public int tripleAddition(int a, int b, int c)
{
return a+++b+++c++;
}
public int cursedFormatting(int n)
{
if
(n == 0)
{
return ((((0))));
}
else
if(n ==
1)
{
return
1;
}else {
return
2
;
}
}
}

View File

@ -0,0 +1,9 @@
public class TestMultipleClasses
{
public AnotherTestClass a = new AnotherTestClass();
}
class AnotherTestClass
{
public int a = 42;
}

View File

@ -0,0 +1,27 @@
public class TestRecursion {
public int value = 0;
public TestRecursion child = null;
public TestRecursion(int n)
{
value = n;
if(n > 0)
{
child = new TestRecursion(n - 1);
}
}
public int fibonacci(int n)
{
if(n < 2)
{
return n;
}
else
{
return fibonacci(n - 1) + this.fibonacci(n - 2);
}
}
}

View File

@ -1,118 +0,0 @@
module TestByteCodeGenerator where
import Test.HUnit
import ByteCode.ClassFile.Generator
import ByteCode.ClassFile
import ByteCode.Constants
import Ast
nakedClass = Class "Testklasse" [] []
expectedClass = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [],
attributes = []
}
classWithFields = Class "Testklasse" [] [VariableDeclaration "int" "testvariable" Nothing]
expectedClassWithFields = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse",
FieldRefInfo 8 11,
NameAndTypeInfo 12 13,
Utf8Info "testvariable",
Utf8Info "I"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [
MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = 12,
memberDescriptorIndex = 13,
memberAttributes = []
}
],
methods = [],
attributes = []
}
method = MethodDeclaration "int" "add_two_numbers" [
ParameterDeclaration "int" "a",
ParameterDeclaration "int" "b" ]
(Block [Return (Just (BinaryOperation Addition (Reference "a") (Reference "b")))])
classWithMethod = Class "Testklasse" [method] []
expectedClassWithMethod = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse",
FieldRefInfo 8 11,
NameAndTypeInfo 12 13,
Utf8Info "add_two_numbers",
Utf8Info "(II)I"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [
MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = 12,
memberDescriptorIndex = 13,
memberAttributes = [
CodeAttribute {
attributeMaxStack = 420,
attributeMaxLocals = 420,
attributeCode = [Opiadd]
}
]
}
],
attributes = []
}
testBasicConstantPool = TestCase $ assertEqual "basic constant pool" expectedClass $ classBuilder nakedClass emptyClassFile
testFields = TestCase $ assertEqual "fields in constant pool" expectedClassWithFields $ classBuilder classWithFields emptyClassFile
testMethodDescriptor = TestCase $ assertEqual "method descriptor" "(II)I" (methodDescriptor method)
testMethodAssembly = TestCase $ assertEqual "method assembly" expectedClassWithMethod (classBuilder classWithMethod emptyClassFile)
testFindMethodIndex = TestCase $ assertEqual "find method" (Just 0) (findMethodIndex expectedClassWithMethod "add_two_numbers")
tests = TestList [
TestLabel "Basic constant pool" testBasicConstantPool,
TestLabel "Fields constant pool" testFields,
TestLabel "Method descriptor building" testMethodDescriptor,
TestLabel "Method assembly" testMethodAssembly,
TestLabel "Find method by name" testFindMethodIndex
]

View File

@ -2,13 +2,11 @@ module Main where
import Test.HUnit
import TestLexer
import TestByteCodeGenerator
import TestParser
tests = TestList [
TestLabel "TestLexer" TestLexer.tests,
TestLabel "TestParser" TestParser.tests,
TestLabel "TestByteCodeGenerator" TestByteCodeGenerator.tests]
TestLabel "TestLexer" TestLexer.tests,
TestLabel "TestParser" TestParser.tests
]
main = do runTestTTAndExit Main.tests

View File

@ -10,7 +10,8 @@ executable compiler
array,
HUnit,
utf8-string,
bytestring
bytestring,
filepath
default-language: Haskell2010
hs-source-dirs: src
build-tool-depends: alex:alex, happy:happy
@ -19,14 +20,11 @@ executable compiler
Ast,
Example,
Typecheck,
ByteCode.Util,
ByteCode.ByteUtil,
ByteCode.ClassFile,
ByteCode.Generation.Generator,
ByteCode.Generation.Assembler.ExpressionAndStatement,
ByteCode.Generation.Assembler.Method,
ByteCode.Generation.Builder.Class,
ByteCode.Generation.Builder.Field,
ByteCode.Generation.Builder.Method,
ByteCode.Assembler,
ByteCode.Builder,
ByteCode.Constants
test-suite tests
@ -37,15 +35,17 @@ test-suite tests
array,
HUnit,
utf8-string,
bytestring
bytestring,
filepath
build-tool-depends: alex:alex, happy:happy
other-modules: Parser.Lexer,
Parser.JavaParser,
Ast,
TestLexer,
TestParser,
TestByteCodeGenerator,
ByteCode.Util,
ByteCode.ByteUtil,
ByteCode.ClassFile,
ByteCode.ClassFile.Generator,
ByteCode.Assembler,
ByteCode.Builder,
ByteCode.Constants

276
src/ByteCode/Assembler.hs Normal file
View File

@ -0,0 +1,276 @@
module ByteCode.Assembler where
import ByteCode.Constants
import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Util
import Ast
import Data.Char
import Data.List
import Data.Word
type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String])
assembleExpression :: Assembler Expression
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
| elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
in
(bConstants, bOps ++ [binaryOperation op], lvars)
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
cmp_op = comparisonOperation op 9
cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1]
in
(bConstants, bOps ++ cmp_ops, lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression btype (FieldVariable b)))) = let
(fConstants, fieldIndex) = getFieldIndex constants (atype, b, datatypeDescriptor btype)
(aConstants, aOps, _) = assembleExpression (fConstants, ops, lvars) (TypedExpression atype a)
in
(aConstants, aOps ++ [Opgetfield (fromIntegral fieldIndex)], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) =
(constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) =
(constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal))
| literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars)
| otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) =
(constants, ops ++ [Opaconst_null], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
newConstant = fromIntegral (1 + length exprConstants)
in case etype of
"int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars)
"char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars)
"boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in
(exprConstants, exprOps ++ [Opineg], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name))
| name == "this" = (constants, ops ++ [Opaload 0], lvars)
| otherwise = let
localIndex = findIndex ((==) name) lvars
isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
Nothing -> error ("No such local variable found in local variable pool: " ++ name)
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
assembleStatementExpression (constants, ops, lvars) stmtexp
assembleExpression _ expr = error ("unimplemented: " ++ show expr)
assembleNameChain :: Assembler Expression
assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) =
assembleExpression input (TypedExpression atype a)
assembleNameChain input expr = assembleExpression input expr
assembleStatementExpression :: Assembler StatementExpression
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (Assignment (TypedExpression dtype receiver) expr)) = let
target = resolveNameChain (TypedExpression dtype receiver)
in case target of
(TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of
Just index -> (constants_a, ops_a ++ if isPrimitive then [Opdup, Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars)
Nothing -> error ("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of
(TypedExpression otype _) -> let
(constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype)
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
(constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr
in
(constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreIncrement (TypedExpression dtype receiver))) = let
target = resolveNameChain (TypedExpression dtype receiver)
in case target of
(TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of
(TypedExpression otype _) -> let
(constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype)
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreDecrement (TypedExpression dtype receiver))) = let
target = resolveNameChain (TypedExpression dtype receiver)
in case target of
(TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of
(TypedExpression otype _) -> let
(constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype)
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostIncrement (TypedExpression dtype receiver))) = let
target = resolveNameChain (TypedExpression dtype receiver)
in case target of
(TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of
(TypedExpression otype _) -> let
(constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype)
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostDecrement (TypedExpression dtype receiver))) = let
target = resolveNameChain (TypedExpression dtype receiver)
in case target of
(TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
(TypedExpression dtype (FieldVariable name)) -> let
owner = resolveNameChainOwner (TypedExpression dtype receiver)
in case owner of
(TypedExpression otype _) -> let
(constants_f, fieldIndex) = getFieldIndex constants (otype, name, datatypeDescriptor dtype)
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression rtype (MethodCall (TypedExpression otype receiver) name params)) = let
(constants_r, ops_r, lvars_r) = assembleExpression (constants, ops, lvars) (TypedExpression otype receiver)
(constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_r, ops_r, lvars_r) params
(constants_m, methodIndex) = getMethodIndex constants_p (otype, name, methodDescriptorFromParamlist params rtype)
in
(constants_m, ops_p ++ [Opinvokevirtual (fromIntegral methodIndex)], lvars_p)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression rtype (ConstructorCall name params)) = let
(constants_c, classIndex) = getClassIndex constants name
(constants_p, ops_p, lvars_p) = foldl assembleExpression (constants_c, ops ++ [Opnew (fromIntegral classIndex), Opdup], lvars) params
(constants_m, methodIndex) = getMethodIndex constants_p (name, "<init>", methodDescriptorFromParamlist params "void")
in
(constants_m, ops_p ++ [Opinvokespecial (fromIntegral methodIndex)], lvars_p)
assembleStatement :: Assembler Statement
assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of
Nothing -> (constants, ops ++ [Opreturn], lvars)
Just expr -> let
(expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr
in
(expr_constants, expr_ops ++ [returnOperation stype], lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) =
foldl assembleStatement (constants, ops, lvars) statements
assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt else_stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
(constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt
(constants_elsea, ops_elsea, _) = case else_stmt of
Nothing -> (constants_ifa, [], lvars)
Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt
-- +6 because we insert 2 gotos, one for if, one for else
if_length = sum (map opcodeEncodingLength ops_ifa)
-- +3 because we need to account for the goto in the if statement.
else_length = sum (map opcodeEncodingLength ops_elsea)
in case dtype of
"void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars)
otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
(constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt
-- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison
stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6
entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp)
in
(constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let
isPrimitive = elem dtype ["char", "boolean", "int"]
(constants_init, ops_init, _) = case expr of
Just exp -> assembleExpression (constants, ops, lvars) exp
Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars)
localIndex = fromIntegral (length lvars)
storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex]
in
(constants_init, ops_init ++ storeLocal, lvars ++ [name])
assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) = let
(constants_e, ops_e, lvars_e) = assembleStatementExpression (constants, ops, lvars) expr
in
(constants_e, ops_e ++ [Oppop], lvars_e)
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt)
assembleMethod :: Assembler MethodDeclaration
assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements)))
| name == "<init>" = let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0, Opinvokespecial 2]
in
(constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a)
| otherwise = case returntype of
"void" -> let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
in
(constants_a, ops_a ++ [Opreturn], lvars_a)
otherwise -> foldl assembleStatement (constants, ops, lvars) statements
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt)

117
src/ByteCode/Builder.hs Normal file
View File

@ -0,0 +1,117 @@
module ByteCode.Builder where
import ByteCode.Constants
import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Assembler
import ByteCode.Util
import Ast
import Data.Char
import Data.List
import Data.Word
type ClassFileBuilder a = a -> ClassFile -> ClassFile
fieldBuilder :: ClassFileBuilder VariableDeclaration
fieldBuilder (VariableDeclaration datatype name _) input = let
baseIndex = 1 + length (constantPool input)
constants = [
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info name,
Utf8Info (datatypeDescriptor datatype)
]
field = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
memberAttributes = []
}
in
input {
constantPool = (constantPool input) ++ constants,
fields = (fields input) ++ [field]
}
methodBuilder :: ClassFileBuilder MethodDeclaration
methodBuilder (MethodDeclaration returntype name parameters statement) input = let
baseIndex = 1 + length (constantPool input)
constants = [
MethodRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info name,
Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block [])))
]
method = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
memberAttributes = []
}
in
input {
constantPool = (constantPool input) ++ constants,
methods = (methods input) ++ [method]
}
methodAssembler :: ClassFileBuilder MethodDeclaration
methodAssembler (MethodDeclaration returntype name parameters statement) input = let
methodConstantIndex = findMethodIndex input name
in case methodConstantIndex of
Nothing -> error ("Cannot find method entry in method pool for method: " ++ name)
Just index -> let
declaration = MethodDeclaration returntype name parameters statement
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
in case (splitAt index (methods input)) of
(pre, []) -> input
(pre, method : post) -> let
(_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration
assembledMethod = method {
memberAttributes = [
CodeAttribute {
attributeMaxStack = 420,
attributeMaxLocals = 420,
attributeCode = bytecode
}
]
}
in
input {
methods = pre ++ (assembledMethod : post)
}
classBuilder :: ClassFileBuilder Class
classBuilder (Class name methods fields) _ = let
baseConstants = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code"
]
nameConstants = [ClassInfo 9, Utf8Info name]
nakedClassFile = ClassFile {
constantPool = baseConstants ++ nameConstants,
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [],
attributes = []
}
methodsWithInjectedConstructor = injectDefaultConstructor methods
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
classFileWithFields = foldr fieldBuilder nakedClassFile fields
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers
in
classFileWithAssembledMethods

View File

@ -1,7 +1,6 @@
module ByteCode.ByteUtil(unpackWord16, unpackWord32) where
module ByteCode.ByteUtil where
import Data.Word ( Word8, Word16, Word32 )
import Data.Int
import Data.Bits
unpackWord16 :: Word16 -> [Word8]

View File

@ -6,15 +6,16 @@ module ByteCode.ClassFile(
Operation(..),
serialize,
emptyClassFile,
opcodeEncodingLength
opcodeEncodingLength,
className
) where
import Data.Word
import Data.Int
import Data.ByteString (unpack)
import Data.ByteString.UTF8 (fromString)
import ByteCode.ByteUtil
import ByteCode.Constants
import ByteCode.ByteUtil
data ConstantInfo = ClassInfo Word16
| FieldRefInfo Word16 Word16
@ -28,11 +29,13 @@ data Operation = Opiadd
| Opisub
| Opimul
| Opidiv
| Opirem
| Opiand
| Opior
| Opixor
| Opineg
| Opdup
| Opnew Word16
| Opif_icmplt Word16
| Opif_icmple Word16
| Opif_icmpgt Word16
@ -43,7 +46,10 @@ data Operation = Opiadd
| Opreturn
| Opireturn
| Opareturn
| Opdup_x1
| Oppop
| Opinvokespecial Word16
| Opinvokevirtual Word16
| Opgoto Word16
| Opsipush Word16
| Opldc_w Word16
@ -91,16 +97,26 @@ emptyClassFile = ClassFile {
attributes = []
}
className :: ClassFile -> String
className classFile = let
classInfo = (constantPool classFile)!!(fromIntegral (thisClass classFile))
in case classInfo of
Utf8Info className -> className
otherwise -> error ("expected Utf8Info but got: " ++ show otherwise)
opcodeEncodingLength :: Operation -> Word16
opcodeEncodingLength Opiadd = 1
opcodeEncodingLength Opisub = 1
opcodeEncodingLength Opimul = 1
opcodeEncodingLength Opidiv = 1
opcodeEncodingLength Opirem = 1
opcodeEncodingLength Opiand = 1
opcodeEncodingLength Opior = 1
opcodeEncodingLength Opixor = 1
opcodeEncodingLength Opineg = 1
opcodeEncodingLength Opdup = 1
opcodeEncodingLength (Opnew _) = 3
opcodeEncodingLength (Opif_icmplt _) = 3
opcodeEncodingLength (Opif_icmple _) = 3
opcodeEncodingLength (Opif_icmpgt _) = 3
@ -111,7 +127,10 @@ opcodeEncodingLength Opaconst_null = 1
opcodeEncodingLength Opreturn = 1
opcodeEncodingLength Opireturn = 1
opcodeEncodingLength Opareturn = 1
opcodeEncodingLength Opdup_x1 = 1
opcodeEncodingLength Oppop = 1
opcodeEncodingLength (Opinvokespecial _) = 3
opcodeEncodingLength (Opinvokevirtual _) = 3
opcodeEncodingLength (Opgoto _) = 3
opcodeEncodingLength (Opsipush _) = 3
opcodeEncodingLength (Opldc_w _) = 3
@ -147,11 +166,13 @@ instance Serializable Operation where
serialize Opisub = [0x64]
serialize Opimul = [0x68]
serialize Opidiv = [0x6C]
serialize Opirem = [0x70]
serialize Opiand = [0x7E]
serialize Opior = [0x80]
serialize Opixor = [0x82]
serialize Opineg = [0x74]
serialize Opdup = [0x59]
serialize (Opnew index) = 0xBB : unpackWord16 index
serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch
serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch
serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch
@ -162,7 +183,10 @@ instance Serializable Operation where
serialize Opreturn = [0xB1]
serialize Opireturn = [0xAC]
serialize Opareturn = [0xB0]
serialize Opdup_x1 = [0x5A]
serialize Oppop = [0x57]
serialize (Opinvokespecial index) = 0xB7 : unpackWord16 index
serialize (Opinvokevirtual index) = 0xB6 : unpackWord16 index
serialize (Opgoto index) = 0xA7 : unpackWord16 index
serialize (Opsipush index) = 0x11 : unpackWord16 index
serialize (Opldc_w index) = 0x13 : unpackWord16 index

View File

@ -1,228 +0,0 @@
module ByteCode.Generation.Assembler.ExpressionAndStatement where
import Ast
import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Generation.Generator
import Data.List
import Data.Char
import ByteCode.Generation.Builder.Field
assembleExpression :: Assembler Expression
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
| elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
in
(bConstants, bOps ++ [binaryOperation op], lvars)
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
cmp_op = comparisonOperation op 9
cmp_ops = [cmp_op, Opsipush 0, Opgoto 6, Opsipush 1]
in
(bConstants, bOps ++ cmp_ops, lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (CharacterLiteral literal)) =
(constants, ops ++ [Opsipush (fromIntegral (ord literal))], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (BooleanLiteral literal)) =
(constants, ops ++ [Opsipush (if literal then 1 else 0)], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (IntegerLiteral literal))
| literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)], lvars)
| otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) =
(constants, ops ++ [Opaconst_null], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression etype (UnaryOperation Not expr)) = let
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
newConstant = fromIntegral (1 + length exprConstants)
in case etype of
"int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor], lvars)
"char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor], lvars)
"boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Minus expr)) = let
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in
(exprConstants, exprOps ++ [Opineg], lvars)
assembleExpression (constants, ops, lvars) (TypedExpression _ (FieldVariable name)) = let
fieldIndex = findFieldIndex constants name
in case fieldIndex of
Just index -> (constants, ops ++ [Opaload 0, Opgetfield (fromIntegral index)], lvars)
Nothing -> error ("No such field found in constant pool: " ++ name)
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) = let
localIndex = findIndex ((==) name) lvars
isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
Nothing -> error ("No such local variable found in local variable pool: " ++ name)
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
assembleStatementExpression (constants, ops, lvars) stmtexp
assembleExpression _ expr = error ("unimplemented: " ++ show expr)
-- TODO untested
assembleStatementExpression :: Assembler StatementExpression
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (Assignment (TypedExpression dtype (LocalVariable name)) expr)) = let
localIndex = findIndex ((==) name) lvars
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of
Just index -> (constants_a, ops_a ++ if isPrimitive then [Opistore (fromIntegral index)] else [Opastore (fromIntegral index)], lvars)
Nothing -> error ("No such local variable found in local variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (Assignment (TypedExpression dtype (FieldVariable name)) expr)) = let
fieldIndex = findFieldIndex constants name
(constants_a, ops_a, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr
in case fieldIndex of
Just index -> (constants_a, ops_a ++ [Opputfield (fromIntegral index)], lvars)
Nothing -> error ("No such field variable found in constant pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreIncrement (TypedExpression dtype (LocalVariable name)))) = let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup]
in case localIndex of
Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostIncrement (TypedExpression dtype (LocalVariable name)))) = let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd]
in case localIndex of
Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreDecrement (TypedExpression dtype (LocalVariable name)))) = let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub]
in case localIndex of
Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostDecrement (TypedExpression dtype (LocalVariable name)))) = let
localIndex = findIndex ((==) name) lvars
expr = (TypedExpression dtype (LocalVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub]
in case localIndex of
Just index -> (exprConstants, incrOps ++ [Opistore (fromIntegral index)], lvars)
Nothing -> error("No such local variable found in local variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreIncrement (TypedExpression dtype (FieldVariable name)))) = let
fieldIndex = findFieldIndex constants name
expr = (TypedExpression dtype (FieldVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr
incrOps = exprOps ++ [Opsipush 1, Opiadd, Opdup]
in case fieldIndex of
Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars)
Nothing -> error("No such field variable found in field variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostIncrement (TypedExpression dtype (FieldVariable name)))) = let
fieldIndex = findFieldIndex constants name
expr = (TypedExpression dtype (FieldVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr
incrOps = exprOps ++ [Opdup, Opsipush 1, Opiadd]
in case fieldIndex of
Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars)
Nothing -> error("No such field variable found in field variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PreDecrement (TypedExpression dtype (FieldVariable name)))) = let
fieldIndex = findFieldIndex constants name
expr = (TypedExpression dtype (FieldVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr
incrOps = exprOps ++ [Opsipush 1, Opiadd, Opisub]
in case fieldIndex of
Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars)
Nothing -> error("No such field variable found in field variable pool: " ++ name)
assembleStatementExpression
(constants, ops, lvars)
(TypedStatementExpression _ (PostDecrement (TypedExpression dtype (FieldVariable name)))) = let
fieldIndex = findFieldIndex constants name
expr = (TypedExpression dtype (FieldVariable name))
(exprConstants, exprOps, _) = assembleExpression (constants, ops ++ [Opaload 0], lvars) expr
incrOps = exprOps ++ [Opdup, Opsipush 1, Opisub]
in case fieldIndex of
Just index -> (exprConstants, incrOps ++ [Opputfield (fromIntegral index)], lvars)
Nothing -> error("No such field variable found in field variable pool: " ++ name)
assembleStatement :: Assembler Statement
assembleStatement (constants, ops, lvars) (TypedStatement stype (Return expr)) = case expr of
Nothing -> (constants, ops ++ [Opreturn], lvars)
Just expr -> let
(expr_constants, expr_ops, _) = assembleExpression (constants, ops, lvars) expr
in
(expr_constants, expr_ops ++ [returnOperation stype], lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (Block statements)) =
foldl assembleStatement (constants, ops, lvars) statements
assembleStatement (constants, ops, lvars) (TypedStatement _ (If expr if_stmt else_stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
(constants_ifa, ops_ifa, _) = assembleStatement (constants_cmp, [], lvars) if_stmt
(constants_elsea, ops_elsea, _) = case else_stmt of
Nothing -> (constants_ifa, [], lvars)
Just stmt -> assembleStatement (constants_ifa, [], lvars) stmt
-- +6 because we insert 2 gotos, one for if, one for else
if_length = sum (map opcodeEncodingLength ops_ifa) + 6
-- +3 because we need to account for the goto in the if statement.
else_length = sum (map opcodeEncodingLength ops_elsea) + 3
in
(constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq if_length] ++ ops_ifa ++ [Opgoto else_length] ++ ops_elsea, lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
(constants_stmta, ops_stmta, _) = assembleStatement (constants_cmp, [], lvars) stmt
-- +3 because we insert 2 gotos, one for the comparison, one for the goto back to the comparison
stmt_length = sum (map opcodeEncodingLength ops_stmta) + 6
entire_length = stmt_length + sum (map opcodeEncodingLength ops_cmp)
in
(constants_stmta, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq stmt_length] ++ ops_stmta ++ [Opgoto (-entire_length)], lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (LocalVariableDeclaration (VariableDeclaration dtype name expr))) = let
isPrimitive = elem dtype ["char", "boolean", "int"]
(constants_init, ops_init, _) = case expr of
Just exp -> assembleExpression (constants, ops, lvars) exp
Nothing -> (constants, ops ++ if isPrimitive then [Opsipush 0] else [Opaconst_null], lvars)
localIndex = fromIntegral (length lvars)
storeLocal = if isPrimitive then [Opistore localIndex] else [Opastore localIndex]
in
(constants_init, ops_init ++ storeLocal, lvars ++ [name])
assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpressionStatement expr)) =
assembleStatementExpression (constants, ops, lvars) expr
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt)

View File

@ -1,20 +0,0 @@
module ByteCode.Generation.Assembler.Method where
import Ast
import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Generation.Generator
import ByteCode.Generation.Assembler.ExpressionAndStatement
assembleMethod :: Assembler MethodDeclaration
assembleMethod (constants, ops, lvars) (MethodDeclaration _ name _ (TypedStatement _ (Block statements)))
| name == "<init>" = let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0, Opinvokespecial 2]
in
(constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a)
| otherwise = let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0]
in
(constants_a, init_ops ++ ops_a, lvars_a)
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Block expected for method body, got: " ++ show stmt)

View File

@ -1,44 +0,0 @@
module ByteCode.Generation.Builder.Class where
import ByteCode.Generation.Builder.Field
import ByteCode.Generation.Builder.Method
import ByteCode.Generation.Generator
import Ast
import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Constants
injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration]
injectDefaultConstructor pre
| any (\(MethodDeclaration _ name _ _) -> name == "<init>") pre = pre
| otherwise = pre ++ [MethodDeclaration "void" "<init>" [] (TypedStatement "void" (Block []))]
classBuilder :: ClassFileBuilder Class
classBuilder (Class name methods fields) _ = let
baseConstants = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code"
]
nameConstants = [ClassInfo 9, Utf8Info name]
nakedClassFile = ClassFile {
constantPool = baseConstants ++ nameConstants,
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [],
attributes = []
}
methodsWithInjectedConstructor = injectDefaultConstructor methods
classFileWithFields = foldr fieldBuilder nakedClassFile fields
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedConstructor
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedConstructor
in
classFileWithAssembledMethods

View File

@ -1,46 +0,0 @@
module ByteCode.Generation.Builder.Field where
import Ast
import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Generation.Generator
import ByteCode.Constants
import Data.List
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int
findFieldIndex constants name = let
fieldRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, FieldRefInfo _ _) <- (zip [1..] constants)
]
fieldRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info fieldName -> (index, fieldName)
something_else -> error ("Expected UTF8Info but got" ++ show something_else))
fieldRefNameInfos
fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames
in case fieldIndex of
Just (index, _) -> Just index
Nothing -> Nothing
fieldBuilder :: ClassFileBuilder VariableDeclaration
fieldBuilder (VariableDeclaration datatype name _) input = let
baseIndex = 1 + length (constantPool input)
constants = [
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info name,
Utf8Info (datatypeDescriptor datatype)
]
field = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
memberAttributes = []
}
in
input {
constantPool = (constantPool input) ++ constants,
fields = (fields input) ++ [field]
}

View File

@ -1,80 +0,0 @@
module ByteCode.Generation.Builder.Method where
import Ast
import ByteCode.ClassFile(ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import ByteCode.Generation.Generator
import ByteCode.Generation.Assembler.Method
import ByteCode.Constants
import Data.List
methodDescriptor :: MethodDeclaration -> String
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
in
"("
++ (concat (map methodParameterDescriptor parameter_types))
++ ")"
++ methodParameterDescriptor returntype
methodParameterDescriptor :: String -> String
methodParameterDescriptor "void" = "V"
methodParameterDescriptor "int" = "I"
methodParameterDescriptor "char" = "C"
methodParameterDescriptor "boolean" = "B"
methodParameterDescriptor x = "L" ++ x ++ ";"
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info)
findMethodIndex :: ClassFile -> String -> Maybe Int
findMethodIndex classFile name = let
constants = constantPool classFile
in
findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile)
methodBuilder :: ClassFileBuilder MethodDeclaration
methodBuilder (MethodDeclaration returntype name parameters statement) input = let
baseIndex = 1 + length (constantPool input)
constants = [
Utf8Info name,
Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block [])))
]
method = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral baseIndex),
memberDescriptorIndex = (fromIntegral (baseIndex + 1)),
memberAttributes = []
}
in
input {
constantPool = (constantPool input) ++ constants,
methods = (methods input) ++ [method]
}
methodAssembler :: ClassFileBuilder MethodDeclaration
methodAssembler (MethodDeclaration returntype name parameters statement) input = let
methodConstantIndex = findMethodIndex input name
in case methodConstantIndex of
Nothing -> error ("Cannot find method entry in method pool for method: " ++ name)
Just index -> let
declaration = MethodDeclaration returntype name parameters statement
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
(pre, method : post) = splitAt index (methods input)
(_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration
assembledMethod = method {
memberAttributes = [
CodeAttribute {
attributeMaxStack = 420,
attributeMaxLocals = 420,
attributeCode = bytecode
}
]
}
in
input {
methods = pre ++ (assembledMethod : post)
}

View File

@ -1,73 +0,0 @@
module ByteCode.Generation.Generator(
datatypeDescriptor,
memberInfoName,
memberInfoDescriptor,
returnOperation,
binaryOperation,
comparisonOperation,
ClassFileBuilder,
Assembler
) where
import ByteCode.Constants
import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..), opcodeEncodingLength)
import Ast
import Data.Char
import Data.List
import Data.Word
type ClassFileBuilder a = a -> ClassFile -> ClassFile
type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInfo], [Operation], [String])
datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V"
datatypeDescriptor "int" = "I"
datatypeDescriptor "char" = "C"
datatypeDescriptor "boolean" = "B"
datatypeDescriptor x = "L" ++ x
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
memberInfoDescriptor constants MemberInfo {
memberAccessFlags = _,
memberNameIndex = _,
memberDescriptorIndex = descriptorIndex,
memberAttributes = _ } = let
descriptor = constants!!((fromIntegral descriptorIndex) - 1)
in case descriptor of
Utf8Info descriptorText -> descriptorText
_ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex)
memberInfoName :: [ConstantInfo] -> MemberInfo -> String
memberInfoName constants MemberInfo {
memberAccessFlags = _,
memberNameIndex = nameIndex,
memberDescriptorIndex = _,
memberAttributes = _ } = let
name = constants!!((fromIntegral nameIndex) - 1)
in case name of
Utf8Info nameText -> nameText
_ -> ("Invalid Item at Constant pool index " ++ show nameIndex)
returnOperation :: DataType -> Operation
returnOperation dtype
| elem dtype ["int", "char", "boolean"] = Opireturn
| otherwise = Opareturn
binaryOperation :: BinaryOperator -> Operation
binaryOperation Addition = Opiadd
binaryOperation Subtraction = Opisub
binaryOperation Multiplication = Opimul
binaryOperation Division = Opidiv
binaryOperation BitwiseAnd = Opiand
binaryOperation BitwiseOr = Opior
binaryOperation BitwiseXor = Opixor
comparisonOperation :: BinaryOperator -> Word16 -> Operation
comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation
comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation
comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation
comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation

245
src/ByteCode/Util.hs Normal file
View File

@ -0,0 +1,245 @@
module ByteCode.Util where
import Data.Int
import Ast
import ByteCode.ClassFile
import Data.List
import Data.Maybe (mapMaybe)
import Data.Word (Word8, Word16, Word32)
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
resolveNameChain :: Expression -> Expression
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name))
resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name))
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression))
-- walks the name resolution chain. returns the second-to-last item of the namechain.
resolveNameChainOwner :: Expression -> Expression
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression))
methodDescriptor :: MethodDeclaration -> String
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
in
"("
++ (concat (map datatypeDescriptor parameter_types))
++ ")"
++ datatypeDescriptor returntype
methodDescriptorFromParamlist :: [Expression] -> String -> String
methodDescriptorFromParamlist parameters returntype = let
parameter_types = [datatype | TypedExpression datatype _ <- parameters]
in
"("
++ (concat (map datatypeDescriptor parameter_types))
++ ")"
++ datatypeDescriptor returntype
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info)
datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V"
datatypeDescriptor "int" = "I"
datatypeDescriptor "char" = "C"
datatypeDescriptor "boolean" = "B"
datatypeDescriptor x = "L" ++ x ++ ";"
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
memberInfoDescriptor constants MemberInfo {
memberAccessFlags = _,
memberNameIndex = _,
memberDescriptorIndex = descriptorIndex,
memberAttributes = _ } = let
descriptor = constants!!((fromIntegral descriptorIndex) - 1)
in case descriptor of
Utf8Info descriptorText -> descriptorText
_ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex)
memberInfoName :: [ConstantInfo] -> MemberInfo -> String
memberInfoName constants MemberInfo {
memberAccessFlags = _,
memberNameIndex = nameIndex,
memberDescriptorIndex = _,
memberAttributes = _ } = let
name = constants!!((fromIntegral nameIndex) - 1)
in case name of
Utf8Info nameText -> nameText
_ -> ("Invalid Item at Constant pool index " ++ show nameIndex)
returnOperation :: DataType -> Operation
returnOperation dtype
| elem dtype ["int", "char", "boolean"] = Opireturn
| otherwise = Opareturn
binaryOperation :: BinaryOperator -> Operation
binaryOperation Addition = Opiadd
binaryOperation Subtraction = Opisub
binaryOperation Multiplication = Opimul
binaryOperation Division = Opidiv
binaryOperation Modulo = Opirem
binaryOperation BitwiseAnd = Opiand
binaryOperation BitwiseOr = Opior
binaryOperation BitwiseXor = Opixor
binaryOperation And = Opiand
binaryOperation Or = Opior
comparisonOperation :: BinaryOperator -> Word16 -> Operation
comparisonOperation CompareEqual branchLocation = Opif_icmpeq branchLocation
comparisonOperation CompareNotEqual branchLocation = Opif_icmpne branchLocation
comparisonOperation CompareLessThan branchLocation = Opif_icmplt branchLocation
comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLocation
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int
findFieldIndex constants name = let
fieldRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, FieldRefInfo classIndex _) <- (zip [1..] constants)
]
fieldRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info fieldName -> (index, fieldName)
something_else -> error ("Expected UTF8Info but got" ++ show something_else))
fieldRefNameInfos
fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames
in case fieldIndex of
Just (index, _) -> Just index
Nothing -> Nothing
findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int
findMethodRefIndex constants name = let
methodRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, MethodRefInfo _ _) <- (zip [1..] constants)
]
methodRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info methodName -> (index, methodName)
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
methodRefNameInfos
methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames
in case methodIndex of
Just (index, _) -> Just index
Nothing -> Nothing
findMethodIndex :: ClassFile -> String -> Maybe Int
findMethodIndex classFile name = let
constants = constantPool classFile
in
findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile)
findClassIndex :: [ConstantInfo] -> String -> Maybe Int
findClassIndex constants name = let
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)]
classNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info className -> (index, className)
something_else -> error("Expected UTF8Info but got " ++ show something_else))
classNameIndices
desiredClassIndex = find (\(index, className) -> className == name) classNames
in case desiredClassIndex of
Just (index, _) -> Just index
Nothing -> Nothing
getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))]
getKnownMembers constants = let
fieldsClassAndNT = [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants)
] ++ [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants)
]
fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of
(ClassInfo nameIndex, NameAndTypeInfo fnameIndex ftypeIndex) -> (index, (constants!!(fromIntegral nameIndex - 1), constants!!(fromIntegral fnameIndex - 1), constants!!(fromIntegral ftypeIndex - 1)))
something_else -> error ("Expected Class and NameType info, but got: " ++ show nameInfo ++ " and " ++ show nameTypeInfo))
fieldsClassAndNT
fieldsResolved = map (\(index, (nameInfo, fnameInfo, ftypeInfo)) -> case (nameInfo, fnameInfo, ftypeInfo) of
(Utf8Info cname, Utf8Info fname, Utf8Info ftype) -> (index, (cname, fname, ftype))
something_else -> error("Expected UTF8Infos but got " ++ show something_else))
fieldsClassNameType
in
fieldsResolved
-- same as findClassIndex, but inserts a new entry into constant pool if not existing
getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int)
getClassIndex constants name = case findClassIndex constants name of
Just index -> (constants, index)
Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants))
-- get the index for a field within a class, creating it if it does not exist.
getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int)
getFieldIndex constants (cname, fname, ftype) = case findMemberIndex constants (cname, fname, ftype) of
Just index -> (constants, index)
Nothing -> let
(constantsWithClass, classIndex) = getClassIndex constants cname
baseIndex = 1 + length constantsWithClass
in
(constantsWithClass ++ [
FieldRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info fname,
Utf8Info (datatypeDescriptor ftype)
], baseIndex)
getMethodIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int)
getMethodIndex constants (cname, mname, mtype) = case findMemberIndex constants (cname, mname, mtype) of
Just index -> (constants, index)
Nothing -> let
(constantsWithClass, classIndex) = getClassIndex constants cname
baseIndex = 1 + length constantsWithClass
in
(constantsWithClass ++ [
MethodRefInfo (fromIntegral classIndex) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info mname,
Utf8Info mtype
], baseIndex)
findMemberIndex :: [ConstantInfo] -> (String, String, String) -> Maybe Int
findMemberIndex constants (cname, fname, ftype) = let
allMembers = getKnownMembers constants
desiredMember = find (\(index, (c, f, ft)) -> (c, f, ft) == (cname, fname, ftype)) allMembers
in
fmap (\(index, _) -> index) desiredMember
injectDefaultConstructor :: [MethodDeclaration] -> [MethodDeclaration]
injectDefaultConstructor pre
| any (\(MethodDeclaration _ name _ _) -> name == "<init>") pre = pre
| otherwise = pre ++ [MethodDeclaration "void" "<init>" [] (TypedStatement "void" (Block []))]
injectFieldInitializers :: String -> [VariableDeclaration] -> [MethodDeclaration] -> [MethodDeclaration]
injectFieldInitializers classname vars pre = let
initializers = mapMaybe (\(variable) -> case variable of
VariableDeclaration dtype name (Just initializer) -> Just (
TypedStatement dtype (
StatementExpressionStatement (
TypedStatementExpression dtype (
Assignment
(TypedExpression dtype (BinaryOperation NameResolution (TypedExpression classname (LocalVariable "this")) (TypedExpression dtype (FieldVariable name))))
initializer
)
)
)
)
otherwise -> Nothing
) vars
in
map (\(method) -> case method of
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
otherwise -> method
) pre

View File

@ -4,17 +4,30 @@ import Example
import Typecheck
import Parser.Lexer (alexScanTokens)
import Parser.JavaParser
import ByteCode.Generation.Generator
import ByteCode.Generation.Builder.Class
import ByteCode.Builder
import ByteCode.ClassFile
import Data.ByteString (pack, writeFile)
import System.Environment
import System.FilePath.Posix (takeDirectory)
main = do
file <- readFile "Testklasse.java"
args <- getArgs
let filename = if null args
then error "Missing filename, I need to know what to compile"
else args!!0
let outputDirectory = takeDirectory filename
print ("Compiling " ++ filename)
file <- readFile filename
let untypedAST = parse $ alexScanTokens file
let typedAST = head (typeCheckCompilationUnit untypedAST)
let abstractClassFile = classBuilder typedAST emptyClassFile
let assembledClassFile = pack (serialize abstractClassFile)
let typedAST = (typeCheckCompilationUnit untypedAST)
let assembledClasses = map (\(typedClass) -> classBuilder typedClass emptyClassFile) typedAST
mapM_ (\(classFile) -> let
fileContent = pack (serialize classFile)
fileName = outputDirectory ++ "/" ++ (className classFile) ++ ".class"
in Data.ByteString.writeFile fileName fileContent
) assembledClasses
Data.ByteString.writeFile "Testklasse.class" assembledClassFile