1aa8a9761SMichael Kruse //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2aa8a9761SMichael Kruse // 3aa8a9761SMichael Kruse // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4aa8a9761SMichael Kruse // See https://llvm.org/LICENSE.txt for license information. 5aa8a9761SMichael Kruse // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6aa8a9761SMichael Kruse // 7aa8a9761SMichael Kruse //===----------------------------------------------------------------------===// 8aa8a9761SMichael Kruse // 9aa8a9761SMichael Kruse // Make changes to isl's schedule tree data structure. 10aa8a9761SMichael Kruse // 11aa8a9761SMichael Kruse //===----------------------------------------------------------------------===// 12aa8a9761SMichael Kruse 13aa8a9761SMichael Kruse #include "polly/ScheduleTreeTransform.h" 1464489255SMichael Kruse #include "polly/Support/GICHelper.h" 15aa8a9761SMichael Kruse #include "polly/Support/ISLTools.h" 163f170eb1SMichael Kruse #include "polly/Support/ScopHelper.h" 17aa8a9761SMichael Kruse #include "llvm/ADT/ArrayRef.h" 183f170eb1SMichael Kruse #include "llvm/ADT/Sequence.h" 19aa8a9761SMichael Kruse #include "llvm/ADT/SmallVector.h" 203f170eb1SMichael Kruse #include "llvm/IR/Constants.h" 213f170eb1SMichael Kruse #include "llvm/IR/Metadata.h" 223f170eb1SMichael Kruse #include "llvm/Transforms/Utils/UnrollLoop.h" 23aa8a9761SMichael Kruse 24601d7eabSKarthika Devi C #include "polly/Support/PollyDebug.h" 2564489255SMichael Kruse #define DEBUG_TYPE "polly-opt-isl" 2664489255SMichael Kruse 27aa8a9761SMichael Kruse using namespace polly; 283f170eb1SMichael Kruse using namespace llvm; 29aa8a9761SMichael Kruse 30aa8a9761SMichael Kruse namespace { 3164489255SMichael Kruse 3264489255SMichael Kruse /// Copy the band member attributes (coincidence, loop type, isolate ast loop 3364489255SMichael Kruse /// type) from one band to another. 3464489255SMichael Kruse static isl::schedule_node_band 3564489255SMichael Kruse applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx, 3664489255SMichael Kruse const isl::schedule_node_band &Source, 3764489255SMichael Kruse int SourceIdx) { 3864489255SMichael Kruse bool Coincident = Source.member_get_coincident(SourceIdx).release(); 3964489255SMichael Kruse Target = Target.member_set_coincident(TargetIdx, Coincident); 4064489255SMichael Kruse 4164489255SMichael Kruse isl_ast_loop_type LoopType = 4264489255SMichael Kruse isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx); 4364489255SMichael Kruse Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type( 4464489255SMichael Kruse Target.release(), TargetIdx, LoopType)) 4564489255SMichael Kruse .as<isl::schedule_node_band>(); 4664489255SMichael Kruse 4764489255SMichael Kruse isl_ast_loop_type IsolateType = 4864489255SMichael Kruse isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(), 4964489255SMichael Kruse SourceIdx); 5064489255SMichael Kruse Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type( 5164489255SMichael Kruse Target.release(), TargetIdx, IsolateType)) 5264489255SMichael Kruse .as<isl::schedule_node_band>(); 5364489255SMichael Kruse 5464489255SMichael Kruse return Target; 5564489255SMichael Kruse } 5664489255SMichael Kruse 5764489255SMichael Kruse /// Create a new band by copying members from another @p Band. @p IncludeCb 5864489255SMichael Kruse /// decides which band indices are copied to the result. 5964489255SMichael Kruse template <typename CbTy> 6064489255SMichael Kruse static isl::schedule rebuildBand(isl::schedule_node_band OldBand, 6164489255SMichael Kruse isl::schedule Body, CbTy IncludeCb) { 6244596fe6SRiccardo Mori int NumBandDims = unsignedFromIslSize(OldBand.n_member()); 6364489255SMichael Kruse 6464489255SMichael Kruse bool ExcludeAny = false; 6564489255SMichael Kruse bool IncludeAny = false; 6664489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) { 6764489255SMichael Kruse if (IncludeCb(OldIdx)) 6864489255SMichael Kruse IncludeAny = true; 6964489255SMichael Kruse else 7064489255SMichael Kruse ExcludeAny = true; 7164489255SMichael Kruse } 7264489255SMichael Kruse 7364489255SMichael Kruse // Instead of creating a zero-member band, don't create a band at all. 7464489255SMichael Kruse if (!IncludeAny) 7564489255SMichael Kruse return Body; 7664489255SMichael Kruse 7764489255SMichael Kruse isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule(); 7864489255SMichael Kruse isl::multi_union_pw_aff NewPartialSched; 7964489255SMichael Kruse if (ExcludeAny) { 8064489255SMichael Kruse // Select the included partial scatter functions. 8164489255SMichael Kruse isl::union_pw_aff_list List = PartialSched.list(); 8264489255SMichael Kruse int NewIdx = 0; 8364489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) { 8464489255SMichael Kruse if (IncludeCb(OldIdx)) 8564489255SMichael Kruse NewIdx += 1; 8664489255SMichael Kruse else 8764489255SMichael Kruse List = List.drop(NewIdx, 1); 8864489255SMichael Kruse } 8964489255SMichael Kruse isl::space ParamSpace = PartialSched.get_space().params(); 9064489255SMichael Kruse isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx); 9164489255SMichael Kruse NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List); 9264489255SMichael Kruse } else { 9364489255SMichael Kruse // Just reuse original scatter function of copying all of them. 9464489255SMichael Kruse NewPartialSched = PartialSched; 9564489255SMichael Kruse } 9664489255SMichael Kruse 9764489255SMichael Kruse // Create the new band node. 9864489255SMichael Kruse isl::schedule_node_band NewBand = 9964489255SMichael Kruse Body.insert_partial_schedule(NewPartialSched) 10064489255SMichael Kruse .get_root() 10164489255SMichael Kruse .child(0) 10264489255SMichael Kruse .as<isl::schedule_node_band>(); 10364489255SMichael Kruse 10464489255SMichael Kruse // If OldBand was permutable, so is the new one, even if some dimensions are 10564489255SMichael Kruse // missing. 10664489255SMichael Kruse bool IsPermutable = OldBand.permutable().release(); 10764489255SMichael Kruse NewBand = NewBand.set_permutable(IsPermutable); 10864489255SMichael Kruse 10964489255SMichael Kruse // Reapply member attributes. 11064489255SMichael Kruse int NewIdx = 0; 11164489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) { 11264489255SMichael Kruse if (!IncludeCb(OldIdx)) 11364489255SMichael Kruse continue; 11464489255SMichael Kruse NewBand = 11564489255SMichael Kruse applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx); 11664489255SMichael Kruse NewIdx += 1; 11764489255SMichael Kruse } 11864489255SMichael Kruse 11964489255SMichael Kruse return NewBand.get_schedule(); 12064489255SMichael Kruse } 12164489255SMichael Kruse 122aa8a9761SMichael Kruse /// Rewrite a schedule tree by reconstructing it bottom-up. 123aa8a9761SMichael Kruse /// 124aa8a9761SMichael Kruse /// By default, the original schedule tree is reconstructed. To build a 125aa8a9761SMichael Kruse /// different tree, redefine visitor methods in a derived class (CRTP). 126aa8a9761SMichael Kruse /// 127aa8a9761SMichael Kruse /// Note that AST build options are not applied; Setting the isolate[] option 128aa8a9761SMichael Kruse /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 129aa8a9761SMichael Kruse /// AST build options must be set after the tree has been constructed. 130aa8a9761SMichael Kruse template <typename Derived, typename... Args> 131aa8a9761SMichael Kruse struct ScheduleTreeRewriter 132bd93df93SMichael Kruse : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 133aa8a9761SMichael Kruse Derived &getDerived() { return *static_cast<Derived *>(this); } 134aa8a9761SMichael Kruse const Derived &getDerived() const { 135aa8a9761SMichael Kruse return *static_cast<const Derived *>(this); 136aa8a9761SMichael Kruse } 137aa8a9761SMichael Kruse 138c62d9a5cSMichael Kruse isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { 139aa8a9761SMichael Kruse // Every schedule_tree already has a domain node, no need to add one. 140aa8a9761SMichael Kruse return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 141aa8a9761SMichael Kruse } 142aa8a9761SMichael Kruse 143c62d9a5cSMichael Kruse isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { 144aa8a9761SMichael Kruse isl::schedule NewChild = 145aa8a9761SMichael Kruse getDerived().visit(Band.child(0), std::forward<Args>(args)...); 14664489255SMichael Kruse return rebuildBand(Band, NewChild, [](int) { return true; }); 147aa8a9761SMichael Kruse } 148aa8a9761SMichael Kruse 149c62d9a5cSMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 150aa8a9761SMichael Kruse Args... args) { 151aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get()); 152aa8a9761SMichael Kruse isl::schedule Result = 153aa8a9761SMichael Kruse getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 154aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) 155aa8a9761SMichael Kruse Result = Result.sequence( 156aa8a9761SMichael Kruse getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 157aa8a9761SMichael Kruse return Result; 158aa8a9761SMichael Kruse } 159aa8a9761SMichael Kruse 160c62d9a5cSMichael Kruse isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { 161aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Set.get()); 162aa8a9761SMichael Kruse isl::schedule Result = 163aa8a9761SMichael Kruse getDerived().visit(Set.child(0), std::forward<Args>(args)...); 164aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) 165aa8a9761SMichael Kruse Result = isl::manage( 166aa8a9761SMichael Kruse isl_schedule_set(Result.release(), 167aa8a9761SMichael Kruse getDerived() 168aa8a9761SMichael Kruse .visit(Set.child(i), std::forward<Args>(args)...) 169aa8a9761SMichael Kruse .release())); 170aa8a9761SMichael Kruse return Result; 171aa8a9761SMichael Kruse } 172aa8a9761SMichael Kruse 173c62d9a5cSMichael Kruse isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { 174aa8a9761SMichael Kruse return isl::schedule::from_domain(Leaf.get_domain()); 175aa8a9761SMichael Kruse } 176aa8a9761SMichael Kruse 177aa8a9761SMichael Kruse isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 178d3fdbda6SRiccardo Mori 179d3fdbda6SRiccardo Mori isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); 180aa8a9761SMichael Kruse isl::schedule_node NewChild = 181aa8a9761SMichael Kruse getDerived() 182aa8a9761SMichael Kruse .visit(Mark.first_child(), std::forward<Args>(args)...) 183aa8a9761SMichael Kruse .get_root() 184aa8a9761SMichael Kruse .first_child(); 185aa8a9761SMichael Kruse return NewChild.insert_mark(TheMark).get_schedule(); 186aa8a9761SMichael Kruse } 187aa8a9761SMichael Kruse 188c62d9a5cSMichael Kruse isl::schedule visitExtension(isl::schedule_node_extension Extension, 189aa8a9761SMichael Kruse Args... args) { 190d3fdbda6SRiccardo Mori isl::union_map TheExtension = 191d3fdbda6SRiccardo Mori Extension.as<isl::schedule_node_extension>().get_extension(); 192aa8a9761SMichael Kruse isl::schedule_node NewChild = getDerived() 193aa8a9761SMichael Kruse .visit(Extension.child(0), args...) 194aa8a9761SMichael Kruse .get_root() 195aa8a9761SMichael Kruse .first_child(); 196aa8a9761SMichael Kruse isl::schedule_node NewExtension = 197aa8a9761SMichael Kruse isl::schedule_node::from_extension(TheExtension); 198aa8a9761SMichael Kruse return NewChild.graft_before(NewExtension).get_schedule(); 199aa8a9761SMichael Kruse } 200aa8a9761SMichael Kruse 201c62d9a5cSMichael Kruse isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { 202d3fdbda6SRiccardo Mori isl::union_set FilterDomain = 203d3fdbda6SRiccardo Mori Filter.as<isl::schedule_node_filter>().get_filter(); 204aa8a9761SMichael Kruse isl::schedule NewSchedule = 205aa8a9761SMichael Kruse getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 206aa8a9761SMichael Kruse return NewSchedule.intersect_domain(FilterDomain); 207aa8a9761SMichael Kruse } 208aa8a9761SMichael Kruse 209c62d9a5cSMichael Kruse isl::schedule visitNode(isl::schedule_node Node, Args... args) { 210aa8a9761SMichael Kruse llvm_unreachable("Not implemented"); 211aa8a9761SMichael Kruse } 212aa8a9761SMichael Kruse }; 213aa8a9761SMichael Kruse 21464489255SMichael Kruse /// Rewrite the schedule tree without any changes. Useful to copy a subtree into 21564489255SMichael Kruse /// a new schedule, discarding everything but. 216bd93df93SMichael Kruse struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {}; 21764489255SMichael Kruse 218aa8a9761SMichael Kruse /// Rewrite a schedule tree to an equivalent one without extension nodes. 219aa8a9761SMichael Kruse /// 220aa8a9761SMichael Kruse /// Each visit method takes two additional arguments: 221aa8a9761SMichael Kruse /// 222aa8a9761SMichael Kruse /// * The new domain the node, which is the inherited domain plus any domains 223aa8a9761SMichael Kruse /// added by extension nodes. 224aa8a9761SMichael Kruse /// 225aa8a9761SMichael Kruse /// * A map of extension domains of all children is returned; it is required by 226aa8a9761SMichael Kruse /// band nodes to schedule the additional domains at the same position as the 227aa8a9761SMichael Kruse /// extension node would. 228aa8a9761SMichael Kruse /// 229bd93df93SMichael Kruse struct ExtensionNodeRewriter final 230bd93df93SMichael Kruse : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 231aa8a9761SMichael Kruse isl::union_map &> { 232aa8a9761SMichael Kruse using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 233aa8a9761SMichael Kruse const isl::union_set &, isl::union_map &>; 234aa8a9761SMichael Kruse BaseTy &getBase() { return *this; } 235aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; } 236aa8a9761SMichael Kruse 237c62d9a5cSMichael Kruse isl::schedule visitSchedule(isl::schedule Schedule) { 238aa8a9761SMichael Kruse isl::union_map Extensions; 239aa8a9761SMichael Kruse isl::schedule Result = 240aa8a9761SMichael Kruse visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 2417c7978a1Spatacca assert(!Extensions.is_null() && Extensions.is_empty()); 242aa8a9761SMichael Kruse return Result; 243aa8a9761SMichael Kruse } 244aa8a9761SMichael Kruse 245c62d9a5cSMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 246aa8a9761SMichael Kruse const isl::union_set &Domain, 247aa8a9761SMichael Kruse isl::union_map &Extensions) { 248aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get()); 249aa8a9761SMichael Kruse isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 250aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) { 251aa8a9761SMichael Kruse isl::schedule_node OldChild = Sequence.child(i); 252aa8a9761SMichael Kruse isl::union_map NewChildExtensions; 253aa8a9761SMichael Kruse isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 254aa8a9761SMichael Kruse NewNode = NewNode.sequence(NewChildNode); 255aa8a9761SMichael Kruse Extensions = Extensions.unite(NewChildExtensions); 256aa8a9761SMichael Kruse } 257aa8a9761SMichael Kruse return NewNode; 258aa8a9761SMichael Kruse } 259aa8a9761SMichael Kruse 260c62d9a5cSMichael Kruse isl::schedule visitSet(isl::schedule_node_set Set, 261aa8a9761SMichael Kruse const isl::union_set &Domain, 262aa8a9761SMichael Kruse isl::union_map &Extensions) { 263aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Set.get()); 264aa8a9761SMichael Kruse isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 265aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) { 266aa8a9761SMichael Kruse isl::schedule_node OldChild = Set.child(i); 267aa8a9761SMichael Kruse isl::union_map NewChildExtensions; 268aa8a9761SMichael Kruse isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 269aa8a9761SMichael Kruse NewNode = isl::manage( 270aa8a9761SMichael Kruse isl_schedule_set(NewNode.release(), NewChildNode.release())); 271aa8a9761SMichael Kruse Extensions = Extensions.unite(NewChildExtensions); 272aa8a9761SMichael Kruse } 273aa8a9761SMichael Kruse return NewNode; 274aa8a9761SMichael Kruse } 275aa8a9761SMichael Kruse 276c62d9a5cSMichael Kruse isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, 277aa8a9761SMichael Kruse const isl::union_set &Domain, 278aa8a9761SMichael Kruse isl::union_map &Extensions) { 279bad3ebbaSRiccardo Mori Extensions = isl::union_map::empty(Leaf.ctx()); 280aa8a9761SMichael Kruse return isl::schedule::from_domain(Domain); 281aa8a9761SMichael Kruse } 282aa8a9761SMichael Kruse 283c62d9a5cSMichael Kruse isl::schedule visitBand(isl::schedule_node_band OldNode, 284aa8a9761SMichael Kruse const isl::union_set &Domain, 285aa8a9761SMichael Kruse isl::union_map &OuterExtensions) { 286aa8a9761SMichael Kruse isl::schedule_node OldChild = OldNode.first_child(); 287aa8a9761SMichael Kruse isl::multi_union_pw_aff PartialSched = 288aa8a9761SMichael Kruse isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 289aa8a9761SMichael Kruse 290aa8a9761SMichael Kruse isl::union_map NewChildExtensions; 291aa8a9761SMichael Kruse isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 292aa8a9761SMichael Kruse 293aa8a9761SMichael Kruse // Add the extensions to the partial schedule. 294bad3ebbaSRiccardo Mori OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx()); 295aa8a9761SMichael Kruse isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 296aa8a9761SMichael Kruse unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 297aa8a9761SMichael Kruse for (isl::map Ext : NewChildExtensions.get_map_list()) { 29844596fe6SRiccardo Mori unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim()); 299aa8a9761SMichael Kruse assert(ExtDims >= BandDims); 300aa8a9761SMichael Kruse unsigned OuterDims = ExtDims - BandDims; 301aa8a9761SMichael Kruse 302aa8a9761SMichael Kruse isl::map BandSched = 303aa8a9761SMichael Kruse Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 304aa8a9761SMichael Kruse NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 305aa8a9761SMichael Kruse 306aa8a9761SMichael Kruse // There might be more outer bands that have to schedule the extensions. 307aa8a9761SMichael Kruse if (OuterDims > 0) { 308aa8a9761SMichael Kruse isl::map OuterSched = 309aa8a9761SMichael Kruse Ext.project_out(isl::dim::in, OuterDims, BandDims); 310d5ee355fSRiccardo Mori OuterExtensions = OuterExtensions.unite(OuterSched); 311aa8a9761SMichael Kruse } 312aa8a9761SMichael Kruse } 313aa8a9761SMichael Kruse isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 314aa8a9761SMichael Kruse isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 315aa8a9761SMichael Kruse isl::schedule_node NewNode = 316aa8a9761SMichael Kruse NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 317aa8a9761SMichael Kruse .get_root() 318d3fdbda6SRiccardo Mori .child(0); 319aa8a9761SMichael Kruse 320aa8a9761SMichael Kruse // Reapply permutability and coincidence attributes. 321aa8a9761SMichael Kruse NewNode = isl::manage(isl_schedule_node_band_set_permutable( 322aa8a9761SMichael Kruse NewNode.release(), 323aa8a9761SMichael Kruse isl_schedule_node_band_get_permutable(OldNode.get()))); 32464489255SMichael Kruse for (unsigned i = 0; i < BandDims; i += 1) 32564489255SMichael Kruse NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(), 32664489255SMichael Kruse i, OldNode, i); 327aa8a9761SMichael Kruse 328aa8a9761SMichael Kruse return NewNode.get_schedule(); 329aa8a9761SMichael Kruse } 330aa8a9761SMichael Kruse 331c62d9a5cSMichael Kruse isl::schedule visitFilter(isl::schedule_node_filter Filter, 332aa8a9761SMichael Kruse const isl::union_set &Domain, 333aa8a9761SMichael Kruse isl::union_map &Extensions) { 334d3fdbda6SRiccardo Mori isl::union_set FilterDomain = 335d3fdbda6SRiccardo Mori Filter.as<isl::schedule_node_filter>().get_filter(); 336aa8a9761SMichael Kruse isl::union_set NewDomain = Domain.intersect(FilterDomain); 337aa8a9761SMichael Kruse 338aa8a9761SMichael Kruse // A filter is added implicitly if necessary when joining schedule trees. 339aa8a9761SMichael Kruse return visit(Filter.first_child(), NewDomain, Extensions); 340aa8a9761SMichael Kruse } 341aa8a9761SMichael Kruse 342c62d9a5cSMichael Kruse isl::schedule visitExtension(isl::schedule_node_extension Extension, 343aa8a9761SMichael Kruse const isl::union_set &Domain, 344aa8a9761SMichael Kruse isl::union_map &Extensions) { 345d3fdbda6SRiccardo Mori isl::union_map ExtDomain = 346d3fdbda6SRiccardo Mori Extension.as<isl::schedule_node_extension>().get_extension(); 347aa8a9761SMichael Kruse isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 348aa8a9761SMichael Kruse isl::union_map ChildExtensions; 349aa8a9761SMichael Kruse isl::schedule NewChild = 350aa8a9761SMichael Kruse visit(Extension.first_child(), NewDomain, ChildExtensions); 351aa8a9761SMichael Kruse Extensions = ChildExtensions.unite(ExtDomain); 352aa8a9761SMichael Kruse return NewChild; 353aa8a9761SMichael Kruse } 354aa8a9761SMichael Kruse }; 355aa8a9761SMichael Kruse 356aa8a9761SMichael Kruse /// Collect all AST build options in any schedule tree band. 357aa8a9761SMichael Kruse /// 358aa8a9761SMichael Kruse /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 359aa8a9761SMichael Kruse /// collects these options to apply them later. 360bd93df93SMichael Kruse struct CollectASTBuildOptions final 361bd93df93SMichael Kruse : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 362aa8a9761SMichael Kruse using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 363aa8a9761SMichael Kruse BaseTy &getBase() { return *this; } 364aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; } 365aa8a9761SMichael Kruse 366aa8a9761SMichael Kruse llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 367aa8a9761SMichael Kruse 368c62d9a5cSMichael Kruse void visitBand(isl::schedule_node_band Band) { 369aa8a9761SMichael Kruse ASTBuildOptions.push_back( 370aa8a9761SMichael Kruse isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 371aa8a9761SMichael Kruse return getBase().visitBand(Band); 372aa8a9761SMichael Kruse } 373aa8a9761SMichael Kruse }; 374aa8a9761SMichael Kruse 375aa8a9761SMichael Kruse /// Apply AST build options to the bands in a schedule tree. 376aa8a9761SMichael Kruse /// 377aa8a9761SMichael Kruse /// This rewrites a schedule tree with the AST build options applied. We assume 378aa8a9761SMichael Kruse /// that the band nodes are visited in the same order as they were when the 379aa8a9761SMichael Kruse /// build options were collected, typically by CollectASTBuildOptions. 380bd93df93SMichael Kruse struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> { 381aa8a9761SMichael Kruse using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 382aa8a9761SMichael Kruse BaseTy &getBase() { return *this; } 383aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; } 384aa8a9761SMichael Kruse 385bd9e810bSMichael Kruse size_t Pos; 386aa8a9761SMichael Kruse llvm::ArrayRef<isl::union_set> ASTBuildOptions; 387aa8a9761SMichael Kruse 388aa8a9761SMichael Kruse ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 389aa8a9761SMichael Kruse : ASTBuildOptions(ASTBuildOptions) {} 390aa8a9761SMichael Kruse 391c62d9a5cSMichael Kruse isl::schedule visitSchedule(isl::schedule Schedule) { 392aa8a9761SMichael Kruse Pos = 0; 393aa8a9761SMichael Kruse isl::schedule Result = visit(Schedule).get_schedule(); 394aa8a9761SMichael Kruse assert(Pos == ASTBuildOptions.size() && 395aa8a9761SMichael Kruse "AST build options must match to band nodes"); 396aa8a9761SMichael Kruse return Result; 397aa8a9761SMichael Kruse } 398aa8a9761SMichael Kruse 399c62d9a5cSMichael Kruse isl::schedule_node visitBand(isl::schedule_node_band Band) { 400c62d9a5cSMichael Kruse isl::schedule_node_band Result = 401c62d9a5cSMichael Kruse Band.set_ast_build_options(ASTBuildOptions[Pos]); 402aa8a9761SMichael Kruse Pos += 1; 403aa8a9761SMichael Kruse return getBase().visitBand(Result); 404aa8a9761SMichael Kruse } 405aa8a9761SMichael Kruse }; 406aa8a9761SMichael Kruse 407aa8a9761SMichael Kruse /// Return whether the schedule contains an extension node. 408aa8a9761SMichael Kruse static bool containsExtensionNode(isl::schedule Schedule) { 409aa8a9761SMichael Kruse assert(!Schedule.is_null()); 410aa8a9761SMichael Kruse 411aa8a9761SMichael Kruse auto Callback = [](__isl_keep isl_schedule_node *Node, 412aa8a9761SMichael Kruse void *User) -> isl_bool { 413aa8a9761SMichael Kruse if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 414aa8a9761SMichael Kruse // Stop walking the schedule tree. 415aa8a9761SMichael Kruse return isl_bool_error; 416aa8a9761SMichael Kruse } 417aa8a9761SMichael Kruse 418aa8a9761SMichael Kruse // Continue searching the subtree. 419aa8a9761SMichael Kruse return isl_bool_true; 420aa8a9761SMichael Kruse }; 421aa8a9761SMichael Kruse isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 422aa8a9761SMichael Kruse Schedule.get(), Callback, nullptr); 423aa8a9761SMichael Kruse 424aa8a9761SMichael Kruse // We assume that the traversal itself does not fail, i.e. the only reason to 425aa8a9761SMichael Kruse // return isl_stat_error is that an extension node was found. 426aa8a9761SMichael Kruse return RetVal == isl_stat_error; 427aa8a9761SMichael Kruse } 428aa8a9761SMichael Kruse 4293f170eb1SMichael Kruse /// Find a named MDNode property in a LoopID. 4303f170eb1SMichael Kruse static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { 4313f170eb1SMichael Kruse return dyn_cast_or_null<MDNode>( 43230c67587SKazu Hirata findMetadataOperand(LoopMD, Name).value_or(nullptr)); 4333f170eb1SMichael Kruse } 4343f170eb1SMichael Kruse 4353f170eb1SMichael Kruse /// Is this node of type mark? 4363f170eb1SMichael Kruse static bool isMark(const isl::schedule_node &Node) { 4373f170eb1SMichael Kruse return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; 4383f170eb1SMichael Kruse } 4393f170eb1SMichael Kruse 44030df6d5dSDavid Blaikie /// Is this node of type band? 44130df6d5dSDavid Blaikie static bool isBand(const isl::schedule_node &Node) { 44230df6d5dSDavid Blaikie return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; 44330df6d5dSDavid Blaikie } 44430df6d5dSDavid Blaikie 445e470f926SMichael Kruse #ifndef NDEBUG 4463f170eb1SMichael Kruse /// Is this node a band of a single dimension (i.e. could represent a loop)? 4473f170eb1SMichael Kruse static bool isBandWithSingleLoop(const isl::schedule_node &Node) { 4483f170eb1SMichael Kruse return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; 4493f170eb1SMichael Kruse } 45030df6d5dSDavid Blaikie #endif 4513f170eb1SMichael Kruse 452e470f926SMichael Kruse static bool isLeaf(const isl::schedule_node &Node) { 453e470f926SMichael Kruse return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf; 454e470f926SMichael Kruse } 455e470f926SMichael Kruse 4563f170eb1SMichael Kruse /// Create an isl::id representing the output loop after a transformation. 4573f170eb1SMichael Kruse static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { 4583f170eb1SMichael Kruse // Don't need to id the followup. 4593f170eb1SMichael Kruse // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by 4603f170eb1SMichael Kruse // user followup-MD 4613f170eb1SMichael Kruse if (!FollowupLoopMD) 4623f170eb1SMichael Kruse return {}; 4633f170eb1SMichael Kruse 4643f170eb1SMichael Kruse BandAttr *Attr = new BandAttr(); 4653f170eb1SMichael Kruse Attr->Metadata = FollowupLoopMD; 4663f170eb1SMichael Kruse return getIslLoopAttr(Ctx, Attr); 4673f170eb1SMichael Kruse } 4683f170eb1SMichael Kruse 4693f170eb1SMichael Kruse /// A loop consists of a band and an optional marker that wraps it. Return the 4703f170eb1SMichael Kruse /// outermost of the two. 4713f170eb1SMichael Kruse 4723f170eb1SMichael Kruse /// That is, either the mark or, if there is not mark, the loop itself. Can 4733f170eb1SMichael Kruse /// start with either the mark or the band. 4743f170eb1SMichael Kruse static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { 4753f170eb1SMichael Kruse if (isBandMark(BandOrMark)) { 476d3fdbda6SRiccardo Mori assert(isBandWithSingleLoop(BandOrMark.child(0))); 4773f170eb1SMichael Kruse return BandOrMark; 4783f170eb1SMichael Kruse } 4793f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandOrMark)); 4803f170eb1SMichael Kruse 4813f170eb1SMichael Kruse isl::schedule_node Mark = BandOrMark.parent(); 4823f170eb1SMichael Kruse if (isBandMark(Mark)) 4833f170eb1SMichael Kruse return Mark; 4843f170eb1SMichael Kruse 4853f170eb1SMichael Kruse // Band has no loop marker. 4863f170eb1SMichael Kruse return BandOrMark; 4873f170eb1SMichael Kruse } 4883f170eb1SMichael Kruse 4893f170eb1SMichael Kruse static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, 4903f170eb1SMichael Kruse BandAttr *&Attr) { 4913f170eb1SMichael Kruse MarkOrBand = moveToBandMark(MarkOrBand); 4923f170eb1SMichael Kruse 4933f170eb1SMichael Kruse isl::schedule_node Band; 4943f170eb1SMichael Kruse if (isMark(MarkOrBand)) { 495d3fdbda6SRiccardo Mori Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 4963f170eb1SMichael Kruse Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); 4973f170eb1SMichael Kruse } else { 4983f170eb1SMichael Kruse Attr = nullptr; 4993f170eb1SMichael Kruse Band = MarkOrBand; 5003f170eb1SMichael Kruse } 5013f170eb1SMichael Kruse 5023f170eb1SMichael Kruse assert(isBandWithSingleLoop(Band)); 5033f170eb1SMichael Kruse return Band; 5043f170eb1SMichael Kruse } 5053f170eb1SMichael Kruse 5063f170eb1SMichael Kruse /// Remove the mark that wraps a loop. Return the band representing the loop. 5073f170eb1SMichael Kruse static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { 5083f170eb1SMichael Kruse BandAttr *Attr; 5093f170eb1SMichael Kruse return removeMark(MarkOrBand, Attr); 5103f170eb1SMichael Kruse } 5113f170eb1SMichael Kruse 5123f170eb1SMichael Kruse static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { 5133f170eb1SMichael Kruse assert(isBand(Band)); 5143f170eb1SMichael Kruse assert(moveToBandMark(Band).is_equal(Band) && 5153f170eb1SMichael Kruse "Don't add a two marks for a band"); 5163f170eb1SMichael Kruse 517d3fdbda6SRiccardo Mori return Band.insert_mark(Mark).child(0); 5183f170eb1SMichael Kruse } 5193f170eb1SMichael Kruse 5203f170eb1SMichael Kruse /// Return the (one-dimensional) set of numbers that are divisible by @p Factor 5213f170eb1SMichael Kruse /// with remainder @p Offset. 5223f170eb1SMichael Kruse /// 5233f170eb1SMichael Kruse /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } 5243f170eb1SMichael Kruse /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } 5253f170eb1SMichael Kruse /// 5263f170eb1SMichael Kruse static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, 5273f170eb1SMichael Kruse long Offset) { 5283f170eb1SMichael Kruse isl::val ValFactor{Ctx, Factor}; 5293f170eb1SMichael Kruse isl::val ValOffset{Ctx, Offset}; 5303f170eb1SMichael Kruse 5313f170eb1SMichael Kruse isl::space Unispace{Ctx, 0, 1}; 5323f170eb1SMichael Kruse isl::local_space LUnispace{Unispace}; 5333f170eb1SMichael Kruse isl::aff AffFactor{LUnispace, ValFactor}; 5343f170eb1SMichael Kruse isl::aff AffOffset{LUnispace, ValOffset}; 5353f170eb1SMichael Kruse 5363f170eb1SMichael Kruse isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); 5373f170eb1SMichael Kruse isl::aff DivMul = Id.mod(ValFactor); 5383f170eb1SMichael Kruse isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); 5393f170eb1SMichael Kruse isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); 5403f170eb1SMichael Kruse return Modulo.domain(); 5413f170eb1SMichael Kruse } 5423f170eb1SMichael Kruse 543d123e983SMichael Kruse /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. 544d123e983SMichael Kruse /// 545d123e983SMichael Kruse /// @param Set A set, which should be modified. 546d123e983SMichael Kruse /// @param VectorWidth A parameter, which determines the constraint. 547d123e983SMichael Kruse static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { 54844596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(Set.tuple_dim()); 54944596fe6SRiccardo Mori assert(Dims >= 1); 550d123e983SMichael Kruse isl::space Space = Set.get_space(); 551d123e983SMichael Kruse isl::local_space LocalSpace = isl::local_space(Space); 552d123e983SMichael Kruse isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 553d123e983SMichael Kruse ExtConstr = ExtConstr.set_constant_si(0); 554d123e983SMichael Kruse ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1); 555d123e983SMichael Kruse Set = Set.add_constraint(ExtConstr); 556d123e983SMichael Kruse ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 557d123e983SMichael Kruse ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); 558d123e983SMichael Kruse ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1); 559d123e983SMichael Kruse return Set.add_constraint(ExtConstr); 560d123e983SMichael Kruse } 56164489255SMichael Kruse 56264489255SMichael Kruse /// Collapse perfectly nested bands into a single band. 563bd93df93SMichael Kruse class BandCollapseRewriter final 564bd93df93SMichael Kruse : public ScheduleTreeRewriter<BandCollapseRewriter> { 56564489255SMichael Kruse private: 56664489255SMichael Kruse using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>; 56764489255SMichael Kruse BaseTy &getBase() { return *this; } 56864489255SMichael Kruse const BaseTy &getBase() const { return *this; } 56964489255SMichael Kruse 57064489255SMichael Kruse public: 57164489255SMichael Kruse isl::schedule visitBand(isl::schedule_node_band RootBand) { 57264489255SMichael Kruse isl::schedule_node_band Band = RootBand; 57364489255SMichael Kruse isl::ctx Ctx = Band.ctx(); 57464489255SMichael Kruse 575*5aafc6d5SChristian Clauss // Do not merge permutable band to avoid losing the permutability property. 57664489255SMichael Kruse // Cannot collapse even two permutable loops, they might be permutable 577ea540bc2SGabriel Ravier // individually, but not necassarily across. 57844596fe6SRiccardo Mori if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 57964489255SMichael Kruse return getBase().visitBand(Band); 58064489255SMichael Kruse 581*5aafc6d5SChristian Clauss // Find collapsible bands. 58264489255SMichael Kruse SmallVector<isl::schedule_node_band> Nest; 58364489255SMichael Kruse int NumTotalLoops = 0; 58464489255SMichael Kruse isl::schedule_node Body; 58564489255SMichael Kruse while (true) { 58664489255SMichael Kruse Nest.push_back(Band); 58744596fe6SRiccardo Mori NumTotalLoops += unsignedFromIslSize(Band.n_member()); 58864489255SMichael Kruse Body = Band.first_child(); 58964489255SMichael Kruse if (!Body.isa<isl::schedule_node_band>()) 59064489255SMichael Kruse break; 59164489255SMichael Kruse Band = Body.as<isl::schedule_node_band>(); 59264489255SMichael Kruse 59364489255SMichael Kruse // Do not include next band if it is permutable to not lose its 59464489255SMichael Kruse // permutability property. 59544596fe6SRiccardo Mori if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 59664489255SMichael Kruse break; 59764489255SMichael Kruse } 59864489255SMichael Kruse 59964489255SMichael Kruse // Nothing to collapse, preserve permutability. 60064489255SMichael Kruse if (Nest.size() <= 1) 60164489255SMichael Kruse return getBase().visitBand(Band); 60264489255SMichael Kruse 603601d7eabSKarthika Devi C POLLY_DEBUG({ 60464489255SMichael Kruse dbgs() << "Found loops to collapse between\n"; 60564489255SMichael Kruse dumpIslObj(RootBand, dbgs()); 60664489255SMichael Kruse dbgs() << "and\n"; 60764489255SMichael Kruse dumpIslObj(Body, dbgs()); 60864489255SMichael Kruse dbgs() << "\n"; 60964489255SMichael Kruse }); 61064489255SMichael Kruse 61164489255SMichael Kruse isl::schedule NewBody = visit(Body); 61264489255SMichael Kruse 61364489255SMichael Kruse // Collect partial schedules from all members. 61464489255SMichael Kruse isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops}; 61564489255SMichael Kruse for (isl::schedule_node_band Band : Nest) { 61644596fe6SRiccardo Mori int NumLoops = unsignedFromIslSize(Band.n_member()); 61764489255SMichael Kruse isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule(); 61864489255SMichael Kruse for (auto j : seq<int>(0, NumLoops)) 61964489255SMichael Kruse PartScheds = PartScheds.add(BandScheds.at(j)); 62064489255SMichael Kruse } 62164489255SMichael Kruse isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops); 62264489255SMichael Kruse isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds}; 62364489255SMichael Kruse 62464489255SMichael Kruse isl::schedule_node_band CollapsedBand = 62564489255SMichael Kruse NewBody.insert_partial_schedule(PartSchedsMulti) 62664489255SMichael Kruse .get_root() 62764489255SMichael Kruse .first_child() 62864489255SMichael Kruse .as<isl::schedule_node_band>(); 62964489255SMichael Kruse 63064489255SMichael Kruse // Copy over loop attributes form original bands. 63164489255SMichael Kruse int LoopIdx = 0; 63264489255SMichael Kruse for (isl::schedule_node_band Band : Nest) { 63344596fe6SRiccardo Mori int NumLoops = unsignedFromIslSize(Band.n_member()); 63464489255SMichael Kruse for (int i : seq<int>(0, NumLoops)) { 63564489255SMichael Kruse CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand), 63664489255SMichael Kruse LoopIdx, Band, i); 63764489255SMichael Kruse LoopIdx += 1; 63864489255SMichael Kruse } 63964489255SMichael Kruse } 64064489255SMichael Kruse assert(LoopIdx == NumTotalLoops && 64164489255SMichael Kruse "Expect the same number of loops to add up again"); 64264489255SMichael Kruse 64364489255SMichael Kruse return CollapsedBand.get_schedule(); 64464489255SMichael Kruse } 64564489255SMichael Kruse }; 64664489255SMichael Kruse 64764489255SMichael Kruse static isl::schedule collapseBands(isl::schedule Sched) { 648601d7eabSKarthika Devi C POLLY_DEBUG(dbgs() << "Collapse bands in schedule\n"); 64964489255SMichael Kruse BandCollapseRewriter Rewriter; 65064489255SMichael Kruse return Rewriter.visit(Sched); 65164489255SMichael Kruse } 65264489255SMichael Kruse 65364489255SMichael Kruse /// Collect sequentially executed bands (or anything else), even if nested in a 65464489255SMichael Kruse /// mark or other nodes whose child is executed just once. If we can 65564489255SMichael Kruse /// successfully fuse the bands, we allow them to be removed. 65664489255SMichael Kruse static void collectPotentiallyFusableBands( 65764489255SMichael Kruse isl::schedule_node Node, 65864489255SMichael Kruse SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>> 65964489255SMichael Kruse &ScheduleBands, 66064489255SMichael Kruse const isl::schedule_node &DirectChild) { 66164489255SMichael Kruse switch (isl_schedule_node_get_type(Node.get())) { 66264489255SMichael Kruse case isl_schedule_node_sequence: 66364489255SMichael Kruse case isl_schedule_node_set: 66464489255SMichael Kruse case isl_schedule_node_mark: 66564489255SMichael Kruse case isl_schedule_node_domain: 66664489255SMichael Kruse case isl_schedule_node_filter: 66764489255SMichael Kruse if (Node.has_children()) { 66864489255SMichael Kruse isl::schedule_node C = Node.first_child(); 66964489255SMichael Kruse while (true) { 67064489255SMichael Kruse collectPotentiallyFusableBands(C, ScheduleBands, DirectChild); 67164489255SMichael Kruse if (!C.has_next_sibling()) 67264489255SMichael Kruse break; 67364489255SMichael Kruse C = C.next_sibling(); 67464489255SMichael Kruse } 67564489255SMichael Kruse } 67664489255SMichael Kruse break; 67764489255SMichael Kruse 67864489255SMichael Kruse default: 67964489255SMichael Kruse // Something that does not execute suquentially (e.g. a band) 68064489255SMichael Kruse ScheduleBands.push_back({Node, DirectChild}); 68164489255SMichael Kruse break; 68264489255SMichael Kruse } 68364489255SMichael Kruse } 68464489255SMichael Kruse 68564489255SMichael Kruse /// Remove dependencies that are resolved by @p PartSched. That is, remove 68664489255SMichael Kruse /// everything that we already know is executed in-order. 68764489255SMichael Kruse static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched, 68864489255SMichael Kruse isl::union_map Deps) { 68944596fe6SRiccardo Mori unsigned NumDims = getNumScatterDims(PartSched); 69064489255SMichael Kruse auto ParamSpace = PartSched.get_space().params(); 69164489255SMichael Kruse 69264489255SMichael Kruse // { Scatter[] } 69364489255SMichael Kruse isl::space ScatterSpace = 69464489255SMichael Kruse ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims); 69564489255SMichael Kruse 69664489255SMichael Kruse // { Scatter[] -> Domain[] } 69764489255SMichael Kruse isl::union_map PartSchedRev = PartSched.reverse(); 69864489255SMichael Kruse 69964489255SMichael Kruse // { Scatter[] -> Scatter[] } 70064489255SMichael Kruse isl::map MaybeBefore = isl::map::lex_le(ScatterSpace); 70164489255SMichael Kruse 70264489255SMichael Kruse // { Domain[] -> Domain[] } 70364489255SMichael Kruse isl::union_map DomMaybeBefore = 70464489255SMichael Kruse MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev); 70564489255SMichael Kruse 70664489255SMichael Kruse // { Domain[] -> Domain[] } 70764489255SMichael Kruse isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore); 70864489255SMichael Kruse 70964489255SMichael Kruse return ChildRemainingDeps; 71064489255SMichael Kruse } 71164489255SMichael Kruse 71264489255SMichael Kruse /// Remove dependencies that are resolved by executing them in the order 71364489255SMichael Kruse /// specified by @p Domains; 71464489255SMichael Kruse static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains, 71564489255SMichael Kruse isl::union_map Deps) { 71664489255SMichael Kruse isl::ctx Ctx = Deps.ctx(); 71764489255SMichael Kruse isl::space ParamSpace = Deps.get_space().params(); 71864489255SMichael Kruse 71964489255SMichael Kruse // Create a partial schedule mapping to constants that reflect the execution 72064489255SMichael Kruse // order. 72164489255SMichael Kruse isl::union_map PartialSchedules = isl::union_map::empty(Ctx); 72264489255SMichael Kruse for (auto P : enumerate(Domains)) { 72364489255SMichael Kruse isl::val ExecTime = isl::val(Ctx, P.index()); 72464489255SMichael Kruse isl::union_pw_aff DomSched{P.value(), ExecTime}; 72564489255SMichael Kruse PartialSchedules = PartialSchedules.unite(DomSched.as_union_map()); 72664489255SMichael Kruse } 72764489255SMichael Kruse 72864489255SMichael Kruse return remainingDepsFromPartialSchedule(PartialSchedules, Deps); 72964489255SMichael Kruse } 73064489255SMichael Kruse 73164489255SMichael Kruse /// Determine whether the outermost loop of to bands can be fused while 73264489255SMichael Kruse /// respecting validity dependencies. 73364489255SMichael Kruse static bool canFuseOutermost(const isl::schedule_node_band &LHS, 73464489255SMichael Kruse const isl::schedule_node_band &RHS, 73564489255SMichael Kruse const isl::union_map &Deps) { 73664489255SMichael Kruse // { LHSDomain[] -> Scatter[] } 73764489255SMichael Kruse isl::union_map LHSPartSched = 73864489255SMichael Kruse LHS.get_partial_schedule().get_at(0).as_union_map(); 73964489255SMichael Kruse 74064489255SMichael Kruse // { Domain[] -> Scatter[] } 74164489255SMichael Kruse isl::union_map RHSPartSched = 74264489255SMichael Kruse RHS.get_partial_schedule().get_at(0).as_union_map(); 74364489255SMichael Kruse 74464489255SMichael Kruse // Dependencies that are already resolved because LHS executes before RHS, but 74564489255SMichael Kruse // will not be anymore after fusion. { DefDomain[] -> UseDomain[] } 74664489255SMichael Kruse isl::union_map OrderedBySequence = 74764489255SMichael Kruse Deps.intersect_domain(LHSPartSched.domain()) 74864489255SMichael Kruse .intersect_range(RHSPartSched.domain()); 74964489255SMichael Kruse 75064489255SMichael Kruse isl::space ParamSpace = OrderedBySequence.get_space().params(); 75164489255SMichael Kruse isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1); 75264489255SMichael Kruse 75364489255SMichael Kruse // { Scatter[] -> Scatter[] } 75464489255SMichael Kruse isl::map After = isl::map::lex_gt(NewScatterSpace); 75564489255SMichael Kruse 75664489255SMichael Kruse // After fusion, instances with smaller (or equal, which means they will be 75764489255SMichael Kruse // executed in the same iteration, but the LHS instance is still sequenced 75864489255SMichael Kruse // before RHS) scatter value will still be executed before. This are the 75964489255SMichael Kruse // orderings where this is not necessarily the case. 76064489255SMichael Kruse // { LHSDomain[] -> RHSDomain[] } 76164489255SMichael Kruse isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse()) 76264489255SMichael Kruse .apply_range(RHSPartSched.reverse()); 76364489255SMichael Kruse 76464489255SMichael Kruse // Dependencies that are not resolved by the new execution order. 76564489255SMichael Kruse isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms); 76664489255SMichael Kruse 76764489255SMichael Kruse return WithBefore.is_empty(); 76864489255SMichael Kruse } 76964489255SMichael Kruse 77064489255SMichael Kruse /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies. 77164489255SMichael Kruse static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS, 77264489255SMichael Kruse isl::schedule_node_band RHS, 77364489255SMichael Kruse const isl::union_map &Deps) { 77464489255SMichael Kruse if (!canFuseOutermost(LHS, RHS, Deps)) 77564489255SMichael Kruse return {}; 77664489255SMichael Kruse 777601d7eabSKarthika Devi C POLLY_DEBUG({ 77864489255SMichael Kruse dbgs() << "Found loops for greedy fusion:\n"; 77964489255SMichael Kruse dumpIslObj(LHS, dbgs()); 78064489255SMichael Kruse dbgs() << "and\n"; 78164489255SMichael Kruse dumpIslObj(RHS, dbgs()); 78264489255SMichael Kruse dbgs() << "\n"; 78364489255SMichael Kruse }); 78464489255SMichael Kruse 78564489255SMichael Kruse // The partial schedule of the bands outermost loop that we need to combine 78664489255SMichael Kruse // for the fusion. 78764489255SMichael Kruse isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0); 78864489255SMichael Kruse isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0); 78964489255SMichael Kruse 79064489255SMichael Kruse // Isolate band bodies as roots of their own schedule trees. 79164489255SMichael Kruse IdentityRewriter Rewriter; 79264489255SMichael Kruse isl::schedule LHSBody = Rewriter.visit(LHS.first_child()); 79364489255SMichael Kruse isl::schedule RHSBody = Rewriter.visit(RHS.first_child()); 79464489255SMichael Kruse 79564489255SMichael Kruse // Reconstruct the non-outermost (not going to be fused) loops from both 79664489255SMichael Kruse // bands. 79764489255SMichael Kruse // TODO: Maybe it is possibly to transfer the 'permutability' property from 79864489255SMichael Kruse // LHS+RHS. At minimum we need merge multiple band members at once, otherwise 79964489255SMichael Kruse // permutability has no meaning. 80064489255SMichael Kruse isl::schedule LHSNewBody = 80164489255SMichael Kruse rebuildBand(LHS, LHSBody, [](int i) { return i > 0; }); 80264489255SMichael Kruse isl::schedule RHSNewBody = 80364489255SMichael Kruse rebuildBand(RHS, RHSBody, [](int i) { return i > 0; }); 80464489255SMichael Kruse 80564489255SMichael Kruse // The loop body of the fused loop. 80664489255SMichael Kruse isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody); 80764489255SMichael Kruse 80864489255SMichael Kruse // Combine the partial schedules of both loops to a new one. Instances with 80964489255SMichael Kruse // the same scatter value are put together. 81064489255SMichael Kruse isl::union_map NewCommonPartialSched = 81164489255SMichael Kruse LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map()); 81264489255SMichael Kruse isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule( 81364489255SMichael Kruse NewCommonPartialSched.as_multi_union_pw_aff()); 81464489255SMichael Kruse 81564489255SMichael Kruse return NewCommonSchedule; 81664489255SMichael Kruse } 81764489255SMichael Kruse 81864489255SMichael Kruse static isl::schedule tryGreedyFuse(isl::schedule_node LHS, 81964489255SMichael Kruse isl::schedule_node RHS, 82064489255SMichael Kruse const isl::union_map &Deps) { 82164489255SMichael Kruse // TODO: Non-bands could be interpreted as a band with just as single 82264489255SMichael Kruse // iteration. However, this is only useful if both ends of a fused loop were 82364489255SMichael Kruse // originally loops themselves. 82464489255SMichael Kruse if (!LHS.isa<isl::schedule_node_band>()) 82564489255SMichael Kruse return {}; 82664489255SMichael Kruse if (!RHS.isa<isl::schedule_node_band>()) 82764489255SMichael Kruse return {}; 82864489255SMichael Kruse 82964489255SMichael Kruse return tryGreedyFuse(LHS.as<isl::schedule_node_band>(), 83064489255SMichael Kruse RHS.as<isl::schedule_node_band>(), Deps); 83164489255SMichael Kruse } 83264489255SMichael Kruse 83364489255SMichael Kruse /// Fuse all fusable loop top-down in a schedule tree. 83464489255SMichael Kruse /// 83564489255SMichael Kruse /// The isl::union_map parameters is the set of validity dependencies that have 83664489255SMichael Kruse /// not been resolved/carried by a parent schedule node. 837bd93df93SMichael Kruse class GreedyFusionRewriter final 83864489255SMichael Kruse : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> { 83964489255SMichael Kruse private: 84064489255SMichael Kruse using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>; 84164489255SMichael Kruse BaseTy &getBase() { return *this; } 84264489255SMichael Kruse const BaseTy &getBase() const { return *this; } 84364489255SMichael Kruse 84464489255SMichael Kruse public: 84564489255SMichael Kruse /// Is set to true if anything has been fused. 84664489255SMichael Kruse bool AnyChange = false; 84764489255SMichael Kruse 84864489255SMichael Kruse isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) { 84964489255SMichael Kruse // { Domain[] -> Scatter[] } 85064489255SMichael Kruse isl::union_map PartSched = 85164489255SMichael Kruse isl::union_map::from(Band.get_partial_schedule()); 85244596fe6SRiccardo Mori assert(getNumScatterDims(PartSched) == 85344596fe6SRiccardo Mori unsignedFromIslSize(Band.n_member())); 85464489255SMichael Kruse isl::space ParamSpace = PartSched.get_space().params(); 85564489255SMichael Kruse 85664489255SMichael Kruse // { Scatter[] -> Domain[] } 85764489255SMichael Kruse isl::union_map PartSchedRev = PartSched.reverse(); 85864489255SMichael Kruse 85964489255SMichael Kruse // Possible within the same iteration. Dependencies with smaller scatter 86064489255SMichael Kruse // value are carried by this loop and therefore have been resolved by the 86164489255SMichael Kruse // in-order execution if the loop iteration. A dependency with small scatter 86264489255SMichael Kruse // value would be a dependency violation that we assume did not happen. { 86364489255SMichael Kruse // Domain[] -> Domain[] } 86464489255SMichael Kruse isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev); 86564489255SMichael Kruse 86664489255SMichael Kruse // Actual dependencies within the same iteration. 86764489255SMichael Kruse // { DefDomain[] -> UseDomain[] } 86864489255SMichael Kruse isl::union_map RemDeps = Deps.intersect(Unsequenced); 86964489255SMichael Kruse 87064489255SMichael Kruse return getBase().visitBand(Band, RemDeps); 87164489255SMichael Kruse } 87264489255SMichael Kruse 87364489255SMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 87464489255SMichael Kruse isl::union_map Deps) { 87564489255SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get()); 87664489255SMichael Kruse 87764489255SMichael Kruse // List of fusion candidates. The first element is the fusion candidate, the 87864489255SMichael Kruse // second is candidate's ancestor that is the sequence's direct child. It is 87964489255SMichael Kruse // preferable to use the direct child if not if its non-direct children is 88064489255SMichael Kruse // fused to preserve its structure such as mark nodes. 88164489255SMichael Kruse SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands; 88264489255SMichael Kruse for (auto i : seq<int>(0, NumChildren)) { 88364489255SMichael Kruse isl::schedule_node Child = Sequence.child(i); 88464489255SMichael Kruse collectPotentiallyFusableBands(Child, Bands, Child); 88564489255SMichael Kruse } 88664489255SMichael Kruse 887*5aafc6d5SChristian Clauss // Direct children that had at least one of its descendants fused. 88864489255SMichael Kruse SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren; 88964489255SMichael Kruse 890*5aafc6d5SChristian Clauss // Fuse neighboring bands until reaching the end of candidates. 89164489255SMichael Kruse int i = 0; 89264489255SMichael Kruse while (i + 1 < (int)Bands.size()) { 89364489255SMichael Kruse isl::schedule Fused = 89464489255SMichael Kruse tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps); 89564489255SMichael Kruse if (Fused.is_null()) { 89664489255SMichael Kruse // Cannot merge this node with the next; look at next pair. 89764489255SMichael Kruse i += 1; 89864489255SMichael Kruse continue; 89964489255SMichael Kruse } 90064489255SMichael Kruse 90164489255SMichael Kruse // Mark the direct children as (partially) fused. 90264489255SMichael Kruse if (!Bands[i].second.is_null()) 90364489255SMichael Kruse ChangedDirectChildren.insert(Bands[i].second.get()); 90464489255SMichael Kruse if (!Bands[i + 1].second.is_null()) 90564489255SMichael Kruse ChangedDirectChildren.insert(Bands[i + 1].second.get()); 90664489255SMichael Kruse 90764489255SMichael Kruse // Collapse the neigbros to a single new candidate that could be fused 90864489255SMichael Kruse // with the next candidate. 90964489255SMichael Kruse Bands[i] = {Fused.get_root(), {}}; 91064489255SMichael Kruse Bands.erase(Bands.begin() + i + 1); 91164489255SMichael Kruse 91264489255SMichael Kruse AnyChange = true; 91364489255SMichael Kruse } 91464489255SMichael Kruse 91564489255SMichael Kruse // By construction equal if done with collectPotentiallyFusableBands's 91664489255SMichael Kruse // output. 91764489255SMichael Kruse SmallVector<isl::union_set> SubDomains; 91864489255SMichael Kruse SubDomains.reserve(NumChildren); 91964489255SMichael Kruse for (int i = 0; i < NumChildren; i += 1) 92064489255SMichael Kruse SubDomains.push_back(Sequence.child(i).domain()); 92164489255SMichael Kruse auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps); 92264489255SMichael Kruse 92364489255SMichael Kruse // We may iterate over direct children multiple times, be sure to add each 92464489255SMichael Kruse // at most once. 92564489255SMichael Kruse SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded; 92664489255SMichael Kruse 92764489255SMichael Kruse isl::schedule Result; 92864489255SMichael Kruse for (auto &P : Bands) { 92964489255SMichael Kruse isl::schedule_node MaybeFused = P.first; 93064489255SMichael Kruse isl::schedule_node DirectChild = P.second; 93164489255SMichael Kruse 93264489255SMichael Kruse // If not modified, use the direct child. 93364489255SMichael Kruse if (!DirectChild.is_null() && 93464489255SMichael Kruse !ChangedDirectChildren.count(DirectChild.get())) { 93564489255SMichael Kruse if (AlreadyAdded.count(DirectChild.get())) 93664489255SMichael Kruse continue; 93764489255SMichael Kruse AlreadyAdded.insert(DirectChild.get()); 93864489255SMichael Kruse MaybeFused = DirectChild; 93964489255SMichael Kruse } else { 94064489255SMichael Kruse assert(AnyChange && 94164489255SMichael Kruse "Need changed flag for be consistent with actual change"); 94264489255SMichael Kruse } 94364489255SMichael Kruse 94464489255SMichael Kruse // Top-down recursion: If the outermost loop has been fused, their nested 94564489255SMichael Kruse // bands might be fusable now as well. 94664489255SMichael Kruse isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps); 94764489255SMichael Kruse 94864489255SMichael Kruse // Reconstruct the sequence, with some of the children fused. 94964489255SMichael Kruse if (Result.is_null()) 95064489255SMichael Kruse Result = InnerFused; 95164489255SMichael Kruse else 95264489255SMichael Kruse Result = Result.sequence(InnerFused); 95364489255SMichael Kruse } 95464489255SMichael Kruse 95564489255SMichael Kruse return Result; 95664489255SMichael Kruse } 95764489255SMichael Kruse }; 95864489255SMichael Kruse 9593f170eb1SMichael Kruse } // namespace 9603f170eb1SMichael Kruse 9613f170eb1SMichael Kruse bool polly::isBandMark(const isl::schedule_node &Node) { 962d3fdbda6SRiccardo Mori return isMark(Node) && 963d3fdbda6SRiccardo Mori isLoopAttr(Node.as<isl::schedule_node_mark>().get_id()); 9643f170eb1SMichael Kruse } 9653f170eb1SMichael Kruse 9663f170eb1SMichael Kruse BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { 9673f170eb1SMichael Kruse MarkOrBand = moveToBandMark(MarkOrBand); 9683f170eb1SMichael Kruse if (!isMark(MarkOrBand)) 9693f170eb1SMichael Kruse return nullptr; 9703f170eb1SMichael Kruse 971d3fdbda6SRiccardo Mori return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 9723f170eb1SMichael Kruse } 9733f170eb1SMichael Kruse 974aa8a9761SMichael Kruse isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 975aa8a9761SMichael Kruse // If there is no extension node in the first place, return the original 976aa8a9761SMichael Kruse // schedule tree. 977aa8a9761SMichael Kruse if (!containsExtensionNode(Sched)) 978aa8a9761SMichael Kruse return Sched; 979aa8a9761SMichael Kruse 980aa8a9761SMichael Kruse // Build options can anchor schedule nodes, such that the schedule tree cannot 981aa8a9761SMichael Kruse // be modified anymore. Therefore, apply build options after the tree has been 982aa8a9761SMichael Kruse // created. 983aa8a9761SMichael Kruse CollectASTBuildOptions Collector; 984aa8a9761SMichael Kruse Collector.visit(Sched); 985aa8a9761SMichael Kruse 986aa8a9761SMichael Kruse // Rewrite the schedule tree without extension nodes. 987aa8a9761SMichael Kruse ExtensionNodeRewriter Rewriter; 988aa8a9761SMichael Kruse isl::schedule NewSched = Rewriter.visitSchedule(Sched); 989aa8a9761SMichael Kruse 990aa8a9761SMichael Kruse // Reapply the AST build options. The rewriter must not change the iteration 991aa8a9761SMichael Kruse // order of bands. Any other node type is ignored. 992aa8a9761SMichael Kruse ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 993aa8a9761SMichael Kruse NewSched = Applicator.visitSchedule(NewSched); 994aa8a9761SMichael Kruse 995aa8a9761SMichael Kruse return NewSched; 996aa8a9761SMichael Kruse } 9973f170eb1SMichael Kruse 9983f170eb1SMichael Kruse isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { 9990813bd16SRiccardo Mori isl::ctx Ctx = BandToUnroll.ctx(); 10003f170eb1SMichael Kruse 10013f170eb1SMichael Kruse // Remove the loop's mark, the loop will disappear anyway. 10023f170eb1SMichael Kruse BandToUnroll = removeMark(BandToUnroll); 10033f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandToUnroll)); 10043f170eb1SMichael Kruse 10053f170eb1SMichael Kruse isl::multi_union_pw_aff PartialSched = isl::manage( 10063f170eb1SMichael Kruse isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 100744596fe6SRiccardo Mori assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u && 10083f170eb1SMichael Kruse "Can only unroll a single dimension"); 1009d3fdbda6SRiccardo Mori isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 10103f170eb1SMichael Kruse 10113f170eb1SMichael Kruse isl::union_set Domain = BandToUnroll.get_domain(); 10123f170eb1SMichael Kruse PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); 1013d3fdbda6SRiccardo Mori isl::union_map PartialSchedUMap = 1014d3fdbda6SRiccardo Mori isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 10153f170eb1SMichael Kruse 1016f51427afSMichael Kruse // Enumerator only the scatter elements. 1017f51427afSMichael Kruse isl::union_set ScatterList = PartialSchedUMap.range(); 10183f170eb1SMichael Kruse 1019f51427afSMichael Kruse // Enumerate all loop iterations. 10203f170eb1SMichael Kruse // TODO: Diagnose if not enumerable or depends on a parameter. 1021f51427afSMichael Kruse SmallVector<isl::point, 16> Elts; 1022f51427afSMichael Kruse ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat { 10233f170eb1SMichael Kruse Elts.push_back(P); 10243f170eb1SMichael Kruse return isl::stat::ok(); 10253f170eb1SMichael Kruse }); 10263f170eb1SMichael Kruse 10273f170eb1SMichael Kruse // Don't assume that foreach_point returns in execution order. 10283f170eb1SMichael Kruse llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { 10293f170eb1SMichael Kruse isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); 10303f170eb1SMichael Kruse isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); 10313f170eb1SMichael Kruse return C1.lt(C2); 10323f170eb1SMichael Kruse }); 10333f170eb1SMichael Kruse 10343f170eb1SMichael Kruse // Convert the points to a sequence of filters. 1035d3fdbda6SRiccardo Mori isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); 10363f170eb1SMichael Kruse for (isl::point P : Elts) { 1037f51427afSMichael Kruse // Determine the domains that map this scatter element. 1038f51427afSMichael Kruse isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain(); 10393f170eb1SMichael Kruse 1040f51427afSMichael Kruse List = List.add(DomainFilter); 10413f170eb1SMichael Kruse } 10423f170eb1SMichael Kruse 10433f170eb1SMichael Kruse // Replace original band with unrolled sequence. 10443f170eb1SMichael Kruse isl::schedule_node Body = 10453f170eb1SMichael Kruse isl::manage(isl_schedule_node_delete(BandToUnroll.release())); 10463f170eb1SMichael Kruse Body = Body.insert_sequence(List); 10473f170eb1SMichael Kruse return Body.get_schedule(); 10483f170eb1SMichael Kruse } 10493f170eb1SMichael Kruse 10503f170eb1SMichael Kruse isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, 10513f170eb1SMichael Kruse int Factor) { 10523f170eb1SMichael Kruse assert(Factor > 0 && "Positive unroll factor required"); 10530813bd16SRiccardo Mori isl::ctx Ctx = BandToUnroll.ctx(); 10543f170eb1SMichael Kruse 10553f170eb1SMichael Kruse // Remove the mark, save the attribute for later use. 10563f170eb1SMichael Kruse BandAttr *Attr; 10573f170eb1SMichael Kruse BandToUnroll = removeMark(BandToUnroll, Attr); 10583f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandToUnroll)); 10593f170eb1SMichael Kruse 10603f170eb1SMichael Kruse isl::multi_union_pw_aff PartialSched = isl::manage( 10613f170eb1SMichael Kruse isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 10623f170eb1SMichael Kruse 10633f170eb1SMichael Kruse // { Stmt[] -> [x] } 1064d3fdbda6SRiccardo Mori isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 10653f170eb1SMichael Kruse 10663f170eb1SMichael Kruse // Here we assume the schedule stride is one and starts with 0, which is not 10673f170eb1SMichael Kruse // necessarily the case. 10683f170eb1SMichael Kruse isl::union_pw_aff StridedPartialSchedUAff = 10693f170eb1SMichael Kruse isl::union_pw_aff::empty(PartialSchedUAff.get_space()); 10703f170eb1SMichael Kruse isl::val ValFactor{Ctx, Factor}; 10713f170eb1SMichael Kruse PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, 10723f170eb1SMichael Kruse &ValFactor](isl::pw_aff PwAff) -> isl::stat { 10733f170eb1SMichael Kruse isl::space Space = PwAff.get_space(); 10743f170eb1SMichael Kruse isl::set Universe = isl::set::universe(Space.domain()); 10753f170eb1SMichael Kruse isl::pw_aff AffFactor{Universe, ValFactor}; 10763f170eb1SMichael Kruse isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); 10773f170eb1SMichael Kruse StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); 10783f170eb1SMichael Kruse return isl::stat::ok(); 10793f170eb1SMichael Kruse }); 10803f170eb1SMichael Kruse 1081d3fdbda6SRiccardo Mori isl::union_set_list List = isl::union_set_list(Ctx, Factor); 10823f170eb1SMichael Kruse for (auto i : seq<int>(0, Factor)) { 10833f170eb1SMichael Kruse // { Stmt[] -> [x] } 1084d3fdbda6SRiccardo Mori isl::union_map UMap = 1085d3fdbda6SRiccardo Mori isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 10863f170eb1SMichael Kruse 10873f170eb1SMichael Kruse // { [x] } 10883f170eb1SMichael Kruse isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); 10893f170eb1SMichael Kruse 10903f170eb1SMichael Kruse // { Stmt[] } 10913f170eb1SMichael Kruse isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); 10923f170eb1SMichael Kruse 10933f170eb1SMichael Kruse List = List.add(UnrolledDomain); 10943f170eb1SMichael Kruse } 10953f170eb1SMichael Kruse 10963f170eb1SMichael Kruse isl::schedule_node Body = 10973f170eb1SMichael Kruse isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); 10983f170eb1SMichael Kruse Body = Body.insert_sequence(List); 10993f170eb1SMichael Kruse isl::schedule_node NewLoop = 11003f170eb1SMichael Kruse Body.insert_partial_schedule(StridedPartialSchedUAff); 11013f170eb1SMichael Kruse 11023f170eb1SMichael Kruse MDNode *FollowupMD = nullptr; 11033f170eb1SMichael Kruse if (Attr && Attr->Metadata) 11043f170eb1SMichael Kruse FollowupMD = 11053f170eb1SMichael Kruse findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); 11063f170eb1SMichael Kruse 11073f170eb1SMichael Kruse isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); 11087c7978a1Spatacca if (!NewBandId.is_null()) 11093f170eb1SMichael Kruse NewLoop = insertMark(NewLoop, NewBandId); 11103f170eb1SMichael Kruse 11113f170eb1SMichael Kruse return NewLoop.get_schedule(); 11123f170eb1SMichael Kruse } 1113d123e983SMichael Kruse 1114d123e983SMichael Kruse isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, 1115d123e983SMichael Kruse int VectorWidth) { 111644596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim()); 111744596fe6SRiccardo Mori assert(Dims >= 1); 1118d123e983SMichael Kruse isl::set LoopPrefixes = 1119d123e983SMichael Kruse ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1); 1120d123e983SMichael Kruse auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth); 1121d123e983SMichael Kruse isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange); 1122d123e983SMichael Kruse BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1123d123e983SMichael Kruse LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1124d123e983SMichael Kruse return LoopPrefixes.subtract(BadPrefixes); 1125d123e983SMichael Kruse } 1126d123e983SMichael Kruse 1127d123e983SMichael Kruse isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, 112844596fe6SRiccardo Mori unsigned OutDimsNum) { 112944596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim()); 1130d123e983SMichael Kruse assert(OutDimsNum <= Dims && 1131d123e983SMichael Kruse "The isl::set IsolateDomain is used to describe the range of schedule " 1132d123e983SMichael Kruse "dimensions values, which should be isolated. Consequently, the " 1133d123e983SMichael Kruse "number of its dimensions should be greater than or equal to the " 1134d123e983SMichael Kruse "number of the schedule dimensions."); 1135d123e983SMichael Kruse isl::map IsolateRelation = isl::map::from_domain(IsolateDomain); 1136d123e983SMichael Kruse IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in, 1137d123e983SMichael Kruse Dims - OutDimsNum, OutDimsNum); 1138d123e983SMichael Kruse isl::set IsolateOption = IsolateRelation.wrap(); 11390813bd16SRiccardo Mori isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr); 1140d123e983SMichael Kruse IsolateOption = IsolateOption.set_tuple_id(Id); 1141d123e983SMichael Kruse return isl::union_set(IsolateOption); 1142d123e983SMichael Kruse } 1143d123e983SMichael Kruse 1144d123e983SMichael Kruse isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { 1145d123e983SMichael Kruse isl::space Space(Ctx, 0, 1); 1146d123e983SMichael Kruse auto DimOption = isl::set::universe(Space); 1147d123e983SMichael Kruse auto Id = isl::id::alloc(Ctx, Option, nullptr); 1148d123e983SMichael Kruse DimOption = DimOption.set_tuple_id(Id); 1149d123e983SMichael Kruse return isl::union_set(DimOption); 1150d123e983SMichael Kruse } 1151d123e983SMichael Kruse 1152d123e983SMichael Kruse isl::schedule_node polly::tileNode(isl::schedule_node Node, 1153d123e983SMichael Kruse const char *Identifier, 1154d123e983SMichael Kruse ArrayRef<int> TileSizes, 1155d123e983SMichael Kruse int DefaultTileSize) { 1156d123e983SMichael Kruse auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get())); 1157d123e983SMichael Kruse auto Dims = Space.dim(isl::dim::set); 1158d123e983SMichael Kruse auto Sizes = isl::multi_val::zero(Space); 1159d123e983SMichael Kruse std::string IdentifierString(Identifier); 116044596fe6SRiccardo Mori for (unsigned i : rangeIslSize(0, Dims)) { 116144596fe6SRiccardo Mori unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize; 11620813bd16SRiccardo Mori Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize)); 1163d123e983SMichael Kruse } 1164d123e983SMichael Kruse auto TileLoopMarkerStr = IdentifierString + " - Tiles"; 11650813bd16SRiccardo Mori auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr); 1166d123e983SMichael Kruse Node = Node.insert_mark(TileLoopMarker); 1167d123e983SMichael Kruse Node = Node.child(0); 1168d123e983SMichael Kruse Node = 1169d123e983SMichael Kruse isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release())); 1170d123e983SMichael Kruse Node = Node.child(0); 1171d123e983SMichael Kruse auto PointLoopMarkerStr = IdentifierString + " - Points"; 1172d123e983SMichael Kruse auto PointLoopMarker = 11730813bd16SRiccardo Mori isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr); 1174d123e983SMichael Kruse Node = Node.insert_mark(PointLoopMarker); 1175d123e983SMichael Kruse return Node.child(0); 1176d123e983SMichael Kruse } 1177d123e983SMichael Kruse 1178d123e983SMichael Kruse isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, 1179d123e983SMichael Kruse ArrayRef<int> TileSizes, 1180d123e983SMichael Kruse int DefaultTileSize) { 1181d123e983SMichael Kruse Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize); 11820813bd16SRiccardo Mori auto Ctx = Node.ctx(); 1183d3fdbda6SRiccardo Mori return Node.as<isl::schedule_node_band>().set_ast_build_options( 1184d3fdbda6SRiccardo Mori isl::union_set(Ctx, "{unroll[x]}")); 1185d123e983SMichael Kruse } 1186e470f926SMichael Kruse 1187e470f926SMichael Kruse /// Find statements and sub-loops in (possibly nested) sequences. 1188e470f926SMichael Kruse static void 1189b554c643SMichael Kruse collectFissionableStmts(isl::schedule_node Node, 1190e470f926SMichael Kruse SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { 1191e470f926SMichael Kruse if (isBand(Node) || isLeaf(Node)) { 1192e470f926SMichael Kruse ScheduleStmts.push_back(Node); 1193e470f926SMichael Kruse return; 1194e470f926SMichael Kruse } 1195e470f926SMichael Kruse 1196e470f926SMichael Kruse if (Node.has_children()) { 1197e470f926SMichael Kruse isl::schedule_node C = Node.first_child(); 1198e470f926SMichael Kruse while (true) { 1199b554c643SMichael Kruse collectFissionableStmts(C, ScheduleStmts); 1200e470f926SMichael Kruse if (!C.has_next_sibling()) 1201e470f926SMichael Kruse break; 1202e470f926SMichael Kruse C = C.next_sibling(); 1203e470f926SMichael Kruse } 1204e470f926SMichael Kruse } 1205e470f926SMichael Kruse } 1206e470f926SMichael Kruse 1207e470f926SMichael Kruse isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { 1208e470f926SMichael Kruse isl::ctx Ctx = BandToFission.ctx(); 1209e470f926SMichael Kruse BandToFission = removeMark(BandToFission); 1210e470f926SMichael Kruse isl::schedule_node BandBody = BandToFission.child(0); 1211e470f926SMichael Kruse 1212e470f926SMichael Kruse SmallVector<isl::schedule_node> FissionableStmts; 1213b554c643SMichael Kruse collectFissionableStmts(BandBody, FissionableStmts); 1214e470f926SMichael Kruse size_t N = FissionableStmts.size(); 1215e470f926SMichael Kruse 1216e470f926SMichael Kruse // Collect the domain for each of the statements that will get their own loop. 1217e470f926SMichael Kruse isl::union_set_list DomList = isl::union_set_list(Ctx, N); 1218e470f926SMichael Kruse for (size_t i = 0; i < N; ++i) { 1219e470f926SMichael Kruse isl::schedule_node BodyPart = FissionableStmts[i]; 1220e470f926SMichael Kruse DomList = DomList.add(BodyPart.get_domain()); 1221e470f926SMichael Kruse } 1222e470f926SMichael Kruse 1223e470f926SMichael Kruse // Apply the fission by copying the entire loop, but inserting a filter for 1224e470f926SMichael Kruse // the statement domains for each fissioned loop. 1225e470f926SMichael Kruse isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList); 1226e470f926SMichael Kruse 1227e470f926SMichael Kruse return Fissioned.get_schedule(); 1228e470f926SMichael Kruse } 122964489255SMichael Kruse 123064489255SMichael Kruse isl::schedule polly::applyGreedyFusion(isl::schedule Sched, 123164489255SMichael Kruse const isl::union_map &Deps) { 1232601d7eabSKarthika Devi C POLLY_DEBUG(dbgs() << "Greedy loop fusion\n"); 123364489255SMichael Kruse 123464489255SMichael Kruse GreedyFusionRewriter Rewriter; 123564489255SMichael Kruse isl::schedule Result = Rewriter.visit(Sched, Deps); 123664489255SMichael Kruse if (!Rewriter.AnyChange) { 1237601d7eabSKarthika Devi C POLLY_DEBUG(dbgs() << "Found nothing to fuse\n"); 123864489255SMichael Kruse return Sched; 123964489255SMichael Kruse } 124064489255SMichael Kruse 124164489255SMichael Kruse // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops 124264489255SMichael Kruse // may have been split into multiple bands. 124364489255SMichael Kruse return collapseBands(Result); 124464489255SMichael Kruse } 1245