xref: /llvm-project/polly/lib/Transform/MatmulOptimizer.cpp (revision 5aafc6d58f3405662902cee006be11e599801b88)
1d123e983SMichael Kruse //===- MatmulOptimizer.cpp -----------------------------------------------===//
2d123e983SMichael Kruse //
3d123e983SMichael Kruse // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d123e983SMichael Kruse // See https://llvm.org/LICENSE.txt for license information.
5d123e983SMichael Kruse // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d123e983SMichael Kruse //
7d123e983SMichael Kruse //===----------------------------------------------------------------------===//
8d123e983SMichael Kruse 
9d123e983SMichael Kruse #include "polly/MatmulOptimizer.h"
10d123e983SMichael Kruse #include "polly/DependenceInfo.h"
11d123e983SMichael Kruse #include "polly/Options.h"
12d123e983SMichael Kruse #include "polly/ScheduleTreeTransform.h"
13d123e983SMichael Kruse #include "polly/ScopInfo.h"
14d123e983SMichael Kruse #include "polly/ScopPass.h"
15d123e983SMichael Kruse #include "polly/Simplify.h"
16b02c7e2bSRoman Gareev #include "polly/Support/GICHelper.h"
17a56bd7deSMichael Kruse #include "polly/Support/ISLTools.h"
18d123e983SMichael Kruse #include "llvm/ADT/ArrayRef.h"
19b02c7e2bSRoman Gareev #include "llvm/ADT/DenseSet.h"
20d123e983SMichael Kruse #include "llvm/ADT/Sequence.h"
21b02c7e2bSRoman Gareev #include "llvm/ADT/SetOperations.h"
22d123e983SMichael Kruse #include "llvm/ADT/SmallVector.h"
23d123e983SMichael Kruse #include "llvm/ADT/StringRef.h"
24d123e983SMichael Kruse #include "llvm/ADT/iterator_range.h"
25d123e983SMichael Kruse #include "llvm/Analysis/TargetTransformInfo.h"
26d123e983SMichael Kruse #include "llvm/IR/DataLayout.h"
27d123e983SMichael Kruse #include "llvm/IR/Function.h"
28d123e983SMichael Kruse #include "llvm/IR/Module.h"
29d123e983SMichael Kruse #include "llvm/Support/CommandLine.h"
30d123e983SMichael Kruse #include "llvm/Support/Debug.h"
31d123e983SMichael Kruse #include "llvm/Support/TypeSize.h"
32d123e983SMichael Kruse #include "llvm/Support/raw_ostream.h"
33d123e983SMichael Kruse #include "isl/ctx.h"
34d123e983SMichael Kruse #include "isl/schedule_node.h"
35d123e983SMichael Kruse #include "isl/schedule_type.h"
36d123e983SMichael Kruse #include "isl/union_map.h"
37d123e983SMichael Kruse #include "isl/union_set.h"
38d123e983SMichael Kruse #include <algorithm>
39d123e983SMichael Kruse #include <cassert>
40d123e983SMichael Kruse #include <cmath>
41d123e983SMichael Kruse #include <cstdint>
42d123e983SMichael Kruse #include <string>
43d123e983SMichael Kruse #include <vector>
44d123e983SMichael Kruse 
45601d7eabSKarthika Devi C #include "polly/Support/PollyDebug.h"
46d123e983SMichael Kruse #define DEBUG_TYPE "polly-opt-isl"
47d123e983SMichael Kruse 
48d123e983SMichael Kruse using namespace llvm;
49d123e983SMichael Kruse using namespace polly;
50d123e983SMichael Kruse 
51d123e983SMichael Kruse namespace llvm {
52d123e983SMichael Kruse class Value;
53d123e983SMichael Kruse }
54d123e983SMichael Kruse 
55d123e983SMichael Kruse static cl::opt<int> LatencyVectorFma(
56d123e983SMichael Kruse     "polly-target-latency-vector-fma",
57d123e983SMichael Kruse     cl::desc("The minimal number of cycles between issuing two "
58d123e983SMichael Kruse              "dependent consecutive vector fused multiply-add "
59d123e983SMichael Kruse              "instructions."),
6095a13425SFangrui Song     cl::Hidden, cl::init(8), cl::cat(PollyCategory));
61d123e983SMichael Kruse 
62d123e983SMichael Kruse static cl::opt<int> ThroughputVectorFma(
63d123e983SMichael Kruse     "polly-target-throughput-vector-fma",
64d123e983SMichael Kruse     cl::desc("A throughput of the processor floating-point arithmetic units "
65d123e983SMichael Kruse              "expressed in the number of vector fused multiply-add "
66d123e983SMichael Kruse              "instructions per clock cycle."),
6795a13425SFangrui Song     cl::Hidden, cl::init(1), cl::cat(PollyCategory));
68d123e983SMichael Kruse 
69d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelSize(
70d123e983SMichael Kruse     "polly-target-1st-cache-level-size",
71d123e983SMichael Kruse     cl::desc("The size of the first cache level specified in bytes."),
7236c7d79dSFangrui Song     cl::Hidden, cl::init(-1), cl::cat(PollyCategory));
73d123e983SMichael Kruse 
74d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelDefaultSize(
75d123e983SMichael Kruse     "polly-target-1st-cache-level-default-size",
76d123e983SMichael Kruse     cl::desc("The default size of the first cache level specified in bytes"
77d123e983SMichael Kruse              " (if not enough were provided by the TargetTransformInfo)."),
7836c7d79dSFangrui Song     cl::Hidden, cl::init(32768), cl::cat(PollyCategory));
79d123e983SMichael Kruse 
80d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelSize(
81d123e983SMichael Kruse     "polly-target-2nd-cache-level-size",
82d123e983SMichael Kruse     cl::desc("The size of the second level specified in bytes."), cl::Hidden,
8336c7d79dSFangrui Song     cl::init(-1), cl::cat(PollyCategory));
84d123e983SMichael Kruse 
85d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelDefaultSize(
86d123e983SMichael Kruse     "polly-target-2nd-cache-level-default-size",
87d123e983SMichael Kruse     cl::desc("The default size of the second cache level specified in bytes"
88d123e983SMichael Kruse              " (if not enough were provided by the TargetTransformInfo)."),
8936c7d79dSFangrui Song     cl::Hidden, cl::init(262144), cl::cat(PollyCategory));
90d123e983SMichael Kruse 
91d123e983SMichael Kruse // This option, along with --polly-target-2nd-cache-level-associativity,
92d123e983SMichael Kruse // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
93d123e983SMichael Kruse // represent the parameters of the target cache, which do not have typical
94d123e983SMichael Kruse // values that can be used by default. However, to apply the pattern matching
95d123e983SMichael Kruse // optimizations, we use the values of the parameters of Intel Core i7-3820
96d123e983SMichael Kruse // SandyBridge in case the parameters are not specified or not provided by the
97d123e983SMichael Kruse // TargetTransformInfo.
98d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelAssociativity(
99d123e983SMichael Kruse     "polly-target-1st-cache-level-associativity",
100d123e983SMichael Kruse     cl::desc("The associativity of the first cache level."), cl::Hidden,
10136c7d79dSFangrui Song     cl::init(-1), cl::cat(PollyCategory));
102d123e983SMichael Kruse 
103d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelDefaultAssociativity(
104d123e983SMichael Kruse     "polly-target-1st-cache-level-default-associativity",
105d123e983SMichael Kruse     cl::desc("The default associativity of the first cache level"
106d123e983SMichael Kruse              " (if not enough were provided by the TargetTransformInfo)."),
10736c7d79dSFangrui Song     cl::Hidden, cl::init(8), cl::cat(PollyCategory));
108d123e983SMichael Kruse 
109d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelAssociativity(
110d123e983SMichael Kruse     "polly-target-2nd-cache-level-associativity",
111d123e983SMichael Kruse     cl::desc("The associativity of the second cache level."), cl::Hidden,
11236c7d79dSFangrui Song     cl::init(-1), cl::cat(PollyCategory));
113d123e983SMichael Kruse 
114d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelDefaultAssociativity(
115d123e983SMichael Kruse     "polly-target-2nd-cache-level-default-associativity",
116d123e983SMichael Kruse     cl::desc("The default associativity of the second cache level"
117d123e983SMichael Kruse              " (if not enough were provided by the TargetTransformInfo)."),
11836c7d79dSFangrui Song     cl::Hidden, cl::init(8), cl::cat(PollyCategory));
119d123e983SMichael Kruse 
120d123e983SMichael Kruse static cl::opt<int> VectorRegisterBitwidth(
121d123e983SMichael Kruse     "polly-target-vector-register-bitwidth",
122d123e983SMichael Kruse     cl::desc("The size in bits of a vector register (if not set, this "
123d123e983SMichael Kruse              "information is taken from LLVM's target information."),
12436c7d79dSFangrui Song     cl::Hidden, cl::init(-1), cl::cat(PollyCategory));
125d123e983SMichael Kruse 
126d123e983SMichael Kruse static cl::opt<int> PollyPatternMatchingNcQuotient(
127d123e983SMichael Kruse     "polly-pattern-matching-nc-quotient",
128d123e983SMichael Kruse     cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
129d123e983SMichael Kruse              "macro-kernel, by Nr, the parameter of the micro-kernel"),
13036c7d79dSFangrui Song     cl::Hidden, cl::init(256), cl::cat(PollyCategory));
131d123e983SMichael Kruse 
132b02c7e2bSRoman Gareev static cl::opt<bool>
133b02c7e2bSRoman Gareev     PMBasedTCOpts("polly-tc-opt",
134b02c7e2bSRoman Gareev                   cl::desc("Perform optimizations of tensor contractions based "
135b02c7e2bSRoman Gareev                            "on pattern matching"),
136b02c7e2bSRoman Gareev                   cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
137b02c7e2bSRoman Gareev 
138b02c7e2bSRoman Gareev static cl::opt<bool>
139b02c7e2bSRoman Gareev     PMBasedMMMOpts("polly-matmul-opt",
140b02c7e2bSRoman Gareev                    cl::desc("Perform optimizations of matrix multiplications "
141b02c7e2bSRoman Gareev                             "based on pattern matching"),
142b02c7e2bSRoman Gareev                    cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory));
143b02c7e2bSRoman Gareev 
144b02c7e2bSRoman Gareev static cl::opt<int> OptComputeOut(
145b02c7e2bSRoman Gareev     "polly-tc-dependences-computeout",
146b02c7e2bSRoman Gareev     cl::desc("Bound the dependence analysis by a maximal amount of "
147b02c7e2bSRoman Gareev              "computational steps (0 means no bound)"),
148b02c7e2bSRoman Gareev     cl::Hidden, cl::init(500000), cl::ZeroOrMore, cl::cat(PollyCategory));
149b02c7e2bSRoman Gareev 
150d123e983SMichael Kruse namespace {
151d123e983SMichael Kruse /// Parameters of the micro kernel.
152d123e983SMichael Kruse ///
153d123e983SMichael Kruse /// Parameters, which determine sizes of rank-1 (i.e., outer product) update
154d123e983SMichael Kruse /// used in the optimized matrix multiplication.
155d123e983SMichael Kruse struct MicroKernelParamsTy {
156d123e983SMichael Kruse   int Mr;
157d123e983SMichael Kruse   int Nr;
158d123e983SMichael Kruse };
159d123e983SMichael Kruse 
160d123e983SMichael Kruse /// Parameters of the macro kernel.
161d123e983SMichael Kruse ///
162d123e983SMichael Kruse /// Parameters, which determine sizes of blocks of partitioned matrices
163d123e983SMichael Kruse /// used in the optimized matrix multiplication.
164d123e983SMichael Kruse struct MacroKernelParamsTy {
165d123e983SMichael Kruse   int Mc;
166d123e983SMichael Kruse   int Nc;
167d123e983SMichael Kruse   int Kc;
168d123e983SMichael Kruse };
169d123e983SMichael Kruse 
170d123e983SMichael Kruse /// Parameters of the matrix multiplication operands.
171d123e983SMichael Kruse ///
172d123e983SMichael Kruse /// Parameters, which describe access relations that represent operands of the
173d123e983SMichael Kruse /// matrix multiplication.
174d123e983SMichael Kruse struct MatMulInfoTy {
175d123e983SMichael Kruse   MemoryAccess *A = nullptr;
176d123e983SMichael Kruse   MemoryAccess *B = nullptr;
177d123e983SMichael Kruse   MemoryAccess *ReadFromC = nullptr;
178d123e983SMichael Kruse   MemoryAccess *WriteToC = nullptr;
179d123e983SMichael Kruse   int i = -1;
180d123e983SMichael Kruse   int j = -1;
181d123e983SMichael Kruse   int k = -1;
182d123e983SMichael Kruse };
183d123e983SMichael Kruse 
184b02c7e2bSRoman Gareev /// Parameters of the tensor contraction operands.
185b02c7e2bSRoman Gareev ///
186b02c7e2bSRoman Gareev /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined
187b02c7e2bSRoman Gareev /// as the set of scalar elements indexed by the set of indices u0 ... ud,
188b02c7e2bSRoman Gareev ///
189b02c7e2bSRoman Gareev /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}.
190b02c7e2bSRoman Gareev ///
191b02c7e2bSRoman Gareev /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively.
192b02c7e2bSRoman Gareev /// Let the free and the contracted indices of the tensor A be grouped into
193b02c7e2bSRoman Gareev /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly,
194b02c7e2bSRoman Gareev /// the free and the contracted indices of B are grouped into bundles
195b02c7e2bSRoman Gareev /// J = j0..js−1 and P and the free indices of C are grouped into
196b02c7e2bSRoman Gareev /// bundles I and J.
197b02c7e2bSRoman Gareev ///
198b02c7e2bSRoman Gareev /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as
199b02c7e2bSRoman Gareev /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)),
200b02c7e2bSRoman Gareev /// where ∑ is a summation over all contracted indices of P,
201b02c7e2bSRoman Gareev /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds
202b02c7e2bSRoman Gareev /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are
203b02c7e2bSRoman Gareev /// accesses to tensors A, B, C, respectively,
204b02c7e2bSRoman Gareev /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of
205b02c7e2bSRoman Gareev /// the enclosed indices.
206b02c7e2bSRoman Gareev ///
207b02c7e2bSRoman Gareev /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP
208b02c7e2bSRoman Gareev /// statement by loop distribution, which is done by the isl scheduler.
209b02c7e2bSRoman Gareev //  If β is not equal to one, the optimization of TC of Polly requires
210b02c7e2bSRoman Gareev /// such a transformation.
211b02c7e2bSRoman Gareev ///
212b02c7e2bSRoman Gareev /// TCInfoTy contains parameters, which describe access relations that represent
213b02c7e2bSRoman Gareev /// operands of the tensor contraction.
214b02c7e2bSRoman Gareev struct TCInfoTy {
215b02c7e2bSRoman Gareev   /// @{
216b02c7e2bSRoman Gareev   /// Memory accesses that represent reading from tensors, which are operands of
217b02c7e2bSRoman Gareev   /// the tensor contraction.
218b02c7e2bSRoman Gareev   MemoryAccess *A = nullptr;
219b02c7e2bSRoman Gareev   MemoryAccess *B = nullptr;
220b02c7e2bSRoman Gareev   /// @}
221b02c7e2bSRoman Gareev 
222b02c7e2bSRoman Gareev   /// @{
223b02c7e2bSRoman Gareev   /// Memory accesses that represent reading from and writing into the tensor,
224b02c7e2bSRoman Gareev   /// which contains the result of the tensor contraction.
225b02c7e2bSRoman Gareev   MemoryAccess *ReadFromC = nullptr;
226b02c7e2bSRoman Gareev   MemoryAccess *WriteToC = nullptr;
227b02c7e2bSRoman Gareev   /// @}
228b02c7e2bSRoman Gareev 
229b02c7e2bSRoman Gareev   /// @{
230b02c7e2bSRoman Gareev   /// Input dimensions of the schedule space, which represent free
231b02c7e2bSRoman Gareev   /// indices of tensors.
232b02c7e2bSRoman Gareev   SmallDenseSet<int> I;
233b02c7e2bSRoman Gareev   SmallDenseSet<int> J;
234b02c7e2bSRoman Gareev   /// @}
235b02c7e2bSRoman Gareev 
236b02c7e2bSRoman Gareev   /// Input dimension of the schedule space, which represents contracted
237b02c7e2bSRoman Gareev   /// indices of tensors.
238b02c7e2bSRoman Gareev   SmallDenseSet<int> P;
239b02c7e2bSRoman Gareev 
240b02c7e2bSRoman Gareev   /// @{
241b02c7e2bSRoman Gareev   /// Sizes of tensor dimensions for corresponding input dimensions of
242b02c7e2bSRoman Gareev   /// the schedule space. The size of the tensor dimension can be larger than
243b02c7e2bSRoman Gareev   /// the size of the corresponding input dimension of the schedule space.
244b02c7e2bSRoman Gareev   /// This does not correspond to a tensor contraction. However, such a pattern
245b02c7e2bSRoman Gareev   /// will be optimized by the transformation.
246b02c7e2bSRoman Gareev   SmallVector<int> DimensionSizes;
247b02c7e2bSRoman Gareev   SmallVector<int> ADimensions;
248b02c7e2bSRoman Gareev   SmallVector<int> BDimensions;
249b02c7e2bSRoman Gareev   SmallVector<int> CDimensions;
250b02c7e2bSRoman Gareev   /// @}
251b02c7e2bSRoman Gareev 
252b02c7e2bSRoman Gareev   /// @{
253b02c7e2bSRoman Gareev   /// Permutations of indices of I, J, and P, which describe operands of
254b02c7e2bSRoman Gareev   /// the tensor contraction and its result.
255b02c7e2bSRoman Gareev   SmallVector<int> OrderedI;
256b02c7e2bSRoman Gareev   SmallVector<int> OrderedJ;
257b02c7e2bSRoman Gareev   SmallVector<int> OrderedP;
258b02c7e2bSRoman Gareev   /// @}
259b02c7e2bSRoman Gareev };
260b02c7e2bSRoman Gareev 
261d123e983SMichael Kruse /// Create an isl::union_set, which describes the option of the form
262d123e983SMichael Kruse /// [isolate[] -> unroll[x]].
263d123e983SMichael Kruse ///
264d123e983SMichael Kruse /// @param Ctx An isl::ctx, which is used to create the isl::union_set.
265d123e983SMichael Kruse static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
266d123e983SMichael Kruse   isl::space Space = isl::space(Ctx, 0, 0, 1);
267d123e983SMichael Kruse   isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
268d123e983SMichael Kruse   isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
269d123e983SMichael Kruse   isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
270d123e983SMichael Kruse   UnrollIsolatedSetOption =
271d123e983SMichael Kruse       UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
272d123e983SMichael Kruse   UnrollIsolatedSetOption =
273d123e983SMichael Kruse       UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
274d123e983SMichael Kruse   return UnrollIsolatedSetOption.wrap();
275d123e983SMichael Kruse }
276d123e983SMichael Kruse 
277d123e983SMichael Kruse /// Permute the two dimensions of the isl map.
278d123e983SMichael Kruse ///
279d123e983SMichael Kruse /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
280d123e983SMichael Kruse /// have type @p DimType.
281d123e983SMichael Kruse ///
282d123e983SMichael Kruse /// @param Map     The isl map to be modified.
283d123e983SMichael Kruse /// @param DimType The type of the dimensions.
284d123e983SMichael Kruse /// @param DstPos  The first dimension.
285d123e983SMichael Kruse /// @param SrcPos  The second dimension.
286d123e983SMichael Kruse /// @return        The modified map.
287d123e983SMichael Kruse static isl::map permuteDimensions(isl::map Map, isl::dim DimType,
288d123e983SMichael Kruse                                   unsigned DstPos, unsigned SrcPos) {
28944596fe6SRiccardo Mori   assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) &&
29044596fe6SRiccardo Mori          SrcPos < unsignedFromIslSize(Map.dim(DimType)));
291d123e983SMichael Kruse   if (DstPos == SrcPos)
292d123e983SMichael Kruse     return Map;
293d123e983SMichael Kruse   isl::id DimId;
294d123e983SMichael Kruse   if (Map.has_tuple_id(DimType))
295d123e983SMichael Kruse     DimId = Map.get_tuple_id(DimType);
296d123e983SMichael Kruse   auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in;
297d123e983SMichael Kruse   isl::id FreeDimId;
298d123e983SMichael Kruse   if (Map.has_tuple_id(FreeDim))
299d123e983SMichael Kruse     FreeDimId = Map.get_tuple_id(FreeDim);
300d123e983SMichael Kruse   auto MaxDim = std::max(DstPos, SrcPos);
301d123e983SMichael Kruse   auto MinDim = std::min(DstPos, SrcPos);
302d123e983SMichael Kruse   Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
303d123e983SMichael Kruse   Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
304d123e983SMichael Kruse   Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
305d123e983SMichael Kruse   Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
3067c7978a1Spatacca   if (!DimId.is_null())
307d123e983SMichael Kruse     Map = Map.set_tuple_id(DimType, DimId);
3087c7978a1Spatacca   if (!FreeDimId.is_null())
309d123e983SMichael Kruse     Map = Map.set_tuple_id(FreeDim, FreeDimId);
310d123e983SMichael Kruse   return Map;
311d123e983SMichael Kruse }
312d123e983SMichael Kruse 
313d123e983SMichael Kruse /// Check the form of the access relation.
314d123e983SMichael Kruse ///
315d123e983SMichael Kruse /// Check that the access relation @p AccMap has the form M[i][j], where i
316d123e983SMichael Kruse /// is a @p FirstPos and j is a @p SecondPos.
317d123e983SMichael Kruse ///
318d123e983SMichael Kruse /// @param AccMap    The access relation to be checked.
319d123e983SMichael Kruse /// @param FirstPos  The index of the input dimension that is mapped to
320d123e983SMichael Kruse ///                  the first output dimension.
321d123e983SMichael Kruse /// @param SecondPos The index of the input dimension that is mapped to the
322d123e983SMichael Kruse ///                  second output dimension.
323d123e983SMichael Kruse /// @return          True in case @p AccMap has the expected form and false,
324d123e983SMichael Kruse ///                  otherwise.
325d123e983SMichael Kruse static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
326d123e983SMichael Kruse                                int &SecondPos) {
327d123e983SMichael Kruse   isl::space Space = AccMap.get_space();
328d123e983SMichael Kruse   isl::map Universe = isl::map::universe(Space);
329d123e983SMichael Kruse 
33044596fe6SRiccardo Mori   if (unsignedFromIslSize(Space.dim(isl::dim::out)) != 2)
331d123e983SMichael Kruse     return false;
332d123e983SMichael Kruse 
333d123e983SMichael Kruse   // MatMul has the form:
334d123e983SMichael Kruse   // for (i = 0; i < N; i++)
335d123e983SMichael Kruse   //   for (j = 0; j < M; j++)
336d123e983SMichael Kruse   //     for (k = 0; k < P; k++)
337d123e983SMichael Kruse   //       C[i, j] += A[i, k] * B[k, j]
338d123e983SMichael Kruse   //
339d123e983SMichael Kruse   // Permutation of three outer loops: 3! = 6 possibilities.
340d123e983SMichael Kruse   int FirstDims[] = {0, 0, 1, 1, 2, 2};
341d123e983SMichael Kruse   int SecondDims[] = {1, 2, 2, 0, 0, 1};
342d123e983SMichael Kruse   for (int i = 0; i < 6; i += 1) {
343d123e983SMichael Kruse     auto PossibleMatMul =
344d123e983SMichael Kruse         Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
345d123e983SMichael Kruse             .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
346d123e983SMichael Kruse 
347d123e983SMichael Kruse     AccMap = AccMap.intersect_domain(Domain);
348d123e983SMichael Kruse     PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
349d123e983SMichael Kruse 
350d123e983SMichael Kruse     // If AccMap spans entire domain (Non-partial write),
351d123e983SMichael Kruse     // compute FirstPos and SecondPos.
352d123e983SMichael Kruse     // If AccMap != PossibleMatMul here (the two maps have been gisted at
353d123e983SMichael Kruse     // this point), it means that the writes are not complete, or in other
354d123e983SMichael Kruse     // words, it is a Partial write and Partial writes must be rejected.
355d123e983SMichael Kruse     if (AccMap.is_equal(PossibleMatMul)) {
356d123e983SMichael Kruse       if (FirstPos != -1 && FirstPos != FirstDims[i])
357d123e983SMichael Kruse         continue;
358d123e983SMichael Kruse       FirstPos = FirstDims[i];
359d123e983SMichael Kruse       if (SecondPos != -1 && SecondPos != SecondDims[i])
360d123e983SMichael Kruse         continue;
361d123e983SMichael Kruse       SecondPos = SecondDims[i];
362d123e983SMichael Kruse       return true;
363d123e983SMichael Kruse     }
364d123e983SMichael Kruse   }
365d123e983SMichael Kruse 
366d123e983SMichael Kruse   return false;
367d123e983SMichael Kruse }
368d123e983SMichael Kruse 
369d123e983SMichael Kruse /// Does the memory access represent a non-scalar operand of the matrix
370d123e983SMichael Kruse /// multiplication.
371d123e983SMichael Kruse ///
372d123e983SMichael Kruse /// Check that the memory access @p MemAccess is the read access to a non-scalar
373d123e983SMichael Kruse /// operand of the matrix multiplication or its result.
374d123e983SMichael Kruse ///
375d123e983SMichael Kruse /// @param MemAccess The memory access to be checked.
376d123e983SMichael Kruse /// @param MMI       Parameters of the matrix multiplication operands.
377d123e983SMichael Kruse /// @return          True in case the memory access represents the read access
378d123e983SMichael Kruse ///                  to a non-scalar operand of the matrix multiplication and
379d123e983SMichael Kruse ///                  false, otherwise.
380d123e983SMichael Kruse static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
381d123e983SMichael Kruse                                         MatMulInfoTy &MMI) {
382d123e983SMichael Kruse   if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
383d123e983SMichael Kruse     return false;
384d123e983SMichael Kruse   auto AccMap = MemAccess->getLatestAccessRelation();
385d123e983SMichael Kruse   isl::set StmtDomain = MemAccess->getStatement()->getDomain();
386d123e983SMichael Kruse   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
387d123e983SMichael Kruse     MMI.ReadFromC = MemAccess;
388d123e983SMichael Kruse     return true;
389d123e983SMichael Kruse   }
390d123e983SMichael Kruse   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
391d123e983SMichael Kruse     MMI.A = MemAccess;
392d123e983SMichael Kruse     return true;
393d123e983SMichael Kruse   }
394d123e983SMichael Kruse   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
395d123e983SMichael Kruse     MMI.B = MemAccess;
396d123e983SMichael Kruse     return true;
397d123e983SMichael Kruse   }
398d123e983SMichael Kruse   return false;
399d123e983SMichael Kruse }
400d123e983SMichael Kruse 
401d123e983SMichael Kruse /// Check accesses to operands of the matrix multiplication.
402d123e983SMichael Kruse ///
403d123e983SMichael Kruse /// Check that accesses of the SCoP statement, which corresponds to
404d123e983SMichael Kruse /// the partial schedule @p PartialSchedule, are scalar in terms of loops
405d123e983SMichael Kruse /// containing the matrix multiplication, in case they do not represent
406d123e983SMichael Kruse /// accesses to the non-scalar operands of the matrix multiplication or
407d123e983SMichael Kruse /// its result.
408d123e983SMichael Kruse ///
409d123e983SMichael Kruse /// @param  PartialSchedule The partial schedule of the SCoP statement.
410d123e983SMichael Kruse /// @param  MMI             Parameters of the matrix multiplication operands.
411d123e983SMichael Kruse /// @return                 True in case the corresponding SCoP statement
412d123e983SMichael Kruse ///                         represents matrix multiplication and false,
413d123e983SMichael Kruse ///                         otherwise.
414d123e983SMichael Kruse static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
415d123e983SMichael Kruse                                     MatMulInfoTy &MMI) {
416d123e983SMichael Kruse   auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
417d123e983SMichael Kruse   auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
41844596fe6SRiccardo Mori   unsigned OutDimNum = unsignedFromIslSize(PartialSchedule.range_tuple_dim());
419d123e983SMichael Kruse   assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
420d123e983SMichael Kruse                           "and, consequently, the corresponding scheduling "
421d123e983SMichael Kruse                           "functions have at least three dimensions.");
422d123e983SMichael Kruse   auto MapI =
423d123e983SMichael Kruse       permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
424d123e983SMichael Kruse   auto MapJ =
425d123e983SMichael Kruse       permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
426d123e983SMichael Kruse   auto MapK =
427d123e983SMichael Kruse       permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
428d123e983SMichael Kruse 
429d123e983SMichael Kruse   auto Accesses = getAccessesInOrder(*Stmt);
430d123e983SMichael Kruse   for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) {
431d123e983SMichael Kruse     auto *MemAccessPtr = *MemA;
432d123e983SMichael Kruse     if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC &&
433d123e983SMichael Kruse         !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) &&
434113fa82cSRoman Gareev         !(MemAccessPtr->isStrideZero(MapI) &&
435113fa82cSRoman Gareev           MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK)))
436d123e983SMichael Kruse       return false;
437d123e983SMichael Kruse   }
438d123e983SMichael Kruse   return true;
439d123e983SMichael Kruse }
440d123e983SMichael Kruse 
441d123e983SMichael Kruse /// Check for dependencies corresponding to the matrix multiplication.
442d123e983SMichael Kruse ///
443d123e983SMichael Kruse /// Check that there is only true dependence of the form
444d123e983SMichael Kruse /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
445d123e983SMichael Kruse /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
446d123e983SMichael Kruse /// to the dependency produced by the matrix multiplication.
447d123e983SMichael Kruse ///
448d123e983SMichael Kruse /// @param  Schedule The schedule of the SCoP statement.
449d123e983SMichael Kruse /// @param  D The SCoP dependencies.
450d123e983SMichael Kruse /// @param  Pos The parameter to describe an acceptable true dependence.
451d123e983SMichael Kruse ///             In case it has a negative value, try to determine its
452d123e983SMichael Kruse ///             acceptable value.
453d123e983SMichael Kruse /// @return True in case dependencies correspond to the matrix multiplication
454d123e983SMichael Kruse ///         and false, otherwise.
455d123e983SMichael Kruse static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
456d123e983SMichael Kruse                                   int &Pos) {
457d123e983SMichael Kruse   isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
458d123e983SMichael Kruse   isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
4597c7978a1Spatacca   if (!Red.is_null())
460d123e983SMichael Kruse     Dep = Dep.unite(Red);
461d123e983SMichael Kruse   auto DomainSpace = Schedule.get_space().domain();
462d123e983SMichael Kruse   auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
463d123e983SMichael Kruse   auto Deltas = Dep.extract_map(Space).deltas();
46444596fe6SRiccardo Mori   int DeltasDimNum = unsignedFromIslSize(Deltas.dim(isl::dim::set));
465d123e983SMichael Kruse   for (int i = 0; i < DeltasDimNum; i++) {
466d123e983SMichael Kruse     auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
467d123e983SMichael Kruse     Pos = Pos < 0 && Val.is_one() ? i : Pos;
468d123e983SMichael Kruse     if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one())))
469d123e983SMichael Kruse       return false;
470d123e983SMichael Kruse   }
471d123e983SMichael Kruse   if (DeltasDimNum == 0 || Pos < 0)
472d123e983SMichael Kruse     return false;
473d123e983SMichael Kruse   return true;
474d123e983SMichael Kruse }
475d123e983SMichael Kruse 
476d123e983SMichael Kruse /// Check if the SCoP statement could probably be optimized with analytical
477d123e983SMichael Kruse /// modeling.
478d123e983SMichael Kruse ///
479d123e983SMichael Kruse /// containsMatrMult tries to determine whether the following conditions
480d123e983SMichael Kruse /// are true:
481d123e983SMichael Kruse /// 1. The last memory access modeling an array, MA1, represents writing to
482d123e983SMichael Kruse ///    memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
483d123e983SMichael Kruse ///    S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
484d123e983SMichael Kruse ///    under consideration.
485d123e983SMichael Kruse /// 2. There is only one loop-carried true dependency, and it has the
486d123e983SMichael Kruse ///    form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
487d123e983SMichael Kruse ///    loop-carried or anti dependencies.
488d123e983SMichael Kruse /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
489d123e983SMichael Kruse ///    reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
490d123e983SMichael Kruse ///    S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
491d123e983SMichael Kruse ///    and all memory accesses of the SCoP that are different from MA1, MA2,
492d123e983SMichael Kruse ///    MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
493d123e983SMichael Kruse ///    of loops i1, i2 and i3.
494d123e983SMichael Kruse ///
495d123e983SMichael Kruse /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
496d123e983SMichael Kruse ///        to check.
497d123e983SMichael Kruse /// @D     The SCoP dependencies.
498d123e983SMichael Kruse /// @MMI   Parameters of the matrix multiplication operands.
499d123e983SMichael Kruse static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
500d123e983SMichael Kruse                              MatMulInfoTy &MMI) {
501d123e983SMichael Kruse   auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
502d123e983SMichael Kruse   auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
503d123e983SMichael Kruse   if (Stmt->size() <= 1)
504d123e983SMichael Kruse     return false;
505d123e983SMichael Kruse 
506d123e983SMichael Kruse   auto Accesses = getAccessesInOrder(*Stmt);
507d123e983SMichael Kruse   for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) {
508d123e983SMichael Kruse     auto *MemAccessPtr = *MemA;
509d123e983SMichael Kruse     if (!MemAccessPtr->isLatestArrayKind())
510d123e983SMichael Kruse       continue;
511d123e983SMichael Kruse     if (!MemAccessPtr->isWrite())
512d123e983SMichael Kruse       return false;
513d123e983SMichael Kruse     auto AccMap = MemAccessPtr->getLatestAccessRelation();
514d123e983SMichael Kruse     if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
515d123e983SMichael Kruse       return false;
516d123e983SMichael Kruse     MMI.WriteToC = MemAccessPtr;
517d123e983SMichael Kruse     break;
518d123e983SMichael Kruse   }
519d123e983SMichael Kruse 
520d123e983SMichael Kruse   if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
521d123e983SMichael Kruse     return false;
522d123e983SMichael Kruse 
523d123e983SMichael Kruse   if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
524d123e983SMichael Kruse     return false;
525d123e983SMichael Kruse 
526d123e983SMichael Kruse   if (!MMI.A || !MMI.B || !MMI.ReadFromC)
527d123e983SMichael Kruse     return false;
528d123e983SMichael Kruse   return true;
529d123e983SMichael Kruse }
530d123e983SMichael Kruse 
531d123e983SMichael Kruse /// Permute two dimensions of the band node.
532d123e983SMichael Kruse ///
533d123e983SMichael Kruse /// Permute FirstDim and SecondDim dimensions of the Node.
534d123e983SMichael Kruse ///
535d123e983SMichael Kruse /// @param Node The band node to be modified.
536d123e983SMichael Kruse /// @param FirstDim The first dimension to be permuted.
537d123e983SMichael Kruse /// @param SecondDim The second dimension to be permuted.
538d123e983SMichael Kruse static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
539d123e983SMichael Kruse                                                     unsigned FirstDim,
540d123e983SMichael Kruse                                                     unsigned SecondDim) {
541d123e983SMichael Kruse   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
542d123e983SMichael Kruse          (unsigned)isl_schedule_node_band_n_member(Node.get()) >
543d123e983SMichael Kruse              std::max(FirstDim, SecondDim));
544d123e983SMichael Kruse   auto PartialSchedule =
545d123e983SMichael Kruse       isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
546d3fdbda6SRiccardo Mori   auto PartialScheduleFirstDim = PartialSchedule.at(FirstDim);
547d3fdbda6SRiccardo Mori   auto PartialScheduleSecondDim = PartialSchedule.at(SecondDim);
548d123e983SMichael Kruse   PartialSchedule =
549d123e983SMichael Kruse       PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
550d123e983SMichael Kruse   PartialSchedule =
551d123e983SMichael Kruse       PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
552d123e983SMichael Kruse   Node = isl::manage(isl_schedule_node_delete(Node.release()));
553d123e983SMichael Kruse   return Node.insert_partial_schedule(PartialSchedule);
554d123e983SMichael Kruse }
555d123e983SMichael Kruse 
556d123e983SMichael Kruse static isl::schedule_node
557d123e983SMichael Kruse createMicroKernel(isl::schedule_node Node,
558d123e983SMichael Kruse                   MicroKernelParamsTy MicroKernelParams) {
559d123e983SMichael Kruse   Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
560d123e983SMichael Kruse                              1);
561d123e983SMichael Kruse   Node = Node.parent().parent();
562d123e983SMichael Kruse   return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
563d123e983SMichael Kruse }
564d123e983SMichael Kruse 
565d123e983SMichael Kruse /// Create the BLIS macro-kernel.
566d123e983SMichael Kruse ///
567d123e983SMichael Kruse /// We create the BLIS macro-kernel by applying a combination of tiling
568d123e983SMichael Kruse /// of dimensions of the band node and interchanging of two innermost
569678e3ee1SFangrui Song /// modified dimensions. The values of MacroKernelParams's fields are used
570d123e983SMichael Kruse /// as tile sizes.
571d123e983SMichael Kruse ///
572d123e983SMichael Kruse /// @param Node The schedule node to be modified.
573d123e983SMichael Kruse /// @param MacroKernelParams Parameters of the macro kernel
574d123e983SMichael Kruse ///                          to be used as tile sizes.
575d123e983SMichael Kruse static isl::schedule_node
576d123e983SMichael Kruse createMacroKernel(isl::schedule_node Node,
577d123e983SMichael Kruse                   MacroKernelParamsTy MacroKernelParams) {
578d123e983SMichael Kruse   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
579d123e983SMichael Kruse   if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 &&
580d123e983SMichael Kruse       MacroKernelParams.Kc == 1)
581d123e983SMichael Kruse     return Node;
582d123e983SMichael Kruse   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
583d123e983SMichael Kruse   std::vector<int> TileSizes(DimOutNum, 1);
584d123e983SMichael Kruse   TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
585d123e983SMichael Kruse   TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
586d123e983SMichael Kruse   TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
587d123e983SMichael Kruse   Node = tileNode(Node, "1st level tiling", TileSizes, 1);
588d123e983SMichael Kruse   Node = Node.parent().parent();
589d123e983SMichael Kruse   Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
590d123e983SMichael Kruse   Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
591d123e983SMichael Kruse 
592d123e983SMichael Kruse   return Node.child(0).child(0);
593d123e983SMichael Kruse }
594d123e983SMichael Kruse 
595d123e983SMichael Kruse /// Get the size of the widest type of the matrix multiplication operands
596d123e983SMichael Kruse /// in bytes, including alignment padding.
597d123e983SMichael Kruse ///
598d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands.
599d123e983SMichael Kruse /// @return The size of the widest type of the matrix multiplication operands
600d123e983SMichael Kruse ///         in bytes, including alignment padding.
6014c6ae8ebSKarthika Devi C static uint64_t getMatMulAlignTypeSize(const MatMulInfoTy &MMI) {
602d123e983SMichael Kruse   auto *S = MMI.A->getStatement()->getParent();
603d123e983SMichael Kruse   auto &DL = S->getFunction().getParent()->getDataLayout();
604d123e983SMichael Kruse   auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
605d123e983SMichael Kruse   auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
606d123e983SMichael Kruse   auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
607d123e983SMichael Kruse   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
608d123e983SMichael Kruse }
609d123e983SMichael Kruse 
610d123e983SMichael Kruse /// Get the size of the widest type of the matrix multiplication operands
611d123e983SMichael Kruse /// in bits.
612d123e983SMichael Kruse ///
613d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands.
614d123e983SMichael Kruse /// @return The size of the widest type of the matrix multiplication operands
615d123e983SMichael Kruse ///         in bits.
6164c6ae8ebSKarthika Devi C static uint64_t getMatMulTypeSize(const MatMulInfoTy &MMI) {
617d123e983SMichael Kruse   auto *S = MMI.A->getStatement()->getParent();
618d123e983SMichael Kruse   auto &DL = S->getFunction().getParent()->getDataLayout();
619d123e983SMichael Kruse   auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
620d123e983SMichael Kruse   auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
621d123e983SMichael Kruse   auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
622d123e983SMichael Kruse   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
623d123e983SMichael Kruse }
624d123e983SMichael Kruse 
625d123e983SMichael Kruse /// Get parameters of the BLIS micro kernel.
626d123e983SMichael Kruse ///
627d123e983SMichael Kruse /// We choose the Mr and Nr parameters of the micro kernel to be large enough
628d123e983SMichael Kruse /// such that no stalls caused by the combination of latencies and dependencies
629d123e983SMichael Kruse /// are introduced during the updates of the resulting matrix of the matrix
630d123e983SMichael Kruse /// multiplication. However, they should also be as small as possible to
631d123e983SMichael Kruse /// release more registers for entries of multiplied matrices.
632d123e983SMichael Kruse ///
633d123e983SMichael Kruse /// @param TTI Target Transform Info.
634d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands.
635d123e983SMichael Kruse /// @return The structure of type MicroKernelParamsTy.
636d123e983SMichael Kruse /// @see MicroKernelParamsTy
637bd93df93SMichael Kruse static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI,
6384c6ae8ebSKarthika Devi C                                                 const MatMulInfoTy &MMI) {
639d123e983SMichael Kruse   assert(TTI && "The target transform info should be provided.");
640d123e983SMichael Kruse 
641d123e983SMichael Kruse   // Nvec - Number of double-precision floating-point numbers that can be hold
642d123e983SMichael Kruse   // by a vector register. Use 2 by default.
643d123e983SMichael Kruse   long RegisterBitwidth = VectorRegisterBitwidth;
644d123e983SMichael Kruse 
645d123e983SMichael Kruse   if (RegisterBitwidth == -1)
646d123e983SMichael Kruse     RegisterBitwidth =
647d123e983SMichael Kruse         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
648d123e983SMichael Kruse   auto ElementSize = getMatMulTypeSize(MMI);
649d123e983SMichael Kruse   assert(ElementSize > 0 && "The element size of the matrix multiplication "
650d123e983SMichael Kruse                             "operands should be greater than zero.");
651d123e983SMichael Kruse   auto Nvec = RegisterBitwidth / ElementSize;
652d123e983SMichael Kruse   if (Nvec == 0)
653d123e983SMichael Kruse     Nvec = 2;
654d123e983SMichael Kruse   int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) /
655d123e983SMichael Kruse                 Nvec) *
656d123e983SMichael Kruse            Nvec;
657d123e983SMichael Kruse   int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr));
658d123e983SMichael Kruse   return {Mr, Nr};
659d123e983SMichael Kruse }
660d123e983SMichael Kruse 
661d123e983SMichael Kruse /// Determine parameters of the target cache.
662d123e983SMichael Kruse ///
663d123e983SMichael Kruse /// @param TTI Target Transform Info.
664d123e983SMichael Kruse static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
665d123e983SMichael Kruse   auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
666d123e983SMichael Kruse   auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
667d123e983SMichael Kruse   if (FirstCacheLevelSize == -1) {
66894460f51SKazu Hirata     if (TTI->getCacheSize(L1DCache))
6695cff5142SKazu Hirata       FirstCacheLevelSize = TTI->getCacheSize(L1DCache).value();
670d123e983SMichael Kruse     else
671d123e983SMichael Kruse       FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
672d123e983SMichael Kruse   }
673d123e983SMichael Kruse   if (SecondCacheLevelSize == -1) {
67494460f51SKazu Hirata     if (TTI->getCacheSize(L2DCache))
6755cff5142SKazu Hirata       SecondCacheLevelSize = TTI->getCacheSize(L2DCache).value();
676d123e983SMichael Kruse     else
677d123e983SMichael Kruse       SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
678d123e983SMichael Kruse   }
679d123e983SMichael Kruse   if (FirstCacheLevelAssociativity == -1) {
68094460f51SKazu Hirata     if (TTI->getCacheAssociativity(L1DCache))
681d123e983SMichael Kruse       FirstCacheLevelAssociativity =
6825cff5142SKazu Hirata           TTI->getCacheAssociativity(L1DCache).value();
683d123e983SMichael Kruse     else
684d123e983SMichael Kruse       FirstCacheLevelAssociativity =
685d123e983SMichael Kruse           static_cast<int>(FirstCacheLevelDefaultAssociativity);
686d123e983SMichael Kruse   }
687d123e983SMichael Kruse   if (SecondCacheLevelAssociativity == -1) {
68894460f51SKazu Hirata     if (TTI->getCacheAssociativity(L2DCache))
689d123e983SMichael Kruse       SecondCacheLevelAssociativity =
6905cff5142SKazu Hirata           TTI->getCacheAssociativity(L2DCache).value();
691d123e983SMichael Kruse     else
692d123e983SMichael Kruse       SecondCacheLevelAssociativity =
693d123e983SMichael Kruse           static_cast<int>(SecondCacheLevelDefaultAssociativity);
694d123e983SMichael Kruse   }
695d123e983SMichael Kruse }
696d123e983SMichael Kruse 
697d123e983SMichael Kruse /// Get parameters of the BLIS macro kernel.
698d123e983SMichael Kruse ///
699d123e983SMichael Kruse /// During the computation of matrix multiplication, blocks of partitioned
700d123e983SMichael Kruse /// matrices are mapped to different layers of the memory hierarchy.
701d123e983SMichael Kruse /// To optimize data reuse, blocks should be ideally kept in cache between
702d123e983SMichael Kruse /// iterations. Since parameters of the macro kernel determine sizes of these
703d123e983SMichael Kruse /// blocks, there are upper and lower bounds on these parameters.
704d123e983SMichael Kruse ///
705d123e983SMichael Kruse /// @param TTI Target Transform Info.
706d123e983SMichael Kruse /// @param MicroKernelParams Parameters of the micro-kernel
707d123e983SMichael Kruse ///                          to be taken into account.
708d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands.
709d123e983SMichael Kruse /// @return The structure of type MacroKernelParamsTy.
710d123e983SMichael Kruse /// @see MacroKernelParamsTy
711d123e983SMichael Kruse /// @see MicroKernelParamsTy
712bd93df93SMichael Kruse static MacroKernelParamsTy
713d123e983SMichael Kruse getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
714d123e983SMichael Kruse                      const MicroKernelParamsTy &MicroKernelParams,
7154c6ae8ebSKarthika Devi C                      const MatMulInfoTy &MMI) {
716d123e983SMichael Kruse   getTargetCacheParameters(TTI);
717d123e983SMichael Kruse   // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
718d123e983SMichael Kruse   // it requires information about the first two levels of a cache to determine
719d123e983SMichael Kruse   // all the parameters of a macro-kernel. It also checks that an associativity
720d123e983SMichael Kruse   // degree of a cache level is greater than two. Otherwise, another algorithm
721d123e983SMichael Kruse   // for determination of the parameters should be used.
722d123e983SMichael Kruse   if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
723d123e983SMichael Kruse         FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 &&
724d123e983SMichael Kruse         FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2))
725d123e983SMichael Kruse     return {1, 1, 1};
726d123e983SMichael Kruse   // The quotient should be greater than zero.
727d123e983SMichael Kruse   if (PollyPatternMatchingNcQuotient <= 0)
728d123e983SMichael Kruse     return {1, 1, 1};
729d123e983SMichael Kruse   int Car = floor(
730d123e983SMichael Kruse       (FirstCacheLevelAssociativity - 1) /
731d123e983SMichael Kruse       (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
732d123e983SMichael Kruse 
733d123e983SMichael Kruse   // Car can be computed to be zero since it is floor to int.
734d123e983SMichael Kruse   // On Mac OS, division by 0 does not raise a signal. This causes negative
735d123e983SMichael Kruse   // tile sizes to be computed. Prevent division by Cac==0 by early returning
736d123e983SMichael Kruse   // if this happens.
737d123e983SMichael Kruse   if (Car == 0)
738d123e983SMichael Kruse     return {1, 1, 1};
739d123e983SMichael Kruse 
740d123e983SMichael Kruse   auto ElementSize = getMatMulAlignTypeSize(MMI);
741d123e983SMichael Kruse   assert(ElementSize > 0 && "The element size of the matrix multiplication "
742d123e983SMichael Kruse                             "operands should be greater than zero.");
743d123e983SMichael Kruse   int Kc = (Car * FirstCacheLevelSize) /
744d123e983SMichael Kruse            (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
745d123e983SMichael Kruse   double Cac =
746d123e983SMichael Kruse       static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
747d123e983SMichael Kruse       SecondCacheLevelSize;
748d123e983SMichael Kruse   int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
749d123e983SMichael Kruse   int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
750d123e983SMichael Kruse 
751d123e983SMichael Kruse   assert(Mc > 0 && Nc > 0 && Kc > 0 &&
752d123e983SMichael Kruse          "Matrix block sizes should be  greater than zero");
753d123e983SMichael Kruse   return {Mc, Nc, Kc};
754d123e983SMichael Kruse }
755d123e983SMichael Kruse 
756d123e983SMichael Kruse /// Create an access relation that is specific to
757d123e983SMichael Kruse ///        the matrix multiplication pattern.
758d123e983SMichael Kruse ///
759d123e983SMichael Kruse /// Create an access relation of the following form:
760d123e983SMichael Kruse /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
761d123e983SMichael Kruse /// where I is @p FirstDim, J is @p SecondDim.
762d123e983SMichael Kruse ///
763d123e983SMichael Kruse /// It can be used, for example, to create relations that helps to consequently
764d123e983SMichael Kruse /// access elements of operands of a matrix multiplication after creation of
765d123e983SMichael Kruse /// the BLIS micro and macro kernels.
766d123e983SMichael Kruse ///
767d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMicroKernel
768d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMacroKernel
769d123e983SMichael Kruse ///
770d123e983SMichael Kruse /// Subsequently, the described access relation is applied to the range of
771d123e983SMichael Kruse /// @p MapOldIndVar, that is used to map original induction variables to
772d123e983SMichael Kruse /// the ones, which are produced by schedule transformations. It helps to
773d123e983SMichael Kruse /// define relations using a new space and, at the same time, keep them
774d123e983SMichael Kruse /// in the original one.
775d123e983SMichael Kruse ///
776d123e983SMichael Kruse /// @param MapOldIndVar The relation, which maps original induction variables
777d123e983SMichael Kruse ///                     to the ones, which are produced by schedule
778d123e983SMichael Kruse ///                     transformations.
779d123e983SMichael Kruse /// @param FirstDim, SecondDim The input dimensions that are used to define
780d123e983SMichael Kruse ///        the specified access relation.
781d123e983SMichael Kruse /// @return The specified access relation.
782d123e983SMichael Kruse static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
783d123e983SMichael Kruse                                 unsigned SecondDim) {
7840813bd16SRiccardo Mori   auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3);
785d123e983SMichael Kruse   auto AccessRel = isl::map::universe(AccessRelSpace);
786d123e983SMichael Kruse   AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
787d123e983SMichael Kruse   AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
788d123e983SMichael Kruse   AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
789d123e983SMichael Kruse   return MapOldIndVar.apply_range(AccessRel);
790d123e983SMichael Kruse }
791d123e983SMichael Kruse 
792d123e983SMichael Kruse static isl::schedule_node createExtensionNode(isl::schedule_node Node,
793d123e983SMichael Kruse                                               isl::map ExtensionMap) {
794d123e983SMichael Kruse   auto Extension = isl::union_map(ExtensionMap);
795d123e983SMichael Kruse   auto NewNode = isl::schedule_node::from_extension(Extension);
796d123e983SMichael Kruse   return Node.graft_before(NewNode);
797d123e983SMichael Kruse }
798d123e983SMichael Kruse 
799a56bd7deSMichael Kruse static isl::schedule_node optimizePackedB(isl::schedule_node Node,
800a56bd7deSMichael Kruse                                           ScopStmt *Stmt, isl::map MapOldIndVar,
801a56bd7deSMichael Kruse                                           MicroKernelParamsTy MicroParams,
802a56bd7deSMichael Kruse                                           MacroKernelParamsTy MacroParams,
803a56bd7deSMichael Kruse                                           MatMulInfoTy &MMI) {
804a56bd7deSMichael Kruse   Scop *S = Stmt->getParent();
805a56bd7deSMichael Kruse   isl::set Domain = Stmt->getDomain();
806a56bd7deSMichael Kruse 
807a56bd7deSMichael Kruse   // Create packed array.
808a56bd7deSMichael Kruse   unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
809a56bd7deSMichael Kruse   unsigned SecondDimSize = MacroParams.Kc;
810a56bd7deSMichael Kruse   unsigned ThirdDimSize = MicroParams.Nr;
811a56bd7deSMichael Kruse   ScopArrayInfo *PackedB =
812a56bd7deSMichael Kruse       S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B",
813a56bd7deSMichael Kruse                              {FirstDimSize, SecondDimSize, ThirdDimSize});
814a56bd7deSMichael Kruse 
815a56bd7deSMichael Kruse   // Compute the access relation for copying from B to PackedB.
816a56bd7deSMichael Kruse   isl::map AccRelB = MMI.B->getLatestAccessRelation();
817a56bd7deSMichael Kruse   isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7);
818a56bd7deSMichael Kruse   AccRelPackedB =
819a56bd7deSMichael Kruse       AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId());
820a56bd7deSMichael Kruse 
821a56bd7deSMichael Kruse   // Create the copy statement and redirect access.
822a56bd7deSMichael Kruse   ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain);
823a56bd7deSMichael Kruse   MMI.B->setNewAccessRelation(AccRelPackedB);
824a56bd7deSMichael Kruse 
82544596fe6SRiccardo Mori   unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
82644596fe6SRiccardo Mori   assert(Dim >= 2);
827a56bd7deSMichael Kruse   // Insert into the schedule tree.
82844596fe6SRiccardo Mori   isl::map ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, Dim - 2);
829a56bd7deSMichael Kruse   ExtMap = ExtMap.reverse();
830a56bd7deSMichael Kruse   ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
831a56bd7deSMichael Kruse   ExtMap = ExtMap.intersect_range(Domain);
832a56bd7deSMichael Kruse   ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId());
833a56bd7deSMichael Kruse   return createExtensionNode(Node, ExtMap);
834a56bd7deSMichael Kruse }
835a56bd7deSMichael Kruse 
836a56bd7deSMichael Kruse static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *,
837a56bd7deSMichael Kruse                                           isl::map MapOldIndVar,
838a56bd7deSMichael Kruse                                           MicroKernelParamsTy MicroParams,
839a56bd7deSMichael Kruse                                           MacroKernelParamsTy MacroParams,
840a56bd7deSMichael Kruse                                           MatMulInfoTy &MMI) {
841a56bd7deSMichael Kruse   isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
842a56bd7deSMichael Kruse   ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
843a56bd7deSMichael Kruse   isl::set Domain = Stmt->getDomain();
844a56bd7deSMichael Kruse   isl::id DomainId = Domain.get_tuple_id();
845a56bd7deSMichael Kruse 
846a56bd7deSMichael Kruse   // Create the packed array.
847a56bd7deSMichael Kruse   unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr;
848a56bd7deSMichael Kruse   unsigned SecondDimSize = MacroParams.Kc;
849a56bd7deSMichael Kruse   unsigned ThirdDimSize = MicroParams.Mr;
850a56bd7deSMichael Kruse   ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo(
851a56bd7deSMichael Kruse       MMI.A->getElementType(), "Packed_A",
852a56bd7deSMichael Kruse       {FirstDimSize, SecondDimSize, ThirdDimSize});
853a56bd7deSMichael Kruse 
854a56bd7deSMichael Kruse   // Compute the access relation for copying from A to PackedA.
855a56bd7deSMichael Kruse   isl::map AccRelA = MMI.A->getLatestAccessRelation();
856a56bd7deSMichael Kruse   isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6);
857a56bd7deSMichael Kruse   AccRelPackedA =
858a56bd7deSMichael Kruse       AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId());
859a56bd7deSMichael Kruse   // { MemrefA[] -> PackedA[] }
860a56bd7deSMichael Kruse   isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA);
861a56bd7deSMichael Kruse 
862a56bd7deSMichael Kruse   // Compute the domain for the copy statement.
863a56bd7deSMichael Kruse   // Construct the copy statement domain out of the 3 outermost scatter
864a56bd7deSMichael Kruse   // dimensions (to match the 3 band nodes surrounding the extension node) and
865a56bd7deSMichael Kruse   // the array elements to copy (one statement instance per array element).
866a56bd7deSMichael Kruse   // { Scatter[] }
867a56bd7deSMichael Kruse   isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range();
868a56bd7deSMichael Kruse   // { Scatter[] -> OutermostScatter[] }
869a56bd7deSMichael Kruse   isl::map OuterDomainMap =
870a56bd7deSMichael Kruse       makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6);
871a56bd7deSMichael Kruse   // { Scatter[] -> MemrefA[] }
872a56bd7deSMichael Kruse   isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA);
873a56bd7deSMichael Kruse   // { Scatter[] -> CopyStmt[] }
874a56bd7deSMichael Kruse   isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom);
875a56bd7deSMichael Kruse   // { CopyStmt[] }
876a56bd7deSMichael Kruse   isl::set CopyDomain = DomainTranslator.range();
877a56bd7deSMichael Kruse 
878a56bd7deSMichael Kruse   // Translate the access relations to the new domain.
879a56bd7deSMichael Kruse   // { CopyStmt[] -> MemrefA[] }
880a56bd7deSMichael Kruse   CopyFrom = CopyFrom.apply_domain(DomainTranslator);
881a56bd7deSMichael Kruse   // { CopyStmt[] -> PackedA[] }
882a56bd7deSMichael Kruse   isl::map CopyTo = CopyFrom.apply_range(PackedATranslator);
883a56bd7deSMichael Kruse 
884a56bd7deSMichael Kruse   // Create the copy statement and redirect access.
885a56bd7deSMichael Kruse   ScopStmt *CopyStmt =
886a56bd7deSMichael Kruse       Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain);
887a56bd7deSMichael Kruse   MMI.A->setNewAccessRelation(AccRelPackedA);
888a56bd7deSMichael Kruse 
889a56bd7deSMichael Kruse   // Insert into the schedule tree.
890a56bd7deSMichael Kruse   // { Scatter[] -> CopyStmt[] }
891a56bd7deSMichael Kruse   isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true);
892a56bd7deSMichael Kruse   ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2);
893a56bd7deSMichael Kruse   return createExtensionNode(Node, ExtScatterCopy);
894a56bd7deSMichael Kruse }
895a56bd7deSMichael Kruse 
896d123e983SMichael Kruse /// Apply the packing transformation.
897d123e983SMichael Kruse ///
898d123e983SMichael Kruse /// The packing transformation can be described as a data-layout
899d123e983SMichael Kruse /// transformation that requires to introduce a new array, copy data
900d123e983SMichael Kruse /// to the array, and change memory access locations to reference the array.
901d123e983SMichael Kruse /// It can be used to ensure that elements of the new array are read in-stride
902d123e983SMichael Kruse /// access, aligned to cache lines boundaries, and preloaded into certain cache
903d123e983SMichael Kruse /// levels.
904d123e983SMichael Kruse ///
905d123e983SMichael Kruse /// As an example let us consider the packing of the array A that would help
906d123e983SMichael Kruse /// to read its elements with in-stride access. An access to the array A
907d123e983SMichael Kruse /// is represented by an access relation that has the form
908d123e983SMichael Kruse /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
909d123e983SMichael Kruse /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
910d123e983SMichael Kruse /// k mod Kc, j mod Nr, i mod Mr].
911d123e983SMichael Kruse ///
912d123e983SMichael Kruse /// To ensure that elements of the array A are read in-stride access, we add
913d123e983SMichael Kruse /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
914d123e983SMichael Kruse /// Scop::createScopArrayInfo, change the access relation
915d123e983SMichael Kruse /// S[i, j, k] -> A[i, k] to
916d123e983SMichael Kruse /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
917d123e983SMichael Kruse /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
918d123e983SMichael Kruse /// the copy statement created by Scop::addScopStmt.
919d123e983SMichael Kruse ///
920d123e983SMichael Kruse /// @param Node The schedule node to be optimized.
921d123e983SMichael Kruse /// @param MapOldIndVar The relation, which maps original induction variables
922d123e983SMichael Kruse ///                     to the ones, which are produced by schedule
923d123e983SMichael Kruse ///                     transformations.
924d123e983SMichael Kruse /// @param MicroParams, MacroParams Parameters of the BLIS kernel
925d123e983SMichael Kruse ///                                 to be taken into account.
926d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands.
927d123e983SMichael Kruse /// @return The optimized schedule node.
928d123e983SMichael Kruse static isl::schedule_node
929d123e983SMichael Kruse optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
930d123e983SMichael Kruse                                  MicroKernelParamsTy MicroParams,
931d123e983SMichael Kruse                                  MacroKernelParamsTy MacroParams,
932d123e983SMichael Kruse                                  MatMulInfoTy &MMI) {
933a56bd7deSMichael Kruse   isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
934a56bd7deSMichael Kruse   ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
935d123e983SMichael Kruse 
936d123e983SMichael Kruse   Node = Node.parent().parent().parent().parent().parent().parent();
937a56bd7deSMichael Kruse   Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2));
938d123e983SMichael Kruse 
939d123e983SMichael Kruse   Node = Node.child(0);
940a56bd7deSMichael Kruse   Node =
941a56bd7deSMichael Kruse       optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
942d123e983SMichael Kruse 
943a56bd7deSMichael Kruse   Node = Node.child(0);
944a56bd7deSMichael Kruse   Node =
945a56bd7deSMichael Kruse       optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
946a56bd7deSMichael Kruse 
947d123e983SMichael Kruse   return Node.child(0).child(0).child(0).child(0).child(0);
948d123e983SMichael Kruse }
949d123e983SMichael Kruse 
950d123e983SMichael Kruse /// Get a relation mapping induction variables produced by schedule
951d123e983SMichael Kruse /// transformations to the original ones.
952d123e983SMichael Kruse ///
953d123e983SMichael Kruse /// @param Node The schedule node produced as the result of creation
954d123e983SMichael Kruse ///        of the BLIS kernels.
955d123e983SMichael Kruse /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
956d123e983SMichael Kruse ///                                             to be taken into account.
957d123e983SMichael Kruse /// @return  The relation mapping original induction variables to the ones
958d123e983SMichael Kruse ///          produced by schedule transformation.
959d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMicroKernel
960d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMacroKernel
961d123e983SMichael Kruse /// @see getMacroKernelParams
962d123e983SMichael Kruse static isl::map
963d123e983SMichael Kruse getInductionVariablesSubstitution(isl::schedule_node Node,
964d123e983SMichael Kruse                                   MicroKernelParamsTy MicroKernelParams,
965d123e983SMichael Kruse                                   MacroKernelParamsTy MacroKernelParams) {
966d123e983SMichael Kruse   auto Child = Node.child(0);
967d123e983SMichael Kruse   auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
968d123e983SMichael Kruse   auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
96944596fe6SRiccardo Mori   unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
97044596fe6SRiccardo Mori   if (Dim > 9u)
97144596fe6SRiccardo Mori     return MapOldIndVar.project_out(isl::dim::out, 0, Dim - 9);
972d123e983SMichael Kruse   return MapOldIndVar;
973d123e983SMichael Kruse }
974d123e983SMichael Kruse 
975d123e983SMichael Kruse /// Isolate a set of partial tile prefixes and unroll the isolated part.
976d123e983SMichael Kruse ///
977d123e983SMichael Kruse /// The set should ensure that it contains only partial tile prefixes that have
978d123e983SMichael Kruse /// exactly Mr x Nr iterations of the two innermost loops produced by
979d123e983SMichael Kruse /// the optimization of the matrix multiplication. Mr and Nr are parameters of
980d123e983SMichael Kruse /// the micro-kernel.
981d123e983SMichael Kruse ///
982d123e983SMichael Kruse /// In case of parametric bounds, this helps to auto-vectorize the unrolled
983d123e983SMichael Kruse /// innermost loops, using the SLP vectorizer.
984d123e983SMichael Kruse ///
985d123e983SMichael Kruse /// @param Node              The schedule node to be modified.
986d123e983SMichael Kruse /// @param MicroKernelParams Parameters of the micro-kernel
987d123e983SMichael Kruse ///                          to be taken into account.
988d123e983SMichael Kruse /// @return The modified isl_schedule_node.
989d123e983SMichael Kruse static isl::schedule_node
990d123e983SMichael Kruse isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
991bd93df93SMichael Kruse                                  MicroKernelParamsTy MicroKernelParams) {
992d3fdbda6SRiccardo Mori   isl::schedule_node Child = Node.child(0);
993d123e983SMichael Kruse   isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
994d123e983SMichael Kruse   isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
99544596fe6SRiccardo Mori   unsigned Dims = unsignedFromIslSize(Prefix.tuple_dim());
99644596fe6SRiccardo Mori   assert(Dims >= 1);
997d123e983SMichael Kruse   Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
998d123e983SMichael Kruse   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
999d123e983SMichael Kruse   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
1000d123e983SMichael Kruse 
1001d123e983SMichael Kruse   isl::union_set IsolateOption =
1002d123e983SMichael Kruse       getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
10030813bd16SRiccardo Mori   isl::ctx Ctx = Node.ctx();
1004d123e983SMichael Kruse   auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
1005d123e983SMichael Kruse   Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
1006d3fdbda6SRiccardo Mori   Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
1007d123e983SMichael Kruse   Node = Node.parent().parent().parent();
1008d123e983SMichael Kruse   IsolateOption = getIsolateOptions(Prefix, 3);
1009d123e983SMichael Kruse   Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
1010d3fdbda6SRiccardo Mori   Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
1011d123e983SMichael Kruse   Node = Node.child(0).child(0).child(0);
1012d123e983SMichael Kruse   return Node;
1013d123e983SMichael Kruse }
1014d123e983SMichael Kruse 
1015d123e983SMichael Kruse /// Insert "Loop Vectorizer Disabled" mark node.
1016d123e983SMichael Kruse ///
1017d123e983SMichael Kruse /// @param Node The child of the mark node to be inserted.
1018d123e983SMichael Kruse /// @return The modified isl_schedule_node.
1019d123e983SMichael Kruse static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
10200813bd16SRiccardo Mori   auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr);
1021d123e983SMichael Kruse   return Node.insert_mark(Id).child(0);
1022d123e983SMichael Kruse }
1023d123e983SMichael Kruse 
1024d123e983SMichael Kruse /// Restore the initial ordering of dimensions of the band node
1025d123e983SMichael Kruse ///
1026d123e983SMichael Kruse /// In case the band node represents all the dimensions of the iteration
1027d123e983SMichael Kruse /// domain, recreate the band node to restore the initial ordering of the
1028d123e983SMichael Kruse /// dimensions.
1029d123e983SMichael Kruse ///
1030d123e983SMichael Kruse /// @param Node The band node to be modified.
1031d123e983SMichael Kruse /// @return The modified schedule node.
1032d123e983SMichael Kruse static isl::schedule_node
1033d123e983SMichael Kruse getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
1034d123e983SMichael Kruse   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
1035d123e983SMichael Kruse   if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
1036d123e983SMichael Kruse     return Node;
1037d123e983SMichael Kruse   auto Domain = Node.get_universe_domain();
1038d123e983SMichael Kruse   assert(isl_union_set_n_set(Domain.get()) == 1);
1039d3fdbda6SRiccardo Mori   if (Node.get_schedule_depth().release() != 0 ||
104044596fe6SRiccardo Mori       (unsignedFromIslSize(isl::set(Domain).tuple_dim()) !=
104144596fe6SRiccardo Mori        unsignedFromIslSize(Node.as<isl::schedule_node_band>().n_member())))
1042d123e983SMichael Kruse     return Node;
1043d123e983SMichael Kruse   Node = isl::manage(isl_schedule_node_delete(Node.copy()));
1044d123e983SMichael Kruse   auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
1045d123e983SMichael Kruse   auto PartialScheduleMultiPwAff =
1046d123e983SMichael Kruse       isl::multi_union_pw_aff(PartialSchedulePwAff);
1047d123e983SMichael Kruse   PartialScheduleMultiPwAff =
1048d123e983SMichael Kruse       PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
1049d123e983SMichael Kruse   return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
1050d123e983SMichael Kruse }
1051d123e983SMichael Kruse 
1052d123e983SMichael Kruse static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node,
1053d123e983SMichael Kruse                                                 const TargetTransformInfo *TTI,
1054d123e983SMichael Kruse                                                 MatMulInfoTy &MMI) {
1055d123e983SMichael Kruse   assert(TTI && "The target transform info should be provided.");
1056d123e983SMichael Kruse   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
1057d123e983SMichael Kruse   assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
1058d123e983SMichael Kruse                           "and, consequently, the corresponding scheduling "
1059d123e983SMichael Kruse                           "functions have at least three dimensions.");
1060d123e983SMichael Kruse   Node = getBandNodeWithOriginDimOrder(Node);
1061d123e983SMichael Kruse   Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
1062d123e983SMichael Kruse   int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
1063d123e983SMichael Kruse   int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
1064d123e983SMichael Kruse   Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
1065d123e983SMichael Kruse   NewK = NewK == DimOutNum - 2 ? NewJ : NewK;
1066d123e983SMichael Kruse   Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
1067d123e983SMichael Kruse   auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
1068d123e983SMichael Kruse   auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
1069d123e983SMichael Kruse   Node = createMacroKernel(Node, MacroKernelParams);
1070d123e983SMichael Kruse   Node = createMicroKernel(Node, MicroKernelParams);
1071d123e983SMichael Kruse   if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 ||
1072d123e983SMichael Kruse       MacroKernelParams.Kc == 1)
1073d123e983SMichael Kruse     return Node;
1074d123e983SMichael Kruse   auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
1075d123e983SMichael Kruse                                                         MacroKernelParams);
10767c7978a1Spatacca   if (MapOldIndVar.is_null())
1077d123e983SMichael Kruse     return Node;
1078d123e983SMichael Kruse   Node = markLoopVectorizerDisabled(Node.parent()).child(0);
1079d123e983SMichael Kruse   Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
1080d123e983SMichael Kruse   return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
1081d123e983SMichael Kruse                                           MacroKernelParams, MMI);
1082d123e983SMichael Kruse }
1083d123e983SMichael Kruse 
1084d123e983SMichael Kruse /// Check if this node contains a partial schedule that could
1085d123e983SMichael Kruse ///        probably be optimized with analytical modeling.
1086d123e983SMichael Kruse ///
1087d123e983SMichael Kruse /// isMatrMultPattern tries to determine whether the following conditions
1088d123e983SMichael Kruse /// are true:
1089d123e983SMichael Kruse /// 1. the partial schedule contains only one statement.
1090d123e983SMichael Kruse /// 2. there are exactly three input dimensions.
1091d123e983SMichael Kruse /// 3. all memory accesses of the statement will have stride 0 or 1, if we
1092d123e983SMichael Kruse ///    interchange loops (switch the variable used in the inner loop to
1093d123e983SMichael Kruse ///    the outer loop).
1094d123e983SMichael Kruse /// 4. all memory accesses of the statement except from the last one, are
1095d123e983SMichael Kruse ///    read memory access and the last one is write memory access.
1096d123e983SMichael Kruse /// 5. all subscripts of the last memory access of the statement don't
1097d123e983SMichael Kruse ///    contain the variable used in the inner loop.
1098d123e983SMichael Kruse /// If this is the case, we could try to use an approach that is similar to
1099d123e983SMichael Kruse /// the one used to get close-to-peak performance of matrix multiplications.
1100d123e983SMichael Kruse ///
1101d123e983SMichael Kruse /// @param Node The node to check.
1102d123e983SMichael Kruse /// @param D    The SCoP dependencies.
1103d123e983SMichael Kruse /// @param MMI  Parameters of the matrix multiplication operands.
1104d123e983SMichael Kruse static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D,
1105d123e983SMichael Kruse                               MatMulInfoTy &MMI) {
1106d123e983SMichael Kruse   auto PartialSchedule = isl::manage(
1107d123e983SMichael Kruse       isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
1108b02c7e2bSRoman Gareev   if (isl_schedule_node_band_n_member(Node.get()) < 3 ||
1109d3fdbda6SRiccardo Mori       Node.get_schedule_depth().release() != 0 ||
1110d123e983SMichael Kruse       isl_union_map_n_map(PartialSchedule.get()) != 1)
1111d123e983SMichael Kruse     return false;
1112d123e983SMichael Kruse   auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
1113d123e983SMichael Kruse   if (containsMatrMult(NewPartialSchedule, D, MMI))
1114d123e983SMichael Kruse     return true;
1115d123e983SMichael Kruse   return false;
1116d123e983SMichael Kruse }
1117d123e983SMichael Kruse 
1118b02c7e2bSRoman Gareev /// Get the dimension size.
1119b02c7e2bSRoman Gareev ///
1120b02c7e2bSRoman Gareev /// Return the size of the dimension @p Pos, which is obtained from @p SAI.
1121b02c7e2bSRoman Gareev /// Return -1 in the case of the first dimension of a multi-dimensional array,
1122b02c7e2bSRoman Gareev /// since the ScopArrayInfo class does not carry size information.
1123b02c7e2bSRoman Gareev ///
1124b02c7e2bSRoman Gareev /// @param SAI The information about the array.
1125b02c7e2bSRoman Gareev /// @param Pos The position of the dimension.
1126b02c7e2bSRoman Gareev /// @return The size of the dimension.
1127b02c7e2bSRoman Gareev static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) {
1128b02c7e2bSRoman Gareev   if (Pos == 0)
1129b02c7e2bSRoman Gareev     return -1;
1130b02c7e2bSRoman Gareev   const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Pos);
1131b02c7e2bSRoman Gareev   assert(SCEVDimSize);
1132b02c7e2bSRoman Gareev   auto *ConstantDimSize = dyn_cast<const SCEVConstant>(SCEVDimSize);
1133b02c7e2bSRoman Gareev   assert(ConstantDimSize);
1134b02c7e2bSRoman Gareev   auto *IntDimSize = dyn_cast<ConstantInt>(ConstantDimSize->getValue());
1135b02c7e2bSRoman Gareev   assert(IntDimSize);
1136b02c7e2bSRoman Gareev   return IntDimSize->getSExtValue();
1137b02c7e2bSRoman Gareev }
1138b02c7e2bSRoman Gareev 
1139b02c7e2bSRoman Gareev /// Check whether the access relation has the specified form.
1140b02c7e2bSRoman Gareev ///
1141b02c7e2bSRoman Gareev /// Check that the access relation @p AccMap has the form T[I0, …, In], where
1142b02c7e2bSRoman Gareev /// indexes I0, …, In are specified by @p Dimensions.
1143b02c7e2bSRoman Gareev ///
1144b02c7e2bSRoman Gareev /// @param Domain     The domain of the access relation.
1145b02c7e2bSRoman Gareev /// @param AccMap     The access relation to be checked.
1146b02c7e2bSRoman Gareev /// @param Dimensions The permutation of the subset of the input dimensions.
1147b02c7e2bSRoman Gareev /// @return True if @p AccMap has the expected form and false,
1148b02c7e2bSRoman Gareev ///         otherwise.
1149b02c7e2bSRoman Gareev static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap,
1150b02c7e2bSRoman Gareev                                ArrayRef<int> Dimensions) {
1151b02c7e2bSRoman Gareev   isl::space Space = AccMap.get_space();
1152b02c7e2bSRoman Gareev   if (unsignedFromIslSize(Space.dim(isl::dim::out)) != Dimensions.size())
1153b02c7e2bSRoman Gareev     return false;
1154b02c7e2bSRoman Gareev 
1155b02c7e2bSRoman Gareev   // Create an access relation of the following form:
1156b02c7e2bSRoman Gareev   // [I0, …, Im] -> [Il, …, In], where indexes
1157b02c7e2bSRoman Gareev   // Il, …, In are specified by @p Dimensions.
1158b02c7e2bSRoman Gareev   isl::map PossibleTensor = isl::map::universe(Space);
1159b02c7e2bSRoman Gareev   unsigned DimInSize = unsignedFromIslSize(Space.dim(isl::dim::in));
1160b02c7e2bSRoman Gareev   for (unsigned i = 0; i < Dimensions.size(); i++) {
1161b02c7e2bSRoman Gareev     const int InPos = Dimensions[i];
1162b02c7e2bSRoman Gareev     if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0))
1163b02c7e2bSRoman Gareev       return false;
1164b02c7e2bSRoman Gareev     PossibleTensor =
1165b02c7e2bSRoman Gareev         PossibleTensor.equate(isl::dim::in, InPos, isl::dim::out, i);
1166b02c7e2bSRoman Gareev   }
1167b02c7e2bSRoman Gareev 
1168b02c7e2bSRoman Gareev   AccMap = AccMap.intersect_domain(Domain);
1169b02c7e2bSRoman Gareev   PossibleTensor = PossibleTensor.intersect_domain(Domain);
1170b02c7e2bSRoman Gareev 
1171b02c7e2bSRoman Gareev   // If AccMap != PossibleTensor here (the two maps have been gisted at
1172b02c7e2bSRoman Gareev   // this point), it means that the writes are not complete, or in other
1173b02c7e2bSRoman Gareev   // words, it is a Partial write and Partial writes must be rejected.
1174b02c7e2bSRoman Gareev   return AccMap.is_equal(PossibleTensor);
1175b02c7e2bSRoman Gareev }
1176b02c7e2bSRoman Gareev 
1177b02c7e2bSRoman Gareev /// Check whether the access represents the tensor contraction operand.
1178b02c7e2bSRoman Gareev ///
1179b02c7e2bSRoman Gareev /// Check that the access relation @p AccMap has the form T[i1, …, in].
1180b02c7e2bSRoman Gareev /// Obtained indexes i1, …, in, their sizes and their permutation are stored
1181b02c7e2bSRoman Gareev /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively.
1182b02c7e2bSRoman Gareev ///
1183b02c7e2bSRoman Gareev /// @param Domain         The domain of the access relation.
1184b02c7e2bSRoman Gareev /// @param AccMap         The access relation to be checked.
1185b02c7e2bSRoman Gareev /// @param IndexSet       The subset of the input dimensions.
1186b02c7e2bSRoman Gareev /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions.
1187b02c7e2bSRoman Gareev /// @param Dimensions     The permutation of the subset of the input dimensions.
1188b02c7e2bSRoman Gareev /// @return True if @p AccMap has the expected form and false,
1189b02c7e2bSRoman Gareev ///         otherwise.
1190b02c7e2bSRoman Gareev static bool isTCOperandAcc(isl::set Domain, isl::map AccMap,
1191b02c7e2bSRoman Gareev                            SmallDenseSet<int> &IndexSet,
1192b02c7e2bSRoman Gareev                            SmallVectorImpl<int> &DimensionSizes,
1193b02c7e2bSRoman Gareev                            SmallVectorImpl<int> &Dimensions) {
1194b02c7e2bSRoman Gareev   isl::id Id = AccMap.get_tuple_id(isl::dim::out);
1195b02c7e2bSRoman Gareev   const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id);
1196b02c7e2bSRoman Gareev   assert(SAI && "AccMap should represent memory access");
1197b02c7e2bSRoman Gareev 
1198b02c7e2bSRoman Gareev   // Fix values of output dimensions with respect to their positions.
1199b02c7e2bSRoman Gareev   // In the case of the tensor contraction, values of output dimensions are
1200b02c7e2bSRoman Gareev   // fixed and form a permutation of a subset of values of input dimensions.
1201b02c7e2bSRoman Gareev   //
1202b02c7e2bSRoman Gareev   // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents
1203b02c7e2bSRoman Gareev   // the operand of the tensor contraction, we get the following map by fixing
1204b02c7e2bSRoman Gareev   // the output dimensions Stmt[1][j][0] -> A[0][1].
1205b02c7e2bSRoman Gareev   //
1206b02c7e2bSRoman Gareev   // We store the permutation of the subset of the input dimensions {2, 0} into
1207b02c7e2bSRoman Gareev   // @p Dimensions.
1208b02c7e2bSRoman Gareev   //
1209b02c7e2bSRoman Gareev   // The obtained permutation and the isCorrectAccessMap function are used to
1210b02c7e2bSRoman Gareev   // check whether the access relation @p AccMap represents the tensor
1211b02c7e2bSRoman Gareev   // contraction operand. For example, in the case of
1212b02c7e2bSRoman Gareev   // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and,
1213b02c7e2bSRoman Gareev   // consequently, {1, 0}, which is rejected by isCorrectAccessMap,
1214b02c7e2bSRoman Gareev   // since it corresponds to Stmt[i][j][k] -> A[j][i].
1215b02c7e2bSRoman Gareev   isl::map CheckMap = isl::manage(AccMap.copy());
1216b02c7e2bSRoman Gareev   unsigned OutDimNum = unsignedFromIslSize(CheckMap.dim(isl::dim::out));
1217b02c7e2bSRoman Gareev   for (unsigned i = 0; i < OutDimNum; i++)
1218b02c7e2bSRoman Gareev     CheckMap = CheckMap.fix_si(isl::dim::out, i, i);
1219b02c7e2bSRoman Gareev 
1220b02c7e2bSRoman Gareev   // Try to obtain the permutation and sizes of corresponding input dimensions.
1221b02c7e2bSRoman Gareev   Dimensions.assign(OutDimNum, -1);
1222b02c7e2bSRoman Gareev   for (unsigned i : rangeIslSize(0, CheckMap.dim(isl::dim::in))) {
1223b02c7e2bSRoman Gareev     isl::val Val = getConstant(CheckMap, isl::dim::in, i);
1224b02c7e2bSRoman Gareev     if (!Val.is_int())
1225b02c7e2bSRoman Gareev       continue;
1226b02c7e2bSRoman Gareev     int OutPos = -1;
1227b02c7e2bSRoman Gareev     llvm::APInt ValAPInt = APIntFromVal(Val);
1228b02c7e2bSRoman Gareev     if (ValAPInt.isSignedIntN(32))
1229b02c7e2bSRoman Gareev       OutPos = ValAPInt.getSExtValue();
1230b02c7e2bSRoman Gareev     if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) ||
1231b02c7e2bSRoman Gareev         IndexSet.count(i))
1232b02c7e2bSRoman Gareev       return false;
1233b02c7e2bSRoman Gareev     IndexSet.insert(i);
1234b02c7e2bSRoman Gareev     Dimensions[OutPos] = i;
1235b02c7e2bSRoman Gareev     if (DimensionSizes[i] <= 0)
1236b02c7e2bSRoman Gareev       DimensionSizes[i] = getDimSize(SAI, OutPos);
1237b02c7e2bSRoman Gareev   }
1238b02c7e2bSRoman Gareev 
1239b02c7e2bSRoman Gareev   return isCorrectAccessMap(Domain, AccMap, Dimensions);
1240b02c7e2bSRoman Gareev }
1241b02c7e2bSRoman Gareev 
1242b02c7e2bSRoman Gareev /// Find the intersection of two sets.
1243b02c7e2bSRoman Gareev ///
1244b02c7e2bSRoman Gareev /// Find the intersection of the set @p A and the set @p B.
1245b02c7e2bSRoman Gareev ///
1246b02c7e2bSRoman Gareev /// @param A, B Sets to intersect.
1247b02c7e2bSRoman Gareev /// @return The set intersection.
1248b02c7e2bSRoman Gareev static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A,
1249b02c7e2bSRoman Gareev                                     const SmallDenseSet<int> &B) {
1250b02c7e2bSRoman Gareev   SmallDenseSet<int> Intersection = A;
1251b02c7e2bSRoman Gareev   set_intersect(Intersection, B);
1252b02c7e2bSRoman Gareev   return Intersection;
1253b02c7e2bSRoman Gareev }
1254b02c7e2bSRoman Gareev 
1255b02c7e2bSRoman Gareev /// Check whether the set is a superset.
1256b02c7e2bSRoman Gareev ///
1257b02c7e2bSRoman Gareev /// Check that the set @p A is a superset of @p B.
1258b02c7e2bSRoman Gareev ///
1259b02c7e2bSRoman Gareev /// @param A, B Sets to be checked.
1260b02c7e2bSRoman Gareev /// @return True if the set A is a superset of B.
1261b02c7e2bSRoman Gareev static bool isSuperset(const SmallDenseSet<int> &A,
1262b02c7e2bSRoman Gareev                        const SmallDenseSet<int> &B) {
1263b02c7e2bSRoman Gareev   return intersect(A, B).size() == B.size();
1264b02c7e2bSRoman Gareev }
1265b02c7e2bSRoman Gareev 
1266b02c7e2bSRoman Gareev /// Find the union of two sets.
1267b02c7e2bSRoman Gareev ///
1268b02c7e2bSRoman Gareev /// Find the union of the set @p A and the set @p B.
1269b02c7e2bSRoman Gareev ///
1270b02c7e2bSRoman Gareev /// @param A, B Sets to unite.
1271b02c7e2bSRoman Gareev /// @return The set union.
1272b02c7e2bSRoman Gareev static SmallDenseSet<int> unite(const SmallDenseSet<int> &A,
1273b02c7e2bSRoman Gareev                                 const SmallDenseSet<int> &B) {
1274b02c7e2bSRoman Gareev   SmallDenseSet<int> Union = A;
1275b02c7e2bSRoman Gareev   set_union(Union, B);
1276b02c7e2bSRoman Gareev   return Union;
1277b02c7e2bSRoman Gareev }
1278b02c7e2bSRoman Gareev 
1279b02c7e2bSRoman Gareev /// Determine the access that writes to the tensor, which contains
1280b02c7e2bSRoman Gareev /// the result of the tensor contraction.
1281b02c7e2bSRoman Gareev ///
1282b02c7e2bSRoman Gareev /// @param Domain        The domain of the statement.
1283b02c7e2bSRoman Gareev /// @param Stmt          The statement, which writes to memory.
1284b02c7e2bSRoman Gareev /// @param TCI           The information about the tensor contraction.
1285b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors.
1286b02c7e2bSRoman Gareev /// @return The determined MemoryAccess, or nullptr if there is no necessary
1287b02c7e2bSRoman Gareev ///         access within the SCoP.
1288b02c7e2bSRoman Gareev static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt,
1289b02c7e2bSRoman Gareev                                     TCInfoTy &TCI,
1290b02c7e2bSRoman Gareev                                     SmallDenseSet<int> &IandJIndexSet) {
1291b02c7e2bSRoman Gareev   TCI.WriteToC = nullptr;
1292b02c7e2bSRoman Gareev   SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
1293b02c7e2bSRoman Gareev   for (MemoryAccess *MemA : reverse(Accesses)) {
1294b02c7e2bSRoman Gareev     // A TC-like does not contain write scalar memory accesses
1295b02c7e2bSRoman Gareev     if (!MemA->isLatestArrayKind())
1296b02c7e2bSRoman Gareev       return nullptr;
1297b02c7e2bSRoman Gareev     // The last memory access should be a write memory access.
1298b02c7e2bSRoman Gareev     if (!MemA->isWrite())
1299b02c7e2bSRoman Gareev       return nullptr;
1300b02c7e2bSRoman Gareev 
1301b02c7e2bSRoman Gareev     isl::map AccMap = MemA->getLatestAccessRelation();
1302b02c7e2bSRoman Gareev     if (!isTCOperandAcc(Domain, AccMap, IandJIndexSet, TCI.DimensionSizes,
1303b02c7e2bSRoman Gareev                         TCI.CDimensions))
1304b02c7e2bSRoman Gareev       return nullptr;
1305b02c7e2bSRoman Gareev 
1306b02c7e2bSRoman Gareev     return MemA;
1307b02c7e2bSRoman Gareev   }
1308b02c7e2bSRoman Gareev   return nullptr;
1309b02c7e2bSRoman Gareev }
1310b02c7e2bSRoman Gareev 
1311b02c7e2bSRoman Gareev /// Determine an access, which reads elements of an operand of the tensor
1312b02c7e2bSRoman Gareev /// contraction
1313b02c7e2bSRoman Gareev ///
1314b02c7e2bSRoman Gareev /// @param MemAccessPtr  The access, which reads elements of the tensor.
1315b02c7e2bSRoman Gareev /// @param IndexSet      The set, which contains indexes of the tensors.
1316b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors.
1317b02c7e2bSRoman Gareev /// @param Dimensions    The permutation of the subset of the input dimensions.
1318b02c7e2bSRoman Gareev /// @param TCI           The information about the tensor contraction.
1319b02c7e2bSRoman Gareev /// @return True if the memory access @p MemAccessPtr corresponds
1320b02c7e2bSRoman Gareev ///         to the tensor contraction.
1321b02c7e2bSRoman Gareev static bool setReadAccess(MemoryAccess *MemAccessPtr,
1322b02c7e2bSRoman Gareev                           const SmallDenseSet<int> &IndexSet,
1323b02c7e2bSRoman Gareev                           const SmallDenseSet<int> &IandJIndexSet,
1324b02c7e2bSRoman Gareev                           ArrayRef<int> Dimensions, TCInfoTy &TCI) {
1325b02c7e2bSRoman Gareev   if (!TCI.A) {
1326b02c7e2bSRoman Gareev     // Probably IndexSet is a union of I and P sets.
1327b02c7e2bSRoman Gareev     if (!isSuperset(IndexSet, TCI.P))
1328b02c7e2bSRoman Gareev       return false;
1329b02c7e2bSRoman Gareev 
1330b02c7e2bSRoman Gareev     // Obtain the set I.
1331b02c7e2bSRoman Gareev     TCI.I = set_difference(IndexSet, TCI.P);
1332b02c7e2bSRoman Gareev     if (!isSuperset(IandJIndexSet, TCI.I))
1333b02c7e2bSRoman Gareev       return false;
1334b02c7e2bSRoman Gareev 
1335b02c7e2bSRoman Gareev     // Obtain the set J.
1336b02c7e2bSRoman Gareev     TCI.J = set_difference(IandJIndexSet, TCI.I);
1337b02c7e2bSRoman Gareev 
1338b02c7e2bSRoman Gareev     // Set the first operand of the tensor contraction.
1339b02c7e2bSRoman Gareev     TCI.A = MemAccessPtr;
1340b02c7e2bSRoman Gareev     llvm::replace(TCI.ADimensions, TCI.ADimensions.begin(),
1341b02c7e2bSRoman Gareev                   TCI.ADimensions.end(), Dimensions.begin(), Dimensions.end());
1342b02c7e2bSRoman Gareev     return true;
1343b02c7e2bSRoman Gareev   }
1344b02c7e2bSRoman Gareev 
1345b02c7e2bSRoman Gareev   if (!TCI.B) {
1346b02c7e2bSRoman Gareev     // IndexSet should be a union of J and P sets.
1347b02c7e2bSRoman Gareev     if (unite(TCI.P, TCI.J) != IndexSet)
1348b02c7e2bSRoman Gareev       return false;
1349b02c7e2bSRoman Gareev 
1350b02c7e2bSRoman Gareev     // Set the second operand of the tensor contraction.
1351b02c7e2bSRoman Gareev     TCI.B = MemAccessPtr;
1352b02c7e2bSRoman Gareev     llvm::replace(TCI.BDimensions, TCI.BDimensions.begin(),
1353b02c7e2bSRoman Gareev                   TCI.BDimensions.end(), Dimensions.begin(), Dimensions.end());
1354b02c7e2bSRoman Gareev     return true;
1355b02c7e2bSRoman Gareev   }
1356b02c7e2bSRoman Gareev 
1357b02c7e2bSRoman Gareev   return false;
1358b02c7e2bSRoman Gareev }
1359b02c7e2bSRoman Gareev 
1360b02c7e2bSRoman Gareev /// Check that all memory accesses of the statement, except from the last
1361b02c7e2bSRoman Gareev /// one, are read memory accesses, which read elements of operands of the tensor
1362b02c7e2bSRoman Gareev /// contraction and its result.
1363b02c7e2bSRoman Gareev ///
1364b02c7e2bSRoman Gareev /// @param Domain        The domain of the statement.
1365b02c7e2bSRoman Gareev /// @param Stmt          The statement, which writes to memory.
1366b02c7e2bSRoman Gareev /// @param TCI           The information about the tensor contraction.
1367b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors.
1368b02c7e2bSRoman Gareev /// @return True if all read memory accesses of the statement @p Stmt correspond
1369b02c7e2bSRoman Gareev ///         to the tensor contraction.
1370b02c7e2bSRoman Gareev static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI,
1371b02c7e2bSRoman Gareev                             SmallDenseSet<int> &IandJIndexSet) {
1372b02c7e2bSRoman Gareev   TCI.A = nullptr;
1373b02c7e2bSRoman Gareev   TCI.B = nullptr;
1374b02c7e2bSRoman Gareev   TCI.ReadFromC = nullptr;
1375b02c7e2bSRoman Gareev   SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
1376b02c7e2bSRoman Gareev   for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
1377b02c7e2bSRoman Gareev     MemoryAccess *MemAccessPtr = *MemA;
1378b02c7e2bSRoman Gareev 
1379b02c7e2bSRoman Gareev     // All memory accesses, except from the last one, should be read memory
1380b02c7e2bSRoman Gareev     // accesses.
1381b02c7e2bSRoman Gareev     if (MemAccessPtr->isWrite())
1382b02c7e2bSRoman Gareev       return false;
1383b02c7e2bSRoman Gareev 
1384b02c7e2bSRoman Gareev     isl::map AccMap = MemAccessPtr->getLatestAccessRelation();
1385b02c7e2bSRoman Gareev 
1386b02c7e2bSRoman Gareev     if (!MemAccessPtr->isLatestArrayKind()) {
1387b02c7e2bSRoman Gareev       // Check whether the scalar read memory access is not partial.
1388b02c7e2bSRoman Gareev       if (!Domain.is_subset(AccMap.domain()))
1389b02c7e2bSRoman Gareev         return false;
1390b02c7e2bSRoman Gareev       continue;
1391b02c7e2bSRoman Gareev       return false;
1392b02c7e2bSRoman Gareev     }
1393b02c7e2bSRoman Gareev 
1394b02c7e2bSRoman Gareev     // There is only one memory access, which reads elements of the result of
1395b02c7e2bSRoman Gareev     // the tensor contraction.
1396b02c7e2bSRoman Gareev     if (AccMap.is_equal(TCI.WriteToC->getLatestAccessRelation())) {
1397b02c7e2bSRoman Gareev       if (TCI.ReadFromC)
1398b02c7e2bSRoman Gareev         return false;
1399b02c7e2bSRoman Gareev       TCI.ReadFromC = MemAccessPtr;
1400b02c7e2bSRoman Gareev       continue;
1401b02c7e2bSRoman Gareev     }
1402b02c7e2bSRoman Gareev 
1403b02c7e2bSRoman Gareev     SmallVector<int> Dimensions;
1404b02c7e2bSRoman Gareev     SmallDenseSet<int> IndexSet;
1405b02c7e2bSRoman Gareev     if (!isTCOperandAcc(Domain, AccMap, IndexSet, TCI.DimensionSizes,
1406b02c7e2bSRoman Gareev                         Dimensions))
1407b02c7e2bSRoman Gareev       return false;
1408b02c7e2bSRoman Gareev 
1409b02c7e2bSRoman Gareev     if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI))
1410b02c7e2bSRoman Gareev       return false;
1411b02c7e2bSRoman Gareev   }
1412b02c7e2bSRoman Gareev 
1413b02c7e2bSRoman Gareev   // Check that there are read memory accesses, which read elements of operands
1414b02c7e2bSRoman Gareev   // of the tensor contraction and its result.
1415b02c7e2bSRoman Gareev   return TCI.ReadFromC && TCI.A && TCI.B;
1416b02c7e2bSRoman Gareev }
1417b02c7e2bSRoman Gareev 
1418b02c7e2bSRoman Gareev /// Check accesses to operands of the tensor contraction.
1419b02c7e2bSRoman Gareev ///
1420b02c7e2bSRoman Gareev /// Check that accesses of the SCoP statement, which corresponds to
1421b02c7e2bSRoman Gareev /// the partial schedule @p PartialSchedule, represent accesses
1422b02c7e2bSRoman Gareev /// to the non-scalar operands of the tensor contraction.
1423b02c7e2bSRoman Gareev ///
1424b02c7e2bSRoman Gareev /// @param  Domain          The domain of the SCoP statement.
1425b02c7e2bSRoman Gareev /// @param  PartialSchedule The partial schedule of the SCoP statement.
1426b02c7e2bSRoman Gareev /// @param  TCI             Parameters of the tensor contraction operands.
1427b02c7e2bSRoman Gareev /// @return                 True if the corresponding SCoP statement
1428b02c7e2bSRoman Gareev ///                         represents tensor contraction and false,
1429b02c7e2bSRoman Gareev ///                         otherwise.
1430b02c7e2bSRoman Gareev static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule,
1431b02c7e2bSRoman Gareev                               TCInfoTy &TCI) {
1432b02c7e2bSRoman Gareev   isl::id InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
1433b02c7e2bSRoman Gareev   ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
1434b02c7e2bSRoman Gareev 
1435b02c7e2bSRoman Gareev   // In region statements, the order of memory accesses execution is not
1436b02c7e2bSRoman Gareev   // predictable at compile-time.
1437b02c7e2bSRoman Gareev   if ((Stmt->size() <= 1) || Stmt->isRegionStmt())
1438b02c7e2bSRoman Gareev     return false;
1439b02c7e2bSRoman Gareev 
1440b02c7e2bSRoman Gareev   unsigned DimNum = unsignedFromIslSize(PartialSchedule.dim(isl::dim::in));
1441b02c7e2bSRoman Gareev   TCI.DimensionSizes.resize(DimNum);
1442b02c7e2bSRoman Gareev   SmallDenseSet<int> IandJIndexSet;
1443b02c7e2bSRoman Gareev 
1444b02c7e2bSRoman Gareev   TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet);
1445b02c7e2bSRoman Gareev   if (!TCI.WriteToC)
1446b02c7e2bSRoman Gareev     return false;
1447b02c7e2bSRoman Gareev 
1448b02c7e2bSRoman Gareev   if (intersect(IandJIndexSet, TCI.P).size() != 0)
1449b02c7e2bSRoman Gareev     return false;
1450b02c7e2bSRoman Gareev 
1451b02c7e2bSRoman Gareev   if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet))
1452b02c7e2bSRoman Gareev     return false;
1453b02c7e2bSRoman Gareev 
1454b02c7e2bSRoman Gareev   return true;
1455b02c7e2bSRoman Gareev }
1456b02c7e2bSRoman Gareev 
1457b02c7e2bSRoman Gareev /// Check that dependency corresponds to the tensor contraction carried over
1458b02c7e2bSRoman Gareev /// loop dimension @p Dim.
1459b02c7e2bSRoman Gareev ///
1460b02c7e2bSRoman Gareev /// Check that the dependency has the form
1461b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
1462b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
1463b02c7e2bSRoman Gareev /// statement. For this purpose, we analyze the set @p DepDelta, which
1464b02c7e2bSRoman Gareev /// represents the differences between image elements and domain elements of
1465b02c7e2bSRoman Gareev /// the corresponding map.
1466b02c7e2bSRoman Gareev ///
1467b02c7e2bSRoman Gareev /// @param  DepDelta    The set contains the differences between image elements
1468b02c7e2bSRoman Gareev ///                     and corresponding domain elements of the map, which
1469b02c7e2bSRoman Gareev ///                     represents the dependency.
1470b02c7e2bSRoman Gareev /// @param  Dim         The position of the index ki.
1471b02c7e2bSRoman Gareev /// @param  BoundDeltas In the case of indexes of ki, the difference between
1472b02c7e2bSRoman Gareev ///                     image elements and corresponding domain elements
1473b02c7e2bSRoman Gareev ///                     corresponds to the difference between lexicographic
1474b02c7e2bSRoman Gareev ///                     minimum and lexicographic maximum of the corresponding
1475b02c7e2bSRoman Gareev ///                     dimension of the domain of the statement.
1476b02c7e2bSRoman Gareev /// @param  IndexSet    Obtained indexes ki, which describe the dependency.
1477b02c7e2bSRoman Gareev /// @return True if dependencies correspond to the tensor contraction
1478b02c7e2bSRoman Gareev ///         and false, otherwise.
1479b02c7e2bSRoman Gareev static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim,
1480b02c7e2bSRoman Gareev                                       isl::pw_multi_aff BoundDeltas,
1481b02c7e2bSRoman Gareev                                       const SmallDenseSet<int> &IndexSet) {
1482b02c7e2bSRoman Gareev   isl::space Space = DepDelta.get_space();
1483b02c7e2bSRoman Gareev   isl::set Superset = isl::set::universe(Space);
1484b02c7e2bSRoman Gareev   for (unsigned i = 0; i < Dim; i += 1)
1485b02c7e2bSRoman Gareev     Superset = Superset.fix_si(isl::dim::set, i, 0);
1486b02c7e2bSRoman Gareev   Superset = Superset.fix_si(isl::dim::set, Dim, 1);
1487b02c7e2bSRoman Gareev 
1488b02c7e2bSRoman Gareev   // Check that the difference between the image element and the domain element
1489b02c7e2bSRoman Gareev   // is equal to one in the case of the index ki. Image elements and
1490b02c7e2bSRoman Gareev   // corresponding domain elements should be equal in the case of positions,
1491b02c7e2bSRoman Gareev   // which are lower than the specified position.
1492b02c7e2bSRoman Gareev   if (!DepDelta.is_subset(Superset))
1493b02c7e2bSRoman Gareev     return false;
1494b02c7e2bSRoman Gareev 
1495b02c7e2bSRoman Gareev   // Compute a set, which is used to analyze how values of
1496b02c7e2bSRoman Gareev   // the domain are related to the map that describes the dependency.
1497b02c7e2bSRoman Gareev   isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(DepDelta.copy());
1498b02c7e2bSRoman Gareev   BoundDeltas = BoundDeltas.add(isl::manage(DepDeltaPW));
1499b02c7e2bSRoman Gareev   isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(BoundDeltas.release());
1500b02c7e2bSRoman Gareev   isl::set Complement = isl::manage(ComplementRawSet);
1501b02c7e2bSRoman Gareev 
1502b02c7e2bSRoman Gareev   for (unsigned i : rangeIslSize(Dim + 1, DepDelta.dim(isl::dim::set))) {
1503b02c7e2bSRoman Gareev     if (!IndexSet.count(i)) {
1504b02c7e2bSRoman Gareev       // Check the difference between the image element and the domain element
1505b02c7e2bSRoman Gareev       // in the case of indexes, which do not describe the dependency.
1506b02c7e2bSRoman Gareev       if (DepDelta.plain_get_val_if_fixed(isl::dim::set, i).is_zero())
1507b02c7e2bSRoman Gareev         continue;
1508b02c7e2bSRoman Gareev       return false;
1509b02c7e2bSRoman Gareev     }
1510b02c7e2bSRoman Gareev 
1511b02c7e2bSRoman Gareev     // In the case of other indexes, which describe the dependency,
1512b02c7e2bSRoman Gareev     // the difference between the image element and the domain element
1513b02c7e2bSRoman Gareev     // should be equal to the difference between lexicographic minimum and
1514b02c7e2bSRoman Gareev     // lexicographic maximum of the domain of the statement.
1515b02c7e2bSRoman Gareev     if (!Complement.plain_get_val_if_fixed(isl::dim::set, i).is_zero())
1516b02c7e2bSRoman Gareev       return false;
1517b02c7e2bSRoman Gareev   }
1518b02c7e2bSRoman Gareev 
1519b02c7e2bSRoman Gareev   return true;
1520b02c7e2bSRoman Gareev }
1521b02c7e2bSRoman Gareev 
1522b02c7e2bSRoman Gareev /// Check whether dependencies are over the complete domain.
1523b02c7e2bSRoman Gareev ///
1524b02c7e2bSRoman Gareev /// In the case of the tensor contraction RAW, WAW, WAR dependencies
1525b02c7e2bSRoman Gareev /// have the form
1526b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
1527b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
1528b02c7e2bSRoman Gareev /// statement. Consequently, the domain of the dependencies
1529b02c7e2bSRoman Gareev /// can be described as
1530b02c7e2bSRoman Gareev /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…),
1531b02c7e2bSRoman Gareev /// where Domain is the domain of the statement S.
1532b02c7e2bSRoman Gareev ///
1533b02c7e2bSRoman Gareev /// For example, in the case of the following tensor contraction,
1534b02c7e2bSRoman Gareev /// corresponding domains will have the following form.
1535b02c7e2bSRoman Gareev ///
1536b02c7e2bSRoman Gareev /// An example of the tensor contraction:
1537b02c7e2bSRoman Gareev /// for (i = 0; i < 1024; i++)
1538b02c7e2bSRoman Gareev ///   for (j = 0; j < 1024; j++)
1539b02c7e2bSRoman Gareev ///     for (l = 0; l < 64; ++l)
1540b02c7e2bSRoman Gareev ///       for (w = 0; w < 64; ++w)
1541b02c7e2bSRoman Gareev ///         C[i][j] += A[i][l][w] * B[w][j][l];
1542b02c7e2bSRoman Gareev ///
1543b02c7e2bSRoman Gareev /// The domain of the statement:
1544b02c7e2bSRoman Gareev /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and
1545b02c7e2bSRoman Gareev ///                       i1 >= 0 and i1 <= 1023 and
1546b02c7e2bSRoman Gareev ///                       i2 >= 0 and i2 <= 63 and
1547b02c7e2bSRoman Gareev ///                       i3 >= 0 and i3 <= 63 }
1548b02c7e2bSRoman Gareev ///
1549b02c7e2bSRoman Gareev /// The domain of the dependencies:
1550b02c7e2bSRoman Gareev /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and
1551b02c7e2bSRoman Gareev ///                        i1 >= 0 and i1 <= 1023 and
1552b02c7e2bSRoman Gareev ///                        i2 >= 0 and i2 <= 63 and
1553b02c7e2bSRoman Gareev ///                        i3 >= 0 and i3 <= 62) or
1554b02c7e2bSRoman Gareev ///                       (i3 = 63 and i0 >= 0 and i0 <= 1023 and
1555b02c7e2bSRoman Gareev ///                        i1 >= 0 and i1 <= 1023 and
1556b02c7e2bSRoman Gareev ///                        i2 >= 0 and i2 <= 62) }
1557b02c7e2bSRoman Gareev ///
1558b02c7e2bSRoman Gareev /// @param  Domain       The domain of the statement.
1559b02c7e2bSRoman Gareev /// @param  DepsForStmt  RAW and RED dependencies for the statement.
1560b02c7e2bSRoman Gareev /// @param  UpperBound   The lexicographic maximum of the elements in
1561b02c7e2bSRoman Gareev ///                      the @p Domain.
1562b02c7e2bSRoman Gareev /// @param  IndexSet     Obtained indexes ki, which describe the dependencies.
1563b02c7e2bSRoman Gareev /// @return True if dependencies are over the complete domain
1564b02c7e2bSRoman Gareev ///         and false, otherwise.
1565b02c7e2bSRoman Gareev static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt,
1566b02c7e2bSRoman Gareev                                       isl::pw_multi_aff UpperBound,
1567b02c7e2bSRoman Gareev                                       SmallDenseSet<int> &IndexSet) {
1568b02c7e2bSRoman Gareev   isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(UpperBound.copy());
1569b02c7e2bSRoman Gareev   isl::set UpperBoundSet = isl::manage(UpperBoundRawSet);
1570b02c7e2bSRoman Gareev 
1571b02c7e2bSRoman Gareev   isl::set DomainRed = isl::manage(Domain.copy());
1572b02c7e2bSRoman Gareev   for (const auto It : IndexSet) {
1573b02c7e2bSRoman Gareev     isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(isl::dim::set, It);
1574b02c7e2bSRoman Gareev     if (FixedVal.is_nan())
1575b02c7e2bSRoman Gareev       return false;
1576b02c7e2bSRoman Gareev     DomainRed = isl::manage(
1577b02c7e2bSRoman Gareev         isl_set_fix_val(DomainRed.copy(), isl_dim_set, It, FixedVal.release()));
1578b02c7e2bSRoman Gareev   }
1579b02c7e2bSRoman Gareev   return DepsForStmt.domain().intersect(Domain).is_equal(
1580b02c7e2bSRoman Gareev       Domain.subtract(DomainRed));
1581b02c7e2bSRoman Gareev }
1582b02c7e2bSRoman Gareev 
1583b02c7e2bSRoman Gareev /// Check that dependencies correspond to the tensor contraction.
1584b02c7e2bSRoman Gareev ///
1585b02c7e2bSRoman Gareev /// Check that there are only true dependencies of the form
1586b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
1587b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
1588b02c7e2bSRoman Gareev /// statement represented by @p Schedule. Such dependencies are produced by
1589b02c7e2bSRoman Gareev /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet.
1590b02c7e2bSRoman Gareev ///
1591b02c7e2bSRoman Gareev /// The form of anti and output dependencies is specified implicitly by
1592b02c7e2bSRoman Gareev /// the form the SCoP statement, which is checked by subsequent analysis.
1593b02c7e2bSRoman Gareev ///
1594b02c7e2bSRoman Gareev /// @param  Schedule The schedule of the SCoP statement.
1595b02c7e2bSRoman Gareev /// @param  D        The SCoP dependencies.
1596b02c7e2bSRoman Gareev /// @param  Domain   The domain of the statement.
1597b02c7e2bSRoman Gareev /// @param  IndexSet Obtained indexes ki, which describe the dependencies.
1598b02c7e2bSRoman Gareev /// @return True if dependencies correspond to the tensor contraction
1599b02c7e2bSRoman Gareev ///         and false, otherwise.
1600b02c7e2bSRoman Gareev static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D,
1601b02c7e2bSRoman Gareev                                SmallDenseSet<int> &IndexSet, isl::set Domain) {
1602b02c7e2bSRoman Gareev   IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut);
1603b02c7e2bSRoman Gareev 
1604b02c7e2bSRoman Gareev   isl::union_map Dep =
1605b02c7e2bSRoman Gareev       D->getDependences(Dependences::TYPE_RAW | Dependences::TYPE_RED);
1606b02c7e2bSRoman Gareev 
1607b02c7e2bSRoman Gareev   isl::space DomainSpace = Schedule.get_space().domain();
1608b02c7e2bSRoman Gareev   isl::space Space = DomainSpace.map_from_domain_and_range(DomainSpace);
1609b02c7e2bSRoman Gareev   isl::map DepsForStmt = Dep.extract_map(Space);
1610b02c7e2bSRoman Gareev   isl::set DepDeltas = DepsForStmt.deltas();
1611b02c7e2bSRoman Gareev   isl::size DeltasDimNum = DepDeltas.dim(isl::dim::set);
1612b02c7e2bSRoman Gareev   isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff();
1613b02c7e2bSRoman Gareev   isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff();
1614b02c7e2bSRoman Gareev   isl::pw_multi_aff BoundDeltas = UpperBound.sub(LowerBound);
1615b02c7e2bSRoman Gareev 
1616b02c7e2bSRoman Gareev   for (int i : reverse(rangeIslSize(0, DeltasDimNum))) {
1617b02c7e2bSRoman Gareev     // In the case of the tensor contraction, the difference between image
1618b02c7e2bSRoman Gareev     // elements and domain elements lies on a hyperplane where a dimension
1619b02c7e2bSRoman Gareev     // has the fixed value one.
1620b02c7e2bSRoman Gareev     isl::set Intersection = DepDeltas.fix_si(isl::dim::set, i, 1);
1621b02c7e2bSRoman Gareev     if (Intersection.is_empty())
1622b02c7e2bSRoman Gareev       continue;
1623b02c7e2bSRoman Gareev 
1624b02c7e2bSRoman Gareev     if (!isReductionCarriedOverDim(Intersection, i, BoundDeltas, IndexSet))
1625b02c7e2bSRoman Gareev       return false;
1626b02c7e2bSRoman Gareev 
1627b02c7e2bSRoman Gareev     IndexSet.insert(i);
1628b02c7e2bSRoman Gareev     DepDeltas = DepDeltas.subtract(Intersection);
1629b02c7e2bSRoman Gareev   }
1630b02c7e2bSRoman Gareev 
1631b02c7e2bSRoman Gareev   // In the case of the tensor contraction, all dependencies should have
1632b02c7e2bSRoman Gareev   // the previously described form.
1633b02c7e2bSRoman Gareev   if ((unsignedFromIslSize(DeltasDimNum) == 0) || !DepDeltas.is_empty())
1634b02c7e2bSRoman Gareev     return false;
1635b02c7e2bSRoman Gareev 
1636b02c7e2bSRoman Gareev   return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet);
1637b02c7e2bSRoman Gareev }
1638b02c7e2bSRoman Gareev 
1639b02c7e2bSRoman Gareev /// Check if the SCoP statement could probably be optimized with analytical
1640b02c7e2bSRoman Gareev /// modeling.
1641b02c7e2bSRoman Gareev ///
1642b02c7e2bSRoman Gareev /// containsTCInfoTy tries to determine whether the following conditions
1643b02c7e2bSRoman Gareev /// are true:
1644b02c7e2bSRoman Gareev ///
1645b02c7e2bSRoman Gareev /// 1. The last memory access modeling an array, MA1, represents writing to
1646b02c7e2bSRoman Gareev ///    memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)),
1647b02c7e2bSRoman Gareev ///    where S is the SCoP statement under consideration and shuffle(I, J)
1648b02c7e2bSRoman Gareev ///    is a permutation of indexes of sets I and J.
1649b02c7e2bSRoman Gareev /// 2. There are only true dependencies of the form
1650b02c7e2bSRoman Gareev ///    S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
1651b02c7e2bSRoman Gareev ///    S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
1652b02c7e2bSRoman Gareev ///    statement represented by @p Schedule and ki are indexes of the set P.
1653b02c7e2bSRoman Gareev /// 3. SCoP contains an arbitrary number of reads from constants and only three
1654b02c7e2bSRoman Gareev ///    access relations, MA2, MA3, and MA4 that represent reading from memory
1655b02c7e2bSRoman Gareev ///    and have the form
1656b02c7e2bSRoman Gareev ///    S(..., I, ..., P, ...) -> M(shuffle(I, P)),
1657b02c7e2bSRoman Gareev ///    S(..., P, ..., J, ...) -> M(shuffle(J, P)),
1658b02c7e2bSRoman Gareev ///    S(...) -> M(shuffle(I, J)), respectively.
1659b02c7e2bSRoman Gareev ///
1660b02c7e2bSRoman Gareev /// @param  PartialSchedule The PartialSchedule that contains a SCoP statement
1661b02c7e2bSRoman Gareev ///                         to check.
1662b02c7e2bSRoman Gareev /// @param  D               The SCoP dependencies.
1663b02c7e2bSRoman Gareev /// @param  TCI             Parameters of the tensor contraction operands.
1664b02c7e2bSRoman Gareev /// @param  Domain          The domain of the statement.
1665b02c7e2bSRoman Gareev /// @return True if dependencies and memory accesses correspond to the tensor
1666b02c7e2bSRoman Gareev ///              contraction and false, otherwise.
1667b02c7e2bSRoman Gareev static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D,
1668b02c7e2bSRoman Gareev                              TCInfoTy &TCI, isl::set Domain) {
1669b02c7e2bSRoman Gareev   if (!containsOnlyTcDeps(PartialSchedule, D, TCI.P, Domain))
1670b02c7e2bSRoman Gareev     return false;
1671b02c7e2bSRoman Gareev 
1672b02c7e2bSRoman Gareev   // TODO: handle cases of scalar multiplication if needed.
1673b02c7e2bSRoman Gareev   if (TCI.P.size() == 0)
1674b02c7e2bSRoman Gareev     return false;
1675b02c7e2bSRoman Gareev 
1676b02c7e2bSRoman Gareev   if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI))
1677b02c7e2bSRoman Gareev     return false;
1678b02c7e2bSRoman Gareev 
1679b02c7e2bSRoman Gareev   // TODO: handle cases of GEMV if needed.
1680b02c7e2bSRoman Gareev   if ((TCI.I.size() == 0) || (TCI.J.size() == 0))
1681b02c7e2bSRoman Gareev     return false;
1682b02c7e2bSRoman Gareev 
1683b02c7e2bSRoman Gareev   return true;
1684b02c7e2bSRoman Gareev }
1685b02c7e2bSRoman Gareev 
1686b02c7e2bSRoman Gareev /// Check if this node contains a partial schedule that could
1687b02c7e2bSRoman Gareev /// probably be optimized with analytical modeling.
1688b02c7e2bSRoman Gareev ///
1689b02c7e2bSRoman Gareev /// isTCPattern is used to determine whether the SCoP represents a TC-like
1690b02c7e2bSRoman Gareev /// kernel [1], which is a perfectly nested set of loops, with a data usage
1691b02c7e2bSRoman Gareev /// pattern that is similar to that produced by the tensor contraction.
1692b02c7e2bSRoman Gareev ///
1693b02c7e2bSRoman Gareev /// A TC-like kernel can be defined as follows:
1694b02c7e2bSRoman Gareev ///
1695b02c7e2bSRoman Gareev /// 1. It satisfies the requirements of the polyhedral model.
1696b02c7e2bSRoman Gareev /// 2. Without loss of generality, it contains three nonempty bundles of
1697b02c7e2bSRoman Gareev ///    one-dimensional for-loops with induction variables that are grouped into
1698b02c7e2bSRoman Gareev ///    bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they
1699b02c7e2bSRoman Gareev ///    are incremented by one.
1700b02c7e2bSRoman Gareev /// 3. The innermost loop body can be represented as a statement of the form
1701b02c7e2bSRoman Gareev ///    C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)),
1702b02c7e2bSRoman Gareev ///    C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)),
1703b02c7e2bSRoman Gareev ///    C(shuffle(I, J)) are accesses to tensors A, B, C, respectively,
1704b02c7e2bSRoman Gareev ///    shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the
1705b02c7e2bSRoman Gareev ///    enclosed indices, and E is an expression that contains reads from
1706b02c7e2bSRoman Gareev ///    the tensors A, B, C, and an arbitrary number of reads from constants
1707b02c7e2bSRoman Gareev ///    with respect to bundles I, J, and P.
1708b02c7e2bSRoman Gareev ///
1709b02c7e2bSRoman Gareev /// TC can be considered as a particular case of a TC-like kernel.
1710b02c7e2bSRoman Gareev ///
1711b02c7e2bSRoman Gareev /// The order of loops with indexes from P should be preserved. Otherwise,
1712b02c7e2bSRoman Gareev /// isTCPattern should check if a commutative operation is used.
1713b02c7e2bSRoman Gareev ///
1714b02c7e2bSRoman Gareev /// isTCPattern performs the following steps to check whether the SCoP
1715b02c7e2bSRoman Gareev /// corresponds to a definition of a TC-like kernel:
1716b02c7e2bSRoman Gareev ///
1717b02c7e2bSRoman Gareev /// 1. Checks that the node is the innermost band node.
1718b02c7e2bSRoman Gareev /// 2. Checks that the partial schedule contains only one statement.
1719b02c7e2bSRoman Gareev /// 3. Check that all ancestors of the node contain all band nodes for
1720b02c7e2bSRoman Gareev ///    the statement and only mark nodes interleave such band nodes. This
1721b02c7e2bSRoman Gareev ///    corresponds to a straightforward implementation of TC.
1722b02c7e2bSRoman Gareev /// 4. Analyses the dependencies to determine contraction dimensions.
1723b02c7e2bSRoman Gareev /// 5. Check that the last memory access modeling an array, represents writing
1724b02c7e2bSRoman Gareev ///    to the result of the TC-like kernel.
1725b02c7e2bSRoman Gareev /// 6. Check that SCoP contains only three access relations that represent
1726b02c7e2bSRoman Gareev ///    reading of the operands of the TC-like kernel and an arbitrary number of
1727b02c7e2bSRoman Gareev ///    reads from constants.
1728b02c7e2bSRoman Gareev ///
1729b02c7e2bSRoman Gareev /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor
1730b02c7e2bSRoman Gareev ///       Operations: A Compiler-Oriented Approach // ACM Transactions
1731b02c7e2bSRoman Gareev ///       Architecture and Code Optimization (TACO). 2018.
1732b02c7e2bSRoman Gareev ///       Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029.
1733b02c7e2bSRoman Gareev ///
1734b02c7e2bSRoman Gareev /// If this is the case, we could logically represent tensors as matrices and
1735b02c7e2bSRoman Gareev /// apply algorithms, which are used to get close-to-peak performance of
1736b02c7e2bSRoman Gareev /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS).
1737b02c7e2bSRoman Gareev ///
1738b02c7e2bSRoman Gareev /// @param Node The node to check.
1739b02c7e2bSRoman Gareev /// @param D    The SCoP dependencies.
1740b02c7e2bSRoman Gareev /// @param TCI  Parameters of the tensor contraction operands.
1741b02c7e2bSRoman Gareev static bool isTCPattern(isl::schedule_node Node, const Dependences *D,
1742b02c7e2bSRoman Gareev                         TCInfoTy &TCI) {
1743b02c7e2bSRoman Gareev   Node = Node.child(0);
1744b02c7e2bSRoman Gareev   isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map();
1745b02c7e2bSRoman Gareev   isl::union_set Domain = Node.domain();
1746b02c7e2bSRoman Gareev   Node = Node.parent();
1747b02c7e2bSRoman Gareev 
1748b02c7e2bSRoman Gareev   // The partial schedule should contain only one statement.
1749b02c7e2bSRoman Gareev   // TODO: This constraint should not be intrinsic to the algorithm.
1750b02c7e2bSRoman Gareev   if (isl_union_set_n_set(Domain.get()) != 1)
1751b02c7e2bSRoman Gareev     return false;
1752b02c7e2bSRoman Gareev 
1753b02c7e2bSRoman Gareev   isl_schedule_node_type NodeType = isl_schedule_node_get_type(Node.get());
1754b02c7e2bSRoman Gareev 
1755b02c7e2bSRoman Gareev   // Check that all ancestors of the node contain all band nodes for
1756b02c7e2bSRoman Gareev   // the statement, which represents the TC-like kernel, and only mark nodes
1757b02c7e2bSRoman Gareev   // interleave such band nodes. This corresponds to a straightforward
1758b02c7e2bSRoman Gareev   // implementation of TC with/without DeLICM applied.
1759b02c7e2bSRoman Gareev   //
1760b02c7e2bSRoman Gareev   // For example, this covers the matrix multiplication pattern after a full
1761b02c7e2bSRoman Gareev   // run of -polly-optree and -polly-delicm, where the write access is not
1762*5aafc6d5SChristian Clauss   // through the original memory access, but through a PHI node that was
1763b02c7e2bSRoman Gareev   // delicmed. Subsequently, such band nodes will be replaced by a single band
1764b02c7e2bSRoman Gareev   // node.
1765b02c7e2bSRoman Gareev   //
1766b02c7e2bSRoman Gareev   // The corresponding schedule can be the following, where Stmt_for_body8
1767b02c7e2bSRoman Gareev   // contains the matrix multiplication:
1768b02c7e2bSRoman Gareev   //
1769b02c7e2bSRoman Gareev   // domain: "{ Stmt_for_body8[i0, i1, i2]  : 0 <= i0 <= 1599 and
1770b02c7e2bSRoman Gareev   //                                          0 <= i1 <= 1799 and
1771b02c7e2bSRoman Gareev   //                                          0 <= i2 <= 2199;
1772b02c7e2bSRoman Gareev   //            Stmt_for_body3[i0, i1] :      0 <= i0 <= 1599 and
1773b02c7e2bSRoman Gareev   //                                          0 <= i1 <= 1799;
1774b02c7e2bSRoman Gareev   //            Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and
1775b02c7e2bSRoman Gareev   //                                          0 <= i1 <= 1799 }"
1776b02c7e2bSRoman Gareev   // child:
1777b02c7e2bSRoman Gareev   //  sequence:
1778b02c7e2bSRoman Gareev   //  - filter: "{ Stmt_for_body3[i0, i1] }"
1779b02c7e2bSRoman Gareev   //    child:
1780b02c7e2bSRoman Gareev   //      schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] },
1781b02c7e2bSRoman Gareev   //                  { Stmt_for_body3[i0, i1] -> [(i1)] }]"
1782b02c7e2bSRoman Gareev   //      permutable: 1
1783b02c7e2bSRoman Gareev   //      coincident: [ 1, 1 ]
1784b02c7e2bSRoman Gareev   //  - filter: "{ Stmt_for_body3_last[i0, i1] }"
1785b02c7e2bSRoman Gareev   //    child:
1786b02c7e2bSRoman Gareev   //      schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] },
1787b02c7e2bSRoman Gareev   //                  { Stmt_for_body3_last[i0, i1] -> [(i1)] }]"
1788b02c7e2bSRoman Gareev   //      permutable: 1
1789b02c7e2bSRoman Gareev   //      coincident: [ 1, 1 ]
1790b02c7e2bSRoman Gareev   //  - filter: "{ Stmt_for_body8[i0, i1, i2] }"
1791b02c7e2bSRoman Gareev   //    child:
1792b02c7e2bSRoman Gareev   //      schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] },
1793b02c7e2bSRoman Gareev   //                  { Stmt_for_body8[i0, i1, i2] -> [(i1)] },
1794b02c7e2bSRoman Gareev   //                  { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]"
1795b02c7e2bSRoman Gareev   //      permutable: 1
1796b02c7e2bSRoman Gareev   //      coincident: [ 1, 1, 0 ]
1797b02c7e2bSRoman Gareev   //
1798b02c7e2bSRoman Gareev   while (NodeType != isl_schedule_node_domain) {
1799b02c7e2bSRoman Gareev     if (NodeType == isl_schedule_node_filter) {
1800b02c7e2bSRoman Gareev       if (!Node.parent().isa<isl::schedule_node_sequence>() ||
1801b02c7e2bSRoman Gareev           !Node.parent().parent().isa<isl::schedule_node_domain>())
1802b02c7e2bSRoman Gareev         return false;
1803b02c7e2bSRoman Gareev       break;
1804b02c7e2bSRoman Gareev     }
1805b02c7e2bSRoman Gareev 
1806b02c7e2bSRoman Gareev     if ((NodeType != isl_schedule_node_band) &&
1807b02c7e2bSRoman Gareev         (NodeType != isl_schedule_node_mark))
1808b02c7e2bSRoman Gareev       return false;
1809b02c7e2bSRoman Gareev 
1810b02c7e2bSRoman Gareev     Node = Node.parent();
1811b02c7e2bSRoman Gareev     NodeType = isl_schedule_node_get_type(Node.get());
1812b02c7e2bSRoman Gareev   }
1813b02c7e2bSRoman Gareev 
1814b02c7e2bSRoman Gareev   isl::map PartialScheduleMap = isl::map::from_union_map(PartialSchedule);
1815b02c7e2bSRoman Gareev   if (containsTCInfoTy(PartialScheduleMap, D, TCI, isl::set(Domain)))
1816b02c7e2bSRoman Gareev     return true;
1817b02c7e2bSRoman Gareev 
1818b02c7e2bSRoman Gareev   return false;
1819b02c7e2bSRoman Gareev }
1820b02c7e2bSRoman Gareev 
1821d123e983SMichael Kruse } // namespace
1822d123e983SMichael Kruse 
1823d123e983SMichael Kruse isl::schedule_node
1824d123e983SMichael Kruse polly::tryOptimizeMatMulPattern(isl::schedule_node Node,
1825d123e983SMichael Kruse                                 const llvm::TargetTransformInfo *TTI,
1826d123e983SMichael Kruse                                 const Dependences *D) {
1827b02c7e2bSRoman Gareev   TCInfoTy TCI;
1828b02c7e2bSRoman Gareev   if (PMBasedTCOpts && isTCPattern(Node, D, TCI))
1829601d7eabSKarthika Devi C     POLLY_DEBUG(dbgs() << "The tensor contraction pattern was detected\n");
1830d123e983SMichael Kruse   MatMulInfoTy MMI;
1831b02c7e2bSRoman Gareev   if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) {
1832601d7eabSKarthika Devi C     POLLY_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
1833d123e983SMichael Kruse     return optimizeMatMulPattern(Node, TTI, MMI);
1834d123e983SMichael Kruse   }
1835d123e983SMichael Kruse   return {};
1836d123e983SMichael Kruse }
1837