diff --git a/src/main/scala/unify/algorithm/Rules.scala b/src/main/scala/unify/algorithm/Rules.scala index 5418a02..1401132 100644 --- a/src/main/scala/unify/algorithm/Rules.scala +++ b/src/main/scala/unify/algorithm/Rules.scala @@ -1,22 +1,62 @@ package unify.algorithm -import unify.model.LessDot +import unify.model.{ExtendsRelations, LessDot} -object Rules { - private def adaptRule(cons: LessDot, cs: ConstraintSet, fc: FiniteClosure) = { - def paramSubst(param: Type, paramMap: Map[Type, Type]): Type = param match { - case RefType(w, n, params) => RefType(w, n, params.map(paramSubst(_, paramMap))) - case typeVariable => paramMap.get(typeVariable).get - } +import unify.model.* - val left = cons.left.asInstanceOf[RefType] - val right = cons.right.asInstanceOf[RefType] - if (fc.isPossibleSupertype(left.name, right.name)) { - val subtypeRelation = fc.supertype(left.name) // C <. D - val paramMap = subtypeRelation._1.params.zip(left.params).toMap - val newParams = subtypeRelation._2.params.map(paramSubst(_, paramMap)) - cs.remove(cons) - cs.add(LessDot(RefType(left.wildcards, right.name, newParams), right)) - } +object Rules: + def applyRules(cs: ConstraintSet, tvFactory: TypeVariableFactory, fc: ExtendsRelations): Unit = { + cs.getEqualsDotCons.foreach(cons => { + if (cons.left.equals(cons.right)) { //erase rule + cs.remove(cons) + return + } + /* + else if (cons.left.isInstanceOf[Wildcard] && cons.right.isInstanceOf[Wildcard]) { + normalizeRule(cons, cs) + return + } + else if (cons.left.isInstanceOf[Wildcard] && !cons.right.isInstanceOf[WTypeVariable]) { + //Tame rule: + tameRule(cons, cs) + return + } + else if (cons.left.isInstanceOf[RefType] && cons.right.isInstanceOf[Wildcard]) { + //Swap rule: + cs.remove(cons) + cs.add(EqualsDot(cons.right, cons.left)) + return + } + else if (!cons.left.isInstanceOf[TypeVariable] && cons.right.isInstanceOf[TypeVariable]) { + //swap rule: + cs.remove(cons) + cs.add(EqualsDot(cons.right, cons.left)) + return + } + else if ((cons.left.isInstanceOf[RefType] || cons.left.isInstanceOf[Wildcard]) && cons.right.isInstanceOf[WTypeVariable]) { + //swap rule: + cs.remove(cons) + cs.add(EqualsDot(cons.right, cons.left)) + return + } + else if (cons.left.isInstanceOf[RefType] && cons.right.isInstanceOf[RefType] && + cs.fv(cons.left).isEmpty && cs.fv(cons.right).isEmpty) { // reduceEq rule + cs.remove(cons) + cs.add(LessDot(cons.left, cons.right)) + cs.add(LessDot(cons.right, cons.left)) + return + } + else if (cons.left.isInstanceOf[RefType] && cons.right.isInstanceOf[RefType] + && cons.left.asInstanceOf[RefType].name.equals(cons.right.asInstanceOf[RefType].name) + && cons.left.asInstanceOf[RefType].wildcards.equals(cons.right.asInstanceOf[RefType].wildcards)) { + reduceEqRule(cons, cs, tvFactory) + return + } + */ + }) } -} + + def adaptRule(cons: LessDot, cs: ConstraintSet, fc: ExtendsRelations) = { + + } + diff --git a/src/main/scala/unify/algorithm/TypeVariableFactory.scala b/src/main/scala/unify/algorithm/TypeVariableFactory.scala new file mode 100644 index 0000000..6afa8ca --- /dev/null +++ b/src/main/scala/unify/algorithm/TypeVariableFactory.scala @@ -0,0 +1,23 @@ +package unify.algorithm + +import unify.model.* +import java.util.concurrent.atomic.AtomicInteger + +class TypeVariableFactory : + val tpvNum = new AtomicInteger() + + def freshWTV(): WTypeVariable = { + val current = tpvNum.incrementAndGet() + WTypeVariable(current.toString) + } + + def freshTV() = { + val current = tpvNum.incrementAndGet() + TypeVariable(current.toString) + } + + def freshName() = { + val current = tpvNum.incrementAndGet() + current.toString + } + diff --git a/src/main/scala/unify/model/ConstraintSet.scala b/src/main/scala/unify/model/ConstraintSet.scala index 1d3c2d4..fcaa86a 100644 --- a/src/main/scala/unify/model/ConstraintSet.scala +++ b/src/main/scala/unify/model/ConstraintSet.scala @@ -1,249 +1,20 @@ package unify.model +class ConstraintSet( private var unifier: Set[EqualsDot] = Set(), + private var lessdot: Set[LessDot] = Set(), + private var equalsdot: Set[EqualsDot] = Set(), + private var bounds: WildcardEnvironment, + private var lessdotCC : List[LessDot] = List(), + private var changed: Boolean = false) { -case class ConstraintSet(private var unifier: Set[EqualsDot] = Set(), - private var processed: Set[(TypeVariable, TypeVariable)] = Set(), - private var lessdot: List[LessDot] = List(), - //private var lessdotFTV: List[LessDot] = List(), //lessdot Constraints containing free type variables - private var equalsdot: Set[EqualsDot] = Set(), - private var bounds: mutable.Map[Wildcard, (Type,Type)] = mutable.HashMap(), //(upperBound, lowerBound) - private var changed: Boolean = false) { - - /** - * Checks if a type variable is used as lower bound and nowhere else - * This means, that the type variable is only present in a <. T constraints on the left side - * and not used as an upper bound in wildcards - */ - def isOnlyLowerBound(tv: TypeVariable): Boolean = { - def tvInType(t: Type): Boolean = t match { - case RefType(_, _, params) => !params.find(tvInType).isEmpty - case x => x.equals(tv) - } - val alessdotT : Set[Constraint] = getLessDotCons.filter(cons => cons.left.equals(tv) && !tvInType(cons.right)).toSet - val tvIsOnlyInAlessdotTCons = getAllConstraints.filter(!alessdotT.contains(_)).find(cons => tvInType(cons.left) || tvInType(cons.right)).isEmpty - bounds.map(_._2._2).toSet.contains(tv) && tvIsOnlyInAlessdotTCons && !bounds.map(_._2._1).toSet.contains(tv) - } - - /** - * ground rule. removes all a <. T constraints in cons and sets a to bottom type - */ - def groundRule(cons: LessDot) = { - this.remove(cons) - this.addUnifier(EqualsDot(cons.left, BotType())) - this.bounds = this.bounds.map(bound => if(bound._2._2.equals(cons.left)) { - (bound._1 -> (bound._2._1, BotType())) - } else bound) - } - - def tvs(t: Type) : Set[TypeVariable] = t match { - case TypeVariable(a) => Set(TypeVariable(a)) - case RefType(wildcards, name, params) => params.flatMap(tvs).toSet - case _ => Set() - } - def fv(t: Type) : Set[Type] = t match { - case TypeVariable(a) => Set() - case WTypeVariable(a) => Set(WTypeVariable(a)) - case RefType(wc, _, ps) => { - ps.map(fv).flatten.filter(t => !wc.contains(t)).toSet - } - case Wildcard(w) => Set(Wildcard(w)) - } - - /** - * Method removes all wildcard type variables a? from the constraint set - * and replaces them with normal type variables a - */ - def trimWTVs() = { - def trimWTV(t: Type): Type = t match { - case WTypeVariable(n) => { - changed = true - TypeVariable(n) - } - case RefType(ws, n, ps) => RefType(ws.map(trimWTV).asInstanceOf[List[Wildcard]], n, ps.map(trimWTV)) - case Wildcard(name) => { - Wildcard(name) - } - case x => x - } - unifier = unifier.map(c => EqualsDot(trimWTV(c.left), trimWTV(c.right))) - lessdot = lessdot.map(c => LessDot(trimWTV(c.left), trimWTV(c.right))) - equalsdot = equalsdot.map(c => EqualsDot(trimWTV(c.left), trimWTV(c.right))) - bounds = bounds.map(b => b._1 -> (trimWTV(b._2._1), trimWTV(b._2._2))) - } - - def generateFreeWildcard(wc: Wildcard, tvFactory: TypeVariableFactory): Wildcard = { - val freshWC = Wildcard(tvFactory.freshName()) - bounds.put(freshWC, bounds(wc)) - freshWC - } - - def freshWildcard(tvFactory: TypeVariableFactory, wcFactory: WildcardFactory): Wildcard = { - val newWC = Wildcard(wcFactory.freshName()) - bounds.put(newWC,(tvFactory.freshTV(),tvFactory.freshTV())) - newWC - } - - /** - * checks if from <.* to - * - * @return - */ - def isLinked(from: TypeVariable, to: TypeVariable): Boolean = { - var searchNext = Set(from) - var visited: Set[TypeVariable] = Set() - do { - searchNext --= visited - searchNext = searchNext.flatMap(tv => { - visited = visited + tv - this.getALessDotBCons.filter(c => c.left.equals(tv)).map(c => c.right.asInstanceOf[TypeVariable]) - }) - if (searchNext.contains(to)) return true - } while (searchNext.nonEmpty) - false - } - - def notProcessed(a: TypeVariable, b: TypeVariable): Boolean = !processed.contains((a,b)) - - def hasChanged(): Boolean = { - val ret = changed - changed = false - ret - } - - def setProcessedByAdopt(a: TypeVariable, b: TypeVariable): Unit = { - processed = processed + ((a,b)) - } - - /** - * Also returns unifier a =. T and all other constraints a <. T, a <. b - * @return - */ - private def getAllConstraints: Set[Constraint] = Set[Constraint]() ++ lessdot ++ equalsdot ++ unifier - def getLessDotCons: List[LessDot] = lessdot - - def getEqualsDotCons: Set[EqualsDot] = equalsdot - - def getALessDotBCons: Set[LessDot] = lessdot.filter(c => c.left.isInstanceOf[TypeVariable] && c.right.isInstanceOf[TypeVariable]).toSet - - def getALessDotCCons() = getLessDotCons.filter(c => c match { - case LessDot(TypeVariable(_), RefType(_, _, _)) => true - case _ => false - }) - - def addUnifier(c: EqualsDot) = { - changed = true - unifier = unifier + c - } - - def copy(): ConstraintSet = { - val ret = ConstraintSet() - ret.processed = processed - ret.lessdot = lessdot - ret.equalsdot = equalsdot - ret.unifier = unifier - ret.bounds ++= bounds - ret - } - - def substitute(withType: Type, toSubstitute: Type): Unit = { - def substitute(inType: Type): Type = if(inType.equals(toSubstitute)) { - withType - } else {inType match { - case RefType(wildcards, name, params) => RefType(wildcards.map(substitute).map(_.asInstanceOf[Wildcard]), name, params.map(substitute)) - case t => t - } - } - getEqualsDotCons.foreach(c => { - val subst = EqualsDot(substitute(c.left), substitute(c.right)) - if(! c.equals(subst)){ - remove(c) - add(subst) - } - }) - lessdot = getLessDotCons.map(c => { - LessDot(substitute(c.left), substitute(c.right)) - }) - unifier = unifier.map(c => { - EqualsDot(substitute(c.left), substitute(c.right)) - }) - this.bounds = this.bounds.map(wc => wc._1 -> (substitute(wc._2._1), substitute(wc._2._2))) - } - - def addAll(value: Iterable[Constraint]) = value.map(add) - - def add(cons: Constraint): Unit = { - cons match { - case EqualsDot(l,r) => equalsdot += (EqualsDot(l,r)) - case LessDot(TypeVariable(a), TypeVariable(b)) => { - lessdot ::= (LessDot(TypeVariable(a), TypeVariable(b))) - } - case LessDot(l,r) => lessdot ::= (LessDot(l,r)) - } - changed = true - } + def getEqualsDotCons = equalsdot def remove(cons: Constraint): Unit = { cons match { - case EqualsDot(l, r) => { - equalsdot -= (EqualsDot(l, r)) - } - case LessDot(TypeVariable(a), TypeVariable(b)) => { - lessdot = lessdot.diff(List(LessDot(TypeVariable(a), TypeVariable(b)))) - } - case LessDot(l, r) => lessdot = lessdot.diff(List(LessDot(l, r))) + case EqualsDot(l, r) => equalsdot -= (EqualsDot(l, r)) + case LessDot(l, r) => lessdot -= LessDot(l, r) + case LessDotCC(left, right) => lessdotCC.diff(List(LessDotCC(left, right))) } changed = true } - def removeAll(cons: List[Constraint]): Unit = cons.foreach(remove) - - private def isSolvedConstraint(c :Constraint) = c match { - case LessDot(TypeVariable(_), RefType(wc, n, ps)) => - fv(RefType(wc, n, ps)).isEmpty - case EqualsDot(TypeVariable(a), RefType(wc, n, ps)) => - !tvs(RefType(wc, n, ps)).contains(TypeVariable(a)) - case EqualsDot(TypeVariable(a), TypeVariable(b)) => true - case _ => false - } - def isSolvedForm: Boolean = getEqualsDotCons.isEmpty && - getLessDotCons.map(isSolvedConstraint).reduceOption((a,b) => a && b).getOrElse(true) && - getAllConstraints.map(_.left).size == getAllConstraints.size //No TV is double on the left side - //unifier.map(isSolvedConstraint).reduceOption((a,b) => a && b).getOrElse(true) - - def solvedForm(): Option[UnifySolution] = { - var sigma = Map[TypeVariable, Type]() - var delta: Map[Wildcard, (Type, Type)] = Map() - - //GenDelta - do{ - val toRemove = getLessDotCons.filter(c => tvs(c.right).isEmpty).map(c => { - val B = Wildcard("X"+c.left.asInstanceOf[TypeVariable].name) - sigma = sigma + (c.left.asInstanceOf[TypeVariable] -> B) - delta = delta + (B -> (c.right, BotType())) - c - }) - this.removeAll(toRemove) - //Do the substitution afterwards: - sigma.map(c => substitute(c._2, c._1)) - }while(hasChanged()) - - //AddDelta - this.bounds.filter(wc => tvs(wc._2._1).isEmpty && tvs(wc._2._2).isEmpty).map(wc =>{ - delta = delta + (wc._1 -> wc._2) - }) - if(this.bounds.filterNot(wc => tvs(wc._2._1).isEmpty && tvs(wc._2._2).isEmpty).nonEmpty)return None - - //GenSigma - val toRemove = unifier.filter(c => tvs(c.right).isEmpty).map(c => { - sigma = sigma + (c.left.asInstanceOf[TypeVariable] -> c.right) - c - }) - this.unifier = this.unifier.filterNot(toRemove.contains(_)) - - if(this.getAllConstraints.isEmpty){ - val ret = new UnifySolution(delta, sigma) - Some(ret) - }else{ - None - } - } } diff --git a/src/main/scala/unify/model/ExtendsRelations.scala b/src/main/scala/unify/model/ExtendsRelations.scala new file mode 100644 index 0000000..3db7692 --- /dev/null +++ b/src/main/scala/unify/model/ExtendsRelations.scala @@ -0,0 +1,13 @@ +package unify.model + +class ExtendsRelations(val extendsRelations : Set[(RefType, RefType)]): + + def supertype(of: String): (RefType, RefType) = extendsRelations.find(r => r._1.name.equals(of)).get + + def isPossibleSupertype(of: String, superType: String): Boolean = { + if (of.equals(superType)) return true + extendsRelations.map(p => { + return false + }) + false + } diff --git a/src/main/scala/unify/model/Type.scala b/src/main/scala/unify/model/Type.scala index 1e2483f..e901af2 100644 --- a/src/main/scala/unify/model/Type.scala +++ b/src/main/scala/unify/model/Type.scala @@ -1,6 +1,11 @@ package unify.model -sealed abstract class Type +sealed abstract class Type: + def tvs(t: Type): Set[TypeVariable] = t match { + case TypeVariable(a) => Set(TypeVariable(a)) + case RefType(wildcards, name, params) => params.flatMap(tvs).toSet + case _ => Set() + } final case class RefType(wildcards: WildcardEnvironment, name: String, params: List[Type]) extends Type{ override def toString: String = {