diff --git a/src/main/scala/hb/dhbw/InsertTypes.scala b/src/main/scala/hb/dhbw/InsertTypes.scala index 361c741..41cef41 100644 --- a/src/main/scala/hb/dhbw/InsertTypes.scala +++ b/src/main/scala/hb/dhbw/InsertTypes.scala @@ -46,17 +46,17 @@ object InsertTypes { case GenericType(name) => Some(name) case _ => None }).toSet - def convertFromUnifyConstraint(t: UnifyType): Type = t match { + def refTypeToGenerics(t: UnifyType): Type = t match { case UnifyTV(a) => TypeVariable(a) case UnifyRefType(n, List()) => if(genericNames.contains(n)) GenericType(n) else RefType(n, List()) - case UnifyRefType(n, params) => RefType(n, params.map(convertFromUnifyConstraint(_))) + case UnifyRefType(n, params) => RefType(n, params.map(refTypeToGenerics(_))) } - def convertFromUnifyConstraint(c: UnifyConstraint): Constraint = c match { - case UnifyLessDot(a, b) => LessDot(convertFromUnifyConstraint(a), convertFromUnifyConstraint(b)) + def refTypeInConsToGenerics(c: UnifyConstraint): Constraint = c match { + case UnifyLessDot(a, b) => LessDot(refTypeToGenerics(a), refTypeToGenerics(b)) + case UnifyEqualsDot(a, b) => EqualsDot(refTypeToGenerics(a), refTypeToGenerics(b)) } - def convertFromUnifyConstraint(cons: Set[UnifyConstraint]): Set[Constraint] = cons.map(convertFromUnifyConstraint(_)) - // TODO - val constraints = flatted.map(convertFromUnifyConstraint(_)) + + val constraints = flatted.map(_.map(refTypeInConsToGenerics(_))) /* @@ -101,6 +101,7 @@ object InsertTypes { def replaceTVWithGeneric(in: Type): Type = in match { case TypeVariable(name) => GenericType(name) case RefType(name, params) => RefType(name, params.map(replaceTVWithGeneric(_))) + case GenericType(n) => GenericType(n) } def substType(t: Type) = constraints.map(_ match { case EqualsDot(t1, t2) => if(t.equals(t1)) t2 else null @@ -109,27 +110,28 @@ object InsertTypes { .map(replaceTVWithGeneric(_)) .getOrElse(if(t.isInstanceOf[TypeVariable]) GenericType(t.asInstanceOf[TypeVariable].name) else t) - def getAllGenerics(from: Type): Set[Type] = from match { - case RefType(name, params) => params.flatMap(getAllGenerics(_)).toSet - case GenericType(a) => Set(GenericType(a)) + def getAllTVs(from: Type): Set[Type] = from match { + case RefType(name, params) => params.flatMap(getAllTVs(_)).toSet + case GenericType(a) => Set() + case TypeVariable(a) => Set(TypeVariable(a)) } val genericRetType = substType(into.retType) val genericParams = into.params.map(p => (substType(p._1), p._2)) - val tvsUsedInMethod = (Set(genericRetType) ++ genericParams.map(_._1)).flatMap(getAllGenerics(_)) + val tvsUsedInMethod = (Set(into.retType) ++ into.params.map(_._1)).flatMap(getAllTVs(_)) val constraintsForMethod = getLinkedConstraints(tvsUsedInMethod, constraints) + val mCons = (into.genericParams ++ constraintsForMethod).map(replaceTypeVarWithGeneric(_)) - Method(into.genericParams ++ constraintsForMethod, genericRetType, into.name, genericParams, into.retExpr) + Method(mCons, genericRetType, into.name, genericParams, into.retExpr) } - private def replaceRefTypeWithGeneric(in: UnifyConstraint, genericNames: Set[String]): Constraint= in match { - case UnifyLessDot(a,b) => LessDot(replaceRefTypeWithGeneric(a, genericNames), replaceRefTypeWithGeneric(b, genericNames)) - case UnifyEqualsDot(a, b) => EqualsDot(replaceRefTypeWithGeneric(a, genericNames), replaceRefTypeWithGeneric(b, genericNames)) + private def replaceTypeVarWithGeneric(in: Constraint): Constraint= in match { + case LessDot(a,b) => LessDot(replaceTypeVarWithGeneric(a), replaceTypeVarWithGeneric(b)) + case EqualsDot(a, b) => EqualsDot(replaceTypeVarWithGeneric(a), replaceTypeVarWithGeneric(b)) } - private def replaceRefTypeWithGeneric(in: UnifyType, genericNames: Set[String]) : Type = in match { - case UnifyRefType(name, List()) => if(genericNames.contains(name)) GenericType(name) else RefType(name,List()) - case UnifyRefType(name, params) => RefType(name, params.map(replaceRefTypeWithGeneric(_, genericNames))) - case UnifyTV(name) => GenericType(name) + private def replaceTypeVarWithGeneric(in: Type) : Type = in match { + case TypeVariable(name) => GenericType(name) + case x => x } } diff --git a/src/main/scala/hb/dhbw/Unify.scala b/src/main/scala/hb/dhbw/Unify.scala index 05f44d9..3652ae6 100644 --- a/src/main/scala/hb/dhbw/Unify.scala +++ b/src/main/scala/hb/dhbw/Unify.scala @@ -259,7 +259,7 @@ object Unify { case UnifyTV(a) => tv.equals(UnifyTV(a)) case UnifyRefType(a,p) => paramsContain(tv, UnifyRefType(a,p)) }).isDefined - def substStep(eq: Set[UnifyConstraint]) = { + def substStep(eq: Set[UnifyConstraint]): Step4Result = { def substCall(eq: Set[UnifyConstraint]) = eq.find(c => c match { case UnifyEqualsDot(UnifyTV(a), UnifyRefType(n, p)) => !paramsContain(UnifyTV(a), UnifyRefType(n,p)) case UnifyEqualsDot(UnifyTV(a), UnifyTV(b)) => !a.equals(b)