1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10480093f4SDimitry Andric // 11480093f4SDimitry Andric // TODO: 125ffd83dbSDimitry Andric // * Improve fusion: 135ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 145ffd83dbSDimitry Andric // transposed. 155ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns 165ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias. 17480093f4SDimitry Andric // 18480093f4SDimitry Andric //===----------------------------------------------------------------------===// 19480093f4SDimitry Andric 20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 22*0fca6ea1SDimitry Andric #include "llvm/ADT/ScopeExit.h" 235f757f3fSDimitry Andric #include "llvm/ADT/SmallSet.h" 24480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 255ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 265ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 2781ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 285ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 29480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 305ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 31480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 32480093f4SDimitry Andric #include "llvm/IR/CFG.h" 33480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 345ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h" 35480093f4SDimitry Andric #include "llvm/IR/Function.h" 36480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 37480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 38480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 39fe6060f1SDimitry Andric #include "llvm/IR/MatrixBuilder.h" 40480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 415ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h" 425ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 43480093f4SDimitry Andric #include "llvm/Support/Debug.h" 445ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 46e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 47480093f4SDimitry Andric 48bdd1243dSDimitry Andric #include <cmath> 49bdd1243dSDimitry Andric 50480093f4SDimitry Andric using namespace llvm; 51480093f4SDimitry Andric using namespace PatternMatch; 52480093f4SDimitry Andric 53480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 54480093f4SDimitry Andric 555ffd83dbSDimitry Andric static cl::opt<bool> 565ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 575ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions.")); 585ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles. 595ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize( 605ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 615ffd83dbSDimitry Andric cl::desc( 625ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles.")); 63e8d8bef9SDimitry Andric static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 64e8d8bef9SDimitry Andric cl::Hidden, 65e8d8bef9SDimitry Andric cl::desc("Generate loop nest for tiling.")); 665ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion( 675ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden, 685ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable.")); 69480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 70480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 71480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 72480093f4SDimitry Andric "result in different results, due to less rounding error.")); 73480093f4SDimitry Andric 7406c3fb27SDimitry Andric static cl::opt<bool> 7506c3fb27SDimitry Andric VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 7606c3fb27SDimitry Andric cl::desc("Enable/disable matrix shape verification."), 7706c3fb27SDimitry Andric cl::init(false)); 7806c3fb27SDimitry Andric 795ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 805ffd83dbSDimitry Andric 815ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout( 825ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 835ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"), 845ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 855ffd83dbSDimitry Andric "Use column-major layout"), 865ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 875ffd83dbSDimitry Andric "Use row-major layout"))); 885ffd83dbSDimitry Andric 89bdd1243dSDimitry Andric static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 90bdd1243dSDimitry Andric cl::init(false)); 91bdd1243dSDimitry Andric 925ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the 935ffd83dbSDimitry Andric /// attached subprogram for a local scope. 945ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) { 955ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 965ffd83dbSDimitry Andric return Subprogram; 975ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram(); 985ffd83dbSDimitry Andric } 995ffd83dbSDimitry Andric 100bdd1243dSDimitry Andric /// Erase \p V from \p BB and move \II forward to avoid invalidating 101bdd1243dSDimitry Andric /// iterators. 102bdd1243dSDimitry Andric static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 103bdd1243dSDimitry Andric BasicBlock &BB) { 104bdd1243dSDimitry Andric auto *Inst = cast<Instruction>(V); 105bdd1243dSDimitry Andric // Still used, don't erase. 106bdd1243dSDimitry Andric if (!Inst->use_empty()) 107bdd1243dSDimitry Andric return; 108bdd1243dSDimitry Andric if (II != BB.rend() && Inst == &*II) 109bdd1243dSDimitry Andric ++II; 110bdd1243dSDimitry Andric Inst->eraseFromParent(); 111bdd1243dSDimitry Andric } 112bdd1243dSDimitry Andric 113bdd1243dSDimitry Andric /// Return true if V is a splat of a value (which is used when multiplying a 114bdd1243dSDimitry Andric /// matrix with a scalar). 115bdd1243dSDimitry Andric static bool isSplat(Value *V) { 116bdd1243dSDimitry Andric if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 117bdd1243dSDimitry Andric return SV->isZeroEltSplat(); 118bdd1243dSDimitry Andric return false; 119bdd1243dSDimitry Andric } 120bdd1243dSDimitry Andric 121bdd1243dSDimitry Andric /// Match any mul operation (fp or integer). 122bdd1243dSDimitry Andric template <typename LTy, typename RTy> 123bdd1243dSDimitry Andric auto m_AnyMul(const LTy &L, const RTy &R) { 124bdd1243dSDimitry Andric return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 125bdd1243dSDimitry Andric } 126bdd1243dSDimitry Andric 127bdd1243dSDimitry Andric /// Match any add operation (fp or integer). 128bdd1243dSDimitry Andric template <typename LTy, typename RTy> 129bdd1243dSDimitry Andric auto m_AnyAdd(const LTy &L, const RTy &R) { 130bdd1243dSDimitry Andric return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 131bdd1243dSDimitry Andric } 132bdd1243dSDimitry Andric 133480093f4SDimitry Andric namespace { 134480093f4SDimitry Andric 1355ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 1365ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 1375ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors. 1385ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements. 1395ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column 1405ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column 1415ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function 1425ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the 1435ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix). 144480093f4SDimitry Andric // 1455ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below 146480093f4SDimitry Andric // 147480093f4SDimitry Andric // 0 1 2 3 148480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 149480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 150480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 151480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 152480093f4SDimitry Andric 153480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 154480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 1555ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns 156480093f4SDimitry Andric // of the sub-matrix. 157480093f4SDimitry Andric // 1585ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 159480093f4SDimitry Andric // -> just returns Base 1605ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 161480093f4SDimitry Andric // -> returns Base + (1 * 4) 1625ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 163480093f4SDimitry Andric // -> returns Base + (2 * 4) 164480093f4SDimitry Andric // 165480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 166480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 167480093f4SDimitry Andric // 168480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 169480093f4SDimitry Andric // Base Col 1 Col 2 170480093f4SDimitry Andric // | | | 171480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 172480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 173480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 174480093f4SDimitry Andric // 1755ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 1765ffd83dbSDimitry Andric unsigned NumElements, Type *EltType, 177480093f4SDimitry Andric IRBuilder<> &Builder) { 178480093f4SDimitry Andric 179480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 1805ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 1815ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector."); 182480093f4SDimitry Andric 1835ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride. 1845ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 185480093f4SDimitry Andric 1865ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation, 1875ffd83dbSDimitry Andric // if we select vector 0. 1885ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 1895ffd83dbSDimitry Andric VecStart = BasePtr; 190480093f4SDimitry Andric else 1915ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 192480093f4SDimitry Andric 1935f757f3fSDimitry Andric return VecStart; 194480093f4SDimitry Andric } 195480093f4SDimitry Andric 196*0fca6ea1SDimitry Andric namespace { 197*0fca6ea1SDimitry Andric struct ShapeInfo { 198*0fca6ea1SDimitry Andric unsigned NumRows; 199*0fca6ea1SDimitry Andric unsigned NumColumns; 200*0fca6ea1SDimitry Andric 201*0fca6ea1SDimitry Andric bool IsColumnMajor; 202*0fca6ea1SDimitry Andric 203*0fca6ea1SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 204*0fca6ea1SDimitry Andric : NumRows(NumRows), NumColumns(NumColumns), 205*0fca6ea1SDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 206*0fca6ea1SDimitry Andric 207*0fca6ea1SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 208*0fca6ea1SDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 209*0fca6ea1SDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {} 210*0fca6ea1SDimitry Andric 211*0fca6ea1SDimitry Andric bool operator==(const ShapeInfo &other) { 212*0fca6ea1SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 213*0fca6ea1SDimitry Andric } 214*0fca6ea1SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 215*0fca6ea1SDimitry Andric 216*0fca6ea1SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 217*0fca6ea1SDimitry Andric /// are != 0. 218*0fca6ea1SDimitry Andric operator bool() const { 219*0fca6ea1SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 220*0fca6ea1SDimitry Andric return NumRows != 0; 221*0fca6ea1SDimitry Andric } 222*0fca6ea1SDimitry Andric 223*0fca6ea1SDimitry Andric unsigned getStride() const { 224*0fca6ea1SDimitry Andric if (IsColumnMajor) 225*0fca6ea1SDimitry Andric return NumRows; 226*0fca6ea1SDimitry Andric return NumColumns; 227*0fca6ea1SDimitry Andric } 228*0fca6ea1SDimitry Andric 229*0fca6ea1SDimitry Andric unsigned getNumVectors() const { 230*0fca6ea1SDimitry Andric if (IsColumnMajor) 231*0fca6ea1SDimitry Andric return NumColumns; 232*0fca6ea1SDimitry Andric return NumRows; 233*0fca6ea1SDimitry Andric } 234*0fca6ea1SDimitry Andric 235*0fca6ea1SDimitry Andric /// Returns the transposed shape. 236*0fca6ea1SDimitry Andric ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 237*0fca6ea1SDimitry Andric }; 238*0fca6ea1SDimitry Andric } // namespace 239*0fca6ea1SDimitry Andric 240*0fca6ea1SDimitry Andric static bool isUniformShape(Value *V) { 241*0fca6ea1SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 242*0fca6ea1SDimitry Andric if (!I) 243*0fca6ea1SDimitry Andric return true; 244*0fca6ea1SDimitry Andric 245*0fca6ea1SDimitry Andric switch (I->getOpcode()) { 246*0fca6ea1SDimitry Andric case Instruction::FAdd: 247*0fca6ea1SDimitry Andric case Instruction::FSub: 248*0fca6ea1SDimitry Andric case Instruction::FMul: // Scalar multiply. 249*0fca6ea1SDimitry Andric case Instruction::FNeg: 250*0fca6ea1SDimitry Andric case Instruction::Add: 251*0fca6ea1SDimitry Andric case Instruction::Mul: 252*0fca6ea1SDimitry Andric case Instruction::Sub: 253*0fca6ea1SDimitry Andric return true; 254*0fca6ea1SDimitry Andric default: 255*0fca6ea1SDimitry Andric return false; 256*0fca6ea1SDimitry Andric } 257*0fca6ea1SDimitry Andric } 258*0fca6ea1SDimitry Andric 259*0fca6ea1SDimitry Andric /// Return the ShapeInfo for the result of \p I, it it can be determined. 260*0fca6ea1SDimitry Andric static std::optional<ShapeInfo> 261*0fca6ea1SDimitry Andric computeShapeInfoForInst(Instruction *I, 262*0fca6ea1SDimitry Andric const ValueMap<Value *, ShapeInfo> &ShapeMap) { 263*0fca6ea1SDimitry Andric Value *M; 264*0fca6ea1SDimitry Andric Value *N; 265*0fca6ea1SDimitry Andric Value *K; 266*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( 267*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) 268*0fca6ea1SDimitry Andric return ShapeInfo(M, K); 269*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), 270*0fca6ea1SDimitry Andric m_Value(N)))) { 271*0fca6ea1SDimitry Andric // Flip dimensions. 272*0fca6ea1SDimitry Andric return ShapeInfo(N, M); 273*0fca6ea1SDimitry Andric } 274*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( 275*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), 276*0fca6ea1SDimitry Andric m_Value(N)))) 277*0fca6ea1SDimitry Andric return ShapeInfo(N, M); 278*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( 279*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) 280*0fca6ea1SDimitry Andric return ShapeInfo(M, N); 281*0fca6ea1SDimitry Andric Value *MatrixA; 282*0fca6ea1SDimitry Andric if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { 283*0fca6ea1SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 284*0fca6ea1SDimitry Andric if (OpShape != ShapeMap.end()) 285*0fca6ea1SDimitry Andric return OpShape->second; 286*0fca6ea1SDimitry Andric } 287*0fca6ea1SDimitry Andric 288*0fca6ea1SDimitry Andric if (isUniformShape(I)) { 289*0fca6ea1SDimitry Andric // Find the first operand that has a known shape and use that. 290*0fca6ea1SDimitry Andric for (auto &Op : I->operands()) { 291*0fca6ea1SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 292*0fca6ea1SDimitry Andric if (OpShape != ShapeMap.end()) 293*0fca6ea1SDimitry Andric return OpShape->second; 294*0fca6ea1SDimitry Andric } 295*0fca6ea1SDimitry Andric } 296*0fca6ea1SDimitry Andric return std::nullopt; 297*0fca6ea1SDimitry Andric } 298*0fca6ea1SDimitry Andric 299480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 300480093f4SDimitry Andric /// 301480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 302480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 303480093f4SDimitry Andric /// instructions. 3045ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout). 3055ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout. 306480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 307480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 308480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 309480093f4SDimitry Andric /// a set of column vectors, 3105ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which 3115ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we 3125ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the 3135ffd83dbSDimitry Andric /// intrinsics, this includes stores for example. 314480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 315480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 316480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 317480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 318480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 319480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 320480093f4SDimitry Andric /// obsolete instructions. 321480093f4SDimitry Andric /// 322480093f4SDimitry Andric class LowerMatrixIntrinsics { 323480093f4SDimitry Andric Function &Func; 324480093f4SDimitry Andric const DataLayout &DL; 325480093f4SDimitry Andric const TargetTransformInfo &TTI; 326e8d8bef9SDimitry Andric AliasAnalysis *AA; 327e8d8bef9SDimitry Andric DominatorTree *DT; 328e8d8bef9SDimitry Andric LoopInfo *LI; 329e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE; 330480093f4SDimitry Andric 3315ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 3325ffd83dbSDimitry Andric struct OpInfoTy { 3335ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix. 3345ffd83dbSDimitry Andric unsigned NumStores = 0; 3355ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix. 3365ffd83dbSDimitry Andric unsigned NumLoads = 0; 3375ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix. 3385ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 339fe6060f1SDimitry Andric /// Most of the time transposes can be fused with matrix multiplies or can 340fe6060f1SDimitry Andric /// be folded away via algebraic simplifications. This is the number of 341fe6060f1SDimitry Andric /// transposes that we failed to make "free" via such optimizations. 342fe6060f1SDimitry Andric unsigned NumExposedTransposes = 0; 3435ffd83dbSDimitry Andric 3445ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) { 3455ffd83dbSDimitry Andric NumStores += RHS.NumStores; 3465ffd83dbSDimitry Andric NumLoads += RHS.NumLoads; 3475ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps; 348fe6060f1SDimitry Andric NumExposedTransposes += RHS.NumExposedTransposes; 3495ffd83dbSDimitry Andric return *this; 3505ffd83dbSDimitry Andric } 3515ffd83dbSDimitry Andric }; 3525ffd83dbSDimitry Andric 3535ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or 3545ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type. 3555ffd83dbSDimitry Andric class MatrixTy { 3565ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors; 3575ffd83dbSDimitry Andric 3585ffd83dbSDimitry Andric OpInfoTy OpInfo; 3595ffd83dbSDimitry Andric 3605ffd83dbSDimitry Andric bool IsColumnMajor = true; 361480093f4SDimitry Andric 362480093f4SDimitry Andric public: 36304eeddc0SDimitry Andric MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 3645ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors) 3655ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()), 3665ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 3675ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 3685ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 369480093f4SDimitry Andric 3705ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows; 3715ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J) 37206c3fb27SDimitry Andric addVector(PoisonValue::get(FixedVectorType::get( 3735ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns))); 374480093f4SDimitry Andric } 375480093f4SDimitry Andric 3765ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; } 3775ffd83dbSDimitry Andric Value *getColumn(unsigned i) const { 3785ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 3795ffd83dbSDimitry Andric return Vectors[i]; 3805ffd83dbSDimitry Andric } 3815ffd83dbSDimitry Andric Value *getRow(unsigned i) const { 3825ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes"); 3835ffd83dbSDimitry Andric return Vectors[i]; 3845ffd83dbSDimitry Andric } 385480093f4SDimitry Andric 3865ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; } 387480093f4SDimitry Andric 388e8d8bef9SDimitry Andric Type *getElementType() const { return getVectorTy()->getElementType(); } 3895ffd83dbSDimitry Andric 3905ffd83dbSDimitry Andric unsigned getNumVectors() const { 3915ffd83dbSDimitry Andric if (isColumnMajor()) 3925ffd83dbSDimitry Andric return getNumColumns(); 3935ffd83dbSDimitry Andric return getNumRows(); 3945ffd83dbSDimitry Andric } 3955ffd83dbSDimitry Andric 3965ffd83dbSDimitry Andric unsigned getNumColumns() const { 3975ffd83dbSDimitry Andric if (isColumnMajor()) 3985ffd83dbSDimitry Andric return Vectors.size(); 3995ffd83dbSDimitry Andric else { 4005ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 4015ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 4025ffd83dbSDimitry Andric } 4035ffd83dbSDimitry Andric } 4045ffd83dbSDimitry Andric unsigned getNumRows() const { 4055ffd83dbSDimitry Andric if (isColumnMajor()) { 4065ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 4075ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 4085ffd83dbSDimitry Andric } else 4095ffd83dbSDimitry Andric return Vectors.size(); 4105ffd83dbSDimitry Andric } 4115ffd83dbSDimitry Andric 4125ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); } 4135ffd83dbSDimitry Andric VectorType *getColumnTy() { 4145ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 4155ffd83dbSDimitry Andric return getVectorTy(); 4165ffd83dbSDimitry Andric } 4175ffd83dbSDimitry Andric 418e8d8bef9SDimitry Andric VectorType *getVectorTy() const { 4195ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType()); 4205ffd83dbSDimitry Andric } 421480093f4SDimitry Andric 422480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 4235ffd83dbSDimitry Andric assert(isColumnMajor() && 4245ffd83dbSDimitry Andric "columns() only supported for column-major matrixes"); 4255ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 426480093f4SDimitry Andric } 427480093f4SDimitry Andric 4285ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 4295ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 4305ffd83dbSDimitry Andric } 4315ffd83dbSDimitry Andric 4325ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating 433480093f4SDimitry Andric /// them. 434480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 4355ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0] 4365ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors); 4375ffd83dbSDimitry Andric } 4385ffd83dbSDimitry Andric 4395ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) { 4405ffd83dbSDimitry Andric OpInfo.NumLoads += N; 4415ffd83dbSDimitry Andric return *this; 4425ffd83dbSDimitry Andric } 4435ffd83dbSDimitry Andric 4445ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 4455ffd83dbSDimitry Andric 4465ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) { 4475ffd83dbSDimitry Andric OpInfo.NumStores += N; 4485ffd83dbSDimitry Andric return *this; 4495ffd83dbSDimitry Andric } 4505ffd83dbSDimitry Andric 451fe6060f1SDimitry Andric MatrixTy &addNumExposedTransposes(unsigned N) { 452fe6060f1SDimitry Andric OpInfo.NumExposedTransposes += N; 453fe6060f1SDimitry Andric return *this; 454fe6060f1SDimitry Andric } 455fe6060f1SDimitry Andric 4565ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) { 4575ffd83dbSDimitry Andric OpInfo.NumComputeOps += N; 4585ffd83dbSDimitry Andric return *this; 4595ffd83dbSDimitry Andric } 4605ffd83dbSDimitry Andric 4615ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; } 4625ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; } 4635ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 4645ffd83dbSDimitry Andric 4655ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; } 4665ffd83dbSDimitry Andric 4675ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; } 4685ffd83dbSDimitry Andric 4695ffd83dbSDimitry Andric unsigned getStride() const { 4705ffd83dbSDimitry Andric if (isColumnMajor()) 4715ffd83dbSDimitry Andric return getNumRows(); 4725ffd83dbSDimitry Andric return getNumColumns(); 4735ffd83dbSDimitry Andric } 4745ffd83dbSDimitry Andric 4755ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 4765ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column 4775ffd83dbSDimitry Andric /// vector, otherwise from a row vector. 4785ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 4795ffd83dbSDimitry Andric IRBuilder<> &Builder) const { 4805ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 481972a253aSDimitry Andric assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 482972a253aSDimitry Andric NumElts && 483972a253aSDimitry Andric "Extracted vector will contain poison values"); 4845ffd83dbSDimitry Andric return Builder.CreateShuffleVector( 485e8d8bef9SDimitry Andric Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 4865ffd83dbSDimitry Andric "block"); 487480093f4SDimitry Andric } 488480093f4SDimitry Andric }; 489480093f4SDimitry Andric 490480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 491480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 492480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 4935ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the 494480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 495fe6060f1SDimitry Andric /// using shape information as well. A ValueMap is used so that when 496fe6060f1SDimitry Andric /// sub-passes like optimizeTransposes performs RAUW the map stays 497fe6060f1SDimitry Andric /// up-to-date. 498fe6060f1SDimitry Andric ValueMap<Value *, ShapeInfo> ShapeMap; 499480093f4SDimitry Andric 500480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 501480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 502480093f4SDimitry Andric /// those need to be removed after we finished lowering. 503480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 504480093f4SDimitry Andric 505480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 5065ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 507480093f4SDimitry Andric 508fe6060f1SDimitry Andric private: 509fe6060f1SDimitry Andric static FastMathFlags getFastMathFlags(Instruction *Inst) { 510fe6060f1SDimitry Andric FastMathFlags FMF; 511fe6060f1SDimitry Andric 512fe6060f1SDimitry Andric if (isa<FPMathOperator>(*Inst)) 513fe6060f1SDimitry Andric FMF = Inst->getFastMathFlags(); 514fe6060f1SDimitry Andric 515fe6060f1SDimitry Andric FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 516fe6060f1SDimitry Andric 517fe6060f1SDimitry Andric return FMF; 518fe6060f1SDimitry Andric } 519fe6060f1SDimitry Andric 520480093f4SDimitry Andric public: 5215ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 522e8d8bef9SDimitry Andric AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 523e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE) 524*0fca6ea1SDimitry Andric : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT), 5255ffd83dbSDimitry Andric LI(LI), ORE(ORE) {} 526480093f4SDimitry Andric 5275ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) { 5285ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type"); 5295ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(), 5305ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements()); 5315ffd83dbSDimitry Andric } 5325ffd83dbSDimitry Andric 533fe6060f1SDimitry Andric /// Is this the minimal version executed in the backend pipelines. 534fe6060f1SDimitry Andric bool isMinimal() const { 535fe6060f1SDimitry Andric return !DT; 536fe6060f1SDimitry Andric } 537fe6060f1SDimitry Andric 5385ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on 5395ffd83dbSDimitry Andric /// \p VT * N. 5405ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) { 541bdd1243dSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 542fe6060f1SDimitry Andric double(TTI.getRegisterBitWidth( 543fe6060f1SDimitry Andric TargetTransformInfo::RGK_FixedWidthVector) 544bdd1243dSDimitry Andric .getFixedValue())); 5455ffd83dbSDimitry Andric } 5465ffd83dbSDimitry Andric 5475ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to. 548480093f4SDimitry Andric /// 5495ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 5505ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 5515ffd83dbSDimitry Andric /// into vectors. 5525ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 5535ffd83dbSDimitry Andric IRBuilder<> &Builder) { 554480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 555480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 5565ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() == 5575ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns && 558480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 559480093f4SDimitry Andric 560480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 5615ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape 562480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 563480093f4SDimitry Andric // vector and split it later. 564480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 565480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 5665ffd83dbSDimitry Andric MatrixTy &M = Found->second; 567480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 568480093f4SDimitry Andric // information 569480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 570480093f4SDimitry Andric return M; 571480093f4SDimitry Andric 572480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 573480093f4SDimitry Andric } 574480093f4SDimitry Andric 575480093f4SDimitry Andric // Otherwise split MatrixVal. 576480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 5775ffd83dbSDimitry Andric for (unsigned MaskStart = 0; 5785ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 5795ffd83dbSDimitry Andric MaskStart += SI.getStride()) { 5805ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector( 581e8d8bef9SDimitry Andric MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 5825ffd83dbSDimitry Andric "split"); 583480093f4SDimitry Andric SplitVecs.push_back(V); 584480093f4SDimitry Andric } 585480093f4SDimitry Andric 586480093f4SDimitry Andric return {SplitVecs}; 587480093f4SDimitry Andric } 588480093f4SDimitry Andric 589480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 590480093f4SDimitry Andric /// for instructions that support it. 591480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 592480093f4SDimitry Andric assert(Shape && "Shape not set"); 593480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 594480093f4SDimitry Andric return false; 595480093f4SDimitry Andric 596480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 597480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 59806c3fb27SDimitry Andric if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 59906c3fb27SDimitry Andric SIter->second.NumColumns != Shape.NumColumns)) { 60006c3fb27SDimitry Andric errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 60106c3fb27SDimitry Andric << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 60206c3fb27SDimitry Andric << Shape.NumColumns << ") for " << *V << "\n"; 60306c3fb27SDimitry Andric report_fatal_error( 60406c3fb27SDimitry Andric "Matrix shape verification failed, compilation aborted!"); 60506c3fb27SDimitry Andric } 60606c3fb27SDimitry Andric 607480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 608480093f4SDimitry Andric << SIter->second.NumRows << " " 609480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 610480093f4SDimitry Andric return false; 611480093f4SDimitry Andric } 612480093f4SDimitry Andric 613480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 614480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 615480093f4SDimitry Andric << " for " << *V << "\n"); 616480093f4SDimitry Andric return true; 617480093f4SDimitry Andric } 618480093f4SDimitry Andric 619480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 620480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 621480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 622480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 623480093f4SDimitry Andric if (!Inst) 624480093f4SDimitry Andric return false; 625480093f4SDimitry Andric 626480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 627480093f4SDimitry Andric if (II) 628480093f4SDimitry Andric switch (II->getIntrinsicID()) { 629480093f4SDimitry Andric case Intrinsic::matrix_multiply: 630480093f4SDimitry Andric case Intrinsic::matrix_transpose: 6315ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 6325ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 633480093f4SDimitry Andric return true; 634480093f4SDimitry Andric default: 635480093f4SDimitry Andric return false; 636480093f4SDimitry Andric } 637480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 638480093f4SDimitry Andric } 639480093f4SDimitry Andric 640480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 641480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 642480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 643480093f4SDimitry Andric /// shapes of operands. 644480093f4SDimitry Andric SmallVector<Instruction *, 32> 645480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 646480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 647480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 648480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 649480093f4SDimitry Andric // list. 650480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 651480093f4SDimitry Andric while (!WorkList.empty()) { 652e8d8bef9SDimitry Andric Instruction *Inst = WorkList.pop_back_val(); 653480093f4SDimitry Andric 654480093f4SDimitry Andric // New entry, set the value and insert operands 655480093f4SDimitry Andric bool Propagate = false; 656*0fca6ea1SDimitry Andric if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) 657*0fca6ea1SDimitry Andric Propagate = setShapeInfo(Inst, *SI); 658480093f4SDimitry Andric 659480093f4SDimitry Andric if (Propagate) { 660480093f4SDimitry Andric NewWorkList.push_back(Inst); 661480093f4SDimitry Andric for (auto *User : Inst->users()) 662480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 663480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 664480093f4SDimitry Andric } 665480093f4SDimitry Andric } 666480093f4SDimitry Andric 667480093f4SDimitry Andric return NewWorkList; 668480093f4SDimitry Andric } 669480093f4SDimitry Andric 670480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 671480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 672480093f4SDimitry Andric SmallVector<Instruction *, 32> 673480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 674480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 675480093f4SDimitry Andric 676480093f4SDimitry Andric auto pushInstruction = [](Value *V, 677480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 678480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 679480093f4SDimitry Andric if (I) 680480093f4SDimitry Andric WorkList.push_back(I); 681480093f4SDimitry Andric }; 682480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 683480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 684480093f4SDimitry Andric // worklist. 685480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 686480093f4SDimitry Andric while (!WorkList.empty()) { 687e8d8bef9SDimitry Andric Value *V = WorkList.pop_back_val(); 688480093f4SDimitry Andric 689480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 690480093f4SDimitry Andric if (!isa<Instruction>(V)) 691480093f4SDimitry Andric continue; 692480093f4SDimitry Andric 693480093f4SDimitry Andric Value *MatrixA; 694480093f4SDimitry Andric Value *MatrixB; 695480093f4SDimitry Andric Value *M; 696480093f4SDimitry Andric Value *N; 697480093f4SDimitry Andric Value *K; 698480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 699480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 700480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 701480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 702480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 703480093f4SDimitry Andric 704480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 705480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 706480093f4SDimitry Andric 707480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 708480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 709480093f4SDimitry Andric // Flip dimensions. 710480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 711480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 7125ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 7135ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 714480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 715480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 716480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 717480093f4SDimitry Andric } 718480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 7195ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 720480093f4SDimitry Andric // Nothing to do, no matrix input. 721480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 722480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 723480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 724480093f4SDimitry Andric } else if (isUniformShape(V)) { 725480093f4SDimitry Andric // Propagate to all operands. 726480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 727480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 728480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 729480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 730480093f4SDimitry Andric } 731480093f4SDimitry Andric } 732480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 733480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 734480093f4SDimitry Andric // propagation. 735480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 736480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 737480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 738480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 739480093f4SDimitry Andric } 740480093f4SDimitry Andric return NewWorkList; 741480093f4SDimitry Andric } 742480093f4SDimitry Andric 743bdd1243dSDimitry Andric /// (Op0 op Op1)^T -> Op0^T op Op1^T 744bdd1243dSDimitry Andric /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 745bdd1243dSDimitry Andric /// them on both sides of \p Operation. 746bdd1243dSDimitry Andric Instruction *distributeTransposes( 747bdd1243dSDimitry Andric Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 748bdd1243dSDimitry Andric MatrixBuilder &Builder, 749bdd1243dSDimitry Andric function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 750bdd1243dSDimitry Andric Operation) { 751bdd1243dSDimitry Andric Value *T0 = Builder.CreateMatrixTranspose( 752bdd1243dSDimitry Andric Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 753bdd1243dSDimitry Andric // We are being run after shape prop, add shape for newly created 754bdd1243dSDimitry Andric // instructions so that we lower them later. 755bdd1243dSDimitry Andric setShapeInfo(T0, Shape0.t()); 756bdd1243dSDimitry Andric Value *T1 = Builder.CreateMatrixTranspose( 757bdd1243dSDimitry Andric Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 758bdd1243dSDimitry Andric setShapeInfo(T1, Shape1.t()); 759bdd1243dSDimitry Andric return Operation(T0, Shape0.t(), T1, Shape1.t()); 760bdd1243dSDimitry Andric } 761bdd1243dSDimitry Andric 762bdd1243dSDimitry Andric void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 763fe6060f1SDimitry Andric // We need to remove Old from the ShapeMap otherwise RAUW will replace it 764fe6060f1SDimitry Andric // with New. We should only add New it it supportsShapeInfo so we insert 765fe6060f1SDimitry Andric // it conditionally instead. 766fe6060f1SDimitry Andric auto S = ShapeMap.find(&Old); 767fe6060f1SDimitry Andric if (S != ShapeMap.end()) { 768fe6060f1SDimitry Andric ShapeMap.erase(S); 769fe6060f1SDimitry Andric if (supportsShapeInfo(New)) 770fe6060f1SDimitry Andric ShapeMap.insert({New, S->second}); 771fe6060f1SDimitry Andric } 772fe6060f1SDimitry Andric Old.replaceAllUsesWith(New); 773fe6060f1SDimitry Andric } 774fe6060f1SDimitry Andric 775bdd1243dSDimitry Andric /// Sink a top-level transpose inside matmuls and adds. 776bdd1243dSDimitry Andric /// This creates and erases instructions as needed, and returns the newly 777bdd1243dSDimitry Andric /// created instruction while updating the iterator to avoid invalidation. If 778bdd1243dSDimitry Andric /// this returns nullptr, no new instruction was created. 779bdd1243dSDimitry Andric Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 780bdd1243dSDimitry Andric BasicBlock &BB = *I.getParent(); 781fe6060f1SDimitry Andric IRBuilder<> IB(&I); 78281ad6265SDimitry Andric MatrixBuilder Builder(IB); 783fe6060f1SDimitry Andric 784fe6060f1SDimitry Andric Value *TA, *TAMA, *TAMB; 785fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 786bdd1243dSDimitry Andric if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 787bdd1243dSDimitry Andric m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 788bdd1243dSDimitry Andric return nullptr; 789fe6060f1SDimitry Andric 790fe6060f1SDimitry Andric // Transpose of a transpose is a nop 791fe6060f1SDimitry Andric Value *TATA; 792bdd1243dSDimitry Andric if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 793bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, TATA); 794bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB); 795bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB); 796bdd1243dSDimitry Andric return nullptr; 797bdd1243dSDimitry Andric } 798bdd1243dSDimitry Andric 799bdd1243dSDimitry Andric // k^T -> k 800bdd1243dSDimitry Andric if (isSplat(TA)) { 801bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, TA); 802bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB); 803bdd1243dSDimitry Andric return nullptr; 804fe6060f1SDimitry Andric } 805fe6060f1SDimitry Andric 806fe6060f1SDimitry Andric // (A * B)^t -> B^t * A^t 807fe6060f1SDimitry Andric // RxK KxC CxK KxR 808bdd1243dSDimitry Andric if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 809fe6060f1SDimitry Andric m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 810fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C)))) { 811bdd1243dSDimitry Andric auto NewInst = distributeTransposes( 812bdd1243dSDimitry Andric TAMB, {K, C}, TAMA, {R, K}, Builder, 813bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 814bdd1243dSDimitry Andric return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 815bdd1243dSDimitry Andric Shape0.NumColumns, 816bdd1243dSDimitry Andric Shape1.NumColumns, "mmul"); 817bdd1243dSDimitry Andric }); 818bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst); 819bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB); 820bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB); 821bdd1243dSDimitry Andric return NewInst; 822fe6060f1SDimitry Andric } 823fe6060f1SDimitry Andric 824bdd1243dSDimitry Andric // Same as above, but with a mul, which occurs when multiplied 825bdd1243dSDimitry Andric // with a scalar. 826bdd1243dSDimitry Andric // (A * k)^t -> A^t * k 827bdd1243dSDimitry Andric // R x C RxC 828bdd1243dSDimitry Andric if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 829bdd1243dSDimitry Andric (isSplat(TAMA) || isSplat(TAMB))) { 830bdd1243dSDimitry Andric IRBuilder<> LocalBuilder(&I); 831bdd1243dSDimitry Andric // We know that the transposed operand is of shape RxC. 832bdd1243dSDimitry Andric // An when multiplied with a scalar, the shape is preserved. 833bdd1243dSDimitry Andric auto NewInst = distributeTransposes( 834bdd1243dSDimitry Andric TAMA, {R, C}, TAMB, {R, C}, Builder, 835bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 836bdd1243dSDimitry Andric bool IsFP = I.getType()->isFPOrFPVectorTy(); 837bdd1243dSDimitry Andric auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 838bdd1243dSDimitry Andric : LocalBuilder.CreateMul(T0, T1, "mmul"); 839bdd1243dSDimitry Andric auto *Result = cast<Instruction>(Mul); 840bdd1243dSDimitry Andric setShapeInfo(Result, Shape0); 841bdd1243dSDimitry Andric return Result; 842bdd1243dSDimitry Andric }); 843bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst); 844bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB); 845bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB); 846bdd1243dSDimitry Andric return NewInst; 847fe6060f1SDimitry Andric } 848fe6060f1SDimitry Andric 849bdd1243dSDimitry Andric // (A + B)^t -> A^t + B^t 850bdd1243dSDimitry Andric // RxC RxC CxR CxR 851bdd1243dSDimitry Andric if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 852bdd1243dSDimitry Andric IRBuilder<> LocalBuilder(&I); 853bdd1243dSDimitry Andric auto NewInst = distributeTransposes( 854bdd1243dSDimitry Andric TAMA, {R, C}, TAMB, {R, C}, Builder, 855bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 85606c3fb27SDimitry Andric bool IsFP = I.getType()->isFPOrFPVectorTy(); 85706c3fb27SDimitry Andric auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 85806c3fb27SDimitry Andric : LocalBuilder.CreateAdd(T0, T1, "madd"); 85906c3fb27SDimitry Andric 86006c3fb27SDimitry Andric auto *Result = cast<Instruction>(Add); 86106c3fb27SDimitry Andric setShapeInfo(Result, Shape0); 86206c3fb27SDimitry Andric return Result; 863bdd1243dSDimitry Andric }); 864bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst); 865bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB); 866bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB); 867bdd1243dSDimitry Andric return NewInst; 868bdd1243dSDimitry Andric } 869bdd1243dSDimitry Andric 870bdd1243dSDimitry Andric return nullptr; 871bdd1243dSDimitry Andric } 872bdd1243dSDimitry Andric 873bdd1243dSDimitry Andric void liftTranspose(Instruction &I) { 874bdd1243dSDimitry Andric // Erase dead Instructions after lifting transposes from binops. 875bdd1243dSDimitry Andric auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { 876bdd1243dSDimitry Andric if (T.use_empty()) 877bdd1243dSDimitry Andric T.eraseFromParent(); 878bdd1243dSDimitry Andric if (A->use_empty()) 879bdd1243dSDimitry Andric cast<Instruction>(A)->eraseFromParent(); 880bdd1243dSDimitry Andric if (A != B && B->use_empty()) 881bdd1243dSDimitry Andric cast<Instruction>(B)->eraseFromParent(); 882bdd1243dSDimitry Andric }; 883bdd1243dSDimitry Andric 884fe6060f1SDimitry Andric Value *A, *B, *AT, *BT; 885fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 886fe6060f1SDimitry Andric // A^t * B ^t -> (B * A)^t 88781ad6265SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 888fe6060f1SDimitry Andric m_Value(A), m_Value(B), m_ConstantInt(R), 889fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C))) && 890fe6060f1SDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 891fe6060f1SDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 89281ad6265SDimitry Andric IRBuilder<> IB(&I); 89381ad6265SDimitry Andric MatrixBuilder Builder(IB); 894fe6060f1SDimitry Andric Value *M = Builder.CreateMatrixMultiply( 895fe6060f1SDimitry Andric BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 896fe6060f1SDimitry Andric setShapeInfo(M, {C, R}); 897bdd1243dSDimitry Andric Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 898bdd1243dSDimitry Andric R->getZExtValue()); 899bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst); 900bdd1243dSDimitry Andric CleanupBinOp(I, A, B); 901fe6060f1SDimitry Andric } 902*0fca6ea1SDimitry Andric // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If 903*0fca6ea1SDimitry Andric // the shape of the second transpose is different, there's a shape conflict 904*0fca6ea1SDimitry Andric // which gets resolved by picking the shape of the first operand. 905bdd1243dSDimitry Andric else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 906bdd1243dSDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 907bdd1243dSDimitry Andric m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 908bdd1243dSDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 909*0fca6ea1SDimitry Andric m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { 910bdd1243dSDimitry Andric IRBuilder<> Builder(&I); 911*0fca6ea1SDimitry Andric auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); 912*0fca6ea1SDimitry Andric setShapeInfo(Add, {R, C}); 913bdd1243dSDimitry Andric MatrixBuilder MBuilder(Builder); 914bdd1243dSDimitry Andric Instruction *NewInst = MBuilder.CreateMatrixTranspose( 915*0fca6ea1SDimitry Andric Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); 916bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst); 917*0fca6ea1SDimitry Andric assert(computeShapeInfoForInst(NewInst, ShapeMap) == 918*0fca6ea1SDimitry Andric computeShapeInfoForInst(&I, ShapeMap) && 919*0fca6ea1SDimitry Andric "Shape of new instruction doesn't match original shape."); 920bdd1243dSDimitry Andric CleanupBinOp(I, A, B); 921*0fca6ea1SDimitry Andric assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) == 922*0fca6ea1SDimitry Andric ShapeMap[Add] && 923*0fca6ea1SDimitry Andric "Shape of updated addition doesn't match cached shape."); 924bdd1243dSDimitry Andric } 925bdd1243dSDimitry Andric } 926bdd1243dSDimitry Andric 927bdd1243dSDimitry Andric /// Try moving transposes in order to fold them away or into multiplies. 928bdd1243dSDimitry Andric void optimizeTransposes() { 929bdd1243dSDimitry Andric // First sink all transposes inside matmuls and adds, hoping that we end up 930bdd1243dSDimitry Andric // with NN, NT or TN variants. 931bdd1243dSDimitry Andric for (BasicBlock &BB : reverse(Func)) { 932bdd1243dSDimitry Andric for (auto II = BB.rbegin(); II != BB.rend();) { 933bdd1243dSDimitry Andric Instruction &I = *II; 934bdd1243dSDimitry Andric // We may remove II. By default continue on the next/prev instruction. 935bdd1243dSDimitry Andric ++II; 936bdd1243dSDimitry Andric if (Instruction *NewInst = sinkTranspose(I, II)) 937bdd1243dSDimitry Andric II = std::next(BasicBlock::reverse_iterator(NewInst)); 938bdd1243dSDimitry Andric } 939bdd1243dSDimitry Andric } 940bdd1243dSDimitry Andric 941bdd1243dSDimitry Andric // If we have a TT matmul or a TT add, lift the transpose. We may be able 942bdd1243dSDimitry Andric // to fold into consuming multiply or add. 943bdd1243dSDimitry Andric for (BasicBlock &BB : Func) { 944bdd1243dSDimitry Andric for (Instruction &I : llvm::make_early_inc_range(BB)) { 945bdd1243dSDimitry Andric liftTranspose(I); 946fe6060f1SDimitry Andric } 947fe6060f1SDimitry Andric } 948fe6060f1SDimitry Andric } 949fe6060f1SDimitry Andric 950480093f4SDimitry Andric bool Visit() { 951480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 952480093f4SDimitry Andric 953480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 954480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 955480093f4SDimitry Andric for (BasicBlock &BB : Func) 956480093f4SDimitry Andric for (Instruction &Inst : BB) { 957480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 958480093f4SDimitry Andric if (!II) 959480093f4SDimitry Andric continue; 960480093f4SDimitry Andric 961480093f4SDimitry Andric switch (II->getIntrinsicID()) { 962480093f4SDimitry Andric case Intrinsic::matrix_multiply: 963480093f4SDimitry Andric case Intrinsic::matrix_transpose: 9645ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 9655ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 966480093f4SDimitry Andric WorkList.push_back(&Inst); 967480093f4SDimitry Andric break; 968480093f4SDimitry Andric default: 969480093f4SDimitry Andric break; 970480093f4SDimitry Andric } 971480093f4SDimitry Andric } 972fe6060f1SDimitry Andric 973fe6060f1SDimitry Andric // Avoid unnecessary work if there are no matrix intrinsics in the function. 974fe6060f1SDimitry Andric if (WorkList.empty()) 975fe6060f1SDimitry Andric return false; 976fe6060f1SDimitry Andric 977480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 978480093f4SDimitry Andric while (!WorkList.empty()) { 979480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 980480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 981480093f4SDimitry Andric } 982fe6060f1SDimitry Andric 983fe6060f1SDimitry Andric if (!isMinimal()) { 984fe6060f1SDimitry Andric optimizeTransposes(); 985bdd1243dSDimitry Andric if (PrintAfterTransposeOpt) { 986fe6060f1SDimitry Andric dbgs() << "Dump after matrix transpose optimization:\n"; 987bdd1243dSDimitry Andric Func.print(dbgs()); 988bdd1243dSDimitry Andric } 989480093f4SDimitry Andric } 990480093f4SDimitry Andric 991480093f4SDimitry Andric bool Changed = false; 9925ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts; 9935ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts; 994*0fca6ea1SDimitry Andric SmallVector<IntrinsicInst *, 16> LifetimeEnds; 995480093f4SDimitry Andric 9965ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for 9975ffd83dbSDimitry Andric // fusion (currently only matrix multiplies). 9985ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 9995ffd83dbSDimitry Andric for (auto *BB : RPOT) 10005ffd83dbSDimitry Andric for (Instruction &I : *BB) { 1001*0fca6ea1SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) 1002*0fca6ea1SDimitry Andric LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); 10035ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end()) 10045ffd83dbSDimitry Andric continue; 10055ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 10065ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I)); 10075ffd83dbSDimitry Andric MatrixInsts.push_back(&I); 10085ffd83dbSDimitry Andric } 10095ffd83dbSDimitry Andric 101006c3fb27SDimitry Andric // Second, try to lower any dot products 10115ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts; 10125ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts) 101306c3fb27SDimitry Andric lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 101406c3fb27SDimitry Andric 101506c3fb27SDimitry Andric // Third, try to fuse candidates. 101606c3fb27SDimitry Andric for (CallInst *CI : MaybeFusableInsts) 1017*0fca6ea1SDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); 101806c3fb27SDimitry Andric 10195ffd83dbSDimitry Andric Changed = !FusedInsts.empty(); 10205ffd83dbSDimitry Andric 102106c3fb27SDimitry Andric // Fourth, lower remaining instructions with shape information. 10225ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) { 10235ffd83dbSDimitry Andric if (FusedInsts.count(Inst)) 10245ffd83dbSDimitry Andric continue; 10255ffd83dbSDimitry Andric 10265ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 10275ffd83dbSDimitry Andric 10285ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 1029480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 1030480093f4SDimitry Andric 1031480093f4SDimitry Andric Value *Op1; 1032480093f4SDimitry Andric Value *Op2; 10335ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1034480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 1035e8d8bef9SDimitry Andric if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1036e8d8bef9SDimitry Andric Changed |= VisitUnaryOperator(UnOp); 10375ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1)))) 10385ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 10395ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 10405ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1041480093f4SDimitry Andric } 10425ffd83dbSDimitry Andric 1043e8d8bef9SDimitry Andric if (ORE) { 1044e8d8bef9SDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 10455ffd83dbSDimitry Andric RemarkGen.emitRemarks(); 1046e8d8bef9SDimitry Andric } 1047480093f4SDimitry Andric 1048fe6060f1SDimitry Andric // Delete the instructions backwards, as it has a reduced likelihood of 1049fe6060f1SDimitry Andric // having to update as many def-use and use-def chains. 1050fe6060f1SDimitry Andric // 1051fe6060f1SDimitry Andric // Because we add to ToRemove during fusion we can't guarantee that defs 105281ad6265SDimitry Andric // are before uses. Change uses to poison temporarily as these should get 1053fe6060f1SDimitry Andric // removed as well. 1054fe6060f1SDimitry Andric // 105581ad6265SDimitry Andric // For verification, we keep track of where we changed uses to poison in 105681ad6265SDimitry Andric // PoisonedInsts and then check that we in fact remove them. 105781ad6265SDimitry Andric SmallSet<Instruction *, 16> PoisonedInsts; 1058fe6060f1SDimitry Andric for (auto *Inst : reverse(ToRemove)) { 1059349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 106081ad6265SDimitry Andric if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 106181ad6265SDimitry Andric PoisonedInsts.insert(Poisoned); 106281ad6265SDimitry Andric U.set(PoisonValue::get(Inst->getType())); 1063fe6060f1SDimitry Andric } 1064480093f4SDimitry Andric Inst->eraseFromParent(); 106581ad6265SDimitry Andric PoisonedInsts.erase(Inst); 1066fe6060f1SDimitry Andric } 106781ad6265SDimitry Andric if (!PoisonedInsts.empty()) { 106881ad6265SDimitry Andric // If we didn't remove all poisoned instructions, it's a hard error. 106981ad6265SDimitry Andric dbgs() << "Poisoned but present instructions:\n"; 107081ad6265SDimitry Andric for (auto *I : PoisonedInsts) 1071fe6060f1SDimitry Andric dbgs() << *I << "\n"; 107281ad6265SDimitry Andric llvm_unreachable("Poisoned but instruction not removed"); 1073fe6060f1SDimitry Andric } 1074480093f4SDimitry Andric 1075480093f4SDimitry Andric return Changed; 1076480093f4SDimitry Andric } 1077480093f4SDimitry Andric 1078480093f4SDimitry Andric /// Replace intrinsic calls 1079480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 1080480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1081480093f4SDimitry Andric return false; 1082480093f4SDimitry Andric 1083480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 1084480093f4SDimitry Andric case Intrinsic::matrix_multiply: 1085480093f4SDimitry Andric LowerMultiply(Inst); 1086480093f4SDimitry Andric break; 1087480093f4SDimitry Andric case Intrinsic::matrix_transpose: 1088480093f4SDimitry Andric LowerTranspose(Inst); 1089480093f4SDimitry Andric break; 10905ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 10915ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst); 1092480093f4SDimitry Andric break; 10935ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 10945ffd83dbSDimitry Andric LowerColumnMajorStore(Inst); 1095480093f4SDimitry Andric break; 1096480093f4SDimitry Andric default: 1097480093f4SDimitry Andric return false; 1098480093f4SDimitry Andric } 1099480093f4SDimitry Andric return true; 1100480093f4SDimitry Andric } 1101480093f4SDimitry Andric 11025ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them. 11035ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 11045ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For 11055ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial 11065ffd83dbSDimitry Andric /// alignment and the element size in bytes. 11075ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 11085ffd83dbSDimitry Andric MaybeAlign A) const { 11095ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 11105ffd83dbSDimitry Andric if (Idx == 0) 11115ffd83dbSDimitry Andric return InitialAlign; 11125ffd83dbSDimitry Andric 11135ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 11145ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 11155ffd83dbSDimitry Andric uint64_t StrideInBytes = 11165ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8; 11175ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes); 11185ffd83dbSDimitry Andric } 11195ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8); 11205ffd83dbSDimitry Andric } 11215ffd83dbSDimitry Andric 11225ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 11235ffd83dbSDimitry Andric /// vectors. 11245ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 11255ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1126fe6060f1SDimitry Andric auto *VType = cast<VectorType>(Ty); 1127fe6060f1SDimitry Andric Type *EltTy = VType->getElementType(); 1128fe6060f1SDimitry Andric Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 11295f757f3fSDimitry Andric Value *EltPtr = Ptr; 11305ffd83dbSDimitry Andric MatrixTy Result; 11315ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1132349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 1133349cc55cSDimitry Andric EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1134349cc55cSDimitry Andric Stride, Shape.getStride(), EltTy, Builder); 11355ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad( 1136fe6060f1SDimitry Andric VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 11375ffd83dbSDimitry Andric IsVolatile, "col.load"); 11385ffd83dbSDimitry Andric 11395ffd83dbSDimitry Andric Result.addVector(Vector); 11405ffd83dbSDimitry Andric } 11415ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 11425ffd83dbSDimitry Andric Result.getNumVectors()); 1143480093f4SDimitry Andric } 1144480093f4SDimitry Andric 11455ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 11465ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J]. 11475ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 11485ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J, 11495ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy, 11505ffd83dbSDimitry Andric IRBuilder<> &Builder) { 11515ffd83dbSDimitry Andric 11525ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 11535ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 11545ffd83dbSDimitry Andric 11555f757f3fSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 11565ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 11575ffd83dbSDimitry Andric ResultShape.NumColumns); 11585ffd83dbSDimitry Andric 11595f757f3fSDimitry Andric return loadMatrix(TileTy, TileStart, Align, 11605ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, 11615ffd83dbSDimitry Andric ResultShape, Builder); 1162480093f4SDimitry Andric } 1163480093f4SDimitry Andric 11645ffd83dbSDimitry Andric /// Lower a load instruction with shape information. 11655ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 11665ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) { 11675ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 11685ffd83dbSDimitry Andric finalizeLowering(Inst, 11695ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 11705ffd83dbSDimitry Andric Shape, Builder), 11715ffd83dbSDimitry Andric Builder); 11725ffd83dbSDimitry Andric } 11735ffd83dbSDimitry Andric 11745ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load. 1175480093f4SDimitry Andric /// 1176480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 11775ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) { 11785ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 11795ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 1180480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 1181480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 11825ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 11835ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1184480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1185480093f4SDimitry Andric } 1186480093f4SDimitry Andric 11875ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 11885ffd83dbSDimitry Andric /// MatrixPtr[I][J]. 11895ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 11905ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 11915ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 11925ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 11935ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 11945ffd83dbSDimitry Andric 11955f757f3fSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 11965ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 11975ffd83dbSDimitry Andric StoreVal.getNumColumns()); 11985ffd83dbSDimitry Andric 11995f757f3fSDimitry Andric storeMatrix(TileTy, StoreVal, TileStart, MAlign, 12005ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 12015ffd83dbSDimitry Andric } 12025ffd83dbSDimitry Andric 12035ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 12045ffd83dbSDimitry Andric /// vectors. 12055ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 12065ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile, 12075ffd83dbSDimitry Andric IRBuilder<> &Builder) { 12085ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 12095f757f3fSDimitry Andric Value *EltPtr = Ptr; 12105ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) { 1211349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 1212349cc55cSDimitry Andric EltPtr, 1213349cc55cSDimitry Andric Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1214349cc55cSDimitry Andric Vec.index()), 1215349cc55cSDimitry Andric Stride, StoreVal.getStride(), VType->getElementType(), Builder); 12165ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP, 12175ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride, 12185ffd83dbSDimitry Andric VType->getElementType(), 12195ffd83dbSDimitry Andric MAlign), 12205ffd83dbSDimitry Andric IsVolatile); 12215ffd83dbSDimitry Andric } 12225ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 12235ffd83dbSDimitry Andric StoreVal.getNumVectors()); 12245ffd83dbSDimitry Andric } 12255ffd83dbSDimitry Andric 12265ffd83dbSDimitry Andric /// Lower a store instruction with shape information. 12275ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 12285ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) { 12295ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 12305ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder); 12315ffd83dbSDimitry Andric finalizeLowering(Inst, 12325ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 12335ffd83dbSDimitry Andric IsVolatile, Builder), 12345ffd83dbSDimitry Andric Builder); 12355ffd83dbSDimitry Andric } 12365ffd83dbSDimitry Andric 12375ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store. 12385ffd83dbSDimitry Andric /// 12395ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 12405ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) { 12415ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 12425ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 12435ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0); 12445ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1); 12455ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2); 12465ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 12475ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 12485ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1249480093f4SDimitry Andric } 1250480093f4SDimitry Andric 1251480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 1252480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 12535ffd83dbSDimitry Andric IRBuilder<> &Builder) { 1254480093f4SDimitry Andric 1255480093f4SDimitry Andric // First, bring Block to the same size as Col 1256480093f4SDimitry Andric unsigned BlockNumElts = 12575ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements(); 12585ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1259480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1260480093f4SDimitry Andric 12615ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector( 1262e8d8bef9SDimitry Andric Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1263480093f4SDimitry Andric 1264480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1265480093f4SDimitry Andric // 8, 4, 5, 6 12665ffd83dbSDimitry Andric SmallVector<int, 16> Mask; 1267480093f4SDimitry Andric unsigned i; 1268480093f4SDimitry Andric for (i = 0; i < I; i++) 12695ffd83dbSDimitry Andric Mask.push_back(i); 1270480093f4SDimitry Andric 12715ffd83dbSDimitry Andric unsigned VecNumElts = 12725ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements(); 1273480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 12745ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts); 1275480093f4SDimitry Andric 1276480093f4SDimitry Andric for (; i < VecNumElts; i++) 12775ffd83dbSDimitry Andric Mask.push_back(i); 1278480093f4SDimitry Andric 12795ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask); 1280480093f4SDimitry Andric } 1281480093f4SDimitry Andric 1282480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 12835ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction, 12845ffd83dbSDimitry Andric unsigned &NumComputeOps) { 12855ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1286480093f4SDimitry Andric if (!Sum) 1287480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1288480093f4SDimitry Andric 1289480093f4SDimitry Andric if (UseFPOp) { 1290480093f4SDimitry Andric if (AllowContraction) { 1291480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 1292480093f4SDimitry Andric // if that's profitable. 12935ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration( 1294480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 1295480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1296480093f4SDimitry Andric } 12975ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1298480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 1299480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 1300480093f4SDimitry Andric } 1301480093f4SDimitry Andric 13025ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1303480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 1304480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 1305480093f4SDimitry Andric } 1306480093f4SDimitry Andric 1307480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1308fe6060f1SDimitry Andric /// users with shape information, there's nothing to do: they will use the 1309480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 1310480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 1311480093f4SDimitry Andric /// deletion. 13125ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1313480093f4SDimitry Andric IRBuilder<> &Builder) { 1314fe6060f1SDimitry Andric auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1315fe6060f1SDimitry Andric (void)inserted; 1316fe6060f1SDimitry Andric assert(inserted.second && "multiple matrix lowering mapping"); 1317480093f4SDimitry Andric 1318480093f4SDimitry Andric ToRemove.push_back(Inst); 1319480093f4SDimitry Andric Value *Flattened = nullptr; 1320fe6060f1SDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1321480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1322480093f4SDimitry Andric if (!Flattened) 1323480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 1324480093f4SDimitry Andric U.set(Flattened); 1325480093f4SDimitry Andric } 1326480093f4SDimitry Andric } 1327480093f4SDimitry Andric } 1328480093f4SDimitry Andric 132906c3fb27SDimitry Andric /// Special case for MatMul lowering. Prevents scalar loads of row-major 133006c3fb27SDimitry Andric /// vectors Lowers to vector reduction add instead of sequential add if 133106c3fb27SDimitry Andric /// reassocation is enabled. 133206c3fb27SDimitry Andric void lowerDotProduct(CallInst *MatMul, 133306c3fb27SDimitry Andric SmallPtrSet<Instruction *, 16> &FusedInsts, 133406c3fb27SDimitry Andric FastMathFlags FMF) { 133506c3fb27SDimitry Andric if (FusedInsts.contains(MatMul) || 133606c3fb27SDimitry Andric MatrixLayout != MatrixLayoutTy::ColumnMajor) 133706c3fb27SDimitry Andric return; 133806c3fb27SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 133906c3fb27SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 134006c3fb27SDimitry Andric 134106c3fb27SDimitry Andric if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 134206c3fb27SDimitry Andric return; 134306c3fb27SDimitry Andric 134406c3fb27SDimitry Andric Value *LHS = MatMul->getArgOperand(0); 134506c3fb27SDimitry Andric Value *RHS = MatMul->getArgOperand(1); 134606c3fb27SDimitry Andric 134706c3fb27SDimitry Andric Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); 134806c3fb27SDimitry Andric bool IsIntVec = ElementType->isIntegerTy(); 134906c3fb27SDimitry Andric 135006c3fb27SDimitry Andric // Floating point reductions require reassocation. 135106c3fb27SDimitry Andric if (!IsIntVec && !FMF.allowReassoc()) 135206c3fb27SDimitry Andric return; 135306c3fb27SDimitry Andric 1354*0fca6ea1SDimitry Andric auto CanBeFlattened = [](Value *Op) { 1355*0fca6ea1SDimitry Andric if (match(Op, m_BinOp())) 135606c3fb27SDimitry Andric return true; 135706c3fb27SDimitry Andric return match( 135806c3fb27SDimitry Andric Op, m_OneUse(m_CombineOr( 135906c3fb27SDimitry Andric m_Load(m_Value()), 136006c3fb27SDimitry Andric m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 136106c3fb27SDimitry Andric m_Intrinsic<Intrinsic::matrix_column_major_load>( 136206c3fb27SDimitry Andric m_Value(), m_SpecificInt(1)))))); 136306c3fb27SDimitry Andric }; 136406c3fb27SDimitry Andric // Returns the cost benefit of using \p Op with the dot product lowering. If 136506c3fb27SDimitry Andric // the returned cost is < 0, the argument is cheaper to use in the 136606c3fb27SDimitry Andric // dot-product lowering. 136706c3fb27SDimitry Andric auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1368*0fca6ea1SDimitry Andric if (ShapeMap.find(Op) == ShapeMap.end()) 1369*0fca6ea1SDimitry Andric return InstructionCost::getInvalid(); 1370*0fca6ea1SDimitry Andric 137106c3fb27SDimitry Andric if (!isa<Instruction>(Op)) 137206c3fb27SDimitry Andric return InstructionCost(0); 137306c3fb27SDimitry Andric 137406c3fb27SDimitry Andric FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 137506c3fb27SDimitry Andric Type *EltTy = VecTy->getElementType(); 137606c3fb27SDimitry Andric 137706c3fb27SDimitry Andric if (!CanBeFlattened(Op)) { 137806c3fb27SDimitry Andric InstructionCost EmbedCost(0); 137906c3fb27SDimitry Andric // Roughly estimate the cost for embedding the columns into a vector. 138006c3fb27SDimitry Andric for (unsigned I = 1; I < N; ++I) 1381*0fca6ea1SDimitry Andric EmbedCost += 138206c3fb27SDimitry Andric TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 138306c3fb27SDimitry Andric std::nullopt, TTI::TCK_RecipThroughput); 138406c3fb27SDimitry Andric return EmbedCost; 138506c3fb27SDimitry Andric } 138606c3fb27SDimitry Andric 138706c3fb27SDimitry Andric if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 138806c3fb27SDimitry Andric InstructionCost OriginalCost = 138906c3fb27SDimitry Andric TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 139006c3fb27SDimitry Andric EltTy) * 139106c3fb27SDimitry Andric N; 139206c3fb27SDimitry Andric InstructionCost NewCost = TTI.getArithmeticInstrCost( 139306c3fb27SDimitry Andric cast<Instruction>(Op)->getOpcode(), VecTy); 139406c3fb27SDimitry Andric return NewCost - OriginalCost; 139506c3fb27SDimitry Andric } 139606c3fb27SDimitry Andric 139706c3fb27SDimitry Andric if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 139806c3fb27SDimitry Andric // The transpose can be skipped for the dot product lowering, roughly 139906c3fb27SDimitry Andric // estimate the savings as the cost of embedding the columns in a 140006c3fb27SDimitry Andric // vector. 140106c3fb27SDimitry Andric InstructionCost EmbedCost(0); 140206c3fb27SDimitry Andric for (unsigned I = 1; I < N; ++I) 1403*0fca6ea1SDimitry Andric EmbedCost -= 140406c3fb27SDimitry Andric TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 140506c3fb27SDimitry Andric std::nullopt, TTI::TCK_RecipThroughput); 140606c3fb27SDimitry Andric return EmbedCost; 140706c3fb27SDimitry Andric } 140806c3fb27SDimitry Andric 140906c3fb27SDimitry Andric // Costs for loads. 141006c3fb27SDimitry Andric if (N == 1) 141106c3fb27SDimitry Andric return InstructionCost(0); 141206c3fb27SDimitry Andric 141306c3fb27SDimitry Andric return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 141406c3fb27SDimitry Andric N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 141506c3fb27SDimitry Andric }; 1416*0fca6ea1SDimitry Andric 1417*0fca6ea1SDimitry Andric // Iterate over LHS and operations feeding LHS and check if it is profitable 1418*0fca6ea1SDimitry Andric // to flatten the visited ops. For each op, we compute the difference 1419*0fca6ea1SDimitry Andric // between the flattened and matrix versions. 1420*0fca6ea1SDimitry Andric SmallPtrSet<Value *, 4> Seen; 1421*0fca6ea1SDimitry Andric SmallVector<Value *> WorkList; 1422*0fca6ea1SDimitry Andric SmallVector<Value *> ToFlatten; 1423*0fca6ea1SDimitry Andric WorkList.push_back(LHS); 1424*0fca6ea1SDimitry Andric InstructionCost LHSCost(0); 1425*0fca6ea1SDimitry Andric while (!WorkList.empty()) { 1426*0fca6ea1SDimitry Andric Value *Op = WorkList.pop_back_val(); 1427*0fca6ea1SDimitry Andric if (!Seen.insert(Op).second) 1428*0fca6ea1SDimitry Andric continue; 1429*0fca6ea1SDimitry Andric 1430*0fca6ea1SDimitry Andric InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); 1431*0fca6ea1SDimitry Andric if (OpCost + LHSCost >= LHSCost) 1432*0fca6ea1SDimitry Andric continue; 1433*0fca6ea1SDimitry Andric 1434*0fca6ea1SDimitry Andric LHSCost += OpCost; 1435*0fca6ea1SDimitry Andric ToFlatten.push_back(Op); 1436*0fca6ea1SDimitry Andric if (auto *I = dyn_cast<Instruction>(Op)) 1437*0fca6ea1SDimitry Andric WorkList.append(I->op_begin(), I->op_end()); 1438*0fca6ea1SDimitry Andric } 143906c3fb27SDimitry Andric 144006c3fb27SDimitry Andric // We compare the costs of a vector.reduce.add to sequential add. 144106c3fb27SDimitry Andric int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 144206c3fb27SDimitry Andric int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 144306c3fb27SDimitry Andric InstructionCost ReductionCost = 144406c3fb27SDimitry Andric TTI.getArithmeticReductionCost( 144506c3fb27SDimitry Andric AddOpCode, cast<VectorType>(LHS->getType()), 144606c3fb27SDimitry Andric IsIntVec ? std::nullopt : std::optional(FMF)) + 144706c3fb27SDimitry Andric TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 144806c3fb27SDimitry Andric InstructionCost SequentialAddCost = 144906c3fb27SDimitry Andric TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 145006c3fb27SDimitry Andric (LShape.NumColumns - 1) + 145106c3fb27SDimitry Andric TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 145206c3fb27SDimitry Andric (LShape.NumColumns); 145306c3fb27SDimitry Andric if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 145406c3fb27SDimitry Andric return; 145506c3fb27SDimitry Andric 145606c3fb27SDimitry Andric FusedInsts.insert(MatMul); 145706c3fb27SDimitry Andric IRBuilder<> Builder(MatMul); 145806c3fb27SDimitry Andric auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1459*0fca6ea1SDimitry Andric this](Value *Op) { 146006c3fb27SDimitry Andric // Matmul must be the only user of loads because we don't use LowerLoad 146106c3fb27SDimitry Andric // for row vectors (LowerLoad results in scalar loads and shufflevectors 146206c3fb27SDimitry Andric // instead of single vector load). 146306c3fb27SDimitry Andric if (!CanBeFlattened(Op)) 1464*0fca6ea1SDimitry Andric return; 146506c3fb27SDimitry Andric 146606c3fb27SDimitry Andric if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 146706c3fb27SDimitry Andric ShapeMap[Op] = ShapeMap[Op].t(); 1468*0fca6ea1SDimitry Andric return; 146906c3fb27SDimitry Andric } 147006c3fb27SDimitry Andric 147106c3fb27SDimitry Andric FusedInsts.insert(cast<Instruction>(Op)); 147206c3fb27SDimitry Andric // If vector uses the builtin load, lower to a LoadInst 147306c3fb27SDimitry Andric Value *Arg; 147406c3fb27SDimitry Andric if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 147506c3fb27SDimitry Andric m_Value(Arg)))) { 147606c3fb27SDimitry Andric auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 147706c3fb27SDimitry Andric Op->replaceAllUsesWith(NewLoad); 147806c3fb27SDimitry Andric cast<Instruction>(Op)->eraseFromParent(); 1479*0fca6ea1SDimitry Andric return; 148006c3fb27SDimitry Andric } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 148106c3fb27SDimitry Andric m_Value(Arg)))) { 148206c3fb27SDimitry Andric ToRemove.push_back(cast<Instruction>(Op)); 1483*0fca6ea1SDimitry Andric Op->replaceAllUsesWith(Arg); 1484*0fca6ea1SDimitry Andric return; 148506c3fb27SDimitry Andric } 148606c3fb27SDimitry Andric }; 1487*0fca6ea1SDimitry Andric 1488*0fca6ea1SDimitry Andric for (auto *V : ToFlatten) 1489*0fca6ea1SDimitry Andric FlattenArg(V); 1490*0fca6ea1SDimitry Andric 1491*0fca6ea1SDimitry Andric LHS = MatMul->getArgOperand(0); 149206c3fb27SDimitry Andric 149306c3fb27SDimitry Andric // Insert mul/fmul and llvm.vector.reduce.fadd 149406c3fb27SDimitry Andric Value *Mul = 149506c3fb27SDimitry Andric IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 149606c3fb27SDimitry Andric 149706c3fb27SDimitry Andric Value *Result; 149806c3fb27SDimitry Andric if (IsIntVec) 149906c3fb27SDimitry Andric Result = Builder.CreateAddReduce(Mul); 150006c3fb27SDimitry Andric else { 150106c3fb27SDimitry Andric Result = Builder.CreateFAddReduce( 150206c3fb27SDimitry Andric ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), 150306c3fb27SDimitry Andric 0.0), 150406c3fb27SDimitry Andric Mul); 150506c3fb27SDimitry Andric cast<Instruction>(Result)->setFastMathFlags(FMF); 150606c3fb27SDimitry Andric } 150706c3fb27SDimitry Andric 150806c3fb27SDimitry Andric // pack scalar back into a matrix and then replace matmul inst 150906c3fb27SDimitry Andric Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 151006c3fb27SDimitry Andric Result, uint64_t(0)); 151106c3fb27SDimitry Andric MatMul->replaceAllUsesWith(Result); 151206c3fb27SDimitry Andric FusedInsts.insert(MatMul); 151306c3fb27SDimitry Andric ToRemove.push_back(MatMul); 151406c3fb27SDimitry Andric } 151506c3fb27SDimitry Andric 15165ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating 15175ffd83dbSDimitry Andric /// addition. 1518fe6060f1SDimitry Andric /// 1519fe6060f1SDimitry Andric /// We can fold a transpose into the operand that is used to extract scalars. 1520fe6060f1SDimitry Andric /// This is the first operands with row-major and the second with 1521fe6060f1SDimitry Andric /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1522fe6060f1SDimitry Andric /// operand is transposed. 15235ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1524fe6060f1SDimitry Andric const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1525fe6060f1SDimitry Andric bool IsScalarMatrixTransposed, FastMathFlags FMF) { 15265ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>( 1527fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1528bdd1243dSDimitry Andric .getFixedValue() / 1529bdd1243dSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 15305ffd83dbSDimitry Andric 1U); 15315ffd83dbSDimitry Andric unsigned R = Result.getNumRows(); 15325ffd83dbSDimitry Andric unsigned C = Result.getNumColumns(); 15335ffd83dbSDimitry Andric unsigned M = A.getNumColumns(); 15345ffd83dbSDimitry Andric 15355ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy(); 15365ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 15375ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 15385ffd83dbSDimitry Andric "operands must agree on matrix layout"); 15395ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 1540fe6060f1SDimitry Andric 1541fe6060f1SDimitry Andric Builder.setFastMathFlags(FMF); 1542fe6060f1SDimitry Andric 15435ffd83dbSDimitry Andric if (A.isColumnMajor()) { 15445ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second 15455ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 15465ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation. 15475ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) { 15485ffd83dbSDimitry Andric unsigned BlockSize = VF; 15495ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration. 15505ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 15515ffd83dbSDimitry Andric 15525ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 15535ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 15545ffd83dbSDimitry Andric while (I + BlockSize > R) 15555ffd83dbSDimitry Andric BlockSize /= 2; 15565ffd83dbSDimitry Andric 1557fe6060f1SDimitry Andric Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 15585ffd83dbSDimitry Andric : nullptr; 15595ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 15605ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder); 1561fe6060f1SDimitry Andric Value *RH = Builder.CreateExtractElement( 1562fe6060f1SDimitry Andric B.getColumn(IsScalarMatrixTransposed ? K : J), 1563fe6060f1SDimitry Andric IsScalarMatrixTransposed ? J : K); 15645ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1565fe6060f1SDimitry Andric Sum = 1566fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1567fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 15685ffd83dbSDimitry Andric } 15695ffd83dbSDimitry Andric Result.setVector(J, 15705ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder)); 15715ffd83dbSDimitry Andric } 15725ffd83dbSDimitry Andric } 15735ffd83dbSDimitry Andric } else { 15745ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first 15755ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this 15765ffd83dbSDimitry Andric // the adds can be vectorized without reassociation. 15775ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) { 15785ffd83dbSDimitry Andric unsigned BlockSize = VF; 15795ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 15805ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) { 15815ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 15825ffd83dbSDimitry Andric while (J + BlockSize > C) 15835ffd83dbSDimitry Andric BlockSize /= 2; 15845ffd83dbSDimitry Andric 15855ffd83dbSDimitry Andric Value *Sum = nullptr; 15865ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 15875ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder); 1588fe6060f1SDimitry Andric Value *LH = Builder.CreateExtractElement( 1589fe6060f1SDimitry Andric A.getVector(IsScalarMatrixTransposed ? K : I), 1590fe6060f1SDimitry Andric IsScalarMatrixTransposed ? I : K); 15915ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1592fe6060f1SDimitry Andric Sum = 1593fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1594fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 15955ffd83dbSDimitry Andric } 15965ffd83dbSDimitry Andric Result.setVector(I, 15975ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder)); 15985ffd83dbSDimitry Andric } 15995ffd83dbSDimitry Andric } 16005ffd83dbSDimitry Andric } 16015ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps); 16025ffd83dbSDimitry Andric } 16035ffd83dbSDimitry Andric 16045ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially 16055ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location 16065ffd83dbSDimitry Andric /// is returned. 16075ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 16085ffd83dbSDimitry Andric CallInst *MatMul) { 16095ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store); 16105ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load); 16115ffd83dbSDimitry Andric 16125ffd83dbSDimitry Andric // If we can statically determine noalias we're good. 1613fe6060f1SDimitry Andric if (AA->isNoAlias(LoadLoc, StoreLoc)) 16145ffd83dbSDimitry Andric return Load->getPointerOperand(); 16155ffd83dbSDimitry Andric 16165ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store 16175ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer. 16185ffd83dbSDimitry Andric 16195ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy. 16205ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent(); 16215ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 16225ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work, 16235ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches. 16245ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 16255ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0)) 1626e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Delete, Check0, Succ}); 16275ffd83dbSDimitry Andric 1628e8d8bef9SDimitry Andric BasicBlock *Check1 = 1629e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 16305ffd83dbSDimitry Andric nullptr, "alias_cont"); 16315ffd83dbSDimitry Andric BasicBlock *Copy = 1632e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1633e8d8bef9SDimitry Andric nullptr, "copy"); 1634e8d8bef9SDimitry Andric BasicBlock *Fusion = 1635e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 16365ffd83dbSDimitry Andric nullptr, "no_alias"); 16375ffd83dbSDimitry Andric 16385ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store 16395ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are 16405ffd83dbSDimitry Andric // guaranteed to not overlap. 16415ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul); 16425ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent(); 16435ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0); 1644*0fca6ea1SDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); 16455ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt( 16465ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 16475ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd( 16485ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 16495ffd83dbSDimitry Andric "store.end", true, true); 16505ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 16515ffd83dbSDimitry Andric IntPtrTy, "load.begin"); 16525ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 16535ffd83dbSDimitry Andric Fusion); 16545ffd83dbSDimitry Andric 16555ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the 16565ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not 16575ffd83dbSDimitry Andric // overlap. 16585ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent(); 16595ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin()); 16605ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd( 16615ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 16625ffd83dbSDimitry Andric "load.end", true, true); 16635ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 16645ffd83dbSDimitry Andric Fusion); 16655ffd83dbSDimitry Andric 16665ffd83dbSDimitry Andric // Copy load operand to new alloca. 16675ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin()); 16681fd87a68SDimitry Andric auto *VT = cast<FixedVectorType>(Load->getType()); 16691fd87a68SDimitry Andric // Use an array type for the alloca, to avoid potentially huge alignment 16701fd87a68SDimitry Andric // requirements for large vector types. 16711fd87a68SDimitry Andric auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 16721fd87a68SDimitry Andric AllocaInst *Alloca = 16731fd87a68SDimitry Andric Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 16741fd87a68SDimitry Andric 167506c3fb27SDimitry Andric Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 16761fd87a68SDimitry Andric Load->getAlign(), LoadLoc.Size.getValue()); 16775ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin()); 16785ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 16795ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0); 16805ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1); 168106c3fb27SDimitry Andric PHI->addIncoming(Alloca, Copy); 16825ffd83dbSDimitry Andric 16835ffd83dbSDimitry Andric // Adjust DT. 1684e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Check1}); 1685e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1686e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Copy}); 1687e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1688e8d8bef9SDimitry Andric DT->applyUpdates(DTUpdates); 16895ffd83dbSDimitry Andric return PHI; 16905ffd83dbSDimitry Andric } 16915ffd83dbSDimitry Andric 16925ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) { 16935ffd83dbSDimitry Andric if (ForceFusion) 16945ffd83dbSDimitry Andric return true; 16955ffd83dbSDimitry Andric 16965ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 16975ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 16985ffd83dbSDimitry Andric 16995ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 17005ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 17015ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 17025ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 17035ffd83dbSDimitry Andric 1704fe6060f1SDimitry Andric const unsigned VF = std::max<unsigned>( 1705fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1706bdd1243dSDimitry Andric .getFixedValue() / 1707bdd1243dSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedValue(), 17085ffd83dbSDimitry Andric 1U); 17095ffd83dbSDimitry Andric 17105ffd83dbSDimitry Andric // Cost model for tiling 17115ffd83dbSDimitry Andric // 17125ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or 17135ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least 17145ffd83dbSDimitry Andric // 3 elements. 17155ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias. 17165ffd83dbSDimitry Andric if (R <= VF && C == 1) 17175ffd83dbSDimitry Andric return false; 17185ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector 17195ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since 17205ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of 17215ffd83dbSDimitry Andric // reloads necessary. 17225ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M; 17235ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C; 172404eeddc0SDimitry Andric return Op0Regs + Op1Regs > 172504eeddc0SDimitry Andric TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 17265ffd83dbSDimitry Andric } 17275ffd83dbSDimitry Andric 17285ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 17295ffd83dbSDimitry Andric MatrixTy Res; 17305ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R); 17315ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I) 17325ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType)); 17335ffd83dbSDimitry Andric return Res; 17345ffd83dbSDimitry Andric } 17355ffd83dbSDimitry Andric 1736e8d8bef9SDimitry Andric void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1737fe6060f1SDimitry Andric Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1738e8d8bef9SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1739e8d8bef9SDimitry Andric 1740e8d8bef9SDimitry Andric // Create the main tiling loop nest. 1741e8d8bef9SDimitry Andric TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1742e8d8bef9SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1743e8d8bef9SDimitry Andric Instruction *InsertI = cast<Instruction>(MatMul); 1744e8d8bef9SDimitry Andric BasicBlock *Start = InsertI->getParent(); 1745e8d8bef9SDimitry Andric BasicBlock *End = 1746e8d8bef9SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1747e8d8bef9SDimitry Andric IRBuilder<> Builder(MatMul); 1748e8d8bef9SDimitry Andric BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1749e8d8bef9SDimitry Andric 1750e8d8bef9SDimitry Andric Type *TileVecTy = 1751e8d8bef9SDimitry Andric FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1752e8d8bef9SDimitry Andric MatrixTy TileResult; 1753e8d8bef9SDimitry Andric // Insert in the inner loop header. 1754972a253aSDimitry Andric Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1755e8d8bef9SDimitry Andric // Create PHI nodes for the result columns to accumulate across iterations. 1756e8d8bef9SDimitry Andric SmallVector<PHINode *, 4> ColumnPhis; 1757e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileSize; I++) { 1758e8d8bef9SDimitry Andric auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1759e8d8bef9SDimitry Andric Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1760972a253aSDimitry Andric TI.RowLoop.Header->getSingleSuccessor()); 1761e8d8bef9SDimitry Andric TileResult.addVector(Phi); 1762e8d8bef9SDimitry Andric ColumnPhis.push_back(Phi); 1763e8d8bef9SDimitry Andric } 1764e8d8bef9SDimitry Andric 1765e8d8bef9SDimitry Andric // Insert in the inner loop body, which computes 1766e8d8bef9SDimitry Andric // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1767e8d8bef9SDimitry Andric Builder.SetInsertPoint(InnerBody->getTerminator()); 1768e8d8bef9SDimitry Andric // Load tiles of the operands. 1769972a253aSDimitry Andric MatrixTy A = 1770972a253aSDimitry Andric loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1771e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1772972a253aSDimitry Andric MatrixTy B = 1773972a253aSDimitry Andric loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1774e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1775fe6060f1SDimitry Andric emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1776fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1777e8d8bef9SDimitry Andric // Store result after the inner loop is done. 1778972a253aSDimitry Andric Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1779e8d8bef9SDimitry Andric storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1780e8d8bef9SDimitry Andric Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1781972a253aSDimitry Andric TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1782e8d8bef9SDimitry Andric 1783e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1784972a253aSDimitry Andric ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1785e8d8bef9SDimitry Andric 1786e8d8bef9SDimitry Andric // Force unrolling of a few iterations of the inner loop, to make sure there 1787e8d8bef9SDimitry Andric // is enough work per iteration. 1788e8d8bef9SDimitry Andric // FIXME: The unroller should make this decision directly instead, but 1789e8d8bef9SDimitry Andric // currently the cost-model is not up to the task. 1790e8d8bef9SDimitry Andric unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1791972a253aSDimitry Andric addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1792e8d8bef9SDimitry Andric "llvm.loop.unroll.count", InnerLoopUnrollCount); 1793e8d8bef9SDimitry Andric } 1794e8d8bef9SDimitry Andric 17955ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 17965ffd83dbSDimitry Andric StoreInst *Store, 17975ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 17985ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 17995ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!"); 18005ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul)) 18015ffd83dbSDimitry Andric return; 18025ffd83dbSDimitry Andric 18035ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 18045ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 18055ffd83dbSDimitry Andric 18065ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 18075ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 18085ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 18095ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 18105ffd83dbSDimitry Andric 18115ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 18125ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 18135ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand(); 18145ffd83dbSDimitry Andric 1815e8d8bef9SDimitry Andric if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1816fe6060f1SDimitry Andric createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1817e8d8bef9SDimitry Andric else { 18185ffd83dbSDimitry Andric IRBuilder<> Builder(Store); 18195ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize) 18205ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) { 18215ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize)); 18225ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize)); 18235ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 18245ffd83dbSDimitry Andric 18255ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) { 18265ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize)); 18275ffd83dbSDimitry Andric MatrixTy A = 18285ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 18295ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K), 18305ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder); 18315ffd83dbSDimitry Andric MatrixTy B = 18325ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 18335ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J), 18345ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder); 1835fe6060f1SDimitry Andric emitMatrixMultiply(Res, A, B, Builder, true, false, 1836fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 18375ffd83dbSDimitry Andric } 18385ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1839e8d8bef9SDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType, 1840e8d8bef9SDimitry Andric Builder); 1841e8d8bef9SDimitry Andric } 18425ffd83dbSDimitry Andric } 18435ffd83dbSDimitry Andric 18445ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them. 18455ffd83dbSDimitry Andric FusedInsts.insert(Store); 18465ffd83dbSDimitry Andric FusedInsts.insert(MatMul); 18475ffd83dbSDimitry Andric Store->eraseFromParent(); 18485ffd83dbSDimitry Andric MatMul->eraseFromParent(); 18495ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) { 18505ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0); 18515ffd83dbSDimitry Andric LoadOp0->eraseFromParent(); 18525ffd83dbSDimitry Andric } 1853fe6060f1SDimitry Andric if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 18545ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1); 18555ffd83dbSDimitry Andric LoadOp1->eraseFromParent(); 18565ffd83dbSDimitry Andric } 18575ffd83dbSDimitry Andric } 18585ffd83dbSDimitry Andric 18595ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations. 18605ffd83dbSDimitry Andric /// 1861fe6060f1SDimitry Andric /// Call finalizeLowering on lowered instructions. Instructions that are 1862fe6060f1SDimitry Andric /// completely eliminated by fusion are added to \p FusedInsts. 1863*0fca6ea1SDimitry Andric void 1864*0fca6ea1SDimitry Andric LowerMatrixMultiplyFused(CallInst *MatMul, 1865*0fca6ea1SDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts, 1866*0fca6ea1SDimitry Andric SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { 1867fe6060f1SDimitry Andric if (!FuseMatrix || !DT) 18685ffd83dbSDimitry Andric return; 18695ffd83dbSDimitry Andric 1870e8d8bef9SDimitry Andric assert(AA && LI && "Analyses should be available"); 1871e8d8bef9SDimitry Andric 1872fe6060f1SDimitry Andric Value *A = MatMul->getArgOperand(0); 1873fe6060f1SDimitry Andric Value *B = MatMul->getArgOperand(1); 1874fe6060f1SDimitry Andric 1875fe6060f1SDimitry Andric // We can fold the transpose into the operand that is used to fetch scalars. 1876fe6060f1SDimitry Andric Value *T; 1877fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1878fe6060f1SDimitry Andric ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1879fe6060f1SDimitry Andric : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1880fe6060f1SDimitry Andric IRBuilder<> Builder(MatMul); 1881fe6060f1SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1882fe6060f1SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1883fe6060f1SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1884fe6060f1SDimitry Andric const unsigned R = LShape.NumRows; 1885fe6060f1SDimitry Andric const unsigned M = LShape.NumColumns; 1886fe6060f1SDimitry Andric const unsigned C = RShape.NumColumns; 1887fe6060f1SDimitry Andric 1888fe6060f1SDimitry Andric MatrixTy MA; 1889fe6060f1SDimitry Andric MatrixTy MB; 1890fe6060f1SDimitry Andric 1891fe6060f1SDimitry Andric Value *Transpose; 1892fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1893fe6060f1SDimitry Andric MA = getMatrix(A, ShapeInfo(R, M), Builder); 1894fe6060f1SDimitry Andric MB = getMatrix(T, ShapeInfo(C, M), Builder); 1895fe6060f1SDimitry Andric Transpose = B; 1896fe6060f1SDimitry Andric } else { 1897fe6060f1SDimitry Andric MA = getMatrix(T, ShapeInfo(R, M), Builder); 1898fe6060f1SDimitry Andric MB = getMatrix(B, ShapeInfo(C, M), Builder); 1899fe6060f1SDimitry Andric Transpose = A; 1900fe6060f1SDimitry Andric } 1901fe6060f1SDimitry Andric 1902fe6060f1SDimitry Andric // Initialize the output 1903fe6060f1SDimitry Andric MatrixTy Result(R, C, EltType); 1904fe6060f1SDimitry Andric 1905fe6060f1SDimitry Andric emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1906fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1907fe6060f1SDimitry Andric 1908fe6060f1SDimitry Andric FusedInsts.insert(MatMul); 1909fe6060f1SDimitry Andric if (Transpose->hasOneUse()) { 1910fe6060f1SDimitry Andric FusedInsts.insert(cast<Instruction>(Transpose)); 1911fe6060f1SDimitry Andric ToRemove.push_back(cast<Instruction>(Transpose)); 1912fe6060f1SDimitry Andric // TODO: add a fake entry for the folded instruction so that this is 1913fe6060f1SDimitry Andric // included in the expression in the remark. 1914fe6060f1SDimitry Andric Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1915fe6060f1SDimitry Andric } 1916fe6060f1SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1917fe6060f1SDimitry Andric return; 1918fe6060f1SDimitry Andric } 1919fe6060f1SDimitry Andric 1920fe6060f1SDimitry Andric if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1921fe6060f1SDimitry Andric return; 1922fe6060f1SDimitry Andric 1923fe6060f1SDimitry Andric // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1924fe6060f1SDimitry Andric // since the single store user will be lowered as part of this. 1925fe6060f1SDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(A); 1926fe6060f1SDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(B); 19275ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 19285ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) { 19295ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise 19305ffd83dbSDimitry Andric // we create invalid IR. 1931fe6060f1SDimitry Andric SetVector<Value *> WorkList; 1932fe6060f1SDimitry Andric WorkList.insert(Store->getOperand(1)); 1933fe6060f1SDimitry Andric SmallVector<Instruction *> ToHoist; 1934fe6060f1SDimitry Andric for (unsigned I = 0; I != WorkList.size(); ++I) { 1935fe6060f1SDimitry Andric Value *Current = WorkList[I]; 1936fe6060f1SDimitry Andric auto *CurrI = dyn_cast<Instruction>(Current); 1937fe6060f1SDimitry Andric if (!CurrI) 1938fe6060f1SDimitry Andric continue; 1939fe6060f1SDimitry Andric if (isa<PHINode>(CurrI)) 19405ffd83dbSDimitry Andric return; 1941fe6060f1SDimitry Andric if (DT->dominates(CurrI, MatMul)) 1942fe6060f1SDimitry Andric continue; 1943fe6060f1SDimitry Andric if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1944fe6060f1SDimitry Andric return; 1945fe6060f1SDimitry Andric ToHoist.push_back(CurrI); 1946fe6060f1SDimitry Andric WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1947fe6060f1SDimitry Andric } 1948fe6060f1SDimitry Andric 1949fe6060f1SDimitry Andric sort(ToHoist, [this](Instruction *A, Instruction *B) { 1950fe6060f1SDimitry Andric return DT->dominates(A, B); 1951fe6060f1SDimitry Andric }); 1952fe6060f1SDimitry Andric for (Instruction *I : ToHoist) 1953fe6060f1SDimitry Andric I->moveBefore(MatMul); 19545ffd83dbSDimitry Andric 1955*0fca6ea1SDimitry Andric // Deal with lifetime.end calls that might be between Load0/Load1 and the 1956*0fca6ea1SDimitry Andric // store. To avoid introducing loads to dead objects (i.e. after the 1957*0fca6ea1SDimitry Andric // lifetime has been termined by @llvm.lifetime.end), either sink them 1958*0fca6ea1SDimitry Andric // after the store if in the same block, or remove the lifetime.end marker 1959*0fca6ea1SDimitry Andric // otherwise. This might pessimize further optimizations, by extending the 1960*0fca6ea1SDimitry Andric // lifetime of the object until the function returns, but should be 1961*0fca6ea1SDimitry Andric // conservatively correct. 1962*0fca6ea1SDimitry Andric MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); 1963*0fca6ea1SDimitry Andric MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); 1964*0fca6ea1SDimitry Andric BasicBlock *StoreParent = Store->getParent(); 1965*0fca6ea1SDimitry Andric bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && 1966*0fca6ea1SDimitry Andric LoadOp1->getParent() == StoreParent; 1967*0fca6ea1SDimitry Andric for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { 1968*0fca6ea1SDimitry Andric IntrinsicInst *End = LifetimeEnds[Idx]; 1969*0fca6ea1SDimitry Andric auto Inc = make_scope_exit([&Idx]() { Idx++; }); 1970*0fca6ea1SDimitry Andric // If the lifetime.end is guaranteed to be before the loads or after the 1971*0fca6ea1SDimitry Andric // store, it won't interfere with fusion. 1972*0fca6ea1SDimitry Andric if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) 1973*0fca6ea1SDimitry Andric continue; 1974*0fca6ea1SDimitry Andric if (DT->dominates(Store, End)) 1975*0fca6ea1SDimitry Andric continue; 1976*0fca6ea1SDimitry Andric // If all fusable ops are in the same block and the lifetime.end is in a 1977*0fca6ea1SDimitry Andric // different block, it won't interfere with fusion. 1978*0fca6ea1SDimitry Andric if (FusableOpsInSameBlock && End->getParent() != StoreParent) 1979*0fca6ea1SDimitry Andric continue; 1980*0fca6ea1SDimitry Andric 1981*0fca6ea1SDimitry Andric // If the loads don't alias the lifetime.end, it won't interfere with 1982*0fca6ea1SDimitry Andric // fusion. 1983*0fca6ea1SDimitry Andric MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); 1984*0fca6ea1SDimitry Andric if (!EndLoc.Ptr) 1985*0fca6ea1SDimitry Andric continue; 1986*0fca6ea1SDimitry Andric if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) 1987*0fca6ea1SDimitry Andric continue; 1988*0fca6ea1SDimitry Andric 1989*0fca6ea1SDimitry Andric // If both lifetime.end and the store are in the same block, extend the 1990*0fca6ea1SDimitry Andric // lifetime until after the store, so the new lifetime covers the loads 1991*0fca6ea1SDimitry Andric // we introduce later. 1992*0fca6ea1SDimitry Andric if (End->getParent() == StoreParent) { 1993*0fca6ea1SDimitry Andric End->moveAfter(Store); 1994*0fca6ea1SDimitry Andric continue; 1995*0fca6ea1SDimitry Andric } 1996*0fca6ea1SDimitry Andric 1997*0fca6ea1SDimitry Andric // Otherwise remove the conflicting lifetime.end marker. 1998*0fca6ea1SDimitry Andric ToRemove.push_back(End); 1999*0fca6ea1SDimitry Andric std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); 2000*0fca6ea1SDimitry Andric LifetimeEnds.pop_back(); 2001*0fca6ea1SDimitry Andric Inc.release(); 2002*0fca6ea1SDimitry Andric } 2003*0fca6ea1SDimitry Andric 20045ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 20055ffd83dbSDimitry Andric return; 20065ffd83dbSDimitry Andric } 20075ffd83dbSDimitry Andric } 20085ffd83dbSDimitry Andric 2009480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 2010480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 2011480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 2012480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 2013480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 2014480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 2015480093f4SDimitry Andric 20165ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 20175ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 2018e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Rhs.getElementType() && 2019e8d8bef9SDimitry Andric "Matrix multiply argument element types do not match."); 2020480093f4SDimitry Andric 2021480093f4SDimitry Andric const unsigned R = LShape.NumRows; 2022480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 20235ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows); 2024480093f4SDimitry Andric 2025480093f4SDimitry Andric // Initialize the output 20265ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType); 2027e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Result.getElementType() && 2028e8d8bef9SDimitry Andric "Matrix multiply result element type does not match arguments."); 2029480093f4SDimitry Andric 2030fe6060f1SDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 2031fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 2032480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 2033480093f4SDimitry Andric } 2034480093f4SDimitry Andric 2035480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 2036480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 20375ffd83dbSDimitry Andric MatrixTy Result; 2038480093f4SDimitry Andric IRBuilder<> Builder(Inst); 2039480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 2040480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 2041480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 20425ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 2043480093f4SDimitry Andric 20445ffd83dbSDimitry Andric const unsigned NewNumVecs = 20455ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 20465ffd83dbSDimitry Andric const unsigned NewNumElts = 20475ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 2048480093f4SDimitry Andric 20495ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) { 20505ffd83dbSDimitry Andric // Build a single result vector. First initialize it. 205181ad6265SDimitry Andric Value *ResultVector = PoisonValue::get( 20525ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 20535ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector. 20545ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) { 20555ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I); 20565ffd83dbSDimitry Andric // Row and column indices are transposed. 20575ffd83dbSDimitry Andric ResultVector = 20585ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index()); 2059480093f4SDimitry Andric } 20605ffd83dbSDimitry Andric Result.addVector(ResultVector); 2061480093f4SDimitry Andric } 2062480093f4SDimitry Andric 20635ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we 20645ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not 20655ffd83dbSDimitry Andric // account for later simplifications/combines. 20665ffd83dbSDimitry Andric finalizeLowering( 20675ffd83dbSDimitry Andric Inst, 2068fe6060f1SDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 2069fe6060f1SDimitry Andric .addNumExposedTransposes(1), 20705ffd83dbSDimitry Andric Builder); 2071480093f4SDimitry Andric } 2072480093f4SDimitry Andric 2073480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 20745ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 2075480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 2076480093f4SDimitry Andric if (I == ShapeMap.end()) 2077480093f4SDimitry Andric return false; 2078480093f4SDimitry Andric 20795ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(), 20805ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 20815ffd83dbSDimitry Andric I->second); 2082480093f4SDimitry Andric return true; 2083480093f4SDimitry Andric } 2084480093f4SDimitry Andric 20855ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 2086480093f4SDimitry Andric IRBuilder<> &Builder) { 2087480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 2088480093f4SDimitry Andric if (I == ShapeMap.end()) 2089480093f4SDimitry Andric return false; 2090480093f4SDimitry Andric 20915ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 20925ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 20935ffd83dbSDimitry Andric I->second); 2094480093f4SDimitry Andric return true; 2095480093f4SDimitry Andric } 2096480093f4SDimitry Andric 2097480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 2098480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 2099480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 2100480093f4SDimitry Andric if (I == ShapeMap.end()) 2101480093f4SDimitry Andric return false; 2102480093f4SDimitry Andric 2103480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 2104480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 2105480093f4SDimitry Andric 2106480093f4SDimitry Andric IRBuilder<> Builder(Inst); 2107480093f4SDimitry Andric ShapeInfo &Shape = I->second; 2108480093f4SDimitry Andric 21095ffd83dbSDimitry Andric MatrixTy Result; 21105ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder); 21115ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder); 21125ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 21135ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 21145ffd83dbSDimitry Andric "operands must agree on matrix layout"); 2115480093f4SDimitry Andric 2116fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 2117fe6060f1SDimitry Andric 21185ffd83dbSDimitry Andric // Helper to perform binary op on vectors. 21195ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 2120480093f4SDimitry Andric switch (Inst->getOpcode()) { 2121480093f4SDimitry Andric case Instruction::Add: 2122480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 2123480093f4SDimitry Andric case Instruction::Mul: 2124480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 2125480093f4SDimitry Andric case Instruction::Sub: 2126480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 2127480093f4SDimitry Andric case Instruction::FAdd: 2128480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 2129480093f4SDimitry Andric case Instruction::FMul: 2130480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 2131480093f4SDimitry Andric case Instruction::FSub: 2132480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 2133480093f4SDimitry Andric default: 2134480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 2135480093f4SDimitry Andric } 2136480093f4SDimitry Andric }; 2137480093f4SDimitry Andric 21385ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 21395ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 21405ffd83dbSDimitry Andric 21415ffd83dbSDimitry Andric finalizeLowering(Inst, 21425ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 21435ffd83dbSDimitry Andric Result.getNumVectors()), 21445ffd83dbSDimitry Andric Builder); 2145480093f4SDimitry Andric return true; 2146480093f4SDimitry Andric } 21475ffd83dbSDimitry Andric 2148e8d8bef9SDimitry Andric /// Lower unary operators, if shape information is available. 2149e8d8bef9SDimitry Andric bool VisitUnaryOperator(UnaryOperator *Inst) { 2150e8d8bef9SDimitry Andric auto I = ShapeMap.find(Inst); 2151e8d8bef9SDimitry Andric if (I == ShapeMap.end()) 2152e8d8bef9SDimitry Andric return false; 2153e8d8bef9SDimitry Andric 2154e8d8bef9SDimitry Andric Value *Op = Inst->getOperand(0); 2155e8d8bef9SDimitry Andric 2156e8d8bef9SDimitry Andric IRBuilder<> Builder(Inst); 2157e8d8bef9SDimitry Andric ShapeInfo &Shape = I->second; 2158e8d8bef9SDimitry Andric 2159e8d8bef9SDimitry Andric MatrixTy Result; 2160e8d8bef9SDimitry Andric MatrixTy M = getMatrix(Op, Shape, Builder); 2161e8d8bef9SDimitry Andric 2162fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 2163fe6060f1SDimitry Andric 2164e8d8bef9SDimitry Andric // Helper to perform unary op on vectors. 2165e8d8bef9SDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2166e8d8bef9SDimitry Andric switch (Inst->getOpcode()) { 2167e8d8bef9SDimitry Andric case Instruction::FNeg: 2168e8d8bef9SDimitry Andric return Builder.CreateFNeg(Op); 2169e8d8bef9SDimitry Andric default: 2170e8d8bef9SDimitry Andric llvm_unreachable("Unsupported unary operator for matrix"); 2171e8d8bef9SDimitry Andric } 2172e8d8bef9SDimitry Andric }; 2173e8d8bef9SDimitry Andric 2174e8d8bef9SDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2175e8d8bef9SDimitry Andric Result.addVector(BuildVectorOp(M.getVector(I))); 2176e8d8bef9SDimitry Andric 2177e8d8bef9SDimitry Andric finalizeLowering(Inst, 2178e8d8bef9SDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2179e8d8bef9SDimitry Andric Result.getNumVectors()), 2180e8d8bef9SDimitry Andric Builder); 2181e8d8bef9SDimitry Andric return true; 2182e8d8bef9SDimitry Andric } 2183e8d8bef9SDimitry Andric 21845ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently 21855ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and 21865ffd83dbSDimitry Andric /// linearizing bottom up. 21875ffd83dbSDimitry Andric struct ExprLinearizer { 21885ffd83dbSDimitry Andric unsigned LengthToBreak = 100; 21895ffd83dbSDimitry Andric std::string Str; 21905ffd83dbSDimitry Andric raw_string_ostream Stream; 21915ffd83dbSDimitry Andric unsigned LineLength = 0; 21925ffd83dbSDimitry Andric const DataLayout &DL; 21935ffd83dbSDimitry Andric 21945ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify 21955ffd83dbSDimitry Andric /// matrix instructions. 21965ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 21975ffd83dbSDimitry Andric 21985ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is 21995ffd83dbSDimitry Andric /// part of. 22005ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 22015ffd83dbSDimitry Andric 22025ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram. 22035ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram; 22045ffd83dbSDimitry Andric 22055ffd83dbSDimitry Andric /// Leaf node of the expression to linearize. 22065ffd83dbSDimitry Andric Value *Leaf; 22075ffd83dbSDimitry Andric 22085ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing 22095ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused). 22105ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 22115ffd83dbSDimitry Andric 22125ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL, 22135ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix, 22145ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 22155ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 22165ffd83dbSDimitry Andric Value *Leaf) 221704eeddc0SDimitry Andric : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 22185ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 22195ffd83dbSDimitry Andric 22205ffd83dbSDimitry Andric void indent(unsigned N) { 22215ffd83dbSDimitry Andric LineLength += N; 22225ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++) 22235ffd83dbSDimitry Andric Stream << " "; 22245ffd83dbSDimitry Andric } 22255ffd83dbSDimitry Andric 22265ffd83dbSDimitry Andric void lineBreak() { 22275ffd83dbSDimitry Andric Stream << "\n"; 22285ffd83dbSDimitry Andric LineLength = 0; 22295ffd83dbSDimitry Andric } 22305ffd83dbSDimitry Andric 22315ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) { 22325ffd83dbSDimitry Andric if (LineLength >= LengthToBreak) 22335ffd83dbSDimitry Andric lineBreak(); 22345ffd83dbSDimitry Andric 22355ffd83dbSDimitry Andric if (LineLength == 0) 22365ffd83dbSDimitry Andric indent(Indent); 22375ffd83dbSDimitry Andric } 22385ffd83dbSDimitry Andric 22395ffd83dbSDimitry Andric void write(StringRef S) { 22405ffd83dbSDimitry Andric LineLength += S.size(); 22415ffd83dbSDimitry Andric Stream << S; 22425ffd83dbSDimitry Andric } 22435ffd83dbSDimitry Andric 22445ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) { 22455ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V)) 22465ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr); 22475ffd83dbSDimitry Andric else if (V->getType()->isPointerTy()) 2248e8d8bef9SDimitry Andric return getUnderlyingObject(V); 22495ffd83dbSDimitry Andric return V; 22505ffd83dbSDimitry Andric } 22515ffd83dbSDimitry Andric 22525ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram. 22535ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 22545ffd83dbSDimitry Andric 22555f757f3fSDimitry Andric /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 22565ffd83dbSDimitry Andric /// \p SS. 22575ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 22585ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V); 22595ffd83dbSDimitry Andric if (M == Inst2Matrix.end()) 22605ffd83dbSDimitry Andric SS << "unknown"; 22615ffd83dbSDimitry Andric else { 22625ffd83dbSDimitry Andric SS << M->second.getNumRows(); 22635ffd83dbSDimitry Andric SS << "x"; 22645ffd83dbSDimitry Andric SS << M->second.getNumColumns(); 22655ffd83dbSDimitry Andric } 22665ffd83dbSDimitry Andric } 22675ffd83dbSDimitry Andric 22685ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.* 22695ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input 22705ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name. 22715ffd83dbSDimitry Andric void writeFnName(CallInst *CI) { 22725ffd83dbSDimitry Andric if (!CI->getCalledFunction()) 22735ffd83dbSDimitry Andric write("<no called fn>"); 22745ffd83dbSDimitry Andric else { 22755ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName(); 22765f757f3fSDimitry Andric if (!Name.starts_with("llvm.matrix")) { 22775ffd83dbSDimitry Andric write(Name); 22785ffd83dbSDimitry Andric return; 22795ffd83dbSDimitry Andric } 228004eeddc0SDimitry Andric auto *II = cast<IntrinsicInst>(CI); 2281fe6060f1SDimitry Andric write(Intrinsic::getBaseName(II->getIntrinsicID()) 22825ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size())); 22835ffd83dbSDimitry Andric write("."); 2284e8d8bef9SDimitry Andric std::string Tmp; 22855ffd83dbSDimitry Andric raw_string_ostream SS(Tmp); 22865ffd83dbSDimitry Andric 22875ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 22885ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 22895ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 22905ffd83dbSDimitry Andric SS << "."; 22915ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS); 22925ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 22935ffd83dbSDimitry Andric break; 22945ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 22955ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 22965ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 22975ffd83dbSDimitry Andric break; 22985ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 22995ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS); 23005ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 23015ffd83dbSDimitry Andric break; 23025ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 23035ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 23045ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType(); 23055ffd83dbSDimitry Andric break; 23065ffd83dbSDimitry Andric default: 23075ffd83dbSDimitry Andric llvm_unreachable("Unhandled case"); 23085ffd83dbSDimitry Andric } 23095ffd83dbSDimitry Andric SS.flush(); 23105ffd83dbSDimitry Andric write(Tmp); 23115ffd83dbSDimitry Andric } 23125ffd83dbSDimitry Andric } 23135ffd83dbSDimitry Andric 23145ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const { 23155ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 23165ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 23175ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 23185ffd83dbSDimitry Andric return 3; 23195ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 23205ffd83dbSDimitry Andric return 2; 23215ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 23225ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 23235ffd83dbSDimitry Andric return 3; 23245ffd83dbSDimitry Andric default: 23255ffd83dbSDimitry Andric return 0; 23265ffd83dbSDimitry Andric } 23275ffd83dbSDimitry Andric } 23285ffd83dbSDimitry Andric return 0; 23295ffd83dbSDimitry Andric } 23305ffd83dbSDimitry Andric 23315ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an 23325ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we 23335ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values. 23345ffd83dbSDimitry Andric void write(Value *V) { 23355ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V); 23365ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) { 23375ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) { 23385ffd83dbSDimitry Andric Stream << "stack addr"; 23395ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size(); 23405ffd83dbSDimitry Andric } else { 23415ffd83dbSDimitry Andric Stream << "addr"; 23425ffd83dbSDimitry Andric LineLength += StringRef("addr").size(); 23435ffd83dbSDimitry Andric } 23445ffd83dbSDimitry Andric if (!V->getName().empty()) { 23455ffd83dbSDimitry Andric Stream << " %" << V->getName() << ""; 23465ffd83dbSDimitry Andric LineLength += V->getName().size() + 2; 23475ffd83dbSDimitry Andric } 23485ffd83dbSDimitry Andric return; 23495ffd83dbSDimitry Andric } 23505ffd83dbSDimitry Andric 23515ffd83dbSDimitry Andric std::string Tmp; 23525ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp); 23535ffd83dbSDimitry Andric 23545ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V)) 23555ffd83dbSDimitry Andric TmpStream << CI->getValue(); 23565ffd83dbSDimitry Andric else if (isa<Constant>(V)) 23575ffd83dbSDimitry Andric TmpStream << "constant"; 23585ffd83dbSDimitry Andric else { 23595ffd83dbSDimitry Andric if (isMatrix(V)) 23605ffd83dbSDimitry Andric TmpStream << "matrix"; 23615ffd83dbSDimitry Andric else 23625ffd83dbSDimitry Andric TmpStream << "scalar"; 23635ffd83dbSDimitry Andric } 23645ffd83dbSDimitry Andric TmpStream.flush(); 23655ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim()); 23665ffd83dbSDimitry Andric LineLength += Tmp.size(); 23675ffd83dbSDimitry Andric Stream << Tmp; 23685ffd83dbSDimitry Andric } 23695ffd83dbSDimitry Andric 23705ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent. 23715ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused) 23725ffd83dbSDimitry Andric /// at the re-used root instruction. 23735ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 23745ffd83dbSDimitry Andric bool ParentShared) { 23755ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr); 23765ffd83dbSDimitry Andric maybeIndent(Indent); 23775ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops; 23785ffd83dbSDimitry Andric 23795ffd83dbSDimitry Andric // Is Expr shared with other expression leaves? 23805ffd83dbSDimitry Andric bool ExprShared = false; 23815ffd83dbSDimitry Andric 23825ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required. 23835ffd83dbSDimitry Andric if (!ParentShared) { 23845ffd83dbSDimitry Andric auto SI = Shared.find(Expr); 23855ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf)); 23865ffd83dbSDimitry Andric 23875ffd83dbSDimitry Andric for (Value *S : SI->second) { 23885ffd83dbSDimitry Andric if (S == Leaf) 23895ffd83dbSDimitry Andric continue; 23905ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 23915ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) + 23925ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " ("); 23935ffd83dbSDimitry Andric } 23945ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1; 23955ffd83dbSDimitry Andric } 23965ffd83dbSDimitry Andric 23975ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second; 23985ffd83dbSDimitry Andric if (Reused && !ParentReused) 23995ffd83dbSDimitry Andric write("(reused) "); 24005ffd83dbSDimitry Andric 24015ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) { 24025ffd83dbSDimitry Andric writeFnName(CI); 24035ffd83dbSDimitry Andric 24045ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 24055ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) { 24065ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from 24075ffd83dbSDimitry Andric // non-matrix ops. 24085ffd83dbSDimitry Andric write("matrix"); 24095ffd83dbSDimitry Andric return; 24105ffd83dbSDimitry Andric } else { 24115ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end()); 24125ffd83dbSDimitry Andric write(std::string(I->getOpcodeName())); 24135ffd83dbSDimitry Andric } 24145ffd83dbSDimitry Andric 24155ffd83dbSDimitry Andric write(std::string("(")); 24165ffd83dbSDimitry Andric 24175ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1; 24185ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 24195ffd83dbSDimitry Andric NumOpsToBreak = 2; 24205ffd83dbSDimitry Andric 24215ffd83dbSDimitry Andric for (Value *Op : Ops) { 24225ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak) 24235ffd83dbSDimitry Andric lineBreak(); 24245ffd83dbSDimitry Andric 24255ffd83dbSDimitry Andric maybeIndent(Indent + 1); 24265ffd83dbSDimitry Andric if (isMatrix(Op)) 24275ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared); 24285ffd83dbSDimitry Andric else 24295ffd83dbSDimitry Andric write(Op); 24305ffd83dbSDimitry Andric if (Op != Ops.back()) 24315ffd83dbSDimitry Andric write(", "); 24325ffd83dbSDimitry Andric } 24335ffd83dbSDimitry Andric 24345ffd83dbSDimitry Andric write(")"); 24355ffd83dbSDimitry Andric } 24365ffd83dbSDimitry Andric 24375ffd83dbSDimitry Andric const std::string &getResult() { 24385ffd83dbSDimitry Andric Stream.flush(); 24395ffd83dbSDimitry Andric return Str; 24405ffd83dbSDimitry Andric } 24415ffd83dbSDimitry Andric }; 24425ffd83dbSDimitry Andric 24435ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks 24445ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used: 24455ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the 24465ffd83dbSDimitry Andric /// DISubprograms they are contained in. 24475ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in 24485ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 24495ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix 24505ffd83dbSDimitry Andric // users (like stores) in the current subprogram. 24515ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the 24525ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive 24535ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note 24545ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions 24555ffd83dbSDimitry Andric /// are explicitly marked as shared(). 24565ffd83dbSDimitry Andric struct RemarkGenerator { 24575ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 24585ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 24595ffd83dbSDimitry Andric Function &Func; 24605ffd83dbSDimitry Andric const DataLayout &DL; 24615ffd83dbSDimitry Andric 24625ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 24635ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func) 24645ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2465*0fca6ea1SDimitry Andric DL(Func.getDataLayout()) {} 24665ffd83dbSDimitry Andric 24675ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 24685ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in 24695ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores. 24705ffd83dbSDimitry Andric SmallVector<Value *, 4> 24715ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 24725ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves; 24735ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram) 24745ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() || 24755ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 24765ffd83dbSDimitry Andric return ExprsInSubprogram.count(U); 24775ffd83dbSDimitry Andric })) 24785ffd83dbSDimitry Andric Leaves.push_back(Expr); 24795ffd83dbSDimitry Andric return Leaves; 24805ffd83dbSDimitry Andric } 24815ffd83dbSDimitry Andric 24825ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 24835ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to 24845ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram. 24855ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V, 24865ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 24875ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 24885ffd83dbSDimitry Andric 24895ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V)) 24905ffd83dbSDimitry Andric return; 24915ffd83dbSDimitry Andric 24925ffd83dbSDimitry Andric auto I = Shared.insert({V, {}}); 24935ffd83dbSDimitry Andric I.first->second.insert(Leaf); 24945ffd83dbSDimitry Andric 24955ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values()) 24965ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 24975ffd83dbSDimitry Andric } 24985ffd83dbSDimitry Andric 24995ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression 25005ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once. 25015ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 25025ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy> 25035ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 25045ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 25055ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 25065ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root)) 25075ffd83dbSDimitry Andric return {}; 25085ffd83dbSDimitry Andric 25095ffd83dbSDimitry Andric // Already counted this expression. Stop. 25105ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second) 25115ffd83dbSDimitry Andric return {}; 25125ffd83dbSDimitry Andric 25135ffd83dbSDimitry Andric OpInfoTy SharedCount; 25145ffd83dbSDimitry Andric OpInfoTy Count; 25155ffd83dbSDimitry Andric 25165ffd83dbSDimitry Andric auto I = Shared.find(Root); 25175ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root); 25185ffd83dbSDimitry Andric if (I->second.size() == 1) 25195ffd83dbSDimitry Andric Count = CM->second.getOpInfo(); 25205ffd83dbSDimitry Andric else 25215ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo(); 25225ffd83dbSDimitry Andric 25235ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) { 25245ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 25255ffd83dbSDimitry Andric Count += C.first; 25265ffd83dbSDimitry Andric SharedCount += C.second; 25275ffd83dbSDimitry Andric } 25285ffd83dbSDimitry Andric return {Count, SharedCount}; 25295ffd83dbSDimitry Andric } 25305ffd83dbSDimitry Andric 25315ffd83dbSDimitry Andric void emitRemarks() { 25325ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 25335ffd83dbSDimitry Andric return; 25345ffd83dbSDimitry Andric 25355ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing 25365ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we 25375ffd83dbSDimitry Andric // only map them to the containing function. 25385ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2539bdd1243dSDimitry Andric for (const auto &KV : Inst2Matrix) { 25405ffd83dbSDimitry Andric if (Func.getSubprogram()) { 25415ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first); 25425ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc(); 25435ffd83dbSDimitry Andric while (Context) { 25445ffd83dbSDimitry Andric auto I = 25455ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 25465ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 25475ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 25485ffd83dbSDimitry Andric } 25495ffd83dbSDimitry Andric } else { 25505ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}}); 25515ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 25525ffd83dbSDimitry Andric } 25535ffd83dbSDimitry Andric } 25545ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) { 25555ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 25565ffd83dbSDimitry Andric KV.second.end()); 25575ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram); 25585ffd83dbSDimitry Andric 25595ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 25605ffd83dbSDimitry Andric for (Value *Leaf : Leaves) 25615ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 25625ffd83dbSDimitry Andric 25635ffd83dbSDimitry Andric // Generate remarks for each leaf. 25645ffd83dbSDimitry Andric for (auto *L : Leaves) { 25655ffd83dbSDimitry Andric 25665ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 25675ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 25685ffd83dbSDimitry Andric while (Context) { 25695ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) { 25705ffd83dbSDimitry Andric Loc = Context; 25715ffd83dbSDimitry Andric break; 25725ffd83dbSDimitry Andric } 25735ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 25745ffd83dbSDimitry Andric } 25755ffd83dbSDimitry Andric 25765ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 25775ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts; 25785ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) = 25795ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 25805ffd83dbSDimitry Andric 25815ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 25825ffd83dbSDimitry Andric cast<Instruction>(L)->getParent()); 25835ffd83dbSDimitry Andric 25845ffd83dbSDimitry Andric Rem << "Lowered with "; 25855ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 25865ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 25875ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps) 2588fe6060f1SDimitry Andric << " compute ops, " 2589fe6060f1SDimitry Andric << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2590fe6060f1SDimitry Andric << " exposed transposes"; 25915ffd83dbSDimitry Andric 25925ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 25935ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) { 25945ffd83dbSDimitry Andric Rem << ",\nadditionally " 25955ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 25965ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 25975ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 25985ffd83dbSDimitry Andric << " compute ops" 25995ffd83dbSDimitry Andric << " are shared with other expressions"; 26005ffd83dbSDimitry Andric } 26015ffd83dbSDimitry Andric 26025ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 26035ffd83dbSDimitry Andric ORE.emit(Rem); 26045ffd83dbSDimitry Andric } 26055ffd83dbSDimitry Andric } 26065ffd83dbSDimitry Andric } 26075ffd83dbSDimitry Andric 26085ffd83dbSDimitry Andric std::string 26095ffd83dbSDimitry Andric linearize(Value *L, 26105ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 26115ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 26125ffd83dbSDimitry Andric const DataLayout &DL) { 26135ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 26145ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false); 26155ffd83dbSDimitry Andric return Lin.getResult(); 26165ffd83dbSDimitry Andric } 26175ffd83dbSDimitry Andric }; 2618480093f4SDimitry Andric }; 2619480093f4SDimitry Andric } // namespace 2620480093f4SDimitry Andric 2621480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2622480093f4SDimitry Andric FunctionAnalysisManager &AM) { 2623480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2624e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE = nullptr; 2625e8d8bef9SDimitry Andric AAResults *AA = nullptr; 2626e8d8bef9SDimitry Andric DominatorTree *DT = nullptr; 2627e8d8bef9SDimitry Andric LoopInfo *LI = nullptr; 2628e8d8bef9SDimitry Andric 2629e8d8bef9SDimitry Andric if (!Minimal) { 2630e8d8bef9SDimitry Andric ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2631e8d8bef9SDimitry Andric AA = &AM.getResult<AAManager>(F); 2632e8d8bef9SDimitry Andric DT = &AM.getResult<DominatorTreeAnalysis>(F); 2633e8d8bef9SDimitry Andric LI = &AM.getResult<LoopAnalysis>(F); 2634e8d8bef9SDimitry Andric } 26355ffd83dbSDimitry Andric 26365ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2637480093f4SDimitry Andric if (LMT.Visit()) { 2638480093f4SDimitry Andric PreservedAnalyses PA; 2639e8d8bef9SDimitry Andric if (!Minimal) { 2640e8d8bef9SDimitry Andric PA.preserve<LoopAnalysis>(); 2641e8d8bef9SDimitry Andric PA.preserve<DominatorTreeAnalysis>(); 2642e8d8bef9SDimitry Andric } 2643480093f4SDimitry Andric return PA; 2644480093f4SDimitry Andric } 2645480093f4SDimitry Andric return PreservedAnalyses::all(); 2646480093f4SDimitry Andric } 2647480093f4SDimitry Andric 2648349cc55cSDimitry Andric void LowerMatrixIntrinsicsPass::printPipeline( 2649349cc55cSDimitry Andric raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2650349cc55cSDimitry Andric static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2651349cc55cSDimitry Andric OS, MapClassName2PassName); 265206c3fb27SDimitry Andric OS << '<'; 2653349cc55cSDimitry Andric if (Minimal) 2654349cc55cSDimitry Andric OS << "minimal"; 265506c3fb27SDimitry Andric OS << '>'; 2656e8d8bef9SDimitry Andric } 2657