xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2edd9515bSthomasraoux //
3edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information.
5edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6edd9515bSthomasraoux //
7edd9515bSthomasraoux //===----------------------------------------------------------------------===//
8edd9515bSthomasraoux //
9edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops.
10edd9515bSthomasraoux //
11edd9515bSthomasraoux //===----------------------------------------------------------------------===//
12edd9515bSthomasraoux 
13edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14edd9515bSthomasraoux 
15ea2ed80eSChristopher Bate #include <type_traits>
16ea2ed80eSChristopher Bate 
17edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h"
18b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
19ea2ed80eSChristopher Bate #include "mlir/Dialect/Affine/IR/AffineOps.h"
20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
21d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
2266f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
2351b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24ea2ed80eSChristopher Bate #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
26edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
28fb7ef637SJakub Kuderski #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
30edd9515bSthomasraoux #include "mlir/IR/Builders.h"
315ef7ceaeSNicolas Vasilache #include "mlir/IR/BuiltinOps.h"
325ef7ceaeSNicolas Vasilache #include "mlir/IR/Region.h"
33edd9515bSthomasraoux #include "mlir/Pass/Pass.h"
34edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35edd9515bSthomasraoux #include "mlir/Transforms/Passes.h"
365ef7ceaeSNicolas Vasilache #include "llvm/ADT/STLExtras.h"
371ca772edSChristopher Bate #include "llvm/ADT/TypeSwitch.h"
38edd9515bSthomasraoux 
395ef7ceaeSNicolas Vasilache #define DEBUG_TYPE "vector-to-gpu"
405ef7ceaeSNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
415ef7ceaeSNicolas Vasilache #define DBGSNL() (llvm::dbgs() << "\n")
425ef7ceaeSNicolas Vasilache 
4367d0d7acSMichele Scuttari namespace mlir {
4467d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTVECTORTOGPU
4567d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
4667d0d7acSMichele Scuttari } // namespace mlir
4767d0d7acSMichele Scuttari 
48edd9515bSthomasraoux using namespace mlir;
49edd9515bSthomasraoux 
501ca772edSChristopher Bate /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
511ca772edSChristopher Bate /// AffineMap representing offsets to apply to indices, the function fills
521ca772edSChristopher Bate /// `indices` with the original indices plus the offsets. The offsets are
531ca772edSChristopher Bate /// applied by taking into account the permutation map of the transfer op. If
541ca772edSChristopher Bate /// the `offsetMap` has dimension placeholders, those should be provided in
551ca772edSChristopher Bate /// `dimValues`.
561ca772edSChristopher Bate template <typename TransferOpType>
575ef7ceaeSNicolas Vasilache static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
581ca772edSChristopher Bate                            AffineMap offsetMap, ArrayRef<Value> dimValues,
591ca772edSChristopher Bate                            SmallVector<Value, 4> &indices) {
601ca772edSChristopher Bate   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
611ca772edSChristopher Bate   Location loc = xferOp.getLoc();
621ca772edSChristopher Bate   unsigned offsetsIdx = 0;
631ca772edSChristopher Bate   for (auto expr : xferOp.getPermutationMap().getResults()) {
641609f1c2Slong.chen     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
651ca772edSChristopher Bate       Value prevIdx = indices[dim.getPosition()];
665262865aSKazu Hirata       SmallVector<OpFoldResult, 3> dims(dimValues);
671ca772edSChristopher Bate       dims.push_back(prevIdx);
685ef7ceaeSNicolas Vasilache       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
694c48f016SMatthias Springer       indices[dim.getPosition()] = affine::makeComposedAffineApply(
705ef7ceaeSNicolas Vasilache           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
711ca772edSChristopher Bate       continue;
721ca772edSChristopher Bate     }
731ca772edSChristopher Bate   }
741ca772edSChristopher Bate }
751ca772edSChristopher Bate 
76edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul.
771ca772edSChristopher Bate static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
781ca772edSChristopher Bate                                           bool useNvGpu) {
79edd9515bSthomasraoux   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80fe8a62c4SUday Bondhugula   auto infer = [&](MapList m) {
81fe8a62c4SUday Bondhugula     return AffineMap::inferFromExprList(m, contract.getContext());
82fe8a62c4SUday Bondhugula   };
83edd9515bSthomasraoux   AffineExpr m, n, k;
84edd9515bSthomasraoux   bindDims(contract.getContext(), m, n, k);
857c38fd60SJacques Pienaar   auto iteratorTypes = contract.getIteratorTypes().getValue();
864758e916SOleg Shyshkov   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
874758e916SOleg Shyshkov         vector::isParallelIterator(iteratorTypes[1]) &&
884758e916SOleg Shyshkov         vector::isReductionIterator(iteratorTypes[2])))
89edd9515bSthomasraoux     return false;
90edd9515bSthomasraoux 
91edd9515bSthomasraoux   // The contract needs to represent a matmul to be able to convert to
92edd9515bSthomasraoux   // MMAMatrix matmul.
931ca772edSChristopher Bate   if (!useNvGpu &&
94d2c0572bSJacques Pienaar       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
951ca772edSChristopher Bate     return false;
96d2c0572bSJacques Pienaar   if (useNvGpu &&
97d2c0572bSJacques Pienaar       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98edd9515bSthomasraoux     return false;
99edd9515bSthomasraoux 
100edd9515bSthomasraoux   return true;
101edd9515bSthomasraoux }
102edd9515bSthomasraoux 
103c0321edcSQuinn Dawkins // Return true if the given map represents a transposed matrix load,
104c0321edcSQuinn Dawkins // i.e. (d0, d1, ...) -> (dn-1, dn-2).
1055ef7ceaeSNicolas Vasilache static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
1065ef7ceaeSNicolas Vasilache   MLIRContext *ctx = permutationMap.getContext();
1075ef7ceaeSNicolas Vasilache   // Local OpBuilder is fine here, we just build attributes.
1085ef7ceaeSNicolas Vasilache   OpBuilder b(ctx);
109c0321edcSQuinn Dawkins   auto nDim = permutationMap.getNumDims();
110f1db4aecSLei Zhang   AffineExpr zero = b.getAffineConstantExpr(0);
111dbddd4f6SLei Zhang   if (nDim < 2) {
112dbddd4f6SLei Zhang     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113dbddd4f6SLei Zhang     AffineExpr dim0 = b.getAffineDimExpr(0);
114f1db4aecSLei Zhang     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115dbddd4f6SLei Zhang   }
116c0321edcSQuinn Dawkins 
117c0321edcSQuinn Dawkins   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118c0321edcSQuinn Dawkins   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119f1db4aecSLei Zhang   // Support both transposed and transposed+broadcasted cases.
120f1db4aecSLei Zhang   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121f1db4aecSLei Zhang          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122c0321edcSQuinn Dawkins }
123c0321edcSQuinn Dawkins 
124cafb6284SChristopher Bate // Return the stide for the second-to-last dimension of |type| if it is a memref
125cafb6284SChristopher Bate // and has a constant stride.
126cafb6284SChristopher Bate static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
1275550c821STres Popp   auto memrefType = dyn_cast<MemRefType>(type);
128edd9515bSthomasraoux   if (!memrefType)
129edd9515bSthomasraoux     return false;
130a57ccad5SThomas Raoux   // If the memref is 0 or 1D the horizontal stride is 0.
131a57ccad5SThomas Raoux   if (memrefType.getRank() < 2)
132a57ccad5SThomas Raoux     return 0;
133edd9515bSthomasraoux   int64_t offset = 0;
134edd9515bSthomasraoux   SmallVector<int64_t, 2> strides;
135*6aaa8f25SMatthias Springer   if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136d77f4836SThomas Raoux       strides.back() != 1)
1371a36588eSKazu Hirata     return std::nullopt;
138a57ccad5SThomas Raoux   int64_t stride = strides[strides.size() - 2];
139399638f9SAliia Khasanova   if (stride == ShapedType::kDynamic)
1401a36588eSKazu Hirata     return std::nullopt;
141a57ccad5SThomas Raoux   return stride;
142edd9515bSthomasraoux }
143edd9515bSthomasraoux 
144edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load.
145cafb6284SChristopher Bate static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
1467c38fd60SJacques Pienaar   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147edd9515bSthomasraoux       readOp.getVectorType().getRank() != 2)
148edd9515bSthomasraoux     return false;
149cafb6284SChristopher Bate   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150edd9515bSthomasraoux     return false;
151985f7ff6SQuinn Dawkins 
152985f7ff6SQuinn Dawkins   // Only allow integer types if the signedness can be inferred.
153cafb6284SChristopher Bate   if (readOp.getVectorType().getElementType().isInteger(8))
1545205c712SQuinn Dawkins     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
1555205c712SQuinn Dawkins                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156985f7ff6SQuinn Dawkins       return false;
157985f7ff6SQuinn Dawkins 
1587c38fd60SJacques Pienaar   AffineMap map = readOp.getPermutationMap();
1595ef7ceaeSNicolas Vasilache   MLIRContext *ctx = readOp.getContext();
1605ef7ceaeSNicolas Vasilache   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
1615ef7ceaeSNicolas Vasilache   AffineExpr zero = getAffineConstantExpr(0, ctx);
1625ef7ceaeSNicolas Vasilache   auto broadcastInnerDim =
1635ef7ceaeSNicolas Vasilache       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164cafb6284SChristopher Bate   return map.isMinorIdentity() || map == broadcastInnerDim ||
1655ef7ceaeSNicolas Vasilache          isTransposeMatrixLoadMap(map);
166edd9515bSthomasraoux }
167edd9515bSthomasraoux 
168edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store.
169edd9515bSthomasraoux static bool
170edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
172c537a943SNicolas Vasilache   if (writeOp.getTransferRank() == 0)
173c537a943SNicolas Vasilache     return false;
174c537a943SNicolas Vasilache 
1757c38fd60SJacques Pienaar   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176edd9515bSthomasraoux       writeOp.getVectorType().getRank() != 2)
177edd9515bSthomasraoux     return false;
178cafb6284SChristopher Bate   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179edd9515bSthomasraoux     return false;
180edd9515bSthomasraoux   // TODO: Support transpose once it is added to GPU dialect ops.
1817c38fd60SJacques Pienaar   if (!writeOp.getPermutationMap().isMinorIdentity())
182edd9515bSthomasraoux     return false;
183edd9515bSthomasraoux   return true;
184edd9515bSthomasraoux }
185edd9515bSthomasraoux 
1866413226dSthomasraoux /// Return true if the constant is a splat to a 2D vector so that it can be
1876413226dSthomasraoux /// converted to a MMA constant matrix op.
188a54f4eaeSMogball static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
1895550c821STres Popp   auto vecType = dyn_cast<VectorType>(constantOp.getType());
1906413226dSthomasraoux   if (!vecType || vecType.getRank() != 2)
1916413226dSthomasraoux     return false;
1925550c821STres Popp   return isa<SplatElementsAttr>(constantOp.getValue());
1936413226dSthomasraoux }
1946413226dSthomasraoux 
19543928419Sthomasraoux /// Return true if this is a broadcast from scalar to a 2D vector.
19643928419Sthomasraoux static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197a1aad28dSLei Zhang   return broadcastOp.getResultVectorType().getRank() == 2;
198985f7ff6SQuinn Dawkins }
199985f7ff6SQuinn Dawkins 
2005205c712SQuinn Dawkins /// Return true if this integer extend op can be folded into a contract op.
2015205c712SQuinn Dawkins template <typename ExtOpTy>
2025205c712SQuinn Dawkins static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203927559d2SLongsheng Mou   auto transferReadOp =
204927559d2SLongsheng Mou       extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205927559d2SLongsheng Mou   if (!transferReadOp)
206985f7ff6SQuinn Dawkins     return false;
207971b8525SJakub Kuderski   return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
20843928419Sthomasraoux }
20943928419Sthomasraoux 
210a0119437SLei Zhang static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211a0119437SLei Zhang 
2127fbb0678Sthomasraoux /// Return the MMA elementwise enum associated with `op` if it is supported.
213192d9dd7SKazu Hirata /// Return `std::nullopt` otherwise.
214d32ec523SRamkumar Ramachandra static std::optional<gpu::MMAElementwiseOp>
2157fbb0678Sthomasraoux convertElementwiseOpToMMA(Operation *op) {
2167fbb0678Sthomasraoux   if (isa<arith::AddFOp>(op))
2177fbb0678Sthomasraoux     return gpu::MMAElementwiseOp::ADDF;
2187fbb0678Sthomasraoux   if (isa<arith::MulFOp>(op))
2197fbb0678Sthomasraoux     return gpu::MMAElementwiseOp::MULF;
22050882b4dSLei Zhang   if (isa<arith::SubFOp>(op))
22150882b4dSLei Zhang     return gpu::MMAElementwiseOp::SUBF;
2228a6e54c9SDaniil Dudkin   if (isa<arith::MaximumFOp>(op))
2237fbb0678Sthomasraoux     return gpu::MMAElementwiseOp::MAXF;
2248a6e54c9SDaniil Dudkin   if (isa<arith::MinimumFOp>(op))
2257fbb0678Sthomasraoux     return gpu::MMAElementwiseOp::MINF;
226e7969240SThomas Raoux   if (isa<arith::DivFOp>(op))
227e7969240SThomas Raoux     return gpu::MMAElementwiseOp::DIVF;
22850882b4dSLei Zhang   if (isa<arith::AddIOp>(op))
22950882b4dSLei Zhang     return gpu::MMAElementwiseOp::ADDI;
23050882b4dSLei Zhang   if (isa<arith::MulIOp>(op))
23150882b4dSLei Zhang     return gpu::MMAElementwiseOp::MULI;
23250882b4dSLei Zhang   if (isa<arith::SubIOp>(op))
23350882b4dSLei Zhang     return gpu::MMAElementwiseOp::SUBI;
23450882b4dSLei Zhang   if (isa<arith::DivSIOp>(op))
23550882b4dSLei Zhang     return gpu::MMAElementwiseOp::DIVS;
23650882b4dSLei Zhang   if (isa<arith::DivUIOp>(op))
23750882b4dSLei Zhang     return gpu::MMAElementwiseOp::DIVU;
23850882b4dSLei Zhang   if (isa<arith::NegFOp>(op))
23950882b4dSLei Zhang     return gpu::MMAElementwiseOp::NEGATEF;
240a0119437SLei Zhang   if (isa<arith::ExtFOp>(op))
241a0119437SLei Zhang     return gpu::MMAElementwiseOp::EXTF;
2421a36588eSKazu Hirata   return std::nullopt;
2437fbb0678Sthomasraoux }
2447fbb0678Sthomasraoux 
2457fbb0678Sthomasraoux /// Return true if the op is supported as elementwise op on MMAMatrix type.
2467fbb0678Sthomasraoux static bool elementwiseSupportsMMAMatrixType(Operation *op) {
247064a08cdSKazu Hirata   return convertElementwiseOpToMMA(op).has_value();
2487fbb0678Sthomasraoux }
2497fbb0678Sthomasraoux 
250114ba722SManish Gupta /// Returns true if the extract strided slice op is supported with `mma.sync`
251114ba722SManish Gupta /// path.
252114ba722SManish Gupta static bool
253114ba722SManish Gupta extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254114ba722SManish Gupta 
255114ba722SManish Gupta   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
256114ba722SManish Gupta       nvgpu::getWarpMatrixInfo(op);
257114ba722SManish Gupta   if (failed(warpMatrixInfo))
258114ba722SManish Gupta     return false;
259114ba722SManish Gupta 
260114ba722SManish Gupta   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261114ba722SManish Gupta   if (failed(contractOp))
262114ba722SManish Gupta     return false;
263114ba722SManish Gupta 
2643af64383SNicolas Vasilache   // Handle vector.extract_strided_slice on registers containing
2653af64383SNicolas Vasilache   // matrixB and matrixC operands. vector.extract_strided_slice op
2663af64383SNicolas Vasilache   // is not supported on registers containing matrixA operands.
267114ba722SManish Gupta   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
2685550c821STres Popp     return (cast<VectorType>(op->getResult(0).getType()) ==
2695550c821STres Popp             cast<VectorType>((*contractOp).getRhs().getType()));
2706a7a1188SMehdi Amini   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
2715550c821STres Popp     return (cast<VectorType>(op->getResult(0).getType()) ==
2725550c821STres Popp             cast<VectorType>((*contractOp).getAcc().getType()));
273114ba722SManish Gupta 
274114ba722SManish Gupta   return false;
275114ba722SManish Gupta }
276114ba722SManish Gupta 
2771ca772edSChristopher Bate static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
2781a865592Sthomasraoux   if (isa<scf::ForOp, scf::YieldOp>(op))
2791a865592Sthomasraoux     return true;
280edd9515bSthomasraoux   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281cafb6284SChristopher Bate     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282cafb6284SChristopher Bate                     : transferReadSupportsMMAMatrixType(transferRead);
283edd9515bSthomasraoux   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284cafb6284SChristopher Bate     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285cafb6284SChristopher Bate                     : transferWriteSupportsMMAMatrixType(transferWrite);
286114ba722SManish Gupta   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287114ba722SManish Gupta     return useNvGpu &&
288114ba722SManish Gupta            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289edd9515bSthomasraoux   if (auto contract = dyn_cast<vector::ContractionOp>(op))
2901ca772edSChristopher Bate     return contractSupportsMMAMatrixType(contract, useNvGpu);
291a54f4eaeSMogball   if (auto constant = dyn_cast<arith::ConstantOp>(op))
2926413226dSthomasraoux     return constantSupportsMMAMatrixType(constant);
2933af64383SNicolas Vasilache   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
2943af64383SNicolas Vasilache     return broadcastSupportsMMAMatrixType(broadcast);
2955205c712SQuinn Dawkins   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
2965205c712SQuinn Dawkins     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
2975205c712SQuinn Dawkins   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
2985205c712SQuinn Dawkins     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299a0119437SLei Zhang   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300a0119437SLei Zhang     return fpExtendSupportsMMAMatrixType(fpExtend);
3017fbb0678Sthomasraoux   return elementwiseSupportsMMAMatrixType(op);
302edd9515bSthomasraoux }
303edd9515bSthomasraoux 
304e7969240SThomas Raoux /// Return an unsorted slice handling scf.for region differently than
305e7969240SThomas Raoux /// `getSlice`. In scf.for we only want to include as part of the slice elements
306e7969240SThomas Raoux /// that are part of the use/def chain.
307641b12e9SMahesh Ravishankar static SetVector<Operation *>
308847d8457SMehdi Amini getSliceContract(Operation *op,
309847d8457SMehdi Amini                  const BackwardSliceOptions &backwardSliceOptions,
310847d8457SMehdi Amini                  const ForwardSliceOptions &forwardSliceOptions) {
311e7969240SThomas Raoux   SetVector<Operation *> slice;
312e7969240SThomas Raoux   slice.insert(op);
313e7969240SThomas Raoux   unsigned currentIndex = 0;
314e7969240SThomas Raoux   SetVector<Operation *> backwardSlice;
315e7969240SThomas Raoux   SetVector<Operation *> forwardSlice;
316e7969240SThomas Raoux   while (currentIndex != slice.size()) {
317e7969240SThomas Raoux     auto *currentOp = (slice)[currentIndex];
3183af64383SNicolas Vasilache     // Compute and insert the backwardSlice starting from currentOp.
319e7969240SThomas Raoux     backwardSlice.clear();
320641b12e9SMahesh Ravishankar     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
321e7969240SThomas Raoux     slice.insert(backwardSlice.begin(), backwardSlice.end());
322e7969240SThomas Raoux 
3233af64383SNicolas Vasilache     // Compute and insert the forwardSlice starting from currentOp.
324e7969240SThomas Raoux     forwardSlice.clear();
3253af64383SNicolas Vasilache     // Special case for ForOp, we don't want to include the whole region but
3263af64383SNicolas Vasilache     // only the value using the region arguments.
3273af64383SNicolas Vasilache     // TODO: We should refine this to only care about the region arguments being
3283af64383SNicolas Vasilache     // converted to matrix type.
329e7969240SThomas Raoux     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
330e7969240SThomas Raoux       for (Value forOpResult : forOp.getResults())
331641b12e9SMahesh Ravishankar         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
332e7969240SThomas Raoux       for (BlockArgument &arg : forOp.getRegionIterArgs())
333641b12e9SMahesh Ravishankar         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
334e7969240SThomas Raoux     } else {
335641b12e9SMahesh Ravishankar       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
336e7969240SThomas Raoux     }
337e7969240SThomas Raoux     slice.insert(forwardSlice.begin(), forwardSlice.end());
338e7969240SThomas Raoux     ++currentIndex;
339e7969240SThomas Raoux   }
340e7969240SThomas Raoux   return slice;
341e7969240SThomas Raoux }
342e7969240SThomas Raoux 
343edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole
344edd9515bSthomasraoux // slice can be converted to MMA operations.
3451ca772edSChristopher Bate static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
3461ca772edSChristopher Bate                                              bool useNvGpu) {
347edd9515bSthomasraoux   auto hasVectorDest = [](Operation *op) {
348971b8525SJakub Kuderski     return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
34943928419Sthomasraoux   };
350641b12e9SMahesh Ravishankar   BackwardSliceOptions backwardSliceOptions;
351641b12e9SMahesh Ravishankar   backwardSliceOptions.filter = hasVectorDest;
352641b12e9SMahesh Ravishankar 
35343928419Sthomasraoux   auto hasVectorSrc = [](Operation *op) {
354971b8525SJakub Kuderski     return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
355edd9515bSthomasraoux   };
356641b12e9SMahesh Ravishankar   ForwardSliceOptions forwardSliceOptions;
357641b12e9SMahesh Ravishankar   forwardSliceOptions.filter = hasVectorSrc;
358641b12e9SMahesh Ravishankar 
359edd9515bSthomasraoux   SetVector<Operation *> opToConvert;
360edd9515bSthomasraoux   op->walk([&](vector::ContractionOp contract) {
361edd9515bSthomasraoux     if (opToConvert.contains(contract.getOperation()))
362edd9515bSthomasraoux       return;
363edd9515bSthomasraoux     SetVector<Operation *> dependentOps =
364641b12e9SMahesh Ravishankar         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
3653af64383SNicolas Vasilache     // If any instruction cannot use MMA matrix type drop the whole
3663af64383SNicolas Vasilache     // chain. MMA matrix are stored in an opaque type so they cannot be used
3673af64383SNicolas Vasilache     // by all operations.
3681ca772edSChristopher Bate     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369cafb6284SChristopher Bate           if (!supportsMMaMatrixType(op, useNvGpu)) {
370cafb6284SChristopher Bate             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371cafb6284SChristopher Bate             return true;
372cafb6284SChristopher Bate           }
373cafb6284SChristopher Bate           return false;
3741ca772edSChristopher Bate         }))
375edd9515bSthomasraoux       return;
376cafb6284SChristopher Bate 
377edd9515bSthomasraoux     opToConvert.insert(dependentOps.begin(), dependentOps.end());
378edd9515bSthomasraoux   });
3793af64383SNicolas Vasilache   // Sort the operations so that we can convert them in topological order.
380e7969240SThomas Raoux   return topologicalSort(opToConvert);
381edd9515bSthomasraoux }
382edd9515bSthomasraoux 
383edd9515bSthomasraoux namespace {
384edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385edd9515bSthomasraoux // to MMA matmul.
386edd9515bSthomasraoux struct PrepareContractToGPUMMA
387edd9515bSthomasraoux     : public OpRewritePattern<vector::ContractionOp> {
388edd9515bSthomasraoux   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
389edd9515bSthomasraoux 
390edd9515bSthomasraoux   LogicalResult matchAndRewrite(vector::ContractionOp op,
391edd9515bSthomasraoux                                 PatternRewriter &rewriter) const override {
392edd9515bSthomasraoux     Location loc = op.getLoc();
3937c38fd60SJacques Pienaar     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394edd9515bSthomasraoux 
395edd9515bSthomasraoux     // Set up the parallel/reduction structure in right form.
396edd9515bSthomasraoux     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397fe8a62c4SUday Bondhugula     auto infer = [&](MapList m) {
398fe8a62c4SUday Bondhugula       return AffineMap::inferFromExprList(m, op.getContext());
399fe8a62c4SUday Bondhugula     };
400edd9515bSthomasraoux     AffineExpr m, n, k;
401edd9515bSthomasraoux     bindDims(rewriter.getContext(), m, n, k);
402edd9515bSthomasraoux     static constexpr std::array<int64_t, 2> perm = {1, 0};
4037c38fd60SJacques Pienaar     auto iteratorTypes = op.getIteratorTypes().getValue();
404d2c0572bSJacques Pienaar     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
4054758e916SOleg Shyshkov     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
4064758e916SOleg Shyshkov           vector::isParallelIterator(iteratorTypes[1]) &&
4074758e916SOleg Shyshkov           vector::isReductionIterator(iteratorTypes[2])))
4085ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
409edd9515bSthomasraoux     //
410edd9515bSthomasraoux     // Two outer parallel, one inner reduction (matmat flavor).
411edd9515bSthomasraoux     //
412edd9515bSthomasraoux     // This is the classical row-major matmul, nothing to do.
4135ef7ceaeSNicolas Vasilache     if (maps == infer({{m, k}, {k, n}, {m, n}}))
4145ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "contraction already prepared");
415edd9515bSthomasraoux     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
416edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
417edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
418edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
420edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
421edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422edd9515bSthomasraoux     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
423edd9515bSthomasraoux       std::swap(rhs, lhs);
424edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
425edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
426edd9515bSthomasraoux     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
427edd9515bSthomasraoux       std::swap(rhs, lhs);
428edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
429edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
430edd9515bSthomasraoux       std::swap(lhs, rhs);
431edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
432edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
433edd9515bSthomasraoux       std::swap(lhs, rhs);
434edd9515bSthomasraoux     } else {
4355ef7ceaeSNicolas Vasilache       // TODO: llvm_unreachable ?
4365ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
437edd9515bSthomasraoux     }
438edd9515bSthomasraoux     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
439edd9515bSthomasraoux         op, lhs, rhs, res,
440edd9515bSthomasraoux         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
4417c38fd60SJacques Pienaar         op.getIteratorTypes());
442edd9515bSthomasraoux     return success();
443edd9515bSthomasraoux   }
444edd9515bSthomasraoux };
445edd9515bSthomasraoux 
446baa5beecStyb0807 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
447114ba722SManish Gupta // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
448114ba722SManish Gupta // respectively. We can fold the transpose operation when loading the data from
449114ba722SManish Gupta // Shared Memory to registers.
450edd9515bSthomasraoux struct CombineTransferReadOpTranspose final
451edd9515bSthomasraoux     : public OpRewritePattern<vector::TransposeOp> {
452edd9515bSthomasraoux   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
453edd9515bSthomasraoux 
454edd9515bSthomasraoux   LogicalResult matchAndRewrite(vector::TransposeOp op,
455edd9515bSthomasraoux                                 PatternRewriter &rewriter) const override {
456985f7ff6SQuinn Dawkins     // Look through integer extend ops.
457985f7ff6SQuinn Dawkins     Value source = op.getVector();
4583cf7f224SThomas Raoux     Type resultType = op.getType();
4595205c712SQuinn Dawkins     Operation *extOp;
4605205c712SQuinn Dawkins     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
46142bba97fSharsh-nod         (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
46242bba97fSharsh-nod         (extOp = source.getDefiningOp<arith::ExtFOp>())) {
4635205c712SQuinn Dawkins       source = extOp->getOperand(0);
464985f7ff6SQuinn Dawkins       resultType =
4655550c821STres Popp           VectorType::get(cast<VectorType>(resultType).getShape(),
4665550c821STres Popp                           cast<VectorType>(source.getType()).getElementType());
467985f7ff6SQuinn Dawkins     }
468985f7ff6SQuinn Dawkins 
469985f7ff6SQuinn Dawkins     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
470edd9515bSthomasraoux     if (!transferReadOp)
4715ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "no transfer read");
472c537a943SNicolas Vasilache 
473c537a943SNicolas Vasilache     // TODO: support 0-d corner case.
474c537a943SNicolas Vasilache     if (transferReadOp.getTransferRank() == 0)
4755ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0-D transfer read");
476c537a943SNicolas Vasilache 
4777c38fd60SJacques Pienaar     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
4785ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
4795ef7ceaeSNicolas Vasilache 
480edd9515bSthomasraoux     AffineMap permutationMap =
48132c3decbSMatthias Springer         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
4827c38fd60SJacques Pienaar     AffineMap newMap =
4837c38fd60SJacques Pienaar         permutationMap.compose(transferReadOp.getPermutationMap());
484985f7ff6SQuinn Dawkins 
485985f7ff6SQuinn Dawkins     auto loc = op.getLoc();
486985f7ff6SQuinn Dawkins     Value result =
487985f7ff6SQuinn Dawkins         rewriter
488985f7ff6SQuinn Dawkins             .create<vector::TransferReadOp>(
489985f7ff6SQuinn Dawkins                 loc, resultType, transferReadOp.getSource(),
4907c38fd60SJacques Pienaar                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
4917c38fd60SJacques Pienaar                 transferReadOp.getPadding(), transferReadOp.getMask(),
492985f7ff6SQuinn Dawkins                 transferReadOp.getInBoundsAttr())
493985f7ff6SQuinn Dawkins             .getResult();
494985f7ff6SQuinn Dawkins 
495985f7ff6SQuinn Dawkins     // Fuse through the integer extend op.
4965205c712SQuinn Dawkins     if (extOp) {
4975205c712SQuinn Dawkins       if (isa<arith::ExtSIOp>(extOp))
498985f7ff6SQuinn Dawkins         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
499985f7ff6SQuinn Dawkins                      .getResult();
50042bba97fSharsh-nod       else if (isa<arith::ExtUIOp>(extOp))
5015205c712SQuinn Dawkins         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
5025205c712SQuinn Dawkins                      .getResult();
50342bba97fSharsh-nod       else
50442bba97fSharsh-nod         result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
50542bba97fSharsh-nod                      .getResult();
5065205c712SQuinn Dawkins     }
507985f7ff6SQuinn Dawkins 
508985f7ff6SQuinn Dawkins     rewriter.replaceOp(op, result);
509edd9515bSthomasraoux     return success();
510edd9515bSthomasraoux   }
511edd9515bSthomasraoux };
512edd9515bSthomasraoux 
513edd9515bSthomasraoux } // namespace
514edd9515bSthomasraoux 
515edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops.
5166413226dSthomasraoux // Figure the right layout to use by looking at op uses.
517edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and
518edd9515bSthomasraoux // only care about it during lowering to NVVM.
5195205c712SQuinn Dawkins static const char *inferFragType(Operation *op) {
520a037d889SLei Zhang   // We can have arith.ext ops before reaching contract ops. See through them
521a037d889SLei Zhang   // and other kinds of elementwise ops.
522a037d889SLei Zhang   if (op->hasOneUse()) {
523a037d889SLei Zhang     Operation *userOp = *op->user_begin();
524a037d889SLei Zhang     if (userOp->hasTrait<OpTrait::Elementwise>())
525a037d889SLei Zhang       return inferFragType(userOp);
526a037d889SLei Zhang   }
527a037d889SLei Zhang 
528edd9515bSthomasraoux   for (Operation *users : op->getUsers()) {
529edd9515bSthomasraoux     auto contract = dyn_cast<vector::ContractionOp>(users);
530edd9515bSthomasraoux     if (!contract)
531edd9515bSthomasraoux       continue;
5325205c712SQuinn Dawkins     assert(op->getNumResults() == 1);
5335205c712SQuinn Dawkins     if (contract.getLhs() == op->getResult(0))
534edd9515bSthomasraoux       return "AOp";
5355205c712SQuinn Dawkins     if (contract.getRhs() == op->getResult(0))
536edd9515bSthomasraoux       return "BOp";
537edd9515bSthomasraoux   }
538edd9515bSthomasraoux   return "COp";
539edd9515bSthomasraoux }
540edd9515bSthomasraoux 
5415ef7ceaeSNicolas Vasilache static LogicalResult
5425ef7ceaeSNicolas Vasilache convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
543edd9515bSthomasraoux                       llvm::DenseMap<Value, Value> &valueMapping) {
5445ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
5455ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
5465ef7ceaeSNicolas Vasilache 
547c537a943SNicolas Vasilache   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548cafb6284SChristopher Bate   assert(transferReadSupportsMMAMatrixType(op) &&
549cafb6284SChristopher Bate          "expected convertible operation");
550dbddd4f6SLei Zhang 
551d32ec523SRamkumar Ramachandra   std::optional<int64_t> stride =
552cafb6284SChristopher Bate       getStaticallyKnownRowStride(op.getShapedType());
5535ef7ceaeSNicolas Vasilache   if (!stride.has_value()) {
5545ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no stride\n");
5555ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no stride");
5565ef7ceaeSNicolas Vasilache   }
557dbddd4f6SLei Zhang 
5587c38fd60SJacques Pienaar   AffineMap map = op.getPermutationMap();
5595ef7ceaeSNicolas Vasilache   bool isTranspose = isTransposeMatrixLoadMap(map);
560dbddd4f6SLei Zhang 
561e7969240SThomas Raoux   // Handle broadcast by setting the stride to 0.
5621609f1c2Slong.chen   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
563dbddd4f6SLei Zhang     assert(cstExpr.getValue() == 0);
564e7969240SThomas Raoux     stride = 0;
565e7969240SThomas Raoux   }
5665ef7ceaeSNicolas Vasilache 
567985f7ff6SQuinn Dawkins   Value mappingResult = op.getResult();
568985f7ff6SQuinn Dawkins   auto elType = op.getVectorType().getElementType();
569edd9515bSthomasraoux   const char *fragType = inferFragType(op);
570985f7ff6SQuinn Dawkins   if (op->hasOneUse()) {
571b8a3f0fdSMehdi Amini     auto *user = *op->user_begin();
5725205c712SQuinn Dawkins     // Infer the signedness of the mma type from the integer extend.
573a037d889SLei Zhang     if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
5745205c712SQuinn Dawkins       elType = IntegerType::get(
5755550c821STres Popp           op.getContext(), cast<IntegerType>(elType).getWidth(),
576a037d889SLei Zhang           isa<arith::ExtSIOp>(user) ? IntegerType::Signed
577a037d889SLei Zhang                                     : IntegerType::Unsigned);
5785205c712SQuinn Dawkins       mappingResult = user->getResult(0);
579985f7ff6SQuinn Dawkins     }
580985f7ff6SQuinn Dawkins   }
581edd9515bSthomasraoux   gpu::MMAMatrixType type =
582985f7ff6SQuinn Dawkins       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
5835ef7ceaeSNicolas Vasilache   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
5847c38fd60SJacques Pienaar       op.getLoc(), type, op.getSource(), op.getIndices(),
5855ef7ceaeSNicolas Vasilache       rewriter.getIndexAttr(*stride),
5865ef7ceaeSNicolas Vasilache       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
587985f7ff6SQuinn Dawkins   valueMapping[mappingResult] = load;
5885ef7ceaeSNicolas Vasilache 
5895ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
5905ef7ceaeSNicolas Vasilache   return success();
591edd9515bSthomasraoux }
592edd9515bSthomasraoux 
5935ef7ceaeSNicolas Vasilache static LogicalResult
5945ef7ceaeSNicolas Vasilache convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
595edd9515bSthomasraoux                        llvm::DenseMap<Value, Value> &valueMapping) {
5965ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
5975ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
5985ef7ceaeSNicolas Vasilache 
599edd9515bSthomasraoux   assert(transferWriteSupportsMMAMatrixType(op));
600d32ec523SRamkumar Ramachandra   std::optional<int64_t> stride =
601cafb6284SChristopher Bate       getStaticallyKnownRowStride(op.getShapedType());
6025ef7ceaeSNicolas Vasilache   if (!stride.has_value()) {
6035ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no stride\n");
6045ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no stride");
6055ef7ceaeSNicolas Vasilache   }
6065ef7ceaeSNicolas Vasilache 
6075ef7ceaeSNicolas Vasilache   auto it = valueMapping.find(op.getVector());
6085ef7ceaeSNicolas Vasilache   if (it == valueMapping.end()) {
6095ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no mapping\n");
6105ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mapping");
6115ef7ceaeSNicolas Vasilache   }
6125ef7ceaeSNicolas Vasilache 
6135ef7ceaeSNicolas Vasilache   Value matrix = it->second;
6145ef7ceaeSNicolas Vasilache   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
6153d35546cSNavdeep Katel       op.getLoc(), matrix, op.getSource(), op.getIndices(),
6165ef7ceaeSNicolas Vasilache       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
6175ef7ceaeSNicolas Vasilache   (void)store;
6185ef7ceaeSNicolas Vasilache 
6195ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
6205ef7ceaeSNicolas Vasilache 
6215ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
6225ef7ceaeSNicolas Vasilache   rewriter.eraseOp(op);
6235ef7ceaeSNicolas Vasilache   return success();
624edd9515bSthomasraoux }
625edd9515bSthomasraoux 
6261ca772edSChristopher Bate /// Returns the vector type which represents a matrix fragment.
6271ca772edSChristopher Bate static VectorType
6281ca772edSChristopher Bate getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
6291ca772edSChristopher Bate   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
6301ca772edSChristopher Bate                              regInfo.elementsPerRegister};
6311ca772edSChristopher Bate   Type elType = regInfo.registerLLVMType;
6325550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(elType))
6331ca772edSChristopher Bate     elType = vecType.getElementType();
6341ca772edSChristopher Bate   return VectorType::get(shape, elType);
6351ca772edSChristopher Bate }
6361ca772edSChristopher Bate 
6371ca772edSChristopher Bate /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
6381ca772edSChristopher Bate static LogicalResult
6395ef7ceaeSNicolas Vasilache convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
6401ca772edSChristopher Bate                          llvm::DenseMap<Value, Value> &valueMapping) {
6415ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
6425ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
6435ef7ceaeSNicolas Vasilache 
6441ca772edSChristopher Bate   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
6451ca772edSChristopher Bate       nvgpu::getWarpMatrixInfo(op);
6465ef7ceaeSNicolas Vasilache   if (failed(warpMatrixInfo)) {
6475ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
6485ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
6495ef7ceaeSNicolas Vasilache   }
6501ca772edSChristopher Bate 
6511ca772edSChristopher Bate   FailureOr<nvgpu::FragmentElementInfo> regInfo =
6521ca772edSChristopher Bate       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
6535ef7ceaeSNicolas Vasilache   if (failed(regInfo)) {
6545ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
6555ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
6565ef7ceaeSNicolas Vasilache   }
6571ca772edSChristopher Bate 
6581ca772edSChristopher Bate   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
6595550c821STres Popp   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
6605ef7ceaeSNicolas Vasilache   if (!dense) {
6615ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "not a splat\n");
6625ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "not a splat");
6635ef7ceaeSNicolas Vasilache   }
6645ef7ceaeSNicolas Vasilache 
6655ef7ceaeSNicolas Vasilache   Value result = rewriter.create<arith::ConstantOp>(
6661ca772edSChristopher Bate       op.getLoc(), vectorType,
6671ca772edSChristopher Bate       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
6681ca772edSChristopher Bate   valueMapping[op.getResult()] = result;
6691ca772edSChristopher Bate   return success();
6701ca772edSChristopher Bate }
6711ca772edSChristopher Bate 
6726b82fc77SManish Gupta /// Check if the loaded matrix operand requires transposed.
6736b82fc77SManish Gupta /// Transposed Map Example:
6746b82fc77SManish Gupta /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
6756b82fc77SManish Gupta /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
6766b82fc77SManish Gupta /// The code below checks if the output 2D is transposed using a generalized
6776b82fc77SManish Gupta /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
6786b82fc77SManish Gupta /// Returns     : true; if m > n, false o.w.
67984eed784SManish Gupta static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
6806b82fc77SManish Gupta   mlir::AffineMap map = op.getPermutationMap();
68184eed784SManish Gupta 
6826b82fc77SManish Gupta   if (map.getNumResults() != 2) {
68384eed784SManish Gupta     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
68484eed784SManish Gupta                          "is not a 2d operand\n");
68584eed784SManish Gupta     return failure();
6866b82fc77SManish Gupta   }
6876b82fc77SManish Gupta 
6886b82fc77SManish Gupta   // Output 2D matrix dimensions in the order of d0, d1.
68984eed784SManish Gupta   mlir::AffineExpr dM = map.getResult(0);
69084eed784SManish Gupta   mlir::AffineExpr dN = map.getResult(1);
6916b82fc77SManish Gupta 
6926b82fc77SManish Gupta   //  Find the position of these expressions in the input.
6931609f1c2Slong.chen   auto exprM = dyn_cast<AffineDimExpr>(dM);
6941609f1c2Slong.chen   auto exprN = dyn_cast<AffineDimExpr>(dN);
69584eed784SManish Gupta 
6966b82fc77SManish Gupta   if (!exprM || !exprN) {
69784eed784SManish Gupta     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
69884eed784SManish Gupta                          "expressions, then transpose cannot be determined.\n");
69984eed784SManish Gupta     return failure();
7006b82fc77SManish Gupta   }
70184eed784SManish Gupta 
7026b82fc77SManish Gupta   return exprM.getPosition() > exprN.getPosition();
7036b82fc77SManish Gupta }
7046b82fc77SManish Gupta 
7051ca772edSChristopher Bate static LogicalResult
7065ef7ceaeSNicolas Vasilache creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
7071ca772edSChristopher Bate                              llvm::DenseMap<Value, Value> &valueMapping) {
7085ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
7095ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
7101ca772edSChristopher Bate   Location loc = op->getLoc();
7111ca772edSChristopher Bate 
7121ca772edSChristopher Bate   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
7131ca772edSChristopher Bate       nvgpu::getWarpMatrixInfo(op);
7145ef7ceaeSNicolas Vasilache   if (failed(warpMatrixInfo)) {
7155ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
7165ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
7175ef7ceaeSNicolas Vasilache   }
7181ca772edSChristopher Bate 
7191ca772edSChristopher Bate   FailureOr<nvgpu::FragmentElementInfo> regInfo =
7201ca772edSChristopher Bate       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
7215ef7ceaeSNicolas Vasilache   if (failed(regInfo)) {
7225ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
7235ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
7245ef7ceaeSNicolas Vasilache   }
7251ca772edSChristopher Bate 
72684eed784SManish Gupta   FailureOr<bool> transpose = isTransposed(op);
72784eed784SManish Gupta   if (failed(transpose)) {
72884eed784SManish Gupta     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
72984eed784SManish Gupta     return rewriter.notifyMatchFailure(
73084eed784SManish Gupta         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
73184eed784SManish Gupta   }
73284eed784SManish Gupta 
7336b82fc77SManish Gupta   FailureOr<nvgpu::LdMatrixParams> params =
73484eed784SManish Gupta       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
7356b82fc77SManish Gupta 
7361ca772edSChristopher Bate   if (failed(params)) {
7375ef7ceaeSNicolas Vasilache     LLVM_DEBUG(
7385ef7ceaeSNicolas Vasilache         DBGS()
7395ef7ceaeSNicolas Vasilache         << "failed to convert vector.transfer_read to ldmatrix. "
7405ef7ceaeSNicolas Vasilache         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
7415ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(
7425ef7ceaeSNicolas Vasilache         op, "failed to convert vector.transfer_read to ldmatrix; this op "
7435ef7ceaeSNicolas Vasilache             "likely should not be converted to a nvgpu.ldmatrix call.");
7441ca772edSChristopher Bate   }
7451ca772edSChristopher Bate 
7461ca772edSChristopher Bate   // Adjust the load offset.
74743fd4c49SKrzysztof Drewniak   auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
7481ca772edSChristopher Bate   FailureOr<AffineMap> offsets =
7495ef7ceaeSNicolas Vasilache       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
7505ef7ceaeSNicolas Vasilache   if (failed(offsets)) {
7515ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "no offsets\n");
7525ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no offsets");
7535ef7ceaeSNicolas Vasilache   }
7541ca772edSChristopher Bate 
7551ca772edSChristopher Bate   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
7561ca772edSChristopher Bate 
7571ca772edSChristopher Bate   SmallVector<Value, 4> indices;
7585ef7ceaeSNicolas Vasilache   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
7591ca772edSChristopher Bate                                          indices);
76084eed784SManish Gupta 
7615ef7ceaeSNicolas Vasilache   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
76284eed784SManish Gupta       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
7631ca772edSChristopher Bate   valueMapping[op] = newOp->getResult(0);
7641ca772edSChristopher Bate   return success();
7651ca772edSChristopher Bate }
7661ca772edSChristopher Bate 
7671ca772edSChristopher Bate static LogicalResult
7685ef7ceaeSNicolas Vasilache createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
7691ca772edSChristopher Bate                        llvm::DenseMap<Value, Value> &valueMapping) {
7705ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
7715ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
7725ef7ceaeSNicolas Vasilache 
7731ca772edSChristopher Bate   Location loc = op.getLoc();
7741ca772edSChristopher Bate   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
7751ca772edSChristopher Bate       nvgpu::getWarpMatrixInfo(op);
7761ca772edSChristopher Bate   if (failed(warpMatrixInfo))
7775ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
7781ca772edSChristopher Bate   FailureOr<nvgpu::FragmentElementInfo> regInfo =
7791ca772edSChristopher Bate       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
7801ca772edSChristopher Bate   if (failed(regInfo)) {
781c0f504dfSJie Fu     return rewriter.notifyMatchFailure(
7825ef7ceaeSNicolas Vasilache         op, "Failed to deduce register fragment type during "
7835ef7ceaeSNicolas Vasilache             "conversion to distributed non-ldmatrix compatible load");
7841ca772edSChristopher Bate   }
7851ca772edSChristopher Bate 
78643fd4c49SKrzysztof Drewniak   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
7871ca772edSChristopher Bate   SmallVector<Value, 4> elements;
7881ca772edSChristopher Bate 
7891ca772edSChristopher Bate   // This is the individual element type.
7901ca772edSChristopher Bate   Type loadedElType = regInfo->registerLLVMType;
7911ca772edSChristopher Bate   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
7921ca772edSChristopher Bate 
7935ef7ceaeSNicolas Vasilache   Value fill = rewriter.create<arith::ConstantOp>(
7941ca772edSChristopher Bate       op.getLoc(), vectorType.getElementType(),
7955ef7ceaeSNicolas Vasilache       rewriter.getZeroAttr(vectorType.getElementType()));
7965ef7ceaeSNicolas Vasilache   Value result =
7975ef7ceaeSNicolas Vasilache       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
7981ca772edSChristopher Bate 
7991ca772edSChristopher Bate   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
8001ca772edSChristopher Bate 
8013af64383SNicolas Vasilache   // If we are not transposing, then we can use vectorized loads. Otherwise, we
8023af64383SNicolas Vasilache   // must load each element individually.
803670eee08SChristopher Bate   if (!isTransposeLoad) {
8045550c821STres Popp     if (!isa<VectorType>(loadedElType)) {
8051ca772edSChristopher Bate       loadedElType = VectorType::get({1}, loadedElType);
8061ca772edSChristopher Bate     }
8071ca772edSChristopher Bate 
8081ca772edSChristopher Bate     for (int i = 0; i < vectorType.getShape()[0]; i++) {
8091ca772edSChristopher Bate       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
8105ef7ceaeSNicolas Vasilache           rewriter, op.getLoc(), *warpMatrixInfo);
8111ca772edSChristopher Bate       if (failed(coords))
8125ef7ceaeSNicolas Vasilache         return rewriter.notifyMatchFailure(op, "no coords");
8135ef7ceaeSNicolas Vasilache 
8145ef7ceaeSNicolas Vasilache       Value logicalValueId = rewriter.create<arith::ConstantOp>(
8155ef7ceaeSNicolas Vasilache           loc, rewriter.getIndexType(),
8165ef7ceaeSNicolas Vasilache           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
8171ca772edSChristopher Bate       SmallVector<Value, 4> newIndices;
8181ca772edSChristopher Bate       getXferIndices<vector::TransferReadOp>(
8195ef7ceaeSNicolas Vasilache           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
8201ca772edSChristopher Bate 
8215ef7ceaeSNicolas Vasilache       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
8221ca772edSChristopher Bate                                                  op.getSource(), newIndices);
82316b75cd2SMatthias Springer       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
8241ca772edSChristopher Bate     }
825670eee08SChristopher Bate   } else {
8265550c821STres Popp     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
8271ca772edSChristopher Bate       loadedElType = vecType.getElementType();
8281ca772edSChristopher Bate     }
8291ca772edSChristopher Bate     for (int i = 0; i < vectorType.getShape()[0]; i++) {
8301ca772edSChristopher Bate       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
8311ca772edSChristopher Bate            innerIdx++) {
8321ca772edSChristopher Bate 
8335ef7ceaeSNicolas Vasilache         Value logicalValueId = rewriter.create<arith::ConstantOp>(
8345ef7ceaeSNicolas Vasilache             loc, rewriter.getIndexType(),
8355ef7ceaeSNicolas Vasilache             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
8361ca772edSChristopher Bate         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
8375ef7ceaeSNicolas Vasilache             rewriter, op.getLoc(), *warpMatrixInfo);
8381ca772edSChristopher Bate         if (failed(coords))
8395ef7ceaeSNicolas Vasilache           return rewriter.notifyMatchFailure(op, "no coords");
8401ca772edSChristopher Bate 
8411ca772edSChristopher Bate         SmallVector<Value, 4> newIndices;
8421ca772edSChristopher Bate         getXferIndices<vector::TransferReadOp>(
8435ef7ceaeSNicolas Vasilache             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
8445ef7ceaeSNicolas Vasilache         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
8451ca772edSChristopher Bate                                                    op.getSource(), newIndices);
8465ef7ceaeSNicolas Vasilache         result = rewriter.create<vector::InsertOp>(
84716b75cd2SMatthias Springer             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
8481ca772edSChristopher Bate       }
8491ca772edSChristopher Bate     }
8501ca772edSChristopher Bate   }
8511ca772edSChristopher Bate 
8521ca772edSChristopher Bate   valueMapping[op.getResult()] = result;
8531ca772edSChristopher Bate   return success();
8541ca772edSChristopher Bate }
8551ca772edSChristopher Bate 
856066b4fcbSThomas Raoux /// Return true if this is a shared memory memref type.
857066b4fcbSThomas Raoux static bool isSharedMemory(MemRefType type) {
858066b4fcbSThomas Raoux   auto addressSpace =
8595550c821STres Popp       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
860b984045dSMehdi Amini   return addressSpace &&
861b984045dSMehdi Amini          addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
862066b4fcbSThomas Raoux }
863066b4fcbSThomas Raoux 
8641ca772edSChristopher Bate /// Converts a `vector.transfer_read` operation directly to either a
8651ca772edSChristopher Bate /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
8661ca772edSChristopher Bate /// used when converting to `nvgpu.mma.sync` operations.
8671ca772edSChristopher Bate static LogicalResult
8685ef7ceaeSNicolas Vasilache convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
8691ca772edSChristopher Bate                            llvm::DenseMap<Value, Value> &valueMapping) {
8705ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
8715ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
8721ca772edSChristopher Bate 
8731ca772edSChristopher Bate   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
8741ca772edSChristopher Bate       nvgpu::getWarpMatrixInfo(op);
8751ca772edSChristopher Bate   if (failed(warpMatrixInfo))
8765ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
8771ca772edSChristopher Bate 
8781ca772edSChristopher Bate   bool isLdMatrixCompatible =
8795550c821STres Popp       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
8801ca772edSChristopher Bate       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
8811ca772edSChristopher Bate 
8821ca772edSChristopher Bate   VectorType vecTy = op.getVectorType();
8831ca772edSChristopher Bate   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
8841ca772edSChristopher Bate 
8853af64383SNicolas Vasilache   // When we are transposing the B operand, ldmatrix will only work if we have
8863af64383SNicolas Vasilache   // at least 8 rows to read and the width to read for the transpose is 128
8873af64383SNicolas Vasilache   // bits.
8881ca772edSChristopher Bate   if (!op.getPermutationMap().isMinorIdentity() &&
889271a48e0SThomas Raoux       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
890271a48e0SThomas Raoux        vecTy.getDimSize(0) * bitWidth < 128))
8911ca772edSChristopher Bate     isLdMatrixCompatible = false;
8921ca772edSChristopher Bate 
8931ca772edSChristopher Bate   if (!isLdMatrixCompatible)
8945ef7ceaeSNicolas Vasilache     return createNonLdMatrixLoads(rewriter, op, valueMapping);
8951ca772edSChristopher Bate 
8965ef7ceaeSNicolas Vasilache   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
8971ca772edSChristopher Bate }
8981ca772edSChristopher Bate 
8991ca772edSChristopher Bate static LogicalResult
9005ef7ceaeSNicolas Vasilache convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
9011ca772edSChristopher Bate                              llvm::DenseMap<Value, Value> &valueMapping) {
9025ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
9035ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
9045ef7ceaeSNicolas Vasilache 
9051ca772edSChristopher Bate   Location loc = op->getLoc();
9065ef7ceaeSNicolas Vasilache   auto it = valueMapping.find(op.getVector());
9075ef7ceaeSNicolas Vasilache   if (it == valueMapping.end())
9085ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mapping");
9095ef7ceaeSNicolas Vasilache   Value matrix = it->second;
9101ca772edSChristopher Bate 
9111ca772edSChristopher Bate   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
9121ca772edSChristopher Bate       nvgpu::getWarpMatrixInfo(op);
9131ca772edSChristopher Bate   if (failed(warpMatrixInfo))
9145ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
9151ca772edSChristopher Bate   FailureOr<nvgpu::FragmentElementInfo> regInfo =
9161ca772edSChristopher Bate       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
9171ca772edSChristopher Bate   if (failed(regInfo))
9185ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
9191ca772edSChristopher Bate 
9201ca772edSChristopher Bate   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
92143fd4c49SKrzysztof Drewniak   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
9221ca772edSChristopher Bate 
9231ca772edSChristopher Bate   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
9245ef7ceaeSNicolas Vasilache     Value logicalValueId = rewriter.create<arith::ConstantOp>(
9255ef7ceaeSNicolas Vasilache         loc, rewriter.getIndexType(),
9265ef7ceaeSNicolas Vasilache         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
9271ca772edSChristopher Bate     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
9285ef7ceaeSNicolas Vasilache         rewriter, op.getLoc(), *warpMatrixInfo);
9291ca772edSChristopher Bate     if (failed(coords))
9305ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "no coords");
9311ca772edSChristopher Bate 
9325ef7ceaeSNicolas Vasilache     Value el =
9335ef7ceaeSNicolas Vasilache         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
9341ca772edSChristopher Bate     SmallVector<Value, 4> newIndices;
9351ca772edSChristopher Bate     getXferIndices<vector::TransferWriteOp>(
9365ef7ceaeSNicolas Vasilache         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
9375ef7ceaeSNicolas Vasilache     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
9381ca772edSChristopher Bate   }
9395ef7ceaeSNicolas Vasilache 
9405ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
9415ef7ceaeSNicolas Vasilache   rewriter.eraseOp(op);
9421ca772edSChristopher Bate   return success();
9431ca772edSChristopher Bate }
9441ca772edSChristopher Bate 
945114ba722SManish Gupta static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
946114ba722SManish Gupta                                        SmallVectorImpl<int64_t> &results) {
947114ba722SManish Gupta   for (auto attr : arrayAttr)
9485550c821STres Popp     results.push_back(cast<IntegerAttr>(attr).getInt());
949114ba722SManish Gupta }
950114ba722SManish Gupta 
951114ba722SManish Gupta static LogicalResult
9525ef7ceaeSNicolas Vasilache convertExtractStridedSlice(RewriterBase &rewriter,
9535ef7ceaeSNicolas Vasilache                            vector::ExtractStridedSliceOp op,
954114ba722SManish Gupta                            llvm::DenseMap<Value, Value> &valueMapping) {
9555ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
9565ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
957114ba722SManish Gupta 
958114ba722SManish Gupta   Location loc = op->getLoc();
959114ba722SManish Gupta 
960114ba722SManish Gupta   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
961114ba722SManish Gupta       nvgpu::getWarpMatrixInfo(op);
962114ba722SManish Gupta   if (failed(warpMatrixInfo))
9635ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
964114ba722SManish Gupta 
965114ba722SManish Gupta   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
966114ba722SManish Gupta       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
967114ba722SManish Gupta   if (failed(mmaSyncFragmentInfo))
9685ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
969114ba722SManish Gupta 
9703af64383SNicolas Vasilache   // Find the vector.transer_read whose result vector is being sliced.
971114ba722SManish Gupta   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
972114ba722SManish Gupta   if (!transferReadOp)
9735ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no transfer read");
974114ba722SManish Gupta 
975114ba722SManish Gupta   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
976114ba722SManish Gupta   if (failed(warpMatrixInfo))
9775ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
978114ba722SManish Gupta 
979114ba722SManish Gupta   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
980114ba722SManish Gupta       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
981114ba722SManish Gupta   if (failed(ldFragmentInfo))
9825ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
983114ba722SManish Gupta 
9843af64383SNicolas Vasilache   assert(
9853af64383SNicolas Vasilache       (mmaSyncFragmentInfo->elementsPerRegister ==
986114ba722SManish Gupta        ldFragmentInfo->elementsPerRegister) &&
9873af64383SNicolas Vasilache       "Number of elements per register should be same for load and mma.sync");
988114ba722SManish Gupta 
9893af64383SNicolas Vasilache   // Create vector.extract_strided_slice op for thread-owned fragments.
990114ba722SManish Gupta   std::array<int64_t, 2> strides = {1,
991114ba722SManish Gupta                                     1}; // stride for extract slice is always 1.
992114ba722SManish Gupta   std::array<int64_t, 2> sliceShape = {
993114ba722SManish Gupta       mmaSyncFragmentInfo->numRegistersPerFragment,
994114ba722SManish Gupta       mmaSyncFragmentInfo->elementsPerRegister};
9955ef7ceaeSNicolas Vasilache   auto it = valueMapping.find(transferReadOp);
9965ef7ceaeSNicolas Vasilache   if (it == valueMapping.end())
9975ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mapping");
9985ef7ceaeSNicolas Vasilache   auto sourceVector = it->second;
999114ba722SManish Gupta 
1000114ba722SManish Gupta   // offset and sizes at warp-level of onwership.
1001114ba722SManish Gupta   SmallVector<int64_t> offsets;
1002114ba722SManish Gupta   populateFromInt64AttrArray(op.getOffsets(), offsets);
1003114ba722SManish Gupta 
1004114ba722SManish Gupta   SmallVector<int64_t> sizes;
1005114ba722SManish Gupta   populateFromInt64AttrArray(op.getSizes(), sizes);
1006a1aad28dSLei Zhang   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1007114ba722SManish Gupta 
10083af64383SNicolas Vasilache   // Compute offset in vector registers. Note that the mma.sync vector registers
10093af64383SNicolas Vasilache   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
10103af64383SNicolas Vasilache   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1011114ba722SManish Gupta   std::array<int64_t, 2> sliceOffset = {0, 0};
1012114ba722SManish Gupta 
1013114ba722SManish Gupta   if (offsets[0] && offsets[1])
1014114ba722SManish Gupta     return op->emitError() << "Slicing fragments in 2D is not supported. ";
10156a7a1188SMehdi Amini   if (offsets[0])
1016114ba722SManish Gupta     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1017114ba722SManish Gupta   else if (offsets[1])
1018114ba722SManish Gupta     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1019114ba722SManish Gupta 
10205ef7ceaeSNicolas Vasilache   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1021114ba722SManish Gupta       loc, sourceVector, sliceOffset, sliceShape, strides);
1022114ba722SManish Gupta 
1023114ba722SManish Gupta   valueMapping[op] = newOp;
1024114ba722SManish Gupta   return success();
1025114ba722SManish Gupta }
1026114ba722SManish Gupta 
10275ef7ceaeSNicolas Vasilache static LogicalResult
10285ef7ceaeSNicolas Vasilache convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1029edd9515bSthomasraoux                   llvm::DenseMap<Value, Value> &valueMapping) {
10305ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
10315ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
10325ef7ceaeSNicolas Vasilache 
10335ef7ceaeSNicolas Vasilache   auto itA = valueMapping.find(op.getLhs());
10345ef7ceaeSNicolas Vasilache   auto itB = valueMapping.find(op.getRhs());
10355ef7ceaeSNicolas Vasilache   auto itC = valueMapping.find(op.getAcc());
10365ef7ceaeSNicolas Vasilache   if (itA == valueMapping.end() || itB == valueMapping.end() ||
10375ef7ceaeSNicolas Vasilache       itC == valueMapping.end())
10385ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mapping");
10395ef7ceaeSNicolas Vasilache   Value opA = itA->second, opB = itB->second, opC = itC->second;
10405ef7ceaeSNicolas Vasilache   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
10413d35546cSNavdeep Katel       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
10423d35546cSNavdeep Katel       /*b_transpose=*/UnitAttr());
1043edd9515bSthomasraoux   valueMapping[op.getResult()] = matmul;
10445ef7ceaeSNicolas Vasilache   return success();
1045edd9515bSthomasraoux }
1046edd9515bSthomasraoux 
10471ca772edSChristopher Bate static LogicalResult
10485ef7ceaeSNicolas Vasilache convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
10491ca772edSChristopher Bate                            llvm::DenseMap<Value, Value> &valueMapping) {
10505ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
10515ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
10525ef7ceaeSNicolas Vasilache 
10535ef7ceaeSNicolas Vasilache   auto itA = valueMapping.find(op.getLhs());
10545ef7ceaeSNicolas Vasilache   auto itB = valueMapping.find(op.getRhs());
10555ef7ceaeSNicolas Vasilache   auto itC = valueMapping.find(op.getAcc());
10565ef7ceaeSNicolas Vasilache   if (itA == valueMapping.end() || itB == valueMapping.end() ||
10575ef7ceaeSNicolas Vasilache       itC == valueMapping.end())
10585ef7ceaeSNicolas Vasilache     return rewriter.notifyMatchFailure(op, "no mapping");
10595ef7ceaeSNicolas Vasilache   Value opA = itA->second, opB = itB->second, opC = itC->second;
10605550c821STres Popp   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
10615550c821STres Popp   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
10625550c821STres Popp   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
10635ef7ceaeSNicolas Vasilache   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
10645ef7ceaeSNicolas Vasilache       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
10651ca772edSChristopher Bate   valueMapping[op.getResult()] = matmul;
10661ca772edSChristopher Bate   return success();
10671ca772edSChristopher Bate }
10681ca772edSChristopher Bate 
10696413226dSthomasraoux /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
10705ef7ceaeSNicolas Vasilache static LogicalResult
10715ef7ceaeSNicolas Vasilache convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
10726413226dSthomasraoux                   llvm::DenseMap<Value, Value> &valueMapping) {
10735ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
10745ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
10755ef7ceaeSNicolas Vasilache 
10766413226dSthomasraoux   assert(constantSupportsMMAMatrixType(op));
10775ef7ceaeSNicolas Vasilache 
1078e1795322SJeff Niu   auto splat =
10795550c821STres Popp       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
10806413226dSthomasraoux   auto scalarConstant =
10815ef7ceaeSNicolas Vasilache       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
10826413226dSthomasraoux   const char *fragType = inferFragType(op);
10835550c821STres Popp   auto vecType = cast<VectorType>(op.getType());
10846413226dSthomasraoux   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
10856413226dSthomasraoux       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
10865ef7ceaeSNicolas Vasilache   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
10875ef7ceaeSNicolas Vasilache       op.getLoc(), type, scalarConstant);
10886413226dSthomasraoux   valueMapping[op.getResult()] = matrix;
10895ef7ceaeSNicolas Vasilache   return success();
10906413226dSthomasraoux }
10916413226dSthomasraoux 
109243928419Sthomasraoux /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
10935ef7ceaeSNicolas Vasilache static LogicalResult
10945ef7ceaeSNicolas Vasilache convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
109543928419Sthomasraoux                    llvm::DenseMap<Value, Value> &valueMapping) {
10965ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
10975ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
10985ef7ceaeSNicolas Vasilache 
10993af64383SNicolas Vasilache   assert(broadcastSupportsMMAMatrixType(op));
11005ef7ceaeSNicolas Vasilache 
110143928419Sthomasraoux   const char *fragType = inferFragType(op);
1102a1aad28dSLei Zhang   auto vecType = op.getResultVectorType();
110343928419Sthomasraoux   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
110443928419Sthomasraoux       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
11055ef7ceaeSNicolas Vasilache   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
11065ef7ceaeSNicolas Vasilache       op.getLoc(), type, op.getSource());
110743928419Sthomasraoux   valueMapping[op.getResult()] = matrix;
11085ef7ceaeSNicolas Vasilache   return success();
110943928419Sthomasraoux }
111043928419Sthomasraoux 
11111a865592Sthomasraoux // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
11129b5ef2beSMatthias Springer // updated and needs to be updated separately for the loop to be correct.
11135ef7ceaeSNicolas Vasilache static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
11145ef7ceaeSNicolas Vasilache                                                scf::ForOp loop,
11155cf714bbSMatthias Springer                                                ValueRange newInitArgs) {
11165ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
11175ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(loop);
11185ef7ceaeSNicolas Vasilache 
11193af64383SNicolas Vasilache   // Create a new loop before the existing one, with the extra operands.
11205ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(loop);
11215cf714bbSMatthias Springer   auto operands = llvm::to_vector<4>(loop.getInitArgs());
11225cf714bbSMatthias Springer   llvm::append_range(operands, newInitArgs);
11235ef7ceaeSNicolas Vasilache   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
11245ef7ceaeSNicolas Vasilache       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
11255ef7ceaeSNicolas Vasilache       operands);
112662bf7710SMatthias Springer   rewriter.eraseBlock(newLoop.getBody());
11275ef7ceaeSNicolas Vasilache 
11289b5ef2beSMatthias Springer   newLoop.getRegion().getBlocks().splice(
11299b5ef2beSMatthias Springer       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
11305cf714bbSMatthias Springer   for (Value operand : newInitArgs)
1131e084679fSRiver Riddle     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
11321a865592Sthomasraoux 
11331a865592Sthomasraoux   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
11341a865592Sthomasraoux                                                   loop.getNumResults())))
11355ef7ceaeSNicolas Vasilache     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
11365ef7ceaeSNicolas Vasilache 
11375ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
11385ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
11395ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "erase: " << loop);
11405ef7ceaeSNicolas Vasilache 
11415ef7ceaeSNicolas Vasilache   rewriter.eraseOp(loop);
11421a865592Sthomasraoux   return newLoop;
11431a865592Sthomasraoux }
11441a865592Sthomasraoux 
11455ef7ceaeSNicolas Vasilache static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
11461a865592Sthomasraoux                                   llvm::DenseMap<Value, Value> &valueMapping) {
11475ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
11485ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
11495ef7ceaeSNicolas Vasilache 
11501a865592Sthomasraoux   SmallVector<Value> newOperands;
11511a865592Sthomasraoux   SmallVector<std::pair<size_t, size_t>> argMapping;
11525cf714bbSMatthias Springer   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
11531a865592Sthomasraoux     auto it = valueMapping.find(operand.value());
11545ef7ceaeSNicolas Vasilache     if (it == valueMapping.end()) {
11555ef7ceaeSNicolas Vasilache       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
11561a865592Sthomasraoux       continue;
11575ef7ceaeSNicolas Vasilache     }
11581a865592Sthomasraoux     argMapping.push_back(std::make_pair(
11595cf714bbSMatthias Springer         operand.index(), op.getInitArgs().size() + newOperands.size()));
11601a865592Sthomasraoux     newOperands.push_back(it->second);
11611a865592Sthomasraoux   }
11625ef7ceaeSNicolas Vasilache 
11635ef7ceaeSNicolas Vasilache   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
11641a865592Sthomasraoux   Block &loopBody = *newForOp.getBody();
11651a865592Sthomasraoux   for (auto mapping : argMapping) {
11661a865592Sthomasraoux     valueMapping[newForOp.getResult(mapping.first)] =
11671a865592Sthomasraoux         newForOp.getResult(mapping.second);
11681a865592Sthomasraoux     valueMapping[loopBody.getArgument(mapping.first +
11691a865592Sthomasraoux                                       newForOp.getNumInductionVars())] =
11701a865592Sthomasraoux         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
11711a865592Sthomasraoux   }
11725ef7ceaeSNicolas Vasilache 
11735ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
11745ef7ceaeSNicolas Vasilache   return success();
11751a865592Sthomasraoux }
11761a865592Sthomasraoux 
11775ef7ceaeSNicolas Vasilache static LogicalResult
11785ef7ceaeSNicolas Vasilache convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
11791a865592Sthomasraoux                llvm::DenseMap<Value, Value> &valueMapping) {
11805ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
11815ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
11825ef7ceaeSNicolas Vasilache 
11831a865592Sthomasraoux   auto loop = cast<scf::ForOp>(op->getParentOp());
11841a865592Sthomasraoux   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1185e4853be2SMehdi Amini   for (const auto &operand : llvm::enumerate(op.getOperands())) {
11861a865592Sthomasraoux     auto it = valueMapping.find(operand.value());
11871a865592Sthomasraoux     if (it == valueMapping.end())
11881a865592Sthomasraoux       continue;
11893af64383SNicolas Vasilache     // Replace the yield of old value with the for op argument to make it easier
11903af64383SNicolas Vasilache     // to remove the dead code.
11915cf714bbSMatthias Springer     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
11921a865592Sthomasraoux     yieldOperands.push_back(it->second);
11931a865592Sthomasraoux   }
11945ef7ceaeSNicolas Vasilache   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
11955ef7ceaeSNicolas Vasilache 
11965ef7ceaeSNicolas Vasilache   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
11975ef7ceaeSNicolas Vasilache   rewriter.eraseOp(op);
11985ef7ceaeSNicolas Vasilache   return success();
11991a865592Sthomasraoux }
12001a865592Sthomasraoux 
12017fbb0678Sthomasraoux /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
12025ef7ceaeSNicolas Vasilache static LogicalResult
12035ef7ceaeSNicolas Vasilache convertElementwiseOp(RewriterBase &rewriter, Operation *op,
12045ef7ceaeSNicolas Vasilache                      gpu::MMAElementwiseOp opType,
12057fbb0678Sthomasraoux                      llvm::DenseMap<Value, Value> &valueMapping) {
12065ef7ceaeSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
12075ef7ceaeSNicolas Vasilache   rewriter.setInsertionPoint(op);
12085ef7ceaeSNicolas Vasilache 
12097fbb0678Sthomasraoux   SmallVector<Value> matrixOperands;
12105ef7ceaeSNicolas Vasilache   for (Value operand : op->getOperands()) {
12115ef7ceaeSNicolas Vasilache     auto it = valueMapping.find(operand);
12125ef7ceaeSNicolas Vasilache     if (it == valueMapping.end())
12135ef7ceaeSNicolas Vasilache       return rewriter.notifyMatchFailure(op, "no mapping");
12145ef7ceaeSNicolas Vasilache     matrixOperands.push_back(it->second);
12155ef7ceaeSNicolas Vasilache   }
1216a5757c5bSChristian Sigg   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1217a0119437SLei Zhang   if (opType == gpu::MMAElementwiseOp::EXTF) {
1218a0119437SLei Zhang     // The floating point extension case has a different result type.
1219a5757c5bSChristian Sigg     auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1220a0119437SLei Zhang     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1221a0119437SLei Zhang                                          vectorType.getElementType(),
1222a0119437SLei Zhang                                          resultType.getOperand());
1223a0119437SLei Zhang   }
1224a0119437SLei Zhang 
12255ef7ceaeSNicolas Vasilache   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1226a0119437SLei Zhang       op->getLoc(), resultType, matrixOperands, opType);
12277fbb0678Sthomasraoux   valueMapping[op->getResult(0)] = newOp;
12285ef7ceaeSNicolas Vasilache   return success();
12297fbb0678Sthomasraoux }
12307fbb0678Sthomasraoux 
12311ca772edSChristopher Bate void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
12321ca772edSChristopher Bate                                               bool useNvGpu) {
12331ca772edSChristopher Bate   if (!useNvGpu) {
1234edd9515bSthomasraoux     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1235edd9515bSthomasraoux         patterns.getContext());
12361ca772edSChristopher Bate     return;
12371ca772edSChristopher Bate   }
1238fb7ef637SJakub Kuderski   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1239fb7ef637SJakub Kuderski   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1240edd9515bSthomasraoux }
1241edd9515bSthomasraoux 
12425ef7ceaeSNicolas Vasilache LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
12435ef7ceaeSNicolas Vasilache                                           Operation *rootOp) {
12441ca772edSChristopher Bate   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1245edd9515bSthomasraoux   llvm::DenseMap<Value, Value> valueMapping;
12465ef7ceaeSNicolas Vasilache 
12475ef7ceaeSNicolas Vasilache   auto globalRes = LogicalResult::success();
1248edd9515bSthomasraoux   for (Operation *op : ops) {
12495ef7ceaeSNicolas Vasilache     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
12505ef7ceaeSNicolas Vasilache     // Apparently callers do not want to early exit on failure here.
12515ef7ceaeSNicolas Vasilache     auto res = LogicalResult::success();
1252edd9515bSthomasraoux     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
12535ef7ceaeSNicolas Vasilache       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1254edd9515bSthomasraoux     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
12555ef7ceaeSNicolas Vasilache       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1256edd9515bSthomasraoux     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
12575ef7ceaeSNicolas Vasilache       res = convertContractOp(rewriter, contractOp, valueMapping);
1258a54f4eaeSMogball     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
12595ef7ceaeSNicolas Vasilache       res = convertConstantOp(rewriter, constantOp, valueMapping);
126043928419Sthomasraoux     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
12615ef7ceaeSNicolas Vasilache       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
12621a865592Sthomasraoux     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
12635ef7ceaeSNicolas Vasilache       res = convertForOp(rewriter, forOp, valueMapping);
12645ef7ceaeSNicolas Vasilache     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
12655ef7ceaeSNicolas Vasilache       res = convertYieldOp(rewriter, yieldOp, valueMapping);
12667fbb0678Sthomasraoux     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
12675ef7ceaeSNicolas Vasilache       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1268edd9515bSthomasraoux     }
12695ef7ceaeSNicolas Vasilache     if (failed(res))
12705ef7ceaeSNicolas Vasilache       globalRes = failure();
1271edd9515bSthomasraoux   }
12725ef7ceaeSNicolas Vasilache   return globalRes;
1273edd9515bSthomasraoux }
1274edd9515bSthomasraoux 
12755ef7ceaeSNicolas Vasilache LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
12765ef7ceaeSNicolas Vasilache                                                          Operation *rootOp) {
12771ca772edSChristopher Bate   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
12781ca772edSChristopher Bate   llvm::DenseMap<Value, Value> valueMapping;
12791ca772edSChristopher Bate   for (Operation *op : ops) {
12801ca772edSChristopher Bate     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
12811ca772edSChristopher Bate             .Case([&](vector::TransferReadOp transferReadOp) {
12825ef7ceaeSNicolas Vasilache               return convertTransferReadToLoads(rewriter, transferReadOp,
12835ef7ceaeSNicolas Vasilache                                                 valueMapping);
12841ca772edSChristopher Bate             })
12851ca772edSChristopher Bate             .Case([&](vector::TransferWriteOp transferWriteOp) {
12865ef7ceaeSNicolas Vasilache               return convertTransferWriteToStores(rewriter, transferWriteOp,
12871ca772edSChristopher Bate                                                   valueMapping);
12881ca772edSChristopher Bate             })
1289114ba722SManish Gupta             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
12905ef7ceaeSNicolas Vasilache               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1291114ba722SManish Gupta                                                 valueMapping);
1292114ba722SManish Gupta             })
12931ca772edSChristopher Bate             .Case([&](vector::ContractionOp contractionOp) {
12945ef7ceaeSNicolas Vasilache               return convertContractOpToMmaSync(rewriter, contractionOp,
12955ef7ceaeSNicolas Vasilache                                                 valueMapping);
12961ca772edSChristopher Bate             })
12971ca772edSChristopher Bate             .Case([&](scf::ForOp forOp) {
12985ef7ceaeSNicolas Vasilache               return convertForOp(rewriter, forOp, valueMapping);
12991ca772edSChristopher Bate             })
13001ca772edSChristopher Bate             .Case([&](scf::YieldOp yieldOp) {
13015ef7ceaeSNicolas Vasilache               return convertYieldOp(rewriter, yieldOp, valueMapping);
13021ca772edSChristopher Bate             })
13031ca772edSChristopher Bate             .Case([&](arith::ConstantOp constOp) {
13045ef7ceaeSNicolas Vasilache               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
13051ca772edSChristopher Bate             })
13061ca772edSChristopher Bate             .Default([&](Operation *op) {
13075ef7ceaeSNicolas Vasilache               return op->emitError() << "unhandled vector to mma type: " << *op;
13081ca772edSChristopher Bate             })
13091ca772edSChristopher Bate             .failed()) {
1310cafb6284SChristopher Bate       return op->emitOpError()
1311cafb6284SChristopher Bate              << "failed to convert op during vector-to-nvgpu conversion";
13121ca772edSChristopher Bate     }
13131ca772edSChristopher Bate   }
13141ca772edSChristopher Bate   return success();
13151ca772edSChristopher Bate }
13161ca772edSChristopher Bate 
1317edd9515bSthomasraoux namespace {
1318edd9515bSthomasraoux 
1319edd9515bSthomasraoux struct ConvertVectorToGPUPass
132067d0d7acSMichele Scuttari     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
13211ca772edSChristopher Bate 
13221ca772edSChristopher Bate   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
13231ca772edSChristopher Bate     useNvGpu.setValue(useNvGpu_);
13241ca772edSChristopher Bate   }
13251ca772edSChristopher Bate 
132641574554SRiver Riddle   void runOnOperation() override {
132747f175b0SRiver Riddle     RewritePatternSet patterns(&getContext());
13281ca772edSChristopher Bate     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
132909dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
13301ca772edSChristopher Bate       return signalPassFailure();
1331edd9515bSthomasraoux 
13325ef7ceaeSNicolas Vasilache     IRRewriter rewriter(&getContext());
1333cafb6284SChristopher Bate     if (useNvGpu) {
13345ef7ceaeSNicolas Vasilache       if (failed(
13355ef7ceaeSNicolas Vasilache               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
13361ca772edSChristopher Bate         return signalPassFailure();
1337cafb6284SChristopher Bate       return;
13381ca772edSChristopher Bate     }
13395ef7ceaeSNicolas Vasilache     (void)convertVectorToMMAOps(rewriter, getOperation());
1340edd9515bSthomasraoux   }
1341edd9515bSthomasraoux };
1342edd9515bSthomasraoux 
1343edd9515bSthomasraoux } // namespace
1344edd9515bSthomasraoux 
13451ca772edSChristopher Bate std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
13461ca772edSChristopher Bate   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1347edd9515bSthomasraoux }
1348