1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10480093f4SDimitry Andric // 11480093f4SDimitry Andric // TODO: 12*5ffd83dbSDimitry Andric // * Improve fusion: 13*5ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14*5ffd83dbSDimitry Andric // transposed. 15*5ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns 16*5ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias. 17480093f4SDimitry Andric // 18480093f4SDimitry Andric //===----------------------------------------------------------------------===// 19480093f4SDimitry Andric 20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21480093f4SDimitry Andric #include "llvm/ADT/GraphTraits.h" 22480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 23480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 24*5ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 25*5ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 26*5ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 28*5ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 29480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 30480093f4SDimitry Andric #include "llvm/IR/CFG.h" 31480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 32*5ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h" 33480093f4SDimitry Andric #include "llvm/IR/Function.h" 34480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 35480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 36480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 37480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 38480093f4SDimitry Andric #include "llvm/InitializePasses.h" 39480093f4SDimitry Andric #include "llvm/Pass.h" 40*5ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h" 41*5ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 42480093f4SDimitry Andric #include "llvm/Support/Debug.h" 43480093f4SDimitry Andric #include "llvm/Transforms/Scalar.h" 44*5ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45480093f4SDimitry Andric 46480093f4SDimitry Andric using namespace llvm; 47480093f4SDimitry Andric using namespace PatternMatch; 48480093f4SDimitry Andric 49480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 50480093f4SDimitry Andric 51*5ffd83dbSDimitry Andric static cl::opt<bool> EnableShapePropagation( 52*5ffd83dbSDimitry Andric "matrix-propagate-shape", cl::init(true), cl::Hidden, 53*5ffd83dbSDimitry Andric cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 54*5ffd83dbSDimitry Andric "instructions.")); 55480093f4SDimitry Andric 56*5ffd83dbSDimitry Andric static cl::opt<bool> 57*5ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 58*5ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions.")); 59*5ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles. 60*5ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize( 61*5ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 62*5ffd83dbSDimitry Andric cl::desc( 63*5ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles.")); 64*5ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion( 65*5ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden, 66*5ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable.")); 67480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 68480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 69480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 70480093f4SDimitry Andric "result in different results, due to less rounding error.")); 71480093f4SDimitry Andric 72*5ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 73*5ffd83dbSDimitry Andric 74*5ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout( 75*5ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 76*5ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"), 77*5ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 78*5ffd83dbSDimitry Andric "Use column-major layout"), 79*5ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 80*5ffd83dbSDimitry Andric "Use row-major layout"))); 81*5ffd83dbSDimitry Andric 82*5ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the 83*5ffd83dbSDimitry Andric /// attached subprogram for a local scope. 84*5ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) { 85*5ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 86*5ffd83dbSDimitry Andric return Subprogram; 87*5ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram(); 88*5ffd83dbSDimitry Andric } 89*5ffd83dbSDimitry Andric 90480093f4SDimitry Andric namespace { 91480093f4SDimitry Andric 92*5ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 93*5ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 94*5ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors. 95*5ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements. 96*5ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column 97*5ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column 98*5ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function 99*5ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the 100*5ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix). 101480093f4SDimitry Andric // 102*5ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below 103480093f4SDimitry Andric // 104480093f4SDimitry Andric // 0 1 2 3 105480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 106480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 107480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 108480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 109480093f4SDimitry Andric 110480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 111480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 112*5ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns 113480093f4SDimitry Andric // of the sub-matrix. 114480093f4SDimitry Andric // 115*5ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 116480093f4SDimitry Andric // -> just returns Base 117*5ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 118480093f4SDimitry Andric // -> returns Base + (1 * 4) 119*5ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 120480093f4SDimitry Andric // -> returns Base + (2 * 4) 121480093f4SDimitry Andric // 122480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 123480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 124480093f4SDimitry Andric // 125480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 126480093f4SDimitry Andric // Base Col 1 Col 2 127480093f4SDimitry Andric // | | | 128480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 129480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 130480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 131480093f4SDimitry Andric // 132*5ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 133*5ffd83dbSDimitry Andric unsigned NumElements, Type *EltType, 134480093f4SDimitry Andric IRBuilder<> &Builder) { 135480093f4SDimitry Andric 136480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 137*5ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 138*5ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector."); 139480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 140480093f4SDimitry Andric 141*5ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride. 142*5ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 143480093f4SDimitry Andric 144*5ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation, 145*5ffd83dbSDimitry Andric // if we select vector 0. 146*5ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 147*5ffd83dbSDimitry Andric VecStart = BasePtr; 148480093f4SDimitry Andric else 149*5ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 150480093f4SDimitry Andric 151*5ffd83dbSDimitry Andric // Cast elementwise vector start pointer to a pointer to a vector 152*5ffd83dbSDimitry Andric // (EltType x NumElements)*. 153*5ffd83dbSDimitry Andric auto *VecType = FixedVectorType::get(EltType, NumElements); 154*5ffd83dbSDimitry Andric Type *VecPtrType = PointerType::get(VecType, AS); 155*5ffd83dbSDimitry Andric return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 156480093f4SDimitry Andric } 157480093f4SDimitry Andric 158480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 159480093f4SDimitry Andric /// 160480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 161480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 162480093f4SDimitry Andric /// instructions. 163*5ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout). 164*5ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout. 165480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 166480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 167480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 168480093f4SDimitry Andric /// a set of column vectors, 169*5ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which 170*5ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we 171*5ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the 172*5ffd83dbSDimitry Andric /// intrinsics, this includes stores for example. 173480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 174480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 175480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 176480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 177480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 178480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 179480093f4SDimitry Andric /// obsolete instructions. 180480093f4SDimitry Andric /// 181480093f4SDimitry Andric class LowerMatrixIntrinsics { 182480093f4SDimitry Andric Function &Func; 183480093f4SDimitry Andric const DataLayout &DL; 184480093f4SDimitry Andric const TargetTransformInfo &TTI; 185*5ffd83dbSDimitry Andric AliasAnalysis &AA; 186*5ffd83dbSDimitry Andric DominatorTree &DT; 187*5ffd83dbSDimitry Andric LoopInfo &LI; 188*5ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 189480093f4SDimitry Andric 190*5ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 191*5ffd83dbSDimitry Andric struct OpInfoTy { 192*5ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix. 193*5ffd83dbSDimitry Andric unsigned NumStores = 0; 194*5ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix. 195*5ffd83dbSDimitry Andric unsigned NumLoads = 0; 196*5ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix. 197*5ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 198*5ffd83dbSDimitry Andric 199*5ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) { 200*5ffd83dbSDimitry Andric NumStores += RHS.NumStores; 201*5ffd83dbSDimitry Andric NumLoads += RHS.NumLoads; 202*5ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps; 203*5ffd83dbSDimitry Andric return *this; 204*5ffd83dbSDimitry Andric } 205*5ffd83dbSDimitry Andric }; 206*5ffd83dbSDimitry Andric 207*5ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or 208*5ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type. 209*5ffd83dbSDimitry Andric class MatrixTy { 210*5ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors; 211*5ffd83dbSDimitry Andric 212*5ffd83dbSDimitry Andric OpInfoTy OpInfo; 213*5ffd83dbSDimitry Andric 214*5ffd83dbSDimitry Andric bool IsColumnMajor = true; 215480093f4SDimitry Andric 216480093f4SDimitry Andric public: 217*5ffd83dbSDimitry Andric MatrixTy() 218*5ffd83dbSDimitry Andric : Vectors(), 219*5ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 220*5ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors) 221*5ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()), 222*5ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 223*5ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 224*5ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 225480093f4SDimitry Andric 226*5ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows; 227*5ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J) 228*5ffd83dbSDimitry Andric addVector(UndefValue::get(FixedVectorType::get( 229*5ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns))); 230480093f4SDimitry Andric } 231480093f4SDimitry Andric 232*5ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; } 233*5ffd83dbSDimitry Andric Value *getColumn(unsigned i) const { 234*5ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 235*5ffd83dbSDimitry Andric return Vectors[i]; 236*5ffd83dbSDimitry Andric } 237*5ffd83dbSDimitry Andric Value *getRow(unsigned i) const { 238*5ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes"); 239*5ffd83dbSDimitry Andric return Vectors[i]; 240*5ffd83dbSDimitry Andric } 241480093f4SDimitry Andric 242*5ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; } 243480093f4SDimitry Andric 244*5ffd83dbSDimitry Andric Type *getElementType() { return getVectorTy()->getElementType(); } 245*5ffd83dbSDimitry Andric 246*5ffd83dbSDimitry Andric unsigned getNumVectors() const { 247*5ffd83dbSDimitry Andric if (isColumnMajor()) 248*5ffd83dbSDimitry Andric return getNumColumns(); 249*5ffd83dbSDimitry Andric return getNumRows(); 250*5ffd83dbSDimitry Andric } 251*5ffd83dbSDimitry Andric 252*5ffd83dbSDimitry Andric unsigned getNumColumns() const { 253*5ffd83dbSDimitry Andric if (isColumnMajor()) 254*5ffd83dbSDimitry Andric return Vectors.size(); 255*5ffd83dbSDimitry Andric else { 256*5ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 257*5ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 258*5ffd83dbSDimitry Andric } 259*5ffd83dbSDimitry Andric } 260*5ffd83dbSDimitry Andric unsigned getNumRows() const { 261*5ffd83dbSDimitry Andric if (isColumnMajor()) { 262*5ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 263*5ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 264*5ffd83dbSDimitry Andric } else 265*5ffd83dbSDimitry Andric return Vectors.size(); 266*5ffd83dbSDimitry Andric } 267*5ffd83dbSDimitry Andric 268*5ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); } 269*5ffd83dbSDimitry Andric VectorType *getColumnTy() { 270*5ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 271*5ffd83dbSDimitry Andric return getVectorTy(); 272*5ffd83dbSDimitry Andric } 273*5ffd83dbSDimitry Andric 274*5ffd83dbSDimitry Andric VectorType *getVectorTy() { 275*5ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType()); 276*5ffd83dbSDimitry Andric } 277480093f4SDimitry Andric 278480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 279*5ffd83dbSDimitry Andric assert(isColumnMajor() && 280*5ffd83dbSDimitry Andric "columns() only supported for column-major matrixes"); 281*5ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 282480093f4SDimitry Andric } 283480093f4SDimitry Andric 284*5ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 285*5ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 286*5ffd83dbSDimitry Andric } 287*5ffd83dbSDimitry Andric 288*5ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating 289480093f4SDimitry Andric /// them. 290480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 291*5ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0] 292*5ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors); 293*5ffd83dbSDimitry Andric } 294*5ffd83dbSDimitry Andric 295*5ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) { 296*5ffd83dbSDimitry Andric OpInfo.NumLoads += N; 297*5ffd83dbSDimitry Andric return *this; 298*5ffd83dbSDimitry Andric } 299*5ffd83dbSDimitry Andric 300*5ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 301*5ffd83dbSDimitry Andric 302*5ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) { 303*5ffd83dbSDimitry Andric OpInfo.NumStores += N; 304*5ffd83dbSDimitry Andric return *this; 305*5ffd83dbSDimitry Andric } 306*5ffd83dbSDimitry Andric 307*5ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) { 308*5ffd83dbSDimitry Andric OpInfo.NumComputeOps += N; 309*5ffd83dbSDimitry Andric return *this; 310*5ffd83dbSDimitry Andric } 311*5ffd83dbSDimitry Andric 312*5ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; } 313*5ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; } 314*5ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 315*5ffd83dbSDimitry Andric 316*5ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; } 317*5ffd83dbSDimitry Andric 318*5ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; } 319*5ffd83dbSDimitry Andric 320*5ffd83dbSDimitry Andric unsigned getStride() const { 321*5ffd83dbSDimitry Andric if (isColumnMajor()) 322*5ffd83dbSDimitry Andric return getNumRows(); 323*5ffd83dbSDimitry Andric return getNumColumns(); 324*5ffd83dbSDimitry Andric } 325*5ffd83dbSDimitry Andric 326*5ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 327*5ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column 328*5ffd83dbSDimitry Andric /// vector, otherwise from a row vector. 329*5ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 330*5ffd83dbSDimitry Andric IRBuilder<> &Builder) const { 331*5ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 332*5ffd83dbSDimitry Andric Value *Undef = UndefValue::get(Vec->getType()); 333*5ffd83dbSDimitry Andric return Builder.CreateShuffleVector( 334*5ffd83dbSDimitry Andric Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 335*5ffd83dbSDimitry Andric "block"); 336480093f4SDimitry Andric } 337480093f4SDimitry Andric }; 338480093f4SDimitry Andric 339480093f4SDimitry Andric struct ShapeInfo { 340480093f4SDimitry Andric unsigned NumRows; 341480093f4SDimitry Andric unsigned NumColumns; 342480093f4SDimitry Andric 343*5ffd83dbSDimitry Andric bool IsColumnMajor; 344*5ffd83dbSDimitry Andric 345480093f4SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 346*5ffd83dbSDimitry Andric : NumRows(NumRows), NumColumns(NumColumns), 347*5ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 348480093f4SDimitry Andric 349480093f4SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 350*5ffd83dbSDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 351*5ffd83dbSDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {} 352480093f4SDimitry Andric 353480093f4SDimitry Andric bool operator==(const ShapeInfo &other) { 354480093f4SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 355480093f4SDimitry Andric } 356480093f4SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 357480093f4SDimitry Andric 358480093f4SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 359480093f4SDimitry Andric /// are != 0. 360480093f4SDimitry Andric operator bool() const { 361480093f4SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 362480093f4SDimitry Andric return NumRows != 0; 363480093f4SDimitry Andric } 364*5ffd83dbSDimitry Andric 365*5ffd83dbSDimitry Andric unsigned getStride() const { 366*5ffd83dbSDimitry Andric if (IsColumnMajor) 367*5ffd83dbSDimitry Andric return NumRows; 368*5ffd83dbSDimitry Andric return NumColumns; 369*5ffd83dbSDimitry Andric } 370*5ffd83dbSDimitry Andric 371*5ffd83dbSDimitry Andric unsigned getNumVectors() const { 372*5ffd83dbSDimitry Andric if (IsColumnMajor) 373*5ffd83dbSDimitry Andric return NumColumns; 374*5ffd83dbSDimitry Andric return NumRows; 375*5ffd83dbSDimitry Andric } 376480093f4SDimitry Andric }; 377480093f4SDimitry Andric 378480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 379480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 380480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 381*5ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the 382480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 383480093f4SDimitry Andric /// using shape information as well. 384480093f4SDimitry Andric DenseMap<Value *, ShapeInfo> ShapeMap; 385480093f4SDimitry Andric 386480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 387480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 388480093f4SDimitry Andric /// those need to be removed after we finished lowering. 389480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 390480093f4SDimitry Andric 391480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 392*5ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 393480093f4SDimitry Andric 394480093f4SDimitry Andric public: 395*5ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 396*5ffd83dbSDimitry Andric AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, 397*5ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE) 398*5ffd83dbSDimitry Andric : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 399*5ffd83dbSDimitry Andric LI(LI), ORE(ORE) {} 400480093f4SDimitry Andric 401*5ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) { 402*5ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type"); 403*5ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(), 404*5ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements()); 405*5ffd83dbSDimitry Andric } 406*5ffd83dbSDimitry Andric 407*5ffd83dbSDimitry Andric // 408*5ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on 409*5ffd83dbSDimitry Andric /// \p VT * N. 410*5ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) { 411*5ffd83dbSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 412*5ffd83dbSDimitry Andric double(TTI.getRegisterBitWidth(true))); 413*5ffd83dbSDimitry Andric } 414*5ffd83dbSDimitry Andric 415*5ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to. 416480093f4SDimitry Andric /// 417*5ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 418*5ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 419*5ffd83dbSDimitry Andric /// into vectors. 420*5ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 421*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 422480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 423480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 424*5ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() == 425*5ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns && 426480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 427480093f4SDimitry Andric 428480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 429*5ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape 430480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 431480093f4SDimitry Andric // vector and split it later. 432480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 433480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 434*5ffd83dbSDimitry Andric MatrixTy &M = Found->second; 435480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 436480093f4SDimitry Andric // information 437480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 438480093f4SDimitry Andric return M; 439480093f4SDimitry Andric 440480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 441480093f4SDimitry Andric } 442480093f4SDimitry Andric 443480093f4SDimitry Andric // Otherwise split MatrixVal. 444480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 445480093f4SDimitry Andric Value *Undef = UndefValue::get(VType); 446*5ffd83dbSDimitry Andric for (unsigned MaskStart = 0; 447*5ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 448*5ffd83dbSDimitry Andric MaskStart += SI.getStride()) { 449*5ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector( 450*5ffd83dbSDimitry Andric MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), 451*5ffd83dbSDimitry Andric "split"); 452480093f4SDimitry Andric SplitVecs.push_back(V); 453480093f4SDimitry Andric } 454480093f4SDimitry Andric 455480093f4SDimitry Andric return {SplitVecs}; 456480093f4SDimitry Andric } 457480093f4SDimitry Andric 458480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 459480093f4SDimitry Andric /// for instructions that support it. 460480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 461480093f4SDimitry Andric assert(Shape && "Shape not set"); 462480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 463480093f4SDimitry Andric return false; 464480093f4SDimitry Andric 465480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 466480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 467480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 468480093f4SDimitry Andric << SIter->second.NumRows << " " 469480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 470480093f4SDimitry Andric return false; 471480093f4SDimitry Andric } 472480093f4SDimitry Andric 473480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 474480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 475480093f4SDimitry Andric << " for " << *V << "\n"); 476480093f4SDimitry Andric return true; 477480093f4SDimitry Andric } 478480093f4SDimitry Andric 479480093f4SDimitry Andric bool isUniformShape(Value *V) { 480480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 481480093f4SDimitry Andric if (!I) 482480093f4SDimitry Andric return true; 483480093f4SDimitry Andric 484480093f4SDimitry Andric switch (I->getOpcode()) { 485480093f4SDimitry Andric case Instruction::FAdd: 486480093f4SDimitry Andric case Instruction::FSub: 487480093f4SDimitry Andric case Instruction::FMul: // Scalar multiply. 488480093f4SDimitry Andric case Instruction::Add: 489480093f4SDimitry Andric case Instruction::Mul: 490480093f4SDimitry Andric case Instruction::Sub: 491480093f4SDimitry Andric return true; 492480093f4SDimitry Andric default: 493480093f4SDimitry Andric return false; 494480093f4SDimitry Andric } 495480093f4SDimitry Andric } 496480093f4SDimitry Andric 497480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 498480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 499480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 500480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 501480093f4SDimitry Andric if (!Inst) 502480093f4SDimitry Andric return false; 503480093f4SDimitry Andric 504480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 505480093f4SDimitry Andric if (II) 506480093f4SDimitry Andric switch (II->getIntrinsicID()) { 507480093f4SDimitry Andric case Intrinsic::matrix_multiply: 508480093f4SDimitry Andric case Intrinsic::matrix_transpose: 509*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 510*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 511480093f4SDimitry Andric return true; 512480093f4SDimitry Andric default: 513480093f4SDimitry Andric return false; 514480093f4SDimitry Andric } 515480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 516480093f4SDimitry Andric } 517480093f4SDimitry Andric 518480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 519480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 520480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 521480093f4SDimitry Andric /// shapes of operands. 522480093f4SDimitry Andric SmallVector<Instruction *, 32> 523480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 524480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 525480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 526480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 527480093f4SDimitry Andric // list. 528480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 529480093f4SDimitry Andric while (!WorkList.empty()) { 530480093f4SDimitry Andric Instruction *Inst = WorkList.back(); 531480093f4SDimitry Andric WorkList.pop_back(); 532480093f4SDimitry Andric 533480093f4SDimitry Andric // New entry, set the value and insert operands 534480093f4SDimitry Andric bool Propagate = false; 535480093f4SDimitry Andric 536480093f4SDimitry Andric Value *MatrixA; 537480093f4SDimitry Andric Value *MatrixB; 538480093f4SDimitry Andric Value *M; 539480093f4SDimitry Andric Value *N; 540480093f4SDimitry Andric Value *K; 541480093f4SDimitry Andric if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 542480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 543480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 544480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, K}); 545480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 546480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 547480093f4SDimitry Andric // Flip dimensions. 548480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 549*5ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 550480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 551*5ffd83dbSDimitry Andric m_Value(), m_Value(M), m_Value(N)))) { 552480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 553*5ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 554*5ffd83dbSDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), 555*5ffd83dbSDimitry Andric m_Value(N)))) { 556480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, N}); 557480093f4SDimitry Andric } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 558480093f4SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 559480093f4SDimitry Andric if (OpShape != ShapeMap.end()) 560480093f4SDimitry Andric setShapeInfo(Inst, OpShape->second); 561480093f4SDimitry Andric continue; 562480093f4SDimitry Andric } else if (isUniformShape(Inst)) { 563480093f4SDimitry Andric // Find the first operand that has a known shape and use that. 564480093f4SDimitry Andric for (auto &Op : Inst->operands()) { 565480093f4SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 566480093f4SDimitry Andric if (OpShape != ShapeMap.end()) { 567480093f4SDimitry Andric Propagate |= setShapeInfo(Inst, OpShape->second); 568480093f4SDimitry Andric break; 569480093f4SDimitry Andric } 570480093f4SDimitry Andric } 571480093f4SDimitry Andric } 572480093f4SDimitry Andric 573480093f4SDimitry Andric if (Propagate) { 574480093f4SDimitry Andric NewWorkList.push_back(Inst); 575480093f4SDimitry Andric for (auto *User : Inst->users()) 576480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 577480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 578480093f4SDimitry Andric } 579480093f4SDimitry Andric } 580480093f4SDimitry Andric 581480093f4SDimitry Andric return NewWorkList; 582480093f4SDimitry Andric } 583480093f4SDimitry Andric 584480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 585480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 586480093f4SDimitry Andric SmallVector<Instruction *, 32> 587480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 588480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 589480093f4SDimitry Andric 590480093f4SDimitry Andric auto pushInstruction = [](Value *V, 591480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 592480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 593480093f4SDimitry Andric if (I) 594480093f4SDimitry Andric WorkList.push_back(I); 595480093f4SDimitry Andric }; 596480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 597480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 598480093f4SDimitry Andric // worklist. 599480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 600480093f4SDimitry Andric while (!WorkList.empty()) { 601480093f4SDimitry Andric Value *V = WorkList.back(); 602480093f4SDimitry Andric WorkList.pop_back(); 603480093f4SDimitry Andric 604480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 605480093f4SDimitry Andric if (!isa<Instruction>(V)) 606480093f4SDimitry Andric continue; 607480093f4SDimitry Andric 608480093f4SDimitry Andric Value *MatrixA; 609480093f4SDimitry Andric Value *MatrixB; 610480093f4SDimitry Andric Value *M; 611480093f4SDimitry Andric Value *N; 612480093f4SDimitry Andric Value *K; 613480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 614480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 615480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 616480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 617480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 618480093f4SDimitry Andric 619480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 620480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 621480093f4SDimitry Andric 622480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 623480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 624480093f4SDimitry Andric // Flip dimensions. 625480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 626480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 627*5ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 628*5ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 629480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 630480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 631480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 632480093f4SDimitry Andric } 633480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 634*5ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 635480093f4SDimitry Andric // Nothing to do, no matrix input. 636480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 637480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 638480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 639480093f4SDimitry Andric } else if (isUniformShape(V)) { 640480093f4SDimitry Andric // Propagate to all operands. 641480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 642480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 643480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 644480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 645480093f4SDimitry Andric } 646480093f4SDimitry Andric } 647480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 648480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 649480093f4SDimitry Andric // propagation. 650480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 651480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 652480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 653480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 654480093f4SDimitry Andric } 655480093f4SDimitry Andric return NewWorkList; 656480093f4SDimitry Andric } 657480093f4SDimitry Andric 658480093f4SDimitry Andric bool Visit() { 659480093f4SDimitry Andric if (EnableShapePropagation) { 660480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 661480093f4SDimitry Andric 662480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 663480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 664480093f4SDimitry Andric for (BasicBlock &BB : Func) 665480093f4SDimitry Andric for (Instruction &Inst : BB) { 666480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 667480093f4SDimitry Andric if (!II) 668480093f4SDimitry Andric continue; 669480093f4SDimitry Andric 670480093f4SDimitry Andric switch (II->getIntrinsicID()) { 671480093f4SDimitry Andric case Intrinsic::matrix_multiply: 672480093f4SDimitry Andric case Intrinsic::matrix_transpose: 673*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 674*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 675480093f4SDimitry Andric WorkList.push_back(&Inst); 676480093f4SDimitry Andric break; 677480093f4SDimitry Andric default: 678480093f4SDimitry Andric break; 679480093f4SDimitry Andric } 680480093f4SDimitry Andric } 681480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 682480093f4SDimitry Andric while (!WorkList.empty()) { 683480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 684480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 685480093f4SDimitry Andric } 686480093f4SDimitry Andric } 687480093f4SDimitry Andric 688480093f4SDimitry Andric bool Changed = false; 689*5ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts; 690*5ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts; 691480093f4SDimitry Andric 692*5ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for 693*5ffd83dbSDimitry Andric // fusion (currently only matrix multiplies). 694*5ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 695*5ffd83dbSDimitry Andric for (auto *BB : RPOT) 696*5ffd83dbSDimitry Andric for (Instruction &I : *BB) { 697*5ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end()) 698*5ffd83dbSDimitry Andric continue; 699*5ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 700*5ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I)); 701*5ffd83dbSDimitry Andric MatrixInsts.push_back(&I); 702*5ffd83dbSDimitry Andric } 703*5ffd83dbSDimitry Andric 704*5ffd83dbSDimitry Andric // Second, try to fuse candidates. 705*5ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts; 706*5ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts) 707*5ffd83dbSDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts); 708*5ffd83dbSDimitry Andric Changed = !FusedInsts.empty(); 709*5ffd83dbSDimitry Andric 710*5ffd83dbSDimitry Andric // Third, lower remaining instructions with shape information. 711*5ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) { 712*5ffd83dbSDimitry Andric if (FusedInsts.count(Inst)) 713*5ffd83dbSDimitry Andric continue; 714*5ffd83dbSDimitry Andric 715*5ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 716*5ffd83dbSDimitry Andric 717*5ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 718480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 719480093f4SDimitry Andric 720480093f4SDimitry Andric Value *Op1; 721480093f4SDimitry Andric Value *Op2; 722*5ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 723480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 724*5ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1)))) 725*5ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 726*5ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 727*5ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 728480093f4SDimitry Andric } 729*5ffd83dbSDimitry Andric 730*5ffd83dbSDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 731*5ffd83dbSDimitry Andric RemarkGen.emitRemarks(); 732480093f4SDimitry Andric 733480093f4SDimitry Andric for (Instruction *Inst : reverse(ToRemove)) 734480093f4SDimitry Andric Inst->eraseFromParent(); 735480093f4SDimitry Andric 736480093f4SDimitry Andric return Changed; 737480093f4SDimitry Andric } 738480093f4SDimitry Andric 739480093f4SDimitry Andric /// Turns \p BasePtr into an elementwise pointer to \p EltType. 740480093f4SDimitry Andric Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 741480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 742480093f4SDimitry Andric Type *EltPtrType = PointerType::get(EltType, AS); 743480093f4SDimitry Andric return Builder.CreatePointerCast(BasePtr, EltPtrType); 744480093f4SDimitry Andric } 745480093f4SDimitry Andric 746480093f4SDimitry Andric /// Replace intrinsic calls 747480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 748480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 749480093f4SDimitry Andric return false; 750480093f4SDimitry Andric 751480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 752480093f4SDimitry Andric case Intrinsic::matrix_multiply: 753480093f4SDimitry Andric LowerMultiply(Inst); 754480093f4SDimitry Andric break; 755480093f4SDimitry Andric case Intrinsic::matrix_transpose: 756480093f4SDimitry Andric LowerTranspose(Inst); 757480093f4SDimitry Andric break; 758*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 759*5ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst); 760480093f4SDimitry Andric break; 761*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 762*5ffd83dbSDimitry Andric LowerColumnMajorStore(Inst); 763480093f4SDimitry Andric break; 764480093f4SDimitry Andric default: 765480093f4SDimitry Andric return false; 766480093f4SDimitry Andric } 767480093f4SDimitry Andric return true; 768480093f4SDimitry Andric } 769480093f4SDimitry Andric 770*5ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them. 771*5ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 772*5ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For 773*5ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial 774*5ffd83dbSDimitry Andric /// alignment and the element size in bytes. 775*5ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 776*5ffd83dbSDimitry Andric MaybeAlign A) const { 777*5ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 778*5ffd83dbSDimitry Andric if (Idx == 0) 779*5ffd83dbSDimitry Andric return InitialAlign; 780*5ffd83dbSDimitry Andric 781*5ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 782*5ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 783*5ffd83dbSDimitry Andric uint64_t StrideInBytes = 784*5ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8; 785*5ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes); 786*5ffd83dbSDimitry Andric } 787*5ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8); 788*5ffd83dbSDimitry Andric } 789*5ffd83dbSDimitry Andric 790*5ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 791*5ffd83dbSDimitry Andric /// vectors. 792*5ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 793*5ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 794*5ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 795480093f4SDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 796*5ffd83dbSDimitry Andric MatrixTy Result; 797*5ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 798*5ffd83dbSDimitry Andric Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 799*5ffd83dbSDimitry Andric Shape.getStride(), VType->getElementType(), 800*5ffd83dbSDimitry Andric Builder); 801*5ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad( 802*5ffd83dbSDimitry Andric GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), 803*5ffd83dbSDimitry Andric IsVolatile, "col.load"); 804*5ffd83dbSDimitry Andric 805*5ffd83dbSDimitry Andric Result.addVector(Vector); 806*5ffd83dbSDimitry Andric } 807*5ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 808*5ffd83dbSDimitry Andric Result.getNumVectors()); 809480093f4SDimitry Andric } 810480093f4SDimitry Andric 811*5ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 812*5ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J]. 813*5ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 814*5ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J, 815*5ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy, 816*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 817*5ffd83dbSDimitry Andric 818*5ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 819*5ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 820*5ffd83dbSDimitry Andric 821*5ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 822*5ffd83dbSDimitry Andric Value *EltPtr = 823*5ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 824*5ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 825*5ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 826*5ffd83dbSDimitry Andric ResultShape.NumColumns); 827*5ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 828*5ffd83dbSDimitry Andric Value *TilePtr = 829*5ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 830*5ffd83dbSDimitry Andric 831*5ffd83dbSDimitry Andric return loadMatrix(TileTy, TilePtr, Align, 832*5ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, 833*5ffd83dbSDimitry Andric ResultShape, Builder); 834480093f4SDimitry Andric } 835480093f4SDimitry Andric 836*5ffd83dbSDimitry Andric /// Lower a load instruction with shape information. 837*5ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 838*5ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) { 839*5ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 840*5ffd83dbSDimitry Andric finalizeLowering(Inst, 841*5ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 842*5ffd83dbSDimitry Andric Shape, Builder), 843*5ffd83dbSDimitry Andric Builder); 844*5ffd83dbSDimitry Andric } 845*5ffd83dbSDimitry Andric 846*5ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load. 847480093f4SDimitry Andric /// 848480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 849*5ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) { 850*5ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 851*5ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 852480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 853480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 854*5ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 855*5ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 856480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 857480093f4SDimitry Andric } 858480093f4SDimitry Andric 859*5ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 860*5ffd83dbSDimitry Andric /// MatrixPtr[I][J]. 861*5ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 862*5ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 863*5ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 864*5ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 865*5ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 866*5ffd83dbSDimitry Andric 867*5ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 868*5ffd83dbSDimitry Andric Value *EltPtr = 869*5ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 870*5ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 871*5ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 872*5ffd83dbSDimitry Andric StoreVal.getNumColumns()); 873*5ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 874*5ffd83dbSDimitry Andric Value *TilePtr = 875*5ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 876*5ffd83dbSDimitry Andric 877*5ffd83dbSDimitry Andric storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 878*5ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 879*5ffd83dbSDimitry Andric } 880*5ffd83dbSDimitry Andric 881*5ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 882*5ffd83dbSDimitry Andric /// vectors. 883*5ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 884*5ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile, 885*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 886*5ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 887*5ffd83dbSDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 888*5ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) { 889*5ffd83dbSDimitry Andric Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 890*5ffd83dbSDimitry Andric Stride, StoreVal.getStride(), 891*5ffd83dbSDimitry Andric VType->getElementType(), Builder); 892*5ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP, 893*5ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride, 894*5ffd83dbSDimitry Andric VType->getElementType(), 895*5ffd83dbSDimitry Andric MAlign), 896*5ffd83dbSDimitry Andric IsVolatile); 897*5ffd83dbSDimitry Andric } 898*5ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 899*5ffd83dbSDimitry Andric StoreVal.getNumVectors()); 900*5ffd83dbSDimitry Andric } 901*5ffd83dbSDimitry Andric 902*5ffd83dbSDimitry Andric /// Lower a store instruction with shape information. 903*5ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 904*5ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) { 905*5ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 906*5ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder); 907*5ffd83dbSDimitry Andric finalizeLowering(Inst, 908*5ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 909*5ffd83dbSDimitry Andric IsVolatile, Builder), 910*5ffd83dbSDimitry Andric Builder); 911*5ffd83dbSDimitry Andric } 912*5ffd83dbSDimitry Andric 913*5ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store. 914*5ffd83dbSDimitry Andric /// 915*5ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 916*5ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) { 917*5ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 918*5ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 919*5ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0); 920*5ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1); 921*5ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2); 922*5ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 923*5ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 924*5ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 925480093f4SDimitry Andric } 926480093f4SDimitry Andric 927480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 928480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 929*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 930480093f4SDimitry Andric 931480093f4SDimitry Andric // First, bring Block to the same size as Col 932480093f4SDimitry Andric unsigned BlockNumElts = 933*5ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements(); 934*5ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 935480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 936480093f4SDimitry Andric 937480093f4SDimitry Andric Value *Undef = UndefValue::get(Block->getType()); 938*5ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector( 939*5ffd83dbSDimitry Andric Block, Undef, 940*5ffd83dbSDimitry Andric createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 941480093f4SDimitry Andric 942480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 943480093f4SDimitry Andric // 8, 4, 5, 6 944*5ffd83dbSDimitry Andric SmallVector<int, 16> Mask; 945480093f4SDimitry Andric unsigned i; 946480093f4SDimitry Andric for (i = 0; i < I; i++) 947*5ffd83dbSDimitry Andric Mask.push_back(i); 948480093f4SDimitry Andric 949*5ffd83dbSDimitry Andric unsigned VecNumElts = 950*5ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements(); 951480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 952*5ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts); 953480093f4SDimitry Andric 954480093f4SDimitry Andric for (; i < VecNumElts; i++) 955*5ffd83dbSDimitry Andric Mask.push_back(i); 956480093f4SDimitry Andric 957*5ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask); 958480093f4SDimitry Andric } 959480093f4SDimitry Andric 960480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 961*5ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction, 962*5ffd83dbSDimitry Andric unsigned &NumComputeOps) { 963*5ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 964480093f4SDimitry Andric if (!Sum) 965480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 966480093f4SDimitry Andric 967480093f4SDimitry Andric if (UseFPOp) { 968480093f4SDimitry Andric if (AllowContraction) { 969480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 970480093f4SDimitry Andric // if that's profitable. 971*5ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration( 972480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 973480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 974480093f4SDimitry Andric } 975*5ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 976480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 977480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 978480093f4SDimitry Andric } 979480093f4SDimitry Andric 980*5ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 981480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 982480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 983480093f4SDimitry Andric } 984480093f4SDimitry Andric 985480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 986480093f4SDimitry Andric /// users with shape information, there's nothing to do: the will use the 987480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 988480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 989480093f4SDimitry Andric /// deletion. 990*5ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 991480093f4SDimitry Andric IRBuilder<> &Builder) { 992480093f4SDimitry Andric Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 993480093f4SDimitry Andric 994480093f4SDimitry Andric ToRemove.push_back(Inst); 995480093f4SDimitry Andric Value *Flattened = nullptr; 996480093f4SDimitry Andric for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 997480093f4SDimitry Andric Use &U = *I++; 998480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 999480093f4SDimitry Andric if (!Flattened) 1000480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 1001480093f4SDimitry Andric U.set(Flattened); 1002480093f4SDimitry Andric } 1003480093f4SDimitry Andric } 1004480093f4SDimitry Andric } 1005480093f4SDimitry Andric 1006*5ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating 1007*5ffd83dbSDimitry Andric /// addition. 1008*5ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1009*5ffd83dbSDimitry Andric const MatrixTy &B, bool AllowContraction, 1010*5ffd83dbSDimitry Andric IRBuilder<> &Builder, bool isTiled) { 1011*5ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>( 1012*5ffd83dbSDimitry Andric TTI.getRegisterBitWidth(true) / 1013*5ffd83dbSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1014*5ffd83dbSDimitry Andric 1U); 1015*5ffd83dbSDimitry Andric unsigned R = Result.getNumRows(); 1016*5ffd83dbSDimitry Andric unsigned C = Result.getNumColumns(); 1017*5ffd83dbSDimitry Andric unsigned M = A.getNumColumns(); 1018*5ffd83dbSDimitry Andric 1019*5ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy(); 1020*5ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 1021*5ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 1022*5ffd83dbSDimitry Andric "operands must agree on matrix layout"); 1023*5ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 1024*5ffd83dbSDimitry Andric if (A.isColumnMajor()) { 1025*5ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second 1026*5ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 1027*5ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation. 1028*5ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) { 1029*5ffd83dbSDimitry Andric unsigned BlockSize = VF; 1030*5ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration. 1031*5ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1032*5ffd83dbSDimitry Andric 1033*5ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 1034*5ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 1035*5ffd83dbSDimitry Andric while (I + BlockSize > R) 1036*5ffd83dbSDimitry Andric BlockSize /= 2; 1037*5ffd83dbSDimitry Andric 1038*5ffd83dbSDimitry Andric Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 1039*5ffd83dbSDimitry Andric : nullptr; 1040*5ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 1041*5ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder); 1042*5ffd83dbSDimitry Andric Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 1043*5ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1044*5ffd83dbSDimitry Andric Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1045*5ffd83dbSDimitry Andric Result.getElementType()->isFloatingPointTy(), 1046*5ffd83dbSDimitry Andric Builder, AllowContraction, NumComputeOps); 1047*5ffd83dbSDimitry Andric } 1048*5ffd83dbSDimitry Andric Result.setVector(J, 1049*5ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder)); 1050*5ffd83dbSDimitry Andric } 1051*5ffd83dbSDimitry Andric } 1052*5ffd83dbSDimitry Andric } else { 1053*5ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first 1054*5ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this 1055*5ffd83dbSDimitry Andric // the adds can be vectorized without reassociation. 1056*5ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) { 1057*5ffd83dbSDimitry Andric unsigned BlockSize = VF; 1058*5ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1059*5ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) { 1060*5ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 1061*5ffd83dbSDimitry Andric while (J + BlockSize > C) 1062*5ffd83dbSDimitry Andric BlockSize /= 2; 1063*5ffd83dbSDimitry Andric 1064*5ffd83dbSDimitry Andric Value *Sum = nullptr; 1065*5ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 1066*5ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder); 1067*5ffd83dbSDimitry Andric Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 1068*5ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1069*5ffd83dbSDimitry Andric Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1070*5ffd83dbSDimitry Andric IsFP, Builder, AllowContraction, NumComputeOps); 1071*5ffd83dbSDimitry Andric } 1072*5ffd83dbSDimitry Andric Result.setVector(I, 1073*5ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder)); 1074*5ffd83dbSDimitry Andric } 1075*5ffd83dbSDimitry Andric } 1076*5ffd83dbSDimitry Andric } 1077*5ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps); 1078*5ffd83dbSDimitry Andric } 1079*5ffd83dbSDimitry Andric 1080*5ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially 1081*5ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location 1082*5ffd83dbSDimitry Andric /// is returned. 1083*5ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1084*5ffd83dbSDimitry Andric CallInst *MatMul) { 1085*5ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store); 1086*5ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load); 1087*5ffd83dbSDimitry Andric 1088*5ffd83dbSDimitry Andric AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); 1089*5ffd83dbSDimitry Andric 1090*5ffd83dbSDimitry Andric // If we can statically determine noalias we're good. 1091*5ffd83dbSDimitry Andric if (!LdAliased) 1092*5ffd83dbSDimitry Andric return Load->getPointerOperand(); 1093*5ffd83dbSDimitry Andric 1094*5ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store 1095*5ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer. 1096*5ffd83dbSDimitry Andric 1097*5ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy. 1098*5ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent(); 1099*5ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1100*5ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1101*5ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches. 1102*5ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1103*5ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0)) 1104*5ffd83dbSDimitry Andric DTUpdates.push_back({DT.Delete, Check0, Succ}); 1105*5ffd83dbSDimitry Andric 1106*5ffd83dbSDimitry Andric BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1107*5ffd83dbSDimitry Andric nullptr, "alias_cont"); 1108*5ffd83dbSDimitry Andric BasicBlock *Copy = 1109*5ffd83dbSDimitry Andric SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); 1110*5ffd83dbSDimitry Andric BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1111*5ffd83dbSDimitry Andric nullptr, "no_alias"); 1112*5ffd83dbSDimitry Andric 1113*5ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store 1114*5ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are 1115*5ffd83dbSDimitry Andric // guaranteed to not overlap. 1116*5ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul); 1117*5ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent(); 1118*5ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0); 1119*5ffd83dbSDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1120*5ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt( 1121*5ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1122*5ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd( 1123*5ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1124*5ffd83dbSDimitry Andric "store.end", true, true); 1125*5ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1126*5ffd83dbSDimitry Andric IntPtrTy, "load.begin"); 1127*5ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1128*5ffd83dbSDimitry Andric Fusion); 1129*5ffd83dbSDimitry Andric 1130*5ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the 1131*5ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not 1132*5ffd83dbSDimitry Andric // overlap. 1133*5ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent(); 1134*5ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin()); 1135*5ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd( 1136*5ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1137*5ffd83dbSDimitry Andric "load.end", true, true); 1138*5ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1139*5ffd83dbSDimitry Andric Fusion); 1140*5ffd83dbSDimitry Andric 1141*5ffd83dbSDimitry Andric // Copy load operand to new alloca. 1142*5ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin()); 1143*5ffd83dbSDimitry Andric AllocaInst *NewLd = 1144*5ffd83dbSDimitry Andric Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1145*5ffd83dbSDimitry Andric Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1146*5ffd83dbSDimitry Andric Load->getPointerOperand(), Load->getAlign(), 1147*5ffd83dbSDimitry Andric LoadLoc.Size.getValue()); 1148*5ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin()); 1149*5ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1150*5ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0); 1151*5ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1); 1152*5ffd83dbSDimitry Andric PHI->addIncoming(NewLd, Copy); 1153*5ffd83dbSDimitry Andric 1154*5ffd83dbSDimitry Andric // Adjust DT. 1155*5ffd83dbSDimitry Andric DTUpdates.push_back({DT.Insert, Check0, Check1}); 1156*5ffd83dbSDimitry Andric DTUpdates.push_back({DT.Insert, Check0, Fusion}); 1157*5ffd83dbSDimitry Andric DTUpdates.push_back({DT.Insert, Check1, Copy}); 1158*5ffd83dbSDimitry Andric DTUpdates.push_back({DT.Insert, Check1, Fusion}); 1159*5ffd83dbSDimitry Andric DT.applyUpdates(DTUpdates); 1160*5ffd83dbSDimitry Andric return PHI; 1161*5ffd83dbSDimitry Andric } 1162*5ffd83dbSDimitry Andric 1163*5ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) { 1164*5ffd83dbSDimitry Andric if (ForceFusion) 1165*5ffd83dbSDimitry Andric return true; 1166*5ffd83dbSDimitry Andric 1167*5ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1168*5ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1169*5ffd83dbSDimitry Andric 1170*5ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 1171*5ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 1172*5ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 1173*5ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1174*5ffd83dbSDimitry Andric 1175*5ffd83dbSDimitry Andric const unsigned VF = 1176*5ffd83dbSDimitry Andric std::max<unsigned>(TTI.getRegisterBitWidth(true) / 1177*5ffd83dbSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedSize(), 1178*5ffd83dbSDimitry Andric 1U); 1179*5ffd83dbSDimitry Andric 1180*5ffd83dbSDimitry Andric // Cost model for tiling 1181*5ffd83dbSDimitry Andric // 1182*5ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or 1183*5ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least 1184*5ffd83dbSDimitry Andric // 3 elements. 1185*5ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias. 1186*5ffd83dbSDimitry Andric if (R <= VF && C == 1) 1187*5ffd83dbSDimitry Andric return false; 1188*5ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector 1189*5ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since 1190*5ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of 1191*5ffd83dbSDimitry Andric // reloads necessary. 1192*5ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M; 1193*5ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C; 1194*5ffd83dbSDimitry Andric return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1195*5ffd83dbSDimitry Andric } 1196*5ffd83dbSDimitry Andric 1197*5ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1198*5ffd83dbSDimitry Andric MatrixTy Res; 1199*5ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R); 1200*5ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I) 1201*5ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType)); 1202*5ffd83dbSDimitry Andric return Res; 1203*5ffd83dbSDimitry Andric } 1204*5ffd83dbSDimitry Andric 1205*5ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1206*5ffd83dbSDimitry Andric StoreInst *Store, 1207*5ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 1208*5ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1209*5ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!"); 1210*5ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul)) 1211*5ffd83dbSDimitry Andric return; 1212*5ffd83dbSDimitry Andric 1213*5ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1214*5ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1215*5ffd83dbSDimitry Andric 1216*5ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 1217*5ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 1218*5ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 1219*5ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1220*5ffd83dbSDimitry Andric 1221*5ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1222*5ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1223*5ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand(); 1224*5ffd83dbSDimitry Andric 1225*5ffd83dbSDimitry Andric bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1226*5ffd83dbSDimitry Andric MatMul->hasAllowContract()); 1227*5ffd83dbSDimitry Andric IRBuilder<> Builder(Store); 1228*5ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize) 1229*5ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) { 1230*5ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1231*5ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1232*5ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1233*5ffd83dbSDimitry Andric 1234*5ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) { 1235*5ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1236*5ffd83dbSDimitry Andric MatrixTy A = 1237*5ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1238*5ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K), 1239*5ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder); 1240*5ffd83dbSDimitry Andric MatrixTy B = 1241*5ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1242*5ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J), 1243*5ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder); 1244*5ffd83dbSDimitry Andric emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 1245*5ffd83dbSDimitry Andric } 1246*5ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1247*5ffd83dbSDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); 1248*5ffd83dbSDimitry Andric } 1249*5ffd83dbSDimitry Andric 1250*5ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them. 1251*5ffd83dbSDimitry Andric FusedInsts.insert(Store); 1252*5ffd83dbSDimitry Andric FusedInsts.insert(MatMul); 1253*5ffd83dbSDimitry Andric Store->eraseFromParent(); 1254*5ffd83dbSDimitry Andric MatMul->eraseFromParent(); 1255*5ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) { 1256*5ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0); 1257*5ffd83dbSDimitry Andric LoadOp0->eraseFromParent(); 1258*5ffd83dbSDimitry Andric } 1259*5ffd83dbSDimitry Andric if (LoadOp1->hasNUses(0)) { 1260*5ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1); 1261*5ffd83dbSDimitry Andric LoadOp1->eraseFromParent(); 1262*5ffd83dbSDimitry Andric } 1263*5ffd83dbSDimitry Andric } 1264*5ffd83dbSDimitry Andric 1265*5ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations. 1266*5ffd83dbSDimitry Andric /// 1267*5ffd83dbSDimitry Andric /// Currently we only lower {ld, ld} -> matmul -> st chains. 1268*5ffd83dbSDimitry Andric // 1269*5ffd83dbSDimitry Andric /// No need to return a MatrixTy object for the result of the operation, since 1270*5ffd83dbSDimitry Andric /// the single store user will be lowered as part of this. Instructions that 1271*5ffd83dbSDimitry Andric /// are completely eliminated by fusion are added to \p FusedInsts. 1272*5ffd83dbSDimitry Andric void LowerMatrixMultiplyFused(CallInst *MatMul, 1273*5ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 1274*5ffd83dbSDimitry Andric if (!FuseMatrix || !MatMul->hasOneUse() || 1275*5ffd83dbSDimitry Andric MatrixLayout != MatrixLayoutTy::ColumnMajor) 1276*5ffd83dbSDimitry Andric return; 1277*5ffd83dbSDimitry Andric 1278*5ffd83dbSDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 1279*5ffd83dbSDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 1280*5ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1281*5ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) { 1282*5ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise 1283*5ffd83dbSDimitry Andric // we create invalid IR. 1284*5ffd83dbSDimitry Andric // FIXME: See if we can hoist the store address computation. 1285*5ffd83dbSDimitry Andric auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1286*5ffd83dbSDimitry Andric if (AddrI && (!DT.dominates(AddrI, MatMul))) 1287*5ffd83dbSDimitry Andric return; 1288*5ffd83dbSDimitry Andric 1289*5ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1290*5ffd83dbSDimitry Andric return; 1291*5ffd83dbSDimitry Andric } 1292*5ffd83dbSDimitry Andric } 1293*5ffd83dbSDimitry Andric 1294480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 1295480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 1296480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 1297480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1298480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1299480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1300480093f4SDimitry Andric 1301*5ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1302*5ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1303480093f4SDimitry Andric 1304480093f4SDimitry Andric const unsigned R = LShape.NumRows; 1305480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 1306*5ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows); 1307480093f4SDimitry Andric 1308480093f4SDimitry Andric // Initialize the output 1309*5ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType); 1310480093f4SDimitry Andric 1311480093f4SDimitry Andric bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1312480093f4SDimitry Andric MatMul->hasAllowContract()); 1313*5ffd83dbSDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1314480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1315480093f4SDimitry Andric } 1316480093f4SDimitry Andric 1317480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 1318480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 1319*5ffd83dbSDimitry Andric MatrixTy Result; 1320480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1321480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 1322480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1323480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1324*5ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1325480093f4SDimitry Andric 1326*5ffd83dbSDimitry Andric const unsigned NewNumVecs = 1327*5ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1328*5ffd83dbSDimitry Andric const unsigned NewNumElts = 1329*5ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1330480093f4SDimitry Andric 1331*5ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) { 1332*5ffd83dbSDimitry Andric // Build a single result vector. First initialize it. 1333*5ffd83dbSDimitry Andric Value *ResultVector = UndefValue::get( 1334*5ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1335*5ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector. 1336*5ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) { 1337*5ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I); 1338*5ffd83dbSDimitry Andric // Row and column indices are transposed. 1339*5ffd83dbSDimitry Andric ResultVector = 1340*5ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1341480093f4SDimitry Andric } 1342*5ffd83dbSDimitry Andric Result.addVector(ResultVector); 1343480093f4SDimitry Andric } 1344480093f4SDimitry Andric 1345*5ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we 1346*5ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not 1347*5ffd83dbSDimitry Andric // account for later simplifications/combines. 1348*5ffd83dbSDimitry Andric finalizeLowering( 1349*5ffd83dbSDimitry Andric Inst, 1350*5ffd83dbSDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1351*5ffd83dbSDimitry Andric Builder); 1352480093f4SDimitry Andric } 1353480093f4SDimitry Andric 1354480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 1355*5ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1356480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1357480093f4SDimitry Andric if (I == ShapeMap.end()) 1358480093f4SDimitry Andric return false; 1359480093f4SDimitry Andric 1360*5ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(), 1361*5ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1362*5ffd83dbSDimitry Andric I->second); 1363480093f4SDimitry Andric return true; 1364480093f4SDimitry Andric } 1365480093f4SDimitry Andric 1366*5ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1367480093f4SDimitry Andric IRBuilder<> &Builder) { 1368480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 1369480093f4SDimitry Andric if (I == ShapeMap.end()) 1370480093f4SDimitry Andric return false; 1371480093f4SDimitry Andric 1372*5ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1373*5ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1374*5ffd83dbSDimitry Andric I->second); 1375480093f4SDimitry Andric return true; 1376480093f4SDimitry Andric } 1377480093f4SDimitry Andric 1378480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 1379480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 1380480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1381480093f4SDimitry Andric if (I == ShapeMap.end()) 1382480093f4SDimitry Andric return false; 1383480093f4SDimitry Andric 1384480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 1385480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 1386480093f4SDimitry Andric 1387480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1388480093f4SDimitry Andric ShapeInfo &Shape = I->second; 1389480093f4SDimitry Andric 1390*5ffd83dbSDimitry Andric MatrixTy Result; 1391*5ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder); 1392*5ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder); 1393*5ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 1394*5ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 1395*5ffd83dbSDimitry Andric "operands must agree on matrix layout"); 1396480093f4SDimitry Andric 1397*5ffd83dbSDimitry Andric // Helper to perform binary op on vectors. 1398*5ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1399480093f4SDimitry Andric switch (Inst->getOpcode()) { 1400480093f4SDimitry Andric case Instruction::Add: 1401480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 1402480093f4SDimitry Andric case Instruction::Mul: 1403480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 1404480093f4SDimitry Andric case Instruction::Sub: 1405480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 1406480093f4SDimitry Andric case Instruction::FAdd: 1407480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 1408480093f4SDimitry Andric case Instruction::FMul: 1409480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 1410480093f4SDimitry Andric case Instruction::FSub: 1411480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 1412480093f4SDimitry Andric default: 1413480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 1414480093f4SDimitry Andric } 1415480093f4SDimitry Andric }; 1416480093f4SDimitry Andric 1417*5ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1418*5ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1419*5ffd83dbSDimitry Andric 1420*5ffd83dbSDimitry Andric finalizeLowering(Inst, 1421*5ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1422*5ffd83dbSDimitry Andric Result.getNumVectors()), 1423*5ffd83dbSDimitry Andric Builder); 1424480093f4SDimitry Andric return true; 1425480093f4SDimitry Andric } 1426*5ffd83dbSDimitry Andric 1427*5ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently 1428*5ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and 1429*5ffd83dbSDimitry Andric /// linearizing bottom up. 1430*5ffd83dbSDimitry Andric struct ExprLinearizer { 1431*5ffd83dbSDimitry Andric unsigned LengthToBreak = 100; 1432*5ffd83dbSDimitry Andric std::string Str; 1433*5ffd83dbSDimitry Andric raw_string_ostream Stream; 1434*5ffd83dbSDimitry Andric unsigned LineLength = 0; 1435*5ffd83dbSDimitry Andric const DataLayout &DL; 1436*5ffd83dbSDimitry Andric 1437*5ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify 1438*5ffd83dbSDimitry Andric /// matrix instructions. 1439*5ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 1440*5ffd83dbSDimitry Andric 1441*5ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is 1442*5ffd83dbSDimitry Andric /// part of. 1443*5ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1444*5ffd83dbSDimitry Andric 1445*5ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram. 1446*5ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1447*5ffd83dbSDimitry Andric 1448*5ffd83dbSDimitry Andric /// Leaf node of the expression to linearize. 1449*5ffd83dbSDimitry Andric Value *Leaf; 1450*5ffd83dbSDimitry Andric 1451*5ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing 1452*5ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused). 1453*5ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 1454*5ffd83dbSDimitry Andric 1455*5ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL, 1456*5ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix, 1457*5ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1458*5ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1459*5ffd83dbSDimitry Andric Value *Leaf) 1460*5ffd83dbSDimitry Andric : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1461*5ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1462*5ffd83dbSDimitry Andric 1463*5ffd83dbSDimitry Andric void indent(unsigned N) { 1464*5ffd83dbSDimitry Andric LineLength += N; 1465*5ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++) 1466*5ffd83dbSDimitry Andric Stream << " "; 1467*5ffd83dbSDimitry Andric } 1468*5ffd83dbSDimitry Andric 1469*5ffd83dbSDimitry Andric void lineBreak() { 1470*5ffd83dbSDimitry Andric Stream << "\n"; 1471*5ffd83dbSDimitry Andric LineLength = 0; 1472*5ffd83dbSDimitry Andric } 1473*5ffd83dbSDimitry Andric 1474*5ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) { 1475*5ffd83dbSDimitry Andric if (LineLength >= LengthToBreak) 1476*5ffd83dbSDimitry Andric lineBreak(); 1477*5ffd83dbSDimitry Andric 1478*5ffd83dbSDimitry Andric if (LineLength == 0) 1479*5ffd83dbSDimitry Andric indent(Indent); 1480*5ffd83dbSDimitry Andric } 1481*5ffd83dbSDimitry Andric 1482*5ffd83dbSDimitry Andric void write(StringRef S) { 1483*5ffd83dbSDimitry Andric LineLength += S.size(); 1484*5ffd83dbSDimitry Andric Stream << S; 1485*5ffd83dbSDimitry Andric } 1486*5ffd83dbSDimitry Andric 1487*5ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) { 1488*5ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V)) 1489*5ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr); 1490*5ffd83dbSDimitry Andric else if (V->getType()->isPointerTy()) 1491*5ffd83dbSDimitry Andric return GetUnderlyingObject(V, DL); 1492*5ffd83dbSDimitry Andric return V; 1493*5ffd83dbSDimitry Andric } 1494*5ffd83dbSDimitry Andric 1495*5ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram. 1496*5ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1497*5ffd83dbSDimitry Andric 1498*5ffd83dbSDimitry Andric /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1499*5ffd83dbSDimitry Andric /// \p SS. 1500*5ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1501*5ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V); 1502*5ffd83dbSDimitry Andric if (M == Inst2Matrix.end()) 1503*5ffd83dbSDimitry Andric SS << "unknown"; 1504*5ffd83dbSDimitry Andric else { 1505*5ffd83dbSDimitry Andric SS << M->second.getNumRows(); 1506*5ffd83dbSDimitry Andric SS << "x"; 1507*5ffd83dbSDimitry Andric SS << M->second.getNumColumns(); 1508*5ffd83dbSDimitry Andric } 1509*5ffd83dbSDimitry Andric } 1510*5ffd83dbSDimitry Andric 1511*5ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.* 1512*5ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input 1513*5ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name. 1514*5ffd83dbSDimitry Andric void writeFnName(CallInst *CI) { 1515*5ffd83dbSDimitry Andric if (!CI->getCalledFunction()) 1516*5ffd83dbSDimitry Andric write("<no called fn>"); 1517*5ffd83dbSDimitry Andric else { 1518*5ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName(); 1519*5ffd83dbSDimitry Andric if (!Name.startswith("llvm.matrix")) { 1520*5ffd83dbSDimitry Andric write(Name); 1521*5ffd83dbSDimitry Andric return; 1522*5ffd83dbSDimitry Andric } 1523*5ffd83dbSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1524*5ffd83dbSDimitry Andric write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1525*5ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size())); 1526*5ffd83dbSDimitry Andric write("."); 1527*5ffd83dbSDimitry Andric std::string Tmp = ""; 1528*5ffd83dbSDimitry Andric raw_string_ostream SS(Tmp); 1529*5ffd83dbSDimitry Andric 1530*5ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 1531*5ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 1532*5ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 1533*5ffd83dbSDimitry Andric SS << "."; 1534*5ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS); 1535*5ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 1536*5ffd83dbSDimitry Andric break; 1537*5ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 1538*5ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 1539*5ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 1540*5ffd83dbSDimitry Andric break; 1541*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 1542*5ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS); 1543*5ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 1544*5ffd83dbSDimitry Andric break; 1545*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 1546*5ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 1547*5ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1548*5ffd83dbSDimitry Andric break; 1549*5ffd83dbSDimitry Andric default: 1550*5ffd83dbSDimitry Andric llvm_unreachable("Unhandled case"); 1551*5ffd83dbSDimitry Andric } 1552*5ffd83dbSDimitry Andric SS.flush(); 1553*5ffd83dbSDimitry Andric write(Tmp); 1554*5ffd83dbSDimitry Andric } 1555*5ffd83dbSDimitry Andric } 1556*5ffd83dbSDimitry Andric 1557*5ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const { 1558*5ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1559*5ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 1560*5ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 1561*5ffd83dbSDimitry Andric return 3; 1562*5ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 1563*5ffd83dbSDimitry Andric return 2; 1564*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 1565*5ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 1566*5ffd83dbSDimitry Andric return 3; 1567*5ffd83dbSDimitry Andric default: 1568*5ffd83dbSDimitry Andric return 0; 1569*5ffd83dbSDimitry Andric } 1570*5ffd83dbSDimitry Andric } 1571*5ffd83dbSDimitry Andric return 0; 1572*5ffd83dbSDimitry Andric } 1573*5ffd83dbSDimitry Andric 1574*5ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an 1575*5ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we 1576*5ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values. 1577*5ffd83dbSDimitry Andric void write(Value *V) { 1578*5ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V); 1579*5ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) { 1580*5ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) { 1581*5ffd83dbSDimitry Andric Stream << "stack addr"; 1582*5ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size(); 1583*5ffd83dbSDimitry Andric } else { 1584*5ffd83dbSDimitry Andric Stream << "addr"; 1585*5ffd83dbSDimitry Andric LineLength += StringRef("addr").size(); 1586*5ffd83dbSDimitry Andric } 1587*5ffd83dbSDimitry Andric if (!V->getName().empty()) { 1588*5ffd83dbSDimitry Andric Stream << " %" << V->getName() << ""; 1589*5ffd83dbSDimitry Andric LineLength += V->getName().size() + 2; 1590*5ffd83dbSDimitry Andric } 1591*5ffd83dbSDimitry Andric return; 1592*5ffd83dbSDimitry Andric } 1593*5ffd83dbSDimitry Andric 1594*5ffd83dbSDimitry Andric std::string Tmp; 1595*5ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp); 1596*5ffd83dbSDimitry Andric 1597*5ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V)) 1598*5ffd83dbSDimitry Andric TmpStream << CI->getValue(); 1599*5ffd83dbSDimitry Andric else if (isa<Constant>(V)) 1600*5ffd83dbSDimitry Andric TmpStream << "constant"; 1601*5ffd83dbSDimitry Andric else { 1602*5ffd83dbSDimitry Andric if (isMatrix(V)) 1603*5ffd83dbSDimitry Andric TmpStream << "matrix"; 1604*5ffd83dbSDimitry Andric else 1605*5ffd83dbSDimitry Andric TmpStream << "scalar"; 1606*5ffd83dbSDimitry Andric } 1607*5ffd83dbSDimitry Andric TmpStream.flush(); 1608*5ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim()); 1609*5ffd83dbSDimitry Andric LineLength += Tmp.size(); 1610*5ffd83dbSDimitry Andric Stream << Tmp; 1611*5ffd83dbSDimitry Andric } 1612*5ffd83dbSDimitry Andric 1613*5ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent. 1614*5ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused) 1615*5ffd83dbSDimitry Andric /// at the re-used root instruction. 1616*5ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1617*5ffd83dbSDimitry Andric bool ParentShared) { 1618*5ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr); 1619*5ffd83dbSDimitry Andric maybeIndent(Indent); 1620*5ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops; 1621*5ffd83dbSDimitry Andric 1622*5ffd83dbSDimitry Andric // Is Expr shared with other expression leaves? 1623*5ffd83dbSDimitry Andric bool ExprShared = false; 1624*5ffd83dbSDimitry Andric 1625*5ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required. 1626*5ffd83dbSDimitry Andric if (!ParentShared) { 1627*5ffd83dbSDimitry Andric auto SI = Shared.find(Expr); 1628*5ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf)); 1629*5ffd83dbSDimitry Andric 1630*5ffd83dbSDimitry Andric for (Value *S : SI->second) { 1631*5ffd83dbSDimitry Andric if (S == Leaf) 1632*5ffd83dbSDimitry Andric continue; 1633*5ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1634*5ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) + 1635*5ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " ("); 1636*5ffd83dbSDimitry Andric } 1637*5ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1; 1638*5ffd83dbSDimitry Andric } 1639*5ffd83dbSDimitry Andric 1640*5ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second; 1641*5ffd83dbSDimitry Andric if (Reused && !ParentReused) 1642*5ffd83dbSDimitry Andric write("(reused) "); 1643*5ffd83dbSDimitry Andric 1644*5ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) { 1645*5ffd83dbSDimitry Andric writeFnName(CI); 1646*5ffd83dbSDimitry Andric 1647*5ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1648*5ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) { 1649*5ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from 1650*5ffd83dbSDimitry Andric // non-matrix ops. 1651*5ffd83dbSDimitry Andric write("matrix"); 1652*5ffd83dbSDimitry Andric return; 1653*5ffd83dbSDimitry Andric } else { 1654*5ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end()); 1655*5ffd83dbSDimitry Andric write(std::string(I->getOpcodeName())); 1656*5ffd83dbSDimitry Andric } 1657*5ffd83dbSDimitry Andric 1658*5ffd83dbSDimitry Andric write(std::string("(")); 1659*5ffd83dbSDimitry Andric 1660*5ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1; 1661*5ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1662*5ffd83dbSDimitry Andric NumOpsToBreak = 2; 1663*5ffd83dbSDimitry Andric 1664*5ffd83dbSDimitry Andric for (Value *Op : Ops) { 1665*5ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak) 1666*5ffd83dbSDimitry Andric lineBreak(); 1667*5ffd83dbSDimitry Andric 1668*5ffd83dbSDimitry Andric maybeIndent(Indent + 1); 1669*5ffd83dbSDimitry Andric if (isMatrix(Op)) 1670*5ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1671*5ffd83dbSDimitry Andric else 1672*5ffd83dbSDimitry Andric write(Op); 1673*5ffd83dbSDimitry Andric if (Op != Ops.back()) 1674*5ffd83dbSDimitry Andric write(", "); 1675*5ffd83dbSDimitry Andric } 1676*5ffd83dbSDimitry Andric 1677*5ffd83dbSDimitry Andric write(")"); 1678*5ffd83dbSDimitry Andric } 1679*5ffd83dbSDimitry Andric 1680*5ffd83dbSDimitry Andric const std::string &getResult() { 1681*5ffd83dbSDimitry Andric Stream.flush(); 1682*5ffd83dbSDimitry Andric return Str; 1683*5ffd83dbSDimitry Andric } 1684*5ffd83dbSDimitry Andric }; 1685*5ffd83dbSDimitry Andric 1686*5ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks 1687*5ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used: 1688*5ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the 1689*5ffd83dbSDimitry Andric /// DISubprograms they are contained in. 1690*5ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in 1691*5ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1692*5ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix 1693*5ffd83dbSDimitry Andric // users (like stores) in the current subprogram. 1694*5ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the 1695*5ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive 1696*5ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1697*5ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions 1698*5ffd83dbSDimitry Andric /// are explicitly marked as shared(). 1699*5ffd83dbSDimitry Andric struct RemarkGenerator { 1700*5ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 1701*5ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 1702*5ffd83dbSDimitry Andric Function &Func; 1703*5ffd83dbSDimitry Andric const DataLayout &DL; 1704*5ffd83dbSDimitry Andric 1705*5ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1706*5ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func) 1707*5ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1708*5ffd83dbSDimitry Andric DL(Func.getParent()->getDataLayout()) {} 1709*5ffd83dbSDimitry Andric 1710*5ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1711*5ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in 1712*5ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores. 1713*5ffd83dbSDimitry Andric SmallVector<Value *, 4> 1714*5ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1715*5ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves; 1716*5ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram) 1717*5ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() || 1718*5ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1719*5ffd83dbSDimitry Andric return ExprsInSubprogram.count(U); 1720*5ffd83dbSDimitry Andric })) 1721*5ffd83dbSDimitry Andric Leaves.push_back(Expr); 1722*5ffd83dbSDimitry Andric return Leaves; 1723*5ffd83dbSDimitry Andric } 1724*5ffd83dbSDimitry Andric 1725*5ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1726*5ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to 1727*5ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram. 1728*5ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V, 1729*5ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1730*5ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1731*5ffd83dbSDimitry Andric 1732*5ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V)) 1733*5ffd83dbSDimitry Andric return; 1734*5ffd83dbSDimitry Andric 1735*5ffd83dbSDimitry Andric auto I = Shared.insert({V, {}}); 1736*5ffd83dbSDimitry Andric I.first->second.insert(Leaf); 1737*5ffd83dbSDimitry Andric 1738*5ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values()) 1739*5ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1740*5ffd83dbSDimitry Andric return; 1741*5ffd83dbSDimitry Andric } 1742*5ffd83dbSDimitry Andric 1743*5ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression 1744*5ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once. 1745*5ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1746*5ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy> 1747*5ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1748*5ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1749*5ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1750*5ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root)) 1751*5ffd83dbSDimitry Andric return {}; 1752*5ffd83dbSDimitry Andric 1753*5ffd83dbSDimitry Andric // Already counted this expression. Stop. 1754*5ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second) 1755*5ffd83dbSDimitry Andric return {}; 1756*5ffd83dbSDimitry Andric 1757*5ffd83dbSDimitry Andric OpInfoTy SharedCount; 1758*5ffd83dbSDimitry Andric OpInfoTy Count; 1759*5ffd83dbSDimitry Andric 1760*5ffd83dbSDimitry Andric auto I = Shared.find(Root); 1761*5ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root); 1762*5ffd83dbSDimitry Andric if (I->second.size() == 1) 1763*5ffd83dbSDimitry Andric Count = CM->second.getOpInfo(); 1764*5ffd83dbSDimitry Andric else 1765*5ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo(); 1766*5ffd83dbSDimitry Andric 1767*5ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1768*5ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1769*5ffd83dbSDimitry Andric Count += C.first; 1770*5ffd83dbSDimitry Andric SharedCount += C.second; 1771*5ffd83dbSDimitry Andric } 1772*5ffd83dbSDimitry Andric return {Count, SharedCount}; 1773*5ffd83dbSDimitry Andric } 1774*5ffd83dbSDimitry Andric 1775*5ffd83dbSDimitry Andric void emitRemarks() { 1776*5ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1777*5ffd83dbSDimitry Andric return; 1778*5ffd83dbSDimitry Andric 1779*5ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing 1780*5ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we 1781*5ffd83dbSDimitry Andric // only map them to the containing function. 1782*5ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1783*5ffd83dbSDimitry Andric for (auto &KV : Inst2Matrix) { 1784*5ffd83dbSDimitry Andric if (Func.getSubprogram()) { 1785*5ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first); 1786*5ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc(); 1787*5ffd83dbSDimitry Andric while (Context) { 1788*5ffd83dbSDimitry Andric auto I = 1789*5ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1790*5ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 1791*5ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 1792*5ffd83dbSDimitry Andric } 1793*5ffd83dbSDimitry Andric } else { 1794*5ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}}); 1795*5ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 1796*5ffd83dbSDimitry Andric } 1797*5ffd83dbSDimitry Andric } 1798*5ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) { 1799*5ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1800*5ffd83dbSDimitry Andric KV.second.end()); 1801*5ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1802*5ffd83dbSDimitry Andric 1803*5ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1804*5ffd83dbSDimitry Andric for (Value *Leaf : Leaves) 1805*5ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1806*5ffd83dbSDimitry Andric 1807*5ffd83dbSDimitry Andric // Generate remarks for each leaf. 1808*5ffd83dbSDimitry Andric for (auto *L : Leaves) { 1809*5ffd83dbSDimitry Andric 1810*5ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1811*5ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1812*5ffd83dbSDimitry Andric while (Context) { 1813*5ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) { 1814*5ffd83dbSDimitry Andric Loc = Context; 1815*5ffd83dbSDimitry Andric break; 1816*5ffd83dbSDimitry Andric } 1817*5ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 1818*5ffd83dbSDimitry Andric } 1819*5ffd83dbSDimitry Andric 1820*5ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 1821*5ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts; 1822*5ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) = 1823*5ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1824*5ffd83dbSDimitry Andric 1825*5ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1826*5ffd83dbSDimitry Andric cast<Instruction>(L)->getParent()); 1827*5ffd83dbSDimitry Andric 1828*5ffd83dbSDimitry Andric Rem << "Lowered with "; 1829*5ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1830*5ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1831*5ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps) 1832*5ffd83dbSDimitry Andric << " compute ops"; 1833*5ffd83dbSDimitry Andric 1834*5ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1835*5ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) { 1836*5ffd83dbSDimitry Andric Rem << ",\nadditionally " 1837*5ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1838*5ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1839*5ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1840*5ffd83dbSDimitry Andric << " compute ops" 1841*5ffd83dbSDimitry Andric << " are shared with other expressions"; 1842*5ffd83dbSDimitry Andric } 1843*5ffd83dbSDimitry Andric 1844*5ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1845*5ffd83dbSDimitry Andric ORE.emit(Rem); 1846*5ffd83dbSDimitry Andric } 1847*5ffd83dbSDimitry Andric } 1848*5ffd83dbSDimitry Andric } 1849*5ffd83dbSDimitry Andric 1850*5ffd83dbSDimitry Andric std::string 1851*5ffd83dbSDimitry Andric linearize(Value *L, 1852*5ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1853*5ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1854*5ffd83dbSDimitry Andric const DataLayout &DL) { 1855*5ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 1856*5ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false); 1857*5ffd83dbSDimitry Andric return Lin.getResult(); 1858*5ffd83dbSDimitry Andric } 1859*5ffd83dbSDimitry Andric }; 1860480093f4SDimitry Andric }; 1861480093f4SDimitry Andric } // namespace 1862480093f4SDimitry Andric 1863480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1864480093f4SDimitry Andric FunctionAnalysisManager &AM) { 1865480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1866*5ffd83dbSDimitry Andric auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1867*5ffd83dbSDimitry Andric auto &AA = AM.getResult<AAManager>(F); 1868*5ffd83dbSDimitry Andric auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1869*5ffd83dbSDimitry Andric auto &LI = AM.getResult<LoopAnalysis>(F); 1870*5ffd83dbSDimitry Andric 1871*5ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1872480093f4SDimitry Andric if (LMT.Visit()) { 1873480093f4SDimitry Andric PreservedAnalyses PA; 1874480093f4SDimitry Andric PA.preserveSet<CFGAnalyses>(); 1875480093f4SDimitry Andric return PA; 1876480093f4SDimitry Andric } 1877480093f4SDimitry Andric return PreservedAnalyses::all(); 1878480093f4SDimitry Andric } 1879480093f4SDimitry Andric 1880480093f4SDimitry Andric namespace { 1881480093f4SDimitry Andric 1882480093f4SDimitry Andric class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1883480093f4SDimitry Andric public: 1884480093f4SDimitry Andric static char ID; 1885480093f4SDimitry Andric 1886480093f4SDimitry Andric LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1887480093f4SDimitry Andric initializeLowerMatrixIntrinsicsLegacyPassPass( 1888480093f4SDimitry Andric *PassRegistry::getPassRegistry()); 1889480093f4SDimitry Andric } 1890480093f4SDimitry Andric 1891480093f4SDimitry Andric bool runOnFunction(Function &F) override { 1892*5ffd83dbSDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1893*5ffd83dbSDimitry Andric auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1894*5ffd83dbSDimitry Andric auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1895*5ffd83dbSDimitry Andric auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1896*5ffd83dbSDimitry Andric auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1897*5ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1898480093f4SDimitry Andric bool C = LMT.Visit(); 1899480093f4SDimitry Andric return C; 1900480093f4SDimitry Andric } 1901480093f4SDimitry Andric 1902480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 1903480093f4SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 1904*5ffd83dbSDimitry Andric AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1905*5ffd83dbSDimitry Andric AU.addRequired<AAResultsWrapperPass>(); 1906*5ffd83dbSDimitry Andric AU.addRequired<DominatorTreeWrapperPass>(); 1907*5ffd83dbSDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 1908*5ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 1909*5ffd83dbSDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 1910480093f4SDimitry Andric } 1911480093f4SDimitry Andric }; 1912480093f4SDimitry Andric } // namespace 1913480093f4SDimitry Andric 1914480093f4SDimitry Andric static const char pass_name[] = "Lower the matrix intrinsics"; 1915480093f4SDimitry Andric char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1916480093f4SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1917480093f4SDimitry Andric false, false) 1918*5ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1919*5ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 1920*5ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1921*5ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1922480093f4SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1923480093f4SDimitry Andric false, false) 1924480093f4SDimitry Andric 1925480093f4SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsPass() { 1926480093f4SDimitry Andric return new LowerMatrixIntrinsicsLegacyPass(); 1927480093f4SDimitry Andric } 1928