bytecode #6

Merged
mrab merged 11 commits from bytecode into master 2024-06-21 07:06:07 +00:00
11 changed files with 305 additions and 204 deletions

View File

@ -1,7 +1,7 @@
// compile all test files using: // compile all test files using:
// ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {} // ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {}
// compile (in project root) using: // compile (in project root) using:
// javac -g:none -sourcepath Test/JavaSources/ Test/JavaSources/Main.java // pushd Test/JavaSources; javac -g:none Main.java; popd
// afterwards, run using // afterwards, run using
// java -ea -cp Test/JavaSources/ Main // java -ea -cp Test/JavaSources/ Main
@ -11,6 +11,7 @@ public class Main {
TestEmpty empty = new TestEmpty(); TestEmpty empty = new TestEmpty();
TestFields fields = new TestFields(); TestFields fields = new TestFields();
TestConstructor constructor = new TestConstructor(42); TestConstructor constructor = new TestConstructor(42);
TestArithmetic arithmetic = new TestArithmetic();
TestMultipleClasses multipleClasses = new TestMultipleClasses(); TestMultipleClasses multipleClasses = new TestMultipleClasses();
TestRecursion recursion = new TestRecursion(10); TestRecursion recursion = new TestRecursion(10);
TestMalicious malicious = new TestMalicious(); TestMalicious malicious = new TestMalicious();
@ -21,6 +22,10 @@ public class Main {
assert fields.a == 0 && fields.b == 42; assert fields.a == 0 && fields.b == 42;
// constructor parameters override initializers // constructor parameters override initializers
assert constructor.a == 42; assert constructor.a == 42;
// basic arithmetics
assert arithmetic.basic(1, 2, 3) == 2;
// we have boolean logic as well
assert arithmetic.logic(false, false, true) == true;
// multiple classes within one file work. Referencing another classes fields/methods works. // multiple classes within one file work. Referencing another classes fields/methods works.
assert multipleClasses.a.a == 42; assert multipleClasses.a.a == 42;
// self-referencing classes work. // self-referencing classes work.

View File

@ -0,0 +1,11 @@
public class TestArithmetic {
public int basic(int a, int b, int c)
{
return a + b - c * a / b % c;
}
public boolean logic(boolean a, boolean b, boolean c)
{
return !a && (c || b);
}
}

View File

@ -24,4 +24,11 @@ public class TestRecursion {
return fibonacci(n - 1) + this.fibonacci(n - 2); return fibonacci(n - 1) + this.fibonacci(n - 2);
} }
} }
public int ackermann(int m, int n)
{
if (m == 0) return n + 1;
if (n == 0) return ackermann(m - 1, 1);
return ackermann(m - 1, ackermann(m, n - 1));
}
} }

View File

@ -4,20 +4,7 @@ Die Bytecodegenerierung ist letztendlich eine zweistufige Transformation:
`Getypter AST -> [ClassFile] -> [[Word8]]` `Getypter AST -> [ClassFile] -> [[Word8]]`
Vom AST, der bereits den Typcheck durchlaufen hat, wird zunächst eine Abbildung in die einzelnen ClassFiles vorgenommen. Diese ClassFiles werden anschließend in deren Byte-Repräsentation serialisiert. Vom AST, der bereits den Typcheck durchlaufen hat, wird zunächst eine Abbildung in die einzelnen ClassFiles vorgenommen. Diese ClassFiles werden anschließend in deren Byte-Repräsentation serialisiert. Dieser Teil der Aufgabenstellung wurde gemeinsam von Christian Brier und Matthias Raba umgesetzt.
## Serialisierung
Damit Bytecode generiert werden kann, braucht es Strukturen, die die Daten halten, die letztendlich serialisiert werden. Die JVM erwartet den kompilierten Code in handliche Pakete verpackt. Die Struktur dieser Pakete ist [so definiert](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html).
Jede Struktur, die in dieser übergreifenden Class File vorkommt, haben wir in Haskell abgebildet. Es gibt z.B die Struktur "ClassFile", die wiederum weitere Strukturen wie z.B Informationen über Felder oder Methoden der Klasse. Alle diese Strukturen implementieren folgendes TypeClass:
```
class Serializable a where
serialize :: a -> [Word8]
```
Die Struktur ClassFile ruft für deren Kinder rekursiv diese `serialize` Funktion auf. Am Ende bleibt eine flache Word8-Liste übrig, die Serialisierung ist damit abgeschlossen.
## Codegenerierung ## Codegenerierung
@ -32,7 +19,6 @@ Die Idee hinter beiden ist, dass sie jeweils zwei Inputs haben, wobei der Rückg
Der Nutzer ruft beispielsweise die Funktion `classBuilder` auf. Diese wendet nach und nach folgende Transformationen an: Der Nutzer ruft beispielsweise die Funktion `classBuilder` auf. Diese wendet nach und nach folgende Transformationen an:
``` ```
methodsWithInjectedConstructor = injectDefaultConstructor methods methodsWithInjectedConstructor = injectDefaultConstructor methods
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
@ -47,7 +33,7 @@ Zuerst wird (falls notwendig) ein leerer Defaultkonstruktor in die Classfile ein
2. Hinzufügen aller Methoden (nur Prototypen) 2. Hinzufügen aller Methoden (nur Prototypen)
3. Hinzufügen des Bytecodes in allen Methoden 3. Hinzufügen des Bytecodes in allen Methoden
Die Unterteilung von Schritt 2 und 3 ist deswegen notwendig, weil der Code einer Methode auch eine andere, erst nachher deklarierte Methode aufrufen kann. Nach Schritt 2 sind alle Methoden der Klasse bekannt. Wie beschrieben wird auch hier der Zustand über alle Faltungen mitgenommen. Jeder Schritt hat Zugriff auf alle Daten, die aus dem vorherigen Schritt bleiben. Sukkzessive wird eine korrekte ClassFile aufgebaut. Die Unterteilung von Schritt 2 und 3 ist deswegen notwendig, weil der Code einer Methode auch eine andere, erst nachher deklarierte Methode aufrufen kann. Nach Schritt 2 sind alle Methoden der Klasse bekannt. Wie beschrieben wird auch hier der Zustand über alle Faltungen mitgenommen. Jeder Schritt hat Zugriff auf alle Daten, die aus dem vorherigen Schritt bleiben. Sukzessive wird eine korrekte ClassFile aufgebaut.
Besonders interessant ist hierbei Schritt 3. Dort wird das Verhalten jeder einzelnen Methode in Bytecode übersetzt. In diesem Schritt werden zusätzlich zu den `Buildern` noch die `Assembler` verwendet (Definition siehe oben.) Die Assembler funktionieren ähnlich wie die Builder, arbeiten allerdings nicht auf einer ClassFile, sondern auf dem Inhalt einer Methode: Sie verarbeiten jeweils ein Tupel: Besonders interessant ist hierbei Schritt 3. Dort wird das Verhalten jeder einzelnen Methode in Bytecode übersetzt. In diesem Schritt werden zusätzlich zu den `Buildern` noch die `Assembler` verwendet (Definition siehe oben.) Die Assembler funktionieren ähnlich wie die Builder, arbeiten allerdings nicht auf einer ClassFile, sondern auf dem Inhalt einer Methode: Sie verarbeiten jeweils ein Tupel:
@ -67,3 +53,27 @@ assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) =
Hier werden die Konstanten und lokalen Variablen des Inputs nicht berührt, dem Bytecode wird lediglich die Operation `aconst_null` hinzugefügt. Damit ist das Verhalten des gematchten Inputs - eines Nullliterals - abgebildet. Hier werden die Konstanten und lokalen Variablen des Inputs nicht berührt, dem Bytecode wird lediglich die Operation `aconst_null` hinzugefügt. Damit ist das Verhalten des gematchten Inputs - eines Nullliterals - abgebildet.
Die Assembler rufen sich teilweise rekursiv selbst auf, da ja auch der AST verschachteltes Verhalten abbilden kann. Der Startpunkt für die Assembly einer Methode ist der Builder `methodAssembler`. Dieser entspricht Schritt 3 in der obigen Übersicht. Die Assembler rufen sich teilweise rekursiv selbst auf, da ja auch der AST verschachteltes Verhalten abbilden kann. Der Startpunkt für die Assembly einer Methode ist der Builder `methodAssembler`. Dieser entspricht Schritt 3 in der obigen Übersicht.
## Serialisierung
Damit Bytecode generiert werden kann, braucht es Strukturen, die die Daten halten, die letztendlich serialisiert werden. Die JVM erwartet den kompilierten Code in handliche Pakete verpackt. Die Struktur dieser Pakete ist [so definiert](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html).
Jede Struktur, die in dieser übergreifenden Class File vorkommt, haben wir in Haskell abgebildet. Es gibt z.B die Struktur "ClassFile", die wiederum weitere Strukturen wie z.B Informationen über Felder oder Methoden der Klasse beinhaltet. Alle diese Strukturen implementieren folgende TypeClass:
```
class Serializable a where
serialize :: a -> [Word8]
```
Hier ist ein Beispiel anhand der Serialisierung der einzelnen Operationen:
```
instance Serializable Operation where
serialize Opiadd = [0x60]
serialize Opisub = [0x64]
serialize Opimul = [0x68]
...
serialize (Opgetfield index) = 0xB4 : unpackWord16 index
```
Die Struktur ClassFile ruft für deren Kinder rekursiv diese `serialize` Funktion auf und konkateniert die Ergebnisse. Am Ende bleibt eine flache Word8-Liste übrig, die Serialisierung ist damit abgeschlossen. Da der Typecheck sicherstellt, dass alle referenzierten Methoden/Felder gültig sind, kann die Übersetzung der einzelnen Klassen voneinander unabhängig geschehen.

View File

@ -12,12 +12,12 @@ type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInf
assembleExpression :: Assembler Expression assembleExpression :: Assembler Expression
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b)) assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
| elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let | op `elem` [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
in in
(bConstants, bOps ++ [binaryOperation op], lvars) (bConstants, bOps ++ [binaryOperation op], lvars)
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let | op `elem` [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a (aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b (bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
cmp_op = comparisonOperation op 9 cmp_op = comparisonOperation op 9
@ -60,7 +60,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Mi
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name)) assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name))
| name == "this" = (constants, ops ++ [Opaload 0], lvars) | name == "this" = (constants, ops ++ [Opaload 0], lvars)
| otherwise = let | otherwise = let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
isPrimitive = elem dtype ["char", "boolean", "int"] isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of in case localIndex of
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars) Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
@ -69,7 +69,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) = assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
assembleStatementExpression (constants, ops, lvars) stmtexp assembleStatementExpression (constants, ops, lvars) stmtexp
assembleExpression _ expr = error ("unimplemented: " ++ show expr) assembleExpression _ expr = error ("Unknown expression: " ++ show expr)
assembleNameChain :: Assembler Expression assembleNameChain :: Assembler Expression
assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) = assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) =
@ -84,7 +84,7 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr (constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
isPrimitive = elem dtype ["char", "boolean", "int"] isPrimitive = elem dtype ["char", "boolean", "int"]
in case localIndex of in case localIndex of
@ -99,7 +99,7 @@ assembleStatementExpression
(constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr (constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr
in in
(constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -107,8 +107,8 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars)
@ -121,7 +121,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -129,8 +129,8 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars)
@ -143,7 +143,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -151,8 +151,8 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars)
@ -165,7 +165,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -173,8 +173,8 @@ assembleStatementExpression
target = resolveNameChain (TypedExpression dtype receiver) target = resolveNameChain (TypedExpression dtype receiver)
in case target of in case target of
(TypedExpression dtype (LocalVariable name)) -> let (TypedExpression dtype (LocalVariable name)) -> let
localIndex = findIndex ((==) name) lvars localIndex = elemIndex name lvars
expr = (TypedExpression dtype (LocalVariable name)) expr = TypedExpression dtype (LocalVariable name)
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr (exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
in case localIndex of in case localIndex of
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars) Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars)
@ -187,7 +187,7 @@ assembleStatementExpression
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver) (constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
in in
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars) (constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars)
something_else -> error ("expected TypedExpression, but got: " ++ show something_else) something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
assembleStatementExpression assembleStatementExpression
(constants, ops, lvars) (constants, ops, lvars)
@ -231,7 +231,7 @@ assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt
else_length = sum (map opcodeEncodingLength ops_elsea) else_length = sum (map opcodeEncodingLength ops_elsea)
in case dtype of in case dtype of
"void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars) "void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars)
otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars) _ -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars)
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr (constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
@ -257,20 +257,19 @@ assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpression
in in
(constants_e, ops_e ++ [Oppop], lvars_e) (constants_e, ops_e ++ [Oppop], lvars_e)
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt) assembleStatement _ stmt = error ("Unknown statement: " ++ show stmt)
assembleMethod :: Assembler MethodDeclaration assembleMethod :: Assembler MethodDeclaration
assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements))) assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements)))
| name == "<init>" = let | name == "<init>" = let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
init_ops = [Opaload 0, Opinvokespecial 2]
in in
(constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a) (constants_a, [Opaload 0, Opinvokespecial 2] ++ ops_a ++ [Opreturn], lvars_a)
| otherwise = case returntype of | otherwise = case returntype of
"void" -> let "void" -> let
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements (constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
in in
(constants_a, ops_a ++ [Opreturn], lvars_a) (constants_a, ops_a ++ [Opreturn], lvars_a)
otherwise -> foldl assembleStatement (constants, ops, lvars) statements _ -> foldl assembleStatement (constants, ops, lvars) statements
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt) assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt)

View File

@ -22,14 +22,14 @@ fieldBuilder (VariableDeclaration datatype name _) input = let
] ]
field = MemberInfo { field = MemberInfo {
memberAccessFlags = accessPublic, memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)), memberNameIndex = fromIntegral (baseIndex + 2),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)), memberDescriptorIndex = fromIntegral (baseIndex + 3),
memberAttributes = [] memberAttributes = []
} }
in in
input { input {
constantPool = (constantPool input) ++ constants, constantPool = constantPool input ++ constants,
fields = (fields input) ++ [field] fields = fields input ++ [field]
} }
@ -46,18 +46,17 @@ methodBuilder (MethodDeclaration returntype name parameters statement) input = l
method = MemberInfo { method = MemberInfo {
memberAccessFlags = accessPublic, memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)), memberNameIndex = fromIntegral (baseIndex + 2),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)), memberDescriptorIndex = fromIntegral (baseIndex + 3),
memberAttributes = [] memberAttributes = []
} }
in in
input { input {
constantPool = (constantPool input) ++ constants, constantPool = constantPool input ++ constants,
methods = (methods input) ++ [method] methods = methods input ++ [method]
} }
methodAssembler :: ClassFileBuilder MethodDeclaration methodAssembler :: ClassFileBuilder MethodDeclaration
methodAssembler (MethodDeclaration returntype name parameters statement) input = let methodAssembler (MethodDeclaration returntype name parameters statement) input = let
methodConstantIndex = findMethodIndex input name methodConstantIndex = findMethodIndex input name
@ -66,21 +65,22 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input =
Just index -> let Just index -> let
declaration = MethodDeclaration returntype name parameters statement declaration = MethodDeclaration returntype name parameters statement
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters] paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
in case (splitAt index (methods input)) of in case splitAt index (methods input) of
(pre, []) -> input (pre, []) -> input
(pre, method : post) -> let (pre, method : post) -> let
(_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration (constants, bytecode, aParamNames) = assembleMethod (constantPool input, [], paramNames) declaration
assembledMethod = method { assembledMethod = method {
memberAttributes = [ memberAttributes = [
CodeAttribute { CodeAttribute {
attributeMaxStack = 420, attributeMaxStack = fromIntegral $ maxStackDepth constants bytecode,
attributeMaxLocals = 420, attributeMaxLocals = fromIntegral $ length aParamNames,
attributeCode = bytecode attributeCode = bytecode
} }
] ]
} }
in in
input { input {
constantPool = constants,
methods = pre ++ (assembledMethod : post) methods = pre ++ (assembledMethod : post)
} }
@ -94,11 +94,12 @@ classBuilder (Class name methods fields) _ = let
Utf8Info "java/lang/Object", Utf8Info "java/lang/Object",
Utf8Info "<init>", Utf8Info "<init>",
Utf8Info "()V", Utf8Info "()V",
Utf8Info "Code" Utf8Info "Code",
ClassInfo 9,
Utf8Info name
] ]
nameConstants = [ClassInfo 9, Utf8Info name]
nakedClassFile = ClassFile { nakedClassFile = ClassFile {
constantPool = baseConstants ++ nameConstants, constantPool = baseConstants,
accessFlags = accessPublic, accessFlags = accessPublic,
thisClass = 8, thisClass = 8,
superClass = 1, superClass = 1,
@ -107,9 +108,13 @@ classBuilder (Class name methods fields) _ = let
attributes = [] attributes = []
} }
-- if a class has no constructor, inject an empty one.
methodsWithInjectedConstructor = injectDefaultConstructor methods methodsWithInjectedConstructor = injectDefaultConstructor methods
-- for every constructor, prepend all initialization assignments for fields.
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
-- add fields, then method bodies to the classfile. After all referable names are known,
-- assemble the methods into bytecode.
classFileWithFields = foldr fieldBuilder nakedClassFile fields classFileWithFields = foldr fieldBuilder nakedClassFile fields
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers

View File

@ -1,14 +1,4 @@
module ByteCode.ClassFile( module ByteCode.ClassFile where
ConstantInfo(..),
Attribute(..),
MemberInfo(..),
ClassFile(..),
Operation(..),
serialize,
emptyClassFile,
opcodeEncodingLength,
className
) where
import Data.Word import Data.Word
import Data.Int import Data.Int
@ -99,10 +89,10 @@ emptyClassFile = ClassFile {
className :: ClassFile -> String className :: ClassFile -> String
className classFile = let className classFile = let
classInfo = (constantPool classFile)!!(fromIntegral (thisClass classFile)) classInfo = constantPool classFile !! fromIntegral (thisClass classFile)
in case classInfo of in case classInfo of
Utf8Info className -> className Utf8Info className -> className
otherwise -> error ("expected Utf8Info but got: " ++ show otherwise) unexpected_element -> error ("expected Utf8Info but got: " ++ show unexpected_element)
opcodeEncodingLength :: Operation -> Word16 opcodeEncodingLength :: Operation -> Word16
@ -201,10 +191,10 @@ instance Serializable Attribute where
serialize (CodeAttribute { attributeMaxStack = maxStack, serialize (CodeAttribute { attributeMaxStack = maxStack,
attributeMaxLocals = maxLocals, attributeMaxLocals = maxLocals,
attributeCode = code }) = let attributeCode = code }) = let
assembledCode = concat (map serialize code) assembledCode = concatMap serialize code
in in
unpackWord16 7 -- attribute_name_index unpackWord16 7 -- attribute_name_index
++ unpackWord32 (12 + (fromIntegral (length assembledCode))) -- attribute_length ++ unpackWord32 (12 + fromIntegral (length assembledCode)) -- attribute_length
++ unpackWord16 maxStack -- max_stack ++ unpackWord16 maxStack -- max_stack
++ unpackWord16 maxLocals -- max_locals ++ unpackWord16 maxLocals -- max_locals
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length ++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length

View File

@ -4,29 +4,28 @@ import Data.Int
import Ast import Ast
import ByteCode.ClassFile import ByteCode.ClassFile
import Data.List import Data.List
import Data.Maybe (mapMaybe) import Data.Maybe (mapMaybe, isJust)
import Data.Word (Word8, Word16, Word32) import Data.Word (Word8, Word16, Word32)
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing. -- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
resolveNameChain :: Expression -> Expression resolveNameChain :: Expression -> Expression
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name)) resolveNameChain (TypedExpression dtype (LocalVariable name)) = TypedExpression dtype (LocalVariable name)
resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name)) resolveNameChain (TypedExpression dtype (FieldVariable name)) = TypedExpression dtype (FieldVariable name)
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
-- walks the name resolution chain. returns the second-to-last item of the namechain. -- walks the name resolution chain. returns the second-to-last item of the namechain.
resolveNameChainOwner :: Expression -> Expression resolveNameChainOwner :: Expression -> Expression
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression)) resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
methodDescriptor :: MethodDeclaration -> String methodDescriptor :: MethodDeclaration -> String
methodDescriptor (MethodDeclaration returntype _ parameters _) = let methodDescriptor (MethodDeclaration returntype _ parameters _) = let
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters] parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
in in
"(" "("
++ (concat (map datatypeDescriptor parameter_types)) ++ concatMap datatypeDescriptor parameter_types
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ datatypeDescriptor returntype
@ -35,49 +34,69 @@ methodDescriptorFromParamlist parameters returntype = let
parameter_types = [datatype | TypedExpression datatype _ <- parameters] parameter_types = [datatype | TypedExpression datatype _ <- parameters]
in in
"(" "("
++ (concat (map datatypeDescriptor parameter_types)) ++ concatMap datatypeDescriptor parameter_types
++ ")" ++ ")"
++ datatypeDescriptor returntype ++ datatypeDescriptor returntype
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool -- recursively parses a given type signature into a list of parameter types and the method return type.
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info) -- As an initial parameter, you can supply ([], "void").
parseMethodType :: ([String], String) -> String -> ([String], String)
parseMethodType (params, returnType) ('(' : descriptor) = parseMethodType (params, returnType) descriptor
parseMethodType (params, returnType) ('I' : descriptor) = parseMethodType (params ++ ["I"], returnType) descriptor
parseMethodType (params, returnType) ('C' : descriptor) = parseMethodType (params ++ ["C"], returnType) descriptor
parseMethodType (params, returnType) ('Z' : descriptor) = parseMethodType (params ++ ["Z"], returnType) descriptor
parseMethodType (params, returnType) ('L' : descriptor) = let
typeLength = elemIndex ';' descriptor
in case typeLength of
Just length -> let
(typeName, semicolon : restOfDescriptor) = splitAt length descriptor
in
parseMethodType (params ++ [typeName], returnType) restOfDescriptor
Nothing -> error $ "unterminated class type in function signature: " ++ show descriptor
parseMethodType (params, _) (')' : descriptor) = (params, descriptor)
parseMethodType _ descriptor = error $ "expected start of type name (L, I, C, Z) but got: " ++ descriptor
-- given a method index (constant pool index),
-- returns the full type of the method. (i.e (LSomething;II)V)
methodTypeFromIndex :: [ConstantInfo] -> Int -> String
methodTypeFromIndex constants index = case constants !! fromIntegral (index - 1) of
MethodRefInfo _ nameAndTypeIndex -> case constants !! fromIntegral (nameAndTypeIndex - 1) of
NameAndTypeInfo _ typeIndex -> case constants !! fromIntegral (typeIndex - 1) of
Utf8Info typeLiteral -> typeLiteral
unexpectedElement -> error "Expected Utf8Info but got: " ++ show unexpectedElement
unexpectedElement -> error "Expected NameAndTypeInfo but got: " ++ show unexpectedElement
unexpectedElement -> error "Expected MethodRefInfo but got: " ++ show unexpectedElement
methodParametersFromIndex :: [ConstantInfo] -> Int -> ([String], String)
methodParametersFromIndex constants index = parseMethodType ([], "V") (methodTypeFromIndex constants index)
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
memberInfoIsMethod constants info = '(' `elem` memberInfoDescriptor constants info
datatypeDescriptor :: String -> String datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V" datatypeDescriptor "void" = "V"
datatypeDescriptor "int" = "I" datatypeDescriptor "int" = "I"
datatypeDescriptor "char" = "C" datatypeDescriptor "char" = "C"
datatypeDescriptor "boolean" = "B" datatypeDescriptor "boolean" = "Z"
datatypeDescriptor x = "L" ++ x ++ ";" datatypeDescriptor x = "L" ++ x ++ ";"
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
memberInfoDescriptor constants MemberInfo { memberInfoDescriptor constants MemberInfo { memberDescriptorIndex = descriptorIndex } = let
memberAccessFlags = _, descriptor = constants !! (fromIntegral descriptorIndex - 1)
memberNameIndex = _,
memberDescriptorIndex = descriptorIndex,
memberAttributes = _ } = let
descriptor = constants!!((fromIntegral descriptorIndex) - 1)
in case descriptor of in case descriptor of
Utf8Info descriptorText -> descriptorText Utf8Info descriptorText -> descriptorText
_ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex) _ -> "Invalid Item at Constant pool index " ++ show descriptorIndex
memberInfoName :: [ConstantInfo] -> MemberInfo -> String memberInfoName :: [ConstantInfo] -> MemberInfo -> String
memberInfoName constants MemberInfo { memberInfoName constants MemberInfo { memberNameIndex = nameIndex } = let
memberAccessFlags = _, name = constants !! (fromIntegral nameIndex - 1)
memberNameIndex = nameIndex,
memberDescriptorIndex = _,
memberAttributes = _ } = let
name = constants!!((fromIntegral nameIndex) - 1)
in case name of in case name of
Utf8Info nameText -> nameText Utf8Info nameText -> nameText
_ -> ("Invalid Item at Constant pool index " ++ show nameIndex) _ -> "Invalid Item at Constant pool index " ++ show nameIndex
returnOperation :: DataType -> Operation returnOperation :: DataType -> Operation
returnOperation dtype returnOperation dtype
| elem dtype ["int", "char", "boolean"] = Opireturn | dtype `elem` ["int", "char", "boolean"] = Opireturn
| otherwise = Opareturn | otherwise = Opareturn
binaryOperation :: BinaryOperator -> Operation binaryOperation :: BinaryOperator -> Operation
@ -100,50 +119,27 @@ comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLoc
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int comparisonOffset :: Operation -> Maybe Int
findFieldIndex constants name = let comparisonOffset (Opif_icmpeq offset) = Just $ fromIntegral offset
fieldRefNameInfos = [ comparisonOffset (Opif_icmpne offset) = Just $ fromIntegral offset
-- we only skip one entry to get the name since the Java constant pool comparisonOffset (Opif_icmplt offset) = Just $ fromIntegral offset
-- is 1-indexed (why) comparisonOffset (Opif_icmple offset) = Just $ fromIntegral offset
(index, constants!!(fromIntegral index + 1)) comparisonOffset (Opif_icmpgt offset) = Just $ fromIntegral offset
| (index, FieldRefInfo classIndex _) <- (zip [1..] constants) comparisonOffset (Opif_icmpge offset) = Just $ fromIntegral offset
] comparisonOffset anything_else = Nothing
fieldRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info fieldName -> (index, fieldName)
something_else -> error ("Expected UTF8Info but got" ++ show something_else))
fieldRefNameInfos
fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames
in case fieldIndex of
Just (index, _) -> Just index
Nothing -> Nothing
findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int
findMethodRefIndex constants name = let
methodRefNameInfos = [
-- we only skip one entry to get the name since the Java constant pool
-- is 1-indexed (why)
(index, constants!!(fromIntegral index + 1))
| (index, MethodRefInfo _ _) <- (zip [1..] constants)
]
methodRefNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info methodName -> (index, methodName)
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
methodRefNameInfos
methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames
in case methodIndex of
Just (index, _) -> Just index
Nothing -> Nothing
isComparisonOperation :: Operation -> Bool
isComparisonOperation op = isJust (comparisonOffset op)
findMethodIndex :: ClassFile -> String -> Maybe Int findMethodIndex :: ClassFile -> String -> Maybe Int
findMethodIndex classFile name = let findMethodIndex classFile name = let
constants = constantPool classFile constants = constantPool classFile
in in
findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile) findIndex (\method -> memberInfoIsMethod constants method && memberInfoName constants method == name) (methods classFile)
findClassIndex :: [ConstantInfo] -> String -> Maybe Int findClassIndex :: [ConstantInfo] -> String -> Maybe Int
findClassIndex constants name = let findClassIndex constants name = let
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)] classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- zip [1..] constants]
classNames = map (\(index, nameInfo) -> case nameInfo of classNames = map (\(index, nameInfo) -> case nameInfo of
Utf8Info className -> (index, className) Utf8Info className -> (index, className)
something_else -> error ("Expected UTF8Info but got " ++ show something_else)) something_else -> error ("Expected UTF8Info but got " ++ show something_else))
@ -157,10 +153,10 @@ getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))]
getKnownMembers constants = let getKnownMembers constants = let
fieldsClassAndNT = [ fieldsClassAndNT = [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) | (index, FieldRefInfo classIndex nameTypeIndex) <- zip [1..] constants
] ++ [ ] ++ [
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1)) (index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
| (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants) | (index, MethodRefInfo classIndex nameTypeIndex) <- zip [1..] constants
] ]
fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of
@ -179,7 +175,7 @@ getKnownMembers constants = let
getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int) getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int)
getClassIndex constants name = case findClassIndex constants name of getClassIndex constants name = case findClassIndex constants name of
Just index -> (constants, index) Just index -> (constants, index)
Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants)) Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants) + 2), Utf8Info name], fromIntegral (length constants) + 1)
-- get the index for a field within a class, creating it if it does not exist. -- get the index for a field within a class, creating it if it does not exist.
getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int) getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int)
@ -239,7 +235,58 @@ injectFieldInitializers classname vars pre = let
otherwise -> Nothing otherwise -> Nothing
) vars ) vars
in in
map (\(method) -> case method of map (\method -> case method of
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements))) MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
otherwise -> method _ -> method
) pre ) pre
-- effect of one instruction/operation on the stack
operationStackCost :: [ConstantInfo] -> Operation -> Int
operationStackCost constants Opiadd = -1
operationStackCost constants Opisub = -1
operationStackCost constants Opimul = -1
operationStackCost constants Opidiv = -1
operationStackCost constants Opirem = -1
operationStackCost constants Opiand = -1
operationStackCost constants Opior = -1
operationStackCost constants Opixor = -1
operationStackCost constants Opineg = 0
operationStackCost constants Opdup = 1
operationStackCost constants (Opnew _) = 1
operationStackCost constants (Opif_icmplt _) = -2
operationStackCost constants (Opif_icmple _) = -2
operationStackCost constants (Opif_icmpgt _) = -2
operationStackCost constants (Opif_icmpge _) = -2
operationStackCost constants (Opif_icmpeq _) = -2
operationStackCost constants (Opif_icmpne _) = -2
operationStackCost constants Opaconst_null = 1
operationStackCost constants Opreturn = 0
operationStackCost constants Opireturn = -1
operationStackCost constants Opareturn = -1
operationStackCost constants Opdup_x1 = 1
operationStackCost constants Oppop = -1
operationStackCost constants (Opinvokespecial idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - fromEnum (returnType /= "V")
operationStackCost constants (Opinvokevirtual idx) = let
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
in (length params + 1) - fromEnum (returnType /= "V")
operationStackCost constants (Opgoto _) = 0
operationStackCost constants (Opsipush _) = 1
operationStackCost constants (Opldc_w _) = 1
operationStackCost constants (Opaload _) = 1
operationStackCost constants (Opiload _) = 1
operationStackCost constants (Opastore _) = -1
operationStackCost constants (Opistore _) = -1
operationStackCost constants (Opputfield _) = -2
operationStackCost constants (Opgetfield _) = -1
simulateStackOperation :: [ConstantInfo] -> Operation -> (Int, Int) -> (Int, Int)
simulateStackOperation constants op (cd, md) = let
depth = cd + operationStackCost constants op
in if depth < 0
then error ("Consuming value off of empty stack: " ++ show op)
else (depth, max depth md)
maxStackDepth :: [ConstantInfo] -> [Operation] -> Int
maxStackDepth constants ops = snd $ foldr (simulateStackOperation constants) (0, 0) (reverse ops)

View File

@ -1,6 +1,5 @@
module Main where module Main where
import Example
import Typecheck import Typecheck
import Parser.Lexer (alexScanTokens) import Parser.Lexer (alexScanTokens)
import Parser.JavaParser import Parser.JavaParser
@ -14,18 +13,18 @@ main = do
args <- getArgs args <- getArgs
let filename = if null args let filename = if null args
then error "Missing filename, I need to know what to compile" then error "Missing filename, I need to know what to compile"
else args!!0 else head args
let outputDirectory = takeDirectory filename let outputDirectory = takeDirectory filename
print ("Compiling " ++ filename) print ("Compiling " ++ filename)
file <- readFile filename file <- readFile filename
let untypedAST = parse $ alexScanTokens file let untypedAST = parse $ alexScanTokens file
let typedAST = (typeCheckCompilationUnit untypedAST) let typedAST = typeCheckCompilationUnit untypedAST
let assembledClasses = map (\(typedClass) -> classBuilder typedClass emptyClassFile) typedAST let assembledClasses = map (`classBuilder` emptyClassFile) typedAST
mapM_ (\(classFile) -> let mapM_ (\classFile -> let
fileContent = pack (serialize classFile) fileContent = pack (serialize classFile)
fileName = outputDirectory ++ "/" ++ (className classFile) ++ ".class" fileName = outputDirectory ++ "/" ++ className classFile ++ ".class"
in Data.ByteString.writeFile fileName fileContent in Data.ByteString.writeFile fileName fileContent
) assembledClasses ) assembledClasses

View File

@ -37,14 +37,18 @@ typeCheckVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)
-- Type check the initializer expression if it exists -- Type check the initializer expression if it exists
checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr exprType = fmap getTypeFromExpr checkedExpr
checkedExprWithType = case exprType of
Just "null" | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
_ -> checkedExpr
in case (validType, redefined, exprType) of in case (validType, redefined, exprType) of
(False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'" (False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'"
(_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope" (_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
(_, _, Just t) (_, _, Just t)
| t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExpr | t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExprWithType
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
| otherwise -> VariableDeclaration dataType identifier checkedExpr | otherwise -> VariableDeclaration dataType identifier checkedExprWithType
(_, _, Nothing) -> VariableDeclaration dataType identifier checkedExpr (_, _, Nothing) -> VariableDeclaration dataType identifier checkedExprWithType
-- ********************************** Type Checking: Expressions ********************************** -- ********************************** Type Checking: Expressions **********************************
@ -125,18 +129,21 @@ typeCheckStatementExpression (Assignment ref expr) symtab classes =
ref' = typeCheckExpression ref symtab classes ref' = typeCheckExpression ref symtab classes
type' = getTypeFromExpr expr' type' = getTypeFromExpr expr'
type'' = getTypeFromExpr ref' type'' = getTypeFromExpr ref'
typeToAssign = if type' == "null" && isObjectType type'' then type'' else type'
exprWithType = if type' == "null" && isObjectType type'' then TypedExpression type'' NullLiteral else expr'
in in
if type'' == type' || (type' == "null" && isObjectType type'') then if type'' == typeToAssign then
TypedStatementExpression type'' (Assignment ref' expr') TypedStatementExpression type'' (Assignment ref' exprWithType)
else else
error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type' error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ typeToAssign
typeCheckStatementExpression (ConstructorCall className args) symtab classes = typeCheckStatementExpression (ConstructorCall className args) symtab classes =
case find (\(Class name _ _) -> name == className) classes of case find (\(Class name _ _) -> name == className) classes of
Nothing -> error $ "Class '" ++ className ++ "' not found." Nothing -> error $ "Class '" ++ className ++ "' not found."
Just (Class _ methods fields) -> Just (Class _ methods _) ->
-- Find constructor matching the class name with void return type -- Find constructor matching the class name with void return type
case find (\(MethodDeclaration retType name params _) -> name == "<init>" && retType == "void") methods of case find (\(MethodDeclaration _ name params _) -> name == "<init>") methods of
-- If no constructor is found, assume standard constructor with no parameters -- If no constructor is found, assume standard constructor with no parameters
Nothing -> Nothing ->
if null args then if null args then
@ -144,21 +151,28 @@ typeCheckStatementExpression (ConstructorCall className args) symtab classes =
else else
error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided." error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided."
Just (MethodDeclaration _ _ params _) -> Just (MethodDeclaration _ _ params _) ->
let let args' = zipWith
args' = map (\arg -> typeCheckExpression arg symtab classes) args (\arg (ParameterDeclaration paramType _) ->
-- Extract expected parameter types from the constructor's parameters let argTyped = typeCheckExpression arg symtab classes
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
then TypedExpression paramType NullLiteral
else argTyped
) args params
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params] expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args' argTypes = map getTypeFromExpr args'
-- Check if the types of the provided arguments match the expected types typeMatches = zipWith
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes (\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches) expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches
fst3 (a, _, _) = a
in in
if length args /= length params then if null mismatches && length args == length params then
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
else if not (null mismatchErrors) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
else
TypedStatementExpression className (ConstructorCall className args') TypedStatementExpression className (ConstructorCall className args')
else if not (null mismatches) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else
error $ "Incorrect number of arguments for constructor of class '" ++ className ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes = typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
let objExprTyped = typeCheckExpression expr symtab classes let objExprTyped = typeCheckExpression expr symtab classes
@ -168,20 +182,26 @@ typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
Just (Class _ methods _) -> Just (Class _ methods _) ->
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
Just (MethodDeclaration retType _ params _) -> Just (MethodDeclaration retType _ params _) ->
let args' = map (\arg -> typeCheckExpression arg symtab classes) args let args' = zipWith
(\arg (ParameterDeclaration paramType _) ->
let argTyped = typeCheckExpression arg symtab classes
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
then TypedExpression paramType NullLiteral
else argTyped
) args params
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params] expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args' argTypes = map getTypeFromExpr args'
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes typeMatches = zipWith
(\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches mismatches = filter (not . fst3) typeMatches
where fst3 (a, _, _) = a fst3 (a, _, _) = a
in in if null mismatches && length args == length params
if null mismatches && length args == length params then then TypedStatementExpression retType (MethodCall objExprTyped methodName args')
TypedStatementExpression retType (MethodCall objExprTyped methodName args') else if not (null mismatches)
else if not (null mismatches) then then error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ] : [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else else error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'." Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found." Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
_ -> error "Invalid object type for method call. Object must have a class type." _ -> error "Invalid object type for method call. Object must have a class type."
@ -251,14 +271,18 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident
-- If there's an initializer expression, type check it -- If there's an initializer expression, type check it
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr exprType = fmap getTypeFromExpr checkedExpr
checkedExprWithType = case (exprType, dataType) of
(Just "null", _) | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
_ -> checkedExpr
in case exprType of in case exprType of
Just t Just t
| t == "null" && isObjectType dataType -> | t == "null" && isObjectType dataType ->
TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
| otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) | otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr)) Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (While cond stmt) symtab classes = typeCheckStatement (While cond stmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes let cond' = typeCheckExpression cond symtab classes
stmt' = typeCheckStatement stmt symtab classes stmt' = typeCheckStatement stmt symtab classes
@ -305,13 +329,17 @@ typeCheckStatement (Block statements) symtab classes =
typeCheckStatement (Return expr) symtab classes = typeCheckStatement (Return expr) symtab classes =
let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab) let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab)
expr' = case expr of expr' = case expr of
Just e -> Just (typeCheckExpression e symtab classes) Just e -> let eTyped = typeCheckExpression e symtab classes
in if getTypeFromExpr eTyped == "null" && isObjectType methodReturnType
then Just (TypedExpression methodReturnType NullLiteral)
else Just eTyped
Nothing -> Nothing Nothing -> Nothing
returnType = maybe "void" getTypeFromExpr expr' returnType = maybe "void" getTypeFromExpr expr'
in if returnType == methodReturnType || isSubtype returnType methodReturnType classes in if returnType == methodReturnType || isSubtype returnType methodReturnType classes
then TypedStatement returnType (Return expr') then TypedStatement returnType (Return expr')
else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes = typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr') in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')