xref: /llvm-project/polly/lib/Support/ISLTools.cpp (revision b02c7e2b630a04701d12efd2376f25eff2767279)
1 //===------ ISLTools.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Tools, utilities, helpers and extensions useful in conjunction with the
10 // Integer Set Library (isl).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/Support/ISLTools.h"
15 #include "polly/Support/GICHelper.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 #include <vector>
19 
20 using namespace polly;
21 
22 namespace {
23 /// Create a map that shifts one dimension by an offset.
24 ///
25 /// Example:
26 /// makeShiftDimAff({ [i0, i1] -> [o0, o1] }, 1, -2)
27 ///   = { [i0, i1] -> [i0, i1 - 1] }
28 ///
29 /// @param Space  The map space of the result. Must have equal number of in- and
30 ///               out-dimensions.
31 /// @param Pos    Position to shift.
32 /// @param Amount Value added to the shifted dimension.
33 ///
34 /// @return An isl_multi_aff for the map with this shifted dimension.
makeShiftDimAff(isl::space Space,int Pos,int Amount)35 isl::multi_aff makeShiftDimAff(isl::space Space, int Pos, int Amount) {
36   auto Identity = isl::multi_aff::identity(Space);
37   if (Amount == 0)
38     return Identity;
39   auto ShiftAff = Identity.at(Pos);
40   ShiftAff = ShiftAff.set_constant_si(Amount);
41   return Identity.set_aff(Pos, ShiftAff);
42 }
43 
44 /// Construct a map that swaps two nested tuples.
45 ///
46 /// @param FromSpace1 { Space1[] }
47 /// @param FromSpace2 { Space2[] }
48 ///
49 /// @return { [Space1[] -> Space2[]] -> [Space2[] -> Space1[]] }
makeTupleSwapBasicMap(isl::space FromSpace1,isl::space FromSpace2)50 isl::basic_map makeTupleSwapBasicMap(isl::space FromSpace1,
51                                      isl::space FromSpace2) {
52   // Fast-path on out-of-quota.
53   if (FromSpace1.is_null() || FromSpace2.is_null())
54     return {};
55 
56   assert(FromSpace1.is_set());
57   assert(FromSpace2.is_set());
58 
59   unsigned Dims1 = unsignedFromIslSize(FromSpace1.dim(isl::dim::set));
60   unsigned Dims2 = unsignedFromIslSize(FromSpace2.dim(isl::dim::set));
61 
62   isl::space FromSpace =
63       FromSpace1.map_from_domain_and_range(FromSpace2).wrap();
64   isl::space ToSpace = FromSpace2.map_from_domain_and_range(FromSpace1).wrap();
65   isl::space MapSpace = FromSpace.map_from_domain_and_range(ToSpace);
66 
67   isl::basic_map Result = isl::basic_map::universe(MapSpace);
68   for (unsigned i = 0u; i < Dims1; i += 1)
69     Result = Result.equate(isl::dim::in, i, isl::dim::out, Dims2 + i);
70   for (unsigned i = 0u; i < Dims2; i += 1) {
71     Result = Result.equate(isl::dim::in, Dims1 + i, isl::dim::out, i);
72   }
73 
74   return Result;
75 }
76 
77 /// Like makeTupleSwapBasicMap(isl::space,isl::space), but returns
78 /// an isl_map.
makeTupleSwapMap(isl::space FromSpace1,isl::space FromSpace2)79 isl::map makeTupleSwapMap(isl::space FromSpace1, isl::space FromSpace2) {
80   isl::basic_map BMapResult = makeTupleSwapBasicMap(FromSpace1, FromSpace2);
81   return isl::map(BMapResult);
82 }
83 } // anonymous namespace
84 
beforeScatter(isl::map Map,bool Strict)85 isl::map polly::beforeScatter(isl::map Map, bool Strict) {
86   isl::space RangeSpace = Map.get_space().range();
87   isl::map ScatterRel =
88       Strict ? isl::map::lex_gt(RangeSpace) : isl::map::lex_ge(RangeSpace);
89   return Map.apply_range(ScatterRel);
90 }
91 
beforeScatter(isl::union_map UMap,bool Strict)92 isl::union_map polly::beforeScatter(isl::union_map UMap, bool Strict) {
93   isl::union_map Result = isl::union_map::empty(UMap.ctx());
94 
95   for (isl::map Map : UMap.get_map_list()) {
96     isl::map After = beforeScatter(Map, Strict);
97     Result = Result.unite(After);
98   }
99 
100   return Result;
101 }
102 
afterScatter(isl::map Map,bool Strict)103 isl::map polly::afterScatter(isl::map Map, bool Strict) {
104   isl::space RangeSpace = Map.get_space().range();
105   isl::map ScatterRel =
106       Strict ? isl::map::lex_lt(RangeSpace) : isl::map::lex_le(RangeSpace);
107   return Map.apply_range(ScatterRel);
108 }
109 
afterScatter(const isl::union_map & UMap,bool Strict)110 isl::union_map polly::afterScatter(const isl::union_map &UMap, bool Strict) {
111   isl::union_map Result = isl::union_map::empty(UMap.ctx());
112   for (isl::map Map : UMap.get_map_list()) {
113     isl::map After = afterScatter(Map, Strict);
114     Result = Result.unite(After);
115   }
116   return Result;
117 }
118 
betweenScatter(isl::map From,isl::map To,bool InclFrom,bool InclTo)119 isl::map polly::betweenScatter(isl::map From, isl::map To, bool InclFrom,
120                                bool InclTo) {
121   isl::map AfterFrom = afterScatter(From, !InclFrom);
122   isl::map BeforeTo = beforeScatter(To, !InclTo);
123 
124   return AfterFrom.intersect(BeforeTo);
125 }
126 
betweenScatter(isl::union_map From,isl::union_map To,bool InclFrom,bool InclTo)127 isl::union_map polly::betweenScatter(isl::union_map From, isl::union_map To,
128                                      bool InclFrom, bool InclTo) {
129   isl::union_map AfterFrom = afterScatter(From, !InclFrom);
130   isl::union_map BeforeTo = beforeScatter(To, !InclTo);
131 
132   return AfterFrom.intersect(BeforeTo);
133 }
134 
singleton(isl::union_map UMap,isl::space ExpectedSpace)135 isl::map polly::singleton(isl::union_map UMap, isl::space ExpectedSpace) {
136   if (UMap.is_null())
137     return {};
138 
139   if (isl_union_map_n_map(UMap.get()) == 0)
140     return isl::map::empty(ExpectedSpace);
141 
142   isl::map Result = isl::map::from_union_map(UMap);
143   assert(Result.is_null() ||
144          Result.get_space().has_equal_tuples(ExpectedSpace));
145 
146   return Result;
147 }
148 
singleton(isl::union_set USet,isl::space ExpectedSpace)149 isl::set polly::singleton(isl::union_set USet, isl::space ExpectedSpace) {
150   if (USet.is_null())
151     return {};
152 
153   if (isl_union_set_n_set(USet.get()) == 0)
154     return isl::set::empty(ExpectedSpace);
155 
156   isl::set Result(USet);
157   assert(Result.is_null() ||
158          Result.get_space().has_equal_tuples(ExpectedSpace));
159 
160   return Result;
161 }
162 
getNumScatterDims(const isl::union_map & Schedule)163 unsigned polly::getNumScatterDims(const isl::union_map &Schedule) {
164   unsigned Dims = 0;
165   for (isl::map Map : Schedule.get_map_list()) {
166     if (Map.is_null())
167       continue;
168 
169     Dims = std::max(Dims, unsignedFromIslSize(Map.range_tuple_dim()));
170   }
171   return Dims;
172 }
173 
getScatterSpace(const isl::union_map & Schedule)174 isl::space polly::getScatterSpace(const isl::union_map &Schedule) {
175   if (Schedule.is_null())
176     return {};
177   unsigned Dims = getNumScatterDims(Schedule);
178   isl::space ScatterSpace = Schedule.get_space().set_from_params();
179   return ScatterSpace.add_dims(isl::dim::set, Dims);
180 }
181 
makeIdentityMap(const isl::set & Set,bool RestrictDomain)182 isl::map polly::makeIdentityMap(const isl::set &Set, bool RestrictDomain) {
183   isl::map Result = isl::map::identity(Set.get_space().map_from_set());
184   if (RestrictDomain)
185     Result = Result.intersect_domain(Set);
186   return Result;
187 }
188 
makeIdentityMap(const isl::union_set & USet,bool RestrictDomain)189 isl::union_map polly::makeIdentityMap(const isl::union_set &USet,
190                                       bool RestrictDomain) {
191   isl::union_map Result = isl::union_map::empty(USet.ctx());
192   for (isl::set Set : USet.get_set_list()) {
193     isl::map IdentityMap = makeIdentityMap(Set, RestrictDomain);
194     Result = Result.unite(IdentityMap);
195   }
196   return Result;
197 }
198 
reverseDomain(isl::map Map)199 isl::map polly::reverseDomain(isl::map Map) {
200   isl::space DomSpace = Map.get_space().domain().unwrap();
201   isl::space Space1 = DomSpace.domain();
202   isl::space Space2 = DomSpace.range();
203   isl::map Swap = makeTupleSwapMap(Space1, Space2);
204   return Map.apply_domain(Swap);
205 }
206 
reverseDomain(const isl::union_map & UMap)207 isl::union_map polly::reverseDomain(const isl::union_map &UMap) {
208   isl::union_map Result = isl::union_map::empty(UMap.ctx());
209   for (isl::map Map : UMap.get_map_list()) {
210     auto Reversed = reverseDomain(std::move(Map));
211     Result = Result.unite(Reversed);
212   }
213   return Result;
214 }
215 
shiftDim(isl::set Set,int Pos,int Amount)216 isl::set polly::shiftDim(isl::set Set, int Pos, int Amount) {
217   unsigned NumDims = unsignedFromIslSize(Set.tuple_dim());
218   if (Pos < 0)
219     Pos = NumDims + Pos;
220   assert(unsigned(Pos) < NumDims && "Dimension index must be in range");
221   isl::space Space = Set.get_space();
222   Space = Space.map_from_domain_and_range(Space);
223   isl::multi_aff Translator = makeShiftDimAff(Space, Pos, Amount);
224   isl::map TranslatorMap = isl::map::from_multi_aff(Translator);
225   return Set.apply(TranslatorMap);
226 }
227 
shiftDim(isl::union_set USet,int Pos,int Amount)228 isl::union_set polly::shiftDim(isl::union_set USet, int Pos, int Amount) {
229   isl::union_set Result = isl::union_set::empty(USet.ctx());
230   for (isl::set Set : USet.get_set_list()) {
231     isl::set Shifted = shiftDim(Set, Pos, Amount);
232     Result = Result.unite(Shifted);
233   }
234   return Result;
235 }
236 
shiftDim(isl::map Map,isl::dim Dim,int Pos,int Amount)237 isl::map polly::shiftDim(isl::map Map, isl::dim Dim, int Pos, int Amount) {
238   unsigned NumDims = unsignedFromIslSize(Map.dim(Dim));
239   if (Pos < 0)
240     Pos = NumDims + Pos;
241   assert(unsigned(Pos) < NumDims && "Dimension index must be in range");
242   isl::space Space = Map.get_space();
243   switch (Dim) {
244   case isl::dim::in:
245     Space = Space.domain();
246     break;
247   case isl::dim::out:
248     Space = Space.range();
249     break;
250   default:
251     llvm_unreachable("Unsupported value for 'dim'");
252   }
253   Space = Space.map_from_domain_and_range(Space);
254   isl::multi_aff Translator = makeShiftDimAff(Space, Pos, Amount);
255   isl::map TranslatorMap = isl::map::from_multi_aff(Translator);
256   switch (Dim) {
257   case isl::dim::in:
258     return Map.apply_domain(TranslatorMap);
259   case isl::dim::out:
260     return Map.apply_range(TranslatorMap);
261   default:
262     llvm_unreachable("Unsupported value for 'dim'");
263   }
264 }
265 
getConstant(isl::map Map,isl::dim Dim,int Pos)266 isl::val polly::getConstant(isl::map Map, isl::dim Dim, int Pos) {
267   unsigned NumDims = unsignedFromIslSize(Map.dim(Dim));
268   if (Pos < 0)
269     Pos = NumDims + Pos;
270   assert(unsigned(Pos) < NumDims && "Dimension index must be in range");
271   // TODO: The isl_map_plain_get_val_if_fixed function is not robust, since its
272   // result is different depending on the internal representation.
273   // Replace it with a different implementation.
274   return isl::manage(isl_map_plain_get_val_if_fixed(
275       Map.get(), static_cast<enum isl_dim_type>(Dim), Pos));
276 }
277 
shiftDim(isl::union_map UMap,isl::dim Dim,int Pos,int Amount)278 isl::union_map polly::shiftDim(isl::union_map UMap, isl::dim Dim, int Pos,
279                                int Amount) {
280   isl::union_map Result = isl::union_map::empty(UMap.ctx());
281 
282   for (isl::map Map : UMap.get_map_list()) {
283     isl::map Shifted = shiftDim(Map, Dim, Pos, Amount);
284     Result = Result.unite(Shifted);
285   }
286   return Result;
287 }
288 
simplify(isl::set & Set)289 void polly::simplify(isl::set &Set) {
290   Set = isl::manage(isl_set_compute_divs(Set.copy()));
291   Set = Set.detect_equalities();
292   Set = Set.coalesce();
293 }
294 
simplify(isl::union_set & USet)295 void polly::simplify(isl::union_set &USet) {
296   USet = isl::manage(isl_union_set_compute_divs(USet.copy()));
297   USet = USet.detect_equalities();
298   USet = USet.coalesce();
299 }
300 
simplify(isl::map & Map)301 void polly::simplify(isl::map &Map) {
302   Map = isl::manage(isl_map_compute_divs(Map.copy()));
303   Map = Map.detect_equalities();
304   Map = Map.coalesce();
305 }
306 
simplify(isl::union_map & UMap)307 void polly::simplify(isl::union_map &UMap) {
308   UMap = isl::manage(isl_union_map_compute_divs(UMap.copy()));
309   UMap = UMap.detect_equalities();
310   UMap = UMap.coalesce();
311 }
312 
computeReachingWrite(isl::union_map Schedule,isl::union_map Writes,bool Reverse,bool InclPrevDef,bool InclNextDef)313 isl::union_map polly::computeReachingWrite(isl::union_map Schedule,
314                                            isl::union_map Writes, bool Reverse,
315                                            bool InclPrevDef, bool InclNextDef) {
316 
317   // { Scatter[] }
318   isl::space ScatterSpace = getScatterSpace(Schedule);
319 
320   // { ScatterRead[] -> ScatterWrite[] }
321   isl::map Relation;
322   if (Reverse)
323     Relation = InclPrevDef ? isl::map::lex_lt(ScatterSpace)
324                            : isl::map::lex_le(ScatterSpace);
325   else
326     Relation = InclNextDef ? isl::map::lex_gt(ScatterSpace)
327                            : isl::map::lex_ge(ScatterSpace);
328 
329   // { ScatterWrite[] -> [ScatterRead[] -> ScatterWrite[]] }
330   isl::map RelationMap = Relation.range_map().reverse();
331 
332   // { Element[] -> ScatterWrite[] }
333   isl::union_map WriteAction = Schedule.apply_domain(Writes);
334 
335   // { ScatterWrite[] -> Element[] }
336   isl::union_map WriteActionRev = WriteAction.reverse();
337 
338   // { Element[] -> [ScatterUse[] -> ScatterWrite[]] }
339   isl::union_map DefSchedRelation =
340       isl::union_map(RelationMap).apply_domain(WriteActionRev);
341 
342   // For each element, at every point in time, map to the times of previous
343   // definitions. { [Element[] -> ScatterRead[]] -> ScatterWrite[] }
344   isl::union_map ReachableWrites = DefSchedRelation.uncurry();
345   if (Reverse)
346     ReachableWrites = ReachableWrites.lexmin();
347   else
348     ReachableWrites = ReachableWrites.lexmax();
349 
350   // { [Element[] -> ScatterWrite[]] -> ScatterWrite[] }
351   isl::union_map SelfUse = WriteAction.range_map();
352 
353   if (InclPrevDef && InclNextDef) {
354     // Add the Def itself to the solution.
355     ReachableWrites = ReachableWrites.unite(SelfUse).coalesce();
356   } else if (!InclPrevDef && !InclNextDef) {
357     // Remove Def itself from the solution.
358     ReachableWrites = ReachableWrites.subtract(SelfUse);
359   }
360 
361   // { [Element[] -> ScatterRead[]] -> Domain[] }
362   return ReachableWrites.apply_range(Schedule.reverse());
363 }
364 
365 isl::union_map
computeArrayUnused(isl::union_map Schedule,isl::union_map Writes,isl::union_map Reads,bool ReadEltInSameInst,bool IncludeLastRead,bool IncludeWrite)366 polly::computeArrayUnused(isl::union_map Schedule, isl::union_map Writes,
367                           isl::union_map Reads, bool ReadEltInSameInst,
368                           bool IncludeLastRead, bool IncludeWrite) {
369   // { Element[] -> Scatter[] }
370   isl::union_map ReadActions = Schedule.apply_domain(Reads);
371   isl::union_map WriteActions = Schedule.apply_domain(Writes);
372 
373   // { [Element[] -> DomainWrite[]] -> Scatter[] }
374   isl::union_map EltDomWrites =
375       Writes.reverse().range_map().apply_range(Schedule);
376 
377   // { [Element[] -> Scatter[]] -> DomainWrite[] }
378   isl::union_map ReachingOverwrite = computeReachingWrite(
379       Schedule, Writes, true, ReadEltInSameInst, !ReadEltInSameInst);
380 
381   // { [Element[] -> Scatter[]] -> DomainWrite[] }
382   isl::union_map ReadsOverwritten =
383       ReachingOverwrite.intersect_domain(ReadActions.wrap());
384 
385   // { [Element[] -> DomainWrite[]] -> Scatter[] }
386   isl::union_map ReadsOverwrittenRotated =
387       reverseDomain(ReadsOverwritten).curry().reverse();
388   isl::union_map LastOverwrittenRead = ReadsOverwrittenRotated.lexmax();
389 
390   // { [Element[] -> DomainWrite[]] -> Scatter[] }
391   isl::union_map BetweenLastReadOverwrite = betweenScatter(
392       LastOverwrittenRead, EltDomWrites, IncludeLastRead, IncludeWrite);
393 
394   // { [Element[] -> Scatter[]] -> DomainWrite[] }
395   isl::union_map ReachingOverwriteZone = computeReachingWrite(
396       Schedule, Writes, true, IncludeLastRead, IncludeWrite);
397 
398   // { [Element[] -> DomainWrite[]] -> Scatter[] }
399   isl::union_map ReachingOverwriteRotated =
400       reverseDomain(ReachingOverwriteZone).curry().reverse();
401 
402   // { [Element[] -> DomainWrite[]] -> Scatter[] }
403   isl::union_map WritesWithoutReads = ReachingOverwriteRotated.subtract_domain(
404       ReadsOverwrittenRotated.domain());
405 
406   return BetweenLastReadOverwrite.unite(WritesWithoutReads)
407       .domain_factor_domain();
408 }
409 
convertZoneToTimepoints(isl::union_set Zone,bool InclStart,bool InclEnd)410 isl::union_set polly::convertZoneToTimepoints(isl::union_set Zone,
411                                               bool InclStart, bool InclEnd) {
412   if (!InclStart && InclEnd)
413     return Zone;
414 
415   auto ShiftedZone = shiftDim(Zone, -1, -1);
416   if (InclStart && !InclEnd)
417     return ShiftedZone;
418   else if (!InclStart && !InclEnd)
419     return Zone.intersect(ShiftedZone);
420 
421   assert(InclStart && InclEnd);
422   return Zone.unite(ShiftedZone);
423 }
424 
convertZoneToTimepoints(isl::union_map Zone,isl::dim Dim,bool InclStart,bool InclEnd)425 isl::union_map polly::convertZoneToTimepoints(isl::union_map Zone, isl::dim Dim,
426                                               bool InclStart, bool InclEnd) {
427   if (!InclStart && InclEnd)
428     return Zone;
429 
430   auto ShiftedZone = shiftDim(Zone, Dim, -1, -1);
431   if (InclStart && !InclEnd)
432     return ShiftedZone;
433   else if (!InclStart && !InclEnd)
434     return Zone.intersect(ShiftedZone);
435 
436   assert(InclStart && InclEnd);
437   return Zone.unite(ShiftedZone);
438 }
439 
convertZoneToTimepoints(isl::map Zone,isl::dim Dim,bool InclStart,bool InclEnd)440 isl::map polly::convertZoneToTimepoints(isl::map Zone, isl::dim Dim,
441                                         bool InclStart, bool InclEnd) {
442   if (!InclStart && InclEnd)
443     return Zone;
444 
445   auto ShiftedZone = shiftDim(Zone, Dim, -1, -1);
446   if (InclStart && !InclEnd)
447     return ShiftedZone;
448   else if (!InclStart && !InclEnd)
449     return Zone.intersect(ShiftedZone);
450 
451   assert(InclStart && InclEnd);
452   return Zone.unite(ShiftedZone);
453 }
454 
distributeDomain(isl::map Map)455 isl::map polly::distributeDomain(isl::map Map) {
456   // Note that we cannot take Map apart into { Domain[] -> Range1[] } and {
457   // Domain[] -> Range2[] } and combine again. We would loose any relation
458   // between Range1[] and Range2[] that is not also a constraint to Domain[].
459 
460   isl::space Space = Map.get_space();
461   isl::space DomainSpace = Space.domain();
462   if (DomainSpace.is_null())
463     return {};
464   unsigned DomainDims = unsignedFromIslSize(DomainSpace.dim(isl::dim::set));
465   isl::space RangeSpace = Space.range().unwrap();
466   isl::space Range1Space = RangeSpace.domain();
467   if (Range1Space.is_null())
468     return {};
469   unsigned Range1Dims = unsignedFromIslSize(Range1Space.dim(isl::dim::set));
470   isl::space Range2Space = RangeSpace.range();
471   if (Range2Space.is_null())
472     return {};
473   unsigned Range2Dims = unsignedFromIslSize(Range2Space.dim(isl::dim::set));
474 
475   isl::space OutputSpace =
476       DomainSpace.map_from_domain_and_range(Range1Space)
477           .wrap()
478           .map_from_domain_and_range(
479               DomainSpace.map_from_domain_and_range(Range2Space).wrap());
480 
481   isl::basic_map Translator = isl::basic_map::universe(
482       Space.wrap().map_from_domain_and_range(OutputSpace.wrap()));
483 
484   for (unsigned i = 0; i < DomainDims; i += 1) {
485     Translator = Translator.equate(isl::dim::in, i, isl::dim::out, i);
486     Translator = Translator.equate(isl::dim::in, i, isl::dim::out,
487                                    DomainDims + Range1Dims + i);
488   }
489   for (unsigned i = 0; i < Range1Dims; i += 1)
490     Translator = Translator.equate(isl::dim::in, DomainDims + i, isl::dim::out,
491                                    DomainDims + i);
492   for (unsigned i = 0; i < Range2Dims; i += 1)
493     Translator = Translator.equate(isl::dim::in, DomainDims + Range1Dims + i,
494                                    isl::dim::out,
495                                    DomainDims + Range1Dims + DomainDims + i);
496 
497   return Map.wrap().apply(Translator).unwrap();
498 }
499 
distributeDomain(isl::union_map UMap)500 isl::union_map polly::distributeDomain(isl::union_map UMap) {
501   isl::union_map Result = isl::union_map::empty(UMap.ctx());
502   for (isl::map Map : UMap.get_map_list()) {
503     auto Distributed = distributeDomain(Map);
504     Result = Result.unite(Distributed);
505   }
506   return Result;
507 }
508 
liftDomains(isl::union_map UMap,isl::union_set Factor)509 isl::union_map polly::liftDomains(isl::union_map UMap, isl::union_set Factor) {
510 
511   // { Factor[] -> Factor[] }
512   isl::union_map Factors = makeIdentityMap(Factor, true);
513 
514   return Factors.product(UMap);
515 }
516 
applyDomainRange(isl::union_map UMap,isl::union_map Func)517 isl::union_map polly::applyDomainRange(isl::union_map UMap,
518                                        isl::union_map Func) {
519   // This implementation creates unnecessary cross products of the
520   // DomainDomain[] and Func. An alternative implementation could reverse
521   // domain+uncurry,apply Func to what now is the domain, then undo the
522   // preparing transformation. Another alternative implementation could create a
523   // translator map for each piece.
524 
525   // { DomainDomain[] }
526   isl::union_set DomainDomain = UMap.domain().unwrap().domain();
527 
528   // { [DomainDomain[] -> DomainRange[]] -> [DomainDomain[] -> NewDomainRange[]]
529   // }
530   isl::union_map LifetedFunc = liftDomains(std::move(Func), DomainDomain);
531 
532   return UMap.apply_domain(LifetedFunc);
533 }
534 
intersectRange(isl::map Map,isl::union_set Range)535 isl::map polly::intersectRange(isl::map Map, isl::union_set Range) {
536   isl::set RangeSet = Range.extract_set(Map.get_space().range());
537   return Map.intersect_range(RangeSet);
538 }
539 
subtractParams(isl::map Map,isl::set Params)540 isl::map polly::subtractParams(isl::map Map, isl::set Params) {
541   auto MapSpace = Map.get_space();
542   auto ParamsMap = isl::map::universe(MapSpace).intersect_params(Params);
543   return Map.subtract(ParamsMap);
544 }
545 
subtractParams(isl::set Set,isl::set Params)546 isl::set polly::subtractParams(isl::set Set, isl::set Params) {
547   isl::space SetSpace = Set.get_space();
548   isl::set ParamsSet = isl::set::universe(SetSpace).intersect_params(Params);
549   return Set.subtract(ParamsSet);
550 }
551 
getConstant(isl::pw_aff PwAff,bool Max,bool Min)552 isl::val polly::getConstant(isl::pw_aff PwAff, bool Max, bool Min) {
553   assert(!Max || !Min); // Cannot return min and max at the same time.
554   isl::val Result;
555   isl::stat Stat = PwAff.foreach_piece(
556       [=, &Result](isl::set Set, isl::aff Aff) -> isl::stat {
557         if (!Result.is_null() && Result.is_nan())
558           return isl::stat::ok();
559 
560         // TODO: If Min/Max, we can also determine a minimum/maximum value if
561         // Set is constant-bounded.
562         if (!Aff.is_cst()) {
563           Result = isl::val::nan(Aff.ctx());
564           return isl::stat::error();
565         }
566 
567         isl::val ThisVal = Aff.get_constant_val();
568         if (Result.is_null()) {
569           Result = ThisVal;
570           return isl::stat::ok();
571         }
572 
573         if (Result.eq(ThisVal))
574           return isl::stat::ok();
575 
576         if (Max && ThisVal.gt(Result)) {
577           Result = ThisVal;
578           return isl::stat::ok();
579         }
580 
581         if (Min && ThisVal.lt(Result)) {
582           Result = ThisVal;
583           return isl::stat::ok();
584         }
585 
586         // Not compatible
587         Result = isl::val::nan(Aff.ctx());
588         return isl::stat::error();
589       });
590 
591   if (Stat.is_error())
592     return {};
593 
594   return Result;
595 }
596 
rangeIslSize(unsigned Begin,isl::size End)597 llvm::iota_range<unsigned> polly::rangeIslSize(unsigned Begin, isl::size End) {
598   unsigned UEnd = unsignedFromIslSize(End);
599   return llvm::seq<unsigned>(std::min(Begin, UEnd), UEnd);
600 }
601 
602 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
foreachPoint(const isl::set & Set,const std::function<void (isl::point P)> & F)603 static void foreachPoint(const isl::set &Set,
604                          const std::function<void(isl::point P)> &F) {
605   Set.foreach_point([&](isl::point P) -> isl::stat {
606     F(P);
607     return isl::stat::ok();
608   });
609 }
610 
foreachPoint(isl::basic_set BSet,const std::function<void (isl::point P)> & F)611 static void foreachPoint(isl::basic_set BSet,
612                          const std::function<void(isl::point P)> &F) {
613   foreachPoint(isl::set(BSet), F);
614 }
615 
616 /// Determine the sorting order of the sets @p A and @p B without considering
617 /// the space structure.
618 ///
619 /// Ordering is based on the lower bounds of the set's dimensions. First
620 /// dimensions are considered first.
flatCompare(const isl::basic_set & A,const isl::basic_set & B)621 static int flatCompare(const isl::basic_set &A, const isl::basic_set &B) {
622   // Quick bail-out on out-of-quota.
623   if (A.is_null() || B.is_null())
624     return 0;
625 
626   unsigned ALen = unsignedFromIslSize(A.dim(isl::dim::set));
627   unsigned BLen = unsignedFromIslSize(B.dim(isl::dim::set));
628   unsigned Len = std::min(ALen, BLen);
629 
630   for (unsigned i = 0; i < Len; i += 1) {
631     isl::basic_set ADim =
632         A.project_out(isl::dim::param, 0,
633                       unsignedFromIslSize(A.dim(isl::dim::param)))
634             .project_out(isl::dim::set, i + 1, ALen - i - 1)
635             .project_out(isl::dim::set, 0, i);
636     isl::basic_set BDim =
637         B.project_out(isl::dim::param, 0,
638                       unsignedFromIslSize(B.dim(isl::dim::param)))
639             .project_out(isl::dim::set, i + 1, BLen - i - 1)
640             .project_out(isl::dim::set, 0, i);
641 
642     isl::basic_set AHull = isl::set(ADim).convex_hull();
643     isl::basic_set BHull = isl::set(BDim).convex_hull();
644 
645     bool ALowerBounded =
646         bool(isl::set(AHull).dim_has_any_lower_bound(isl::dim::set, 0));
647     bool BLowerBounded =
648         bool(isl::set(BHull).dim_has_any_lower_bound(isl::dim::set, 0));
649 
650     int BoundedCompare = BLowerBounded - ALowerBounded;
651     if (BoundedCompare != 0)
652       return BoundedCompare;
653 
654     if (!ALowerBounded || !BLowerBounded)
655       continue;
656 
657     isl::pw_aff AMin = isl::set(ADim).dim_min(0);
658     isl::pw_aff BMin = isl::set(BDim).dim_min(0);
659 
660     isl::val AMinVal = polly::getConstant(AMin, false, true);
661     isl::val BMinVal = polly::getConstant(BMin, false, true);
662 
663     int MinCompare = AMinVal.sub(BMinVal).sgn();
664     if (MinCompare != 0)
665       return MinCompare;
666   }
667 
668   // If all the dimensions' lower bounds are equal or incomparable, sort based
669   // on the number of dimensions.
670   return ALen - BLen;
671 }
672 
673 /// Compare the sets @p A and @p B according to their nested space structure.
674 /// Returns 0 if the structure is considered equal.
675 /// If @p ConsiderTupleLen is false, the number of dimensions in a tuple are
676 /// ignored, i.e. a tuple with the same name but different number of dimensions
677 /// are considered equal.
structureCompare(const isl::space & ASpace,const isl::space & BSpace,bool ConsiderTupleLen)678 static int structureCompare(const isl::space &ASpace, const isl::space &BSpace,
679                             bool ConsiderTupleLen) {
680   int WrappingCompare = bool(ASpace.is_wrapping()) - bool(BSpace.is_wrapping());
681   if (WrappingCompare != 0)
682     return WrappingCompare;
683 
684   if (ASpace.is_wrapping() && BSpace.is_wrapping()) {
685     isl::space AMap = ASpace.unwrap();
686     isl::space BMap = BSpace.unwrap();
687 
688     int FirstResult =
689         structureCompare(AMap.domain(), BMap.domain(), ConsiderTupleLen);
690     if (FirstResult != 0)
691       return FirstResult;
692 
693     return structureCompare(AMap.range(), BMap.range(), ConsiderTupleLen);
694   }
695 
696   std::string AName;
697   if (!ASpace.is_params() && ASpace.has_tuple_name(isl::dim::set))
698     AName = ASpace.get_tuple_name(isl::dim::set);
699 
700   std::string BName;
701   if (!BSpace.is_params() && BSpace.has_tuple_name(isl::dim::set))
702     BName = BSpace.get_tuple_name(isl::dim::set);
703 
704   int NameCompare = AName.compare(BName);
705   if (NameCompare != 0)
706     return NameCompare;
707 
708   if (ConsiderTupleLen) {
709     int LenCompare = (int)unsignedFromIslSize(BSpace.dim(isl::dim::set)) -
710                      (int)unsignedFromIslSize(ASpace.dim(isl::dim::set));
711     if (LenCompare != 0)
712       return LenCompare;
713   }
714 
715   return 0;
716 }
717 
718 /// Compare the sets @p A and @p B according to their nested space structure. If
719 /// the structure is the same, sort using the dimension lower bounds.
720 /// Returns an std::sort compatible bool.
orderComparer(const isl::basic_set & A,const isl::basic_set & B)721 static bool orderComparer(const isl::basic_set &A, const isl::basic_set &B) {
722   isl::space ASpace = A.get_space();
723   isl::space BSpace = B.get_space();
724 
725   // Ignoring number of dimensions first ensures that structures with same tuple
726   // names, but different number of dimensions are still sorted close together.
727   int TupleNestingCompare = structureCompare(ASpace, BSpace, false);
728   if (TupleNestingCompare != 0)
729     return TupleNestingCompare < 0;
730 
731   int TupleCompare = structureCompare(ASpace, BSpace, true);
732   if (TupleCompare != 0)
733     return TupleCompare < 0;
734 
735   return flatCompare(A, B) < 0;
736 }
737 
738 /// Print a string representation of @p USet to @p OS.
739 ///
740 /// The pieces of @p USet are printed in a sorted order. Spaces with equal or
741 /// similar nesting structure are printed together. Compared to isl's own
742 /// printing function the uses the structure itself as base of the sorting, not
743 /// a hash of it. It ensures that e.g. maps spaces with same domain structure
744 /// are printed together. Set pieces with same structure are printed in order of
745 /// their lower bounds.
746 ///
747 /// @param USet     Polyhedra to print.
748 /// @param OS       Target stream.
749 /// @param Simplify Whether to simplify the polyhedron before printing.
750 /// @param IsMap    Whether @p USet is a wrapped map. If true, sets are
751 ///                 unwrapped before printing to again appear as a map.
printSortedPolyhedra(isl::union_set USet,llvm::raw_ostream & OS,bool Simplify,bool IsMap)752 static void printSortedPolyhedra(isl::union_set USet, llvm::raw_ostream &OS,
753                                  bool Simplify, bool IsMap) {
754   if (USet.is_null()) {
755     OS << "<null>\n";
756     return;
757   }
758 
759   if (Simplify)
760     simplify(USet);
761 
762   // Get all the polyhedra.
763   std::vector<isl::basic_set> BSets;
764 
765   for (isl::set Set : USet.get_set_list()) {
766     for (isl::basic_set BSet : Set.get_basic_set_list()) {
767       BSets.push_back(BSet);
768     }
769   }
770 
771   if (BSets.empty()) {
772     OS << "{\n}\n";
773     return;
774   }
775 
776   // Sort the polyhedra.
777   llvm::sort(BSets, orderComparer);
778 
779   // Print the polyhedra.
780   bool First = true;
781   for (const isl::basic_set &BSet : BSets) {
782     std::string Str;
783     if (IsMap)
784       Str = stringFromIslObj(isl::map(BSet.unwrap()));
785     else
786       Str = stringFromIslObj(isl::set(BSet));
787     size_t OpenPos = Str.find_first_of('{');
788     assert(OpenPos != std::string::npos);
789     size_t ClosePos = Str.find_last_of('}');
790     assert(ClosePos != std::string::npos);
791 
792     if (First)
793       OS << llvm::StringRef(Str).substr(0, OpenPos + 1) << "\n ";
794     else
795       OS << ";\n ";
796 
797     OS << llvm::StringRef(Str).substr(OpenPos + 1, ClosePos - OpenPos - 2);
798     First = false;
799   }
800   assert(!First);
801   OS << "\n}\n";
802 }
803 
recursiveExpand(isl::basic_set BSet,unsigned Dim,isl::set & Expanded)804 static void recursiveExpand(isl::basic_set BSet, unsigned Dim,
805                             isl::set &Expanded) {
806   unsigned Dims = unsignedFromIslSize(BSet.dim(isl::dim::set));
807   if (Dim >= Dims) {
808     Expanded = Expanded.unite(BSet);
809     return;
810   }
811 
812   isl::basic_set DimOnly =
813       BSet.project_out(isl::dim::param, 0,
814                        unsignedFromIslSize(BSet.dim(isl::dim::param)))
815           .project_out(isl::dim::set, Dim + 1, Dims - Dim - 1)
816           .project_out(isl::dim::set, 0, Dim);
817   if (!DimOnly.is_bounded()) {
818     recursiveExpand(BSet, Dim + 1, Expanded);
819     return;
820   }
821 
822   foreachPoint(DimOnly, [&, Dim](isl::point P) {
823     isl::val Val = P.get_coordinate_val(isl::dim::set, 0);
824     isl::basic_set FixBSet = BSet.fix_val(isl::dim::set, Dim, Val);
825     recursiveExpand(FixBSet, Dim + 1, Expanded);
826   });
827 }
828 
829 /// Make each point of a set explicit.
830 ///
831 /// "Expanding" makes each point a set contains explicit. That is, the result is
832 /// a set of singleton polyhedra. Unbounded dimensions are not expanded.
833 ///
834 /// Example:
835 ///   { [i] : 0 <= i < 2 }
836 /// is expanded to:
837 ///   { [0]; [1] }
expand(const isl::set & Set)838 static isl::set expand(const isl::set &Set) {
839   isl::set Expanded = isl::set::empty(Set.get_space());
840   for (isl::basic_set BSet : Set.get_basic_set_list())
841     recursiveExpand(BSet, 0, Expanded);
842   return Expanded;
843 }
844 
845 /// Expand all points of a union set explicit.
846 ///
847 /// @see expand(const isl::set)
expand(const isl::union_set & USet)848 static isl::union_set expand(const isl::union_set &USet) {
849   isl::union_set Expanded = isl::union_set::empty(USet.ctx());
850   for (isl::set Set : USet.get_set_list()) {
851     isl::set SetExpanded = expand(Set);
852     Expanded = Expanded.unite(SetExpanded);
853   }
854   return Expanded;
855 }
856 
dumpPw(const isl::set & Set)857 LLVM_DUMP_METHOD void polly::dumpPw(const isl::set &Set) {
858   printSortedPolyhedra(Set, llvm::errs(), true, false);
859 }
860 
dumpPw(const isl::map & Map)861 LLVM_DUMP_METHOD void polly::dumpPw(const isl::map &Map) {
862   printSortedPolyhedra(Map.wrap(), llvm::errs(), true, true);
863 }
864 
dumpPw(const isl::union_set & USet)865 LLVM_DUMP_METHOD void polly::dumpPw(const isl::union_set &USet) {
866   printSortedPolyhedra(USet, llvm::errs(), true, false);
867 }
868 
dumpPw(const isl::union_map & UMap)869 LLVM_DUMP_METHOD void polly::dumpPw(const isl::union_map &UMap) {
870   printSortedPolyhedra(UMap.wrap(), llvm::errs(), true, true);
871 }
872 
dumpPw(__isl_keep isl_set * Set)873 LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_set *Set) {
874   dumpPw(isl::manage_copy(Set));
875 }
876 
dumpPw(__isl_keep isl_map * Map)877 LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_map *Map) {
878   dumpPw(isl::manage_copy(Map));
879 }
880 
dumpPw(__isl_keep isl_union_set * USet)881 LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_union_set *USet) {
882   dumpPw(isl::manage_copy(USet));
883 }
884 
dumpPw(__isl_keep isl_union_map * UMap)885 LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_union_map *UMap) {
886   dumpPw(isl::manage_copy(UMap));
887 }
888 
dumpExpanded(const isl::set & Set)889 LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::set &Set) {
890   printSortedPolyhedra(expand(Set), llvm::errs(), false, false);
891 }
892 
dumpExpanded(const isl::map & Map)893 LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::map &Map) {
894   printSortedPolyhedra(expand(Map.wrap()), llvm::errs(), false, true);
895 }
896 
dumpExpanded(const isl::union_set & USet)897 LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::union_set &USet) {
898   printSortedPolyhedra(expand(USet), llvm::errs(), false, false);
899 }
900 
dumpExpanded(const isl::union_map & UMap)901 LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::union_map &UMap) {
902   printSortedPolyhedra(expand(UMap.wrap()), llvm::errs(), false, true);
903 }
904 
dumpExpanded(__isl_keep isl_set * Set)905 LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_set *Set) {
906   dumpExpanded(isl::manage_copy(Set));
907 }
908 
dumpExpanded(__isl_keep isl_map * Map)909 LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_map *Map) {
910   dumpExpanded(isl::manage_copy(Map));
911 }
912 
dumpExpanded(__isl_keep isl_union_set * USet)913 LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_union_set *USet) {
914   dumpExpanded(isl::manage_copy(USet));
915 }
916 
dumpExpanded(__isl_keep isl_union_map * UMap)917 LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_union_map *UMap) {
918   dumpExpanded(isl::manage_copy(UMap));
919 }
920 #endif
921