Compare commits

..

49 Commits

Author SHA1 Message Date
8de0309107 Merge branch 'create-parser' of https://gitea.hb.dhbw-stuttgart.de/MisterChaos69/MiniJavaCompiler into create-parser 2024-05-08 15:51:40 +02:00
90fa658c8f parser implement varibale initialization 2024-05-08 15:36:05 +02:00
ebf54bf4cb parser implement int initializiation 2024-05-08 15:36:05 +02:00
b957f512ad parser add emptystatement 2024-05-08 15:36:05 +02:00
b82e205bcd parser implement nested blocks 2024-05-08 15:36:04 +02:00
98321fa162 parser implement multiple blockstatements 2024-05-08 15:36:04 +02:00
58367a95e6 parser implement localvardeclaration 2024-05-08 15:36:04 +02:00
ea18431b77 parser add constructor 2024-05-08 15:36:04 +02:00
9e6f31479f parser improve error messages 2024-05-08 15:36:04 +02:00
f183b8b183 parser implement methods with params 2024-05-08 15:36:04 +02:00
3130f7f7d4 add void methods 2024-05-08 15:36:04 +02:00
f57f360abf parser add empty methods 2024-05-08 15:36:04 +02:00
5e90f0d0ee parser add modifier 2024-05-08 15:36:04 +02:00
301b87c9ac parser implement multiple declarations 2024-05-08 15:36:04 +02:00
decc909c23 parser add custom type fields 2024-05-08 15:36:04 +02:00
mrab
7cbf8aad08 Merge branch 'create-parser' of ssh://gitea.hb.dhbw-stuttgart.de:2222/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-08 11:04:43 +02:00
mrab
93702071fb Merge branch 'typedAST' of ssh://gitea.hb.dhbw-stuttgart.de:2222/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-08 10:57:21 +02:00
7bfdeed620 add local and field var 2024-05-08 10:26:19 +02:00
mrab
176b98d659 expression assembler 2024-05-07 21:18:01 +02:00
ced5d1df9c add program tests 2024-05-07 20:10:39 +02:00
acab0add95 added methodcall and assigment test 2024-05-07 16:42:01 +02:00
mrab
c0d48b5274 Merge branch 'bytecode' of ssh://gitea.hb.dhbw-stuttgart.de:2222/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-07 16:38:28 +02:00
b040130569 removed unneccesary comments 2024-05-07 16:36:42 +02:00
e723e9cb18 Merge branch 'bytecode' of https://gitea.hb.dhbw-stuttgart.de/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-07 16:35:49 +02:00
d47c7f5a45 added tests 2024-05-07 16:34:43 +02:00
8c3c3625b9 fixed method and name resolution typecheck 2024-05-07 16:34:34 +02:00
f9df24f456 Merge branch 'bytecode' of https://gitea.hb.dhbw-stuttgart.de/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-07 16:34:22 +02:00
b7d8f19433 add: MethodBuilder 2024-05-07 16:32:36 +02:00
mrab
8b5650dd61 code serialization, minimal opcodes 2024-05-07 15:46:08 +02:00
068b97e0e7 add: constantPoolBuilding 2024-05-07 14:27:20 +02:00
c920a3dfb6 add: Classbuilder & ClassFileBuilder 2024-05-07 11:54:54 +02:00
358293d4d4 add typecheck method and constructor call 2024-05-07 11:28:35 +02:00
7422510267 add expression to methodcall 2024-05-07 11:09:10 +02:00
40c7cab0e3 apply pull and add tests 2024-05-07 11:04:40 +02:00
mrab
f988c80b83 Merge branch 'typedAST' of ssh://gitea.hb.dhbw-stuttgart.de:2222/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-07 09:59:01 +02:00
mrab
ff03be2671 moved type check to its own file 2024-05-07 09:53:16 +02:00
mrab
c248dc3676 Merge branch 'typedAST' of ssh://gitea.hb.dhbw-stuttgart.de:2222/MisterChaos69/MiniJavaCompiler into bytecode 2024-05-07 09:49:32 +02:00
75272498c4 fixed typechecker trying to find a field of a class in the symboltable 2024-05-07 08:20:10 +02:00
8d9d28cb1f added means of testing 2024-05-06 23:15:22 +02:00
f14470d8cc add name resolution, fix symbol table key, value switchup 2024-05-06 23:15:10 +02:00
ae6232380e Fix type mismatch and redefinition issues in typeCheckStatement function 2024-05-06 19:52:06 +02:00
0debfe995d add full block type checking 2024-05-06 19:13:46 +02:00
62bd191314 update block and local variable typecheck 2024-05-06 10:49:33 +02:00
dbdd2f85bd inital commit with first version 2024-05-05 10:00:11 +02:00
mrab
f93a1ba58c working export of binary class files 2024-05-04 19:08:41 +02:00
mrab
967c3a4b41 Moved bytecode stuff into separate folder 2024-05-04 16:28:42 +02:00
mrab
e2f978ca93 serialization of class files 2024-05-04 15:41:05 +02:00
f4c70d4d30 add:Serialization Framework 2024-05-04 13:02:35 +02:00
mrab
580e56ebbb Class file data structure 2024-05-04 11:39:44 +02:00
13 changed files with 1130 additions and 64 deletions

6
.gitignore vendored
View File

@ -8,6 +8,12 @@ cabal-dev
*.chs.h *.chs.h
*.dyn_o *.dyn_o
*.dyn_hi *.dyn_hi
*.java
*.class
*.local~*
src/Parser/JavaParser.hs
src/Parser/Parser.hs
src/Parser/Lexer.hs
.hpc .hpc
.hsenv .hsenv
.cabal-sandbox/ .cabal-sandbox/

View File

@ -0,0 +1,116 @@
module TestByteCodeGenerator where
import Test.HUnit
import ByteCode.ClassFile.Generator
import ByteCode.ClassFile
import ByteCode.Constants
import Ast
nakedClass = Class "Testklasse" [] []
expectedClass = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [],
attributes = []
}
classWithFields = Class "Testklasse" [] [VariableDeclaration "int" "testvariable" Nothing]
expectedClassWithFields = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse",
FieldRefInfo 8 11,
NameAndTypeInfo 12 13,
Utf8Info "testvariable",
Utf8Info "I"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [
MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = 12,
memberDescriptorIndex = 13,
memberAttributes = []
}
],
methods = [],
attributes = []
}
method = MethodDeclaration "int" "add_two_numbers" [
ParameterDeclaration "int" "a",
ParameterDeclaration "int" "b" ]
(Block [Return (Just (BinaryOperation Addition (Reference "a") (Reference "b")))])
classWithMethod = Class "Testklasse" [method] []
expectedClassWithMethod = ClassFile {
constantPool = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code",
ClassInfo 9,
Utf8Info "Testklasse",
FieldRefInfo 8 11,
NameAndTypeInfo 12 13,
Utf8Info "add_two_numbers",
Utf8Info "(II)I"
],
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [
MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = 12,
memberDescriptorIndex = 13,
memberAttributes = [
CodeAttribute {
attributeMaxStack = 420,
attributeMaxLocals = 420,
attributeCode = [Opiadd]
}
]
}
],
attributes = []
}
testBasicConstantPool = TestCase $ assertEqual "basic constant pool" expectedClass $ classBuilder nakedClass emptyClassFile
testFields = TestCase $ assertEqual "fields in constant pool" expectedClassWithFields $ classBuilder classWithFields emptyClassFile
testMethodDescriptor = TestCase $ assertEqual "method descriptor" "(II)I" (methodDescriptor method)
testMethodAssembly = TestCase $ assertEqual "method assembly" expectedClassWithMethod (classBuilder classWithMethod emptyClassFile)
tests = TestList [
TestLabel "Basic constant pool" testBasicConstantPool,
TestLabel "Fields constant pool" testFields,
TestLabel "Method descriptor building" testMethodDescriptor,
TestLabel "Method assembly" testMethodAssembly
]

View File

@ -2,12 +2,13 @@ module Main where
import Test.HUnit import Test.HUnit
import TestLexer import TestLexer
import TestByteCodeGenerator
import TestParser import TestParser
tests = TestList [ tests = TestList [
TestLabel "TestLexer" TestLexer.tests, TestLabel "TestLexer" TestLexer.tests,
TestLabel "TestParser" TestParser.tests TestLabel "TestParser" TestParser.tests,
] TestLabel "TestByteCodeGenerator" TestByteCodeGenerator.tests]
main = do runTestTTAndExit Main.tests main = do runTestTTAndExit Main.tests

View File

@ -8,11 +8,23 @@ executable compiler
main-is: Main.hs main-is: Main.hs
build-depends: base, build-depends: base,
array, array,
HUnit HUnit,
utf8-string,
bytestring
default-language: Haskell2010 default-language: Haskell2010
hs-source-dirs: src hs-source-dirs: src,
src/ByteCode,
src/ByteCode/ClassFile
build-tool-depends: alex:alex, happy:happy build-tool-depends: alex:alex, happy:happy
other-modules: Parser.Lexer, Parser.JavaParser, Ast other-modules: Parser.Lexer,
Parser.JavaParser
Ast,
Example,
Typecheck,
ByteCode.ByteUtil,
ByteCode.ClassFile,
ByteCode.ClassFile.Generator,
ByteCode.Constants
test-suite tests test-suite tests
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
@ -20,10 +32,17 @@ test-suite tests
hs-source-dirs: src,Test hs-source-dirs: src,Test
build-depends: base, build-depends: base,
array, array,
HUnit HUnit,
utf8-string,
bytestring
build-tool-depends: alex:alex, happy:happy build-tool-depends: alex:alex, happy:happy
other-modules: Parser.Lexer, other-modules: Parser.Lexer,
Parser.JavaParser, Parser.JavaParser,
Ast, Ast,
TestLexer TestLexer,
TestParser TestParser,
TestByteCodeGenerator,
ByteCode.ByteUtil,
ByteCode.ClassFile,
ByteCode.ClassFile.Generator,
ByteCode.Constants

2
questions.md Normal file
View File

@ -0,0 +1,2 @@
# Questions
- Enum?

View File

@ -4,16 +4,13 @@ type CompilationUnit = [Class]
type DataType = String type DataType = String
type Identifier = String type Identifier = String
data ParameterDeclaration = ParameterDeclaration DataType Identifier data ParameterDeclaration = ParameterDeclaration DataType Identifier deriving (Show, Eq)
deriving(Show, Eq) data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) deriving (Show, Eq)
data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) data Class = Class DataType [MethodDeclaration] [VariableDeclaration] deriving (Show, Eq)
deriving(Show, Eq) data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement deriving (Show, Eq)
data Class = Class DataType [MethodDeclaration] [VariableDeclaration]
deriving(Show, Eq)
data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement
deriving(Show, Eq)
data Statement = If Expression Statement (Maybe Statement) data Statement
= If Expression Statement (Maybe Statement)
| LocalVariableDeclaration VariableDeclaration | LocalVariableDeclaration VariableDeclaration
| While Expression Statement | While Expression Statement
| Block [Statement] | Block [Statement]
@ -22,13 +19,15 @@ data Statement = If Expression Statement (Maybe Statement)
| TypedStatement DataType Statement | TypedStatement DataType Statement
deriving (Show, Eq) deriving (Show, Eq)
data StatementExpression = Assignment Identifier Expression data StatementExpression
= Assignment Identifier Expression
| ConstructorCall DataType [Expression] | ConstructorCall DataType [Expression]
| MethodCall Identifier [Expression] | MethodCall Expression Identifier [Expression]
| TypedStatementExpression DataType StatementExpression | TypedStatementExpression DataType StatementExpression
deriving (Show, Eq) deriving (Show, Eq)
data BinaryOperator = Addition data BinaryOperator
= Addition
| Subtraction | Subtraction
| Multiplication | Multiplication
| Division | Division
@ -46,18 +45,21 @@ data BinaryOperator = Addition
| NameResolution | NameResolution
deriving (Show, Eq) deriving (Show, Eq)
data UnaryOperator = Not data UnaryOperator
= Not
| Minus | Minus
deriving (Show, Eq) deriving (Show, Eq)
data Expression = IntegerLiteral Int data Expression
= IntegerLiteral Int
| CharacterLiteral Char | CharacterLiteral Char
| BooleanLiteral Bool | BooleanLiteral Bool
| NullLiteral | NullLiteral
| Reference Identifier | Reference Identifier
| LocalVariable Identifier
| FieldVariable Identifier
| BinaryOperation BinaryOperator Expression Expression | BinaryOperation BinaryOperator Expression Expression
| UnaryOperation UnaryOperator Expression | UnaryOperation UnaryOperator Expression
| StatementExpressionExpression StatementExpression | StatementExpressionExpression StatementExpression
| TypedExpression DataType Expression | TypedExpression DataType Expression
deriving (Show, Eq) deriving (Show, Eq)

19
src/ByteCode/ByteUtil.hs Normal file
View File

@ -0,0 +1,19 @@
module ByteCode.ByteUtil(unpackWord16, unpackWord32) where
import Data.Word ( Word8, Word16, Word32 )
import Data.Int
import Data.Bits
unpackWord16 :: Word16 -> [Word8]
unpackWord16 v = [
fromIntegral (shiftR ((.&.) v 0xFF00) 8),
fromIntegral (shiftR ((.&.) v 0x00FF) 0)
]
unpackWord32 :: Word32 -> [Word8]
unpackWord32 v = [
fromIntegral (shiftR ((.&.) v 0xFF000000) 24),
fromIntegral (shiftR ((.&.) v 0x00FF0000) 16),
fromIntegral (shiftR ((.&.) v 0x0000FF00) 8),
fromIntegral (shiftR ((.&.) v 0x000000FF) 0)
]

168
src/ByteCode/ClassFile.hs Normal file
View File

@ -0,0 +1,168 @@
module ByteCode.ClassFile(
ConstantInfo(..),
Attribute(..),
MemberInfo(..),
ClassFile(..),
Operation(..),
serialize,
emptyClassFile
) where
import Data.Word
import Data.Int
import Data.ByteString (unpack)
import Data.ByteString.UTF8 (fromString)
import ByteCode.ByteUtil
import ByteCode.Constants
data ConstantInfo = ClassInfo Word16
| FieldRefInfo Word16 Word16
| MethodRefInfo Word16 Word16
| NameAndTypeInfo Word16 Word16
| IntegerInfo Int32
| Utf8Info [Char]
deriving (Show, Eq)
data Operation = Opiadd
| Opisub
| Opimul
| Opidiv
| Opiand
| Opior
| Opixor
| Opineg
| Opif_icmplt Word16
| Opif_icmple Word16
| Opif_icmpgt Word16
| Opif_icmpge Word16
| Opif_icmpeq Word16
| Opif_icmpne Word16
| Opaconst_null
| Opreturn
| Opireturn
| Opareturn
| Opsipush Word16
| Opldc_w Word16
| Opaload Word16
| Opiload Word16
| Opastore Word16
| Opistore Word16
| Opputfield Word16
| OpgetField Word16
deriving (Show, Eq)
data Attribute = CodeAttribute {
attributeMaxStack :: Word16,
attributeMaxLocals :: Word16,
attributeCode :: [Operation]
} deriving (Show, Eq)
data MemberInfo = MemberInfo {
memberAccessFlags :: Word16,
memberNameIndex :: Word16,
memberDescriptorIndex :: Word16,
memberAttributes :: [Attribute]
} deriving (Show, Eq)
data ClassFile = ClassFile {
constantPool :: [ConstantInfo],
accessFlags :: Word16,
thisClass :: Word16,
superClass :: Word16,
fields :: [MemberInfo],
methods :: [MemberInfo],
attributes :: [Attribute]
} deriving (Show, Eq)
emptyClassFile :: ClassFile
emptyClassFile = ClassFile {
constantPool = [],
accessFlags = accessPublic,
thisClass = 0,
superClass = 0,
fields = [],
methods = [],
attributes = []
}
class Serializable a where
serialize :: a -> [Word8]
instance Serializable ConstantInfo where
serialize (ClassInfo nameIndex) = tagClass : unpackWord16 nameIndex
serialize (FieldRefInfo classIndex nameAndTypeIndex) = tagFieldref : (unpackWord16 classIndex ++ unpackWord16 nameAndTypeIndex)
serialize (MethodRefInfo classIndex nameAndTypeIndex) = tagMethodref : (unpackWord16 classIndex ++ unpackWord16 nameAndTypeIndex)
serialize (NameAndTypeInfo classIndex descriptorIndex) = tagNameandtype : (unpackWord16 classIndex ++ unpackWord16 descriptorIndex)
serialize (IntegerInfo value) = tagInteger : unpackWord32 (fromIntegral value)
serialize (Utf8Info string) = tagUtf8 : unpackWord16 num_bytes ++ bytes where
bytes = unpack (fromString string)
num_bytes = fromIntegral $ length bytes
instance Serializable MemberInfo where
serialize member = unpackWord16 (memberAccessFlags member)
++ unpackWord16 (memberNameIndex member)
++ unpackWord16 (memberDescriptorIndex member)
++ unpackWord16 (fromIntegral (length (memberAttributes member)))
++ concatMap serialize (memberAttributes member)
instance Serializable Operation where
serialize Opiadd = [0x60]
serialize Opisub = [0x64]
serialize Opimul = [0x68]
serialize Opidiv = [0x6C]
serialize Opiand = [0x7E]
serialize Opior = [0x80]
serialize Opixor = [0x82]
serialize Opineg = [0x74]
serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch
serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch
serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch
serialize (Opif_icmpge branch) = 0xA2 : unpackWord16 branch
serialize (Opif_icmpeq branch) = 0x9F : unpackWord16 branch
serialize (Opif_icmpne branch) = 0xA0 : unpackWord16 branch
serialize Opaconst_null = [0x01]
serialize Opreturn = [0xB1]
serialize Opireturn = [0xAC]
serialize Opareturn = [0xB0]
serialize (Opsipush index) = 0x11 : unpackWord16 index
serialize (Opldc_w index) = 0x13 : unpackWord16 index
serialize (Opaload index) = [0xC4, 0x19] ++ unpackWord16 index
serialize (Opiload index) = [0xC4, 0x15] ++ unpackWord16 index
serialize (Opastore index) = [0xC4, 0x3A] ++ unpackWord16 index
serialize (Opistore index) = [0xC4, 0x36] ++ unpackWord16 index
serialize (Opputfield index) = 0xB5 : unpackWord16 index
serialize (OpgetField index) = 0xB4 : unpackWord16 index
instance Serializable Attribute where
serialize (CodeAttribute { attributeMaxStack = maxStack,
attributeMaxLocals = maxLocals,
attributeCode = code }) = let
assembledCode = concat (map serialize code)
in
unpackWord16 7 -- attribute_name_index
++ unpackWord32 (12 + (fromIntegral (length assembledCode))) -- attribute_length
++ unpackWord16 maxStack -- max_stack
++ unpackWord16 maxLocals -- max_locals
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length
++ assembledCode -- code
++ unpackWord16 0 -- exception_table_length
++ unpackWord16 0 -- attributes_count
instance Serializable ClassFile where
serialize classfile = unpackWord32 0xC0FEBABE -- magic
++ unpackWord16 0 -- minor version
++ unpackWord16 49 -- major version
++ unpackWord16 (fromIntegral (1 + length (constantPool classfile))) -- constant pool count
++ concatMap serialize (constantPool classfile) -- constant pool
++ unpackWord16 (accessFlags classfile) -- access flags
++ unpackWord16 (thisClass classfile) -- this class
++ unpackWord16 (superClass classfile) -- super class
++ unpackWord16 0 -- interface count
++ unpackWord16 (fromIntegral (length (fields classfile))) -- fields count
++ concatMap serialize (fields classfile) -- fields info
++ unpackWord16 (fromIntegral (length (methods classfile))) -- methods count
++ concatMap serialize (methods classfile) -- methods info
++ unpackWord16 (fromIntegral (length (attributes classfile))) -- attributes count
++ concatMap serialize (attributes classfile) -- attributes info

View File

@ -0,0 +1,169 @@
module ByteCode.ClassFile.Generator(
classBuilder,
datatypeDescriptor,
methodParameterDescriptor,
methodDescriptor,
) where
import ByteCode.Constants
import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..))
import Ast
import Data.Char
type ClassFileBuilder a = a -> ClassFile -> ClassFile
datatypeDescriptor :: String -> String
datatypeDescriptor "void" = "V"
datatypeDescriptor "int" = "I"
datatypeDescriptor "char" = "C"
datatypeDescriptor "boolean" = "B"
datatypeDescriptor x = "L" ++ x
methodParameterDescriptor :: String -> String
methodParameterDescriptor "void" = "V"
methodParameterDescriptor "int" = "I"
methodParameterDescriptor "char" = "C"
methodParameterDescriptor "boolean" = "B"
methodParameterDescriptor x = "L" ++ x ++ ";"
methodDescriptor :: MethodDeclaration -> String
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
in
"("
++ (concat (map methodParameterDescriptor parameter_types))
++ ")"
++ datatypeDescriptor returntype
classBuilder :: ClassFileBuilder Class
classBuilder (Class name methods fields) _ = let
baseConstants = [
ClassInfo 4,
MethodRefInfo 1 3,
NameAndTypeInfo 5 6,
Utf8Info "java/lang/Object",
Utf8Info "<init>",
Utf8Info "()V",
Utf8Info "Code"
]
nameConstants = [ClassInfo 9, Utf8Info name]
nakedClassFile = ClassFile {
constantPool = baseConstants ++ nameConstants,
accessFlags = accessPublic,
thisClass = 8,
superClass = 1,
fields = [],
methods = [],
attributes = []
}
in
foldr methodBuilder (foldr fieldBuilder nakedClassFile fields) methods
fieldBuilder :: ClassFileBuilder VariableDeclaration
fieldBuilder (VariableDeclaration datatype name _) input = let
baseIndex = 1 + length (constantPool input)
constants = [
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info name,
Utf8Info (datatypeDescriptor datatype)
]
field = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
memberAttributes = []
}
in
input {
constantPool = (constantPool input) ++ constants,
fields = (fields input) ++ [field]
}
methodBuilder :: ClassFileBuilder MethodDeclaration
methodBuilder (MethodDeclaration returntype name parameters statement) input = let
baseIndex = 1 + length (constantPool input)
constants = [
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
Utf8Info name,
Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block [])))
]
--code = assembleByteCode statement
method = MemberInfo {
memberAccessFlags = accessPublic,
memberNameIndex = (fromIntegral (baseIndex + 2)),
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
memberAttributes = [
CodeAttribute {
attributeMaxStack = 420,
attributeMaxLocals = 420,
attributeCode = [Opiadd]
}
]
}
in
input {
constantPool = (constantPool input) ++ constants,
methods = (fields input) ++ [method]
}
type Assembler a = a -> ([ConstantInfo], [Operation]) -> ([ConstantInfo], [Operation])
returnOperation :: DataType -> Operation
returnOperation dtype
| elem dtype ["int", "char", "boolean"] = Opireturn
| otherwise = Opareturn
binaryOperation :: BinaryOperator -> Operation
binaryOperation Addition = Opiadd
binaryOperation Subtraction = Opisub
binaryOperation Multiplication = Opimul
binaryOperation Division = Opidiv
binaryOperation BitwiseAnd = Opiand
binaryOperation BitwiseOr = Opior
binaryOperation BitwiseXor = Opixor
assembleMethod :: Assembler MethodDeclaration
assembleMethod (MethodDeclaration _ _ _ (Block statements)) (constants, ops) =
foldr assembleStatement (constants, ops) statements
assembleStatement :: Assembler Statement
assembleStatement (TypedStatement stype (Return expr)) (constants, ops) = case expr of
Nothing -> (constants, ops ++ [Opreturn])
Just expr -> let
(expr_constants, expr_ops) = assembleExpression expr (constants, ops)
in
(expr_constants, expr_ops ++ [returnOperation stype])
assembleExpression :: Assembler Expression
assembleExpression (TypedExpression _ (BinaryOperation op a b)) (constants, ops)
| elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let
(aConstants, aOps) = assembleExpression a (constants, ops)
(bConstants, bOps) = assembleExpression b (aConstants, aOps)
in
(bConstants, bOps ++ [binaryOperation op])
assembleExpression (TypedExpression _ (CharacterLiteral literal)) (constants, ops) =
(constants, ops ++ [Opsipush (fromIntegral (ord literal))])
assembleExpression (TypedExpression _ (BooleanLiteral literal)) (constants, ops) =
(constants, ops ++ [Opsipush (if literal then 1 else 0)])
assembleExpression (TypedExpression _ (IntegerLiteral literal)) (constants, ops)
| literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)])
| otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))])
assembleExpression (TypedExpression _ NullLiteral) (constants, ops) =
(constants, ops ++ [Opaconst_null])
assembleExpression (TypedExpression etype (UnaryOperation Not expr)) (constants, ops) = let
(exprConstants, exprOps) = assembleExpression expr (constants, ops)
newConstant = fromIntegral (1 + length exprConstants)
in case etype of
"int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor])
"char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor])
"boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor])
assembleExpression (TypedExpression _ (UnaryOperation Minus expr)) (constants, ops) = let
(exprConstants, exprOps) = assembleExpression expr (constants, ops)
in
(exprConstants, exprOps ++ [Opineg])

25
src/ByteCode/Constants.hs Normal file
View File

@ -0,0 +1,25 @@
module ByteCode.Constants where
import Data.Word
tagClass :: Word8
tagFieldref :: Word8
tagMethodref :: Word8
tagNameandtype :: Word8
tagInteger :: Word8
tagUtf8 :: Word8
accessPublic :: Word16
accessPrivate :: Word16
accessProtected :: Word16
tagClass = 0x07
tagFieldref = 0x09
tagMethodref = 0x0A
tagNameandtype = 0x0C
tagInteger = 0x03
tagUtf8 = 0x01
accessPublic = 0x01
accessPrivate = 0x02
accessProtected = 0x04

203
src/Example.hs Normal file
View File

@ -0,0 +1,203 @@
module Example where
import Ast
import Typecheck
import Control.Exception (catch, evaluate, SomeException, displayException)
import Control.Exception.Base
import System.IO (stderr, hPutStrLn)
import Data.Maybe
import Data.List
green, red, yellow, blue, magenta, cyan, white :: String -> String
green str = "\x1b[32m" ++ str ++ "\x1b[0m"
red str = "\x1b[31m" ++ str ++ "\x1b[0m"
yellow str = "\x1b[33m" ++ str ++ "\x1b[0m"
blue str = "\x1b[34m" ++ str ++ "\x1b[0m"
magenta str = "\x1b[35m" ++ str ++ "\x1b[0m"
cyan str = "\x1b[36m" ++ str ++ "\x1b[0m"
white str = "\x1b[37m" ++ str ++ "\x1b[0m"
printSuccess :: String -> IO ()
printSuccess msg = putStrLn $ green "Success:" ++ white msg
handleError :: SomeException -> IO ()
handleError e = hPutStrLn stderr $ red ("Error: " ++ displayException e)
printResult :: Show a => String -> a -> IO ()
printResult title result = do
putStrLn $ green title
print result
sampleClasses :: [Class]
sampleClasses = [
Class "Person" [
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
(Block [
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
]),
MethodDeclaration "int" "getAge" [] (Return (Just (Reference "age"))),
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] (Block [])
] [
VariableDeclaration "int" "age" (Just (IntegerLiteral 25))
]
]
initialSymtab :: [(DataType, Identifier)]
initialSymtab = []
exampleExpression :: Expression
exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age")
exampleAssignment :: Expression
exampleAssignment = StatementExpressionExpression (Assignment "a" (IntegerLiteral 30))
exampleMethodCall :: Statement
exampleMethodCall = StatementExpressionStatement (MethodCall (Reference "this") "setAge" [IntegerLiteral 30])
exampleConstructorCall :: Statement
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
exampleNameResolution :: Expression
exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age")
exampleBlockResolution :: Statement
exampleBlockResolution = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
]
exampleBlockResolutionFail :: Statement
exampleBlockResolutionFail = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "bool" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
]
exampleMethodCallAndAssignment :: Statement
exampleMethodCallAndAssignment = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
LocalVariableDeclaration (VariableDeclaration "int" "a" Nothing),
StatementExpressionStatement (Assignment "a" (Reference "age"))
]
exampleMethodCallAndAssignmentFail :: Statement
exampleMethodCallAndAssignmentFail = Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
StatementExpressionStatement (Assignment "a" (Reference "age"))
]
testClasses :: [Class]
testClasses = [
Class "Person" [
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"]
(Block [
Return (Just (Reference "this"))
]),
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
(Block [
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
]),
MethodDeclaration "int" "getAge" []
(Return (Just (Reference "age")))
] [
VariableDeclaration "int" "age" Nothing -- initially unassigned
],
Class "Main" [
MethodDeclaration "int" "main" []
(Block [
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
Return (Just (Reference "bobAge"))
])
] []
]
runTypeCheck :: IO ()
runTypeCheck = do
catch (do
print "====================================================================================="
evaluatedExpression <- evaluate (typeCheckExpression exampleExpression [("bob", "Person")] sampleClasses)
printSuccess "Type checking of expression completed successfully"
printResult "Result Expression:" evaluatedExpression
) handleError
catch (do
print "====================================================================================="
evaluatedAssignment <- evaluate (typeCheckExpression exampleAssignment [("a", "int")] sampleClasses)
printSuccess "Type checking of assignment completed successfully"
printResult "Result Assignment:" evaluatedAssignment
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCall <- evaluate (typeCheckStatement exampleMethodCall [("this", "Person"), ("setAge", "Person"), ("getAge", "Person")] sampleClasses)
printSuccess "Type checking of method call this completed successfully"
printResult "Result MethodCall:" evaluatedMethodCall
) handleError
catch (do
print "====================================================================================="
evaluatedConstructorCall <- evaluate (typeCheckStatement exampleConstructorCall [] sampleClasses)
printSuccess "Type checking of constructor call completed successfully"
printResult "Result Constructor Call:" evaluatedConstructorCall
) handleError
catch (do
print "====================================================================================="
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses)
printSuccess "Type checking of name resolution completed successfully"
printResult "Result Name Resolution:" evaluatedNameResolution
) handleError
catch (do
print "====================================================================================="
evaluatedBlockResolution <- evaluate (typeCheckStatement exampleBlockResolution [] sampleClasses)
printSuccess "Type checking of block resolution completed successfully"
printResult "Result Block Resolution:" evaluatedBlockResolution
) handleError
catch (do
print "====================================================================================="
evaluatedBlockResolutionFail <- evaluate (typeCheckStatement exampleBlockResolutionFail [] sampleClasses)
printSuccess "Type checking of block resolution failed"
printResult "Result Block Resolution:" evaluatedBlockResolutionFail
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCallAndAssignment <- evaluate (typeCheckStatement exampleMethodCallAndAssignment [] sampleClasses)
printSuccess "Type checking of method call and assignment completed successfully"
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignment
) handleError
catch (do
print "====================================================================================="
evaluatedMethodCallAndAssignmentFail <- evaluate (typeCheckStatement exampleMethodCallAndAssignmentFail [] sampleClasses)
printSuccess "Type checking of method call and assignment failed"
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignmentFail
) handleError
catch (do
print "====================================================================================="
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
case mainClass of
Class _ [mainMethod] _ -> do
let result = typeCheckMethodDeclaration mainMethod [] testClasses
printSuccess "Full program type checking completed successfully."
printResult "Main method result:" result
) handleError
catch (do
print "====================================================================================="
let typedProgram = typeCheckCompilationUnit testClasses
printSuccess "Type checking of Program completed successfully"
printResult "Typed Program:" typedProgram
) handleError

View File

@ -1,6 +1,8 @@
module Main where module Main where
import Parser.Lexer import Example
import Typecheck
main = do main = do
print $ alexScanTokens "/**/" Example.runTypeCheck

334
src/Typecheck.hs Normal file
View File

@ -0,0 +1,334 @@
module Typecheck where
import Data.List (find)
import Data.Maybe
import Ast
typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit
typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes
typeCheckClass :: Class -> [Class] -> Class
typeCheckClass (Class className methods fields) classes =
let
-- Create a symbol table from class fields and method entries
classFields = [(id, dt) | VariableDeclaration dt id _ <- fields]
methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
initalSymTab = ("this", className) : classFields ++ methodEntries
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
in Class className checkedMethods fields
typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration
typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes =
let
-- Combine class fields with method parameters to form the initial symbol table for the method
methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params]
initialSymtab = classFields ++ methodParams
checkedBody = typeCheckStatement body initialSymtab classes
bodyType = getTypeFromStmt checkedBody
-- Check if the type of the body matches the declared return type
in if bodyType == retType || (bodyType == "void" && retType == "void")
then MethodDeclaration retType name params checkedBody
else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType
-- ********************************** Type Checking: Expressions **********************************
typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression
typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i)
typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c)
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral
typeCheckExpression (Reference id) symtab classes =
let type' = lookupType id symtab
in TypedExpression type' (Reference id)
typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
let expr1' = typeCheckExpression expr1 symtab classes
expr2' = typeCheckExpression expr2 symtab classes
type1 = getTypeFromExpr expr1'
type2 = getTypeFromExpr expr2'
in case op of
Addition ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Addition operation requires two operands of type int"
Subtraction ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Subtraction operation requires two operands of type int"
Multiplication ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Multiplication operation requires two operands of type int"
Division ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Division operation requires two operands of type int"
BitwiseAnd ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Bitwise AND operation requires two operands of type int"
BitwiseOr ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Bitwise OR operation requires two operands of type int"
BitwiseXor ->
if type1 == "int" && type2 == "int"
then
TypedExpression "int" (BinaryOperation op expr1' expr2')
else
error "Bitwise XOR operation requires two operands of type int"
CompareLessThan ->
if type1 == "int" && type2 == "int"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Less than operation requires two operands of type int"
CompareLessOrEqual ->
if type1 == "int" && type2 == "int"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Less than or equal operation requires two operands of type int"
CompareGreaterThan ->
if type1 == "int" && type2 == "int"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Greater than operation requires two operands of type int"
CompareGreaterOrEqual ->
if type1 == "int" && type2 == "int"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Greater than or equal operation requires two operands of type int"
CompareEqual ->
if type1 == type2
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Equality operation requires two operands of the same type"
CompareNotEqual ->
if type1 == type2
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Inequality operation requires two operands of the same type"
And ->
if type1 == "boolean" && type2 == "boolean"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Logical AND operation requires two operands of type boolean"
Or ->
if type1 == "boolean" && type2 == "boolean"
then
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
else
error "Logical OR operation requires two operands of type boolean"
NameResolution ->
case (expr1', expr2) of
(TypedExpression t1 (Reference obj), Reference member) ->
let objectType = lookupType obj symtab
classDetails = find (\(Class className _ _) -> className == objectType) classes
in case classDetails of
Just (Class _ _ fields) ->
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member]
in case fieldTypes of
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2))
[] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'"
_ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'"
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
_ -> error "Name resolution requires object reference and field name"
typeCheckExpression (UnaryOperation op expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
in case op of
Not ->
if type' == "boolean"
then
TypedExpression "boolean" (UnaryOperation op expr')
else
error "Logical NOT operation requires an operand of type boolean"
Minus ->
if type' == "int"
then
TypedExpression "int" (UnaryOperation op expr')
else
error "Unary minus operation requires an operand of type int"
typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr')
-- ********************************** Type Checking: StatementExpressions **********************************
typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression
typeCheckStatementExpression (Assignment id expr) symtab classes =
let expr' = typeCheckExpression expr symtab classes
type' = getTypeFromExpr expr'
type'' = lookupType id symtab
in if type' == type''
then
TypedStatementExpression type' (Assignment id expr')
else
error "Assignment type mismatch"
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
case find (\(Class name _ _) -> name == className) classes of
Nothing -> error $ "Class '" ++ className ++ "' not found."
Just (Class _ methods fields) ->
-- Constructor needs the same name as the class
case find (\(MethodDeclaration retType name params _) -> name == className && retType == className) methods of
Nothing -> error $ "No valid constructor found for class '" ++ className ++ "'."
Just (MethodDeclaration _ _ params _) ->
let
args' = map (\arg -> typeCheckExpression arg symtab classes) args
-- Extract expected parameter types from the constructor's parameters
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args'
-- Check if the types of the provided arguments match the expected types
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches)
in
if length args /= length params then
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
else if not (null mismatchErrors) then
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
else
TypedStatementExpression className (ConstructorCall className args')
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
let objExprTyped = typeCheckExpression expr symtab classes
in case objExprTyped of
TypedExpression objType _ ->
case find (\(Class className _ _) -> className == objType) classes of
Just (Class _ methods _) ->
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
Just (MethodDeclaration retType _ params _) ->
let args' = map (\arg -> typeCheckExpression arg symtab classes) args
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
argTypes = map getTypeFromExpr args'
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes
mismatches = filter (not . fst3) typeMatches
where fst3 (a, _, _) = a
in
if null mismatches && length args == length params then
TypedStatementExpression retType (MethodCall objExprTyped methodName args')
else if not (null mismatches) then
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
else
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
_ -> error "Invalid object type for method call. Object must have a class type."
-- ********************************** Type Checking: Statements **********************************
typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement
typeCheckStatement (If cond thenStmt elseStmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes
thenStmt' = typeCheckStatement thenStmt symtab classes
elseStmt' = case elseStmt of
Just stmt -> Just (typeCheckStatement stmt symtab classes)
Nothing -> Nothing
in if getTypeFromExpr cond' == "boolean"
then
TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt')
else
error "If condition must be of type boolean"
typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes =
-- Check for redefinition in the current scope
if any ((== identifier) . snd) symtab
then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
else
-- If there's an initializer expression, type check it
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
exprType = fmap getTypeFromExpr checkedExpr
in case exprType of
Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
_ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
typeCheckStatement (While cond stmt) symtab classes =
let cond' = typeCheckExpression cond symtab classes
stmt' = typeCheckStatement stmt symtab classes
in if getTypeFromExpr cond' == "boolean"
then
TypedStatement (getTypeFromStmt stmt') (While cond' stmt')
else
error "While condition must be of type boolean"
typeCheckStatement (Block statements) symtab classes =
let
processStatements (accSts, currentSymtab, types) stmt =
let
checkedStmt = typeCheckStatement stmt currentSymtab classes
stmtType = getTypeFromStmt checkedStmt
in case stmt of
LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) ->
let
checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr
newSymtab = (identifier, dataType) : currentSymtab
in (accSts ++ [checkedStmt], newSymtab, types)
If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
_ -> (accSts ++ [checkedStmt], currentSymtab, types)
-- Initial accumulator: empty statements list, initial symbol table, empty types list
(checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements
-- Determine the block's type: unify all collected types, default to "Void" if none
blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes
in TypedStatement blockType (Block checkedStatements)
typeCheckStatement (Return expr) symtab classes =
let expr' = case expr of
Just e -> Just (typeCheckExpression e symtab classes)
Nothing -> Nothing
in case expr' of
Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e'))
Nothing -> TypedStatement "Void" (Return Nothing)
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')
-- ********************************** Type Checking: Helpers **********************************
getTypeFromExpr :: Expression -> DataType
getTypeFromExpr (TypedExpression t _) = t
getTypeFromExpr _ = error "Untyped expression found where typed was expected"
getTypeFromStmt :: Statement -> DataType
getTypeFromStmt (TypedStatement t _) = t
getTypeFromStmt _ = error "Untyped statement found where typed was expected"
getTypeFromStmtExpr :: StatementExpression -> DataType
getTypeFromStmtExpr (TypedStatementExpression t _) = t
getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected"
unifyReturnTypes :: DataType -> DataType -> DataType
unifyReturnTypes dt1 dt2
| dt1 == dt2 = dt1
| otherwise = "Object"
lookupType :: Identifier -> [(Identifier, DataType)] -> DataType
lookupType id symtab =
case lookup id symtab of
Just t -> t
Nothing -> error ("Identifier " ++ id ++ " not found in symbol table")