1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Alignment.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/Debug.h" 43 #include "llvm/Transforms/Scalar.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/LoopUtils.h" 46 #include "llvm/Transforms/Utils/MatrixUtils.h" 47 48 using namespace llvm; 49 using namespace PatternMatch; 50 51 #define DEBUG_TYPE "lower-matrix-intrinsics" 52 53 static cl::opt<bool> 54 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 55 cl::desc("Enable/disable fusing matrix instructions.")); 56 // TODO: Allow and use non-square tiles. 57 static cl::opt<unsigned> TileSize( 58 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 59 cl::desc( 60 "Tile size for matrix instruction fusion using square-shaped tiles.")); 61 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 62 cl::Hidden, 63 cl::desc("Generate loop nest for tiling.")); 64 static cl::opt<bool> ForceFusion( 65 "force-fuse-matrix", cl::init(false), cl::Hidden, 66 cl::desc("Force matrix instruction fusion even if not profitable.")); 67 static cl::opt<bool> AllowContractEnabled( 68 "matrix-allow-contract", cl::init(false), cl::Hidden, 69 cl::desc("Allow the use of FMAs if available and profitable. This may " 70 "result in different results, due to less rounding error.")); 71 72 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 73 74 static cl::opt<MatrixLayoutTy> MatrixLayout( 75 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 76 cl::desc("Sets the default matrix layout"), 77 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 78 "Use column-major layout"), 79 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 80 "Use row-major layout"))); 81 82 /// Helper function to either return Scope, if it is a subprogram or the 83 /// attached subprogram for a local scope. 84 static DISubprogram *getSubprogram(DIScope *Scope) { 85 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 86 return Subprogram; 87 return cast<DILocalScope>(Scope)->getSubprogram(); 88 } 89 90 namespace { 91 92 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 93 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 94 // assuming \p Stride elements between start two consecutive vectors. 95 // \p Stride must be >= \p NumElements. 96 // For column-major matrixes, the function computes the address of a column 97 // vectors and \p NumElements must be set to the number of elements in a column 98 // (= number of rows of the matrix). For row-major matrixes, the function 99 // computes the address of a row vector and \p NumElements must be set to the 100 // number of elements in a column (= number of columns of the matrix). 101 // 102 // Consider a 4x4 matrix in column-mjaor layout like below 103 // 104 // 0 1 2 3 105 // 0 v_0_0 v_0_1 v_0_2 v_0_3 106 // 1 v_1_0 v_1_1 v_1_2 v_1_3 107 // 2 v_2_0 v_2_1 v_2_2 v_2_3 108 // 3 v_3_0 v_3_1 v_3_2 v_3_3 109 110 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 111 // we need a pointer to the first element of the submatrix as base pointer. 112 // Then we can use computeVectorAddr to compute the addresses for the columns 113 // of the sub-matrix. 114 // 115 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 116 // -> just returns Base 117 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 118 // -> returns Base + (1 * 4) 119 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 120 // -> returns Base + (2 * 4) 121 // 122 // The graphic below illustrates the number of elements in a column (marked 123 // with |) and the number of skipped elements (marked with }). 124 // 125 // v_0_0 v_0_1 {v_0_2 {v_0_3 126 // Base Col 1 Col 2 127 // | | | 128 // v_1_0 |v_1_1 |v_1_2 |v_1_3 129 // v_2_0 |v_2_1 |v_2_2 |v_2_3 130 // v_3_0 {v_3_1 {v_3_2 v_3_3 131 // 132 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 133 unsigned NumElements, Type *EltType, 134 IRBuilder<> &Builder) { 135 136 assert((!isa<ConstantInt>(Stride) || 137 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 138 "Stride must be >= the number of elements in the result vector."); 139 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 140 141 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 142 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 143 144 // Get pointer to the start of the selected vector. Skip GEP creation, 145 // if we select vector 0. 146 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 147 VecStart = BasePtr; 148 else 149 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 150 151 // Cast elementwise vector start pointer to a pointer to a vector 152 // (EltType x NumElements)*. 153 auto *VecType = FixedVectorType::get(EltType, NumElements); 154 Type *VecPtrType = PointerType::get(VecType, AS); 155 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 156 } 157 158 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 159 /// 160 /// Currently, the lowering for each matrix intrinsic is done as follows: 161 /// 1. Propagate the shape information from intrinsics to connected 162 /// instructions. 163 /// 2. Lower instructions with shape information (assuming column-major layout). 164 /// The lowering works similarly using row-major layout. 165 /// 2.1. Get column vectors for each argument. If we already lowered the 166 /// definition of an argument, use the produced column vectors directly. 167 /// If not, split the operand vector containing an embedded matrix into 168 /// a set of column vectors, 169 /// 2.2. Lower the instruction in terms of column major operations, which 170 /// yields a set of column vectors containing result matrix. Note that we 171 /// lower all instructions that have shape information. Besides the 172 /// intrinsics, this includes stores for example. 173 /// 2.3. Update uses of the lowered instruction. If we have shape information 174 /// for a user, there is nothing to do, as we will look up the result 175 /// column matrix when lowering the user. For other uses, we embed the 176 /// result matrix in a flat vector and update the use. 177 /// 2.4. Cache the result column matrix for the instruction we lowered 178 /// 3. After we lowered all instructions in a function, remove the now 179 /// obsolete instructions. 180 /// 181 class LowerMatrixIntrinsics { 182 Function &Func; 183 const DataLayout &DL; 184 const TargetTransformInfo &TTI; 185 AliasAnalysis *AA; 186 DominatorTree *DT; 187 LoopInfo *LI; 188 OptimizationRemarkEmitter *ORE; 189 190 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 191 struct OpInfoTy { 192 /// Number of stores emitted to generate this matrix. 193 unsigned NumStores = 0; 194 /// Number of loads emitted to generate this matrix. 195 unsigned NumLoads = 0; 196 /// Number of compute operations emitted to generate this matrix. 197 unsigned NumComputeOps = 0; 198 199 OpInfoTy &operator+=(const OpInfoTy &RHS) { 200 NumStores += RHS.NumStores; 201 NumLoads += RHS.NumLoads; 202 NumComputeOps += RHS.NumComputeOps; 203 return *this; 204 } 205 }; 206 207 /// Wrapper class representing a matrix as a set of vectors, either in row or 208 /// column major layout. All vectors must have the same vector type. 209 class MatrixTy { 210 SmallVector<Value *, 16> Vectors; 211 212 OpInfoTy OpInfo; 213 214 bool IsColumnMajor = true; 215 216 public: 217 MatrixTy() 218 : Vectors(), 219 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 220 MatrixTy(ArrayRef<Value *> Vectors) 221 : Vectors(Vectors.begin(), Vectors.end()), 222 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 223 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 224 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 225 226 unsigned D = isColumnMajor() ? NumColumns : NumRows; 227 for (unsigned J = 0; J < D; ++J) 228 addVector(UndefValue::get(FixedVectorType::get( 229 EltTy, isColumnMajor() ? NumRows : NumColumns))); 230 } 231 232 Value *getVector(unsigned i) const { return Vectors[i]; } 233 Value *getColumn(unsigned i) const { 234 assert(isColumnMajor() && "only supported for column-major matrixes"); 235 return Vectors[i]; 236 } 237 Value *getRow(unsigned i) const { 238 assert(!isColumnMajor() && "only supported for row-major matrixes"); 239 return Vectors[i]; 240 } 241 242 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 243 244 Type *getElementType() const { return getVectorTy()->getElementType(); } 245 246 unsigned getNumVectors() const { 247 if (isColumnMajor()) 248 return getNumColumns(); 249 return getNumRows(); 250 } 251 252 unsigned getNumColumns() const { 253 if (isColumnMajor()) 254 return Vectors.size(); 255 else { 256 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 257 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 258 } 259 } 260 unsigned getNumRows() const { 261 if (isColumnMajor()) { 262 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 263 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 264 } else 265 return Vectors.size(); 266 } 267 268 void addVector(Value *V) { Vectors.push_back(V); } 269 VectorType *getColumnTy() { 270 assert(isColumnMajor() && "only supported for column-major matrixes"); 271 return getVectorTy(); 272 } 273 274 VectorType *getVectorTy() const { 275 return cast<VectorType>(Vectors[0]->getType()); 276 } 277 278 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 279 assert(isColumnMajor() && 280 "columns() only supported for column-major matrixes"); 281 return make_range(Vectors.begin(), Vectors.end()); 282 } 283 284 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 285 return make_range(Vectors.begin(), Vectors.end()); 286 } 287 288 /// Embed the vectors of the matrix into a flat vector by concatenating 289 /// them. 290 Value *embedInVector(IRBuilder<> &Builder) const { 291 return Vectors.size() == 1 ? Vectors[0] 292 : concatenateVectors(Builder, Vectors); 293 } 294 295 MatrixTy &addNumLoads(unsigned N) { 296 OpInfo.NumLoads += N; 297 return *this; 298 } 299 300 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 301 302 MatrixTy &addNumStores(unsigned N) { 303 OpInfo.NumStores += N; 304 return *this; 305 } 306 307 MatrixTy &addNumComputeOps(unsigned N) { 308 OpInfo.NumComputeOps += N; 309 return *this; 310 } 311 312 unsigned getNumStores() const { return OpInfo.NumStores; } 313 unsigned getNumLoads() const { return OpInfo.NumLoads; } 314 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 315 316 const OpInfoTy &getOpInfo() const { return OpInfo; } 317 318 bool isColumnMajor() const { return IsColumnMajor; } 319 320 unsigned getStride() const { 321 if (isColumnMajor()) 322 return getNumRows(); 323 return getNumColumns(); 324 } 325 326 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 327 /// matrix is column-major, the result vector is extracted from a column 328 /// vector, otherwise from a row vector. 329 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 330 IRBuilder<> &Builder) const { 331 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 332 return Builder.CreateShuffleVector( 333 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 334 "block"); 335 } 336 }; 337 338 struct ShapeInfo { 339 unsigned NumRows; 340 unsigned NumColumns; 341 342 bool IsColumnMajor; 343 344 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 345 : NumRows(NumRows), NumColumns(NumColumns), 346 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 347 348 ShapeInfo(Value *NumRows, Value *NumColumns) 349 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 350 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 351 352 bool operator==(const ShapeInfo &other) { 353 return NumRows == other.NumRows && NumColumns == other.NumColumns; 354 } 355 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 356 357 /// Returns true if shape-information is defined, meaning both dimensions 358 /// are != 0. 359 operator bool() const { 360 assert(NumRows == 0 || NumColumns != 0); 361 return NumRows != 0; 362 } 363 364 unsigned getStride() const { 365 if (IsColumnMajor) 366 return NumRows; 367 return NumColumns; 368 } 369 370 unsigned getNumVectors() const { 371 if (IsColumnMajor) 372 return NumColumns; 373 return NumRows; 374 } 375 }; 376 377 /// Maps instructions to their shape information. The shape information 378 /// describes the shape to be used while lowering. This matches the shape of 379 /// the result value of the instruction, with the only exceptions being store 380 /// instructions and the matrix_column_major_store intrinsics. For those, the 381 /// shape information indicates that those instructions should be lowered 382 /// using shape information as well. 383 DenseMap<Value *, ShapeInfo> ShapeMap; 384 385 /// List of instructions to remove. While lowering, we are not replacing all 386 /// users of a lowered instruction, if shape information is available and 387 /// those need to be removed after we finished lowering. 388 SmallVector<Instruction *, 16> ToRemove; 389 390 /// Map from instructions to their produced column matrix. 391 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 392 393 public: 394 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 395 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 396 OptimizationRemarkEmitter *ORE) 397 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 398 LI(LI), ORE(ORE) {} 399 400 unsigned getNumOps(Type *VT) { 401 assert(isa<VectorType>(VT) && "Expected vector type"); 402 return getNumOps(VT->getScalarType(), 403 cast<FixedVectorType>(VT)->getNumElements()); 404 } 405 406 // 407 /// Return the estimated number of vector ops required for an operation on 408 /// \p VT * N. 409 unsigned getNumOps(Type *ST, unsigned N) { 410 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 411 double(TTI.getRegisterBitWidth( 412 TargetTransformInfo::RGK_FixedWidthVector) 413 .getFixedSize())); 414 } 415 416 /// Return the set of vectors that a matrix value is lowered to. 417 /// 418 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 419 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 420 /// into vectors. 421 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 422 IRBuilder<> &Builder) { 423 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 424 assert(VType && "MatrixVal must be a vector type"); 425 assert(cast<FixedVectorType>(VType)->getNumElements() == 426 SI.NumRows * SI.NumColumns && 427 "The vector size must match the number of matrix elements"); 428 429 // Check if we lowered MatrixVal using shape information. In that case, 430 // return the existing matrix, if it matches the requested shape 431 // information. If there is a mis-match, embed the result in a flat 432 // vector and split it later. 433 auto Found = Inst2ColumnMatrix.find(MatrixVal); 434 if (Found != Inst2ColumnMatrix.end()) { 435 MatrixTy &M = Found->second; 436 // Return the found matrix, if its shape matches the requested shape 437 // information 438 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 439 return M; 440 441 MatrixVal = M.embedInVector(Builder); 442 } 443 444 // Otherwise split MatrixVal. 445 SmallVector<Value *, 16> SplitVecs; 446 for (unsigned MaskStart = 0; 447 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 448 MaskStart += SI.getStride()) { 449 Value *V = Builder.CreateShuffleVector( 450 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 451 "split"); 452 SplitVecs.push_back(V); 453 } 454 455 return {SplitVecs}; 456 } 457 458 /// If \p V already has a known shape return false. Otherwise set the shape 459 /// for instructions that support it. 460 bool setShapeInfo(Value *V, ShapeInfo Shape) { 461 assert(Shape && "Shape not set"); 462 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 463 return false; 464 465 auto SIter = ShapeMap.find(V); 466 if (SIter != ShapeMap.end()) { 467 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 468 << SIter->second.NumRows << " " 469 << SIter->second.NumColumns << " for " << *V << "\n"); 470 return false; 471 } 472 473 ShapeMap.insert({V, Shape}); 474 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 475 << " for " << *V << "\n"); 476 return true; 477 } 478 479 bool isUniformShape(Value *V) { 480 Instruction *I = dyn_cast<Instruction>(V); 481 if (!I) 482 return true; 483 484 switch (I->getOpcode()) { 485 case Instruction::FAdd: 486 case Instruction::FSub: 487 case Instruction::FMul: // Scalar multiply. 488 case Instruction::FNeg: 489 case Instruction::Add: 490 case Instruction::Mul: 491 case Instruction::Sub: 492 return true; 493 default: 494 return false; 495 } 496 } 497 498 /// Returns true if shape information can be used for \p V. The supported 499 /// instructions must match the instructions that can be lowered by this pass. 500 bool supportsShapeInfo(Value *V) { 501 Instruction *Inst = dyn_cast<Instruction>(V); 502 if (!Inst) 503 return false; 504 505 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 506 if (II) 507 switch (II->getIntrinsicID()) { 508 case Intrinsic::matrix_multiply: 509 case Intrinsic::matrix_transpose: 510 case Intrinsic::matrix_column_major_load: 511 case Intrinsic::matrix_column_major_store: 512 return true; 513 default: 514 return false; 515 } 516 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 517 } 518 519 /// Propagate the shape information of instructions to their users. 520 /// The work list contains instructions for which we can compute the shape, 521 /// either based on the information provided by matrix intrinsics or known 522 /// shapes of operands. 523 SmallVector<Instruction *, 32> 524 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 525 SmallVector<Instruction *, 32> NewWorkList; 526 // Pop an element for which we guaranteed to have at least one of the 527 // operand shapes. Add the shape for this and then add users to the work 528 // list. 529 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 530 while (!WorkList.empty()) { 531 Instruction *Inst = WorkList.pop_back_val(); 532 533 // New entry, set the value and insert operands 534 bool Propagate = false; 535 536 Value *MatrixA; 537 Value *MatrixB; 538 Value *M; 539 Value *N; 540 Value *K; 541 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 542 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 543 m_Value(N), m_Value(K)))) { 544 Propagate = setShapeInfo(Inst, {M, K}); 545 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 546 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 547 // Flip dimensions. 548 Propagate = setShapeInfo(Inst, {N, M}); 549 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 550 m_Value(MatrixA), m_Value(), m_Value(), 551 m_Value(), m_Value(M), m_Value(N)))) { 552 Propagate = setShapeInfo(Inst, {N, M}); 553 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 554 m_Value(), m_Value(), m_Value(), m_Value(M), 555 m_Value(N)))) { 556 Propagate = setShapeInfo(Inst, {M, N}); 557 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 558 auto OpShape = ShapeMap.find(MatrixA); 559 if (OpShape != ShapeMap.end()) 560 setShapeInfo(Inst, OpShape->second); 561 continue; 562 } else if (isUniformShape(Inst)) { 563 // Find the first operand that has a known shape and use that. 564 for (auto &Op : Inst->operands()) { 565 auto OpShape = ShapeMap.find(Op.get()); 566 if (OpShape != ShapeMap.end()) { 567 Propagate |= setShapeInfo(Inst, OpShape->second); 568 break; 569 } 570 } 571 } 572 573 if (Propagate) { 574 NewWorkList.push_back(Inst); 575 for (auto *User : Inst->users()) 576 if (ShapeMap.count(User) == 0) 577 WorkList.push_back(cast<Instruction>(User)); 578 } 579 } 580 581 return NewWorkList; 582 } 583 584 /// Propagate the shape to operands of instructions with shape information. 585 /// \p Worklist contains the instruction for which we already know the shape. 586 SmallVector<Instruction *, 32> 587 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 588 SmallVector<Instruction *, 32> NewWorkList; 589 590 auto pushInstruction = [](Value *V, 591 SmallVectorImpl<Instruction *> &WorkList) { 592 Instruction *I = dyn_cast<Instruction>(V); 593 if (I) 594 WorkList.push_back(I); 595 }; 596 // Pop an element with known shape. Traverse the operands, if their shape 597 // derives from the result shape and is unknown, add it and add them to the 598 // worklist. 599 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 600 while (!WorkList.empty()) { 601 Value *V = WorkList.pop_back_val(); 602 603 size_t BeforeProcessingV = WorkList.size(); 604 if (!isa<Instruction>(V)) 605 continue; 606 607 Value *MatrixA; 608 Value *MatrixB; 609 Value *M; 610 Value *N; 611 Value *K; 612 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 613 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 614 m_Value(N), m_Value(K)))) { 615 if (setShapeInfo(MatrixA, {M, N})) 616 pushInstruction(MatrixA, WorkList); 617 618 if (setShapeInfo(MatrixB, {N, K})) 619 pushInstruction(MatrixB, WorkList); 620 621 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 622 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 623 // Flip dimensions. 624 if (setShapeInfo(MatrixA, {M, N})) 625 pushInstruction(MatrixA, WorkList); 626 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 627 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 628 m_Value(M), m_Value(N)))) { 629 if (setShapeInfo(MatrixA, {M, N})) { 630 pushInstruction(MatrixA, WorkList); 631 } 632 } else if (isa<LoadInst>(V) || 633 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 634 // Nothing to do, no matrix input. 635 } else if (isa<StoreInst>(V)) { 636 // Nothing to do. We forward-propagated to this so we would just 637 // backward propagate to an instruction with an already known shape. 638 } else if (isUniformShape(V)) { 639 // Propagate to all operands. 640 ShapeInfo Shape = ShapeMap[V]; 641 for (Use &U : cast<Instruction>(V)->operands()) { 642 if (setShapeInfo(U.get(), Shape)) 643 pushInstruction(U.get(), WorkList); 644 } 645 } 646 // After we discovered new shape info for new instructions in the 647 // worklist, we use their users as seeds for the next round of forward 648 // propagation. 649 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 650 for (User *U : WorkList[I]->users()) 651 if (isa<Instruction>(U) && V != U) 652 NewWorkList.push_back(cast<Instruction>(U)); 653 } 654 return NewWorkList; 655 } 656 657 bool Visit() { 658 SmallVector<Instruction *, 32> WorkList; 659 660 // Initially only the shape of matrix intrinsics is known. 661 // Initialize the work list with ops carrying shape information. 662 for (BasicBlock &BB : Func) 663 for (Instruction &Inst : BB) { 664 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 665 if (!II) 666 continue; 667 668 switch (II->getIntrinsicID()) { 669 case Intrinsic::matrix_multiply: 670 case Intrinsic::matrix_transpose: 671 case Intrinsic::matrix_column_major_load: 672 case Intrinsic::matrix_column_major_store: 673 WorkList.push_back(&Inst); 674 break; 675 default: 676 break; 677 } 678 } 679 680 // Avoid unnecessary work if there are no matrix intrinsics in the function. 681 if (WorkList.empty()) 682 return false; 683 684 // Propagate shapes until nothing changes any longer. 685 while (!WorkList.empty()) { 686 WorkList = propagateShapeForward(WorkList); 687 WorkList = propagateShapeBackward(WorkList); 688 } 689 690 bool Changed = false; 691 SmallVector<CallInst *, 16> MaybeFusableInsts; 692 SmallVector<Instruction *, 16> MatrixInsts; 693 694 // First, collect all instructions with shape information and candidates for 695 // fusion (currently only matrix multiplies). 696 ReversePostOrderTraversal<Function *> RPOT(&Func); 697 for (auto *BB : RPOT) 698 for (Instruction &I : *BB) { 699 if (ShapeMap.find(&I) == ShapeMap.end()) 700 continue; 701 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 702 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 703 MatrixInsts.push_back(&I); 704 } 705 706 // Second, try to fuse candidates. 707 SmallPtrSet<Instruction *, 16> FusedInsts; 708 for (CallInst *CI : MaybeFusableInsts) 709 LowerMatrixMultiplyFused(CI, FusedInsts); 710 Changed = !FusedInsts.empty(); 711 712 // Third, lower remaining instructions with shape information. 713 for (Instruction *Inst : MatrixInsts) { 714 if (FusedInsts.count(Inst)) 715 continue; 716 717 IRBuilder<> Builder(Inst); 718 719 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 720 Changed |= VisitCallInst(CInst); 721 722 Value *Op1; 723 Value *Op2; 724 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 725 Changed |= VisitBinaryOperator(BinOp); 726 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 727 Changed |= VisitUnaryOperator(UnOp); 728 if (match(Inst, m_Load(m_Value(Op1)))) 729 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 730 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 731 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 732 } 733 734 if (ORE) { 735 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 736 RemarkGen.emitRemarks(); 737 } 738 739 // Delete the instructions backwards, as it has a reduced likelihood of 740 // having to update as many def-use and use-def chains. 741 for (auto *Inst : reverse(ToRemove)) { 742 if (!Inst->use_empty()) 743 Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); 744 Inst->eraseFromParent(); 745 } 746 747 return Changed; 748 } 749 750 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 751 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 752 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 753 Type *EltPtrType = PointerType::get(EltType, AS); 754 return Builder.CreatePointerCast(BasePtr, EltPtrType); 755 } 756 757 /// Replace intrinsic calls 758 bool VisitCallInst(CallInst *Inst) { 759 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 760 return false; 761 762 switch (Inst->getCalledFunction()->getIntrinsicID()) { 763 case Intrinsic::matrix_multiply: 764 LowerMultiply(Inst); 765 break; 766 case Intrinsic::matrix_transpose: 767 LowerTranspose(Inst); 768 break; 769 case Intrinsic::matrix_column_major_load: 770 LowerColumnMajorLoad(Inst); 771 break; 772 case Intrinsic::matrix_column_major_store: 773 LowerColumnMajorStore(Inst); 774 break; 775 default: 776 return false; 777 } 778 return true; 779 } 780 781 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 782 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 783 /// ConstantInt, reduce the initial alignment based on the byte offset. For 784 /// non-ConstantInt strides, return the common alignment of the initial 785 /// alignment and the element size in bytes. 786 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 787 MaybeAlign A) const { 788 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 789 if (Idx == 0) 790 return InitialAlign; 791 792 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 793 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 794 uint64_t StrideInBytes = 795 ConstStride->getZExtValue() * ElementSizeInBits / 8; 796 return commonAlignment(InitialAlign, Idx * StrideInBytes); 797 } 798 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 799 } 800 801 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 802 /// vectors. 803 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 804 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 805 auto *VType = cast<VectorType>(Ty); 806 Type *EltTy = VType->getElementType(); 807 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 808 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 809 MatrixTy Result; 810 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 811 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 812 Shape.getStride(), EltTy, Builder); 813 Value *Vector = Builder.CreateAlignedLoad( 814 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 815 IsVolatile, "col.load"); 816 817 Result.addVector(Vector); 818 } 819 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 820 Result.getNumVectors()); 821 } 822 823 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 824 /// starting at \p MatrixPtr[I][J]. 825 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 826 ShapeInfo MatrixShape, Value *I, Value *J, 827 ShapeInfo ResultShape, Type *EltTy, 828 IRBuilder<> &Builder) { 829 830 Value *Offset = Builder.CreateAdd( 831 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 832 833 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 834 Value *EltPtr = 835 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 836 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 837 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 838 ResultShape.NumColumns); 839 Type *TilePtrTy = PointerType::get(TileTy, AS); 840 Value *TilePtr = 841 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 842 843 return loadMatrix(TileTy, TilePtr, Align, 844 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 845 ResultShape, Builder); 846 } 847 848 /// Lower a load instruction with shape information. 849 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 850 bool IsVolatile, ShapeInfo Shape) { 851 IRBuilder<> Builder(Inst); 852 finalizeLowering(Inst, 853 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 854 Shape, Builder), 855 Builder); 856 } 857 858 /// Lowers llvm.matrix.column.major.load. 859 /// 860 /// The intrinsic loads a matrix from memory using a stride between columns. 861 void LowerColumnMajorLoad(CallInst *Inst) { 862 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 863 "Intrinsic only supports column-major layout!"); 864 Value *Ptr = Inst->getArgOperand(0); 865 Value *Stride = Inst->getArgOperand(1); 866 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 867 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 868 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 869 } 870 871 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 872 /// MatrixPtr[I][J]. 873 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 874 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 875 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 876 Value *Offset = Builder.CreateAdd( 877 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 878 879 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 880 Value *EltPtr = 881 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 882 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 883 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 884 StoreVal.getNumColumns()); 885 Type *TilePtrTy = PointerType::get(TileTy, AS); 886 Value *TilePtr = 887 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 888 889 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 890 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 891 } 892 893 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 894 /// vectors. 895 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 896 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 897 IRBuilder<> &Builder) { 898 auto VType = cast<VectorType>(Ty); 899 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 900 for (auto Vec : enumerate(StoreVal.vectors())) { 901 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 902 Stride, StoreVal.getStride(), 903 VType->getElementType(), Builder); 904 Builder.CreateAlignedStore(Vec.value(), GEP, 905 getAlignForIndex(Vec.index(), Stride, 906 VType->getElementType(), 907 MAlign), 908 IsVolatile); 909 } 910 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 911 StoreVal.getNumVectors()); 912 } 913 914 /// Lower a store instruction with shape information. 915 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 916 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 917 IRBuilder<> Builder(Inst); 918 auto StoreVal = getMatrix(Matrix, Shape, Builder); 919 finalizeLowering(Inst, 920 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 921 IsVolatile, Builder), 922 Builder); 923 } 924 925 /// Lowers llvm.matrix.column.major.store. 926 /// 927 /// The intrinsic store a matrix back memory using a stride between columns. 928 void LowerColumnMajorStore(CallInst *Inst) { 929 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 930 "Intrinsic only supports column-major layout!"); 931 Value *Matrix = Inst->getArgOperand(0); 932 Value *Ptr = Inst->getArgOperand(1); 933 Value *Stride = Inst->getArgOperand(2); 934 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 935 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 936 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 937 } 938 939 // Set elements I..I+NumElts-1 to Block 940 Value *insertVector(Value *Col, unsigned I, Value *Block, 941 IRBuilder<> &Builder) { 942 943 // First, bring Block to the same size as Col 944 unsigned BlockNumElts = 945 cast<FixedVectorType>(Block->getType())->getNumElements(); 946 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 947 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 948 949 Block = Builder.CreateShuffleVector( 950 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 951 952 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 953 // 8, 4, 5, 6 954 SmallVector<int, 16> Mask; 955 unsigned i; 956 for (i = 0; i < I; i++) 957 Mask.push_back(i); 958 959 unsigned VecNumElts = 960 cast<FixedVectorType>(Col->getType())->getNumElements(); 961 for (; i < I + BlockNumElts; i++) 962 Mask.push_back(i - I + VecNumElts); 963 964 for (; i < VecNumElts; i++) 965 Mask.push_back(i); 966 967 return Builder.CreateShuffleVector(Col, Block, Mask); 968 } 969 970 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 971 IRBuilder<> &Builder, bool AllowContraction, 972 unsigned &NumComputeOps) { 973 NumComputeOps += getNumOps(A->getType()); 974 if (!Sum) 975 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 976 977 if (UseFPOp) { 978 if (AllowContraction) { 979 // Use fmuladd for floating point operations and let the backend decide 980 // if that's profitable. 981 Function *FMulAdd = Intrinsic::getDeclaration( 982 Func.getParent(), Intrinsic::fmuladd, A->getType()); 983 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 984 } 985 NumComputeOps += getNumOps(A->getType()); 986 Value *Mul = Builder.CreateFMul(A, B); 987 return Builder.CreateFAdd(Sum, Mul); 988 } 989 990 NumComputeOps += getNumOps(A->getType()); 991 Value *Mul = Builder.CreateMul(A, B); 992 return Builder.CreateAdd(Sum, Mul); 993 } 994 995 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 996 /// users with shape information, there's nothing to do: they will use the 997 /// cached value when they are lowered. For other users, \p Matrix is 998 /// flattened and the uses are updated to use it. Also marks \p Inst for 999 /// deletion. 1000 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1001 IRBuilder<> &Builder) { 1002 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1003 1004 ToRemove.push_back(Inst); 1005 Value *Flattened = nullptr; 1006 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1007 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1008 if (!Flattened) 1009 Flattened = Matrix.embedInVector(Builder); 1010 U.set(Flattened); 1011 } 1012 } 1013 } 1014 1015 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1016 /// addition. 1017 /// 1018 /// We can fold a transpose into the operand that is used to extract scalars. 1019 /// This is the first operands with row-major and the second with 1020 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1021 /// operand is transposed. 1022 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1023 const MatrixTy &B, bool AllowContraction, 1024 IRBuilder<> &Builder, bool IsTiled, 1025 bool IsScalarMatrixTransposed) { 1026 const unsigned VF = std::max<unsigned>( 1027 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1028 .getFixedSize() / 1029 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1030 1U); 1031 unsigned R = Result.getNumRows(); 1032 unsigned C = Result.getNumColumns(); 1033 unsigned M = A.getNumColumns(); 1034 1035 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1036 assert(A.isColumnMajor() == B.isColumnMajor() && 1037 Result.isColumnMajor() == A.isColumnMajor() && 1038 "operands must agree on matrix layout"); 1039 unsigned NumComputeOps = 0; 1040 if (A.isColumnMajor()) { 1041 // Multiply columns from the first operand with scalars from the second 1042 // operand. Then move along the K axes and accumulate the columns. With 1043 // this the adds can be vectorized without reassociation. 1044 for (unsigned J = 0; J < C; ++J) { 1045 unsigned BlockSize = VF; 1046 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1047 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1048 1049 for (unsigned I = 0; I < R; I += BlockSize) { 1050 // Gradually lower the vectorization factor to cover the remainder. 1051 while (I + BlockSize > R) 1052 BlockSize /= 2; 1053 1054 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1055 : nullptr; 1056 for (unsigned K = 0; K < M; ++K) { 1057 Value *L = A.extractVector(I, K, BlockSize, Builder); 1058 Value *RH = Builder.CreateExtractElement( 1059 B.getColumn(IsScalarMatrixTransposed ? K : J), 1060 IsScalarMatrixTransposed ? J : K); 1061 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1062 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1063 Result.getElementType()->isFloatingPointTy(), 1064 Builder, AllowContraction, NumComputeOps); 1065 } 1066 Result.setVector(J, 1067 insertVector(Result.getVector(J), I, Sum, Builder)); 1068 } 1069 } 1070 } else { 1071 // Multiply rows from the second operand with scalars from the first 1072 // operand. Then move along the K axes and accumulate the rows. With this 1073 // the adds can be vectorized without reassociation. 1074 for (unsigned I = 0; I < R; ++I) { 1075 unsigned BlockSize = VF; 1076 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1077 for (unsigned J = 0; J < C; J += BlockSize) { 1078 // Gradually lower the vectorization factor to cover the remainder. 1079 while (J + BlockSize > C) 1080 BlockSize /= 2; 1081 1082 Value *Sum = nullptr; 1083 for (unsigned K = 0; K < M; ++K) { 1084 Value *R = B.extractVector(K, J, BlockSize, Builder); 1085 Value *LH = Builder.CreateExtractElement( 1086 A.getVector(IsScalarMatrixTransposed ? K : I), 1087 IsScalarMatrixTransposed ? I : K); 1088 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1089 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1090 IsFP, Builder, AllowContraction, NumComputeOps); 1091 } 1092 Result.setVector(I, 1093 insertVector(Result.getVector(I), J, Sum, Builder)); 1094 } 1095 } 1096 } 1097 Result.addNumComputeOps(NumComputeOps); 1098 } 1099 1100 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1101 /// copying it to a new location. This new or otherwise the original location 1102 /// is returned. 1103 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1104 CallInst *MatMul) { 1105 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1106 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1107 1108 // If we can statically determine noalias we're good. 1109 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1110 return Load->getPointerOperand(); 1111 1112 // Create code to check if the memory locations of the Load and Store 1113 // overlap and if they do, copy Load's operand to a new buffer. 1114 1115 // First, create new blocks for 2n part of the check and the copy. 1116 BasicBlock *Check0 = MatMul->getParent(); 1117 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1118 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1119 // as we adjust Check0 and Check1's branches. 1120 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1121 for (BasicBlock *Succ : successors(Check0)) 1122 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1123 1124 BasicBlock *Check1 = 1125 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1126 nullptr, "alias_cont"); 1127 BasicBlock *Copy = 1128 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1129 nullptr, "copy"); 1130 BasicBlock *Fusion = 1131 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1132 nullptr, "no_alias"); 1133 1134 // Check if the loaded memory location begins before the end of the store 1135 // location. If the condition holds, they might overlap, otherwise they are 1136 // guaranteed to not overlap. 1137 IRBuilder<> Builder(MatMul); 1138 Check0->getTerminator()->eraseFromParent(); 1139 Builder.SetInsertPoint(Check0); 1140 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1141 Value *StoreBegin = Builder.CreatePtrToInt( 1142 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1143 Value *StoreEnd = Builder.CreateAdd( 1144 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1145 "store.end", true, true); 1146 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1147 IntPtrTy, "load.begin"); 1148 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1149 Fusion); 1150 1151 // Check if the store begins before the end of the load location. If the 1152 // condition holds, they alias, otherwise they are guaranteed to not 1153 // overlap. 1154 Check1->getTerminator()->eraseFromParent(); 1155 Builder.SetInsertPoint(Check1, Check1->begin()); 1156 Value *LoadEnd = Builder.CreateAdd( 1157 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1158 "load.end", true, true); 1159 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1160 Fusion); 1161 1162 // Copy load operand to new alloca. 1163 Builder.SetInsertPoint(Copy, Copy->begin()); 1164 AllocaInst *NewLd = 1165 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1166 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1167 Load->getPointerOperand(), Load->getAlign(), 1168 LoadLoc.Size.getValue()); 1169 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1170 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1171 PHI->addIncoming(Load->getPointerOperand(), Check0); 1172 PHI->addIncoming(Load->getPointerOperand(), Check1); 1173 PHI->addIncoming(NewLd, Copy); 1174 1175 // Adjust DT. 1176 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1177 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1178 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1179 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1180 DT->applyUpdates(DTUpdates); 1181 return PHI; 1182 } 1183 1184 bool isFusionProfitable(CallInst *MatMul) { 1185 if (ForceFusion) 1186 return true; 1187 1188 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1189 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1190 1191 const unsigned R = LShape.NumRows; 1192 const unsigned C = RShape.NumColumns; 1193 const unsigned M = LShape.NumColumns; 1194 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1195 1196 const unsigned VF = std::max<unsigned>( 1197 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1198 .getFixedSize() / 1199 EltType->getPrimitiveSizeInBits().getFixedSize(), 1200 1U); 1201 1202 // Cost model for tiling 1203 // 1204 // For tiling to be beneficial, we need reuse either along the R or 1205 // the C axis. We vectorize along the R axis so that means at least 1206 // 3 elements. 1207 // TODO: Also consider cost of copying if operands alias. 1208 if (R <= VF && C == 1) 1209 return false; 1210 // Then we need enough elements to exceed the number of vector 1211 // registers we have. Note that this is an oversimplification since 1212 // fusing also takes some extra loads which may exceed the number of 1213 // reloads necessary. 1214 unsigned Op0Regs = (R + VF - 1) / VF * M; 1215 unsigned Op1Regs = (M + VF - 1) / VF * C; 1216 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1217 } 1218 1219 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1220 MatrixTy Res; 1221 auto *ColumType = FixedVectorType::get(EltType, R); 1222 for (unsigned I = 0; I < C; ++I) 1223 Res.addVector(ConstantAggregateZero::get(ColumType)); 1224 return Res; 1225 } 1226 1227 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1228 Value *RPtr, ShapeInfo RShape, StoreInst *Store, 1229 bool AllowContract) { 1230 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1231 1232 // Create the main tiling loop nest. 1233 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1234 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1235 Instruction *InsertI = cast<Instruction>(MatMul); 1236 BasicBlock *Start = InsertI->getParent(); 1237 BasicBlock *End = 1238 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1239 IRBuilder<> Builder(MatMul); 1240 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1241 1242 Type *TileVecTy = 1243 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1244 MatrixTy TileResult; 1245 // Insert in the inner loop header. 1246 Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1247 // Create PHI nodes for the result columns to accumulate across iterations. 1248 SmallVector<PHINode *, 4> ColumnPhis; 1249 for (unsigned I = 0; I < TileSize; I++) { 1250 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1251 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1252 TI.RowLoopHeader->getSingleSuccessor()); 1253 TileResult.addVector(Phi); 1254 ColumnPhis.push_back(Phi); 1255 } 1256 1257 // Insert in the inner loop body, which computes 1258 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1259 Builder.SetInsertPoint(InnerBody->getTerminator()); 1260 // Load tiles of the operands. 1261 MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1262 {TileSize, TileSize}, EltType, Builder); 1263 MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1264 {TileSize, TileSize}, EltType, Builder); 1265 emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true, false); 1266 // Store result after the inner loop is done. 1267 Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1268 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1269 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1270 TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1271 1272 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1273 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1274 1275 // Force unrolling of a few iterations of the inner loop, to make sure there 1276 // is enough work per iteration. 1277 // FIXME: The unroller should make this decision directly instead, but 1278 // currently the cost-model is not up to the task. 1279 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1280 addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1281 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1282 } 1283 1284 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1285 StoreInst *Store, 1286 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1287 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1288 "Tiling only supported for column-major matrixes at the moment!"); 1289 if (!isFusionProfitable(MatMul)) 1290 return; 1291 1292 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1293 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1294 1295 const unsigned R = LShape.NumRows; 1296 const unsigned C = RShape.NumColumns; 1297 const unsigned M = LShape.NumColumns; 1298 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1299 1300 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1301 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1302 Value *CPtr = Store->getPointerOperand(); 1303 1304 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1305 MatMul->hasAllowContract()); 1306 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1307 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store, 1308 AllowContract); 1309 else { 1310 IRBuilder<> Builder(Store); 1311 for (unsigned J = 0; J < C; J += TileSize) 1312 for (unsigned I = 0; I < R; I += TileSize) { 1313 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1314 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1315 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1316 1317 for (unsigned K = 0; K < M; K += TileSize) { 1318 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1319 MatrixTy A = 1320 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1321 LShape, Builder.getInt64(I), Builder.getInt64(K), 1322 {TileR, TileM}, EltType, Builder); 1323 MatrixTy B = 1324 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1325 RShape, Builder.getInt64(K), Builder.getInt64(J), 1326 {TileM, TileC}, EltType, Builder); 1327 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true, false); 1328 } 1329 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1330 Builder.getInt64(I), Builder.getInt64(J), EltType, 1331 Builder); 1332 } 1333 } 1334 1335 // Mark eliminated instructions as fused and remove them. 1336 FusedInsts.insert(Store); 1337 FusedInsts.insert(MatMul); 1338 Store->eraseFromParent(); 1339 MatMul->eraseFromParent(); 1340 if (LoadOp0->hasNUses(0)) { 1341 FusedInsts.insert(LoadOp0); 1342 LoadOp0->eraseFromParent(); 1343 } 1344 if (LoadOp1->hasNUses(0)) { 1345 FusedInsts.insert(LoadOp1); 1346 LoadOp1->eraseFromParent(); 1347 } 1348 } 1349 1350 /// Try to lower matrix multiply chains by fusing operations. 1351 /// 1352 /// Call finalizeLowering on lowered instructions. Instructions that are 1353 /// completely eliminated by fusion are added to \p FusedInsts. 1354 void LowerMatrixMultiplyFused(CallInst *MatMul, 1355 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1356 if (!FuseMatrix || !DT) 1357 return; 1358 1359 assert(AA && LI && "Analyses should be available"); 1360 1361 Value *A = MatMul->getArgOperand(0); 1362 Value *B = MatMul->getArgOperand(1); 1363 1364 // We can fold the transpose into the operand that is used to fetch scalars. 1365 Value *T; 1366 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1367 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1368 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1369 IRBuilder<> Builder(MatMul); 1370 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1371 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1372 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1373 const unsigned R = LShape.NumRows; 1374 const unsigned M = LShape.NumColumns; 1375 const unsigned C = RShape.NumColumns; 1376 1377 MatrixTy MA; 1378 MatrixTy MB; 1379 1380 Value *Transpose; 1381 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1382 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1383 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1384 Transpose = B; 1385 } else { 1386 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1387 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1388 Transpose = A; 1389 } 1390 1391 // Initialize the output 1392 MatrixTy Result(R, C, EltType); 1393 1394 bool AllowContract = 1395 AllowContractEnabled || 1396 (isa<FPMathOperator>(MatMul) && MatMul->hasAllowContract()); 1397 emitMatrixMultiply(Result, MA, MB, AllowContract, Builder, false, true); 1398 1399 FusedInsts.insert(MatMul); 1400 FusedInsts.insert(cast<Instruction>(Transpose)); 1401 if (Transpose->hasOneUse()) 1402 ToRemove.push_back(cast<Instruction>(Transpose)); 1403 finalizeLowering(MatMul, Result, Builder); 1404 // TODO: add a fake entry for the folded instruction so that this is 1405 // included in the expression in the remark. 1406 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1407 return; 1408 } 1409 1410 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1411 return; 1412 1413 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1414 // since the single store user will be lowered as part of this. 1415 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1416 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1417 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1418 if (LoadOp0 && LoadOp1 && Store) { 1419 // The store address must dominate the MatMul instruction, otherwise 1420 // we create invalid IR. 1421 // FIXME: See if we can hoist the store address computation. 1422 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1423 if (AddrI && (!DT->dominates(AddrI, MatMul))) 1424 return; 1425 1426 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1427 return; 1428 } 1429 } 1430 1431 /// Lowers llvm.matrix.multiply. 1432 void LowerMultiply(CallInst *MatMul) { 1433 IRBuilder<> Builder(MatMul); 1434 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1435 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1436 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1437 1438 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1439 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1440 assert(Lhs.getElementType() == Rhs.getElementType() && 1441 "Matrix multiply argument element types do not match."); 1442 1443 const unsigned R = LShape.NumRows; 1444 const unsigned C = RShape.NumColumns; 1445 assert(LShape.NumColumns == RShape.NumRows); 1446 1447 // Initialize the output 1448 MatrixTy Result(R, C, EltType); 1449 assert(Lhs.getElementType() == Result.getElementType() && 1450 "Matrix multiply result element type does not match arguments."); 1451 1452 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1453 MatMul->hasAllowContract()); 1454 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false, false); 1455 finalizeLowering(MatMul, Result, Builder); 1456 } 1457 1458 /// Lowers llvm.matrix.transpose. 1459 void LowerTranspose(CallInst *Inst) { 1460 MatrixTy Result; 1461 IRBuilder<> Builder(Inst); 1462 Value *InputVal = Inst->getArgOperand(0); 1463 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1464 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1465 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1466 1467 const unsigned NewNumVecs = 1468 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1469 const unsigned NewNumElts = 1470 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1471 1472 for (unsigned I = 0; I < NewNumVecs; ++I) { 1473 // Build a single result vector. First initialize it. 1474 Value *ResultVector = UndefValue::get( 1475 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1476 // Go through the old elements and insert it into the resulting vector. 1477 for (auto J : enumerate(InputMatrix.vectors())) { 1478 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1479 // Row and column indices are transposed. 1480 ResultVector = 1481 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1482 } 1483 Result.addVector(ResultVector); 1484 } 1485 1486 // TODO: Improve estimate of operations needed for transposes. Currently we 1487 // just count the insertelement/extractelement instructions, but do not 1488 // account for later simplifications/combines. 1489 finalizeLowering( 1490 Inst, 1491 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1492 Builder); 1493 } 1494 1495 /// Lower load instructions, if shape information is available. 1496 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1497 auto I = ShapeMap.find(Inst); 1498 if (I == ShapeMap.end()) 1499 return false; 1500 1501 LowerLoad(Inst, Ptr, Inst->getAlign(), 1502 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1503 I->second); 1504 return true; 1505 } 1506 1507 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1508 IRBuilder<> &Builder) { 1509 auto I = ShapeMap.find(StoredVal); 1510 if (I == ShapeMap.end()) 1511 return false; 1512 1513 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1514 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1515 I->second); 1516 return true; 1517 } 1518 1519 /// Lower binary operators, if shape information is available. 1520 bool VisitBinaryOperator(BinaryOperator *Inst) { 1521 auto I = ShapeMap.find(Inst); 1522 if (I == ShapeMap.end()) 1523 return false; 1524 1525 Value *Lhs = Inst->getOperand(0); 1526 Value *Rhs = Inst->getOperand(1); 1527 1528 IRBuilder<> Builder(Inst); 1529 ShapeInfo &Shape = I->second; 1530 1531 MatrixTy Result; 1532 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1533 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1534 assert(A.isColumnMajor() == B.isColumnMajor() && 1535 Result.isColumnMajor() == A.isColumnMajor() && 1536 "operands must agree on matrix layout"); 1537 1538 // Helper to perform binary op on vectors. 1539 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1540 switch (Inst->getOpcode()) { 1541 case Instruction::Add: 1542 return Builder.CreateAdd(LHS, RHS); 1543 case Instruction::Mul: 1544 return Builder.CreateMul(LHS, RHS); 1545 case Instruction::Sub: 1546 return Builder.CreateSub(LHS, RHS); 1547 case Instruction::FAdd: 1548 return Builder.CreateFAdd(LHS, RHS); 1549 case Instruction::FMul: 1550 return Builder.CreateFMul(LHS, RHS); 1551 case Instruction::FSub: 1552 return Builder.CreateFSub(LHS, RHS); 1553 default: 1554 llvm_unreachable("Unsupported binary operator for matrix"); 1555 } 1556 }; 1557 1558 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1559 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1560 1561 finalizeLowering(Inst, 1562 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1563 Result.getNumVectors()), 1564 Builder); 1565 return true; 1566 } 1567 1568 /// Lower unary operators, if shape information is available. 1569 bool VisitUnaryOperator(UnaryOperator *Inst) { 1570 auto I = ShapeMap.find(Inst); 1571 if (I == ShapeMap.end()) 1572 return false; 1573 1574 Value *Op = Inst->getOperand(0); 1575 1576 IRBuilder<> Builder(Inst); 1577 ShapeInfo &Shape = I->second; 1578 1579 MatrixTy Result; 1580 MatrixTy M = getMatrix(Op, Shape, Builder); 1581 1582 // Helper to perform unary op on vectors. 1583 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1584 switch (Inst->getOpcode()) { 1585 case Instruction::FNeg: 1586 return Builder.CreateFNeg(Op); 1587 default: 1588 llvm_unreachable("Unsupported unary operator for matrix"); 1589 } 1590 }; 1591 1592 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1593 Result.addVector(BuildVectorOp(M.getVector(I))); 1594 1595 finalizeLowering(Inst, 1596 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1597 Result.getNumVectors()), 1598 Builder); 1599 return true; 1600 } 1601 1602 /// Helper to linearize a matrix expression tree into a string. Currently 1603 /// matrix expressions are linarized by starting at an expression leaf and 1604 /// linearizing bottom up. 1605 struct ExprLinearizer { 1606 unsigned LengthToBreak = 100; 1607 std::string Str; 1608 raw_string_ostream Stream; 1609 unsigned LineLength = 0; 1610 const DataLayout &DL; 1611 1612 /// Mapping from instructions to matrixes. It is used to identify 1613 /// matrix instructions. 1614 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1615 1616 /// Mapping from values to the leaves of all expressions that the value is 1617 /// part of. 1618 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1619 1620 /// Set of matrix expressions in the scope of a given DISubprogram. 1621 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1622 1623 /// Leaf node of the expression to linearize. 1624 Value *Leaf; 1625 1626 /// Used to keep track of sub-expressions that get reused while linearizing 1627 /// the expression. Re-used sub-expressions are marked as (reused). 1628 SmallPtrSet<Value *, 8> ReusedExprs; 1629 1630 ExprLinearizer(const DataLayout &DL, 1631 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1632 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1633 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1634 Value *Leaf) 1635 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1636 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1637 1638 void indent(unsigned N) { 1639 LineLength += N; 1640 for (unsigned i = 0; i < N; i++) 1641 Stream << " "; 1642 } 1643 1644 void lineBreak() { 1645 Stream << "\n"; 1646 LineLength = 0; 1647 } 1648 1649 void maybeIndent(unsigned Indent) { 1650 if (LineLength >= LengthToBreak) 1651 lineBreak(); 1652 1653 if (LineLength == 0) 1654 indent(Indent); 1655 } 1656 1657 void write(StringRef S) { 1658 LineLength += S.size(); 1659 Stream << S; 1660 } 1661 1662 Value *getUnderlyingObjectThroughLoads(Value *V) { 1663 if (Value *Ptr = getPointerOperand(V)) 1664 return getUnderlyingObjectThroughLoads(Ptr); 1665 else if (V->getType()->isPointerTy()) 1666 return getUnderlyingObject(V); 1667 return V; 1668 } 1669 1670 /// Returns true if \p V is a matrix value in the given subprogram. 1671 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1672 1673 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1674 /// \p SS. 1675 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1676 auto M = Inst2Matrix.find(V); 1677 if (M == Inst2Matrix.end()) 1678 SS << "unknown"; 1679 else { 1680 SS << M->second.getNumRows(); 1681 SS << "x"; 1682 SS << M->second.getNumColumns(); 1683 } 1684 } 1685 1686 /// Write the called function name. Handles calls to llvm.matrix.* 1687 /// specially: we write the name, followed by the dimensions of the input 1688 /// matrixes, followed by the scalar type name. 1689 void writeFnName(CallInst *CI) { 1690 if (!CI->getCalledFunction()) 1691 write("<no called fn>"); 1692 else { 1693 StringRef Name = CI->getCalledFunction()->getName(); 1694 if (!Name.startswith("llvm.matrix")) { 1695 write(Name); 1696 return; 1697 } 1698 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1699 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1700 .drop_front(StringRef("llvm.matrix.").size())); 1701 write("."); 1702 std::string Tmp; 1703 raw_string_ostream SS(Tmp); 1704 1705 switch (II->getIntrinsicID()) { 1706 case Intrinsic::matrix_multiply: 1707 prettyPrintMatrixType(II->getOperand(0), SS); 1708 SS << "."; 1709 prettyPrintMatrixType(II->getOperand(1), SS); 1710 SS << "." << *II->getType()->getScalarType(); 1711 break; 1712 case Intrinsic::matrix_transpose: 1713 prettyPrintMatrixType(II->getOperand(0), SS); 1714 SS << "." << *II->getType()->getScalarType(); 1715 break; 1716 case Intrinsic::matrix_column_major_load: 1717 prettyPrintMatrixType(II, SS); 1718 SS << "." << *II->getType()->getScalarType(); 1719 break; 1720 case Intrinsic::matrix_column_major_store: 1721 prettyPrintMatrixType(II->getOperand(0), SS); 1722 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1723 break; 1724 default: 1725 llvm_unreachable("Unhandled case"); 1726 } 1727 SS.flush(); 1728 write(Tmp); 1729 } 1730 } 1731 1732 unsigned getNumShapeArgs(CallInst *CI) const { 1733 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1734 switch (II->getIntrinsicID()) { 1735 case Intrinsic::matrix_multiply: 1736 return 3; 1737 case Intrinsic::matrix_transpose: 1738 return 2; 1739 case Intrinsic::matrix_column_major_load: 1740 case Intrinsic::matrix_column_major_store: 1741 return 3; 1742 default: 1743 return 0; 1744 } 1745 } 1746 return 0; 1747 } 1748 1749 /// Special printing for values: for pointers, we print if they refer to an 1750 /// (function) external address or a stack address, for other values we 1751 /// either print the constant or "scalar"/"matrix" for other values. 1752 void write(Value *V) { 1753 V = getUnderlyingObjectThroughLoads(V); 1754 if (V->getType()->isPointerTy()) { 1755 if (isa<AllocaInst>(V)) { 1756 Stream << "stack addr"; 1757 LineLength += StringRef("stack addr").size(); 1758 } else { 1759 Stream << "addr"; 1760 LineLength += StringRef("addr").size(); 1761 } 1762 if (!V->getName().empty()) { 1763 Stream << " %" << V->getName() << ""; 1764 LineLength += V->getName().size() + 2; 1765 } 1766 return; 1767 } 1768 1769 std::string Tmp; 1770 raw_string_ostream TmpStream(Tmp); 1771 1772 if (auto *CI = dyn_cast<ConstantInt>(V)) 1773 TmpStream << CI->getValue(); 1774 else if (isa<Constant>(V)) 1775 TmpStream << "constant"; 1776 else { 1777 if (isMatrix(V)) 1778 TmpStream << "matrix"; 1779 else 1780 TmpStream << "scalar"; 1781 } 1782 TmpStream.flush(); 1783 Tmp = std::string(StringRef(Tmp).trim()); 1784 LineLength += Tmp.size(); 1785 Stream << Tmp; 1786 } 1787 1788 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1789 /// Expressions that are re-used multiple times are prefixed with (reused) 1790 /// at the re-used root instruction. 1791 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1792 bool ParentShared) { 1793 auto *I = cast<Instruction>(Expr); 1794 maybeIndent(Indent); 1795 SmallVector<Value *, 8> Ops; 1796 1797 // Is Expr shared with other expression leaves? 1798 bool ExprShared = false; 1799 1800 // Deal with shared subtrees. Mark them as shared, if required. 1801 if (!ParentShared) { 1802 auto SI = Shared.find(Expr); 1803 assert(SI != Shared.end() && SI->second.count(Leaf)); 1804 1805 for (Value *S : SI->second) { 1806 if (S == Leaf) 1807 continue; 1808 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1809 write("shared with remark at line " + std::to_string(DL.getLine()) + 1810 " column " + std::to_string(DL.getCol()) + " ("); 1811 } 1812 ExprShared = SI->second.size() > 1; 1813 } 1814 1815 bool Reused = !ReusedExprs.insert(Expr).second; 1816 if (Reused && !ParentReused) 1817 write("(reused) "); 1818 1819 if (auto *CI = dyn_cast<CallInst>(I)) { 1820 writeFnName(CI); 1821 1822 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1823 } else if (isa<BitCastInst>(Expr)) { 1824 // Special case bitcasts, which are used to materialize matrixes from 1825 // non-matrix ops. 1826 write("matrix"); 1827 return; 1828 } else { 1829 Ops.append(I->value_op_begin(), I->value_op_end()); 1830 write(std::string(I->getOpcodeName())); 1831 } 1832 1833 write(std::string("(")); 1834 1835 unsigned NumOpsToBreak = 1; 1836 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1837 NumOpsToBreak = 2; 1838 1839 for (Value *Op : Ops) { 1840 if (Ops.size() > NumOpsToBreak) 1841 lineBreak(); 1842 1843 maybeIndent(Indent + 1); 1844 if (isMatrix(Op)) 1845 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1846 else 1847 write(Op); 1848 if (Op != Ops.back()) 1849 write(", "); 1850 } 1851 1852 write(")"); 1853 } 1854 1855 const std::string &getResult() { 1856 Stream.flush(); 1857 return Str; 1858 } 1859 }; 1860 1861 /// Generate remarks for matrix operations in a function. To generate remarks 1862 /// for matrix expressions, the following approach is used: 1863 /// 1. Use the inlined-at debug information to group matrix operations to the 1864 /// DISubprograms they are contained in. 1865 /// 2. Collect leaves of matrix expressions (done in 1866 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1867 // mapping. Leaves are lowered matrix instructions without other matrix 1868 // users (like stores) in the current subprogram. 1869 /// 3. For each leaf, create a remark containing a linearizied version of the 1870 /// matrix expression. The expression is linearized by a recursive 1871 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1872 /// that multiple leaves can share sub-expressions. Shared subexpressions 1873 /// are explicitly marked as shared(). 1874 struct RemarkGenerator { 1875 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1876 OptimizationRemarkEmitter &ORE; 1877 Function &Func; 1878 const DataLayout &DL; 1879 1880 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1881 OptimizationRemarkEmitter &ORE, Function &Func) 1882 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1883 DL(Func.getParent()->getDataLayout()) {} 1884 1885 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1886 /// instructions in Inst2Matrix returning void or without any users in 1887 /// \p ExprsInSubprogram. Currently that should only include stores. 1888 SmallVector<Value *, 4> 1889 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1890 SmallVector<Value *, 4> Leaves; 1891 for (auto *Expr : ExprsInSubprogram) 1892 if (Expr->getType()->isVoidTy() || 1893 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1894 return ExprsInSubprogram.count(U); 1895 })) 1896 Leaves.push_back(Expr); 1897 return Leaves; 1898 } 1899 1900 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1901 /// to all visited expressions in \p Shared. Limit the matrix operations to 1902 /// the ones in \p ExprsInSubprogram. 1903 void collectSharedInfo(Value *Leaf, Value *V, 1904 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1905 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1906 1907 if (!ExprsInSubprogram.count(V)) 1908 return; 1909 1910 auto I = Shared.insert({V, {}}); 1911 I.first->second.insert(Leaf); 1912 1913 for (Value *Op : cast<Instruction>(V)->operand_values()) 1914 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1915 } 1916 1917 /// Calculate the number of exclusive and shared op counts for expression 1918 /// starting at \p V. Expressions used multiple times are counted once. 1919 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1920 std::pair<OpInfoTy, OpInfoTy> 1921 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1922 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1923 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1924 if (!ExprsInSubprogram.count(Root)) 1925 return {}; 1926 1927 // Already counted this expression. Stop. 1928 if (!ReusedExprs.insert(Root).second) 1929 return {}; 1930 1931 OpInfoTy SharedCount; 1932 OpInfoTy Count; 1933 1934 auto I = Shared.find(Root); 1935 auto CM = Inst2Matrix.find(Root); 1936 if (I->second.size() == 1) 1937 Count = CM->second.getOpInfo(); 1938 else 1939 SharedCount = CM->second.getOpInfo(); 1940 1941 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1942 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1943 Count += C.first; 1944 SharedCount += C.second; 1945 } 1946 return {Count, SharedCount}; 1947 } 1948 1949 void emitRemarks() { 1950 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1951 return; 1952 1953 // Map matrix operations to their containting subprograms, by traversing 1954 // the inlinedAt chain. If the function does not have a DISubprogram, we 1955 // only map them to the containing function. 1956 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1957 for (auto &KV : Inst2Matrix) { 1958 if (Func.getSubprogram()) { 1959 auto *I = cast<Instruction>(KV.first); 1960 DILocation *Context = I->getDebugLoc(); 1961 while (Context) { 1962 auto I = 1963 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1964 I.first->second.push_back(KV.first); 1965 Context = DebugLoc(Context).getInlinedAt(); 1966 } 1967 } else { 1968 auto I = Subprog2Exprs.insert({nullptr, {}}); 1969 I.first->second.push_back(KV.first); 1970 } 1971 } 1972 for (auto &KV : Subprog2Exprs) { 1973 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1974 KV.second.end()); 1975 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1976 1977 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1978 for (Value *Leaf : Leaves) 1979 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1980 1981 // Generate remarks for each leaf. 1982 for (auto *L : Leaves) { 1983 1984 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1985 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1986 while (Context) { 1987 if (getSubprogram(Context->getScope()) == KV.first) { 1988 Loc = Context; 1989 break; 1990 } 1991 Context = DebugLoc(Context).getInlinedAt(); 1992 } 1993 1994 SmallPtrSet<Value *, 8> ReusedExprs; 1995 OpInfoTy Counts, SharedCounts; 1996 std::tie(Counts, SharedCounts) = 1997 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1998 1999 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2000 cast<Instruction>(L)->getParent()); 2001 2002 Rem << "Lowered with "; 2003 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2004 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2005 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2006 << " compute ops"; 2007 2008 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2009 SharedCounts.NumComputeOps > 0) { 2010 Rem << ",\nadditionally " 2011 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2012 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2013 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2014 << " compute ops" 2015 << " are shared with other expressions"; 2016 } 2017 2018 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2019 ORE.emit(Rem); 2020 } 2021 } 2022 } 2023 2024 std::string 2025 linearize(Value *L, 2026 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2027 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2028 const DataLayout &DL) { 2029 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2030 Lin.linearizeExpr(L, 0, false, false); 2031 return Lin.getResult(); 2032 } 2033 }; 2034 }; 2035 } // namespace 2036 2037 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2038 FunctionAnalysisManager &AM) { 2039 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2040 OptimizationRemarkEmitter *ORE = nullptr; 2041 AAResults *AA = nullptr; 2042 DominatorTree *DT = nullptr; 2043 LoopInfo *LI = nullptr; 2044 2045 if (!Minimal) { 2046 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2047 AA = &AM.getResult<AAManager>(F); 2048 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2049 LI = &AM.getResult<LoopAnalysis>(F); 2050 } 2051 2052 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2053 if (LMT.Visit()) { 2054 PreservedAnalyses PA; 2055 if (!Minimal) { 2056 PA.preserve<LoopAnalysis>(); 2057 PA.preserve<DominatorTreeAnalysis>(); 2058 } 2059 return PA; 2060 } 2061 return PreservedAnalyses::all(); 2062 } 2063 2064 namespace { 2065 2066 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2067 public: 2068 static char ID; 2069 2070 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2071 initializeLowerMatrixIntrinsicsLegacyPassPass( 2072 *PassRegistry::getPassRegistry()); 2073 } 2074 2075 bool runOnFunction(Function &F) override { 2076 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2077 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 2078 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 2079 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2080 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2081 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2082 bool C = LMT.Visit(); 2083 return C; 2084 } 2085 2086 void getAnalysisUsage(AnalysisUsage &AU) const override { 2087 AU.addRequired<TargetTransformInfoWrapperPass>(); 2088 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 2089 AU.addRequired<AAResultsWrapperPass>(); 2090 AU.addRequired<DominatorTreeWrapperPass>(); 2091 AU.addPreserved<DominatorTreeWrapperPass>(); 2092 AU.addRequired<LoopInfoWrapperPass>(); 2093 AU.addPreserved<LoopInfoWrapperPass>(); 2094 } 2095 }; 2096 } // namespace 2097 2098 static const char pass_name[] = "Lower the matrix intrinsics"; 2099 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2100 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2101 false, false) 2102 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 2103 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 2104 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 2105 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2106 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2107 false, false) 2108 2109 Pass *llvm::createLowerMatrixIntrinsicsPass() { 2110 return new LowerMatrixIntrinsicsLegacyPass(); 2111 } 2112 2113 namespace { 2114 2115 /// A lightweight version of the matrix lowering pass that only requires TTI. 2116 /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2117 /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2118 /// example with -O0. 2119 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2120 public: 2121 static char ID; 2122 2123 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2124 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2125 *PassRegistry::getPassRegistry()); 2126 } 2127 2128 bool runOnFunction(Function &F) override { 2129 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2130 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2131 bool C = LMT.Visit(); 2132 return C; 2133 } 2134 2135 void getAnalysisUsage(AnalysisUsage &AU) const override { 2136 AU.addRequired<TargetTransformInfoWrapperPass>(); 2137 AU.setPreservesCFG(); 2138 } 2139 }; 2140 } // namespace 2141 2142 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2143 char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2144 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2145 "lower-matrix-intrinsics-minimal", pass_name_minimal, 2146 false, false) 2147 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2148 "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2149 false) 2150 2151 Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2152 return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2153 } 2154