xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Analysis/TopologicalSortUtils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58                            AffineMap offsetMap, ArrayRef<Value> dimValues,
59                            SmallVector<Value, 4> &indices) {
60   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61   Location loc = xferOp.getLoc();
62   unsigned offsetsIdx = 0;
63   for (auto expr : xferOp.getPermutationMap().getResults()) {
64     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65       Value prevIdx = indices[dim.getPosition()];
66       SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67       dims.push_back(prevIdx);
68       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69       indices[dim.getPosition()] = affine::makeComposedAffineApply(
70           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71       continue;
72     }
73   }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78                                           bool useNvGpu) {
79   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80   auto infer = [&](MapList m) {
81     return AffineMap::inferFromExprList(m, contract.getContext());
82   };
83   AffineExpr m, n, k;
84   bindDims(contract.getContext(), m, n, k);
85   auto iteratorTypes = contract.getIteratorTypes().getValue();
86   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87         vector::isParallelIterator(iteratorTypes[1]) &&
88         vector::isReductionIterator(iteratorTypes[2])))
89     return false;
90 
91   // The contract needs to represent a matmul to be able to convert to
92   // MMAMatrix matmul.
93   if (!useNvGpu &&
94       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95     return false;
96   if (useNvGpu &&
97       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98     return false;
99 
100   return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106   MLIRContext *ctx = permutationMap.getContext();
107   // Local OpBuilder is fine here, we just build attributes.
108   OpBuilder b(ctx);
109   auto nDim = permutationMap.getNumDims();
110   AffineExpr zero = b.getAffineConstantExpr(0);
111   if (nDim < 2) {
112     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113     AffineExpr dim0 = b.getAffineDimExpr(0);
114     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115   }
116 
117   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119   // Support both transposed and transposed+broadcasted cases.
120   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127   auto memrefType = dyn_cast<MemRefType>(type);
128   if (!memrefType)
129     return false;
130   // If the memref is 0 or 1D the horizontal stride is 0.
131   if (memrefType.getRank() < 2)
132     return 0;
133   int64_t offset = 0;
134   SmallVector<int64_t, 2> strides;
135   if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
136       strides.back() != 1)
137     return std::nullopt;
138   int64_t stride = strides[strides.size() - 2];
139   if (stride == ShapedType::kDynamic)
140     return std::nullopt;
141   return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147       readOp.getVectorType().getRank() != 2)
148     return false;
149   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150     return false;
151 
152   // Only allow integer types if the signedness can be inferred.
153   if (readOp.getVectorType().getElementType().isInteger(8))
154     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156       return false;
157 
158   AffineMap map = readOp.getPermutationMap();
159   MLIRContext *ctx = readOp.getContext();
160   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161   AffineExpr zero = getAffineConstantExpr(0, ctx);
162   auto broadcastInnerDim =
163       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164   return map.isMinorIdentity() || map == broadcastInnerDim ||
165          isTransposeMatrixLoadMap(map);
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171   // TODO: support 0-d corner case.
172   if (writeOp.getTransferRank() == 0)
173     return false;
174 
175   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176       writeOp.getVectorType().getRank() != 2)
177     return false;
178   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179     return false;
180   // TODO: Support transpose once it is added to GPU dialect ops.
181   if (!writeOp.getPermutationMap().isMinorIdentity())
182     return false;
183   return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189   auto vecType = dyn_cast<VectorType>(constantOp.getType());
190   if (!vecType || vecType.getRank() != 2)
191     return false;
192   return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197   return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
204     return false;
205   return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
213 convertElementwiseOpToMMA(Operation *op) {
214   if (isa<arith::AddFOp>(op))
215     return gpu::MMAElementwiseOp::ADDF;
216   if (isa<arith::MulFOp>(op))
217     return gpu::MMAElementwiseOp::MULF;
218   if (isa<arith::SubFOp>(op))
219     return gpu::MMAElementwiseOp::SUBF;
220   if (isa<arith::MaximumFOp>(op))
221     return gpu::MMAElementwiseOp::MAXF;
222   if (isa<arith::MinimumFOp>(op))
223     return gpu::MMAElementwiseOp::MINF;
224   if (isa<arith::DivFOp>(op))
225     return gpu::MMAElementwiseOp::DIVF;
226   if (isa<arith::AddIOp>(op))
227     return gpu::MMAElementwiseOp::ADDI;
228   if (isa<arith::MulIOp>(op))
229     return gpu::MMAElementwiseOp::MULI;
230   if (isa<arith::SubIOp>(op))
231     return gpu::MMAElementwiseOp::SUBI;
232   if (isa<arith::DivSIOp>(op))
233     return gpu::MMAElementwiseOp::DIVS;
234   if (isa<arith::DivUIOp>(op))
235     return gpu::MMAElementwiseOp::DIVU;
236   if (isa<arith::NegFOp>(op))
237     return gpu::MMAElementwiseOp::NEGATEF;
238   if (isa<arith::ExtFOp>(op))
239     return gpu::MMAElementwiseOp::EXTF;
240   return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
244 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
245   return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
254       nvgpu::getWarpMatrixInfo(op);
255   if (failed(warpMatrixInfo))
256     return false;
257 
258   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259   if (failed(contractOp))
260     return false;
261 
262   // Handle vector.extract_strided_slice on registers containing
263   // matrixB and matrixC operands. vector.extract_strided_slice op
264   // is not supported on registers containing matrixA operands.
265   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266     return (cast<VectorType>(op->getResult(0).getType()) ==
267             cast<VectorType>((*contractOp).getRhs().getType()));
268   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269     return (cast<VectorType>(op->getResult(0).getType()) ==
270             cast<VectorType>((*contractOp).getAcc().getType()));
271 
272   return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276   if (isa<scf::ForOp, scf::YieldOp>(op))
277     return true;
278   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280                     : transferReadSupportsMMAMatrixType(transferRead);
281   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283                     : transferWriteSupportsMMAMatrixType(transferWrite);
284   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285     return useNvGpu &&
286            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287   if (auto contract = dyn_cast<vector::ContractionOp>(op))
288     return contractSupportsMMAMatrixType(contract, useNvGpu);
289   if (auto constant = dyn_cast<arith::ConstantOp>(op))
290     return constantSupportsMMAMatrixType(constant);
291   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
292     return broadcastSupportsMMAMatrixType(broadcast);
293   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298     return fpExtendSupportsMMAMatrixType(fpExtend);
299   return elementwiseSupportsMMAMatrixType(op);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
305 static SetVector<Operation *>
306 getSliceContract(Operation *op,
307                  const BackwardSliceOptions &backwardSliceOptions,
308                  const ForwardSliceOptions &forwardSliceOptions) {
309   SetVector<Operation *> slice;
310   slice.insert(op);
311   unsigned currentIndex = 0;
312   SetVector<Operation *> backwardSlice;
313   SetVector<Operation *> forwardSlice;
314   while (currentIndex != slice.size()) {
315     auto *currentOp = (slice)[currentIndex];
316     // Compute and insert the backwardSlice starting from currentOp.
317     backwardSlice.clear();
318     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
319     slice.insert(backwardSlice.begin(), backwardSlice.end());
320 
321     // Compute and insert the forwardSlice starting from currentOp.
322     forwardSlice.clear();
323     // Special case for ForOp, we don't want to include the whole region but
324     // only the value using the region arguments.
325     // TODO: We should refine this to only care about the region arguments being
326     // converted to matrix type.
327     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328       for (Value forOpResult : forOp.getResults())
329         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330       for (BlockArgument &arg : forOp.getRegionIterArgs())
331         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332     } else {
333       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334     }
335     slice.insert(forwardSlice.begin(), forwardSlice.end());
336     ++currentIndex;
337   }
338   return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
343 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
344                                              bool useNvGpu) {
345   auto hasVectorDest = [](Operation *op) {
346     return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
347   };
348   BackwardSliceOptions backwardSliceOptions;
349   backwardSliceOptions.filter = hasVectorDest;
350 
351   auto hasVectorSrc = [](Operation *op) {
352     return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
353   };
354   ForwardSliceOptions forwardSliceOptions;
355   forwardSliceOptions.filter = hasVectorSrc;
356 
357   SetVector<Operation *> opToConvert;
358   op->walk([&](vector::ContractionOp contract) {
359     if (opToConvert.contains(contract.getOperation()))
360       return;
361     SetVector<Operation *> dependentOps =
362         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
363     // If any instruction cannot use MMA matrix type drop the whole
364     // chain. MMA matrix are stored in an opaque type so they cannot be used
365     // by all operations.
366     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
367           if (!supportsMMaMatrixType(op, useNvGpu)) {
368             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
369             return true;
370           }
371           return false;
372         }))
373       return;
374 
375     opToConvert.insert(dependentOps.begin(), dependentOps.end());
376   });
377   // Sort the operations so that we can convert them in topological order.
378   return topologicalSort(opToConvert);
379 }
380 
381 namespace {
382 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
383 // to MMA matmul.
384 struct PrepareContractToGPUMMA
385     : public OpRewritePattern<vector::ContractionOp> {
386   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
387 
388   LogicalResult matchAndRewrite(vector::ContractionOp op,
389                                 PatternRewriter &rewriter) const override {
390     Location loc = op.getLoc();
391     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
392 
393     // Set up the parallel/reduction structure in right form.
394     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
395     auto infer = [&](MapList m) {
396       return AffineMap::inferFromExprList(m, op.getContext());
397     };
398     AffineExpr m, n, k;
399     bindDims(rewriter.getContext(), m, n, k);
400     static constexpr std::array<int64_t, 2> perm = {1, 0};
401     auto iteratorTypes = op.getIteratorTypes().getValue();
402     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
404           vector::isParallelIterator(iteratorTypes[1]) &&
405           vector::isReductionIterator(iteratorTypes[2])))
406       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407     //
408     // Two outer parallel, one inner reduction (matmat flavor).
409     //
410     // This is the classical row-major matmul, nothing to do.
411     if (maps == infer({{m, k}, {k, n}, {m, n}}))
412       return rewriter.notifyMatchFailure(op, "contraction already prepared");
413     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421       std::swap(rhs, lhs);
422       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425       std::swap(rhs, lhs);
426       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428       std::swap(lhs, rhs);
429       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431       std::swap(lhs, rhs);
432     } else {
433       // TODO: llvm_unreachable ?
434       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435     }
436     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437         op, lhs, rhs, res,
438         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
439         op.getIteratorTypes());
440     return success();
441   }
442 };
443 
444 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446 // respectively. We can fold the transpose operation when loading the data from
447 // Shared Memory to registers.
448 struct CombineTransferReadOpTranspose final
449     : public OpRewritePattern<vector::TransposeOp> {
450   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
451 
452   LogicalResult matchAndRewrite(vector::TransposeOp op,
453                                 PatternRewriter &rewriter) const override {
454     // Look through integer extend ops.
455     Value source = op.getVector();
456     Type resultType = op.getType();
457     Operation *extOp;
458     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459         (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460         (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461       source = extOp->getOperand(0);
462       resultType =
463           VectorType::get(cast<VectorType>(resultType).getShape(),
464                           cast<VectorType>(source.getType()).getElementType());
465     }
466 
467     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468     if (!transferReadOp)
469       return rewriter.notifyMatchFailure(op, "no transfer read");
470 
471     // TODO: support 0-d corner case.
472     if (transferReadOp.getTransferRank() == 0)
473       return rewriter.notifyMatchFailure(op, "0-D transfer read");
474 
475     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477 
478     AffineMap permutationMap =
479         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480     AffineMap newMap =
481         permutationMap.compose(transferReadOp.getPermutationMap());
482 
483     auto loc = op.getLoc();
484     Value result =
485         rewriter
486             .create<vector::TransferReadOp>(
487                 loc, resultType, transferReadOp.getSource(),
488                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489                 transferReadOp.getPadding(), transferReadOp.getMask(),
490                 transferReadOp.getInBoundsAttr())
491             .getResult();
492 
493     // Fuse through the integer extend op.
494     if (extOp) {
495       if (isa<arith::ExtSIOp>(extOp))
496         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497                      .getResult();
498       else if (isa<arith::ExtUIOp>(extOp))
499         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500                      .getResult();
501       else
502         result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503                      .getResult();
504     }
505 
506     rewriter.replaceOp(op, result);
507     return success();
508   }
509 };
510 
511 } // namespace
512 
513 // MMA types have different layout based on how they are used in matmul ops.
514 // Figure the right layout to use by looking at op uses.
515 // TODO: Change the GPU dialect to abstract the layout at the this level and
516 // only care about it during lowering to NVVM.
517 static const char *inferFragType(Operation *op) {
518   // We can have arith.ext ops before reaching contract ops. See through them
519   // and other kinds of elementwise ops.
520   if (op->hasOneUse()) {
521     Operation *userOp = *op->user_begin();
522     if (userOp->hasTrait<OpTrait::Elementwise>())
523       return inferFragType(userOp);
524   }
525 
526   for (Operation *users : op->getUsers()) {
527     auto contract = dyn_cast<vector::ContractionOp>(users);
528     if (!contract)
529       continue;
530     assert(op->getNumResults() == 1);
531     if (contract.getLhs() == op->getResult(0))
532       return "AOp";
533     if (contract.getRhs() == op->getResult(0))
534       return "BOp";
535   }
536   return "COp";
537 }
538 
539 static LogicalResult
540 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
541                       llvm::DenseMap<Value, Value> &valueMapping) {
542   OpBuilder::InsertionGuard g(rewriter);
543   rewriter.setInsertionPoint(op);
544 
545   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
546   assert(transferReadSupportsMMAMatrixType(op) &&
547          "expected convertible operation");
548 
549   std::optional<int64_t> stride =
550       getStaticallyKnownRowStride(op.getShapedType());
551   if (!stride.has_value()) {
552     LLVM_DEBUG(DBGS() << "no stride\n");
553     return rewriter.notifyMatchFailure(op, "no stride");
554   }
555 
556   AffineMap map = op.getPermutationMap();
557   bool isTranspose = isTransposeMatrixLoadMap(map);
558 
559   // Handle broadcast by setting the stride to 0.
560   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
561     assert(cstExpr.getValue() == 0);
562     stride = 0;
563   }
564 
565   Value mappingResult = op.getResult();
566   auto elType = op.getVectorType().getElementType();
567   const char *fragType = inferFragType(op);
568   if (op->hasOneUse()) {
569     auto *user = *op->user_begin();
570     // Infer the signedness of the mma type from the integer extend.
571     if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
572       elType = IntegerType::get(
573           op.getContext(), cast<IntegerType>(elType).getWidth(),
574           isa<arith::ExtSIOp>(user) ? IntegerType::Signed
575                                     : IntegerType::Unsigned);
576       mappingResult = user->getResult(0);
577     }
578   }
579   gpu::MMAMatrixType type =
580       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
581   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
582       op.getLoc(), type, op.getSource(), op.getIndices(),
583       rewriter.getIndexAttr(*stride),
584       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
585   valueMapping[mappingResult] = load;
586 
587   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
588   return success();
589 }
590 
591 static LogicalResult
592 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
593                        llvm::DenseMap<Value, Value> &valueMapping) {
594   OpBuilder::InsertionGuard g(rewriter);
595   rewriter.setInsertionPoint(op);
596 
597   assert(transferWriteSupportsMMAMatrixType(op));
598   std::optional<int64_t> stride =
599       getStaticallyKnownRowStride(op.getShapedType());
600   if (!stride.has_value()) {
601     LLVM_DEBUG(DBGS() << "no stride\n");
602     return rewriter.notifyMatchFailure(op, "no stride");
603   }
604 
605   auto it = valueMapping.find(op.getVector());
606   if (it == valueMapping.end()) {
607     LLVM_DEBUG(DBGS() << "no mapping\n");
608     return rewriter.notifyMatchFailure(op, "no mapping");
609   }
610 
611   Value matrix = it->second;
612   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
613       op.getLoc(), matrix, op.getSource(), op.getIndices(),
614       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
615   (void)store;
616 
617   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
618 
619   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
620   rewriter.eraseOp(op);
621   return success();
622 }
623 
624 /// Returns the vector type which represents a matrix fragment.
625 static VectorType
626 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
627   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
628                              regInfo.elementsPerRegister};
629   Type elType = regInfo.registerLLVMType;
630   if (auto vecType = dyn_cast<VectorType>(elType))
631     elType = vecType.getElementType();
632   return VectorType::get(shape, elType);
633 }
634 
635 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
636 static LogicalResult
637 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
638                          llvm::DenseMap<Value, Value> &valueMapping) {
639   OpBuilder::InsertionGuard g(rewriter);
640   rewriter.setInsertionPoint(op);
641 
642   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
643       nvgpu::getWarpMatrixInfo(op);
644   if (failed(warpMatrixInfo)) {
645     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
646     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
647   }
648 
649   FailureOr<nvgpu::FragmentElementInfo> regInfo =
650       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
651   if (failed(regInfo)) {
652     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
653     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
654   }
655 
656   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
657   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
658   if (!dense) {
659     LLVM_DEBUG(DBGS() << "not a splat\n");
660     return rewriter.notifyMatchFailure(op, "not a splat");
661   }
662 
663   Value result = rewriter.create<arith::ConstantOp>(
664       op.getLoc(), vectorType,
665       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
666   valueMapping[op.getResult()] = result;
667   return success();
668 }
669 
670 /// Check if the loaded matrix operand requires transposed.
671 /// Transposed Map Example:
672 /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
673 /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
674 /// The code below checks if the output 2D is transposed using a generalized
675 /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
676 /// Returns     : true; if m > n, false o.w.
677 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
678   mlir::AffineMap map = op.getPermutationMap();
679 
680   if (map.getNumResults() != 2) {
681     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
682                          "is not a 2d operand\n");
683     return failure();
684   }
685 
686   // Output 2D matrix dimensions in the order of d0, d1.
687   mlir::AffineExpr dM = map.getResult(0);
688   mlir::AffineExpr dN = map.getResult(1);
689 
690   //  Find the position of these expressions in the input.
691   auto exprM = dyn_cast<AffineDimExpr>(dM);
692   auto exprN = dyn_cast<AffineDimExpr>(dN);
693 
694   if (!exprM || !exprN) {
695     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
696                          "expressions, then transpose cannot be determined.\n");
697     return failure();
698   }
699 
700   return exprM.getPosition() > exprN.getPosition();
701 }
702 
703 static LogicalResult
704 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
705                              llvm::DenseMap<Value, Value> &valueMapping) {
706   OpBuilder::InsertionGuard g(rewriter);
707   rewriter.setInsertionPoint(op);
708   Location loc = op->getLoc();
709 
710   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
711       nvgpu::getWarpMatrixInfo(op);
712   if (failed(warpMatrixInfo)) {
713     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
714     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
715   }
716 
717   FailureOr<nvgpu::FragmentElementInfo> regInfo =
718       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
719   if (failed(regInfo)) {
720     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
721     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
722   }
723 
724   FailureOr<bool> transpose = isTransposed(op);
725   if (failed(transpose)) {
726     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
727     return rewriter.notifyMatchFailure(
728         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
729   }
730 
731   FailureOr<nvgpu::LdMatrixParams> params =
732       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
733 
734   if (failed(params)) {
735     LLVM_DEBUG(
736         DBGS()
737         << "failed to convert vector.transfer_read to ldmatrix. "
738         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
739     return rewriter.notifyMatchFailure(
740         op, "failed to convert vector.transfer_read to ldmatrix; this op "
741             "likely should not be converted to a nvgpu.ldmatrix call.");
742   }
743 
744   // Adjust the load offset.
745   auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
746   FailureOr<AffineMap> offsets =
747       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
748   if (failed(offsets)) {
749     LLVM_DEBUG(DBGS() << "no offsets\n");
750     return rewriter.notifyMatchFailure(op, "no offsets");
751   }
752 
753   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
754 
755   SmallVector<Value, 4> indices;
756   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
757                                          indices);
758 
759   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
760       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
761   valueMapping[op] = newOp->getResult(0);
762   return success();
763 }
764 
765 static LogicalResult
766 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
767                        llvm::DenseMap<Value, Value> &valueMapping) {
768   OpBuilder::InsertionGuard g(rewriter);
769   rewriter.setInsertionPoint(op);
770 
771   Location loc = op.getLoc();
772   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
773       nvgpu::getWarpMatrixInfo(op);
774   if (failed(warpMatrixInfo))
775     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
776   FailureOr<nvgpu::FragmentElementInfo> regInfo =
777       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
778   if (failed(regInfo)) {
779     return rewriter.notifyMatchFailure(
780         op, "Failed to deduce register fragment type during "
781             "conversion to distributed non-ldmatrix compatible load");
782   }
783 
784   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
785   SmallVector<Value, 4> elements;
786 
787   // This is the individual element type.
788   Type loadedElType = regInfo->registerLLVMType;
789   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
790 
791   Value fill = rewriter.create<arith::ConstantOp>(
792       op.getLoc(), vectorType.getElementType(),
793       rewriter.getZeroAttr(vectorType.getElementType()));
794   Value result =
795       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
796 
797   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798 
799   // If we are not transposing, then we can use vectorized loads. Otherwise, we
800   // must load each element individually.
801   if (!isTransposeLoad) {
802     if (!isa<VectorType>(loadedElType)) {
803       loadedElType = VectorType::get({1}, loadedElType);
804     }
805 
806     for (int i = 0; i < vectorType.getShape()[0]; i++) {
807       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
808           rewriter, op.getLoc(), *warpMatrixInfo);
809       if (failed(coords))
810         return rewriter.notifyMatchFailure(op, "no coords");
811 
812       Value logicalValueId = rewriter.create<arith::ConstantOp>(
813           loc, rewriter.getIndexType(),
814           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
815       SmallVector<Value, 4> newIndices;
816       getXferIndices<vector::TransferReadOp>(
817           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
818 
819       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
820                                                  op.getSource(), newIndices);
821       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
822     }
823   } else {
824     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
825       loadedElType = vecType.getElementType();
826     }
827     for (int i = 0; i < vectorType.getShape()[0]; i++) {
828       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
829            innerIdx++) {
830 
831         Value logicalValueId = rewriter.create<arith::ConstantOp>(
832             loc, rewriter.getIndexType(),
833             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
834         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
835             rewriter, op.getLoc(), *warpMatrixInfo);
836         if (failed(coords))
837           return rewriter.notifyMatchFailure(op, "no coords");
838 
839         SmallVector<Value, 4> newIndices;
840         getXferIndices<vector::TransferReadOp>(
841             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
842         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
843                                                    op.getSource(), newIndices);
844         result = rewriter.create<vector::InsertOp>(
845             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
846       }
847     }
848   }
849 
850   valueMapping[op.getResult()] = result;
851   return success();
852 }
853 
854 /// Return true if this is a shared memory memref type.
855 static bool isSharedMemory(MemRefType type) {
856   auto addressSpace =
857       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
858   return addressSpace &&
859          addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
860 }
861 
862 /// Converts a `vector.transfer_read` operation directly to either a
863 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
864 /// used when converting to `nvgpu.mma.sync` operations.
865 static LogicalResult
866 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
867                            llvm::DenseMap<Value, Value> &valueMapping) {
868   OpBuilder::InsertionGuard g(rewriter);
869   rewriter.setInsertionPoint(op);
870 
871   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
872       nvgpu::getWarpMatrixInfo(op);
873   if (failed(warpMatrixInfo))
874     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
875 
876   bool isLdMatrixCompatible =
877       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
878       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
879 
880   VectorType vecTy = op.getVectorType();
881   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
882 
883   // When we are transposing the B operand, ldmatrix will only work if we have
884   // at least 8 rows to read and the width to read for the transpose is 128
885   // bits.
886   if (!op.getPermutationMap().isMinorIdentity() &&
887       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
888        vecTy.getDimSize(0) * bitWidth < 128))
889     isLdMatrixCompatible = false;
890 
891   if (!isLdMatrixCompatible)
892     return createNonLdMatrixLoads(rewriter, op, valueMapping);
893 
894   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
895 }
896 
897 static LogicalResult
898 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
899                              llvm::DenseMap<Value, Value> &valueMapping) {
900   OpBuilder::InsertionGuard g(rewriter);
901   rewriter.setInsertionPoint(op);
902 
903   Location loc = op->getLoc();
904   auto it = valueMapping.find(op.getVector());
905   if (it == valueMapping.end())
906     return rewriter.notifyMatchFailure(op, "no mapping");
907   Value matrix = it->second;
908 
909   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
910       nvgpu::getWarpMatrixInfo(op);
911   if (failed(warpMatrixInfo))
912     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
913   FailureOr<nvgpu::FragmentElementInfo> regInfo =
914       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
915   if (failed(regInfo))
916     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
917 
918   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
919   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
920 
921   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
922     Value logicalValueId = rewriter.create<arith::ConstantOp>(
923         loc, rewriter.getIndexType(),
924         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
925     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
926         rewriter, op.getLoc(), *warpMatrixInfo);
927     if (failed(coords))
928       return rewriter.notifyMatchFailure(op, "no coords");
929 
930     Value el =
931         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
932     SmallVector<Value, 4> newIndices;
933     getXferIndices<vector::TransferWriteOp>(
934         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
935     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
936   }
937 
938   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
939   rewriter.eraseOp(op);
940   return success();
941 }
942 
943 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944                                        SmallVectorImpl<int64_t> &results) {
945   for (auto attr : arrayAttr)
946     results.push_back(cast<IntegerAttr>(attr).getInt());
947 }
948 
949 static LogicalResult
950 convertExtractStridedSlice(RewriterBase &rewriter,
951                            vector::ExtractStridedSliceOp op,
952                            llvm::DenseMap<Value, Value> &valueMapping) {
953   OpBuilder::InsertionGuard g(rewriter);
954   rewriter.setInsertionPoint(op);
955 
956   Location loc = op->getLoc();
957 
958   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
959       nvgpu::getWarpMatrixInfo(op);
960   if (failed(warpMatrixInfo))
961     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
962 
963   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
964       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
965   if (failed(mmaSyncFragmentInfo))
966     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
967 
968   // Find the vector.transer_read whose result vector is being sliced.
969   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
970   if (!transferReadOp)
971     return rewriter.notifyMatchFailure(op, "no transfer read");
972 
973   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
974   if (failed(warpMatrixInfo))
975     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
976 
977   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
978       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
979   if (failed(ldFragmentInfo))
980     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
981 
982   assert(
983       (mmaSyncFragmentInfo->elementsPerRegister ==
984        ldFragmentInfo->elementsPerRegister) &&
985       "Number of elements per register should be same for load and mma.sync");
986 
987   // Create vector.extract_strided_slice op for thread-owned fragments.
988   std::array<int64_t, 2> strides = {1,
989                                     1}; // stride for extract slice is always 1.
990   std::array<int64_t, 2> sliceShape = {
991       mmaSyncFragmentInfo->numRegistersPerFragment,
992       mmaSyncFragmentInfo->elementsPerRegister};
993   auto it = valueMapping.find(transferReadOp);
994   if (it == valueMapping.end())
995     return rewriter.notifyMatchFailure(op, "no mapping");
996   auto sourceVector = it->second;
997 
998   // offset and sizes at warp-level of onwership.
999   SmallVector<int64_t> offsets;
1000   populateFromInt64AttrArray(op.getOffsets(), offsets);
1001 
1002   SmallVector<int64_t> sizes;
1003   populateFromInt64AttrArray(op.getSizes(), sizes);
1004   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005 
1006   // Compute offset in vector registers. Note that the mma.sync vector registers
1007   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1008   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1009   std::array<int64_t, 2> sliceOffset = {0, 0};
1010 
1011   if (offsets[0] && offsets[1])
1012     return op->emitError() << "Slicing fragments in 2D is not supported. ";
1013   if (offsets[0])
1014     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1015   else if (offsets[1])
1016     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1017 
1018   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1019       loc, sourceVector, sliceOffset, sliceShape, strides);
1020 
1021   valueMapping[op] = newOp;
1022   return success();
1023 }
1024 
1025 static LogicalResult
1026 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1027                   llvm::DenseMap<Value, Value> &valueMapping) {
1028   OpBuilder::InsertionGuard g(rewriter);
1029   rewriter.setInsertionPoint(op);
1030 
1031   auto itA = valueMapping.find(op.getLhs());
1032   auto itB = valueMapping.find(op.getRhs());
1033   auto itC = valueMapping.find(op.getAcc());
1034   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1035       itC == valueMapping.end())
1036     return rewriter.notifyMatchFailure(op, "no mapping");
1037   Value opA = itA->second, opB = itB->second, opC = itC->second;
1038   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1039       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1040       /*b_transpose=*/UnitAttr());
1041   valueMapping[op.getResult()] = matmul;
1042   return success();
1043 }
1044 
1045 static LogicalResult
1046 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047                            llvm::DenseMap<Value, Value> &valueMapping) {
1048   OpBuilder::InsertionGuard g(rewriter);
1049   rewriter.setInsertionPoint(op);
1050 
1051   auto itA = valueMapping.find(op.getLhs());
1052   auto itB = valueMapping.find(op.getRhs());
1053   auto itC = valueMapping.find(op.getAcc());
1054   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055       itC == valueMapping.end())
1056     return rewriter.notifyMatchFailure(op, "no mapping");
1057   Value opA = itA->second, opB = itB->second, opC = itC->second;
1058   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1062       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1063   valueMapping[op.getResult()] = matmul;
1064   return success();
1065 }
1066 
1067 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068 static LogicalResult
1069 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070                   llvm::DenseMap<Value, Value> &valueMapping) {
1071   OpBuilder::InsertionGuard g(rewriter);
1072   rewriter.setInsertionPoint(op);
1073 
1074   assert(constantSupportsMMAMatrixType(op));
1075 
1076   auto splat =
1077       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078   auto scalarConstant =
1079       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1080   const char *fragType = inferFragType(op);
1081   auto vecType = cast<VectorType>(op.getType());
1082   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1083       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1085       op.getLoc(), type, scalarConstant);
1086   valueMapping[op.getResult()] = matrix;
1087   return success();
1088 }
1089 
1090 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091 static LogicalResult
1092 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093                    llvm::DenseMap<Value, Value> &valueMapping) {
1094   OpBuilder::InsertionGuard g(rewriter);
1095   rewriter.setInsertionPoint(op);
1096 
1097   assert(broadcastSupportsMMAMatrixType(op));
1098 
1099   const char *fragType = inferFragType(op);
1100   auto vecType = op.getResultVectorType();
1101   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1102       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1104       op.getLoc(), type, op.getSource());
1105   valueMapping[op.getResult()] = matrix;
1106   return success();
1107 }
1108 
1109 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110 // updated and needs to be updated separately for the loop to be correct.
1111 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112                                                scf::ForOp loop,
1113                                                ValueRange newInitArgs) {
1114   OpBuilder::InsertionGuard g(rewriter);
1115   rewriter.setInsertionPoint(loop);
1116 
1117   // Create a new loop before the existing one, with the extra operands.
1118   rewriter.setInsertionPoint(loop);
1119   auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120   llvm::append_range(operands, newInitArgs);
1121   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1122       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1123       operands);
1124   rewriter.eraseBlock(newLoop.getBody());
1125 
1126   newLoop.getRegion().getBlocks().splice(
1127       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128   for (Value operand : newInitArgs)
1129     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1130 
1131   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132                                                   loop.getNumResults())))
1133     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1134 
1135   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1136   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1137   LLVM_DEBUG(DBGS() << "erase: " << loop);
1138 
1139   rewriter.eraseOp(loop);
1140   return newLoop;
1141 }
1142 
1143 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144                                   llvm::DenseMap<Value, Value> &valueMapping) {
1145   OpBuilder::InsertionGuard g(rewriter);
1146   rewriter.setInsertionPoint(op);
1147 
1148   SmallVector<Value> newOperands;
1149   SmallVector<std::pair<size_t, size_t>> argMapping;
1150   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1151     auto it = valueMapping.find(operand.value());
1152     if (it == valueMapping.end()) {
1153       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1154       continue;
1155     }
1156     argMapping.push_back(std::make_pair(
1157         operand.index(), op.getInitArgs().size() + newOperands.size()));
1158     newOperands.push_back(it->second);
1159   }
1160 
1161   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1162   Block &loopBody = *newForOp.getBody();
1163   for (auto mapping : argMapping) {
1164     valueMapping[newForOp.getResult(mapping.first)] =
1165         newForOp.getResult(mapping.second);
1166     valueMapping[loopBody.getArgument(mapping.first +
1167                                       newForOp.getNumInductionVars())] =
1168         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1169   }
1170 
1171   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1172   return success();
1173 }
1174 
1175 static LogicalResult
1176 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177                llvm::DenseMap<Value, Value> &valueMapping) {
1178   OpBuilder::InsertionGuard g(rewriter);
1179   rewriter.setInsertionPoint(op);
1180 
1181   auto loop = cast<scf::ForOp>(op->getParentOp());
1182   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1183   for (const auto &operand : llvm::enumerate(op.getOperands())) {
1184     auto it = valueMapping.find(operand.value());
1185     if (it == valueMapping.end())
1186       continue;
1187     // Replace the yield of old value with the for op argument to make it easier
1188     // to remove the dead code.
1189     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190     yieldOperands.push_back(it->second);
1191   }
1192   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1193 
1194   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1195   rewriter.eraseOp(op);
1196   return success();
1197 }
1198 
1199 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200 static LogicalResult
1201 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1202                      gpu::MMAElementwiseOp opType,
1203                      llvm::DenseMap<Value, Value> &valueMapping) {
1204   OpBuilder::InsertionGuard g(rewriter);
1205   rewriter.setInsertionPoint(op);
1206 
1207   SmallVector<Value> matrixOperands;
1208   for (Value operand : op->getOperands()) {
1209     auto it = valueMapping.find(operand);
1210     if (it == valueMapping.end())
1211       return rewriter.notifyMatchFailure(op, "no mapping");
1212     matrixOperands.push_back(it->second);
1213   }
1214   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1215   if (opType == gpu::MMAElementwiseOp::EXTF) {
1216     // The floating point extension case has a different result type.
1217     auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1218     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1219                                          vectorType.getElementType(),
1220                                          resultType.getOperand());
1221   }
1222 
1223   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1224       op->getLoc(), resultType, matrixOperands, opType);
1225   valueMapping[op->getResult(0)] = newOp;
1226   return success();
1227 }
1228 
1229 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1230                                               bool useNvGpu) {
1231   if (!useNvGpu) {
1232     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233         patterns.getContext());
1234     return;
1235   }
1236   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1237   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1238 }
1239 
1240 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1241                                           Operation *rootOp) {
1242   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1243   llvm::DenseMap<Value, Value> valueMapping;
1244 
1245   auto globalRes = LogicalResult::success();
1246   for (Operation *op : ops) {
1247     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1248     // Apparently callers do not want to early exit on failure here.
1249     auto res = LogicalResult::success();
1250     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1251       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1252     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1253       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1254     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1255       res = convertContractOp(rewriter, contractOp, valueMapping);
1256     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1257       res = convertConstantOp(rewriter, constantOp, valueMapping);
1258     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1259       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1260     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1261       res = convertForOp(rewriter, forOp, valueMapping);
1262     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1263       res = convertYieldOp(rewriter, yieldOp, valueMapping);
1264     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1266     }
1267     if (failed(res))
1268       globalRes = failure();
1269   }
1270   return globalRes;
1271 }
1272 
1273 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1274                                                          Operation *rootOp) {
1275   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1276   llvm::DenseMap<Value, Value> valueMapping;
1277   for (Operation *op : ops) {
1278     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1279             .Case([&](vector::TransferReadOp transferReadOp) {
1280               return convertTransferReadToLoads(rewriter, transferReadOp,
1281                                                 valueMapping);
1282             })
1283             .Case([&](vector::TransferWriteOp transferWriteOp) {
1284               return convertTransferWriteToStores(rewriter, transferWriteOp,
1285                                                   valueMapping);
1286             })
1287             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1289                                                 valueMapping);
1290             })
1291             .Case([&](vector::ContractionOp contractionOp) {
1292               return convertContractOpToMmaSync(rewriter, contractionOp,
1293                                                 valueMapping);
1294             })
1295             .Case([&](scf::ForOp forOp) {
1296               return convertForOp(rewriter, forOp, valueMapping);
1297             })
1298             .Case([&](scf::YieldOp yieldOp) {
1299               return convertYieldOp(rewriter, yieldOp, valueMapping);
1300             })
1301             .Case([&](arith::ConstantOp constOp) {
1302               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1303             })
1304             .Default([&](Operation *op) {
1305               return op->emitError() << "unhandled vector to mma type: " << *op;
1306             })
1307             .failed()) {
1308       return op->emitOpError()
1309              << "failed to convert op during vector-to-nvgpu conversion";
1310     }
1311   }
1312   return success();
1313 }
1314 
1315 namespace {
1316 
1317 struct ConvertVectorToGPUPass
1318     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319 
1320   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321     useNvGpu.setValue(useNvGpu_);
1322   }
1323 
1324   void runOnOperation() override {
1325     RewritePatternSet patterns(&getContext());
1326     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1327     if (failed(
1328             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1329       return signalPassFailure();
1330 
1331     IRRewriter rewriter(&getContext());
1332     if (useNvGpu) {
1333       if (failed(
1334               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1335         return signalPassFailure();
1336       return;
1337     }
1338     (void)convertVectorToMMAOps(rewriter, getOperation());
1339   }
1340 };
1341 
1342 } // namespace
1343 
1344 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1345   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1346 }
1347