bytecode #6
@ -1,7 +1,7 @@
|
|||||||
// compile all test files using:
|
// compile all test files using:
|
||||||
// ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {}
|
// ls Test/JavaSources/*.java | grep -v ".*Main.java" | xargs -I {} cabal run compiler {}
|
||||||
// compile (in project root) using:
|
// compile (in project root) using:
|
||||||
// javac -g:none -sourcepath Test/JavaSources/ Test/JavaSources/Main.java
|
// pushd Test/JavaSources; javac -g:none Main.java; popd
|
||||||
// afterwards, run using
|
// afterwards, run using
|
||||||
// java -ea -cp Test/JavaSources/ Main
|
// java -ea -cp Test/JavaSources/ Main
|
||||||
|
|
||||||
@ -11,6 +11,7 @@ public class Main {
|
|||||||
TestEmpty empty = new TestEmpty();
|
TestEmpty empty = new TestEmpty();
|
||||||
TestFields fields = new TestFields();
|
TestFields fields = new TestFields();
|
||||||
TestConstructor constructor = new TestConstructor(42);
|
TestConstructor constructor = new TestConstructor(42);
|
||||||
|
TestArithmetic arithmetic = new TestArithmetic();
|
||||||
TestMultipleClasses multipleClasses = new TestMultipleClasses();
|
TestMultipleClasses multipleClasses = new TestMultipleClasses();
|
||||||
TestRecursion recursion = new TestRecursion(10);
|
TestRecursion recursion = new TestRecursion(10);
|
||||||
TestMalicious malicious = new TestMalicious();
|
TestMalicious malicious = new TestMalicious();
|
||||||
@ -21,6 +22,10 @@ public class Main {
|
|||||||
assert fields.a == 0 && fields.b == 42;
|
assert fields.a == 0 && fields.b == 42;
|
||||||
// constructor parameters override initializers
|
// constructor parameters override initializers
|
||||||
assert constructor.a == 42;
|
assert constructor.a == 42;
|
||||||
|
// basic arithmetics
|
||||||
|
assert arithmetic.basic(1, 2, 3) == 2;
|
||||||
|
// we have boolean logic as well
|
||||||
|
assert arithmetic.logic(false, false, true) == true;
|
||||||
// multiple classes within one file work. Referencing another classes fields/methods works.
|
// multiple classes within one file work. Referencing another classes fields/methods works.
|
||||||
assert multipleClasses.a.a == 42;
|
assert multipleClasses.a.a == 42;
|
||||||
// self-referencing classes work.
|
// self-referencing classes work.
|
||||||
|
11
Test/JavaSources/TestArithmetic.java
Normal file
11
Test/JavaSources/TestArithmetic.java
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
public class TestArithmetic {
|
||||||
|
public int basic(int a, int b, int c)
|
||||||
|
{
|
||||||
|
return a + b - c * a / b % c;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean logic(boolean a, boolean b, boolean c)
|
||||||
|
{
|
||||||
|
return !a && (c || b);
|
||||||
|
}
|
||||||
|
}
|
@ -24,4 +24,11 @@ public class TestRecursion {
|
|||||||
return fibonacci(n - 1) + this.fibonacci(n - 2);
|
return fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int ackermann(int m, int n)
|
||||||
|
{
|
||||||
|
if (m == 0) return n + 1;
|
||||||
|
if (n == 0) return ackermann(m - 1, 1);
|
||||||
|
return ackermann(m - 1, ackermann(m, n - 1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,20 +4,7 @@ Die Bytecodegenerierung ist letztendlich eine zweistufige Transformation:
|
|||||||
|
|
||||||
`Getypter AST -> [ClassFile] -> [[Word8]]`
|
`Getypter AST -> [ClassFile] -> [[Word8]]`
|
||||||
|
|
||||||
Vom AST, der bereits den Typcheck durchlaufen hat, wird zunächst eine Abbildung in die einzelnen ClassFiles vorgenommen. Diese ClassFiles werden anschließend in deren Byte-Repräsentation serialisiert.
|
Vom AST, der bereits den Typcheck durchlaufen hat, wird zunächst eine Abbildung in die einzelnen ClassFiles vorgenommen. Diese ClassFiles werden anschließend in deren Byte-Repräsentation serialisiert. Dieser Teil der Aufgabenstellung wurde gemeinsam von Christian Brier und Matthias Raba umgesetzt.
|
||||||
|
|
||||||
## Serialisierung
|
|
||||||
|
|
||||||
Damit Bytecode generiert werden kann, braucht es Strukturen, die die Daten halten, die letztendlich serialisiert werden. Die JVM erwartet den kompilierten Code in handliche Pakete verpackt. Die Struktur dieser Pakete ist [so definiert](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html).
|
|
||||||
|
|
||||||
Jede Struktur, die in dieser übergreifenden Class File vorkommt, haben wir in Haskell abgebildet. Es gibt z.B die Struktur "ClassFile", die wiederum weitere Strukturen wie z.B Informationen über Felder oder Methoden der Klasse. Alle diese Strukturen implementieren folgendes TypeClass:
|
|
||||||
|
|
||||||
```
|
|
||||||
class Serializable a where
|
|
||||||
serialize :: a -> [Word8]
|
|
||||||
```
|
|
||||||
|
|
||||||
Die Struktur ClassFile ruft für deren Kinder rekursiv diese `serialize` Funktion auf. Am Ende bleibt eine flache Word8-Liste übrig, die Serialisierung ist damit abgeschlossen.
|
|
||||||
|
|
||||||
## Codegenerierung
|
## Codegenerierung
|
||||||
|
|
||||||
@ -32,7 +19,6 @@ Die Idee hinter beiden ist, dass sie jeweils zwei Inputs haben, wobei der Rückg
|
|||||||
Der Nutzer ruft beispielsweise die Funktion `classBuilder` auf. Diese wendet nach und nach folgende Transformationen an:
|
Der Nutzer ruft beispielsweise die Funktion `classBuilder` auf. Diese wendet nach und nach folgende Transformationen an:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
methodsWithInjectedConstructor = injectDefaultConstructor methods
|
methodsWithInjectedConstructor = injectDefaultConstructor methods
|
||||||
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
|
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
|
||||||
|
|
||||||
@ -47,7 +33,7 @@ Zuerst wird (falls notwendig) ein leerer Defaultkonstruktor in die Classfile ein
|
|||||||
2. Hinzufügen aller Methoden (nur Prototypen)
|
2. Hinzufügen aller Methoden (nur Prototypen)
|
||||||
3. Hinzufügen des Bytecodes in allen Methoden
|
3. Hinzufügen des Bytecodes in allen Methoden
|
||||||
|
|
||||||
Die Unterteilung von Schritt 2 und 3 ist deswegen notwendig, weil der Code einer Methode auch eine andere, erst nachher deklarierte Methode aufrufen kann. Nach Schritt 2 sind alle Methoden der Klasse bekannt. Wie beschrieben wird auch hier der Zustand über alle Faltungen mitgenommen. Jeder Schritt hat Zugriff auf alle Daten, die aus dem vorherigen Schritt bleiben. Sukkzessive wird eine korrekte ClassFile aufgebaut.
|
Die Unterteilung von Schritt 2 und 3 ist deswegen notwendig, weil der Code einer Methode auch eine andere, erst nachher deklarierte Methode aufrufen kann. Nach Schritt 2 sind alle Methoden der Klasse bekannt. Wie beschrieben wird auch hier der Zustand über alle Faltungen mitgenommen. Jeder Schritt hat Zugriff auf alle Daten, die aus dem vorherigen Schritt bleiben. Sukzessive wird eine korrekte ClassFile aufgebaut.
|
||||||
|
|
||||||
Besonders interessant ist hierbei Schritt 3. Dort wird das Verhalten jeder einzelnen Methode in Bytecode übersetzt. In diesem Schritt werden zusätzlich zu den `Buildern` noch die `Assembler` verwendet (Definition siehe oben.) Die Assembler funktionieren ähnlich wie die Builder, arbeiten allerdings nicht auf einer ClassFile, sondern auf dem Inhalt einer Methode: Sie verarbeiten jeweils ein Tupel:
|
Besonders interessant ist hierbei Schritt 3. Dort wird das Verhalten jeder einzelnen Methode in Bytecode übersetzt. In diesem Schritt werden zusätzlich zu den `Buildern` noch die `Assembler` verwendet (Definition siehe oben.) Die Assembler funktionieren ähnlich wie die Builder, arbeiten allerdings nicht auf einer ClassFile, sondern auf dem Inhalt einer Methode: Sie verarbeiten jeweils ein Tupel:
|
||||||
|
|
||||||
@ -67,3 +53,27 @@ assembleExpression (constants, ops, lvars) (TypedExpression _ NullLiteral) =
|
|||||||
Hier werden die Konstanten und lokalen Variablen des Inputs nicht berührt, dem Bytecode wird lediglich die Operation `aconst_null` hinzugefügt. Damit ist das Verhalten des gematchten Inputs - eines Nullliterals - abgebildet.
|
Hier werden die Konstanten und lokalen Variablen des Inputs nicht berührt, dem Bytecode wird lediglich die Operation `aconst_null` hinzugefügt. Damit ist das Verhalten des gematchten Inputs - eines Nullliterals - abgebildet.
|
||||||
|
|
||||||
Die Assembler rufen sich teilweise rekursiv selbst auf, da ja auch der AST verschachteltes Verhalten abbilden kann. Der Startpunkt für die Assembly einer Methode ist der Builder `methodAssembler`. Dieser entspricht Schritt 3 in der obigen Übersicht.
|
Die Assembler rufen sich teilweise rekursiv selbst auf, da ja auch der AST verschachteltes Verhalten abbilden kann. Der Startpunkt für die Assembly einer Methode ist der Builder `methodAssembler`. Dieser entspricht Schritt 3 in der obigen Übersicht.
|
||||||
|
|
||||||
|
## Serialisierung
|
||||||
|
|
||||||
|
Damit Bytecode generiert werden kann, braucht es Strukturen, die die Daten halten, die letztendlich serialisiert werden. Die JVM erwartet den kompilierten Code in handliche Pakete verpackt. Die Struktur dieser Pakete ist [so definiert](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html).
|
||||||
|
|
||||||
|
Jede Struktur, die in dieser übergreifenden Class File vorkommt, haben wir in Haskell abgebildet. Es gibt z.B die Struktur "ClassFile", die wiederum weitere Strukturen wie z.B Informationen über Felder oder Methoden der Klasse beinhaltet. Alle diese Strukturen implementieren folgende TypeClass:
|
||||||
|
|
||||||
|
```
|
||||||
|
class Serializable a where
|
||||||
|
serialize :: a -> [Word8]
|
||||||
|
```
|
||||||
|
|
||||||
|
Hier ist ein Beispiel anhand der Serialisierung der einzelnen Operationen:
|
||||||
|
|
||||||
|
```
|
||||||
|
instance Serializable Operation where
|
||||||
|
serialize Opiadd = [0x60]
|
||||||
|
serialize Opisub = [0x64]
|
||||||
|
serialize Opimul = [0x68]
|
||||||
|
...
|
||||||
|
serialize (Opgetfield index) = 0xB4 : unpackWord16 index
|
||||||
|
```
|
||||||
|
|
||||||
|
Die Struktur ClassFile ruft für deren Kinder rekursiv diese `serialize` Funktion auf und konkateniert die Ergebnisse. Am Ende bleibt eine flache Word8-Liste übrig, die Serialisierung ist damit abgeschlossen. Da der Typecheck sicherstellt, dass alle referenzierten Methoden/Felder gültig sind, kann die Übersetzung der einzelnen Klassen voneinander unabhängig geschehen.
|
@ -12,12 +12,12 @@ type Assembler a = ([ConstantInfo], [Operation], [String]) -> a -> ([ConstantInf
|
|||||||
|
|
||||||
assembleExpression :: Assembler Expression
|
assembleExpression :: Assembler Expression
|
||||||
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
|
assembleExpression (constants, ops, lvars) (TypedExpression _ (BinaryOperation op a b))
|
||||||
| elem op [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let
|
| op `elem` [Addition, Subtraction, Multiplication, Division, Modulo, BitwiseAnd, BitwiseOr, BitwiseXor, And, Or] = let
|
||||||
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
|
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
|
||||||
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
|
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
|
||||||
in
|
in
|
||||||
(bConstants, bOps ++ [binaryOperation op], lvars)
|
(bConstants, bOps ++ [binaryOperation op], lvars)
|
||||||
| elem op [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
|
| op `elem` [CompareEqual, CompareNotEqual, CompareLessThan, CompareLessOrEqual, CompareGreaterThan, CompareGreaterOrEqual] = let
|
||||||
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
|
(aConstants, aOps, _) = assembleExpression (constants, ops, lvars) a
|
||||||
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
|
(bConstants, bOps, _) = assembleExpression (aConstants, aOps, lvars) b
|
||||||
cmp_op = comparisonOperation op 9
|
cmp_op = comparisonOperation op 9
|
||||||
@ -60,7 +60,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression _ (UnaryOperation Mi
|
|||||||
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name))
|
assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable name))
|
||||||
| name == "this" = (constants, ops ++ [Opaload 0], lvars)
|
| name == "this" = (constants, ops ++ [Opaload 0], lvars)
|
||||||
| otherwise = let
|
| otherwise = let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
isPrimitive = elem dtype ["char", "boolean", "int"]
|
isPrimitive = elem dtype ["char", "boolean", "int"]
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
|
Just index -> (constants, ops ++ if isPrimitive then [Opiload (fromIntegral index)] else [Opaload (fromIntegral index)], lvars)
|
||||||
@ -69,7 +69,7 @@ assembleExpression (constants, ops, lvars) (TypedExpression dtype (LocalVariable
|
|||||||
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
|
assembleExpression (constants, ops, lvars) (TypedExpression dtype (StatementExpressionExpression stmtexp)) =
|
||||||
assembleStatementExpression (constants, ops, lvars) stmtexp
|
assembleStatementExpression (constants, ops, lvars) stmtexp
|
||||||
|
|
||||||
assembleExpression _ expr = error ("unimplemented: " ++ show expr)
|
assembleExpression _ expr = error ("Unknown expression: " ++ show expr)
|
||||||
|
|
||||||
assembleNameChain :: Assembler Expression
|
assembleNameChain :: Assembler Expression
|
||||||
assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) =
|
assembleNameChain input (TypedExpression _ (BinaryOperation NameResolution (TypedExpression atype a) (TypedExpression _ (FieldVariable _)))) =
|
||||||
@ -84,7 +84,7 @@ assembleStatementExpression
|
|||||||
target = resolveNameChain (TypedExpression dtype receiver)
|
target = resolveNameChain (TypedExpression dtype receiver)
|
||||||
in case target of
|
in case target of
|
||||||
(TypedExpression dtype (LocalVariable name)) -> let
|
(TypedExpression dtype (LocalVariable name)) -> let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
|
(constants_a, ops_a, _) = assembleExpression (constants, ops, lvars) expr
|
||||||
isPrimitive = elem dtype ["char", "boolean", "int"]
|
isPrimitive = elem dtype ["char", "boolean", "int"]
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
@ -99,7 +99,7 @@ assembleStatementExpression
|
|||||||
(constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr
|
(constants_a, ops_a, _) = assembleExpression (constants_r, ops_r, lvars) expr
|
||||||
in
|
in
|
||||||
(constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
(constants_a, ops_a ++ [Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
||||||
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
|
something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
|
||||||
|
|
||||||
assembleStatementExpression
|
assembleStatementExpression
|
||||||
(constants, ops, lvars)
|
(constants, ops, lvars)
|
||||||
@ -107,8 +107,8 @@ assembleStatementExpression
|
|||||||
target = resolveNameChain (TypedExpression dtype receiver)
|
target = resolveNameChain (TypedExpression dtype receiver)
|
||||||
in case target of
|
in case target of
|
||||||
(TypedExpression dtype (LocalVariable name)) -> let
|
(TypedExpression dtype (LocalVariable name)) -> let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
expr = (TypedExpression dtype (LocalVariable name))
|
expr = TypedExpression dtype (LocalVariable name)
|
||||||
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars)
|
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opiadd, Opdup, Opistore (fromIntegral index)], lvars)
|
||||||
@ -121,7 +121,7 @@ assembleStatementExpression
|
|||||||
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
||||||
in
|
in
|
||||||
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opiadd, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
||||||
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
|
something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
|
||||||
|
|
||||||
assembleStatementExpression
|
assembleStatementExpression
|
||||||
(constants, ops, lvars)
|
(constants, ops, lvars)
|
||||||
@ -129,8 +129,8 @@ assembleStatementExpression
|
|||||||
target = resolveNameChain (TypedExpression dtype receiver)
|
target = resolveNameChain (TypedExpression dtype receiver)
|
||||||
in case target of
|
in case target of
|
||||||
(TypedExpression dtype (LocalVariable name)) -> let
|
(TypedExpression dtype (LocalVariable name)) -> let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
expr = (TypedExpression dtype (LocalVariable name))
|
expr = TypedExpression dtype (LocalVariable name)
|
||||||
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars)
|
Just index -> (exprConstants, exprOps ++ [Opsipush 1, Opisub, Opdup, Opistore (fromIntegral index)], lvars)
|
||||||
@ -143,7 +143,7 @@ assembleStatementExpression
|
|||||||
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
||||||
in
|
in
|
||||||
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opsipush 1, Opisub, Opdup_x1, Opputfield (fromIntegral fieldIndex)], lvars)
|
||||||
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
|
something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
|
||||||
|
|
||||||
assembleStatementExpression
|
assembleStatementExpression
|
||||||
(constants, ops, lvars)
|
(constants, ops, lvars)
|
||||||
@ -151,8 +151,8 @@ assembleStatementExpression
|
|||||||
target = resolveNameChain (TypedExpression dtype receiver)
|
target = resolveNameChain (TypedExpression dtype receiver)
|
||||||
in case target of
|
in case target of
|
||||||
(TypedExpression dtype (LocalVariable name)) -> let
|
(TypedExpression dtype (LocalVariable name)) -> let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
expr = (TypedExpression dtype (LocalVariable name))
|
expr = TypedExpression dtype (LocalVariable name)
|
||||||
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars)
|
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opiadd, Opistore (fromIntegral index)], lvars)
|
||||||
@ -165,7 +165,7 @@ assembleStatementExpression
|
|||||||
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
||||||
in
|
in
|
||||||
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars)
|
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opiadd, Opputfield (fromIntegral fieldIndex)], lvars)
|
||||||
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
|
something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
|
||||||
|
|
||||||
assembleStatementExpression
|
assembleStatementExpression
|
||||||
(constants, ops, lvars)
|
(constants, ops, lvars)
|
||||||
@ -173,8 +173,8 @@ assembleStatementExpression
|
|||||||
target = resolveNameChain (TypedExpression dtype receiver)
|
target = resolveNameChain (TypedExpression dtype receiver)
|
||||||
in case target of
|
in case target of
|
||||||
(TypedExpression dtype (LocalVariable name)) -> let
|
(TypedExpression dtype (LocalVariable name)) -> let
|
||||||
localIndex = findIndex ((==) name) lvars
|
localIndex = elemIndex name lvars
|
||||||
expr = (TypedExpression dtype (LocalVariable name))
|
expr = TypedExpression dtype (LocalVariable name)
|
||||||
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
(exprConstants, exprOps, _) = assembleExpression (constants, ops, lvars) expr
|
||||||
in case localIndex of
|
in case localIndex of
|
||||||
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars)
|
Just index -> (exprConstants, exprOps ++ [Opdup, Opsipush 1, Opisub, Opistore (fromIntegral index)], lvars)
|
||||||
@ -187,7 +187,7 @@ assembleStatementExpression
|
|||||||
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
(constants_r, ops_r, _) = assembleNameChain (constants_f, ops, lvars) (TypedExpression dtype receiver)
|
||||||
in
|
in
|
||||||
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars)
|
(constants_r, ops_r ++ [Opdup, Opgetfield (fromIntegral fieldIndex), Opdup_x1, Opsipush 1, Opisub, Opputfield (fromIntegral fieldIndex)], lvars)
|
||||||
something_else -> error ("expected TypedExpression, but got: " ++ show something_else)
|
something_else -> error ("Expected TypedExpression, but got: " ++ show something_else)
|
||||||
|
|
||||||
assembleStatementExpression
|
assembleStatementExpression
|
||||||
(constants, ops, lvars)
|
(constants, ops, lvars)
|
||||||
@ -231,7 +231,7 @@ assembleStatement (constants, ops, lvars) (TypedStatement dtype (If expr if_stmt
|
|||||||
else_length = sum (map opcodeEncodingLength ops_elsea)
|
else_length = sum (map opcodeEncodingLength ops_elsea)
|
||||||
in case dtype of
|
in case dtype of
|
||||||
"void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars)
|
"void" -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 6)] ++ ops_ifa ++ [Opgoto (else_length + 3)] ++ ops_elsea, lvars)
|
||||||
otherwise -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars)
|
_ -> (constants_ifa, ops ++ ops_cmp ++ [Opsipush 0, Opif_icmpeq (if_length + 3)] ++ ops_ifa ++ ops_elsea, lvars)
|
||||||
|
|
||||||
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
|
assembleStatement (constants, ops, lvars) (TypedStatement _ (While expr stmt)) = let
|
||||||
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
|
(constants_cmp, ops_cmp, _) = assembleExpression (constants, [], lvars) expr
|
||||||
@ -257,20 +257,19 @@ assembleStatement (constants, ops, lvars) (TypedStatement _ (StatementExpression
|
|||||||
in
|
in
|
||||||
(constants_e, ops_e ++ [Oppop], lvars_e)
|
(constants_e, ops_e ++ [Oppop], lvars_e)
|
||||||
|
|
||||||
assembleStatement _ stmt = error ("Not yet implemented: " ++ show stmt)
|
assembleStatement _ stmt = error ("Unknown statement: " ++ show stmt)
|
||||||
|
|
||||||
|
|
||||||
assembleMethod :: Assembler MethodDeclaration
|
assembleMethod :: Assembler MethodDeclaration
|
||||||
assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements)))
|
assembleMethod (constants, ops, lvars) (MethodDeclaration returntype name _ (TypedStatement _ (Block statements)))
|
||||||
| name == "<init>" = let
|
| name == "<init>" = let
|
||||||
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
|
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
|
||||||
init_ops = [Opaload 0, Opinvokespecial 2]
|
|
||||||
in
|
in
|
||||||
(constants_a, init_ops ++ ops_a ++ [Opreturn], lvars_a)
|
(constants_a, [Opaload 0, Opinvokespecial 2] ++ ops_a ++ [Opreturn], lvars_a)
|
||||||
| otherwise = case returntype of
|
| otherwise = case returntype of
|
||||||
"void" -> let
|
"void" -> let
|
||||||
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
|
(constants_a, ops_a, lvars_a) = foldl assembleStatement (constants, ops, lvars) statements
|
||||||
in
|
in
|
||||||
(constants_a, ops_a ++ [Opreturn], lvars_a)
|
(constants_a, ops_a ++ [Opreturn], lvars_a)
|
||||||
otherwise -> foldl assembleStatement (constants, ops, lvars) statements
|
_ -> foldl assembleStatement (constants, ops, lvars) statements
|
||||||
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt)
|
assembleMethod _ (MethodDeclaration _ _ _ stmt) = error ("Typed block expected for method body, got: " ++ show stmt)
|
||||||
|
@ -22,14 +22,14 @@ fieldBuilder (VariableDeclaration datatype name _) input = let
|
|||||||
]
|
]
|
||||||
field = MemberInfo {
|
field = MemberInfo {
|
||||||
memberAccessFlags = accessPublic,
|
memberAccessFlags = accessPublic,
|
||||||
memberNameIndex = (fromIntegral (baseIndex + 2)),
|
memberNameIndex = fromIntegral (baseIndex + 2),
|
||||||
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
|
memberDescriptorIndex = fromIntegral (baseIndex + 3),
|
||||||
memberAttributes = []
|
memberAttributes = []
|
||||||
}
|
}
|
||||||
in
|
in
|
||||||
input {
|
input {
|
||||||
constantPool = (constantPool input) ++ constants,
|
constantPool = constantPool input ++ constants,
|
||||||
fields = (fields input) ++ [field]
|
fields = fields input ++ [field]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -46,18 +46,17 @@ methodBuilder (MethodDeclaration returntype name parameters statement) input = l
|
|||||||
|
|
||||||
method = MemberInfo {
|
method = MemberInfo {
|
||||||
memberAccessFlags = accessPublic,
|
memberAccessFlags = accessPublic,
|
||||||
memberNameIndex = (fromIntegral (baseIndex + 2)),
|
memberNameIndex = fromIntegral (baseIndex + 2),
|
||||||
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
|
memberDescriptorIndex = fromIntegral (baseIndex + 3),
|
||||||
memberAttributes = []
|
memberAttributes = []
|
||||||
}
|
}
|
||||||
in
|
in
|
||||||
input {
|
input {
|
||||||
constantPool = (constantPool input) ++ constants,
|
constantPool = constantPool input ++ constants,
|
||||||
methods = (methods input) ++ [method]
|
methods = methods input ++ [method]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
methodAssembler :: ClassFileBuilder MethodDeclaration
|
methodAssembler :: ClassFileBuilder MethodDeclaration
|
||||||
methodAssembler (MethodDeclaration returntype name parameters statement) input = let
|
methodAssembler (MethodDeclaration returntype name parameters statement) input = let
|
||||||
methodConstantIndex = findMethodIndex input name
|
methodConstantIndex = findMethodIndex input name
|
||||||
@ -66,21 +65,22 @@ methodAssembler (MethodDeclaration returntype name parameters statement) input =
|
|||||||
Just index -> let
|
Just index -> let
|
||||||
declaration = MethodDeclaration returntype name parameters statement
|
declaration = MethodDeclaration returntype name parameters statement
|
||||||
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
|
paramNames = "this" : [name | ParameterDeclaration _ name <- parameters]
|
||||||
in case (splitAt index (methods input)) of
|
in case splitAt index (methods input) of
|
||||||
(pre, []) -> input
|
(pre, []) -> input
|
||||||
(pre, method : post) -> let
|
(pre, method : post) -> let
|
||||||
(_, bytecode, _) = assembleMethod (constantPool input, [], paramNames) declaration
|
(constants, bytecode, aParamNames) = assembleMethod (constantPool input, [], paramNames) declaration
|
||||||
assembledMethod = method {
|
assembledMethod = method {
|
||||||
memberAttributes = [
|
memberAttributes = [
|
||||||
CodeAttribute {
|
CodeAttribute {
|
||||||
attributeMaxStack = 420,
|
attributeMaxStack = fromIntegral $ maxStackDepth constants bytecode,
|
||||||
attributeMaxLocals = 420,
|
attributeMaxLocals = fromIntegral $ length aParamNames,
|
||||||
attributeCode = bytecode
|
attributeCode = bytecode
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
in
|
in
|
||||||
input {
|
input {
|
||||||
|
constantPool = constants,
|
||||||
methods = pre ++ (assembledMethod : post)
|
methods = pre ++ (assembledMethod : post)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,11 +94,12 @@ classBuilder (Class name methods fields) _ = let
|
|||||||
Utf8Info "java/lang/Object",
|
Utf8Info "java/lang/Object",
|
||||||
Utf8Info "<init>",
|
Utf8Info "<init>",
|
||||||
Utf8Info "()V",
|
Utf8Info "()V",
|
||||||
Utf8Info "Code"
|
Utf8Info "Code",
|
||||||
|
ClassInfo 9,
|
||||||
|
Utf8Info name
|
||||||
]
|
]
|
||||||
nameConstants = [ClassInfo 9, Utf8Info name]
|
|
||||||
nakedClassFile = ClassFile {
|
nakedClassFile = ClassFile {
|
||||||
constantPool = baseConstants ++ nameConstants,
|
constantPool = baseConstants,
|
||||||
accessFlags = accessPublic,
|
accessFlags = accessPublic,
|
||||||
thisClass = 8,
|
thisClass = 8,
|
||||||
superClass = 1,
|
superClass = 1,
|
||||||
@ -107,9 +108,13 @@ classBuilder (Class name methods fields) _ = let
|
|||||||
attributes = []
|
attributes = []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- if a class has no constructor, inject an empty one.
|
||||||
methodsWithInjectedConstructor = injectDefaultConstructor methods
|
methodsWithInjectedConstructor = injectDefaultConstructor methods
|
||||||
|
-- for every constructor, prepend all initialization assignments for fields.
|
||||||
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
|
methodsWithInjectedInitializers = injectFieldInitializers name fields methodsWithInjectedConstructor
|
||||||
|
|
||||||
|
-- add fields, then method bodies to the classfile. After all referable names are known,
|
||||||
|
-- assemble the methods into bytecode.
|
||||||
classFileWithFields = foldr fieldBuilder nakedClassFile fields
|
classFileWithFields = foldr fieldBuilder nakedClassFile fields
|
||||||
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers
|
classFileWithMethods = foldr methodBuilder classFileWithFields methodsWithInjectedInitializers
|
||||||
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers
|
classFileWithAssembledMethods = foldr methodAssembler classFileWithMethods methodsWithInjectedInitializers
|
||||||
|
@ -1,14 +1,4 @@
|
|||||||
module ByteCode.ClassFile(
|
module ByteCode.ClassFile where
|
||||||
ConstantInfo(..),
|
|
||||||
Attribute(..),
|
|
||||||
MemberInfo(..),
|
|
||||||
ClassFile(..),
|
|
||||||
Operation(..),
|
|
||||||
serialize,
|
|
||||||
emptyClassFile,
|
|
||||||
opcodeEncodingLength,
|
|
||||||
className
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Data.Word
|
import Data.Word
|
||||||
import Data.Int
|
import Data.Int
|
||||||
@ -99,10 +89,10 @@ emptyClassFile = ClassFile {
|
|||||||
|
|
||||||
className :: ClassFile -> String
|
className :: ClassFile -> String
|
||||||
className classFile = let
|
className classFile = let
|
||||||
classInfo = (constantPool classFile)!!(fromIntegral (thisClass classFile))
|
classInfo = constantPool classFile !! fromIntegral (thisClass classFile)
|
||||||
in case classInfo of
|
in case classInfo of
|
||||||
Utf8Info className -> className
|
Utf8Info className -> className
|
||||||
otherwise -> error ("expected Utf8Info but got: " ++ show otherwise)
|
unexpected_element -> error ("expected Utf8Info but got: " ++ show unexpected_element)
|
||||||
|
|
||||||
|
|
||||||
opcodeEncodingLength :: Operation -> Word16
|
opcodeEncodingLength :: Operation -> Word16
|
||||||
@ -201,10 +191,10 @@ instance Serializable Attribute where
|
|||||||
serialize (CodeAttribute { attributeMaxStack = maxStack,
|
serialize (CodeAttribute { attributeMaxStack = maxStack,
|
||||||
attributeMaxLocals = maxLocals,
|
attributeMaxLocals = maxLocals,
|
||||||
attributeCode = code }) = let
|
attributeCode = code }) = let
|
||||||
assembledCode = concat (map serialize code)
|
assembledCode = concatMap serialize code
|
||||||
in
|
in
|
||||||
unpackWord16 7 -- attribute_name_index
|
unpackWord16 7 -- attribute_name_index
|
||||||
++ unpackWord32 (12 + (fromIntegral (length assembledCode))) -- attribute_length
|
++ unpackWord32 (12 + fromIntegral (length assembledCode)) -- attribute_length
|
||||||
++ unpackWord16 maxStack -- max_stack
|
++ unpackWord16 maxStack -- max_stack
|
||||||
++ unpackWord16 maxLocals -- max_locals
|
++ unpackWord16 maxLocals -- max_locals
|
||||||
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length
|
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length
|
||||||
|
@ -4,29 +4,28 @@ import Data.Int
|
|||||||
import Ast
|
import Ast
|
||||||
import ByteCode.ClassFile
|
import ByteCode.ClassFile
|
||||||
import Data.List
|
import Data.List
|
||||||
import Data.Maybe (mapMaybe)
|
import Data.Maybe (mapMaybe, isJust)
|
||||||
import Data.Word (Word8, Word16, Word32)
|
import Data.Word (Word8, Word16, Word32)
|
||||||
|
|
||||||
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
|
-- walks the name resolution chain. returns the innermost Just LocalVariable/FieldVariable or Nothing.
|
||||||
resolveNameChain :: Expression -> Expression
|
resolveNameChain :: Expression -> Expression
|
||||||
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
|
resolveNameChain (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
|
||||||
resolveNameChain (TypedExpression dtype (LocalVariable name)) = (TypedExpression dtype (LocalVariable name))
|
resolveNameChain (TypedExpression dtype (LocalVariable name)) = TypedExpression dtype (LocalVariable name)
|
||||||
resolveNameChain (TypedExpression dtype (FieldVariable name)) = (TypedExpression dtype (FieldVariable name))
|
resolveNameChain (TypedExpression dtype (FieldVariable name)) = TypedExpression dtype (FieldVariable name)
|
||||||
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression))
|
resolveNameChain invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
|
||||||
|
|
||||||
-- walks the name resolution chain. returns the second-to-last item of the namechain.
|
-- walks the name resolution chain. returns the second-to-last item of the namechain.
|
||||||
resolveNameChainOwner :: Expression -> Expression
|
resolveNameChainOwner :: Expression -> Expression
|
||||||
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
|
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a (TypedExpression dtype (FieldVariable name)))) = a
|
||||||
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
|
resolveNameChainOwner (TypedExpression _ (BinaryOperation NameResolution a b)) = resolveNameChain b
|
||||||
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show(invalidExpression))
|
resolveNameChainOwner invalidExpression = error ("expected a NameResolution or Local/Field Variable but got: " ++ show invalidExpression)
|
||||||
|
|
||||||
|
|
||||||
methodDescriptor :: MethodDeclaration -> String
|
methodDescriptor :: MethodDeclaration -> String
|
||||||
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
|
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
|
||||||
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
|
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
|
||||||
in
|
in
|
||||||
"("
|
"("
|
||||||
++ (concat (map datatypeDescriptor parameter_types))
|
++ concatMap datatypeDescriptor parameter_types
|
||||||
++ ")"
|
++ ")"
|
||||||
++ datatypeDescriptor returntype
|
++ datatypeDescriptor returntype
|
||||||
|
|
||||||
@ -35,49 +34,69 @@ methodDescriptorFromParamlist parameters returntype = let
|
|||||||
parameter_types = [datatype | TypedExpression datatype _ <- parameters]
|
parameter_types = [datatype | TypedExpression datatype _ <- parameters]
|
||||||
in
|
in
|
||||||
"("
|
"("
|
||||||
++ (concat (map datatypeDescriptor parameter_types))
|
++ concatMap datatypeDescriptor parameter_types
|
||||||
++ ")"
|
++ ")"
|
||||||
++ datatypeDescriptor returntype
|
++ datatypeDescriptor returntype
|
||||||
|
|
||||||
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
|
-- recursively parses a given type signature into a list of parameter types and the method return type.
|
||||||
memberInfoIsMethod constants info = elem '(' (memberInfoDescriptor constants info)
|
-- As an initial parameter, you can supply ([], "void").
|
||||||
|
parseMethodType :: ([String], String) -> String -> ([String], String)
|
||||||
|
parseMethodType (params, returnType) ('(' : descriptor) = parseMethodType (params, returnType) descriptor
|
||||||
|
parseMethodType (params, returnType) ('I' : descriptor) = parseMethodType (params ++ ["I"], returnType) descriptor
|
||||||
|
parseMethodType (params, returnType) ('C' : descriptor) = parseMethodType (params ++ ["C"], returnType) descriptor
|
||||||
|
parseMethodType (params, returnType) ('Z' : descriptor) = parseMethodType (params ++ ["Z"], returnType) descriptor
|
||||||
|
parseMethodType (params, returnType) ('L' : descriptor) = let
|
||||||
|
typeLength = elemIndex ';' descriptor
|
||||||
|
in case typeLength of
|
||||||
|
Just length -> let
|
||||||
|
(typeName, semicolon : restOfDescriptor) = splitAt length descriptor
|
||||||
|
in
|
||||||
|
parseMethodType (params ++ [typeName], returnType) restOfDescriptor
|
||||||
|
Nothing -> error $ "unterminated class type in function signature: " ++ show descriptor
|
||||||
|
parseMethodType (params, _) (')' : descriptor) = (params, descriptor)
|
||||||
|
parseMethodType _ descriptor = error $ "expected start of type name (L, I, C, Z) but got: " ++ descriptor
|
||||||
|
|
||||||
|
-- given a method index (constant pool index),
|
||||||
|
-- returns the full type of the method. (i.e (LSomething;II)V)
|
||||||
|
methodTypeFromIndex :: [ConstantInfo] -> Int -> String
|
||||||
|
methodTypeFromIndex constants index = case constants !! fromIntegral (index - 1) of
|
||||||
|
MethodRefInfo _ nameAndTypeIndex -> case constants !! fromIntegral (nameAndTypeIndex - 1) of
|
||||||
|
NameAndTypeInfo _ typeIndex -> case constants !! fromIntegral (typeIndex - 1) of
|
||||||
|
Utf8Info typeLiteral -> typeLiteral
|
||||||
|
unexpectedElement -> error "Expected Utf8Info but got: " ++ show unexpectedElement
|
||||||
|
unexpectedElement -> error "Expected NameAndTypeInfo but got: " ++ show unexpectedElement
|
||||||
|
unexpectedElement -> error "Expected MethodRefInfo but got: " ++ show unexpectedElement
|
||||||
|
|
||||||
|
methodParametersFromIndex :: [ConstantInfo] -> Int -> ([String], String)
|
||||||
|
methodParametersFromIndex constants index = parseMethodType ([], "V") (methodTypeFromIndex constants index)
|
||||||
|
|
||||||
|
memberInfoIsMethod :: [ConstantInfo] -> MemberInfo -> Bool
|
||||||
|
memberInfoIsMethod constants info = '(' `elem` memberInfoDescriptor constants info
|
||||||
|
|
||||||
datatypeDescriptor :: String -> String
|
datatypeDescriptor :: String -> String
|
||||||
datatypeDescriptor "void" = "V"
|
datatypeDescriptor "void" = "V"
|
||||||
datatypeDescriptor "int" = "I"
|
datatypeDescriptor "int" = "I"
|
||||||
datatypeDescriptor "char" = "C"
|
datatypeDescriptor "char" = "C"
|
||||||
datatypeDescriptor "boolean" = "B"
|
datatypeDescriptor "boolean" = "Z"
|
||||||
datatypeDescriptor x = "L" ++ x ++ ";"
|
datatypeDescriptor x = "L" ++ x ++ ";"
|
||||||
|
|
||||||
|
|
||||||
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
|
memberInfoDescriptor :: [ConstantInfo] -> MemberInfo -> String
|
||||||
memberInfoDescriptor constants MemberInfo {
|
memberInfoDescriptor constants MemberInfo { memberDescriptorIndex = descriptorIndex } = let
|
||||||
memberAccessFlags = _,
|
descriptor = constants !! (fromIntegral descriptorIndex - 1)
|
||||||
memberNameIndex = _,
|
|
||||||
memberDescriptorIndex = descriptorIndex,
|
|
||||||
memberAttributes = _ } = let
|
|
||||||
descriptor = constants!!((fromIntegral descriptorIndex) - 1)
|
|
||||||
in case descriptor of
|
in case descriptor of
|
||||||
Utf8Info descriptorText -> descriptorText
|
Utf8Info descriptorText -> descriptorText
|
||||||
_ -> ("Invalid Item at Constant pool index " ++ show descriptorIndex)
|
_ -> "Invalid Item at Constant pool index " ++ show descriptorIndex
|
||||||
|
|
||||||
|
|
||||||
memberInfoName :: [ConstantInfo] -> MemberInfo -> String
|
memberInfoName :: [ConstantInfo] -> MemberInfo -> String
|
||||||
memberInfoName constants MemberInfo {
|
memberInfoName constants MemberInfo { memberNameIndex = nameIndex } = let
|
||||||
memberAccessFlags = _,
|
name = constants !! (fromIntegral nameIndex - 1)
|
||||||
memberNameIndex = nameIndex,
|
|
||||||
memberDescriptorIndex = _,
|
|
||||||
memberAttributes = _ } = let
|
|
||||||
name = constants!!((fromIntegral nameIndex) - 1)
|
|
||||||
in case name of
|
in case name of
|
||||||
Utf8Info nameText -> nameText
|
Utf8Info nameText -> nameText
|
||||||
_ -> ("Invalid Item at Constant pool index " ++ show nameIndex)
|
_ -> "Invalid Item at Constant pool index " ++ show nameIndex
|
||||||
|
|
||||||
|
|
||||||
returnOperation :: DataType -> Operation
|
returnOperation :: DataType -> Operation
|
||||||
returnOperation dtype
|
returnOperation dtype
|
||||||
| elem dtype ["int", "char", "boolean"] = Opireturn
|
| dtype `elem` ["int", "char", "boolean"] = Opireturn
|
||||||
| otherwise = Opareturn
|
| otherwise = Opareturn
|
||||||
|
|
||||||
binaryOperation :: BinaryOperator -> Operation
|
binaryOperation :: BinaryOperator -> Operation
|
||||||
@ -100,50 +119,27 @@ comparisonOperation CompareLessOrEqual branchLocation = Opif_icmple branchLoc
|
|||||||
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
|
comparisonOperation CompareGreaterThan branchLocation = Opif_icmpgt branchLocation
|
||||||
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation
|
comparisonOperation CompareGreaterOrEqual branchLocation = Opif_icmpge branchLocation
|
||||||
|
|
||||||
findFieldIndex :: [ConstantInfo] -> String -> Maybe Int
|
comparisonOffset :: Operation -> Maybe Int
|
||||||
findFieldIndex constants name = let
|
comparisonOffset (Opif_icmpeq offset) = Just $ fromIntegral offset
|
||||||
fieldRefNameInfos = [
|
comparisonOffset (Opif_icmpne offset) = Just $ fromIntegral offset
|
||||||
-- we only skip one entry to get the name since the Java constant pool
|
comparisonOffset (Opif_icmplt offset) = Just $ fromIntegral offset
|
||||||
-- is 1-indexed (why)
|
comparisonOffset (Opif_icmple offset) = Just $ fromIntegral offset
|
||||||
(index, constants!!(fromIntegral index + 1))
|
comparisonOffset (Opif_icmpgt offset) = Just $ fromIntegral offset
|
||||||
| (index, FieldRefInfo classIndex _) <- (zip [1..] constants)
|
comparisonOffset (Opif_icmpge offset) = Just $ fromIntegral offset
|
||||||
]
|
comparisonOffset anything_else = Nothing
|
||||||
fieldRefNames = map (\(index, nameInfo) -> case nameInfo of
|
|
||||||
Utf8Info fieldName -> (index, fieldName)
|
|
||||||
something_else -> error ("Expected UTF8Info but got" ++ show something_else))
|
|
||||||
fieldRefNameInfos
|
|
||||||
fieldIndex = find (\(index, fieldName) -> fieldName == name) fieldRefNames
|
|
||||||
in case fieldIndex of
|
|
||||||
Just (index, _) -> Just index
|
|
||||||
Nothing -> Nothing
|
|
||||||
|
|
||||||
findMethodRefIndex :: [ConstantInfo] -> String -> Maybe Int
|
|
||||||
findMethodRefIndex constants name = let
|
|
||||||
methodRefNameInfos = [
|
|
||||||
-- we only skip one entry to get the name since the Java constant pool
|
|
||||||
-- is 1-indexed (why)
|
|
||||||
(index, constants!!(fromIntegral index + 1))
|
|
||||||
| (index, MethodRefInfo _ _) <- (zip [1..] constants)
|
|
||||||
]
|
|
||||||
methodRefNames = map (\(index, nameInfo) -> case nameInfo of
|
|
||||||
Utf8Info methodName -> (index, methodName)
|
|
||||||
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
|
|
||||||
methodRefNameInfos
|
|
||||||
methodIndex = find (\(index, methodName) -> methodName == name) methodRefNames
|
|
||||||
in case methodIndex of
|
|
||||||
Just (index, _) -> Just index
|
|
||||||
Nothing -> Nothing
|
|
||||||
|
|
||||||
|
isComparisonOperation :: Operation -> Bool
|
||||||
|
isComparisonOperation op = isJust (comparisonOffset op)
|
||||||
|
|
||||||
findMethodIndex :: ClassFile -> String -> Maybe Int
|
findMethodIndex :: ClassFile -> String -> Maybe Int
|
||||||
findMethodIndex classFile name = let
|
findMethodIndex classFile name = let
|
||||||
constants = constantPool classFile
|
constants = constantPool classFile
|
||||||
in
|
in
|
||||||
findIndex (\method -> ((memberInfoIsMethod constants method) && (memberInfoName constants method) == name)) (methods classFile)
|
findIndex (\method -> memberInfoIsMethod constants method && memberInfoName constants method == name) (methods classFile)
|
||||||
|
|
||||||
findClassIndex :: [ConstantInfo] -> String -> Maybe Int
|
findClassIndex :: [ConstantInfo] -> String -> Maybe Int
|
||||||
findClassIndex constants name = let
|
findClassIndex constants name = let
|
||||||
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- (zip[1..] constants)]
|
classNameIndices = [(index, constants!!(fromIntegral nameIndex - 1)) | (index, ClassInfo nameIndex) <- zip [1..] constants]
|
||||||
classNames = map (\(index, nameInfo) -> case nameInfo of
|
classNames = map (\(index, nameInfo) -> case nameInfo of
|
||||||
Utf8Info className -> (index, className)
|
Utf8Info className -> (index, className)
|
||||||
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
|
something_else -> error ("Expected UTF8Info but got " ++ show something_else))
|
||||||
@ -157,10 +153,10 @@ getKnownMembers :: [ConstantInfo] -> [(Int, (String, String, String))]
|
|||||||
getKnownMembers constants = let
|
getKnownMembers constants = let
|
||||||
fieldsClassAndNT = [
|
fieldsClassAndNT = [
|
||||||
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
|
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
|
||||||
| (index, FieldRefInfo classIndex nameTypeIndex) <- (zip [1..] constants)
|
| (index, FieldRefInfo classIndex nameTypeIndex) <- zip [1..] constants
|
||||||
] ++ [
|
] ++ [
|
||||||
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
|
(index, constants!!(fromIntegral classIndex - 1), constants!!(fromIntegral nameTypeIndex - 1))
|
||||||
| (index, MethodRefInfo classIndex nameTypeIndex) <- (zip [1..] constants)
|
| (index, MethodRefInfo classIndex nameTypeIndex) <- zip [1..] constants
|
||||||
]
|
]
|
||||||
|
|
||||||
fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of
|
fieldsClassNameType = map (\(index, nameInfo, nameTypeInfo) -> case (nameInfo, nameTypeInfo) of
|
||||||
@ -179,7 +175,7 @@ getKnownMembers constants = let
|
|||||||
getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int)
|
getClassIndex :: [ConstantInfo] -> String -> ([ConstantInfo], Int)
|
||||||
getClassIndex constants name = case findClassIndex constants name of
|
getClassIndex constants name = case findClassIndex constants name of
|
||||||
Just index -> (constants, index)
|
Just index -> (constants, index)
|
||||||
Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants)), Utf8Info name], fromIntegral (length constants))
|
Nothing -> (constants ++ [ClassInfo (fromIntegral (length constants) + 2), Utf8Info name], fromIntegral (length constants) + 1)
|
||||||
|
|
||||||
-- get the index for a field within a class, creating it if it does not exist.
|
-- get the index for a field within a class, creating it if it does not exist.
|
||||||
getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int)
|
getFieldIndex :: [ConstantInfo] -> (String, String, String) -> ([ConstantInfo], Int)
|
||||||
@ -239,7 +235,58 @@ injectFieldInitializers classname vars pre = let
|
|||||||
otherwise -> Nothing
|
otherwise -> Nothing
|
||||||
) vars
|
) vars
|
||||||
in
|
in
|
||||||
map (\(method) -> case method of
|
map (\method -> case method of
|
||||||
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
|
MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block statements)) -> MethodDeclaration "void" "<init>" params (TypedStatement "void" (Block (initializers ++ statements)))
|
||||||
otherwise -> method
|
_ -> method
|
||||||
) pre
|
) pre
|
||||||
|
|
||||||
|
-- effect of one instruction/operation on the stack
|
||||||
|
operationStackCost :: [ConstantInfo] -> Operation -> Int
|
||||||
|
operationStackCost constants Opiadd = -1
|
||||||
|
operationStackCost constants Opisub = -1
|
||||||
|
operationStackCost constants Opimul = -1
|
||||||
|
operationStackCost constants Opidiv = -1
|
||||||
|
operationStackCost constants Opirem = -1
|
||||||
|
operationStackCost constants Opiand = -1
|
||||||
|
operationStackCost constants Opior = -1
|
||||||
|
operationStackCost constants Opixor = -1
|
||||||
|
operationStackCost constants Opineg = 0
|
||||||
|
operationStackCost constants Opdup = 1
|
||||||
|
operationStackCost constants (Opnew _) = 1
|
||||||
|
operationStackCost constants (Opif_icmplt _) = -2
|
||||||
|
operationStackCost constants (Opif_icmple _) = -2
|
||||||
|
operationStackCost constants (Opif_icmpgt _) = -2
|
||||||
|
operationStackCost constants (Opif_icmpge _) = -2
|
||||||
|
operationStackCost constants (Opif_icmpeq _) = -2
|
||||||
|
operationStackCost constants (Opif_icmpne _) = -2
|
||||||
|
operationStackCost constants Opaconst_null = 1
|
||||||
|
operationStackCost constants Opreturn = 0
|
||||||
|
operationStackCost constants Opireturn = -1
|
||||||
|
operationStackCost constants Opareturn = -1
|
||||||
|
operationStackCost constants Opdup_x1 = 1
|
||||||
|
operationStackCost constants Oppop = -1
|
||||||
|
operationStackCost constants (Opinvokespecial idx) = let
|
||||||
|
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
|
||||||
|
in (length params + 1) - fromEnum (returnType /= "V")
|
||||||
|
operationStackCost constants (Opinvokevirtual idx) = let
|
||||||
|
(params, returnType) = methodParametersFromIndex constants (fromIntegral idx)
|
||||||
|
in (length params + 1) - fromEnum (returnType /= "V")
|
||||||
|
operationStackCost constants (Opgoto _) = 0
|
||||||
|
operationStackCost constants (Opsipush _) = 1
|
||||||
|
operationStackCost constants (Opldc_w _) = 1
|
||||||
|
operationStackCost constants (Opaload _) = 1
|
||||||
|
operationStackCost constants (Opiload _) = 1
|
||||||
|
operationStackCost constants (Opastore _) = -1
|
||||||
|
operationStackCost constants (Opistore _) = -1
|
||||||
|
operationStackCost constants (Opputfield _) = -2
|
||||||
|
operationStackCost constants (Opgetfield _) = -1
|
||||||
|
|
||||||
|
simulateStackOperation :: [ConstantInfo] -> Operation -> (Int, Int) -> (Int, Int)
|
||||||
|
simulateStackOperation constants op (cd, md) = let
|
||||||
|
depth = cd + operationStackCost constants op
|
||||||
|
in if depth < 0
|
||||||
|
then error ("Consuming value off of empty stack: " ++ show op)
|
||||||
|
else (depth, max depth md)
|
||||||
|
|
||||||
|
maxStackDepth :: [ConstantInfo] -> [Operation] -> Int
|
||||||
|
maxStackDepth constants ops = snd $ foldr (simulateStackOperation constants) (0, 0) (reverse ops)
|
||||||
|
11
src/Main.hs
11
src/Main.hs
@ -1,6 +1,5 @@
|
|||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Example
|
|
||||||
import Typecheck
|
import Typecheck
|
||||||
import Parser.Lexer (alexScanTokens)
|
import Parser.Lexer (alexScanTokens)
|
||||||
import Parser.JavaParser
|
import Parser.JavaParser
|
||||||
@ -14,18 +13,18 @@ main = do
|
|||||||
args <- getArgs
|
args <- getArgs
|
||||||
let filename = if null args
|
let filename = if null args
|
||||||
then error "Missing filename, I need to know what to compile"
|
then error "Missing filename, I need to know what to compile"
|
||||||
else args!!0
|
else head args
|
||||||
let outputDirectory = takeDirectory filename
|
let outputDirectory = takeDirectory filename
|
||||||
print ("Compiling " ++ filename)
|
print ("Compiling " ++ filename)
|
||||||
file <- readFile filename
|
file <- readFile filename
|
||||||
|
|
||||||
let untypedAST = parse $ alexScanTokens file
|
let untypedAST = parse $ alexScanTokens file
|
||||||
let typedAST = (typeCheckCompilationUnit untypedAST)
|
let typedAST = typeCheckCompilationUnit untypedAST
|
||||||
let assembledClasses = map (\(typedClass) -> classBuilder typedClass emptyClassFile) typedAST
|
let assembledClasses = map (`classBuilder` emptyClassFile) typedAST
|
||||||
|
|
||||||
mapM_ (\(classFile) -> let
|
mapM_ (\classFile -> let
|
||||||
fileContent = pack (serialize classFile)
|
fileContent = pack (serialize classFile)
|
||||||
fileName = outputDirectory ++ "/" ++ (className classFile) ++ ".class"
|
fileName = outputDirectory ++ "/" ++ className classFile ++ ".class"
|
||||||
in Data.ByteString.writeFile fileName fileContent
|
in Data.ByteString.writeFile fileName fileContent
|
||||||
) assembledClasses
|
) assembledClasses
|
||||||
|
|
||||||
|
@ -37,14 +37,18 @@ typeCheckVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)
|
|||||||
-- Type check the initializer expression if it exists
|
-- Type check the initializer expression if it exists
|
||||||
checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
|
checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
|
||||||
exprType = fmap getTypeFromExpr checkedExpr
|
exprType = fmap getTypeFromExpr checkedExpr
|
||||||
|
checkedExprWithType = case exprType of
|
||||||
|
Just "null" | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
|
||||||
|
_ -> checkedExpr
|
||||||
in case (validType, redefined, exprType) of
|
in case (validType, redefined, exprType) of
|
||||||
(False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'"
|
(False, _, _) -> error $ "Type '" ++ dataType ++ "' is not a valid type for variable '" ++ identifier ++ "'"
|
||||||
(_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
|
(_, True, _) -> error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
|
||||||
(_, _, Just t)
|
(_, _, Just t)
|
||||||
| t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExpr
|
| t == "null" && isObjectType dataType -> VariableDeclaration dataType identifier checkedExprWithType
|
||||||
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
|
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
|
||||||
| otherwise -> VariableDeclaration dataType identifier checkedExpr
|
| otherwise -> VariableDeclaration dataType identifier checkedExprWithType
|
||||||
(_, _, Nothing) -> VariableDeclaration dataType identifier checkedExpr
|
(_, _, Nothing) -> VariableDeclaration dataType identifier checkedExprWithType
|
||||||
|
|
||||||
|
|
||||||
-- ********************************** Type Checking: Expressions **********************************
|
-- ********************************** Type Checking: Expressions **********************************
|
||||||
|
|
||||||
@ -125,18 +129,21 @@ typeCheckStatementExpression (Assignment ref expr) symtab classes =
|
|||||||
ref' = typeCheckExpression ref symtab classes
|
ref' = typeCheckExpression ref symtab classes
|
||||||
type' = getTypeFromExpr expr'
|
type' = getTypeFromExpr expr'
|
||||||
type'' = getTypeFromExpr ref'
|
type'' = getTypeFromExpr ref'
|
||||||
|
typeToAssign = if type' == "null" && isObjectType type'' then type'' else type'
|
||||||
|
exprWithType = if type' == "null" && isObjectType type'' then TypedExpression type'' NullLiteral else expr'
|
||||||
in
|
in
|
||||||
if type'' == type' || (type' == "null" && isObjectType type'') then
|
if type'' == typeToAssign then
|
||||||
TypedStatementExpression type'' (Assignment ref' expr')
|
TypedStatementExpression type'' (Assignment ref' exprWithType)
|
||||||
else
|
else
|
||||||
error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ type'
|
error $ "Type mismatch in assignment to variable: expected " ++ type'' ++ ", found " ++ typeToAssign
|
||||||
|
|
||||||
|
|
||||||
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
||||||
case find (\(Class name _ _) -> name == className) classes of
|
case find (\(Class name _ _) -> name == className) classes of
|
||||||
Nothing -> error $ "Class '" ++ className ++ "' not found."
|
Nothing -> error $ "Class '" ++ className ++ "' not found."
|
||||||
Just (Class _ methods fields) ->
|
Just (Class _ methods _) ->
|
||||||
-- Find constructor matching the class name with void return type
|
-- Find constructor matching the class name with void return type
|
||||||
case find (\(MethodDeclaration retType name params _) -> name == "<init>" && retType == "void") methods of
|
case find (\(MethodDeclaration _ name params _) -> name == "<init>") methods of
|
||||||
-- If no constructor is found, assume standard constructor with no parameters
|
-- If no constructor is found, assume standard constructor with no parameters
|
||||||
Nothing ->
|
Nothing ->
|
||||||
if null args then
|
if null args then
|
||||||
@ -144,21 +151,28 @@ typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
|||||||
else
|
else
|
||||||
error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided."
|
error $ "No valid constructor found for class '" ++ className ++ "', but arguments were provided."
|
||||||
Just (MethodDeclaration _ _ params _) ->
|
Just (MethodDeclaration _ _ params _) ->
|
||||||
let
|
let args' = zipWith
|
||||||
args' = map (\arg -> typeCheckExpression arg symtab classes) args
|
(\arg (ParameterDeclaration paramType _) ->
|
||||||
-- Extract expected parameter types from the constructor's parameters
|
let argTyped = typeCheckExpression arg symtab classes
|
||||||
|
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
|
||||||
|
then TypedExpression paramType NullLiteral
|
||||||
|
else argTyped
|
||||||
|
) args params
|
||||||
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
||||||
argTypes = map getTypeFromExpr args'
|
argTypes = map getTypeFromExpr args'
|
||||||
-- Check if the types of the provided arguments match the expected types
|
typeMatches = zipWith
|
||||||
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes
|
(\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
|
||||||
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches)
|
expectedTypes argTypes
|
||||||
|
mismatches = filter (not . fst3) typeMatches
|
||||||
|
fst3 (a, _, _) = a
|
||||||
in
|
in
|
||||||
if length args /= length params then
|
if null mismatches && length args == length params then
|
||||||
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
|
|
||||||
else if not (null mismatchErrors) then
|
|
||||||
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
|
|
||||||
else
|
|
||||||
TypedStatementExpression className (ConstructorCall className args')
|
TypedStatementExpression className (ConstructorCall className args')
|
||||||
|
else if not (null mismatches) then
|
||||||
|
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':")
|
||||||
|
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
|
||||||
|
else
|
||||||
|
error $ "Incorrect number of arguments for constructor of class '" ++ className ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
|
||||||
|
|
||||||
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
|
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
|
||||||
let objExprTyped = typeCheckExpression expr symtab classes
|
let objExprTyped = typeCheckExpression expr symtab classes
|
||||||
@ -168,20 +182,26 @@ typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
|
|||||||
Just (Class _ methods _) ->
|
Just (Class _ methods _) ->
|
||||||
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
|
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
|
||||||
Just (MethodDeclaration retType _ params _) ->
|
Just (MethodDeclaration retType _ params _) ->
|
||||||
let args' = map (\arg -> typeCheckExpression arg symtab classes) args
|
let args' = zipWith
|
||||||
|
(\arg (ParameterDeclaration paramType _) ->
|
||||||
|
let argTyped = typeCheckExpression arg symtab classes
|
||||||
|
in if getTypeFromExpr argTyped == "null" && isObjectType paramType
|
||||||
|
then TypedExpression paramType NullLiteral
|
||||||
|
else argTyped
|
||||||
|
) args params
|
||||||
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
||||||
argTypes = map getTypeFromExpr args'
|
argTypes = map getTypeFromExpr args'
|
||||||
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes
|
typeMatches = zipWith
|
||||||
|
(\expType argType -> (expType == argType || (argType == "null" && isObjectType expType), expType, argType))
|
||||||
|
expectedTypes argTypes
|
||||||
mismatches = filter (not . fst3) typeMatches
|
mismatches = filter (not . fst3) typeMatches
|
||||||
where fst3 (a, _, _) = a
|
fst3 (a, _, _) = a
|
||||||
in
|
in if null mismatches && length args == length params
|
||||||
if null mismatches && length args == length params then
|
then TypedStatementExpression retType (MethodCall objExprTyped methodName args')
|
||||||
TypedStatementExpression retType (MethodCall objExprTyped methodName args')
|
else if not (null mismatches)
|
||||||
else if not (null mismatches) then
|
then error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
|
||||||
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
|
|
||||||
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
|
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
|
||||||
else
|
else error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
|
||||||
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
|
|
||||||
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
|
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
|
||||||
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
|
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
|
||||||
_ -> error "Invalid object type for method call. Object must have a class type."
|
_ -> error "Invalid object type for method call. Object must have a class type."
|
||||||
@ -251,14 +271,18 @@ typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType ident
|
|||||||
-- If there's an initializer expression, type check it
|
-- If there's an initializer expression, type check it
|
||||||
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
|
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
|
||||||
exprType = fmap getTypeFromExpr checkedExpr
|
exprType = fmap getTypeFromExpr checkedExpr
|
||||||
|
checkedExprWithType = case (exprType, dataType) of
|
||||||
|
(Just "null", _) | isObjectType dataType -> Just (TypedExpression dataType NullLiteral)
|
||||||
|
_ -> checkedExpr
|
||||||
in case exprType of
|
in case exprType of
|
||||||
Just t
|
Just t
|
||||||
| t == "null" && isObjectType dataType ->
|
| t == "null" && isObjectType dataType ->
|
||||||
TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
|
TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
|
||||||
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
|
| t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
|
||||||
| otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
|
| otherwise -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExprWithType))
|
||||||
Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
|
Nothing -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
|
||||||
|
|
||||||
|
|
||||||
typeCheckStatement (While cond stmt) symtab classes =
|
typeCheckStatement (While cond stmt) symtab classes =
|
||||||
let cond' = typeCheckExpression cond symtab classes
|
let cond' = typeCheckExpression cond symtab classes
|
||||||
stmt' = typeCheckStatement stmt symtab classes
|
stmt' = typeCheckStatement stmt symtab classes
|
||||||
@ -305,13 +329,17 @@ typeCheckStatement (Block statements) symtab classes =
|
|||||||
typeCheckStatement (Return expr) symtab classes =
|
typeCheckStatement (Return expr) symtab classes =
|
||||||
let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab)
|
let methodReturnType = fromMaybe (error "Method return type not found in symbol table") (lookup "thisMeth" symtab)
|
||||||
expr' = case expr of
|
expr' = case expr of
|
||||||
Just e -> Just (typeCheckExpression e symtab classes)
|
Just e -> let eTyped = typeCheckExpression e symtab classes
|
||||||
|
in if getTypeFromExpr eTyped == "null" && isObjectType methodReturnType
|
||||||
|
then Just (TypedExpression methodReturnType NullLiteral)
|
||||||
|
else Just eTyped
|
||||||
Nothing -> Nothing
|
Nothing -> Nothing
|
||||||
returnType = maybe "void" getTypeFromExpr expr'
|
returnType = maybe "void" getTypeFromExpr expr'
|
||||||
in if returnType == methodReturnType || isSubtype returnType methodReturnType classes
|
in if returnType == methodReturnType || isSubtype returnType methodReturnType classes
|
||||||
then TypedStatement returnType (Return expr')
|
then TypedStatement returnType (Return expr')
|
||||||
else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType
|
else error $ "Return: Return type mismatch: expected " ++ methodReturnType ++ ", found " ++ returnType
|
||||||
|
|
||||||
|
|
||||||
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
|
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
|
||||||
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
|
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
|
||||||
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')
|
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')
|
||||||
|
Loading…
Reference in New Issue
Block a user