Merge: compiler: improve `poset_from_mtypes` used for type coloring.
[nit.git] / src / compiler / separate_compiler.nit
index a51ef09..99a18d6 100644 (file)
@@ -456,14 +456,9 @@ class SeparateCompiler
                # Collect types to colorize
                var live_types = runtime_type_analysis.live_types
                var live_cast_types = runtime_type_analysis.live_cast_types
-               var mtypes = new HashSet[MType]
-               mtypes.add_all(live_types)
-               for c in self.box_kinds.keys do
-                       mtypes.add(c.mclass_type)
-               end
 
                # Compute colors
-               var poset = poset_from_mtypes(mtypes, live_cast_types)
+               var poset = poset_from_mtypes(live_types, live_cast_types)
                var colorer = new POSetColorer[MType]
                colorer.colorize(poset)
                type_ids = colorer.ids
@@ -471,20 +466,42 @@ class SeparateCompiler
                type_tables = build_type_tables(poset)
 
                # VT and FT are stored with other unresolved types in the big resolution_tables
-               self.compile_resolution_tables(mtypes)
+               self.compute_resolution_tables(live_types)
 
                return poset
        end
 
        private fun poset_from_mtypes(mtypes, cast_types: Set[MType]): POSet[MType] do
                var poset = new POSet[MType]
+
+               # Instead of doing the full matrix mtypes X cast_types,
+               # a grouping is done by the base classes of the type so
+               # that we compare only types whose base classes are in inheritance.
+
+               var mtypes_by_class = new MultiHashMap[MClass, MType]
                for e in mtypes do
+                       var c = e.as_notnullable.as(MClassType).mclass
+                       mtypes_by_class[c].add(e)
+                       poset.add_node(e)
+               end
+
+               var casttypes_by_class = new MultiHashMap[MClass, MType]
+               for e in cast_types do
+                       var c = e.as_notnullable.as(MClassType).mclass
+                       casttypes_by_class[c].add(e)
                        poset.add_node(e)
-                       for o in cast_types do
-                               if e == o then continue
-                               poset.add_node(o)
-                               if e.is_subtype(mainmodule, null, o) then
-                                       poset.add_edge(e, o)
+               end
+
+               for c1, ts1 in mtypes_by_class do
+                       for c2 in c1.in_hierarchy(mainmodule).greaters do
+                               var ts2 = casttypes_by_class[c2]
+                               for e in ts1 do
+                                       for o in ts2 do
+                                               if e == o then continue
+                                               if e.is_subtype(mainmodule, null, o) then
+                                                       poset.add_edge(e, o)
+                                               end
+                                       end
                                end
                        end
                end
@@ -510,9 +527,8 @@ class SeparateCompiler
                return tables
        end
 
-       protected fun compile_resolution_tables(mtypes: Set[MType]) do
-               # resolution_tables is used to perform a type resolution at runtime in O(1)
-
+       # resolution_tables is used to perform a type resolution at runtime in O(1)
+       private fun compute_resolution_tables(mtypes: Set[MType]) do
                # During the visit of the body of classes, live_unresolved_types are collected
                # and associated to
                # Collect all live_unresolved_types (visited in the body of classes)
@@ -973,6 +989,7 @@ class SeparateCompiler
                                v.add_decl("NULL,")
                        else
                                var s = "type_{t.c_name}"
+                               undead_types.add(t.mclass_type)
                                v.require_declaration(s)
                                v.add_decl("&{s},")
                        end