xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision 32c3decb773b0cc7fd1736fe1b47d889c8c0011c)
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arith/IR/Arith.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Region.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58                            AffineMap offsetMap, ArrayRef<Value> dimValues,
59                            SmallVector<Value, 4> &indices) {
60   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61   Location loc = xferOp.getLoc();
62   unsigned offsetsIdx = 0;
63   for (auto expr : xferOp.getPermutationMap().getResults()) {
64     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65       Value prevIdx = indices[dim.getPosition()];
66       SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67       dims.push_back(prevIdx);
68       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69       indices[dim.getPosition()] = affine::makeComposedAffineApply(
70           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71       continue;
72     }
73   }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78                                           bool useNvGpu) {
79   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
81   AffineExpr m, n, k;
82   bindDims(contract.getContext(), m, n, k);
83   auto iteratorTypes = contract.getIteratorTypes().getValue();
84   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
85         vector::isParallelIterator(iteratorTypes[1]) &&
86         vector::isReductionIterator(iteratorTypes[2])))
87     return false;
88 
89   // The contract needs to represent a matmul to be able to convert to
90   // MMAMatrix matmul.
91   if (!useNvGpu &&
92       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
93     return false;
94   if (useNvGpu &&
95       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
96     return false;
97 
98   return true;
99 }
100 
101 // Return true if the given map represents a transposed matrix load,
102 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
104   MLIRContext *ctx = permutationMap.getContext();
105   // Local OpBuilder is fine here, we just build attributes.
106   OpBuilder b(ctx);
107   auto nDim = permutationMap.getNumDims();
108   AffineExpr zero = b.getAffineConstantExpr(0);
109   if (nDim < 2) {
110     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
111     AffineExpr dim0 = b.getAffineDimExpr(0);
112     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
113   }
114 
115   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
116   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
117   // Support both transposed and transposed+broadcasted cases.
118   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
119          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
120 }
121 
122 // Return the stide for the second-to-last dimension of |type| if it is a memref
123 // and has a constant stride.
124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
125   auto memrefType = dyn_cast<MemRefType>(type);
126   if (!memrefType)
127     return false;
128   // If the memref is 0 or 1D the horizontal stride is 0.
129   if (memrefType.getRank() < 2)
130     return 0;
131   int64_t offset = 0;
132   SmallVector<int64_t, 2> strides;
133   if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
134       strides.back() != 1)
135     return std::nullopt;
136   int64_t stride = strides[strides.size() - 2];
137   if (stride == ShapedType::kDynamic)
138     return std::nullopt;
139   return stride;
140 }
141 
142 // Return true if the transfer op can be converted to a MMA matrix load.
143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
144   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
145       readOp.getVectorType().getRank() != 2)
146     return false;
147   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
148     return false;
149 
150   // Only allow integer types if the signedness can be inferred.
151   if (readOp.getVectorType().getElementType().isInteger(8))
152     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
153                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
154       return false;
155 
156   AffineMap map = readOp.getPermutationMap();
157   MLIRContext *ctx = readOp.getContext();
158   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
159   AffineExpr zero = getAffineConstantExpr(0, ctx);
160   auto broadcastInnerDim =
161       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
162   return map.isMinorIdentity() || map == broadcastInnerDim ||
163          isTransposeMatrixLoadMap(map);
164 }
165 
166 // Return true if the transfer op can be converted to a MMA matrix store.
167 static bool
168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
169   // TODO: support 0-d corner case.
170   if (writeOp.getTransferRank() == 0)
171     return false;
172 
173   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
174       writeOp.getVectorType().getRank() != 2)
175     return false;
176   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
177     return false;
178   // TODO: Support transpose once it is added to GPU dialect ops.
179   if (!writeOp.getPermutationMap().isMinorIdentity())
180     return false;
181   return true;
182 }
183 
184 /// Return true if the constant is a splat to a 2D vector so that it can be
185 /// converted to a MMA constant matrix op.
186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
187   auto vecType = dyn_cast<VectorType>(constantOp.getType());
188   if (!vecType || vecType.getRank() != 2)
189     return false;
190   return isa<SplatElementsAttr>(constantOp.getValue());
191 }
192 
193 /// Return true if this is a broadcast from scalar to a 2D vector.
194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
195   return broadcastOp.getResultVectorType().getRank() == 2;
196 }
197 
198 /// Return true if this integer extend op can be folded into a contract op.
199 template <typename ExtOpTy>
200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
201   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
202     return false;
203   return llvm::all_of(extOp->getUsers(), [](Operation *user) {
204     return isa<vector::ContractionOp>(user);
205   });
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
213 convertElementwiseOpToMMA(Operation *op) {
214   if (isa<arith::AddFOp>(op))
215     return gpu::MMAElementwiseOp::ADDF;
216   if (isa<arith::MulFOp>(op))
217     return gpu::MMAElementwiseOp::MULF;
218   if (isa<arith::SubFOp>(op))
219     return gpu::MMAElementwiseOp::SUBF;
220   if (isa<arith::MaximumFOp>(op))
221     return gpu::MMAElementwiseOp::MAXF;
222   if (isa<arith::MinimumFOp>(op))
223     return gpu::MMAElementwiseOp::MINF;
224   if (isa<arith::DivFOp>(op))
225     return gpu::MMAElementwiseOp::DIVF;
226   if (isa<arith::AddIOp>(op))
227     return gpu::MMAElementwiseOp::ADDI;
228   if (isa<arith::MulIOp>(op))
229     return gpu::MMAElementwiseOp::MULI;
230   if (isa<arith::SubIOp>(op))
231     return gpu::MMAElementwiseOp::SUBI;
232   if (isa<arith::DivSIOp>(op))
233     return gpu::MMAElementwiseOp::DIVS;
234   if (isa<arith::DivUIOp>(op))
235     return gpu::MMAElementwiseOp::DIVU;
236   if (isa<arith::NegFOp>(op))
237     return gpu::MMAElementwiseOp::NEGATEF;
238   if (isa<arith::ExtFOp>(op))
239     return gpu::MMAElementwiseOp::EXTF;
240   return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
244 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
245   return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
254       nvgpu::getWarpMatrixInfo(op);
255   if (failed(warpMatrixInfo))
256     return false;
257 
258   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259   if (failed(contractOp))
260     return false;
261 
262   // Handle vector.extract_strided_slice on registers containing
263   // matrixB and matrixC operands. vector.extract_strided_slice op
264   // is not supported on registers containing matrixA operands.
265   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266     return (cast<VectorType>(op->getResult(0).getType()) ==
267             cast<VectorType>((*contractOp).getRhs().getType()));
268   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269     return (cast<VectorType>(op->getResult(0).getType()) ==
270             cast<VectorType>((*contractOp).getAcc().getType()));
271 
272   return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276   if (isa<scf::ForOp, scf::YieldOp>(op))
277     return true;
278   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280                     : transferReadSupportsMMAMatrixType(transferRead);
281   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283                     : transferWriteSupportsMMAMatrixType(transferWrite);
284   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285     return useNvGpu &&
286            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287   if (auto contract = dyn_cast<vector::ContractionOp>(op))
288     return contractSupportsMMAMatrixType(contract, useNvGpu);
289   if (auto constant = dyn_cast<arith::ConstantOp>(op))
290     return constantSupportsMMAMatrixType(constant);
291   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
292     return broadcastSupportsMMAMatrixType(broadcast);
293   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298     return fpExtendSupportsMMAMatrixType(fpExtend);
299   return elementwiseSupportsMMAMatrixType(op);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
305 static SetVector<Operation *>
306 getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions,
307                  ForwardSliceOptions forwardSliceOptions) {
308   SetVector<Operation *> slice;
309   slice.insert(op);
310   unsigned currentIndex = 0;
311   SetVector<Operation *> backwardSlice;
312   SetVector<Operation *> forwardSlice;
313   while (currentIndex != slice.size()) {
314     auto *currentOp = (slice)[currentIndex];
315     // Compute and insert the backwardSlice starting from currentOp.
316     backwardSlice.clear();
317     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
318     slice.insert(backwardSlice.begin(), backwardSlice.end());
319 
320     // Compute and insert the forwardSlice starting from currentOp.
321     forwardSlice.clear();
322     // Special case for ForOp, we don't want to include the whole region but
323     // only the value using the region arguments.
324     // TODO: We should refine this to only care about the region arguments being
325     // converted to matrix type.
326     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
327       for (Value forOpResult : forOp.getResults())
328         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
329       for (BlockArgument &arg : forOp.getRegionIterArgs())
330         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
331     } else {
332       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
333     }
334     slice.insert(forwardSlice.begin(), forwardSlice.end());
335     ++currentIndex;
336   }
337   return slice;
338 }
339 
340 // Analyze slice of operations based on convert op to figure out if the whole
341 // slice can be converted to MMA operations.
342 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
343                                              bool useNvGpu) {
344   auto hasVectorDest = [](Operation *op) {
345     return llvm::any_of(op->getResultTypes(),
346                         [](Type t) { return isa<VectorType>(t); });
347   };
348   BackwardSliceOptions backwardSliceOptions;
349   backwardSliceOptions.filter = hasVectorDest;
350 
351   auto hasVectorSrc = [](Operation *op) {
352     return llvm::any_of(op->getOperandTypes(),
353                         [](Type t) { return isa<VectorType>(t); });
354   };
355   ForwardSliceOptions forwardSliceOptions;
356   forwardSliceOptions.filter = hasVectorSrc;
357 
358   SetVector<Operation *> opToConvert;
359   op->walk([&](vector::ContractionOp contract) {
360     if (opToConvert.contains(contract.getOperation()))
361       return;
362     SetVector<Operation *> dependentOps =
363         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
364     // If any instruction cannot use MMA matrix type drop the whole
365     // chain. MMA matrix are stored in an opaque type so they cannot be used
366     // by all operations.
367     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
368           if (!supportsMMaMatrixType(op, useNvGpu)) {
369             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
370             return true;
371           }
372           return false;
373         }))
374       return;
375 
376     opToConvert.insert(dependentOps.begin(), dependentOps.end());
377   });
378   // Sort the operations so that we can convert them in topological order.
379   return topologicalSort(opToConvert);
380 }
381 
382 namespace {
383 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
384 // to MMA matmul.
385 struct PrepareContractToGPUMMA
386     : public OpRewritePattern<vector::ContractionOp> {
387   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
388 
389   LogicalResult matchAndRewrite(vector::ContractionOp op,
390                                 PatternRewriter &rewriter) const override {
391     Location loc = op.getLoc();
392     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
393 
394     // Set up the parallel/reduction structure in right form.
395     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
396     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
397     AffineExpr m, n, k;
398     bindDims(rewriter.getContext(), m, n, k);
399     static constexpr std::array<int64_t, 2> perm = {1, 0};
400     auto iteratorTypes = op.getIteratorTypes().getValue();
401     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
402     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
403           vector::isParallelIterator(iteratorTypes[1]) &&
404           vector::isReductionIterator(iteratorTypes[2])))
405       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
406     //
407     // Two outer parallel, one inner reduction (matmat flavor).
408     //
409     // This is the classical row-major matmul, nothing to do.
410     if (maps == infer({{m, k}, {k, n}, {m, n}}))
411       return rewriter.notifyMatchFailure(op, "contraction already prepared");
412     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
413       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
414     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
415       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
416     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
417       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
418       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
420       std::swap(rhs, lhs);
421       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
422       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
423     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
424       std::swap(rhs, lhs);
425       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
426     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
427       std::swap(lhs, rhs);
428       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
429     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
430       std::swap(lhs, rhs);
431     } else {
432       // TODO: llvm_unreachable ?
433       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
434     }
435     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
436         op, lhs, rhs, res,
437         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
438         op.getIteratorTypes());
439     return success();
440   }
441 };
442 
443 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
444 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
445 // respectively. We can fold the transpose operation when loading the data from
446 // Shared Memory to registers.
447 struct CombineTransferReadOpTranspose final
448     : public OpRewritePattern<vector::TransposeOp> {
449   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
450 
451   LogicalResult matchAndRewrite(vector::TransposeOp op,
452                                 PatternRewriter &rewriter) const override {
453     // Look through integer extend ops.
454     Value source = op.getVector();
455     Type resultType = op.getType();
456     Operation *extOp;
457     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
458         (extOp = source.getDefiningOp<arith::ExtUIOp>())) {
459       source = extOp->getOperand(0);
460       resultType =
461           VectorType::get(cast<VectorType>(resultType).getShape(),
462                           cast<VectorType>(source.getType()).getElementType());
463     }
464 
465     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
466     if (!transferReadOp)
467       return rewriter.notifyMatchFailure(op, "no transfer read");
468 
469     // TODO: support 0-d corner case.
470     if (transferReadOp.getTransferRank() == 0)
471       return rewriter.notifyMatchFailure(op, "0-D transfer read");
472 
473     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
474       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
475 
476     AffineMap permutationMap =
477         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
478     AffineMap newMap =
479         permutationMap.compose(transferReadOp.getPermutationMap());
480 
481     auto loc = op.getLoc();
482     Value result =
483         rewriter
484             .create<vector::TransferReadOp>(
485                 loc, resultType, transferReadOp.getSource(),
486                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
487                 transferReadOp.getPadding(), transferReadOp.getMask(),
488                 transferReadOp.getInBoundsAttr())
489             .getResult();
490 
491     // Fuse through the integer extend op.
492     if (extOp) {
493       if (isa<arith::ExtSIOp>(extOp))
494         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
495                      .getResult();
496       else
497         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
498                      .getResult();
499     }
500 
501     rewriter.replaceOp(op, result);
502     return success();
503   }
504 };
505 
506 } // namespace
507 
508 // MMA types have different layout based on how they are used in matmul ops.
509 // Figure the right layout to use by looking at op uses.
510 // TODO: Change the GPU dialect to abstract the layout at the this level and
511 // only care about it during lowering to NVVM.
512 static const char *inferFragType(Operation *op) {
513   for (Operation *users : op->getUsers()) {
514     auto contract = dyn_cast<vector::ContractionOp>(users);
515     if (!contract)
516       continue;
517     assert(op->getNumResults() == 1);
518     if (contract.getLhs() == op->getResult(0))
519       return "AOp";
520     if (contract.getRhs() == op->getResult(0))
521       return "BOp";
522   }
523   return "COp";
524 }
525 
526 static LogicalResult
527 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
528                       llvm::DenseMap<Value, Value> &valueMapping) {
529   OpBuilder::InsertionGuard g(rewriter);
530   rewriter.setInsertionPoint(op);
531 
532   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
533   assert(transferReadSupportsMMAMatrixType(op) &&
534          "expected convertible operation");
535 
536   std::optional<int64_t> stride =
537       getStaticallyKnownRowStride(op.getShapedType());
538   if (!stride.has_value()) {
539     LLVM_DEBUG(DBGS() << "no stride\n");
540     return rewriter.notifyMatchFailure(op, "no stride");
541   }
542 
543   AffineMap map = op.getPermutationMap();
544   bool isTranspose = isTransposeMatrixLoadMap(map);
545 
546   // Handle broadcast by setting the stride to 0.
547   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
548     assert(cstExpr.getValue() == 0);
549     stride = 0;
550   }
551 
552   Value mappingResult = op.getResult();
553   auto elType = op.getVectorType().getElementType();
554   const char *fragType = inferFragType(op);
555   if (op->hasOneUse()) {
556     auto user = *op->user_begin();
557     // Infer the signedness of the mma type from the integer extend.
558     bool isSignedExtend = isa<arith::ExtSIOp>(user);
559     if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
560       elType = IntegerType::get(
561           op.getContext(), cast<IntegerType>(elType).getWidth(),
562           isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
563       mappingResult = user->getResult(0);
564       fragType = inferFragType(user);
565     }
566   }
567   gpu::MMAMatrixType type =
568       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
569   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
570       op.getLoc(), type, op.getSource(), op.getIndices(),
571       rewriter.getIndexAttr(*stride),
572       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
573   valueMapping[mappingResult] = load;
574 
575   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
576   return success();
577 }
578 
579 static LogicalResult
580 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
581                        llvm::DenseMap<Value, Value> &valueMapping) {
582   OpBuilder::InsertionGuard g(rewriter);
583   rewriter.setInsertionPoint(op);
584 
585   assert(transferWriteSupportsMMAMatrixType(op));
586   std::optional<int64_t> stride =
587       getStaticallyKnownRowStride(op.getShapedType());
588   if (!stride.has_value()) {
589     LLVM_DEBUG(DBGS() << "no stride\n");
590     return rewriter.notifyMatchFailure(op, "no stride");
591   }
592 
593   auto it = valueMapping.find(op.getVector());
594   if (it == valueMapping.end()) {
595     LLVM_DEBUG(DBGS() << "no mapping\n");
596     return rewriter.notifyMatchFailure(op, "no mapping");
597   }
598 
599   Value matrix = it->second;
600   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
601       op.getLoc(), matrix, op.getSource(), op.getIndices(),
602       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
603   (void)store;
604 
605   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
606 
607   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
608   rewriter.eraseOp(op);
609   return success();
610 }
611 
612 /// Returns the vector type which represents a matrix fragment.
613 static VectorType
614 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
615   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
616                              regInfo.elementsPerRegister};
617   Type elType = regInfo.registerLLVMType;
618   if (auto vecType = dyn_cast<VectorType>(elType))
619     elType = vecType.getElementType();
620   return VectorType::get(shape, elType);
621 }
622 
623 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
624 static LogicalResult
625 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
626                          llvm::DenseMap<Value, Value> &valueMapping) {
627   OpBuilder::InsertionGuard g(rewriter);
628   rewriter.setInsertionPoint(op);
629 
630   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
631       nvgpu::getWarpMatrixInfo(op);
632   if (failed(warpMatrixInfo)) {
633     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
634     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
635   }
636 
637   FailureOr<nvgpu::FragmentElementInfo> regInfo =
638       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
639   if (failed(regInfo)) {
640     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
641     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
642   }
643 
644   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
645   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
646   if (!dense) {
647     LLVM_DEBUG(DBGS() << "not a splat\n");
648     return rewriter.notifyMatchFailure(op, "not a splat");
649   }
650 
651   Value result = rewriter.create<arith::ConstantOp>(
652       op.getLoc(), vectorType,
653       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
654   valueMapping[op.getResult()] = result;
655   return success();
656 }
657 
658 /// Check if the loaded matrix operand requires transposed.
659 /// Transposed Map Example:
660 /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
661 /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
662 /// The code below checks if the output 2D is transposed using a generalized
663 /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
664 /// Returns     : true; if m > n, false o.w.
665 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
666   mlir::AffineMap map = op.getPermutationMap();
667 
668   if (map.getNumResults() != 2) {
669     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
670                          "is not a 2d operand\n");
671     return failure();
672   }
673 
674   // Output 2D matrix dimensions in the order of d0, d1.
675   mlir::AffineExpr dM = map.getResult(0);
676   mlir::AffineExpr dN = map.getResult(1);
677 
678   //  Find the position of these expressions in the input.
679   auto exprM = dyn_cast<AffineDimExpr>(dM);
680   auto exprN = dyn_cast<AffineDimExpr>(dN);
681 
682   if (!exprM || !exprN) {
683     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
684                          "expressions, then transpose cannot be determined.\n");
685     return failure();
686   }
687 
688   return exprM.getPosition() > exprN.getPosition();
689 }
690 
691 static LogicalResult
692 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
693                              llvm::DenseMap<Value, Value> &valueMapping) {
694   OpBuilder::InsertionGuard g(rewriter);
695   rewriter.setInsertionPoint(op);
696   Location loc = op->getLoc();
697 
698   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
699       nvgpu::getWarpMatrixInfo(op);
700   if (failed(warpMatrixInfo)) {
701     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
702     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
703   }
704 
705   FailureOr<nvgpu::FragmentElementInfo> regInfo =
706       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
707   if (failed(regInfo)) {
708     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
709     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
710   }
711 
712   FailureOr<bool> transpose = isTransposed(op);
713   if (failed(transpose)) {
714     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
715     return rewriter.notifyMatchFailure(
716         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
717   }
718 
719   FailureOr<nvgpu::LdMatrixParams> params =
720       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
721 
722   if (failed(params)) {
723     LLVM_DEBUG(
724         DBGS()
725         << "failed to convert vector.transfer_read to ldmatrix. "
726         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
727     return rewriter.notifyMatchFailure(
728         op, "failed to convert vector.transfer_read to ldmatrix; this op "
729             "likely should not be converted to a nvgpu.ldmatrix call.");
730   }
731 
732   // Adjust the load offset.
733   auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
734   FailureOr<AffineMap> offsets =
735       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
736   if (failed(offsets)) {
737     LLVM_DEBUG(DBGS() << "no offsets\n");
738     return rewriter.notifyMatchFailure(op, "no offsets");
739   }
740 
741   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
742 
743   SmallVector<Value, 4> indices;
744   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
745                                          indices);
746 
747   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
748       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
749   valueMapping[op] = newOp->getResult(0);
750   return success();
751 }
752 
753 static LogicalResult
754 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
755                        llvm::DenseMap<Value, Value> &valueMapping) {
756   OpBuilder::InsertionGuard g(rewriter);
757   rewriter.setInsertionPoint(op);
758 
759   Location loc = op.getLoc();
760   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
761       nvgpu::getWarpMatrixInfo(op);
762   if (failed(warpMatrixInfo))
763     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
764   FailureOr<nvgpu::FragmentElementInfo> regInfo =
765       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
766   if (failed(regInfo)) {
767     return rewriter.notifyMatchFailure(
768         op, "Failed to deduce register fragment type during "
769             "conversion to distributed non-ldmatrix compatible load");
770   }
771 
772   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
773   SmallVector<Value, 4> elements;
774 
775   // This is the individual element type.
776   Type loadedElType = regInfo->registerLLVMType;
777   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
778 
779   Value fill = rewriter.create<arith::ConstantOp>(
780       op.getLoc(), vectorType.getElementType(),
781       rewriter.getZeroAttr(vectorType.getElementType()));
782   Value result =
783       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
784 
785   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
786 
787   // If we are not transposing, then we can use vectorized loads. Otherwise, we
788   // must load each element individually.
789   if (!isTransposeLoad) {
790     if (!isa<VectorType>(loadedElType)) {
791       loadedElType = VectorType::get({1}, loadedElType);
792     }
793 
794     for (int i = 0; i < vectorType.getShape()[0]; i++) {
795       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
796           rewriter, op.getLoc(), *warpMatrixInfo);
797       if (failed(coords))
798         return rewriter.notifyMatchFailure(op, "no coords");
799 
800       Value logicalValueId = rewriter.create<arith::ConstantOp>(
801           loc, rewriter.getIndexType(),
802           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
803       SmallVector<Value, 4> newIndices;
804       getXferIndices<vector::TransferReadOp>(
805           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
806 
807       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
808                                                  op.getSource(), newIndices);
809       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
810     }
811   } else {
812     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
813       loadedElType = vecType.getElementType();
814     }
815     for (int i = 0; i < vectorType.getShape()[0]; i++) {
816       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
817            innerIdx++) {
818 
819         Value logicalValueId = rewriter.create<arith::ConstantOp>(
820             loc, rewriter.getIndexType(),
821             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
822         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
823             rewriter, op.getLoc(), *warpMatrixInfo);
824         if (failed(coords))
825           return rewriter.notifyMatchFailure(op, "no coords");
826 
827         SmallVector<Value, 4> newIndices;
828         getXferIndices<vector::TransferReadOp>(
829             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
830         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
831                                                    op.getSource(), newIndices);
832         result = rewriter.create<vector::InsertOp>(
833             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
834       }
835     }
836   }
837 
838   valueMapping[op.getResult()] = result;
839   return success();
840 }
841 
842 /// Return true if this is a shared memory memref type.
843 static bool isSharedMemory(MemRefType type) {
844   auto addressSpace =
845       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
846   if (addressSpace &&
847       addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
848     return true;
849   return false;
850 }
851 
852 /// Converts a `vector.transfer_read` operation directly to either a
853 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
854 /// used when converting to `nvgpu.mma.sync` operations.
855 static LogicalResult
856 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
857                            llvm::DenseMap<Value, Value> &valueMapping) {
858   OpBuilder::InsertionGuard g(rewriter);
859   rewriter.setInsertionPoint(op);
860 
861   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
862       nvgpu::getWarpMatrixInfo(op);
863   if (failed(warpMatrixInfo))
864     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
865 
866   bool isLdMatrixCompatible =
867       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
868       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
869 
870   VectorType vecTy = op.getVectorType();
871   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
872 
873   // When we are transposing the B operand, ldmatrix will only work if we have
874   // at least 8 rows to read and the width to read for the transpose is 128
875   // bits.
876   if (!op.getPermutationMap().isMinorIdentity() &&
877       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
878        vecTy.getDimSize(0) * bitWidth < 128))
879     isLdMatrixCompatible = false;
880 
881   if (!isLdMatrixCompatible)
882     return createNonLdMatrixLoads(rewriter, op, valueMapping);
883 
884   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
885 }
886 
887 static LogicalResult
888 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
889                              llvm::DenseMap<Value, Value> &valueMapping) {
890   OpBuilder::InsertionGuard g(rewriter);
891   rewriter.setInsertionPoint(op);
892 
893   Location loc = op->getLoc();
894   auto it = valueMapping.find(op.getVector());
895   if (it == valueMapping.end())
896     return rewriter.notifyMatchFailure(op, "no mapping");
897   Value matrix = it->second;
898 
899   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
900       nvgpu::getWarpMatrixInfo(op);
901   if (failed(warpMatrixInfo))
902     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
903   FailureOr<nvgpu::FragmentElementInfo> regInfo =
904       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
905   if (failed(regInfo))
906     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
907 
908   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
909   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
910 
911   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
912     Value logicalValueId = rewriter.create<arith::ConstantOp>(
913         loc, rewriter.getIndexType(),
914         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
915     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
916         rewriter, op.getLoc(), *warpMatrixInfo);
917     if (failed(coords))
918       return rewriter.notifyMatchFailure(op, "no coords");
919 
920     Value el =
921         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
922     SmallVector<Value, 4> newIndices;
923     getXferIndices<vector::TransferWriteOp>(
924         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
925     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
926   }
927 
928   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
929   rewriter.eraseOp(op);
930   return success();
931 }
932 
933 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
934                                        SmallVectorImpl<int64_t> &results) {
935   for (auto attr : arrayAttr)
936     results.push_back(cast<IntegerAttr>(attr).getInt());
937 }
938 
939 static LogicalResult
940 convertExtractStridedSlice(RewriterBase &rewriter,
941                            vector::ExtractStridedSliceOp op,
942                            llvm::DenseMap<Value, Value> &valueMapping) {
943   OpBuilder::InsertionGuard g(rewriter);
944   rewriter.setInsertionPoint(op);
945 
946   Location loc = op->getLoc();
947 
948   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
949       nvgpu::getWarpMatrixInfo(op);
950   if (failed(warpMatrixInfo))
951     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
952 
953   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
954       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
955   if (failed(mmaSyncFragmentInfo))
956     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
957 
958   // Find the vector.transer_read whose result vector is being sliced.
959   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
960   if (!transferReadOp)
961     return rewriter.notifyMatchFailure(op, "no transfer read");
962 
963   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
964   if (failed(warpMatrixInfo))
965     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
966 
967   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
968       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
969   if (failed(ldFragmentInfo))
970     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
971 
972   assert(
973       (mmaSyncFragmentInfo->elementsPerRegister ==
974        ldFragmentInfo->elementsPerRegister) &&
975       "Number of elements per register should be same for load and mma.sync");
976 
977   // Create vector.extract_strided_slice op for thread-owned fragments.
978   std::array<int64_t, 2> strides = {1,
979                                     1}; // stride for extract slice is always 1.
980   std::array<int64_t, 2> sliceShape = {
981       mmaSyncFragmentInfo->numRegistersPerFragment,
982       mmaSyncFragmentInfo->elementsPerRegister};
983   auto it = valueMapping.find(transferReadOp);
984   if (it == valueMapping.end())
985     return rewriter.notifyMatchFailure(op, "no mapping");
986   auto sourceVector = it->second;
987 
988   // offset and sizes at warp-level of onwership.
989   SmallVector<int64_t> offsets;
990   populateFromInt64AttrArray(op.getOffsets(), offsets);
991 
992   SmallVector<int64_t> sizes;
993   populateFromInt64AttrArray(op.getSizes(), sizes);
994   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
995 
996   // Compute offset in vector registers. Note that the mma.sync vector registers
997   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
998   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
999   std::array<int64_t, 2> sliceOffset = {0, 0};
1000 
1001   if (offsets[0] && offsets[1])
1002     return op->emitError() << "Slicing fragments in 2D is not supported. ";
1003   if (offsets[0])
1004     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1005   else if (offsets[1])
1006     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1007 
1008   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1009       loc, sourceVector, sliceOffset, sliceShape, strides);
1010 
1011   valueMapping[op] = newOp;
1012   return success();
1013 }
1014 
1015 static LogicalResult
1016 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1017                   llvm::DenseMap<Value, Value> &valueMapping) {
1018   OpBuilder::InsertionGuard g(rewriter);
1019   rewriter.setInsertionPoint(op);
1020 
1021   auto itA = valueMapping.find(op.getLhs());
1022   auto itB = valueMapping.find(op.getRhs());
1023   auto itC = valueMapping.find(op.getAcc());
1024   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1025       itC == valueMapping.end())
1026     return rewriter.notifyMatchFailure(op, "no mapping");
1027   Value opA = itA->second, opB = itB->second, opC = itC->second;
1028   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1029       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1030       /*b_transpose=*/UnitAttr());
1031   valueMapping[op.getResult()] = matmul;
1032   return success();
1033 }
1034 
1035 static LogicalResult
1036 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1037                            llvm::DenseMap<Value, Value> &valueMapping) {
1038   OpBuilder::InsertionGuard g(rewriter);
1039   rewriter.setInsertionPoint(op);
1040 
1041   auto itA = valueMapping.find(op.getLhs());
1042   auto itB = valueMapping.find(op.getRhs());
1043   auto itC = valueMapping.find(op.getAcc());
1044   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1045       itC == valueMapping.end())
1046     return rewriter.notifyMatchFailure(op, "no mapping");
1047   Value opA = itA->second, opB = itB->second, opC = itC->second;
1048   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1049   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1050   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1051   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1052       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1053   valueMapping[op.getResult()] = matmul;
1054   return success();
1055 }
1056 
1057 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1058 static LogicalResult
1059 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1060                   llvm::DenseMap<Value, Value> &valueMapping) {
1061   OpBuilder::InsertionGuard g(rewriter);
1062   rewriter.setInsertionPoint(op);
1063 
1064   assert(constantSupportsMMAMatrixType(op));
1065 
1066   auto splat =
1067       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1068   auto scalarConstant =
1069       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1070   const char *fragType = inferFragType(op);
1071   auto vecType = cast<VectorType>(op.getType());
1072   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1073       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1074   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1075       op.getLoc(), type, scalarConstant);
1076   valueMapping[op.getResult()] = matrix;
1077   return success();
1078 }
1079 
1080 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1081 static LogicalResult
1082 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1083                    llvm::DenseMap<Value, Value> &valueMapping) {
1084   OpBuilder::InsertionGuard g(rewriter);
1085   rewriter.setInsertionPoint(op);
1086 
1087   assert(broadcastSupportsMMAMatrixType(op));
1088 
1089   const char *fragType = inferFragType(op);
1090   auto vecType = op.getResultVectorType();
1091   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1092       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1093   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1094       op.getLoc(), type, op.getSource());
1095   valueMapping[op.getResult()] = matrix;
1096   return success();
1097 }
1098 
1099 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1100 // updated and needs to be updated separately for the loop to be correct.
1101 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1102                                                scf::ForOp loop,
1103                                                ValueRange newInitArgs) {
1104   OpBuilder::InsertionGuard g(rewriter);
1105   rewriter.setInsertionPoint(loop);
1106 
1107   // Create a new loop before the existing one, with the extra operands.
1108   rewriter.setInsertionPoint(loop);
1109   auto operands = llvm::to_vector<4>(loop.getInitArgs());
1110   llvm::append_range(operands, newInitArgs);
1111   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1112       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1113       operands);
1114   newLoop.getBody()->erase();
1115 
1116   newLoop.getRegion().getBlocks().splice(
1117       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1118   for (Value operand : newInitArgs)
1119     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1120 
1121   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1122                                                   loop.getNumResults())))
1123     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1124 
1125   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1126   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1127   LLVM_DEBUG(DBGS() << "erase: " << loop);
1128 
1129   rewriter.eraseOp(loop);
1130   return newLoop;
1131 }
1132 
1133 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1134                                   llvm::DenseMap<Value, Value> &valueMapping) {
1135   OpBuilder::InsertionGuard g(rewriter);
1136   rewriter.setInsertionPoint(op);
1137 
1138   SmallVector<Value> newOperands;
1139   SmallVector<std::pair<size_t, size_t>> argMapping;
1140   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1141     auto it = valueMapping.find(operand.value());
1142     if (it == valueMapping.end()) {
1143       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1144       continue;
1145     }
1146     argMapping.push_back(std::make_pair(
1147         operand.index(), op.getInitArgs().size() + newOperands.size()));
1148     newOperands.push_back(it->second);
1149   }
1150 
1151   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1152   Block &loopBody = *newForOp.getBody();
1153   for (auto mapping : argMapping) {
1154     valueMapping[newForOp.getResult(mapping.first)] =
1155         newForOp.getResult(mapping.second);
1156     valueMapping[loopBody.getArgument(mapping.first +
1157                                       newForOp.getNumInductionVars())] =
1158         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1159   }
1160 
1161   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1162   return success();
1163 }
1164 
1165 static LogicalResult
1166 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1167                llvm::DenseMap<Value, Value> &valueMapping) {
1168   OpBuilder::InsertionGuard g(rewriter);
1169   rewriter.setInsertionPoint(op);
1170 
1171   auto loop = cast<scf::ForOp>(op->getParentOp());
1172   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1173   for (const auto &operand : llvm::enumerate(op.getOperands())) {
1174     auto it = valueMapping.find(operand.value());
1175     if (it == valueMapping.end())
1176       continue;
1177     // Replace the yield of old value with the for op argument to make it easier
1178     // to remove the dead code.
1179     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1180     yieldOperands.push_back(it->second);
1181   }
1182   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1183 
1184   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1185   rewriter.eraseOp(op);
1186   return success();
1187 }
1188 
1189 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1190 static LogicalResult
1191 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1192                      gpu::MMAElementwiseOp opType,
1193                      llvm::DenseMap<Value, Value> &valueMapping) {
1194   OpBuilder::InsertionGuard g(rewriter);
1195   rewriter.setInsertionPoint(op);
1196 
1197   SmallVector<Value> matrixOperands;
1198   for (Value operand : op->getOperands()) {
1199     auto it = valueMapping.find(operand);
1200     if (it == valueMapping.end())
1201       return rewriter.notifyMatchFailure(op, "no mapping");
1202     matrixOperands.push_back(it->second);
1203   }
1204   auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
1205   if (opType == gpu::MMAElementwiseOp::EXTF) {
1206     // The floating point extension case has a different result type.
1207     auto vectorType = op->getResultTypes()[0].cast<VectorType>();
1208     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1209                                          vectorType.getElementType(),
1210                                          resultType.getOperand());
1211   }
1212 
1213   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1214       op->getLoc(), resultType, matrixOperands, opType);
1215   valueMapping[op->getResult(0)] = newOp;
1216   return success();
1217 }
1218 
1219 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1220                                               bool useNvGpu) {
1221   if (!useNvGpu) {
1222     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1223         patterns.getContext());
1224     return;
1225   }
1226   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1227   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1228 }
1229 
1230 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1231                                           Operation *rootOp) {
1232   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1233   llvm::DenseMap<Value, Value> valueMapping;
1234 
1235   auto globalRes = LogicalResult::success();
1236   for (Operation *op : ops) {
1237     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1238     // Apparently callers do not want to early exit on failure here.
1239     auto res = LogicalResult::success();
1240     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1241       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1242     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1243       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1244     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1245       res = convertContractOp(rewriter, contractOp, valueMapping);
1246     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1247       res = convertConstantOp(rewriter, constantOp, valueMapping);
1248     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1249       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1250     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1251       res = convertForOp(rewriter, forOp, valueMapping);
1252     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1253       res = convertYieldOp(rewriter, yieldOp, valueMapping);
1254     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1255       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1256     }
1257     if (failed(res))
1258       globalRes = failure();
1259   }
1260   return globalRes;
1261 }
1262 
1263 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1264                                                          Operation *rootOp) {
1265   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1266   llvm::DenseMap<Value, Value> valueMapping;
1267   for (Operation *op : ops) {
1268     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1269             .Case([&](vector::TransferReadOp transferReadOp) {
1270               return convertTransferReadToLoads(rewriter, transferReadOp,
1271                                                 valueMapping);
1272             })
1273             .Case([&](vector::TransferWriteOp transferWriteOp) {
1274               return convertTransferWriteToStores(rewriter, transferWriteOp,
1275                                                   valueMapping);
1276             })
1277             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1278               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1279                                                 valueMapping);
1280             })
1281             .Case([&](vector::ContractionOp contractionOp) {
1282               return convertContractOpToMmaSync(rewriter, contractionOp,
1283                                                 valueMapping);
1284             })
1285             .Case([&](scf::ForOp forOp) {
1286               return convertForOp(rewriter, forOp, valueMapping);
1287             })
1288             .Case([&](scf::YieldOp yieldOp) {
1289               return convertYieldOp(rewriter, yieldOp, valueMapping);
1290             })
1291             .Case([&](arith::ConstantOp constOp) {
1292               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1293             })
1294             .Default([&](Operation *op) {
1295               return op->emitError() << "unhandled vector to mma type: " << *op;
1296             })
1297             .failed()) {
1298       return op->emitOpError()
1299              << "failed to convert op during vector-to-nvgpu conversion";
1300     }
1301   }
1302   return success();
1303 }
1304 
1305 namespace {
1306 
1307 struct ConvertVectorToGPUPass
1308     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1309 
1310   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1311     useNvGpu.setValue(useNvGpu_);
1312   }
1313 
1314   void runOnOperation() override {
1315     RewritePatternSet patterns(&getContext());
1316     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1317     if (failed(
1318             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1319       return signalPassFailure();
1320 
1321     IRRewriter rewriter(&getContext());
1322     if (useNvGpu) {
1323       if (failed(
1324               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1325         return signalPassFailure();
1326       return;
1327     }
1328     (void)convertVectorToMMAOps(rewriter, getOperation());
1329   }
1330 };
1331 
1332 } // namespace
1333 
1334 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1335   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1336 }
1337