diff --git a/src/main/scala/hb/dhbw/FJTypeinference.scala b/src/main/scala/hb/dhbw/FJTypeinference.scala index 1220667..31a3d2d 100644 --- a/src/main/scala/hb/dhbw/FJTypeinference.scala +++ b/src/main/scala/hb/dhbw/FJTypeinference.scala @@ -31,16 +31,20 @@ object FJTypeinference { def typeinference(str: String): Either[String, Set[Set[UnifyConstraint]]] = { val ast = Parser.parse(str).map(ASTBuilder.fromParseTree(_)) + val typeResult = ast.map(ast => { - /*ast.foldLeft(List())((cOld, c) => { - val typeResult = TYPE.generateConstraints(ast, generateFC(ast)) + var unifyResults = Set[Set[Set[UnifyConstraint]]]() + ast.foldLeft(List[Class]())((cOld, c) => { + val newClassList = cOld :+ c + val typeResult = TYPE.generateConstraints(newClassList, generateFC(newClassList)) val unifyResult = Unify.unify(convertOrConstraints(typeResult._1), typeResult._2) - //TODO: Insert intersection types - List(c) - }) */ - TYPE.generateConstraints(ast, generateFC(ast)) + //Insert intersection types + val typeInsertedC = InsertTypes.insert(unifyResult, c) + unifyResults = unifyResults + unifyResult + cOld :+ typeInsertedC + }) + unifyResults }) - val unifyResult = typeResult.map(res => Unify.unify(convertOrConstraints(res._1), res._2)) - unifyResult + typeResult.map(_.flatten) } } diff --git a/src/main/scala/hb/dhbw/InsertTypes.scala b/src/main/scala/hb/dhbw/InsertTypes.scala index c2682e1..9a19904 100644 --- a/src/main/scala/hb/dhbw/InsertTypes.scala +++ b/src/main/scala/hb/dhbw/InsertTypes.scala @@ -1,22 +1,40 @@ package hb.dhbw -class InsertTypes { +object InsertTypes { def insert(unifyResult: Set[Set[UnifyConstraint]], into: Class): Class = { - val constraints = unifyResult.map(_.map(replaceTVWithGeneric(_))) + + def extractTVNames(unifyType: UnifyType): Set[String] = unifyType match { + case UnifyTV(name) => Set(name) + case UnifyRefType(_, params) => params.flatMap(extractTVNames(_)).toSet + } + + val genericNames:Set[String] = into.genericParams.map(_._1).flatMap(_ match { + case GenericType(name) => Some(name) + case _ => None + }).toSet ++ unifyResult.flatMap(_.flatMap(_ match{ + case UnifyLessDot(a,b) => Set(a, b) + case UnifyEqualsDot(a,b) => Set(a,b) + })).flatMap(extractTVNames(_)) + val constraints = unifyResult.map(_.map(replaceTVWithGeneric(_, genericNames))) val newMethods = into.methods.flatMap(m => constraints.map(cons => insert(cons, m))) Class(into.name, into.genericParams, into.superType, into.fields, newMethods) } private def insert(constraints: Set[Constraint], into: Method): Method = { - Method(into.genericParams ++ constraints, into.retType, into.name, into.params, into.retExpr) + def replaceTVWithGeneric(in: Type): Type = in match { + case TypeVariable(name) => GenericType(name) + case RefType(name, params) => RefType(name, params.map(replaceTVWithGeneric(_))) + } + Method(into.genericParams ++ constraints, replaceTVWithGeneric(into.retType), into.name, into.params.map(p => (replaceTVWithGeneric(p._1), p._2)), into.retExpr) } - private def replaceTVWithGeneric(in: UnifyConstraint): Constraint= in match { - case UnifyLessDot(a,b) => LessDot(replaceTVWithGeneric(a), replaceTVWithGeneric(b)) - case UnifyEqualsDot(a, b) => EqualsDot(replaceTVWithGeneric(a), replaceTVWithGeneric(b)) + private def replaceTVWithGeneric(in: UnifyConstraint, genericNames: Set[String]): Constraint= in match { + case UnifyLessDot(a,b) => LessDot(replaceTVWithGeneric(a, genericNames), replaceTVWithGeneric(b, genericNames)) + case UnifyEqualsDot(a, b) => EqualsDot(replaceTVWithGeneric(a, genericNames), replaceTVWithGeneric(b, genericNames)) } - private def replaceTVWithGeneric(in: UnifyType) : Type = in match { - case UnifyRefType(name, params) => RefType(name, params.map(replaceTVWithGeneric(_))) + private def replaceTVWithGeneric(in: UnifyType, genericNames: Set[String]) : Type = in match { + case UnifyRefType(name, List()) => if(genericNames.contains(name)) GenericType(name) else RefType(name,List()) + case UnifyRefType(name, params) => RefType(name, params.map(replaceTVWithGeneric(_, genericNames))) case UnifyTV(name) => GenericType(name) }