Compare commits
No commits in common. "8de030910703ccb0e178af5eb869c447bbf5107a" and "c49b7f556c055432dc18236311c3b0cc65a09e7a" have entirely different histories.
8de0309107
...
c49b7f556c
6
.gitignore
vendored
6
.gitignore
vendored
@ -8,12 +8,6 @@ cabal-dev
|
||||
*.chs.h
|
||||
*.dyn_o
|
||||
*.dyn_hi
|
||||
*.java
|
||||
*.class
|
||||
*.local~*
|
||||
src/Parser/JavaParser.hs
|
||||
src/Parser/Parser.hs
|
||||
src/Parser/Lexer.hs
|
||||
.hpc
|
||||
.hsenv
|
||||
.cabal-sandbox/
|
||||
|
@ -1,116 +0,0 @@
|
||||
module TestByteCodeGenerator where
|
||||
|
||||
import Test.HUnit
|
||||
import ByteCode.ClassFile.Generator
|
||||
import ByteCode.ClassFile
|
||||
import ByteCode.Constants
|
||||
import Ast
|
||||
|
||||
nakedClass = Class "Testklasse" [] []
|
||||
expectedClass = ClassFile {
|
||||
constantPool = [
|
||||
ClassInfo 4,
|
||||
MethodRefInfo 1 3,
|
||||
NameAndTypeInfo 5 6,
|
||||
Utf8Info "java/lang/Object",
|
||||
Utf8Info "<init>",
|
||||
Utf8Info "()V",
|
||||
Utf8Info "Code",
|
||||
ClassInfo 9,
|
||||
Utf8Info "Testklasse"
|
||||
],
|
||||
accessFlags = accessPublic,
|
||||
thisClass = 8,
|
||||
superClass = 1,
|
||||
fields = [],
|
||||
methods = [],
|
||||
attributes = []
|
||||
}
|
||||
|
||||
classWithFields = Class "Testklasse" [] [VariableDeclaration "int" "testvariable" Nothing]
|
||||
expectedClassWithFields = ClassFile {
|
||||
constantPool = [
|
||||
ClassInfo 4,
|
||||
MethodRefInfo 1 3,
|
||||
NameAndTypeInfo 5 6,
|
||||
Utf8Info "java/lang/Object",
|
||||
Utf8Info "<init>",
|
||||
Utf8Info "()V",
|
||||
Utf8Info "Code",
|
||||
ClassInfo 9,
|
||||
Utf8Info "Testklasse",
|
||||
FieldRefInfo 8 11,
|
||||
NameAndTypeInfo 12 13,
|
||||
Utf8Info "testvariable",
|
||||
Utf8Info "I"
|
||||
],
|
||||
accessFlags = accessPublic,
|
||||
thisClass = 8,
|
||||
superClass = 1,
|
||||
fields = [
|
||||
MemberInfo {
|
||||
memberAccessFlags = accessPublic,
|
||||
memberNameIndex = 12,
|
||||
memberDescriptorIndex = 13,
|
||||
memberAttributes = []
|
||||
}
|
||||
],
|
||||
methods = [],
|
||||
attributes = []
|
||||
}
|
||||
|
||||
method = MethodDeclaration "int" "add_two_numbers" [
|
||||
ParameterDeclaration "int" "a",
|
||||
ParameterDeclaration "int" "b" ]
|
||||
(Block [Return (Just (BinaryOperation Addition (Reference "a") (Reference "b")))])
|
||||
|
||||
|
||||
classWithMethod = Class "Testklasse" [method] []
|
||||
expectedClassWithMethod = ClassFile {
|
||||
constantPool = [
|
||||
ClassInfo 4,
|
||||
MethodRefInfo 1 3,
|
||||
NameAndTypeInfo 5 6,
|
||||
Utf8Info "java/lang/Object",
|
||||
Utf8Info "<init>",
|
||||
Utf8Info "()V",
|
||||
Utf8Info "Code",
|
||||
ClassInfo 9,
|
||||
Utf8Info "Testklasse",
|
||||
FieldRefInfo 8 11,
|
||||
NameAndTypeInfo 12 13,
|
||||
Utf8Info "add_two_numbers",
|
||||
Utf8Info "(II)I"
|
||||
],
|
||||
accessFlags = accessPublic,
|
||||
thisClass = 8,
|
||||
superClass = 1,
|
||||
fields = [],
|
||||
methods = [
|
||||
MemberInfo {
|
||||
memberAccessFlags = accessPublic,
|
||||
memberNameIndex = 12,
|
||||
memberDescriptorIndex = 13,
|
||||
memberAttributes = [
|
||||
CodeAttribute {
|
||||
attributeMaxStack = 420,
|
||||
attributeMaxLocals = 420,
|
||||
attributeCode = [Opiadd]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
attributes = []
|
||||
}
|
||||
|
||||
testBasicConstantPool = TestCase $ assertEqual "basic constant pool" expectedClass $ classBuilder nakedClass emptyClassFile
|
||||
testFields = TestCase $ assertEqual "fields in constant pool" expectedClassWithFields $ classBuilder classWithFields emptyClassFile
|
||||
testMethodDescriptor = TestCase $ assertEqual "method descriptor" "(II)I" (methodDescriptor method)
|
||||
testMethodAssembly = TestCase $ assertEqual "method assembly" expectedClassWithMethod (classBuilder classWithMethod emptyClassFile)
|
||||
|
||||
tests = TestList [
|
||||
TestLabel "Basic constant pool" testBasicConstantPool,
|
||||
TestLabel "Fields constant pool" testFields,
|
||||
TestLabel "Method descriptor building" testMethodDescriptor,
|
||||
TestLabel "Method assembly" testMethodAssembly
|
||||
]
|
@ -2,13 +2,12 @@ module Main where
|
||||
|
||||
import Test.HUnit
|
||||
import TestLexer
|
||||
import TestByteCodeGenerator
|
||||
import TestParser
|
||||
|
||||
|
||||
tests = TestList [
|
||||
TestLabel "TestLexer" TestLexer.tests,
|
||||
TestLabel "TestParser" TestParser.tests,
|
||||
TestLabel "TestByteCodeGenerator" TestByteCodeGenerator.tests]
|
||||
TestLabel "TestParser" TestParser.tests
|
||||
]
|
||||
|
||||
main = do runTestTTAndExit Main.tests
|
@ -8,23 +8,11 @@ executable compiler
|
||||
main-is: Main.hs
|
||||
build-depends: base,
|
||||
array,
|
||||
HUnit,
|
||||
utf8-string,
|
||||
bytestring
|
||||
HUnit
|
||||
default-language: Haskell2010
|
||||
hs-source-dirs: src,
|
||||
src/ByteCode,
|
||||
src/ByteCode/ClassFile
|
||||
hs-source-dirs: src
|
||||
build-tool-depends: alex:alex, happy:happy
|
||||
other-modules: Parser.Lexer,
|
||||
Parser.JavaParser
|
||||
Ast,
|
||||
Example,
|
||||
Typecheck,
|
||||
ByteCode.ByteUtil,
|
||||
ByteCode.ClassFile,
|
||||
ByteCode.ClassFile.Generator,
|
||||
ByteCode.Constants
|
||||
other-modules: Parser.Lexer, Parser.JavaParser, Ast
|
||||
|
||||
test-suite tests
|
||||
type: exitcode-stdio-1.0
|
||||
@ -32,17 +20,10 @@ test-suite tests
|
||||
hs-source-dirs: src,Test
|
||||
build-depends: base,
|
||||
array,
|
||||
HUnit,
|
||||
utf8-string,
|
||||
bytestring
|
||||
HUnit
|
||||
build-tool-depends: alex:alex, happy:happy
|
||||
other-modules: Parser.Lexer,
|
||||
Parser.JavaParser,
|
||||
Ast,
|
||||
TestLexer,
|
||||
TestParser,
|
||||
TestByteCodeGenerator,
|
||||
ByteCode.ByteUtil,
|
||||
ByteCode.ClassFile,
|
||||
ByteCode.ClassFile.Generator,
|
||||
ByteCode.Constants
|
||||
TestLexer
|
||||
TestParser
|
||||
|
@ -1,2 +0,0 @@
|
||||
# Questions
|
||||
- Enum?
|
42
src/Ast.hs
42
src/Ast.hs
@ -4,30 +4,31 @@ type CompilationUnit = [Class]
|
||||
type DataType = String
|
||||
type Identifier = String
|
||||
|
||||
data ParameterDeclaration = ParameterDeclaration DataType Identifier deriving (Show, Eq)
|
||||
data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression) deriving (Show, Eq)
|
||||
data Class = Class DataType [MethodDeclaration] [VariableDeclaration] deriving (Show, Eq)
|
||||
data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement deriving (Show, Eq)
|
||||
data ParameterDeclaration = ParameterDeclaration DataType Identifier
|
||||
deriving(Show, Eq)
|
||||
data VariableDeclaration = VariableDeclaration DataType Identifier (Maybe Expression)
|
||||
deriving(Show, Eq)
|
||||
data Class = Class DataType [MethodDeclaration] [VariableDeclaration]
|
||||
deriving(Show, Eq)
|
||||
data MethodDeclaration = MethodDeclaration DataType Identifier [ParameterDeclaration] Statement
|
||||
deriving(Show, Eq)
|
||||
|
||||
data Statement
|
||||
= If Expression Statement (Maybe Statement)
|
||||
data Statement = If Expression Statement (Maybe Statement)
|
||||
| LocalVariableDeclaration VariableDeclaration
|
||||
| While Expression Statement
|
||||
| Block [Statement]
|
||||
| Return (Maybe Expression)
|
||||
| StatementExpressionStatement StatementExpression
|
||||
| TypedStatement DataType Statement
|
||||
deriving (Show, Eq)
|
||||
deriving(Show, Eq)
|
||||
|
||||
data StatementExpression
|
||||
= Assignment Identifier Expression
|
||||
data StatementExpression = Assignment Identifier Expression
|
||||
| ConstructorCall DataType [Expression]
|
||||
| MethodCall Expression Identifier [Expression]
|
||||
| MethodCall Identifier [Expression]
|
||||
| TypedStatementExpression DataType StatementExpression
|
||||
deriving (Show, Eq)
|
||||
deriving(Show, Eq)
|
||||
|
||||
data BinaryOperator
|
||||
= Addition
|
||||
data BinaryOperator = Addition
|
||||
| Subtraction
|
||||
| Multiplication
|
||||
| Division
|
||||
@ -43,23 +44,20 @@ data BinaryOperator
|
||||
| And
|
||||
| Or
|
||||
| NameResolution
|
||||
deriving (Show, Eq)
|
||||
deriving(Show, Eq)
|
||||
|
||||
data UnaryOperator
|
||||
= Not
|
||||
data UnaryOperator = Not
|
||||
| Minus
|
||||
deriving (Show, Eq)
|
||||
deriving(Show, Eq)
|
||||
|
||||
data Expression
|
||||
= IntegerLiteral Int
|
||||
data Expression = IntegerLiteral Int
|
||||
| CharacterLiteral Char
|
||||
| BooleanLiteral Bool
|
||||
| NullLiteral
|
||||
| Reference Identifier
|
||||
| LocalVariable Identifier
|
||||
| FieldVariable Identifier
|
||||
| BinaryOperation BinaryOperator Expression Expression
|
||||
| UnaryOperation UnaryOperator Expression
|
||||
| StatementExpressionExpression StatementExpression
|
||||
| TypedExpression DataType Expression
|
||||
deriving (Show, Eq)
|
||||
deriving(Show, Eq)
|
||||
|
@ -1,19 +0,0 @@
|
||||
module ByteCode.ByteUtil(unpackWord16, unpackWord32) where
|
||||
|
||||
import Data.Word ( Word8, Word16, Word32 )
|
||||
import Data.Int
|
||||
import Data.Bits
|
||||
|
||||
unpackWord16 :: Word16 -> [Word8]
|
||||
unpackWord16 v = [
|
||||
fromIntegral (shiftR ((.&.) v 0xFF00) 8),
|
||||
fromIntegral (shiftR ((.&.) v 0x00FF) 0)
|
||||
]
|
||||
|
||||
unpackWord32 :: Word32 -> [Word8]
|
||||
unpackWord32 v = [
|
||||
fromIntegral (shiftR ((.&.) v 0xFF000000) 24),
|
||||
fromIntegral (shiftR ((.&.) v 0x00FF0000) 16),
|
||||
fromIntegral (shiftR ((.&.) v 0x0000FF00) 8),
|
||||
fromIntegral (shiftR ((.&.) v 0x000000FF) 0)
|
||||
]
|
@ -1,168 +0,0 @@
|
||||
module ByteCode.ClassFile(
|
||||
ConstantInfo(..),
|
||||
Attribute(..),
|
||||
MemberInfo(..),
|
||||
ClassFile(..),
|
||||
Operation(..),
|
||||
serialize,
|
||||
emptyClassFile
|
||||
) where
|
||||
|
||||
import Data.Word
|
||||
import Data.Int
|
||||
import Data.ByteString (unpack)
|
||||
import Data.ByteString.UTF8 (fromString)
|
||||
import ByteCode.ByteUtil
|
||||
import ByteCode.Constants
|
||||
|
||||
data ConstantInfo = ClassInfo Word16
|
||||
| FieldRefInfo Word16 Word16
|
||||
| MethodRefInfo Word16 Word16
|
||||
| NameAndTypeInfo Word16 Word16
|
||||
| IntegerInfo Int32
|
||||
| Utf8Info [Char]
|
||||
deriving (Show, Eq)
|
||||
|
||||
data Operation = Opiadd
|
||||
| Opisub
|
||||
| Opimul
|
||||
| Opidiv
|
||||
| Opiand
|
||||
| Opior
|
||||
| Opixor
|
||||
| Opineg
|
||||
| Opif_icmplt Word16
|
||||
| Opif_icmple Word16
|
||||
| Opif_icmpgt Word16
|
||||
| Opif_icmpge Word16
|
||||
| Opif_icmpeq Word16
|
||||
| Opif_icmpne Word16
|
||||
| Opaconst_null
|
||||
| Opreturn
|
||||
| Opireturn
|
||||
| Opareturn
|
||||
| Opsipush Word16
|
||||
| Opldc_w Word16
|
||||
| Opaload Word16
|
||||
| Opiload Word16
|
||||
| Opastore Word16
|
||||
| Opistore Word16
|
||||
| Opputfield Word16
|
||||
| OpgetField Word16
|
||||
deriving (Show, Eq)
|
||||
|
||||
|
||||
data Attribute = CodeAttribute {
|
||||
attributeMaxStack :: Word16,
|
||||
attributeMaxLocals :: Word16,
|
||||
attributeCode :: [Operation]
|
||||
} deriving (Show, Eq)
|
||||
|
||||
|
||||
data MemberInfo = MemberInfo {
|
||||
memberAccessFlags :: Word16,
|
||||
memberNameIndex :: Word16,
|
||||
memberDescriptorIndex :: Word16,
|
||||
memberAttributes :: [Attribute]
|
||||
} deriving (Show, Eq)
|
||||
|
||||
data ClassFile = ClassFile {
|
||||
constantPool :: [ConstantInfo],
|
||||
accessFlags :: Word16,
|
||||
thisClass :: Word16,
|
||||
superClass :: Word16,
|
||||
fields :: [MemberInfo],
|
||||
methods :: [MemberInfo],
|
||||
attributes :: [Attribute]
|
||||
} deriving (Show, Eq)
|
||||
|
||||
emptyClassFile :: ClassFile
|
||||
emptyClassFile = ClassFile {
|
||||
constantPool = [],
|
||||
accessFlags = accessPublic,
|
||||
thisClass = 0,
|
||||
superClass = 0,
|
||||
fields = [],
|
||||
methods = [],
|
||||
attributes = []
|
||||
}
|
||||
|
||||
class Serializable a where
|
||||
serialize :: a -> [Word8]
|
||||
|
||||
instance Serializable ConstantInfo where
|
||||
serialize (ClassInfo nameIndex) = tagClass : unpackWord16 nameIndex
|
||||
serialize (FieldRefInfo classIndex nameAndTypeIndex) = tagFieldref : (unpackWord16 classIndex ++ unpackWord16 nameAndTypeIndex)
|
||||
serialize (MethodRefInfo classIndex nameAndTypeIndex) = tagMethodref : (unpackWord16 classIndex ++ unpackWord16 nameAndTypeIndex)
|
||||
serialize (NameAndTypeInfo classIndex descriptorIndex) = tagNameandtype : (unpackWord16 classIndex ++ unpackWord16 descriptorIndex)
|
||||
serialize (IntegerInfo value) = tagInteger : unpackWord32 (fromIntegral value)
|
||||
serialize (Utf8Info string) = tagUtf8 : unpackWord16 num_bytes ++ bytes where
|
||||
bytes = unpack (fromString string)
|
||||
num_bytes = fromIntegral $ length bytes
|
||||
|
||||
instance Serializable MemberInfo where
|
||||
serialize member = unpackWord16 (memberAccessFlags member)
|
||||
++ unpackWord16 (memberNameIndex member)
|
||||
++ unpackWord16 (memberDescriptorIndex member)
|
||||
++ unpackWord16 (fromIntegral (length (memberAttributes member)))
|
||||
++ concatMap serialize (memberAttributes member)
|
||||
|
||||
instance Serializable Operation where
|
||||
serialize Opiadd = [0x60]
|
||||
serialize Opisub = [0x64]
|
||||
serialize Opimul = [0x68]
|
||||
serialize Opidiv = [0x6C]
|
||||
serialize Opiand = [0x7E]
|
||||
serialize Opior = [0x80]
|
||||
serialize Opixor = [0x82]
|
||||
serialize Opineg = [0x74]
|
||||
serialize (Opif_icmplt branch) = 0xA1 : unpackWord16 branch
|
||||
serialize (Opif_icmple branch) = 0xA4 : unpackWord16 branch
|
||||
serialize (Opif_icmpgt branch) = 0xA3 : unpackWord16 branch
|
||||
serialize (Opif_icmpge branch) = 0xA2 : unpackWord16 branch
|
||||
serialize (Opif_icmpeq branch) = 0x9F : unpackWord16 branch
|
||||
serialize (Opif_icmpne branch) = 0xA0 : unpackWord16 branch
|
||||
serialize Opaconst_null = [0x01]
|
||||
serialize Opreturn = [0xB1]
|
||||
serialize Opireturn = [0xAC]
|
||||
serialize Opareturn = [0xB0]
|
||||
serialize (Opsipush index) = 0x11 : unpackWord16 index
|
||||
serialize (Opldc_w index) = 0x13 : unpackWord16 index
|
||||
serialize (Opaload index) = [0xC4, 0x19] ++ unpackWord16 index
|
||||
serialize (Opiload index) = [0xC4, 0x15] ++ unpackWord16 index
|
||||
serialize (Opastore index) = [0xC4, 0x3A] ++ unpackWord16 index
|
||||
serialize (Opistore index) = [0xC4, 0x36] ++ unpackWord16 index
|
||||
serialize (Opputfield index) = 0xB5 : unpackWord16 index
|
||||
serialize (OpgetField index) = 0xB4 : unpackWord16 index
|
||||
|
||||
instance Serializable Attribute where
|
||||
serialize (CodeAttribute { attributeMaxStack = maxStack,
|
||||
attributeMaxLocals = maxLocals,
|
||||
attributeCode = code }) = let
|
||||
assembledCode = concat (map serialize code)
|
||||
in
|
||||
unpackWord16 7 -- attribute_name_index
|
||||
++ unpackWord32 (12 + (fromIntegral (length assembledCode))) -- attribute_length
|
||||
++ unpackWord16 maxStack -- max_stack
|
||||
++ unpackWord16 maxLocals -- max_locals
|
||||
++ unpackWord32 (fromIntegral (length assembledCode)) -- code_length
|
||||
++ assembledCode -- code
|
||||
++ unpackWord16 0 -- exception_table_length
|
||||
++ unpackWord16 0 -- attributes_count
|
||||
|
||||
instance Serializable ClassFile where
|
||||
serialize classfile = unpackWord32 0xC0FEBABE -- magic
|
||||
++ unpackWord16 0 -- minor version
|
||||
++ unpackWord16 49 -- major version
|
||||
++ unpackWord16 (fromIntegral (1 + length (constantPool classfile))) -- constant pool count
|
||||
++ concatMap serialize (constantPool classfile) -- constant pool
|
||||
++ unpackWord16 (accessFlags classfile) -- access flags
|
||||
++ unpackWord16 (thisClass classfile) -- this class
|
||||
++ unpackWord16 (superClass classfile) -- super class
|
||||
++ unpackWord16 0 -- interface count
|
||||
++ unpackWord16 (fromIntegral (length (fields classfile))) -- fields count
|
||||
++ concatMap serialize (fields classfile) -- fields info
|
||||
++ unpackWord16 (fromIntegral (length (methods classfile))) -- methods count
|
||||
++ concatMap serialize (methods classfile) -- methods info
|
||||
++ unpackWord16 (fromIntegral (length (attributes classfile))) -- attributes count
|
||||
++ concatMap serialize (attributes classfile) -- attributes info
|
@ -1,169 +0,0 @@
|
||||
module ByteCode.ClassFile.Generator(
|
||||
classBuilder,
|
||||
datatypeDescriptor,
|
||||
methodParameterDescriptor,
|
||||
methodDescriptor,
|
||||
) where
|
||||
|
||||
import ByteCode.Constants
|
||||
import ByteCode.ClassFile (ClassFile (..), ConstantInfo (..), MemberInfo(..), Operation(..), Attribute(..))
|
||||
import Ast
|
||||
import Data.Char
|
||||
|
||||
|
||||
type ClassFileBuilder a = a -> ClassFile -> ClassFile
|
||||
|
||||
|
||||
datatypeDescriptor :: String -> String
|
||||
datatypeDescriptor "void" = "V"
|
||||
datatypeDescriptor "int" = "I"
|
||||
datatypeDescriptor "char" = "C"
|
||||
datatypeDescriptor "boolean" = "B"
|
||||
datatypeDescriptor x = "L" ++ x
|
||||
|
||||
methodParameterDescriptor :: String -> String
|
||||
methodParameterDescriptor "void" = "V"
|
||||
methodParameterDescriptor "int" = "I"
|
||||
methodParameterDescriptor "char" = "C"
|
||||
methodParameterDescriptor "boolean" = "B"
|
||||
methodParameterDescriptor x = "L" ++ x ++ ";"
|
||||
|
||||
methodDescriptor :: MethodDeclaration -> String
|
||||
methodDescriptor (MethodDeclaration returntype _ parameters _) = let
|
||||
parameter_types = [datatype | ParameterDeclaration datatype _ <- parameters]
|
||||
in
|
||||
"("
|
||||
++ (concat (map methodParameterDescriptor parameter_types))
|
||||
++ ")"
|
||||
++ datatypeDescriptor returntype
|
||||
|
||||
classBuilder :: ClassFileBuilder Class
|
||||
classBuilder (Class name methods fields) _ = let
|
||||
baseConstants = [
|
||||
ClassInfo 4,
|
||||
MethodRefInfo 1 3,
|
||||
NameAndTypeInfo 5 6,
|
||||
Utf8Info "java/lang/Object",
|
||||
Utf8Info "<init>",
|
||||
Utf8Info "()V",
|
||||
Utf8Info "Code"
|
||||
]
|
||||
nameConstants = [ClassInfo 9, Utf8Info name]
|
||||
nakedClassFile = ClassFile {
|
||||
constantPool = baseConstants ++ nameConstants,
|
||||
accessFlags = accessPublic,
|
||||
thisClass = 8,
|
||||
superClass = 1,
|
||||
fields = [],
|
||||
methods = [],
|
||||
attributes = []
|
||||
}
|
||||
in
|
||||
foldr methodBuilder (foldr fieldBuilder nakedClassFile fields) methods
|
||||
|
||||
|
||||
|
||||
fieldBuilder :: ClassFileBuilder VariableDeclaration
|
||||
fieldBuilder (VariableDeclaration datatype name _) input = let
|
||||
baseIndex = 1 + length (constantPool input)
|
||||
constants = [
|
||||
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
|
||||
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
|
||||
Utf8Info name,
|
||||
Utf8Info (datatypeDescriptor datatype)
|
||||
]
|
||||
field = MemberInfo {
|
||||
memberAccessFlags = accessPublic,
|
||||
memberNameIndex = (fromIntegral (baseIndex + 2)),
|
||||
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
|
||||
memberAttributes = []
|
||||
}
|
||||
in
|
||||
input {
|
||||
constantPool = (constantPool input) ++ constants,
|
||||
fields = (fields input) ++ [field]
|
||||
}
|
||||
|
||||
methodBuilder :: ClassFileBuilder MethodDeclaration
|
||||
methodBuilder (MethodDeclaration returntype name parameters statement) input = let
|
||||
baseIndex = 1 + length (constantPool input)
|
||||
constants = [
|
||||
FieldRefInfo (fromIntegral (thisClass input)) (fromIntegral (baseIndex + 1)),
|
||||
NameAndTypeInfo (fromIntegral (baseIndex + 2)) (fromIntegral (baseIndex + 3)),
|
||||
Utf8Info name,
|
||||
Utf8Info (methodDescriptor (MethodDeclaration returntype name parameters (Block [])))
|
||||
]
|
||||
--code = assembleByteCode statement
|
||||
method = MemberInfo {
|
||||
memberAccessFlags = accessPublic,
|
||||
memberNameIndex = (fromIntegral (baseIndex + 2)),
|
||||
memberDescriptorIndex = (fromIntegral (baseIndex + 3)),
|
||||
memberAttributes = [
|
||||
CodeAttribute {
|
||||
attributeMaxStack = 420,
|
||||
attributeMaxLocals = 420,
|
||||
attributeCode = [Opiadd]
|
||||
}
|
||||
]
|
||||
}
|
||||
in
|
||||
input {
|
||||
constantPool = (constantPool input) ++ constants,
|
||||
methods = (fields input) ++ [method]
|
||||
}
|
||||
|
||||
type Assembler a = a -> ([ConstantInfo], [Operation]) -> ([ConstantInfo], [Operation])
|
||||
|
||||
returnOperation :: DataType -> Operation
|
||||
returnOperation dtype
|
||||
| elem dtype ["int", "char", "boolean"] = Opireturn
|
||||
| otherwise = Opareturn
|
||||
|
||||
binaryOperation :: BinaryOperator -> Operation
|
||||
binaryOperation Addition = Opiadd
|
||||
binaryOperation Subtraction = Opisub
|
||||
binaryOperation Multiplication = Opimul
|
||||
binaryOperation Division = Opidiv
|
||||
binaryOperation BitwiseAnd = Opiand
|
||||
binaryOperation BitwiseOr = Opior
|
||||
binaryOperation BitwiseXor = Opixor
|
||||
|
||||
assembleMethod :: Assembler MethodDeclaration
|
||||
assembleMethod (MethodDeclaration _ _ _ (Block statements)) (constants, ops) =
|
||||
foldr assembleStatement (constants, ops) statements
|
||||
|
||||
assembleStatement :: Assembler Statement
|
||||
assembleStatement (TypedStatement stype (Return expr)) (constants, ops) = case expr of
|
||||
Nothing -> (constants, ops ++ [Opreturn])
|
||||
Just expr -> let
|
||||
(expr_constants, expr_ops) = assembleExpression expr (constants, ops)
|
||||
in
|
||||
(expr_constants, expr_ops ++ [returnOperation stype])
|
||||
|
||||
assembleExpression :: Assembler Expression
|
||||
assembleExpression (TypedExpression _ (BinaryOperation op a b)) (constants, ops)
|
||||
| elem op [Addition, Subtraction, Multiplication, Division, BitwiseAnd, BitwiseOr, BitwiseXor] = let
|
||||
(aConstants, aOps) = assembleExpression a (constants, ops)
|
||||
(bConstants, bOps) = assembleExpression b (aConstants, aOps)
|
||||
in
|
||||
(bConstants, bOps ++ [binaryOperation op])
|
||||
assembleExpression (TypedExpression _ (CharacterLiteral literal)) (constants, ops) =
|
||||
(constants, ops ++ [Opsipush (fromIntegral (ord literal))])
|
||||
assembleExpression (TypedExpression _ (BooleanLiteral literal)) (constants, ops) =
|
||||
(constants, ops ++ [Opsipush (if literal then 1 else 0)])
|
||||
assembleExpression (TypedExpression _ (IntegerLiteral literal)) (constants, ops)
|
||||
| literal <= 32767 && literal >= -32768 = (constants, ops ++ [Opsipush (fromIntegral literal)])
|
||||
| otherwise = (constants ++ [IntegerInfo (fromIntegral literal)], ops ++ [Opldc_w (fromIntegral (1 + length constants))])
|
||||
assembleExpression (TypedExpression _ NullLiteral) (constants, ops) =
|
||||
(constants, ops ++ [Opaconst_null])
|
||||
assembleExpression (TypedExpression etype (UnaryOperation Not expr)) (constants, ops) = let
|
||||
(exprConstants, exprOps) = assembleExpression expr (constants, ops)
|
||||
newConstant = fromIntegral (1 + length exprConstants)
|
||||
in case etype of
|
||||
"int" -> (exprConstants ++ [IntegerInfo 0x7FFFFFFF], exprOps ++ [Opldc_w newConstant, Opixor])
|
||||
"char" -> (exprConstants, exprOps ++ [Opsipush 0xFFFF, Opixor])
|
||||
"boolean" -> (exprConstants, exprOps ++ [Opsipush 0x01, Opixor])
|
||||
assembleExpression (TypedExpression _ (UnaryOperation Minus expr)) (constants, ops) = let
|
||||
(exprConstants, exprOps) = assembleExpression expr (constants, ops)
|
||||
in
|
||||
(exprConstants, exprOps ++ [Opineg])
|
@ -1,25 +0,0 @@
|
||||
module ByteCode.Constants where
|
||||
import Data.Word
|
||||
|
||||
|
||||
tagClass :: Word8
|
||||
tagFieldref :: Word8
|
||||
tagMethodref :: Word8
|
||||
tagNameandtype :: Word8
|
||||
tagInteger :: Word8
|
||||
tagUtf8 :: Word8
|
||||
|
||||
accessPublic :: Word16
|
||||
accessPrivate :: Word16
|
||||
accessProtected :: Word16
|
||||
|
||||
tagClass = 0x07
|
||||
tagFieldref = 0x09
|
||||
tagMethodref = 0x0A
|
||||
tagNameandtype = 0x0C
|
||||
tagInteger = 0x03
|
||||
tagUtf8 = 0x01
|
||||
|
||||
accessPublic = 0x01
|
||||
accessPrivate = 0x02
|
||||
accessProtected = 0x04
|
203
src/Example.hs
203
src/Example.hs
@ -1,203 +0,0 @@
|
||||
module Example where
|
||||
|
||||
import Ast
|
||||
import Typecheck
|
||||
import Control.Exception (catch, evaluate, SomeException, displayException)
|
||||
import Control.Exception.Base
|
||||
import System.IO (stderr, hPutStrLn)
|
||||
import Data.Maybe
|
||||
import Data.List
|
||||
|
||||
green, red, yellow, blue, magenta, cyan, white :: String -> String
|
||||
green str = "\x1b[32m" ++ str ++ "\x1b[0m"
|
||||
red str = "\x1b[31m" ++ str ++ "\x1b[0m"
|
||||
yellow str = "\x1b[33m" ++ str ++ "\x1b[0m"
|
||||
blue str = "\x1b[34m" ++ str ++ "\x1b[0m"
|
||||
magenta str = "\x1b[35m" ++ str ++ "\x1b[0m"
|
||||
cyan str = "\x1b[36m" ++ str ++ "\x1b[0m"
|
||||
white str = "\x1b[37m" ++ str ++ "\x1b[0m"
|
||||
|
||||
printSuccess :: String -> IO ()
|
||||
printSuccess msg = putStrLn $ green "Success:" ++ white msg
|
||||
|
||||
handleError :: SomeException -> IO ()
|
||||
handleError e = hPutStrLn stderr $ red ("Error: " ++ displayException e)
|
||||
|
||||
printResult :: Show a => String -> a -> IO ()
|
||||
printResult title result = do
|
||||
putStrLn $ green title
|
||||
print result
|
||||
|
||||
sampleClasses :: [Class]
|
||||
sampleClasses = [
|
||||
Class "Person" [
|
||||
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
|
||||
(Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
|
||||
]),
|
||||
MethodDeclaration "int" "getAge" [] (Return (Just (Reference "age"))),
|
||||
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"] (Block [])
|
||||
] [
|
||||
VariableDeclaration "int" "age" (Just (IntegerLiteral 25))
|
||||
]
|
||||
]
|
||||
|
||||
initialSymtab :: [(DataType, Identifier)]
|
||||
initialSymtab = []
|
||||
|
||||
exampleExpression :: Expression
|
||||
exampleExpression = BinaryOperation NameResolution (Reference "bob") (Reference "age")
|
||||
|
||||
exampleAssignment :: Expression
|
||||
exampleAssignment = StatementExpressionExpression (Assignment "a" (IntegerLiteral 30))
|
||||
|
||||
exampleMethodCall :: Statement
|
||||
exampleMethodCall = StatementExpressionStatement (MethodCall (Reference "this") "setAge" [IntegerLiteral 30])
|
||||
|
||||
exampleConstructorCall :: Statement
|
||||
exampleConstructorCall = LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30]))))
|
||||
|
||||
exampleNameResolution :: Expression
|
||||
exampleNameResolution = BinaryOperation NameResolution (Reference "b") (Reference "age")
|
||||
|
||||
exampleBlockResolution :: Statement
|
||||
exampleBlockResolution = Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
|
||||
]
|
||||
|
||||
exampleBlockResolutionFail :: Statement
|
||||
exampleBlockResolutionFail = Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
|
||||
LocalVariableDeclaration (VariableDeclaration "bool" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30])
|
||||
]
|
||||
|
||||
exampleMethodCallAndAssignment :: Statement
|
||||
exampleMethodCallAndAssignment = Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "a" Nothing),
|
||||
StatementExpressionStatement (Assignment "a" (Reference "age"))
|
||||
]
|
||||
|
||||
|
||||
exampleMethodCallAndAssignmentFail :: Statement
|
||||
exampleMethodCallAndAssignmentFail = Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 30])))),
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
|
||||
StatementExpressionStatement (Assignment "a" (Reference "age"))
|
||||
]
|
||||
|
||||
testClasses :: [Class]
|
||||
testClasses = [
|
||||
Class "Person" [
|
||||
MethodDeclaration "Person" "Person" [ParameterDeclaration "int" "initialAge"]
|
||||
(Block [
|
||||
Return (Just (Reference "this"))
|
||||
]),
|
||||
MethodDeclaration "void" "setAge" [ParameterDeclaration "int" "newAge"]
|
||||
(Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "age" (Just (Reference "newAge")))
|
||||
]),
|
||||
MethodDeclaration "int" "getAge" []
|
||||
(Return (Just (Reference "age")))
|
||||
] [
|
||||
VariableDeclaration "int" "age" Nothing -- initially unassigned
|
||||
],
|
||||
Class "Main" [
|
||||
MethodDeclaration "int" "main" []
|
||||
(Block [
|
||||
LocalVariableDeclaration (VariableDeclaration "Person" "bob" (Just (StatementExpressionExpression (ConstructorCall "Person" [IntegerLiteral 25])))),
|
||||
StatementExpressionStatement (MethodCall (Reference "bob") "setAge" [IntegerLiteral 30]),
|
||||
LocalVariableDeclaration (VariableDeclaration "int" "bobAge" (Just (StatementExpressionExpression (MethodCall (Reference "bob") "getAge" [])))),
|
||||
Return (Just (Reference "bobAge"))
|
||||
])
|
||||
] []
|
||||
]
|
||||
|
||||
runTypeCheck :: IO ()
|
||||
runTypeCheck = do
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedExpression <- evaluate (typeCheckExpression exampleExpression [("bob", "Person")] sampleClasses)
|
||||
printSuccess "Type checking of expression completed successfully"
|
||||
printResult "Result Expression:" evaluatedExpression
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedAssignment <- evaluate (typeCheckExpression exampleAssignment [("a", "int")] sampleClasses)
|
||||
printSuccess "Type checking of assignment completed successfully"
|
||||
printResult "Result Assignment:" evaluatedAssignment
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedMethodCall <- evaluate (typeCheckStatement exampleMethodCall [("this", "Person"), ("setAge", "Person"), ("getAge", "Person")] sampleClasses)
|
||||
printSuccess "Type checking of method call this completed successfully"
|
||||
printResult "Result MethodCall:" evaluatedMethodCall
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedConstructorCall <- evaluate (typeCheckStatement exampleConstructorCall [] sampleClasses)
|
||||
printSuccess "Type checking of constructor call completed successfully"
|
||||
printResult "Result Constructor Call:" evaluatedConstructorCall
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedNameResolution <- evaluate (typeCheckExpression exampleNameResolution [("b", "Person")] sampleClasses)
|
||||
printSuccess "Type checking of name resolution completed successfully"
|
||||
printResult "Result Name Resolution:" evaluatedNameResolution
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedBlockResolution <- evaluate (typeCheckStatement exampleBlockResolution [] sampleClasses)
|
||||
printSuccess "Type checking of block resolution completed successfully"
|
||||
printResult "Result Block Resolution:" evaluatedBlockResolution
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedBlockResolutionFail <- evaluate (typeCheckStatement exampleBlockResolutionFail [] sampleClasses)
|
||||
printSuccess "Type checking of block resolution failed"
|
||||
printResult "Result Block Resolution:" evaluatedBlockResolutionFail
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedMethodCallAndAssignment <- evaluate (typeCheckStatement exampleMethodCallAndAssignment [] sampleClasses)
|
||||
printSuccess "Type checking of method call and assignment completed successfully"
|
||||
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignment
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
evaluatedMethodCallAndAssignmentFail <- evaluate (typeCheckStatement exampleMethodCallAndAssignmentFail [] sampleClasses)
|
||||
printSuccess "Type checking of method call and assignment failed"
|
||||
printResult "Result Method Call and Assignment:" evaluatedMethodCallAndAssignmentFail
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
let mainClass = fromJust $ find (\(Class className _ _) -> className == "Main") testClasses
|
||||
case mainClass of
|
||||
Class _ [mainMethod] _ -> do
|
||||
let result = typeCheckMethodDeclaration mainMethod [] testClasses
|
||||
printSuccess "Full program type checking completed successfully."
|
||||
printResult "Main method result:" result
|
||||
) handleError
|
||||
|
||||
catch (do
|
||||
print "====================================================================================="
|
||||
let typedProgram = typeCheckCompilationUnit testClasses
|
||||
printSuccess "Type checking of Program completed successfully"
|
||||
printResult "Typed Program:" typedProgram
|
||||
) handleError
|
||||
|
@ -1,8 +1,6 @@
|
||||
module Main where
|
||||
|
||||
import Example
|
||||
import Typecheck
|
||||
import Parser.Lexer
|
||||
|
||||
main = do
|
||||
Example.runTypeCheck
|
||||
|
||||
print $ alexScanTokens "/**/"
|
334
src/Typecheck.hs
334
src/Typecheck.hs
@ -1,334 +0,0 @@
|
||||
module Typecheck where
|
||||
import Data.List (find)
|
||||
import Data.Maybe
|
||||
import Ast
|
||||
|
||||
typeCheckCompilationUnit :: CompilationUnit -> CompilationUnit
|
||||
typeCheckCompilationUnit classes = map (`typeCheckClass` classes) classes
|
||||
|
||||
typeCheckClass :: Class -> [Class] -> Class
|
||||
typeCheckClass (Class className methods fields) classes =
|
||||
let
|
||||
-- Create a symbol table from class fields and method entries
|
||||
classFields = [(id, dt) | VariableDeclaration dt id _ <- fields]
|
||||
methodEntries = [(methodName, className) | MethodDeclaration _ methodName _ _ <- methods]
|
||||
initalSymTab = ("this", className) : classFields ++ methodEntries
|
||||
checkedMethods = map (\method -> typeCheckMethodDeclaration method initalSymTab classes) methods
|
||||
in Class className checkedMethods fields
|
||||
|
||||
typeCheckMethodDeclaration :: MethodDeclaration -> [(Identifier, DataType)] -> [Class] -> MethodDeclaration
|
||||
typeCheckMethodDeclaration (MethodDeclaration retType name params body) classFields classes =
|
||||
let
|
||||
-- Combine class fields with method parameters to form the initial symbol table for the method
|
||||
methodParams = [(identifier, dataType) | ParameterDeclaration dataType identifier <- params]
|
||||
initialSymtab = classFields ++ methodParams
|
||||
checkedBody = typeCheckStatement body initialSymtab classes
|
||||
bodyType = getTypeFromStmt checkedBody
|
||||
-- Check if the type of the body matches the declared return type
|
||||
in if bodyType == retType || (bodyType == "void" && retType == "void")
|
||||
then MethodDeclaration retType name params checkedBody
|
||||
else error $ "Return type mismatch in method " ++ name ++ ": expected " ++ retType ++ ", found " ++ bodyType
|
||||
|
||||
-- ********************************** Type Checking: Expressions **********************************
|
||||
|
||||
typeCheckExpression :: Expression -> [(Identifier, DataType)] -> [Class] -> Expression
|
||||
typeCheckExpression (IntegerLiteral i) _ _ = TypedExpression "int" (IntegerLiteral i)
|
||||
typeCheckExpression (CharacterLiteral c) _ _ = TypedExpression "char" (CharacterLiteral c)
|
||||
typeCheckExpression (BooleanLiteral b) _ _ = TypedExpression "boolean" (BooleanLiteral b)
|
||||
typeCheckExpression NullLiteral _ _ = TypedExpression "null" NullLiteral
|
||||
typeCheckExpression (Reference id) symtab classes =
|
||||
let type' = lookupType id symtab
|
||||
in TypedExpression type' (Reference id)
|
||||
typeCheckExpression (BinaryOperation op expr1 expr2) symtab classes =
|
||||
let expr1' = typeCheckExpression expr1 symtab classes
|
||||
expr2' = typeCheckExpression expr2 symtab classes
|
||||
type1 = getTypeFromExpr expr1'
|
||||
type2 = getTypeFromExpr expr2'
|
||||
in case op of
|
||||
Addition ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Addition operation requires two operands of type int"
|
||||
Subtraction ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Subtraction operation requires two operands of type int"
|
||||
Multiplication ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Multiplication operation requires two operands of type int"
|
||||
Division ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Division operation requires two operands of type int"
|
||||
BitwiseAnd ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Bitwise AND operation requires two operands of type int"
|
||||
BitwiseOr ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Bitwise OR operation requires two operands of type int"
|
||||
BitwiseXor ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "int" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Bitwise XOR operation requires two operands of type int"
|
||||
CompareLessThan ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Less than operation requires two operands of type int"
|
||||
CompareLessOrEqual ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Less than or equal operation requires two operands of type int"
|
||||
CompareGreaterThan ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Greater than operation requires two operands of type int"
|
||||
CompareGreaterOrEqual ->
|
||||
if type1 == "int" && type2 == "int"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Greater than or equal operation requires two operands of type int"
|
||||
CompareEqual ->
|
||||
if type1 == type2
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Equality operation requires two operands of the same type"
|
||||
CompareNotEqual ->
|
||||
if type1 == type2
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Inequality operation requires two operands of the same type"
|
||||
And ->
|
||||
if type1 == "boolean" && type2 == "boolean"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Logical AND operation requires two operands of type boolean"
|
||||
Or ->
|
||||
if type1 == "boolean" && type2 == "boolean"
|
||||
then
|
||||
TypedExpression "boolean" (BinaryOperation op expr1' expr2')
|
||||
else
|
||||
error "Logical OR operation requires two operands of type boolean"
|
||||
NameResolution ->
|
||||
case (expr1', expr2) of
|
||||
(TypedExpression t1 (Reference obj), Reference member) ->
|
||||
let objectType = lookupType obj symtab
|
||||
classDetails = find (\(Class className _ _) -> className == objectType) classes
|
||||
in case classDetails of
|
||||
Just (Class _ _ fields) ->
|
||||
let fieldTypes = [dt | VariableDeclaration dt id _ <- fields, id == member]
|
||||
in case fieldTypes of
|
||||
[resolvedType] -> TypedExpression resolvedType (BinaryOperation NameResolution expr1' (TypedExpression resolvedType expr2))
|
||||
[] -> error $ "Field '" ++ member ++ "' not found in class '" ++ objectType ++ "'"
|
||||
_ -> error $ "Ambiguous reference to field '" ++ member ++ "' in class '" ++ objectType ++ "'"
|
||||
Nothing -> error $ "Object '" ++ obj ++ "' does not correspond to a known class"
|
||||
_ -> error "Name resolution requires object reference and field name"
|
||||
|
||||
typeCheckExpression (UnaryOperation op expr) symtab classes =
|
||||
let expr' = typeCheckExpression expr symtab classes
|
||||
type' = getTypeFromExpr expr'
|
||||
in case op of
|
||||
Not ->
|
||||
if type' == "boolean"
|
||||
then
|
||||
TypedExpression "boolean" (UnaryOperation op expr')
|
||||
else
|
||||
error "Logical NOT operation requires an operand of type boolean"
|
||||
Minus ->
|
||||
if type' == "int"
|
||||
then
|
||||
TypedExpression "int" (UnaryOperation op expr')
|
||||
else
|
||||
error "Unary minus operation requires an operand of type int"
|
||||
|
||||
typeCheckExpression (StatementExpressionExpression stmtExpr) symtab classes =
|
||||
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
|
||||
in TypedExpression (getTypeFromStmtExpr stmtExpr') (StatementExpressionExpression stmtExpr')
|
||||
|
||||
-- ********************************** Type Checking: StatementExpressions **********************************
|
||||
|
||||
typeCheckStatementExpression :: StatementExpression -> [(Identifier, DataType)] -> [Class] -> StatementExpression
|
||||
typeCheckStatementExpression (Assignment id expr) symtab classes =
|
||||
let expr' = typeCheckExpression expr symtab classes
|
||||
type' = getTypeFromExpr expr'
|
||||
type'' = lookupType id symtab
|
||||
in if type' == type''
|
||||
then
|
||||
TypedStatementExpression type' (Assignment id expr')
|
||||
else
|
||||
error "Assignment type mismatch"
|
||||
|
||||
typeCheckStatementExpression (ConstructorCall className args) symtab classes =
|
||||
case find (\(Class name _ _) -> name == className) classes of
|
||||
Nothing -> error $ "Class '" ++ className ++ "' not found."
|
||||
Just (Class _ methods fields) ->
|
||||
-- Constructor needs the same name as the class
|
||||
case find (\(MethodDeclaration retType name params _) -> name == className && retType == className) methods of
|
||||
Nothing -> error $ "No valid constructor found for class '" ++ className ++ "'."
|
||||
Just (MethodDeclaration _ _ params _) ->
|
||||
let
|
||||
args' = map (\arg -> typeCheckExpression arg symtab classes) args
|
||||
-- Extract expected parameter types from the constructor's parameters
|
||||
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
||||
argTypes = map getTypeFromExpr args'
|
||||
-- Check if the types of the provided arguments match the expected types
|
||||
typeMatches = zipWith (\expected actual -> if expected == actual then Nothing else Just (expected, actual)) expectedTypes argTypes
|
||||
mismatchErrors = map (\(exp, act) -> "Expected type '" ++ exp ++ "', found '" ++ act ++ "'.") (catMaybes typeMatches)
|
||||
in
|
||||
if length args /= length params then
|
||||
error $ "Constructor for class '" ++ className ++ "' expects " ++ show (length params) ++ " arguments, but got " ++ show (length args) ++ "."
|
||||
else if not (null mismatchErrors) then
|
||||
error $ unlines $ ("Type mismatch in constructor arguments for class '" ++ className ++ "':") : mismatchErrors
|
||||
else
|
||||
TypedStatementExpression className (ConstructorCall className args')
|
||||
|
||||
typeCheckStatementExpression (MethodCall expr methodName args) symtab classes =
|
||||
let objExprTyped = typeCheckExpression expr symtab classes
|
||||
in case objExprTyped of
|
||||
TypedExpression objType _ ->
|
||||
case find (\(Class className _ _) -> className == objType) classes of
|
||||
Just (Class _ methods _) ->
|
||||
case find (\(MethodDeclaration retType name params _) -> name == methodName) methods of
|
||||
Just (MethodDeclaration retType _ params _) ->
|
||||
let args' = map (\arg -> typeCheckExpression arg symtab classes) args
|
||||
expectedTypes = [dataType | ParameterDeclaration dataType _ <- params]
|
||||
argTypes = map getTypeFromExpr args'
|
||||
typeMatches = zipWith (\expType argType -> (expType == argType, expType, argType)) expectedTypes argTypes
|
||||
mismatches = filter (not . fst3) typeMatches
|
||||
where fst3 (a, _, _) = a
|
||||
in
|
||||
if null mismatches && length args == length params then
|
||||
TypedStatementExpression retType (MethodCall objExprTyped methodName args')
|
||||
else if not (null mismatches) then
|
||||
error $ unlines $ ("Argument type mismatches for method '" ++ methodName ++ "':")
|
||||
: [ "Expected: " ++ expType ++ ", Found: " ++ argType | (_, expType, argType) <- mismatches ]
|
||||
else
|
||||
error $ "Incorrect number of arguments for method '" ++ methodName ++ "'. Expected " ++ show (length expectedTypes) ++ ", found " ++ show (length args) ++ "."
|
||||
Nothing -> error $ "Method '" ++ methodName ++ "' not found in class '" ++ objType ++ "'."
|
||||
Nothing -> error $ "Class for object type '" ++ objType ++ "' not found."
|
||||
_ -> error "Invalid object type for method call. Object must have a class type."
|
||||
|
||||
-- ********************************** Type Checking: Statements **********************************
|
||||
|
||||
typeCheckStatement :: Statement -> [(Identifier, DataType)] -> [Class] -> Statement
|
||||
typeCheckStatement (If cond thenStmt elseStmt) symtab classes =
|
||||
let cond' = typeCheckExpression cond symtab classes
|
||||
thenStmt' = typeCheckStatement thenStmt symtab classes
|
||||
elseStmt' = case elseStmt of
|
||||
Just stmt -> Just (typeCheckStatement stmt symtab classes)
|
||||
Nothing -> Nothing
|
||||
in if getTypeFromExpr cond' == "boolean"
|
||||
then
|
||||
TypedStatement (getTypeFromStmt thenStmt') (If cond' thenStmt' elseStmt')
|
||||
else
|
||||
error "If condition must be of type boolean"
|
||||
|
||||
typeCheckStatement (LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr)) symtab classes =
|
||||
-- Check for redefinition in the current scope
|
||||
if any ((== identifier) . snd) symtab
|
||||
then error $ "Variable '" ++ identifier ++ "' is redefined in the same scope"
|
||||
else
|
||||
-- If there's an initializer expression, type check it
|
||||
let checkedExpr = fmap (\expr -> typeCheckExpression expr symtab classes) maybeExpr
|
||||
exprType = fmap getTypeFromExpr checkedExpr
|
||||
in case exprType of
|
||||
Just t | t /= dataType -> error $ "Type mismatch in declaration of '" ++ identifier ++ "': expected " ++ dataType ++ ", found " ++ t
|
||||
_ -> TypedStatement dataType (LocalVariableDeclaration (VariableDeclaration dataType identifier checkedExpr))
|
||||
|
||||
typeCheckStatement (While cond stmt) symtab classes =
|
||||
let cond' = typeCheckExpression cond symtab classes
|
||||
stmt' = typeCheckStatement stmt symtab classes
|
||||
in if getTypeFromExpr cond' == "boolean"
|
||||
then
|
||||
TypedStatement (getTypeFromStmt stmt') (While cond' stmt')
|
||||
else
|
||||
error "While condition must be of type boolean"
|
||||
|
||||
typeCheckStatement (Block statements) symtab classes =
|
||||
let
|
||||
processStatements (accSts, currentSymtab, types) stmt =
|
||||
let
|
||||
checkedStmt = typeCheckStatement stmt currentSymtab classes
|
||||
stmtType = getTypeFromStmt checkedStmt
|
||||
in case stmt of
|
||||
LocalVariableDeclaration (VariableDeclaration dataType identifier maybeExpr) ->
|
||||
let
|
||||
checkedExpr = fmap (\expr -> typeCheckExpression expr currentSymtab classes) maybeExpr
|
||||
newSymtab = (identifier, dataType) : currentSymtab
|
||||
in (accSts ++ [checkedStmt], newSymtab, types)
|
||||
|
||||
If {} -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
|
||||
While _ _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
|
||||
Return _ -> (accSts ++ [checkedStmt], currentSymtab, if stmtType /= "void" then types ++ [stmtType] else types)
|
||||
_ -> (accSts ++ [checkedStmt], currentSymtab, types)
|
||||
|
||||
-- Initial accumulator: empty statements list, initial symbol table, empty types list
|
||||
(checkedStatements, finalSymtab, collectedTypes) = foldl processStatements ([], symtab, []) statements
|
||||
|
||||
-- Determine the block's type: unify all collected types, default to "Void" if none
|
||||
blockType = if null collectedTypes then "void" else foldl1 unifyReturnTypes collectedTypes
|
||||
|
||||
in TypedStatement blockType (Block checkedStatements)
|
||||
|
||||
typeCheckStatement (Return expr) symtab classes =
|
||||
let expr' = case expr of
|
||||
Just e -> Just (typeCheckExpression e symtab classes)
|
||||
Nothing -> Nothing
|
||||
in case expr' of
|
||||
Just e' -> TypedStatement (getTypeFromExpr e') (Return (Just e'))
|
||||
Nothing -> TypedStatement "Void" (Return Nothing)
|
||||
|
||||
typeCheckStatement (StatementExpressionStatement stmtExpr) symtab classes =
|
||||
let stmtExpr' = typeCheckStatementExpression stmtExpr symtab classes
|
||||
in TypedStatement (getTypeFromStmtExpr stmtExpr') (StatementExpressionStatement stmtExpr')
|
||||
|
||||
-- ********************************** Type Checking: Helpers **********************************
|
||||
|
||||
getTypeFromExpr :: Expression -> DataType
|
||||
getTypeFromExpr (TypedExpression t _) = t
|
||||
getTypeFromExpr _ = error "Untyped expression found where typed was expected"
|
||||
|
||||
getTypeFromStmt :: Statement -> DataType
|
||||
getTypeFromStmt (TypedStatement t _) = t
|
||||
getTypeFromStmt _ = error "Untyped statement found where typed was expected"
|
||||
|
||||
getTypeFromStmtExpr :: StatementExpression -> DataType
|
||||
getTypeFromStmtExpr (TypedStatementExpression t _) = t
|
||||
getTypeFromStmtExpr _ = error "Untyped statement expression found where typed was expected"
|
||||
|
||||
unifyReturnTypes :: DataType -> DataType -> DataType
|
||||
unifyReturnTypes dt1 dt2
|
||||
| dt1 == dt2 = dt1
|
||||
| otherwise = "Object"
|
||||
|
||||
lookupType :: Identifier -> [(Identifier, DataType)] -> DataType
|
||||
lookupType id symtab =
|
||||
case lookup id symtab of
|
||||
Just t -> t
|
||||
Nothing -> error ("Identifier " ++ id ++ " not found in symbol table")
|
Loading…
Reference in New Issue
Block a user