1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10480093f4SDimitry Andric // 11480093f4SDimitry Andric // TODO: 125ffd83dbSDimitry Andric // * Improve fusion: 135ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 145ffd83dbSDimitry Andric // transposed. 155ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns 165ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias. 17480093f4SDimitry Andric // 18480093f4SDimitry Andric //===----------------------------------------------------------------------===// 19480093f4SDimitry Andric 20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21480093f4SDimitry Andric #include "llvm/ADT/GraphTraits.h" 22480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 23480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 245ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 255ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 265ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 285ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 29480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 30480093f4SDimitry Andric #include "llvm/IR/CFG.h" 31480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 325ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h" 33480093f4SDimitry Andric #include "llvm/IR/Function.h" 34480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 35480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 36480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 37480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 38480093f4SDimitry Andric #include "llvm/InitializePasses.h" 39480093f4SDimitry Andric #include "llvm/Pass.h" 405ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h" 415ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h" 42480093f4SDimitry Andric #include "llvm/Support/Debug.h" 43480093f4SDimitry Andric #include "llvm/Transforms/Scalar.h" 445ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 46*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 47480093f4SDimitry Andric 48480093f4SDimitry Andric using namespace llvm; 49480093f4SDimitry Andric using namespace PatternMatch; 50480093f4SDimitry Andric 51480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 52480093f4SDimitry Andric 535ffd83dbSDimitry Andric static cl::opt<bool> EnableShapePropagation( 545ffd83dbSDimitry Andric "matrix-propagate-shape", cl::init(true), cl::Hidden, 555ffd83dbSDimitry Andric cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 565ffd83dbSDimitry Andric "instructions.")); 57480093f4SDimitry Andric 585ffd83dbSDimitry Andric static cl::opt<bool> 595ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 605ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions.")); 615ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles. 625ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize( 635ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 645ffd83dbSDimitry Andric cl::desc( 655ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles.")); 66*e8d8bef9SDimitry Andric static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 67*e8d8bef9SDimitry Andric cl::Hidden, 68*e8d8bef9SDimitry Andric cl::desc("Generate loop nest for tiling.")); 695ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion( 705ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden, 715ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable.")); 72480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 73480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 74480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 75480093f4SDimitry Andric "result in different results, due to less rounding error.")); 76480093f4SDimitry Andric 775ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 785ffd83dbSDimitry Andric 795ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout( 805ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 815ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"), 825ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 835ffd83dbSDimitry Andric "Use column-major layout"), 845ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 855ffd83dbSDimitry Andric "Use row-major layout"))); 865ffd83dbSDimitry Andric 875ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the 885ffd83dbSDimitry Andric /// attached subprogram for a local scope. 895ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) { 905ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 915ffd83dbSDimitry Andric return Subprogram; 925ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram(); 935ffd83dbSDimitry Andric } 945ffd83dbSDimitry Andric 95480093f4SDimitry Andric namespace { 96480093f4SDimitry Andric 975ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 985ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 995ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors. 1005ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements. 1015ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column 1025ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column 1035ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function 1045ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the 1055ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix). 106480093f4SDimitry Andric // 1075ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below 108480093f4SDimitry Andric // 109480093f4SDimitry Andric // 0 1 2 3 110480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 111480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 112480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 113480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 114480093f4SDimitry Andric 115480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 116480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 1175ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns 118480093f4SDimitry Andric // of the sub-matrix. 119480093f4SDimitry Andric // 1205ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 121480093f4SDimitry Andric // -> just returns Base 1225ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 123480093f4SDimitry Andric // -> returns Base + (1 * 4) 1245ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 125480093f4SDimitry Andric // -> returns Base + (2 * 4) 126480093f4SDimitry Andric // 127480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 128480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 129480093f4SDimitry Andric // 130480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 131480093f4SDimitry Andric // Base Col 1 Col 2 132480093f4SDimitry Andric // | | | 133480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 134480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 135480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 136480093f4SDimitry Andric // 1375ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 1385ffd83dbSDimitry Andric unsigned NumElements, Type *EltType, 139480093f4SDimitry Andric IRBuilder<> &Builder) { 140480093f4SDimitry Andric 141480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 1425ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 1435ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector."); 144480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 145480093f4SDimitry Andric 1465ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride. 1475ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 148480093f4SDimitry Andric 1495ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation, 1505ffd83dbSDimitry Andric // if we select vector 0. 1515ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 1525ffd83dbSDimitry Andric VecStart = BasePtr; 153480093f4SDimitry Andric else 1545ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 155480093f4SDimitry Andric 1565ffd83dbSDimitry Andric // Cast elementwise vector start pointer to a pointer to a vector 1575ffd83dbSDimitry Andric // (EltType x NumElements)*. 1585ffd83dbSDimitry Andric auto *VecType = FixedVectorType::get(EltType, NumElements); 1595ffd83dbSDimitry Andric Type *VecPtrType = PointerType::get(VecType, AS); 1605ffd83dbSDimitry Andric return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 161480093f4SDimitry Andric } 162480093f4SDimitry Andric 163480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 164480093f4SDimitry Andric /// 165480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 166480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 167480093f4SDimitry Andric /// instructions. 1685ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout). 1695ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout. 170480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 171480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 172480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 173480093f4SDimitry Andric /// a set of column vectors, 1745ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which 1755ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we 1765ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the 1775ffd83dbSDimitry Andric /// intrinsics, this includes stores for example. 178480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 179480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 180480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 181480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 182480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 183480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 184480093f4SDimitry Andric /// obsolete instructions. 185480093f4SDimitry Andric /// 186480093f4SDimitry Andric class LowerMatrixIntrinsics { 187480093f4SDimitry Andric Function &Func; 188480093f4SDimitry Andric const DataLayout &DL; 189480093f4SDimitry Andric const TargetTransformInfo &TTI; 190*e8d8bef9SDimitry Andric AliasAnalysis *AA; 191*e8d8bef9SDimitry Andric DominatorTree *DT; 192*e8d8bef9SDimitry Andric LoopInfo *LI; 193*e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE; 194480093f4SDimitry Andric 1955ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 1965ffd83dbSDimitry Andric struct OpInfoTy { 1975ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix. 1985ffd83dbSDimitry Andric unsigned NumStores = 0; 1995ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix. 2005ffd83dbSDimitry Andric unsigned NumLoads = 0; 2015ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix. 2025ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 2035ffd83dbSDimitry Andric 2045ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) { 2055ffd83dbSDimitry Andric NumStores += RHS.NumStores; 2065ffd83dbSDimitry Andric NumLoads += RHS.NumLoads; 2075ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps; 2085ffd83dbSDimitry Andric return *this; 2095ffd83dbSDimitry Andric } 2105ffd83dbSDimitry Andric }; 2115ffd83dbSDimitry Andric 2125ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or 2135ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type. 2145ffd83dbSDimitry Andric class MatrixTy { 2155ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors; 2165ffd83dbSDimitry Andric 2175ffd83dbSDimitry Andric OpInfoTy OpInfo; 2185ffd83dbSDimitry Andric 2195ffd83dbSDimitry Andric bool IsColumnMajor = true; 220480093f4SDimitry Andric 221480093f4SDimitry Andric public: 2225ffd83dbSDimitry Andric MatrixTy() 2235ffd83dbSDimitry Andric : Vectors(), 2245ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2255ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors) 2265ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()), 2275ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 2285ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 2295ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 230480093f4SDimitry Andric 2315ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows; 2325ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J) 2335ffd83dbSDimitry Andric addVector(UndefValue::get(FixedVectorType::get( 2345ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns))); 235480093f4SDimitry Andric } 236480093f4SDimitry Andric 2375ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; } 2385ffd83dbSDimitry Andric Value *getColumn(unsigned i) const { 2395ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2405ffd83dbSDimitry Andric return Vectors[i]; 2415ffd83dbSDimitry Andric } 2425ffd83dbSDimitry Andric Value *getRow(unsigned i) const { 2435ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes"); 2445ffd83dbSDimitry Andric return Vectors[i]; 2455ffd83dbSDimitry Andric } 246480093f4SDimitry Andric 2475ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; } 248480093f4SDimitry Andric 249*e8d8bef9SDimitry Andric Type *getElementType() const { return getVectorTy()->getElementType(); } 2505ffd83dbSDimitry Andric 2515ffd83dbSDimitry Andric unsigned getNumVectors() const { 2525ffd83dbSDimitry Andric if (isColumnMajor()) 2535ffd83dbSDimitry Andric return getNumColumns(); 2545ffd83dbSDimitry Andric return getNumRows(); 2555ffd83dbSDimitry Andric } 2565ffd83dbSDimitry Andric 2575ffd83dbSDimitry Andric unsigned getNumColumns() const { 2585ffd83dbSDimitry Andric if (isColumnMajor()) 2595ffd83dbSDimitry Andric return Vectors.size(); 2605ffd83dbSDimitry Andric else { 2615ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2625ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2635ffd83dbSDimitry Andric } 2645ffd83dbSDimitry Andric } 2655ffd83dbSDimitry Andric unsigned getNumRows() const { 2665ffd83dbSDimitry Andric if (isColumnMajor()) { 2675ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 2685ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 2695ffd83dbSDimitry Andric } else 2705ffd83dbSDimitry Andric return Vectors.size(); 2715ffd83dbSDimitry Andric } 2725ffd83dbSDimitry Andric 2735ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); } 2745ffd83dbSDimitry Andric VectorType *getColumnTy() { 2755ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes"); 2765ffd83dbSDimitry Andric return getVectorTy(); 2775ffd83dbSDimitry Andric } 2785ffd83dbSDimitry Andric 279*e8d8bef9SDimitry Andric VectorType *getVectorTy() const { 2805ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType()); 2815ffd83dbSDimitry Andric } 282480093f4SDimitry Andric 283480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 2845ffd83dbSDimitry Andric assert(isColumnMajor() && 2855ffd83dbSDimitry Andric "columns() only supported for column-major matrixes"); 2865ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 287480093f4SDimitry Andric } 288480093f4SDimitry Andric 2895ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 2905ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end()); 2915ffd83dbSDimitry Andric } 2925ffd83dbSDimitry Andric 2935ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating 294480093f4SDimitry Andric /// them. 295480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 2965ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0] 2975ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors); 2985ffd83dbSDimitry Andric } 2995ffd83dbSDimitry Andric 3005ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) { 3015ffd83dbSDimitry Andric OpInfo.NumLoads += N; 3025ffd83dbSDimitry Andric return *this; 3035ffd83dbSDimitry Andric } 3045ffd83dbSDimitry Andric 3055ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 3065ffd83dbSDimitry Andric 3075ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) { 3085ffd83dbSDimitry Andric OpInfo.NumStores += N; 3095ffd83dbSDimitry Andric return *this; 3105ffd83dbSDimitry Andric } 3115ffd83dbSDimitry Andric 3125ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) { 3135ffd83dbSDimitry Andric OpInfo.NumComputeOps += N; 3145ffd83dbSDimitry Andric return *this; 3155ffd83dbSDimitry Andric } 3165ffd83dbSDimitry Andric 3175ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; } 3185ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; } 3195ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 3205ffd83dbSDimitry Andric 3215ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; } 3225ffd83dbSDimitry Andric 3235ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; } 3245ffd83dbSDimitry Andric 3255ffd83dbSDimitry Andric unsigned getStride() const { 3265ffd83dbSDimitry Andric if (isColumnMajor()) 3275ffd83dbSDimitry Andric return getNumRows(); 3285ffd83dbSDimitry Andric return getNumColumns(); 3295ffd83dbSDimitry Andric } 3305ffd83dbSDimitry Andric 3315ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 3325ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column 3335ffd83dbSDimitry Andric /// vector, otherwise from a row vector. 3345ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 3355ffd83dbSDimitry Andric IRBuilder<> &Builder) const { 3365ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 3375ffd83dbSDimitry Andric return Builder.CreateShuffleVector( 338*e8d8bef9SDimitry Andric Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 3395ffd83dbSDimitry Andric "block"); 340480093f4SDimitry Andric } 341480093f4SDimitry Andric }; 342480093f4SDimitry Andric 343480093f4SDimitry Andric struct ShapeInfo { 344480093f4SDimitry Andric unsigned NumRows; 345480093f4SDimitry Andric unsigned NumColumns; 346480093f4SDimitry Andric 3475ffd83dbSDimitry Andric bool IsColumnMajor; 3485ffd83dbSDimitry Andric 349480093f4SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 3505ffd83dbSDimitry Andric : NumRows(NumRows), NumColumns(NumColumns), 3515ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 352480093f4SDimitry Andric 353480093f4SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 3545ffd83dbSDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 3555ffd83dbSDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {} 356480093f4SDimitry Andric 357480093f4SDimitry Andric bool operator==(const ShapeInfo &other) { 358480093f4SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 359480093f4SDimitry Andric } 360480093f4SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 361480093f4SDimitry Andric 362480093f4SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 363480093f4SDimitry Andric /// are != 0. 364480093f4SDimitry Andric operator bool() const { 365480093f4SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 366480093f4SDimitry Andric return NumRows != 0; 367480093f4SDimitry Andric } 3685ffd83dbSDimitry Andric 3695ffd83dbSDimitry Andric unsigned getStride() const { 3705ffd83dbSDimitry Andric if (IsColumnMajor) 3715ffd83dbSDimitry Andric return NumRows; 3725ffd83dbSDimitry Andric return NumColumns; 3735ffd83dbSDimitry Andric } 3745ffd83dbSDimitry Andric 3755ffd83dbSDimitry Andric unsigned getNumVectors() const { 3765ffd83dbSDimitry Andric if (IsColumnMajor) 3775ffd83dbSDimitry Andric return NumColumns; 3785ffd83dbSDimitry Andric return NumRows; 3795ffd83dbSDimitry Andric } 380480093f4SDimitry Andric }; 381480093f4SDimitry Andric 382480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 383480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 384480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 3855ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the 386480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 387480093f4SDimitry Andric /// using shape information as well. 388480093f4SDimitry Andric DenseMap<Value *, ShapeInfo> ShapeMap; 389480093f4SDimitry Andric 390480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 391480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 392480093f4SDimitry Andric /// those need to be removed after we finished lowering. 393480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 394480093f4SDimitry Andric 395480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 3965ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 397480093f4SDimitry Andric 398480093f4SDimitry Andric public: 3995ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 400*e8d8bef9SDimitry Andric AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 401*e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE) 4025ffd83dbSDimitry Andric : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 4035ffd83dbSDimitry Andric LI(LI), ORE(ORE) {} 404480093f4SDimitry Andric 4055ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) { 4065ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type"); 4075ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(), 4085ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements()); 4095ffd83dbSDimitry Andric } 4105ffd83dbSDimitry Andric 4115ffd83dbSDimitry Andric // 4125ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on 4135ffd83dbSDimitry Andric /// \p VT * N. 4145ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) { 4155ffd83dbSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 4165ffd83dbSDimitry Andric double(TTI.getRegisterBitWidth(true))); 4175ffd83dbSDimitry Andric } 4185ffd83dbSDimitry Andric 4195ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to. 420480093f4SDimitry Andric /// 4215ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 4225ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 4235ffd83dbSDimitry Andric /// into vectors. 4245ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 4255ffd83dbSDimitry Andric IRBuilder<> &Builder) { 426480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 427480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 4285ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() == 4295ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns && 430480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 431480093f4SDimitry Andric 432480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 4335ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape 434480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 435480093f4SDimitry Andric // vector and split it later. 436480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 437480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 4385ffd83dbSDimitry Andric MatrixTy &M = Found->second; 439480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 440480093f4SDimitry Andric // information 441480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 442480093f4SDimitry Andric return M; 443480093f4SDimitry Andric 444480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 445480093f4SDimitry Andric } 446480093f4SDimitry Andric 447480093f4SDimitry Andric // Otherwise split MatrixVal. 448480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 4495ffd83dbSDimitry Andric for (unsigned MaskStart = 0; 4505ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 4515ffd83dbSDimitry Andric MaskStart += SI.getStride()) { 4525ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector( 453*e8d8bef9SDimitry Andric MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 4545ffd83dbSDimitry Andric "split"); 455480093f4SDimitry Andric SplitVecs.push_back(V); 456480093f4SDimitry Andric } 457480093f4SDimitry Andric 458480093f4SDimitry Andric return {SplitVecs}; 459480093f4SDimitry Andric } 460480093f4SDimitry Andric 461480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 462480093f4SDimitry Andric /// for instructions that support it. 463480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 464480093f4SDimitry Andric assert(Shape && "Shape not set"); 465480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 466480093f4SDimitry Andric return false; 467480093f4SDimitry Andric 468480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 469480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 470480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 471480093f4SDimitry Andric << SIter->second.NumRows << " " 472480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 473480093f4SDimitry Andric return false; 474480093f4SDimitry Andric } 475480093f4SDimitry Andric 476480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 477480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 478480093f4SDimitry Andric << " for " << *V << "\n"); 479480093f4SDimitry Andric return true; 480480093f4SDimitry Andric } 481480093f4SDimitry Andric 482480093f4SDimitry Andric bool isUniformShape(Value *V) { 483480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 484480093f4SDimitry Andric if (!I) 485480093f4SDimitry Andric return true; 486480093f4SDimitry Andric 487480093f4SDimitry Andric switch (I->getOpcode()) { 488480093f4SDimitry Andric case Instruction::FAdd: 489480093f4SDimitry Andric case Instruction::FSub: 490480093f4SDimitry Andric case Instruction::FMul: // Scalar multiply. 491*e8d8bef9SDimitry Andric case Instruction::FNeg: 492480093f4SDimitry Andric case Instruction::Add: 493480093f4SDimitry Andric case Instruction::Mul: 494480093f4SDimitry Andric case Instruction::Sub: 495480093f4SDimitry Andric return true; 496480093f4SDimitry Andric default: 497480093f4SDimitry Andric return false; 498480093f4SDimitry Andric } 499480093f4SDimitry Andric } 500480093f4SDimitry Andric 501480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 502480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 503480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 504480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 505480093f4SDimitry Andric if (!Inst) 506480093f4SDimitry Andric return false; 507480093f4SDimitry Andric 508480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 509480093f4SDimitry Andric if (II) 510480093f4SDimitry Andric switch (II->getIntrinsicID()) { 511480093f4SDimitry Andric case Intrinsic::matrix_multiply: 512480093f4SDimitry Andric case Intrinsic::matrix_transpose: 5135ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 5145ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 515480093f4SDimitry Andric return true; 516480093f4SDimitry Andric default: 517480093f4SDimitry Andric return false; 518480093f4SDimitry Andric } 519480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 520480093f4SDimitry Andric } 521480093f4SDimitry Andric 522480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 523480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 524480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 525480093f4SDimitry Andric /// shapes of operands. 526480093f4SDimitry Andric SmallVector<Instruction *, 32> 527480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 528480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 529480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 530480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 531480093f4SDimitry Andric // list. 532480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 533480093f4SDimitry Andric while (!WorkList.empty()) { 534*e8d8bef9SDimitry Andric Instruction *Inst = WorkList.pop_back_val(); 535480093f4SDimitry Andric 536480093f4SDimitry Andric // New entry, set the value and insert operands 537480093f4SDimitry Andric bool Propagate = false; 538480093f4SDimitry Andric 539480093f4SDimitry Andric Value *MatrixA; 540480093f4SDimitry Andric Value *MatrixB; 541480093f4SDimitry Andric Value *M; 542480093f4SDimitry Andric Value *N; 543480093f4SDimitry Andric Value *K; 544480093f4SDimitry Andric if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 545480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 546480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 547480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, K}); 548480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 549480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 550480093f4SDimitry Andric // Flip dimensions. 551480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5525ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 553480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 5545ffd83dbSDimitry Andric m_Value(), m_Value(M), m_Value(N)))) { 555480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 5565ffd83dbSDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 5575ffd83dbSDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), 5585ffd83dbSDimitry Andric m_Value(N)))) { 559480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, N}); 560480093f4SDimitry Andric } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 561480093f4SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 562480093f4SDimitry Andric if (OpShape != ShapeMap.end()) 563480093f4SDimitry Andric setShapeInfo(Inst, OpShape->second); 564480093f4SDimitry Andric continue; 565480093f4SDimitry Andric } else if (isUniformShape(Inst)) { 566480093f4SDimitry Andric // Find the first operand that has a known shape and use that. 567480093f4SDimitry Andric for (auto &Op : Inst->operands()) { 568480093f4SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 569480093f4SDimitry Andric if (OpShape != ShapeMap.end()) { 570480093f4SDimitry Andric Propagate |= setShapeInfo(Inst, OpShape->second); 571480093f4SDimitry Andric break; 572480093f4SDimitry Andric } 573480093f4SDimitry Andric } 574480093f4SDimitry Andric } 575480093f4SDimitry Andric 576480093f4SDimitry Andric if (Propagate) { 577480093f4SDimitry Andric NewWorkList.push_back(Inst); 578480093f4SDimitry Andric for (auto *User : Inst->users()) 579480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 580480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 581480093f4SDimitry Andric } 582480093f4SDimitry Andric } 583480093f4SDimitry Andric 584480093f4SDimitry Andric return NewWorkList; 585480093f4SDimitry Andric } 586480093f4SDimitry Andric 587480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 588480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 589480093f4SDimitry Andric SmallVector<Instruction *, 32> 590480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 591480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 592480093f4SDimitry Andric 593480093f4SDimitry Andric auto pushInstruction = [](Value *V, 594480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 595480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 596480093f4SDimitry Andric if (I) 597480093f4SDimitry Andric WorkList.push_back(I); 598480093f4SDimitry Andric }; 599480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 600480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 601480093f4SDimitry Andric // worklist. 602480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 603480093f4SDimitry Andric while (!WorkList.empty()) { 604*e8d8bef9SDimitry Andric Value *V = WorkList.pop_back_val(); 605480093f4SDimitry Andric 606480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 607480093f4SDimitry Andric if (!isa<Instruction>(V)) 608480093f4SDimitry Andric continue; 609480093f4SDimitry Andric 610480093f4SDimitry Andric Value *MatrixA; 611480093f4SDimitry Andric Value *MatrixB; 612480093f4SDimitry Andric Value *M; 613480093f4SDimitry Andric Value *N; 614480093f4SDimitry Andric Value *K; 615480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 616480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 617480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 618480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 619480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 620480093f4SDimitry Andric 621480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 622480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 623480093f4SDimitry Andric 624480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 625480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 626480093f4SDimitry Andric // Flip dimensions. 627480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 628480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 6295ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 6305ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 631480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 632480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 633480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 634480093f4SDimitry Andric } 635480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 6365ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 637480093f4SDimitry Andric // Nothing to do, no matrix input. 638480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 639480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 640480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 641480093f4SDimitry Andric } else if (isUniformShape(V)) { 642480093f4SDimitry Andric // Propagate to all operands. 643480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 644480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 645480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 646480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 647480093f4SDimitry Andric } 648480093f4SDimitry Andric } 649480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 650480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 651480093f4SDimitry Andric // propagation. 652480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 653480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 654480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 655480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 656480093f4SDimitry Andric } 657480093f4SDimitry Andric return NewWorkList; 658480093f4SDimitry Andric } 659480093f4SDimitry Andric 660480093f4SDimitry Andric bool Visit() { 661480093f4SDimitry Andric if (EnableShapePropagation) { 662480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 663480093f4SDimitry Andric 664480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 665480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 666480093f4SDimitry Andric for (BasicBlock &BB : Func) 667480093f4SDimitry Andric for (Instruction &Inst : BB) { 668480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 669480093f4SDimitry Andric if (!II) 670480093f4SDimitry Andric continue; 671480093f4SDimitry Andric 672480093f4SDimitry Andric switch (II->getIntrinsicID()) { 673480093f4SDimitry Andric case Intrinsic::matrix_multiply: 674480093f4SDimitry Andric case Intrinsic::matrix_transpose: 6755ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 6765ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 677480093f4SDimitry Andric WorkList.push_back(&Inst); 678480093f4SDimitry Andric break; 679480093f4SDimitry Andric default: 680480093f4SDimitry Andric break; 681480093f4SDimitry Andric } 682480093f4SDimitry Andric } 683480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 684480093f4SDimitry Andric while (!WorkList.empty()) { 685480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 686480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 687480093f4SDimitry Andric } 688480093f4SDimitry Andric } 689480093f4SDimitry Andric 690480093f4SDimitry Andric bool Changed = false; 6915ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts; 6925ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts; 693480093f4SDimitry Andric 6945ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for 6955ffd83dbSDimitry Andric // fusion (currently only matrix multiplies). 6965ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 6975ffd83dbSDimitry Andric for (auto *BB : RPOT) 6985ffd83dbSDimitry Andric for (Instruction &I : *BB) { 6995ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end()) 7005ffd83dbSDimitry Andric continue; 7015ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 7025ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I)); 7035ffd83dbSDimitry Andric MatrixInsts.push_back(&I); 7045ffd83dbSDimitry Andric } 7055ffd83dbSDimitry Andric 7065ffd83dbSDimitry Andric // Second, try to fuse candidates. 7075ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts; 7085ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts) 7095ffd83dbSDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts); 7105ffd83dbSDimitry Andric Changed = !FusedInsts.empty(); 7115ffd83dbSDimitry Andric 7125ffd83dbSDimitry Andric // Third, lower remaining instructions with shape information. 7135ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) { 7145ffd83dbSDimitry Andric if (FusedInsts.count(Inst)) 7155ffd83dbSDimitry Andric continue; 7165ffd83dbSDimitry Andric 7175ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 7185ffd83dbSDimitry Andric 7195ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 720480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 721480093f4SDimitry Andric 722480093f4SDimitry Andric Value *Op1; 723480093f4SDimitry Andric Value *Op2; 7245ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 725480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 726*e8d8bef9SDimitry Andric if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 727*e8d8bef9SDimitry Andric Changed |= VisitUnaryOperator(UnOp); 7285ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1)))) 7295ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 7305ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 7315ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 732480093f4SDimitry Andric } 7335ffd83dbSDimitry Andric 734*e8d8bef9SDimitry Andric if (ORE) { 735*e8d8bef9SDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 7365ffd83dbSDimitry Andric RemarkGen.emitRemarks(); 737*e8d8bef9SDimitry Andric } 738480093f4SDimitry Andric 739480093f4SDimitry Andric for (Instruction *Inst : reverse(ToRemove)) 740480093f4SDimitry Andric Inst->eraseFromParent(); 741480093f4SDimitry Andric 742480093f4SDimitry Andric return Changed; 743480093f4SDimitry Andric } 744480093f4SDimitry Andric 745480093f4SDimitry Andric /// Turns \p BasePtr into an elementwise pointer to \p EltType. 746480093f4SDimitry Andric Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 747480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 748480093f4SDimitry Andric Type *EltPtrType = PointerType::get(EltType, AS); 749480093f4SDimitry Andric return Builder.CreatePointerCast(BasePtr, EltPtrType); 750480093f4SDimitry Andric } 751480093f4SDimitry Andric 752480093f4SDimitry Andric /// Replace intrinsic calls 753480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 754480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 755480093f4SDimitry Andric return false; 756480093f4SDimitry Andric 757480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 758480093f4SDimitry Andric case Intrinsic::matrix_multiply: 759480093f4SDimitry Andric LowerMultiply(Inst); 760480093f4SDimitry Andric break; 761480093f4SDimitry Andric case Intrinsic::matrix_transpose: 762480093f4SDimitry Andric LowerTranspose(Inst); 763480093f4SDimitry Andric break; 7645ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 7655ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst); 766480093f4SDimitry Andric break; 7675ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 7685ffd83dbSDimitry Andric LowerColumnMajorStore(Inst); 769480093f4SDimitry Andric break; 770480093f4SDimitry Andric default: 771480093f4SDimitry Andric return false; 772480093f4SDimitry Andric } 773480093f4SDimitry Andric return true; 774480093f4SDimitry Andric } 775480093f4SDimitry Andric 7765ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them. 7775ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 7785ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For 7795ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial 7805ffd83dbSDimitry Andric /// alignment and the element size in bytes. 7815ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 7825ffd83dbSDimitry Andric MaybeAlign A) const { 7835ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 7845ffd83dbSDimitry Andric if (Idx == 0) 7855ffd83dbSDimitry Andric return InitialAlign; 7865ffd83dbSDimitry Andric 7875ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 7885ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 7895ffd83dbSDimitry Andric uint64_t StrideInBytes = 7905ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8; 7915ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes); 7925ffd83dbSDimitry Andric } 7935ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8); 7945ffd83dbSDimitry Andric } 7955ffd83dbSDimitry Andric 7965ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 7975ffd83dbSDimitry Andric /// vectors. 7985ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 7995ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 8005ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 801480093f4SDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 8025ffd83dbSDimitry Andric MatrixTy Result; 8035ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 8045ffd83dbSDimitry Andric Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 8055ffd83dbSDimitry Andric Shape.getStride(), VType->getElementType(), 8065ffd83dbSDimitry Andric Builder); 8075ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad( 8085ffd83dbSDimitry Andric GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), 8095ffd83dbSDimitry Andric IsVolatile, "col.load"); 8105ffd83dbSDimitry Andric 8115ffd83dbSDimitry Andric Result.addVector(Vector); 8125ffd83dbSDimitry Andric } 8135ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 8145ffd83dbSDimitry Andric Result.getNumVectors()); 815480093f4SDimitry Andric } 816480093f4SDimitry Andric 8175ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 8185ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J]. 8195ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 8205ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J, 8215ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy, 8225ffd83dbSDimitry Andric IRBuilder<> &Builder) { 8235ffd83dbSDimitry Andric 8245ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 8255ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 8265ffd83dbSDimitry Andric 8275ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 8285ffd83dbSDimitry Andric Value *EltPtr = 8295ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 8305ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 8315ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 8325ffd83dbSDimitry Andric ResultShape.NumColumns); 8335ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 8345ffd83dbSDimitry Andric Value *TilePtr = 8355ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 8365ffd83dbSDimitry Andric 8375ffd83dbSDimitry Andric return loadMatrix(TileTy, TilePtr, Align, 8385ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, 8395ffd83dbSDimitry Andric ResultShape, Builder); 840480093f4SDimitry Andric } 841480093f4SDimitry Andric 8425ffd83dbSDimitry Andric /// Lower a load instruction with shape information. 8435ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 8445ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) { 8455ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 8465ffd83dbSDimitry Andric finalizeLowering(Inst, 8475ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 8485ffd83dbSDimitry Andric Shape, Builder), 8495ffd83dbSDimitry Andric Builder); 8505ffd83dbSDimitry Andric } 8515ffd83dbSDimitry Andric 8525ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load. 853480093f4SDimitry Andric /// 854480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 8555ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) { 8565ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 8575ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 858480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 859480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 8605ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 8615ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 862480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 863480093f4SDimitry Andric } 864480093f4SDimitry Andric 8655ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 8665ffd83dbSDimitry Andric /// MatrixPtr[I][J]. 8675ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 8685ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 8695ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 8705ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd( 8715ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 8725ffd83dbSDimitry Andric 8735ffd83dbSDimitry Andric unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 8745ffd83dbSDimitry Andric Value *EltPtr = 8755ffd83dbSDimitry Andric Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 8765ffd83dbSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 8775ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 8785ffd83dbSDimitry Andric StoreVal.getNumColumns()); 8795ffd83dbSDimitry Andric Type *TilePtrTy = PointerType::get(TileTy, AS); 8805ffd83dbSDimitry Andric Value *TilePtr = 8815ffd83dbSDimitry Andric Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 8825ffd83dbSDimitry Andric 8835ffd83dbSDimitry Andric storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 8845ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 8855ffd83dbSDimitry Andric } 8865ffd83dbSDimitry Andric 8875ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 8885ffd83dbSDimitry Andric /// vectors. 8895ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 8905ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile, 8915ffd83dbSDimitry Andric IRBuilder<> &Builder) { 8925ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty); 8935ffd83dbSDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 8945ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) { 8955ffd83dbSDimitry Andric Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 8965ffd83dbSDimitry Andric Stride, StoreVal.getStride(), 8975ffd83dbSDimitry Andric VType->getElementType(), Builder); 8985ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP, 8995ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride, 9005ffd83dbSDimitry Andric VType->getElementType(), 9015ffd83dbSDimitry Andric MAlign), 9025ffd83dbSDimitry Andric IsVolatile); 9035ffd83dbSDimitry Andric } 9045ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 9055ffd83dbSDimitry Andric StoreVal.getNumVectors()); 9065ffd83dbSDimitry Andric } 9075ffd83dbSDimitry Andric 9085ffd83dbSDimitry Andric /// Lower a store instruction with shape information. 9095ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 9105ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) { 9115ffd83dbSDimitry Andric IRBuilder<> Builder(Inst); 9125ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder); 9135ffd83dbSDimitry Andric finalizeLowering(Inst, 9145ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 9155ffd83dbSDimitry Andric IsVolatile, Builder), 9165ffd83dbSDimitry Andric Builder); 9175ffd83dbSDimitry Andric } 9185ffd83dbSDimitry Andric 9195ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store. 9205ffd83dbSDimitry Andric /// 9215ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 9225ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) { 9235ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 9245ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!"); 9255ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0); 9265ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1); 9275ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2); 9285ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 9295ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 9305ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 931480093f4SDimitry Andric } 932480093f4SDimitry Andric 933480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 934480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 9355ffd83dbSDimitry Andric IRBuilder<> &Builder) { 936480093f4SDimitry Andric 937480093f4SDimitry Andric // First, bring Block to the same size as Col 938480093f4SDimitry Andric unsigned BlockNumElts = 9395ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements(); 9405ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 941480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 942480093f4SDimitry Andric 9435ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector( 944*e8d8bef9SDimitry Andric Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 945480093f4SDimitry Andric 946480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 947480093f4SDimitry Andric // 8, 4, 5, 6 9485ffd83dbSDimitry Andric SmallVector<int, 16> Mask; 949480093f4SDimitry Andric unsigned i; 950480093f4SDimitry Andric for (i = 0; i < I; i++) 9515ffd83dbSDimitry Andric Mask.push_back(i); 952480093f4SDimitry Andric 9535ffd83dbSDimitry Andric unsigned VecNumElts = 9545ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements(); 955480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 9565ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts); 957480093f4SDimitry Andric 958480093f4SDimitry Andric for (; i < VecNumElts; i++) 9595ffd83dbSDimitry Andric Mask.push_back(i); 960480093f4SDimitry Andric 9615ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask); 962480093f4SDimitry Andric } 963480093f4SDimitry Andric 964480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 9655ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction, 9665ffd83dbSDimitry Andric unsigned &NumComputeOps) { 9675ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 968480093f4SDimitry Andric if (!Sum) 969480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 970480093f4SDimitry Andric 971480093f4SDimitry Andric if (UseFPOp) { 972480093f4SDimitry Andric if (AllowContraction) { 973480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 974480093f4SDimitry Andric // if that's profitable. 9755ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration( 976480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 977480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 978480093f4SDimitry Andric } 9795ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 980480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 981480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 982480093f4SDimitry Andric } 983480093f4SDimitry Andric 9845ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType()); 985480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 986480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 987480093f4SDimitry Andric } 988480093f4SDimitry Andric 989480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 990480093f4SDimitry Andric /// users with shape information, there's nothing to do: the will use the 991480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 992480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 993480093f4SDimitry Andric /// deletion. 9945ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 995480093f4SDimitry Andric IRBuilder<> &Builder) { 996480093f4SDimitry Andric Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 997480093f4SDimitry Andric 998480093f4SDimitry Andric ToRemove.push_back(Inst); 999480093f4SDimitry Andric Value *Flattened = nullptr; 1000480093f4SDimitry Andric for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 1001480093f4SDimitry Andric Use &U = *I++; 1002480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1003480093f4SDimitry Andric if (!Flattened) 1004480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 1005480093f4SDimitry Andric U.set(Flattened); 1006480093f4SDimitry Andric } 1007480093f4SDimitry Andric } 1008480093f4SDimitry Andric } 1009480093f4SDimitry Andric 10105ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating 10115ffd83dbSDimitry Andric /// addition. 10125ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 10135ffd83dbSDimitry Andric const MatrixTy &B, bool AllowContraction, 10145ffd83dbSDimitry Andric IRBuilder<> &Builder, bool isTiled) { 10155ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>( 10165ffd83dbSDimitry Andric TTI.getRegisterBitWidth(true) / 10175ffd83dbSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 10185ffd83dbSDimitry Andric 1U); 10195ffd83dbSDimitry Andric unsigned R = Result.getNumRows(); 10205ffd83dbSDimitry Andric unsigned C = Result.getNumColumns(); 10215ffd83dbSDimitry Andric unsigned M = A.getNumColumns(); 10225ffd83dbSDimitry Andric 10235ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy(); 10245ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 10255ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 10265ffd83dbSDimitry Andric "operands must agree on matrix layout"); 10275ffd83dbSDimitry Andric unsigned NumComputeOps = 0; 10285ffd83dbSDimitry Andric if (A.isColumnMajor()) { 10295ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second 10305ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 10315ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation. 10325ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) { 10335ffd83dbSDimitry Andric unsigned BlockSize = VF; 10345ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration. 10355ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 10365ffd83dbSDimitry Andric 10375ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 10385ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 10395ffd83dbSDimitry Andric while (I + BlockSize > R) 10405ffd83dbSDimitry Andric BlockSize /= 2; 10415ffd83dbSDimitry Andric 10425ffd83dbSDimitry Andric Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 10435ffd83dbSDimitry Andric : nullptr; 10445ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 10455ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder); 10465ffd83dbSDimitry Andric Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 10475ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 10485ffd83dbSDimitry Andric Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 10495ffd83dbSDimitry Andric Result.getElementType()->isFloatingPointTy(), 10505ffd83dbSDimitry Andric Builder, AllowContraction, NumComputeOps); 10515ffd83dbSDimitry Andric } 10525ffd83dbSDimitry Andric Result.setVector(J, 10535ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder)); 10545ffd83dbSDimitry Andric } 10555ffd83dbSDimitry Andric } 10565ffd83dbSDimitry Andric } else { 10575ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first 10585ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this 10595ffd83dbSDimitry Andric // the adds can be vectorized without reassociation. 10605ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) { 10615ffd83dbSDimitry Andric unsigned BlockSize = VF; 10625ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 10635ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) { 10645ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 10655ffd83dbSDimitry Andric while (J + BlockSize > C) 10665ffd83dbSDimitry Andric BlockSize /= 2; 10675ffd83dbSDimitry Andric 10685ffd83dbSDimitry Andric Value *Sum = nullptr; 10695ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) { 10705ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder); 10715ffd83dbSDimitry Andric Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 10725ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 10735ffd83dbSDimitry Andric Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 10745ffd83dbSDimitry Andric IsFP, Builder, AllowContraction, NumComputeOps); 10755ffd83dbSDimitry Andric } 10765ffd83dbSDimitry Andric Result.setVector(I, 10775ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder)); 10785ffd83dbSDimitry Andric } 10795ffd83dbSDimitry Andric } 10805ffd83dbSDimitry Andric } 10815ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps); 10825ffd83dbSDimitry Andric } 10835ffd83dbSDimitry Andric 10845ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially 10855ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location 10865ffd83dbSDimitry Andric /// is returned. 10875ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 10885ffd83dbSDimitry Andric CallInst *MatMul) { 10895ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store); 10905ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load); 10915ffd83dbSDimitry Andric 1092*e8d8bef9SDimitry Andric AliasResult LdAliased = AA->alias(LoadLoc, StoreLoc); 10935ffd83dbSDimitry Andric 10945ffd83dbSDimitry Andric // If we can statically determine noalias we're good. 10955ffd83dbSDimitry Andric if (!LdAliased) 10965ffd83dbSDimitry Andric return Load->getPointerOperand(); 10975ffd83dbSDimitry Andric 10985ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store 10995ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer. 11005ffd83dbSDimitry Andric 11015ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy. 11025ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent(); 11035ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 11045ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work, 11055ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches. 11065ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 11075ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0)) 1108*e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Delete, Check0, Succ}); 11095ffd83dbSDimitry Andric 1110*e8d8bef9SDimitry Andric BasicBlock *Check1 = 1111*e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 11125ffd83dbSDimitry Andric nullptr, "alias_cont"); 11135ffd83dbSDimitry Andric BasicBlock *Copy = 1114*e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1115*e8d8bef9SDimitry Andric nullptr, "copy"); 1116*e8d8bef9SDimitry Andric BasicBlock *Fusion = 1117*e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 11185ffd83dbSDimitry Andric nullptr, "no_alias"); 11195ffd83dbSDimitry Andric 11205ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store 11215ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are 11225ffd83dbSDimitry Andric // guaranteed to not overlap. 11235ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul); 11245ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent(); 11255ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0); 11265ffd83dbSDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 11275ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt( 11285ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 11295ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd( 11305ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 11315ffd83dbSDimitry Andric "store.end", true, true); 11325ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 11335ffd83dbSDimitry Andric IntPtrTy, "load.begin"); 11345ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 11355ffd83dbSDimitry Andric Fusion); 11365ffd83dbSDimitry Andric 11375ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the 11385ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not 11395ffd83dbSDimitry Andric // overlap. 11405ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent(); 11415ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin()); 11425ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd( 11435ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 11445ffd83dbSDimitry Andric "load.end", true, true); 11455ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 11465ffd83dbSDimitry Andric Fusion); 11475ffd83dbSDimitry Andric 11485ffd83dbSDimitry Andric // Copy load operand to new alloca. 11495ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin()); 11505ffd83dbSDimitry Andric AllocaInst *NewLd = 11515ffd83dbSDimitry Andric Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 11525ffd83dbSDimitry Andric Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 11535ffd83dbSDimitry Andric Load->getPointerOperand(), Load->getAlign(), 11545ffd83dbSDimitry Andric LoadLoc.Size.getValue()); 11555ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin()); 11565ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 11575ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0); 11585ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1); 11595ffd83dbSDimitry Andric PHI->addIncoming(NewLd, Copy); 11605ffd83dbSDimitry Andric 11615ffd83dbSDimitry Andric // Adjust DT. 1162*e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Check1}); 1163*e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1164*e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Copy}); 1165*e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1166*e8d8bef9SDimitry Andric DT->applyUpdates(DTUpdates); 11675ffd83dbSDimitry Andric return PHI; 11685ffd83dbSDimitry Andric } 11695ffd83dbSDimitry Andric 11705ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) { 11715ffd83dbSDimitry Andric if (ForceFusion) 11725ffd83dbSDimitry Andric return true; 11735ffd83dbSDimitry Andric 11745ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 11755ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 11765ffd83dbSDimitry Andric 11775ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 11785ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 11795ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 11805ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 11815ffd83dbSDimitry Andric 11825ffd83dbSDimitry Andric const unsigned VF = 11835ffd83dbSDimitry Andric std::max<unsigned>(TTI.getRegisterBitWidth(true) / 11845ffd83dbSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedSize(), 11855ffd83dbSDimitry Andric 1U); 11865ffd83dbSDimitry Andric 11875ffd83dbSDimitry Andric // Cost model for tiling 11885ffd83dbSDimitry Andric // 11895ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or 11905ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least 11915ffd83dbSDimitry Andric // 3 elements. 11925ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias. 11935ffd83dbSDimitry Andric if (R <= VF && C == 1) 11945ffd83dbSDimitry Andric return false; 11955ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector 11965ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since 11975ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of 11985ffd83dbSDimitry Andric // reloads necessary. 11995ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M; 12005ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C; 12015ffd83dbSDimitry Andric return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 12025ffd83dbSDimitry Andric } 12035ffd83dbSDimitry Andric 12045ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 12055ffd83dbSDimitry Andric MatrixTy Res; 12065ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R); 12075ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I) 12085ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType)); 12095ffd83dbSDimitry Andric return Res; 12105ffd83dbSDimitry Andric } 12115ffd83dbSDimitry Andric 1212*e8d8bef9SDimitry Andric void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1213*e8d8bef9SDimitry Andric Value *RPtr, ShapeInfo RShape, StoreInst *Store, 1214*e8d8bef9SDimitry Andric bool AllowContract) { 1215*e8d8bef9SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1216*e8d8bef9SDimitry Andric 1217*e8d8bef9SDimitry Andric // Create the main tiling loop nest. 1218*e8d8bef9SDimitry Andric TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1219*e8d8bef9SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1220*e8d8bef9SDimitry Andric Instruction *InsertI = cast<Instruction>(MatMul); 1221*e8d8bef9SDimitry Andric BasicBlock *Start = InsertI->getParent(); 1222*e8d8bef9SDimitry Andric BasicBlock *End = 1223*e8d8bef9SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1224*e8d8bef9SDimitry Andric IRBuilder<> Builder(MatMul); 1225*e8d8bef9SDimitry Andric BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1226*e8d8bef9SDimitry Andric 1227*e8d8bef9SDimitry Andric Type *TileVecTy = 1228*e8d8bef9SDimitry Andric FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1229*e8d8bef9SDimitry Andric MatrixTy TileResult; 1230*e8d8bef9SDimitry Andric // Insert in the inner loop header. 1231*e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1232*e8d8bef9SDimitry Andric // Create PHI nodes for the result columns to accumulate across iterations. 1233*e8d8bef9SDimitry Andric SmallVector<PHINode *, 4> ColumnPhis; 1234*e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileSize; I++) { 1235*e8d8bef9SDimitry Andric auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1236*e8d8bef9SDimitry Andric Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1237*e8d8bef9SDimitry Andric TI.RowLoopHeader->getSingleSuccessor()); 1238*e8d8bef9SDimitry Andric TileResult.addVector(Phi); 1239*e8d8bef9SDimitry Andric ColumnPhis.push_back(Phi); 1240*e8d8bef9SDimitry Andric } 1241*e8d8bef9SDimitry Andric 1242*e8d8bef9SDimitry Andric // Insert in the inner loop body, which computes 1243*e8d8bef9SDimitry Andric // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1244*e8d8bef9SDimitry Andric Builder.SetInsertPoint(InnerBody->getTerminator()); 1245*e8d8bef9SDimitry Andric // Load tiles of the operands. 1246*e8d8bef9SDimitry Andric MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1247*e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1248*e8d8bef9SDimitry Andric MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1249*e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder); 1250*e8d8bef9SDimitry Andric emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true); 1251*e8d8bef9SDimitry Andric // Store result after the inner loop is done. 1252*e8d8bef9SDimitry Andric Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1253*e8d8bef9SDimitry Andric storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1254*e8d8bef9SDimitry Andric Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1255*e8d8bef9SDimitry Andric TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1256*e8d8bef9SDimitry Andric 1257*e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1258*e8d8bef9SDimitry Andric ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1259*e8d8bef9SDimitry Andric 1260*e8d8bef9SDimitry Andric // Force unrolling of a few iterations of the inner loop, to make sure there 1261*e8d8bef9SDimitry Andric // is enough work per iteration. 1262*e8d8bef9SDimitry Andric // FIXME: The unroller should make this decision directly instead, but 1263*e8d8bef9SDimitry Andric // currently the cost-model is not up to the task. 1264*e8d8bef9SDimitry Andric unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1265*e8d8bef9SDimitry Andric addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1266*e8d8bef9SDimitry Andric "llvm.loop.unroll.count", InnerLoopUnrollCount); 1267*e8d8bef9SDimitry Andric } 1268*e8d8bef9SDimitry Andric 12695ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 12705ffd83dbSDimitry Andric StoreInst *Store, 12715ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 12725ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 12735ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!"); 12745ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul)) 12755ffd83dbSDimitry Andric return; 12765ffd83dbSDimitry Andric 12775ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 12785ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 12795ffd83dbSDimitry Andric 12805ffd83dbSDimitry Andric const unsigned R = LShape.NumRows; 12815ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns; 12825ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns; 12835ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 12845ffd83dbSDimitry Andric 12855ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 12865ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 12875ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand(); 12885ffd83dbSDimitry Andric 12895ffd83dbSDimitry Andric bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 12905ffd83dbSDimitry Andric MatMul->hasAllowContract()); 1291*e8d8bef9SDimitry Andric if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1292*e8d8bef9SDimitry Andric createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store, 1293*e8d8bef9SDimitry Andric AllowContract); 1294*e8d8bef9SDimitry Andric else { 12955ffd83dbSDimitry Andric IRBuilder<> Builder(Store); 12965ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize) 12975ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) { 12985ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize)); 12995ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize)); 13005ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 13015ffd83dbSDimitry Andric 13025ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) { 13035ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize)); 13045ffd83dbSDimitry Andric MatrixTy A = 13055ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 13065ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K), 13075ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder); 13085ffd83dbSDimitry Andric MatrixTy B = 13095ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 13105ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J), 13115ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder); 13125ffd83dbSDimitry Andric emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 13135ffd83dbSDimitry Andric } 13145ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1315*e8d8bef9SDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType, 1316*e8d8bef9SDimitry Andric Builder); 1317*e8d8bef9SDimitry Andric } 13185ffd83dbSDimitry Andric } 13195ffd83dbSDimitry Andric 13205ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them. 13215ffd83dbSDimitry Andric FusedInsts.insert(Store); 13225ffd83dbSDimitry Andric FusedInsts.insert(MatMul); 13235ffd83dbSDimitry Andric Store->eraseFromParent(); 13245ffd83dbSDimitry Andric MatMul->eraseFromParent(); 13255ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) { 13265ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0); 13275ffd83dbSDimitry Andric LoadOp0->eraseFromParent(); 13285ffd83dbSDimitry Andric } 13295ffd83dbSDimitry Andric if (LoadOp1->hasNUses(0)) { 13305ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1); 13315ffd83dbSDimitry Andric LoadOp1->eraseFromParent(); 13325ffd83dbSDimitry Andric } 13335ffd83dbSDimitry Andric } 13345ffd83dbSDimitry Andric 13355ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations. 13365ffd83dbSDimitry Andric /// 13375ffd83dbSDimitry Andric /// Currently we only lower {ld, ld} -> matmul -> st chains. 13385ffd83dbSDimitry Andric // 13395ffd83dbSDimitry Andric /// No need to return a MatrixTy object for the result of the operation, since 13405ffd83dbSDimitry Andric /// the single store user will be lowered as part of this. Instructions that 13415ffd83dbSDimitry Andric /// are completely eliminated by fusion are added to \p FusedInsts. 13425ffd83dbSDimitry Andric void LowerMatrixMultiplyFused(CallInst *MatMul, 13435ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) { 13445ffd83dbSDimitry Andric if (!FuseMatrix || !MatMul->hasOneUse() || 1345*e8d8bef9SDimitry Andric MatrixLayout != MatrixLayoutTy::ColumnMajor || !DT) 13465ffd83dbSDimitry Andric return; 13475ffd83dbSDimitry Andric 1348*e8d8bef9SDimitry Andric assert(AA && LI && "Analyses should be available"); 1349*e8d8bef9SDimitry Andric 13505ffd83dbSDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 13515ffd83dbSDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 13525ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 13535ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) { 13545ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise 13555ffd83dbSDimitry Andric // we create invalid IR. 13565ffd83dbSDimitry Andric // FIXME: See if we can hoist the store address computation. 13575ffd83dbSDimitry Andric auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1358*e8d8bef9SDimitry Andric if (AddrI && (!DT->dominates(AddrI, MatMul))) 13595ffd83dbSDimitry Andric return; 13605ffd83dbSDimitry Andric 13615ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 13625ffd83dbSDimitry Andric return; 13635ffd83dbSDimitry Andric } 13645ffd83dbSDimitry Andric } 13655ffd83dbSDimitry Andric 1366480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 1367480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 1368480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 1369480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1370480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1371480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1372480093f4SDimitry Andric 13735ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 13745ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1375*e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Rhs.getElementType() && 1376*e8d8bef9SDimitry Andric "Matrix multiply argument element types do not match."); 1377480093f4SDimitry Andric 1378480093f4SDimitry Andric const unsigned R = LShape.NumRows; 1379480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 13805ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows); 1381480093f4SDimitry Andric 1382480093f4SDimitry Andric // Initialize the output 13835ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType); 1384*e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Result.getElementType() && 1385*e8d8bef9SDimitry Andric "Matrix multiply result element type does not match arguments."); 1386480093f4SDimitry Andric 1387480093f4SDimitry Andric bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1388480093f4SDimitry Andric MatMul->hasAllowContract()); 13895ffd83dbSDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1390480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 1391480093f4SDimitry Andric } 1392480093f4SDimitry Andric 1393480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 1394480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 13955ffd83dbSDimitry Andric MatrixTy Result; 1396480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1397480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 1398480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1399480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 14005ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1401480093f4SDimitry Andric 14025ffd83dbSDimitry Andric const unsigned NewNumVecs = 14035ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 14045ffd83dbSDimitry Andric const unsigned NewNumElts = 14055ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1406480093f4SDimitry Andric 14075ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) { 14085ffd83dbSDimitry Andric // Build a single result vector. First initialize it. 14095ffd83dbSDimitry Andric Value *ResultVector = UndefValue::get( 14105ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 14115ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector. 14125ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) { 14135ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I); 14145ffd83dbSDimitry Andric // Row and column indices are transposed. 14155ffd83dbSDimitry Andric ResultVector = 14165ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1417480093f4SDimitry Andric } 14185ffd83dbSDimitry Andric Result.addVector(ResultVector); 1419480093f4SDimitry Andric } 1420480093f4SDimitry Andric 14215ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we 14225ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not 14235ffd83dbSDimitry Andric // account for later simplifications/combines. 14245ffd83dbSDimitry Andric finalizeLowering( 14255ffd83dbSDimitry Andric Inst, 14265ffd83dbSDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 14275ffd83dbSDimitry Andric Builder); 1428480093f4SDimitry Andric } 1429480093f4SDimitry Andric 1430480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 14315ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1432480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1433480093f4SDimitry Andric if (I == ShapeMap.end()) 1434480093f4SDimitry Andric return false; 1435480093f4SDimitry Andric 14365ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(), 14375ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 14385ffd83dbSDimitry Andric I->second); 1439480093f4SDimitry Andric return true; 1440480093f4SDimitry Andric } 1441480093f4SDimitry Andric 14425ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1443480093f4SDimitry Andric IRBuilder<> &Builder) { 1444480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 1445480093f4SDimitry Andric if (I == ShapeMap.end()) 1446480093f4SDimitry Andric return false; 1447480093f4SDimitry Andric 14485ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 14495ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 14505ffd83dbSDimitry Andric I->second); 1451480093f4SDimitry Andric return true; 1452480093f4SDimitry Andric } 1453480093f4SDimitry Andric 1454480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 1455480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 1456480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 1457480093f4SDimitry Andric if (I == ShapeMap.end()) 1458480093f4SDimitry Andric return false; 1459480093f4SDimitry Andric 1460480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 1461480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 1462480093f4SDimitry Andric 1463480093f4SDimitry Andric IRBuilder<> Builder(Inst); 1464480093f4SDimitry Andric ShapeInfo &Shape = I->second; 1465480093f4SDimitry Andric 14665ffd83dbSDimitry Andric MatrixTy Result; 14675ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder); 14685ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder); 14695ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() && 14705ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() && 14715ffd83dbSDimitry Andric "operands must agree on matrix layout"); 1472480093f4SDimitry Andric 14735ffd83dbSDimitry Andric // Helper to perform binary op on vectors. 14745ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1475480093f4SDimitry Andric switch (Inst->getOpcode()) { 1476480093f4SDimitry Andric case Instruction::Add: 1477480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 1478480093f4SDimitry Andric case Instruction::Mul: 1479480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 1480480093f4SDimitry Andric case Instruction::Sub: 1481480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 1482480093f4SDimitry Andric case Instruction::FAdd: 1483480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 1484480093f4SDimitry Andric case Instruction::FMul: 1485480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 1486480093f4SDimitry Andric case Instruction::FSub: 1487480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 1488480093f4SDimitry Andric default: 1489480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 1490480093f4SDimitry Andric } 1491480093f4SDimitry Andric }; 1492480093f4SDimitry Andric 14935ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 14945ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 14955ffd83dbSDimitry Andric 14965ffd83dbSDimitry Andric finalizeLowering(Inst, 14975ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 14985ffd83dbSDimitry Andric Result.getNumVectors()), 14995ffd83dbSDimitry Andric Builder); 1500480093f4SDimitry Andric return true; 1501480093f4SDimitry Andric } 15025ffd83dbSDimitry Andric 1503*e8d8bef9SDimitry Andric /// Lower unary operators, if shape information is available. 1504*e8d8bef9SDimitry Andric bool VisitUnaryOperator(UnaryOperator *Inst) { 1505*e8d8bef9SDimitry Andric auto I = ShapeMap.find(Inst); 1506*e8d8bef9SDimitry Andric if (I == ShapeMap.end()) 1507*e8d8bef9SDimitry Andric return false; 1508*e8d8bef9SDimitry Andric 1509*e8d8bef9SDimitry Andric Value *Op = Inst->getOperand(0); 1510*e8d8bef9SDimitry Andric 1511*e8d8bef9SDimitry Andric IRBuilder<> Builder(Inst); 1512*e8d8bef9SDimitry Andric ShapeInfo &Shape = I->second; 1513*e8d8bef9SDimitry Andric 1514*e8d8bef9SDimitry Andric MatrixTy Result; 1515*e8d8bef9SDimitry Andric MatrixTy M = getMatrix(Op, Shape, Builder); 1516*e8d8bef9SDimitry Andric 1517*e8d8bef9SDimitry Andric // Helper to perform unary op on vectors. 1518*e8d8bef9SDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1519*e8d8bef9SDimitry Andric switch (Inst->getOpcode()) { 1520*e8d8bef9SDimitry Andric case Instruction::FNeg: 1521*e8d8bef9SDimitry Andric return Builder.CreateFNeg(Op); 1522*e8d8bef9SDimitry Andric default: 1523*e8d8bef9SDimitry Andric llvm_unreachable("Unsupported unary operator for matrix"); 1524*e8d8bef9SDimitry Andric } 1525*e8d8bef9SDimitry Andric }; 1526*e8d8bef9SDimitry Andric 1527*e8d8bef9SDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1528*e8d8bef9SDimitry Andric Result.addVector(BuildVectorOp(M.getVector(I))); 1529*e8d8bef9SDimitry Andric 1530*e8d8bef9SDimitry Andric finalizeLowering(Inst, 1531*e8d8bef9SDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1532*e8d8bef9SDimitry Andric Result.getNumVectors()), 1533*e8d8bef9SDimitry Andric Builder); 1534*e8d8bef9SDimitry Andric return true; 1535*e8d8bef9SDimitry Andric } 1536*e8d8bef9SDimitry Andric 15375ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently 15385ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and 15395ffd83dbSDimitry Andric /// linearizing bottom up. 15405ffd83dbSDimitry Andric struct ExprLinearizer { 15415ffd83dbSDimitry Andric unsigned LengthToBreak = 100; 15425ffd83dbSDimitry Andric std::string Str; 15435ffd83dbSDimitry Andric raw_string_ostream Stream; 15445ffd83dbSDimitry Andric unsigned LineLength = 0; 15455ffd83dbSDimitry Andric const DataLayout &DL; 15465ffd83dbSDimitry Andric 15475ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify 15485ffd83dbSDimitry Andric /// matrix instructions. 15495ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 15505ffd83dbSDimitry Andric 15515ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is 15525ffd83dbSDimitry Andric /// part of. 15535ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 15545ffd83dbSDimitry Andric 15555ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram. 15565ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram; 15575ffd83dbSDimitry Andric 15585ffd83dbSDimitry Andric /// Leaf node of the expression to linearize. 15595ffd83dbSDimitry Andric Value *Leaf; 15605ffd83dbSDimitry Andric 15615ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing 15625ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused). 15635ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 15645ffd83dbSDimitry Andric 15655ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL, 15665ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix, 15675ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 15685ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 15695ffd83dbSDimitry Andric Value *Leaf) 15705ffd83dbSDimitry Andric : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 15715ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 15725ffd83dbSDimitry Andric 15735ffd83dbSDimitry Andric void indent(unsigned N) { 15745ffd83dbSDimitry Andric LineLength += N; 15755ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++) 15765ffd83dbSDimitry Andric Stream << " "; 15775ffd83dbSDimitry Andric } 15785ffd83dbSDimitry Andric 15795ffd83dbSDimitry Andric void lineBreak() { 15805ffd83dbSDimitry Andric Stream << "\n"; 15815ffd83dbSDimitry Andric LineLength = 0; 15825ffd83dbSDimitry Andric } 15835ffd83dbSDimitry Andric 15845ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) { 15855ffd83dbSDimitry Andric if (LineLength >= LengthToBreak) 15865ffd83dbSDimitry Andric lineBreak(); 15875ffd83dbSDimitry Andric 15885ffd83dbSDimitry Andric if (LineLength == 0) 15895ffd83dbSDimitry Andric indent(Indent); 15905ffd83dbSDimitry Andric } 15915ffd83dbSDimitry Andric 15925ffd83dbSDimitry Andric void write(StringRef S) { 15935ffd83dbSDimitry Andric LineLength += S.size(); 15945ffd83dbSDimitry Andric Stream << S; 15955ffd83dbSDimitry Andric } 15965ffd83dbSDimitry Andric 15975ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) { 15985ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V)) 15995ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr); 16005ffd83dbSDimitry Andric else if (V->getType()->isPointerTy()) 1601*e8d8bef9SDimitry Andric return getUnderlyingObject(V); 16025ffd83dbSDimitry Andric return V; 16035ffd83dbSDimitry Andric } 16045ffd83dbSDimitry Andric 16055ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram. 16065ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 16075ffd83dbSDimitry Andric 16085ffd83dbSDimitry Andric /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 16095ffd83dbSDimitry Andric /// \p SS. 16105ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 16115ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V); 16125ffd83dbSDimitry Andric if (M == Inst2Matrix.end()) 16135ffd83dbSDimitry Andric SS << "unknown"; 16145ffd83dbSDimitry Andric else { 16155ffd83dbSDimitry Andric SS << M->second.getNumRows(); 16165ffd83dbSDimitry Andric SS << "x"; 16175ffd83dbSDimitry Andric SS << M->second.getNumColumns(); 16185ffd83dbSDimitry Andric } 16195ffd83dbSDimitry Andric } 16205ffd83dbSDimitry Andric 16215ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.* 16225ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input 16235ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name. 16245ffd83dbSDimitry Andric void writeFnName(CallInst *CI) { 16255ffd83dbSDimitry Andric if (!CI->getCalledFunction()) 16265ffd83dbSDimitry Andric write("<no called fn>"); 16275ffd83dbSDimitry Andric else { 16285ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName(); 16295ffd83dbSDimitry Andric if (!Name.startswith("llvm.matrix")) { 16305ffd83dbSDimitry Andric write(Name); 16315ffd83dbSDimitry Andric return; 16325ffd83dbSDimitry Andric } 16335ffd83dbSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 16345ffd83dbSDimitry Andric write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 16355ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size())); 16365ffd83dbSDimitry Andric write("."); 1637*e8d8bef9SDimitry Andric std::string Tmp; 16385ffd83dbSDimitry Andric raw_string_ostream SS(Tmp); 16395ffd83dbSDimitry Andric 16405ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 16415ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 16425ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 16435ffd83dbSDimitry Andric SS << "."; 16445ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS); 16455ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 16465ffd83dbSDimitry Andric break; 16475ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 16485ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 16495ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 16505ffd83dbSDimitry Andric break; 16515ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 16525ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS); 16535ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType(); 16545ffd83dbSDimitry Andric break; 16555ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 16565ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS); 16575ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType(); 16585ffd83dbSDimitry Andric break; 16595ffd83dbSDimitry Andric default: 16605ffd83dbSDimitry Andric llvm_unreachable("Unhandled case"); 16615ffd83dbSDimitry Andric } 16625ffd83dbSDimitry Andric SS.flush(); 16635ffd83dbSDimitry Andric write(Tmp); 16645ffd83dbSDimitry Andric } 16655ffd83dbSDimitry Andric } 16665ffd83dbSDimitry Andric 16675ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const { 16685ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 16695ffd83dbSDimitry Andric switch (II->getIntrinsicID()) { 16705ffd83dbSDimitry Andric case Intrinsic::matrix_multiply: 16715ffd83dbSDimitry Andric return 3; 16725ffd83dbSDimitry Andric case Intrinsic::matrix_transpose: 16735ffd83dbSDimitry Andric return 2; 16745ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load: 16755ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store: 16765ffd83dbSDimitry Andric return 3; 16775ffd83dbSDimitry Andric default: 16785ffd83dbSDimitry Andric return 0; 16795ffd83dbSDimitry Andric } 16805ffd83dbSDimitry Andric } 16815ffd83dbSDimitry Andric return 0; 16825ffd83dbSDimitry Andric } 16835ffd83dbSDimitry Andric 16845ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an 16855ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we 16865ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values. 16875ffd83dbSDimitry Andric void write(Value *V) { 16885ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V); 16895ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) { 16905ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) { 16915ffd83dbSDimitry Andric Stream << "stack addr"; 16925ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size(); 16935ffd83dbSDimitry Andric } else { 16945ffd83dbSDimitry Andric Stream << "addr"; 16955ffd83dbSDimitry Andric LineLength += StringRef("addr").size(); 16965ffd83dbSDimitry Andric } 16975ffd83dbSDimitry Andric if (!V->getName().empty()) { 16985ffd83dbSDimitry Andric Stream << " %" << V->getName() << ""; 16995ffd83dbSDimitry Andric LineLength += V->getName().size() + 2; 17005ffd83dbSDimitry Andric } 17015ffd83dbSDimitry Andric return; 17025ffd83dbSDimitry Andric } 17035ffd83dbSDimitry Andric 17045ffd83dbSDimitry Andric std::string Tmp; 17055ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp); 17065ffd83dbSDimitry Andric 17075ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V)) 17085ffd83dbSDimitry Andric TmpStream << CI->getValue(); 17095ffd83dbSDimitry Andric else if (isa<Constant>(V)) 17105ffd83dbSDimitry Andric TmpStream << "constant"; 17115ffd83dbSDimitry Andric else { 17125ffd83dbSDimitry Andric if (isMatrix(V)) 17135ffd83dbSDimitry Andric TmpStream << "matrix"; 17145ffd83dbSDimitry Andric else 17155ffd83dbSDimitry Andric TmpStream << "scalar"; 17165ffd83dbSDimitry Andric } 17175ffd83dbSDimitry Andric TmpStream.flush(); 17185ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim()); 17195ffd83dbSDimitry Andric LineLength += Tmp.size(); 17205ffd83dbSDimitry Andric Stream << Tmp; 17215ffd83dbSDimitry Andric } 17225ffd83dbSDimitry Andric 17235ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent. 17245ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused) 17255ffd83dbSDimitry Andric /// at the re-used root instruction. 17265ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 17275ffd83dbSDimitry Andric bool ParentShared) { 17285ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr); 17295ffd83dbSDimitry Andric maybeIndent(Indent); 17305ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops; 17315ffd83dbSDimitry Andric 17325ffd83dbSDimitry Andric // Is Expr shared with other expression leaves? 17335ffd83dbSDimitry Andric bool ExprShared = false; 17345ffd83dbSDimitry Andric 17355ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required. 17365ffd83dbSDimitry Andric if (!ParentShared) { 17375ffd83dbSDimitry Andric auto SI = Shared.find(Expr); 17385ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf)); 17395ffd83dbSDimitry Andric 17405ffd83dbSDimitry Andric for (Value *S : SI->second) { 17415ffd83dbSDimitry Andric if (S == Leaf) 17425ffd83dbSDimitry Andric continue; 17435ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 17445ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) + 17455ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " ("); 17465ffd83dbSDimitry Andric } 17475ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1; 17485ffd83dbSDimitry Andric } 17495ffd83dbSDimitry Andric 17505ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second; 17515ffd83dbSDimitry Andric if (Reused && !ParentReused) 17525ffd83dbSDimitry Andric write("(reused) "); 17535ffd83dbSDimitry Andric 17545ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) { 17555ffd83dbSDimitry Andric writeFnName(CI); 17565ffd83dbSDimitry Andric 17575ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 17585ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) { 17595ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from 17605ffd83dbSDimitry Andric // non-matrix ops. 17615ffd83dbSDimitry Andric write("matrix"); 17625ffd83dbSDimitry Andric return; 17635ffd83dbSDimitry Andric } else { 17645ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end()); 17655ffd83dbSDimitry Andric write(std::string(I->getOpcodeName())); 17665ffd83dbSDimitry Andric } 17675ffd83dbSDimitry Andric 17685ffd83dbSDimitry Andric write(std::string("(")); 17695ffd83dbSDimitry Andric 17705ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1; 17715ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 17725ffd83dbSDimitry Andric NumOpsToBreak = 2; 17735ffd83dbSDimitry Andric 17745ffd83dbSDimitry Andric for (Value *Op : Ops) { 17755ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak) 17765ffd83dbSDimitry Andric lineBreak(); 17775ffd83dbSDimitry Andric 17785ffd83dbSDimitry Andric maybeIndent(Indent + 1); 17795ffd83dbSDimitry Andric if (isMatrix(Op)) 17805ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared); 17815ffd83dbSDimitry Andric else 17825ffd83dbSDimitry Andric write(Op); 17835ffd83dbSDimitry Andric if (Op != Ops.back()) 17845ffd83dbSDimitry Andric write(", "); 17855ffd83dbSDimitry Andric } 17865ffd83dbSDimitry Andric 17875ffd83dbSDimitry Andric write(")"); 17885ffd83dbSDimitry Andric } 17895ffd83dbSDimitry Andric 17905ffd83dbSDimitry Andric const std::string &getResult() { 17915ffd83dbSDimitry Andric Stream.flush(); 17925ffd83dbSDimitry Andric return Str; 17935ffd83dbSDimitry Andric } 17945ffd83dbSDimitry Andric }; 17955ffd83dbSDimitry Andric 17965ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks 17975ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used: 17985ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the 17995ffd83dbSDimitry Andric /// DISubprograms they are contained in. 18005ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in 18015ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 18025ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix 18035ffd83dbSDimitry Andric // users (like stores) in the current subprogram. 18045ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the 18055ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive 18065ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note 18075ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions 18085ffd83dbSDimitry Andric /// are explicitly marked as shared(). 18095ffd83dbSDimitry Andric struct RemarkGenerator { 18105ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix; 18115ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE; 18125ffd83dbSDimitry Andric Function &Func; 18135ffd83dbSDimitry Andric const DataLayout &DL; 18145ffd83dbSDimitry Andric 18155ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 18165ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func) 18175ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 18185ffd83dbSDimitry Andric DL(Func.getParent()->getDataLayout()) {} 18195ffd83dbSDimitry Andric 18205ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 18215ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in 18225ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores. 18235ffd83dbSDimitry Andric SmallVector<Value *, 4> 18245ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 18255ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves; 18265ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram) 18275ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() || 18285ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 18295ffd83dbSDimitry Andric return ExprsInSubprogram.count(U); 18305ffd83dbSDimitry Andric })) 18315ffd83dbSDimitry Andric Leaves.push_back(Expr); 18325ffd83dbSDimitry Andric return Leaves; 18335ffd83dbSDimitry Andric } 18345ffd83dbSDimitry Andric 18355ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 18365ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to 18375ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram. 18385ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V, 18395ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 18405ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 18415ffd83dbSDimitry Andric 18425ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V)) 18435ffd83dbSDimitry Andric return; 18445ffd83dbSDimitry Andric 18455ffd83dbSDimitry Andric auto I = Shared.insert({V, {}}); 18465ffd83dbSDimitry Andric I.first->second.insert(Leaf); 18475ffd83dbSDimitry Andric 18485ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values()) 18495ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 18505ffd83dbSDimitry Andric } 18515ffd83dbSDimitry Andric 18525ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression 18535ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once. 18545ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 18555ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy> 18565ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 18575ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 18585ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 18595ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root)) 18605ffd83dbSDimitry Andric return {}; 18615ffd83dbSDimitry Andric 18625ffd83dbSDimitry Andric // Already counted this expression. Stop. 18635ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second) 18645ffd83dbSDimitry Andric return {}; 18655ffd83dbSDimitry Andric 18665ffd83dbSDimitry Andric OpInfoTy SharedCount; 18675ffd83dbSDimitry Andric OpInfoTy Count; 18685ffd83dbSDimitry Andric 18695ffd83dbSDimitry Andric auto I = Shared.find(Root); 18705ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root); 18715ffd83dbSDimitry Andric if (I->second.size() == 1) 18725ffd83dbSDimitry Andric Count = CM->second.getOpInfo(); 18735ffd83dbSDimitry Andric else 18745ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo(); 18755ffd83dbSDimitry Andric 18765ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) { 18775ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 18785ffd83dbSDimitry Andric Count += C.first; 18795ffd83dbSDimitry Andric SharedCount += C.second; 18805ffd83dbSDimitry Andric } 18815ffd83dbSDimitry Andric return {Count, SharedCount}; 18825ffd83dbSDimitry Andric } 18835ffd83dbSDimitry Andric 18845ffd83dbSDimitry Andric void emitRemarks() { 18855ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 18865ffd83dbSDimitry Andric return; 18875ffd83dbSDimitry Andric 18885ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing 18895ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we 18905ffd83dbSDimitry Andric // only map them to the containing function. 18915ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 18925ffd83dbSDimitry Andric for (auto &KV : Inst2Matrix) { 18935ffd83dbSDimitry Andric if (Func.getSubprogram()) { 18945ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first); 18955ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc(); 18965ffd83dbSDimitry Andric while (Context) { 18975ffd83dbSDimitry Andric auto I = 18985ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 18995ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 19005ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 19015ffd83dbSDimitry Andric } 19025ffd83dbSDimitry Andric } else { 19035ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}}); 19045ffd83dbSDimitry Andric I.first->second.push_back(KV.first); 19055ffd83dbSDimitry Andric } 19065ffd83dbSDimitry Andric } 19075ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) { 19085ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 19095ffd83dbSDimitry Andric KV.second.end()); 19105ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram); 19115ffd83dbSDimitry Andric 19125ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 19135ffd83dbSDimitry Andric for (Value *Leaf : Leaves) 19145ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 19155ffd83dbSDimitry Andric 19165ffd83dbSDimitry Andric // Generate remarks for each leaf. 19175ffd83dbSDimitry Andric for (auto *L : Leaves) { 19185ffd83dbSDimitry Andric 19195ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 19205ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 19215ffd83dbSDimitry Andric while (Context) { 19225ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) { 19235ffd83dbSDimitry Andric Loc = Context; 19245ffd83dbSDimitry Andric break; 19255ffd83dbSDimitry Andric } 19265ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt(); 19275ffd83dbSDimitry Andric } 19285ffd83dbSDimitry Andric 19295ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs; 19305ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts; 19315ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) = 19325ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 19335ffd83dbSDimitry Andric 19345ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 19355ffd83dbSDimitry Andric cast<Instruction>(L)->getParent()); 19365ffd83dbSDimitry Andric 19375ffd83dbSDimitry Andric Rem << "Lowered with "; 19385ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 19395ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 19405ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps) 19415ffd83dbSDimitry Andric << " compute ops"; 19425ffd83dbSDimitry Andric 19435ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 19445ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) { 19455ffd83dbSDimitry Andric Rem << ",\nadditionally " 19465ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 19475ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 19485ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 19495ffd83dbSDimitry Andric << " compute ops" 19505ffd83dbSDimitry Andric << " are shared with other expressions"; 19515ffd83dbSDimitry Andric } 19525ffd83dbSDimitry Andric 19535ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 19545ffd83dbSDimitry Andric ORE.emit(Rem); 19555ffd83dbSDimitry Andric } 19565ffd83dbSDimitry Andric } 19575ffd83dbSDimitry Andric } 19585ffd83dbSDimitry Andric 19595ffd83dbSDimitry Andric std::string 19605ffd83dbSDimitry Andric linearize(Value *L, 19615ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 19625ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram, 19635ffd83dbSDimitry Andric const DataLayout &DL) { 19645ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 19655ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false); 19665ffd83dbSDimitry Andric return Lin.getResult(); 19675ffd83dbSDimitry Andric } 19685ffd83dbSDimitry Andric }; 1969480093f4SDimitry Andric }; 1970480093f4SDimitry Andric } // namespace 1971480093f4SDimitry Andric 1972480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1973480093f4SDimitry Andric FunctionAnalysisManager &AM) { 1974480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1975*e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE = nullptr; 1976*e8d8bef9SDimitry Andric AAResults *AA = nullptr; 1977*e8d8bef9SDimitry Andric DominatorTree *DT = nullptr; 1978*e8d8bef9SDimitry Andric LoopInfo *LI = nullptr; 1979*e8d8bef9SDimitry Andric 1980*e8d8bef9SDimitry Andric if (!Minimal) { 1981*e8d8bef9SDimitry Andric ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1982*e8d8bef9SDimitry Andric AA = &AM.getResult<AAManager>(F); 1983*e8d8bef9SDimitry Andric DT = &AM.getResult<DominatorTreeAnalysis>(F); 1984*e8d8bef9SDimitry Andric LI = &AM.getResult<LoopAnalysis>(F); 1985*e8d8bef9SDimitry Andric } 19865ffd83dbSDimitry Andric 19875ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1988480093f4SDimitry Andric if (LMT.Visit()) { 1989480093f4SDimitry Andric PreservedAnalyses PA; 1990*e8d8bef9SDimitry Andric if (!Minimal) { 1991*e8d8bef9SDimitry Andric PA.preserve<LoopAnalysis>(); 1992*e8d8bef9SDimitry Andric PA.preserve<DominatorTreeAnalysis>(); 1993*e8d8bef9SDimitry Andric } 1994480093f4SDimitry Andric return PA; 1995480093f4SDimitry Andric } 1996480093f4SDimitry Andric return PreservedAnalyses::all(); 1997480093f4SDimitry Andric } 1998480093f4SDimitry Andric 1999480093f4SDimitry Andric namespace { 2000480093f4SDimitry Andric 2001480093f4SDimitry Andric class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2002480093f4SDimitry Andric public: 2003480093f4SDimitry Andric static char ID; 2004480093f4SDimitry Andric 2005480093f4SDimitry Andric LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2006480093f4SDimitry Andric initializeLowerMatrixIntrinsicsLegacyPassPass( 2007480093f4SDimitry Andric *PassRegistry::getPassRegistry()); 2008480093f4SDimitry Andric } 2009480093f4SDimitry Andric 2010480093f4SDimitry Andric bool runOnFunction(Function &F) override { 20115ffd83dbSDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 20125ffd83dbSDimitry Andric auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 20135ffd83dbSDimitry Andric auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 20145ffd83dbSDimitry Andric auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 20155ffd83dbSDimitry Andric auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2016*e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2017480093f4SDimitry Andric bool C = LMT.Visit(); 2018480093f4SDimitry Andric return C; 2019480093f4SDimitry Andric } 2020480093f4SDimitry Andric 2021480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2022480093f4SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 20235ffd83dbSDimitry Andric AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 20245ffd83dbSDimitry Andric AU.addRequired<AAResultsWrapperPass>(); 20255ffd83dbSDimitry Andric AU.addRequired<DominatorTreeWrapperPass>(); 20265ffd83dbSDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 20275ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 20285ffd83dbSDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 2029480093f4SDimitry Andric } 2030480093f4SDimitry Andric }; 2031480093f4SDimitry Andric } // namespace 2032480093f4SDimitry Andric 2033480093f4SDimitry Andric static const char pass_name[] = "Lower the matrix intrinsics"; 2034480093f4SDimitry Andric char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2035480093f4SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2036480093f4SDimitry Andric false, false) 20375ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 20385ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 20395ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 20405ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2041480093f4SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2042480093f4SDimitry Andric false, false) 2043480093f4SDimitry Andric 2044480093f4SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsPass() { 2045480093f4SDimitry Andric return new LowerMatrixIntrinsicsLegacyPass(); 2046480093f4SDimitry Andric } 2047*e8d8bef9SDimitry Andric 2048*e8d8bef9SDimitry Andric namespace { 2049*e8d8bef9SDimitry Andric 2050*e8d8bef9SDimitry Andric /// A lightweight version of the matrix lowering pass that only requires TTI. 2051*e8d8bef9SDimitry Andric /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2052*e8d8bef9SDimitry Andric /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2053*e8d8bef9SDimitry Andric /// example with -O0. 2054*e8d8bef9SDimitry Andric class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2055*e8d8bef9SDimitry Andric public: 2056*e8d8bef9SDimitry Andric static char ID; 2057*e8d8bef9SDimitry Andric 2058*e8d8bef9SDimitry Andric LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2059*e8d8bef9SDimitry Andric initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2060*e8d8bef9SDimitry Andric *PassRegistry::getPassRegistry()); 2061*e8d8bef9SDimitry Andric } 2062*e8d8bef9SDimitry Andric 2063*e8d8bef9SDimitry Andric bool runOnFunction(Function &F) override { 2064*e8d8bef9SDimitry Andric auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2065*e8d8bef9SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2066*e8d8bef9SDimitry Andric bool C = LMT.Visit(); 2067*e8d8bef9SDimitry Andric return C; 2068*e8d8bef9SDimitry Andric } 2069*e8d8bef9SDimitry Andric 2070*e8d8bef9SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2071*e8d8bef9SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 2072*e8d8bef9SDimitry Andric AU.setPreservesCFG(); 2073*e8d8bef9SDimitry Andric } 2074*e8d8bef9SDimitry Andric }; 2075*e8d8bef9SDimitry Andric } // namespace 2076*e8d8bef9SDimitry Andric 2077*e8d8bef9SDimitry Andric static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2078*e8d8bef9SDimitry Andric char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2079*e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2080*e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, 2081*e8d8bef9SDimitry Andric false, false) 2082*e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2083*e8d8bef9SDimitry Andric "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2084*e8d8bef9SDimitry Andric false) 2085*e8d8bef9SDimitry Andric 2086*e8d8bef9SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2087*e8d8bef9SDimitry Andric return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2088*e8d8bef9SDimitry Andric } 2089