diff --git a/src/main/scala/hb/dhbw/AST.scala b/src/main/scala/hb/dhbw/AST.scala index 594a2a2..ae1e824 100644 --- a/src/main/scala/hb/dhbw/AST.scala +++ b/src/main/scala/hb/dhbw/AST.scala @@ -14,6 +14,7 @@ final case class FieldVar(e: Expr, f: String) extends Expr final case class MethodCall(e: Expr, name: String, params: List[Expr]) extends Expr final case class Constructor(className: String, params: List[Expr]) extends Expr final case class Cast(to: Type, expr: Expr) extends Expr +final case class Lambda(p: String, expr: Expr) extends Expr object ASTBuilder { def fromParseTree(toAst: List[ParserClass]) = new ASTBuilderMonad().fromParseTree(toAst) @@ -37,6 +38,7 @@ object ASTBuilder { case PFieldVar(e, f) => FieldVar(fromParseExpr(e, genericNames), f) case PCast(ntype, e) => Cast(nTypeToType(ntype, genericNames), fromParseExpr(e, genericNames)) case PLocalVar(n) => LocalVar(n) + case PLambda(p, e) => Lambda(p, fromParseExpr(e, genericNames)) } private def freshTPV() = { diff --git a/src/main/scala/hb/dhbw/Main.scala b/src/main/scala/hb/dhbw/Main.scala index 57fded6..9dc2327 100644 --- a/src/main/scala/hb/dhbw/Main.scala +++ b/src/main/scala/hb/dhbw/Main.scala @@ -53,6 +53,7 @@ object Main { case FieldVar(e, f) => prettyPrintExpr(e)+"."+f case MethodCall(e, name, params) => prettyPrintExpr(e)+"."+name+"("+params.map(prettyPrintExpr(_)).mkString(", ")+")" case Constructor(className, params) => "new "+className+"(" + params.map(prettyPrintExpr(_)).mkString(", ") +")" + case Lambda(a, expr) => "(" + a + ") -> " + prettyPrintExpr(expr) } def prettyPrintType(l: Type): String = l match { case RefType(name, List()) => name diff --git a/src/main/scala/hb/dhbw/Parser.scala b/src/main/scala/hb/dhbw/Parser.scala index d11378d..869e5fa 100644 --- a/src/main/scala/hb/dhbw/Parser.scala +++ b/src/main/scala/hb/dhbw/Parser.scala @@ -12,6 +12,7 @@ final case class PFieldVar(e: ParserExpr, f: String) extends ParserExpr final case class PMethodCall(e: ParserExpr, name: String, params: List[ParserExpr]) extends ParserExpr final case class PConstructor(className: String, params: List[ParserExpr]) extends ParserExpr final case class PCast(to: NType, expr: ParserExpr) extends ParserExpr +final case class PLambda(param: String, expr: ParserExpr) extends ParserExpr final case class NType(name: String, params: List[NType]) @@ -35,7 +36,7 @@ object Parser { .map(ite => ite._2.map(params => params._1 :: params._2).getOrElse(List.empty)) def variable[_: P]: P[ParserExpr] = P(ident).map(PLocalVar) def cast[_: P]: P[ParserExpr] = P("(" ~ typeParser ~ ")" ~ expr).map(x => PCast(x._1, x._2)) - def expr[_: P]: P[ParserExpr] = P( (variable | constructor | cast)~ (prefixMethodCall | fieldVar).rep.map(_.toList) ) + def expr[_: P]: P[ParserExpr] = P( (variable | lambdaExpr | constructor | cast)~ (prefixMethodCall | fieldVar).rep.map(_.toList) ) .map(ite => ite._2.foldLeft(ite._1) { (e1 : ParserExpr, e2 : ParserExpr) => e2 match{ case PMethodCall(_, name, e3) => PMethodCall(e1, name, e3) @@ -43,6 +44,8 @@ object Parser { } }) + def lambdaExpr[_: P]: P[PLambda] = P( "(" ~ ident ~ ")" ~ "->" ~ expr).map(it => PLambda(it._1, it._2)) + def constructor[_: P]: P[ParserExpr] = P( kw("new") ~ methodCall).map(m => PConstructor(m.name,m.params)) def classDefinition[_: P]: P[ParserClass] = P(kw("class") ~ ident ~ genericParamList.? ~ kw("extends") ~ typeParser ~ "{" ~ field.rep(0) ~ method.rep(0) ~ "}") diff --git a/src/main/scala/hb/dhbw/TYPE.scala b/src/main/scala/hb/dhbw/TYPE.scala index 981339e..22b2686 100644 --- a/src/main/scala/hb/dhbw/TYPE.scala +++ b/src/main/scala/hb/dhbw/TYPE.scala @@ -94,6 +94,14 @@ object TYPE { val (rty, cons) = TYPEExpr(expr, localVars, ast) (casttype, cons) } + case Lambda(p, expr) => { + val t = freshTPV() + val a = freshTPV() + val b = freshTPV() + val typeRet = TYPEExpr(expr, (t, p) :: localVars, ast) + val cons = typeRet._2 ++ List(LessDot(t, a), LessDot(typeRet._1, b)) + (RefType("Function", List(a, b)), cons) + } } private def findMethods(m: String, numParams: Int, ast: List[Class]) = diff --git a/src/main/scala/hb/dhbw/Unify.scala b/src/main/scala/hb/dhbw/Unify.scala index 25a2864..4cc5d93 100644 --- a/src/main/scala/hb/dhbw/Unify.scala +++ b/src/main/scala/hb/dhbw/Unify.scala @@ -132,6 +132,11 @@ object Unify { ) } + def eraseFRule(eq: Set[UnifyConstraint]): Set[UnifyConstraint]= eq.filter(_ match { + case UnifyLessDot(UnifyRefType("Function", _), UnifyRefType("Object", _)) => false + case _ => true + }) + def reduceRule(eq: Set[UnifyConstraint]) = eq.flatMap(c => c match { case UnifyEqualsDot(UnifyRefType(an, ap), UnifyRefType(bn, bp)) => { if(an.equals(bn)){ @@ -318,7 +323,7 @@ object Unify { var eqFinish: Set[UnifyConstraint] = eq do{ eqNew = doWhileSome(Unify.equalsRule,eqFinish) //We have to apply equals rule first, to get rid of circles - eqFinish = eraseRule(swapRule(reduceRule(matchRule(adoptRule(adaptRule(eqNew, fc), fc), fc)))) + eqFinish = eraseRule(swapRule(reduceRule(eraseFRule(matchRule(adoptRule(adaptRule(eqNew, fc), fc), fc))))) }while(!eqNew.equals(eqFinish)) eqNew } diff --git a/src/test/scala/IntegrationTest.scala b/src/test/scala/IntegrationTest.scala index b844f16..9617ef7 100644 --- a/src/test/scala/IntegrationTest.scala +++ b/src/test/scala/IntegrationTest.scala @@ -24,6 +24,11 @@ class IntegrationTest extends FunSuite { val result = FJTypeinference.typeinference("class List extends Object{\n add(a){\n return this;\n}\n}") println(result.map(Main.prettyPrintAST(_))) } + + test("lambdaIdentity"){ + val result = FJTypeinference.typeinference("class Test extends Object{\ntest(){\nreturn (a) -> a;\n}\n}") + println(result.map(Main.prettyPrintAST(_))) + } /* test("PaperExample"){ val result = FJTypeinference.typeinference("class List extends Object{\n add( a){\n return this;\n}\n}\nclass Test extends Object{\nm(a){ return a.add(this);}\n}")