Insert minimal set of generics. correctly output AST

This commit is contained in:
JanUlrich 2021-11-22 19:10:20 +01:00
parent 39c9adb794
commit fb7593cf3b
3 changed files with 49 additions and 29 deletions

View File

@ -2,7 +2,12 @@ package hb.dhbw
object InsertTypes {
def normalize(eq:Set[UnifyConstraint]) = {
/**
* Remove a <. b constraints
* @param eq
* @return
*/
private def normalize(eq:Set[UnifyConstraint]) = {
def substHelper(a: UnifyTV, withType: UnifyType,in: UnifyType) :UnifyType = in match {
case UnifyRefType(n, p) => UnifyRefType(n,p.map(t => substHelper(a, withType, t)))
case UnifyTV(n) =>
@ -24,7 +29,7 @@ object InsertTypes {
case _ => true
})
alessdotB.foreach(it => ret = subst(it.left.asInstanceOf[UnifyTV], it.right, ret))
ret
ret ++ alessdotB.map(cons => UnifyEqualsDot(cons.left, cons.right))
}
def insert(unifyResult: Set[Set[UnifyConstraint]], into: Class): Class = {
@ -48,24 +53,26 @@ object InsertTypes {
Class(into.name, into.genericParams, into.superType, into.fields, newMethods)
}
private def insert(constraints: Set[Constraint], into: Method): Method = {
def getAllGenericTypes(from: Constraint): Set[Type] = from match {
case EqualsDot(a,b) => getAllGenerics(a) ++ getAllGenerics(b)
case LessDot(a,b) => getAllGenerics(a) ++ getAllGenerics(b)
}
def getAllGenerics(from: Type): Set[Type] = from match {
case RefType(name, params) => params.flatMap(getAllGenerics(_)).toSet
case GenericType(a) => Set(GenericType(a))
private def getLinkedConstraints(linkedTypes: Set[Type], in: Set[Constraint]): Set[Constraint] ={
var typesWithoutBounds = linkedTypes
in.flatMap(_ match {
case LessDot(GenericType(a), RefType(name, params)) => {
if(linkedTypes.contains(GenericType(a))){
typesWithoutBounds = typesWithoutBounds - GenericType(a)
val genericsInParams = params.filter(_ match {
case GenericType(_) => true
case _ => false
}).toSet
getLinkedConstraints(genericsInParams, in) +LessDot(GenericType(a), RefType(name, params))
}else{
Set()
}
}
case _ => Set()
}
def getLinkedCons(linkedTypes: Set[Type], in: Set[Constraint]): Set[Constraint] ={
val linkedCons = in.filter(it => getAllGenericTypes(it).exists(linkedTypes.contains(_)))
val newLinkedTypes = linkedCons.flatMap(getAllGenericTypes(_))
if(newLinkedTypes.equals(linkedTypes))
linkedCons
else
getLinkedCons(newLinkedTypes, in)
}
}) ++ typesWithoutBounds.map(t => LessDot(t, RefType("Object", List())))
}
private def insert(constraints: Set[Constraint], into: Method): Method = {
def replaceTVWithGeneric(in: Type): Type = in match {
case TypeVariable(name) => GenericType(name)
case RefType(name, params) => RefType(name, params.map(replaceTVWithGeneric(_)))
@ -75,13 +82,16 @@ object InsertTypes {
case _ => null
}).find(_ != null).getOrElse(t)
def getAllGenerics(from: Type): Set[Type] = from match {
case RefType(name, params) => params.flatMap(getAllGenerics(_)).toSet
case GenericType(a) => Set(GenericType(a))
}
val genericRetType = substType(replaceTVWithGeneric(into.retType))
val genericParams = into.params.map(p => (substType(replaceTVWithGeneric(p._1)), p._2))
val constraintsForMethod = getLinkedCons(Set(genericRetType) ++ genericParams.map(_._1), constraints)
.filter(_ match {
case LessDot(GenericType(_), RefType(_,_)) => true
case _ => false
})
val genericsUsedInMethod = (Set(genericRetType) ++ genericParams.map(_._1)).flatMap(getAllGenerics(_))
val constraintsForMethod = getLinkedConstraints(genericsUsedInMethod, constraints)
Method(into.genericParams ++ constraintsForMethod, genericRetType, into.name, genericParams, into.retExpr)
}
private def replaceTVWithGeneric(in: UnifyConstraint, genericNames: Set[String]): Constraint= in match {

View File

@ -68,11 +68,22 @@ object Main {
def prettyPrintCons(constraint: Constraint)= constraint match{
case LessDot(l, r) => prettyPrintType(l) + " extends " + prettyPrintType(r)
}
def prettyPrintGenericList(constraints: List[Constraint]) = {
if(constraints.isEmpty){
""
}else{
"<" + constraints.map(prettyPrintCons(_)).mkString(", ") + ">"
}
}
ast.map(cl => {
"class " + cl.name + "{\n" +
"class " + cl.name + prettyPrintGenericList(cl.genericParams.map(it => LessDot(it._1, it._2))) +
" extends " + prettyPrintType(cl.superType) + "{\n" +
cl.fields.map(f => {
prettyPrintType(f._1) + " " + f._2 + ";"
}).mkString("\n") +
cl.methods.map(m => {
" "+m.genericParams.map(prettyPrintCons(_)).mkString(", ") + " " +
prettyPrintType(m.retType) +" "+ m.name +"(" + ") {\n"+
" "+ prettyPrintGenericList(m.genericParams) + " " +
prettyPrintType(m.retType) +" "+ m.name +"(" + m.params.map(tp=>prettyPrintType(tp._1) + " " + tp._2).mkString(", ") + ") {\n"+
" return " + prettyPrintExpr(m.retExpr) + ";\n" +
" }"
}).mkString("\n") + "\n}"

View File

@ -71,7 +71,6 @@ class IntegrationTest extends FunSuite {
test("list.add") {
val input = "class List<A extends Object> extends Object{\nA f;\n add( a){\n return new List(a);\n}\n}\nclass Test extends Object{\n\nm(a){return a.add(this);}\n}"
val result = FJTypeinference.typeinference(input )
println(result.map(it => Main.prettyPrint(it._1)))
println(result.map(it => Main.prettyPrintAST(it._2)))
}
}