1 //===- MatmulOptimizer.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "polly/MatmulOptimizer.h" 10 #include "polly/DependenceInfo.h" 11 #include "polly/Options.h" 12 #include "polly/ScheduleTreeTransform.h" 13 #include "polly/ScopInfo.h" 14 #include "polly/ScopPass.h" 15 #include "polly/Simplify.h" 16 #include "polly/Support/GICHelper.h" 17 #include "polly/Support/ISLTools.h" 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/DenseSet.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SetOperations.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/iterator_range.h" 25 #include "llvm/Analysis/TargetTransformInfo.h" 26 #include "llvm/IR/DataLayout.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/Module.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/TypeSize.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include "isl/ctx.h" 34 #include "isl/schedule_node.h" 35 #include "isl/schedule_type.h" 36 #include "isl/union_map.h" 37 #include "isl/union_set.h" 38 #include <algorithm> 39 #include <cassert> 40 #include <cmath> 41 #include <cstdint> 42 #include <string> 43 #include <vector> 44 45 #include "polly/Support/PollyDebug.h" 46 #define DEBUG_TYPE "polly-opt-isl" 47 48 using namespace llvm; 49 using namespace polly; 50 51 namespace llvm { 52 class Value; 53 } 54 55 static cl::opt<int> LatencyVectorFma( 56 "polly-target-latency-vector-fma", 57 cl::desc("The minimal number of cycles between issuing two " 58 "dependent consecutive vector fused multiply-add " 59 "instructions."), 60 cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 61 62 static cl::opt<int> ThroughputVectorFma( 63 "polly-target-throughput-vector-fma", 64 cl::desc("A throughput of the processor floating-point arithmetic units " 65 "expressed in the number of vector fused multiply-add " 66 "instructions per clock cycle."), 67 cl::Hidden, cl::init(1), cl::cat(PollyCategory)); 68 69 static cl::opt<int> FirstCacheLevelSize( 70 "polly-target-1st-cache-level-size", 71 cl::desc("The size of the first cache level specified in bytes."), 72 cl::Hidden, cl::init(-1), cl::cat(PollyCategory)); 73 74 static cl::opt<int> FirstCacheLevelDefaultSize( 75 "polly-target-1st-cache-level-default-size", 76 cl::desc("The default size of the first cache level specified in bytes" 77 " (if not enough were provided by the TargetTransformInfo)."), 78 cl::Hidden, cl::init(32768), cl::cat(PollyCategory)); 79 80 static cl::opt<int> SecondCacheLevelSize( 81 "polly-target-2nd-cache-level-size", 82 cl::desc("The size of the second level specified in bytes."), cl::Hidden, 83 cl::init(-1), cl::cat(PollyCategory)); 84 85 static cl::opt<int> SecondCacheLevelDefaultSize( 86 "polly-target-2nd-cache-level-default-size", 87 cl::desc("The default size of the second cache level specified in bytes" 88 " (if not enough were provided by the TargetTransformInfo)."), 89 cl::Hidden, cl::init(262144), cl::cat(PollyCategory)); 90 91 // This option, along with --polly-target-2nd-cache-level-associativity, 92 // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size 93 // represent the parameters of the target cache, which do not have typical 94 // values that can be used by default. However, to apply the pattern matching 95 // optimizations, we use the values of the parameters of Intel Core i7-3820 96 // SandyBridge in case the parameters are not specified or not provided by the 97 // TargetTransformInfo. 98 static cl::opt<int> FirstCacheLevelAssociativity( 99 "polly-target-1st-cache-level-associativity", 100 cl::desc("The associativity of the first cache level."), cl::Hidden, 101 cl::init(-1), cl::cat(PollyCategory)); 102 103 static cl::opt<int> FirstCacheLevelDefaultAssociativity( 104 "polly-target-1st-cache-level-default-associativity", 105 cl::desc("The default associativity of the first cache level" 106 " (if not enough were provided by the TargetTransformInfo)."), 107 cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 108 109 static cl::opt<int> SecondCacheLevelAssociativity( 110 "polly-target-2nd-cache-level-associativity", 111 cl::desc("The associativity of the second cache level."), cl::Hidden, 112 cl::init(-1), cl::cat(PollyCategory)); 113 114 static cl::opt<int> SecondCacheLevelDefaultAssociativity( 115 "polly-target-2nd-cache-level-default-associativity", 116 cl::desc("The default associativity of the second cache level" 117 " (if not enough were provided by the TargetTransformInfo)."), 118 cl::Hidden, cl::init(8), cl::cat(PollyCategory)); 119 120 static cl::opt<int> VectorRegisterBitwidth( 121 "polly-target-vector-register-bitwidth", 122 cl::desc("The size in bits of a vector register (if not set, this " 123 "information is taken from LLVM's target information."), 124 cl::Hidden, cl::init(-1), cl::cat(PollyCategory)); 125 126 static cl::opt<int> PollyPatternMatchingNcQuotient( 127 "polly-pattern-matching-nc-quotient", 128 cl::desc("Quotient that is obtained by dividing Nc, the parameter of the" 129 "macro-kernel, by Nr, the parameter of the micro-kernel"), 130 cl::Hidden, cl::init(256), cl::cat(PollyCategory)); 131 132 static cl::opt<bool> 133 PMBasedTCOpts("polly-tc-opt", 134 cl::desc("Perform optimizations of tensor contractions based " 135 "on pattern matching"), 136 cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 137 138 static cl::opt<bool> 139 PMBasedMMMOpts("polly-matmul-opt", 140 cl::desc("Perform optimizations of matrix multiplications " 141 "based on pattern matching"), 142 cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory)); 143 144 static cl::opt<int> OptComputeOut( 145 "polly-tc-dependences-computeout", 146 cl::desc("Bound the dependence analysis by a maximal amount of " 147 "computational steps (0 means no bound)"), 148 cl::Hidden, cl::init(500000), cl::ZeroOrMore, cl::cat(PollyCategory)); 149 150 namespace { 151 /// Parameters of the micro kernel. 152 /// 153 /// Parameters, which determine sizes of rank-1 (i.e., outer product) update 154 /// used in the optimized matrix multiplication. 155 struct MicroKernelParamsTy { 156 int Mr; 157 int Nr; 158 }; 159 160 /// Parameters of the macro kernel. 161 /// 162 /// Parameters, which determine sizes of blocks of partitioned matrices 163 /// used in the optimized matrix multiplication. 164 struct MacroKernelParamsTy { 165 int Mc; 166 int Nc; 167 int Kc; 168 }; 169 170 /// Parameters of the matrix multiplication operands. 171 /// 172 /// Parameters, which describe access relations that represent operands of the 173 /// matrix multiplication. 174 struct MatMulInfoTy { 175 MemoryAccess *A = nullptr; 176 MemoryAccess *B = nullptr; 177 MemoryAccess *ReadFromC = nullptr; 178 MemoryAccess *WriteToC = nullptr; 179 int i = -1; 180 int j = -1; 181 int k = -1; 182 }; 183 184 /// Parameters of the tensor contraction operands. 185 /// 186 /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined 187 /// as the set of scalar elements indexed by the set of indices u0 ... ud, 188 /// 189 /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}. 190 /// 191 /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively. 192 /// Let the free and the contracted indices of the tensor A be grouped into 193 /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly, 194 /// the free and the contracted indices of B are grouped into bundles 195 /// J = j0..js−1 and P and the free indices of C are grouped into 196 /// bundles I and J. 197 /// 198 /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as 199 /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)), 200 /// where ∑ is a summation over all contracted indices of P, 201 /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds 202 /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are 203 /// accesses to tensors A, B, C, respectively, 204 /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of 205 /// the enclosed indices. 206 /// 207 /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP 208 /// statement by loop distribution, which is done by the isl scheduler. 209 // If β is not equal to one, the optimization of TC of Polly requires 210 /// such a transformation. 211 /// 212 /// TCInfoTy contains parameters, which describe access relations that represent 213 /// operands of the tensor contraction. 214 struct TCInfoTy { 215 /// @{ 216 /// Memory accesses that represent reading from tensors, which are operands of 217 /// the tensor contraction. 218 MemoryAccess *A = nullptr; 219 MemoryAccess *B = nullptr; 220 /// @} 221 222 /// @{ 223 /// Memory accesses that represent reading from and writing into the tensor, 224 /// which contains the result of the tensor contraction. 225 MemoryAccess *ReadFromC = nullptr; 226 MemoryAccess *WriteToC = nullptr; 227 /// @} 228 229 /// @{ 230 /// Input dimensions of the schedule space, which represent free 231 /// indices of tensors. 232 SmallDenseSet<int> I; 233 SmallDenseSet<int> J; 234 /// @} 235 236 /// Input dimension of the schedule space, which represents contracted 237 /// indices of tensors. 238 SmallDenseSet<int> P; 239 240 /// @{ 241 /// Sizes of tensor dimensions for corresponding input dimensions of 242 /// the schedule space. The size of the tensor dimension can be larger than 243 /// the size of the corresponding input dimension of the schedule space. 244 /// This does not correspond to a tensor contraction. However, such a pattern 245 /// will be optimized by the transformation. 246 SmallVector<int> DimensionSizes; 247 SmallVector<int> ADimensions; 248 SmallVector<int> BDimensions; 249 SmallVector<int> CDimensions; 250 /// @} 251 252 /// @{ 253 /// Permutations of indices of I, J, and P, which describe operands of 254 /// the tensor contraction and its result. 255 SmallVector<int> OrderedI; 256 SmallVector<int> OrderedJ; 257 SmallVector<int> OrderedP; 258 /// @} 259 }; 260 261 /// Create an isl::union_set, which describes the option of the form 262 /// [isolate[] -> unroll[x]]. 263 /// 264 /// @param Ctx An isl::ctx, which is used to create the isl::union_set. 265 static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) { 266 isl::space Space = isl::space(Ctx, 0, 0, 1); 267 isl::map UnrollIsolatedSetOption = isl::map::universe(Space); 268 isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr); 269 isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr); 270 UnrollIsolatedSetOption = 271 UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId); 272 UnrollIsolatedSetOption = 273 UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId); 274 return UnrollIsolatedSetOption.wrap(); 275 } 276 277 /// Permute the two dimensions of the isl map. 278 /// 279 /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that 280 /// have type @p DimType. 281 /// 282 /// @param Map The isl map to be modified. 283 /// @param DimType The type of the dimensions. 284 /// @param DstPos The first dimension. 285 /// @param SrcPos The second dimension. 286 /// @return The modified map. 287 static isl::map permuteDimensions(isl::map Map, isl::dim DimType, 288 unsigned DstPos, unsigned SrcPos) { 289 assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) && 290 SrcPos < unsignedFromIslSize(Map.dim(DimType))); 291 if (DstPos == SrcPos) 292 return Map; 293 isl::id DimId; 294 if (Map.has_tuple_id(DimType)) 295 DimId = Map.get_tuple_id(DimType); 296 auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in; 297 isl::id FreeDimId; 298 if (Map.has_tuple_id(FreeDim)) 299 FreeDimId = Map.get_tuple_id(FreeDim); 300 auto MaxDim = std::max(DstPos, SrcPos); 301 auto MinDim = std::min(DstPos, SrcPos); 302 Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1); 303 Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1); 304 Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1); 305 Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1); 306 if (!DimId.is_null()) 307 Map = Map.set_tuple_id(DimType, DimId); 308 if (!FreeDimId.is_null()) 309 Map = Map.set_tuple_id(FreeDim, FreeDimId); 310 return Map; 311 } 312 313 /// Check the form of the access relation. 314 /// 315 /// Check that the access relation @p AccMap has the form M[i][j], where i 316 /// is a @p FirstPos and j is a @p SecondPos. 317 /// 318 /// @param AccMap The access relation to be checked. 319 /// @param FirstPos The index of the input dimension that is mapped to 320 /// the first output dimension. 321 /// @param SecondPos The index of the input dimension that is mapped to the 322 /// second output dimension. 323 /// @return True in case @p AccMap has the expected form and false, 324 /// otherwise. 325 static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, 326 int &SecondPos) { 327 isl::space Space = AccMap.get_space(); 328 isl::map Universe = isl::map::universe(Space); 329 330 if (unsignedFromIslSize(Space.dim(isl::dim::out)) != 2) 331 return false; 332 333 // MatMul has the form: 334 // for (i = 0; i < N; i++) 335 // for (j = 0; j < M; j++) 336 // for (k = 0; k < P; k++) 337 // C[i, j] += A[i, k] * B[k, j] 338 // 339 // Permutation of three outer loops: 3! = 6 possibilities. 340 int FirstDims[] = {0, 0, 1, 1, 2, 2}; 341 int SecondDims[] = {1, 2, 2, 0, 0, 1}; 342 for (int i = 0; i < 6; i += 1) { 343 auto PossibleMatMul = 344 Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) 345 .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); 346 347 AccMap = AccMap.intersect_domain(Domain); 348 PossibleMatMul = PossibleMatMul.intersect_domain(Domain); 349 350 // If AccMap spans entire domain (Non-partial write), 351 // compute FirstPos and SecondPos. 352 // If AccMap != PossibleMatMul here (the two maps have been gisted at 353 // this point), it means that the writes are not complete, or in other 354 // words, it is a Partial write and Partial writes must be rejected. 355 if (AccMap.is_equal(PossibleMatMul)) { 356 if (FirstPos != -1 && FirstPos != FirstDims[i]) 357 continue; 358 FirstPos = FirstDims[i]; 359 if (SecondPos != -1 && SecondPos != SecondDims[i]) 360 continue; 361 SecondPos = SecondDims[i]; 362 return true; 363 } 364 } 365 366 return false; 367 } 368 369 /// Does the memory access represent a non-scalar operand of the matrix 370 /// multiplication. 371 /// 372 /// Check that the memory access @p MemAccess is the read access to a non-scalar 373 /// operand of the matrix multiplication or its result. 374 /// 375 /// @param MemAccess The memory access to be checked. 376 /// @param MMI Parameters of the matrix multiplication operands. 377 /// @return True in case the memory access represents the read access 378 /// to a non-scalar operand of the matrix multiplication and 379 /// false, otherwise. 380 static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, 381 MatMulInfoTy &MMI) { 382 if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) 383 return false; 384 auto AccMap = MemAccess->getLatestAccessRelation(); 385 isl::set StmtDomain = MemAccess->getStatement()->getDomain(); 386 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { 387 MMI.ReadFromC = MemAccess; 388 return true; 389 } 390 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { 391 MMI.A = MemAccess; 392 return true; 393 } 394 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { 395 MMI.B = MemAccess; 396 return true; 397 } 398 return false; 399 } 400 401 /// Check accesses to operands of the matrix multiplication. 402 /// 403 /// Check that accesses of the SCoP statement, which corresponds to 404 /// the partial schedule @p PartialSchedule, are scalar in terms of loops 405 /// containing the matrix multiplication, in case they do not represent 406 /// accesses to the non-scalar operands of the matrix multiplication or 407 /// its result. 408 /// 409 /// @param PartialSchedule The partial schedule of the SCoP statement. 410 /// @param MMI Parameters of the matrix multiplication operands. 411 /// @return True in case the corresponding SCoP statement 412 /// represents matrix multiplication and false, 413 /// otherwise. 414 static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, 415 MatMulInfoTy &MMI) { 416 auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in); 417 auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user()); 418 unsigned OutDimNum = unsignedFromIslSize(PartialSchedule.range_tuple_dim()); 419 assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " 420 "and, consequently, the corresponding scheduling " 421 "functions have at least three dimensions."); 422 auto MapI = 423 permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1); 424 auto MapJ = 425 permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1); 426 auto MapK = 427 permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1); 428 429 auto Accesses = getAccessesInOrder(*Stmt); 430 for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) { 431 auto *MemAccessPtr = *MemA; 432 if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC && 433 !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) && 434 !(MemAccessPtr->isStrideZero(MapI) && 435 MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK))) 436 return false; 437 } 438 return true; 439 } 440 441 /// Check for dependencies corresponding to the matrix multiplication. 442 /// 443 /// Check that there is only true dependence of the form 444 /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement 445 /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds 446 /// to the dependency produced by the matrix multiplication. 447 /// 448 /// @param Schedule The schedule of the SCoP statement. 449 /// @param D The SCoP dependencies. 450 /// @param Pos The parameter to describe an acceptable true dependence. 451 /// In case it has a negative value, try to determine its 452 /// acceptable value. 453 /// @return True in case dependencies correspond to the matrix multiplication 454 /// and false, otherwise. 455 static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D, 456 int &Pos) { 457 isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW); 458 isl::union_map Red = D->getDependences(Dependences::TYPE_RED); 459 if (!Red.is_null()) 460 Dep = Dep.unite(Red); 461 auto DomainSpace = Schedule.get_space().domain(); 462 auto Space = DomainSpace.map_from_domain_and_range(DomainSpace); 463 auto Deltas = Dep.extract_map(Space).deltas(); 464 int DeltasDimNum = unsignedFromIslSize(Deltas.dim(isl::dim::set)); 465 for (int i = 0; i < DeltasDimNum; i++) { 466 auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i); 467 Pos = Pos < 0 && Val.is_one() ? i : Pos; 468 if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one()))) 469 return false; 470 } 471 if (DeltasDimNum == 0 || Pos < 0) 472 return false; 473 return true; 474 } 475 476 /// Check if the SCoP statement could probably be optimized with analytical 477 /// modeling. 478 /// 479 /// containsMatrMult tries to determine whether the following conditions 480 /// are true: 481 /// 1. The last memory access modeling an array, MA1, represents writing to 482 /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or 483 /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement 484 /// under consideration. 485 /// 2. There is only one loop-carried true dependency, and it has the 486 /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no 487 /// loop-carried or anti dependencies. 488 /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent 489 /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3), 490 /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively, 491 /// and all memory accesses of the SCoP that are different from MA1, MA2, 492 /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any 493 /// of loops i1, i2 and i3. 494 /// 495 /// @param PartialSchedule The PartialSchedule that contains a SCoP statement 496 /// to check. 497 /// @D The SCoP dependencies. 498 /// @MMI Parameters of the matrix multiplication operands. 499 static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, 500 MatMulInfoTy &MMI) { 501 auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); 502 auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 503 if (Stmt->size() <= 1) 504 return false; 505 506 auto Accesses = getAccessesInOrder(*Stmt); 507 for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { 508 auto *MemAccessPtr = *MemA; 509 if (!MemAccessPtr->isLatestArrayKind()) 510 continue; 511 if (!MemAccessPtr->isWrite()) 512 return false; 513 auto AccMap = MemAccessPtr->getLatestAccessRelation(); 514 if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) 515 return false; 516 MMI.WriteToC = MemAccessPtr; 517 break; 518 } 519 520 if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k)) 521 return false; 522 523 if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI)) 524 return false; 525 526 if (!MMI.A || !MMI.B || !MMI.ReadFromC) 527 return false; 528 return true; 529 } 530 531 /// Permute two dimensions of the band node. 532 /// 533 /// Permute FirstDim and SecondDim dimensions of the Node. 534 /// 535 /// @param Node The band node to be modified. 536 /// @param FirstDim The first dimension to be permuted. 537 /// @param SecondDim The second dimension to be permuted. 538 static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node, 539 unsigned FirstDim, 540 unsigned SecondDim) { 541 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band && 542 (unsigned)isl_schedule_node_band_n_member(Node.get()) > 543 std::max(FirstDim, SecondDim)); 544 auto PartialSchedule = 545 isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get())); 546 auto PartialScheduleFirstDim = PartialSchedule.at(FirstDim); 547 auto PartialScheduleSecondDim = PartialSchedule.at(SecondDim); 548 PartialSchedule = 549 PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim); 550 PartialSchedule = 551 PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim); 552 Node = isl::manage(isl_schedule_node_delete(Node.release())); 553 return Node.insert_partial_schedule(PartialSchedule); 554 } 555 556 static isl::schedule_node 557 createMicroKernel(isl::schedule_node Node, 558 MicroKernelParamsTy MicroKernelParams) { 559 Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr}, 560 1); 561 Node = Node.parent().parent(); 562 return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0); 563 } 564 565 /// Create the BLIS macro-kernel. 566 /// 567 /// We create the BLIS macro-kernel by applying a combination of tiling 568 /// of dimensions of the band node and interchanging of two innermost 569 /// modified dimensions. The values of MacroKernelParams's fields are used 570 /// as tile sizes. 571 /// 572 /// @param Node The schedule node to be modified. 573 /// @param MacroKernelParams Parameters of the macro kernel 574 /// to be used as tile sizes. 575 static isl::schedule_node 576 createMacroKernel(isl::schedule_node Node, 577 MacroKernelParamsTy MacroKernelParams) { 578 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 579 if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 && 580 MacroKernelParams.Kc == 1) 581 return Node; 582 int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 583 std::vector<int> TileSizes(DimOutNum, 1); 584 TileSizes[DimOutNum - 3] = MacroKernelParams.Mc; 585 TileSizes[DimOutNum - 2] = MacroKernelParams.Nc; 586 TileSizes[DimOutNum - 1] = MacroKernelParams.Kc; 587 Node = tileNode(Node, "1st level tiling", TileSizes, 1); 588 Node = Node.parent().parent(); 589 Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1); 590 Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1); 591 592 return Node.child(0).child(0); 593 } 594 595 /// Get the size of the widest type of the matrix multiplication operands 596 /// in bytes, including alignment padding. 597 /// 598 /// @param MMI Parameters of the matrix multiplication operands. 599 /// @return The size of the widest type of the matrix multiplication operands 600 /// in bytes, including alignment padding. 601 static uint64_t getMatMulAlignTypeSize(const MatMulInfoTy &MMI) { 602 auto *S = MMI.A->getStatement()->getParent(); 603 auto &DL = S->getFunction().getParent()->getDataLayout(); 604 auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType()); 605 auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType()); 606 auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType()); 607 return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 608 } 609 610 /// Get the size of the widest type of the matrix multiplication operands 611 /// in bits. 612 /// 613 /// @param MMI Parameters of the matrix multiplication operands. 614 /// @return The size of the widest type of the matrix multiplication operands 615 /// in bits. 616 static uint64_t getMatMulTypeSize(const MatMulInfoTy &MMI) { 617 auto *S = MMI.A->getStatement()->getParent(); 618 auto &DL = S->getFunction().getParent()->getDataLayout(); 619 auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType()); 620 auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType()); 621 auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType()); 622 return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 623 } 624 625 /// Get parameters of the BLIS micro kernel. 626 /// 627 /// We choose the Mr and Nr parameters of the micro kernel to be large enough 628 /// such that no stalls caused by the combination of latencies and dependencies 629 /// are introduced during the updates of the resulting matrix of the matrix 630 /// multiplication. However, they should also be as small as possible to 631 /// release more registers for entries of multiplied matrices. 632 /// 633 /// @param TTI Target Transform Info. 634 /// @param MMI Parameters of the matrix multiplication operands. 635 /// @return The structure of type MicroKernelParamsTy. 636 /// @see MicroKernelParamsTy 637 static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI, 638 const MatMulInfoTy &MMI) { 639 assert(TTI && "The target transform info should be provided."); 640 641 // Nvec - Number of double-precision floating-point numbers that can be hold 642 // by a vector register. Use 2 by default. 643 long RegisterBitwidth = VectorRegisterBitwidth; 644 645 if (RegisterBitwidth == -1) 646 RegisterBitwidth = 647 TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector); 648 auto ElementSize = getMatMulTypeSize(MMI); 649 assert(ElementSize > 0 && "The element size of the matrix multiplication " 650 "operands should be greater than zero."); 651 auto Nvec = RegisterBitwidth / ElementSize; 652 if (Nvec == 0) 653 Nvec = 2; 654 int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) / 655 Nvec) * 656 Nvec; 657 int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr)); 658 return {Mr, Nr}; 659 } 660 661 /// Determine parameters of the target cache. 662 /// 663 /// @param TTI Target Transform Info. 664 static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) { 665 auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D; 666 auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D; 667 if (FirstCacheLevelSize == -1) { 668 if (TTI->getCacheSize(L1DCache)) 669 FirstCacheLevelSize = TTI->getCacheSize(L1DCache).value(); 670 else 671 FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize); 672 } 673 if (SecondCacheLevelSize == -1) { 674 if (TTI->getCacheSize(L2DCache)) 675 SecondCacheLevelSize = TTI->getCacheSize(L2DCache).value(); 676 else 677 SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize); 678 } 679 if (FirstCacheLevelAssociativity == -1) { 680 if (TTI->getCacheAssociativity(L1DCache)) 681 FirstCacheLevelAssociativity = 682 TTI->getCacheAssociativity(L1DCache).value(); 683 else 684 FirstCacheLevelAssociativity = 685 static_cast<int>(FirstCacheLevelDefaultAssociativity); 686 } 687 if (SecondCacheLevelAssociativity == -1) { 688 if (TTI->getCacheAssociativity(L2DCache)) 689 SecondCacheLevelAssociativity = 690 TTI->getCacheAssociativity(L2DCache).value(); 691 else 692 SecondCacheLevelAssociativity = 693 static_cast<int>(SecondCacheLevelDefaultAssociativity); 694 } 695 } 696 697 /// Get parameters of the BLIS macro kernel. 698 /// 699 /// During the computation of matrix multiplication, blocks of partitioned 700 /// matrices are mapped to different layers of the memory hierarchy. 701 /// To optimize data reuse, blocks should be ideally kept in cache between 702 /// iterations. Since parameters of the macro kernel determine sizes of these 703 /// blocks, there are upper and lower bounds on these parameters. 704 /// 705 /// @param TTI Target Transform Info. 706 /// @param MicroKernelParams Parameters of the micro-kernel 707 /// to be taken into account. 708 /// @param MMI Parameters of the matrix multiplication operands. 709 /// @return The structure of type MacroKernelParamsTy. 710 /// @see MacroKernelParamsTy 711 /// @see MicroKernelParamsTy 712 static MacroKernelParamsTy 713 getMacroKernelParams(const llvm::TargetTransformInfo *TTI, 714 const MicroKernelParamsTy &MicroKernelParams, 715 const MatMulInfoTy &MMI) { 716 getTargetCacheParameters(TTI); 717 // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, 718 // it requires information about the first two levels of a cache to determine 719 // all the parameters of a macro-kernel. It also checks that an associativity 720 // degree of a cache level is greater than two. Otherwise, another algorithm 721 // for determination of the parameters should be used. 722 if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 && 723 FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 && 724 FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2)) 725 return {1, 1, 1}; 726 // The quotient should be greater than zero. 727 if (PollyPatternMatchingNcQuotient <= 0) 728 return {1, 1, 1}; 729 int Car = floor( 730 (FirstCacheLevelAssociativity - 1) / 731 (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr)); 732 733 // Car can be computed to be zero since it is floor to int. 734 // On Mac OS, division by 0 does not raise a signal. This causes negative 735 // tile sizes to be computed. Prevent division by Cac==0 by early returning 736 // if this happens. 737 if (Car == 0) 738 return {1, 1, 1}; 739 740 auto ElementSize = getMatMulAlignTypeSize(MMI); 741 assert(ElementSize > 0 && "The element size of the matrix multiplication " 742 "operands should be greater than zero."); 743 int Kc = (Car * FirstCacheLevelSize) / 744 (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize); 745 double Cac = 746 static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) / 747 SecondCacheLevelSize; 748 int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac); 749 int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr; 750 751 assert(Mc > 0 && Nc > 0 && Kc > 0 && 752 "Matrix block sizes should be greater than zero"); 753 return {Mc, Nc, Kc}; 754 } 755 756 /// Create an access relation that is specific to 757 /// the matrix multiplication pattern. 758 /// 759 /// Create an access relation of the following form: 760 /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] 761 /// where I is @p FirstDim, J is @p SecondDim. 762 /// 763 /// It can be used, for example, to create relations that helps to consequently 764 /// access elements of operands of a matrix multiplication after creation of 765 /// the BLIS micro and macro kernels. 766 /// 767 /// @see ScheduleTreeOptimizer::createMicroKernel 768 /// @see ScheduleTreeOptimizer::createMacroKernel 769 /// 770 /// Subsequently, the described access relation is applied to the range of 771 /// @p MapOldIndVar, that is used to map original induction variables to 772 /// the ones, which are produced by schedule transformations. It helps to 773 /// define relations using a new space and, at the same time, keep them 774 /// in the original one. 775 /// 776 /// @param MapOldIndVar The relation, which maps original induction variables 777 /// to the ones, which are produced by schedule 778 /// transformations. 779 /// @param FirstDim, SecondDim The input dimensions that are used to define 780 /// the specified access relation. 781 /// @return The specified access relation. 782 static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim, 783 unsigned SecondDim) { 784 auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3); 785 auto AccessRel = isl::map::universe(AccessRelSpace); 786 AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0); 787 AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1); 788 AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2); 789 return MapOldIndVar.apply_range(AccessRel); 790 } 791 792 static isl::schedule_node createExtensionNode(isl::schedule_node Node, 793 isl::map ExtensionMap) { 794 auto Extension = isl::union_map(ExtensionMap); 795 auto NewNode = isl::schedule_node::from_extension(Extension); 796 return Node.graft_before(NewNode); 797 } 798 799 static isl::schedule_node optimizePackedB(isl::schedule_node Node, 800 ScopStmt *Stmt, isl::map MapOldIndVar, 801 MicroKernelParamsTy MicroParams, 802 MacroKernelParamsTy MacroParams, 803 MatMulInfoTy &MMI) { 804 Scop *S = Stmt->getParent(); 805 isl::set Domain = Stmt->getDomain(); 806 807 // Create packed array. 808 unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; 809 unsigned SecondDimSize = MacroParams.Kc; 810 unsigned ThirdDimSize = MicroParams.Nr; 811 ScopArrayInfo *PackedB = 812 S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B", 813 {FirstDimSize, SecondDimSize, ThirdDimSize}); 814 815 // Compute the access relation for copying from B to PackedB. 816 isl::map AccRelB = MMI.B->getLatestAccessRelation(); 817 isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7); 818 AccRelPackedB = 819 AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId()); 820 821 // Create the copy statement and redirect access. 822 ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain); 823 MMI.B->setNewAccessRelation(AccRelPackedB); 824 825 unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim()); 826 assert(Dim >= 2); 827 // Insert into the schedule tree. 828 isl::map ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, Dim - 2); 829 ExtMap = ExtMap.reverse(); 830 ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0); 831 ExtMap = ExtMap.intersect_range(Domain); 832 ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId()); 833 return createExtensionNode(Node, ExtMap); 834 } 835 836 static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *, 837 isl::map MapOldIndVar, 838 MicroKernelParamsTy MicroParams, 839 MacroKernelParamsTy MacroParams, 840 MatMulInfoTy &MMI) { 841 isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); 842 ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 843 isl::set Domain = Stmt->getDomain(); 844 isl::id DomainId = Domain.get_tuple_id(); 845 846 // Create the packed array. 847 unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr; 848 unsigned SecondDimSize = MacroParams.Kc; 849 unsigned ThirdDimSize = MicroParams.Mr; 850 ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo( 851 MMI.A->getElementType(), "Packed_A", 852 {FirstDimSize, SecondDimSize, ThirdDimSize}); 853 854 // Compute the access relation for copying from A to PackedA. 855 isl::map AccRelA = MMI.A->getLatestAccessRelation(); 856 isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6); 857 AccRelPackedA = 858 AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId()); 859 // { MemrefA[] -> PackedA[] } 860 isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA); 861 862 // Compute the domain for the copy statement. 863 // Construct the copy statement domain out of the 3 outermost scatter 864 // dimensions (to match the 3 band nodes surrounding the extension node) and 865 // the array elements to copy (one statement instance per array element). 866 // { Scatter[] } 867 isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range(); 868 // { Scatter[] -> OutermostScatter[] } 869 isl::map OuterDomainMap = 870 makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6); 871 // { Scatter[] -> MemrefA[] } 872 isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA); 873 // { Scatter[] -> CopyStmt[] } 874 isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom); 875 // { CopyStmt[] } 876 isl::set CopyDomain = DomainTranslator.range(); 877 878 // Translate the access relations to the new domain. 879 // { CopyStmt[] -> MemrefA[] } 880 CopyFrom = CopyFrom.apply_domain(DomainTranslator); 881 // { CopyStmt[] -> PackedA[] } 882 isl::map CopyTo = CopyFrom.apply_range(PackedATranslator); 883 884 // Create the copy statement and redirect access. 885 ScopStmt *CopyStmt = 886 Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain); 887 MMI.A->setNewAccessRelation(AccRelPackedA); 888 889 // Insert into the schedule tree. 890 // { Scatter[] -> CopyStmt[] } 891 isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true); 892 ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2); 893 return createExtensionNode(Node, ExtScatterCopy); 894 } 895 896 /// Apply the packing transformation. 897 /// 898 /// The packing transformation can be described as a data-layout 899 /// transformation that requires to introduce a new array, copy data 900 /// to the array, and change memory access locations to reference the array. 901 /// It can be used to ensure that elements of the new array are read in-stride 902 /// access, aligned to cache lines boundaries, and preloaded into certain cache 903 /// levels. 904 /// 905 /// As an example let us consider the packing of the array A that would help 906 /// to read its elements with in-stride access. An access to the array A 907 /// is represented by an access relation that has the form 908 /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has 909 /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr), 910 /// k mod Kc, j mod Nr, i mod Mr]. 911 /// 912 /// To ensure that elements of the array A are read in-stride access, we add 913 /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using 914 /// Scop::createScopArrayInfo, change the access relation 915 /// S[i, j, k] -> A[i, k] to 916 /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using 917 /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using 918 /// the copy statement created by Scop::addScopStmt. 919 /// 920 /// @param Node The schedule node to be optimized. 921 /// @param MapOldIndVar The relation, which maps original induction variables 922 /// to the ones, which are produced by schedule 923 /// transformations. 924 /// @param MicroParams, MacroParams Parameters of the BLIS kernel 925 /// to be taken into account. 926 /// @param MMI Parameters of the matrix multiplication operands. 927 /// @return The optimized schedule node. 928 static isl::schedule_node 929 optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar, 930 MicroKernelParamsTy MicroParams, 931 MacroKernelParamsTy MacroParams, 932 MatMulInfoTy &MMI) { 933 isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); 934 ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 935 936 Node = Node.parent().parent().parent().parent().parent().parent(); 937 Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)); 938 939 Node = Node.child(0); 940 Node = 941 optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); 942 943 Node = Node.child(0); 944 Node = 945 optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); 946 947 return Node.child(0).child(0).child(0).child(0).child(0); 948 } 949 950 /// Get a relation mapping induction variables produced by schedule 951 /// transformations to the original ones. 952 /// 953 /// @param Node The schedule node produced as the result of creation 954 /// of the BLIS kernels. 955 /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel 956 /// to be taken into account. 957 /// @return The relation mapping original induction variables to the ones 958 /// produced by schedule transformation. 959 /// @see ScheduleTreeOptimizer::createMicroKernel 960 /// @see ScheduleTreeOptimizer::createMacroKernel 961 /// @see getMacroKernelParams 962 static isl::map 963 getInductionVariablesSubstitution(isl::schedule_node Node, 964 MicroKernelParamsTy MicroKernelParams, 965 MacroKernelParamsTy MacroKernelParams) { 966 auto Child = Node.child(0); 967 auto UnMapOldIndVar = Child.get_prefix_schedule_union_map(); 968 auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar); 969 unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim()); 970 if (Dim > 9u) 971 return MapOldIndVar.project_out(isl::dim::out, 0, Dim - 9); 972 return MapOldIndVar; 973 } 974 975 /// Isolate a set of partial tile prefixes and unroll the isolated part. 976 /// 977 /// The set should ensure that it contains only partial tile prefixes that have 978 /// exactly Mr x Nr iterations of the two innermost loops produced by 979 /// the optimization of the matrix multiplication. Mr and Nr are parameters of 980 /// the micro-kernel. 981 /// 982 /// In case of parametric bounds, this helps to auto-vectorize the unrolled 983 /// innermost loops, using the SLP vectorizer. 984 /// 985 /// @param Node The schedule node to be modified. 986 /// @param MicroKernelParams Parameters of the micro-kernel 987 /// to be taken into account. 988 /// @return The modified isl_schedule_node. 989 static isl::schedule_node 990 isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node, 991 MicroKernelParamsTy MicroKernelParams) { 992 isl::schedule_node Child = Node.child(0); 993 isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation(); 994 isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range(); 995 unsigned Dims = unsignedFromIslSize(Prefix.tuple_dim()); 996 assert(Dims >= 1); 997 Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1); 998 Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr); 999 Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr); 1000 1001 isl::union_set IsolateOption = 1002 getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3); 1003 isl::ctx Ctx = Node.ctx(); 1004 auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll")); 1005 Options = Options.unite(getUnrollIsolatedSetOptions(Ctx)); 1006 Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); 1007 Node = Node.parent().parent().parent(); 1008 IsolateOption = getIsolateOptions(Prefix, 3); 1009 Options = IsolateOption.unite(getDimOptions(Ctx, "separate")); 1010 Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); 1011 Node = Node.child(0).child(0).child(0); 1012 return Node; 1013 } 1014 1015 /// Insert "Loop Vectorizer Disabled" mark node. 1016 /// 1017 /// @param Node The child of the mark node to be inserted. 1018 /// @return The modified isl_schedule_node. 1019 static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) { 1020 auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr); 1021 return Node.insert_mark(Id).child(0); 1022 } 1023 1024 /// Restore the initial ordering of dimensions of the band node 1025 /// 1026 /// In case the band node represents all the dimensions of the iteration 1027 /// domain, recreate the band node to restore the initial ordering of the 1028 /// dimensions. 1029 /// 1030 /// @param Node The band node to be modified. 1031 /// @return The modified schedule node. 1032 static isl::schedule_node 1033 getBandNodeWithOriginDimOrder(isl::schedule_node Node) { 1034 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 1035 if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf) 1036 return Node; 1037 auto Domain = Node.get_universe_domain(); 1038 assert(isl_union_set_n_set(Domain.get()) == 1); 1039 if (Node.get_schedule_depth().release() != 0 || 1040 (unsignedFromIslSize(isl::set(Domain).tuple_dim()) != 1041 unsignedFromIslSize(Node.as<isl::schedule_node_band>().n_member()))) 1042 return Node; 1043 Node = isl::manage(isl_schedule_node_delete(Node.copy())); 1044 auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff(); 1045 auto PartialScheduleMultiPwAff = 1046 isl::multi_union_pw_aff(PartialSchedulePwAff); 1047 PartialScheduleMultiPwAff = 1048 PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set); 1049 return Node.insert_partial_schedule(PartialScheduleMultiPwAff); 1050 } 1051 1052 static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node, 1053 const TargetTransformInfo *TTI, 1054 MatMulInfoTy &MMI) { 1055 assert(TTI && "The target transform info should be provided."); 1056 int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 1057 assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " 1058 "and, consequently, the corresponding scheduling " 1059 "functions have at least three dimensions."); 1060 Node = getBandNodeWithOriginDimOrder(Node); 1061 Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3); 1062 int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; 1063 int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; 1064 Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2); 1065 NewK = NewK == DimOutNum - 2 ? NewJ : NewK; 1066 Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1); 1067 auto MicroKernelParams = getMicroKernelParams(TTI, MMI); 1068 auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI); 1069 Node = createMacroKernel(Node, MacroKernelParams); 1070 Node = createMicroKernel(Node, MicroKernelParams); 1071 if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || 1072 MacroKernelParams.Kc == 1) 1073 return Node; 1074 auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams, 1075 MacroKernelParams); 1076 if (MapOldIndVar.is_null()) 1077 return Node; 1078 Node = markLoopVectorizerDisabled(Node.parent()).child(0); 1079 Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); 1080 return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, 1081 MacroKernelParams, MMI); 1082 } 1083 1084 /// Check if this node contains a partial schedule that could 1085 /// probably be optimized with analytical modeling. 1086 /// 1087 /// isMatrMultPattern tries to determine whether the following conditions 1088 /// are true: 1089 /// 1. the partial schedule contains only one statement. 1090 /// 2. there are exactly three input dimensions. 1091 /// 3. all memory accesses of the statement will have stride 0 or 1, if we 1092 /// interchange loops (switch the variable used in the inner loop to 1093 /// the outer loop). 1094 /// 4. all memory accesses of the statement except from the last one, are 1095 /// read memory access and the last one is write memory access. 1096 /// 5. all subscripts of the last memory access of the statement don't 1097 /// contain the variable used in the inner loop. 1098 /// If this is the case, we could try to use an approach that is similar to 1099 /// the one used to get close-to-peak performance of matrix multiplications. 1100 /// 1101 /// @param Node The node to check. 1102 /// @param D The SCoP dependencies. 1103 /// @param MMI Parameters of the matrix multiplication operands. 1104 static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D, 1105 MatMulInfoTy &MMI) { 1106 auto PartialSchedule = isl::manage( 1107 isl_schedule_node_band_get_partial_schedule_union_map(Node.get())); 1108 if (isl_schedule_node_band_n_member(Node.get()) < 3 || 1109 Node.get_schedule_depth().release() != 0 || 1110 isl_union_map_n_map(PartialSchedule.get()) != 1) 1111 return false; 1112 auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule); 1113 if (containsMatrMult(NewPartialSchedule, D, MMI)) 1114 return true; 1115 return false; 1116 } 1117 1118 /// Get the dimension size. 1119 /// 1120 /// Return the size of the dimension @p Pos, which is obtained from @p SAI. 1121 /// Return -1 in the case of the first dimension of a multi-dimensional array, 1122 /// since the ScopArrayInfo class does not carry size information. 1123 /// 1124 /// @param SAI The information about the array. 1125 /// @param Pos The position of the dimension. 1126 /// @return The size of the dimension. 1127 static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) { 1128 if (Pos == 0) 1129 return -1; 1130 const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Pos); 1131 assert(SCEVDimSize); 1132 auto *ConstantDimSize = dyn_cast<const SCEVConstant>(SCEVDimSize); 1133 assert(ConstantDimSize); 1134 auto *IntDimSize = dyn_cast<ConstantInt>(ConstantDimSize->getValue()); 1135 assert(IntDimSize); 1136 return IntDimSize->getSExtValue(); 1137 } 1138 1139 /// Check whether the access relation has the specified form. 1140 /// 1141 /// Check that the access relation @p AccMap has the form T[I0, …, In], where 1142 /// indexes I0, …, In are specified by @p Dimensions. 1143 /// 1144 /// @param Domain The domain of the access relation. 1145 /// @param AccMap The access relation to be checked. 1146 /// @param Dimensions The permutation of the subset of the input dimensions. 1147 /// @return True if @p AccMap has the expected form and false, 1148 /// otherwise. 1149 static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap, 1150 ArrayRef<int> Dimensions) { 1151 isl::space Space = AccMap.get_space(); 1152 if (unsignedFromIslSize(Space.dim(isl::dim::out)) != Dimensions.size()) 1153 return false; 1154 1155 // Create an access relation of the following form: 1156 // [I0, …, Im] -> [Il, …, In], where indexes 1157 // Il, …, In are specified by @p Dimensions. 1158 isl::map PossibleTensor = isl::map::universe(Space); 1159 unsigned DimInSize = unsignedFromIslSize(Space.dim(isl::dim::in)); 1160 for (unsigned i = 0; i < Dimensions.size(); i++) { 1161 const int InPos = Dimensions[i]; 1162 if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0)) 1163 return false; 1164 PossibleTensor = 1165 PossibleTensor.equate(isl::dim::in, InPos, isl::dim::out, i); 1166 } 1167 1168 AccMap = AccMap.intersect_domain(Domain); 1169 PossibleTensor = PossibleTensor.intersect_domain(Domain); 1170 1171 // If AccMap != PossibleTensor here (the two maps have been gisted at 1172 // this point), it means that the writes are not complete, or in other 1173 // words, it is a Partial write and Partial writes must be rejected. 1174 return AccMap.is_equal(PossibleTensor); 1175 } 1176 1177 /// Check whether the access represents the tensor contraction operand. 1178 /// 1179 /// Check that the access relation @p AccMap has the form T[i1, …, in]. 1180 /// Obtained indexes i1, …, in, their sizes and their permutation are stored 1181 /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively. 1182 /// 1183 /// @param Domain The domain of the access relation. 1184 /// @param AccMap The access relation to be checked. 1185 /// @param IndexSet The subset of the input dimensions. 1186 /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions. 1187 /// @param Dimensions The permutation of the subset of the input dimensions. 1188 /// @return True if @p AccMap has the expected form and false, 1189 /// otherwise. 1190 static bool isTCOperandAcc(isl::set Domain, isl::map AccMap, 1191 SmallDenseSet<int> &IndexSet, 1192 SmallVectorImpl<int> &DimensionSizes, 1193 SmallVectorImpl<int> &Dimensions) { 1194 isl::id Id = AccMap.get_tuple_id(isl::dim::out); 1195 const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id); 1196 assert(SAI && "AccMap should represent memory access"); 1197 1198 // Fix values of output dimensions with respect to their positions. 1199 // In the case of the tensor contraction, values of output dimensions are 1200 // fixed and form a permutation of a subset of values of input dimensions. 1201 // 1202 // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents 1203 // the operand of the tensor contraction, we get the following map by fixing 1204 // the output dimensions Stmt[1][j][0] -> A[0][1]. 1205 // 1206 // We store the permutation of the subset of the input dimensions {2, 0} into 1207 // @p Dimensions. 1208 // 1209 // The obtained permutation and the isCorrectAccessMap function are used to 1210 // check whether the access relation @p AccMap represents the tensor 1211 // contraction operand. For example, in the case of 1212 // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and, 1213 // consequently, {1, 0}, which is rejected by isCorrectAccessMap, 1214 // since it corresponds to Stmt[i][j][k] -> A[j][i]. 1215 isl::map CheckMap = isl::manage(AccMap.copy()); 1216 unsigned OutDimNum = unsignedFromIslSize(CheckMap.dim(isl::dim::out)); 1217 for (unsigned i = 0; i < OutDimNum; i++) 1218 CheckMap = CheckMap.fix_si(isl::dim::out, i, i); 1219 1220 // Try to obtain the permutation and sizes of corresponding input dimensions. 1221 Dimensions.assign(OutDimNum, -1); 1222 for (unsigned i : rangeIslSize(0, CheckMap.dim(isl::dim::in))) { 1223 isl::val Val = getConstant(CheckMap, isl::dim::in, i); 1224 if (!Val.is_int()) 1225 continue; 1226 int OutPos = -1; 1227 llvm::APInt ValAPInt = APIntFromVal(Val); 1228 if (ValAPInt.isSignedIntN(32)) 1229 OutPos = ValAPInt.getSExtValue(); 1230 if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) || 1231 IndexSet.count(i)) 1232 return false; 1233 IndexSet.insert(i); 1234 Dimensions[OutPos] = i; 1235 if (DimensionSizes[i] <= 0) 1236 DimensionSizes[i] = getDimSize(SAI, OutPos); 1237 } 1238 1239 return isCorrectAccessMap(Domain, AccMap, Dimensions); 1240 } 1241 1242 /// Find the intersection of two sets. 1243 /// 1244 /// Find the intersection of the set @p A and the set @p B. 1245 /// 1246 /// @param A, B Sets to intersect. 1247 /// @return The set intersection. 1248 static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A, 1249 const SmallDenseSet<int> &B) { 1250 SmallDenseSet<int> Intersection = A; 1251 set_intersect(Intersection, B); 1252 return Intersection; 1253 } 1254 1255 /// Check whether the set is a superset. 1256 /// 1257 /// Check that the set @p A is a superset of @p B. 1258 /// 1259 /// @param A, B Sets to be checked. 1260 /// @return True if the set A is a superset of B. 1261 static bool isSuperset(const SmallDenseSet<int> &A, 1262 const SmallDenseSet<int> &B) { 1263 return intersect(A, B).size() == B.size(); 1264 } 1265 1266 /// Find the union of two sets. 1267 /// 1268 /// Find the union of the set @p A and the set @p B. 1269 /// 1270 /// @param A, B Sets to unite. 1271 /// @return The set union. 1272 static SmallDenseSet<int> unite(const SmallDenseSet<int> &A, 1273 const SmallDenseSet<int> &B) { 1274 SmallDenseSet<int> Union = A; 1275 set_union(Union, B); 1276 return Union; 1277 } 1278 1279 /// Determine the access that writes to the tensor, which contains 1280 /// the result of the tensor contraction. 1281 /// 1282 /// @param Domain The domain of the statement. 1283 /// @param Stmt The statement, which writes to memory. 1284 /// @param TCI The information about the tensor contraction. 1285 /// @param IandJIndexSet The set, which contains free indexes of tensors. 1286 /// @return The determined MemoryAccess, or nullptr if there is no necessary 1287 /// access within the SCoP. 1288 static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt, 1289 TCInfoTy &TCI, 1290 SmallDenseSet<int> &IandJIndexSet) { 1291 TCI.WriteToC = nullptr; 1292 SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt); 1293 for (MemoryAccess *MemA : reverse(Accesses)) { 1294 // A TC-like does not contain write scalar memory accesses 1295 if (!MemA->isLatestArrayKind()) 1296 return nullptr; 1297 // The last memory access should be a write memory access. 1298 if (!MemA->isWrite()) 1299 return nullptr; 1300 1301 isl::map AccMap = MemA->getLatestAccessRelation(); 1302 if (!isTCOperandAcc(Domain, AccMap, IandJIndexSet, TCI.DimensionSizes, 1303 TCI.CDimensions)) 1304 return nullptr; 1305 1306 return MemA; 1307 } 1308 return nullptr; 1309 } 1310 1311 /// Determine an access, which reads elements of an operand of the tensor 1312 /// contraction 1313 /// 1314 /// @param MemAccessPtr The access, which reads elements of the tensor. 1315 /// @param IndexSet The set, which contains indexes of the tensors. 1316 /// @param IandJIndexSet The set, which contains free indexes of tensors. 1317 /// @param Dimensions The permutation of the subset of the input dimensions. 1318 /// @param TCI The information about the tensor contraction. 1319 /// @return True if the memory access @p MemAccessPtr corresponds 1320 /// to the tensor contraction. 1321 static bool setReadAccess(MemoryAccess *MemAccessPtr, 1322 const SmallDenseSet<int> &IndexSet, 1323 const SmallDenseSet<int> &IandJIndexSet, 1324 ArrayRef<int> Dimensions, TCInfoTy &TCI) { 1325 if (!TCI.A) { 1326 // Probably IndexSet is a union of I and P sets. 1327 if (!isSuperset(IndexSet, TCI.P)) 1328 return false; 1329 1330 // Obtain the set I. 1331 TCI.I = set_difference(IndexSet, TCI.P); 1332 if (!isSuperset(IandJIndexSet, TCI.I)) 1333 return false; 1334 1335 // Obtain the set J. 1336 TCI.J = set_difference(IandJIndexSet, TCI.I); 1337 1338 // Set the first operand of the tensor contraction. 1339 TCI.A = MemAccessPtr; 1340 llvm::replace(TCI.ADimensions, TCI.ADimensions.begin(), 1341 TCI.ADimensions.end(), Dimensions.begin(), Dimensions.end()); 1342 return true; 1343 } 1344 1345 if (!TCI.B) { 1346 // IndexSet should be a union of J and P sets. 1347 if (unite(TCI.P, TCI.J) != IndexSet) 1348 return false; 1349 1350 // Set the second operand of the tensor contraction. 1351 TCI.B = MemAccessPtr; 1352 llvm::replace(TCI.BDimensions, TCI.BDimensions.begin(), 1353 TCI.BDimensions.end(), Dimensions.begin(), Dimensions.end()); 1354 return true; 1355 } 1356 1357 return false; 1358 } 1359 1360 /// Check that all memory accesses of the statement, except from the last 1361 /// one, are read memory accesses, which read elements of operands of the tensor 1362 /// contraction and its result. 1363 /// 1364 /// @param Domain The domain of the statement. 1365 /// @param Stmt The statement, which writes to memory. 1366 /// @param TCI The information about the tensor contraction. 1367 /// @param IandJIndexSet The set, which contains free indexes of tensors. 1368 /// @return True if all read memory accesses of the statement @p Stmt correspond 1369 /// to the tensor contraction. 1370 static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI, 1371 SmallDenseSet<int> &IandJIndexSet) { 1372 TCI.A = nullptr; 1373 TCI.B = nullptr; 1374 TCI.ReadFromC = nullptr; 1375 SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt); 1376 for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) { 1377 MemoryAccess *MemAccessPtr = *MemA; 1378 1379 // All memory accesses, except from the last one, should be read memory 1380 // accesses. 1381 if (MemAccessPtr->isWrite()) 1382 return false; 1383 1384 isl::map AccMap = MemAccessPtr->getLatestAccessRelation(); 1385 1386 if (!MemAccessPtr->isLatestArrayKind()) { 1387 // Check whether the scalar read memory access is not partial. 1388 if (!Domain.is_subset(AccMap.domain())) 1389 return false; 1390 continue; 1391 return false; 1392 } 1393 1394 // There is only one memory access, which reads elements of the result of 1395 // the tensor contraction. 1396 if (AccMap.is_equal(TCI.WriteToC->getLatestAccessRelation())) { 1397 if (TCI.ReadFromC) 1398 return false; 1399 TCI.ReadFromC = MemAccessPtr; 1400 continue; 1401 } 1402 1403 SmallVector<int> Dimensions; 1404 SmallDenseSet<int> IndexSet; 1405 if (!isTCOperandAcc(Domain, AccMap, IndexSet, TCI.DimensionSizes, 1406 Dimensions)) 1407 return false; 1408 1409 if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI)) 1410 return false; 1411 } 1412 1413 // Check that there are read memory accesses, which read elements of operands 1414 // of the tensor contraction and its result. 1415 return TCI.ReadFromC && TCI.A && TCI.B; 1416 } 1417 1418 /// Check accesses to operands of the tensor contraction. 1419 /// 1420 /// Check that accesses of the SCoP statement, which corresponds to 1421 /// the partial schedule @p PartialSchedule, represent accesses 1422 /// to the non-scalar operands of the tensor contraction. 1423 /// 1424 /// @param Domain The domain of the SCoP statement. 1425 /// @param PartialSchedule The partial schedule of the SCoP statement. 1426 /// @param TCI Parameters of the tensor contraction operands. 1427 /// @return True if the corresponding SCoP statement 1428 /// represents tensor contraction and false, 1429 /// otherwise. 1430 static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule, 1431 TCInfoTy &TCI) { 1432 isl::id InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); 1433 ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 1434 1435 // In region statements, the order of memory accesses execution is not 1436 // predictable at compile-time. 1437 if ((Stmt->size() <= 1) || Stmt->isRegionStmt()) 1438 return false; 1439 1440 unsigned DimNum = unsignedFromIslSize(PartialSchedule.dim(isl::dim::in)); 1441 TCI.DimensionSizes.resize(DimNum); 1442 SmallDenseSet<int> IandJIndexSet; 1443 1444 TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet); 1445 if (!TCI.WriteToC) 1446 return false; 1447 1448 if (intersect(IandJIndexSet, TCI.P).size() != 0) 1449 return false; 1450 1451 if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet)) 1452 return false; 1453 1454 return true; 1455 } 1456 1457 /// Check that dependency corresponds to the tensor contraction carried over 1458 /// loop dimension @p Dim. 1459 /// 1460 /// Check that the dependency has the form 1461 /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1462 /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1463 /// statement. For this purpose, we analyze the set @p DepDelta, which 1464 /// represents the differences between image elements and domain elements of 1465 /// the corresponding map. 1466 /// 1467 /// @param DepDelta The set contains the differences between image elements 1468 /// and corresponding domain elements of the map, which 1469 /// represents the dependency. 1470 /// @param Dim The position of the index ki. 1471 /// @param BoundDeltas In the case of indexes of ki, the difference between 1472 /// image elements and corresponding domain elements 1473 /// corresponds to the difference between lexicographic 1474 /// minimum and lexicographic maximum of the corresponding 1475 /// dimension of the domain of the statement. 1476 /// @param IndexSet Obtained indexes ki, which describe the dependency. 1477 /// @return True if dependencies correspond to the tensor contraction 1478 /// and false, otherwise. 1479 static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim, 1480 isl::pw_multi_aff BoundDeltas, 1481 const SmallDenseSet<int> &IndexSet) { 1482 isl::space Space = DepDelta.get_space(); 1483 isl::set Superset = isl::set::universe(Space); 1484 for (unsigned i = 0; i < Dim; i += 1) 1485 Superset = Superset.fix_si(isl::dim::set, i, 0); 1486 Superset = Superset.fix_si(isl::dim::set, Dim, 1); 1487 1488 // Check that the difference between the image element and the domain element 1489 // is equal to one in the case of the index ki. Image elements and 1490 // corresponding domain elements should be equal in the case of positions, 1491 // which are lower than the specified position. 1492 if (!DepDelta.is_subset(Superset)) 1493 return false; 1494 1495 // Compute a set, which is used to analyze how values of 1496 // the domain are related to the map that describes the dependency. 1497 isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(DepDelta.copy()); 1498 BoundDeltas = BoundDeltas.add(isl::manage(DepDeltaPW)); 1499 isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(BoundDeltas.release()); 1500 isl::set Complement = isl::manage(ComplementRawSet); 1501 1502 for (unsigned i : rangeIslSize(Dim + 1, DepDelta.dim(isl::dim::set))) { 1503 if (!IndexSet.count(i)) { 1504 // Check the difference between the image element and the domain element 1505 // in the case of indexes, which do not describe the dependency. 1506 if (DepDelta.plain_get_val_if_fixed(isl::dim::set, i).is_zero()) 1507 continue; 1508 return false; 1509 } 1510 1511 // In the case of other indexes, which describe the dependency, 1512 // the difference between the image element and the domain element 1513 // should be equal to the difference between lexicographic minimum and 1514 // lexicographic maximum of the domain of the statement. 1515 if (!Complement.plain_get_val_if_fixed(isl::dim::set, i).is_zero()) 1516 return false; 1517 } 1518 1519 return true; 1520 } 1521 1522 /// Check whether dependencies are over the complete domain. 1523 /// 1524 /// In the case of the tensor contraction RAW, WAW, WAR dependencies 1525 /// have the form 1526 /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1527 /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1528 /// statement. Consequently, the domain of the dependencies 1529 /// can be described as 1530 /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…), 1531 /// where Domain is the domain of the statement S. 1532 /// 1533 /// For example, in the case of the following tensor contraction, 1534 /// corresponding domains will have the following form. 1535 /// 1536 /// An example of the tensor contraction: 1537 /// for (i = 0; i < 1024; i++) 1538 /// for (j = 0; j < 1024; j++) 1539 /// for (l = 0; l < 64; ++l) 1540 /// for (w = 0; w < 64; ++w) 1541 /// C[i][j] += A[i][l][w] * B[w][j][l]; 1542 /// 1543 /// The domain of the statement: 1544 /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and 1545 /// i1 >= 0 and i1 <= 1023 and 1546 /// i2 >= 0 and i2 <= 63 and 1547 /// i3 >= 0 and i3 <= 63 } 1548 /// 1549 /// The domain of the dependencies: 1550 /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and 1551 /// i1 >= 0 and i1 <= 1023 and 1552 /// i2 >= 0 and i2 <= 63 and 1553 /// i3 >= 0 and i3 <= 62) or 1554 /// (i3 = 63 and i0 >= 0 and i0 <= 1023 and 1555 /// i1 >= 0 and i1 <= 1023 and 1556 /// i2 >= 0 and i2 <= 62) } 1557 /// 1558 /// @param Domain The domain of the statement. 1559 /// @param DepsForStmt RAW and RED dependencies for the statement. 1560 /// @param UpperBound The lexicographic maximum of the elements in 1561 /// the @p Domain. 1562 /// @param IndexSet Obtained indexes ki, which describe the dependencies. 1563 /// @return True if dependencies are over the complete domain 1564 /// and false, otherwise. 1565 static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt, 1566 isl::pw_multi_aff UpperBound, 1567 SmallDenseSet<int> &IndexSet) { 1568 isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(UpperBound.copy()); 1569 isl::set UpperBoundSet = isl::manage(UpperBoundRawSet); 1570 1571 isl::set DomainRed = isl::manage(Domain.copy()); 1572 for (const auto It : IndexSet) { 1573 isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(isl::dim::set, It); 1574 if (FixedVal.is_nan()) 1575 return false; 1576 DomainRed = isl::manage( 1577 isl_set_fix_val(DomainRed.copy(), isl_dim_set, It, FixedVal.release())); 1578 } 1579 return DepsForStmt.domain().intersect(Domain).is_equal( 1580 Domain.subtract(DomainRed)); 1581 } 1582 1583 /// Check that dependencies correspond to the tensor contraction. 1584 /// 1585 /// Check that there are only true dependencies of the form 1586 /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1587 /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1588 /// statement represented by @p Schedule. Such dependencies are produced by 1589 /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet. 1590 /// 1591 /// The form of anti and output dependencies is specified implicitly by 1592 /// the form the SCoP statement, which is checked by subsequent analysis. 1593 /// 1594 /// @param Schedule The schedule of the SCoP statement. 1595 /// @param D The SCoP dependencies. 1596 /// @param Domain The domain of the statement. 1597 /// @param IndexSet Obtained indexes ki, which describe the dependencies. 1598 /// @return True if dependencies correspond to the tensor contraction 1599 /// and false, otherwise. 1600 static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D, 1601 SmallDenseSet<int> &IndexSet, isl::set Domain) { 1602 IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut); 1603 1604 isl::union_map Dep = 1605 D->getDependences(Dependences::TYPE_RAW | Dependences::TYPE_RED); 1606 1607 isl::space DomainSpace = Schedule.get_space().domain(); 1608 isl::space Space = DomainSpace.map_from_domain_and_range(DomainSpace); 1609 isl::map DepsForStmt = Dep.extract_map(Space); 1610 isl::set DepDeltas = DepsForStmt.deltas(); 1611 isl::size DeltasDimNum = DepDeltas.dim(isl::dim::set); 1612 isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff(); 1613 isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff(); 1614 isl::pw_multi_aff BoundDeltas = UpperBound.sub(LowerBound); 1615 1616 for (int i : reverse(rangeIslSize(0, DeltasDimNum))) { 1617 // In the case of the tensor contraction, the difference between image 1618 // elements and domain elements lies on a hyperplane where a dimension 1619 // has the fixed value one. 1620 isl::set Intersection = DepDeltas.fix_si(isl::dim::set, i, 1); 1621 if (Intersection.is_empty()) 1622 continue; 1623 1624 if (!isReductionCarriedOverDim(Intersection, i, BoundDeltas, IndexSet)) 1625 return false; 1626 1627 IndexSet.insert(i); 1628 DepDeltas = DepDeltas.subtract(Intersection); 1629 } 1630 1631 // In the case of the tensor contraction, all dependencies should have 1632 // the previously described form. 1633 if ((unsignedFromIslSize(DeltasDimNum) == 0) || !DepDeltas.is_empty()) 1634 return false; 1635 1636 return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet); 1637 } 1638 1639 /// Check if the SCoP statement could probably be optimized with analytical 1640 /// modeling. 1641 /// 1642 /// containsTCInfoTy tries to determine whether the following conditions 1643 /// are true: 1644 /// 1645 /// 1. The last memory access modeling an array, MA1, represents writing to 1646 /// memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)), 1647 /// where S is the SCoP statement under consideration and shuffle(I, J) 1648 /// is a permutation of indexes of sets I and J. 1649 /// 2. There are only true dependencies of the form 1650 /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> 1651 /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP 1652 /// statement represented by @p Schedule and ki are indexes of the set P. 1653 /// 3. SCoP contains an arbitrary number of reads from constants and only three 1654 /// access relations, MA2, MA3, and MA4 that represent reading from memory 1655 /// and have the form 1656 /// S(..., I, ..., P, ...) -> M(shuffle(I, P)), 1657 /// S(..., P, ..., J, ...) -> M(shuffle(J, P)), 1658 /// S(...) -> M(shuffle(I, J)), respectively. 1659 /// 1660 /// @param PartialSchedule The PartialSchedule that contains a SCoP statement 1661 /// to check. 1662 /// @param D The SCoP dependencies. 1663 /// @param TCI Parameters of the tensor contraction operands. 1664 /// @param Domain The domain of the statement. 1665 /// @return True if dependencies and memory accesses correspond to the tensor 1666 /// contraction and false, otherwise. 1667 static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D, 1668 TCInfoTy &TCI, isl::set Domain) { 1669 if (!containsOnlyTcDeps(PartialSchedule, D, TCI.P, Domain)) 1670 return false; 1671 1672 // TODO: handle cases of scalar multiplication if needed. 1673 if (TCI.P.size() == 0) 1674 return false; 1675 1676 if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI)) 1677 return false; 1678 1679 // TODO: handle cases of GEMV if needed. 1680 if ((TCI.I.size() == 0) || (TCI.J.size() == 0)) 1681 return false; 1682 1683 return true; 1684 } 1685 1686 /// Check if this node contains a partial schedule that could 1687 /// probably be optimized with analytical modeling. 1688 /// 1689 /// isTCPattern is used to determine whether the SCoP represents a TC-like 1690 /// kernel [1], which is a perfectly nested set of loops, with a data usage 1691 /// pattern that is similar to that produced by the tensor contraction. 1692 /// 1693 /// A TC-like kernel can be defined as follows: 1694 /// 1695 /// 1. It satisfies the requirements of the polyhedral model. 1696 /// 2. Without loss of generality, it contains three nonempty bundles of 1697 /// one-dimensional for-loops with induction variables that are grouped into 1698 /// bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they 1699 /// are incremented by one. 1700 /// 3. The innermost loop body can be represented as a statement of the form 1701 /// C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)), 1702 /// C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)), 1703 /// C(shuffle(I, J)) are accesses to tensors A, B, C, respectively, 1704 /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the 1705 /// enclosed indices, and E is an expression that contains reads from 1706 /// the tensors A, B, C, and an arbitrary number of reads from constants 1707 /// with respect to bundles I, J, and P. 1708 /// 1709 /// TC can be considered as a particular case of a TC-like kernel. 1710 /// 1711 /// The order of loops with indexes from P should be preserved. Otherwise, 1712 /// isTCPattern should check if a commutative operation is used. 1713 /// 1714 /// isTCPattern performs the following steps to check whether the SCoP 1715 /// corresponds to a definition of a TC-like kernel: 1716 /// 1717 /// 1. Checks that the node is the innermost band node. 1718 /// 2. Checks that the partial schedule contains only one statement. 1719 /// 3. Check that all ancestors of the node contain all band nodes for 1720 /// the statement and only mark nodes interleave such band nodes. This 1721 /// corresponds to a straightforward implementation of TC. 1722 /// 4. Analyses the dependencies to determine contraction dimensions. 1723 /// 5. Check that the last memory access modeling an array, represents writing 1724 /// to the result of the TC-like kernel. 1725 /// 6. Check that SCoP contains only three access relations that represent 1726 /// reading of the operands of the TC-like kernel and an arbitrary number of 1727 /// reads from constants. 1728 /// 1729 /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor 1730 /// Operations: A Compiler-Oriented Approach // ACM Transactions 1731 /// Architecture and Code Optimization (TACO). 2018. 1732 /// Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029. 1733 /// 1734 /// If this is the case, we could logically represent tensors as matrices and 1735 /// apply algorithms, which are used to get close-to-peak performance of 1736 /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS). 1737 /// 1738 /// @param Node The node to check. 1739 /// @param D The SCoP dependencies. 1740 /// @param TCI Parameters of the tensor contraction operands. 1741 static bool isTCPattern(isl::schedule_node Node, const Dependences *D, 1742 TCInfoTy &TCI) { 1743 Node = Node.child(0); 1744 isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map(); 1745 isl::union_set Domain = Node.domain(); 1746 Node = Node.parent(); 1747 1748 // The partial schedule should contain only one statement. 1749 // TODO: This constraint should not be intrinsic to the algorithm. 1750 if (isl_union_set_n_set(Domain.get()) != 1) 1751 return false; 1752 1753 isl_schedule_node_type NodeType = isl_schedule_node_get_type(Node.get()); 1754 1755 // Check that all ancestors of the node contain all band nodes for 1756 // the statement, which represents the TC-like kernel, and only mark nodes 1757 // interleave such band nodes. This corresponds to a straightforward 1758 // implementation of TC with/without DeLICM applied. 1759 // 1760 // For example, this covers the matrix multiplication pattern after a full 1761 // run of -polly-optree and -polly-delicm, where the write access is not 1762 // through the original memory access, but through a PHI node that was 1763 // delicmed. Subsequently, such band nodes will be replaced by a single band 1764 // node. 1765 // 1766 // The corresponding schedule can be the following, where Stmt_for_body8 1767 // contains the matrix multiplication: 1768 // 1769 // domain: "{ Stmt_for_body8[i0, i1, i2] : 0 <= i0 <= 1599 and 1770 // 0 <= i1 <= 1799 and 1771 // 0 <= i2 <= 2199; 1772 // Stmt_for_body3[i0, i1] : 0 <= i0 <= 1599 and 1773 // 0 <= i1 <= 1799; 1774 // Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and 1775 // 0 <= i1 <= 1799 }" 1776 // child: 1777 // sequence: 1778 // - filter: "{ Stmt_for_body3[i0, i1] }" 1779 // child: 1780 // schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] }, 1781 // { Stmt_for_body3[i0, i1] -> [(i1)] }]" 1782 // permutable: 1 1783 // coincident: [ 1, 1 ] 1784 // - filter: "{ Stmt_for_body3_last[i0, i1] }" 1785 // child: 1786 // schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] }, 1787 // { Stmt_for_body3_last[i0, i1] -> [(i1)] }]" 1788 // permutable: 1 1789 // coincident: [ 1, 1 ] 1790 // - filter: "{ Stmt_for_body8[i0, i1, i2] }" 1791 // child: 1792 // schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] }, 1793 // { Stmt_for_body8[i0, i1, i2] -> [(i1)] }, 1794 // { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]" 1795 // permutable: 1 1796 // coincident: [ 1, 1, 0 ] 1797 // 1798 while (NodeType != isl_schedule_node_domain) { 1799 if (NodeType == isl_schedule_node_filter) { 1800 if (!Node.parent().isa<isl::schedule_node_sequence>() || 1801 !Node.parent().parent().isa<isl::schedule_node_domain>()) 1802 return false; 1803 break; 1804 } 1805 1806 if ((NodeType != isl_schedule_node_band) && 1807 (NodeType != isl_schedule_node_mark)) 1808 return false; 1809 1810 Node = Node.parent(); 1811 NodeType = isl_schedule_node_get_type(Node.get()); 1812 } 1813 1814 isl::map PartialScheduleMap = isl::map::from_union_map(PartialSchedule); 1815 if (containsTCInfoTy(PartialScheduleMap, D, TCI, isl::set(Domain))) 1816 return true; 1817 1818 return false; 1819 } 1820 1821 } // namespace 1822 1823 isl::schedule_node 1824 polly::tryOptimizeMatMulPattern(isl::schedule_node Node, 1825 const llvm::TargetTransformInfo *TTI, 1826 const Dependences *D) { 1827 TCInfoTy TCI; 1828 if (PMBasedTCOpts && isTCPattern(Node, D, TCI)) 1829 POLLY_DEBUG(dbgs() << "The tensor contraction pattern was detected\n"); 1830 MatMulInfoTy MMI; 1831 if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) { 1832 POLLY_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); 1833 return optimizeMatMulPattern(Node, TTI, MMI); 1834 } 1835 return {}; 1836 } 1837