1*480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2*480093f4SDimitry Andric // 3*480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*480093f4SDimitry Andric // 7*480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8*480093f4SDimitry Andric // 9*480093f4SDimitry Andric // Lower matrix intrinsics to vector operations. 10*480093f4SDimitry Andric // 11*480093f4SDimitry Andric // TODO: 12*480093f4SDimitry Andric // * Implement multiply & add fusion 13*480093f4SDimitry Andric // * Add remark, summarizing the available matrix optimization opportunities. 14*480093f4SDimitry Andric // 15*480093f4SDimitry Andric //===----------------------------------------------------------------------===// 16*480093f4SDimitry Andric 17*480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 18*480093f4SDimitry Andric #include "llvm/ADT/GraphTraits.h" 19*480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 20*480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h" 21*480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 22*480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h" 23*480093f4SDimitry Andric #include "llvm/IR/CFG.h" 24*480093f4SDimitry Andric #include "llvm/IR/DataLayout.h" 25*480093f4SDimitry Andric #include "llvm/IR/Function.h" 26*480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 27*480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 28*480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 29*480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 30*480093f4SDimitry Andric #include "llvm/InitializePasses.h" 31*480093f4SDimitry Andric #include "llvm/Pass.h" 32*480093f4SDimitry Andric #include "llvm/Support/Debug.h" 33*480093f4SDimitry Andric #include "llvm/Transforms/Scalar.h" 34*480093f4SDimitry Andric 35*480093f4SDimitry Andric using namespace llvm; 36*480093f4SDimitry Andric using namespace PatternMatch; 37*480093f4SDimitry Andric 38*480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics" 39*480093f4SDimitry Andric 40*480093f4SDimitry Andric static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", 41*480093f4SDimitry Andric cl::init(true)); 42*480093f4SDimitry Andric 43*480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled( 44*480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden, 45*480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may " 46*480093f4SDimitry Andric "result in different results, due to less rounding error.")); 47*480093f4SDimitry Andric 48*480093f4SDimitry Andric namespace { 49*480093f4SDimitry Andric 50*480093f4SDimitry Andric // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 51*480093f4SDimitry Andric // the start address of column \p Col with type (\p EltType x \p NumRows) 52*480093f4SDimitry Andric // assuming \p Stride elements between start two consecutive columns. 53*480093f4SDimitry Andric // \p Stride must be >= \p NumRows. 54*480093f4SDimitry Andric // 55*480093f4SDimitry Andric // Consider a 4x4 matrix like below 56*480093f4SDimitry Andric // 57*480093f4SDimitry Andric // 0 1 2 3 58*480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3 59*480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3 60*480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3 61*480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3 62*480093f4SDimitry Andric 63*480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 64*480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer. 65*480093f4SDimitry Andric // Then we can use computeColumnAddr to compute the addresses for the columns 66*480093f4SDimitry Andric // of the sub-matrix. 67*480093f4SDimitry Andric // 68*480093f4SDimitry Andric // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 69*480093f4SDimitry Andric // -> just returns Base 70*480093f4SDimitry Andric // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 71*480093f4SDimitry Andric // -> returns Base + (1 * 4) 72*480093f4SDimitry Andric // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 73*480093f4SDimitry Andric // -> returns Base + (2 * 4) 74*480093f4SDimitry Andric // 75*480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked 76*480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }). 77*480093f4SDimitry Andric // 78*480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3 79*480093f4SDimitry Andric // Base Col 1 Col 2 80*480093f4SDimitry Andric // | | | 81*480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3 82*480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3 83*480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3 84*480093f4SDimitry Andric // 85*480093f4SDimitry Andric Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 86*480093f4SDimitry Andric unsigned NumRows, Type *EltType, 87*480093f4SDimitry Andric IRBuilder<> &Builder) { 88*480093f4SDimitry Andric 89*480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) || 90*480093f4SDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 91*480093f4SDimitry Andric "Stride must be >= the number of rows."); 92*480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 93*480093f4SDimitry Andric 94*480093f4SDimitry Andric // Compute the start of the column with index Col as Col * Stride. 95*480093f4SDimitry Andric Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 96*480093f4SDimitry Andric 97*480093f4SDimitry Andric // Get pointer to the start of the selected column. Skip GEP creation, 98*480093f4SDimitry Andric // if we select column 0. 99*480093f4SDimitry Andric if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 100*480093f4SDimitry Andric ColumnStart = BasePtr; 101*480093f4SDimitry Andric else 102*480093f4SDimitry Andric ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 103*480093f4SDimitry Andric 104*480093f4SDimitry Andric // Cast elementwise column start pointer to a pointer to a column 105*480093f4SDimitry Andric // (EltType x NumRows)*. 106*480093f4SDimitry Andric Type *ColumnType = VectorType::get(EltType, NumRows); 107*480093f4SDimitry Andric Type *ColumnPtrType = PointerType::get(ColumnType, AS); 108*480093f4SDimitry Andric return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 109*480093f4SDimitry Andric } 110*480093f4SDimitry Andric 111*480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 112*480093f4SDimitry Andric /// 113*480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows: 114*480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected 115*480093f4SDimitry Andric /// instructions. 116*480093f4SDimitry Andric /// 2. Lower instructions with shape information. 117*480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the 118*480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly. 119*480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into 120*480093f4SDimitry Andric /// a set of column vectors, 121*480093f4SDimitry Andric /// 2.2. Lower the instruction in terms of columnwise operations, which yields 122*480093f4SDimitry Andric /// a set of column vectors containing result matrix. Note that we lower 123*480093f4SDimitry Andric /// all instructions that have shape information. Besides the intrinsics, 124*480093f4SDimitry Andric /// this includes stores for example. 125*480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information 126*480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result 127*480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the 128*480093f4SDimitry Andric /// result matrix in a flat vector and update the use. 129*480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered 130*480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now 131*480093f4SDimitry Andric /// obsolete instructions. 132*480093f4SDimitry Andric /// 133*480093f4SDimitry Andric class LowerMatrixIntrinsics { 134*480093f4SDimitry Andric Function &Func; 135*480093f4SDimitry Andric const DataLayout &DL; 136*480093f4SDimitry Andric const TargetTransformInfo &TTI; 137*480093f4SDimitry Andric 138*480093f4SDimitry Andric /// Wrapper class representing a matrix as a set of column vectors. 139*480093f4SDimitry Andric /// All column vectors must have the same vector type. 140*480093f4SDimitry Andric class ColumnMatrixTy { 141*480093f4SDimitry Andric SmallVector<Value *, 16> Columns; 142*480093f4SDimitry Andric 143*480093f4SDimitry Andric public: 144*480093f4SDimitry Andric ColumnMatrixTy() : Columns() {} 145*480093f4SDimitry Andric ColumnMatrixTy(ArrayRef<Value *> Cols) 146*480093f4SDimitry Andric : Columns(Cols.begin(), Cols.end()) {} 147*480093f4SDimitry Andric 148*480093f4SDimitry Andric Value *getColumn(unsigned i) const { return Columns[i]; } 149*480093f4SDimitry Andric 150*480093f4SDimitry Andric void setColumn(unsigned i, Value *V) { Columns[i] = V; } 151*480093f4SDimitry Andric 152*480093f4SDimitry Andric size_t getNumColumns() const { return Columns.size(); } 153*480093f4SDimitry Andric size_t getNumRows() const { 154*480093f4SDimitry Andric assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 155*480093f4SDimitry Andric return cast<VectorType>(Columns[0]->getType())->getNumElements(); 156*480093f4SDimitry Andric } 157*480093f4SDimitry Andric 158*480093f4SDimitry Andric const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 159*480093f4SDimitry Andric 160*480093f4SDimitry Andric SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 161*480093f4SDimitry Andric 162*480093f4SDimitry Andric void addColumn(Value *V) { Columns.push_back(V); } 163*480093f4SDimitry Andric 164*480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() { 165*480093f4SDimitry Andric return make_range(Columns.begin(), Columns.end()); 166*480093f4SDimitry Andric } 167*480093f4SDimitry Andric 168*480093f4SDimitry Andric /// Embed the columns of the matrix into a flat vector by concatenating 169*480093f4SDimitry Andric /// them. 170*480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const { 171*480093f4SDimitry Andric return Columns.size() == 1 ? Columns[0] 172*480093f4SDimitry Andric : concatenateVectors(Builder, Columns); 173*480093f4SDimitry Andric } 174*480093f4SDimitry Andric }; 175*480093f4SDimitry Andric 176*480093f4SDimitry Andric struct ShapeInfo { 177*480093f4SDimitry Andric unsigned NumRows; 178*480093f4SDimitry Andric unsigned NumColumns; 179*480093f4SDimitry Andric 180*480093f4SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 181*480093f4SDimitry Andric : NumRows(NumRows), NumColumns(NumColumns) {} 182*480093f4SDimitry Andric 183*480093f4SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns) 184*480093f4SDimitry Andric : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 185*480093f4SDimitry Andric NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 186*480093f4SDimitry Andric 187*480093f4SDimitry Andric bool operator==(const ShapeInfo &other) { 188*480093f4SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns; 189*480093f4SDimitry Andric } 190*480093f4SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); } 191*480093f4SDimitry Andric 192*480093f4SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions 193*480093f4SDimitry Andric /// are != 0. 194*480093f4SDimitry Andric operator bool() const { 195*480093f4SDimitry Andric assert(NumRows == 0 || NumColumns != 0); 196*480093f4SDimitry Andric return NumRows != 0; 197*480093f4SDimitry Andric } 198*480093f4SDimitry Andric }; 199*480093f4SDimitry Andric 200*480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information 201*480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of 202*480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store 203*480093f4SDimitry Andric /// instructions and the matrix_columnwise_store intrinsics. For those, the 204*480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered 205*480093f4SDimitry Andric /// using shape information as well. 206*480093f4SDimitry Andric DenseMap<Value *, ShapeInfo> ShapeMap; 207*480093f4SDimitry Andric 208*480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all 209*480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and 210*480093f4SDimitry Andric /// those need to be removed after we finished lowering. 211*480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove; 212*480093f4SDimitry Andric 213*480093f4SDimitry Andric /// Map from instructions to their produced column matrix. 214*480093f4SDimitry Andric DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 215*480093f4SDimitry Andric 216*480093f4SDimitry Andric public: 217*480093f4SDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) 218*480093f4SDimitry Andric : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} 219*480093f4SDimitry Andric 220*480093f4SDimitry Andric /// Return the set of column vectors that a matrix value is lowered to. 221*480093f4SDimitry Andric /// 222*480093f4SDimitry Andric /// If we lowered \p MatrixVal, just return the cache result column matrix. 223*480093f4SDimitry Andric /// Otherwie split the flat vector \p MatrixVal containing a matrix with 224*480093f4SDimitry Andric /// shape \p SI into column vectors. 225*480093f4SDimitry Andric ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 226*480093f4SDimitry Andric IRBuilder<> Builder) { 227*480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 228*480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type"); 229*480093f4SDimitry Andric assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 230*480093f4SDimitry Andric "The vector size must match the number of matrix elements"); 231*480093f4SDimitry Andric 232*480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case, 233*480093f4SDimitry Andric // return the existing column matrix, if it matches the requested shape 234*480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat 235*480093f4SDimitry Andric // vector and split it later. 236*480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal); 237*480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) { 238*480093f4SDimitry Andric ColumnMatrixTy &M = Found->second; 239*480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape 240*480093f4SDimitry Andric // information 241*480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 242*480093f4SDimitry Andric return M; 243*480093f4SDimitry Andric 244*480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder); 245*480093f4SDimitry Andric } 246*480093f4SDimitry Andric 247*480093f4SDimitry Andric // Otherwise split MatrixVal. 248*480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs; 249*480093f4SDimitry Andric Value *Undef = UndefValue::get(VType); 250*480093f4SDimitry Andric for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 251*480093f4SDimitry Andric MaskStart += SI.NumRows) { 252*480093f4SDimitry Andric Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 253*480093f4SDimitry Andric Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 254*480093f4SDimitry Andric SplitVecs.push_back(V); 255*480093f4SDimitry Andric } 256*480093f4SDimitry Andric 257*480093f4SDimitry Andric return {SplitVecs}; 258*480093f4SDimitry Andric } 259*480093f4SDimitry Andric 260*480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape 261*480093f4SDimitry Andric /// for instructions that support it. 262*480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) { 263*480093f4SDimitry Andric assert(Shape && "Shape not set"); 264*480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 265*480093f4SDimitry Andric return false; 266*480093f4SDimitry Andric 267*480093f4SDimitry Andric auto SIter = ShapeMap.find(V); 268*480093f4SDimitry Andric if (SIter != ShapeMap.end()) { 269*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: " 270*480093f4SDimitry Andric << SIter->second.NumRows << " " 271*480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n"); 272*480093f4SDimitry Andric return false; 273*480093f4SDimitry Andric } 274*480093f4SDimitry Andric 275*480093f4SDimitry Andric ShapeMap.insert({V, Shape}); 276*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 277*480093f4SDimitry Andric << " for " << *V << "\n"); 278*480093f4SDimitry Andric return true; 279*480093f4SDimitry Andric } 280*480093f4SDimitry Andric 281*480093f4SDimitry Andric bool isUniformShape(Value *V) { 282*480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 283*480093f4SDimitry Andric if (!I) 284*480093f4SDimitry Andric return true; 285*480093f4SDimitry Andric 286*480093f4SDimitry Andric switch (I->getOpcode()) { 287*480093f4SDimitry Andric case Instruction::FAdd: 288*480093f4SDimitry Andric case Instruction::FSub: 289*480093f4SDimitry Andric case Instruction::FMul: // Scalar multiply. 290*480093f4SDimitry Andric case Instruction::Add: 291*480093f4SDimitry Andric case Instruction::Mul: 292*480093f4SDimitry Andric case Instruction::Sub: 293*480093f4SDimitry Andric return true; 294*480093f4SDimitry Andric default: 295*480093f4SDimitry Andric return false; 296*480093f4SDimitry Andric } 297*480093f4SDimitry Andric } 298*480093f4SDimitry Andric 299*480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported 300*480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass. 301*480093f4SDimitry Andric bool supportsShapeInfo(Value *V) { 302*480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V); 303*480093f4SDimitry Andric if (!Inst) 304*480093f4SDimitry Andric return false; 305*480093f4SDimitry Andric 306*480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 307*480093f4SDimitry Andric if (II) 308*480093f4SDimitry Andric switch (II->getIntrinsicID()) { 309*480093f4SDimitry Andric case Intrinsic::matrix_multiply: 310*480093f4SDimitry Andric case Intrinsic::matrix_transpose: 311*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_load: 312*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_store: 313*480093f4SDimitry Andric return true; 314*480093f4SDimitry Andric default: 315*480093f4SDimitry Andric return false; 316*480093f4SDimitry Andric } 317*480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 318*480093f4SDimitry Andric } 319*480093f4SDimitry Andric 320*480093f4SDimitry Andric /// Propagate the shape information of instructions to their users. 321*480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape, 322*480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known 323*480093f4SDimitry Andric /// shapes of operands. 324*480093f4SDimitry Andric SmallVector<Instruction *, 32> 325*480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 326*480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 327*480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the 328*480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work 329*480093f4SDimitry Andric // list. 330*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 331*480093f4SDimitry Andric while (!WorkList.empty()) { 332*480093f4SDimitry Andric Instruction *Inst = WorkList.back(); 333*480093f4SDimitry Andric WorkList.pop_back(); 334*480093f4SDimitry Andric 335*480093f4SDimitry Andric // New entry, set the value and insert operands 336*480093f4SDimitry Andric bool Propagate = false; 337*480093f4SDimitry Andric 338*480093f4SDimitry Andric Value *MatrixA; 339*480093f4SDimitry Andric Value *MatrixB; 340*480093f4SDimitry Andric Value *M; 341*480093f4SDimitry Andric Value *N; 342*480093f4SDimitry Andric Value *K; 343*480093f4SDimitry Andric if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 344*480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 345*480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 346*480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, K}); 347*480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 348*480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 349*480093f4SDimitry Andric // Flip dimensions. 350*480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 351*480093f4SDimitry Andric } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 352*480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 353*480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 354*480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {N, M}); 355*480093f4SDimitry Andric } else if (match(Inst, 356*480093f4SDimitry Andric m_Intrinsic<Intrinsic::matrix_columnwise_load>( 357*480093f4SDimitry Andric m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 358*480093f4SDimitry Andric Propagate = setShapeInfo(Inst, {M, N}); 359*480093f4SDimitry Andric } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 360*480093f4SDimitry Andric auto OpShape = ShapeMap.find(MatrixA); 361*480093f4SDimitry Andric if (OpShape != ShapeMap.end()) 362*480093f4SDimitry Andric setShapeInfo(Inst, OpShape->second); 363*480093f4SDimitry Andric continue; 364*480093f4SDimitry Andric } else if (isUniformShape(Inst)) { 365*480093f4SDimitry Andric // Find the first operand that has a known shape and use that. 366*480093f4SDimitry Andric for (auto &Op : Inst->operands()) { 367*480093f4SDimitry Andric auto OpShape = ShapeMap.find(Op.get()); 368*480093f4SDimitry Andric if (OpShape != ShapeMap.end()) { 369*480093f4SDimitry Andric Propagate |= setShapeInfo(Inst, OpShape->second); 370*480093f4SDimitry Andric break; 371*480093f4SDimitry Andric } 372*480093f4SDimitry Andric } 373*480093f4SDimitry Andric } 374*480093f4SDimitry Andric 375*480093f4SDimitry Andric if (Propagate) { 376*480093f4SDimitry Andric NewWorkList.push_back(Inst); 377*480093f4SDimitry Andric for (auto *User : Inst->users()) 378*480093f4SDimitry Andric if (ShapeMap.count(User) == 0) 379*480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User)); 380*480093f4SDimitry Andric } 381*480093f4SDimitry Andric } 382*480093f4SDimitry Andric 383*480093f4SDimitry Andric return NewWorkList; 384*480093f4SDimitry Andric } 385*480093f4SDimitry Andric 386*480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information. 387*480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape. 388*480093f4SDimitry Andric SmallVector<Instruction *, 32> 389*480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 390*480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList; 391*480093f4SDimitry Andric 392*480093f4SDimitry Andric auto pushInstruction = [](Value *V, 393*480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) { 394*480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 395*480093f4SDimitry Andric if (I) 396*480093f4SDimitry Andric WorkList.push_back(I); 397*480093f4SDimitry Andric }; 398*480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape 399*480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the 400*480093f4SDimitry Andric // worklist. 401*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 402*480093f4SDimitry Andric while (!WorkList.empty()) { 403*480093f4SDimitry Andric Value *V = WorkList.back(); 404*480093f4SDimitry Andric WorkList.pop_back(); 405*480093f4SDimitry Andric 406*480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size(); 407*480093f4SDimitry Andric if (!isa<Instruction>(V)) 408*480093f4SDimitry Andric continue; 409*480093f4SDimitry Andric 410*480093f4SDimitry Andric Value *MatrixA; 411*480093f4SDimitry Andric Value *MatrixB; 412*480093f4SDimitry Andric Value *M; 413*480093f4SDimitry Andric Value *N; 414*480093f4SDimitry Andric Value *K; 415*480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 416*480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 417*480093f4SDimitry Andric m_Value(N), m_Value(K)))) { 418*480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 419*480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 420*480093f4SDimitry Andric 421*480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K})) 422*480093f4SDimitry Andric pushInstruction(MatrixB, WorkList); 423*480093f4SDimitry Andric 424*480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 425*480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) { 426*480093f4SDimitry Andric // Flip dimensions. 427*480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) 428*480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 429*480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 430*480093f4SDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), 431*480093f4SDimitry Andric m_Value(M), m_Value(N)))) { 432*480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) { 433*480093f4SDimitry Andric pushInstruction(MatrixA, WorkList); 434*480093f4SDimitry Andric } 435*480093f4SDimitry Andric } else if (isa<LoadInst>(V) || 436*480093f4SDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 437*480093f4SDimitry Andric // Nothing to do, no matrix input. 438*480093f4SDimitry Andric } else if (isa<StoreInst>(V)) { 439*480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just 440*480093f4SDimitry Andric // backward propagate to an instruction with an already known shape. 441*480093f4SDimitry Andric } else if (isUniformShape(V)) { 442*480093f4SDimitry Andric // Propagate to all operands. 443*480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V]; 444*480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) { 445*480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape)) 446*480093f4SDimitry Andric pushInstruction(U.get(), WorkList); 447*480093f4SDimitry Andric } 448*480093f4SDimitry Andric } 449*480093f4SDimitry Andric // After we discovered new shape info for new instructions in the 450*480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward 451*480093f4SDimitry Andric // propagation. 452*480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 453*480093f4SDimitry Andric for (User *U : WorkList[I]->users()) 454*480093f4SDimitry Andric if (isa<Instruction>(U) && V != U) 455*480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U)); 456*480093f4SDimitry Andric } 457*480093f4SDimitry Andric return NewWorkList; 458*480093f4SDimitry Andric } 459*480093f4SDimitry Andric 460*480093f4SDimitry Andric bool Visit() { 461*480093f4SDimitry Andric if (EnableShapePropagation) { 462*480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList; 463*480093f4SDimitry Andric 464*480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known. 465*480093f4SDimitry Andric // Initialize the work list with ops carrying shape information. 466*480093f4SDimitry Andric for (BasicBlock &BB : Func) 467*480093f4SDimitry Andric for (Instruction &Inst : BB) { 468*480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 469*480093f4SDimitry Andric if (!II) 470*480093f4SDimitry Andric continue; 471*480093f4SDimitry Andric 472*480093f4SDimitry Andric switch (II->getIntrinsicID()) { 473*480093f4SDimitry Andric case Intrinsic::matrix_multiply: 474*480093f4SDimitry Andric case Intrinsic::matrix_transpose: 475*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_load: 476*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_store: 477*480093f4SDimitry Andric WorkList.push_back(&Inst); 478*480093f4SDimitry Andric break; 479*480093f4SDimitry Andric default: 480*480093f4SDimitry Andric break; 481*480093f4SDimitry Andric } 482*480093f4SDimitry Andric } 483*480093f4SDimitry Andric // Propagate shapes until nothing changes any longer. 484*480093f4SDimitry Andric while (!WorkList.empty()) { 485*480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList); 486*480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList); 487*480093f4SDimitry Andric } 488*480093f4SDimitry Andric } 489*480093f4SDimitry Andric 490*480093f4SDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func); 491*480093f4SDimitry Andric bool Changed = false; 492*480093f4SDimitry Andric for (auto *BB : RPOT) { 493*480093f4SDimitry Andric for (Instruction &Inst : make_early_inc_range(*BB)) { 494*480093f4SDimitry Andric IRBuilder<> Builder(&Inst); 495*480093f4SDimitry Andric 496*480093f4SDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 497*480093f4SDimitry Andric Changed |= VisitCallInst(CInst); 498*480093f4SDimitry Andric 499*480093f4SDimitry Andric Value *Op1; 500*480093f4SDimitry Andric Value *Op2; 501*480093f4SDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 502*480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp); 503*480093f4SDimitry Andric if (match(&Inst, m_Load(m_Value(Op1)))) 504*480093f4SDimitry Andric Changed |= VisitLoad(&Inst, Op1, Builder); 505*480093f4SDimitry Andric else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 506*480093f4SDimitry Andric Changed |= VisitStore(&Inst, Op1, Op2, Builder); 507*480093f4SDimitry Andric } 508*480093f4SDimitry Andric } 509*480093f4SDimitry Andric 510*480093f4SDimitry Andric for (Instruction *Inst : reverse(ToRemove)) 511*480093f4SDimitry Andric Inst->eraseFromParent(); 512*480093f4SDimitry Andric 513*480093f4SDimitry Andric return Changed; 514*480093f4SDimitry Andric } 515*480093f4SDimitry Andric 516*480093f4SDimitry Andric LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 517*480093f4SDimitry Andric IRBuilder<> Builder) { 518*480093f4SDimitry Andric unsigned Align = DL.getABITypeAlignment(EltType); 519*480093f4SDimitry Andric return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); 520*480093f4SDimitry Andric } 521*480093f4SDimitry Andric 522*480093f4SDimitry Andric StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 523*480093f4SDimitry Andric Type *EltType, IRBuilder<> Builder) { 524*480093f4SDimitry Andric unsigned Align = DL.getABITypeAlignment(EltType); 525*480093f4SDimitry Andric return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); 526*480093f4SDimitry Andric } 527*480093f4SDimitry Andric 528*480093f4SDimitry Andric 529*480093f4SDimitry Andric /// Turns \p BasePtr into an elementwise pointer to \p EltType. 530*480093f4SDimitry Andric Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 531*480093f4SDimitry Andric unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 532*480093f4SDimitry Andric Type *EltPtrType = PointerType::get(EltType, AS); 533*480093f4SDimitry Andric return Builder.CreatePointerCast(BasePtr, EltPtrType); 534*480093f4SDimitry Andric } 535*480093f4SDimitry Andric 536*480093f4SDimitry Andric /// Replace intrinsic calls 537*480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) { 538*480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 539*480093f4SDimitry Andric return false; 540*480093f4SDimitry Andric 541*480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) { 542*480093f4SDimitry Andric case Intrinsic::matrix_multiply: 543*480093f4SDimitry Andric LowerMultiply(Inst); 544*480093f4SDimitry Andric break; 545*480093f4SDimitry Andric case Intrinsic::matrix_transpose: 546*480093f4SDimitry Andric LowerTranspose(Inst); 547*480093f4SDimitry Andric break; 548*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_load: 549*480093f4SDimitry Andric LowerColumnwiseLoad(Inst); 550*480093f4SDimitry Andric break; 551*480093f4SDimitry Andric case Intrinsic::matrix_columnwise_store: 552*480093f4SDimitry Andric LowerColumnwiseStore(Inst); 553*480093f4SDimitry Andric break; 554*480093f4SDimitry Andric default: 555*480093f4SDimitry Andric return false; 556*480093f4SDimitry Andric } 557*480093f4SDimitry Andric return true; 558*480093f4SDimitry Andric } 559*480093f4SDimitry Andric 560*480093f4SDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 561*480093f4SDimitry Andric ShapeInfo Shape) { 562*480093f4SDimitry Andric IRBuilder<> Builder(Inst); 563*480093f4SDimitry Andric auto VType = cast<VectorType>(Inst->getType()); 564*480093f4SDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 565*480093f4SDimitry Andric ColumnMatrixTy Result; 566*480093f4SDimitry Andric // Distance between start of one column and the start of the next 567*480093f4SDimitry Andric for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 568*480093f4SDimitry Andric Value *GEP = 569*480093f4SDimitry Andric computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 570*480093f4SDimitry Andric VType->getElementType(), Builder); 571*480093f4SDimitry Andric Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 572*480093f4SDimitry Andric Result.addColumn(Column); 573*480093f4SDimitry Andric } 574*480093f4SDimitry Andric 575*480093f4SDimitry Andric finalizeLowering(Inst, Result, Builder); 576*480093f4SDimitry Andric } 577*480093f4SDimitry Andric 578*480093f4SDimitry Andric /// Lowers llvm.matrix.columnwise.load. 579*480093f4SDimitry Andric /// 580*480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns. 581*480093f4SDimitry Andric void LowerColumnwiseLoad(CallInst *Inst) { 582*480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0); 583*480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1); 584*480093f4SDimitry Andric LowerLoad(Inst, Ptr, Stride, 585*480093f4SDimitry Andric {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 586*480093f4SDimitry Andric } 587*480093f4SDimitry Andric 588*480093f4SDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 589*480093f4SDimitry Andric ShapeInfo Shape) { 590*480093f4SDimitry Andric IRBuilder<> Builder(Inst); 591*480093f4SDimitry Andric auto VType = cast<VectorType>(Matrix->getType()); 592*480093f4SDimitry Andric Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 593*480093f4SDimitry Andric auto LM = getMatrix(Matrix, Shape, Builder); 594*480093f4SDimitry Andric for (auto C : enumerate(LM.columns())) { 595*480093f4SDimitry Andric Value *GEP = 596*480093f4SDimitry Andric computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 597*480093f4SDimitry Andric Shape.NumRows, VType->getElementType(), Builder); 598*480093f4SDimitry Andric createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 599*480093f4SDimitry Andric } 600*480093f4SDimitry Andric 601*480093f4SDimitry Andric ToRemove.push_back(Inst); 602*480093f4SDimitry Andric } 603*480093f4SDimitry Andric 604*480093f4SDimitry Andric /// Lowers llvm.matrix.columnwise.store. 605*480093f4SDimitry Andric /// 606*480093f4SDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns. 607*480093f4SDimitry Andric void LowerColumnwiseStore(CallInst *Inst) { 608*480093f4SDimitry Andric Value *Matrix = Inst->getArgOperand(0); 609*480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(1); 610*480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(2); 611*480093f4SDimitry Andric LowerStore(Inst, Matrix, Ptr, Stride, 612*480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 613*480093f4SDimitry Andric } 614*480093f4SDimitry Andric 615*480093f4SDimitry Andric /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 616*480093f4SDimitry Andric /// the matrix \p LM represented as a vector of column vectors. 617*480093f4SDimitry Andric Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 618*480093f4SDimitry Andric unsigned NumElts, IRBuilder<> Builder) { 619*480093f4SDimitry Andric Value *Col = LM.getColumn(J); 620*480093f4SDimitry Andric Value *Undef = UndefValue::get(Col->getType()); 621*480093f4SDimitry Andric Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 622*480093f4SDimitry Andric return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 623*480093f4SDimitry Andric } 624*480093f4SDimitry Andric 625*480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block 626*480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block, 627*480093f4SDimitry Andric IRBuilder<> Builder) { 628*480093f4SDimitry Andric 629*480093f4SDimitry Andric // First, bring Block to the same size as Col 630*480093f4SDimitry Andric unsigned BlockNumElts = 631*480093f4SDimitry Andric cast<VectorType>(Block->getType())->getNumElements(); 632*480093f4SDimitry Andric unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 633*480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block"); 634*480093f4SDimitry Andric 635*480093f4SDimitry Andric Value *ExtendMask = 636*480093f4SDimitry Andric createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 637*480093f4SDimitry Andric Value *Undef = UndefValue::get(Block->getType()); 638*480093f4SDimitry Andric Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 639*480093f4SDimitry Andric 640*480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 641*480093f4SDimitry Andric // 8, 4, 5, 6 642*480093f4SDimitry Andric SmallVector<Constant *, 16> Mask; 643*480093f4SDimitry Andric unsigned i; 644*480093f4SDimitry Andric for (i = 0; i < I; i++) 645*480093f4SDimitry Andric Mask.push_back(Builder.getInt32(i)); 646*480093f4SDimitry Andric 647*480093f4SDimitry Andric unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 648*480093f4SDimitry Andric for (; i < I + BlockNumElts; i++) 649*480093f4SDimitry Andric Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 650*480093f4SDimitry Andric 651*480093f4SDimitry Andric for (; i < VecNumElts; i++) 652*480093f4SDimitry Andric Mask.push_back(Builder.getInt32(i)); 653*480093f4SDimitry Andric 654*480093f4SDimitry Andric Value *MaskVal = ConstantVector::get(Mask); 655*480093f4SDimitry Andric 656*480093f4SDimitry Andric return Builder.CreateShuffleVector(Col, Block, MaskVal); 657*480093f4SDimitry Andric } 658*480093f4SDimitry Andric 659*480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 660*480093f4SDimitry Andric IRBuilder<> &Builder, bool AllowContraction) { 661*480093f4SDimitry Andric 662*480093f4SDimitry Andric if (!Sum) 663*480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 664*480093f4SDimitry Andric 665*480093f4SDimitry Andric if (UseFPOp) { 666*480093f4SDimitry Andric if (AllowContraction) { 667*480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide 668*480093f4SDimitry Andric // if that's profitable. 669*480093f4SDimitry Andric Value *FMulAdd = Intrinsic::getDeclaration( 670*480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType()); 671*480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum}); 672*480093f4SDimitry Andric } 673*480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B); 674*480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul); 675*480093f4SDimitry Andric } 676*480093f4SDimitry Andric 677*480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B); 678*480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul); 679*480093f4SDimitry Andric } 680*480093f4SDimitry Andric 681*480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 682*480093f4SDimitry Andric /// users with shape information, there's nothing to do: the will use the 683*480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is 684*480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for 685*480093f4SDimitry Andric /// deletion. 686*480093f4SDimitry Andric void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 687*480093f4SDimitry Andric IRBuilder<> &Builder) { 688*480093f4SDimitry Andric Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 689*480093f4SDimitry Andric 690*480093f4SDimitry Andric ToRemove.push_back(Inst); 691*480093f4SDimitry Andric Value *Flattened = nullptr; 692*480093f4SDimitry Andric for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 693*480093f4SDimitry Andric Use &U = *I++; 694*480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 695*480093f4SDimitry Andric if (!Flattened) 696*480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder); 697*480093f4SDimitry Andric U.set(Flattened); 698*480093f4SDimitry Andric } 699*480093f4SDimitry Andric } 700*480093f4SDimitry Andric } 701*480093f4SDimitry Andric 702*480093f4SDimitry Andric /// Lowers llvm.matrix.multiply. 703*480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) { 704*480093f4SDimitry Andric IRBuilder<> Builder(MatMul); 705*480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 706*480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 707*480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 708*480093f4SDimitry Andric 709*480093f4SDimitry Andric const ColumnMatrixTy &Lhs = 710*480093f4SDimitry Andric getMatrix(MatMul->getArgOperand(0), LShape, Builder); 711*480093f4SDimitry Andric const ColumnMatrixTy &Rhs = 712*480093f4SDimitry Andric getMatrix(MatMul->getArgOperand(1), RShape, Builder); 713*480093f4SDimitry Andric 714*480093f4SDimitry Andric const unsigned R = LShape.NumRows; 715*480093f4SDimitry Andric const unsigned M = LShape.NumColumns; 716*480093f4SDimitry Andric const unsigned C = RShape.NumColumns; 717*480093f4SDimitry Andric assert(M == RShape.NumRows); 718*480093f4SDimitry Andric 719*480093f4SDimitry Andric // Initialize the output 720*480093f4SDimitry Andric ColumnMatrixTy Result; 721*480093f4SDimitry Andric for (unsigned J = 0; J < C; ++J) 722*480093f4SDimitry Andric Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 723*480093f4SDimitry Andric 724*480093f4SDimitry Andric const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 725*480093f4SDimitry Andric EltType->getPrimitiveSizeInBits(), 726*480093f4SDimitry Andric uint64_t(1)); 727*480093f4SDimitry Andric 728*480093f4SDimitry Andric bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 729*480093f4SDimitry Andric MatMul->hasAllowContract()); 730*480093f4SDimitry Andric // Multiply columns from the first operand with scalars from the second 731*480093f4SDimitry Andric // operand. Then move along the K axes and accumulate the columns. With 732*480093f4SDimitry Andric // this the adds can be vectorized without reassociation. 733*480093f4SDimitry Andric for (unsigned J = 0; J < C; ++J) { 734*480093f4SDimitry Andric unsigned BlockSize = VF; 735*480093f4SDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) { 736*480093f4SDimitry Andric // Gradually lower the vectorization factor to cover the remainder. 737*480093f4SDimitry Andric while (I + BlockSize > R) 738*480093f4SDimitry Andric BlockSize /= 2; 739*480093f4SDimitry Andric 740*480093f4SDimitry Andric Value *Sum = nullptr; 741*480093f4SDimitry Andric for (unsigned K = 0; K < M; ++K) { 742*480093f4SDimitry Andric Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 743*480093f4SDimitry Andric Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 744*480093f4SDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 745*480093f4SDimitry Andric Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 746*480093f4SDimitry Andric Builder, AllowContract); 747*480093f4SDimitry Andric } 748*480093f4SDimitry Andric Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 749*480093f4SDimitry Andric } 750*480093f4SDimitry Andric } 751*480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder); 752*480093f4SDimitry Andric } 753*480093f4SDimitry Andric 754*480093f4SDimitry Andric /// Lowers llvm.matrix.transpose. 755*480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) { 756*480093f4SDimitry Andric ColumnMatrixTy Result; 757*480093f4SDimitry Andric IRBuilder<> Builder(Inst); 758*480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0); 759*480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 760*480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 761*480093f4SDimitry Andric ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 762*480093f4SDimitry Andric 763*480093f4SDimitry Andric for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 764*480093f4SDimitry Andric // Build a single column vector for this row. First initialize it. 765*480093f4SDimitry Andric Value *ResultColumn = UndefValue::get( 766*480093f4SDimitry Andric VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 767*480093f4SDimitry Andric 768*480093f4SDimitry Andric // Go through the elements of this row and insert it into the resulting 769*480093f4SDimitry Andric // column vector. 770*480093f4SDimitry Andric for (auto C : enumerate(InputMatrix.columns())) { 771*480093f4SDimitry Andric Value *Elt = Builder.CreateExtractElement(C.value(), Row); 772*480093f4SDimitry Andric // We insert at index Column since that is the row index after the 773*480093f4SDimitry Andric // transpose. 774*480093f4SDimitry Andric ResultColumn = 775*480093f4SDimitry Andric Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 776*480093f4SDimitry Andric } 777*480093f4SDimitry Andric Result.addColumn(ResultColumn); 778*480093f4SDimitry Andric } 779*480093f4SDimitry Andric 780*480093f4SDimitry Andric finalizeLowering(Inst, Result, Builder); 781*480093f4SDimitry Andric } 782*480093f4SDimitry Andric 783*480093f4SDimitry Andric /// Lower load instructions, if shape information is available. 784*480093f4SDimitry Andric bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 785*480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 786*480093f4SDimitry Andric if (I == ShapeMap.end()) 787*480093f4SDimitry Andric return false; 788*480093f4SDimitry Andric 789*480093f4SDimitry Andric LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 790*480093f4SDimitry Andric return true; 791*480093f4SDimitry Andric } 792*480093f4SDimitry Andric 793*480093f4SDimitry Andric bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 794*480093f4SDimitry Andric IRBuilder<> &Builder) { 795*480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal); 796*480093f4SDimitry Andric if (I == ShapeMap.end()) 797*480093f4SDimitry Andric return false; 798*480093f4SDimitry Andric 799*480093f4SDimitry Andric LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 800*480093f4SDimitry Andric return true; 801*480093f4SDimitry Andric } 802*480093f4SDimitry Andric 803*480093f4SDimitry Andric /// Lower binary operators, if shape information is available. 804*480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) { 805*480093f4SDimitry Andric auto I = ShapeMap.find(Inst); 806*480093f4SDimitry Andric if (I == ShapeMap.end()) 807*480093f4SDimitry Andric return false; 808*480093f4SDimitry Andric 809*480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0); 810*480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1); 811*480093f4SDimitry Andric 812*480093f4SDimitry Andric IRBuilder<> Builder(Inst); 813*480093f4SDimitry Andric ShapeInfo &Shape = I->second; 814*480093f4SDimitry Andric 815*480093f4SDimitry Andric ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 816*480093f4SDimitry Andric ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 817*480093f4SDimitry Andric 818*480093f4SDimitry Andric // Add each column and store the result back into the opmapping 819*480093f4SDimitry Andric ColumnMatrixTy Result; 820*480093f4SDimitry Andric auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 821*480093f4SDimitry Andric switch (Inst->getOpcode()) { 822*480093f4SDimitry Andric case Instruction::Add: 823*480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS); 824*480093f4SDimitry Andric case Instruction::Mul: 825*480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS); 826*480093f4SDimitry Andric case Instruction::Sub: 827*480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS); 828*480093f4SDimitry Andric case Instruction::FAdd: 829*480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS); 830*480093f4SDimitry Andric case Instruction::FMul: 831*480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS); 832*480093f4SDimitry Andric case Instruction::FSub: 833*480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS); 834*480093f4SDimitry Andric default: 835*480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix"); 836*480093f4SDimitry Andric } 837*480093f4SDimitry Andric }; 838*480093f4SDimitry Andric for (unsigned C = 0; C < Shape.NumColumns; ++C) 839*480093f4SDimitry Andric Result.addColumn( 840*480093f4SDimitry Andric BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 841*480093f4SDimitry Andric 842*480093f4SDimitry Andric finalizeLowering(Inst, Result, Builder); 843*480093f4SDimitry Andric return true; 844*480093f4SDimitry Andric } 845*480093f4SDimitry Andric }; 846*480093f4SDimitry Andric } // namespace 847*480093f4SDimitry Andric 848*480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 849*480093f4SDimitry Andric FunctionAnalysisManager &AM) { 850*480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F); 851*480093f4SDimitry Andric LowerMatrixIntrinsics LMT(F, TTI); 852*480093f4SDimitry Andric if (LMT.Visit()) { 853*480093f4SDimitry Andric PreservedAnalyses PA; 854*480093f4SDimitry Andric PA.preserveSet<CFGAnalyses>(); 855*480093f4SDimitry Andric return PA; 856*480093f4SDimitry Andric } 857*480093f4SDimitry Andric return PreservedAnalyses::all(); 858*480093f4SDimitry Andric } 859*480093f4SDimitry Andric 860*480093f4SDimitry Andric namespace { 861*480093f4SDimitry Andric 862*480093f4SDimitry Andric class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 863*480093f4SDimitry Andric public: 864*480093f4SDimitry Andric static char ID; 865*480093f4SDimitry Andric 866*480093f4SDimitry Andric LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 867*480093f4SDimitry Andric initializeLowerMatrixIntrinsicsLegacyPassPass( 868*480093f4SDimitry Andric *PassRegistry::getPassRegistry()); 869*480093f4SDimitry Andric } 870*480093f4SDimitry Andric 871*480093f4SDimitry Andric bool runOnFunction(Function &F) override { 872*480093f4SDimitry Andric auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 873*480093f4SDimitry Andric LowerMatrixIntrinsics LMT(F, *TTI); 874*480093f4SDimitry Andric bool C = LMT.Visit(); 875*480093f4SDimitry Andric return C; 876*480093f4SDimitry Andric } 877*480093f4SDimitry Andric 878*480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 879*480093f4SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 880*480093f4SDimitry Andric AU.setPreservesCFG(); 881*480093f4SDimitry Andric } 882*480093f4SDimitry Andric }; 883*480093f4SDimitry Andric } // namespace 884*480093f4SDimitry Andric 885*480093f4SDimitry Andric static const char pass_name[] = "Lower the matrix intrinsics"; 886*480093f4SDimitry Andric char LowerMatrixIntrinsicsLegacyPass::ID = 0; 887*480093f4SDimitry Andric INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 888*480093f4SDimitry Andric false, false) 889*480093f4SDimitry Andric INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 890*480093f4SDimitry Andric false, false) 891*480093f4SDimitry Andric 892*480093f4SDimitry Andric Pass *llvm::createLowerMatrixIntrinsicsPass() { 893*480093f4SDimitry Andric return new LowerMatrixIntrinsicsLegacyPass(); 894*480093f4SDimitry Andric } 895