xref: /llvm-project/polly/lib/Transform/FlattenAlgo.cpp (revision 601d7eab0665ba298d81952da11593124fd893a0)
1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
10 // the unittest for this requiring linking against LLVM.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/FlattenAlgo.h"
15 #include "polly/Support/ISLOStream.h"
16 #include "polly/Support/ISLTools.h"
17 #include "polly/Support/PollyDebug.h"
18 #include "llvm/Support/Debug.h"
19 #define DEBUG_TYPE "polly-flatten-algo"
20 
21 using namespace polly;
22 using namespace llvm;
23 
24 namespace {
25 
26 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
27 /// i.e. there are two constants Min and Max, such that every value x of the
28 /// chosen dimensions is Min <= x <= Max.
isDimBoundedByConstant(isl::set Set,unsigned dim)29 bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
30   auto ParamDims = unsignedFromIslSize(Set.dim(isl::dim::param));
31   Set = Set.project_out(isl::dim::param, 0, ParamDims);
32   Set = Set.project_out(isl::dim::set, 0, dim);
33   auto SetDims = unsignedFromIslSize(Set.tuple_dim());
34   assert(SetDims >= 1);
35   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
36   return bool(Set.is_bounded());
37 }
38 
39 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
40 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
41 /// p, such that every value x of the chosen dimensions is
42 /// Min_p <= x <= Max_p.
isDimBoundedByParameter(isl::set Set,unsigned dim)43 bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
44   Set = Set.project_out(isl::dim::set, 0, dim);
45   auto SetDims = unsignedFromIslSize(Set.tuple_dim());
46   assert(SetDims >= 1);
47   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
48   return bool(Set.is_bounded());
49 }
50 
51 /// Whether BMap's first out-dimension is not a constant.
isVariableDim(const isl::basic_map & BMap)52 bool isVariableDim(const isl::basic_map &BMap) {
53   auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
54   return FixedVal.is_null() || FixedVal.is_nan();
55 }
56 
57 /// Whether Map's first out dimension is no constant nor piecewise constant.
isVariableDim(const isl::map & Map)58 bool isVariableDim(const isl::map &Map) {
59   for (isl::basic_map BMap : Map.get_basic_map_list())
60     if (isVariableDim(BMap))
61       return false;
62 
63   return true;
64 }
65 
66 /// Whether UMap's first out dimension is no (piecewise) constant.
isVariableDim(const isl::union_map & UMap)67 bool isVariableDim(const isl::union_map &UMap) {
68   for (isl::map Map : UMap.get_map_list())
69     if (isVariableDim(Map))
70       return false;
71   return true;
72 }
73 
74 /// Compute @p UPwAff - @p Val.
subtract(isl::union_pw_aff UPwAff,isl::val Val)75 isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
76   if (Val.is_zero())
77     return UPwAff;
78 
79   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
80   isl::stat Stat =
81       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
82         auto ValAff =
83             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
84         auto Subtracted = PwAff.sub(ValAff);
85         Result = Result.union_add(isl::union_pw_aff(Subtracted));
86         return isl::stat::ok();
87       });
88   if (Stat.is_error())
89     return {};
90   return Result;
91 }
92 
93 /// Compute @UPwAff * @p Val.
multiply(isl::union_pw_aff UPwAff,isl::val Val)94 isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
95   if (Val.is_one())
96     return UPwAff;
97 
98   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
99   isl::stat Stat =
100       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
101         auto ValAff =
102             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
103         auto Multiplied = PwAff.mul(ValAff);
104         Result = Result.union_add(Multiplied);
105         return isl::stat::ok();
106       });
107   if (Stat.is_error())
108     return {};
109   return Result;
110 }
111 
112 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
113 ///
114 /// It is assumed that all maps in the maps have at least the necessary number
115 /// of out dimensions.
scheduleProjectOut(const isl::union_map & UMap,unsigned first,unsigned n)116 isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
117                                   unsigned n) {
118   if (n == 0)
119     return UMap; /* isl_map_project_out would also reset the tuple, which should
120                     have no effect on schedule ranges */
121 
122   auto Result = isl::union_map::empty(UMap.ctx());
123   for (isl::map Map : UMap.get_map_list()) {
124     auto Outprojected = Map.project_out(isl::dim::out, first, n);
125     Result = Result.unite(Outprojected);
126   }
127   return Result;
128 }
129 
130 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
scheduleExtractDimAff(isl::union_map UMap,unsigned pos)131 isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
132   auto SingleUMap = isl::union_map::empty(UMap.ctx());
133   for (isl::map Map : UMap.get_map_list()) {
134     unsigned MapDims = unsignedFromIslSize(Map.range_tuple_dim());
135     assert(MapDims > pos);
136     isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
137     SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
138     SingleUMap = SingleUMap.unite(SingleMap);
139   };
140 
141   auto UAff = isl::union_pw_multi_aff(SingleUMap);
142   auto FirstMAff = isl::multi_union_pw_aff(UAff);
143   return FirstMAff.at(0);
144 }
145 
146 /// Flatten a sequence-like first dimension.
147 ///
148 /// A sequence-like scatter dimension is constant, or at least only small
149 /// variation, typically the result of ordering a sequence of different
150 /// statements. An example would be:
151 ///   { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
152 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
153 ///
154 /// To flatten, first begin with an offset of zero. Then determine the lowest
155 /// possible value of the dimension, call it "i" [In the example we start at 0].
156 /// Considering only schedules with that value, consider only instances with
157 /// that value and determine the extent of the next dimension. Let l_X(i) and
158 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
159 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
160 /// to Offset and remove all i-instances from the old schedule. Repeat with the
161 /// remaining lowest value i' until there are no instances in the old schedule
162 /// left.
163 /// The example schedule would be transformed to:
164 ///   { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
tryFlattenSequence(isl::union_map Schedule)165 isl::union_map tryFlattenSequence(isl::union_map Schedule) {
166   auto IslCtx = Schedule.ctx();
167   auto ScatterSet = isl::set(Schedule.range());
168 
169   auto ParamSpace = Schedule.get_space().params();
170   auto Dims = unsignedFromIslSize(ScatterSet.tuple_dim());
171   assert(Dims >= 2u);
172 
173   // Would cause an infinite loop.
174   if (!isDimBoundedByConstant(ScatterSet, 0)) {
175     POLLY_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
176     return {};
177   }
178 
179   auto AllDomains = Schedule.domain();
180   auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
181 
182   auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
183   auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
184 
185   while (!ScatterSet.is_empty()) {
186     POLLY_DEBUG(dbgs() << "Next counter:\n  " << Counter << "\n");
187     POLLY_DEBUG(dbgs() << "Remaining scatter set:\n  " << ScatterSet << "\n");
188     auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
189     auto ThisFirst = ThisSet.lexmin();
190     auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
191 
192     auto SubSchedule = Schedule.intersect_range(ScatterFirst);
193     SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
194     SubSchedule = flattenSchedule(SubSchedule);
195 
196     unsigned SubDims = getNumScatterDims(SubSchedule);
197     assert(SubDims >= 1);
198     auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
199     auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
200     auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
201 
202     auto FirstSubScatter = isl::set(FirstSubSchedule.range());
203     POLLY_DEBUG(dbgs() << "Next step in sequence is:\n  " << FirstSubScatter
204                        << "\n");
205 
206     if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
207       POLLY_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
208       return {};
209     }
210 
211     auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
212 
213     // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
214     // 'none'. It doesn't match with any space including a 0-dimensional
215     // anonymous tuple.
216     // Interesting, one can create such a set using
217     // isl_set_universe(ParamSpace). Bug?
218     auto PartMin = FirstSubScatterMap.dim_min(0);
219     auto PartMax = FirstSubScatterMap.dim_max(0);
220     auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
221                            isl::val::one(IslCtx));
222     auto PartLen = PartMax.add(PartMin.neg()).add(One);
223 
224     auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
225     auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
226     auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
227     auto FirstScheduleAffWithOffset =
228         FirstScheduleAffNormalized.add(AllCounter);
229 
230     auto ScheduleWithOffset =
231         isl::union_map::from(
232             isl::union_pw_multi_aff(FirstScheduleAffWithOffset))
233             .flat_range_product(RemainingSubSchedule);
234     NewSchedule = NewSchedule.unite(ScheduleWithOffset);
235 
236     ScatterSet = ScatterSet.subtract(ScatterFirst);
237     Counter = Counter.add(PartLen);
238   }
239 
240   POLLY_DEBUG(dbgs() << "Sequence-flatten result is:\n  " << NewSchedule
241                      << "\n");
242   return NewSchedule;
243 }
244 
245 /// Flatten a loop-like first dimension.
246 ///
247 /// A loop-like dimension is one that depends on a variable (usually a loop's
248 /// induction variable). Let the input schedule look like this:
249 ///   { Stmt[i] -> [i, X, ...] }
250 ///
251 /// To flatten, we determine the largest extent of X which may not depend on the
252 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
253 /// largest value. Then, construct a new schedule
254 ///   { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
tryFlattenLoop(isl::union_map Schedule)255 isl::union_map tryFlattenLoop(isl::union_map Schedule) {
256   assert(getNumScatterDims(Schedule) >= 2);
257 
258   auto Remaining = scheduleProjectOut(Schedule, 0, 1);
259   auto SubSchedule = flattenSchedule(Remaining);
260   unsigned SubDims = getNumScatterDims(SubSchedule);
261 
262   assert(SubDims >= 1);
263 
264   auto SubExtent = isl::set(SubSchedule.range());
265   auto SubExtentDims = unsignedFromIslSize(SubExtent.dim(isl::dim::param));
266   SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
267   SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
268 
269   if (!isDimBoundedByConstant(SubExtent, 0)) {
270     POLLY_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
271     return {};
272   }
273 
274   auto Min = SubExtent.dim_min(0);
275   POLLY_DEBUG(dbgs() << "Min bound:\n  " << Min << "\n");
276   auto MinVal = getConstant(Min, false, true);
277   auto Max = SubExtent.dim_max(0);
278   POLLY_DEBUG(dbgs() << "Max bound:\n  " << Max << "\n");
279   auto MaxVal = getConstant(Max, true, false);
280 
281   if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
282       MaxVal.is_nan()) {
283     POLLY_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
284     return {};
285   }
286 
287   auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
288   auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
289 
290   auto LenVal = MaxVal.sub(MinVal).add(1);
291   auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
292 
293   // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
294   // subtract it)
295   auto FirstAff = scheduleExtractDimAff(Schedule, 0);
296   auto Offset = multiply(FirstAff, LenVal);
297   isl::union_pw_multi_aff Index = FirstSubScheduleNormalized.add(Offset);
298   auto IndexMap = isl::union_map::from(Index);
299 
300   auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
301   POLLY_DEBUG(dbgs() << "Loop-flatten result is:\n  " << Result << "\n");
302   return Result;
303 }
304 } // anonymous namespace
305 
flattenSchedule(isl::union_map Schedule)306 isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
307   unsigned Dims = getNumScatterDims(Schedule);
308   POLLY_DEBUG(dbgs() << "Recursive schedule to process:\n  " << Schedule
309                      << "\n");
310 
311   // Base case; no dimensions left
312   if (Dims == 0) {
313     // TODO: Add one dimension?
314     return Schedule;
315   }
316 
317   // Base case; already one-dimensional
318   if (Dims == 1)
319     return Schedule;
320 
321   // Fixed dimension; no need to preserve variabledness.
322   if (!isVariableDim(Schedule)) {
323     POLLY_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
324     auto NewScheduleSequence = tryFlattenSequence(Schedule);
325     if (!NewScheduleSequence.is_null())
326       return NewScheduleSequence;
327   }
328 
329   // Constant stride
330   POLLY_DEBUG(dbgs() << "Try loop flattening\n");
331   auto NewScheduleLoop = tryFlattenLoop(Schedule);
332   if (!NewScheduleLoop.is_null())
333     return NewScheduleLoop;
334 
335   // Try again without loop condition (may blow up the number of pieces!!)
336   POLLY_DEBUG(dbgs() << "Try sequence flattening again\n");
337   auto NewScheduleSequence = tryFlattenSequence(Schedule);
338   if (!NewScheduleSequence.is_null())
339     return NewScheduleSequence;
340 
341   // Cannot flatten
342   return Schedule;
343 }
344