1d123e983SMichael Kruse //===- MatmulOptimizer.cpp -----------------------------------------------===// 2d123e983SMichael Kruse // 3d123e983SMichael Kruse // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d123e983SMichael Kruse // See https://llvm.org/LICENSE.txt for license information. 5d123e983SMichael Kruse // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d123e983SMichael Kruse // 7d123e983SMichael Kruse //===----------------------------------------------------------------------===// 8d123e983SMichael Kruse 9d123e983SMichael Kruse #include "polly/MatmulOptimizer.h" 10d123e983SMichael Kruse #include "polly/DependenceInfo.h" 11d123e983SMichael Kruse #include "polly/Options.h" 12d123e983SMichael Kruse #include "polly/ScheduleTreeTransform.h" 13d123e983SMichael Kruse #include "polly/ScopInfo.h" 14d123e983SMichael Kruse #include "polly/ScopPass.h" 15d123e983SMichael Kruse #include "polly/Simplify.h" 16b02c7e2bSRoman Gareev #include "polly/Support/GICHelper.h" 17a56bd7deSMichael Kruse #include "polly/Support/ISLTools.h" 18d123e983SMichael Kruse #include "llvm/ADT/ArrayRef.h" 19b02c7e2bSRoman Gareev #include "llvm/ADT/DenseSet.h" 20d123e983SMichael Kruse #include "llvm/ADT/Sequence.h" 21b02c7e2bSRoman Gareev #include "llvm/ADT/SetOperations.h" 22d123e983SMichael Kruse #include "llvm/ADT/SmallVector.h" 23d123e983SMichael Kruse #include "llvm/ADT/StringRef.h" 24d123e983SMichael Kruse #include "llvm/ADT/iterator_range.h" 25d123e983SMichael Kruse #include "llvm/Analysis/TargetTransformInfo.h" 26d123e983SMichael Kruse #include "llvm/IR/DataLayout.h" 27d123e983SMichael Kruse #include "llvm/IR/Function.h" 28d123e983SMichael Kruse #include "llvm/IR/Module.h" 29d123e983SMichael Kruse #include "llvm/Support/CommandLine.h" 30d123e983SMichael Kruse #include "llvm/Support/Debug.h" 31d123e983SMichael Kruse #include "llvm/Support/TypeSize.h" 32d123e983SMichael Kruse #include "llvm/Support/raw_ostream.h" 33d123e983SMichael Kruse #include "isl/ctx.h" 34d123e983SMichael Kruse #include "isl/schedule_node.h" 35d123e983SMichael Kruse #include "isl/schedule_type.h" 36d123e983SMichael Kruse #include "isl/union_map.h" 37d123e983SMichael Kruse #include "isl/union_set.h" 38d123e983SMichael Kruse #include <algorithm> 39d123e983SMichael Kruse #include <cassert> 40d123e983SMichael Kruse #include <cmath> 41d123e983SMichael Kruse #include <cstdint> 42d123e983SMichael Kruse #include <string> 43d123e983SMichael Kruse #include <vector> 44d123e983SMichael Kruse 45601d7eabSKarthika Devi C #include "polly/Support/PollyDebug.h" 46d123e983SMichael Kruse #define DEBUG_TYPE "polly-opt-isl" 47d123e983SMichael Kruse 48d123e983SMichael Kruse using namespace llvm; 49d123e983SMichael Kruse using namespace polly; 50d123e983SMichael Kruse 51d123e983SMichael Kruse namespace llvm { 52d123e983SMichael Kruse class Value; 53d123e983SMichael Kruse } 54d123e983SMichael Kruse 55d123e983SMichael Kruse static cl::opt<int> LatencyVectorFma( 56d123e983SMichael Kruse "polly-target-latency-vector-fma", 57d123e983SMichael Kruse cl::desc("The minimal number of cycles between issuing two " 58d123e983SMichael Kruse "dependent consecutive vector fused multiply-add " 59d123e983SMichael Kruse "instructions."), 6095a13425SFangrui Song cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 61d123e983SMichael Kruse 62d123e983SMichael Kruse static cl::opt<int> ThroughputVectorFma( 63d123e983SMichael Kruse "polly-target-throughput-vector-fma", 64d123e983SMichael Kruse cl::desc("A throughput of the processor floating-point arithmetic units " 65d123e983SMichael Kruse "expressed in the number of vector fused multiply-add " 66d123e983SMichael Kruse "instructions per clock cycle."), 6795a13425SFangrui Song cl::Hidden, cl::init(1), cl::cat(PollyCategory)); 68d123e983SMichael Kruse 69d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelSize( 70d123e983SMichael Kruse "polly-target-1st-cache-level-size", 71d123e983SMichael Kruse cl::desc("The size of the first cache level specified in bytes."), 7236c7d79dSFangrui Song cl::Hidden, cl::init(-1), cl::cat(PollyCategory)); 73d123e983SMichael Kruse 74d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelDefaultSize( 75d123e983SMichael Kruse "polly-target-1st-cache-level-default-size", 76d123e983SMichael Kruse cl::desc("The default size of the first cache level specified in bytes" 77d123e983SMichael Kruse " (if not enough were provided by the TargetTransformInfo)."), 7836c7d79dSFangrui Song cl::Hidden, cl::init(32768), cl::cat(PollyCategory)); 79d123e983SMichael Kruse 80d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelSize( 81d123e983SMichael Kruse "polly-target-2nd-cache-level-size", 82d123e983SMichael Kruse cl::desc("The size of the second level specified in bytes."), cl::Hidden, 8336c7d79dSFangrui Song cl::init(-1), cl::cat(PollyCategory)); 84d123e983SMichael Kruse 85d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelDefaultSize( 86d123e983SMichael Kruse "polly-target-2nd-cache-level-default-size", 87d123e983SMichael Kruse cl::desc("The default size of the second cache level specified in bytes" 88d123e983SMichael Kruse " (if not enough were provided by the TargetTransformInfo)."), 8936c7d79dSFangrui Song cl::Hidden, cl::init(262144), cl::cat(PollyCategory)); 90d123e983SMichael Kruse 91d123e983SMichael Kruse // This option, along with --polly-target-2nd-cache-level-associativity, 92d123e983SMichael Kruse // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size 93d123e983SMichael Kruse // represent the parameters of the target cache, which do not have typical 94d123e983SMichael Kruse // values that can be used by default. However, to apply the pattern matching 95d123e983SMichael Kruse // optimizations, we use the values of the parameters of Intel Core i7-3820 96d123e983SMichael Kruse // SandyBridge in case the parameters are not specified or not provided by the 97d123e983SMichael Kruse // TargetTransformInfo. 98d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelAssociativity( 99d123e983SMichael Kruse "polly-target-1st-cache-level-associativity", 100d123e983SMichael Kruse cl::desc("The associativity of the first cache level."), cl::Hidden, 10136c7d79dSFangrui Song cl::init(-1), cl::cat(PollyCategory)); 102d123e983SMichael Kruse 103d123e983SMichael Kruse static cl::opt<int> FirstCacheLevelDefaultAssociativity( 104d123e983SMichael Kruse "polly-target-1st-cache-level-default-associativity", 105d123e983SMichael Kruse cl::desc("The default associativity of the first cache level" 106d123e983SMichael Kruse " (if not enough were provided by the TargetTransformInfo)."), 10736c7d79dSFangrui Song cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 108d123e983SMichael Kruse 109d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelAssociativity( 110d123e983SMichael Kruse "polly-target-2nd-cache-level-associativity", 111d123e983SMichael Kruse cl::desc("The associativity of the second cache level."), cl::Hidden, 11236c7d79dSFangrui Song cl::init(-1), cl::cat(PollyCategory)); 113d123e983SMichael Kruse 114d123e983SMichael Kruse static cl::opt<int> SecondCacheLevelDefaultAssociativity( 115d123e983SMichael Kruse "polly-target-2nd-cache-level-default-associativity", 116d123e983SMichael Kruse cl::desc("The default associativity of the second cache level" 117d123e983SMichael Kruse " (if not enough were provided by the TargetTransformInfo)."), 11836c7d79dSFangrui Song cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 119d123e983SMichael Kruse 120d123e983SMichael Kruse static cl::opt<int> VectorRegisterBitwidth( 121d123e983SMichael Kruse "polly-target-vector-register-bitwidth", 122d123e983SMichael Kruse cl::desc("The size in bits of a vector register (if not set, this " 123d123e983SMichael Kruse "information is taken from LLVM's target information."), 12436c7d79dSFangrui Song cl::Hidden, cl::init(-1), cl::cat(PollyCategory)); 125d123e983SMichael Kruse 126d123e983SMichael Kruse static cl::opt<int> PollyPatternMatchingNcQuotient( 127d123e983SMichael Kruse "polly-pattern-matching-nc-quotient", 128d123e983SMichael Kruse cl::desc("Quotient that is obtained by dividing Nc, the parameter of the" 129d123e983SMichael Kruse "macro-kernel, by Nr, the parameter of the micro-kernel"), 13036c7d79dSFangrui Song cl::Hidden, cl::init(256), cl::cat(PollyCategory)); 131d123e983SMichael Kruse 132b02c7e2bSRoman Gareev static cl::opt<bool> 133b02c7e2bSRoman Gareev PMBasedTCOpts("polly-tc-opt", 134b02c7e2bSRoman Gareev cl::desc("Perform optimizations of tensor contractions based " 135b02c7e2bSRoman Gareev "on pattern matching"), 136b02c7e2bSRoman Gareev cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 137b02c7e2bSRoman Gareev 138b02c7e2bSRoman Gareev static cl::opt<bool> 139b02c7e2bSRoman Gareev PMBasedMMMOpts("polly-matmul-opt", 140b02c7e2bSRoman Gareev cl::desc("Perform optimizations of matrix multiplications " 141b02c7e2bSRoman Gareev "based on pattern matching"), 142b02c7e2bSRoman Gareev cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory)); 143b02c7e2bSRoman Gareev 144b02c7e2bSRoman Gareev static cl::opt<int> OptComputeOut( 145b02c7e2bSRoman Gareev "polly-tc-dependences-computeout", 146b02c7e2bSRoman Gareev cl::desc("Bound the dependence analysis by a maximal amount of " 147b02c7e2bSRoman Gareev "computational steps (0 means no bound)"), 148b02c7e2bSRoman Gareev cl::Hidden, cl::init(500000), cl::ZeroOrMore, cl::cat(PollyCategory)); 149b02c7e2bSRoman Gareev 150d123e983SMichael Kruse namespace { 151d123e983SMichael Kruse /// Parameters of the micro kernel. 152d123e983SMichael Kruse /// 153d123e983SMichael Kruse /// Parameters, which determine sizes of rank-1 (i.e., outer product) update 154d123e983SMichael Kruse /// used in the optimized matrix multiplication. 155d123e983SMichael Kruse struct MicroKernelParamsTy { 156d123e983SMichael Kruse int Mr; 157d123e983SMichael Kruse int Nr; 158d123e983SMichael Kruse }; 159d123e983SMichael Kruse 160d123e983SMichael Kruse /// Parameters of the macro kernel. 161d123e983SMichael Kruse /// 162d123e983SMichael Kruse /// Parameters, which determine sizes of blocks of partitioned matrices 163d123e983SMichael Kruse /// used in the optimized matrix multiplication. 164d123e983SMichael Kruse struct MacroKernelParamsTy { 165d123e983SMichael Kruse int Mc; 166d123e983SMichael Kruse int Nc; 167d123e983SMichael Kruse int Kc; 168d123e983SMichael Kruse }; 169d123e983SMichael Kruse 170d123e983SMichael Kruse /// Parameters of the matrix multiplication operands. 171d123e983SMichael Kruse /// 172d123e983SMichael Kruse /// Parameters, which describe access relations that represent operands of the 173d123e983SMichael Kruse /// matrix multiplication. 174d123e983SMichael Kruse struct MatMulInfoTy { 175d123e983SMichael Kruse MemoryAccess *A = nullptr; 176d123e983SMichael Kruse MemoryAccess *B = nullptr; 177d123e983SMichael Kruse MemoryAccess *ReadFromC = nullptr; 178d123e983SMichael Kruse MemoryAccess *WriteToC = nullptr; 179d123e983SMichael Kruse int i = -1; 180d123e983SMichael Kruse int j = -1; 181d123e983SMichael Kruse int k = -1; 182d123e983SMichael Kruse }; 183d123e983SMichael Kruse 184b02c7e2bSRoman Gareev /// Parameters of the tensor contraction operands. 185b02c7e2bSRoman Gareev /// 186b02c7e2bSRoman Gareev /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined 187b02c7e2bSRoman Gareev /// as the set of scalar elements indexed by the set of indices u0 ... ud, 188b02c7e2bSRoman Gareev /// 189b02c7e2bSRoman Gareev /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}. 190b02c7e2bSRoman Gareev /// 191b02c7e2bSRoman Gareev /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively. 192b02c7e2bSRoman Gareev /// Let the free and the contracted indices of the tensor A be grouped into 193b02c7e2bSRoman Gareev /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly, 194b02c7e2bSRoman Gareev /// the free and the contracted indices of B are grouped into bundles 195b02c7e2bSRoman Gareev /// J = j0..js−1 and P and the free indices of C are grouped into 196b02c7e2bSRoman Gareev /// bundles I and J. 197b02c7e2bSRoman Gareev /// 198b02c7e2bSRoman Gareev /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as 199b02c7e2bSRoman Gareev /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)), 200b02c7e2bSRoman Gareev /// where ∑ is a summation over all contracted indices of P, 201b02c7e2bSRoman Gareev /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds 202b02c7e2bSRoman Gareev /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are 203b02c7e2bSRoman Gareev /// accesses to tensors A, B, C, respectively, 204b02c7e2bSRoman Gareev /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of 205b02c7e2bSRoman Gareev /// the enclosed indices. 206b02c7e2bSRoman Gareev /// 207b02c7e2bSRoman Gareev /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP 208b02c7e2bSRoman Gareev /// statement by loop distribution, which is done by the isl scheduler. 209b02c7e2bSRoman Gareev // If β is not equal to one, the optimization of TC of Polly requires 210b02c7e2bSRoman Gareev /// such a transformation. 211b02c7e2bSRoman Gareev /// 212b02c7e2bSRoman Gareev /// TCInfoTy contains parameters, which describe access relations that represent 213b02c7e2bSRoman Gareev /// operands of the tensor contraction. 214b02c7e2bSRoman Gareev struct TCInfoTy { 215b02c7e2bSRoman Gareev /// @{ 216b02c7e2bSRoman Gareev /// Memory accesses that represent reading from tensors, which are operands of 217b02c7e2bSRoman Gareev /// the tensor contraction. 218b02c7e2bSRoman Gareev MemoryAccess *A = nullptr; 219b02c7e2bSRoman Gareev MemoryAccess *B = nullptr; 220b02c7e2bSRoman Gareev /// @} 221b02c7e2bSRoman Gareev 222b02c7e2bSRoman Gareev /// @{ 223b02c7e2bSRoman Gareev /// Memory accesses that represent reading from and writing into the tensor, 224b02c7e2bSRoman Gareev /// which contains the result of the tensor contraction. 225b02c7e2bSRoman Gareev MemoryAccess *ReadFromC = nullptr; 226b02c7e2bSRoman Gareev MemoryAccess *WriteToC = nullptr; 227b02c7e2bSRoman Gareev /// @} 228b02c7e2bSRoman Gareev 229b02c7e2bSRoman Gareev /// @{ 230b02c7e2bSRoman Gareev /// Input dimensions of the schedule space, which represent free 231b02c7e2bSRoman Gareev /// indices of tensors. 232b02c7e2bSRoman Gareev SmallDenseSet<int> I; 233b02c7e2bSRoman Gareev SmallDenseSet<int> J; 234b02c7e2bSRoman Gareev /// @} 235b02c7e2bSRoman Gareev 236b02c7e2bSRoman Gareev /// Input dimension of the schedule space, which represents contracted 237b02c7e2bSRoman Gareev /// indices of tensors. 238b02c7e2bSRoman Gareev SmallDenseSet<int> P; 239b02c7e2bSRoman Gareev 240b02c7e2bSRoman Gareev /// @{ 241b02c7e2bSRoman Gareev /// Sizes of tensor dimensions for corresponding input dimensions of 242b02c7e2bSRoman Gareev /// the schedule space. The size of the tensor dimension can be larger than 243b02c7e2bSRoman Gareev /// the size of the corresponding input dimension of the schedule space. 244b02c7e2bSRoman Gareev /// This does not correspond to a tensor contraction. However, such a pattern 245b02c7e2bSRoman Gareev /// will be optimized by the transformation. 246b02c7e2bSRoman Gareev SmallVector<int> DimensionSizes; 247b02c7e2bSRoman Gareev SmallVector<int> ADimensions; 248b02c7e2bSRoman Gareev SmallVector<int> BDimensions; 249b02c7e2bSRoman Gareev SmallVector<int> CDimensions; 250b02c7e2bSRoman Gareev /// @} 251b02c7e2bSRoman Gareev 252b02c7e2bSRoman Gareev /// @{ 253b02c7e2bSRoman Gareev /// Permutations of indices of I, J, and P, which describe operands of 254b02c7e2bSRoman Gareev /// the tensor contraction and its result. 255b02c7e2bSRoman Gareev SmallVector<int> OrderedI; 256b02c7e2bSRoman Gareev SmallVector<int> OrderedJ; 257b02c7e2bSRoman Gareev SmallVector<int> OrderedP; 258b02c7e2bSRoman Gareev /// @} 259b02c7e2bSRoman Gareev }; 260b02c7e2bSRoman Gareev 261d123e983SMichael Kruse /// Create an isl::union_set, which describes the option of the form 262d123e983SMichael Kruse /// [isolate[] -> unroll[x]]. 263d123e983SMichael Kruse /// 264d123e983SMichael Kruse /// @param Ctx An isl::ctx, which is used to create the isl::union_set. 265d123e983SMichael Kruse static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) { 266d123e983SMichael Kruse isl::space Space = isl::space(Ctx, 0, 0, 1); 267d123e983SMichael Kruse isl::map UnrollIsolatedSetOption = isl::map::universe(Space); 268d123e983SMichael Kruse isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr); 269d123e983SMichael Kruse isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr); 270d123e983SMichael Kruse UnrollIsolatedSetOption = 271d123e983SMichael Kruse UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId); 272d123e983SMichael Kruse UnrollIsolatedSetOption = 273d123e983SMichael Kruse UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId); 274d123e983SMichael Kruse return UnrollIsolatedSetOption.wrap(); 275d123e983SMichael Kruse } 276d123e983SMichael Kruse 277d123e983SMichael Kruse /// Permute the two dimensions of the isl map. 278d123e983SMichael Kruse /// 279d123e983SMichael Kruse /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that 280d123e983SMichael Kruse /// have type @p DimType. 281d123e983SMichael Kruse /// 282d123e983SMichael Kruse /// @param Map The isl map to be modified. 283d123e983SMichael Kruse /// @param DimType The type of the dimensions. 284d123e983SMichael Kruse /// @param DstPos The first dimension. 285d123e983SMichael Kruse /// @param SrcPos The second dimension. 286d123e983SMichael Kruse /// @return The modified map. 287d123e983SMichael Kruse static isl::map permuteDimensions(isl::map Map, isl::dim DimType, 288d123e983SMichael Kruse unsigned DstPos, unsigned SrcPos) { 28944596fe6SRiccardo Mori assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) && 29044596fe6SRiccardo Mori SrcPos < unsignedFromIslSize(Map.dim(DimType))); 291d123e983SMichael Kruse if (DstPos == SrcPos) 292d123e983SMichael Kruse return Map; 293d123e983SMichael Kruse isl::id DimId; 294d123e983SMichael Kruse if (Map.has_tuple_id(DimType)) 295d123e983SMichael Kruse DimId = Map.get_tuple_id(DimType); 296d123e983SMichael Kruse auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in; 297d123e983SMichael Kruse isl::id FreeDimId; 298d123e983SMichael Kruse if (Map.has_tuple_id(FreeDim)) 299d123e983SMichael Kruse FreeDimId = Map.get_tuple_id(FreeDim); 300d123e983SMichael Kruse auto MaxDim = std::max(DstPos, SrcPos); 301d123e983SMichael Kruse auto MinDim = std::min(DstPos, SrcPos); 302d123e983SMichael Kruse Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1); 303d123e983SMichael Kruse Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1); 304d123e983SMichael Kruse Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1); 305d123e983SMichael Kruse Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1); 3067c7978a1Spatacca if (!DimId.is_null()) 307d123e983SMichael Kruse Map = Map.set_tuple_id(DimType, DimId); 3087c7978a1Spatacca if (!FreeDimId.is_null()) 309d123e983SMichael Kruse Map = Map.set_tuple_id(FreeDim, FreeDimId); 310d123e983SMichael Kruse return Map; 311d123e983SMichael Kruse } 312d123e983SMichael Kruse 313d123e983SMichael Kruse /// Check the form of the access relation. 314d123e983SMichael Kruse /// 315d123e983SMichael Kruse /// Check that the access relation @p AccMap has the form M[i][j], where i 316d123e983SMichael Kruse /// is a @p FirstPos and j is a @p SecondPos. 317d123e983SMichael Kruse /// 318d123e983SMichael Kruse /// @param AccMap The access relation to be checked. 319d123e983SMichael Kruse /// @param FirstPos The index of the input dimension that is mapped to 320d123e983SMichael Kruse /// the first output dimension. 321d123e983SMichael Kruse /// @param SecondPos The index of the input dimension that is mapped to the 322d123e983SMichael Kruse /// second output dimension. 323d123e983SMichael Kruse /// @return True in case @p AccMap has the expected form and false, 324d123e983SMichael Kruse /// otherwise. 325d123e983SMichael Kruse static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, 326d123e983SMichael Kruse int &SecondPos) { 327d123e983SMichael Kruse isl::space Space = AccMap.get_space(); 328d123e983SMichael Kruse isl::map Universe = isl::map::universe(Space); 329d123e983SMichael Kruse 33044596fe6SRiccardo Mori if (unsignedFromIslSize(Space.dim(isl::dim::out)) != 2) 331d123e983SMichael Kruse return false; 332d123e983SMichael Kruse 333d123e983SMichael Kruse // MatMul has the form: 334d123e983SMichael Kruse // for (i = 0; i < N; i++) 335d123e983SMichael Kruse // for (j = 0; j < M; j++) 336d123e983SMichael Kruse // for (k = 0; k < P; k++) 337d123e983SMichael Kruse // C[i, j] += A[i, k] * B[k, j] 338d123e983SMichael Kruse // 339d123e983SMichael Kruse // Permutation of three outer loops: 3! = 6 possibilities. 340d123e983SMichael Kruse int FirstDims[] = {0, 0, 1, 1, 2, 2}; 341d123e983SMichael Kruse int SecondDims[] = {1, 2, 2, 0, 0, 1}; 342d123e983SMichael Kruse for (int i = 0; i < 6; i += 1) { 343d123e983SMichael Kruse auto PossibleMatMul = 344d123e983SMichael Kruse Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) 345d123e983SMichael Kruse .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); 346d123e983SMichael Kruse 347d123e983SMichael Kruse AccMap = AccMap.intersect_domain(Domain); 348d123e983SMichael Kruse PossibleMatMul = PossibleMatMul.intersect_domain(Domain); 349d123e983SMichael Kruse 350d123e983SMichael Kruse // If AccMap spans entire domain (Non-partial write), 351d123e983SMichael Kruse // compute FirstPos and SecondPos. 352d123e983SMichael Kruse // If AccMap != PossibleMatMul here (the two maps have been gisted at 353d123e983SMichael Kruse // this point), it means that the writes are not complete, or in other 354d123e983SMichael Kruse // words, it is a Partial write and Partial writes must be rejected. 355d123e983SMichael Kruse if (AccMap.is_equal(PossibleMatMul)) { 356d123e983SMichael Kruse if (FirstPos != -1 && FirstPos != FirstDims[i]) 357d123e983SMichael Kruse continue; 358d123e983SMichael Kruse FirstPos = FirstDims[i]; 359d123e983SMichael Kruse if (SecondPos != -1 && SecondPos != SecondDims[i]) 360d123e983SMichael Kruse continue; 361d123e983SMichael Kruse SecondPos = SecondDims[i]; 362d123e983SMichael Kruse return true; 363d123e983SMichael Kruse } 364d123e983SMichael Kruse } 365d123e983SMichael Kruse 366d123e983SMichael Kruse return false; 367d123e983SMichael Kruse } 368d123e983SMichael Kruse 369d123e983SMichael Kruse /// Does the memory access represent a non-scalar operand of the matrix 370d123e983SMichael Kruse /// multiplication. 371d123e983SMichael Kruse /// 372d123e983SMichael Kruse /// Check that the memory access @p MemAccess is the read access to a non-scalar 373d123e983SMichael Kruse /// operand of the matrix multiplication or its result. 374d123e983SMichael Kruse /// 375d123e983SMichael Kruse /// @param MemAccess The memory access to be checked. 376d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 377d123e983SMichael Kruse /// @return True in case the memory access represents the read access 378d123e983SMichael Kruse /// to a non-scalar operand of the matrix multiplication and 379d123e983SMichael Kruse /// false, otherwise. 380d123e983SMichael Kruse static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, 381d123e983SMichael Kruse MatMulInfoTy &MMI) { 382d123e983SMichael Kruse if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) 383d123e983SMichael Kruse return false; 384d123e983SMichael Kruse auto AccMap = MemAccess->getLatestAccessRelation(); 385d123e983SMichael Kruse isl::set StmtDomain = MemAccess->getStatement()->getDomain(); 386d123e983SMichael Kruse if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { 387d123e983SMichael Kruse MMI.ReadFromC = MemAccess; 388d123e983SMichael Kruse return true; 389d123e983SMichael Kruse } 390d123e983SMichael Kruse if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { 391d123e983SMichael Kruse MMI.A = MemAccess; 392d123e983SMichael Kruse return true; 393d123e983SMichael Kruse } 394d123e983SMichael Kruse if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { 395d123e983SMichael Kruse MMI.B = MemAccess; 396d123e983SMichael Kruse return true; 397d123e983SMichael Kruse } 398d123e983SMichael Kruse return false; 399d123e983SMichael Kruse } 400d123e983SMichael Kruse 401d123e983SMichael Kruse /// Check accesses to operands of the matrix multiplication. 402d123e983SMichael Kruse /// 403d123e983SMichael Kruse /// Check that accesses of the SCoP statement, which corresponds to 404d123e983SMichael Kruse /// the partial schedule @p PartialSchedule, are scalar in terms of loops 405d123e983SMichael Kruse /// containing the matrix multiplication, in case they do not represent 406d123e983SMichael Kruse /// accesses to the non-scalar operands of the matrix multiplication or 407d123e983SMichael Kruse /// its result. 408d123e983SMichael Kruse /// 409d123e983SMichael Kruse /// @param PartialSchedule The partial schedule of the SCoP statement. 410d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 411d123e983SMichael Kruse /// @return True in case the corresponding SCoP statement 412d123e983SMichael Kruse /// represents matrix multiplication and false, 413d123e983SMichael Kruse /// otherwise. 414d123e983SMichael Kruse static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, 415d123e983SMichael Kruse MatMulInfoTy &MMI) { 416d123e983SMichael Kruse auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in); 417d123e983SMichael Kruse auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user()); 41844596fe6SRiccardo Mori unsigned OutDimNum = unsignedFromIslSize(PartialSchedule.range_tuple_dim()); 419d123e983SMichael Kruse assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " 420d123e983SMichael Kruse "and, consequently, the corresponding scheduling " 421d123e983SMichael Kruse "functions have at least three dimensions."); 422d123e983SMichael Kruse auto MapI = 423d123e983SMichael Kruse permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1); 424d123e983SMichael Kruse auto MapJ = 425d123e983SMichael Kruse permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1); 426d123e983SMichael Kruse auto MapK = 427d123e983SMichael Kruse permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1); 428d123e983SMichael Kruse 429d123e983SMichael Kruse auto Accesses = getAccessesInOrder(*Stmt); 430d123e983SMichael Kruse for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) { 431d123e983SMichael Kruse auto *MemAccessPtr = *MemA; 432d123e983SMichael Kruse if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC && 433d123e983SMichael Kruse !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) && 434113fa82cSRoman Gareev !(MemAccessPtr->isStrideZero(MapI) && 435113fa82cSRoman Gareev MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK))) 436d123e983SMichael Kruse return false; 437d123e983SMichael Kruse } 438d123e983SMichael Kruse return true; 439d123e983SMichael Kruse } 440d123e983SMichael Kruse 441d123e983SMichael Kruse /// Check for dependencies corresponding to the matrix multiplication. 442d123e983SMichael Kruse /// 443d123e983SMichael Kruse /// Check that there is only true dependence of the form 444d123e983SMichael Kruse /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement 445d123e983SMichael Kruse /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds 446d123e983SMichael Kruse /// to the dependency produced by the matrix multiplication. 447d123e983SMichael Kruse /// 448d123e983SMichael Kruse /// @param Schedule The schedule of the SCoP statement. 449d123e983SMichael Kruse /// @param D The SCoP dependencies. 450d123e983SMichael Kruse /// @param Pos The parameter to describe an acceptable true dependence. 451d123e983SMichael Kruse /// In case it has a negative value, try to determine its 452d123e983SMichael Kruse /// acceptable value. 453d123e983SMichael Kruse /// @return True in case dependencies correspond to the matrix multiplication 454d123e983SMichael Kruse /// and false, otherwise. 455d123e983SMichael Kruse static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D, 456d123e983SMichael Kruse int &Pos) { 457d123e983SMichael Kruse isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW); 458d123e983SMichael Kruse isl::union_map Red = D->getDependences(Dependences::TYPE_RED); 4597c7978a1Spatacca if (!Red.is_null()) 460d123e983SMichael Kruse Dep = Dep.unite(Red); 461d123e983SMichael Kruse auto DomainSpace = Schedule.get_space().domain(); 462d123e983SMichael Kruse auto Space = DomainSpace.map_from_domain_and_range(DomainSpace); 463d123e983SMichael Kruse auto Deltas = Dep.extract_map(Space).deltas(); 46444596fe6SRiccardo Mori int DeltasDimNum = unsignedFromIslSize(Deltas.dim(isl::dim::set)); 465d123e983SMichael Kruse for (int i = 0; i < DeltasDimNum; i++) { 466d123e983SMichael Kruse auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i); 467d123e983SMichael Kruse Pos = Pos < 0 && Val.is_one() ? i : Pos; 468d123e983SMichael Kruse if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one()))) 469d123e983SMichael Kruse return false; 470d123e983SMichael Kruse } 471d123e983SMichael Kruse if (DeltasDimNum == 0 || Pos < 0) 472d123e983SMichael Kruse return false; 473d123e983SMichael Kruse return true; 474d123e983SMichael Kruse } 475d123e983SMichael Kruse 476d123e983SMichael Kruse /// Check if the SCoP statement could probably be optimized with analytical 477d123e983SMichael Kruse /// modeling. 478d123e983SMichael Kruse /// 479d123e983SMichael Kruse /// containsMatrMult tries to determine whether the following conditions 480d123e983SMichael Kruse /// are true: 481d123e983SMichael Kruse /// 1. The last memory access modeling an array, MA1, represents writing to 482d123e983SMichael Kruse /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or 483d123e983SMichael Kruse /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement 484d123e983SMichael Kruse /// under consideration. 485d123e983SMichael Kruse /// 2. There is only one loop-carried true dependency, and it has the 486d123e983SMichael Kruse /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no 487d123e983SMichael Kruse /// loop-carried or anti dependencies. 488d123e983SMichael Kruse /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent 489d123e983SMichael Kruse /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3), 490d123e983SMichael Kruse /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively, 491d123e983SMichael Kruse /// and all memory accesses of the SCoP that are different from MA1, MA2, 492d123e983SMichael Kruse /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any 493d123e983SMichael Kruse /// of loops i1, i2 and i3. 494d123e983SMichael Kruse /// 495d123e983SMichael Kruse /// @param PartialSchedule The PartialSchedule that contains a SCoP statement 496d123e983SMichael Kruse /// to check. 497d123e983SMichael Kruse /// @D The SCoP dependencies. 498d123e983SMichael Kruse /// @MMI Parameters of the matrix multiplication operands. 499d123e983SMichael Kruse static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, 500d123e983SMichael Kruse MatMulInfoTy &MMI) { 501d123e983SMichael Kruse auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); 502d123e983SMichael Kruse auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 503d123e983SMichael Kruse if (Stmt->size() <= 1) 504d123e983SMichael Kruse return false; 505d123e983SMichael Kruse 506d123e983SMichael Kruse auto Accesses = getAccessesInOrder(*Stmt); 507d123e983SMichael Kruse for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { 508d123e983SMichael Kruse auto *MemAccessPtr = *MemA; 509d123e983SMichael Kruse if (!MemAccessPtr->isLatestArrayKind()) 510d123e983SMichael Kruse continue; 511d123e983SMichael Kruse if (!MemAccessPtr->isWrite()) 512d123e983SMichael Kruse return false; 513d123e983SMichael Kruse auto AccMap = MemAccessPtr->getLatestAccessRelation(); 514d123e983SMichael Kruse if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) 515d123e983SMichael Kruse return false; 516d123e983SMichael Kruse MMI.WriteToC = MemAccessPtr; 517d123e983SMichael Kruse break; 518d123e983SMichael Kruse } 519d123e983SMichael Kruse 520d123e983SMichael Kruse if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k)) 521d123e983SMichael Kruse return false; 522d123e983SMichael Kruse 523d123e983SMichael Kruse if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI)) 524d123e983SMichael Kruse return false; 525d123e983SMichael Kruse 526d123e983SMichael Kruse if (!MMI.A || !MMI.B || !MMI.ReadFromC) 527d123e983SMichael Kruse return false; 528d123e983SMichael Kruse return true; 529d123e983SMichael Kruse } 530d123e983SMichael Kruse 531d123e983SMichael Kruse /// Permute two dimensions of the band node. 532d123e983SMichael Kruse /// 533d123e983SMichael Kruse /// Permute FirstDim and SecondDim dimensions of the Node. 534d123e983SMichael Kruse /// 535d123e983SMichael Kruse /// @param Node The band node to be modified. 536d123e983SMichael Kruse /// @param FirstDim The first dimension to be permuted. 537d123e983SMichael Kruse /// @param SecondDim The second dimension to be permuted. 538d123e983SMichael Kruse static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node, 539d123e983SMichael Kruse unsigned FirstDim, 540d123e983SMichael Kruse unsigned SecondDim) { 541d123e983SMichael Kruse assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band && 542d123e983SMichael Kruse (unsigned)isl_schedule_node_band_n_member(Node.get()) > 543d123e983SMichael Kruse std::max(FirstDim, SecondDim)); 544d123e983SMichael Kruse auto PartialSchedule = 545d123e983SMichael Kruse isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get())); 546d3fdbda6SRiccardo Mori auto PartialScheduleFirstDim = PartialSchedule.at(FirstDim); 547d3fdbda6SRiccardo Mori auto PartialScheduleSecondDim = PartialSchedule.at(SecondDim); 548d123e983SMichael Kruse PartialSchedule = 549d123e983SMichael Kruse PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim); 550d123e983SMichael Kruse PartialSchedule = 551d123e983SMichael Kruse PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim); 552d123e983SMichael Kruse Node = isl::manage(isl_schedule_node_delete(Node.release())); 553d123e983SMichael Kruse return Node.insert_partial_schedule(PartialSchedule); 554d123e983SMichael Kruse } 555d123e983SMichael Kruse 556d123e983SMichael Kruse static isl::schedule_node 557d123e983SMichael Kruse createMicroKernel(isl::schedule_node Node, 558d123e983SMichael Kruse MicroKernelParamsTy MicroKernelParams) { 559d123e983SMichael Kruse Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr}, 560d123e983SMichael Kruse 1); 561d123e983SMichael Kruse Node = Node.parent().parent(); 562d123e983SMichael Kruse return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0); 563d123e983SMichael Kruse } 564d123e983SMichael Kruse 565d123e983SMichael Kruse /// Create the BLIS macro-kernel. 566d123e983SMichael Kruse /// 567d123e983SMichael Kruse /// We create the BLIS macro-kernel by applying a combination of tiling 568d123e983SMichael Kruse /// of dimensions of the band node and interchanging of two innermost 569678e3ee1SFangrui Song /// modified dimensions. The values of MacroKernelParams's fields are used 570d123e983SMichael Kruse /// as tile sizes. 571d123e983SMichael Kruse /// 572d123e983SMichael Kruse /// @param Node The schedule node to be modified. 573d123e983SMichael Kruse /// @param MacroKernelParams Parameters of the macro kernel 574d123e983SMichael Kruse /// to be used as tile sizes. 575d123e983SMichael Kruse static isl::schedule_node 576d123e983SMichael Kruse createMacroKernel(isl::schedule_node Node, 577d123e983SMichael Kruse MacroKernelParamsTy MacroKernelParams) { 578d123e983SMichael Kruse assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 579d123e983SMichael Kruse if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 && 580d123e983SMichael Kruse MacroKernelParams.Kc == 1) 581d123e983SMichael Kruse return Node; 582d123e983SMichael Kruse int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 583d123e983SMichael Kruse std::vector<int> TileSizes(DimOutNum, 1); 584d123e983SMichael Kruse TileSizes[DimOutNum - 3] = MacroKernelParams.Mc; 585d123e983SMichael Kruse TileSizes[DimOutNum - 2] = MacroKernelParams.Nc; 586d123e983SMichael Kruse TileSizes[DimOutNum - 1] = MacroKernelParams.Kc; 587d123e983SMichael Kruse Node = tileNode(Node, "1st level tiling", TileSizes, 1); 588d123e983SMichael Kruse Node = Node.parent().parent(); 589d123e983SMichael Kruse Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1); 590d123e983SMichael Kruse Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1); 591d123e983SMichael Kruse 592d123e983SMichael Kruse return Node.child(0).child(0); 593d123e983SMichael Kruse } 594d123e983SMichael Kruse 595d123e983SMichael Kruse /// Get the size of the widest type of the matrix multiplication operands 596d123e983SMichael Kruse /// in bytes, including alignment padding. 597d123e983SMichael Kruse /// 598d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 599d123e983SMichael Kruse /// @return The size of the widest type of the matrix multiplication operands 600d123e983SMichael Kruse /// in bytes, including alignment padding. 6014c6ae8ebSKarthika Devi C static uint64_t getMatMulAlignTypeSize(const MatMulInfoTy &MMI) { 602d123e983SMichael Kruse auto *S = MMI.A->getStatement()->getParent(); 603d123e983SMichael Kruse auto &DL = S->getFunction().getParent()->getDataLayout(); 604d123e983SMichael Kruse auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType()); 605d123e983SMichael Kruse auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType()); 606d123e983SMichael Kruse auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType()); 607d123e983SMichael Kruse return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 608d123e983SMichael Kruse } 609d123e983SMichael Kruse 610d123e983SMichael Kruse /// Get the size of the widest type of the matrix multiplication operands 611d123e983SMichael Kruse /// in bits. 612d123e983SMichael Kruse /// 613d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 614d123e983SMichael Kruse /// @return The size of the widest type of the matrix multiplication operands 615d123e983SMichael Kruse /// in bits. 6164c6ae8ebSKarthika Devi C static uint64_t getMatMulTypeSize(const MatMulInfoTy &MMI) { 617d123e983SMichael Kruse auto *S = MMI.A->getStatement()->getParent(); 618d123e983SMichael Kruse auto &DL = S->getFunction().getParent()->getDataLayout(); 619d123e983SMichael Kruse auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType()); 620d123e983SMichael Kruse auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType()); 621d123e983SMichael Kruse auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType()); 622d123e983SMichael Kruse return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 623d123e983SMichael Kruse } 624d123e983SMichael Kruse 625d123e983SMichael Kruse /// Get parameters of the BLIS micro kernel. 626d123e983SMichael Kruse /// 627d123e983SMichael Kruse /// We choose the Mr and Nr parameters of the micro kernel to be large enough 628d123e983SMichael Kruse /// such that no stalls caused by the combination of latencies and dependencies 629d123e983SMichael Kruse /// are introduced during the updates of the resulting matrix of the matrix 630d123e983SMichael Kruse /// multiplication. However, they should also be as small as possible to 631d123e983SMichael Kruse /// release more registers for entries of multiplied matrices. 632d123e983SMichael Kruse /// 633d123e983SMichael Kruse /// @param TTI Target Transform Info. 634d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 635d123e983SMichael Kruse /// @return The structure of type MicroKernelParamsTy. 636d123e983SMichael Kruse /// @see MicroKernelParamsTy 637bd93df93SMichael Kruse static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI, 6384c6ae8ebSKarthika Devi C const MatMulInfoTy &MMI) { 639d123e983SMichael Kruse assert(TTI && "The target transform info should be provided."); 640d123e983SMichael Kruse 641d123e983SMichael Kruse // Nvec - Number of double-precision floating-point numbers that can be hold 642d123e983SMichael Kruse // by a vector register. Use 2 by default. 643d123e983SMichael Kruse long RegisterBitwidth = VectorRegisterBitwidth; 644d123e983SMichael Kruse 645d123e983SMichael Kruse if (RegisterBitwidth == -1) 646d123e983SMichael Kruse RegisterBitwidth = 647d123e983SMichael Kruse TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector); 648d123e983SMichael Kruse auto ElementSize = getMatMulTypeSize(MMI); 649d123e983SMichael Kruse assert(ElementSize > 0 && "The element size of the matrix multiplication " 650d123e983SMichael Kruse "operands should be greater than zero."); 651d123e983SMichael Kruse auto Nvec = RegisterBitwidth / ElementSize; 652d123e983SMichael Kruse if (Nvec == 0) 653d123e983SMichael Kruse Nvec = 2; 654d123e983SMichael Kruse int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) / 655d123e983SMichael Kruse Nvec) * 656d123e983SMichael Kruse Nvec; 657d123e983SMichael Kruse int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr)); 658d123e983SMichael Kruse return {Mr, Nr}; 659d123e983SMichael Kruse } 660d123e983SMichael Kruse 661d123e983SMichael Kruse /// Determine parameters of the target cache. 662d123e983SMichael Kruse /// 663d123e983SMichael Kruse /// @param TTI Target Transform Info. 664d123e983SMichael Kruse static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) { 665d123e983SMichael Kruse auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D; 666d123e983SMichael Kruse auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D; 667d123e983SMichael Kruse if (FirstCacheLevelSize == -1) { 66894460f51SKazu Hirata if (TTI->getCacheSize(L1DCache)) 6695cff5142SKazu Hirata FirstCacheLevelSize = TTI->getCacheSize(L1DCache).value(); 670d123e983SMichael Kruse else 671d123e983SMichael Kruse FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize); 672d123e983SMichael Kruse } 673d123e983SMichael Kruse if (SecondCacheLevelSize == -1) { 67494460f51SKazu Hirata if (TTI->getCacheSize(L2DCache)) 6755cff5142SKazu Hirata SecondCacheLevelSize = TTI->getCacheSize(L2DCache).value(); 676d123e983SMichael Kruse else 677d123e983SMichael Kruse SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize); 678d123e983SMichael Kruse } 679d123e983SMichael Kruse if (FirstCacheLevelAssociativity == -1) { 68094460f51SKazu Hirata if (TTI->getCacheAssociativity(L1DCache)) 681d123e983SMichael Kruse FirstCacheLevelAssociativity = 6825cff5142SKazu Hirata TTI->getCacheAssociativity(L1DCache).value(); 683d123e983SMichael Kruse else 684d123e983SMichael Kruse FirstCacheLevelAssociativity = 685d123e983SMichael Kruse static_cast<int>(FirstCacheLevelDefaultAssociativity); 686d123e983SMichael Kruse } 687d123e983SMichael Kruse if (SecondCacheLevelAssociativity == -1) { 68894460f51SKazu Hirata if (TTI->getCacheAssociativity(L2DCache)) 689d123e983SMichael Kruse SecondCacheLevelAssociativity = 6905cff5142SKazu Hirata TTI->getCacheAssociativity(L2DCache).value(); 691d123e983SMichael Kruse else 692d123e983SMichael Kruse SecondCacheLevelAssociativity = 693d123e983SMichael Kruse static_cast<int>(SecondCacheLevelDefaultAssociativity); 694d123e983SMichael Kruse } 695d123e983SMichael Kruse } 696d123e983SMichael Kruse 697d123e983SMichael Kruse /// Get parameters of the BLIS macro kernel. 698d123e983SMichael Kruse /// 699d123e983SMichael Kruse /// During the computation of matrix multiplication, blocks of partitioned 700d123e983SMichael Kruse /// matrices are mapped to different layers of the memory hierarchy. 701d123e983SMichael Kruse /// To optimize data reuse, blocks should be ideally kept in cache between 702d123e983SMichael Kruse /// iterations. Since parameters of the macro kernel determine sizes of these 703d123e983SMichael Kruse /// blocks, there are upper and lower bounds on these parameters. 704d123e983SMichael Kruse /// 705d123e983SMichael Kruse /// @param TTI Target Transform Info. 706d123e983SMichael Kruse /// @param MicroKernelParams Parameters of the micro-kernel 707d123e983SMichael Kruse /// to be taken into account. 708d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 709d123e983SMichael Kruse /// @return The structure of type MacroKernelParamsTy. 710d123e983SMichael Kruse /// @see MacroKernelParamsTy 711d123e983SMichael Kruse /// @see MicroKernelParamsTy 712bd93df93SMichael Kruse static MacroKernelParamsTy 713d123e983SMichael Kruse getMacroKernelParams(const llvm::TargetTransformInfo *TTI, 714d123e983SMichael Kruse const MicroKernelParamsTy &MicroKernelParams, 7154c6ae8ebSKarthika Devi C const MatMulInfoTy &MMI) { 716d123e983SMichael Kruse getTargetCacheParameters(TTI); 717d123e983SMichael Kruse // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, 718d123e983SMichael Kruse // it requires information about the first two levels of a cache to determine 719d123e983SMichael Kruse // all the parameters of a macro-kernel. It also checks that an associativity 720d123e983SMichael Kruse // degree of a cache level is greater than two. Otherwise, another algorithm 721d123e983SMichael Kruse // for determination of the parameters should be used. 722d123e983SMichael Kruse if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 && 723d123e983SMichael Kruse FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 && 724d123e983SMichael Kruse FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2)) 725d123e983SMichael Kruse return {1, 1, 1}; 726d123e983SMichael Kruse // The quotient should be greater than zero. 727d123e983SMichael Kruse if (PollyPatternMatchingNcQuotient <= 0) 728d123e983SMichael Kruse return {1, 1, 1}; 729d123e983SMichael Kruse int Car = floor( 730d123e983SMichael Kruse (FirstCacheLevelAssociativity - 1) / 731d123e983SMichael Kruse (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr)); 732d123e983SMichael Kruse 733d123e983SMichael Kruse // Car can be computed to be zero since it is floor to int. 734d123e983SMichael Kruse // On Mac OS, division by 0 does not raise a signal. This causes negative 735d123e983SMichael Kruse // tile sizes to be computed. Prevent division by Cac==0 by early returning 736d123e983SMichael Kruse // if this happens. 737d123e983SMichael Kruse if (Car == 0) 738d123e983SMichael Kruse return {1, 1, 1}; 739d123e983SMichael Kruse 740d123e983SMichael Kruse auto ElementSize = getMatMulAlignTypeSize(MMI); 741d123e983SMichael Kruse assert(ElementSize > 0 && "The element size of the matrix multiplication " 742d123e983SMichael Kruse "operands should be greater than zero."); 743d123e983SMichael Kruse int Kc = (Car * FirstCacheLevelSize) / 744d123e983SMichael Kruse (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize); 745d123e983SMichael Kruse double Cac = 746d123e983SMichael Kruse static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) / 747d123e983SMichael Kruse SecondCacheLevelSize; 748d123e983SMichael Kruse int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac); 749d123e983SMichael Kruse int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr; 750d123e983SMichael Kruse 751d123e983SMichael Kruse assert(Mc > 0 && Nc > 0 && Kc > 0 && 752d123e983SMichael Kruse "Matrix block sizes should be greater than zero"); 753d123e983SMichael Kruse return {Mc, Nc, Kc}; 754d123e983SMichael Kruse } 755d123e983SMichael Kruse 756d123e983SMichael Kruse /// Create an access relation that is specific to 757d123e983SMichael Kruse /// the matrix multiplication pattern. 758d123e983SMichael Kruse /// 759d123e983SMichael Kruse /// Create an access relation of the following form: 760d123e983SMichael Kruse /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] 761d123e983SMichael Kruse /// where I is @p FirstDim, J is @p SecondDim. 762d123e983SMichael Kruse /// 763d123e983SMichael Kruse /// It can be used, for example, to create relations that helps to consequently 764d123e983SMichael Kruse /// access elements of operands of a matrix multiplication after creation of 765d123e983SMichael Kruse /// the BLIS micro and macro kernels. 766d123e983SMichael Kruse /// 767d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMicroKernel 768d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMacroKernel 769d123e983SMichael Kruse /// 770d123e983SMichael Kruse /// Subsequently, the described access relation is applied to the range of 771d123e983SMichael Kruse /// @p MapOldIndVar, that is used to map original induction variables to 772d123e983SMichael Kruse /// the ones, which are produced by schedule transformations. It helps to 773d123e983SMichael Kruse /// define relations using a new space and, at the same time, keep them 774d123e983SMichael Kruse /// in the original one. 775d123e983SMichael Kruse /// 776d123e983SMichael Kruse /// @param MapOldIndVar The relation, which maps original induction variables 777d123e983SMichael Kruse /// to the ones, which are produced by schedule 778d123e983SMichael Kruse /// transformations. 779d123e983SMichael Kruse /// @param FirstDim, SecondDim The input dimensions that are used to define 780d123e983SMichael Kruse /// the specified access relation. 781d123e983SMichael Kruse /// @return The specified access relation. 782d123e983SMichael Kruse static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim, 783d123e983SMichael Kruse unsigned SecondDim) { 7840813bd16SRiccardo Mori auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3); 785d123e983SMichael Kruse auto AccessRel = isl::map::universe(AccessRelSpace); 786d123e983SMichael Kruse AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0); 787d123e983SMichael Kruse AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1); 788d123e983SMichael Kruse AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2); 789d123e983SMichael Kruse return MapOldIndVar.apply_range(AccessRel); 790d123e983SMichael Kruse } 791d123e983SMichael Kruse 792d123e983SMichael Kruse static isl::schedule_node createExtensionNode(isl::schedule_node Node, 793d123e983SMichael Kruse isl::map ExtensionMap) { 794d123e983SMichael Kruse auto Extension = isl::union_map(ExtensionMap); 795d123e983SMichael Kruse auto NewNode = isl::schedule_node::from_extension(Extension); 796d123e983SMichael Kruse return Node.graft_before(NewNode); 797d123e983SMichael Kruse } 798d123e983SMichael Kruse 799a56bd7deSMichael Kruse static isl::schedule_node optimizePackedB(isl::schedule_node Node, 800a56bd7deSMichael Kruse ScopStmt *Stmt, isl::map MapOldIndVar, 801a56bd7deSMichael Kruse MicroKernelParamsTy MicroParams, 802a56bd7deSMichael Kruse MacroKernelParamsTy MacroParams, 803a56bd7deSMichael Kruse MatMulInfoTy &MMI) { 804a56bd7deSMichael Kruse Scop *S = Stmt->getParent(); 805a56bd7deSMichael Kruse isl::set Domain = Stmt->getDomain(); 806a56bd7deSMichael Kruse 807a56bd7deSMichael Kruse // Create packed array. 808a56bd7deSMichael Kruse unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; 809a56bd7deSMichael Kruse unsigned SecondDimSize = MacroParams.Kc; 810a56bd7deSMichael Kruse unsigned ThirdDimSize = MicroParams.Nr; 811a56bd7deSMichael Kruse ScopArrayInfo *PackedB = 812a56bd7deSMichael Kruse S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B", 813a56bd7deSMichael Kruse {FirstDimSize, SecondDimSize, ThirdDimSize}); 814a56bd7deSMichael Kruse 815a56bd7deSMichael Kruse // Compute the access relation for copying from B to PackedB. 816a56bd7deSMichael Kruse isl::map AccRelB = MMI.B->getLatestAccessRelation(); 817a56bd7deSMichael Kruse isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7); 818a56bd7deSMichael Kruse AccRelPackedB = 819a56bd7deSMichael Kruse AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId()); 820a56bd7deSMichael Kruse 821a56bd7deSMichael Kruse // Create the copy statement and redirect access. 822a56bd7deSMichael Kruse ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain); 823a56bd7deSMichael Kruse MMI.B->setNewAccessRelation(AccRelPackedB); 824a56bd7deSMichael Kruse 82544596fe6SRiccardo Mori unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim()); 82644596fe6SRiccardo Mori assert(Dim >= 2); 827a56bd7deSMichael Kruse // Insert into the schedule tree. 82844596fe6SRiccardo Mori isl::map ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, Dim - 2); 829a56bd7deSMichael Kruse ExtMap = ExtMap.reverse(); 830a56bd7deSMichael Kruse ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0); 831a56bd7deSMichael Kruse ExtMap = ExtMap.intersect_range(Domain); 832a56bd7deSMichael Kruse ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId()); 833a56bd7deSMichael Kruse return createExtensionNode(Node, ExtMap); 834a56bd7deSMichael Kruse } 835a56bd7deSMichael Kruse 836a56bd7deSMichael Kruse static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *, 837a56bd7deSMichael Kruse isl::map MapOldIndVar, 838a56bd7deSMichael Kruse MicroKernelParamsTy MicroParams, 839a56bd7deSMichael Kruse MacroKernelParamsTy MacroParams, 840a56bd7deSMichael Kruse MatMulInfoTy &MMI) { 841a56bd7deSMichael Kruse isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); 842a56bd7deSMichael Kruse ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 843a56bd7deSMichael Kruse isl::set Domain = Stmt->getDomain(); 844a56bd7deSMichael Kruse isl::id DomainId = Domain.get_tuple_id(); 845a56bd7deSMichael Kruse 846a56bd7deSMichael Kruse // Create the packed array. 847a56bd7deSMichael Kruse unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr; 848a56bd7deSMichael Kruse unsigned SecondDimSize = MacroParams.Kc; 849a56bd7deSMichael Kruse unsigned ThirdDimSize = MicroParams.Mr; 850a56bd7deSMichael Kruse ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo( 851a56bd7deSMichael Kruse MMI.A->getElementType(), "Packed_A", 852a56bd7deSMichael Kruse {FirstDimSize, SecondDimSize, ThirdDimSize}); 853a56bd7deSMichael Kruse 854a56bd7deSMichael Kruse // Compute the access relation for copying from A to PackedA. 855a56bd7deSMichael Kruse isl::map AccRelA = MMI.A->getLatestAccessRelation(); 856a56bd7deSMichael Kruse isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6); 857a56bd7deSMichael Kruse AccRelPackedA = 858a56bd7deSMichael Kruse AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId()); 859a56bd7deSMichael Kruse // { MemrefA[] -> PackedA[] } 860a56bd7deSMichael Kruse isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA); 861a56bd7deSMichael Kruse 862a56bd7deSMichael Kruse // Compute the domain for the copy statement. 863a56bd7deSMichael Kruse // Construct the copy statement domain out of the 3 outermost scatter 864a56bd7deSMichael Kruse // dimensions (to match the 3 band nodes surrounding the extension node) and 865a56bd7deSMichael Kruse // the array elements to copy (one statement instance per array element). 866a56bd7deSMichael Kruse // { Scatter[] } 867a56bd7deSMichael Kruse isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range(); 868a56bd7deSMichael Kruse // { Scatter[] -> OutermostScatter[] } 869a56bd7deSMichael Kruse isl::map OuterDomainMap = 870a56bd7deSMichael Kruse makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6); 871a56bd7deSMichael Kruse // { Scatter[] -> MemrefA[] } 872a56bd7deSMichael Kruse isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA); 873a56bd7deSMichael Kruse // { Scatter[] -> CopyStmt[] } 874a56bd7deSMichael Kruse isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom); 875a56bd7deSMichael Kruse // { CopyStmt[] } 876a56bd7deSMichael Kruse isl::set CopyDomain = DomainTranslator.range(); 877a56bd7deSMichael Kruse 878a56bd7deSMichael Kruse // Translate the access relations to the new domain. 879a56bd7deSMichael Kruse // { CopyStmt[] -> MemrefA[] } 880a56bd7deSMichael Kruse CopyFrom = CopyFrom.apply_domain(DomainTranslator); 881a56bd7deSMichael Kruse // { CopyStmt[] -> PackedA[] } 882a56bd7deSMichael Kruse isl::map CopyTo = CopyFrom.apply_range(PackedATranslator); 883a56bd7deSMichael Kruse 884a56bd7deSMichael Kruse // Create the copy statement and redirect access. 885a56bd7deSMichael Kruse ScopStmt *CopyStmt = 886a56bd7deSMichael Kruse Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain); 887a56bd7deSMichael Kruse MMI.A->setNewAccessRelation(AccRelPackedA); 888a56bd7deSMichael Kruse 889a56bd7deSMichael Kruse // Insert into the schedule tree. 890a56bd7deSMichael Kruse // { Scatter[] -> CopyStmt[] } 891a56bd7deSMichael Kruse isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true); 892a56bd7deSMichael Kruse ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2); 893a56bd7deSMichael Kruse return createExtensionNode(Node, ExtScatterCopy); 894a56bd7deSMichael Kruse } 895a56bd7deSMichael Kruse 896d123e983SMichael Kruse /// Apply the packing transformation. 897d123e983SMichael Kruse /// 898d123e983SMichael Kruse /// The packing transformation can be described as a data-layout 899d123e983SMichael Kruse /// transformation that requires to introduce a new array, copy data 900d123e983SMichael Kruse /// to the array, and change memory access locations to reference the array. 901d123e983SMichael Kruse /// It can be used to ensure that elements of the new array are read in-stride 902d123e983SMichael Kruse /// access, aligned to cache lines boundaries, and preloaded into certain cache 903d123e983SMichael Kruse /// levels. 904d123e983SMichael Kruse /// 905d123e983SMichael Kruse /// As an example let us consider the packing of the array A that would help 906d123e983SMichael Kruse /// to read its elements with in-stride access. An access to the array A 907d123e983SMichael Kruse /// is represented by an access relation that has the form 908d123e983SMichael Kruse /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has 909d123e983SMichael Kruse /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr), 910d123e983SMichael Kruse /// k mod Kc, j mod Nr, i mod Mr]. 911d123e983SMichael Kruse /// 912d123e983SMichael Kruse /// To ensure that elements of the array A are read in-stride access, we add 913d123e983SMichael Kruse /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using 914d123e983SMichael Kruse /// Scop::createScopArrayInfo, change the access relation 915d123e983SMichael Kruse /// S[i, j, k] -> A[i, k] to 916d123e983SMichael Kruse /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using 917d123e983SMichael Kruse /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using 918d123e983SMichael Kruse /// the copy statement created by Scop::addScopStmt. 919d123e983SMichael Kruse /// 920d123e983SMichael Kruse /// @param Node The schedule node to be optimized. 921d123e983SMichael Kruse /// @param MapOldIndVar The relation, which maps original induction variables 922d123e983SMichael Kruse /// to the ones, which are produced by schedule 923d123e983SMichael Kruse /// transformations. 924d123e983SMichael Kruse /// @param MicroParams, MacroParams Parameters of the BLIS kernel 925d123e983SMichael Kruse /// to be taken into account. 926d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 927d123e983SMichael Kruse /// @return The optimized schedule node. 928d123e983SMichael Kruse static isl::schedule_node 929d123e983SMichael Kruse optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar, 930d123e983SMichael Kruse MicroKernelParamsTy MicroParams, 931d123e983SMichael Kruse MacroKernelParamsTy MacroParams, 932d123e983SMichael Kruse MatMulInfoTy &MMI) { 933a56bd7deSMichael Kruse isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); 934a56bd7deSMichael Kruse ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 935d123e983SMichael Kruse 936d123e983SMichael Kruse Node = Node.parent().parent().parent().parent().parent().parent(); 937a56bd7deSMichael Kruse Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)); 938d123e983SMichael Kruse 939d123e983SMichael Kruse Node = Node.child(0); 940a56bd7deSMichael Kruse Node = 941a56bd7deSMichael Kruse optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); 942d123e983SMichael Kruse 943a56bd7deSMichael Kruse Node = Node.child(0); 944a56bd7deSMichael Kruse Node = 945a56bd7deSMichael Kruse optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); 946a56bd7deSMichael Kruse 947d123e983SMichael Kruse return Node.child(0).child(0).child(0).child(0).child(0); 948d123e983SMichael Kruse } 949d123e983SMichael Kruse 950d123e983SMichael Kruse /// Get a relation mapping induction variables produced by schedule 951d123e983SMichael Kruse /// transformations to the original ones. 952d123e983SMichael Kruse /// 953d123e983SMichael Kruse /// @param Node The schedule node produced as the result of creation 954d123e983SMichael Kruse /// of the BLIS kernels. 955d123e983SMichael Kruse /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel 956d123e983SMichael Kruse /// to be taken into account. 957d123e983SMichael Kruse /// @return The relation mapping original induction variables to the ones 958d123e983SMichael Kruse /// produced by schedule transformation. 959d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMicroKernel 960d123e983SMichael Kruse /// @see ScheduleTreeOptimizer::createMacroKernel 961d123e983SMichael Kruse /// @see getMacroKernelParams 962d123e983SMichael Kruse static isl::map 963d123e983SMichael Kruse getInductionVariablesSubstitution(isl::schedule_node Node, 964d123e983SMichael Kruse MicroKernelParamsTy MicroKernelParams, 965d123e983SMichael Kruse MacroKernelParamsTy MacroKernelParams) { 966d123e983SMichael Kruse auto Child = Node.child(0); 967d123e983SMichael Kruse auto UnMapOldIndVar = Child.get_prefix_schedule_union_map(); 968d123e983SMichael Kruse auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar); 96944596fe6SRiccardo Mori unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim()); 97044596fe6SRiccardo Mori if (Dim > 9u) 97144596fe6SRiccardo Mori return MapOldIndVar.project_out(isl::dim::out, 0, Dim - 9); 972d123e983SMichael Kruse return MapOldIndVar; 973d123e983SMichael Kruse } 974d123e983SMichael Kruse 975d123e983SMichael Kruse /// Isolate a set of partial tile prefixes and unroll the isolated part. 976d123e983SMichael Kruse /// 977d123e983SMichael Kruse /// The set should ensure that it contains only partial tile prefixes that have 978d123e983SMichael Kruse /// exactly Mr x Nr iterations of the two innermost loops produced by 979d123e983SMichael Kruse /// the optimization of the matrix multiplication. Mr and Nr are parameters of 980d123e983SMichael Kruse /// the micro-kernel. 981d123e983SMichael Kruse /// 982d123e983SMichael Kruse /// In case of parametric bounds, this helps to auto-vectorize the unrolled 983d123e983SMichael Kruse /// innermost loops, using the SLP vectorizer. 984d123e983SMichael Kruse /// 985d123e983SMichael Kruse /// @param Node The schedule node to be modified. 986d123e983SMichael Kruse /// @param MicroKernelParams Parameters of the micro-kernel 987d123e983SMichael Kruse /// to be taken into account. 988d123e983SMichael Kruse /// @return The modified isl_schedule_node. 989d123e983SMichael Kruse static isl::schedule_node 990d123e983SMichael Kruse isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node, 991bd93df93SMichael Kruse MicroKernelParamsTy MicroKernelParams) { 992d3fdbda6SRiccardo Mori isl::schedule_node Child = Node.child(0); 993d123e983SMichael Kruse isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation(); 994d123e983SMichael Kruse isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range(); 99544596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(Prefix.tuple_dim()); 99644596fe6SRiccardo Mori assert(Dims >= 1); 997d123e983SMichael Kruse Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1); 998d123e983SMichael Kruse Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr); 999d123e983SMichael Kruse Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr); 1000d123e983SMichael Kruse 1001d123e983SMichael Kruse isl::union_set IsolateOption = 1002d123e983SMichael Kruse getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3); 10030813bd16SRiccardo Mori isl::ctx Ctx = Node.ctx(); 1004d123e983SMichael Kruse auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll")); 1005d123e983SMichael Kruse Options = Options.unite(getUnrollIsolatedSetOptions(Ctx)); 1006d3fdbda6SRiccardo Mori Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); 1007d123e983SMichael Kruse Node = Node.parent().parent().parent(); 1008d123e983SMichael Kruse IsolateOption = getIsolateOptions(Prefix, 3); 1009d123e983SMichael Kruse Options = IsolateOption.unite(getDimOptions(Ctx, "separate")); 1010d3fdbda6SRiccardo Mori Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); 1011d123e983SMichael Kruse Node = Node.child(0).child(0).child(0); 1012d123e983SMichael Kruse return Node; 1013d123e983SMichael Kruse } 1014d123e983SMichael Kruse 1015d123e983SMichael Kruse /// Insert "Loop Vectorizer Disabled" mark node. 1016d123e983SMichael Kruse /// 1017d123e983SMichael Kruse /// @param Node The child of the mark node to be inserted. 1018d123e983SMichael Kruse /// @return The modified isl_schedule_node. 1019d123e983SMichael Kruse static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) { 10200813bd16SRiccardo Mori auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr); 1021d123e983SMichael Kruse return Node.insert_mark(Id).child(0); 1022d123e983SMichael Kruse } 1023d123e983SMichael Kruse 1024d123e983SMichael Kruse /// Restore the initial ordering of dimensions of the band node 1025d123e983SMichael Kruse /// 1026d123e983SMichael Kruse /// In case the band node represents all the dimensions of the iteration 1027d123e983SMichael Kruse /// domain, recreate the band node to restore the initial ordering of the 1028d123e983SMichael Kruse /// dimensions. 1029d123e983SMichael Kruse /// 1030d123e983SMichael Kruse /// @param Node The band node to be modified. 1031d123e983SMichael Kruse /// @return The modified schedule node. 1032d123e983SMichael Kruse static isl::schedule_node 1033d123e983SMichael Kruse getBandNodeWithOriginDimOrder(isl::schedule_node Node) { 1034d123e983SMichael Kruse assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 1035d123e983SMichael Kruse if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf) 1036d123e983SMichael Kruse return Node; 1037d123e983SMichael Kruse auto Domain = Node.get_universe_domain(); 1038d123e983SMichael Kruse assert(isl_union_set_n_set(Domain.get()) == 1); 1039d3fdbda6SRiccardo Mori if (Node.get_schedule_depth().release() != 0 || 104044596fe6SRiccardo Mori (unsignedFromIslSize(isl::set(Domain).tuple_dim()) != 104144596fe6SRiccardo Mori unsignedFromIslSize(Node.as<isl::schedule_node_band>().n_member()))) 1042d123e983SMichael Kruse return Node; 1043d123e983SMichael Kruse Node = isl::manage(isl_schedule_node_delete(Node.copy())); 1044d123e983SMichael Kruse auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff(); 1045d123e983SMichael Kruse auto PartialScheduleMultiPwAff = 1046d123e983SMichael Kruse isl::multi_union_pw_aff(PartialSchedulePwAff); 1047d123e983SMichael Kruse PartialScheduleMultiPwAff = 1048d123e983SMichael Kruse PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set); 1049d123e983SMichael Kruse return Node.insert_partial_schedule(PartialScheduleMultiPwAff); 1050d123e983SMichael Kruse } 1051d123e983SMichael Kruse 1052d123e983SMichael Kruse static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node, 1053d123e983SMichael Kruse const TargetTransformInfo *TTI, 1054d123e983SMichael Kruse MatMulInfoTy &MMI) { 1055d123e983SMichael Kruse assert(TTI && "The target transform info should be provided."); 1056d123e983SMichael Kruse int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 1057d123e983SMichael Kruse assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " 1058d123e983SMichael Kruse "and, consequently, the corresponding scheduling " 1059d123e983SMichael Kruse "functions have at least three dimensions."); 1060d123e983SMichael Kruse Node = getBandNodeWithOriginDimOrder(Node); 1061d123e983SMichael Kruse Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3); 1062d123e983SMichael Kruse int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; 1063d123e983SMichael Kruse int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; 1064d123e983SMichael Kruse Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2); 1065d123e983SMichael Kruse NewK = NewK == DimOutNum - 2 ? NewJ : NewK; 1066d123e983SMichael Kruse Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1); 1067d123e983SMichael Kruse auto MicroKernelParams = getMicroKernelParams(TTI, MMI); 1068d123e983SMichael Kruse auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI); 1069d123e983SMichael Kruse Node = createMacroKernel(Node, MacroKernelParams); 1070d123e983SMichael Kruse Node = createMicroKernel(Node, MicroKernelParams); 1071d123e983SMichael Kruse if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || 1072d123e983SMichael Kruse MacroKernelParams.Kc == 1) 1073d123e983SMichael Kruse return Node; 1074d123e983SMichael Kruse auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams, 1075d123e983SMichael Kruse MacroKernelParams); 10767c7978a1Spatacca if (MapOldIndVar.is_null()) 1077d123e983SMichael Kruse return Node; 1078d123e983SMichael Kruse Node = markLoopVectorizerDisabled(Node.parent()).child(0); 1079d123e983SMichael Kruse Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); 1080d123e983SMichael Kruse return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, 1081d123e983SMichael Kruse MacroKernelParams, MMI); 1082d123e983SMichael Kruse } 1083d123e983SMichael Kruse 1084d123e983SMichael Kruse /// Check if this node contains a partial schedule that could 1085d123e983SMichael Kruse /// probably be optimized with analytical modeling. 1086d123e983SMichael Kruse /// 1087d123e983SMichael Kruse /// isMatrMultPattern tries to determine whether the following conditions 1088d123e983SMichael Kruse /// are true: 1089d123e983SMichael Kruse /// 1. the partial schedule contains only one statement. 1090d123e983SMichael Kruse /// 2. there are exactly three input dimensions. 1091d123e983SMichael Kruse /// 3. all memory accesses of the statement will have stride 0 or 1, if we 1092d123e983SMichael Kruse /// interchange loops (switch the variable used in the inner loop to 1093d123e983SMichael Kruse /// the outer loop). 1094d123e983SMichael Kruse /// 4. all memory accesses of the statement except from the last one, are 1095d123e983SMichael Kruse /// read memory access and the last one is write memory access. 1096d123e983SMichael Kruse /// 5. all subscripts of the last memory access of the statement don't 1097d123e983SMichael Kruse /// contain the variable used in the inner loop. 1098d123e983SMichael Kruse /// If this is the case, we could try to use an approach that is similar to 1099d123e983SMichael Kruse /// the one used to get close-to-peak performance of matrix multiplications. 1100d123e983SMichael Kruse /// 1101d123e983SMichael Kruse /// @param Node The node to check. 1102d123e983SMichael Kruse /// @param D The SCoP dependencies. 1103d123e983SMichael Kruse /// @param MMI Parameters of the matrix multiplication operands. 1104d123e983SMichael Kruse static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D, 1105d123e983SMichael Kruse MatMulInfoTy &MMI) { 1106d123e983SMichael Kruse auto PartialSchedule = isl::manage( 1107d123e983SMichael Kruse isl_schedule_node_band_get_partial_schedule_union_map(Node.get())); 1108b02c7e2bSRoman Gareev if (isl_schedule_node_band_n_member(Node.get()) < 3 || 1109d3fdbda6SRiccardo Mori Node.get_schedule_depth().release() != 0 || 1110d123e983SMichael Kruse isl_union_map_n_map(PartialSchedule.get()) != 1) 1111d123e983SMichael Kruse return false; 1112d123e983SMichael Kruse auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule); 1113d123e983SMichael Kruse if (containsMatrMult(NewPartialSchedule, D, MMI)) 1114d123e983SMichael Kruse return true; 1115d123e983SMichael Kruse return false; 1116d123e983SMichael Kruse } 1117d123e983SMichael Kruse 1118b02c7e2bSRoman Gareev /// Get the dimension size. 1119b02c7e2bSRoman Gareev /// 1120b02c7e2bSRoman Gareev /// Return the size of the dimension @p Pos, which is obtained from @p SAI. 1121b02c7e2bSRoman Gareev /// Return -1 in the case of the first dimension of a multi-dimensional array, 1122b02c7e2bSRoman Gareev /// since the ScopArrayInfo class does not carry size information. 1123b02c7e2bSRoman Gareev /// 1124b02c7e2bSRoman Gareev /// @param SAI The information about the array. 1125b02c7e2bSRoman Gareev /// @param Pos The position of the dimension. 1126b02c7e2bSRoman Gareev /// @return The size of the dimension. 1127b02c7e2bSRoman Gareev static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) { 1128b02c7e2bSRoman Gareev if (Pos == 0) 1129b02c7e2bSRoman Gareev return -1; 1130b02c7e2bSRoman Gareev const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Pos); 1131b02c7e2bSRoman Gareev assert(SCEVDimSize); 1132b02c7e2bSRoman Gareev auto *ConstantDimSize = dyn_cast<const SCEVConstant>(SCEVDimSize); 1133b02c7e2bSRoman Gareev assert(ConstantDimSize); 1134b02c7e2bSRoman Gareev auto *IntDimSize = dyn_cast<ConstantInt>(ConstantDimSize->getValue()); 1135b02c7e2bSRoman Gareev assert(IntDimSize); 1136b02c7e2bSRoman Gareev return IntDimSize->getSExtValue(); 1137b02c7e2bSRoman Gareev } 1138b02c7e2bSRoman Gareev 1139b02c7e2bSRoman Gareev /// Check whether the access relation has the specified form. 1140b02c7e2bSRoman Gareev /// 1141b02c7e2bSRoman Gareev /// Check that the access relation @p AccMap has the form T[I0, …, In], where 1142b02c7e2bSRoman Gareev /// indexes I0, …, In are specified by @p Dimensions. 1143b02c7e2bSRoman Gareev /// 1144b02c7e2bSRoman Gareev /// @param Domain The domain of the access relation. 1145b02c7e2bSRoman Gareev /// @param AccMap The access relation to be checked. 1146b02c7e2bSRoman Gareev /// @param Dimensions The permutation of the subset of the input dimensions. 1147b02c7e2bSRoman Gareev /// @return True if @p AccMap has the expected form and false, 1148b02c7e2bSRoman Gareev /// otherwise. 1149b02c7e2bSRoman Gareev static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap, 1150b02c7e2bSRoman Gareev ArrayRef<int> Dimensions) { 1151b02c7e2bSRoman Gareev isl::space Space = AccMap.get_space(); 1152b02c7e2bSRoman Gareev if (unsignedFromIslSize(Space.dim(isl::dim::out)) != Dimensions.size()) 1153b02c7e2bSRoman Gareev return false; 1154b02c7e2bSRoman Gareev 1155b02c7e2bSRoman Gareev // Create an access relation of the following form: 1156b02c7e2bSRoman Gareev // [I0, …, Im] -> [Il, …, In], where indexes 1157b02c7e2bSRoman Gareev // Il, …, In are specified by @p Dimensions. 1158b02c7e2bSRoman Gareev isl::map PossibleTensor = isl::map::universe(Space); 1159b02c7e2bSRoman Gareev unsigned DimInSize = unsignedFromIslSize(Space.dim(isl::dim::in)); 1160b02c7e2bSRoman Gareev for (unsigned i = 0; i < Dimensions.size(); i++) { 1161b02c7e2bSRoman Gareev const int InPos = Dimensions[i]; 1162b02c7e2bSRoman Gareev if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0)) 1163b02c7e2bSRoman Gareev return false; 1164b02c7e2bSRoman Gareev PossibleTensor = 1165b02c7e2bSRoman Gareev PossibleTensor.equate(isl::dim::in, InPos, isl::dim::out, i); 1166b02c7e2bSRoman Gareev } 1167b02c7e2bSRoman Gareev 1168b02c7e2bSRoman Gareev AccMap = AccMap.intersect_domain(Domain); 1169b02c7e2bSRoman Gareev PossibleTensor = PossibleTensor.intersect_domain(Domain); 1170b02c7e2bSRoman Gareev 1171b02c7e2bSRoman Gareev // If AccMap != PossibleTensor here (the two maps have been gisted at 1172b02c7e2bSRoman Gareev // this point), it means that the writes are not complete, or in other 1173b02c7e2bSRoman Gareev // words, it is a Partial write and Partial writes must be rejected. 1174b02c7e2bSRoman Gareev return AccMap.is_equal(PossibleTensor); 1175b02c7e2bSRoman Gareev } 1176b02c7e2bSRoman Gareev 1177b02c7e2bSRoman Gareev /// Check whether the access represents the tensor contraction operand. 1178b02c7e2bSRoman Gareev /// 1179b02c7e2bSRoman Gareev /// Check that the access relation @p AccMap has the form T[i1, …, in]. 1180b02c7e2bSRoman Gareev /// Obtained indexes i1, …, in, their sizes and their permutation are stored 1181b02c7e2bSRoman Gareev /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively. 1182b02c7e2bSRoman Gareev /// 1183b02c7e2bSRoman Gareev /// @param Domain The domain of the access relation. 1184b02c7e2bSRoman Gareev /// @param AccMap The access relation to be checked. 1185b02c7e2bSRoman Gareev /// @param IndexSet The subset of the input dimensions. 1186b02c7e2bSRoman Gareev /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions. 1187b02c7e2bSRoman Gareev /// @param Dimensions The permutation of the subset of the input dimensions. 1188b02c7e2bSRoman Gareev /// @return True if @p AccMap has the expected form and false, 1189b02c7e2bSRoman Gareev /// otherwise. 1190b02c7e2bSRoman Gareev static bool isTCOperandAcc(isl::set Domain, isl::map AccMap, 1191b02c7e2bSRoman Gareev SmallDenseSet<int> &IndexSet, 1192b02c7e2bSRoman Gareev SmallVectorImpl<int> &DimensionSizes, 1193b02c7e2bSRoman Gareev SmallVectorImpl<int> &Dimensions) { 1194b02c7e2bSRoman Gareev isl::id Id = AccMap.get_tuple_id(isl::dim::out); 1195b02c7e2bSRoman Gareev const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id); 1196b02c7e2bSRoman Gareev assert(SAI && "AccMap should represent memory access"); 1197b02c7e2bSRoman Gareev 1198b02c7e2bSRoman Gareev // Fix values of output dimensions with respect to their positions. 1199b02c7e2bSRoman Gareev // In the case of the tensor contraction, values of output dimensions are 1200b02c7e2bSRoman Gareev // fixed and form a permutation of a subset of values of input dimensions. 1201b02c7e2bSRoman Gareev // 1202b02c7e2bSRoman Gareev // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents 1203b02c7e2bSRoman Gareev // the operand of the tensor contraction, we get the following map by fixing 1204b02c7e2bSRoman Gareev // the output dimensions Stmt[1][j][0] -> A[0][1]. 1205b02c7e2bSRoman Gareev // 1206b02c7e2bSRoman Gareev // We store the permutation of the subset of the input dimensions {2, 0} into 1207b02c7e2bSRoman Gareev // @p Dimensions. 1208b02c7e2bSRoman Gareev // 1209b02c7e2bSRoman Gareev // The obtained permutation and the isCorrectAccessMap function are used to 1210b02c7e2bSRoman Gareev // check whether the access relation @p AccMap represents the tensor 1211b02c7e2bSRoman Gareev // contraction operand. For example, in the case of 1212b02c7e2bSRoman Gareev // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and, 1213b02c7e2bSRoman Gareev // consequently, {1, 0}, which is rejected by isCorrectAccessMap, 1214b02c7e2bSRoman Gareev // since it corresponds to Stmt[i][j][k] -> A[j][i]. 1215b02c7e2bSRoman Gareev isl::map CheckMap = isl::manage(AccMap.copy()); 1216b02c7e2bSRoman Gareev unsigned OutDimNum = unsignedFromIslSize(CheckMap.dim(isl::dim::out)); 1217b02c7e2bSRoman Gareev for (unsigned i = 0; i < OutDimNum; i++) 1218b02c7e2bSRoman Gareev CheckMap = CheckMap.fix_si(isl::dim::out, i, i); 1219b02c7e2bSRoman Gareev 1220b02c7e2bSRoman Gareev // Try to obtain the permutation and sizes of corresponding input dimensions. 1221b02c7e2bSRoman Gareev Dimensions.assign(OutDimNum, -1); 1222b02c7e2bSRoman Gareev for (unsigned i : rangeIslSize(0, CheckMap.dim(isl::dim::in))) { 1223b02c7e2bSRoman Gareev isl::val Val = getConstant(CheckMap, isl::dim::in, i); 1224b02c7e2bSRoman Gareev if (!Val.is_int()) 1225b02c7e2bSRoman Gareev continue; 1226b02c7e2bSRoman Gareev int OutPos = -1; 1227b02c7e2bSRoman Gareev llvm::APInt ValAPInt = APIntFromVal(Val); 1228b02c7e2bSRoman Gareev if (ValAPInt.isSignedIntN(32)) 1229b02c7e2bSRoman Gareev OutPos = ValAPInt.getSExtValue(); 1230b02c7e2bSRoman Gareev if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) || 1231b02c7e2bSRoman Gareev IndexSet.count(i)) 1232b02c7e2bSRoman Gareev return false; 1233b02c7e2bSRoman Gareev IndexSet.insert(i); 1234b02c7e2bSRoman Gareev Dimensions[OutPos] = i; 1235b02c7e2bSRoman Gareev if (DimensionSizes[i] <= 0) 1236b02c7e2bSRoman Gareev DimensionSizes[i] = getDimSize(SAI, OutPos); 1237b02c7e2bSRoman Gareev } 1238b02c7e2bSRoman Gareev 1239b02c7e2bSRoman Gareev return isCorrectAccessMap(Domain, AccMap, Dimensions); 1240b02c7e2bSRoman Gareev } 1241b02c7e2bSRoman Gareev 1242b02c7e2bSRoman Gareev /// Find the intersection of two sets. 1243b02c7e2bSRoman Gareev /// 1244b02c7e2bSRoman Gareev /// Find the intersection of the set @p A and the set @p B. 1245b02c7e2bSRoman Gareev /// 1246b02c7e2bSRoman Gareev /// @param A, B Sets to intersect. 1247b02c7e2bSRoman Gareev /// @return The set intersection. 1248b02c7e2bSRoman Gareev static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A, 1249b02c7e2bSRoman Gareev const SmallDenseSet<int> &B) { 1250b02c7e2bSRoman Gareev SmallDenseSet<int> Intersection = A; 1251b02c7e2bSRoman Gareev set_intersect(Intersection, B); 1252b02c7e2bSRoman Gareev return Intersection; 1253b02c7e2bSRoman Gareev } 1254b02c7e2bSRoman Gareev 1255b02c7e2bSRoman Gareev /// Check whether the set is a superset. 1256b02c7e2bSRoman Gareev /// 1257b02c7e2bSRoman Gareev /// Check that the set @p A is a superset of @p B. 1258b02c7e2bSRoman Gareev /// 1259b02c7e2bSRoman Gareev /// @param A, B Sets to be checked. 1260b02c7e2bSRoman Gareev /// @return True if the set A is a superset of B. 1261b02c7e2bSRoman Gareev static bool isSuperset(const SmallDenseSet<int> &A, 1262b02c7e2bSRoman Gareev const SmallDenseSet<int> &B) { 1263b02c7e2bSRoman Gareev return intersect(A, B).size() == B.size(); 1264b02c7e2bSRoman Gareev } 1265b02c7e2bSRoman Gareev 1266b02c7e2bSRoman Gareev /// Find the union of two sets. 1267b02c7e2bSRoman Gareev /// 1268b02c7e2bSRoman Gareev /// Find the union of the set @p A and the set @p B. 1269b02c7e2bSRoman Gareev /// 1270b02c7e2bSRoman Gareev /// @param A, B Sets to unite. 1271b02c7e2bSRoman Gareev /// @return The set union. 1272b02c7e2bSRoman Gareev static SmallDenseSet<int> unite(const SmallDenseSet<int> &A, 1273b02c7e2bSRoman Gareev const SmallDenseSet<int> &B) { 1274b02c7e2bSRoman Gareev SmallDenseSet<int> Union = A; 1275b02c7e2bSRoman Gareev set_union(Union, B); 1276b02c7e2bSRoman Gareev return Union; 1277b02c7e2bSRoman Gareev } 1278b02c7e2bSRoman Gareev 1279b02c7e2bSRoman Gareev /// Determine the access that writes to the tensor, which contains 1280b02c7e2bSRoman Gareev /// the result of the tensor contraction. 1281b02c7e2bSRoman Gareev /// 1282b02c7e2bSRoman Gareev /// @param Domain The domain of the statement. 1283b02c7e2bSRoman Gareev /// @param Stmt The statement, which writes to memory. 1284b02c7e2bSRoman Gareev /// @param TCI The information about the tensor contraction. 1285b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors. 1286b02c7e2bSRoman Gareev /// @return The determined MemoryAccess, or nullptr if there is no necessary 1287b02c7e2bSRoman Gareev /// access within the SCoP. 1288b02c7e2bSRoman Gareev static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt, 1289b02c7e2bSRoman Gareev TCInfoTy &TCI, 1290b02c7e2bSRoman Gareev SmallDenseSet<int> &IandJIndexSet) { 1291b02c7e2bSRoman Gareev TCI.WriteToC = nullptr; 1292b02c7e2bSRoman Gareev SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt); 1293b02c7e2bSRoman Gareev for (MemoryAccess *MemA : reverse(Accesses)) { 1294b02c7e2bSRoman Gareev // A TC-like does not contain write scalar memory accesses 1295b02c7e2bSRoman Gareev if (!MemA->isLatestArrayKind()) 1296b02c7e2bSRoman Gareev return nullptr; 1297b02c7e2bSRoman Gareev // The last memory access should be a write memory access. 1298b02c7e2bSRoman Gareev if (!MemA->isWrite()) 1299b02c7e2bSRoman Gareev return nullptr; 1300b02c7e2bSRoman Gareev 1301b02c7e2bSRoman Gareev isl::map AccMap = MemA->getLatestAccessRelation(); 1302b02c7e2bSRoman Gareev if (!isTCOperandAcc(Domain, AccMap, IandJIndexSet, TCI.DimensionSizes, 1303b02c7e2bSRoman Gareev TCI.CDimensions)) 1304b02c7e2bSRoman Gareev return nullptr; 1305b02c7e2bSRoman Gareev 1306b02c7e2bSRoman Gareev return MemA; 1307b02c7e2bSRoman Gareev } 1308b02c7e2bSRoman Gareev return nullptr; 1309b02c7e2bSRoman Gareev } 1310b02c7e2bSRoman Gareev 1311b02c7e2bSRoman Gareev /// Determine an access, which reads elements of an operand of the tensor 1312b02c7e2bSRoman Gareev /// contraction 1313b02c7e2bSRoman Gareev /// 1314b02c7e2bSRoman Gareev /// @param MemAccessPtr The access, which reads elements of the tensor. 1315b02c7e2bSRoman Gareev /// @param IndexSet The set, which contains indexes of the tensors. 1316b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors. 1317b02c7e2bSRoman Gareev /// @param Dimensions The permutation of the subset of the input dimensions. 1318b02c7e2bSRoman Gareev /// @param TCI The information about the tensor contraction. 1319b02c7e2bSRoman Gareev /// @return True if the memory access @p MemAccessPtr corresponds 1320b02c7e2bSRoman Gareev /// to the tensor contraction. 1321b02c7e2bSRoman Gareev static bool setReadAccess(MemoryAccess *MemAccessPtr, 1322b02c7e2bSRoman Gareev const SmallDenseSet<int> &IndexSet, 1323b02c7e2bSRoman Gareev const SmallDenseSet<int> &IandJIndexSet, 1324b02c7e2bSRoman Gareev ArrayRef<int> Dimensions, TCInfoTy &TCI) { 1325b02c7e2bSRoman Gareev if (!TCI.A) { 1326b02c7e2bSRoman Gareev // Probably IndexSet is a union of I and P sets. 1327b02c7e2bSRoman Gareev if (!isSuperset(IndexSet, TCI.P)) 1328b02c7e2bSRoman Gareev return false; 1329b02c7e2bSRoman Gareev 1330b02c7e2bSRoman Gareev // Obtain the set I. 1331b02c7e2bSRoman Gareev TCI.I = set_difference(IndexSet, TCI.P); 1332b02c7e2bSRoman Gareev if (!isSuperset(IandJIndexSet, TCI.I)) 1333b02c7e2bSRoman Gareev return false; 1334b02c7e2bSRoman Gareev 1335b02c7e2bSRoman Gareev // Obtain the set J. 1336b02c7e2bSRoman Gareev TCI.J = set_difference(IandJIndexSet, TCI.I); 1337b02c7e2bSRoman Gareev 1338b02c7e2bSRoman Gareev // Set the first operand of the tensor contraction. 1339b02c7e2bSRoman Gareev TCI.A = MemAccessPtr; 1340b02c7e2bSRoman Gareev llvm::replace(TCI.ADimensions, TCI.ADimensions.begin(), 1341b02c7e2bSRoman Gareev TCI.ADimensions.end(), Dimensions.begin(), Dimensions.end()); 1342b02c7e2bSRoman Gareev return true; 1343b02c7e2bSRoman Gareev } 1344b02c7e2bSRoman Gareev 1345b02c7e2bSRoman Gareev if (!TCI.B) { 1346b02c7e2bSRoman Gareev // IndexSet should be a union of J and P sets. 1347b02c7e2bSRoman Gareev if (unite(TCI.P, TCI.J) != IndexSet) 1348b02c7e2bSRoman Gareev return false; 1349b02c7e2bSRoman Gareev 1350b02c7e2bSRoman Gareev // Set the second operand of the tensor contraction. 1351b02c7e2bSRoman Gareev TCI.B = MemAccessPtr; 1352b02c7e2bSRoman Gareev llvm::replace(TCI.BDimensions, TCI.BDimensions.begin(), 1353b02c7e2bSRoman Gareev TCI.BDimensions.end(), Dimensions.begin(), Dimensions.end()); 1354b02c7e2bSRoman Gareev return true; 1355b02c7e2bSRoman Gareev } 1356b02c7e2bSRoman Gareev 1357b02c7e2bSRoman Gareev return false; 1358b02c7e2bSRoman Gareev } 1359b02c7e2bSRoman Gareev 1360b02c7e2bSRoman Gareev /// Check that all memory accesses of the statement, except from the last 1361b02c7e2bSRoman Gareev /// one, are read memory accesses, which read elements of operands of the tensor 1362b02c7e2bSRoman Gareev /// contraction and its result. 1363b02c7e2bSRoman Gareev /// 1364b02c7e2bSRoman Gareev /// @param Domain The domain of the statement. 1365b02c7e2bSRoman Gareev /// @param Stmt The statement, which writes to memory. 1366b02c7e2bSRoman Gareev /// @param TCI The information about the tensor contraction. 1367b02c7e2bSRoman Gareev /// @param IandJIndexSet The set, which contains free indexes of tensors. 1368b02c7e2bSRoman Gareev /// @return True if all read memory accesses of the statement @p Stmt correspond 1369b02c7e2bSRoman Gareev /// to the tensor contraction. 1370b02c7e2bSRoman Gareev static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI, 1371b02c7e2bSRoman Gareev SmallDenseSet<int> &IandJIndexSet) { 1372b02c7e2bSRoman Gareev TCI.A = nullptr; 1373b02c7e2bSRoman Gareev TCI.B = nullptr; 1374b02c7e2bSRoman Gareev TCI.ReadFromC = nullptr; 1375b02c7e2bSRoman Gareev SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt); 1376b02c7e2bSRoman Gareev for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) { 1377b02c7e2bSRoman Gareev MemoryAccess *MemAccessPtr = *MemA; 1378b02c7e2bSRoman Gareev 1379b02c7e2bSRoman Gareev // All memory accesses, except from the last one, should be read memory 1380b02c7e2bSRoman Gareev // accesses. 1381b02c7e2bSRoman Gareev if (MemAccessPtr->isWrite()) 1382b02c7e2bSRoman Gareev return false; 1383b02c7e2bSRoman Gareev 1384b02c7e2bSRoman Gareev isl::map AccMap = MemAccessPtr->getLatestAccessRelation(); 1385b02c7e2bSRoman Gareev 1386b02c7e2bSRoman Gareev if (!MemAccessPtr->isLatestArrayKind()) { 1387b02c7e2bSRoman Gareev // Check whether the scalar read memory access is not partial. 1388b02c7e2bSRoman Gareev if (!Domain.is_subset(AccMap.domain())) 1389b02c7e2bSRoman Gareev return false; 1390b02c7e2bSRoman Gareev continue; 1391b02c7e2bSRoman Gareev return false; 1392b02c7e2bSRoman Gareev } 1393b02c7e2bSRoman Gareev 1394b02c7e2bSRoman Gareev // There is only one memory access, which reads elements of the result of 1395b02c7e2bSRoman Gareev // the tensor contraction. 1396b02c7e2bSRoman Gareev if (AccMap.is_equal(TCI.WriteToC->getLatestAccessRelation())) { 1397b02c7e2bSRoman Gareev if (TCI.ReadFromC) 1398b02c7e2bSRoman Gareev return false; 1399b02c7e2bSRoman Gareev TCI.ReadFromC = MemAccessPtr; 1400b02c7e2bSRoman Gareev continue; 1401b02c7e2bSRoman Gareev } 1402b02c7e2bSRoman Gareev 1403b02c7e2bSRoman Gareev SmallVector<int> Dimensions; 1404b02c7e2bSRoman Gareev SmallDenseSet<int> IndexSet; 1405b02c7e2bSRoman Gareev if (!isTCOperandAcc(Domain, AccMap, IndexSet, TCI.DimensionSizes, 1406b02c7e2bSRoman Gareev Dimensions)) 1407b02c7e2bSRoman Gareev return false; 1408b02c7e2bSRoman Gareev 1409b02c7e2bSRoman Gareev if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI)) 1410b02c7e2bSRoman Gareev return false; 1411b02c7e2bSRoman Gareev } 1412b02c7e2bSRoman Gareev 1413b02c7e2bSRoman Gareev // Check that there are read memory accesses, which read elements of operands 1414b02c7e2bSRoman Gareev // of the tensor contraction and its result. 1415b02c7e2bSRoman Gareev return TCI.ReadFromC && TCI.A && TCI.B; 1416b02c7e2bSRoman Gareev } 1417b02c7e2bSRoman Gareev 1418b02c7e2bSRoman Gareev /// Check accesses to operands of the tensor contraction. 1419b02c7e2bSRoman Gareev /// 1420b02c7e2bSRoman Gareev /// Check that accesses of the SCoP statement, which corresponds to 1421b02c7e2bSRoman Gareev /// the partial schedule @p PartialSchedule, represent accesses 1422b02c7e2bSRoman Gareev /// to the non-scalar operands of the tensor contraction. 1423b02c7e2bSRoman Gareev /// 1424b02c7e2bSRoman Gareev /// @param Domain The domain of the SCoP statement. 1425b02c7e2bSRoman Gareev /// @param PartialSchedule The partial schedule of the SCoP statement. 1426b02c7e2bSRoman Gareev /// @param TCI Parameters of the tensor contraction operands. 1427b02c7e2bSRoman Gareev /// @return True if the corresponding SCoP statement 1428b02c7e2bSRoman Gareev /// represents tensor contraction and false, 1429b02c7e2bSRoman Gareev /// otherwise. 1430b02c7e2bSRoman Gareev static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule, 1431b02c7e2bSRoman Gareev TCInfoTy &TCI) { 1432b02c7e2bSRoman Gareev isl::id InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); 1433b02c7e2bSRoman Gareev ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 1434b02c7e2bSRoman Gareev 1435b02c7e2bSRoman Gareev // In region statements, the order of memory accesses execution is not 1436b02c7e2bSRoman Gareev // predictable at compile-time. 1437b02c7e2bSRoman Gareev if ((Stmt->size() <= 1) || Stmt->isRegionStmt()) 1438b02c7e2bSRoman Gareev return false; 1439b02c7e2bSRoman Gareev 1440b02c7e2bSRoman Gareev unsigned DimNum = unsignedFromIslSize(PartialSchedule.dim(isl::dim::in)); 1441b02c7e2bSRoman Gareev TCI.DimensionSizes.resize(DimNum); 1442b02c7e2bSRoman Gareev SmallDenseSet<int> IandJIndexSet; 1443b02c7e2bSRoman Gareev 1444b02c7e2bSRoman Gareev TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet); 1445b02c7e2bSRoman Gareev if (!TCI.WriteToC) 1446b02c7e2bSRoman Gareev return false; 1447b02c7e2bSRoman Gareev 1448b02c7e2bSRoman Gareev if (intersect(IandJIndexSet, TCI.P).size() != 0) 1449b02c7e2bSRoman Gareev return false; 1450b02c7e2bSRoman Gareev 1451b02c7e2bSRoman Gareev if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet)) 1452b02c7e2bSRoman Gareev return false; 1453b02c7e2bSRoman Gareev 1454b02c7e2bSRoman Gareev return true; 1455b02c7e2bSRoman Gareev } 1456b02c7e2bSRoman Gareev 1457b02c7e2bSRoman Gareev /// Check that dependency corresponds to the tensor contraction carried over 1458b02c7e2bSRoman Gareev /// loop dimension @p Dim. 1459b02c7e2bSRoman Gareev /// 1460b02c7e2bSRoman Gareev /// Check that the dependency has the form 1461b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1462b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1463b02c7e2bSRoman Gareev /// statement. For this purpose, we analyze the set @p DepDelta, which 1464b02c7e2bSRoman Gareev /// represents the differences between image elements and domain elements of 1465b02c7e2bSRoman Gareev /// the corresponding map. 1466b02c7e2bSRoman Gareev /// 1467b02c7e2bSRoman Gareev /// @param DepDelta The set contains the differences between image elements 1468b02c7e2bSRoman Gareev /// and corresponding domain elements of the map, which 1469b02c7e2bSRoman Gareev /// represents the dependency. 1470b02c7e2bSRoman Gareev /// @param Dim The position of the index ki. 1471b02c7e2bSRoman Gareev /// @param BoundDeltas In the case of indexes of ki, the difference between 1472b02c7e2bSRoman Gareev /// image elements and corresponding domain elements 1473b02c7e2bSRoman Gareev /// corresponds to the difference between lexicographic 1474b02c7e2bSRoman Gareev /// minimum and lexicographic maximum of the corresponding 1475b02c7e2bSRoman Gareev /// dimension of the domain of the statement. 1476b02c7e2bSRoman Gareev /// @param IndexSet Obtained indexes ki, which describe the dependency. 1477b02c7e2bSRoman Gareev /// @return True if dependencies correspond to the tensor contraction 1478b02c7e2bSRoman Gareev /// and false, otherwise. 1479b02c7e2bSRoman Gareev static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim, 1480b02c7e2bSRoman Gareev isl::pw_multi_aff BoundDeltas, 1481b02c7e2bSRoman Gareev const SmallDenseSet<int> &IndexSet) { 1482b02c7e2bSRoman Gareev isl::space Space = DepDelta.get_space(); 1483b02c7e2bSRoman Gareev isl::set Superset = isl::set::universe(Space); 1484b02c7e2bSRoman Gareev for (unsigned i = 0; i < Dim; i += 1) 1485b02c7e2bSRoman Gareev Superset = Superset.fix_si(isl::dim::set, i, 0); 1486b02c7e2bSRoman Gareev Superset = Superset.fix_si(isl::dim::set, Dim, 1); 1487b02c7e2bSRoman Gareev 1488b02c7e2bSRoman Gareev // Check that the difference between the image element and the domain element 1489b02c7e2bSRoman Gareev // is equal to one in the case of the index ki. Image elements and 1490b02c7e2bSRoman Gareev // corresponding domain elements should be equal in the case of positions, 1491b02c7e2bSRoman Gareev // which are lower than the specified position. 1492b02c7e2bSRoman Gareev if (!DepDelta.is_subset(Superset)) 1493b02c7e2bSRoman Gareev return false; 1494b02c7e2bSRoman Gareev 1495b02c7e2bSRoman Gareev // Compute a set, which is used to analyze how values of 1496b02c7e2bSRoman Gareev // the domain are related to the map that describes the dependency. 1497b02c7e2bSRoman Gareev isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(DepDelta.copy()); 1498b02c7e2bSRoman Gareev BoundDeltas = BoundDeltas.add(isl::manage(DepDeltaPW)); 1499b02c7e2bSRoman Gareev isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(BoundDeltas.release()); 1500b02c7e2bSRoman Gareev isl::set Complement = isl::manage(ComplementRawSet); 1501b02c7e2bSRoman Gareev 1502b02c7e2bSRoman Gareev for (unsigned i : rangeIslSize(Dim + 1, DepDelta.dim(isl::dim::set))) { 1503b02c7e2bSRoman Gareev if (!IndexSet.count(i)) { 1504b02c7e2bSRoman Gareev // Check the difference between the image element and the domain element 1505b02c7e2bSRoman Gareev // in the case of indexes, which do not describe the dependency. 1506b02c7e2bSRoman Gareev if (DepDelta.plain_get_val_if_fixed(isl::dim::set, i).is_zero()) 1507b02c7e2bSRoman Gareev continue; 1508b02c7e2bSRoman Gareev return false; 1509b02c7e2bSRoman Gareev } 1510b02c7e2bSRoman Gareev 1511b02c7e2bSRoman Gareev // In the case of other indexes, which describe the dependency, 1512b02c7e2bSRoman Gareev // the difference between the image element and the domain element 1513b02c7e2bSRoman Gareev // should be equal to the difference between lexicographic minimum and 1514b02c7e2bSRoman Gareev // lexicographic maximum of the domain of the statement. 1515b02c7e2bSRoman Gareev if (!Complement.plain_get_val_if_fixed(isl::dim::set, i).is_zero()) 1516b02c7e2bSRoman Gareev return false; 1517b02c7e2bSRoman Gareev } 1518b02c7e2bSRoman Gareev 1519b02c7e2bSRoman Gareev return true; 1520b02c7e2bSRoman Gareev } 1521b02c7e2bSRoman Gareev 1522b02c7e2bSRoman Gareev /// Check whether dependencies are over the complete domain. 1523b02c7e2bSRoman Gareev /// 1524b02c7e2bSRoman Gareev /// In the case of the tensor contraction RAW, WAW, WAR dependencies 1525b02c7e2bSRoman Gareev /// have the form 1526b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1527b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1528b02c7e2bSRoman Gareev /// statement. Consequently, the domain of the dependencies 1529b02c7e2bSRoman Gareev /// can be described as 1530b02c7e2bSRoman Gareev /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…), 1531b02c7e2bSRoman Gareev /// where Domain is the domain of the statement S. 1532b02c7e2bSRoman Gareev /// 1533b02c7e2bSRoman Gareev /// For example, in the case of the following tensor contraction, 1534b02c7e2bSRoman Gareev /// corresponding domains will have the following form. 1535b02c7e2bSRoman Gareev /// 1536b02c7e2bSRoman Gareev /// An example of the tensor contraction: 1537b02c7e2bSRoman Gareev /// for (i = 0; i < 1024; i++) 1538b02c7e2bSRoman Gareev /// for (j = 0; j < 1024; j++) 1539b02c7e2bSRoman Gareev /// for (l = 0; l < 64; ++l) 1540b02c7e2bSRoman Gareev /// for (w = 0; w < 64; ++w) 1541b02c7e2bSRoman Gareev /// C[i][j] += A[i][l][w] * B[w][j][l]; 1542b02c7e2bSRoman Gareev /// 1543b02c7e2bSRoman Gareev /// The domain of the statement: 1544b02c7e2bSRoman Gareev /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and 1545b02c7e2bSRoman Gareev /// i1 >= 0 and i1 <= 1023 and 1546b02c7e2bSRoman Gareev /// i2 >= 0 and i2 <= 63 and 1547b02c7e2bSRoman Gareev /// i3 >= 0 and i3 <= 63 } 1548b02c7e2bSRoman Gareev /// 1549b02c7e2bSRoman Gareev /// The domain of the dependencies: 1550b02c7e2bSRoman Gareev /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and 1551b02c7e2bSRoman Gareev /// i1 >= 0 and i1 <= 1023 and 1552b02c7e2bSRoman Gareev /// i2 >= 0 and i2 <= 63 and 1553b02c7e2bSRoman Gareev /// i3 >= 0 and i3 <= 62) or 1554b02c7e2bSRoman Gareev /// (i3 = 63 and i0 >= 0 and i0 <= 1023 and 1555b02c7e2bSRoman Gareev /// i1 >= 0 and i1 <= 1023 and 1556b02c7e2bSRoman Gareev /// i2 >= 0 and i2 <= 62) } 1557b02c7e2bSRoman Gareev /// 1558b02c7e2bSRoman Gareev /// @param Domain The domain of the statement. 1559b02c7e2bSRoman Gareev /// @param DepsForStmt RAW and RED dependencies for the statement. 1560b02c7e2bSRoman Gareev /// @param UpperBound The lexicographic maximum of the elements in 1561b02c7e2bSRoman Gareev /// the @p Domain. 1562b02c7e2bSRoman Gareev /// @param IndexSet Obtained indexes ki, which describe the dependencies. 1563b02c7e2bSRoman Gareev /// @return True if dependencies are over the complete domain 1564b02c7e2bSRoman Gareev /// and false, otherwise. 1565b02c7e2bSRoman Gareev static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt, 1566b02c7e2bSRoman Gareev isl::pw_multi_aff UpperBound, 1567b02c7e2bSRoman Gareev SmallDenseSet<int> &IndexSet) { 1568b02c7e2bSRoman Gareev isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(UpperBound.copy()); 1569b02c7e2bSRoman Gareev isl::set UpperBoundSet = isl::manage(UpperBoundRawSet); 1570b02c7e2bSRoman Gareev 1571b02c7e2bSRoman Gareev isl::set DomainRed = isl::manage(Domain.copy()); 1572b02c7e2bSRoman Gareev for (const auto It : IndexSet) { 1573b02c7e2bSRoman Gareev isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(isl::dim::set, It); 1574b02c7e2bSRoman Gareev if (FixedVal.is_nan()) 1575b02c7e2bSRoman Gareev return false; 1576b02c7e2bSRoman Gareev DomainRed = isl::manage( 1577b02c7e2bSRoman Gareev isl_set_fix_val(DomainRed.copy(), isl_dim_set, It, FixedVal.release())); 1578b02c7e2bSRoman Gareev } 1579b02c7e2bSRoman Gareev return DepsForStmt.domain().intersect(Domain).is_equal( 1580b02c7e2bSRoman Gareev Domain.subtract(DomainRed)); 1581b02c7e2bSRoman Gareev } 1582b02c7e2bSRoman Gareev 1583b02c7e2bSRoman Gareev /// Check that dependencies correspond to the tensor contraction. 1584b02c7e2bSRoman Gareev /// 1585b02c7e2bSRoman Gareev /// Check that there are only true dependencies of the form 1586b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1587b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1588b02c7e2bSRoman Gareev /// statement represented by @p Schedule. Such dependencies are produced by 1589b02c7e2bSRoman Gareev /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet. 1590b02c7e2bSRoman Gareev /// 1591b02c7e2bSRoman Gareev /// The form of anti and output dependencies is specified implicitly by 1592b02c7e2bSRoman Gareev /// the form the SCoP statement, which is checked by subsequent analysis. 1593b02c7e2bSRoman Gareev /// 1594b02c7e2bSRoman Gareev /// @param Schedule The schedule of the SCoP statement. 1595b02c7e2bSRoman Gareev /// @param D The SCoP dependencies. 1596b02c7e2bSRoman Gareev /// @param Domain The domain of the statement. 1597b02c7e2bSRoman Gareev /// @param IndexSet Obtained indexes ki, which describe the dependencies. 1598b02c7e2bSRoman Gareev /// @return True if dependencies correspond to the tensor contraction 1599b02c7e2bSRoman Gareev /// and false, otherwise. 1600b02c7e2bSRoman Gareev static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D, 1601b02c7e2bSRoman Gareev SmallDenseSet<int> &IndexSet, isl::set Domain) { 1602b02c7e2bSRoman Gareev IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut); 1603b02c7e2bSRoman Gareev 1604b02c7e2bSRoman Gareev isl::union_map Dep = 1605b02c7e2bSRoman Gareev D->getDependences(Dependences::TYPE_RAW | Dependences::TYPE_RED); 1606b02c7e2bSRoman Gareev 1607b02c7e2bSRoman Gareev isl::space DomainSpace = Schedule.get_space().domain(); 1608b02c7e2bSRoman Gareev isl::space Space = DomainSpace.map_from_domain_and_range(DomainSpace); 1609b02c7e2bSRoman Gareev isl::map DepsForStmt = Dep.extract_map(Space); 1610b02c7e2bSRoman Gareev isl::set DepDeltas = DepsForStmt.deltas(); 1611b02c7e2bSRoman Gareev isl::size DeltasDimNum = DepDeltas.dim(isl::dim::set); 1612b02c7e2bSRoman Gareev isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff(); 1613b02c7e2bSRoman Gareev isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff(); 1614b02c7e2bSRoman Gareev isl::pw_multi_aff BoundDeltas = UpperBound.sub(LowerBound); 1615b02c7e2bSRoman Gareev 1616b02c7e2bSRoman Gareev for (int i : reverse(rangeIslSize(0, DeltasDimNum))) { 1617b02c7e2bSRoman Gareev // In the case of the tensor contraction, the difference between image 1618b02c7e2bSRoman Gareev // elements and domain elements lies on a hyperplane where a dimension 1619b02c7e2bSRoman Gareev // has the fixed value one. 1620b02c7e2bSRoman Gareev isl::set Intersection = DepDeltas.fix_si(isl::dim::set, i, 1); 1621b02c7e2bSRoman Gareev if (Intersection.is_empty()) 1622b02c7e2bSRoman Gareev continue; 1623b02c7e2bSRoman Gareev 1624b02c7e2bSRoman Gareev if (!isReductionCarriedOverDim(Intersection, i, BoundDeltas, IndexSet)) 1625b02c7e2bSRoman Gareev return false; 1626b02c7e2bSRoman Gareev 1627b02c7e2bSRoman Gareev IndexSet.insert(i); 1628b02c7e2bSRoman Gareev DepDeltas = DepDeltas.subtract(Intersection); 1629b02c7e2bSRoman Gareev } 1630b02c7e2bSRoman Gareev 1631b02c7e2bSRoman Gareev // In the case of the tensor contraction, all dependencies should have 1632b02c7e2bSRoman Gareev // the previously described form. 1633b02c7e2bSRoman Gareev if ((unsignedFromIslSize(DeltasDimNum) == 0) || !DepDeltas.is_empty()) 1634b02c7e2bSRoman Gareev return false; 1635b02c7e2bSRoman Gareev 1636b02c7e2bSRoman Gareev return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet); 1637b02c7e2bSRoman Gareev } 1638b02c7e2bSRoman Gareev 1639b02c7e2bSRoman Gareev /// Check if the SCoP statement could probably be optimized with analytical 1640b02c7e2bSRoman Gareev /// modeling. 1641b02c7e2bSRoman Gareev /// 1642b02c7e2bSRoman Gareev /// containsTCInfoTy tries to determine whether the following conditions 1643b02c7e2bSRoman Gareev /// are true: 1644b02c7e2bSRoman Gareev /// 1645b02c7e2bSRoman Gareev /// 1. The last memory access modeling an array, MA1, represents writing to 1646b02c7e2bSRoman Gareev /// memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)), 1647b02c7e2bSRoman Gareev /// where S is the SCoP statement under consideration and shuffle(I, J) 1648b02c7e2bSRoman Gareev /// is a permutation of indexes of sets I and J. 1649b02c7e2bSRoman Gareev /// 2. There are only true dependencies of the form 1650b02c7e2bSRoman Gareev /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1651b02c7e2bSRoman Gareev /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1652b02c7e2bSRoman Gareev /// statement represented by @p Schedule and ki are indexes of the set P. 1653b02c7e2bSRoman Gareev /// 3. SCoP contains an arbitrary number of reads from constants and only three 1654b02c7e2bSRoman Gareev /// access relations, MA2, MA3, and MA4 that represent reading from memory 1655b02c7e2bSRoman Gareev /// and have the form 1656b02c7e2bSRoman Gareev /// S(..., I, ..., P, ...) -> M(shuffle(I, P)), 1657b02c7e2bSRoman Gareev /// S(..., P, ..., J, ...) -> M(shuffle(J, P)), 1658b02c7e2bSRoman Gareev /// S(...) -> M(shuffle(I, J)), respectively. 1659b02c7e2bSRoman Gareev /// 1660b02c7e2bSRoman Gareev /// @param PartialSchedule The PartialSchedule that contains a SCoP statement 1661b02c7e2bSRoman Gareev /// to check. 1662b02c7e2bSRoman Gareev /// @param D The SCoP dependencies. 1663b02c7e2bSRoman Gareev /// @param TCI Parameters of the tensor contraction operands. 1664b02c7e2bSRoman Gareev /// @param Domain The domain of the statement. 1665b02c7e2bSRoman Gareev /// @return True if dependencies and memory accesses correspond to the tensor 1666b02c7e2bSRoman Gareev /// contraction and false, otherwise. 1667b02c7e2bSRoman Gareev static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D, 1668b02c7e2bSRoman Gareev TCInfoTy &TCI, isl::set Domain) { 1669b02c7e2bSRoman Gareev if (!containsOnlyTcDeps(PartialSchedule, D, TCI.P, Domain)) 1670b02c7e2bSRoman Gareev return false; 1671b02c7e2bSRoman Gareev 1672b02c7e2bSRoman Gareev // TODO: handle cases of scalar multiplication if needed. 1673b02c7e2bSRoman Gareev if (TCI.P.size() == 0) 1674b02c7e2bSRoman Gareev return false; 1675b02c7e2bSRoman Gareev 1676b02c7e2bSRoman Gareev if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI)) 1677b02c7e2bSRoman Gareev return false; 1678b02c7e2bSRoman Gareev 1679b02c7e2bSRoman Gareev // TODO: handle cases of GEMV if needed. 1680b02c7e2bSRoman Gareev if ((TCI.I.size() == 0) || (TCI.J.size() == 0)) 1681b02c7e2bSRoman Gareev return false; 1682b02c7e2bSRoman Gareev 1683b02c7e2bSRoman Gareev return true; 1684b02c7e2bSRoman Gareev } 1685b02c7e2bSRoman Gareev 1686b02c7e2bSRoman Gareev /// Check if this node contains a partial schedule that could 1687b02c7e2bSRoman Gareev /// probably be optimized with analytical modeling. 1688b02c7e2bSRoman Gareev /// 1689b02c7e2bSRoman Gareev /// isTCPattern is used to determine whether the SCoP represents a TC-like 1690b02c7e2bSRoman Gareev /// kernel [1], which is a perfectly nested set of loops, with a data usage 1691b02c7e2bSRoman Gareev /// pattern that is similar to that produced by the tensor contraction. 1692b02c7e2bSRoman Gareev /// 1693b02c7e2bSRoman Gareev /// A TC-like kernel can be defined as follows: 1694b02c7e2bSRoman Gareev /// 1695b02c7e2bSRoman Gareev /// 1. It satisfies the requirements of the polyhedral model. 1696b02c7e2bSRoman Gareev /// 2. Without loss of generality, it contains three nonempty bundles of 1697b02c7e2bSRoman Gareev /// one-dimensional for-loops with induction variables that are grouped into 1698b02c7e2bSRoman Gareev /// bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they 1699b02c7e2bSRoman Gareev /// are incremented by one. 1700b02c7e2bSRoman Gareev /// 3. The innermost loop body can be represented as a statement of the form 1701b02c7e2bSRoman Gareev /// C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)), 1702b02c7e2bSRoman Gareev /// C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)), 1703b02c7e2bSRoman Gareev /// C(shuffle(I, J)) are accesses to tensors A, B, C, respectively, 1704b02c7e2bSRoman Gareev /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the 1705b02c7e2bSRoman Gareev /// enclosed indices, and E is an expression that contains reads from 1706b02c7e2bSRoman Gareev /// the tensors A, B, C, and an arbitrary number of reads from constants 1707b02c7e2bSRoman Gareev /// with respect to bundles I, J, and P. 1708b02c7e2bSRoman Gareev /// 1709b02c7e2bSRoman Gareev /// TC can be considered as a particular case of a TC-like kernel. 1710b02c7e2bSRoman Gareev /// 1711b02c7e2bSRoman Gareev /// The order of loops with indexes from P should be preserved. Otherwise, 1712b02c7e2bSRoman Gareev /// isTCPattern should check if a commutative operation is used. 1713b02c7e2bSRoman Gareev /// 1714b02c7e2bSRoman Gareev /// isTCPattern performs the following steps to check whether the SCoP 1715b02c7e2bSRoman Gareev /// corresponds to a definition of a TC-like kernel: 1716b02c7e2bSRoman Gareev /// 1717b02c7e2bSRoman Gareev /// 1. Checks that the node is the innermost band node. 1718b02c7e2bSRoman Gareev /// 2. Checks that the partial schedule contains only one statement. 1719b02c7e2bSRoman Gareev /// 3. Check that all ancestors of the node contain all band nodes for 1720b02c7e2bSRoman Gareev /// the statement and only mark nodes interleave such band nodes. This 1721b02c7e2bSRoman Gareev /// corresponds to a straightforward implementation of TC. 1722b02c7e2bSRoman Gareev /// 4. Analyses the dependencies to determine contraction dimensions. 1723b02c7e2bSRoman Gareev /// 5. Check that the last memory access modeling an array, represents writing 1724b02c7e2bSRoman Gareev /// to the result of the TC-like kernel. 1725b02c7e2bSRoman Gareev /// 6. Check that SCoP contains only three access relations that represent 1726b02c7e2bSRoman Gareev /// reading of the operands of the TC-like kernel and an arbitrary number of 1727b02c7e2bSRoman Gareev /// reads from constants. 1728b02c7e2bSRoman Gareev /// 1729b02c7e2bSRoman Gareev /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor 1730b02c7e2bSRoman Gareev /// Operations: A Compiler-Oriented Approach // ACM Transactions 1731b02c7e2bSRoman Gareev /// Architecture and Code Optimization (TACO). 2018. 1732b02c7e2bSRoman Gareev /// Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029. 1733b02c7e2bSRoman Gareev /// 1734b02c7e2bSRoman Gareev /// If this is the case, we could logically represent tensors as matrices and 1735b02c7e2bSRoman Gareev /// apply algorithms, which are used to get close-to-peak performance of 1736b02c7e2bSRoman Gareev /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS). 1737b02c7e2bSRoman Gareev /// 1738b02c7e2bSRoman Gareev /// @param Node The node to check. 1739b02c7e2bSRoman Gareev /// @param D The SCoP dependencies. 1740b02c7e2bSRoman Gareev /// @param TCI Parameters of the tensor contraction operands. 1741b02c7e2bSRoman Gareev static bool isTCPattern(isl::schedule_node Node, const Dependences *D, 1742b02c7e2bSRoman Gareev TCInfoTy &TCI) { 1743b02c7e2bSRoman Gareev Node = Node.child(0); 1744b02c7e2bSRoman Gareev isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map(); 1745b02c7e2bSRoman Gareev isl::union_set Domain = Node.domain(); 1746b02c7e2bSRoman Gareev Node = Node.parent(); 1747b02c7e2bSRoman Gareev 1748b02c7e2bSRoman Gareev // The partial schedule should contain only one statement. 1749b02c7e2bSRoman Gareev // TODO: This constraint should not be intrinsic to the algorithm. 1750b02c7e2bSRoman Gareev if (isl_union_set_n_set(Domain.get()) != 1) 1751b02c7e2bSRoman Gareev return false; 1752b02c7e2bSRoman Gareev 1753b02c7e2bSRoman Gareev isl_schedule_node_type NodeType = isl_schedule_node_get_type(Node.get()); 1754b02c7e2bSRoman Gareev 1755b02c7e2bSRoman Gareev // Check that all ancestors of the node contain all band nodes for 1756b02c7e2bSRoman Gareev // the statement, which represents the TC-like kernel, and only mark nodes 1757b02c7e2bSRoman Gareev // interleave such band nodes. This corresponds to a straightforward 1758b02c7e2bSRoman Gareev // implementation of TC with/without DeLICM applied. 1759b02c7e2bSRoman Gareev // 1760b02c7e2bSRoman Gareev // For example, this covers the matrix multiplication pattern after a full 1761b02c7e2bSRoman Gareev // run of -polly-optree and -polly-delicm, where the write access is not 1762*5aafc6d5SChristian Clauss // through the original memory access, but through a PHI node that was 1763b02c7e2bSRoman Gareev // delicmed. Subsequently, such band nodes will be replaced by a single band 1764b02c7e2bSRoman Gareev // node. 1765b02c7e2bSRoman Gareev // 1766b02c7e2bSRoman Gareev // The corresponding schedule can be the following, where Stmt_for_body8 1767b02c7e2bSRoman Gareev // contains the matrix multiplication: 1768b02c7e2bSRoman Gareev // 1769b02c7e2bSRoman Gareev // domain: "{ Stmt_for_body8[i0, i1, i2] : 0 <= i0 <= 1599 and 1770b02c7e2bSRoman Gareev // 0 <= i1 <= 1799 and 1771b02c7e2bSRoman Gareev // 0 <= i2 <= 2199; 1772b02c7e2bSRoman Gareev // Stmt_for_body3[i0, i1] : 0 <= i0 <= 1599 and 1773b02c7e2bSRoman Gareev // 0 <= i1 <= 1799; 1774b02c7e2bSRoman Gareev // Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and 1775b02c7e2bSRoman Gareev // 0 <= i1 <= 1799 }" 1776b02c7e2bSRoman Gareev // child: 1777b02c7e2bSRoman Gareev // sequence: 1778b02c7e2bSRoman Gareev // - filter: "{ Stmt_for_body3[i0, i1] }" 1779b02c7e2bSRoman Gareev // child: 1780b02c7e2bSRoman Gareev // schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] }, 1781b02c7e2bSRoman Gareev // { Stmt_for_body3[i0, i1] -> [(i1)] }]" 1782b02c7e2bSRoman Gareev // permutable: 1 1783b02c7e2bSRoman Gareev // coincident: [ 1, 1 ] 1784b02c7e2bSRoman Gareev // - filter: "{ Stmt_for_body3_last[i0, i1] }" 1785b02c7e2bSRoman Gareev // child: 1786b02c7e2bSRoman Gareev // schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] }, 1787b02c7e2bSRoman Gareev // { Stmt_for_body3_last[i0, i1] -> [(i1)] }]" 1788b02c7e2bSRoman Gareev // permutable: 1 1789b02c7e2bSRoman Gareev // coincident: [ 1, 1 ] 1790b02c7e2bSRoman Gareev // - filter: "{ Stmt_for_body8[i0, i1, i2] }" 1791b02c7e2bSRoman Gareev // child: 1792b02c7e2bSRoman Gareev // schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] }, 1793b02c7e2bSRoman Gareev // { Stmt_for_body8[i0, i1, i2] -> [(i1)] }, 1794b02c7e2bSRoman Gareev // { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]" 1795b02c7e2bSRoman Gareev // permutable: 1 1796b02c7e2bSRoman Gareev // coincident: [ 1, 1, 0 ] 1797b02c7e2bSRoman Gareev // 1798b02c7e2bSRoman Gareev while (NodeType != isl_schedule_node_domain) { 1799b02c7e2bSRoman Gareev if (NodeType == isl_schedule_node_filter) { 1800b02c7e2bSRoman Gareev if (!Node.parent().isa<isl::schedule_node_sequence>() || 1801b02c7e2bSRoman Gareev !Node.parent().parent().isa<isl::schedule_node_domain>()) 1802b02c7e2bSRoman Gareev return false; 1803b02c7e2bSRoman Gareev break; 1804b02c7e2bSRoman Gareev } 1805b02c7e2bSRoman Gareev 1806b02c7e2bSRoman Gareev if ((NodeType != isl_schedule_node_band) && 1807b02c7e2bSRoman Gareev (NodeType != isl_schedule_node_mark)) 1808b02c7e2bSRoman Gareev return false; 1809b02c7e2bSRoman Gareev 1810b02c7e2bSRoman Gareev Node = Node.parent(); 1811b02c7e2bSRoman Gareev NodeType = isl_schedule_node_get_type(Node.get()); 1812b02c7e2bSRoman Gareev } 1813b02c7e2bSRoman Gareev 1814b02c7e2bSRoman Gareev isl::map PartialScheduleMap = isl::map::from_union_map(PartialSchedule); 1815b02c7e2bSRoman Gareev if (containsTCInfoTy(PartialScheduleMap, D, TCI, isl::set(Domain))) 1816b02c7e2bSRoman Gareev return true; 1817b02c7e2bSRoman Gareev 1818b02c7e2bSRoman Gareev return false; 1819b02c7e2bSRoman Gareev } 1820b02c7e2bSRoman Gareev 1821d123e983SMichael Kruse } // namespace 1822d123e983SMichael Kruse 1823d123e983SMichael Kruse isl::schedule_node 1824d123e983SMichael Kruse polly::tryOptimizeMatMulPattern(isl::schedule_node Node, 1825d123e983SMichael Kruse const llvm::TargetTransformInfo *TTI, 1826d123e983SMichael Kruse const Dependences *D) { 1827b02c7e2bSRoman Gareev TCInfoTy TCI; 1828b02c7e2bSRoman Gareev if (PMBasedTCOpts && isTCPattern(Node, D, TCI)) 1829601d7eabSKarthika Devi C POLLY_DEBUG(dbgs() << "The tensor contraction pattern was detected\n"); 1830d123e983SMichael Kruse MatMulInfoTy MMI; 1831b02c7e2bSRoman Gareev if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) { 1832601d7eabSKarthika Devi C POLLY_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); 1833d123e983SMichael Kruse return optimizeMatMulPattern(Node, TTI, MMI); 1834d123e983SMichael Kruse } 1835d123e983SMichael Kruse return {}; 1836d123e983SMichael Kruse } 1837