1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10480093f4SDimitry Andric // 11480093f4SDimitry Andric // TODO: 125ffd83dbSDimitry Andric // * Improve fusion: 135ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 145ffd83dbSDimitry Andric // transposed. 155ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns 165ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias. 17480093f4SDimitry Andric // 18480093f4SDimitry Andric //===----------------------------------------------------------------------===// 19480093f4SDimitry Andric 20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 22480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 235ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 245ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 25*81ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 265ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 285ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 29480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 30480093f4SDimitry Andric #include "llvm/IR/CFG.h" 31480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 325ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h" 33480093f4SDimitry Andric #include "llvm/IR/Function.h" 34480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 35480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 36480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 37fe6060f1SDimitry Andric #include "llvm/IR/MatrixBuilder.h" 38480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 39480093f4SDimitry Andric #include "llvm/InitializePasses.h" 40480093f4SDimitry Andric #include "llvm/Pass.h" 415ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h" 425ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 43480093f4SDimitry Andric #include "llvm/Support/Debug.h" 44480093f4SDimitry Andric #include "llvm/Transforms/Scalar.h" 455ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 47e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 48480093f4SDimitry Andric 49480093f4SDimitry Andric using namespace llvm; 50480093f4SDimitry Andric using namespace PatternMatch; 51480093f4SDimitry Andric 52480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 53480093f4SDimitry Andric 545ffd83dbSDimitry Andric static cl::opt<bool> 555ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 565ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions.")); 575ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles. 585ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize( 595ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 605ffd83dbSDimitry Andric cl::desc( 615ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles.")); 62e8d8bef9SDimitry Andric static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 63e8d8bef9SDimitry Andric cl::Hidden, 64e8d8bef9SDimitry Andric cl::desc("Generate loop nest for tiling.")); 655ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion( 665ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden, 675ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable.")); 68480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 69480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 70480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 71480093f4SDimitry Andric "result in different results, due to less rounding error.")); 72480093f4SDimitry Andric 735ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 745ffd83dbSDimitry Andric 755ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout( 765ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 775ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"), 785ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 795ffd83dbSDimitry Andric "Use column-major layout"), 805ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 815ffd83dbSDimitry Andric "Use row-major layout"))); 825ffd83dbSDimitry Andric 835ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the 845ffd83dbSDimitry Andric /// attached subprogram for a local scope. 855ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) { 865ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 875ffd83dbSDimitry Andric return Subprogram; 885ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram(); 895ffd83dbSDimitry Andric } 905ffd83dbSDimitry Andric 91480093f4SDimitry Andric namespace { 92480093f4SDimitry Andric 935ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 945ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 955ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors. 965ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements. 975ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column 985ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column 995ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function 1005ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the 1015ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix). 102480093f4SDimitry Andric // 1035ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below 104480093f4SDimitry Andric // 105480093f4SDimitry Andric // 0 1 2 3 106480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 107480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 108480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 109480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 110480093f4SDimitry Andric 111480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 112480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 1135ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns 114480093f4SDimitry Andric // of the sub-matrix. 115480093f4SDimitry Andric // 1165ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 117480093f4SDimitry Andric // -> just returns Base 1185ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 119480093f4SDimitry Andric // -> returns Base + (1 * 4) 1205ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 121480093f4SDimitry Andric // -> returns Base + (2 * 4) 122480093f4SDimitry Andric // 123480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 124480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 125480093f4SDimitry Andric // 126480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 127480093f4SDimitry Andric // Base Col 1 Col 2 128480093f4SDimitry Andric // | | | 129480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 130480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 131480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 132480093f4SDimitry Andric // 1335ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 1345ffd83dbSDimitry Andric unsigned NumElements, Type *EltType, 135480093f4SDimitry Andric IRBuilder<> &Builder) { 136480093f4SDimitry Andric 137480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 1385ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 1395ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector."); 140480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 141480093f4SDimitry Andric 1425ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride. 1435ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 144480093f4SDimitry Andric 1455ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation, 1465ffd83dbSDimitry Andric // if we select vector 0. 1475ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 1485ffd83dbSDimitry Andric VecStart = BasePtr; 149480093f4SDimitry Andric else 1505ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 151480093f4SDimitry Andric 1525ffd83dbSDimitry Andric // Cast elementwise vector start pointer to a pointer to a vector 1535ffd83dbSDimitry Andric // (EltType x NumElements)*. 1545ffd83dbSDimitry Andric auto *VecType = FixedVectorType::get(EltType, NumElements); 1555ffd83dbSDimitry Andric Type *VecPtrType = PointerType::get(VecType, AS); 1565ffd83dbSDimitry Andric return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 157480093f4SDimitry Andric } 158480093f4SDimitry Andric 159480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 160480093f4SDimitry Andric /// 161480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 162480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 163480093f4SDimitry Andric /// instructions. 1645ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout). 1655ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout. 166480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 167480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 168480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 169480093f4SDimitry Andric /// a set of column vectors, 1705ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which 1715ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we 1725ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the 1735ffd83dbSDimitry Andric /// intrinsics, this includes stores for example. 174480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 175480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 176480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 177480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 178480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 179480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 180480093f4SDimitry Andric /// obsolete instructions. 181480093f4SDimitry Andric /// 182480093f4SDimitry Andric class LowerMatrixIntrinsics { 183480093f4SDimitry Andric Function &Func; 184480093f4SDimitry Andric const DataLayout &DL; 185480093f4SDimitry Andric const TargetTransformInfo &TTI; 186e8d8bef9SDimitry Andric AliasAnalysis *AA; 187e8d8bef9SDimitry Andric DominatorTree *DT; 188e8d8bef9SDimitry Andric LoopInfo *LI; 189e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE; 190480093f4SDimitry Andric 1915ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 1925ffd83dbSDimitry Andric struct OpInfoTy { 1935ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix. 1945ffd83dbSDimitry Andric unsigned NumStores = 0; 1955ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix. 1965ffd83dbSDimitry Andric unsigned NumLoads = 0; 1975ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix. 1985ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 199fe6060f1SDimitry Andric /// Most of the time transposes can be fused with matrix multiplies or can 200fe6060f1SDimitry Andric /// be folded away via algebraic simplifications. This is the number of 201fe6060f1SDimitry Andric /// transposes that we failed to make "free" via such optimizations. 202fe6060f1SDimitry Andric unsigned NumExposedTransposes = 0; 2035ffd83dbSDimitry Andric 2045ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) { 2055ffd83dbSDimitry Andric NumStores += RHS.NumStores; 2065ffd83dbSDimitry Andric NumLoads += RHS.NumLoads; 2075ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps; 208fe6060f1SDimitry Andric NumExposedTransposes += RHS.NumExposedTransposes; 2095ffd83dbSDimitry Andric return *this; 2105ffd83dbSDimitry Andric } 2115ffd83dbSDimitry Andric }; 2125ffd83dbSDimitry Andric 2135ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or 2145ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type. 2155ffd83dbSDimitry Andric class MatrixTy { 2165ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors; 2175ffd83dbSDimitry Andric 2185ffd83dbSDimitry Andric OpInfoTy OpInfo; 2195ffd83dbSDimitry Andric 2205ffd83dbSDimitry Andric bool IsColumnMajor = true; 221480093f4SDimitry Andric 222480093f4SDimitry Andric public: 22304eeddc0SDimitry Andric MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2245ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors) 2255ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()), 2265ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2275ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 2285ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 229480093f4SDimitry Andric 2305ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows; 2315ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J) 2325ffd83dbSDimitry Andric addVector(UndefValue::get(FixedVectorType::get( 2335ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns))); 234480093f4SDimitry Andric } 235480093f4SDimitry Andric 2365ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; } 2375ffd83dbSDimitry Andric Value *getColumn(unsigned i) const { 2385ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2395ffd83dbSDimitry Andric return Vectors[i]; 2405ffd83dbSDimitry Andric } 2415ffd83dbSDimitry Andric Value *getRow(unsigned i) const { 2425ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes"); 2435ffd83dbSDimitry Andric return Vectors[i]; 2445ffd83dbSDimitry Andric } 245480093f4SDimitry Andric 2465ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; } 247480093f4SDimitry Andric 248e8d8bef9SDimitry Andric Type *getElementType() const { return getVectorTy()->getElementType(); } 2495ffd83dbSDimitry Andric 2505ffd83dbSDimitry Andric unsigned getNumVectors() const { 2515ffd83dbSDimitry Andric if (isColumnMajor()) 2525ffd83dbSDimitry Andric return getNumColumns(); 2535ffd83dbSDimitry Andric return getNumRows(); 2545ffd83dbSDimitry Andric } 2555ffd83dbSDimitry Andric 2565ffd83dbSDimitry Andric unsigned getNumColumns() const { 2575ffd83dbSDimitry Andric if (isColumnMajor()) 2585ffd83dbSDimitry Andric return Vectors.size(); 2595ffd83dbSDimitry Andric else { 2605ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2615ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2625ffd83dbSDimitry Andric } 2635ffd83dbSDimitry Andric } 2645ffd83dbSDimitry Andric unsigned getNumRows() const { 2655ffd83dbSDimitry Andric if (isColumnMajor()) { 2665ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2675ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2685ffd83dbSDimitry Andric } else 2695ffd83dbSDimitry Andric return Vectors.size(); 2705ffd83dbSDimitry Andric } 2715ffd83dbSDimitry Andric 2725ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); } 2735ffd83dbSDimitry Andric VectorType *getColumnTy() { 2745ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2755ffd83dbSDimitry Andric return getVectorTy(); 2765ffd83dbSDimitry Andric } 2775ffd83dbSDimitry Andric 278e8d8bef9SDimitry Andric VectorType *getVectorTy() const { 2795ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType()); 2805ffd83dbSDimitry Andric } 281480093f4SDimitry Andric 282480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 2835ffd83dbSDimitry Andric assert(isColumnMajor() && 2845ffd83dbSDimitry Andric "columns() only supported for column-major matrixes"); 2855ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 286480093f4SDimitry Andric } 287480093f4SDimitry Andric 2885ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 2895ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 2905ffd83dbSDimitry Andric } 2915ffd83dbSDimitry Andric 2925ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating 293480093f4SDimitry Andric /// them. 294480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 2955ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0] 2965ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors); 2975ffd83dbSDimitry Andric } 2985ffd83dbSDimitry Andric 2995ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) { 3005ffd83dbSDimitry Andric OpInfo.NumLoads += N; 3015ffd83dbSDimitry Andric return *this; 3025ffd83dbSDimitry Andric } 3035ffd83dbSDimitry Andric 3045ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 3055ffd83dbSDimitry Andric 3065ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) { 3075ffd83dbSDimitry Andric OpInfo.NumStores += N; 3085ffd83dbSDimitry Andric return *this; 3095ffd83dbSDimitry Andric } 3105ffd83dbSDimitry Andric 311fe6060f1SDimitry Andric MatrixTy &addNumExposedTransposes(unsigned N) { 312fe6060f1SDimitry Andric OpInfo.NumExposedTransposes += N; 313fe6060f1SDimitry Andric return *this; 314fe6060f1SDimitry Andric } 315fe6060f1SDimitry Andric 3165ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) { 3175ffd83dbSDimitry Andric OpInfo.NumComputeOps += N; 3185ffd83dbSDimitry Andric return *this; 3195ffd83dbSDimitry Andric } 3205ffd83dbSDimitry Andric 3215ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; } 3225ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; } 3235ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 3245ffd83dbSDimitry Andric 3255ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; } 3265ffd83dbSDimitry Andric 3275ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; } 3285ffd83dbSDimitry Andric 3295ffd83dbSDimitry Andric unsigned getStride() const { 3305ffd83dbSDimitry Andric if (isColumnMajor()) 3315ffd83dbSDimitry Andric return getNumRows(); 3325ffd83dbSDimitry Andric return getNumColumns(); 3335ffd83dbSDimitry Andric } 3345ffd83dbSDimitry Andric 3355ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 3365ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column 3375ffd83dbSDimitry Andric /// vector, otherwise from a row vector. 3385ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 3395ffd83dbSDimitry Andric IRBuilder<> &Builder) const { 3405ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 3415ffd83dbSDimitry Andric return Builder.CreateShuffleVector( 342e8d8bef9SDimitry Andric Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 3435ffd83dbSDimitry Andric "block"); 344480093f4SDimitry Andric } 345480093f4SDimitry Andric }; 346480093f4SDimitry Andric 347480093f4SDimitry Andric struct ShapeInfo { 348480093f4SDimitry Andric unsigned NumRows; 349480093f4SDimitry Andric unsigned NumColumns; 350480093f4SDimitry Andric 3515ffd83dbSDimitry Andric bool IsColumnMajor; 3525ffd83dbSDimitry Andric 353480093f4SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 3545ffd83dbSDimitry Andric : NumRows(NumRows), NumColumns(NumColumns), 3555ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 356480093f4SDimitry Andric 357480093f4SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 3585ffd83dbSDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 3595ffd83dbSDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {} 360480093f4SDimitry Andric 361480093f4SDimitry Andric bool operator==(const ShapeInfo &other) { 362480093f4SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 363480093f4SDimitry Andric } 364480093f4SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 365480093f4SDimitry Andric 366480093f4SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 367480093f4SDimitry Andric /// are != 0. 368480093f4SDimitry Andric operator bool() const { 369480093f4SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 370480093f4SDimitry Andric return NumRows != 0; 371480093f4SDimitry Andric } 3725ffd83dbSDimitry Andric 3735ffd83dbSDimitry Andric unsigned getStride() const { 3745ffd83dbSDimitry Andric if (IsColumnMajor) 3755ffd83dbSDimitry Andric return NumRows; 3765ffd83dbSDimitry Andric return NumColumns; 3775ffd83dbSDimitry Andric } 3785ffd83dbSDimitry Andric 3795ffd83dbSDimitry Andric unsigned getNumVectors() const { 3805ffd83dbSDimitry Andric if (IsColumnMajor) 3815ffd83dbSDimitry Andric return NumColumns; 3825ffd83dbSDimitry Andric return NumRows; 3835ffd83dbSDimitry Andric } 384480093f4SDimitry Andric }; 385480093f4SDimitry Andric 386480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 387480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 388480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 3895ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the 390480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 391fe6060f1SDimitry Andric /// using shape information as well. A ValueMap is used so that when 392fe6060f1SDimitry Andric /// sub-passes like optimizeTransposes performs RAUW the map stays 393fe6060f1SDimitry Andric /// up-to-date. 394fe6060f1SDimitry Andric ValueMap<Value *, ShapeInfo> ShapeMap; 395480093f4SDimitry Andric 396480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 397480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 398480093f4SDimitry Andric /// those need to be removed after we finished lowering. 399480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 400480093f4SDimitry Andric 401480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 4025ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 403480093f4SDimitry Andric 404fe6060f1SDimitry Andric private: 405fe6060f1SDimitry Andric static FastMathFlags getFastMathFlags(Instruction *Inst) { 406fe6060f1SDimitry Andric FastMathFlags FMF; 407fe6060f1SDimitry Andric 408fe6060f1SDimitry Andric if (isa<FPMathOperator>(*Inst)) 409fe6060f1SDimitry Andric FMF = Inst->getFastMathFlags(); 410fe6060f1SDimitry Andric 411fe6060f1SDimitry Andric FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 412fe6060f1SDimitry Andric 413fe6060f1SDimitry Andric return FMF; 414fe6060f1SDimitry Andric } 415fe6060f1SDimitry Andric 416480093f4SDimitry Andric public: 4175ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 418e8d8bef9SDimitry Andric AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 419e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE) 4205ffd83dbSDimitry Andric : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 4215ffd83dbSDimitry Andric LI(LI), ORE(ORE) {} 422480093f4SDimitry Andric 4235ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) { 4245ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type"); 4255ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(), 4265ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements()); 4275ffd83dbSDimitry Andric } 4285ffd83dbSDimitry Andric 429fe6060f1SDimitry Andric /// Is this the minimal version executed in the backend pipelines. 430fe6060f1SDimitry Andric bool isMinimal() const { 431fe6060f1SDimitry Andric return !DT; 432fe6060f1SDimitry Andric } 433fe6060f1SDimitry Andric 4345ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on 4355ffd83dbSDimitry Andric /// \p VT * N. 4365ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) { 4375ffd83dbSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 438fe6060f1SDimitry Andric double(TTI.getRegisterBitWidth( 439fe6060f1SDimitry Andric TargetTransformInfo::RGK_FixedWidthVector) 440fe6060f1SDimitry Andric .getFixedSize())); 4415ffd83dbSDimitry Andric } 4425ffd83dbSDimitry Andric 4435ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to. 444480093f4SDimitry Andric /// 4455ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 4465ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 4475ffd83dbSDimitry Andric /// into vectors. 4485ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 4495ffd83dbSDimitry Andric IRBuilder<> &Builder) { 450480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 451480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 4525ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() == 4535ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns && 454480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 455480093f4SDimitry Andric 456480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 4575ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape 458480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 459480093f4SDimitry Andric // vector and split it later. 460480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 461480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 4625ffd83dbSDimitry Andric MatrixTy &M = Found->second; 463480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 464480093f4SDimitry Andric // information 465480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 466480093f4SDimitry Andric return M; 467480093f4SDimitry Andric 468480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 469480093f4SDimitry Andric } 470480093f4SDimitry Andric 471480093f4SDimitry Andric // Otherwise split MatrixVal. 472480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 4735ffd83dbSDimitry Andric for (unsigned MaskStart = 0; 4745ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 4755ffd83dbSDimitry Andric MaskStart += SI.getStride()) { 4765ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector( 477e8d8bef9SDimitry Andric MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 4785ffd83dbSDimitry Andric "split"); 479480093f4SDimitry Andric SplitVecs.push_back(V); 480480093f4SDimitry Andric } 481480093f4SDimitry Andric 482480093f4SDimitry Andric return {SplitVecs}; 483480093f4SDimitry Andric } 484480093f4SDimitry Andric 485480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 486480093f4SDimitry Andric /// for instructions that support it. 487480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 488480093f4SDimitry Andric assert(Shape && "Shape not set"); 489480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 490480093f4SDimitry Andric return false; 491480093f4SDimitry Andric 492480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 493480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 494480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 495480093f4SDimitry Andric << SIter->second.NumRows << " " 496480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 497480093f4SDimitry Andric return false; 498480093f4SDimitry Andric } 499480093f4SDimitry Andric 500480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 501480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 502480093f4SDimitry Andric << " for " << *V << "\n"); 503480093f4SDimitry Andric return true; 504480093f4SDimitry Andric } 505480093f4SDimitry Andric 506480093f4SDimitry Andric bool isUniformShape(Value *V) { 507480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 508480093f4SDimitry Andric if (!I) 509480093f4SDimitry Andric return true; 510480093f4SDimitry Andric 511480093f4SDimitry Andric switch (I->getOpcode()) { 512480093f4SDimitry Andric case Instruction::FAdd: 513480093f4SDimitry Andric case Instruction::FSub: 514480093f4SDimitry Andric case Instruction::FMul: // Scalar multiply. 515e8d8bef9SDimitry Andric case Instruction::FNeg: 516480093f4SDimitry Andric case Instruction::Add: 517480093f4SDimitry Andric case Instruction::Mul: 518480093f4SDimitry Andric case Instruction::Sub: 519480093f4SDimitry Andric return true; 520480093f4SDimitry Andric default: 521480093f4SDimitry Andric return false; 522480093f4SDimitry Andric } 523480093f4SDimitry Andric } 524480093f4SDimitry Andric 525480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 526480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 527480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 528480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 529480093f4SDimitry Andric if (!Inst) 530480093f4SDimitry Andric return false; 531480093f4SDimitry Andric 532480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 533480093f4SDimitry Andric if (II) 534480093f4SDimitry Andric switch (II->getIntrinsicID()) { 535480093f4SDimitry Andric case Intrinsic::matrix_multiply: 536480093f4SDimitry Andric case Intrinsic::matrix_transpose: 5375ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 5385ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 539480093f4SDimitry Andric return true; 540480093f4SDimitry Andric default: 541480093f4SDimitry Andric return false; 542480093f4SDimitry Andric } 543480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 544480093f4SDimitry Andric } 545480093f4SDimitry Andric 546480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 547480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 548480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 549480093f4SDimitry Andric /// shapes of operands. 550480093f4SDimitry Andric SmallVector<Instruction *, 32> 551480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 552480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 553480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 554480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 555480093f4SDimitry Andric // list. 556480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 557480093f4SDimitry Andric while (!WorkList.empty()) { 558e8d8bef9SDimitry Andric Instruction *Inst = WorkList.pop_back_val(); 559480093f4SDimitry Andric 560480093f4SDimitry Andric // New entry, set the value and insert operands 561480093f4SDimitry Andric bool Propagate = false; 562480093f4SDimitry Andric 563480093f4SDimitry Andric Value *MatrixA; 564480093f4SDimitry Andric Value *MatrixB; 565480093f4SDimitry Andric Value *M; 566480093f4SDimitry Andric Value *N; 567480093f4SDimitry Andric Value *K; 568480093f4SDimitry Andric if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 569480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 570480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 571480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, K}); 572480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 573480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 574480093f4SDimitry Andric // Flip dimensions. 575480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5765ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 577480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 5785ffd83dbSDimitry Andric m_Value(), m_Value(M), m_Value(N)))) { 579480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5805ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 5815ffd83dbSDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), 5825ffd83dbSDimitry Andric m_Value(N)))) { 583480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, N}); 584480093f4SDimitry Andric } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 585480093f4SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 586480093f4SDimitry Andric if (OpShape != ShapeMap.end()) 587480093f4SDimitry Andric setShapeInfo(Inst, OpShape->second); 588480093f4SDimitry Andric continue; 589480093f4SDimitry Andric } else if (isUniformShape(Inst)) { 590480093f4SDimitry Andric // Find the first operand that has a known shape and use that. 591480093f4SDimitry Andric for (auto &Op : Inst->operands()) { 592480093f4SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 593480093f4SDimitry Andric if (OpShape != ShapeMap.end()) { 594480093f4SDimitry Andric Propagate |= setShapeInfo(Inst, OpShape->second); 595480093f4SDimitry Andric break; 596480093f4SDimitry Andric } 597480093f4SDimitry Andric } 598480093f4SDimitry Andric } 599480093f4SDimitry Andric 600480093f4SDimitry Andric if (Propagate) { 601480093f4SDimitry Andric NewWorkList.push_back(Inst); 602480093f4SDimitry Andric for (auto *User : Inst->users()) 603480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 604480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 605480093f4SDimitry Andric } 606480093f4SDimitry Andric } 607480093f4SDimitry Andric 608480093f4SDimitry Andric return NewWorkList; 609480093f4SDimitry Andric } 610480093f4SDimitry Andric 611480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 612480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 613480093f4SDimitry Andric SmallVector<Instruction *, 32> 614480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 615480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 616480093f4SDimitry Andric 617480093f4SDimitry Andric auto pushInstruction = [](Value *V, 618480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 619480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 620480093f4SDimitry Andric if (I) 621480093f4SDimitry Andric WorkList.push_back(I); 622480093f4SDimitry Andric }; 623480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 624480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 625480093f4SDimitry Andric // worklist. 626480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 627480093f4SDimitry Andric while (!WorkList.empty()) { 628e8d8bef9SDimitry Andric Value *V = WorkList.pop_back_val(); 629480093f4SDimitry Andric 630480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 631480093f4SDimitry Andric if (!isa<Instruction>(V)) 632480093f4SDimitry Andric continue; 633480093f4SDimitry Andric 634480093f4SDimitry Andric Value *MatrixA; 635480093f4SDimitry Andric Value *MatrixB; 636480093f4SDimitry Andric Value *M; 637480093f4SDimitry Andric Value *N; 638480093f4SDimitry Andric Value *K; 639480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 640480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 641480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 642480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 643480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 644480093f4SDimitry Andric 645480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 646480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 647480093f4SDimitry Andric 648480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 649480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 650480093f4SDimitry Andric // Flip dimensions. 651480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 652480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 6535ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 6545ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 655480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 656480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 657480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 658480093f4SDimitry Andric } 659480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 6605ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 661480093f4SDimitry Andric // Nothing to do, no matrix input. 662480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 663480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 664480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 665480093f4SDimitry Andric } else if (isUniformShape(V)) { 666480093f4SDimitry Andric // Propagate to all operands. 667480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 668480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 669480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 670480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 671480093f4SDimitry Andric } 672480093f4SDimitry Andric } 673480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 674480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 675480093f4SDimitry Andric // propagation. 676480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 677480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 678480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 679480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 680480093f4SDimitry Andric } 681480093f4SDimitry Andric return NewWorkList; 682480093f4SDimitry Andric } 683480093f4SDimitry Andric 684fe6060f1SDimitry Andric /// Try moving transposes in order to fold them away or into multiplies. 685fe6060f1SDimitry Andric void optimizeTransposes() { 686fe6060f1SDimitry Andric auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { 687fe6060f1SDimitry Andric // We need to remove Old from the ShapeMap otherwise RAUW will replace it 688fe6060f1SDimitry Andric // with New. We should only add New it it supportsShapeInfo so we insert 689fe6060f1SDimitry Andric // it conditionally instead. 690fe6060f1SDimitry Andric auto S = ShapeMap.find(&Old); 691fe6060f1SDimitry Andric if (S != ShapeMap.end()) { 692fe6060f1SDimitry Andric ShapeMap.erase(S); 693fe6060f1SDimitry Andric if (supportsShapeInfo(New)) 694fe6060f1SDimitry Andric ShapeMap.insert({New, S->second}); 695fe6060f1SDimitry Andric } 696fe6060f1SDimitry Andric Old.replaceAllUsesWith(New); 697fe6060f1SDimitry Andric }; 698fe6060f1SDimitry Andric 699fe6060f1SDimitry Andric // First sink all transposes inside matmuls, hoping that we end up with NN, 700fe6060f1SDimitry Andric // NT or TN variants. 701fe6060f1SDimitry Andric for (BasicBlock &BB : reverse(Func)) { 702fe6060f1SDimitry Andric for (auto II = BB.rbegin(); II != BB.rend();) { 703fe6060f1SDimitry Andric Instruction &I = *II; 704fe6060f1SDimitry Andric // We may remove II. By default continue on the next/prev instruction. 705fe6060f1SDimitry Andric ++II; 706fe6060f1SDimitry Andric // If we were to erase II, move again. 707*81ad6265SDimitry Andric auto EraseFromParent = [&II, &BB](Value *V) { 708fe6060f1SDimitry Andric auto *Inst = cast<Instruction>(V); 709fe6060f1SDimitry Andric if (Inst->use_empty()) { 710*81ad6265SDimitry Andric if (II != BB.rend() && Inst == &*II) { 711fe6060f1SDimitry Andric ++II; 712fe6060f1SDimitry Andric } 713fe6060f1SDimitry Andric Inst->eraseFromParent(); 714fe6060f1SDimitry Andric } 715fe6060f1SDimitry Andric }; 716fe6060f1SDimitry Andric 717fe6060f1SDimitry Andric // If we're creating a new instruction, continue from there. 718fe6060f1SDimitry Andric Instruction *NewInst = nullptr; 719fe6060f1SDimitry Andric 720fe6060f1SDimitry Andric IRBuilder<> IB(&I); 721*81ad6265SDimitry Andric MatrixBuilder Builder(IB); 722fe6060f1SDimitry Andric 723fe6060f1SDimitry Andric Value *TA, *TAMA, *TAMB; 724fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 725fe6060f1SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { 726fe6060f1SDimitry Andric 727fe6060f1SDimitry Andric // Transpose of a transpose is a nop 728fe6060f1SDimitry Andric Value *TATA; 729fe6060f1SDimitry Andric if (match(TA, 730fe6060f1SDimitry Andric m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 731fe6060f1SDimitry Andric ReplaceAllUsesWith(I, TATA); 732fe6060f1SDimitry Andric EraseFromParent(&I); 733fe6060f1SDimitry Andric EraseFromParent(TA); 734fe6060f1SDimitry Andric } 735fe6060f1SDimitry Andric 736fe6060f1SDimitry Andric // (A * B)^t -> B^t * A^t 737fe6060f1SDimitry Andric // RxK KxC CxK KxR 738fe6060f1SDimitry Andric else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 739fe6060f1SDimitry Andric m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 740fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C)))) { 741fe6060f1SDimitry Andric Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), 742fe6060f1SDimitry Andric C->getZExtValue(), 743fe6060f1SDimitry Andric TAMB->getName() + "_t"); 744fe6060f1SDimitry Andric // We are being run after shape prop, add shape for newly created 745fe6060f1SDimitry Andric // instructions so that we lower them later. 746fe6060f1SDimitry Andric setShapeInfo(T0, {C, K}); 747fe6060f1SDimitry Andric Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), 748fe6060f1SDimitry Andric K->getZExtValue(), 749fe6060f1SDimitry Andric TAMA->getName() + "_t"); 750fe6060f1SDimitry Andric setShapeInfo(T1, {K, R}); 751fe6060f1SDimitry Andric NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), 752fe6060f1SDimitry Andric K->getZExtValue(), 753fe6060f1SDimitry Andric R->getZExtValue(), "mmul"); 754fe6060f1SDimitry Andric ReplaceAllUsesWith(I, NewInst); 755fe6060f1SDimitry Andric EraseFromParent(&I); 756fe6060f1SDimitry Andric EraseFromParent(TA); 757fe6060f1SDimitry Andric } 758fe6060f1SDimitry Andric } 759fe6060f1SDimitry Andric 760fe6060f1SDimitry Andric // If we replaced I with a new instruction, continue from there. 761fe6060f1SDimitry Andric if (NewInst) 762fe6060f1SDimitry Andric II = std::next(BasicBlock::reverse_iterator(NewInst)); 763fe6060f1SDimitry Andric } 764fe6060f1SDimitry Andric } 765fe6060f1SDimitry Andric 766fe6060f1SDimitry Andric // If we have a TT matmul, lift the transpose. We may be able to fold into 767fe6060f1SDimitry Andric // consuming multiply. 768fe6060f1SDimitry Andric for (BasicBlock &BB : Func) { 769*81ad6265SDimitry Andric for (Instruction &I : llvm::make_early_inc_range(BB)) { 770fe6060f1SDimitry Andric Value *A, *B, *AT, *BT; 771fe6060f1SDimitry Andric ConstantInt *R, *K, *C; 772fe6060f1SDimitry Andric // A^t * B ^t -> (B * A)^t 773*81ad6265SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 774fe6060f1SDimitry Andric m_Value(A), m_Value(B), m_ConstantInt(R), 775fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C))) && 776fe6060f1SDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 777fe6060f1SDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 778*81ad6265SDimitry Andric IRBuilder<> IB(&I); 779*81ad6265SDimitry Andric MatrixBuilder Builder(IB); 780fe6060f1SDimitry Andric Value *M = Builder.CreateMatrixMultiply( 781fe6060f1SDimitry Andric BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 782fe6060f1SDimitry Andric setShapeInfo(M, {C, R}); 783fe6060f1SDimitry Andric Instruction *NewInst = Builder.CreateMatrixTranspose( 784fe6060f1SDimitry Andric M, C->getZExtValue(), R->getZExtValue()); 785*81ad6265SDimitry Andric ReplaceAllUsesWith(I, NewInst); 786*81ad6265SDimitry Andric if (I.use_empty()) 787*81ad6265SDimitry Andric I.eraseFromParent(); 788fe6060f1SDimitry Andric if (A->use_empty()) 789fe6060f1SDimitry Andric cast<Instruction>(A)->eraseFromParent(); 790fe6060f1SDimitry Andric if (A != B && B->use_empty()) 791fe6060f1SDimitry Andric cast<Instruction>(B)->eraseFromParent(); 792fe6060f1SDimitry Andric } 793fe6060f1SDimitry Andric } 794fe6060f1SDimitry Andric } 795fe6060f1SDimitry Andric } 796fe6060f1SDimitry Andric 797480093f4SDimitry Andric bool Visit() { 798480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 799480093f4SDimitry Andric 800480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 801480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 802480093f4SDimitry Andric for (BasicBlock &BB : Func) 803480093f4SDimitry Andric for (Instruction &Inst : BB) { 804480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 805480093f4SDimitry Andric if (!II) 806480093f4SDimitry Andric continue; 807480093f4SDimitry Andric 808480093f4SDimitry Andric switch (II->getIntrinsicID()) { 809480093f4SDimitry Andric case Intrinsic::matrix_multiply: 810480093f4SDimitry Andric case Intrinsic::matrix_transpose: 8115ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 8125ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 813480093f4SDimitry Andric WorkList.push_back(&Inst); 814480093f4SDimitry Andric break; 815480093f4SDimitry Andric default: 816480093f4SDimitry Andric break; 817480093f4SDimitry Andric } 818480093f4SDimitry Andric } 819fe6060f1SDimitry Andric 820fe6060f1SDimitry Andric // Avoid unnecessary work if there are no matrix intrinsics in the function. 821fe6060f1SDimitry Andric if (WorkList.empty()) 822fe6060f1SDimitry Andric return false; 823fe6060f1SDimitry Andric 824480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 825480093f4SDimitry Andric while (!WorkList.empty()) { 826480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 827480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 828480093f4SDimitry Andric } 829fe6060f1SDimitry Andric 830fe6060f1SDimitry Andric if (!isMinimal()) { 831fe6060f1SDimitry Andric optimizeTransposes(); 832fe6060f1SDimitry Andric LLVM_DEBUG({ 833fe6060f1SDimitry Andric dbgs() << "Dump after matrix transpose optimization:\n"; 834fe6060f1SDimitry Andric Func.dump(); 835fe6060f1SDimitry Andric }); 836480093f4SDimitry Andric } 837480093f4SDimitry Andric 838480093f4SDimitry Andric bool Changed = false; 8395ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts; 8405ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts; 841480093f4SDimitry Andric 8425ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for 8435ffd83dbSDimitry Andric // fusion (currently only matrix multiplies). 8445ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 8455ffd83dbSDimitry Andric for (auto *BB : RPOT) 8465ffd83dbSDimitry Andric for (Instruction &I : *BB) { 8475ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end()) 8485ffd83dbSDimitry Andric continue; 8495ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 8505ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I)); 8515ffd83dbSDimitry Andric MatrixInsts.push_back(&I); 8525ffd83dbSDimitry Andric } 8535ffd83dbSDimitry Andric 8545ffd83dbSDimitry Andric // Second, try to fuse candidates. 8555ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts; 8565ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts) 8575ffd83dbSDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts); 8585ffd83dbSDimitry Andric Changed = !FusedInsts.empty(); 8595ffd83dbSDimitry Andric 8605ffd83dbSDimitry Andric // Third, lower remaining instructions with shape information. 8615ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) { 8625ffd83dbSDimitry Andric if (FusedInsts.count(Inst)) 8635ffd83dbSDimitry Andric continue; 8645ffd83dbSDimitry Andric 8655ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 8665ffd83dbSDimitry Andric 8675ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 868480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 869480093f4SDimitry Andric 870480093f4SDimitry Andric Value *Op1; 871480093f4SDimitry Andric Value *Op2; 8725ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 873480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 874e8d8bef9SDimitry Andric if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 875e8d8bef9SDimitry Andric Changed |= VisitUnaryOperator(UnOp); 8765ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1)))) 8775ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 8785ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 8795ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 880480093f4SDimitry Andric } 8815ffd83dbSDimitry Andric 882e8d8bef9SDimitry Andric if (ORE) { 883e8d8bef9SDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 8845ffd83dbSDimitry Andric RemarkGen.emitRemarks(); 885e8d8bef9SDimitry Andric } 886480093f4SDimitry Andric 887fe6060f1SDimitry Andric // Delete the instructions backwards, as it has a reduced likelihood of 888fe6060f1SDimitry Andric // having to update as many def-use and use-def chains. 889fe6060f1SDimitry Andric // 890fe6060f1SDimitry Andric // Because we add to ToRemove during fusion we can't guarantee that defs 891*81ad6265SDimitry Andric // are before uses. Change uses to poison temporarily as these should get 892fe6060f1SDimitry Andric // removed as well. 893fe6060f1SDimitry Andric // 894*81ad6265SDimitry Andric // For verification, we keep track of where we changed uses to poison in 895*81ad6265SDimitry Andric // PoisonedInsts and then check that we in fact remove them. 896*81ad6265SDimitry Andric SmallSet<Instruction *, 16> PoisonedInsts; 897fe6060f1SDimitry Andric for (auto *Inst : reverse(ToRemove)) { 898349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 899*81ad6265SDimitry Andric if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 900*81ad6265SDimitry Andric PoisonedInsts.insert(Poisoned); 901*81ad6265SDimitry Andric U.set(PoisonValue::get(Inst->getType())); 902fe6060f1SDimitry Andric } 903480093f4SDimitry Andric Inst->eraseFromParent(); 904*81ad6265SDimitry Andric PoisonedInsts.erase(Inst); 905fe6060f1SDimitry Andric } 906*81ad6265SDimitry Andric if (!PoisonedInsts.empty()) { 907*81ad6265SDimitry Andric // If we didn't remove all poisoned instructions, it's a hard error. 908*81ad6265SDimitry Andric dbgs() << "Poisoned but present instructions:\n"; 909*81ad6265SDimitry Andric for (auto *I : PoisonedInsts) 910fe6060f1SDimitry Andric dbgs() << *I << "\n"; 911*81ad6265SDimitry Andric llvm_unreachable("Poisoned but instruction not removed"); 912fe6060f1SDimitry Andric } 913480093f4SDimitry Andric 914480093f4SDimitry Andric return Changed; 915480093f4SDimitry Andric } 916480093f4SDimitry Andric 917480093f4SDimitry Andric /// Turns \p BasePtr into an elementwise pointer to \p EltType. 918480093f4SDimitry Andric Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 919480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 920480093f4SDimitry Andric Type *EltPtrType = PointerType::get(EltType, AS); 921480093f4SDimitry Andric return Builder.CreatePointerCast(BasePtr, EltPtrType); 922480093f4SDimitry Andric } 923480093f4SDimitry Andric 924480093f4SDimitry Andric /// Replace intrinsic calls 925480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 926480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 927480093f4SDimitry Andric return false; 928480093f4SDimitry Andric 929480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 930480093f4SDimitry Andric case Intrinsic::matrix_multiply: 931480093f4SDimitry Andric LowerMultiply(Inst); 932480093f4SDimitry Andric break; 933480093f4SDimitry Andric case Intrinsic::matrix_transpose: 934480093f4SDimitry Andric LowerTranspose(Inst); 935480093f4SDimitry Andric break; 9365ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 9375ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst); 938480093f4SDimitry Andric break; 9395ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 9405ffd83dbSDimitry Andric LowerColumnMajorStore(Inst); 941480093f4SDimitry Andric break; 942480093f4SDimitry Andric default: 943480093f4SDimitry Andric return false; 944480093f4SDimitry Andric } 945480093f4SDimitry Andric return true; 946480093f4SDimitry Andric } 947480093f4SDimitry Andric 9485ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them. 9495ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 9505ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For 9515ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial 9525ffd83dbSDimitry Andric /// alignment and the element size in bytes. 9535ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 9545ffd83dbSDimitry Andric MaybeAlign A) const { 9555ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 9565ffd83dbSDimitry Andric if (Idx == 0) 9575ffd83dbSDimitry Andric return InitialAlign; 9585ffd83dbSDimitry Andric 9595ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 9605ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 9615ffd83dbSDimitry Andric uint64_t StrideInBytes = 9625ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8; 9635ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes); 9645ffd83dbSDimitry Andric } 9655ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8); 9665ffd83dbSDimitry Andric } 9675ffd83dbSDimitry Andric 9685ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 9695ffd83dbSDimitry Andric /// vectors. 9705ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 9715ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 972fe6060f1SDimitry Andric auto *VType = cast<VectorType>(Ty); 973fe6060f1SDimitry Andric Type *EltTy = VType->getElementType(); 974fe6060f1SDimitry Andric Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 975fe6060f1SDimitry Andric Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 9765ffd83dbSDimitry Andric MatrixTy Result; 9775ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 978349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 979349cc55cSDimitry Andric EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 980349cc55cSDimitry Andric Stride, Shape.getStride(), EltTy, Builder); 9815ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad( 982fe6060f1SDimitry Andric VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 9835ffd83dbSDimitry Andric IsVolatile, "col.load"); 9845ffd83dbSDimitry Andric 9855ffd83dbSDimitry Andric Result.addVector(Vector); 9865ffd83dbSDimitry Andric } 9875ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 9885ffd83dbSDimitry Andric Result.getNumVectors()); 989480093f4SDimitry Andric } 990480093f4SDimitry Andric 9915ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 9925ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J]. 9935ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 9945ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J, 9955ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy, 9965ffd83dbSDimitry Andric IRBuilder<> &Builder) { 9975ffd83dbSDimitry Andric 9985ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 9995ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 10005ffd83dbSDimitry Andric 10015ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 10025ffd83dbSDimitry Andric Value *EltPtr = 10035ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 10045ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 10055ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 10065ffd83dbSDimitry Andric ResultShape.NumColumns); 10075ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 10085ffd83dbSDimitry Andric Value *TilePtr = 10095ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 10105ffd83dbSDimitry Andric 10115ffd83dbSDimitry Andric return loadMatrix(TileTy, TilePtr, Align, 10125ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, 10135ffd83dbSDimitry Andric ResultShape, Builder); 1014480093f4SDimitry Andric } 1015480093f4SDimitry Andric 10165ffd83dbSDimitry Andric /// Lower a load instruction with shape information. 10175ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 10185ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) { 10195ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 10205ffd83dbSDimitry Andric finalizeLowering(Inst, 10215ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 10225ffd83dbSDimitry Andric Shape, Builder), 10235ffd83dbSDimitry Andric Builder); 10245ffd83dbSDimitry Andric } 10255ffd83dbSDimitry Andric 10265ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load. 1027480093f4SDimitry Andric /// 1028480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 10295ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) { 10305ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 10315ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 1032480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 1033480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 10345ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 10355ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1036480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1037480093f4SDimitry Andric } 1038480093f4SDimitry Andric 10395ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 10405ffd83dbSDimitry Andric /// MatrixPtr[I][J]. 10415ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 10425ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 10435ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 10445ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 10455ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 10465ffd83dbSDimitry Andric 10475ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 10485ffd83dbSDimitry Andric Value *EltPtr = 10495ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 10505ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 10515ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 10525ffd83dbSDimitry Andric StoreVal.getNumColumns()); 10535ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 10545ffd83dbSDimitry Andric Value *TilePtr = 10555ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 10565ffd83dbSDimitry Andric 10575ffd83dbSDimitry Andric storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 10585ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 10595ffd83dbSDimitry Andric } 10605ffd83dbSDimitry Andric 10615ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 10625ffd83dbSDimitry Andric /// vectors. 10635ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 10645ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile, 10655ffd83dbSDimitry Andric IRBuilder<> &Builder) { 10665ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 10675ffd83dbSDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 10685ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) { 1069349cc55cSDimitry Andric Value *GEP = computeVectorAddr( 1070349cc55cSDimitry Andric EltPtr, 1071349cc55cSDimitry Andric Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1072349cc55cSDimitry Andric Vec.index()), 1073349cc55cSDimitry Andric Stride, StoreVal.getStride(), VType->getElementType(), Builder); 10745ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP, 10755ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride, 10765ffd83dbSDimitry Andric VType->getElementType(), 10775ffd83dbSDimitry Andric MAlign), 10785ffd83dbSDimitry Andric IsVolatile); 10795ffd83dbSDimitry Andric } 10805ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 10815ffd83dbSDimitry Andric StoreVal.getNumVectors()); 10825ffd83dbSDimitry Andric } 10835ffd83dbSDimitry Andric 10845ffd83dbSDimitry Andric /// Lower a store instruction with shape information. 10855ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 10865ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) { 10875ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 10885ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder); 10895ffd83dbSDimitry Andric finalizeLowering(Inst, 10905ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 10915ffd83dbSDimitry Andric IsVolatile, Builder), 10925ffd83dbSDimitry Andric Builder); 10935ffd83dbSDimitry Andric } 10945ffd83dbSDimitry Andric 10955ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store. 10965ffd83dbSDimitry Andric /// 10975ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 10985ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) { 10995ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 11005ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 11015ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0); 11025ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1); 11035ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2); 11045ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 11055ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 11065ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1107480093f4SDimitry Andric } 1108480093f4SDimitry Andric 1109480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 1110480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 11115ffd83dbSDimitry Andric IRBuilder<> &Builder) { 1112480093f4SDimitry Andric 1113480093f4SDimitry Andric // First, bring Block to the same size as Col 1114480093f4SDimitry Andric unsigned BlockNumElts = 11155ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements(); 11165ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1117480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1118480093f4SDimitry Andric 11195ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector( 1120e8d8bef9SDimitry Andric Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1121480093f4SDimitry Andric 1122480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1123480093f4SDimitry Andric // 8, 4, 5, 6 11245ffd83dbSDimitry Andric SmallVector<int, 16> Mask; 1125480093f4SDimitry Andric unsigned i; 1126480093f4SDimitry Andric for (i = 0; i < I; i++) 11275ffd83dbSDimitry Andric Mask.push_back(i); 1128480093f4SDimitry Andric 11295ffd83dbSDimitry Andric unsigned VecNumElts = 11305ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements(); 1131480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 11325ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts); 1133480093f4SDimitry Andric 1134480093f4SDimitry Andric for (; i < VecNumElts; i++) 11355ffd83dbSDimitry Andric Mask.push_back(i); 1136480093f4SDimitry Andric 11375ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask); 1138480093f4SDimitry Andric } 1139480093f4SDimitry Andric 1140480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 11415ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction, 11425ffd83dbSDimitry Andric unsigned &NumComputeOps) { 11435ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1144480093f4SDimitry Andric if (!Sum) 1145480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1146480093f4SDimitry Andric 1147480093f4SDimitry Andric if (UseFPOp) { 1148480093f4SDimitry Andric if (AllowContraction) { 1149480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 1150480093f4SDimitry Andric // if that's profitable. 11515ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration( 1152480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 1153480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1154480093f4SDimitry Andric } 11555ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1156480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 1157480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 1158480093f4SDimitry Andric } 1159480093f4SDimitry Andric 11605ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 1161480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 1162480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 1163480093f4SDimitry Andric } 1164480093f4SDimitry Andric 1165480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1166fe6060f1SDimitry Andric /// users with shape information, there's nothing to do: they will use the 1167480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 1168480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 1169480093f4SDimitry Andric /// deletion. 11705ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1171480093f4SDimitry Andric IRBuilder<> &Builder) { 1172fe6060f1SDimitry Andric auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1173fe6060f1SDimitry Andric (void)inserted; 1174fe6060f1SDimitry Andric assert(inserted.second && "multiple matrix lowering mapping"); 1175480093f4SDimitry Andric 1176480093f4SDimitry Andric ToRemove.push_back(Inst); 1177480093f4SDimitry Andric Value *Flattened = nullptr; 1178fe6060f1SDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1179480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1180480093f4SDimitry Andric if (!Flattened) 1181480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 1182480093f4SDimitry Andric U.set(Flattened); 1183480093f4SDimitry Andric } 1184480093f4SDimitry Andric } 1185480093f4SDimitry Andric } 1186480093f4SDimitry Andric 11875ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating 11885ffd83dbSDimitry Andric /// addition. 1189fe6060f1SDimitry Andric /// 1190fe6060f1SDimitry Andric /// We can fold a transpose into the operand that is used to extract scalars. 1191fe6060f1SDimitry Andric /// This is the first operands with row-major and the second with 1192fe6060f1SDimitry Andric /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1193fe6060f1SDimitry Andric /// operand is transposed. 11945ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1195fe6060f1SDimitry Andric const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1196fe6060f1SDimitry Andric bool IsScalarMatrixTransposed, FastMathFlags FMF) { 11975ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>( 1198fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1199fe6060f1SDimitry Andric .getFixedSize() / 12005ffd83dbSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 12015ffd83dbSDimitry Andric 1U); 12025ffd83dbSDimitry Andric unsigned R = Result.getNumRows(); 12035ffd83dbSDimitry Andric unsigned C = Result.getNumColumns(); 12045ffd83dbSDimitry Andric unsigned M = A.getNumColumns(); 12055ffd83dbSDimitry Andric 12065ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy(); 12075ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 12085ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 12095ffd83dbSDimitry Andric "operands must agree on matrix layout"); 12105ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 1211fe6060f1SDimitry Andric 1212fe6060f1SDimitry Andric Builder.setFastMathFlags(FMF); 1213fe6060f1SDimitry Andric 12145ffd83dbSDimitry Andric if (A.isColumnMajor()) { 12155ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second 12165ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 12175ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation. 12185ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) { 12195ffd83dbSDimitry Andric unsigned BlockSize = VF; 12205ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration. 12215ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 12225ffd83dbSDimitry Andric 12235ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 12245ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 12255ffd83dbSDimitry Andric while (I + BlockSize > R) 12265ffd83dbSDimitry Andric BlockSize /= 2; 12275ffd83dbSDimitry Andric 1228fe6060f1SDimitry Andric Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 12295ffd83dbSDimitry Andric : nullptr; 12305ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 12315ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder); 1232fe6060f1SDimitry Andric Value *RH = Builder.CreateExtractElement( 1233fe6060f1SDimitry Andric B.getColumn(IsScalarMatrixTransposed ? K : J), 1234fe6060f1SDimitry Andric IsScalarMatrixTransposed ? J : K); 12355ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1236fe6060f1SDimitry Andric Sum = 1237fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1238fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 12395ffd83dbSDimitry Andric } 12405ffd83dbSDimitry Andric Result.setVector(J, 12415ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder)); 12425ffd83dbSDimitry Andric } 12435ffd83dbSDimitry Andric } 12445ffd83dbSDimitry Andric } else { 12455ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first 12465ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this 12475ffd83dbSDimitry Andric // the adds can be vectorized without reassociation. 12485ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) { 12495ffd83dbSDimitry Andric unsigned BlockSize = VF; 12505ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 12515ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) { 12525ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 12535ffd83dbSDimitry Andric while (J + BlockSize > C) 12545ffd83dbSDimitry Andric BlockSize /= 2; 12555ffd83dbSDimitry Andric 12565ffd83dbSDimitry Andric Value *Sum = nullptr; 12575ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 12585ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder); 1259fe6060f1SDimitry Andric Value *LH = Builder.CreateExtractElement( 1260fe6060f1SDimitry Andric A.getVector(IsScalarMatrixTransposed ? K : I), 1261fe6060f1SDimitry Andric IsScalarMatrixTransposed ? I : K); 12625ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1263fe6060f1SDimitry Andric Sum = 1264fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1265fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps); 12665ffd83dbSDimitry Andric } 12675ffd83dbSDimitry Andric Result.setVector(I, 12685ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder)); 12695ffd83dbSDimitry Andric } 12705ffd83dbSDimitry Andric } 12715ffd83dbSDimitry Andric } 12725ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps); 12735ffd83dbSDimitry Andric } 12745ffd83dbSDimitry Andric 12755ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially 12765ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location 12775ffd83dbSDimitry Andric /// is returned. 12785ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 12795ffd83dbSDimitry Andric CallInst *MatMul) { 12805ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store); 12815ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load); 12825ffd83dbSDimitry Andric 12835ffd83dbSDimitry Andric // If we can statically determine noalias we're good. 1284fe6060f1SDimitry Andric if (AA->isNoAlias(LoadLoc, StoreLoc)) 12855ffd83dbSDimitry Andric return Load->getPointerOperand(); 12865ffd83dbSDimitry Andric 12875ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store 12885ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer. 12895ffd83dbSDimitry Andric 12905ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy. 12915ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent(); 12925ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 12935ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work, 12945ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches. 12955ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 12965ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0)) 1297e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Delete, Check0, Succ}); 12985ffd83dbSDimitry Andric 1299e8d8bef9SDimitry Andric BasicBlock *Check1 = 1300e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 13015ffd83dbSDimitry Andric nullptr, "alias_cont"); 13025ffd83dbSDimitry Andric BasicBlock *Copy = 1303e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1304e8d8bef9SDimitry Andric nullptr, "copy"); 1305e8d8bef9SDimitry Andric BasicBlock *Fusion = 1306e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 13075ffd83dbSDimitry Andric nullptr, "no_alias"); 13085ffd83dbSDimitry Andric 13095ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store 13105ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are 13115ffd83dbSDimitry Andric // guaranteed to not overlap. 13125ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul); 13135ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent(); 13145ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0); 13155ffd83dbSDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 13165ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt( 13175ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 13185ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd( 13195ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 13205ffd83dbSDimitry Andric "store.end", true, true); 13215ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 13225ffd83dbSDimitry Andric IntPtrTy, "load.begin"); 13235ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 13245ffd83dbSDimitry Andric Fusion); 13255ffd83dbSDimitry Andric 13265ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the 13275ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not 13285ffd83dbSDimitry Andric // overlap. 13295ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent(); 13305ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin()); 13315ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd( 13325ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 13335ffd83dbSDimitry Andric "load.end", true, true); 13345ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 13355ffd83dbSDimitry Andric Fusion); 13365ffd83dbSDimitry Andric 13375ffd83dbSDimitry Andric // Copy load operand to new alloca. 13385ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin()); 13391fd87a68SDimitry Andric auto *VT = cast<FixedVectorType>(Load->getType()); 13401fd87a68SDimitry Andric // Use an array type for the alloca, to avoid potentially huge alignment 13411fd87a68SDimitry Andric // requirements for large vector types. 13421fd87a68SDimitry Andric auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 13431fd87a68SDimitry Andric AllocaInst *Alloca = 13441fd87a68SDimitry Andric Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 13451fd87a68SDimitry Andric Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); 13461fd87a68SDimitry Andric 13471fd87a68SDimitry Andric Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), 13481fd87a68SDimitry Andric Load->getAlign(), LoadLoc.Size.getValue()); 13495ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin()); 13505ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 13515ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0); 13525ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1); 13531fd87a68SDimitry Andric PHI->addIncoming(BC, Copy); 13545ffd83dbSDimitry Andric 13555ffd83dbSDimitry Andric // Adjust DT. 1356e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Check1}); 1357e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1358e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Copy}); 1359e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1360e8d8bef9SDimitry Andric DT->applyUpdates(DTUpdates); 13615ffd83dbSDimitry Andric return PHI; 13625ffd83dbSDimitry Andric } 13635ffd83dbSDimitry Andric 13645ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) { 13655ffd83dbSDimitry Andric if (ForceFusion) 13665ffd83dbSDimitry Andric return true; 13675ffd83dbSDimitry Andric 13685ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 13695ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 13705ffd83dbSDimitry Andric 13715ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 13725ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 13735ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 13745ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 13755ffd83dbSDimitry Andric 1376fe6060f1SDimitry Andric const unsigned VF = std::max<unsigned>( 1377fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1378fe6060f1SDimitry Andric .getFixedSize() / 13795ffd83dbSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedSize(), 13805ffd83dbSDimitry Andric 1U); 13815ffd83dbSDimitry Andric 13825ffd83dbSDimitry Andric // Cost model for tiling 13835ffd83dbSDimitry Andric // 13845ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or 13855ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least 13865ffd83dbSDimitry Andric // 3 elements. 13875ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias. 13885ffd83dbSDimitry Andric if (R <= VF && C == 1) 13895ffd83dbSDimitry Andric return false; 13905ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector 13915ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since 13925ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of 13935ffd83dbSDimitry Andric // reloads necessary. 13945ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M; 13955ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C; 139604eeddc0SDimitry Andric return Op0Regs + Op1Regs > 139704eeddc0SDimitry Andric TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 13985ffd83dbSDimitry Andric } 13995ffd83dbSDimitry Andric 14005ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 14015ffd83dbSDimitry Andric MatrixTy Res; 14025ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R); 14035ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I) 14045ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType)); 14055ffd83dbSDimitry Andric return Res; 14065ffd83dbSDimitry Andric } 14075ffd83dbSDimitry Andric 1408e8d8bef9SDimitry Andric void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1409fe6060f1SDimitry Andric Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1410e8d8bef9SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1411e8d8bef9SDimitry Andric 1412e8d8bef9SDimitry Andric // Create the main tiling loop nest. 1413e8d8bef9SDimitry Andric TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1414e8d8bef9SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1415e8d8bef9SDimitry Andric Instruction *InsertI = cast<Instruction>(MatMul); 1416e8d8bef9SDimitry Andric BasicBlock *Start = InsertI->getParent(); 1417e8d8bef9SDimitry Andric BasicBlock *End = 1418e8d8bef9SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1419e8d8bef9SDimitry Andric IRBuilder<> Builder(MatMul); 1420e8d8bef9SDimitry Andric BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1421e8d8bef9SDimitry Andric 1422e8d8bef9SDimitry Andric Type *TileVecTy = 1423e8d8bef9SDimitry Andric FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1424e8d8bef9SDimitry Andric MatrixTy TileResult; 1425e8d8bef9SDimitry Andric // Insert in the inner loop header. 1426e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1427e8d8bef9SDimitry Andric // Create PHI nodes for the result columns to accumulate across iterations. 1428e8d8bef9SDimitry Andric SmallVector<PHINode *, 4> ColumnPhis; 1429e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileSize; I++) { 1430e8d8bef9SDimitry Andric auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1431e8d8bef9SDimitry Andric Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1432e8d8bef9SDimitry Andric TI.RowLoopHeader->getSingleSuccessor()); 1433e8d8bef9SDimitry Andric TileResult.addVector(Phi); 1434e8d8bef9SDimitry Andric ColumnPhis.push_back(Phi); 1435e8d8bef9SDimitry Andric } 1436e8d8bef9SDimitry Andric 1437e8d8bef9SDimitry Andric // Insert in the inner loop body, which computes 1438e8d8bef9SDimitry Andric // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1439e8d8bef9SDimitry Andric Builder.SetInsertPoint(InnerBody->getTerminator()); 1440e8d8bef9SDimitry Andric // Load tiles of the operands. 1441e8d8bef9SDimitry Andric MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1442e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1443e8d8bef9SDimitry Andric MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1444e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1445fe6060f1SDimitry Andric emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1446fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1447e8d8bef9SDimitry Andric // Store result after the inner loop is done. 1448e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1449e8d8bef9SDimitry Andric storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1450e8d8bef9SDimitry Andric Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1451e8d8bef9SDimitry Andric TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1452e8d8bef9SDimitry Andric 1453e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1454e8d8bef9SDimitry Andric ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1455e8d8bef9SDimitry Andric 1456e8d8bef9SDimitry Andric // Force unrolling of a few iterations of the inner loop, to make sure there 1457e8d8bef9SDimitry Andric // is enough work per iteration. 1458e8d8bef9SDimitry Andric // FIXME: The unroller should make this decision directly instead, but 1459e8d8bef9SDimitry Andric // currently the cost-model is not up to the task. 1460e8d8bef9SDimitry Andric unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1461e8d8bef9SDimitry Andric addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1462e8d8bef9SDimitry Andric "llvm.loop.unroll.count", InnerLoopUnrollCount); 1463e8d8bef9SDimitry Andric } 1464e8d8bef9SDimitry Andric 14655ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 14665ffd83dbSDimitry Andric StoreInst *Store, 14675ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 14685ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 14695ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!"); 14705ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul)) 14715ffd83dbSDimitry Andric return; 14725ffd83dbSDimitry Andric 14735ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 14745ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 14755ffd83dbSDimitry Andric 14765ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 14775ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 14785ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 14795ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 14805ffd83dbSDimitry Andric 14815ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 14825ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 14835ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand(); 14845ffd83dbSDimitry Andric 1485e8d8bef9SDimitry Andric if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1486fe6060f1SDimitry Andric createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1487e8d8bef9SDimitry Andric else { 14885ffd83dbSDimitry Andric IRBuilder<> Builder(Store); 14895ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize) 14905ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) { 14915ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize)); 14925ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize)); 14935ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 14945ffd83dbSDimitry Andric 14955ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) { 14965ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize)); 14975ffd83dbSDimitry Andric MatrixTy A = 14985ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 14995ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K), 15005ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder); 15015ffd83dbSDimitry Andric MatrixTy B = 15025ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 15035ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J), 15045ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder); 1505fe6060f1SDimitry Andric emitMatrixMultiply(Res, A, B, Builder, true, false, 1506fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 15075ffd83dbSDimitry Andric } 15085ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1509e8d8bef9SDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType, 1510e8d8bef9SDimitry Andric Builder); 1511e8d8bef9SDimitry Andric } 15125ffd83dbSDimitry Andric } 15135ffd83dbSDimitry Andric 15145ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them. 15155ffd83dbSDimitry Andric FusedInsts.insert(Store); 15165ffd83dbSDimitry Andric FusedInsts.insert(MatMul); 15175ffd83dbSDimitry Andric Store->eraseFromParent(); 15185ffd83dbSDimitry Andric MatMul->eraseFromParent(); 15195ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) { 15205ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0); 15215ffd83dbSDimitry Andric LoadOp0->eraseFromParent(); 15225ffd83dbSDimitry Andric } 1523fe6060f1SDimitry Andric if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 15245ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1); 15255ffd83dbSDimitry Andric LoadOp1->eraseFromParent(); 15265ffd83dbSDimitry Andric } 15275ffd83dbSDimitry Andric } 15285ffd83dbSDimitry Andric 15295ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations. 15305ffd83dbSDimitry Andric /// 1531fe6060f1SDimitry Andric /// Call finalizeLowering on lowered instructions. Instructions that are 1532fe6060f1SDimitry Andric /// completely eliminated by fusion are added to \p FusedInsts. 15335ffd83dbSDimitry Andric void LowerMatrixMultiplyFused(CallInst *MatMul, 15345ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 1535fe6060f1SDimitry Andric if (!FuseMatrix || !DT) 15365ffd83dbSDimitry Andric return; 15375ffd83dbSDimitry Andric 1538e8d8bef9SDimitry Andric assert(AA && LI && "Analyses should be available"); 1539e8d8bef9SDimitry Andric 1540fe6060f1SDimitry Andric Value *A = MatMul->getArgOperand(0); 1541fe6060f1SDimitry Andric Value *B = MatMul->getArgOperand(1); 1542fe6060f1SDimitry Andric 1543fe6060f1SDimitry Andric // We can fold the transpose into the operand that is used to fetch scalars. 1544fe6060f1SDimitry Andric Value *T; 1545fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1546fe6060f1SDimitry Andric ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1547fe6060f1SDimitry Andric : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1548fe6060f1SDimitry Andric IRBuilder<> Builder(MatMul); 1549fe6060f1SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1550fe6060f1SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1551fe6060f1SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1552fe6060f1SDimitry Andric const unsigned R = LShape.NumRows; 1553fe6060f1SDimitry Andric const unsigned M = LShape.NumColumns; 1554fe6060f1SDimitry Andric const unsigned C = RShape.NumColumns; 1555fe6060f1SDimitry Andric 1556fe6060f1SDimitry Andric MatrixTy MA; 1557fe6060f1SDimitry Andric MatrixTy MB; 1558fe6060f1SDimitry Andric 1559fe6060f1SDimitry Andric Value *Transpose; 1560fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1561fe6060f1SDimitry Andric MA = getMatrix(A, ShapeInfo(R, M), Builder); 1562fe6060f1SDimitry Andric MB = getMatrix(T, ShapeInfo(C, M), Builder); 1563fe6060f1SDimitry Andric Transpose = B; 1564fe6060f1SDimitry Andric } else { 1565fe6060f1SDimitry Andric MA = getMatrix(T, ShapeInfo(R, M), Builder); 1566fe6060f1SDimitry Andric MB = getMatrix(B, ShapeInfo(C, M), Builder); 1567fe6060f1SDimitry Andric Transpose = A; 1568fe6060f1SDimitry Andric } 1569fe6060f1SDimitry Andric 1570fe6060f1SDimitry Andric // Initialize the output 1571fe6060f1SDimitry Andric MatrixTy Result(R, C, EltType); 1572fe6060f1SDimitry Andric 1573fe6060f1SDimitry Andric emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1574fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1575fe6060f1SDimitry Andric 1576fe6060f1SDimitry Andric FusedInsts.insert(MatMul); 1577fe6060f1SDimitry Andric if (Transpose->hasOneUse()) { 1578fe6060f1SDimitry Andric FusedInsts.insert(cast<Instruction>(Transpose)); 1579fe6060f1SDimitry Andric ToRemove.push_back(cast<Instruction>(Transpose)); 1580fe6060f1SDimitry Andric // TODO: add a fake entry for the folded instruction so that this is 1581fe6060f1SDimitry Andric // included in the expression in the remark. 1582fe6060f1SDimitry Andric Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1583fe6060f1SDimitry Andric } 1584fe6060f1SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1585fe6060f1SDimitry Andric return; 1586fe6060f1SDimitry Andric } 1587fe6060f1SDimitry Andric 1588fe6060f1SDimitry Andric if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1589fe6060f1SDimitry Andric return; 1590fe6060f1SDimitry Andric 1591fe6060f1SDimitry Andric // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1592fe6060f1SDimitry Andric // since the single store user will be lowered as part of this. 1593fe6060f1SDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(A); 1594fe6060f1SDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(B); 15955ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 15965ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) { 15975ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise 15985ffd83dbSDimitry Andric // we create invalid IR. 1599fe6060f1SDimitry Andric SetVector<Value *> WorkList; 1600fe6060f1SDimitry Andric WorkList.insert(Store->getOperand(1)); 1601fe6060f1SDimitry Andric SmallVector<Instruction *> ToHoist; 1602fe6060f1SDimitry Andric for (unsigned I = 0; I != WorkList.size(); ++I) { 1603fe6060f1SDimitry Andric Value *Current = WorkList[I]; 1604fe6060f1SDimitry Andric auto *CurrI = dyn_cast<Instruction>(Current); 1605fe6060f1SDimitry Andric if (!CurrI) 1606fe6060f1SDimitry Andric continue; 1607fe6060f1SDimitry Andric if (isa<PHINode>(CurrI)) 16085ffd83dbSDimitry Andric return; 1609fe6060f1SDimitry Andric if (DT->dominates(CurrI, MatMul)) 1610fe6060f1SDimitry Andric continue; 1611fe6060f1SDimitry Andric if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1612fe6060f1SDimitry Andric return; 1613fe6060f1SDimitry Andric ToHoist.push_back(CurrI); 1614fe6060f1SDimitry Andric WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1615fe6060f1SDimitry Andric } 1616fe6060f1SDimitry Andric 1617fe6060f1SDimitry Andric sort(ToHoist, [this](Instruction *A, Instruction *B) { 1618fe6060f1SDimitry Andric return DT->dominates(A, B); 1619fe6060f1SDimitry Andric }); 1620fe6060f1SDimitry Andric for (Instruction *I : ToHoist) 1621fe6060f1SDimitry Andric I->moveBefore(MatMul); 16225ffd83dbSDimitry Andric 16235ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 16245ffd83dbSDimitry Andric return; 16255ffd83dbSDimitry Andric } 16265ffd83dbSDimitry Andric } 16275ffd83dbSDimitry Andric 1628480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 1629480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 1630480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 1631480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1632480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1633480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1634480093f4SDimitry Andric 16355ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 16365ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1637e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Rhs.getElementType() && 1638e8d8bef9SDimitry Andric "Matrix multiply argument element types do not match."); 1639480093f4SDimitry Andric 1640480093f4SDimitry Andric const unsigned R = LShape.NumRows; 1641480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 16425ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows); 1643480093f4SDimitry Andric 1644480093f4SDimitry Andric // Initialize the output 16455ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType); 1646e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Result.getElementType() && 1647e8d8bef9SDimitry Andric "Matrix multiply result element type does not match arguments."); 1648480093f4SDimitry Andric 1649fe6060f1SDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1650fe6060f1SDimitry Andric getFastMathFlags(MatMul)); 1651480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1652480093f4SDimitry Andric } 1653480093f4SDimitry Andric 1654480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 1655480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 16565ffd83dbSDimitry Andric MatrixTy Result; 1657480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1658480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 1659480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1660480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 16615ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1662480093f4SDimitry Andric 16635ffd83dbSDimitry Andric const unsigned NewNumVecs = 16645ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 16655ffd83dbSDimitry Andric const unsigned NewNumElts = 16665ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1667480093f4SDimitry Andric 16685ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) { 16695ffd83dbSDimitry Andric // Build a single result vector. First initialize it. 1670*81ad6265SDimitry Andric Value *ResultVector = PoisonValue::get( 16715ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 16725ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector. 16735ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) { 16745ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I); 16755ffd83dbSDimitry Andric // Row and column indices are transposed. 16765ffd83dbSDimitry Andric ResultVector = 16775ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1678480093f4SDimitry Andric } 16795ffd83dbSDimitry Andric Result.addVector(ResultVector); 1680480093f4SDimitry Andric } 1681480093f4SDimitry Andric 16825ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we 16835ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not 16845ffd83dbSDimitry Andric // account for later simplifications/combines. 16855ffd83dbSDimitry Andric finalizeLowering( 16865ffd83dbSDimitry Andric Inst, 1687fe6060f1SDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1688fe6060f1SDimitry Andric .addNumExposedTransposes(1), 16895ffd83dbSDimitry Andric Builder); 1690480093f4SDimitry Andric } 1691480093f4SDimitry Andric 1692480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 16935ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1694480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1695480093f4SDimitry Andric if (I == ShapeMap.end()) 1696480093f4SDimitry Andric return false; 1697480093f4SDimitry Andric 16985ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(), 16995ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 17005ffd83dbSDimitry Andric I->second); 1701480093f4SDimitry Andric return true; 1702480093f4SDimitry Andric } 1703480093f4SDimitry Andric 17045ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1705480093f4SDimitry Andric IRBuilder<> &Builder) { 1706480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 1707480093f4SDimitry Andric if (I == ShapeMap.end()) 1708480093f4SDimitry Andric return false; 1709480093f4SDimitry Andric 17105ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 17115ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 17125ffd83dbSDimitry Andric I->second); 1713480093f4SDimitry Andric return true; 1714480093f4SDimitry Andric } 1715480093f4SDimitry Andric 1716480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 1717480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 1718480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1719480093f4SDimitry Andric if (I == ShapeMap.end()) 1720480093f4SDimitry Andric return false; 1721480093f4SDimitry Andric 1722480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 1723480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 1724480093f4SDimitry Andric 1725480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1726480093f4SDimitry Andric ShapeInfo &Shape = I->second; 1727480093f4SDimitry Andric 17285ffd83dbSDimitry Andric MatrixTy Result; 17295ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder); 17305ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder); 17315ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 17325ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 17335ffd83dbSDimitry Andric "operands must agree on matrix layout"); 1734480093f4SDimitry Andric 1735fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 1736fe6060f1SDimitry Andric 17375ffd83dbSDimitry Andric // Helper to perform binary op on vectors. 17385ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1739480093f4SDimitry Andric switch (Inst->getOpcode()) { 1740480093f4SDimitry Andric case Instruction::Add: 1741480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 1742480093f4SDimitry Andric case Instruction::Mul: 1743480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 1744480093f4SDimitry Andric case Instruction::Sub: 1745480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 1746480093f4SDimitry Andric case Instruction::FAdd: 1747480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 1748480093f4SDimitry Andric case Instruction::FMul: 1749480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 1750480093f4SDimitry Andric case Instruction::FSub: 1751480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 1752480093f4SDimitry Andric default: 1753480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 1754480093f4SDimitry Andric } 1755480093f4SDimitry Andric }; 1756480093f4SDimitry Andric 17575ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 17585ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 17595ffd83dbSDimitry Andric 17605ffd83dbSDimitry Andric finalizeLowering(Inst, 17615ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 17625ffd83dbSDimitry Andric Result.getNumVectors()), 17635ffd83dbSDimitry Andric Builder); 1764480093f4SDimitry Andric return true; 1765480093f4SDimitry Andric } 17665ffd83dbSDimitry Andric 1767e8d8bef9SDimitry Andric /// Lower unary operators, if shape information is available. 1768e8d8bef9SDimitry Andric bool VisitUnaryOperator(UnaryOperator *Inst) { 1769e8d8bef9SDimitry Andric auto I = ShapeMap.find(Inst); 1770e8d8bef9SDimitry Andric if (I == ShapeMap.end()) 1771e8d8bef9SDimitry Andric return false; 1772e8d8bef9SDimitry Andric 1773e8d8bef9SDimitry Andric Value *Op = Inst->getOperand(0); 1774e8d8bef9SDimitry Andric 1775e8d8bef9SDimitry Andric IRBuilder<> Builder(Inst); 1776e8d8bef9SDimitry Andric ShapeInfo &Shape = I->second; 1777e8d8bef9SDimitry Andric 1778e8d8bef9SDimitry Andric MatrixTy Result; 1779e8d8bef9SDimitry Andric MatrixTy M = getMatrix(Op, Shape, Builder); 1780e8d8bef9SDimitry Andric 1781fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst)); 1782fe6060f1SDimitry Andric 1783e8d8bef9SDimitry Andric // Helper to perform unary op on vectors. 1784e8d8bef9SDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1785e8d8bef9SDimitry Andric switch (Inst->getOpcode()) { 1786e8d8bef9SDimitry Andric case Instruction::FNeg: 1787e8d8bef9SDimitry Andric return Builder.CreateFNeg(Op); 1788e8d8bef9SDimitry Andric default: 1789e8d8bef9SDimitry Andric llvm_unreachable("Unsupported unary operator for matrix"); 1790e8d8bef9SDimitry Andric } 1791e8d8bef9SDimitry Andric }; 1792e8d8bef9SDimitry Andric 1793e8d8bef9SDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1794e8d8bef9SDimitry Andric Result.addVector(BuildVectorOp(M.getVector(I))); 1795e8d8bef9SDimitry Andric 1796e8d8bef9SDimitry Andric finalizeLowering(Inst, 1797e8d8bef9SDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1798e8d8bef9SDimitry Andric Result.getNumVectors()), 1799e8d8bef9SDimitry Andric Builder); 1800e8d8bef9SDimitry Andric return true; 1801e8d8bef9SDimitry Andric } 1802e8d8bef9SDimitry Andric 18035ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently 18045ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and 18055ffd83dbSDimitry Andric /// linearizing bottom up. 18065ffd83dbSDimitry Andric struct ExprLinearizer { 18075ffd83dbSDimitry Andric unsigned LengthToBreak = 100; 18085ffd83dbSDimitry Andric std::string Str; 18095ffd83dbSDimitry Andric raw_string_ostream Stream; 18105ffd83dbSDimitry Andric unsigned LineLength = 0; 18115ffd83dbSDimitry Andric const DataLayout &DL; 18125ffd83dbSDimitry Andric 18135ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify 18145ffd83dbSDimitry Andric /// matrix instructions. 18155ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 18165ffd83dbSDimitry Andric 18175ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is 18185ffd83dbSDimitry Andric /// part of. 18195ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 18205ffd83dbSDimitry Andric 18215ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram. 18225ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram; 18235ffd83dbSDimitry Andric 18245ffd83dbSDimitry Andric /// Leaf node of the expression to linearize. 18255ffd83dbSDimitry Andric Value *Leaf; 18265ffd83dbSDimitry Andric 18275ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing 18285ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused). 18295ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 18305ffd83dbSDimitry Andric 18315ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL, 18325ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix, 18335ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 18345ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 18355ffd83dbSDimitry Andric Value *Leaf) 183604eeddc0SDimitry Andric : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 18375ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 18385ffd83dbSDimitry Andric 18395ffd83dbSDimitry Andric void indent(unsigned N) { 18405ffd83dbSDimitry Andric LineLength += N; 18415ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++) 18425ffd83dbSDimitry Andric Stream << " "; 18435ffd83dbSDimitry Andric } 18445ffd83dbSDimitry Andric 18455ffd83dbSDimitry Andric void lineBreak() { 18465ffd83dbSDimitry Andric Stream << "\n"; 18475ffd83dbSDimitry Andric LineLength = 0; 18485ffd83dbSDimitry Andric } 18495ffd83dbSDimitry Andric 18505ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) { 18515ffd83dbSDimitry Andric if (LineLength >= LengthToBreak) 18525ffd83dbSDimitry Andric lineBreak(); 18535ffd83dbSDimitry Andric 18545ffd83dbSDimitry Andric if (LineLength == 0) 18555ffd83dbSDimitry Andric indent(Indent); 18565ffd83dbSDimitry Andric } 18575ffd83dbSDimitry Andric 18585ffd83dbSDimitry Andric void write(StringRef S) { 18595ffd83dbSDimitry Andric LineLength += S.size(); 18605ffd83dbSDimitry Andric Stream << S; 18615ffd83dbSDimitry Andric } 18625ffd83dbSDimitry Andric 18635ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) { 18645ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V)) 18655ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr); 18665ffd83dbSDimitry Andric else if (V->getType()->isPointerTy()) 1867e8d8bef9SDimitry Andric return getUnderlyingObject(V); 18685ffd83dbSDimitry Andric return V; 18695ffd83dbSDimitry Andric } 18705ffd83dbSDimitry Andric 18715ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram. 18725ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 18735ffd83dbSDimitry Andric 18745ffd83dbSDimitry Andric /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 18755ffd83dbSDimitry Andric /// \p SS. 18765ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 18775ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V); 18785ffd83dbSDimitry Andric if (M == Inst2Matrix.end()) 18795ffd83dbSDimitry Andric SS << "unknown"; 18805ffd83dbSDimitry Andric else { 18815ffd83dbSDimitry Andric SS << M->second.getNumRows(); 18825ffd83dbSDimitry Andric SS << "x"; 18835ffd83dbSDimitry Andric SS << M->second.getNumColumns(); 18845ffd83dbSDimitry Andric } 18855ffd83dbSDimitry Andric } 18865ffd83dbSDimitry Andric 18875ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.* 18885ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input 18895ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name. 18905ffd83dbSDimitry Andric void writeFnName(CallInst *CI) { 18915ffd83dbSDimitry Andric if (!CI->getCalledFunction()) 18925ffd83dbSDimitry Andric write("<no called fn>"); 18935ffd83dbSDimitry Andric else { 18945ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName(); 18955ffd83dbSDimitry Andric if (!Name.startswith("llvm.matrix")) { 18965ffd83dbSDimitry Andric write(Name); 18975ffd83dbSDimitry Andric return; 18985ffd83dbSDimitry Andric } 189904eeddc0SDimitry Andric auto *II = cast<IntrinsicInst>(CI); 1900fe6060f1SDimitry Andric write(Intrinsic::getBaseName(II->getIntrinsicID()) 19015ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size())); 19025ffd83dbSDimitry Andric write("."); 1903e8d8bef9SDimitry Andric std::string Tmp; 19045ffd83dbSDimitry Andric raw_string_ostream SS(Tmp); 19055ffd83dbSDimitry Andric 19065ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 19075ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 19085ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19095ffd83dbSDimitry Andric SS << "."; 19105ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS); 19115ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19125ffd83dbSDimitry Andric break; 19135ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 19145ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19155ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19165ffd83dbSDimitry Andric break; 19175ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 19185ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS); 19195ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 19205ffd83dbSDimitry Andric break; 19215ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 19225ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 19235ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType(); 19245ffd83dbSDimitry Andric break; 19255ffd83dbSDimitry Andric default: 19265ffd83dbSDimitry Andric llvm_unreachable("Unhandled case"); 19275ffd83dbSDimitry Andric } 19285ffd83dbSDimitry Andric SS.flush(); 19295ffd83dbSDimitry Andric write(Tmp); 19305ffd83dbSDimitry Andric } 19315ffd83dbSDimitry Andric } 19325ffd83dbSDimitry Andric 19335ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const { 19345ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 19355ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 19365ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 19375ffd83dbSDimitry Andric return 3; 19385ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 19395ffd83dbSDimitry Andric return 2; 19405ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 19415ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 19425ffd83dbSDimitry Andric return 3; 19435ffd83dbSDimitry Andric default: 19445ffd83dbSDimitry Andric return 0; 19455ffd83dbSDimitry Andric } 19465ffd83dbSDimitry Andric } 19475ffd83dbSDimitry Andric return 0; 19485ffd83dbSDimitry Andric } 19495ffd83dbSDimitry Andric 19505ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an 19515ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we 19525ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values. 19535ffd83dbSDimitry Andric void write(Value *V) { 19545ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V); 19555ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) { 19565ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) { 19575ffd83dbSDimitry Andric Stream << "stack addr"; 19585ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size(); 19595ffd83dbSDimitry Andric } else { 19605ffd83dbSDimitry Andric Stream << "addr"; 19615ffd83dbSDimitry Andric LineLength += StringRef("addr").size(); 19625ffd83dbSDimitry Andric } 19635ffd83dbSDimitry Andric if (!V->getName().empty()) { 19645ffd83dbSDimitry Andric Stream << " %" << V->getName() << ""; 19655ffd83dbSDimitry Andric LineLength += V->getName().size() + 2; 19665ffd83dbSDimitry Andric } 19675ffd83dbSDimitry Andric return; 19685ffd83dbSDimitry Andric } 19695ffd83dbSDimitry Andric 19705ffd83dbSDimitry Andric std::string Tmp; 19715ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp); 19725ffd83dbSDimitry Andric 19735ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V)) 19745ffd83dbSDimitry Andric TmpStream << CI->getValue(); 19755ffd83dbSDimitry Andric else if (isa<Constant>(V)) 19765ffd83dbSDimitry Andric TmpStream << "constant"; 19775ffd83dbSDimitry Andric else { 19785ffd83dbSDimitry Andric if (isMatrix(V)) 19795ffd83dbSDimitry Andric TmpStream << "matrix"; 19805ffd83dbSDimitry Andric else 19815ffd83dbSDimitry Andric TmpStream << "scalar"; 19825ffd83dbSDimitry Andric } 19835ffd83dbSDimitry Andric TmpStream.flush(); 19845ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim()); 19855ffd83dbSDimitry Andric LineLength += Tmp.size(); 19865ffd83dbSDimitry Andric Stream << Tmp; 19875ffd83dbSDimitry Andric } 19885ffd83dbSDimitry Andric 19895ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent. 19905ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused) 19915ffd83dbSDimitry Andric /// at the re-used root instruction. 19925ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 19935ffd83dbSDimitry Andric bool ParentShared) { 19945ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr); 19955ffd83dbSDimitry Andric maybeIndent(Indent); 19965ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops; 19975ffd83dbSDimitry Andric 19985ffd83dbSDimitry Andric // Is Expr shared with other expression leaves? 19995ffd83dbSDimitry Andric bool ExprShared = false; 20005ffd83dbSDimitry Andric 20015ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required. 20025ffd83dbSDimitry Andric if (!ParentShared) { 20035ffd83dbSDimitry Andric auto SI = Shared.find(Expr); 20045ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf)); 20055ffd83dbSDimitry Andric 20065ffd83dbSDimitry Andric for (Value *S : SI->second) { 20075ffd83dbSDimitry Andric if (S == Leaf) 20085ffd83dbSDimitry Andric continue; 20095ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 20105ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) + 20115ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " ("); 20125ffd83dbSDimitry Andric } 20135ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1; 20145ffd83dbSDimitry Andric } 20155ffd83dbSDimitry Andric 20165ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second; 20175ffd83dbSDimitry Andric if (Reused && !ParentReused) 20185ffd83dbSDimitry Andric write("(reused) "); 20195ffd83dbSDimitry Andric 20205ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) { 20215ffd83dbSDimitry Andric writeFnName(CI); 20225ffd83dbSDimitry Andric 20235ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 20245ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) { 20255ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from 20265ffd83dbSDimitry Andric // non-matrix ops. 20275ffd83dbSDimitry Andric write("matrix"); 20285ffd83dbSDimitry Andric return; 20295ffd83dbSDimitry Andric } else { 20305ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end()); 20315ffd83dbSDimitry Andric write(std::string(I->getOpcodeName())); 20325ffd83dbSDimitry Andric } 20335ffd83dbSDimitry Andric 20345ffd83dbSDimitry Andric write(std::string("(")); 20355ffd83dbSDimitry Andric 20365ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1; 20375ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 20385ffd83dbSDimitry Andric NumOpsToBreak = 2; 20395ffd83dbSDimitry Andric 20405ffd83dbSDimitry Andric for (Value *Op : Ops) { 20415ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak) 20425ffd83dbSDimitry Andric lineBreak(); 20435ffd83dbSDimitry Andric 20445ffd83dbSDimitry Andric maybeIndent(Indent + 1); 20455ffd83dbSDimitry Andric if (isMatrix(Op)) 20465ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared); 20475ffd83dbSDimitry Andric else 20485ffd83dbSDimitry Andric write(Op); 20495ffd83dbSDimitry Andric if (Op != Ops.back()) 20505ffd83dbSDimitry Andric write(", "); 20515ffd83dbSDimitry Andric } 20525ffd83dbSDimitry Andric 20535ffd83dbSDimitry Andric write(")"); 20545ffd83dbSDimitry Andric } 20555ffd83dbSDimitry Andric 20565ffd83dbSDimitry Andric const std::string &getResult() { 20575ffd83dbSDimitry Andric Stream.flush(); 20585ffd83dbSDimitry Andric return Str; 20595ffd83dbSDimitry Andric } 20605ffd83dbSDimitry Andric }; 20615ffd83dbSDimitry Andric 20625ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks 20635ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used: 20645ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the 20655ffd83dbSDimitry Andric /// DISubprograms they are contained in. 20665ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in 20675ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 20685ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix 20695ffd83dbSDimitry Andric // users (like stores) in the current subprogram. 20705ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the 20715ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive 20725ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note 20735ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions 20745ffd83dbSDimitry Andric /// are explicitly marked as shared(). 20755ffd83dbSDimitry Andric struct RemarkGenerator { 20765ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 20775ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 20785ffd83dbSDimitry Andric Function &Func; 20795ffd83dbSDimitry Andric const DataLayout &DL; 20805ffd83dbSDimitry Andric 20815ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 20825ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func) 20835ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 20845ffd83dbSDimitry Andric DL(Func.getParent()->getDataLayout()) {} 20855ffd83dbSDimitry Andric 20865ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 20875ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in 20885ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores. 20895ffd83dbSDimitry Andric SmallVector<Value *, 4> 20905ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 20915ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves; 20925ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram) 20935ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() || 20945ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 20955ffd83dbSDimitry Andric return ExprsInSubprogram.count(U); 20965ffd83dbSDimitry Andric })) 20975ffd83dbSDimitry Andric Leaves.push_back(Expr); 20985ffd83dbSDimitry Andric return Leaves; 20995ffd83dbSDimitry Andric } 21005ffd83dbSDimitry Andric 21015ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 21025ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to 21035ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram. 21045ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V, 21055ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 21065ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 21075ffd83dbSDimitry Andric 21085ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V)) 21095ffd83dbSDimitry Andric return; 21105ffd83dbSDimitry Andric 21115ffd83dbSDimitry Andric auto I = Shared.insert({V, {}}); 21125ffd83dbSDimitry Andric I.first->second.insert(Leaf); 21135ffd83dbSDimitry Andric 21145ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values()) 21155ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 21165ffd83dbSDimitry Andric } 21175ffd83dbSDimitry Andric 21185ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression 21195ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once. 21205ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 21215ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy> 21225ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 21235ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 21245ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 21255ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root)) 21265ffd83dbSDimitry Andric return {}; 21275ffd83dbSDimitry Andric 21285ffd83dbSDimitry Andric // Already counted this expression. Stop. 21295ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second) 21305ffd83dbSDimitry Andric return {}; 21315ffd83dbSDimitry Andric 21325ffd83dbSDimitry Andric OpInfoTy SharedCount; 21335ffd83dbSDimitry Andric OpInfoTy Count; 21345ffd83dbSDimitry Andric 21355ffd83dbSDimitry Andric auto I = Shared.find(Root); 21365ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root); 21375ffd83dbSDimitry Andric if (I->second.size() == 1) 21385ffd83dbSDimitry Andric Count = CM->second.getOpInfo(); 21395ffd83dbSDimitry Andric else 21405ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo(); 21415ffd83dbSDimitry Andric 21425ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) { 21435ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 21445ffd83dbSDimitry Andric Count += C.first; 21455ffd83dbSDimitry Andric SharedCount += C.second; 21465ffd83dbSDimitry Andric } 21475ffd83dbSDimitry Andric return {Count, SharedCount}; 21485ffd83dbSDimitry Andric } 21495ffd83dbSDimitry Andric 21505ffd83dbSDimitry Andric void emitRemarks() { 21515ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 21525ffd83dbSDimitry Andric return; 21535ffd83dbSDimitry Andric 21545ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing 21555ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we 21565ffd83dbSDimitry Andric // only map them to the containing function. 21575ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 21585ffd83dbSDimitry Andric for (auto &KV : Inst2Matrix) { 21595ffd83dbSDimitry Andric if (Func.getSubprogram()) { 21605ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first); 21615ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc(); 21625ffd83dbSDimitry Andric while (Context) { 21635ffd83dbSDimitry Andric auto I = 21645ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 21655ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 21665ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 21675ffd83dbSDimitry Andric } 21685ffd83dbSDimitry Andric } else { 21695ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}}); 21705ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 21715ffd83dbSDimitry Andric } 21725ffd83dbSDimitry Andric } 21735ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) { 21745ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 21755ffd83dbSDimitry Andric KV.second.end()); 21765ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram); 21775ffd83dbSDimitry Andric 21785ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 21795ffd83dbSDimitry Andric for (Value *Leaf : Leaves) 21805ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 21815ffd83dbSDimitry Andric 21825ffd83dbSDimitry Andric // Generate remarks for each leaf. 21835ffd83dbSDimitry Andric for (auto *L : Leaves) { 21845ffd83dbSDimitry Andric 21855ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 21865ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 21875ffd83dbSDimitry Andric while (Context) { 21885ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) { 21895ffd83dbSDimitry Andric Loc = Context; 21905ffd83dbSDimitry Andric break; 21915ffd83dbSDimitry Andric } 21925ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 21935ffd83dbSDimitry Andric } 21945ffd83dbSDimitry Andric 21955ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 21965ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts; 21975ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) = 21985ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 21995ffd83dbSDimitry Andric 22005ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 22015ffd83dbSDimitry Andric cast<Instruction>(L)->getParent()); 22025ffd83dbSDimitry Andric 22035ffd83dbSDimitry Andric Rem << "Lowered with "; 22045ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 22055ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 22065ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps) 2207fe6060f1SDimitry Andric << " compute ops, " 2208fe6060f1SDimitry Andric << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2209fe6060f1SDimitry Andric << " exposed transposes"; 22105ffd83dbSDimitry Andric 22115ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 22125ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) { 22135ffd83dbSDimitry Andric Rem << ",\nadditionally " 22145ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 22155ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 22165ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 22175ffd83dbSDimitry Andric << " compute ops" 22185ffd83dbSDimitry Andric << " are shared with other expressions"; 22195ffd83dbSDimitry Andric } 22205ffd83dbSDimitry Andric 22215ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 22225ffd83dbSDimitry Andric ORE.emit(Rem); 22235ffd83dbSDimitry Andric } 22245ffd83dbSDimitry Andric } 22255ffd83dbSDimitry Andric } 22265ffd83dbSDimitry Andric 22275ffd83dbSDimitry Andric std::string 22285ffd83dbSDimitry Andric linearize(Value *L, 22295ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 22305ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 22315ffd83dbSDimitry Andric const DataLayout &DL) { 22325ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 22335ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false); 22345ffd83dbSDimitry Andric return Lin.getResult(); 22355ffd83dbSDimitry Andric } 22365ffd83dbSDimitry Andric }; 2237480093f4SDimitry Andric }; 2238480093f4SDimitry Andric } // namespace 2239480093f4SDimitry Andric 2240480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2241480093f4SDimitry Andric FunctionAnalysisManager &AM) { 2242480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2243e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE = nullptr; 2244e8d8bef9SDimitry Andric AAResults *AA = nullptr; 2245e8d8bef9SDimitry Andric DominatorTree *DT = nullptr; 2246e8d8bef9SDimitry Andric LoopInfo *LI = nullptr; 2247e8d8bef9SDimitry Andric 2248e8d8bef9SDimitry Andric if (!Minimal) { 2249e8d8bef9SDimitry Andric ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2250e8d8bef9SDimitry Andric AA = &AM.getResult<AAManager>(F); 2251e8d8bef9SDimitry Andric DT = &AM.getResult<DominatorTreeAnalysis>(F); 2252e8d8bef9SDimitry Andric LI = &AM.getResult<LoopAnalysis>(F); 2253e8d8bef9SDimitry Andric } 22545ffd83dbSDimitry Andric 22555ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2256480093f4SDimitry Andric if (LMT.Visit()) { 2257480093f4SDimitry Andric PreservedAnalyses PA; 2258e8d8bef9SDimitry Andric if (!Minimal) { 2259e8d8bef9SDimitry Andric PA.preserve<LoopAnalysis>(); 2260e8d8bef9SDimitry Andric PA.preserve<DominatorTreeAnalysis>(); 2261e8d8bef9SDimitry Andric } 2262480093f4SDimitry Andric return PA; 2263480093f4SDimitry Andric } 2264480093f4SDimitry Andric return PreservedAnalyses::all(); 2265480093f4SDimitry Andric } 2266480093f4SDimitry Andric 2267349cc55cSDimitry Andric void LowerMatrixIntrinsicsPass::printPipeline( 2268349cc55cSDimitry Andric raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2269349cc55cSDimitry Andric static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2270349cc55cSDimitry Andric OS, MapClassName2PassName); 2271349cc55cSDimitry Andric OS << "<"; 2272349cc55cSDimitry Andric if (Minimal) 2273349cc55cSDimitry Andric OS << "minimal"; 2274349cc55cSDimitry Andric OS << ">"; 2275349cc55cSDimitry Andric } 2276349cc55cSDimitry Andric 2277480093f4SDimitry Andric namespace { 2278480093f4SDimitry Andric 2279480093f4SDimitry Andric class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2280480093f4SDimitry Andric public: 2281480093f4SDimitry Andric static char ID; 2282480093f4SDimitry Andric 2283480093f4SDimitry Andric LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2284480093f4SDimitry Andric initializeLowerMatrixIntrinsicsLegacyPassPass( 2285480093f4SDimitry Andric *PassRegistry::getPassRegistry()); 2286480093f4SDimitry Andric } 2287480093f4SDimitry Andric 2288480093f4SDimitry Andric bool runOnFunction(Function &F) override { 22895ffd83dbSDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 22905ffd83dbSDimitry Andric auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 22915ffd83dbSDimitry Andric auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 22925ffd83dbSDimitry Andric auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 22935ffd83dbSDimitry Andric auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2294e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2295480093f4SDimitry Andric bool C = LMT.Visit(); 2296480093f4SDimitry Andric return C; 2297480093f4SDimitry Andric } 2298480093f4SDimitry Andric 2299480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2300480093f4SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 23015ffd83dbSDimitry Andric AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 23025ffd83dbSDimitry Andric AU.addRequired<AAResultsWrapperPass>(); 23035ffd83dbSDimitry Andric AU.addRequired<DominatorTreeWrapperPass>(); 23045ffd83dbSDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 23055ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 23065ffd83dbSDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 2307480093f4SDimitry Andric } 2308480093f4SDimitry Andric }; 2309480093f4SDimitry Andric } // namespace 2310480093f4SDimitry Andric 2311480093f4SDimitry Andric static const char pass_name[] = "Lower the matrix intrinsics"; 2312480093f4SDimitry Andric char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2313480093f4SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2314480093f4SDimitry Andric false, false) 23155ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 23165ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 23175ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 23185ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2319480093f4SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2320480093f4SDimitry Andric false, false) 2321480093f4SDimitry Andric 2322480093f4SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsPass() { 2323480093f4SDimitry Andric return new LowerMatrixIntrinsicsLegacyPass(); 2324480093f4SDimitry Andric } 2325e8d8bef9SDimitry Andric 2326e8d8bef9SDimitry Andric namespace { 2327e8d8bef9SDimitry Andric 2328e8d8bef9SDimitry Andric /// A lightweight version of the matrix lowering pass that only requires TTI. 2329e8d8bef9SDimitry Andric /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2330e8d8bef9SDimitry Andric /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2331e8d8bef9SDimitry Andric /// example with -O0. 2332e8d8bef9SDimitry Andric class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2333e8d8bef9SDimitry Andric public: 2334e8d8bef9SDimitry Andric static char ID; 2335e8d8bef9SDimitry Andric 2336e8d8bef9SDimitry Andric LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2337e8d8bef9SDimitry Andric initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2338e8d8bef9SDimitry Andric *PassRegistry::getPassRegistry()); 2339e8d8bef9SDimitry Andric } 2340e8d8bef9SDimitry Andric 2341e8d8bef9SDimitry Andric bool runOnFunction(Function &F) override { 2342e8d8bef9SDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2343e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2344e8d8bef9SDimitry Andric bool C = LMT.Visit(); 2345e8d8bef9SDimitry Andric return C; 2346e8d8bef9SDimitry Andric } 2347e8d8bef9SDimitry Andric 2348e8d8bef9SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2349e8d8bef9SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 2350e8d8bef9SDimitry Andric AU.setPreservesCFG(); 2351e8d8bef9SDimitry Andric } 2352e8d8bef9SDimitry Andric }; 2353e8d8bef9SDimitry Andric } // namespace 2354e8d8bef9SDimitry Andric 2355e8d8bef9SDimitry Andric static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2356e8d8bef9SDimitry Andric char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2357e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2358e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, 2359e8d8bef9SDimitry Andric false, false) 2360e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2361e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2362e8d8bef9SDimitry Andric false) 2363e8d8bef9SDimitry Andric 2364e8d8bef9SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2365e8d8bef9SDimitry Andric return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2366e8d8bef9SDimitry Andric } 2367