Lines Matching +full:matrix +full:- +full:matrix +full:- +full:transpose
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Lower matrix intrinsics to vector operations.
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
18 //===----------------------------------------------------------------------===//
53 #define DEBUG_TYPE "lower-matrix-intrinsics"
56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
57 cl::desc("Enable/disable fusing matrix instructions."));
58 // TODO: Allow and use non-square tiles.
60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
63 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
67 "force-fuse-matrix", cl::init(false), cl::Hidden,
68 cl::desc("Force matrix instruction fusion even if not profitable."));
70 "matrix-allow-contract", cl::init(false), cl::Hidden,
75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
76 cl::desc("Enable/disable matrix shape verification."),
82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
83 cl::desc("Sets the default matrix layout"),
84 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
85 "Use column-major layout"),
86 clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
87 "Use row-major layout")));
89 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
97 return cast<DILocalScope>(Scope)->getSubprogram();
101 /// matrix with a scalar).
104 return SV->isZeroEltSplat();
122 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
126 // For column-major matrixes, the function computes the address of a column
128 // (= number of rows of the matrix). For row-major matrixes, the function
130 // number of elements in a column (= number of columns of the matrix).
132 // Consider a 4x4 matrix in column-mjaor layout like below
140 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
143 // of the sub-matrix.
146 // -> just returns Base
148 // -> returns Base + (1 * 4)
150 // -> returns Base + (2 * 4)
167 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
175 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
195 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
196 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
203 /// Returns true if shape-information is defined, meaning both dimensions
232 switch (I->getOpcode()) {
272 return OpShape->second;
277 for (auto &Op : I->operands()) {
280 return OpShape->second;
286 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
288 /// Currently, the lowering for each matrix intrinsic is done as follows:
291 /// 2. Lower instructions with shape information (assuming column-major layout).
292 /// The lowering works similarly using row-major layout.
295 /// If not, split the operand vector containing an embedded matrix into
298 /// yields a set of column vectors containing result matrix. Note that we
303 /// column matrix when lowering the user. For other uses, we embed the
304 /// result matrix in a flat vector and update the use.
305 /// 2.4. Cache the result column matrix for the instruction we lowered
319 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
321 /// Number of stores emitted to generate this matrix.
323 /// Number of loads emitted to generate this matrix.
325 /// Number of compute operations emitted to generate this matrix.
327 /// Most of the time transposes can be fused with matrix multiplies or can
341 /// Wrapper class representing a matrix as a set of vectors, either in row or
366 assert(isColumnMajor() && "only supported for column-major matrixes");
370 assert(!isColumnMajor() && "only supported for row-major matrixes");
376 Type *getElementType() const { return getVectorTy()->getElementType(); }
389 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
395 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
402 assert(isColumnMajor() && "only supported for column-major matrixes");
407 return cast<VectorType>(Vectors[0]->getType());
412 "columns() only supported for column-major matrixes");
420 /// Embed the vectors of the matrix into a flat vector by concatenating
464 /// matrix is column-major, the result vector is extracted from a column
469 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
485 /// replacement is also a matrix operation, use
499 /// Map from instructions to their produced column matrix.
507 FMF = Inst->getFastMathFlags();
521 return getNumOps(VT->getScalarType(),
522 cast<FixedVectorType>(VT)->getNumElements());
533 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
539 /// Return the set of vectors that a matrix value is lowered to.
541 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
542 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
546 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
548 assert(cast<FixedVectorType>(VType)->getNumElements() ==
550 "The vector size must match the number of matrix elements");
553 // return the existing matrix, if it matches the requested shape
554 // information. If there is a mis-match, embed the result in a flat
558 MatrixTy &M = Found->second;
559 // Return the found matrix, if its shape matches the requested shape
570 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
590 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
591 SIter->second.NumColumns != Shape.NumColumns)) {
592 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
593 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
596 "Matrix shape verification failed, compilation aborted!");
600 << SIter->second.NumRows << " "
601 << SIter->second.NumColumns << " for " << *V << "\n");
620 switch (II->getIntrinsicID()) {
634 /// either based on the information provided by matrix intrinsics or known
642 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
653 for (auto *User : Inst->users())
677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
712 // Nothing to do, no matrix input.
714 // Nothing to do. We forward-propagated to this so we would just
719 for (Use &U : cast<Instruction>(V)->operands()) {
728 for (User *U : WorkList[I]->users())
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T
736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
760 Inst->eraseFromParent();
769 if (!Inst->use_empty())
786 ShapeMap.insert({New, S->second});
791 /// Sink a top-level transpose inside matmuls and adds.
806 // Transpose of a transpose is a nop
815 // k^T -> k
822 // (A * B)^t -> B^t * A^t
842 // (A * k)^t -> A^t * k
852 bool IsFP = I.getType()->isFPOrFPVectorTy();
865 // (A + B)^t -> A^t + B^t
872 bool IsFP = I.getType()->isFPOrFPVectorTy();
894 if (A->use_empty())
896 if (A != B && B->use_empty())
902 // A^t * B ^t -> (B * A)^t
911 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
913 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
914 R->getZExtValue());
918 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919 // the shape of the second transpose is different, there's a shape conflict
930 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
960 // If we have a TT matmul or a TT add, lift the transpose. We may be able
972 // Initially only the shape of matrix intrinsics is known.
980 switch (II->getIntrinsicID()) {
992 // Avoid unnecessary work if there are no matrix intrinsics in the function.
997 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func);
998 AA = &AM->getResult<AAManager>(Func);
999 DT = &AM->getResult<DominatorTreeAnalysis>(Func);
1000 LI = &AM->getResult<LoopAnalysis>(Func);
1012 dbgs() << "Dump after matrix transpose optimization:\n";
1023 // fusion (currently only matrix multiplies).
1076 // having to update as many def-use and use-def chains.
1086 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1089 U.set(PoisonValue::get(Inst->getType()));
1091 Inst->eraseFromParent();
1107 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1110 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1132 /// non-ConstantInt strides, return the common alignment of the initial
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1149 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1154 Type *EltTy = VType->getElementType();
1160 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1172 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1196 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1201 /// Lowers llvm.matrix.column.major.load.
1203 /// The intrinsic loads a matrix from memory using a stride between columns.
1206 "Intrinsic only supports column-major layout!");
1207 Value *Ptr = Inst->getArgOperand(0);
1208 Value *Stride = Inst->getArgOperand(1);
1209 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1210 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1214 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1230 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1240 Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1245 VType->getElementType(),
1254 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1257 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1259 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1264 /// Lowers llvm.matrix.column.major.store.
1266 /// The intrinsic store a matrix back memory using a stride between columns.
1269 "Intrinsic only supports column-major layout!");
1270 Value *Matrix = Inst->getArgOperand(0);
1271 Value *Ptr = Inst->getArgOperand(1);
1272 Value *Stride = Inst->getArgOperand(2);
1273 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1274 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1278 // Set elements I..I+NumElts-1 to Block
1284 cast<FixedVectorType>(Block->getType())->getNumElements();
1285 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1289 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1299 cast<FixedVectorType>(Col->getType())->getNumElements();
1301 Mask.push_back(i - I + VecNumElts);
1312 NumComputeOps += getNumOps(A->getType());
1320 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1323 NumComputeOps += getNumOps(A->getType());
1328 NumComputeOps += getNumOps(A->getType());
1333 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1335 /// cached value when they are lowered. For other users, \p Matrix is
1338 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1340 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1342 assert(inserted.second && "multiple matrix lowering mapping");
1346 for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1349 Flattened = Matrix.embedInVector(Builder);
1355 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1364 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1365 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1370 Value *LHS = MatMul->getArgOperand(0);
1371 Value *RHS = MatMul->getArgOperand(1);
1373 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1374 bool IsIntVec = ElementType->isIntegerTy();
1392 // dot-product lowering.
1400 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1401 Type *EltTy = VecTy->getElementType();
1415 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1419 cast<Instruction>(Op)->getOpcode(), VecTy);
1420 return NewCost - OriginalCost;
1424 // The transpose can be skipped for the dot product lowering, roughly
1429 EmbedCost -=
1439 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1445 // between the flattened and matrix versions.
1463 WorkList.append(I->op_begin(), I->op_end());
1471 AddOpCode, cast<VectorType>(LHS->getType()),
1473 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1476 (LShape.NumColumns - 1) +
1479 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1495 It->second = It->second.t();
1505 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1506 Op->replaceAllUsesWith(NewLoad);
1512 Op->replaceAllUsesWith(Arg);
1520 LHS = MatMul->getArgOperand(0);
1531 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1534 cast<Instruction>(Result)->setFastMathFlags(FMF);
1537 // pack scalar back into a matrix and then replace matmul inst
1538 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
1540 MatMul->replaceAllUsesWith(Result);
1545 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1548 /// We can fold a transpose into the operand that is used to extract scalars.
1549 /// This is the first operands with row-major and the second with
1550 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1564 bool IsFP = Result.getElementType()->isFloatingPointTy();
1567 "operands must agree on matrix layout");
1642 if (AA->isNoAlias(LoadLoc, StoreLoc))
1643 return Load->getPointerOperand();
1649 BasicBlock *Check0 = MatMul->getParent();
1655 DTUpdates.push_back({DT->Delete, Check0, Succ});
1658 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1661 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1664 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1671 Check0->getTerminator()->eraseFromParent();
1673 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1687 Check1->getTerminator()->eraseFromParent();
1688 Builder.SetInsertPoint(Check1, Check1->begin());
1696 Builder.SetInsertPoint(Copy, Copy->begin());
1697 auto *VT = cast<FixedVectorType>(Load->getType());
1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1702 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1704 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1705 Load->getAlign(), LoadLoc.Size.getValue());
1706 Builder.SetInsertPoint(Fusion, Fusion->begin());
1707 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1708 PHI->addIncoming(Load->getPointerOperand(), Check0);
1709 PHI->addIncoming(Load->getPointerOperand(), Check1);
1710 PHI->addIncoming(Alloca, Copy);
1713 DTUpdates.push_back({DT->Insert, Check0, Check1});
1714 DTUpdates.push_back({DT->Insert, Check0, Fusion});
1715 DTUpdates.push_back({DT->Insert, Check1, Copy});
1716 DTUpdates.push_back({DT->Insert, Check1, Fusion});
1717 DT->applyUpdates(DTUpdates);
1725 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1726 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1731 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1736 EltType->getPrimitiveSizeInBits().getFixedValue(),
1751 unsigned Op0Regs = (R + VF - 1) / VF * M;
1752 unsigned Op1Regs = (M + VF - 1) / VF * C;
1767 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1773 BasicBlock *Start = InsertI->getParent();
1775 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1780 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1788 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1789 TI.RowLoop.Header->getSingleSuccessor());
1796 Builder.SetInsertPoint(InnerBody->getTerminator());
1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1808 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1818 // currently the cost-model is not up to the task.
1820 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1828 "Tiling only supported for column-major matrixes at the moment!");
1832 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1833 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1838 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1842 Value *CPtr = Store->getPointerOperand();
1850 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1851 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1855 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1857 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1861 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1867 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1878 if (LoadOp0->hasNUses(0)) {
1882 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1888 /// Try to lower matrix multiply chains by fusing operations.
1901 Value *A = MatMul->getArgOperand(0);
1902 Value *B = MatMul->getArgOperand(1);
1904 // We can fold the transpose into the operand that is used to fetch scalars.
1910 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1911 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1912 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1920 Value *Transpose;
1924 Transpose = B;
1928 Transpose = A;
1938 if (Transpose->hasOneUse()) {
1939 FusedInsts.insert(cast<Instruction>(Transpose));
1940 ToRemove.push_back(cast<Instruction>(Transpose));
1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1949 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1952 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1956 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1961 WorkList.insert(Store->getOperand(1));
1970 if (DT->dominates(CurrI, MatMul))
1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1979 return DT->dominates(A, B);
1982 I->moveBefore(MatMul->getIterator());
1993 BasicBlock *StoreParent = Store->getParent();
1994 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1995 LoadOp1->getParent() == StoreParent;
2001 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2003 if (DT->dominates(Store, End))
2007 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2015 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2021 if (End->getParent() == StoreParent) {
2022 End->moveAfter(Store);
2038 /// Lowers llvm.matrix.multiply.
2041 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2042 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2043 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2045 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2046 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2048 "Matrix multiply argument element types do not match.");
2057 "Matrix multiply result element type does not match arguments.");
2064 /// Lowers llvm.matrix.transpose.
2068 Value *InputVal = Inst->getArgOperand(0);
2069 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2070 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2081 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2108 LowerLoad(Inst, Ptr, Inst->getAlign(),
2109 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2110 I->second);
2120 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2121 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
2122 I->second);
2132 Value *Lhs = Inst->getOperand(0);
2133 Value *Rhs = Inst->getOperand(1);
2136 ShapeInfo &Shape = I->second;
2143 "operands must agree on matrix layout");
2149 switch (Inst->getOpcode()) {
2163 llvm_unreachable("Unsupported binary operator for matrix");
2183 Value *Op = Inst->getOperand(0);
2186 ShapeInfo &Shape = I->second;
2195 switch (Inst->getOpcode()) {
2199 llvm_unreachable("Unsupported unary operator for matrix");
2213 /// Helper to linearize a matrix expression tree into a string. Currently
2214 /// matrix expressions are linarized by starting at an expression leaf and
2224 /// matrix instructions.
2231 /// Set of matrix expressions in the scope of a given DISubprogram.
2237 /// Used to keep track of sub-expressions that get reused while linearizing
2238 /// the expression. Re-used sub-expressions are marked as (reused).
2276 else if (V->getType()->isPointerTy())
2281 /// Returns true if \p V is a matrix value in the given subprogram.
2284 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2291 SS << M->second.getNumRows();
2293 SS << M->second.getNumColumns();
2297 /// Write the called function name. Handles calls to llvm.matrix.*
2301 if (!CI->getCalledFunction())
2304 StringRef Name = CI->getCalledFunction()->getName();
2305 if (!Name.starts_with("llvm.matrix")) {
2310 write(Intrinsic::getBaseName(II->getIntrinsicID())
2311 .drop_front(StringRef("llvm.matrix.").size()));
2316 switch (II->getIntrinsicID()) {
2318 prettyPrintMatrixType(II->getOperand(0), SS);
2320 prettyPrintMatrixType(II->getOperand(1), SS);
2321 SS << "." << *II->getType()->getScalarType();
2324 prettyPrintMatrixType(II->getOperand(0), SS);
2325 SS << "." << *II->getType()->getScalarType();
2329 SS << "." << *II->getType()->getScalarType();
2332 prettyPrintMatrixType(II->getOperand(0), SS);
2333 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2344 switch (II->getIntrinsicID()) {
2361 /// either print the constant or "scalar"/"matrix" for other values.
2364 if (V->getType()->isPointerTy()) {
2372 if (!V->getName().empty()) {
2373 Stream << " %" << V->getName() << "";
2374 LineLength += V->getName().size() + 2;
2383 TmpStream << CI->getValue();
2388 TmpStream << "matrix";
2398 /// Expressions that are re-used multiple times are prefixed with (reused)
2399 /// at the re-used root instruction.
2412 assert(SI != Shared.end() && SI->second.count(Leaf));
2414 for (Value *S : SI->second) {
2417 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2421 ExprShared = SI->second.size() > 1;
2431 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2434 // non-matrix ops.
2435 write("matrix");
2438 Ops.append(I->value_op_begin(), I->value_op_end());
2439 write(std::string(I->getOpcodeName()));
2469 /// Generate remarks for matrix operations in a function. To generate remarks
2470 /// for matrix expressions, the following approach is used:
2471 /// 1. Use the inlined-at debug information to group matrix operations to the
2473 /// 2. Collect leaves of matrix expressions (done in
2474 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2475 // mapping. Leaves are lowered matrix instructions without other matrix
2478 /// matrix expression. The expression is linearized by a recursive
2479 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2480 /// that multiple leaves can share sub-expressions. Shared subexpressions
2500 if (Expr->getType()->isVoidTy() ||
2501 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2509 /// to all visited expressions in \p Shared. Limit the matrix operations to
2520 for (Value *Op : cast<Instruction>(V)->operand_values())
2526 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2543 if (I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2546 SharedCount = CM->second.getOpInfo();
2548 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2560 // Map matrix operations to their containting subprograms, by traversing
2567 DILocation *Context = I->getDebugLoc();
2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2589 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2590 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2592 if (getSubprogram(Context->getScope()) == KV.first) {
2604 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2605 cast<Instruction>(L)->getParent());
2662 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(