xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision b984045d4fbd061c6ddb25beeff9797f78b26e27)
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arith/IR/Arith.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Region.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58                            AffineMap offsetMap, ArrayRef<Value> dimValues,
59                            SmallVector<Value, 4> &indices) {
60   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61   Location loc = xferOp.getLoc();
62   unsigned offsetsIdx = 0;
63   for (auto expr : xferOp.getPermutationMap().getResults()) {
64     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65       Value prevIdx = indices[dim.getPosition()];
66       SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67       dims.push_back(prevIdx);
68       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69       indices[dim.getPosition()] = affine::makeComposedAffineApply(
70           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71       continue;
72     }
73   }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78                                           bool useNvGpu) {
79   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
81   AffineExpr m, n, k;
82   bindDims(contract.getContext(), m, n, k);
83   auto iteratorTypes = contract.getIteratorTypes().getValue();
84   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
85         vector::isParallelIterator(iteratorTypes[1]) &&
86         vector::isReductionIterator(iteratorTypes[2])))
87     return false;
88 
89   // The contract needs to represent a matmul to be able to convert to
90   // MMAMatrix matmul.
91   if (!useNvGpu &&
92       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
93     return false;
94   if (useNvGpu &&
95       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
96     return false;
97 
98   return true;
99 }
100 
101 // Return true if the given map represents a transposed matrix load,
102 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
104   MLIRContext *ctx = permutationMap.getContext();
105   // Local OpBuilder is fine here, we just build attributes.
106   OpBuilder b(ctx);
107   auto nDim = permutationMap.getNumDims();
108   AffineExpr zero = b.getAffineConstantExpr(0);
109   if (nDim < 2) {
110     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
111     AffineExpr dim0 = b.getAffineDimExpr(0);
112     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
113   }
114 
115   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
116   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
117   // Support both transposed and transposed+broadcasted cases.
118   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
119          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
120 }
121 
122 // Return the stide for the second-to-last dimension of |type| if it is a memref
123 // and has a constant stride.
124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
125   auto memrefType = dyn_cast<MemRefType>(type);
126   if (!memrefType)
127     return false;
128   // If the memref is 0 or 1D the horizontal stride is 0.
129   if (memrefType.getRank() < 2)
130     return 0;
131   int64_t offset = 0;
132   SmallVector<int64_t, 2> strides;
133   if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
134       strides.back() != 1)
135     return std::nullopt;
136   int64_t stride = strides[strides.size() - 2];
137   if (stride == ShapedType::kDynamic)
138     return std::nullopt;
139   return stride;
140 }
141 
142 // Return true if the transfer op can be converted to a MMA matrix load.
143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
144   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
145       readOp.getVectorType().getRank() != 2)
146     return false;
147   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
148     return false;
149 
150   // Only allow integer types if the signedness can be inferred.
151   if (readOp.getVectorType().getElementType().isInteger(8))
152     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
153                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
154       return false;
155 
156   AffineMap map = readOp.getPermutationMap();
157   MLIRContext *ctx = readOp.getContext();
158   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
159   AffineExpr zero = getAffineConstantExpr(0, ctx);
160   auto broadcastInnerDim =
161       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
162   return map.isMinorIdentity() || map == broadcastInnerDim ||
163          isTransposeMatrixLoadMap(map);
164 }
165 
166 // Return true if the transfer op can be converted to a MMA matrix store.
167 static bool
168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
169   // TODO: support 0-d corner case.
170   if (writeOp.getTransferRank() == 0)
171     return false;
172 
173   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
174       writeOp.getVectorType().getRank() != 2)
175     return false;
176   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
177     return false;
178   // TODO: Support transpose once it is added to GPU dialect ops.
179   if (!writeOp.getPermutationMap().isMinorIdentity())
180     return false;
181   return true;
182 }
183 
184 /// Return true if the constant is a splat to a 2D vector so that it can be
185 /// converted to a MMA constant matrix op.
186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
187   auto vecType = dyn_cast<VectorType>(constantOp.getType());
188   if (!vecType || vecType.getRank() != 2)
189     return false;
190   return isa<SplatElementsAttr>(constantOp.getValue());
191 }
192 
193 /// Return true if this is a broadcast from scalar to a 2D vector.
194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
195   return broadcastOp.getResultVectorType().getRank() == 2;
196 }
197 
198 /// Return true if this integer extend op can be folded into a contract op.
199 template <typename ExtOpTy>
200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
201   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
202     return false;
203   return llvm::all_of(extOp->getUsers(), [](Operation *user) {
204     return isa<vector::ContractionOp>(user);
205   });
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
213 convertElementwiseOpToMMA(Operation *op) {
214   if (isa<arith::AddFOp>(op))
215     return gpu::MMAElementwiseOp::ADDF;
216   if (isa<arith::MulFOp>(op))
217     return gpu::MMAElementwiseOp::MULF;
218   if (isa<arith::SubFOp>(op))
219     return gpu::MMAElementwiseOp::SUBF;
220   if (isa<arith::MaximumFOp>(op))
221     return gpu::MMAElementwiseOp::MAXF;
222   if (isa<arith::MinimumFOp>(op))
223     return gpu::MMAElementwiseOp::MINF;
224   if (isa<arith::DivFOp>(op))
225     return gpu::MMAElementwiseOp::DIVF;
226   if (isa<arith::AddIOp>(op))
227     return gpu::MMAElementwiseOp::ADDI;
228   if (isa<arith::MulIOp>(op))
229     return gpu::MMAElementwiseOp::MULI;
230   if (isa<arith::SubIOp>(op))
231     return gpu::MMAElementwiseOp::SUBI;
232   if (isa<arith::DivSIOp>(op))
233     return gpu::MMAElementwiseOp::DIVS;
234   if (isa<arith::DivUIOp>(op))
235     return gpu::MMAElementwiseOp::DIVU;
236   if (isa<arith::NegFOp>(op))
237     return gpu::MMAElementwiseOp::NEGATEF;
238   if (isa<arith::ExtFOp>(op))
239     return gpu::MMAElementwiseOp::EXTF;
240   return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
244 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
245   return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
254       nvgpu::getWarpMatrixInfo(op);
255   if (failed(warpMatrixInfo))
256     return false;
257 
258   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259   if (failed(contractOp))
260     return false;
261 
262   // Handle vector.extract_strided_slice on registers containing
263   // matrixB and matrixC operands. vector.extract_strided_slice op
264   // is not supported on registers containing matrixA operands.
265   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266     return (cast<VectorType>(op->getResult(0).getType()) ==
267             cast<VectorType>((*contractOp).getRhs().getType()));
268   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269     return (cast<VectorType>(op->getResult(0).getType()) ==
270             cast<VectorType>((*contractOp).getAcc().getType()));
271 
272   return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276   if (isa<scf::ForOp, scf::YieldOp>(op))
277     return true;
278   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280                     : transferReadSupportsMMAMatrixType(transferRead);
281   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283                     : transferWriteSupportsMMAMatrixType(transferWrite);
284   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285     return useNvGpu &&
286            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287   if (auto contract = dyn_cast<vector::ContractionOp>(op))
288     return contractSupportsMMAMatrixType(contract, useNvGpu);
289   if (auto constant = dyn_cast<arith::ConstantOp>(op))
290     return constantSupportsMMAMatrixType(constant);
291   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
292     return broadcastSupportsMMAMatrixType(broadcast);
293   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298     return fpExtendSupportsMMAMatrixType(fpExtend);
299   return elementwiseSupportsMMAMatrixType(op);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
305 static SetVector<Operation *>
306 getSliceContract(Operation *op,
307                  const BackwardSliceOptions &backwardSliceOptions,
308                  const ForwardSliceOptions &forwardSliceOptions) {
309   SetVector<Operation *> slice;
310   slice.insert(op);
311   unsigned currentIndex = 0;
312   SetVector<Operation *> backwardSlice;
313   SetVector<Operation *> forwardSlice;
314   while (currentIndex != slice.size()) {
315     auto *currentOp = (slice)[currentIndex];
316     // Compute and insert the backwardSlice starting from currentOp.
317     backwardSlice.clear();
318     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
319     slice.insert(backwardSlice.begin(), backwardSlice.end());
320 
321     // Compute and insert the forwardSlice starting from currentOp.
322     forwardSlice.clear();
323     // Special case for ForOp, we don't want to include the whole region but
324     // only the value using the region arguments.
325     // TODO: We should refine this to only care about the region arguments being
326     // converted to matrix type.
327     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328       for (Value forOpResult : forOp.getResults())
329         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330       for (BlockArgument &arg : forOp.getRegionIterArgs())
331         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332     } else {
333       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334     }
335     slice.insert(forwardSlice.begin(), forwardSlice.end());
336     ++currentIndex;
337   }
338   return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
343 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
344                                              bool useNvGpu) {
345   auto hasVectorDest = [](Operation *op) {
346     return llvm::any_of(op->getResultTypes(),
347                         [](Type t) { return isa<VectorType>(t); });
348   };
349   BackwardSliceOptions backwardSliceOptions;
350   backwardSliceOptions.filter = hasVectorDest;
351 
352   auto hasVectorSrc = [](Operation *op) {
353     return llvm::any_of(op->getOperandTypes(),
354                         [](Type t) { return isa<VectorType>(t); });
355   };
356   ForwardSliceOptions forwardSliceOptions;
357   forwardSliceOptions.filter = hasVectorSrc;
358 
359   SetVector<Operation *> opToConvert;
360   op->walk([&](vector::ContractionOp contract) {
361     if (opToConvert.contains(contract.getOperation()))
362       return;
363     SetVector<Operation *> dependentOps =
364         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365     // If any instruction cannot use MMA matrix type drop the whole
366     // chain. MMA matrix are stored in an opaque type so they cannot be used
367     // by all operations.
368     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369           if (!supportsMMaMatrixType(op, useNvGpu)) {
370             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371             return true;
372           }
373           return false;
374         }))
375       return;
376 
377     opToConvert.insert(dependentOps.begin(), dependentOps.end());
378   });
379   // Sort the operations so that we can convert them in topological order.
380   return topologicalSort(opToConvert);
381 }
382 
383 namespace {
384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385 // to MMA matmul.
386 struct PrepareContractToGPUMMA
387     : public OpRewritePattern<vector::ContractionOp> {
388   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
389 
390   LogicalResult matchAndRewrite(vector::ContractionOp op,
391                                 PatternRewriter &rewriter) const override {
392     Location loc = op.getLoc();
393     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394 
395     // Set up the parallel/reduction structure in right form.
396     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
398     AffineExpr m, n, k;
399     bindDims(rewriter.getContext(), m, n, k);
400     static constexpr std::array<int64_t, 2> perm = {1, 0};
401     auto iteratorTypes = op.getIteratorTypes().getValue();
402     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
404           vector::isParallelIterator(iteratorTypes[1]) &&
405           vector::isReductionIterator(iteratorTypes[2])))
406       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407     //
408     // Two outer parallel, one inner reduction (matmat flavor).
409     //
410     // This is the classical row-major matmul, nothing to do.
411     if (maps == infer({{m, k}, {k, n}, {m, n}}))
412       return rewriter.notifyMatchFailure(op, "contraction already prepared");
413     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421       std::swap(rhs, lhs);
422       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425       std::swap(rhs, lhs);
426       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428       std::swap(lhs, rhs);
429       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431       std::swap(lhs, rhs);
432     } else {
433       // TODO: llvm_unreachable ?
434       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435     }
436     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437         op, lhs, rhs, res,
438         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
439         op.getIteratorTypes());
440     return success();
441   }
442 };
443 
444 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446 // respectively. We can fold the transpose operation when loading the data from
447 // Shared Memory to registers.
448 struct CombineTransferReadOpTranspose final
449     : public OpRewritePattern<vector::TransposeOp> {
450   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
451 
452   LogicalResult matchAndRewrite(vector::TransposeOp op,
453                                 PatternRewriter &rewriter) const override {
454     // Look through integer extend ops.
455     Value source = op.getVector();
456     Type resultType = op.getType();
457     Operation *extOp;
458     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459         (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460         (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461       source = extOp->getOperand(0);
462       resultType =
463           VectorType::get(cast<VectorType>(resultType).getShape(),
464                           cast<VectorType>(source.getType()).getElementType());
465     }
466 
467     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468     if (!transferReadOp)
469       return rewriter.notifyMatchFailure(op, "no transfer read");
470 
471     // TODO: support 0-d corner case.
472     if (transferReadOp.getTransferRank() == 0)
473       return rewriter.notifyMatchFailure(op, "0-D transfer read");
474 
475     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477 
478     AffineMap permutationMap =
479         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480     AffineMap newMap =
481         permutationMap.compose(transferReadOp.getPermutationMap());
482 
483     auto loc = op.getLoc();
484     Value result =
485         rewriter
486             .create<vector::TransferReadOp>(
487                 loc, resultType, transferReadOp.getSource(),
488                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489                 transferReadOp.getPadding(), transferReadOp.getMask(),
490                 transferReadOp.getInBoundsAttr())
491             .getResult();
492 
493     // Fuse through the integer extend op.
494     if (extOp) {
495       if (isa<arith::ExtSIOp>(extOp))
496         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497                      .getResult();
498       else if (isa<arith::ExtUIOp>(extOp))
499         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500                      .getResult();
501       else
502         result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503                      .getResult();
504     }
505 
506     rewriter.replaceOp(op, result);
507     return success();
508   }
509 };
510 
511 } // namespace
512 
513 // MMA types have different layout based on how they are used in matmul ops.
514 // Figure the right layout to use by looking at op uses.
515 // TODO: Change the GPU dialect to abstract the layout at the this level and
516 // only care about it during lowering to NVVM.
517 static const char *inferFragType(Operation *op) {
518   for (Operation *users : op->getUsers()) {
519     auto contract = dyn_cast<vector::ContractionOp>(users);
520     if (!contract)
521       continue;
522     assert(op->getNumResults() == 1);
523     if (contract.getLhs() == op->getResult(0))
524       return "AOp";
525     if (contract.getRhs() == op->getResult(0))
526       return "BOp";
527   }
528   return "COp";
529 }
530 
531 static LogicalResult
532 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
533                       llvm::DenseMap<Value, Value> &valueMapping) {
534   OpBuilder::InsertionGuard g(rewriter);
535   rewriter.setInsertionPoint(op);
536 
537   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
538   assert(transferReadSupportsMMAMatrixType(op) &&
539          "expected convertible operation");
540 
541   std::optional<int64_t> stride =
542       getStaticallyKnownRowStride(op.getShapedType());
543   if (!stride.has_value()) {
544     LLVM_DEBUG(DBGS() << "no stride\n");
545     return rewriter.notifyMatchFailure(op, "no stride");
546   }
547 
548   AffineMap map = op.getPermutationMap();
549   bool isTranspose = isTransposeMatrixLoadMap(map);
550 
551   // Handle broadcast by setting the stride to 0.
552   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
553     assert(cstExpr.getValue() == 0);
554     stride = 0;
555   }
556 
557   Value mappingResult = op.getResult();
558   auto elType = op.getVectorType().getElementType();
559   const char *fragType = inferFragType(op);
560   if (op->hasOneUse()) {
561     auto *user = *op->user_begin();
562     // Infer the signedness of the mma type from the integer extend.
563     bool isSignedExtend = isa<arith::ExtSIOp>(user);
564     if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
565       elType = IntegerType::get(
566           op.getContext(), cast<IntegerType>(elType).getWidth(),
567           isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
568       mappingResult = user->getResult(0);
569       fragType = inferFragType(user);
570     }
571   }
572   gpu::MMAMatrixType type =
573       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
574   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
575       op.getLoc(), type, op.getSource(), op.getIndices(),
576       rewriter.getIndexAttr(*stride),
577       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
578   valueMapping[mappingResult] = load;
579 
580   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
581   return success();
582 }
583 
584 static LogicalResult
585 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
586                        llvm::DenseMap<Value, Value> &valueMapping) {
587   OpBuilder::InsertionGuard g(rewriter);
588   rewriter.setInsertionPoint(op);
589 
590   assert(transferWriteSupportsMMAMatrixType(op));
591   std::optional<int64_t> stride =
592       getStaticallyKnownRowStride(op.getShapedType());
593   if (!stride.has_value()) {
594     LLVM_DEBUG(DBGS() << "no stride\n");
595     return rewriter.notifyMatchFailure(op, "no stride");
596   }
597 
598   auto it = valueMapping.find(op.getVector());
599   if (it == valueMapping.end()) {
600     LLVM_DEBUG(DBGS() << "no mapping\n");
601     return rewriter.notifyMatchFailure(op, "no mapping");
602   }
603 
604   Value matrix = it->second;
605   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
606       op.getLoc(), matrix, op.getSource(), op.getIndices(),
607       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
608   (void)store;
609 
610   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
611 
612   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
613   rewriter.eraseOp(op);
614   return success();
615 }
616 
617 /// Returns the vector type which represents a matrix fragment.
618 static VectorType
619 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
620   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
621                              regInfo.elementsPerRegister};
622   Type elType = regInfo.registerLLVMType;
623   if (auto vecType = dyn_cast<VectorType>(elType))
624     elType = vecType.getElementType();
625   return VectorType::get(shape, elType);
626 }
627 
628 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
629 static LogicalResult
630 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
631                          llvm::DenseMap<Value, Value> &valueMapping) {
632   OpBuilder::InsertionGuard g(rewriter);
633   rewriter.setInsertionPoint(op);
634 
635   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
636       nvgpu::getWarpMatrixInfo(op);
637   if (failed(warpMatrixInfo)) {
638     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
639     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
640   }
641 
642   FailureOr<nvgpu::FragmentElementInfo> regInfo =
643       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
644   if (failed(regInfo)) {
645     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
646     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
647   }
648 
649   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
650   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
651   if (!dense) {
652     LLVM_DEBUG(DBGS() << "not a splat\n");
653     return rewriter.notifyMatchFailure(op, "not a splat");
654   }
655 
656   Value result = rewriter.create<arith::ConstantOp>(
657       op.getLoc(), vectorType,
658       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
659   valueMapping[op.getResult()] = result;
660   return success();
661 }
662 
663 /// Check if the loaded matrix operand requires transposed.
664 /// Transposed Map Example:
665 /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
666 /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
667 /// The code below checks if the output 2D is transposed using a generalized
668 /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
669 /// Returns     : true; if m > n, false o.w.
670 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
671   mlir::AffineMap map = op.getPermutationMap();
672 
673   if (map.getNumResults() != 2) {
674     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
675                          "is not a 2d operand\n");
676     return failure();
677   }
678 
679   // Output 2D matrix dimensions in the order of d0, d1.
680   mlir::AffineExpr dM = map.getResult(0);
681   mlir::AffineExpr dN = map.getResult(1);
682 
683   //  Find the position of these expressions in the input.
684   auto exprM = dyn_cast<AffineDimExpr>(dM);
685   auto exprN = dyn_cast<AffineDimExpr>(dN);
686 
687   if (!exprM || !exprN) {
688     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
689                          "expressions, then transpose cannot be determined.\n");
690     return failure();
691   }
692 
693   return exprM.getPosition() > exprN.getPosition();
694 }
695 
696 static LogicalResult
697 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
698                              llvm::DenseMap<Value, Value> &valueMapping) {
699   OpBuilder::InsertionGuard g(rewriter);
700   rewriter.setInsertionPoint(op);
701   Location loc = op->getLoc();
702 
703   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
704       nvgpu::getWarpMatrixInfo(op);
705   if (failed(warpMatrixInfo)) {
706     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
707     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
708   }
709 
710   FailureOr<nvgpu::FragmentElementInfo> regInfo =
711       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
712   if (failed(regInfo)) {
713     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
714     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
715   }
716 
717   FailureOr<bool> transpose = isTransposed(op);
718   if (failed(transpose)) {
719     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
720     return rewriter.notifyMatchFailure(
721         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
722   }
723 
724   FailureOr<nvgpu::LdMatrixParams> params =
725       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
726 
727   if (failed(params)) {
728     LLVM_DEBUG(
729         DBGS()
730         << "failed to convert vector.transfer_read to ldmatrix. "
731         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
732     return rewriter.notifyMatchFailure(
733         op, "failed to convert vector.transfer_read to ldmatrix; this op "
734             "likely should not be converted to a nvgpu.ldmatrix call.");
735   }
736 
737   // Adjust the load offset.
738   auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
739   FailureOr<AffineMap> offsets =
740       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
741   if (failed(offsets)) {
742     LLVM_DEBUG(DBGS() << "no offsets\n");
743     return rewriter.notifyMatchFailure(op, "no offsets");
744   }
745 
746   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
747 
748   SmallVector<Value, 4> indices;
749   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
750                                          indices);
751 
752   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
753       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
754   valueMapping[op] = newOp->getResult(0);
755   return success();
756 }
757 
758 static LogicalResult
759 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
760                        llvm::DenseMap<Value, Value> &valueMapping) {
761   OpBuilder::InsertionGuard g(rewriter);
762   rewriter.setInsertionPoint(op);
763 
764   Location loc = op.getLoc();
765   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
766       nvgpu::getWarpMatrixInfo(op);
767   if (failed(warpMatrixInfo))
768     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
769   FailureOr<nvgpu::FragmentElementInfo> regInfo =
770       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
771   if (failed(regInfo)) {
772     return rewriter.notifyMatchFailure(
773         op, "Failed to deduce register fragment type during "
774             "conversion to distributed non-ldmatrix compatible load");
775   }
776 
777   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
778   SmallVector<Value, 4> elements;
779 
780   // This is the individual element type.
781   Type loadedElType = regInfo->registerLLVMType;
782   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
783 
784   Value fill = rewriter.create<arith::ConstantOp>(
785       op.getLoc(), vectorType.getElementType(),
786       rewriter.getZeroAttr(vectorType.getElementType()));
787   Value result =
788       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
789 
790   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
791 
792   // If we are not transposing, then we can use vectorized loads. Otherwise, we
793   // must load each element individually.
794   if (!isTransposeLoad) {
795     if (!isa<VectorType>(loadedElType)) {
796       loadedElType = VectorType::get({1}, loadedElType);
797     }
798 
799     for (int i = 0; i < vectorType.getShape()[0]; i++) {
800       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
801           rewriter, op.getLoc(), *warpMatrixInfo);
802       if (failed(coords))
803         return rewriter.notifyMatchFailure(op, "no coords");
804 
805       Value logicalValueId = rewriter.create<arith::ConstantOp>(
806           loc, rewriter.getIndexType(),
807           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
808       SmallVector<Value, 4> newIndices;
809       getXferIndices<vector::TransferReadOp>(
810           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
811 
812       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
813                                                  op.getSource(), newIndices);
814       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
815     }
816   } else {
817     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
818       loadedElType = vecType.getElementType();
819     }
820     for (int i = 0; i < vectorType.getShape()[0]; i++) {
821       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
822            innerIdx++) {
823 
824         Value logicalValueId = rewriter.create<arith::ConstantOp>(
825             loc, rewriter.getIndexType(),
826             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
827         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
828             rewriter, op.getLoc(), *warpMatrixInfo);
829         if (failed(coords))
830           return rewriter.notifyMatchFailure(op, "no coords");
831 
832         SmallVector<Value, 4> newIndices;
833         getXferIndices<vector::TransferReadOp>(
834             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
835         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
836                                                    op.getSource(), newIndices);
837         result = rewriter.create<vector::InsertOp>(
838             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
839       }
840     }
841   }
842 
843   valueMapping[op.getResult()] = result;
844   return success();
845 }
846 
847 /// Return true if this is a shared memory memref type.
848 static bool isSharedMemory(MemRefType type) {
849   auto addressSpace =
850       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
851   return addressSpace &&
852          addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
853 }
854 
855 /// Converts a `vector.transfer_read` operation directly to either a
856 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
857 /// used when converting to `nvgpu.mma.sync` operations.
858 static LogicalResult
859 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
860                            llvm::DenseMap<Value, Value> &valueMapping) {
861   OpBuilder::InsertionGuard g(rewriter);
862   rewriter.setInsertionPoint(op);
863 
864   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
865       nvgpu::getWarpMatrixInfo(op);
866   if (failed(warpMatrixInfo))
867     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
868 
869   bool isLdMatrixCompatible =
870       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
871       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
872 
873   VectorType vecTy = op.getVectorType();
874   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
875 
876   // When we are transposing the B operand, ldmatrix will only work if we have
877   // at least 8 rows to read and the width to read for the transpose is 128
878   // bits.
879   if (!op.getPermutationMap().isMinorIdentity() &&
880       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
881        vecTy.getDimSize(0) * bitWidth < 128))
882     isLdMatrixCompatible = false;
883 
884   if (!isLdMatrixCompatible)
885     return createNonLdMatrixLoads(rewriter, op, valueMapping);
886 
887   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
888 }
889 
890 static LogicalResult
891 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
892                              llvm::DenseMap<Value, Value> &valueMapping) {
893   OpBuilder::InsertionGuard g(rewriter);
894   rewriter.setInsertionPoint(op);
895 
896   Location loc = op->getLoc();
897   auto it = valueMapping.find(op.getVector());
898   if (it == valueMapping.end())
899     return rewriter.notifyMatchFailure(op, "no mapping");
900   Value matrix = it->second;
901 
902   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
903       nvgpu::getWarpMatrixInfo(op);
904   if (failed(warpMatrixInfo))
905     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
906   FailureOr<nvgpu::FragmentElementInfo> regInfo =
907       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
908   if (failed(regInfo))
909     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
910 
911   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
912   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
913 
914   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
915     Value logicalValueId = rewriter.create<arith::ConstantOp>(
916         loc, rewriter.getIndexType(),
917         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
918     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
919         rewriter, op.getLoc(), *warpMatrixInfo);
920     if (failed(coords))
921       return rewriter.notifyMatchFailure(op, "no coords");
922 
923     Value el =
924         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
925     SmallVector<Value, 4> newIndices;
926     getXferIndices<vector::TransferWriteOp>(
927         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
928     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
929   }
930 
931   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
932   rewriter.eraseOp(op);
933   return success();
934 }
935 
936 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
937                                        SmallVectorImpl<int64_t> &results) {
938   for (auto attr : arrayAttr)
939     results.push_back(cast<IntegerAttr>(attr).getInt());
940 }
941 
942 static LogicalResult
943 convertExtractStridedSlice(RewriterBase &rewriter,
944                            vector::ExtractStridedSliceOp op,
945                            llvm::DenseMap<Value, Value> &valueMapping) {
946   OpBuilder::InsertionGuard g(rewriter);
947   rewriter.setInsertionPoint(op);
948 
949   Location loc = op->getLoc();
950 
951   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
952       nvgpu::getWarpMatrixInfo(op);
953   if (failed(warpMatrixInfo))
954     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
955 
956   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
957       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
958   if (failed(mmaSyncFragmentInfo))
959     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
960 
961   // Find the vector.transer_read whose result vector is being sliced.
962   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
963   if (!transferReadOp)
964     return rewriter.notifyMatchFailure(op, "no transfer read");
965 
966   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
967   if (failed(warpMatrixInfo))
968     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
969 
970   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
971       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
972   if (failed(ldFragmentInfo))
973     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
974 
975   assert(
976       (mmaSyncFragmentInfo->elementsPerRegister ==
977        ldFragmentInfo->elementsPerRegister) &&
978       "Number of elements per register should be same for load and mma.sync");
979 
980   // Create vector.extract_strided_slice op for thread-owned fragments.
981   std::array<int64_t, 2> strides = {1,
982                                     1}; // stride for extract slice is always 1.
983   std::array<int64_t, 2> sliceShape = {
984       mmaSyncFragmentInfo->numRegistersPerFragment,
985       mmaSyncFragmentInfo->elementsPerRegister};
986   auto it = valueMapping.find(transferReadOp);
987   if (it == valueMapping.end())
988     return rewriter.notifyMatchFailure(op, "no mapping");
989   auto sourceVector = it->second;
990 
991   // offset and sizes at warp-level of onwership.
992   SmallVector<int64_t> offsets;
993   populateFromInt64AttrArray(op.getOffsets(), offsets);
994 
995   SmallVector<int64_t> sizes;
996   populateFromInt64AttrArray(op.getSizes(), sizes);
997   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
998 
999   // Compute offset in vector registers. Note that the mma.sync vector registers
1000   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1001   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1002   std::array<int64_t, 2> sliceOffset = {0, 0};
1003 
1004   if (offsets[0] && offsets[1])
1005     return op->emitError() << "Slicing fragments in 2D is not supported. ";
1006   if (offsets[0])
1007     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1008   else if (offsets[1])
1009     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1010 
1011   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1012       loc, sourceVector, sliceOffset, sliceShape, strides);
1013 
1014   valueMapping[op] = newOp;
1015   return success();
1016 }
1017 
1018 static LogicalResult
1019 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1020                   llvm::DenseMap<Value, Value> &valueMapping) {
1021   OpBuilder::InsertionGuard g(rewriter);
1022   rewriter.setInsertionPoint(op);
1023 
1024   auto itA = valueMapping.find(op.getLhs());
1025   auto itB = valueMapping.find(op.getRhs());
1026   auto itC = valueMapping.find(op.getAcc());
1027   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1028       itC == valueMapping.end())
1029     return rewriter.notifyMatchFailure(op, "no mapping");
1030   Value opA = itA->second, opB = itB->second, opC = itC->second;
1031   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1032       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1033       /*b_transpose=*/UnitAttr());
1034   valueMapping[op.getResult()] = matmul;
1035   return success();
1036 }
1037 
1038 static LogicalResult
1039 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1040                            llvm::DenseMap<Value, Value> &valueMapping) {
1041   OpBuilder::InsertionGuard g(rewriter);
1042   rewriter.setInsertionPoint(op);
1043 
1044   auto itA = valueMapping.find(op.getLhs());
1045   auto itB = valueMapping.find(op.getRhs());
1046   auto itC = valueMapping.find(op.getAcc());
1047   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1048       itC == valueMapping.end())
1049     return rewriter.notifyMatchFailure(op, "no mapping");
1050   Value opA = itA->second, opB = itB->second, opC = itC->second;
1051   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1052   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1053   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1054   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1055       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1056   valueMapping[op.getResult()] = matmul;
1057   return success();
1058 }
1059 
1060 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1061 static LogicalResult
1062 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1063                   llvm::DenseMap<Value, Value> &valueMapping) {
1064   OpBuilder::InsertionGuard g(rewriter);
1065   rewriter.setInsertionPoint(op);
1066 
1067   assert(constantSupportsMMAMatrixType(op));
1068 
1069   auto splat =
1070       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1071   auto scalarConstant =
1072       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1073   const char *fragType = inferFragType(op);
1074   auto vecType = cast<VectorType>(op.getType());
1075   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1076       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1077   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1078       op.getLoc(), type, scalarConstant);
1079   valueMapping[op.getResult()] = matrix;
1080   return success();
1081 }
1082 
1083 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1084 static LogicalResult
1085 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1086                    llvm::DenseMap<Value, Value> &valueMapping) {
1087   OpBuilder::InsertionGuard g(rewriter);
1088   rewriter.setInsertionPoint(op);
1089 
1090   assert(broadcastSupportsMMAMatrixType(op));
1091 
1092   const char *fragType = inferFragType(op);
1093   auto vecType = op.getResultVectorType();
1094   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1095       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1096   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1097       op.getLoc(), type, op.getSource());
1098   valueMapping[op.getResult()] = matrix;
1099   return success();
1100 }
1101 
1102 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1103 // updated and needs to be updated separately for the loop to be correct.
1104 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1105                                                scf::ForOp loop,
1106                                                ValueRange newInitArgs) {
1107   OpBuilder::InsertionGuard g(rewriter);
1108   rewriter.setInsertionPoint(loop);
1109 
1110   // Create a new loop before the existing one, with the extra operands.
1111   rewriter.setInsertionPoint(loop);
1112   auto operands = llvm::to_vector<4>(loop.getInitArgs());
1113   llvm::append_range(operands, newInitArgs);
1114   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1115       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1116       operands);
1117   newLoop.getBody()->erase();
1118 
1119   newLoop.getRegion().getBlocks().splice(
1120       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1121   for (Value operand : newInitArgs)
1122     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1123 
1124   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1125                                                   loop.getNumResults())))
1126     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1127 
1128   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1129   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1130   LLVM_DEBUG(DBGS() << "erase: " << loop);
1131 
1132   rewriter.eraseOp(loop);
1133   return newLoop;
1134 }
1135 
1136 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1137                                   llvm::DenseMap<Value, Value> &valueMapping) {
1138   OpBuilder::InsertionGuard g(rewriter);
1139   rewriter.setInsertionPoint(op);
1140 
1141   SmallVector<Value> newOperands;
1142   SmallVector<std::pair<size_t, size_t>> argMapping;
1143   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1144     auto it = valueMapping.find(operand.value());
1145     if (it == valueMapping.end()) {
1146       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1147       continue;
1148     }
1149     argMapping.push_back(std::make_pair(
1150         operand.index(), op.getInitArgs().size() + newOperands.size()));
1151     newOperands.push_back(it->second);
1152   }
1153 
1154   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1155   Block &loopBody = *newForOp.getBody();
1156   for (auto mapping : argMapping) {
1157     valueMapping[newForOp.getResult(mapping.first)] =
1158         newForOp.getResult(mapping.second);
1159     valueMapping[loopBody.getArgument(mapping.first +
1160                                       newForOp.getNumInductionVars())] =
1161         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1162   }
1163 
1164   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1165   return success();
1166 }
1167 
1168 static LogicalResult
1169 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1170                llvm::DenseMap<Value, Value> &valueMapping) {
1171   OpBuilder::InsertionGuard g(rewriter);
1172   rewriter.setInsertionPoint(op);
1173 
1174   auto loop = cast<scf::ForOp>(op->getParentOp());
1175   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1176   for (const auto &operand : llvm::enumerate(op.getOperands())) {
1177     auto it = valueMapping.find(operand.value());
1178     if (it == valueMapping.end())
1179       continue;
1180     // Replace the yield of old value with the for op argument to make it easier
1181     // to remove the dead code.
1182     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1183     yieldOperands.push_back(it->second);
1184   }
1185   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1186 
1187   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1188   rewriter.eraseOp(op);
1189   return success();
1190 }
1191 
1192 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1193 static LogicalResult
1194 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1195                      gpu::MMAElementwiseOp opType,
1196                      llvm::DenseMap<Value, Value> &valueMapping) {
1197   OpBuilder::InsertionGuard g(rewriter);
1198   rewriter.setInsertionPoint(op);
1199 
1200   SmallVector<Value> matrixOperands;
1201   for (Value operand : op->getOperands()) {
1202     auto it = valueMapping.find(operand);
1203     if (it == valueMapping.end())
1204       return rewriter.notifyMatchFailure(op, "no mapping");
1205     matrixOperands.push_back(it->second);
1206   }
1207   auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
1208   if (opType == gpu::MMAElementwiseOp::EXTF) {
1209     // The floating point extension case has a different result type.
1210     auto vectorType = op->getResultTypes()[0].cast<VectorType>();
1211     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1212                                          vectorType.getElementType(),
1213                                          resultType.getOperand());
1214   }
1215 
1216   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1217       op->getLoc(), resultType, matrixOperands, opType);
1218   valueMapping[op->getResult(0)] = newOp;
1219   return success();
1220 }
1221 
1222 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1223                                               bool useNvGpu) {
1224   if (!useNvGpu) {
1225     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1226         patterns.getContext());
1227     return;
1228   }
1229   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1230   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1231 }
1232 
1233 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1234                                           Operation *rootOp) {
1235   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1236   llvm::DenseMap<Value, Value> valueMapping;
1237 
1238   auto globalRes = LogicalResult::success();
1239   for (Operation *op : ops) {
1240     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1241     // Apparently callers do not want to early exit on failure here.
1242     auto res = LogicalResult::success();
1243     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1244       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1245     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1246       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1247     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1248       res = convertContractOp(rewriter, contractOp, valueMapping);
1249     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1250       res = convertConstantOp(rewriter, constantOp, valueMapping);
1251     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1252       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1253     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1254       res = convertForOp(rewriter, forOp, valueMapping);
1255     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1256       res = convertYieldOp(rewriter, yieldOp, valueMapping);
1257     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1258       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1259     }
1260     if (failed(res))
1261       globalRes = failure();
1262   }
1263   return globalRes;
1264 }
1265 
1266 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1267                                                          Operation *rootOp) {
1268   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1269   llvm::DenseMap<Value, Value> valueMapping;
1270   for (Operation *op : ops) {
1271     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1272             .Case([&](vector::TransferReadOp transferReadOp) {
1273               return convertTransferReadToLoads(rewriter, transferReadOp,
1274                                                 valueMapping);
1275             })
1276             .Case([&](vector::TransferWriteOp transferWriteOp) {
1277               return convertTransferWriteToStores(rewriter, transferWriteOp,
1278                                                   valueMapping);
1279             })
1280             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1281               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1282                                                 valueMapping);
1283             })
1284             .Case([&](vector::ContractionOp contractionOp) {
1285               return convertContractOpToMmaSync(rewriter, contractionOp,
1286                                                 valueMapping);
1287             })
1288             .Case([&](scf::ForOp forOp) {
1289               return convertForOp(rewriter, forOp, valueMapping);
1290             })
1291             .Case([&](scf::YieldOp yieldOp) {
1292               return convertYieldOp(rewriter, yieldOp, valueMapping);
1293             })
1294             .Case([&](arith::ConstantOp constOp) {
1295               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1296             })
1297             .Default([&](Operation *op) {
1298               return op->emitError() << "unhandled vector to mma type: " << *op;
1299             })
1300             .failed()) {
1301       return op->emitOpError()
1302              << "failed to convert op during vector-to-nvgpu conversion";
1303     }
1304   }
1305   return success();
1306 }
1307 
1308 namespace {
1309 
1310 struct ConvertVectorToGPUPass
1311     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1312 
1313   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1314     useNvGpu.setValue(useNvGpu_);
1315   }
1316 
1317   void runOnOperation() override {
1318     RewritePatternSet patterns(&getContext());
1319     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1320     if (failed(
1321             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1322       return signalPassFailure();
1323 
1324     IRRewriter rewriter(&getContext());
1325     if (useNvGpu) {
1326       if (failed(
1327               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1328         return signalPassFailure();
1329       return;
1330     }
1331     (void)convertVectorToMMAOps(rewriter, getOperation());
1332   }
1333 };
1334 
1335 } // namespace
1336 
1337 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1338   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1339 }
1340