From ce48177ac281f508637556bf847309d1950ae19e Mon Sep 17 00:00:00 2001 From: JanUlrich Date: Wed, 3 Nov 2021 00:13:33 +0100 Subject: [PATCH] Error in WrongEQSet Test, but runnable --- src/main/scala/hb/dhbw/AST.scala | 17 ++- src/main/scala/hb/dhbw/CartesianProduct.scala | 13 ++ src/main/scala/hb/dhbw/FJTypeinference.scala | 16 ++- src/main/scala/hb/dhbw/FiniteClosure.scala | 22 ++-- src/main/scala/hb/dhbw/InsertTypes.scala | 5 + src/main/scala/hb/dhbw/Main.scala | 13 +- src/main/scala/hb/dhbw/TYPE.scala | 40 ++++-- src/main/scala/hb/dhbw/Unify.scala | 120 ++++++++++-------- src/test/scala/CartesianProductTest.scala | 22 ++++ src/test/scala/IntegrationTest.scala | 9 ++ src/test/scala/UnifyTest.scala | 17 +-- 11 files changed, 206 insertions(+), 88 deletions(-) create mode 100644 src/main/scala/hb/dhbw/InsertTypes.scala create mode 100644 src/test/scala/CartesianProductTest.scala diff --git a/src/main/scala/hb/dhbw/AST.scala b/src/main/scala/hb/dhbw/AST.scala index 7c0681b..08f35c8 100644 --- a/src/main/scala/hb/dhbw/AST.scala +++ b/src/main/scala/hb/dhbw/AST.scala @@ -5,6 +5,7 @@ final case class Method(retType: Type, name: String, params: List[(Type, String) sealed trait Type final case class RefType(name: String, params: List[Type]) extends Type +final case class GenericType(name: String) extends Type final case class TypeVariable(name: String) extends Type sealed trait Expr @@ -20,9 +21,12 @@ object ASTBuilder { var tpvNum = 0 - def fromParseTree(toAst: List[ParserClass]) = toAst.map(c => Class(c.name, c.params.map(p => (nTypeToType(p._1), nTypeToType(p._2))), - nTypeToType(c.superType).asInstanceOf[RefType], - c.fields.map(f => (nTypeToType(f._1),f._2)), c.methods.map(m => Method(freshTPV(), m.name, m.params.map(p => (freshTPV(), p)), m.retExpr)))) + def fromParseTree(toAst: List[ParserClass]) = toAst.map(c => { + val genericNames = c.params.map(_._1).map(_.name).toSet + Class(c.name, c.params.map(p => (nTypeToType(p._1, genericNames), nTypeToType(p._2, genericNames))), + nTypeToType(c.superType, genericNames).asInstanceOf[RefType], + c.fields.map(f => (nTypeToType(f._1, genericNames),f._2)), c.methods.map(m => Method(freshTPV(), m.name, m.params.map(p => (freshTPV(), p)), m.retExpr))) + }) private def freshTPV() = { def numToLetter(num: Int) = { @@ -39,6 +43,11 @@ object ASTBuilder { tpvNum = tpvNum+1 TypeVariable(numToLetter(tpvNum)) } - private def nTypeToType(t : NType): Type = RefType(t.name, t.params.map(nTypeToType)) + + private def nTypeToType(t : NType, genericNames: Set[String]): Type = if(t.params.isEmpty && genericNames.contains(t.name)) { + GenericType(t.name) + }else{ + RefType(t.name, t.params.map(p => nTypeToType(p, genericNames))) + } } } diff --git a/src/main/scala/hb/dhbw/CartesianProduct.scala b/src/main/scala/hb/dhbw/CartesianProduct.scala index 0053abd..798c87a 100644 --- a/src/main/scala/hb/dhbw/CartesianProduct.scala +++ b/src/main/scala/hb/dhbw/CartesianProduct.scala @@ -1,6 +1,19 @@ package hb.dhbw class CartesianProduct[A](private val setOfSets: List[List[A]]){ + def productWith(product: CartesianProduct[A]) = { + val ret = new CartesianProduct[A](setOfSets ++ product.setOfSets) + var base: Long = 1 + ret.sizes = ret.setOfSets.map(_.size) + ret.sizes.foreach(size => { + ret.bases = ret.bases :+ size + base = base * size + }) + ret.max = base + ret.i = i + ret + } + private var sizes: List[Int] = null private var bases: List[Long] = List() private var max: Long = 1 diff --git a/src/main/scala/hb/dhbw/FJTypeinference.scala b/src/main/scala/hb/dhbw/FJTypeinference.scala index 09ec8de..b196f23 100644 --- a/src/main/scala/hb/dhbw/FJTypeinference.scala +++ b/src/main/scala/hb/dhbw/FJTypeinference.scala @@ -13,15 +13,25 @@ object FJTypeinference { private def convertOrConstraints(constraints: List[Constraint]): Set[Set[Set[UnifyConstraint]]] = constraints.map( convertOrCons ).toSet + def convertType(t: Type): UnifyType = t match { + case GenericType(name) => UnifyRefType(name, List()) + case RefType(n, p) => UnifyRefType(n,p.map(convertType)) + case TypeVariable(n) => UnifyTV(n) + } + private def convertSingleConstraint(constraint: Constraint) = constraint match { - case LessDot(l, r) => UnifyLessDot(l,r) - case EqualsDot(l, r) => UnifyEqualsDot(l,r) + case LessDot(l, r) => UnifyLessDot(convertType(l),convertType(r)) + case EqualsDot(l, r) => UnifyEqualsDot(convertType(l),convertType(r)) case _ => throw new Exception("Error: Möglicherweise zu tiefe Verschachtelung von OrConstraints") } + private def generateFC(ast: List[Class]): FiniteClosure = new FiniteClosure( + ast.map(c => (cToUnifyType(c), convertType(c.superType).asInstanceOf[UnifyRefType])).toSet) + private def cToUnifyType(c: Class) = UnifyRefType(c.name, c.params.map(it => convertType(it._1))) + def typeinference(str: String): Either[String, Set[Set[UnifyConstraint]]] = { val ast = Parser.parse(str).map(ASTBuilder.fromParseTree(_)) - val typeResult = ast.map(TYPE.generateConstraints(_)) + val typeResult = ast.map(ast => TYPE.generateConstraints(ast, generateFC(ast))) val unifyResult = typeResult.map(res => Unify.unify(convertOrConstraints(res._1), res._2)) unifyResult } diff --git a/src/main/scala/hb/dhbw/FiniteClosure.scala b/src/main/scala/hb/dhbw/FiniteClosure.scala index 9ad2494..c1eb6c7 100644 --- a/src/main/scala/hb/dhbw/FiniteClosure.scala +++ b/src/main/scala/hb/dhbw/FiniteClosure.scala @@ -1,9 +1,9 @@ package hb.dhbw -class FiniteClosure(val extendsRelations : Set[(RefType, RefType)]){ +class FiniteClosure(val extendsRelations : Set[(UnifyRefType, UnifyRefType)]){ - private def calculateSupertypes(of: RefType) ={ + private def calculateSupertypes(of: UnifyRefType) ={ var rel = Set((of, of)) var size = rel.size do { @@ -12,32 +12,32 @@ class FiniteClosure(val extendsRelations : Set[(RefType, RefType)]){ }while(rel.size > size) rel.map(_._2) } - private def reflexiveTypes(of: Set[(RefType, RefType)]) ={ - val ref = Set.newBuilder[(RefType, RefType)] + private def reflexiveTypes(of: Set[(UnifyRefType, UnifyRefType)]) ={ + val ref = Set.newBuilder[(UnifyRefType, UnifyRefType)] ref ++= of.map(pair => (pair._1, pair._1)) ref ++= of.map(pair => (pair._2, pair._2)) ref.result() } - private def transitiveTypes(of: Set[(RefType, RefType)]) ={ - val ref = Set.newBuilder[(RefType, RefType)] + private def transitiveTypes(of: Set[(UnifyRefType, UnifyRefType)]) ={ + val ref = Set.newBuilder[(UnifyRefType, UnifyRefType)] ref ++= of.map(pair => (pair._1, pair._1)) ref ++= of.map(pair => (pair._2, pair._2)) ref.result() } - private def superClassTypes(of: RefType) = { + private def superClassTypes(of: UnifyRefType) = { val extendsRelation = extendsRelations.filter(pair => pair._1.name.equals(of.name)) extendsRelation.map(p => { val paramMap = p._1.params.zip(of.params).toMap - (of,RefType(p._2.name, p._2.params.map(paramMap))) + (of,UnifyRefType(p._2.name, p._2.params.map(paramMap))) }) } - private def superClassTypes(of: Set[(RefType, RefType)]) : Set[(RefType, RefType)] ={ - val sClass = Set.newBuilder[(RefType, RefType)] + private def superClassTypes(of: Set[(UnifyRefType, UnifyRefType)]) : Set[(UnifyRefType, UnifyRefType)] ={ + val sClass = Set.newBuilder[(UnifyRefType, UnifyRefType)] sClass ++= of.flatMap(pair => Set(pair._2, pair._1)).flatMap(t => superClassTypes(t)) sClass.result() } - def superTypes(of : RefType) : Set[RefType] = calculateSupertypes(of) + def superTypes(of : UnifyRefType) : Set[UnifyRefType] = calculateSupertypes(of) def isPossibleSupertype(of: String, superType: String): Boolean = { val extendsMap = extendsRelations.map(p => (p._1.name,p._2.name)).toMap diff --git a/src/main/scala/hb/dhbw/InsertTypes.scala b/src/main/scala/hb/dhbw/InsertTypes.scala new file mode 100644 index 0000000..3e87a57 --- /dev/null +++ b/src/main/scala/hb/dhbw/InsertTypes.scala @@ -0,0 +1,5 @@ +package hb.dhbw + +class InsertTypes { + +} diff --git a/src/main/scala/hb/dhbw/Main.scala b/src/main/scala/hb/dhbw/Main.scala index a657881..1519d79 100644 --- a/src/main/scala/hb/dhbw/Main.scala +++ b/src/main/scala/hb/dhbw/Main.scala @@ -64,10 +64,16 @@ object Main { case UnifyLessDot(a, b) => "("+prettyPrintHTML(a)+" <. "+prettyPrintHTML(b)+")" case UnifyEqualsDot(a, b) => "("+prettyPrintHTML(a)+" =. "+prettyPrintHTML(b)+")" } + def prettyPrintHTML(t: UnifyType): String = t match { + case UnifyRefType(name, List()) => name + case UnifyRefType(name, params) => name + "<" + params.map(prettyPrintHTML).mkString(", ") + ">" + case UnifyTV(name) => "" + name + "" + } def prettyPrintHTML(t: Type): String = t match { case RefType(name, List()) => name case RefType(name, params) => name + "<" + params.map(prettyPrintHTML).mkString(", ") + ">" case TypeVariable(name) => "" + name + "" + case GenericType(name) => name } def prettyPrint(unifyResult : Set[Set[UnifyConstraint]]): String = unifyResult.map( @@ -81,6 +87,11 @@ object Main { case RefType(name, List()) => name case RefType(name, params) => name + "<" + params.map(prettyPrint).mkString(", ") + ">" case TypeVariable(name) => "_" + name + "_" + case GenericType(name) => name + } + def prettyPrint(t: UnifyType): String = t match { + case UnifyRefType(name, List()) => name + case UnifyRefType(name, params) => name + "<" + params.map(prettyPrint).mkString(", ") + ">" + case UnifyTV(name) => "_" + name + "_" } - } diff --git a/src/main/scala/hb/dhbw/TYPE.scala b/src/main/scala/hb/dhbw/TYPE.scala index a8a1ef6..3cf23ef 100644 --- a/src/main/scala/hb/dhbw/TYPE.scala +++ b/src/main/scala/hb/dhbw/TYPE.scala @@ -8,20 +8,39 @@ final case class EqualsDot(l: Type, r: Type) extends Constraint final case class LessDot(l: Type, r: Type) extends Constraint object TYPE { - def generateConstraints(c: List[Class]) = { - new TYPEMonad().TYPEClass(c) + def generateConstraints(c: List[Class], finiteClosure: FiniteClosure) = { + new TYPEMonad().TYPEClass(c, finiteClosure) + } + + private class GenericTypeReplaceMonad(tpvs: TYPEMonad){ + var genericNameToTVMap: Map[String, TypeVariable] = Map() + + def replaceGenerics(inConstraint: Constraint): Constraint = inConstraint match { + case OrConstraint(cons) => OrConstraint(cons.map(replaceGenerics(_))) + case AndConstraint(andCons) => AndConstraint(andCons.map(replaceGenerics(_))) + case LessDot(l, r) => LessDot(replaceGenerics(l), replaceGenerics(r)) + case EqualsDot(l, r) => EqualsDot(replaceGenerics(l), replaceGenerics(r)) + } + def replaceGenerics(inType: Type): Type= inType match { + case RefType(name, params) =>RefType(name, params.map(replaceGenerics(_))) + case GenericType(name) => genericNameToTVMap.get(name) + .getOrElse{ + val newTV = tpvs.freshTPV() + genericNameToTVMap = genericNameToTVMap + (name -> newTV) + newTV + } + case x => x + } } private class TYPEMonad{ var tpvNum = 0 - private def generateFC(ast: List[Class]): FiniteClosure = new FiniteClosure(ast.map(c => (cToType(c), c.superType)).toSet) - - def TYPEClass(ast: List[Class]) = { - (ast.flatMap(cl => cl.methods.flatMap(m => TYPEMethod(m, cToType(cl), ast))), generateFC(ast)) + def TYPEClass(ast: List[Class], fc: FiniteClosure) = { + (ast.flatMap(cl => cl.methods.flatMap(m => TYPEMethod(m, cToType(cl), ast))), fc) } - private def freshTPV() = { + def freshTPV() = { tpvNum = tpvNum+1 TypeVariable(tpvNum.toString) } @@ -35,13 +54,15 @@ object TYPE { case LocalVar(n) => localVars.find(it => it._2.equals(n)).map(p => (p._1, List())) .getOrElse(throw new Exception("Local Variable "+ n + " not found")) case FieldVar(e, f) => { + val genericReplace = new GenericTypeReplaceMonad(this) val (rty, cons) = TYPEExpr(e, localVars, ast) val fields = findFields(f, ast) val a = freshTPV() - val orCons = OrConstraint(fields.map(f => AndConstraint(List(EqualsDot(rty, cToType(f._1)), EqualsDot(a, f._2))))) + val orCons = OrConstraint(fields.map(f => AndConstraint(List(EqualsDot(rty, genericReplace.replaceGenerics(cToType(f._1))), EqualsDot(a, genericReplace.replaceGenerics(f._2)))))) (a, orCons :: cons) } case MethodCall(e, name, params) => { + val genericReplace = new GenericTypeReplaceMonad(this) val a = freshTPV() val (rty, cons) = TYPEExpr(e, localVars, ast) val es = params.map(ex => TYPEExpr(ex, localVars, ast)) @@ -50,7 +71,8 @@ object TYPE { List(EqualsDot(rty, cToType(m._1)), EqualsDot(a, m._2.retType)) ++ m._2.params.map(_._1).zip(es.map(_._1)).map(a => LessDot(a._2, a._1)) )) - (a, cons ++ es.flatMap(_._2) ++ List(OrConstraint(consM))) + val retCons = (cons ++ es.flatMap(_._2) ++ List(OrConstraint(consM))).map(genericReplace.replaceGenerics(_)) + (a, retCons) } case Constructor(className, params) => { throw new NotImplementedError() diff --git a/src/main/scala/hb/dhbw/Unify.scala b/src/main/scala/hb/dhbw/Unify.scala index 9604c26..9909762 100644 --- a/src/main/scala/hb/dhbw/Unify.scala +++ b/src/main/scala/hb/dhbw/Unify.scala @@ -1,12 +1,27 @@ package hb.dhbw -sealed abstract class UnifyConstraint(val left: Type, val right: Type) -final case class UnifyLessDot(override val left: Type, override val right: Type) extends UnifyConstraint(left, right) -final case class UnifyEqualsDot(override val left: Type, override val right: Type) extends UnifyConstraint(left, right) +sealed abstract class UnifyConstraint(val left: UnifyType, val right: UnifyType) +final case class UnifyLessDot(override val left: UnifyType, override val right: UnifyType) extends UnifyConstraint(left, right) +final case class UnifyEqualsDot(override val left: UnifyType, override val right: UnifyType) extends UnifyConstraint(left, right) + +sealed abstract class UnifyType +final case class UnifyRefType(name: String, params: List[UnifyType]) extends UnifyType +final case class UnifyTV(name: String) extends UnifyType object Unify { + def unifyIteratove(orCons: Set[Set[Set[UnifyConstraint]]], fc: FiniteClosure) : Set[Set[UnifyConstraint]] = { + var eqSets = new CartesianProduct[Set[UnifyConstraint]](orCons) + while(eqSets.hasNext()){ + val eqSet = eqSets.nextProduct() + val rulesResult = applyRules(fc)(eqSet.flatten) + val step2Result = step2(rulesResult, fc) + + } + Set() + } + def unify(orCons: Set[Set[Set[UnifyConstraint]]], fc: FiniteClosure) : Set[Set[UnifyConstraint]] = { val eqSets = cartesianProduct(orCons) val step2Results = eqSets.flatMap(eqSet => { @@ -25,30 +40,30 @@ object Unify { def step2(eq : Set[UnifyConstraint], fc: FiniteClosure) ={ val eq1 = eq.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), TypeVariable(_)) => true - case UnifyEqualsDot(TypeVariable(_), TypeVariable(_)) => true + case UnifyLessDot(UnifyTV(_), UnifyTV(_)) => true + case UnifyEqualsDot(UnifyTV(_), UnifyTV(_)) => true case _ => false }) val cUnifyLessDotACons: Set[Set[Set[UnifyConstraint]]] = eq.map(c => c match{ - case UnifyLessDot(RefType(name,params), TypeVariable(a)) => - fc.superTypes(RefType(name,params)) - .map(superType => Set(UnifyEqualsDot(TypeVariable(a), superType).asInstanceOf[UnifyConstraint])) + case UnifyLessDot(UnifyRefType(name,params), UnifyTV(a)) => + fc.superTypes(UnifyRefType(name,params)) + .map(superType => Set(UnifyEqualsDot(UnifyTV(a), superType).asInstanceOf[UnifyConstraint])) case _ => null }).filter(s => s!=null) val aUnifyLessDota = eq1.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), TypeVariable(_)) => true + case UnifyLessDot(UnifyTV(_), UnifyTV(_)) => true case _ => false }).asInstanceOf[Set[UnifyLessDot]] - val aUnifyLessDotCConsAndBs: Set[(UnifyLessDot,Option[TypeVariable])] = eq.map(c => c match{ - case UnifyLessDot(TypeVariable(a),RefType(name,params)) =>{ - val bs = aUnifyLessDota.flatMap(c => Set(c.left, c.right)).asInstanceOf[Set[TypeVariable]] - .filter(c => !a.equals(c) && isLinked(TypeVariable(a), c, aUnifyLessDota)) + val aUnifyLessDotCConsAndBs: Set[(UnifyLessDot,Option[UnifyTV])] = eq.map(c => c match{ + case UnifyLessDot(UnifyTV(a),UnifyRefType(name,params)) =>{ + val bs = aUnifyLessDota.flatMap(c => Set(c.left, c.right)).asInstanceOf[Set[UnifyTV]] + .filter(c => !a.equals(c) && isLinked(UnifyTV(a), c, aUnifyLessDota)) if(bs.isEmpty){ - Set((UnifyLessDot(TypeVariable(a),RefType(name,params)),None)) + Set((UnifyLessDot(UnifyTV(a),UnifyRefType(name,params)),None)) }else{ - bs.map(b => (UnifyLessDot(TypeVariable(a),RefType(name,params)),Some(b))) + bs.map(b => (UnifyLessDot(UnifyTV(a),UnifyRefType(name,params)),Some(b))) } } case _ => null @@ -57,30 +72,31 @@ object Unify { val aUnifyLessDotCCons = aUnifyLessDotCConsAndBs.map{ case (ac:UnifyLessDot,Some(b)) => Set(Set(UnifyLessDot(b, ac.right))) ++ - fc.superTypes(ac.right.asInstanceOf[RefType]) + fc.superTypes(ac.right.asInstanceOf[UnifyRefType]) .map(superType => Set(UnifyEqualsDot(b, superType))) - case (ac, None) => null + case (ac, None) => null }.filter(c => c != null).asInstanceOf[Set[Set[Set[UnifyConstraint]]]] val eq2 = eq.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), RefType(_,_)) => true - case UnifyEqualsDot(TypeVariable(_), RefType(_,_)) => true - case UnifyEqualsDot(RefType(_,_),TypeVariable(_)) => true + case UnifyLessDot(UnifyTV(_), UnifyRefType(_,_)) => true + case UnifyEqualsDot(UnifyTV(_), UnifyRefType(_,_)) => true + case UnifyEqualsDot(UnifyRefType(_,_),UnifyTV(_)) => true case _ => false }) val eqSet = cartesianProduct(Set(Set(eq1)) ++ Set(Set(eq2)) ++ aUnifyLessDotCCons ++ cUnifyLessDotACons) - eqSet.map( s => s.flatten) + val ret = eqSet.map( s => s.flatten) + ret } private def getAUnifyLessDotC(from: Set[UnifyConstraint]) = from.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), RefType(_,_)) => true + case UnifyLessDot(UnifyTV(_), UnifyRefType(_,_)) => true case _ => false }).asInstanceOf[Set[UnifyLessDot]] def matchRule(eq : Set[UnifyConstraint], fc: FiniteClosure) = { val aUnifyLessDotC = getAUnifyLessDotC(eq) (eq -- aUnifyLessDotC) ++ aUnifyLessDotC.map(c => { - val smallerC = aUnifyLessDotC.find(c2 => c2 != c && c2.left.equals(c.left) && fc.isPossibleSupertype(c2.right.asInstanceOf[RefType].name,c.right.asInstanceOf[RefType].name)) + val smallerC = aUnifyLessDotC.find(c2 => c2 != c && c2.left.equals(c.left) && fc.isPossibleSupertype(c2.right.asInstanceOf[UnifyRefType].name,c.right.asInstanceOf[UnifyRefType].name)) if(smallerC.isEmpty){ c }else{ @@ -91,28 +107,28 @@ object Unify { } def reduceRule(eq: Set[UnifyConstraint]) = eq.flatMap(c => c match { - case UnifyEqualsDot(RefType(an, ap), RefType(bn, bp)) => { + case UnifyEqualsDot(UnifyRefType(an, ap), UnifyRefType(bn, bp)) => { if(an.equals(bn)){ ap.zip(bp).map(p => UnifyEqualsDot(p._1, p._2)) }else{ - Set(UnifyLessDot(RefType(an, ap), RefType(bn, bp))) + Set(UnifyEqualsDot(UnifyRefType(an, ap), UnifyRefType(bn, bp))) } } case x => Set(x) }) def swapRule(eq : Set[UnifyConstraint]) = eq.map(c => c match { - case UnifyEqualsDot(RefType(an, ap), TypeVariable(a)) => UnifyEqualsDot(TypeVariable(a), RefType(an, ap)) + case UnifyEqualsDot(UnifyRefType(an, ap), UnifyTV(a)) => UnifyEqualsDot(UnifyTV(a), UnifyRefType(an, ap)) case x => x }) def adaptRule(eq: Set[UnifyConstraint], fc: FiniteClosure) = { eq.map(c => c match { - case UnifyLessDot(RefType(an, ap), RefType(bn, bp)) => { + case UnifyLessDot(UnifyRefType(an, ap), UnifyRefType(bn, bp)) => { if(fc.isPossibleSupertype(an, bn)){ - UnifyEqualsDot(fc.superTypes(RefType(an, ap)).find(r => r.name.equals(bn)).get, RefType(bn, bp)) + UnifyEqualsDot(fc.superTypes(UnifyRefType(an, ap)).find(r => r.name.equals(bn)).get, UnifyRefType(bn, bp)) }else{ - UnifyLessDot(RefType(an, ap), RefType(bn, bp)) + UnifyLessDot(UnifyRefType(an, ap), UnifyRefType(bn, bp)) } } case x => x @@ -121,14 +137,14 @@ object Unify { def adoptRule(eq: Set[UnifyConstraint], fc: FiniteClosure) ={ val aUnifyLessDota = eq.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), TypeVariable(_)) => true + case UnifyLessDot(UnifyTV(_), UnifyTV(_)) => true case _ => false }).asInstanceOf[Set[UnifyLessDot]] val aUnifyLessDotC = getAUnifyLessDotC(eq) (eq -- aUnifyLessDotC) ++ aUnifyLessDotC.map(c => { val smallerC = aUnifyLessDotC.find(c2 => c2 != c - && isLinked(c2.left.asInstanceOf[TypeVariable], c.left.asInstanceOf[TypeVariable], aUnifyLessDota) - && fc.isPossibleSupertype(c2.right.asInstanceOf[RefType].name,c.right.asInstanceOf[RefType].name)) + && isLinked(c2.left.asInstanceOf[UnifyTV], c.left.asInstanceOf[UnifyTV], aUnifyLessDota) + && fc.isPossibleSupertype(c2.right.asInstanceOf[UnifyRefType].name,c.right.asInstanceOf[UnifyRefType].name)) if(smallerC.isEmpty){ c }else{ @@ -138,9 +154,9 @@ object Unify { ) } - private def isLinked(a: TypeVariable, b: TypeVariable, aUnifyLessDota: Set[UnifyLessDot]): Boolean = { - def getRightSides(of: TypeVariable) ={ - aUnifyLessDota.filter(c => c.left.asInstanceOf[TypeVariable].name.equals(of.name)) + private def isLinked(a: UnifyTV, b: UnifyTV, aUnifyLessDota: Set[UnifyLessDot]): Boolean = { + def getRightSides(of: UnifyTV) ={ + aUnifyLessDota.filter(c => c.left.asInstanceOf[UnifyTV].name.equals(of.name)) } val rightsides = getRightSides(a).map(c => c.right) if(rightsides.isEmpty){ @@ -148,16 +164,16 @@ object Unify { } else if (rightsides.contains(b)){ true }else{ - rightsides.foldLeft(false)((r, c) => r || isLinked(c.asInstanceOf[TypeVariable],b, aUnifyLessDota)) + rightsides.foldLeft(false)((r, c) => r || isLinked(c.asInstanceOf[UnifyTV],b, aUnifyLessDota)) } } private def findCircles(aUnifyLessDota: Set[UnifyLessDot]) ={ - def getRightSides(of: TypeVariable) ={ - aUnifyLessDota.filter(c => c.left.asInstanceOf[TypeVariable].name.equals(of.name)) + def getRightSides(of: UnifyTV) ={ + aUnifyLessDota.filter(c => c.left.asInstanceOf[UnifyTV].name.equals(of.name)) } def findCircle(graph: List[UnifyLessDot]): List[UnifyLessDot] = { - val newAdditions = getRightSides(graph.last.right.asInstanceOf[TypeVariable]) + val newAdditions = getRightSides(graph.last.right.asInstanceOf[UnifyTV]) var circle: List[UnifyLessDot] = List() val iterator = newAdditions.iterator while(iterator.hasNext && circle.isEmpty){ @@ -175,7 +191,7 @@ object Unify { def equalsRule(eq: Set[UnifyConstraint]) ={ val aUnifyLessDota = eq.filter(c => c match{ - case UnifyLessDot(TypeVariable(_), TypeVariable(_)) => true + case UnifyLessDot(UnifyTV(_), UnifyTV(_)) => true case _ => false }).asInstanceOf[Set[UnifyLessDot]] val circle = findCircles(aUnifyLessDota).find(!_.isEmpty) @@ -187,24 +203,24 @@ object Unify { } } - private def paramsContain(tv: TypeVariable, inParams: RefType): Boolean = + private def paramsContain(tv: UnifyTV, inParams: UnifyRefType): Boolean = inParams.params.find(t => t match { - case TypeVariable(a) => tv.equals(TypeVariable(a)) - case RefType(a,p) => paramsContain(tv, RefType(a,p)) + case UnifyTV(a) => tv.equals(UnifyTV(a)) + case UnifyRefType(a,p) => paramsContain(tv, UnifyRefType(a,p)) }).isDefined def substStep(eq: Set[UnifyConstraint]) = eq.find(c => c match { - case UnifyEqualsDot(TypeVariable(a), RefType(n, p)) => !paramsContain(TypeVariable(a), RefType(n,p)) - case UnifyEqualsDot(TypeVariable(a), TypeVariable(b)) => !a.equals(b) + case UnifyEqualsDot(UnifyTV(a), UnifyRefType(n, p)) => !paramsContain(UnifyTV(a), UnifyRefType(n,p)) + case UnifyEqualsDot(UnifyTV(a), UnifyTV(b)) => !a.equals(b) case _ => false - }).map(c => (subst(c.left.asInstanceOf[TypeVariable], c.right, eq), Some(c))).getOrElse((eq, None)) + }).map(c => (subst(c.left.asInstanceOf[UnifyTV], c.right, eq.filter(!_.equals(c))), Some(c))).getOrElse((eq, None)) - private def substHelper(a: TypeVariable, withType: Type,in: Type) :Type = in match { - case RefType(n, p) => RefType(n,p.map(t => substHelper(a, withType, t)).asInstanceOf[List[Type]]) - case TypeVariable(n) => + private def substHelper(a: UnifyTV, withType: UnifyType,in: UnifyType) :UnifyType = in match { + case UnifyRefType(n, p) => UnifyRefType(n,p.map(t => substHelper(a, withType, t))) + case UnifyTV(n) => if(a.equals(in)){withType}else{in} } - def subst(a: TypeVariable, substType: Type,eq: Set[UnifyConstraint]): Set[UnifyConstraint] = { + def subst(a: UnifyTV, substType: UnifyType,eq: Set[UnifyConstraint]): Set[UnifyConstraint] = { eq.map(c => c match { case UnifyLessDot(left, right) => UnifyLessDot(substHelper(a, substType, left), substHelper(a, substType, right)) case UnifyEqualsDot(left, right) => UnifyEqualsDot(substHelper(a, substType, left), substHelper(a, substType, right)) @@ -218,8 +234,8 @@ object Unify { var eqNew: Set[UnifyConstraint] = null var eqFinish: Set[UnifyConstraint] = eq do{ - eqNew = doWhileSome(Unify.equalsRule,eqFinish) - eqFinish = reduceRule(matchRule(adaptRule(adaptRule(eqNew, fc), fc), fc)) + eqNew = doWhileSome(Unify.equalsRule,eqFinish) //We have to apply equals rule first, to get rid of circles + eqFinish = swapRule(reduceRule(matchRule(adoptRule(adaptRule(eqNew, fc), fc), fc))) }while(!eqNew.equals(eqFinish)) eqNew } diff --git a/src/test/scala/CartesianProductTest.scala b/src/test/scala/CartesianProductTest.scala new file mode 100644 index 0000000..e6281dc --- /dev/null +++ b/src/test/scala/CartesianProductTest.scala @@ -0,0 +1,22 @@ +import hb.dhbw.CartesianProduct +import org.scalatest.FunSuite + +class CartesianProductTest extends FunSuite{ + + test("nextProduct"){ + val test = new CartesianProduct[Int](Set(Set(1,2),Set(4,3))) + val result = List(1,2,3,4).map( _ => test.nextProduct()) + assert(result.contains(Set(1,4))) + } + + test("productWith"){ + val test = new CartesianProduct[Int](Set(Set(1,2),Set(4,3))) + val test2 = new CartesianProduct[Int](Set(Set(5,6), Set(7,8))) + val test3 = test.productWith(test2) + val result = for( i <- 1 to 16) yield test3.nextProduct() + assert(result.contains(Set(2,3,6,8))) + assert(result.toSet.size == 16) + println(result) + } + +} diff --git a/src/test/scala/IntegrationTest.scala b/src/test/scala/IntegrationTest.scala index 8b54900..5aa97f0 100644 --- a/src/test/scala/IntegrationTest.scala +++ b/src/test/scala/IntegrationTest.scala @@ -27,6 +27,11 @@ class IntegrationTest extends FunSuite { println(result.map(Main.prettyPrint(_))) } + test("GenericVar"){ + val result = FJTypeinference.typeinference("class List extends Object{\nA a;\n\nget(){ return this.a;\n\n}\n}\n\n\nclass Test extends Object{\nList test;\n\nm(a){\n return this.test.get();\n}\n\n}") + println(result.map(Main.prettyPrint(_))) + } + test("IdentCallExample"){ val result = FJTypeinference.typeinference("class Test extends Object{\n\n m(a,b){return this.m(a);\n}\nm(a){return a;}\n}") println(result.map(Main.prettyPrint(_))) @@ -37,4 +42,8 @@ class IntegrationTest extends FunSuite { println(result.map(Main.prettyPrint(_))) } + test("GetMethods"){ + val result = FJTypeinference.typeinference("class Test extends Object{\nget(){ return this.get().get();}\n}\n\nclass Test2 extends Object{\nget(){ return this;}\n}" ) + println(result.map(Main.prettyPrint(_))) + } } diff --git a/src/test/scala/UnifyTest.scala b/src/test/scala/UnifyTest.scala index 96070bf..5245024 100644 --- a/src/test/scala/UnifyTest.scala +++ b/src/test/scala/UnifyTest.scala @@ -1,5 +1,5 @@ -import hb.dhbw.{FiniteClosure, RefType, TypeVariable, Unify, UnifyEqualsDot, UnifyLessDot} +import hb.dhbw.{FiniteClosure, RefType, TypeVariable, Unify, UnifyEqualsDot, UnifyLessDot, UnifyRefType, UnifyTV} import org.scalatest.FunSuite class UnifyTest extends FunSuite { @@ -7,14 +7,10 @@ class UnifyTest extends FunSuite { val fcPair3 = (RefType("List", List(TypeVariable("A"))), RefType("Object", List())) val fcPair2 = (RefType("MyMap", List(TypeVariable("A"), TypeVariable("B"))), RefType("Map", List(RefType("String", List(TypeVariable("A"))), TypeVariable("B")))) - val fc = new FiniteClosure(Set(fcPair1, fcPair2, fcPair3)) + val fc = new FiniteClosure(Set()) - test("Unify.step2.alinkedb"){ - var step2 = Unify.step2(Set(UnifyLessDot(TypeVariable("c"), TypeVariable("b")), - UnifyLessDot(TypeVariable("a"), RefType("List", List(RefType("Object", List()))))), fc) - assert(step2.head.size == 2) - } + /* test("Unify.step2") { var step2 = Unify.step2(Set(UnifyLessDot(TypeVariable("a"), TypeVariable("b")), UnifyLessDot(TypeVariable("a"), RefType("List", List(RefType("Object", List()))))), fc) @@ -25,5 +21,10 @@ class UnifyTest extends FunSuite { UnifyLessDot(RefType("List", List(RefType("Object", List()))), TypeVariable("a"))), fc) println(step2) } - +*/ + test("Unify.applyRules.WrongEQSet"){ + val unifyRes = Unify.unify(Set(Set(Set(UnifyEqualsDot(UnifyRefType("List", List()),UnifyRefType("Object", List()))))), fc) + println(unifyRes) + assert(!unifyRes.isEmpty) + } }