1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/ADT/SmallSet.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Analysis/AliasAnalysis.h" 26 #include "llvm/Analysis/DomTreeUpdater.h" 27 #include "llvm/Analysis/LoopInfo.h" 28 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 29 #include "llvm/Analysis/TargetTransformInfo.h" 30 #include "llvm/Analysis/ValueTracking.h" 31 #include "llvm/Analysis/VectorUtils.h" 32 #include "llvm/IR/CFG.h" 33 #include "llvm/IR/DataLayout.h" 34 #include "llvm/IR/DebugInfoMetadata.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/Instructions.h" 38 #include "llvm/IR/IntrinsicInst.h" 39 #include "llvm/IR/MatrixBuilder.h" 40 #include "llvm/IR/PatternMatch.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/LoopUtils.h" 46 #include "llvm/Transforms/Utils/MatrixUtils.h" 47 48 #include <cmath> 49 50 using namespace llvm; 51 using namespace PatternMatch; 52 53 #define DEBUG_TYPE "lower-matrix-intrinsics" 54 55 static cl::opt<bool> 56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 57 cl::desc("Enable/disable fusing matrix instructions.")); 58 // TODO: Allow and use non-square tiles. 59 static cl::opt<unsigned> TileSize( 60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 61 cl::desc( 62 "Tile size for matrix instruction fusion using square-shaped tiles.")); 63 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 64 cl::Hidden, 65 cl::desc("Generate loop nest for tiling.")); 66 static cl::opt<bool> ForceFusion( 67 "force-fuse-matrix", cl::init(false), cl::Hidden, 68 cl::desc("Force matrix instruction fusion even if not profitable.")); 69 static cl::opt<bool> AllowContractEnabled( 70 "matrix-allow-contract", cl::init(false), cl::Hidden, 71 cl::desc("Allow the use of FMAs if available and profitable. This may " 72 "result in different results, due to less rounding error.")); 73 74 static cl::opt<bool> 75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 76 cl::desc("Enable/disable matrix shape verification."), 77 cl::init(false)); 78 79 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 80 81 static cl::opt<MatrixLayoutTy> MatrixLayout( 82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 83 cl::desc("Sets the default matrix layout"), 84 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 85 "Use column-major layout"), 86 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 87 "Use row-major layout"))); 88 89 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 90 cl::init(false)); 91 92 /// Helper function to either return Scope, if it is a subprogram or the 93 /// attached subprogram for a local scope. 94 static DISubprogram *getSubprogram(DIScope *Scope) { 95 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 96 return Subprogram; 97 return cast<DILocalScope>(Scope)->getSubprogram(); 98 } 99 100 /// Return true if V is a splat of a value (which is used when multiplying a 101 /// matrix with a scalar). 102 static bool isSplat(Value *V) { 103 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 104 return SV->isZeroEltSplat(); 105 return false; 106 } 107 108 /// Match any mul operation (fp or integer). 109 template <typename LTy, typename RTy> 110 auto m_AnyMul(const LTy &L, const RTy &R) { 111 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 112 } 113 114 /// Match any add operation (fp or integer). 115 template <typename LTy, typename RTy> 116 auto m_AnyAdd(const LTy &L, const RTy &R) { 117 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 118 } 119 120 namespace { 121 122 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 123 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 124 // assuming \p Stride elements between start two consecutive vectors. 125 // \p Stride must be >= \p NumElements. 126 // For column-major matrixes, the function computes the address of a column 127 // vectors and \p NumElements must be set to the number of elements in a column 128 // (= number of rows of the matrix). For row-major matrixes, the function 129 // computes the address of a row vector and \p NumElements must be set to the 130 // number of elements in a column (= number of columns of the matrix). 131 // 132 // Consider a 4x4 matrix in column-mjaor layout like below 133 // 134 // 0 1 2 3 135 // 0 v_0_0 v_0_1 v_0_2 v_0_3 136 // 1 v_1_0 v_1_1 v_1_2 v_1_3 137 // 2 v_2_0 v_2_1 v_2_2 v_2_3 138 // 3 v_3_0 v_3_1 v_3_2 v_3_3 139 140 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 141 // we need a pointer to the first element of the submatrix as base pointer. 142 // Then we can use computeVectorAddr to compute the addresses for the columns 143 // of the sub-matrix. 144 // 145 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 146 // -> just returns Base 147 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 148 // -> returns Base + (1 * 4) 149 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 150 // -> returns Base + (2 * 4) 151 // 152 // The graphic below illustrates the number of elements in a column (marked 153 // with |) and the number of skipped elements (marked with }). 154 // 155 // v_0_0 v_0_1 {v_0_2 {v_0_3 156 // Base Col 1 Col 2 157 // | | | 158 // v_1_0 |v_1_1 |v_1_2 |v_1_3 159 // v_2_0 |v_2_1 |v_2_2 |v_2_3 160 // v_3_0 {v_3_1 {v_3_2 v_3_3 161 // 162 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 163 unsigned NumElements, Type *EltType, 164 IRBuilder<> &Builder) { 165 166 assert((!isa<ConstantInt>(Stride) || 167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 168 "Stride must be >= the number of elements in the result vector."); 169 170 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 171 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 172 173 // Get pointer to the start of the selected vector. Skip GEP creation, 174 // if we select vector 0. 175 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 176 VecStart = BasePtr; 177 else 178 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 179 180 return VecStart; 181 } 182 183 namespace { 184 struct ShapeInfo { 185 unsigned NumRows; 186 unsigned NumColumns; 187 188 bool IsColumnMajor; 189 190 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 191 : NumRows(NumRows), NumColumns(NumColumns), 192 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 193 194 ShapeInfo(Value *NumRows, Value *NumColumns) 195 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 196 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 197 198 bool operator==(const ShapeInfo &other) { 199 return NumRows == other.NumRows && NumColumns == other.NumColumns; 200 } 201 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 202 203 /// Returns true if shape-information is defined, meaning both dimensions 204 /// are != 0. 205 operator bool() const { 206 assert(NumRows == 0 || NumColumns != 0); 207 return NumRows != 0; 208 } 209 210 unsigned getStride() const { 211 if (IsColumnMajor) 212 return NumRows; 213 return NumColumns; 214 } 215 216 unsigned getNumVectors() const { 217 if (IsColumnMajor) 218 return NumColumns; 219 return NumRows; 220 } 221 222 /// Returns the transposed shape. 223 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 224 }; 225 } // namespace 226 227 static bool isUniformShape(Value *V) { 228 Instruction *I = dyn_cast<Instruction>(V); 229 if (!I) 230 return true; 231 232 switch (I->getOpcode()) { 233 case Instruction::FAdd: 234 case Instruction::FSub: 235 case Instruction::FMul: // Scalar multiply. 236 case Instruction::FNeg: 237 case Instruction::Add: 238 case Instruction::Mul: 239 case Instruction::Sub: 240 return true; 241 default: 242 return false; 243 } 244 } 245 246 /// Return the ShapeInfo for the result of \p I, it it can be determined. 247 static std::optional<ShapeInfo> 248 computeShapeInfoForInst(Instruction *I, 249 const DenseMap<Value *, ShapeInfo> &ShapeMap) { 250 Value *M; 251 Value *N; 252 Value *K; 253 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( 254 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) 255 return ShapeInfo(M, K); 256 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), 257 m_Value(N)))) { 258 // Flip dimensions. 259 return ShapeInfo(N, M); 260 } 261 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( 262 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), 263 m_Value(N)))) 264 return ShapeInfo(N, M); 265 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( 266 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) 267 return ShapeInfo(M, N); 268 Value *MatrixA; 269 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { 270 auto OpShape = ShapeMap.find(MatrixA); 271 if (OpShape != ShapeMap.end()) 272 return OpShape->second; 273 } 274 275 if (isUniformShape(I)) { 276 // Find the first operand that has a known shape and use that. 277 for (auto &Op : I->operands()) { 278 auto OpShape = ShapeMap.find(Op.get()); 279 if (OpShape != ShapeMap.end()) 280 return OpShape->second; 281 } 282 } 283 return std::nullopt; 284 } 285 286 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 287 /// 288 /// Currently, the lowering for each matrix intrinsic is done as follows: 289 /// 1. Propagate the shape information from intrinsics to connected 290 /// instructions. 291 /// 2. Lower instructions with shape information (assuming column-major layout). 292 /// The lowering works similarly using row-major layout. 293 /// 2.1. Get column vectors for each argument. If we already lowered the 294 /// definition of an argument, use the produced column vectors directly. 295 /// If not, split the operand vector containing an embedded matrix into 296 /// a set of column vectors, 297 /// 2.2. Lower the instruction in terms of column major operations, which 298 /// yields a set of column vectors containing result matrix. Note that we 299 /// lower all instructions that have shape information. Besides the 300 /// intrinsics, this includes stores for example. 301 /// 2.3. Update uses of the lowered instruction. If we have shape information 302 /// for a user, there is nothing to do, as we will look up the result 303 /// column matrix when lowering the user. For other uses, we embed the 304 /// result matrix in a flat vector and update the use. 305 /// 2.4. Cache the result column matrix for the instruction we lowered 306 /// 3. After we lowered all instructions in a function, remove the now 307 /// obsolete instructions. 308 /// 309 class LowerMatrixIntrinsics { 310 Function &Func; 311 const DataLayout &DL; 312 const TargetTransformInfo &TTI; 313 FunctionAnalysisManager *AM; 314 AliasAnalysis *AA = nullptr; 315 DominatorTree *DT = nullptr; 316 LoopInfo *LI = nullptr; 317 OptimizationRemarkEmitter *ORE = nullptr; 318 319 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 320 struct OpInfoTy { 321 /// Number of stores emitted to generate this matrix. 322 unsigned NumStores = 0; 323 /// Number of loads emitted to generate this matrix. 324 unsigned NumLoads = 0; 325 /// Number of compute operations emitted to generate this matrix. 326 unsigned NumComputeOps = 0; 327 /// Most of the time transposes can be fused with matrix multiplies or can 328 /// be folded away via algebraic simplifications. This is the number of 329 /// transposes that we failed to make "free" via such optimizations. 330 unsigned NumExposedTransposes = 0; 331 332 OpInfoTy &operator+=(const OpInfoTy &RHS) { 333 NumStores += RHS.NumStores; 334 NumLoads += RHS.NumLoads; 335 NumComputeOps += RHS.NumComputeOps; 336 NumExposedTransposes += RHS.NumExposedTransposes; 337 return *this; 338 } 339 }; 340 341 /// Wrapper class representing a matrix as a set of vectors, either in row or 342 /// column major layout. All vectors must have the same vector type. 343 class MatrixTy { 344 SmallVector<Value *, 16> Vectors; 345 346 OpInfoTy OpInfo; 347 348 bool IsColumnMajor = true; 349 350 public: 351 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 352 MatrixTy(ArrayRef<Value *> Vectors) 353 : Vectors(Vectors), 354 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 355 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 356 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 357 358 unsigned D = isColumnMajor() ? NumColumns : NumRows; 359 for (unsigned J = 0; J < D; ++J) 360 addVector(PoisonValue::get(FixedVectorType::get( 361 EltTy, isColumnMajor() ? NumRows : NumColumns))); 362 } 363 364 Value *getVector(unsigned i) const { return Vectors[i]; } 365 Value *getColumn(unsigned i) const { 366 assert(isColumnMajor() && "only supported for column-major matrixes"); 367 return Vectors[i]; 368 } 369 Value *getRow(unsigned i) const { 370 assert(!isColumnMajor() && "only supported for row-major matrixes"); 371 return Vectors[i]; 372 } 373 374 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 375 376 Type *getElementType() const { return getVectorTy()->getElementType(); } 377 378 unsigned getNumVectors() const { 379 if (isColumnMajor()) 380 return getNumColumns(); 381 return getNumRows(); 382 } 383 384 unsigned getNumColumns() const { 385 if (isColumnMajor()) 386 return Vectors.size(); 387 else { 388 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 389 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 390 } 391 } 392 unsigned getNumRows() const { 393 if (isColumnMajor()) { 394 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 395 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 396 } else 397 return Vectors.size(); 398 } 399 400 void addVector(Value *V) { Vectors.push_back(V); } 401 VectorType *getColumnTy() { 402 assert(isColumnMajor() && "only supported for column-major matrixes"); 403 return getVectorTy(); 404 } 405 406 VectorType *getVectorTy() const { 407 return cast<VectorType>(Vectors[0]->getType()); 408 } 409 410 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 411 assert(isColumnMajor() && 412 "columns() only supported for column-major matrixes"); 413 return make_range(Vectors.begin(), Vectors.end()); 414 } 415 416 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 417 return make_range(Vectors.begin(), Vectors.end()); 418 } 419 420 /// Embed the vectors of the matrix into a flat vector by concatenating 421 /// them. 422 Value *embedInVector(IRBuilder<> &Builder) const { 423 return Vectors.size() == 1 ? Vectors[0] 424 : concatenateVectors(Builder, Vectors); 425 } 426 427 MatrixTy &addNumLoads(unsigned N) { 428 OpInfo.NumLoads += N; 429 return *this; 430 } 431 432 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 433 434 MatrixTy &addNumStores(unsigned N) { 435 OpInfo.NumStores += N; 436 return *this; 437 } 438 439 MatrixTy &addNumExposedTransposes(unsigned N) { 440 OpInfo.NumExposedTransposes += N; 441 return *this; 442 } 443 444 MatrixTy &addNumComputeOps(unsigned N) { 445 OpInfo.NumComputeOps += N; 446 return *this; 447 } 448 449 unsigned getNumStores() const { return OpInfo.NumStores; } 450 unsigned getNumLoads() const { return OpInfo.NumLoads; } 451 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 452 453 const OpInfoTy &getOpInfo() const { return OpInfo; } 454 455 bool isColumnMajor() const { return IsColumnMajor; } 456 457 unsigned getStride() const { 458 if (isColumnMajor()) 459 return getNumRows(); 460 return getNumColumns(); 461 } 462 463 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 464 /// matrix is column-major, the result vector is extracted from a column 465 /// vector, otherwise from a row vector. 466 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 467 IRBuilder<> &Builder) const { 468 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 469 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 470 NumElts && 471 "Extracted vector will contain poison values"); 472 return Builder.CreateShuffleVector( 473 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 474 "block"); 475 } 476 }; 477 478 /// Maps instructions to their shape information. The shape information 479 /// describes the shape to be used while lowering. This matches the shape of 480 /// the result value of the instruction, with the only exceptions being store 481 /// instructions and the matrix_column_major_store intrinsics. For those, the 482 /// shape information indicates that those instructions should be lowered 483 /// using shape information as well. Note that extra care is needed when 484 /// erasing or RAUW'ing a value that is present in ShapeMap. If the 485 /// replacement is also a matrix operation, use 486 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to 487 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not 488 /// want to add shape information for a replacement instruction. When directly 489 /// erasing a value with an entry in ShapeMap, use 490 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated 491 /// accordingly. 492 DenseMap<Value *, ShapeInfo> ShapeMap; 493 494 /// List of instructions to remove. While lowering, we are not replacing all 495 /// users of a lowered instruction, if shape information is available and 496 /// those need to be removed after we finished lowering. 497 SmallVector<Instruction *, 16> ToRemove; 498 499 /// Map from instructions to their produced column matrix. 500 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 501 502 private: 503 static FastMathFlags getFastMathFlags(Instruction *Inst) { 504 FastMathFlags FMF; 505 506 if (isa<FPMathOperator>(*Inst)) 507 FMF = Inst->getFastMathFlags(); 508 509 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 510 511 return FMF; 512 } 513 514 public: 515 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 516 FunctionAnalysisManager *AM) 517 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {} 518 519 unsigned getNumOps(Type *VT) { 520 assert(isa<VectorType>(VT) && "Expected vector type"); 521 return getNumOps(VT->getScalarType(), 522 cast<FixedVectorType>(VT)->getNumElements()); 523 } 524 525 /// Is this the minimal version executed in the backend pipelines. 526 bool isMinimal() const { 527 return !DT; 528 } 529 530 /// Return the estimated number of vector ops required for an operation on 531 /// \p VT * N. 532 unsigned getNumOps(Type *ST, unsigned N) { 533 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 534 double(TTI.getRegisterBitWidth( 535 TargetTransformInfo::RGK_FixedWidthVector) 536 .getFixedValue())); 537 } 538 539 /// Return the set of vectors that a matrix value is lowered to. 540 /// 541 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 542 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 543 /// into vectors. 544 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 545 IRBuilder<> &Builder) { 546 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 547 assert(VType && "MatrixVal must be a vector type"); 548 assert(cast<FixedVectorType>(VType)->getNumElements() == 549 SI.NumRows * SI.NumColumns && 550 "The vector size must match the number of matrix elements"); 551 552 // Check if we lowered MatrixVal using shape information. In that case, 553 // return the existing matrix, if it matches the requested shape 554 // information. If there is a mis-match, embed the result in a flat 555 // vector and split it later. 556 auto Found = Inst2ColumnMatrix.find(MatrixVal); 557 if (Found != Inst2ColumnMatrix.end()) { 558 MatrixTy &M = Found->second; 559 // Return the found matrix, if its shape matches the requested shape 560 // information 561 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 562 return M; 563 564 MatrixVal = M.embedInVector(Builder); 565 } 566 567 // Otherwise split MatrixVal. 568 SmallVector<Value *, 16> SplitVecs; 569 for (unsigned MaskStart = 0; 570 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 571 MaskStart += SI.getStride()) { 572 Value *V = Builder.CreateShuffleVector( 573 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 574 "split"); 575 SplitVecs.push_back(V); 576 } 577 578 return {SplitVecs}; 579 } 580 581 /// If \p V already has a known shape return false. Otherwise set the shape 582 /// for instructions that support it. 583 bool setShapeInfo(Value *V, ShapeInfo Shape) { 584 assert(Shape && "Shape not set"); 585 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 586 return false; 587 588 auto SIter = ShapeMap.find(V); 589 if (SIter != ShapeMap.end()) { 590 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 591 SIter->second.NumColumns != Shape.NumColumns)) { 592 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 593 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 594 << Shape.NumColumns << ") for " << *V << "\n"; 595 report_fatal_error( 596 "Matrix shape verification failed, compilation aborted!"); 597 } 598 599 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 600 << SIter->second.NumRows << " " 601 << SIter->second.NumColumns << " for " << *V << "\n"); 602 return false; 603 } 604 605 ShapeMap.insert({V, Shape}); 606 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 607 << " for " << *V << "\n"); 608 return true; 609 } 610 611 /// Returns true if shape information can be used for \p V. The supported 612 /// instructions must match the instructions that can be lowered by this pass. 613 bool supportsShapeInfo(Value *V) { 614 Instruction *Inst = dyn_cast<Instruction>(V); 615 if (!Inst) 616 return false; 617 618 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 619 if (II) 620 switch (II->getIntrinsicID()) { 621 case Intrinsic::matrix_multiply: 622 case Intrinsic::matrix_transpose: 623 case Intrinsic::matrix_column_major_load: 624 case Intrinsic::matrix_column_major_store: 625 return true; 626 default: 627 return false; 628 } 629 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 630 } 631 632 /// Propagate the shape information of instructions to their users. 633 /// The work list contains instructions for which we can compute the shape, 634 /// either based on the information provided by matrix intrinsics or known 635 /// shapes of operands. 636 SmallVector<Instruction *, 32> 637 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 638 SmallVector<Instruction *, 32> NewWorkList; 639 // Pop an element for which we guaranteed to have at least one of the 640 // operand shapes. Add the shape for this and then add users to the work 641 // list. 642 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 643 while (!WorkList.empty()) { 644 Instruction *Inst = WorkList.pop_back_val(); 645 646 // New entry, set the value and insert operands 647 bool Propagate = false; 648 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) 649 Propagate = setShapeInfo(Inst, *SI); 650 651 if (Propagate) { 652 NewWorkList.push_back(Inst); 653 for (auto *User : Inst->users()) 654 if (ShapeMap.count(User) == 0) 655 WorkList.push_back(cast<Instruction>(User)); 656 } 657 } 658 659 return NewWorkList; 660 } 661 662 /// Propagate the shape to operands of instructions with shape information. 663 /// \p Worklist contains the instruction for which we already know the shape. 664 SmallVector<Instruction *, 32> 665 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 666 SmallVector<Instruction *, 32> NewWorkList; 667 668 auto pushInstruction = [](Value *V, 669 SmallVectorImpl<Instruction *> &WorkList) { 670 Instruction *I = dyn_cast<Instruction>(V); 671 if (I) 672 WorkList.push_back(I); 673 }; 674 // Pop an element with known shape. Traverse the operands, if their shape 675 // derives from the result shape and is unknown, add it and add them to the 676 // worklist. 677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 678 while (!WorkList.empty()) { 679 Value *V = WorkList.pop_back_val(); 680 681 size_t BeforeProcessingV = WorkList.size(); 682 if (!isa<Instruction>(V)) 683 continue; 684 685 Value *MatrixA; 686 Value *MatrixB; 687 Value *M; 688 Value *N; 689 Value *K; 690 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 691 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 692 m_Value(N), m_Value(K)))) { 693 if (setShapeInfo(MatrixA, {M, N})) 694 pushInstruction(MatrixA, WorkList); 695 696 if (setShapeInfo(MatrixB, {N, K})) 697 pushInstruction(MatrixB, WorkList); 698 699 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 700 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 701 // Flip dimensions. 702 if (setShapeInfo(MatrixA, {M, N})) 703 pushInstruction(MatrixA, WorkList); 704 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 705 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 706 m_Value(M), m_Value(N)))) { 707 if (setShapeInfo(MatrixA, {M, N})) { 708 pushInstruction(MatrixA, WorkList); 709 } 710 } else if (isa<LoadInst>(V) || 711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 712 // Nothing to do, no matrix input. 713 } else if (isa<StoreInst>(V)) { 714 // Nothing to do. We forward-propagated to this so we would just 715 // backward propagate to an instruction with an already known shape. 716 } else if (isUniformShape(V)) { 717 // Propagate to all operands. 718 ShapeInfo Shape = ShapeMap[V]; 719 for (Use &U : cast<Instruction>(V)->operands()) { 720 if (setShapeInfo(U.get(), Shape)) 721 pushInstruction(U.get(), WorkList); 722 } 723 } 724 // After we discovered new shape info for new instructions in the 725 // worklist, we use their users as seeds for the next round of forward 726 // propagation. 727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 728 for (User *U : WorkList[I]->users()) 729 if (isa<Instruction>(U) && V != U) 730 NewWorkList.push_back(cast<Instruction>(U)); 731 } 732 return NewWorkList; 733 } 734 735 /// (Op0 op Op1)^T -> Op0^T op Op1^T 736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 737 /// them on both sides of \p Operation. 738 Instruction *distributeTransposes( 739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 740 MatrixBuilder &Builder, 741 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 742 Operation) { 743 Value *T0 = Builder.CreateMatrixTranspose( 744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 745 // We are being run after shape prop, add shape for newly created 746 // instructions so that we lower them later. 747 setShapeInfo(T0, Shape0.t()); 748 Value *T1 = Builder.CreateMatrixTranspose( 749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 750 setShapeInfo(T1, Shape1.t()); 751 return Operation(T0, Shape0.t(), T1, Shape1.t()); 752 } 753 754 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst 755 /// itself. 756 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) { 757 auto Iter = ShapeMap.find(Inst); 758 if (Iter != ShapeMap.end()) 759 ShapeMap.erase(Iter); 760 Inst->eraseFromParent(); 761 } 762 763 /// Erase \p V from \p BB and move \II forward to avoid invalidating 764 /// iterators. 765 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 766 BasicBlock &BB) { 767 auto *Inst = cast<Instruction>(V); 768 // Still used, don't erase. 769 if (!Inst->use_empty()) 770 return; 771 if (II != BB.rend() && Inst == &*II) 772 ++II; 773 eraseFromParentAndRemoveFromShapeMap(Inst); 774 } 775 776 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the 777 /// entry for \p Old and replace all uses of \p Old with \p New. 778 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 779 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 780 // with New. We should only add New it it supportsShapeInfo so we insert 781 // it conditionally instead. 782 auto S = ShapeMap.find(&Old); 783 if (S != ShapeMap.end()) { 784 ShapeMap.erase(S); 785 if (supportsShapeInfo(New)) 786 ShapeMap.insert({New, S->second}); 787 } 788 Old.replaceAllUsesWith(New); 789 } 790 791 /// Sink a top-level transpose inside matmuls and adds. 792 /// This creates and erases instructions as needed, and returns the newly 793 /// created instruction while updating the iterator to avoid invalidation. If 794 /// this returns nullptr, no new instruction was created. 795 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 796 BasicBlock &BB = *I.getParent(); 797 IRBuilder<> IB(&I); 798 MatrixBuilder Builder(IB); 799 800 Value *TA, *TAMA, *TAMB; 801 ConstantInt *R, *K, *C; 802 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 803 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 804 return nullptr; 805 806 // Transpose of a transpose is a nop 807 Value *TATA; 808 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 809 updateShapeAndReplaceAllUsesWith(I, TATA); 810 eraseFromParentAndMove(&I, II, BB); 811 eraseFromParentAndMove(TA, II, BB); 812 return nullptr; 813 } 814 815 // k^T -> k 816 if (isSplat(TA)) { 817 updateShapeAndReplaceAllUsesWith(I, TA); 818 eraseFromParentAndMove(&I, II, BB); 819 return nullptr; 820 } 821 822 // (A * B)^t -> B^t * A^t 823 // RxK KxC CxK KxR 824 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 825 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 826 m_ConstantInt(K), m_ConstantInt(C)))) { 827 auto NewInst = distributeTransposes( 828 TAMB, {K, C}, TAMA, {R, K}, Builder, 829 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 830 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 831 Shape0.NumColumns, 832 Shape1.NumColumns, "mmul"); 833 }); 834 updateShapeAndReplaceAllUsesWith(I, NewInst); 835 eraseFromParentAndMove(&I, II, BB); 836 eraseFromParentAndMove(TA, II, BB); 837 return NewInst; 838 } 839 840 // Same as above, but with a mul, which occurs when multiplied 841 // with a scalar. 842 // (A * k)^t -> A^t * k 843 // R x C RxC 844 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 845 (isSplat(TAMA) || isSplat(TAMB))) { 846 IRBuilder<> LocalBuilder(&I); 847 // We know that the transposed operand is of shape RxC. 848 // An when multiplied with a scalar, the shape is preserved. 849 auto NewInst = distributeTransposes( 850 TAMA, {R, C}, TAMB, {R, C}, Builder, 851 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 852 bool IsFP = I.getType()->isFPOrFPVectorTy(); 853 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 854 : LocalBuilder.CreateMul(T0, T1, "mmul"); 855 auto *Result = cast<Instruction>(Mul); 856 setShapeInfo(Result, Shape0); 857 return Result; 858 }); 859 updateShapeAndReplaceAllUsesWith(I, NewInst); 860 eraseFromParentAndMove(&I, II, BB); 861 eraseFromParentAndMove(TA, II, BB); 862 return NewInst; 863 } 864 865 // (A + B)^t -> A^t + B^t 866 // RxC RxC CxR CxR 867 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 868 IRBuilder<> LocalBuilder(&I); 869 auto NewInst = distributeTransposes( 870 TAMA, {R, C}, TAMB, {R, C}, Builder, 871 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 872 bool IsFP = I.getType()->isFPOrFPVectorTy(); 873 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 874 : LocalBuilder.CreateAdd(T0, T1, "madd"); 875 876 auto *Result = cast<Instruction>(Add); 877 setShapeInfo(Result, Shape0); 878 return Result; 879 }); 880 updateShapeAndReplaceAllUsesWith(I, NewInst); 881 eraseFromParentAndMove(&I, II, BB); 882 eraseFromParentAndMove(TA, II, BB); 883 return NewInst; 884 } 885 886 return nullptr; 887 } 888 889 void liftTranspose(Instruction &I) { 890 // Erase dead Instructions after lifting transposes from binops. 891 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) { 892 if (T.use_empty()) 893 eraseFromParentAndRemoveFromShapeMap(&T); 894 if (A->use_empty()) 895 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A)); 896 if (A != B && B->use_empty()) 897 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B)); 898 }; 899 900 Value *A, *B, *AT, *BT; 901 ConstantInt *R, *K, *C; 902 // A^t * B ^t -> (B * A)^t 903 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 904 m_Value(A), m_Value(B), m_ConstantInt(R), 905 m_ConstantInt(K), m_ConstantInt(C))) && 906 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 907 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 908 IRBuilder<> IB(&I); 909 MatrixBuilder Builder(IB); 910 Value *M = Builder.CreateMatrixMultiply( 911 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 912 setShapeInfo(M, {C, R}); 913 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 914 R->getZExtValue()); 915 updateShapeAndReplaceAllUsesWith(I, NewInst); 916 CleanupBinOp(I, A, B); 917 } 918 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If 919 // the shape of the second transpose is different, there's a shape conflict 920 // which gets resolved by picking the shape of the first operand. 921 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 922 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 923 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 924 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 925 m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { 926 IRBuilder<> Builder(&I); 927 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd"); 928 MatrixBuilder MBuilder(Builder); 929 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 930 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); 931 updateShapeAndReplaceAllUsesWith(I, NewInst); 932 assert(computeShapeInfoForInst(NewInst, ShapeMap) == 933 computeShapeInfoForInst(&I, ShapeMap) && 934 "Shape of new instruction doesn't match original shape."); 935 CleanupBinOp(I, A, B); 936 if (auto *AddI = dyn_cast<Instruction>(Add)) { 937 setShapeInfo(AddI, {R, C}); 938 assert( 939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) == 940 ShapeMap[AddI] && 941 "Shape of updated addition doesn't match cached shape."); 942 } 943 } 944 } 945 946 /// Try moving transposes in order to fold them away or into multiplies. 947 void optimizeTransposes() { 948 // First sink all transposes inside matmuls and adds, hoping that we end up 949 // with NN, NT or TN variants. 950 for (BasicBlock &BB : reverse(Func)) { 951 for (auto II = BB.rbegin(); II != BB.rend();) { 952 Instruction &I = *II; 953 // We may remove II. By default continue on the next/prev instruction. 954 ++II; 955 if (Instruction *NewInst = sinkTranspose(I, II)) 956 II = std::next(BasicBlock::reverse_iterator(NewInst)); 957 } 958 } 959 960 // If we have a TT matmul or a TT add, lift the transpose. We may be able 961 // to fold into consuming multiply or add. 962 for (BasicBlock &BB : Func) { 963 for (Instruction &I : llvm::make_early_inc_range(BB)) { 964 liftTranspose(I); 965 } 966 } 967 } 968 969 bool Visit() { 970 SmallVector<Instruction *, 32> WorkList; 971 972 // Initially only the shape of matrix intrinsics is known. 973 // Initialize the work list with ops carrying shape information. 974 for (BasicBlock &BB : Func) 975 for (Instruction &Inst : BB) { 976 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 977 if (!II) 978 continue; 979 980 switch (II->getIntrinsicID()) { 981 case Intrinsic::matrix_multiply: 982 case Intrinsic::matrix_transpose: 983 case Intrinsic::matrix_column_major_load: 984 case Intrinsic::matrix_column_major_store: 985 WorkList.push_back(&Inst); 986 break; 987 default: 988 break; 989 } 990 } 991 992 // Avoid unnecessary work if there are no matrix intrinsics in the function. 993 if (WorkList.empty()) 994 return false; 995 996 if (AM) { 997 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func); 998 AA = &AM->getResult<AAManager>(Func); 999 DT = &AM->getResult<DominatorTreeAnalysis>(Func); 1000 LI = &AM->getResult<LoopAnalysis>(Func); 1001 } 1002 1003 // Propagate shapes until nothing changes any longer. 1004 while (!WorkList.empty()) { 1005 WorkList = propagateShapeForward(WorkList); 1006 WorkList = propagateShapeBackward(WorkList); 1007 } 1008 1009 if (!isMinimal()) { 1010 optimizeTransposes(); 1011 if (PrintAfterTransposeOpt) { 1012 dbgs() << "Dump after matrix transpose optimization:\n"; 1013 Func.print(dbgs()); 1014 } 1015 } 1016 1017 bool Changed = false; 1018 SmallVector<CallInst *, 16> MaybeFusableInsts; 1019 SmallVector<Instruction *, 16> MatrixInsts; 1020 SmallVector<IntrinsicInst *, 16> LifetimeEnds; 1021 1022 // First, collect all instructions with shape information and candidates for 1023 // fusion (currently only matrix multiplies). 1024 ReversePostOrderTraversal<Function *> RPOT(&Func); 1025 for (auto *BB : RPOT) 1026 for (Instruction &I : *BB) { 1027 if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) 1028 LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); 1029 if (ShapeMap.find(&I) == ShapeMap.end()) 1030 continue; 1031 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 1032 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 1033 MatrixInsts.push_back(&I); 1034 } 1035 1036 // Second, try to lower any dot products 1037 SmallPtrSet<Instruction *, 16> FusedInsts; 1038 for (CallInst *CI : MaybeFusableInsts) 1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 1040 1041 // Third, try to fuse candidates. 1042 for (CallInst *CI : MaybeFusableInsts) 1043 if (!FusedInsts.contains(CI)) 1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); 1045 1046 Changed = !FusedInsts.empty(); 1047 1048 // Fourth, lower remaining instructions with shape information. 1049 for (Instruction *Inst : MatrixInsts) { 1050 if (FusedInsts.count(Inst)) 1051 continue; 1052 1053 IRBuilder<> Builder(Inst); 1054 1055 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 1056 Changed |= VisitCallInst(CInst); 1057 1058 Value *Op1; 1059 Value *Op2; 1060 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1061 Changed |= VisitBinaryOperator(BinOp); 1062 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1063 Changed |= VisitUnaryOperator(UnOp); 1064 if (match(Inst, m_Load(m_Value(Op1)))) 1065 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 1066 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1067 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1068 } 1069 1070 if (ORE) { 1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1072 RemarkGen.emitRemarks(); 1073 } 1074 1075 // Delete the instructions backwards, as it has a reduced likelihood of 1076 // having to update as many def-use and use-def chains. 1077 // 1078 // Because we add to ToRemove during fusion we can't guarantee that defs 1079 // are before uses. Change uses to poison temporarily as these should get 1080 // removed as well. 1081 // 1082 // For verification, we keep track of where we changed uses to poison in 1083 // PoisonedInsts and then check that we in fact remove them. 1084 SmallSet<Instruction *, 16> PoisonedInsts; 1085 for (auto *Inst : reverse(ToRemove)) { 1086 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1087 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1088 PoisonedInsts.insert(Poisoned); 1089 U.set(PoisonValue::get(Inst->getType())); 1090 } 1091 Inst->eraseFromParent(); 1092 PoisonedInsts.erase(Inst); 1093 } 1094 if (!PoisonedInsts.empty()) { 1095 // If we didn't remove all poisoned instructions, it's a hard error. 1096 dbgs() << "Poisoned but present instructions:\n"; 1097 for (auto *I : PoisonedInsts) 1098 dbgs() << *I << "\n"; 1099 llvm_unreachable("Poisoned but instruction not removed"); 1100 } 1101 1102 return Changed; 1103 } 1104 1105 /// Replace intrinsic calls 1106 bool VisitCallInst(CallInst *Inst) { 1107 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1108 return false; 1109 1110 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1111 case Intrinsic::matrix_multiply: 1112 LowerMultiply(Inst); 1113 break; 1114 case Intrinsic::matrix_transpose: 1115 LowerTranspose(Inst); 1116 break; 1117 case Intrinsic::matrix_column_major_load: 1118 LowerColumnMajorLoad(Inst); 1119 break; 1120 case Intrinsic::matrix_column_major_store: 1121 LowerColumnMajorStore(Inst); 1122 break; 1123 default: 1124 return false; 1125 } 1126 return true; 1127 } 1128 1129 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1130 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1131 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1132 /// non-ConstantInt strides, return the common alignment of the initial 1133 /// alignment and the element size in bytes. 1134 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1135 MaybeAlign A) const { 1136 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1137 if (Idx == 0) 1138 return InitialAlign; 1139 1140 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1141 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1142 uint64_t StrideInBytes = 1143 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1144 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1145 } 1146 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1147 } 1148 1149 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1150 /// vectors. 1151 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1152 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1153 auto *VType = cast<VectorType>(Ty); 1154 Type *EltTy = VType->getElementType(); 1155 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1156 Value *EltPtr = Ptr; 1157 MatrixTy Result; 1158 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1159 Value *GEP = computeVectorAddr( 1160 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1161 Stride, Shape.getStride(), EltTy, Builder); 1162 Value *Vector = Builder.CreateAlignedLoad( 1163 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1164 IsVolatile, "col.load"); 1165 1166 Result.addVector(Vector); 1167 } 1168 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1169 Result.getNumVectors()); 1170 } 1171 1172 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1173 /// starting at \p MatrixPtr[I][J]. 1174 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1175 ShapeInfo MatrixShape, Value *I, Value *J, 1176 ShapeInfo ResultShape, Type *EltTy, 1177 IRBuilder<> &Builder) { 1178 1179 Value *Offset = Builder.CreateAdd( 1180 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1181 1182 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1183 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1184 ResultShape.NumColumns); 1185 1186 return loadMatrix(TileTy, TileStart, Align, 1187 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1188 ResultShape, Builder); 1189 } 1190 1191 /// Lower a load instruction with shape information. 1192 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1193 bool IsVolatile, ShapeInfo Shape) { 1194 IRBuilder<> Builder(Inst); 1195 finalizeLowering(Inst, 1196 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1197 Shape, Builder), 1198 Builder); 1199 } 1200 1201 /// Lowers llvm.matrix.column.major.load. 1202 /// 1203 /// The intrinsic loads a matrix from memory using a stride between columns. 1204 void LowerColumnMajorLoad(CallInst *Inst) { 1205 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1206 "Intrinsic only supports column-major layout!"); 1207 Value *Ptr = Inst->getArgOperand(0); 1208 Value *Stride = Inst->getArgOperand(1); 1209 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1210 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1212 } 1213 1214 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1215 /// MatrixPtr[I][J]. 1216 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1217 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1218 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1219 Value *Offset = Builder.CreateAdd( 1220 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1221 1222 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1223 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1224 StoreVal.getNumColumns()); 1225 1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign, 1227 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1228 } 1229 1230 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1231 /// vectors. 1232 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1233 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1234 IRBuilder<> &Builder) { 1235 auto VType = cast<VectorType>(Ty); 1236 Value *EltPtr = Ptr; 1237 for (auto Vec : enumerate(StoreVal.vectors())) { 1238 Value *GEP = computeVectorAddr( 1239 EltPtr, 1240 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1241 Vec.index()), 1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1243 Builder.CreateAlignedStore(Vec.value(), GEP, 1244 getAlignForIndex(Vec.index(), Stride, 1245 VType->getElementType(), 1246 MAlign), 1247 IsVolatile); 1248 } 1249 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1250 StoreVal.getNumVectors()); 1251 } 1252 1253 /// Lower a store instruction with shape information. 1254 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1255 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1256 IRBuilder<> Builder(Inst); 1257 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1258 finalizeLowering(Inst, 1259 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1260 IsVolatile, Builder), 1261 Builder); 1262 } 1263 1264 /// Lowers llvm.matrix.column.major.store. 1265 /// 1266 /// The intrinsic store a matrix back memory using a stride between columns. 1267 void LowerColumnMajorStore(CallInst *Inst) { 1268 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1269 "Intrinsic only supports column-major layout!"); 1270 Value *Matrix = Inst->getArgOperand(0); 1271 Value *Ptr = Inst->getArgOperand(1); 1272 Value *Stride = Inst->getArgOperand(2); 1273 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1274 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1276 } 1277 1278 // Set elements I..I+NumElts-1 to Block 1279 Value *insertVector(Value *Col, unsigned I, Value *Block, 1280 IRBuilder<> &Builder) { 1281 1282 // First, bring Block to the same size as Col 1283 unsigned BlockNumElts = 1284 cast<FixedVectorType>(Block->getType())->getNumElements(); 1285 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1286 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1287 1288 Block = Builder.CreateShuffleVector( 1289 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1290 1291 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1292 // 8, 4, 5, 6 1293 SmallVector<int, 16> Mask; 1294 unsigned i; 1295 for (i = 0; i < I; i++) 1296 Mask.push_back(i); 1297 1298 unsigned VecNumElts = 1299 cast<FixedVectorType>(Col->getType())->getNumElements(); 1300 for (; i < I + BlockNumElts; i++) 1301 Mask.push_back(i - I + VecNumElts); 1302 1303 for (; i < VecNumElts; i++) 1304 Mask.push_back(i); 1305 1306 return Builder.CreateShuffleVector(Col, Block, Mask); 1307 } 1308 1309 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1310 IRBuilder<> &Builder, bool AllowContraction, 1311 unsigned &NumComputeOps) { 1312 NumComputeOps += getNumOps(A->getType()); 1313 if (!Sum) 1314 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1315 1316 if (UseFPOp) { 1317 if (AllowContraction) { 1318 // Use fmuladd for floating point operations and let the backend decide 1319 // if that's profitable. 1320 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(), 1321 {A, B, Sum}); 1322 } 1323 NumComputeOps += getNumOps(A->getType()); 1324 Value *Mul = Builder.CreateFMul(A, B); 1325 return Builder.CreateFAdd(Sum, Mul); 1326 } 1327 1328 NumComputeOps += getNumOps(A->getType()); 1329 Value *Mul = Builder.CreateMul(A, B); 1330 return Builder.CreateAdd(Sum, Mul); 1331 } 1332 1333 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1334 /// users with shape information, there's nothing to do: they will use the 1335 /// cached value when they are lowered. For other users, \p Matrix is 1336 /// flattened and the uses are updated to use it. Also marks \p Inst for 1337 /// deletion. 1338 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1339 IRBuilder<> &Builder) { 1340 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1341 (void)inserted; 1342 assert(inserted.second && "multiple matrix lowering mapping"); 1343 1344 ToRemove.push_back(Inst); 1345 Value *Flattened = nullptr; 1346 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1347 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1348 if (!Flattened) 1349 Flattened = Matrix.embedInVector(Builder); 1350 U.set(Flattened); 1351 } 1352 } 1353 } 1354 1355 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1356 /// vectors Lowers to vector reduction add instead of sequential add if 1357 /// reassocation is enabled. 1358 void lowerDotProduct(CallInst *MatMul, 1359 SmallPtrSet<Instruction *, 16> &FusedInsts, 1360 FastMathFlags FMF) { 1361 if (FusedInsts.contains(MatMul) || 1362 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1363 return; 1364 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1365 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1366 1367 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 1368 return; 1369 1370 Value *LHS = MatMul->getArgOperand(0); 1371 Value *RHS = MatMul->getArgOperand(1); 1372 1373 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); 1374 bool IsIntVec = ElementType->isIntegerTy(); 1375 1376 // Floating point reductions require reassocation. 1377 if (!IsIntVec && !FMF.allowReassoc()) 1378 return; 1379 1380 auto CanBeFlattened = [](Value *Op) { 1381 if (match(Op, m_BinOp())) 1382 return true; 1383 return match( 1384 Op, m_OneUse(m_CombineOr( 1385 m_Load(m_Value()), 1386 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 1387 m_Intrinsic<Intrinsic::matrix_column_major_load>( 1388 m_Value(), m_SpecificInt(1)))))); 1389 }; 1390 // Returns the cost benefit of using \p Op with the dot product lowering. If 1391 // the returned cost is < 0, the argument is cheaper to use in the 1392 // dot-product lowering. 1393 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1394 if (ShapeMap.find(Op) == ShapeMap.end()) 1395 return InstructionCost::getInvalid(); 1396 1397 if (!isa<Instruction>(Op)) 1398 return InstructionCost(0); 1399 1400 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 1401 Type *EltTy = VecTy->getElementType(); 1402 1403 if (!CanBeFlattened(Op)) { 1404 InstructionCost EmbedCost(0); 1405 // Roughly estimate the cost for embedding the columns into a vector. 1406 for (unsigned I = 1; I < N; ++I) 1407 EmbedCost += 1408 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1409 {}, TTI::TCK_RecipThroughput); 1410 return EmbedCost; 1411 } 1412 1413 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1414 InstructionCost OriginalCost = 1415 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 1416 EltTy) * 1417 N; 1418 InstructionCost NewCost = TTI.getArithmeticInstrCost( 1419 cast<Instruction>(Op)->getOpcode(), VecTy); 1420 return NewCost - OriginalCost; 1421 } 1422 1423 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 1424 // The transpose can be skipped for the dot product lowering, roughly 1425 // estimate the savings as the cost of embedding the columns in a 1426 // vector. 1427 InstructionCost EmbedCost(0); 1428 for (unsigned I = 1; I < N; ++I) 1429 EmbedCost -= 1430 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1431 {}, TTI::TCK_RecipThroughput); 1432 return EmbedCost; 1433 } 1434 1435 // Costs for loads. 1436 if (N == 1) 1437 return InstructionCost(0); 1438 1439 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 1440 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 1441 }; 1442 1443 // Iterate over LHS and operations feeding LHS and check if it is profitable 1444 // to flatten the visited ops. For each op, we compute the difference 1445 // between the flattened and matrix versions. 1446 SmallPtrSet<Value *, 4> Seen; 1447 SmallVector<Value *> WorkList; 1448 SmallVector<Value *> ToFlatten; 1449 WorkList.push_back(LHS); 1450 InstructionCost LHSCost(0); 1451 while (!WorkList.empty()) { 1452 Value *Op = WorkList.pop_back_val(); 1453 if (!Seen.insert(Op).second) 1454 continue; 1455 1456 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); 1457 if (OpCost + LHSCost >= LHSCost) 1458 continue; 1459 1460 LHSCost += OpCost; 1461 ToFlatten.push_back(Op); 1462 if (auto *I = dyn_cast<Instruction>(Op)) 1463 WorkList.append(I->op_begin(), I->op_end()); 1464 } 1465 1466 // We compare the costs of a vector.reduce.add to sequential add. 1467 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 1468 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 1469 InstructionCost ReductionCost = 1470 TTI.getArithmeticReductionCost( 1471 AddOpCode, cast<VectorType>(LHS->getType()), 1472 IsIntVec ? std::nullopt : std::optional(FMF)) + 1473 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 1474 InstructionCost SequentialAddCost = 1475 TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 1476 (LShape.NumColumns - 1) + 1477 TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 1478 (LShape.NumColumns); 1479 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 1480 return; 1481 1482 FusedInsts.insert(MatMul); 1483 IRBuilder<> Builder(MatMul); 1484 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1485 this](Value *Op) { 1486 // Matmul must be the only user of loads because we don't use LowerLoad 1487 // for row vectors (LowerLoad results in scalar loads and shufflevectors 1488 // instead of single vector load). 1489 if (!CanBeFlattened(Op)) 1490 return; 1491 1492 if (match(Op, m_BinOp())) { 1493 auto It = ShapeMap.find(Op); 1494 if (It != ShapeMap.end()) { 1495 It->second = It->second.t(); 1496 return; 1497 } 1498 } 1499 1500 FusedInsts.insert(cast<Instruction>(Op)); 1501 // If vector uses the builtin load, lower to a LoadInst 1502 Value *Arg; 1503 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 1504 m_Value(Arg)))) { 1505 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 1506 Op->replaceAllUsesWith(NewLoad); 1507 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op)); 1508 return; 1509 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 1510 m_Value(Arg)))) { 1511 ToRemove.push_back(cast<Instruction>(Op)); 1512 Op->replaceAllUsesWith(Arg); 1513 return; 1514 } 1515 }; 1516 1517 for (auto *V : ToFlatten) 1518 FlattenArg(V); 1519 1520 LHS = MatMul->getArgOperand(0); 1521 1522 // Insert mul/fmul and llvm.vector.reduce.fadd 1523 Value *Mul = 1524 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 1525 1526 Value *Result; 1527 if (IsIntVec) 1528 Result = Builder.CreateAddReduce(Mul); 1529 else { 1530 Result = Builder.CreateFAddReduce( 1531 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), 1532 0.0), 1533 Mul); 1534 cast<Instruction>(Result)->setFastMathFlags(FMF); 1535 } 1536 1537 // pack scalar back into a matrix and then replace matmul inst 1538 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 1539 Result, uint64_t(0)); 1540 MatMul->replaceAllUsesWith(Result); 1541 FusedInsts.insert(MatMul); 1542 ToRemove.push_back(MatMul); 1543 } 1544 1545 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1546 /// addition. 1547 /// 1548 /// We can fold a transpose into the operand that is used to extract scalars. 1549 /// This is the first operands with row-major and the second with 1550 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1551 /// operand is transposed. 1552 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1553 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1554 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1555 const unsigned VF = std::max<unsigned>( 1556 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1557 .getFixedValue() / 1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1559 1U); 1560 unsigned R = Result.getNumRows(); 1561 unsigned C = Result.getNumColumns(); 1562 unsigned M = A.getNumColumns(); 1563 1564 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1565 assert(A.isColumnMajor() == B.isColumnMajor() && 1566 Result.isColumnMajor() == A.isColumnMajor() && 1567 "operands must agree on matrix layout"); 1568 unsigned NumComputeOps = 0; 1569 1570 Builder.setFastMathFlags(FMF); 1571 1572 if (A.isColumnMajor()) { 1573 // Multiply columns from the first operand with scalars from the second 1574 // operand. Then move along the K axes and accumulate the columns. With 1575 // this the adds can be vectorized without reassociation. 1576 for (unsigned J = 0; J < C; ++J) { 1577 unsigned BlockSize = VF; 1578 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1579 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1580 1581 for (unsigned I = 0; I < R; I += BlockSize) { 1582 // Gradually lower the vectorization factor to cover the remainder. 1583 while (I + BlockSize > R) 1584 BlockSize /= 2; 1585 1586 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1587 : nullptr; 1588 for (unsigned K = 0; K < M; ++K) { 1589 Value *L = A.extractVector(I, K, BlockSize, Builder); 1590 Value *RH = Builder.CreateExtractElement( 1591 B.getColumn(IsScalarMatrixTransposed ? K : J), 1592 IsScalarMatrixTransposed ? J : K); 1593 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1594 Sum = 1595 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1596 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1597 } 1598 Result.setVector(J, 1599 insertVector(Result.getVector(J), I, Sum, Builder)); 1600 } 1601 } 1602 } else { 1603 // Multiply rows from the second operand with scalars from the first 1604 // operand. Then move along the K axes and accumulate the rows. With this 1605 // the adds can be vectorized without reassociation. 1606 for (unsigned I = 0; I < R; ++I) { 1607 unsigned BlockSize = VF; 1608 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1609 for (unsigned J = 0; J < C; J += BlockSize) { 1610 // Gradually lower the vectorization factor to cover the remainder. 1611 while (J + BlockSize > C) 1612 BlockSize /= 2; 1613 1614 Value *Sum = nullptr; 1615 for (unsigned K = 0; K < M; ++K) { 1616 Value *R = B.extractVector(K, J, BlockSize, Builder); 1617 Value *LH = Builder.CreateExtractElement( 1618 A.getVector(IsScalarMatrixTransposed ? K : I), 1619 IsScalarMatrixTransposed ? I : K); 1620 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1621 Sum = 1622 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1623 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1624 } 1625 Result.setVector(I, 1626 insertVector(Result.getVector(I), J, Sum, Builder)); 1627 } 1628 } 1629 } 1630 Result.addNumComputeOps(NumComputeOps); 1631 } 1632 1633 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1634 /// copying it to a new location. This new or otherwise the original location 1635 /// is returned. 1636 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1637 CallInst *MatMul) { 1638 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1639 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1640 1641 // If we can statically determine noalias we're good. 1642 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1643 return Load->getPointerOperand(); 1644 1645 // Create code to check if the memory locations of the Load and Store 1646 // overlap and if they do, copy Load's operand to a new buffer. 1647 1648 // First, create new blocks for 2n part of the check and the copy. 1649 BasicBlock *Check0 = MatMul->getParent(); 1650 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1651 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1652 // as we adjust Check0 and Check1's branches. 1653 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1654 for (BasicBlock *Succ : successors(Check0)) 1655 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1656 1657 BasicBlock *Check1 = 1658 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1659 nullptr, "alias_cont"); 1660 BasicBlock *Copy = 1661 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1662 nullptr, "copy"); 1663 BasicBlock *Fusion = 1664 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1665 nullptr, "no_alias"); 1666 1667 // Check if the loaded memory location begins before the end of the store 1668 // location. If the condition holds, they might overlap, otherwise they are 1669 // guaranteed to not overlap. 1670 IRBuilder<> Builder(MatMul); 1671 Check0->getTerminator()->eraseFromParent(); 1672 Builder.SetInsertPoint(Check0); 1673 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); 1674 Value *StoreBegin = Builder.CreatePtrToInt( 1675 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1676 Value *StoreEnd = Builder.CreateAdd( 1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1678 "store.end", true, true); 1679 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1680 IntPtrTy, "load.begin"); 1681 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1682 Fusion); 1683 1684 // Check if the store begins before the end of the load location. If the 1685 // condition holds, they alias, otherwise they are guaranteed to not 1686 // overlap. 1687 Check1->getTerminator()->eraseFromParent(); 1688 Builder.SetInsertPoint(Check1, Check1->begin()); 1689 Value *LoadEnd = Builder.CreateAdd( 1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1691 "load.end", true, true); 1692 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1693 Fusion); 1694 1695 // Copy load operand to new alloca. 1696 Builder.SetInsertPoint(Copy, Copy->begin()); 1697 auto *VT = cast<FixedVectorType>(Load->getType()); 1698 // Use an array type for the alloca, to avoid potentially huge alignment 1699 // requirements for large vector types. 1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1701 AllocaInst *Alloca = 1702 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1703 1704 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 1705 Load->getAlign(), LoadLoc.Size.getValue()); 1706 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1707 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1708 PHI->addIncoming(Load->getPointerOperand(), Check0); 1709 PHI->addIncoming(Load->getPointerOperand(), Check1); 1710 PHI->addIncoming(Alloca, Copy); 1711 1712 // Adjust DT. 1713 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1714 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1715 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1716 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1717 DT->applyUpdates(DTUpdates); 1718 return PHI; 1719 } 1720 1721 bool isFusionProfitable(CallInst *MatMul) { 1722 if (ForceFusion) 1723 return true; 1724 1725 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1726 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1727 1728 const unsigned R = LShape.NumRows; 1729 const unsigned C = RShape.NumColumns; 1730 const unsigned M = LShape.NumColumns; 1731 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1732 1733 const unsigned VF = std::max<unsigned>( 1734 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1735 .getFixedValue() / 1736 EltType->getPrimitiveSizeInBits().getFixedValue(), 1737 1U); 1738 1739 // Cost model for tiling 1740 // 1741 // For tiling to be beneficial, we need reuse either along the R or 1742 // the C axis. We vectorize along the R axis so that means at least 1743 // 3 elements. 1744 // TODO: Also consider cost of copying if operands alias. 1745 if (R <= VF && C == 1) 1746 return false; 1747 // Then we need enough elements to exceed the number of vector 1748 // registers we have. Note that this is an oversimplification since 1749 // fusing also takes some extra loads which may exceed the number of 1750 // reloads necessary. 1751 unsigned Op0Regs = (R + VF - 1) / VF * M; 1752 unsigned Op1Regs = (M + VF - 1) / VF * C; 1753 return Op0Regs + Op1Regs > 1754 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1755 } 1756 1757 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1758 MatrixTy Res; 1759 auto *ColumType = FixedVectorType::get(EltType, R); 1760 for (unsigned I = 0; I < C; ++I) 1761 Res.addVector(ConstantAggregateZero::get(ColumType)); 1762 return Res; 1763 } 1764 1765 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1766 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1767 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1768 1769 // Create the main tiling loop nest. 1770 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1771 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1772 Instruction *InsertI = cast<Instruction>(MatMul); 1773 BasicBlock *Start = InsertI->getParent(); 1774 BasicBlock *End = 1775 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1776 IRBuilder<> Builder(MatMul); 1777 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1778 1779 Type *TileVecTy = 1780 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1781 MatrixTy TileResult; 1782 // Insert in the inner loop header. 1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1784 // Create PHI nodes for the result columns to accumulate across iterations. 1785 SmallVector<PHINode *, 4> ColumnPhis; 1786 for (unsigned I = 0; I < TileSize; I++) { 1787 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1788 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1789 TI.RowLoop.Header->getSingleSuccessor()); 1790 TileResult.addVector(Phi); 1791 ColumnPhis.push_back(Phi); 1792 } 1793 1794 // Insert in the inner loop body, which computes 1795 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1796 Builder.SetInsertPoint(InnerBody->getTerminator()); 1797 // Load tiles of the operands. 1798 MatrixTy A = 1799 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1800 {TileSize, TileSize}, EltType, Builder); 1801 MatrixTy B = 1802 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1803 {TileSize, TileSize}, EltType, Builder); 1804 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1805 getFastMathFlags(MatMul)); 1806 // Store result after the inner loop is done. 1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1808 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1811 1812 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1814 1815 // Force unrolling of a few iterations of the inner loop, to make sure there 1816 // is enough work per iteration. 1817 // FIXME: The unroller should make this decision directly instead, but 1818 // currently the cost-model is not up to the task. 1819 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1820 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1821 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1822 } 1823 1824 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1825 StoreInst *Store, 1826 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1827 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1828 "Tiling only supported for column-major matrixes at the moment!"); 1829 if (!isFusionProfitable(MatMul)) 1830 return; 1831 1832 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1833 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1834 1835 const unsigned R = LShape.NumRows; 1836 const unsigned C = RShape.NumColumns; 1837 const unsigned M = LShape.NumColumns; 1838 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1839 1840 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1841 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1842 Value *CPtr = Store->getPointerOperand(); 1843 1844 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1846 else { 1847 IRBuilder<> Builder(Store); 1848 for (unsigned J = 0; J < C; J += TileSize) 1849 for (unsigned I = 0; I < R; I += TileSize) { 1850 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1851 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1853 1854 for (unsigned K = 0; K < M; K += TileSize) { 1855 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1856 MatrixTy A = 1857 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1858 LShape, Builder.getInt64(I), Builder.getInt64(K), 1859 {TileR, TileM}, EltType, Builder); 1860 MatrixTy B = 1861 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1862 RShape, Builder.getInt64(K), Builder.getInt64(J), 1863 {TileM, TileC}, EltType, Builder); 1864 emitMatrixMultiply(Res, A, B, Builder, true, false, 1865 getFastMathFlags(MatMul)); 1866 } 1867 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1868 Builder.getInt64(I), Builder.getInt64(J), EltType, 1869 Builder); 1870 } 1871 } 1872 1873 // Mark eliminated instructions as fused and remove them. 1874 FusedInsts.insert(Store); 1875 FusedInsts.insert(MatMul); 1876 eraseFromParentAndRemoveFromShapeMap(Store); 1877 eraseFromParentAndRemoveFromShapeMap(MatMul); 1878 if (LoadOp0->hasNUses(0)) { 1879 FusedInsts.insert(LoadOp0); 1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0); 1881 } 1882 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1883 FusedInsts.insert(LoadOp1); 1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1); 1885 } 1886 } 1887 1888 /// Try to lower matrix multiply chains by fusing operations. 1889 /// 1890 /// Call finalizeLowering on lowered instructions. Instructions that are 1891 /// completely eliminated by fusion are added to \p FusedInsts. 1892 void 1893 LowerMatrixMultiplyFused(CallInst *MatMul, 1894 SmallPtrSetImpl<Instruction *> &FusedInsts, 1895 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { 1896 if (!FuseMatrix || !DT) 1897 return; 1898 1899 assert(AA && LI && "Analyses should be available"); 1900 1901 Value *A = MatMul->getArgOperand(0); 1902 Value *B = MatMul->getArgOperand(1); 1903 1904 // We can fold the transpose into the operand that is used to fetch scalars. 1905 Value *T; 1906 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1907 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1908 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1909 IRBuilder<> Builder(MatMul); 1910 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1911 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1912 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1913 const unsigned R = LShape.NumRows; 1914 const unsigned M = LShape.NumColumns; 1915 const unsigned C = RShape.NumColumns; 1916 1917 MatrixTy MA; 1918 MatrixTy MB; 1919 1920 Value *Transpose; 1921 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1922 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1923 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1924 Transpose = B; 1925 } else { 1926 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1927 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1928 Transpose = A; 1929 } 1930 1931 // Initialize the output 1932 MatrixTy Result(R, C, EltType); 1933 1934 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1935 getFastMathFlags(MatMul)); 1936 1937 FusedInsts.insert(MatMul); 1938 if (Transpose->hasOneUse()) { 1939 FusedInsts.insert(cast<Instruction>(Transpose)); 1940 ToRemove.push_back(cast<Instruction>(Transpose)); 1941 // TODO: add a fake entry for the folded instruction so that this is 1942 // included in the expression in the remark. 1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1944 } 1945 finalizeLowering(MatMul, Result, Builder); 1946 return; 1947 } 1948 1949 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1950 return; 1951 1952 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1953 // since the single store user will be lowered as part of this. 1954 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1955 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1956 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1957 if (LoadOp0 && LoadOp1 && Store) { 1958 // The store address must dominate the MatMul instruction, otherwise 1959 // we create invalid IR. 1960 SetVector<Value *> WorkList; 1961 WorkList.insert(Store->getOperand(1)); 1962 SmallVector<Instruction *> ToHoist; 1963 for (unsigned I = 0; I != WorkList.size(); ++I) { 1964 Value *Current = WorkList[I]; 1965 auto *CurrI = dyn_cast<Instruction>(Current); 1966 if (!CurrI) 1967 continue; 1968 if (isa<PHINode>(CurrI)) 1969 return; 1970 if (DT->dominates(CurrI, MatMul)) 1971 continue; 1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1973 return; 1974 ToHoist.push_back(CurrI); 1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1976 } 1977 1978 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1979 return DT->dominates(A, B); 1980 }); 1981 for (Instruction *I : ToHoist) 1982 I->moveBefore(MatMul->getIterator()); 1983 1984 // Deal with lifetime.end calls that might be between Load0/Load1 and the 1985 // store. To avoid introducing loads to dead objects (i.e. after the 1986 // lifetime has been termined by @llvm.lifetime.end), either sink them 1987 // after the store if in the same block, or remove the lifetime.end marker 1988 // otherwise. This might pessimize further optimizations, by extending the 1989 // lifetime of the object until the function returns, but should be 1990 // conservatively correct. 1991 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); 1992 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); 1993 BasicBlock *StoreParent = Store->getParent(); 1994 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && 1995 LoadOp1->getParent() == StoreParent; 1996 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { 1997 IntrinsicInst *End = LifetimeEnds[Idx]; 1998 auto Inc = make_scope_exit([&Idx]() { Idx++; }); 1999 // If the lifetime.end is guaranteed to be before the loads or after the 2000 // store, it won't interfere with fusion. 2001 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) 2002 continue; 2003 if (DT->dominates(Store, End)) 2004 continue; 2005 // If all fusable ops are in the same block and the lifetime.end is in a 2006 // different block, it won't interfere with fusion. 2007 if (FusableOpsInSameBlock && End->getParent() != StoreParent) 2008 continue; 2009 2010 // If the loads don't alias the lifetime.end, it won't interfere with 2011 // fusion. 2012 MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); 2013 if (!EndLoc.Ptr) 2014 continue; 2015 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) 2016 continue; 2017 2018 // If both lifetime.end and the store are in the same block, extend the 2019 // lifetime until after the store, so the new lifetime covers the loads 2020 // we introduce later. 2021 if (End->getParent() == StoreParent) { 2022 End->moveAfter(Store); 2023 continue; 2024 } 2025 2026 // Otherwise remove the conflicting lifetime.end marker. 2027 ToRemove.push_back(End); 2028 std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); 2029 LifetimeEnds.pop_back(); 2030 Inc.release(); 2031 } 2032 2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 2034 return; 2035 } 2036 } 2037 2038 /// Lowers llvm.matrix.multiply. 2039 void LowerMultiply(CallInst *MatMul) { 2040 IRBuilder<> Builder(MatMul); 2041 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 2042 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 2043 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 2044 2045 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 2046 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 2047 assert(Lhs.getElementType() == Rhs.getElementType() && 2048 "Matrix multiply argument element types do not match."); 2049 2050 const unsigned R = LShape.NumRows; 2051 const unsigned C = RShape.NumColumns; 2052 assert(LShape.NumColumns == RShape.NumRows); 2053 2054 // Initialize the output 2055 MatrixTy Result(R, C, EltType); 2056 assert(Lhs.getElementType() == Result.getElementType() && 2057 "Matrix multiply result element type does not match arguments."); 2058 2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 2060 getFastMathFlags(MatMul)); 2061 finalizeLowering(MatMul, Result, Builder); 2062 } 2063 2064 /// Lowers llvm.matrix.transpose. 2065 void LowerTranspose(CallInst *Inst) { 2066 MatrixTy Result; 2067 IRBuilder<> Builder(Inst); 2068 Value *InputVal = Inst->getArgOperand(0); 2069 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 2070 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 2072 2073 const unsigned NewNumVecs = 2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 2075 const unsigned NewNumElts = 2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 2077 2078 for (unsigned I = 0; I < NewNumVecs; ++I) { 2079 // Build a single result vector. First initialize it. 2080 Value *ResultVector = PoisonValue::get( 2081 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 2082 // Go through the old elements and insert it into the resulting vector. 2083 for (auto J : enumerate(InputMatrix.vectors())) { 2084 Value *Elt = Builder.CreateExtractElement(J.value(), I); 2085 // Row and column indices are transposed. 2086 ResultVector = 2087 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 2088 } 2089 Result.addVector(ResultVector); 2090 } 2091 2092 // TODO: Improve estimate of operations needed for transposes. Currently we 2093 // just count the insertelement/extractelement instructions, but do not 2094 // account for later simplifications/combines. 2095 finalizeLowering( 2096 Inst, 2097 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 2098 .addNumExposedTransposes(1), 2099 Builder); 2100 } 2101 2102 /// Lower load instructions, if shape information is available. 2103 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 2104 auto I = ShapeMap.find(Inst); 2105 if (I == ShapeMap.end()) 2106 return false; 2107 2108 LowerLoad(Inst, Ptr, Inst->getAlign(), 2109 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2110 I->second); 2111 return true; 2112 } 2113 2114 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 2115 IRBuilder<> &Builder) { 2116 auto I = ShapeMap.find(StoredVal); 2117 if (I == ShapeMap.end()) 2118 return false; 2119 2120 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 2121 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2122 I->second); 2123 return true; 2124 } 2125 2126 /// Lower binary operators, if shape information is available. 2127 bool VisitBinaryOperator(BinaryOperator *Inst) { 2128 auto I = ShapeMap.find(Inst); 2129 if (I == ShapeMap.end()) 2130 return false; 2131 2132 Value *Lhs = Inst->getOperand(0); 2133 Value *Rhs = Inst->getOperand(1); 2134 2135 IRBuilder<> Builder(Inst); 2136 ShapeInfo &Shape = I->second; 2137 2138 MatrixTy Result; 2139 MatrixTy A = getMatrix(Lhs, Shape, Builder); 2140 MatrixTy B = getMatrix(Rhs, Shape, Builder); 2141 assert(A.isColumnMajor() == B.isColumnMajor() && 2142 Result.isColumnMajor() == A.isColumnMajor() && 2143 "operands must agree on matrix layout"); 2144 2145 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2146 2147 // Helper to perform binary op on vectors. 2148 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 2149 switch (Inst->getOpcode()) { 2150 case Instruction::Add: 2151 return Builder.CreateAdd(LHS, RHS); 2152 case Instruction::Mul: 2153 return Builder.CreateMul(LHS, RHS); 2154 case Instruction::Sub: 2155 return Builder.CreateSub(LHS, RHS); 2156 case Instruction::FAdd: 2157 return Builder.CreateFAdd(LHS, RHS); 2158 case Instruction::FMul: 2159 return Builder.CreateFMul(LHS, RHS); 2160 case Instruction::FSub: 2161 return Builder.CreateFSub(LHS, RHS); 2162 default: 2163 llvm_unreachable("Unsupported binary operator for matrix"); 2164 } 2165 }; 2166 2167 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2168 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 2169 2170 finalizeLowering(Inst, 2171 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2172 Result.getNumVectors()), 2173 Builder); 2174 return true; 2175 } 2176 2177 /// Lower unary operators, if shape information is available. 2178 bool VisitUnaryOperator(UnaryOperator *Inst) { 2179 auto I = ShapeMap.find(Inst); 2180 if (I == ShapeMap.end()) 2181 return false; 2182 2183 Value *Op = Inst->getOperand(0); 2184 2185 IRBuilder<> Builder(Inst); 2186 ShapeInfo &Shape = I->second; 2187 2188 MatrixTy Result; 2189 MatrixTy M = getMatrix(Op, Shape, Builder); 2190 2191 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2192 2193 // Helper to perform unary op on vectors. 2194 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2195 switch (Inst->getOpcode()) { 2196 case Instruction::FNeg: 2197 return Builder.CreateFNeg(Op); 2198 default: 2199 llvm_unreachable("Unsupported unary operator for matrix"); 2200 } 2201 }; 2202 2203 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2204 Result.addVector(BuildVectorOp(M.getVector(I))); 2205 2206 finalizeLowering(Inst, 2207 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2208 Result.getNumVectors()), 2209 Builder); 2210 return true; 2211 } 2212 2213 /// Helper to linearize a matrix expression tree into a string. Currently 2214 /// matrix expressions are linarized by starting at an expression leaf and 2215 /// linearizing bottom up. 2216 struct ExprLinearizer { 2217 unsigned LengthToBreak = 100; 2218 std::string Str; 2219 raw_string_ostream Stream; 2220 unsigned LineLength = 0; 2221 const DataLayout &DL; 2222 2223 /// Mapping from instructions to matrixes. It is used to identify 2224 /// matrix instructions. 2225 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2226 2227 /// Mapping from values to the leaves of all expressions that the value is 2228 /// part of. 2229 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 2230 2231 /// Set of matrix expressions in the scope of a given DISubprogram. 2232 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 2233 2234 /// Leaf node of the expression to linearize. 2235 Value *Leaf; 2236 2237 /// Used to keep track of sub-expressions that get reused while linearizing 2238 /// the expression. Re-used sub-expressions are marked as (reused). 2239 SmallPtrSet<Value *, 8> ReusedExprs; 2240 2241 ExprLinearizer(const DataLayout &DL, 2242 const MapVector<Value *, MatrixTy> &Inst2Matrix, 2243 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2244 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2245 Value *Leaf) 2246 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 2248 2249 void indent(unsigned N) { 2250 LineLength += N; 2251 for (unsigned i = 0; i < N; i++) 2252 Stream << " "; 2253 } 2254 2255 void lineBreak() { 2256 Stream << "\n"; 2257 LineLength = 0; 2258 } 2259 2260 void maybeIndent(unsigned Indent) { 2261 if (LineLength >= LengthToBreak) 2262 lineBreak(); 2263 2264 if (LineLength == 0) 2265 indent(Indent); 2266 } 2267 2268 void write(StringRef S) { 2269 LineLength += S.size(); 2270 Stream << S; 2271 } 2272 2273 Value *getUnderlyingObjectThroughLoads(Value *V) { 2274 if (Value *Ptr = getPointerOperand(V)) 2275 return getUnderlyingObjectThroughLoads(Ptr); 2276 else if (V->getType()->isPointerTy()) 2277 return getUnderlyingObject(V); 2278 return V; 2279 } 2280 2281 /// Returns true if \p V is a matrix value in the given subprogram. 2282 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2283 2284 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 2285 /// \p SS. 2286 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2287 auto M = Inst2Matrix.find(V); 2288 if (M == Inst2Matrix.end()) 2289 SS << "unknown"; 2290 else { 2291 SS << M->second.getNumRows(); 2292 SS << "x"; 2293 SS << M->second.getNumColumns(); 2294 } 2295 } 2296 2297 /// Write the called function name. Handles calls to llvm.matrix.* 2298 /// specially: we write the name, followed by the dimensions of the input 2299 /// matrixes, followed by the scalar type name. 2300 void writeFnName(CallInst *CI) { 2301 if (!CI->getCalledFunction()) 2302 write("<no called fn>"); 2303 else { 2304 StringRef Name = CI->getCalledFunction()->getName(); 2305 if (!Name.starts_with("llvm.matrix")) { 2306 write(Name); 2307 return; 2308 } 2309 auto *II = cast<IntrinsicInst>(CI); 2310 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2311 .drop_front(StringRef("llvm.matrix.").size())); 2312 write("."); 2313 std::string Tmp; 2314 raw_string_ostream SS(Tmp); 2315 2316 switch (II->getIntrinsicID()) { 2317 case Intrinsic::matrix_multiply: 2318 prettyPrintMatrixType(II->getOperand(0), SS); 2319 SS << "."; 2320 prettyPrintMatrixType(II->getOperand(1), SS); 2321 SS << "." << *II->getType()->getScalarType(); 2322 break; 2323 case Intrinsic::matrix_transpose: 2324 prettyPrintMatrixType(II->getOperand(0), SS); 2325 SS << "." << *II->getType()->getScalarType(); 2326 break; 2327 case Intrinsic::matrix_column_major_load: 2328 prettyPrintMatrixType(II, SS); 2329 SS << "." << *II->getType()->getScalarType(); 2330 break; 2331 case Intrinsic::matrix_column_major_store: 2332 prettyPrintMatrixType(II->getOperand(0), SS); 2333 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2334 break; 2335 default: 2336 llvm_unreachable("Unhandled case"); 2337 } 2338 write(Tmp); 2339 } 2340 } 2341 2342 unsigned getNumShapeArgs(CallInst *CI) const { 2343 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2344 switch (II->getIntrinsicID()) { 2345 case Intrinsic::matrix_multiply: 2346 return 3; 2347 case Intrinsic::matrix_transpose: 2348 return 2; 2349 case Intrinsic::matrix_column_major_load: 2350 case Intrinsic::matrix_column_major_store: 2351 return 3; 2352 default: 2353 return 0; 2354 } 2355 } 2356 return 0; 2357 } 2358 2359 /// Special printing for values: for pointers, we print if they refer to an 2360 /// (function) external address or a stack address, for other values we 2361 /// either print the constant or "scalar"/"matrix" for other values. 2362 void write(Value *V) { 2363 V = getUnderlyingObjectThroughLoads(V); 2364 if (V->getType()->isPointerTy()) { 2365 if (isa<AllocaInst>(V)) { 2366 Stream << "stack addr"; 2367 LineLength += StringRef("stack addr").size(); 2368 } else { 2369 Stream << "addr"; 2370 LineLength += StringRef("addr").size(); 2371 } 2372 if (!V->getName().empty()) { 2373 Stream << " %" << V->getName() << ""; 2374 LineLength += V->getName().size() + 2; 2375 } 2376 return; 2377 } 2378 2379 std::string Tmp; 2380 raw_string_ostream TmpStream(Tmp); 2381 2382 if (auto *CI = dyn_cast<ConstantInt>(V)) 2383 TmpStream << CI->getValue(); 2384 else if (isa<Constant>(V)) 2385 TmpStream << "constant"; 2386 else { 2387 if (isMatrix(V)) 2388 TmpStream << "matrix"; 2389 else 2390 TmpStream << "scalar"; 2391 } 2392 Tmp = std::string(StringRef(Tmp).trim()); 2393 LineLength += Tmp.size(); 2394 Stream << Tmp; 2395 } 2396 2397 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2398 /// Expressions that are re-used multiple times are prefixed with (reused) 2399 /// at the re-used root instruction. 2400 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2401 bool ParentShared) { 2402 auto *I = cast<Instruction>(Expr); 2403 maybeIndent(Indent); 2404 SmallVector<Value *, 8> Ops; 2405 2406 // Is Expr shared with other expression leaves? 2407 bool ExprShared = false; 2408 2409 // Deal with shared subtrees. Mark them as shared, if required. 2410 if (!ParentShared) { 2411 auto SI = Shared.find(Expr); 2412 assert(SI != Shared.end() && SI->second.count(Leaf)); 2413 2414 for (Value *S : SI->second) { 2415 if (S == Leaf) 2416 continue; 2417 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2418 write("shared with remark at line " + std::to_string(DL.getLine()) + 2419 " column " + std::to_string(DL.getCol()) + " ("); 2420 } 2421 ExprShared = SI->second.size() > 1; 2422 } 2423 2424 bool Reused = !ReusedExprs.insert(Expr).second; 2425 if (Reused && !ParentReused) 2426 write("(reused) "); 2427 2428 if (auto *CI = dyn_cast<CallInst>(I)) { 2429 writeFnName(CI); 2430 2431 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2432 } else if (isa<BitCastInst>(Expr)) { 2433 // Special case bitcasts, which are used to materialize matrixes from 2434 // non-matrix ops. 2435 write("matrix"); 2436 return; 2437 } else { 2438 Ops.append(I->value_op_begin(), I->value_op_end()); 2439 write(std::string(I->getOpcodeName())); 2440 } 2441 2442 write(std::string("(")); 2443 2444 unsigned NumOpsToBreak = 1; 2445 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2446 NumOpsToBreak = 2; 2447 2448 for (Value *Op : Ops) { 2449 if (Ops.size() > NumOpsToBreak) 2450 lineBreak(); 2451 2452 maybeIndent(Indent + 1); 2453 if (isMatrix(Op)) 2454 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2455 else 2456 write(Op); 2457 if (Op != Ops.back()) 2458 write(", "); 2459 } 2460 2461 write(")"); 2462 } 2463 2464 const std::string &getResult() { 2465 return Str; 2466 } 2467 }; 2468 2469 /// Generate remarks for matrix operations in a function. To generate remarks 2470 /// for matrix expressions, the following approach is used: 2471 /// 1. Use the inlined-at debug information to group matrix operations to the 2472 /// DISubprograms they are contained in. 2473 /// 2. Collect leaves of matrix expressions (done in 2474 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2475 // mapping. Leaves are lowered matrix instructions without other matrix 2476 // users (like stores) in the current subprogram. 2477 /// 3. For each leaf, create a remark containing a linearizied version of the 2478 /// matrix expression. The expression is linearized by a recursive 2479 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2480 /// that multiple leaves can share sub-expressions. Shared subexpressions 2481 /// are explicitly marked as shared(). 2482 struct RemarkGenerator { 2483 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2484 OptimizationRemarkEmitter &ORE; 2485 Function &Func; 2486 const DataLayout &DL; 2487 2488 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2489 OptimizationRemarkEmitter &ORE, Function &Func) 2490 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2491 DL(Func.getDataLayout()) {} 2492 2493 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2494 /// instructions in Inst2Matrix returning void or without any users in 2495 /// \p ExprsInSubprogram. Currently that should only include stores. 2496 SmallVector<Value *, 4> 2497 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2498 SmallVector<Value *, 4> Leaves; 2499 for (auto *Expr : ExprsInSubprogram) 2500 if (Expr->getType()->isVoidTy() || 2501 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2502 return ExprsInSubprogram.count(U); 2503 })) 2504 Leaves.push_back(Expr); 2505 return Leaves; 2506 } 2507 2508 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2509 /// to all visited expressions in \p Shared. Limit the matrix operations to 2510 /// the ones in \p ExprsInSubprogram. 2511 void collectSharedInfo(Value *Leaf, Value *V, 2512 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2513 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2514 2515 if (!ExprsInSubprogram.count(V)) 2516 return; 2517 2518 Shared[V].insert(Leaf); 2519 2520 for (Value *Op : cast<Instruction>(V)->operand_values()) 2521 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2522 } 2523 2524 /// Calculate the number of exclusive and shared op counts for expression 2525 /// starting at \p V. Expressions used multiple times are counted once. 2526 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2527 std::pair<OpInfoTy, OpInfoTy> 2528 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2529 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2530 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2531 if (!ExprsInSubprogram.count(Root)) 2532 return {}; 2533 2534 // Already counted this expression. Stop. 2535 if (!ReusedExprs.insert(Root).second) 2536 return {}; 2537 2538 OpInfoTy SharedCount; 2539 OpInfoTy Count; 2540 2541 auto I = Shared.find(Root); 2542 auto CM = Inst2Matrix.find(Root); 2543 if (I->second.size() == 1) 2544 Count = CM->second.getOpInfo(); 2545 else 2546 SharedCount = CM->second.getOpInfo(); 2547 2548 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2549 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2550 Count += C.first; 2551 SharedCount += C.second; 2552 } 2553 return {Count, SharedCount}; 2554 } 2555 2556 void emitRemarks() { 2557 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2558 return; 2559 2560 // Map matrix operations to their containting subprograms, by traversing 2561 // the inlinedAt chain. If the function does not have a DISubprogram, we 2562 // only map them to the containing function. 2563 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2564 for (const auto &KV : Inst2Matrix) { 2565 if (Func.getSubprogram()) { 2566 auto *I = cast<Instruction>(KV.first); 2567 DILocation *Context = I->getDebugLoc(); 2568 while (Context) { 2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back( 2570 KV.first); 2571 Context = DebugLoc(Context).getInlinedAt(); 2572 } 2573 } else { 2574 Subprog2Exprs[nullptr].push_back(KV.first); 2575 } 2576 } 2577 for (auto &KV : Subprog2Exprs) { 2578 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2579 KV.second.end()); 2580 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2581 2582 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2583 for (Value *Leaf : Leaves) 2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2585 2586 // Generate remarks for each leaf. 2587 for (auto *L : Leaves) { 2588 2589 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2590 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2591 while (Context) { 2592 if (getSubprogram(Context->getScope()) == KV.first) { 2593 Loc = Context; 2594 break; 2595 } 2596 Context = DebugLoc(Context).getInlinedAt(); 2597 } 2598 2599 SmallPtrSet<Value *, 8> ReusedExprs; 2600 OpInfoTy Counts, SharedCounts; 2601 std::tie(Counts, SharedCounts) = 2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2603 2604 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2605 cast<Instruction>(L)->getParent()); 2606 2607 Rem << "Lowered with "; 2608 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2609 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2610 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2611 << " compute ops, " 2612 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2613 << " exposed transposes"; 2614 2615 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2616 SharedCounts.NumComputeOps > 0) { 2617 Rem << ",\nadditionally " 2618 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2619 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2620 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2621 << " compute ops" 2622 << " are shared with other expressions"; 2623 } 2624 2625 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2626 ORE.emit(Rem); 2627 } 2628 } 2629 } 2630 2631 std::string 2632 linearize(Value *L, 2633 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2634 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2635 const DataLayout &DL) { 2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2637 Lin.linearizeExpr(L, 0, false, false); 2638 return Lin.getResult(); 2639 } 2640 }; 2641 }; 2642 } // namespace 2643 2644 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2645 FunctionAnalysisManager &AM) { 2646 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2647 2648 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM); 2649 if (LMT.Visit()) { 2650 PreservedAnalyses PA; 2651 if (!Minimal) { 2652 PA.preserve<LoopAnalysis>(); 2653 PA.preserve<DominatorTreeAnalysis>(); 2654 } 2655 return PA; 2656 } 2657 return PreservedAnalyses::all(); 2658 } 2659 2660 void LowerMatrixIntrinsicsPass::printPipeline( 2661 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2662 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2663 OS, MapClassName2PassName); 2664 OS << '<'; 2665 if (Minimal) 2666 OS << "minimal"; 2667 OS << '>'; 2668 } 2669