1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10480093f4SDimitry Andric // 11480093f4SDimitry Andric // TODO: 125ffd83dbSDimitry Andric // * Improve fusion: 135ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 145ffd83dbSDimitry Andric // transposed. 155ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns 165ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias. 17480093f4SDimitry Andric // 18480093f4SDimitry Andric //===----------------------------------------------------------------------===// 19480093f4SDimitry Andric 20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21480093f4SDimitry Andric #include "llvm/ADT/GraphTraits.h" 22480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 23480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 245ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 255ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 265ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 285ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 29480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 30480093f4SDimitry Andric #include "llvm/IR/CFG.h" 31480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 325ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h" 33480093f4SDimitry Andric #include "llvm/IR/Function.h" 34480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 35480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 36480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 37fe6060f1SDimitry Andric #include "llvm/IR/MatrixBuilder.h" 38480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 39480093f4SDimitry Andric #include "llvm/InitializePasses.h" 40480093f4SDimitry Andric #include "llvm/Pass.h" 415ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h" 425ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 43480093f4SDimitry Andric #include "llvm/Support/Debug.h" 44480093f4SDimitry Andric #include "llvm/Transforms/Scalar.h" 455ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 47e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 48480093f4SDimitry Andric 49480093f4SDimitry Andric using namespace llvm; 50480093f4SDimitry Andric using namespace PatternMatch; 51480093f4SDimitry Andric 52480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 53480093f4SDimitry Andric 545ffd83dbSDimitry Andric static cl::opt<bool> 555ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 565ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions.")); 575ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles. 585ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize( 595ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 605ffd83dbSDimitry Andric cl::desc( 615ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles.")); 62e8d8bef9SDimitry Andric static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 63e8d8bef9SDimitry Andric cl::Hidden, 64e8d8bef9SDimitry Andric cl::desc("Generate loop nest for tiling.")); 655ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion( 665ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden, 675ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable.")); 68480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 69480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 70480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 71480093f4SDimitry Andric "result in different results, due to less rounding error.")); 72480093f4SDimitry Andric 735ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 745ffd83dbSDimitry Andric 755ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout( 765ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 775ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"), 785ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 795ffd83dbSDimitry Andric "Use column-major layout"), 805ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 815ffd83dbSDimitry Andric "Use row-major layout"))); 825ffd83dbSDimitry Andric 835ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the 845ffd83dbSDimitry Andric /// attached subprogram for a local scope. 855ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) { 865ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 875ffd83dbSDimitry Andric return Subprogram; 885ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram(); 895ffd83dbSDimitry Andric } 905ffd83dbSDimitry Andric 91480093f4SDimitry Andric namespace { 92480093f4SDimitry Andric 935ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 945ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 955ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors. 965ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements. 975ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column 985ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column 995ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function 1005ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the 1015ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix). 102480093f4SDimitry Andric // 1035ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below 104480093f4SDimitry Andric // 105480093f4SDimitry Andric // 0 1 2 3 106480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 107480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 108480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 109480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 110480093f4SDimitry Andric 111480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 112480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 1135ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns 114480093f4SDimitry Andric // of the sub-matrix. 115480093f4SDimitry Andric // 1165ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 117480093f4SDimitry Andric // -> just returns Base 1185ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 119480093f4SDimitry Andric // -> returns Base + (1 * 4) 1205ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 121480093f4SDimitry Andric // -> returns Base + (2 * 4) 122480093f4SDimitry Andric // 123480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 124480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 125480093f4SDimitry Andric // 126480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 127480093f4SDimitry Andric // Base Col 1 Col 2 128480093f4SDimitry Andric // | | | 129480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 130480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 131480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 132480093f4SDimitry Andric // 1335ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 1345ffd83dbSDimitry Andric unsigned NumElements, Type *EltType, 135480093f4SDimitry Andric IRBuilder<> &Builder) { 136480093f4SDimitry Andric 137480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 1385ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 1395ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector."); 140480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 141480093f4SDimitry Andric 1425ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride. 1435ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 144480093f4SDimitry Andric 1455ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation, 1465ffd83dbSDimitry Andric // if we select vector 0. 1475ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 1485ffd83dbSDimitry Andric VecStart = BasePtr; 149480093f4SDimitry Andric else 1505ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 151480093f4SDimitry Andric 1525ffd83dbSDimitry Andric // Cast elementwise vector start pointer to a pointer to a vector 1535ffd83dbSDimitry Andric // (EltType x NumElements)*. 1545ffd83dbSDimitry Andric auto *VecType = FixedVectorType::get(EltType, NumElements); 1555ffd83dbSDimitry Andric Type *VecPtrType = PointerType::get(VecType, AS); 1565ffd83dbSDimitry Andric return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 157480093f4SDimitry Andric } 158480093f4SDimitry Andric 159480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 160480093f4SDimitry Andric /// 161480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 162480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 163480093f4SDimitry Andric /// instructions. 1645ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout). 1655ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout. 166480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 167480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 168480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 169480093f4SDimitry Andric /// a set of column vectors, 1705ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which 1715ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we 1725ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the 1735ffd83dbSDimitry Andric /// intrinsics, this includes stores for example. 174480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 175480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 176480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 177480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 178480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 179480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 180480093f4SDimitry Andric /// obsolete instructions. 181480093f4SDimitry Andric /// 182480093f4SDimitry Andric class LowerMatrixIntrinsics { 183480093f4SDimitry Andric Function &Func; 184480093f4SDimitry Andric const DataLayout &DL; 185480093f4SDimitry Andric const TargetTransformInfo &TTI; 186e8d8bef9SDimitry Andric AliasAnalysis *AA; 187e8d8bef9SDimitry Andric DominatorTree *DT; 188e8d8bef9SDimitry Andric LoopInfo *LI; 189e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE; 190480093f4SDimitry Andric 1915ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 1925ffd83dbSDimitry Andric struct OpInfoTy { 1935ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix. 1945ffd83dbSDimitry Andric unsigned NumStores = 0; 1955ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix. 1965ffd83dbSDimitry Andric unsigned NumLoads = 0; 1975ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix. 1985ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 199fe6060f1SDimitry Andric /// Most of the time transposes can be fused with matrix multiplies or can 200fe6060f1SDimitry Andric /// be folded away via algebraic simplifications. This is the number of 201fe6060f1SDimitry Andric /// transposes that we failed to make "free" via such optimizations. 202fe6060f1SDimitry Andric unsigned NumExposedTransposes = 0; 2035ffd83dbSDimitry Andric 2045ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) { 2055ffd83dbSDimitry Andric NumStores += RHS.NumStores; 2065ffd83dbSDimitry Andric NumLoads += RHS.NumLoads; 2075ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps; 208fe6060f1SDimitry Andric NumExposedTransposes += RHS.NumExposedTransposes; 2095ffd83dbSDimitry Andric return *this; 2105ffd83dbSDimitry Andric } 2115ffd83dbSDimitry Andric }; 2125ffd83dbSDimitry Andric 2135ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or 2145ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type. 2155ffd83dbSDimitry Andric class MatrixTy { 2165ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors; 2175ffd83dbSDimitry Andric 2185ffd83dbSDimitry Andric OpInfoTy OpInfo; 2195ffd83dbSDimitry Andric 2205ffd83dbSDimitry Andric bool IsColumnMajor = true; 221480093f4SDimitry Andric 222480093f4SDimitry Andric public: 2235ffd83dbSDimitry Andric MatrixTy() 2245ffd83dbSDimitry Andric : Vectors(), 2255ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2265ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors) 2275ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()), 2285ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2295ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 2305ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 231480093f4SDimitry Andric 2325ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows; 2335ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J) 2345ffd83dbSDimitry Andric addVector(UndefValue::get(FixedVectorType::get( 2355ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns))); 236480093f4SDimitry Andric } 237480093f4SDimitry Andric 2385ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; } 2395ffd83dbSDimitry Andric Value *getColumn(unsigned i) const { 2405ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2415ffd83dbSDimitry Andric return Vectors[i]; 2425ffd83dbSDimitry Andric } 2435ffd83dbSDimitry Andric Value *getRow(unsigned i) const { 2445ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes"); 2455ffd83dbSDimitry Andric return Vectors[i]; 2465ffd83dbSDimitry Andric } 247480093f4SDimitry Andric 2485ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; } 249480093f4SDimitry Andric 250e8d8bef9SDimitry Andric Type *getElementType() const { return getVectorTy()->getElementType(); } 2515ffd83dbSDimitry Andric 2525ffd83dbSDimitry Andric unsigned getNumVectors() const { 2535ffd83dbSDimitry Andric if (isColumnMajor()) 2545ffd83dbSDimitry Andric return getNumColumns(); 2555ffd83dbSDimitry Andric return getNumRows(); 2565ffd83dbSDimitry Andric } 2575ffd83dbSDimitry Andric 2585ffd83dbSDimitry Andric unsigned getNumColumns() const { 2595ffd83dbSDimitry Andric if (isColumnMajor()) 2605ffd83dbSDimitry Andric return Vectors.size(); 2615ffd83dbSDimitry Andric else { 2625ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2635ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2645ffd83dbSDimitry Andric } 2655ffd83dbSDimitry Andric } 2665ffd83dbSDimitry Andric unsigned getNumRows() const { 2675ffd83dbSDimitry Andric if (isColumnMajor()) { 2685ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2695ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2705ffd83dbSDimitry Andric } else 2715ffd83dbSDimitry Andric return Vectors.size(); 2725ffd83dbSDimitry Andric } 2735ffd83dbSDimitry Andric 2745ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); } 2755ffd83dbSDimitry Andric VectorType *getColumnTy() { 2765ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2775ffd83dbSDimitry Andric return getVectorTy(); 2785ffd83dbSDimitry Andric } 2795ffd83dbSDimitry Andric 280e8d8bef9SDimitry Andric VectorType *getVectorTy() const { 2815ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType()); 2825ffd83dbSDimitry Andric } 283480093f4SDimitry Andric 284480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 2855ffd83dbSDimitry Andric assert(isColumnMajor() && 2865ffd83dbSDimitry Andric "columns() only supported for column-major matrixes"); 2875ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 288480093f4SDimitry Andric } 289480093f4SDimitry Andric 2905ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 2915ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 2925ffd83dbSDimitry Andric } 2935ffd83dbSDimitry Andric 2945ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating 295480093f4SDimitry Andric /// them. 296480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 2975ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0] 2985ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors); 2995ffd83dbSDimitry Andric } 3005ffd83dbSDimitry Andric 3015ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) { 3025ffd83dbSDimitry Andric OpInfo.NumLoads += N; 3035ffd83dbSDimitry Andric return *this; 3045ffd83dbSDimitry Andric } 3055ffd83dbSDimitry Andric 3065ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 3075ffd83dbSDimitry Andric 3085ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) { 3095ffd83dbSDimitry Andric OpInfo.NumStores += N; 3105ffd83dbSDimitry Andric return *this; 3115ffd83dbSDimitry Andric } 3125ffd83dbSDimitry Andric 313fe6060f1SDimitry Andric MatrixTy &addNumExposedTransposes(unsigned N) { 314fe6060f1SDimitry Andric OpInfo.NumExposedTransposes += N; 315fe6060f1SDimitry Andric return *this; 316fe6060f1SDimitry Andric } 317fe6060f1SDimitry Andric 3185ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) { 3195ffd83dbSDimitry Andric OpInfo.NumComputeOps += N; 3205ffd83dbSDimitry Andric return *this; 3215ffd83dbSDimitry Andric } 3225ffd83dbSDimitry Andric 3235ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; } 3245ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; } 3255ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 3265ffd83dbSDimitry Andric 3275ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; } 3285ffd83dbSDimitry Andric 3295ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; } 3305ffd83dbSDimitry Andric 3315ffd83dbSDimitry Andric unsigned getStride() const { 3325ffd83dbSDimitry Andric if (isColumnMajor()) 3335ffd83dbSDimitry Andric return getNumRows(); 3345ffd83dbSDimitry Andric return getNumColumns(); 3355ffd83dbSDimitry Andric } 3365ffd83dbSDimitry Andric 3375ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 3385ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column 3395ffd83dbSDimitry Andric /// vector, otherwise from a row vector. 3405ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 3415ffd83dbSDimitry Andric IRBuilder<> &Builder) const { 3425ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 3435ffd83dbSDimitry Andric return Builder.CreateShuffleVector( 344e8d8bef9SDimitry Andric Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 3455ffd83dbSDimitry Andric "block"); 346480093f4SDimitry Andric } 347480093f4SDimitry Andric }; 348480093f4SDimitry Andric 349480093f4SDimitry Andric struct ShapeInfo { 350480093f4SDimitry Andric unsigned NumRows; 351480093f4SDimitry Andric unsigned NumColumns; 352480093f4SDimitry Andric 3535ffd83dbSDimitry Andric bool IsColumnMajor; 3545ffd83dbSDimitry Andric 355480093f4SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 3565ffd83dbSDimitry Andric : NumRows(NumRows), NumColumns(NumColumns), 3575ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 358480093f4SDimitry Andric 359480093f4SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 3605ffd83dbSDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 3615ffd83dbSDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {} 362480093f4SDimitry Andric 363480093f4SDimitry Andric bool operator==(const ShapeInfo &other) { 364480093f4SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 365480093f4SDimitry Andric } 366480093f4SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 367480093f4SDimitry Andric 368480093f4SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 369480093f4SDimitry Andric /// are != 0. 370480093f4SDimitry Andric operator bool() const { 371480093f4SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 372480093f4SDimitry Andric return NumRows != 0; 373480093f4SDimitry Andric } 3745ffd83dbSDimitry Andric 3755ffd83dbSDimitry Andric unsigned getStride() const { 3765ffd83dbSDimitry Andric if (IsColumnMajor) 3775ffd83dbSDimitry Andric return NumRows; 3785ffd83dbSDimitry Andric return NumColumns; 3795ffd83dbSDimitry Andric } 3805ffd83dbSDimitry Andric 3815ffd83dbSDimitry Andric unsigned getNumVectors() const { 3825ffd83dbSDimitry Andric if (IsColumnMajor) 3835ffd83dbSDimitry Andric return NumColumns; 3845ffd83dbSDimitry Andric return NumRows; 3855ffd83dbSDimitry Andric } 386480093f4SDimitry Andric }; 387480093f4SDimitry Andric 388480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 389480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 390480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 3915ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the 392480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 393fe6060f1SDimitry Andric /// using shape information as well. A ValueMap is used so that when 394fe6060f1SDimitry Andric /// sub-passes like optimizeTransposes performs RAUW the map stays 395fe6060f1SDimitry Andric /// up-to-date. 396fe6060f1SDimitry Andric ValueMap<Value *, ShapeInfo> ShapeMap; 397480093f4SDimitry Andric 398480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 399480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 400480093f4SDimitry Andric /// those need to be removed after we finished lowering. 401480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 402480093f4SDimitry Andric 403480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 4045ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 405480093f4SDimitry Andric 406fe6060f1SDimitry Andric private: 407fe6060f1SDimitry Andric static FastMathFlags getFastMathFlags(Instruction *Inst) { 408fe6060f1SDimitry Andric FastMathFlags FMF; 409fe6060f1SDimitry Andric 410fe6060f1SDimitry Andric if (isa<FPMathOperator>(*Inst)) 411fe6060f1SDimitry Andric FMF = Inst->getFastMathFlags(); 412fe6060f1SDimitry Andric 413fe6060f1SDimitry Andric FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 414fe6060f1SDimitry Andric 415fe6060f1SDimitry Andric return FMF; 416fe6060f1SDimitry Andric } 417fe6060f1SDimitry Andric 418480093f4SDimitry Andric public: 4195ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 420e8d8bef9SDimitry Andric AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 421e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE) 4225ffd83dbSDimitry Andric : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 4235ffd83dbSDimitry Andric LI(LI), ORE(ORE) {} 424480093f4SDimitry Andric 4255ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) { 4265ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type"); 4275ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(), 4285ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements()); 4295ffd83dbSDimitry Andric } 4305ffd83dbSDimitry Andric 431fe6060f1SDimitry Andric /// Is this the minimal version executed in the backend pipelines. 432fe6060f1SDimitry Andric bool isMinimal() const { 433fe6060f1SDimitry Andric return !DT; 434fe6060f1SDimitry Andric } 435fe6060f1SDimitry Andric 4365ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on 4375ffd83dbSDimitry Andric /// \p VT * N. 4385ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) { 4395ffd83dbSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 440fe6060f1SDimitry Andric double(TTI.getRegisterBitWidth( 441fe6060f1SDimitry Andric TargetTransformInfo::RGK_FixedWidthVector) 442fe6060f1SDimitry Andric .getFixedSize())); 4435ffd83dbSDimitry Andric } 4445ffd83dbSDimitry Andric 4455ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to. 446480093f4SDimitry Andric /// 4475ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 4485ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 4495ffd83dbSDimitry Andric /// into vectors. 4505ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 4515ffd83dbSDimitry Andric IRBuilder<> &Builder) { 452480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 453480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 4545ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() == 4555ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns && 456480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 457480093f4SDimitry Andric 458480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 4595ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape 460480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 461480093f4SDimitry Andric // vector and split it later. 462480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 463480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 4645ffd83dbSDimitry Andric MatrixTy &M = Found->second; 465480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 466480093f4SDimitry Andric // information 467480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 468480093f4SDimitry Andric return M; 469480093f4SDimitry Andric 470480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 471480093f4SDimitry Andric } 472480093f4SDimitry Andric 473480093f4SDimitry Andric // Otherwise split MatrixVal. 474480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 4755ffd83dbSDimitry Andric for (unsigned MaskStart = 0; 4765ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 4775ffd83dbSDimitry Andric MaskStart += SI.getStride()) { 4785ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector( 479e8d8bef9SDimitry Andric MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 4805ffd83dbSDimitry Andric "split"); 481480093f4SDimitry Andric SplitVecs.push_back(V); 482480093f4SDimitry Andric } 483480093f4SDimitry Andric 484480093f4SDimitry Andric return {SplitVecs}; 485480093f4SDimitry Andric } 486480093f4SDimitry Andric 487480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 488480093f4SDimitry Andric /// for instructions that support it. 489480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 490480093f4SDimitry Andric assert(Shape && "Shape not set"); 491480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 492480093f4SDimitry Andric return false; 493480093f4SDimitry Andric 494480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 495480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 496480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 497480093f4SDimitry Andric << SIter->second.NumRows << " " 498480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 499480093f4SDimitry Andric return false; 500480093f4SDimitry Andric } 501480093f4SDimitry Andric 502480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 503480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 504480093f4SDimitry Andric << " for " << *V << "\n"); 505480093f4SDimitry Andric return true; 506480093f4SDimitry Andric } 507480093f4SDimitry Andric 508480093f4SDimitry Andric bool isUniformShape(Value *V) { 509480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 510480093f4SDimitry Andric if (!I) 511480093f4SDimitry Andric return true; 512480093f4SDimitry Andric 513480093f4SDimitry Andric switch (I->getOpcode()) { 514480093f4SDimitry Andric case Instruction::FAdd: 515480093f4SDimitry Andric case Instruction::FSub: 516480093f4SDimitry Andric case Instruction::FMul: // Scalar multiply. 517e8d8bef9SDimitry Andric case Instruction::FNeg: 518480093f4SDimitry Andric case Instruction::Add: 519480093f4SDimitry Andric case Instruction::Mul: 520480093f4SDimitry Andric case Instruction::Sub: 521480093f4SDimitry Andric return true; 522480093f4SDimitry Andric default: 523480093f4SDimitry Andric return false; 524480093f4SDimitry Andric } 525480093f4SDimitry Andric } 526480093f4SDimitry Andric 527480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 528480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 529480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 530480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 531480093f4SDimitry Andric if (!Inst) 532480093f4SDimitry Andric return false; 533480093f4SDimitry Andric 534480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 535480093f4SDimitry Andric if (II) 536480093f4SDimitry Andric switch (II->getIntrinsicID()) { 537480093f4SDimitry Andric case Intrinsic::matrix_multiply: 538480093f4SDimitry Andric case Intrinsic::matrix_transpose: 5395ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 5405ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 541480093f4SDimitry Andric return true; 542480093f4SDimitry Andric default: 543480093f4SDimitry Andric return false; 544480093f4SDimitry Andric } 545480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 546480093f4SDimitry Andric } 547480093f4SDimitry Andric 548480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 549480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 550480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 551480093f4SDimitry Andric /// shapes of operands. 552480093f4SDimitry Andric SmallVector<Instruction *, 32> 553480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 554480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 555480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 556480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 557480093f4SDimitry Andric // list. 558480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 559480093f4SDimitry Andric while (!WorkList.empty()) { 560e8d8bef9SDimitry Andric Instruction *Inst = WorkList.pop_back_val(); 561480093f4SDimitry Andric 562480093f4SDimitry Andric // New entry, set the value and insert operands 563480093f4SDimitry Andric bool Propagate = false; 564480093f4SDimitry Andric 565480093f4SDimitry Andric Value *MatrixA; 566480093f4SDimitry Andric Value *MatrixB; 567480093f4SDimitry Andric Value *M; 568480093f4SDimitry Andric Value *N; 569480093f4SDimitry Andric Value *K; 570480093f4SDimitry Andric if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 571480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 572480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 573480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, K}); 574480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 575480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 576480093f4SDimitry Andric // Flip dimensions. 577480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5785ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 579480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 5805ffd83dbSDimitry Andric m_Value(), m_Value(M), m_Value(N)))) { 581480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5825ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 5835ffd83dbSDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), 5845ffd83dbSDimitry Andric m_Value(N)))) { 585480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, N}); 586480093f4SDimitry Andric } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 587480093f4SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 588480093f4SDimitry Andric if (OpShape != ShapeMap.end()) 589480093f4SDimitry Andric setShapeInfo(Inst, OpShape->second); 590480093f4SDimitry Andric continue; 591480093f4SDimitry Andric } else if (isUniformShape(Inst)) { 592480093f4SDimitry Andric // Find the first operand that has a known shape and use that. 593480093f4SDimitry Andric for (auto &Op : Inst->operands()) { 594480093f4SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 595480093f4SDimitry Andric if (OpShape != ShapeMap.end()) { 596480093f4SDimitry Andric Propagate |= setShapeInfo(Inst, OpShape->second); 597480093f4SDimitry Andric break; 598480093f4SDimitry Andric } 599480093f4SDimitry Andric } 600480093f4SDimitry Andric } 601480093f4SDimitry Andric 602480093f4SDimitry Andric if (Propagate) { 603480093f4SDimitry Andric NewWorkList.push_back(Inst); 604480093f4SDimitry Andric for (auto *User : Inst->users()) 605480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 606480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 607480093f4SDimitry Andric } 608480093f4SDimitry Andric } 609480093f4SDimitry Andric 610480093f4SDimitry Andric return NewWorkList; 611480093f4SDimitry Andric } 612480093f4SDimitry Andric 613480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 614480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 615480093f4SDimitry Andric SmallVector<Instruction *, 32> 616480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 617480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 618480093f4SDimitry Andric 619480093f4SDimitry Andric auto pushInstruction = [](Value *V, 620480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 621480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 622480093f4SDimitry Andric if (I) 623480093f4SDimitry Andric WorkList.push_back(I); 624480093f4SDimitry Andric }; 625480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 626480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 627480093f4SDimitry Andric // worklist. 628480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 629480093f4SDimitry Andric while (!WorkList.empty()) { 630e8d8bef9SDimitry Andric Value *V = WorkList.pop_back_val(); 631480093f4SDimitry Andric 632480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 633480093f4SDimitry Andric if (!isa<Instruction>(V)) 634480093f4SDimitry Andric continue; 635480093f4SDimitry Andric 636480093f4SDimitry Andric Value *MatrixA; 637480093f4SDimitry Andric Value *MatrixB; 638480093f4SDimitry Andric Value *M; 639480093f4SDimitry Andric Value *N; 640480093f4SDimitry Andric Value *K; 641480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 642480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 643480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 644480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 645480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 646480093f4SDimitry Andric 647480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 648480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 649480093f4SDimitry Andric 650480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 651480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 652480093f4SDimitry Andric // Flip dimensions. 653480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 654480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 6555ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 6565ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 657480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 658480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 659480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 660480093f4SDimitry Andric } 661480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 6625ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 663480093f4SDimitry Andric // Nothing to do, no matrix input. 664480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 665480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 666480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 667480093f4SDimitry Andric } else if (isUniformShape(V)) { 668480093f4SDimitry Andric // Propagate to all operands. 669480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 670480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 671480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 672480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 673480093f4SDimitry Andric } 674480093f4SDimitry Andric } 675480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 676480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 677480093f4SDimitry Andric // propagation. 678480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 679480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 680480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 681480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 682480093f4SDimitry Andric } 683480093f4SDimitry Andric return NewWorkList; 684480093f4SDimitry Andric } 685480093f4SDimitry Andric 686fe6060f1SDimitry Andric /// Try moving transposes in order to fold them away or into multiplies. 687fe6060f1SDimitry Andric void optimizeTransposes() { 688fe6060f1SDimitry Andric auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { 689fe6060f1SDimitry Andric // We need to remove Old from the ShapeMap otherwise RAUW will replace it 690fe6060f1SDimitry Andric // with New. We should only add New it it supportsShapeInfo so we insert 691fe6060f1SDimitry Andric // it conditionally instead. 692fe6060f1SDimitry Andric auto S = ShapeMap.find(&Old); 693fe6060f1SDimitry Andric if (S != ShapeMap.end()) { 694fe6060f1SDimitry Andric ShapeMap.erase(S); 695fe6060f1SDimitry Andric if (supportsShapeInfo(New)) 696fe6060f1SDimitry Andric ShapeMap.insert({New, S->second}); 697fe6060f1SDimitry Andric } 698fe6060f1SDimitry Andric Old.replaceAllUsesWith(New); 699fe6060f1SDimitry Andric }; 700fe6060f1SDimitry Andric 701fe6060f1SDimitry Andric // First sink all transposes inside matmuls, hoping that we end up with NN, 702fe6060f1SDimitry Andric // NT or TN variants. 703fe6060f1SDimitry Andric for (BasicBlock &BB : reverse(Func)) { 704fe6060f1SDimitry Andric for (auto II = BB.rbegin(); II != BB.rend();) { 705fe6060f1SDimitry Andric Instruction &I = *II; 706fe6060f1SDimitry Andric // We may remove II. By default continue on the next/prev instruction. 707fe6060f1SDimitry Andric ++II; 708fe6060f1SDimitry Andric // If we were to erase II, move again. 709fe6060f1SDimitry Andric auto EraseFromParent = [&II](Value *V) { 710fe6060f1SDimitry Andric auto *Inst = cast<Instruction>(V); 711fe6060f1SDimitry Andric if (Inst->use_empty()) { 712fe6060f1SDimitry Andric if (Inst == &*II) { 713fe6060f1SDimitry Andric ++II; 714fe6060f1SDimitry Andric } 715fe6060f1SDimitry Andric Inst->eraseFromParent(); 716fe6060f1SDimitry Andric } 717fe6060f1SDimitry Andric }; 718fe6060f1SDimitry Andric 719fe6060f1SDimitry Andric // If we're creating a new instruction, continue from there. 720fe6060f1SDimitry Andric Instruction *NewInst = nullptr; 721fe6060f1SDimitry Andric 722fe6060f1SDimitry Andric IRBuilder<> IB(&I); 723fe6060f1SDimitry Andric MatrixBuilder<IRBuilder<>> Builder(IB); 724fe6060f1SDimitry Andric 725fe6060f1SDimitry Andric Value *TA, *TAMA, *TAMB; 726fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 727fe6060f1SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { 728fe6060f1SDimitry Andric 729fe6060f1SDimitry Andric // Transpose of a transpose is a nop 730fe6060f1SDimitry Andric Value *TATA; 731fe6060f1SDimitry Andric if (match(TA, 732fe6060f1SDimitry Andric m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 733fe6060f1SDimitry Andric ReplaceAllUsesWith(I, TATA); 734fe6060f1SDimitry Andric EraseFromParent(&I); 735fe6060f1SDimitry Andric EraseFromParent(TA); 736fe6060f1SDimitry Andric } 737fe6060f1SDimitry Andric 738fe6060f1SDimitry Andric // (A * B)^t -> B^t * A^t 739fe6060f1SDimitry Andric // RxK KxC CxK KxR 740fe6060f1SDimitry Andric else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 741fe6060f1SDimitry Andric m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 742fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C)))) { 743fe6060f1SDimitry Andric Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), 744fe6060f1SDimitry Andric C->getZExtValue(), 745fe6060f1SDimitry Andric TAMB->getName() + "_t"); 746fe6060f1SDimitry Andric // We are being run after shape prop, add shape for newly created 747fe6060f1SDimitry Andric // instructions so that we lower them later. 748fe6060f1SDimitry Andric setShapeInfo(T0, {C, K}); 749fe6060f1SDimitry Andric Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), 750fe6060f1SDimitry Andric K->getZExtValue(), 751fe6060f1SDimitry Andric TAMA->getName() + "_t"); 752fe6060f1SDimitry Andric setShapeInfo(T1, {K, R}); 753fe6060f1SDimitry Andric NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), 754fe6060f1SDimitry Andric K->getZExtValue(), 755fe6060f1SDimitry Andric R->getZExtValue(), "mmul"); 756fe6060f1SDimitry Andric ReplaceAllUsesWith(I, NewInst); 757fe6060f1SDimitry Andric EraseFromParent(&I); 758fe6060f1SDimitry Andric EraseFromParent(TA); 759fe6060f1SDimitry Andric } 760fe6060f1SDimitry Andric } 761fe6060f1SDimitry Andric 762fe6060f1SDimitry Andric // If we replaced I with a new instruction, continue from there. 763fe6060f1SDimitry Andric if (NewInst) 764fe6060f1SDimitry Andric II = std::next(BasicBlock::reverse_iterator(NewInst)); 765fe6060f1SDimitry Andric } 766fe6060f1SDimitry Andric } 767fe6060f1SDimitry Andric 768fe6060f1SDimitry Andric // If we have a TT matmul, lift the transpose. We may be able to fold into 769fe6060f1SDimitry Andric // consuming multiply. 770fe6060f1SDimitry Andric for (BasicBlock &BB : Func) { 771fe6060f1SDimitry Andric for (BasicBlock::iterator II = BB.begin(); II != BB.end();) { 772fe6060f1SDimitry Andric Instruction *I = &*II; 773fe6060f1SDimitry Andric // We may remove I. 774fe6060f1SDimitry Andric ++II; 775fe6060f1SDimitry Andric Value *A, *B, *AT, *BT; 776fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 777fe6060f1SDimitry Andric // A^t * B ^t -> (B * A)^t 778fe6060f1SDimitry Andric if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>( 779fe6060f1SDimitry Andric m_Value(A), m_Value(B), m_ConstantInt(R), 780fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C))) && 781fe6060f1SDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 782fe6060f1SDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 783fe6060f1SDimitry Andric IRBuilder<> IB(&*I); 784fe6060f1SDimitry Andric MatrixBuilder<IRBuilder<>> Builder(IB); 785fe6060f1SDimitry Andric Value *M = Builder.CreateMatrixMultiply( 786fe6060f1SDimitry Andric BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 787fe6060f1SDimitry Andric setShapeInfo(M, {C, R}); 788fe6060f1SDimitry Andric Instruction *NewInst = Builder.CreateMatrixTranspose( 789fe6060f1SDimitry Andric M, C->getZExtValue(), R->getZExtValue()); 790fe6060f1SDimitry Andric ReplaceAllUsesWith(*I, NewInst); 791fe6060f1SDimitry Andric if (I->use_empty()) 792fe6060f1SDimitry Andric I->eraseFromParent(); 793fe6060f1SDimitry Andric if (A->use_empty()) 794fe6060f1SDimitry Andric cast<Instruction>(A)->eraseFromParent(); 795fe6060f1SDimitry Andric if (A != B && B->use_empty()) 796fe6060f1SDimitry Andric cast<Instruction>(B)->eraseFromParent(); 797fe6060f1SDimitry Andric } 798fe6060f1SDimitry Andric } 799fe6060f1SDimitry Andric } 800fe6060f1SDimitry Andric } 801fe6060f1SDimitry Andric 802480093f4SDimitry Andric bool Visit() { 803480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 804480093f4SDimitry Andric 805480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 806480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 807480093f4SDimitry Andric for (BasicBlock &BB : Func) 808480093f4SDimitry Andric for (Instruction &Inst : BB) { 809480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 810480093f4SDimitry Andric if (!II) 811480093f4SDimitry Andric continue; 812480093f4SDimitry Andric 813480093f4SDimitry Andric switch (II->getIntrinsicID()) { 814480093f4SDimitry Andric case Intrinsic::matrix_multiply: 815480093f4SDimitry Andric case Intrinsic::matrix_transpose: 8165ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 8175ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 818480093f4SDimitry Andric WorkList.push_back(&Inst); 819480093f4SDimitry Andric break; 820480093f4SDimitry Andric default: 821480093f4SDimitry Andric break; 822480093f4SDimitry Andric } 823480093f4SDimitry Andric } 824fe6060f1SDimitry Andric 825fe6060f1SDimitry Andric // Avoid unnecessary work if there are no matrix intrinsics in the function. 826fe6060f1SDimitry Andric if (WorkList.empty()) 827fe6060f1SDimitry Andric return false; 828fe6060f1SDimitry Andric 829480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 830480093f4SDimitry Andric while (!WorkList.empty()) { 831480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 832480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 833480093f4SDimitry Andric } 834fe6060f1SDimitry Andric 835fe6060f1SDimitry Andric if (!isMinimal()) { 836fe6060f1SDimitry Andric optimizeTransposes(); 837fe6060f1SDimitry Andric LLVM_DEBUG({ 838fe6060f1SDimitry Andric dbgs() << "Dump after matrix transpose optimization:\n"; 839fe6060f1SDimitry Andric Func.dump(); 840fe6060f1SDimitry Andric }); 841480093f4SDimitry Andric } 842480093f4SDimitry Andric 843480093f4SDimitry Andric bool Changed = false; 8445ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts; 8455ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts; 846480093f4SDimitry Andric 8475ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for 8485ffd83dbSDimitry Andric // fusion (currently only matrix multiplies). 8495ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 8505ffd83dbSDimitry Andric for (auto *BB : RPOT) 8515ffd83dbSDimitry Andric for (Instruction &I : *BB) { 8525ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end()) 8535ffd83dbSDimitry Andric continue; 8545ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 8555ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I)); 8565ffd83dbSDimitry Andric MatrixInsts.push_back(&I); 8575ffd83dbSDimitry Andric } 8585ffd83dbSDimitry Andric 8595ffd83dbSDimitry Andric // Second, try to fuse candidates. 8605ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts; 8615ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts) 8625ffd83dbSDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts); 8635ffd83dbSDimitry Andric Changed = !FusedInsts.empty(); 8645ffd83dbSDimitry Andric 8655ffd83dbSDimitry Andric // Third, lower remaining instructions with shape information. 8665ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) { 8675ffd83dbSDimitry Andric if (FusedInsts.count(Inst)) 8685ffd83dbSDimitry Andric continue; 8695ffd83dbSDimitry Andric 8705ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 8715ffd83dbSDimitry Andric 8725ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 873480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 874480093f4SDimitry Andric 875480093f4SDimitry Andric Value *Op1; 876480093f4SDimitry Andric Value *Op2; 8775ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 878480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 879e8d8bef9SDimitry Andric if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 880e8d8bef9SDimitry Andric Changed |= VisitUnaryOperator(UnOp); 8815ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1)))) 8825ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 8835ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 8845ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 885480093f4SDimitry Andric } 8865ffd83dbSDimitry Andric 887e8d8bef9SDimitry Andric if (ORE) { 888e8d8bef9SDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 8895ffd83dbSDimitry Andric RemarkGen.emitRemarks(); 890e8d8bef9SDimitry Andric } 891480093f4SDimitry Andric 892fe6060f1SDimitry Andric // Delete the instructions backwards, as it has a reduced likelihood of 893fe6060f1SDimitry Andric // having to update as many def-use and use-def chains. 894fe6060f1SDimitry Andric // 895fe6060f1SDimitry Andric // Because we add to ToRemove during fusion we can't guarantee that defs 896fe6060f1SDimitry Andric // are before uses. Change uses to undef temporarily as these should get 897fe6060f1SDimitry Andric // removed as well. 898fe6060f1SDimitry Andric // 899fe6060f1SDimitry Andric // For verification, we keep track of where we changed uses to undefs in 900fe6060f1SDimitry Andric // UndefedInsts and then check that we in fact remove them. 901fe6060f1SDimitry Andric SmallSet<Instruction *, 16> UndefedInsts; 902fe6060f1SDimitry Andric for (auto *Inst : reverse(ToRemove)) { 903*349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 904fe6060f1SDimitry Andric if (auto *Undefed = dyn_cast<Instruction>(U.getUser())) 905fe6060f1SDimitry Andric UndefedInsts.insert(Undefed); 906fe6060f1SDimitry Andric U.set(UndefValue::get(Inst->getType())); 907fe6060f1SDimitry Andric } 908480093f4SDimitry Andric Inst->eraseFromParent(); 909fe6060f1SDimitry Andric UndefedInsts.erase(Inst); 910fe6060f1SDimitry Andric } 911fe6060f1SDimitry Andric if (!UndefedInsts.empty()) { 912fe6060f1SDimitry Andric // If we didn't remove all undefed instructions, it's a hard error. 913fe6060f1SDimitry Andric dbgs() << "Undefed but present instructions:\n"; 914fe6060f1SDimitry Andric for (auto *I : UndefedInsts) 915fe6060f1SDimitry Andric dbgs() << *I << "\n"; 916fe6060f1SDimitry Andric llvm_unreachable("Undefed but instruction not removed"); 917fe6060f1SDimitry Andric } 918480093f4SDimitry Andric 919480093f4SDimitry Andric return Changed; 920480093f4SDimitry Andric } 921480093f4SDimitry Andric 922480093f4SDimitry Andric /// Turns \p BasePtr into an elementwise pointer to \p EltType. 923480093f4SDimitry Andric Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 924480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 925480093f4SDimitry Andric Type *EltPtrType = PointerType::get(EltType, AS); 926480093f4SDimitry Andric return Builder.CreatePointerCast(BasePtr, EltPtrType); 927480093f4SDimitry Andric } 928480093f4SDimitry Andric 929480093f4SDimitry Andric /// Replace intrinsic calls 930480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 931480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 932480093f4SDimitry Andric return false; 933480093f4SDimitry Andric 934480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 935480093f4SDimitry Andric case Intrinsic::matrix_multiply: 936480093f4SDimitry Andric LowerMultiply(Inst); 937480093f4SDimitry Andric break; 938480093f4SDimitry Andric case Intrinsic::matrix_transpose: 939480093f4SDimitry Andric LowerTranspose(Inst); 940480093f4SDimitry Andric break; 9415ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 9425ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst); 943480093f4SDimitry Andric break; 9445ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 9455ffd83dbSDimitry Andric LowerColumnMajorStore(Inst); 946480093f4SDimitry Andric break; 947480093f4SDimitry Andric default: 948480093f4SDimitry Andric return false; 949480093f4SDimitry Andric } 950480093f4SDimitry Andric return true; 951480093f4SDimitry Andric } 952480093f4SDimitry Andric 9535ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them. 9545ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 9555ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For 9565ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial 9575ffd83dbSDimitry Andric /// alignment and the element size in bytes. 9585ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 9595ffd83dbSDimitry Andric MaybeAlign A) const { 9605ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 9615ffd83dbSDimitry Andric if (Idx == 0) 9625ffd83dbSDimitry Andric return InitialAlign; 9635ffd83dbSDimitry Andric 9645ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 9655ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 9665ffd83dbSDimitry Andric uint64_t StrideInBytes = 9675ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8; 9685ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes); 9695ffd83dbSDimitry Andric } 9705ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8); 9715ffd83dbSDimitry Andric } 9725ffd83dbSDimitry Andric 9735ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 9745ffd83dbSDimitry Andric /// vectors. 9755ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 9765ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 977fe6060f1SDimitry Andric auto *VType = cast<VectorType>(Ty); 978fe6060f1SDimitry Andric Type *EltTy = VType->getElementType(); 979fe6060f1SDimitry Andric Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 980fe6060f1SDimitry Andric Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 9815ffd83dbSDimitry Andric MatrixTy Result; 9825ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 983*349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 984*349cc55cSDimitry Andric EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 985*349cc55cSDimitry Andric Stride, Shape.getStride(), EltTy, Builder); 9865ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad( 987fe6060f1SDimitry Andric VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 9885ffd83dbSDimitry Andric IsVolatile, "col.load"); 9895ffd83dbSDimitry Andric 9905ffd83dbSDimitry Andric Result.addVector(Vector); 9915ffd83dbSDimitry Andric } 9925ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 9935ffd83dbSDimitry Andric Result.getNumVectors()); 994480093f4SDimitry Andric } 995480093f4SDimitry Andric 9965ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 9975ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J]. 9985ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 9995ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J, 10005ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy, 10015ffd83dbSDimitry Andric IRBuilder<> &Builder) { 10025ffd83dbSDimitry Andric 10035ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 10045ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 10055ffd83dbSDimitry Andric 10065ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 10075ffd83dbSDimitry Andric Value *EltPtr = 10085ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 10095ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 10105ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 10115ffd83dbSDimitry Andric ResultShape.NumColumns); 10125ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 10135ffd83dbSDimitry Andric Value *TilePtr = 10145ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 10155ffd83dbSDimitry Andric 10165ffd83dbSDimitry Andric return loadMatrix(TileTy, TilePtr, Align, 10175ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, 10185ffd83dbSDimitry Andric ResultShape, Builder); 1019480093f4SDimitry Andric } 1020480093f4SDimitry Andric 10215ffd83dbSDimitry Andric /// Lower a load instruction with shape information. 10225ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 10235ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) { 10245ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 10255ffd83dbSDimitry Andric finalizeLowering(Inst, 10265ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 10275ffd83dbSDimitry Andric Shape, Builder), 10285ffd83dbSDimitry Andric Builder); 10295ffd83dbSDimitry Andric } 10305ffd83dbSDimitry Andric 10315ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load. 1032480093f4SDimitry Andric /// 1033480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 10345ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) { 10355ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 10365ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 1037480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 1038480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 10395ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 10405ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1041480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1042480093f4SDimitry Andric } 1043480093f4SDimitry Andric 10445ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 10455ffd83dbSDimitry Andric /// MatrixPtr[I][J]. 10465ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 10475ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 10485ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 10495ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 10505ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 10515ffd83dbSDimitry Andric 10525ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 10535ffd83dbSDimitry Andric Value *EltPtr = 10545ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 10555ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 10565ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 10575ffd83dbSDimitry Andric StoreVal.getNumColumns()); 10585ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 10595ffd83dbSDimitry Andric Value *TilePtr = 10605ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 10615ffd83dbSDimitry Andric 10625ffd83dbSDimitry Andric storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 10635ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 10645ffd83dbSDimitry Andric } 10655ffd83dbSDimitry Andric 10665ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 10675ffd83dbSDimitry Andric /// vectors. 10685ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 10695ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile, 10705ffd83dbSDimitry Andric IRBuilder<> &Builder) { 10715ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 10725ffd83dbSDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 10735ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) { 1074*349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 1075*349cc55cSDimitry Andric EltPtr, 1076*349cc55cSDimitry Andric Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1077*349cc55cSDimitry Andric Vec.index()), 1078*349cc55cSDimitry Andric Stride, StoreVal.getStride(), VType->getElementType(), Builder); 10795ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP, 10805ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride, 10815ffd83dbSDimitry Andric VType->getElementType(), 10825ffd83dbSDimitry Andric MAlign), 10835ffd83dbSDimitry Andric IsVolatile); 10845ffd83dbSDimitry Andric } 10855ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 10865ffd83dbSDimitry Andric StoreVal.getNumVectors()); 10875ffd83dbSDimitry Andric } 10885ffd83dbSDimitry Andric 10895ffd83dbSDimitry Andric /// Lower a store instruction with shape information. 10905ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 10915ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) { 10925ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 10935ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder); 10945ffd83dbSDimitry Andric finalizeLowering(Inst, 10955ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 10965ffd83dbSDimitry Andric IsVolatile, Builder), 10975ffd83dbSDimitry Andric Builder); 10985ffd83dbSDimitry Andric } 10995ffd83dbSDimitry Andric 11005ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store. 11015ffd83dbSDimitry Andric /// 11025ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 11035ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) { 11045ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 11055ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 11065ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0); 11075ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1); 11085ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2); 11095ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 11105ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 11115ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1112480093f4SDimitry Andric } 1113480093f4SDimitry Andric 1114480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 1115480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 11165ffd83dbSDimitry Andric IRBuilder<> &Builder) { 1117480093f4SDimitry Andric 1118480093f4SDimitry Andric // First, bring Block to the same size as Col 1119480093f4SDimitry Andric unsigned BlockNumElts = 11205ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements(); 11215ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1122480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1123480093f4SDimitry Andric 11245ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector( 1125e8d8bef9SDimitry Andric Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1126480093f4SDimitry Andric 1127480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1128480093f4SDimitry Andric // 8, 4, 5, 6 11295ffd83dbSDimitry Andric SmallVector<int, 16> Mask; 1130480093f4SDimitry Andric unsigned i; 1131480093f4SDimitry Andric for (i = 0; i < I; i++) 11325ffd83dbSDimitry Andric Mask.push_back(i); 1133480093f4SDimitry Andric 11345ffd83dbSDimitry Andric unsigned VecNumElts = 11355ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements(); 1136480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 11375ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts); 1138480093f4SDimitry Andric 1139480093f4SDimitry Andric for (; i < VecNumElts; i++) 11405ffd83dbSDimitry Andric Mask.push_back(i); 1141480093f4SDimitry Andric 11425ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask); 1143480093f4SDimitry Andric } 1144480093f4SDimitry Andric 1145480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 11465ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction, 11475ffd83dbSDimitry Andric unsigned &NumComputeOps) { 11485ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1149480093f4SDimitry Andric if (!Sum) 1150480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1151480093f4SDimitry Andric 1152480093f4SDimitry Andric if (UseFPOp) { 1153480093f4SDimitry Andric if (AllowContraction) { 1154480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 1155480093f4SDimitry Andric // if that's profitable. 11565ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration( 1157480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 1158480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1159480093f4SDimitry Andric } 11605ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1161480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 1162480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 1163480093f4SDimitry Andric } 1164480093f4SDimitry Andric 11655ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1166480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 1167480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 1168480093f4SDimitry Andric } 1169480093f4SDimitry Andric 1170480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1171fe6060f1SDimitry Andric /// users with shape information, there's nothing to do: they will use the 1172480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 1173480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 1174480093f4SDimitry Andric /// deletion. 11755ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1176480093f4SDimitry Andric IRBuilder<> &Builder) { 1177fe6060f1SDimitry Andric auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1178fe6060f1SDimitry Andric (void)inserted; 1179fe6060f1SDimitry Andric assert(inserted.second && "multiple matrix lowering mapping"); 1180480093f4SDimitry Andric 1181480093f4SDimitry Andric ToRemove.push_back(Inst); 1182480093f4SDimitry Andric Value *Flattened = nullptr; 1183fe6060f1SDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1184480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1185480093f4SDimitry Andric if (!Flattened) 1186480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 1187480093f4SDimitry Andric U.set(Flattened); 1188480093f4SDimitry Andric } 1189480093f4SDimitry Andric } 1190480093f4SDimitry Andric } 1191480093f4SDimitry Andric 11925ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating 11935ffd83dbSDimitry Andric /// addition. 1194fe6060f1SDimitry Andric /// 1195fe6060f1SDimitry Andric /// We can fold a transpose into the operand that is used to extract scalars. 1196fe6060f1SDimitry Andric /// This is the first operands with row-major and the second with 1197fe6060f1SDimitry Andric /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1198fe6060f1SDimitry Andric /// operand is transposed. 11995ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1200fe6060f1SDimitry Andric const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1201fe6060f1SDimitry Andric bool IsScalarMatrixTransposed, FastMathFlags FMF) { 12025ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>( 1203fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1204fe6060f1SDimitry Andric .getFixedSize() / 12055ffd83dbSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 12065ffd83dbSDimitry Andric 1U); 12075ffd83dbSDimitry Andric unsigned R = Result.getNumRows(); 12085ffd83dbSDimitry Andric unsigned C = Result.getNumColumns(); 12095ffd83dbSDimitry Andric unsigned M = A.getNumColumns(); 12105ffd83dbSDimitry Andric 12115ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy(); 12125ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 12135ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 12145ffd83dbSDimitry Andric "operands must agree on matrix layout"); 12155ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 1216fe6060f1SDimitry Andric 1217fe6060f1SDimitry Andric Builder.setFastMathFlags(FMF); 1218fe6060f1SDimitry Andric 12195ffd83dbSDimitry Andric if (A.isColumnMajor()) { 12205ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second 12215ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 12225ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation. 12235ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) { 12245ffd83dbSDimitry Andric unsigned BlockSize = VF; 12255ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration. 12265ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 12275ffd83dbSDimitry Andric 12285ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 12295ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 12305ffd83dbSDimitry Andric while (I + BlockSize > R) 12315ffd83dbSDimitry Andric BlockSize /= 2; 12325ffd83dbSDimitry Andric 1233fe6060f1SDimitry Andric Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 12345ffd83dbSDimitry Andric : nullptr; 12355ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 12365ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder); 1237fe6060f1SDimitry Andric Value *RH = Builder.CreateExtractElement( 1238fe6060f1SDimitry Andric B.getColumn(IsScalarMatrixTransposed ? K : J), 1239fe6060f1SDimitry Andric IsScalarMatrixTransposed ? J : K); 12405ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1241fe6060f1SDimitry Andric Sum = 1242fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1243fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 12445ffd83dbSDimitry Andric } 12455ffd83dbSDimitry Andric Result.setVector(J, 12465ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder)); 12475ffd83dbSDimitry Andric } 12485ffd83dbSDimitry Andric } 12495ffd83dbSDimitry Andric } else { 12505ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first 12515ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this 12525ffd83dbSDimitry Andric // the adds can be vectorized without reassociation. 12535ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) { 12545ffd83dbSDimitry Andric unsigned BlockSize = VF; 12555ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 12565ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) { 12575ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 12585ffd83dbSDimitry Andric while (J + BlockSize > C) 12595ffd83dbSDimitry Andric BlockSize /= 2; 12605ffd83dbSDimitry Andric 12615ffd83dbSDimitry Andric Value *Sum = nullptr; 12625ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 12635ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder); 1264fe6060f1SDimitry Andric Value *LH = Builder.CreateExtractElement( 1265fe6060f1SDimitry Andric A.getVector(IsScalarMatrixTransposed ? K : I), 1266fe6060f1SDimitry Andric IsScalarMatrixTransposed ? I : K); 12675ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1268fe6060f1SDimitry Andric Sum = 1269fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1270fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 12715ffd83dbSDimitry Andric } 12725ffd83dbSDimitry Andric Result.setVector(I, 12735ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder)); 12745ffd83dbSDimitry Andric } 12755ffd83dbSDimitry Andric } 12765ffd83dbSDimitry Andric } 12775ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps); 12785ffd83dbSDimitry Andric } 12795ffd83dbSDimitry Andric 12805ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially 12815ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location 12825ffd83dbSDimitry Andric /// is returned. 12835ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 12845ffd83dbSDimitry Andric CallInst *MatMul) { 12855ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store); 12865ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load); 12875ffd83dbSDimitry Andric 12885ffd83dbSDimitry Andric // If we can statically determine noalias we're good. 1289fe6060f1SDimitry Andric if (AA->isNoAlias(LoadLoc, StoreLoc)) 12905ffd83dbSDimitry Andric return Load->getPointerOperand(); 12915ffd83dbSDimitry Andric 12925ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store 12935ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer. 12945ffd83dbSDimitry Andric 12955ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy. 12965ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent(); 12975ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 12985ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work, 12995ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches. 13005ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 13015ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0)) 1302e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Delete, Check0, Succ}); 13035ffd83dbSDimitry Andric 1304e8d8bef9SDimitry Andric BasicBlock *Check1 = 1305e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 13065ffd83dbSDimitry Andric nullptr, "alias_cont"); 13075ffd83dbSDimitry Andric BasicBlock *Copy = 1308e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1309e8d8bef9SDimitry Andric nullptr, "copy"); 1310e8d8bef9SDimitry Andric BasicBlock *Fusion = 1311e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 13125ffd83dbSDimitry Andric nullptr, "no_alias"); 13135ffd83dbSDimitry Andric 13145ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store 13155ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are 13165ffd83dbSDimitry Andric // guaranteed to not overlap. 13175ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul); 13185ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent(); 13195ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0); 13205ffd83dbSDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 13215ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt( 13225ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 13235ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd( 13245ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 13255ffd83dbSDimitry Andric "store.end", true, true); 13265ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 13275ffd83dbSDimitry Andric IntPtrTy, "load.begin"); 13285ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 13295ffd83dbSDimitry Andric Fusion); 13305ffd83dbSDimitry Andric 13315ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the 13325ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not 13335ffd83dbSDimitry Andric // overlap. 13345ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent(); 13355ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin()); 13365ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd( 13375ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 13385ffd83dbSDimitry Andric "load.end", true, true); 13395ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 13405ffd83dbSDimitry Andric Fusion); 13415ffd83dbSDimitry Andric 13425ffd83dbSDimitry Andric // Copy load operand to new alloca. 13435ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin()); 13445ffd83dbSDimitry Andric AllocaInst *NewLd = 13455ffd83dbSDimitry Andric Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 13465ffd83dbSDimitry Andric Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 13475ffd83dbSDimitry Andric Load->getPointerOperand(), Load->getAlign(), 13485ffd83dbSDimitry Andric LoadLoc.Size.getValue()); 13495ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin()); 13505ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 13515ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0); 13525ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1); 13535ffd83dbSDimitry Andric PHI->addIncoming(NewLd, Copy); 13545ffd83dbSDimitry Andric 13555ffd83dbSDimitry Andric // Adjust DT. 1356e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Check1}); 1357e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1358e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Copy}); 1359e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1360e8d8bef9SDimitry Andric DT->applyUpdates(DTUpdates); 13615ffd83dbSDimitry Andric return PHI; 13625ffd83dbSDimitry Andric } 13635ffd83dbSDimitry Andric 13645ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) { 13655ffd83dbSDimitry Andric if (ForceFusion) 13665ffd83dbSDimitry Andric return true; 13675ffd83dbSDimitry Andric 13685ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 13695ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 13705ffd83dbSDimitry Andric 13715ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 13725ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 13735ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 13745ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 13755ffd83dbSDimitry Andric 1376fe6060f1SDimitry Andric const unsigned VF = std::max<unsigned>( 1377fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1378fe6060f1SDimitry Andric .getFixedSize() / 13795ffd83dbSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedSize(), 13805ffd83dbSDimitry Andric 1U); 13815ffd83dbSDimitry Andric 13825ffd83dbSDimitry Andric // Cost model for tiling 13835ffd83dbSDimitry Andric // 13845ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or 13855ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least 13865ffd83dbSDimitry Andric // 3 elements. 13875ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias. 13885ffd83dbSDimitry Andric if (R <= VF && C == 1) 13895ffd83dbSDimitry Andric return false; 13905ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector 13915ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since 13925ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of 13935ffd83dbSDimitry Andric // reloads necessary. 13945ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M; 13955ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C; 13965ffd83dbSDimitry Andric return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 13975ffd83dbSDimitry Andric } 13985ffd83dbSDimitry Andric 13995ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 14005ffd83dbSDimitry Andric MatrixTy Res; 14015ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R); 14025ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I) 14035ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType)); 14045ffd83dbSDimitry Andric return Res; 14055ffd83dbSDimitry Andric } 14065ffd83dbSDimitry Andric 1407e8d8bef9SDimitry Andric void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1408fe6060f1SDimitry Andric Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1409e8d8bef9SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1410e8d8bef9SDimitry Andric 1411e8d8bef9SDimitry Andric // Create the main tiling loop nest. 1412e8d8bef9SDimitry Andric TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1413e8d8bef9SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1414e8d8bef9SDimitry Andric Instruction *InsertI = cast<Instruction>(MatMul); 1415e8d8bef9SDimitry Andric BasicBlock *Start = InsertI->getParent(); 1416e8d8bef9SDimitry Andric BasicBlock *End = 1417e8d8bef9SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1418e8d8bef9SDimitry Andric IRBuilder<> Builder(MatMul); 1419e8d8bef9SDimitry Andric BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1420e8d8bef9SDimitry Andric 1421e8d8bef9SDimitry Andric Type *TileVecTy = 1422e8d8bef9SDimitry Andric FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1423e8d8bef9SDimitry Andric MatrixTy TileResult; 1424e8d8bef9SDimitry Andric // Insert in the inner loop header. 1425e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1426e8d8bef9SDimitry Andric // Create PHI nodes for the result columns to accumulate across iterations. 1427e8d8bef9SDimitry Andric SmallVector<PHINode *, 4> ColumnPhis; 1428e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileSize; I++) { 1429e8d8bef9SDimitry Andric auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1430e8d8bef9SDimitry Andric Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1431e8d8bef9SDimitry Andric TI.RowLoopHeader->getSingleSuccessor()); 1432e8d8bef9SDimitry Andric TileResult.addVector(Phi); 1433e8d8bef9SDimitry Andric ColumnPhis.push_back(Phi); 1434e8d8bef9SDimitry Andric } 1435e8d8bef9SDimitry Andric 1436e8d8bef9SDimitry Andric // Insert in the inner loop body, which computes 1437e8d8bef9SDimitry Andric // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1438e8d8bef9SDimitry Andric Builder.SetInsertPoint(InnerBody->getTerminator()); 1439e8d8bef9SDimitry Andric // Load tiles of the operands. 1440e8d8bef9SDimitry Andric MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1441e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1442e8d8bef9SDimitry Andric MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1443e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1444fe6060f1SDimitry Andric emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1445fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1446e8d8bef9SDimitry Andric // Store result after the inner loop is done. 1447e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1448e8d8bef9SDimitry Andric storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1449e8d8bef9SDimitry Andric Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1450e8d8bef9SDimitry Andric TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1451e8d8bef9SDimitry Andric 1452e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1453e8d8bef9SDimitry Andric ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1454e8d8bef9SDimitry Andric 1455e8d8bef9SDimitry Andric // Force unrolling of a few iterations of the inner loop, to make sure there 1456e8d8bef9SDimitry Andric // is enough work per iteration. 1457e8d8bef9SDimitry Andric // FIXME: The unroller should make this decision directly instead, but 1458e8d8bef9SDimitry Andric // currently the cost-model is not up to the task. 1459e8d8bef9SDimitry Andric unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1460e8d8bef9SDimitry Andric addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1461e8d8bef9SDimitry Andric "llvm.loop.unroll.count", InnerLoopUnrollCount); 1462e8d8bef9SDimitry Andric } 1463e8d8bef9SDimitry Andric 14645ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 14655ffd83dbSDimitry Andric StoreInst *Store, 14665ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 14675ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 14685ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!"); 14695ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul)) 14705ffd83dbSDimitry Andric return; 14715ffd83dbSDimitry Andric 14725ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 14735ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 14745ffd83dbSDimitry Andric 14755ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 14765ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 14775ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 14785ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 14795ffd83dbSDimitry Andric 14805ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 14815ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 14825ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand(); 14835ffd83dbSDimitry Andric 1484e8d8bef9SDimitry Andric if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1485fe6060f1SDimitry Andric createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1486e8d8bef9SDimitry Andric else { 14875ffd83dbSDimitry Andric IRBuilder<> Builder(Store); 14885ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize) 14895ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) { 14905ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize)); 14915ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize)); 14925ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 14935ffd83dbSDimitry Andric 14945ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) { 14955ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize)); 14965ffd83dbSDimitry Andric MatrixTy A = 14975ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 14985ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K), 14995ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder); 15005ffd83dbSDimitry Andric MatrixTy B = 15015ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 15025ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J), 15035ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder); 1504fe6060f1SDimitry Andric emitMatrixMultiply(Res, A, B, Builder, true, false, 1505fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 15065ffd83dbSDimitry Andric } 15075ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1508e8d8bef9SDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType, 1509e8d8bef9SDimitry Andric Builder); 1510e8d8bef9SDimitry Andric } 15115ffd83dbSDimitry Andric } 15125ffd83dbSDimitry Andric 15135ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them. 15145ffd83dbSDimitry Andric FusedInsts.insert(Store); 15155ffd83dbSDimitry Andric FusedInsts.insert(MatMul); 15165ffd83dbSDimitry Andric Store->eraseFromParent(); 15175ffd83dbSDimitry Andric MatMul->eraseFromParent(); 15185ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) { 15195ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0); 15205ffd83dbSDimitry Andric LoadOp0->eraseFromParent(); 15215ffd83dbSDimitry Andric } 1522fe6060f1SDimitry Andric if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 15235ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1); 15245ffd83dbSDimitry Andric LoadOp1->eraseFromParent(); 15255ffd83dbSDimitry Andric } 15265ffd83dbSDimitry Andric } 15275ffd83dbSDimitry Andric 15285ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations. 15295ffd83dbSDimitry Andric /// 1530fe6060f1SDimitry Andric /// Call finalizeLowering on lowered instructions. Instructions that are 1531fe6060f1SDimitry Andric /// completely eliminated by fusion are added to \p FusedInsts. 15325ffd83dbSDimitry Andric void LowerMatrixMultiplyFused(CallInst *MatMul, 15335ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 1534fe6060f1SDimitry Andric if (!FuseMatrix || !DT) 15355ffd83dbSDimitry Andric return; 15365ffd83dbSDimitry Andric 1537e8d8bef9SDimitry Andric assert(AA && LI && "Analyses should be available"); 1538e8d8bef9SDimitry Andric 1539fe6060f1SDimitry Andric Value *A = MatMul->getArgOperand(0); 1540fe6060f1SDimitry Andric Value *B = MatMul->getArgOperand(1); 1541fe6060f1SDimitry Andric 1542fe6060f1SDimitry Andric // We can fold the transpose into the operand that is used to fetch scalars. 1543fe6060f1SDimitry Andric Value *T; 1544fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1545fe6060f1SDimitry Andric ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1546fe6060f1SDimitry Andric : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1547fe6060f1SDimitry Andric IRBuilder<> Builder(MatMul); 1548fe6060f1SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1549fe6060f1SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1550fe6060f1SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1551fe6060f1SDimitry Andric const unsigned R = LShape.NumRows; 1552fe6060f1SDimitry Andric const unsigned M = LShape.NumColumns; 1553fe6060f1SDimitry Andric const unsigned C = RShape.NumColumns; 1554fe6060f1SDimitry Andric 1555fe6060f1SDimitry Andric MatrixTy MA; 1556fe6060f1SDimitry Andric MatrixTy MB; 1557fe6060f1SDimitry Andric 1558fe6060f1SDimitry Andric Value *Transpose; 1559fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1560fe6060f1SDimitry Andric MA = getMatrix(A, ShapeInfo(R, M), Builder); 1561fe6060f1SDimitry Andric MB = getMatrix(T, ShapeInfo(C, M), Builder); 1562fe6060f1SDimitry Andric Transpose = B; 1563fe6060f1SDimitry Andric } else { 1564fe6060f1SDimitry Andric MA = getMatrix(T, ShapeInfo(R, M), Builder); 1565fe6060f1SDimitry Andric MB = getMatrix(B, ShapeInfo(C, M), Builder); 1566fe6060f1SDimitry Andric Transpose = A; 1567fe6060f1SDimitry Andric } 1568fe6060f1SDimitry Andric 1569fe6060f1SDimitry Andric // Initialize the output 1570fe6060f1SDimitry Andric MatrixTy Result(R, C, EltType); 1571fe6060f1SDimitry Andric 1572fe6060f1SDimitry Andric emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1573fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1574fe6060f1SDimitry Andric 1575fe6060f1SDimitry Andric FusedInsts.insert(MatMul); 1576fe6060f1SDimitry Andric if (Transpose->hasOneUse()) { 1577fe6060f1SDimitry Andric FusedInsts.insert(cast<Instruction>(Transpose)); 1578fe6060f1SDimitry Andric ToRemove.push_back(cast<Instruction>(Transpose)); 1579fe6060f1SDimitry Andric // TODO: add a fake entry for the folded instruction so that this is 1580fe6060f1SDimitry Andric // included in the expression in the remark. 1581fe6060f1SDimitry Andric Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1582fe6060f1SDimitry Andric } 1583fe6060f1SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1584fe6060f1SDimitry Andric return; 1585fe6060f1SDimitry Andric } 1586fe6060f1SDimitry Andric 1587fe6060f1SDimitry Andric if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1588fe6060f1SDimitry Andric return; 1589fe6060f1SDimitry Andric 1590fe6060f1SDimitry Andric // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1591fe6060f1SDimitry Andric // since the single store user will be lowered as part of this. 1592fe6060f1SDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(A); 1593fe6060f1SDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(B); 15945ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 15955ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) { 15965ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise 15975ffd83dbSDimitry Andric // we create invalid IR. 1598fe6060f1SDimitry Andric SetVector<Value *> WorkList; 1599fe6060f1SDimitry Andric WorkList.insert(Store->getOperand(1)); 1600fe6060f1SDimitry Andric SmallVector<Instruction *> ToHoist; 1601fe6060f1SDimitry Andric for (unsigned I = 0; I != WorkList.size(); ++I) { 1602fe6060f1SDimitry Andric Value *Current = WorkList[I]; 1603fe6060f1SDimitry Andric auto *CurrI = dyn_cast<Instruction>(Current); 1604fe6060f1SDimitry Andric if (!CurrI) 1605fe6060f1SDimitry Andric continue; 1606fe6060f1SDimitry Andric if (isa<PHINode>(CurrI)) 16075ffd83dbSDimitry Andric return; 1608fe6060f1SDimitry Andric if (DT->dominates(CurrI, MatMul)) 1609fe6060f1SDimitry Andric continue; 1610fe6060f1SDimitry Andric if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1611fe6060f1SDimitry Andric return; 1612fe6060f1SDimitry Andric ToHoist.push_back(CurrI); 1613fe6060f1SDimitry Andric WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1614fe6060f1SDimitry Andric } 1615fe6060f1SDimitry Andric 1616fe6060f1SDimitry Andric sort(ToHoist, [this](Instruction *A, Instruction *B) { 1617fe6060f1SDimitry Andric return DT->dominates(A, B); 1618fe6060f1SDimitry Andric }); 1619fe6060f1SDimitry Andric for (Instruction *I : ToHoist) 1620fe6060f1SDimitry Andric I->moveBefore(MatMul); 16215ffd83dbSDimitry Andric 16225ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 16235ffd83dbSDimitry Andric return; 16245ffd83dbSDimitry Andric } 16255ffd83dbSDimitry Andric } 16265ffd83dbSDimitry Andric 1627480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 1628480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 1629480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 1630480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1631480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1632480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1633480093f4SDimitry Andric 16345ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 16355ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1636e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Rhs.getElementType() && 1637e8d8bef9SDimitry Andric "Matrix multiply argument element types do not match."); 1638480093f4SDimitry Andric 1639480093f4SDimitry Andric const unsigned R = LShape.NumRows; 1640480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 16415ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows); 1642480093f4SDimitry Andric 1643480093f4SDimitry Andric // Initialize the output 16445ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType); 1645e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Result.getElementType() && 1646e8d8bef9SDimitry Andric "Matrix multiply result element type does not match arguments."); 1647480093f4SDimitry Andric 1648fe6060f1SDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1649fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1650480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1651480093f4SDimitry Andric } 1652480093f4SDimitry Andric 1653480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 1654480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 16555ffd83dbSDimitry Andric MatrixTy Result; 1656480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1657480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 1658480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1659480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 16605ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1661480093f4SDimitry Andric 16625ffd83dbSDimitry Andric const unsigned NewNumVecs = 16635ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 16645ffd83dbSDimitry Andric const unsigned NewNumElts = 16655ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1666480093f4SDimitry Andric 16675ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) { 16685ffd83dbSDimitry Andric // Build a single result vector. First initialize it. 16695ffd83dbSDimitry Andric Value *ResultVector = UndefValue::get( 16705ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 16715ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector. 16725ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) { 16735ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I); 16745ffd83dbSDimitry Andric // Row and column indices are transposed. 16755ffd83dbSDimitry Andric ResultVector = 16765ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1677480093f4SDimitry Andric } 16785ffd83dbSDimitry Andric Result.addVector(ResultVector); 1679480093f4SDimitry Andric } 1680480093f4SDimitry Andric 16815ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we 16825ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not 16835ffd83dbSDimitry Andric // account for later simplifications/combines. 16845ffd83dbSDimitry Andric finalizeLowering( 16855ffd83dbSDimitry Andric Inst, 1686fe6060f1SDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1687fe6060f1SDimitry Andric .addNumExposedTransposes(1), 16885ffd83dbSDimitry Andric Builder); 1689480093f4SDimitry Andric } 1690480093f4SDimitry Andric 1691480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 16925ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1693480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1694480093f4SDimitry Andric if (I == ShapeMap.end()) 1695480093f4SDimitry Andric return false; 1696480093f4SDimitry Andric 16975ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(), 16985ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 16995ffd83dbSDimitry Andric I->second); 1700480093f4SDimitry Andric return true; 1701480093f4SDimitry Andric } 1702480093f4SDimitry Andric 17035ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1704480093f4SDimitry Andric IRBuilder<> &Builder) { 1705480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 1706480093f4SDimitry Andric if (I == ShapeMap.end()) 1707480093f4SDimitry Andric return false; 1708480093f4SDimitry Andric 17095ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 17105ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 17115ffd83dbSDimitry Andric I->second); 1712480093f4SDimitry Andric return true; 1713480093f4SDimitry Andric } 1714480093f4SDimitry Andric 1715480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 1716480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 1717480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1718480093f4SDimitry Andric if (I == ShapeMap.end()) 1719480093f4SDimitry Andric return false; 1720480093f4SDimitry Andric 1721480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 1722480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 1723480093f4SDimitry Andric 1724480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1725480093f4SDimitry Andric ShapeInfo &Shape = I->second; 1726480093f4SDimitry Andric 17275ffd83dbSDimitry Andric MatrixTy Result; 17285ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder); 17295ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder); 17305ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 17315ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 17325ffd83dbSDimitry Andric "operands must agree on matrix layout"); 1733480093f4SDimitry Andric 1734fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 1735fe6060f1SDimitry Andric 17365ffd83dbSDimitry Andric // Helper to perform binary op on vectors. 17375ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1738480093f4SDimitry Andric switch (Inst->getOpcode()) { 1739480093f4SDimitry Andric case Instruction::Add: 1740480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 1741480093f4SDimitry Andric case Instruction::Mul: 1742480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 1743480093f4SDimitry Andric case Instruction::Sub: 1744480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 1745480093f4SDimitry Andric case Instruction::FAdd: 1746480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 1747480093f4SDimitry Andric case Instruction::FMul: 1748480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 1749480093f4SDimitry Andric case Instruction::FSub: 1750480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 1751480093f4SDimitry Andric default: 1752480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 1753480093f4SDimitry Andric } 1754480093f4SDimitry Andric }; 1755480093f4SDimitry Andric 17565ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 17575ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 17585ffd83dbSDimitry Andric 17595ffd83dbSDimitry Andric finalizeLowering(Inst, 17605ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 17615ffd83dbSDimitry Andric Result.getNumVectors()), 17625ffd83dbSDimitry Andric Builder); 1763480093f4SDimitry Andric return true; 1764480093f4SDimitry Andric } 17655ffd83dbSDimitry Andric 1766e8d8bef9SDimitry Andric /// Lower unary operators, if shape information is available. 1767e8d8bef9SDimitry Andric bool VisitUnaryOperator(UnaryOperator *Inst) { 1768e8d8bef9SDimitry Andric auto I = ShapeMap.find(Inst); 1769e8d8bef9SDimitry Andric if (I == ShapeMap.end()) 1770e8d8bef9SDimitry Andric return false; 1771e8d8bef9SDimitry Andric 1772e8d8bef9SDimitry Andric Value *Op = Inst->getOperand(0); 1773e8d8bef9SDimitry Andric 1774e8d8bef9SDimitry Andric IRBuilder<> Builder(Inst); 1775e8d8bef9SDimitry Andric ShapeInfo &Shape = I->second; 1776e8d8bef9SDimitry Andric 1777e8d8bef9SDimitry Andric MatrixTy Result; 1778e8d8bef9SDimitry Andric MatrixTy M = getMatrix(Op, Shape, Builder); 1779e8d8bef9SDimitry Andric 1780fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 1781fe6060f1SDimitry Andric 1782e8d8bef9SDimitry Andric // Helper to perform unary op on vectors. 1783e8d8bef9SDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1784e8d8bef9SDimitry Andric switch (Inst->getOpcode()) { 1785e8d8bef9SDimitry Andric case Instruction::FNeg: 1786e8d8bef9SDimitry Andric return Builder.CreateFNeg(Op); 1787e8d8bef9SDimitry Andric default: 1788e8d8bef9SDimitry Andric llvm_unreachable("Unsupported unary operator for matrix"); 1789e8d8bef9SDimitry Andric } 1790e8d8bef9SDimitry Andric }; 1791e8d8bef9SDimitry Andric 1792e8d8bef9SDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1793e8d8bef9SDimitry Andric Result.addVector(BuildVectorOp(M.getVector(I))); 1794e8d8bef9SDimitry Andric 1795e8d8bef9SDimitry Andric finalizeLowering(Inst, 1796e8d8bef9SDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1797e8d8bef9SDimitry Andric Result.getNumVectors()), 1798e8d8bef9SDimitry Andric Builder); 1799e8d8bef9SDimitry Andric return true; 1800e8d8bef9SDimitry Andric } 1801e8d8bef9SDimitry Andric 18025ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently 18035ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and 18045ffd83dbSDimitry Andric /// linearizing bottom up. 18055ffd83dbSDimitry Andric struct ExprLinearizer { 18065ffd83dbSDimitry Andric unsigned LengthToBreak = 100; 18075ffd83dbSDimitry Andric std::string Str; 18085ffd83dbSDimitry Andric raw_string_ostream Stream; 18095ffd83dbSDimitry Andric unsigned LineLength = 0; 18105ffd83dbSDimitry Andric const DataLayout &DL; 18115ffd83dbSDimitry Andric 18125ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify 18135ffd83dbSDimitry Andric /// matrix instructions. 18145ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 18155ffd83dbSDimitry Andric 18165ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is 18175ffd83dbSDimitry Andric /// part of. 18185ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 18195ffd83dbSDimitry Andric 18205ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram. 18215ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram; 18225ffd83dbSDimitry Andric 18235ffd83dbSDimitry Andric /// Leaf node of the expression to linearize. 18245ffd83dbSDimitry Andric Value *Leaf; 18255ffd83dbSDimitry Andric 18265ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing 18275ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused). 18285ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 18295ffd83dbSDimitry Andric 18305ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL, 18315ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix, 18325ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 18335ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 18345ffd83dbSDimitry Andric Value *Leaf) 18355ffd83dbSDimitry Andric : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 18365ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 18375ffd83dbSDimitry Andric 18385ffd83dbSDimitry Andric void indent(unsigned N) { 18395ffd83dbSDimitry Andric LineLength += N; 18405ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++) 18415ffd83dbSDimitry Andric Stream << " "; 18425ffd83dbSDimitry Andric } 18435ffd83dbSDimitry Andric 18445ffd83dbSDimitry Andric void lineBreak() { 18455ffd83dbSDimitry Andric Stream << "\n"; 18465ffd83dbSDimitry Andric LineLength = 0; 18475ffd83dbSDimitry Andric } 18485ffd83dbSDimitry Andric 18495ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) { 18505ffd83dbSDimitry Andric if (LineLength >= LengthToBreak) 18515ffd83dbSDimitry Andric lineBreak(); 18525ffd83dbSDimitry Andric 18535ffd83dbSDimitry Andric if (LineLength == 0) 18545ffd83dbSDimitry Andric indent(Indent); 18555ffd83dbSDimitry Andric } 18565ffd83dbSDimitry Andric 18575ffd83dbSDimitry Andric void write(StringRef S) { 18585ffd83dbSDimitry Andric LineLength += S.size(); 18595ffd83dbSDimitry Andric Stream << S; 18605ffd83dbSDimitry Andric } 18615ffd83dbSDimitry Andric 18625ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) { 18635ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V)) 18645ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr); 18655ffd83dbSDimitry Andric else if (V->getType()->isPointerTy()) 1866e8d8bef9SDimitry Andric return getUnderlyingObject(V); 18675ffd83dbSDimitry Andric return V; 18685ffd83dbSDimitry Andric } 18695ffd83dbSDimitry Andric 18705ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram. 18715ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 18725ffd83dbSDimitry Andric 18735ffd83dbSDimitry Andric /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 18745ffd83dbSDimitry Andric /// \p SS. 18755ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 18765ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V); 18775ffd83dbSDimitry Andric if (M == Inst2Matrix.end()) 18785ffd83dbSDimitry Andric SS << "unknown"; 18795ffd83dbSDimitry Andric else { 18805ffd83dbSDimitry Andric SS << M->second.getNumRows(); 18815ffd83dbSDimitry Andric SS << "x"; 18825ffd83dbSDimitry Andric SS << M->second.getNumColumns(); 18835ffd83dbSDimitry Andric } 18845ffd83dbSDimitry Andric } 18855ffd83dbSDimitry Andric 18865ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.* 18875ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input 18885ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name. 18895ffd83dbSDimitry Andric void writeFnName(CallInst *CI) { 18905ffd83dbSDimitry Andric if (!CI->getCalledFunction()) 18915ffd83dbSDimitry Andric write("<no called fn>"); 18925ffd83dbSDimitry Andric else { 18935ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName(); 18945ffd83dbSDimitry Andric if (!Name.startswith("llvm.matrix")) { 18955ffd83dbSDimitry Andric write(Name); 18965ffd83dbSDimitry Andric return; 18975ffd83dbSDimitry Andric } 18985ffd83dbSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1899fe6060f1SDimitry Andric write(Intrinsic::getBaseName(II->getIntrinsicID()) 19005ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size())); 19015ffd83dbSDimitry Andric write("."); 1902e8d8bef9SDimitry Andric std::string Tmp; 19035ffd83dbSDimitry Andric raw_string_ostream SS(Tmp); 19045ffd83dbSDimitry Andric 19055ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 19065ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 19075ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19085ffd83dbSDimitry Andric SS << "."; 19095ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS); 19105ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19115ffd83dbSDimitry Andric break; 19125ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 19135ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19145ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19155ffd83dbSDimitry Andric break; 19165ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 19175ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS); 19185ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19195ffd83dbSDimitry Andric break; 19205ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 19215ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19225ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType(); 19235ffd83dbSDimitry Andric break; 19245ffd83dbSDimitry Andric default: 19255ffd83dbSDimitry Andric llvm_unreachable("Unhandled case"); 19265ffd83dbSDimitry Andric } 19275ffd83dbSDimitry Andric SS.flush(); 19285ffd83dbSDimitry Andric write(Tmp); 19295ffd83dbSDimitry Andric } 19305ffd83dbSDimitry Andric } 19315ffd83dbSDimitry Andric 19325ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const { 19335ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 19345ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 19355ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 19365ffd83dbSDimitry Andric return 3; 19375ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 19385ffd83dbSDimitry Andric return 2; 19395ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 19405ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 19415ffd83dbSDimitry Andric return 3; 19425ffd83dbSDimitry Andric default: 19435ffd83dbSDimitry Andric return 0; 19445ffd83dbSDimitry Andric } 19455ffd83dbSDimitry Andric } 19465ffd83dbSDimitry Andric return 0; 19475ffd83dbSDimitry Andric } 19485ffd83dbSDimitry Andric 19495ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an 19505ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we 19515ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values. 19525ffd83dbSDimitry Andric void write(Value *V) { 19535ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V); 19545ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) { 19555ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) { 19565ffd83dbSDimitry Andric Stream << "stack addr"; 19575ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size(); 19585ffd83dbSDimitry Andric } else { 19595ffd83dbSDimitry Andric Stream << "addr"; 19605ffd83dbSDimitry Andric LineLength += StringRef("addr").size(); 19615ffd83dbSDimitry Andric } 19625ffd83dbSDimitry Andric if (!V->getName().empty()) { 19635ffd83dbSDimitry Andric Stream << " %" << V->getName() << ""; 19645ffd83dbSDimitry Andric LineLength += V->getName().size() + 2; 19655ffd83dbSDimitry Andric } 19665ffd83dbSDimitry Andric return; 19675ffd83dbSDimitry Andric } 19685ffd83dbSDimitry Andric 19695ffd83dbSDimitry Andric std::string Tmp; 19705ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp); 19715ffd83dbSDimitry Andric 19725ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V)) 19735ffd83dbSDimitry Andric TmpStream << CI->getValue(); 19745ffd83dbSDimitry Andric else if (isa<Constant>(V)) 19755ffd83dbSDimitry Andric TmpStream << "constant"; 19765ffd83dbSDimitry Andric else { 19775ffd83dbSDimitry Andric if (isMatrix(V)) 19785ffd83dbSDimitry Andric TmpStream << "matrix"; 19795ffd83dbSDimitry Andric else 19805ffd83dbSDimitry Andric TmpStream << "scalar"; 19815ffd83dbSDimitry Andric } 19825ffd83dbSDimitry Andric TmpStream.flush(); 19835ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim()); 19845ffd83dbSDimitry Andric LineLength += Tmp.size(); 19855ffd83dbSDimitry Andric Stream << Tmp; 19865ffd83dbSDimitry Andric } 19875ffd83dbSDimitry Andric 19885ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent. 19895ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused) 19905ffd83dbSDimitry Andric /// at the re-used root instruction. 19915ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 19925ffd83dbSDimitry Andric bool ParentShared) { 19935ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr); 19945ffd83dbSDimitry Andric maybeIndent(Indent); 19955ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops; 19965ffd83dbSDimitry Andric 19975ffd83dbSDimitry Andric // Is Expr shared with other expression leaves? 19985ffd83dbSDimitry Andric bool ExprShared = false; 19995ffd83dbSDimitry Andric 20005ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required. 20015ffd83dbSDimitry Andric if (!ParentShared) { 20025ffd83dbSDimitry Andric auto SI = Shared.find(Expr); 20035ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf)); 20045ffd83dbSDimitry Andric 20055ffd83dbSDimitry Andric for (Value *S : SI->second) { 20065ffd83dbSDimitry Andric if (S == Leaf) 20075ffd83dbSDimitry Andric continue; 20085ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 20095ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) + 20105ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " ("); 20115ffd83dbSDimitry Andric } 20125ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1; 20135ffd83dbSDimitry Andric } 20145ffd83dbSDimitry Andric 20155ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second; 20165ffd83dbSDimitry Andric if (Reused && !ParentReused) 20175ffd83dbSDimitry Andric write("(reused) "); 20185ffd83dbSDimitry Andric 20195ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) { 20205ffd83dbSDimitry Andric writeFnName(CI); 20215ffd83dbSDimitry Andric 20225ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 20235ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) { 20245ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from 20255ffd83dbSDimitry Andric // non-matrix ops. 20265ffd83dbSDimitry Andric write("matrix"); 20275ffd83dbSDimitry Andric return; 20285ffd83dbSDimitry Andric } else { 20295ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end()); 20305ffd83dbSDimitry Andric write(std::string(I->getOpcodeName())); 20315ffd83dbSDimitry Andric } 20325ffd83dbSDimitry Andric 20335ffd83dbSDimitry Andric write(std::string("(")); 20345ffd83dbSDimitry Andric 20355ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1; 20365ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 20375ffd83dbSDimitry Andric NumOpsToBreak = 2; 20385ffd83dbSDimitry Andric 20395ffd83dbSDimitry Andric for (Value *Op : Ops) { 20405ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak) 20415ffd83dbSDimitry Andric lineBreak(); 20425ffd83dbSDimitry Andric 20435ffd83dbSDimitry Andric maybeIndent(Indent + 1); 20445ffd83dbSDimitry Andric if (isMatrix(Op)) 20455ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared); 20465ffd83dbSDimitry Andric else 20475ffd83dbSDimitry Andric write(Op); 20485ffd83dbSDimitry Andric if (Op != Ops.back()) 20495ffd83dbSDimitry Andric write(", "); 20505ffd83dbSDimitry Andric } 20515ffd83dbSDimitry Andric 20525ffd83dbSDimitry Andric write(")"); 20535ffd83dbSDimitry Andric } 20545ffd83dbSDimitry Andric 20555ffd83dbSDimitry Andric const std::string &getResult() { 20565ffd83dbSDimitry Andric Stream.flush(); 20575ffd83dbSDimitry Andric return Str; 20585ffd83dbSDimitry Andric } 20595ffd83dbSDimitry Andric }; 20605ffd83dbSDimitry Andric 20615ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks 20625ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used: 20635ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the 20645ffd83dbSDimitry Andric /// DISubprograms they are contained in. 20655ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in 20665ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 20675ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix 20685ffd83dbSDimitry Andric // users (like stores) in the current subprogram. 20695ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the 20705ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive 20715ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note 20725ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions 20735ffd83dbSDimitry Andric /// are explicitly marked as shared(). 20745ffd83dbSDimitry Andric struct RemarkGenerator { 20755ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 20765ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 20775ffd83dbSDimitry Andric Function &Func; 20785ffd83dbSDimitry Andric const DataLayout &DL; 20795ffd83dbSDimitry Andric 20805ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 20815ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func) 20825ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 20835ffd83dbSDimitry Andric DL(Func.getParent()->getDataLayout()) {} 20845ffd83dbSDimitry Andric 20855ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 20865ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in 20875ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores. 20885ffd83dbSDimitry Andric SmallVector<Value *, 4> 20895ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 20905ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves; 20915ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram) 20925ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() || 20935ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 20945ffd83dbSDimitry Andric return ExprsInSubprogram.count(U); 20955ffd83dbSDimitry Andric })) 20965ffd83dbSDimitry Andric Leaves.push_back(Expr); 20975ffd83dbSDimitry Andric return Leaves; 20985ffd83dbSDimitry Andric } 20995ffd83dbSDimitry Andric 21005ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 21015ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to 21025ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram. 21035ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V, 21045ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 21055ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 21065ffd83dbSDimitry Andric 21075ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V)) 21085ffd83dbSDimitry Andric return; 21095ffd83dbSDimitry Andric 21105ffd83dbSDimitry Andric auto I = Shared.insert({V, {}}); 21115ffd83dbSDimitry Andric I.first->second.insert(Leaf); 21125ffd83dbSDimitry Andric 21135ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values()) 21145ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 21155ffd83dbSDimitry Andric } 21165ffd83dbSDimitry Andric 21175ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression 21185ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once. 21195ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 21205ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy> 21215ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 21225ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 21235ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 21245ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root)) 21255ffd83dbSDimitry Andric return {}; 21265ffd83dbSDimitry Andric 21275ffd83dbSDimitry Andric // Already counted this expression. Stop. 21285ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second) 21295ffd83dbSDimitry Andric return {}; 21305ffd83dbSDimitry Andric 21315ffd83dbSDimitry Andric OpInfoTy SharedCount; 21325ffd83dbSDimitry Andric OpInfoTy Count; 21335ffd83dbSDimitry Andric 21345ffd83dbSDimitry Andric auto I = Shared.find(Root); 21355ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root); 21365ffd83dbSDimitry Andric if (I->second.size() == 1) 21375ffd83dbSDimitry Andric Count = CM->second.getOpInfo(); 21385ffd83dbSDimitry Andric else 21395ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo(); 21405ffd83dbSDimitry Andric 21415ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) { 21425ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 21435ffd83dbSDimitry Andric Count += C.first; 21445ffd83dbSDimitry Andric SharedCount += C.second; 21455ffd83dbSDimitry Andric } 21465ffd83dbSDimitry Andric return {Count, SharedCount}; 21475ffd83dbSDimitry Andric } 21485ffd83dbSDimitry Andric 21495ffd83dbSDimitry Andric void emitRemarks() { 21505ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 21515ffd83dbSDimitry Andric return; 21525ffd83dbSDimitry Andric 21535ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing 21545ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we 21555ffd83dbSDimitry Andric // only map them to the containing function. 21565ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 21575ffd83dbSDimitry Andric for (auto &KV : Inst2Matrix) { 21585ffd83dbSDimitry Andric if (Func.getSubprogram()) { 21595ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first); 21605ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc(); 21615ffd83dbSDimitry Andric while (Context) { 21625ffd83dbSDimitry Andric auto I = 21635ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 21645ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 21655ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 21665ffd83dbSDimitry Andric } 21675ffd83dbSDimitry Andric } else { 21685ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}}); 21695ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 21705ffd83dbSDimitry Andric } 21715ffd83dbSDimitry Andric } 21725ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) { 21735ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 21745ffd83dbSDimitry Andric KV.second.end()); 21755ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram); 21765ffd83dbSDimitry Andric 21775ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 21785ffd83dbSDimitry Andric for (Value *Leaf : Leaves) 21795ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 21805ffd83dbSDimitry Andric 21815ffd83dbSDimitry Andric // Generate remarks for each leaf. 21825ffd83dbSDimitry Andric for (auto *L : Leaves) { 21835ffd83dbSDimitry Andric 21845ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 21855ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 21865ffd83dbSDimitry Andric while (Context) { 21875ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) { 21885ffd83dbSDimitry Andric Loc = Context; 21895ffd83dbSDimitry Andric break; 21905ffd83dbSDimitry Andric } 21915ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 21925ffd83dbSDimitry Andric } 21935ffd83dbSDimitry Andric 21945ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 21955ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts; 21965ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) = 21975ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 21985ffd83dbSDimitry Andric 21995ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 22005ffd83dbSDimitry Andric cast<Instruction>(L)->getParent()); 22015ffd83dbSDimitry Andric 22025ffd83dbSDimitry Andric Rem << "Lowered with "; 22035ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 22045ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 22055ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps) 2206fe6060f1SDimitry Andric << " compute ops, " 2207fe6060f1SDimitry Andric << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2208fe6060f1SDimitry Andric << " exposed transposes"; 22095ffd83dbSDimitry Andric 22105ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 22115ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) { 22125ffd83dbSDimitry Andric Rem << ",\nadditionally " 22135ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 22145ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 22155ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 22165ffd83dbSDimitry Andric << " compute ops" 22175ffd83dbSDimitry Andric << " are shared with other expressions"; 22185ffd83dbSDimitry Andric } 22195ffd83dbSDimitry Andric 22205ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 22215ffd83dbSDimitry Andric ORE.emit(Rem); 22225ffd83dbSDimitry Andric } 22235ffd83dbSDimitry Andric } 22245ffd83dbSDimitry Andric } 22255ffd83dbSDimitry Andric 22265ffd83dbSDimitry Andric std::string 22275ffd83dbSDimitry Andric linearize(Value *L, 22285ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 22295ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 22305ffd83dbSDimitry Andric const DataLayout &DL) { 22315ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 22325ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false); 22335ffd83dbSDimitry Andric return Lin.getResult(); 22345ffd83dbSDimitry Andric } 22355ffd83dbSDimitry Andric }; 2236480093f4SDimitry Andric }; 2237480093f4SDimitry Andric } // namespace 2238480093f4SDimitry Andric 2239480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2240480093f4SDimitry Andric FunctionAnalysisManager &AM) { 2241480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2242e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE = nullptr; 2243e8d8bef9SDimitry Andric AAResults *AA = nullptr; 2244e8d8bef9SDimitry Andric DominatorTree *DT = nullptr; 2245e8d8bef9SDimitry Andric LoopInfo *LI = nullptr; 2246e8d8bef9SDimitry Andric 2247e8d8bef9SDimitry Andric if (!Minimal) { 2248e8d8bef9SDimitry Andric ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2249e8d8bef9SDimitry Andric AA = &AM.getResult<AAManager>(F); 2250e8d8bef9SDimitry Andric DT = &AM.getResult<DominatorTreeAnalysis>(F); 2251e8d8bef9SDimitry Andric LI = &AM.getResult<LoopAnalysis>(F); 2252e8d8bef9SDimitry Andric } 22535ffd83dbSDimitry Andric 22545ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2255480093f4SDimitry Andric if (LMT.Visit()) { 2256480093f4SDimitry Andric PreservedAnalyses PA; 2257e8d8bef9SDimitry Andric if (!Minimal) { 2258e8d8bef9SDimitry Andric PA.preserve<LoopAnalysis>(); 2259e8d8bef9SDimitry Andric PA.preserve<DominatorTreeAnalysis>(); 2260e8d8bef9SDimitry Andric } 2261480093f4SDimitry Andric return PA; 2262480093f4SDimitry Andric } 2263480093f4SDimitry Andric return PreservedAnalyses::all(); 2264480093f4SDimitry Andric } 2265480093f4SDimitry Andric 2266*349cc55cSDimitry Andric void LowerMatrixIntrinsicsPass::printPipeline( 2267*349cc55cSDimitry Andric raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2268*349cc55cSDimitry Andric static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2269*349cc55cSDimitry Andric OS, MapClassName2PassName); 2270*349cc55cSDimitry Andric OS << "<"; 2271*349cc55cSDimitry Andric if (Minimal) 2272*349cc55cSDimitry Andric OS << "minimal"; 2273*349cc55cSDimitry Andric OS << ">"; 2274*349cc55cSDimitry Andric } 2275*349cc55cSDimitry Andric 2276480093f4SDimitry Andric namespace { 2277480093f4SDimitry Andric 2278480093f4SDimitry Andric class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2279480093f4SDimitry Andric public: 2280480093f4SDimitry Andric static char ID; 2281480093f4SDimitry Andric 2282480093f4SDimitry Andric LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2283480093f4SDimitry Andric initializeLowerMatrixIntrinsicsLegacyPassPass( 2284480093f4SDimitry Andric *PassRegistry::getPassRegistry()); 2285480093f4SDimitry Andric } 2286480093f4SDimitry Andric 2287480093f4SDimitry Andric bool runOnFunction(Function &F) override { 22885ffd83dbSDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 22895ffd83dbSDimitry Andric auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 22905ffd83dbSDimitry Andric auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 22915ffd83dbSDimitry Andric auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 22925ffd83dbSDimitry Andric auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2293e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2294480093f4SDimitry Andric bool C = LMT.Visit(); 2295480093f4SDimitry Andric return C; 2296480093f4SDimitry Andric } 2297480093f4SDimitry Andric 2298480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2299480093f4SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 23005ffd83dbSDimitry Andric AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 23015ffd83dbSDimitry Andric AU.addRequired<AAResultsWrapperPass>(); 23025ffd83dbSDimitry Andric AU.addRequired<DominatorTreeWrapperPass>(); 23035ffd83dbSDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 23045ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 23055ffd83dbSDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 2306480093f4SDimitry Andric } 2307480093f4SDimitry Andric }; 2308480093f4SDimitry Andric } // namespace 2309480093f4SDimitry Andric 2310480093f4SDimitry Andric static const char pass_name[] = "Lower the matrix intrinsics"; 2311480093f4SDimitry Andric char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2312480093f4SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2313480093f4SDimitry Andric false, false) 23145ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 23155ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 23165ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 23175ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2318480093f4SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2319480093f4SDimitry Andric false, false) 2320480093f4SDimitry Andric 2321480093f4SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsPass() { 2322480093f4SDimitry Andric return new LowerMatrixIntrinsicsLegacyPass(); 2323480093f4SDimitry Andric } 2324e8d8bef9SDimitry Andric 2325e8d8bef9SDimitry Andric namespace { 2326e8d8bef9SDimitry Andric 2327e8d8bef9SDimitry Andric /// A lightweight version of the matrix lowering pass that only requires TTI. 2328e8d8bef9SDimitry Andric /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2329e8d8bef9SDimitry Andric /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2330e8d8bef9SDimitry Andric /// example with -O0. 2331e8d8bef9SDimitry Andric class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2332e8d8bef9SDimitry Andric public: 2333e8d8bef9SDimitry Andric static char ID; 2334e8d8bef9SDimitry Andric 2335e8d8bef9SDimitry Andric LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2336e8d8bef9SDimitry Andric initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2337e8d8bef9SDimitry Andric *PassRegistry::getPassRegistry()); 2338e8d8bef9SDimitry Andric } 2339e8d8bef9SDimitry Andric 2340e8d8bef9SDimitry Andric bool runOnFunction(Function &F) override { 2341e8d8bef9SDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2342e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2343e8d8bef9SDimitry Andric bool C = LMT.Visit(); 2344e8d8bef9SDimitry Andric return C; 2345e8d8bef9SDimitry Andric } 2346e8d8bef9SDimitry Andric 2347e8d8bef9SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2348e8d8bef9SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 2349e8d8bef9SDimitry Andric AU.setPreservesCFG(); 2350e8d8bef9SDimitry Andric } 2351e8d8bef9SDimitry Andric }; 2352e8d8bef9SDimitry Andric } // namespace 2353e8d8bef9SDimitry Andric 2354e8d8bef9SDimitry Andric static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2355e8d8bef9SDimitry Andric char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2356e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2357e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, 2358e8d8bef9SDimitry Andric false, false) 2359e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2360e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2361e8d8bef9SDimitry Andric false) 2362e8d8bef9SDimitry Andric 2363e8d8bef9SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2364e8d8bef9SDimitry Andric return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2365e8d8bef9SDimitry Andric } 2366