xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision 847d8457d16a7334ba39bdd35c70faa1b295304d)
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arith/IR/Arith.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Region.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58                            AffineMap offsetMap, ArrayRef<Value> dimValues,
59                            SmallVector<Value, 4> &indices) {
60   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61   Location loc = xferOp.getLoc();
62   unsigned offsetsIdx = 0;
63   for (auto expr : xferOp.getPermutationMap().getResults()) {
64     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65       Value prevIdx = indices[dim.getPosition()];
66       SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67       dims.push_back(prevIdx);
68       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69       indices[dim.getPosition()] = affine::makeComposedAffineApply(
70           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71       continue;
72     }
73   }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78                                           bool useNvGpu) {
79   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
81   AffineExpr m, n, k;
82   bindDims(contract.getContext(), m, n, k);
83   auto iteratorTypes = contract.getIteratorTypes().getValue();
84   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
85         vector::isParallelIterator(iteratorTypes[1]) &&
86         vector::isReductionIterator(iteratorTypes[2])))
87     return false;
88 
89   // The contract needs to represent a matmul to be able to convert to
90   // MMAMatrix matmul.
91   if (!useNvGpu &&
92       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
93     return false;
94   if (useNvGpu &&
95       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
96     return false;
97 
98   return true;
99 }
100 
101 // Return true if the given map represents a transposed matrix load,
102 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
104   MLIRContext *ctx = permutationMap.getContext();
105   // Local OpBuilder is fine here, we just build attributes.
106   OpBuilder b(ctx);
107   auto nDim = permutationMap.getNumDims();
108   AffineExpr zero = b.getAffineConstantExpr(0);
109   if (nDim < 2) {
110     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
111     AffineExpr dim0 = b.getAffineDimExpr(0);
112     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
113   }
114 
115   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
116   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
117   // Support both transposed and transposed+broadcasted cases.
118   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
119          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
120 }
121 
122 // Return the stide for the second-to-last dimension of |type| if it is a memref
123 // and has a constant stride.
124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
125   auto memrefType = dyn_cast<MemRefType>(type);
126   if (!memrefType)
127     return false;
128   // If the memref is 0 or 1D the horizontal stride is 0.
129   if (memrefType.getRank() < 2)
130     return 0;
131   int64_t offset = 0;
132   SmallVector<int64_t, 2> strides;
133   if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
134       strides.back() != 1)
135     return std::nullopt;
136   int64_t stride = strides[strides.size() - 2];
137   if (stride == ShapedType::kDynamic)
138     return std::nullopt;
139   return stride;
140 }
141 
142 // Return true if the transfer op can be converted to a MMA matrix load.
143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
144   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
145       readOp.getVectorType().getRank() != 2)
146     return false;
147   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
148     return false;
149 
150   // Only allow integer types if the signedness can be inferred.
151   if (readOp.getVectorType().getElementType().isInteger(8))
152     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
153                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
154       return false;
155 
156   AffineMap map = readOp.getPermutationMap();
157   MLIRContext *ctx = readOp.getContext();
158   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
159   AffineExpr zero = getAffineConstantExpr(0, ctx);
160   auto broadcastInnerDim =
161       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
162   return map.isMinorIdentity() || map == broadcastInnerDim ||
163          isTransposeMatrixLoadMap(map);
164 }
165 
166 // Return true if the transfer op can be converted to a MMA matrix store.
167 static bool
168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
169   // TODO: support 0-d corner case.
170   if (writeOp.getTransferRank() == 0)
171     return false;
172 
173   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
174       writeOp.getVectorType().getRank() != 2)
175     return false;
176   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
177     return false;
178   // TODO: Support transpose once it is added to GPU dialect ops.
179   if (!writeOp.getPermutationMap().isMinorIdentity())
180     return false;
181   return true;
182 }
183 
184 /// Return true if the constant is a splat to a 2D vector so that it can be
185 /// converted to a MMA constant matrix op.
186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
187   auto vecType = dyn_cast<VectorType>(constantOp.getType());
188   if (!vecType || vecType.getRank() != 2)
189     return false;
190   return isa<SplatElementsAttr>(constantOp.getValue());
191 }
192 
193 /// Return true if this is a broadcast from scalar to a 2D vector.
194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
195   return broadcastOp.getResultVectorType().getRank() == 2;
196 }
197 
198 /// Return true if this integer extend op can be folded into a contract op.
199 template <typename ExtOpTy>
200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
201   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
202     return false;
203   return llvm::all_of(extOp->getUsers(), [](Operation *user) {
204     return isa<vector::ContractionOp>(user);
205   });
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
213 convertElementwiseOpToMMA(Operation *op) {
214   if (isa<arith::AddFOp>(op))
215     return gpu::MMAElementwiseOp::ADDF;
216   if (isa<arith::MulFOp>(op))
217     return gpu::MMAElementwiseOp::MULF;
218   if (isa<arith::SubFOp>(op))
219     return gpu::MMAElementwiseOp::SUBF;
220   if (isa<arith::MaximumFOp>(op))
221     return gpu::MMAElementwiseOp::MAXF;
222   if (isa<arith::MinimumFOp>(op))
223     return gpu::MMAElementwiseOp::MINF;
224   if (isa<arith::DivFOp>(op))
225     return gpu::MMAElementwiseOp::DIVF;
226   if (isa<arith::AddIOp>(op))
227     return gpu::MMAElementwiseOp::ADDI;
228   if (isa<arith::MulIOp>(op))
229     return gpu::MMAElementwiseOp::MULI;
230   if (isa<arith::SubIOp>(op))
231     return gpu::MMAElementwiseOp::SUBI;
232   if (isa<arith::DivSIOp>(op))
233     return gpu::MMAElementwiseOp::DIVS;
234   if (isa<arith::DivUIOp>(op))
235     return gpu::MMAElementwiseOp::DIVU;
236   if (isa<arith::NegFOp>(op))
237     return gpu::MMAElementwiseOp::NEGATEF;
238   if (isa<arith::ExtFOp>(op))
239     return gpu::MMAElementwiseOp::EXTF;
240   return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
244 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
245   return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
254       nvgpu::getWarpMatrixInfo(op);
255   if (failed(warpMatrixInfo))
256     return false;
257 
258   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259   if (failed(contractOp))
260     return false;
261 
262   // Handle vector.extract_strided_slice on registers containing
263   // matrixB and matrixC operands. vector.extract_strided_slice op
264   // is not supported on registers containing matrixA operands.
265   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266     return (cast<VectorType>(op->getResult(0).getType()) ==
267             cast<VectorType>((*contractOp).getRhs().getType()));
268   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269     return (cast<VectorType>(op->getResult(0).getType()) ==
270             cast<VectorType>((*contractOp).getAcc().getType()));
271 
272   return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276   if (isa<scf::ForOp, scf::YieldOp>(op))
277     return true;
278   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280                     : transferReadSupportsMMAMatrixType(transferRead);
281   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283                     : transferWriteSupportsMMAMatrixType(transferWrite);
284   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285     return useNvGpu &&
286            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287   if (auto contract = dyn_cast<vector::ContractionOp>(op))
288     return contractSupportsMMAMatrixType(contract, useNvGpu);
289   if (auto constant = dyn_cast<arith::ConstantOp>(op))
290     return constantSupportsMMAMatrixType(constant);
291   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
292     return broadcastSupportsMMAMatrixType(broadcast);
293   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298     return fpExtendSupportsMMAMatrixType(fpExtend);
299   return elementwiseSupportsMMAMatrixType(op);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
305 static SetVector<Operation *>
306 getSliceContract(Operation *op,
307                  const BackwardSliceOptions &backwardSliceOptions,
308                  const ForwardSliceOptions &forwardSliceOptions) {
309   SetVector<Operation *> slice;
310   slice.insert(op);
311   unsigned currentIndex = 0;
312   SetVector<Operation *> backwardSlice;
313   SetVector<Operation *> forwardSlice;
314   while (currentIndex != slice.size()) {
315     auto *currentOp = (slice)[currentIndex];
316     // Compute and insert the backwardSlice starting from currentOp.
317     backwardSlice.clear();
318     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
319     slice.insert(backwardSlice.begin(), backwardSlice.end());
320 
321     // Compute and insert the forwardSlice starting from currentOp.
322     forwardSlice.clear();
323     // Special case for ForOp, we don't want to include the whole region but
324     // only the value using the region arguments.
325     // TODO: We should refine this to only care about the region arguments being
326     // converted to matrix type.
327     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328       for (Value forOpResult : forOp.getResults())
329         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330       for (BlockArgument &arg : forOp.getRegionIterArgs())
331         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332     } else {
333       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334     }
335     slice.insert(forwardSlice.begin(), forwardSlice.end());
336     ++currentIndex;
337   }
338   return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
343 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
344                                              bool useNvGpu) {
345   auto hasVectorDest = [](Operation *op) {
346     return llvm::any_of(op->getResultTypes(),
347                         [](Type t) { return isa<VectorType>(t); });
348   };
349   BackwardSliceOptions backwardSliceOptions;
350   backwardSliceOptions.filter = hasVectorDest;
351 
352   auto hasVectorSrc = [](Operation *op) {
353     return llvm::any_of(op->getOperandTypes(),
354                         [](Type t) { return isa<VectorType>(t); });
355   };
356   ForwardSliceOptions forwardSliceOptions;
357   forwardSliceOptions.filter = hasVectorSrc;
358 
359   SetVector<Operation *> opToConvert;
360   op->walk([&](vector::ContractionOp contract) {
361     if (opToConvert.contains(contract.getOperation()))
362       return;
363     SetVector<Operation *> dependentOps =
364         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365     // If any instruction cannot use MMA matrix type drop the whole
366     // chain. MMA matrix are stored in an opaque type so they cannot be used
367     // by all operations.
368     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369           if (!supportsMMaMatrixType(op, useNvGpu)) {
370             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371             return true;
372           }
373           return false;
374         }))
375       return;
376 
377     opToConvert.insert(dependentOps.begin(), dependentOps.end());
378   });
379   // Sort the operations so that we can convert them in topological order.
380   return topologicalSort(opToConvert);
381 }
382 
383 namespace {
384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385 // to MMA matmul.
386 struct PrepareContractToGPUMMA
387     : public OpRewritePattern<vector::ContractionOp> {
388   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
389 
390   LogicalResult matchAndRewrite(vector::ContractionOp op,
391                                 PatternRewriter &rewriter) const override {
392     Location loc = op.getLoc();
393     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394 
395     // Set up the parallel/reduction structure in right form.
396     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
398     AffineExpr m, n, k;
399     bindDims(rewriter.getContext(), m, n, k);
400     static constexpr std::array<int64_t, 2> perm = {1, 0};
401     auto iteratorTypes = op.getIteratorTypes().getValue();
402     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
404           vector::isParallelIterator(iteratorTypes[1]) &&
405           vector::isReductionIterator(iteratorTypes[2])))
406       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407     //
408     // Two outer parallel, one inner reduction (matmat flavor).
409     //
410     // This is the classical row-major matmul, nothing to do.
411     if (maps == infer({{m, k}, {k, n}, {m, n}}))
412       return rewriter.notifyMatchFailure(op, "contraction already prepared");
413     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421       std::swap(rhs, lhs);
422       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425       std::swap(rhs, lhs);
426       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428       std::swap(lhs, rhs);
429       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431       std::swap(lhs, rhs);
432     } else {
433       // TODO: llvm_unreachable ?
434       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435     }
436     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437         op, lhs, rhs, res,
438         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
439         op.getIteratorTypes());
440     return success();
441   }
442 };
443 
444 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446 // respectively. We can fold the transpose operation when loading the data from
447 // Shared Memory to registers.
448 struct CombineTransferReadOpTranspose final
449     : public OpRewritePattern<vector::TransposeOp> {
450   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
451 
452   LogicalResult matchAndRewrite(vector::TransposeOp op,
453                                 PatternRewriter &rewriter) const override {
454     // Look through integer extend ops.
455     Value source = op.getVector();
456     Type resultType = op.getType();
457     Operation *extOp;
458     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459         (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460         (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461       source = extOp->getOperand(0);
462       resultType =
463           VectorType::get(cast<VectorType>(resultType).getShape(),
464                           cast<VectorType>(source.getType()).getElementType());
465     }
466 
467     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468     if (!transferReadOp)
469       return rewriter.notifyMatchFailure(op, "no transfer read");
470 
471     // TODO: support 0-d corner case.
472     if (transferReadOp.getTransferRank() == 0)
473       return rewriter.notifyMatchFailure(op, "0-D transfer read");
474 
475     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477 
478     AffineMap permutationMap =
479         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480     AffineMap newMap =
481         permutationMap.compose(transferReadOp.getPermutationMap());
482 
483     auto loc = op.getLoc();
484     Value result =
485         rewriter
486             .create<vector::TransferReadOp>(
487                 loc, resultType, transferReadOp.getSource(),
488                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489                 transferReadOp.getPadding(), transferReadOp.getMask(),
490                 transferReadOp.getInBoundsAttr())
491             .getResult();
492 
493     // Fuse through the integer extend op.
494     if (extOp) {
495       if (isa<arith::ExtSIOp>(extOp))
496         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497                      .getResult();
498       else if (isa<arith::ExtUIOp>(extOp))
499         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500                      .getResult();
501       else
502         result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503                      .getResult();
504     }
505 
506     rewriter.replaceOp(op, result);
507     return success();
508   }
509 };
510 
511 } // namespace
512 
513 // MMA types have different layout based on how they are used in matmul ops.
514 // Figure the right layout to use by looking at op uses.
515 // TODO: Change the GPU dialect to abstract the layout at the this level and
516 // only care about it during lowering to NVVM.
517 static const char *inferFragType(Operation *op) {
518   for (Operation *users : op->getUsers()) {
519     auto contract = dyn_cast<vector::ContractionOp>(users);
520     if (!contract)
521       continue;
522     assert(op->getNumResults() == 1);
523     if (contract.getLhs() == op->getResult(0))
524       return "AOp";
525     if (contract.getRhs() == op->getResult(0))
526       return "BOp";
527   }
528   return "COp";
529 }
530 
531 static LogicalResult
532 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
533                       llvm::DenseMap<Value, Value> &valueMapping) {
534   OpBuilder::InsertionGuard g(rewriter);
535   rewriter.setInsertionPoint(op);
536 
537   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
538   assert(transferReadSupportsMMAMatrixType(op) &&
539          "expected convertible operation");
540 
541   std::optional<int64_t> stride =
542       getStaticallyKnownRowStride(op.getShapedType());
543   if (!stride.has_value()) {
544     LLVM_DEBUG(DBGS() << "no stride\n");
545     return rewriter.notifyMatchFailure(op, "no stride");
546   }
547 
548   AffineMap map = op.getPermutationMap();
549   bool isTranspose = isTransposeMatrixLoadMap(map);
550 
551   // Handle broadcast by setting the stride to 0.
552   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
553     assert(cstExpr.getValue() == 0);
554     stride = 0;
555   }
556 
557   Value mappingResult = op.getResult();
558   auto elType = op.getVectorType().getElementType();
559   const char *fragType = inferFragType(op);
560   if (op->hasOneUse()) {
561     auto *user = *op->user_begin();
562     // Infer the signedness of the mma type from the integer extend.
563     bool isSignedExtend = isa<arith::ExtSIOp>(user);
564     if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
565       elType = IntegerType::get(
566           op.getContext(), cast<IntegerType>(elType).getWidth(),
567           isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
568       mappingResult = user->getResult(0);
569       fragType = inferFragType(user);
570     }
571   }
572   gpu::MMAMatrixType type =
573       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
574   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
575       op.getLoc(), type, op.getSource(), op.getIndices(),
576       rewriter.getIndexAttr(*stride),
577       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
578   valueMapping[mappingResult] = load;
579 
580   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
581   return success();
582 }
583 
584 static LogicalResult
585 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
586                        llvm::DenseMap<Value, Value> &valueMapping) {
587   OpBuilder::InsertionGuard g(rewriter);
588   rewriter.setInsertionPoint(op);
589 
590   assert(transferWriteSupportsMMAMatrixType(op));
591   std::optional<int64_t> stride =
592       getStaticallyKnownRowStride(op.getShapedType());
593   if (!stride.has_value()) {
594     LLVM_DEBUG(DBGS() << "no stride\n");
595     return rewriter.notifyMatchFailure(op, "no stride");
596   }
597 
598   auto it = valueMapping.find(op.getVector());
599   if (it == valueMapping.end()) {
600     LLVM_DEBUG(DBGS() << "no mapping\n");
601     return rewriter.notifyMatchFailure(op, "no mapping");
602   }
603 
604   Value matrix = it->second;
605   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
606       op.getLoc(), matrix, op.getSource(), op.getIndices(),
607       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
608   (void)store;
609 
610   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
611 
612   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
613   rewriter.eraseOp(op);
614   return success();
615 }
616 
617 /// Returns the vector type which represents a matrix fragment.
618 static VectorType
619 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
620   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
621                              regInfo.elementsPerRegister};
622   Type elType = regInfo.registerLLVMType;
623   if (auto vecType = dyn_cast<VectorType>(elType))
624     elType = vecType.getElementType();
625   return VectorType::get(shape, elType);
626 }
627 
628 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
629 static LogicalResult
630 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
631                          llvm::DenseMap<Value, Value> &valueMapping) {
632   OpBuilder::InsertionGuard g(rewriter);
633   rewriter.setInsertionPoint(op);
634 
635   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
636       nvgpu::getWarpMatrixInfo(op);
637   if (failed(warpMatrixInfo)) {
638     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
639     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
640   }
641 
642   FailureOr<nvgpu::FragmentElementInfo> regInfo =
643       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
644   if (failed(regInfo)) {
645     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
646     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
647   }
648 
649   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
650   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
651   if (!dense) {
652     LLVM_DEBUG(DBGS() << "not a splat\n");
653     return rewriter.notifyMatchFailure(op, "not a splat");
654   }
655 
656   Value result = rewriter.create<arith::ConstantOp>(
657       op.getLoc(), vectorType,
658       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
659   valueMapping[op.getResult()] = result;
660   return success();
661 }
662 
663 /// Check if the loaded matrix operand requires transposed.
664 /// Transposed Map Example:
665 /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
666 /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
667 /// The code below checks if the output 2D is transposed using a generalized
668 /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
669 /// Returns     : true; if m > n, false o.w.
670 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
671   mlir::AffineMap map = op.getPermutationMap();
672 
673   if (map.getNumResults() != 2) {
674     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
675                          "is not a 2d operand\n");
676     return failure();
677   }
678 
679   // Output 2D matrix dimensions in the order of d0, d1.
680   mlir::AffineExpr dM = map.getResult(0);
681   mlir::AffineExpr dN = map.getResult(1);
682 
683   //  Find the position of these expressions in the input.
684   auto exprM = dyn_cast<AffineDimExpr>(dM);
685   auto exprN = dyn_cast<AffineDimExpr>(dN);
686 
687   if (!exprM || !exprN) {
688     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
689                          "expressions, then transpose cannot be determined.\n");
690     return failure();
691   }
692 
693   return exprM.getPosition() > exprN.getPosition();
694 }
695 
696 static LogicalResult
697 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
698                              llvm::DenseMap<Value, Value> &valueMapping) {
699   OpBuilder::InsertionGuard g(rewriter);
700   rewriter.setInsertionPoint(op);
701   Location loc = op->getLoc();
702 
703   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
704       nvgpu::getWarpMatrixInfo(op);
705   if (failed(warpMatrixInfo)) {
706     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
707     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
708   }
709 
710   FailureOr<nvgpu::FragmentElementInfo> regInfo =
711       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
712   if (failed(regInfo)) {
713     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
714     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
715   }
716 
717   FailureOr<bool> transpose = isTransposed(op);
718   if (failed(transpose)) {
719     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
720     return rewriter.notifyMatchFailure(
721         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
722   }
723 
724   FailureOr<nvgpu::LdMatrixParams> params =
725       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
726 
727   if (failed(params)) {
728     LLVM_DEBUG(
729         DBGS()
730         << "failed to convert vector.transfer_read to ldmatrix. "
731         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
732     return rewriter.notifyMatchFailure(
733         op, "failed to convert vector.transfer_read to ldmatrix; this op "
734             "likely should not be converted to a nvgpu.ldmatrix call.");
735   }
736 
737   // Adjust the load offset.
738   auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
739   FailureOr<AffineMap> offsets =
740       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
741   if (failed(offsets)) {
742     LLVM_DEBUG(DBGS() << "no offsets\n");
743     return rewriter.notifyMatchFailure(op, "no offsets");
744   }
745 
746   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
747 
748   SmallVector<Value, 4> indices;
749   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
750                                          indices);
751 
752   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
753       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
754   valueMapping[op] = newOp->getResult(0);
755   return success();
756 }
757 
758 static LogicalResult
759 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
760                        llvm::DenseMap<Value, Value> &valueMapping) {
761   OpBuilder::InsertionGuard g(rewriter);
762   rewriter.setInsertionPoint(op);
763 
764   Location loc = op.getLoc();
765   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
766       nvgpu::getWarpMatrixInfo(op);
767   if (failed(warpMatrixInfo))
768     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
769   FailureOr<nvgpu::FragmentElementInfo> regInfo =
770       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
771   if (failed(regInfo)) {
772     return rewriter.notifyMatchFailure(
773         op, "Failed to deduce register fragment type during "
774             "conversion to distributed non-ldmatrix compatible load");
775   }
776 
777   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
778   SmallVector<Value, 4> elements;
779 
780   // This is the individual element type.
781   Type loadedElType = regInfo->registerLLVMType;
782   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
783 
784   Value fill = rewriter.create<arith::ConstantOp>(
785       op.getLoc(), vectorType.getElementType(),
786       rewriter.getZeroAttr(vectorType.getElementType()));
787   Value result =
788       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
789 
790   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
791 
792   // If we are not transposing, then we can use vectorized loads. Otherwise, we
793   // must load each element individually.
794   if (!isTransposeLoad) {
795     if (!isa<VectorType>(loadedElType)) {
796       loadedElType = VectorType::get({1}, loadedElType);
797     }
798 
799     for (int i = 0; i < vectorType.getShape()[0]; i++) {
800       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
801           rewriter, op.getLoc(), *warpMatrixInfo);
802       if (failed(coords))
803         return rewriter.notifyMatchFailure(op, "no coords");
804 
805       Value logicalValueId = rewriter.create<arith::ConstantOp>(
806           loc, rewriter.getIndexType(),
807           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
808       SmallVector<Value, 4> newIndices;
809       getXferIndices<vector::TransferReadOp>(
810           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
811 
812       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
813                                                  op.getSource(), newIndices);
814       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
815     }
816   } else {
817     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
818       loadedElType = vecType.getElementType();
819     }
820     for (int i = 0; i < vectorType.getShape()[0]; i++) {
821       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
822            innerIdx++) {
823 
824         Value logicalValueId = rewriter.create<arith::ConstantOp>(
825             loc, rewriter.getIndexType(),
826             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
827         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
828             rewriter, op.getLoc(), *warpMatrixInfo);
829         if (failed(coords))
830           return rewriter.notifyMatchFailure(op, "no coords");
831 
832         SmallVector<Value, 4> newIndices;
833         getXferIndices<vector::TransferReadOp>(
834             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
835         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
836                                                    op.getSource(), newIndices);
837         result = rewriter.create<vector::InsertOp>(
838             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
839       }
840     }
841   }
842 
843   valueMapping[op.getResult()] = result;
844   return success();
845 }
846 
847 /// Return true if this is a shared memory memref type.
848 static bool isSharedMemory(MemRefType type) {
849   auto addressSpace =
850       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
851   if (addressSpace &&
852       addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
853     return true;
854   return false;
855 }
856 
857 /// Converts a `vector.transfer_read` operation directly to either a
858 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
859 /// used when converting to `nvgpu.mma.sync` operations.
860 static LogicalResult
861 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
862                            llvm::DenseMap<Value, Value> &valueMapping) {
863   OpBuilder::InsertionGuard g(rewriter);
864   rewriter.setInsertionPoint(op);
865 
866   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
867       nvgpu::getWarpMatrixInfo(op);
868   if (failed(warpMatrixInfo))
869     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
870 
871   bool isLdMatrixCompatible =
872       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
873       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
874 
875   VectorType vecTy = op.getVectorType();
876   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
877 
878   // When we are transposing the B operand, ldmatrix will only work if we have
879   // at least 8 rows to read and the width to read for the transpose is 128
880   // bits.
881   if (!op.getPermutationMap().isMinorIdentity() &&
882       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
883        vecTy.getDimSize(0) * bitWidth < 128))
884     isLdMatrixCompatible = false;
885 
886   if (!isLdMatrixCompatible)
887     return createNonLdMatrixLoads(rewriter, op, valueMapping);
888 
889   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
890 }
891 
892 static LogicalResult
893 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
894                              llvm::DenseMap<Value, Value> &valueMapping) {
895   OpBuilder::InsertionGuard g(rewriter);
896   rewriter.setInsertionPoint(op);
897 
898   Location loc = op->getLoc();
899   auto it = valueMapping.find(op.getVector());
900   if (it == valueMapping.end())
901     return rewriter.notifyMatchFailure(op, "no mapping");
902   Value matrix = it->second;
903 
904   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
905       nvgpu::getWarpMatrixInfo(op);
906   if (failed(warpMatrixInfo))
907     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
908   FailureOr<nvgpu::FragmentElementInfo> regInfo =
909       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
910   if (failed(regInfo))
911     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
912 
913   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
914   Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
915 
916   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
917     Value logicalValueId = rewriter.create<arith::ConstantOp>(
918         loc, rewriter.getIndexType(),
919         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
920     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
921         rewriter, op.getLoc(), *warpMatrixInfo);
922     if (failed(coords))
923       return rewriter.notifyMatchFailure(op, "no coords");
924 
925     Value el =
926         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
927     SmallVector<Value, 4> newIndices;
928     getXferIndices<vector::TransferWriteOp>(
929         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
930     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
931   }
932 
933   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
934   rewriter.eraseOp(op);
935   return success();
936 }
937 
938 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
939                                        SmallVectorImpl<int64_t> &results) {
940   for (auto attr : arrayAttr)
941     results.push_back(cast<IntegerAttr>(attr).getInt());
942 }
943 
944 static LogicalResult
945 convertExtractStridedSlice(RewriterBase &rewriter,
946                            vector::ExtractStridedSliceOp op,
947                            llvm::DenseMap<Value, Value> &valueMapping) {
948   OpBuilder::InsertionGuard g(rewriter);
949   rewriter.setInsertionPoint(op);
950 
951   Location loc = op->getLoc();
952 
953   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
954       nvgpu::getWarpMatrixInfo(op);
955   if (failed(warpMatrixInfo))
956     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
957 
958   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
959       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
960   if (failed(mmaSyncFragmentInfo))
961     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
962 
963   // Find the vector.transer_read whose result vector is being sliced.
964   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
965   if (!transferReadOp)
966     return rewriter.notifyMatchFailure(op, "no transfer read");
967 
968   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
969   if (failed(warpMatrixInfo))
970     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
971 
972   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
973       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
974   if (failed(ldFragmentInfo))
975     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
976 
977   assert(
978       (mmaSyncFragmentInfo->elementsPerRegister ==
979        ldFragmentInfo->elementsPerRegister) &&
980       "Number of elements per register should be same for load and mma.sync");
981 
982   // Create vector.extract_strided_slice op for thread-owned fragments.
983   std::array<int64_t, 2> strides = {1,
984                                     1}; // stride for extract slice is always 1.
985   std::array<int64_t, 2> sliceShape = {
986       mmaSyncFragmentInfo->numRegistersPerFragment,
987       mmaSyncFragmentInfo->elementsPerRegister};
988   auto it = valueMapping.find(transferReadOp);
989   if (it == valueMapping.end())
990     return rewriter.notifyMatchFailure(op, "no mapping");
991   auto sourceVector = it->second;
992 
993   // offset and sizes at warp-level of onwership.
994   SmallVector<int64_t> offsets;
995   populateFromInt64AttrArray(op.getOffsets(), offsets);
996 
997   SmallVector<int64_t> sizes;
998   populateFromInt64AttrArray(op.getSizes(), sizes);
999   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1000 
1001   // Compute offset in vector registers. Note that the mma.sync vector registers
1002   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1003   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1004   std::array<int64_t, 2> sliceOffset = {0, 0};
1005 
1006   if (offsets[0] && offsets[1])
1007     return op->emitError() << "Slicing fragments in 2D is not supported. ";
1008   if (offsets[0])
1009     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1010   else if (offsets[1])
1011     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1012 
1013   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1014       loc, sourceVector, sliceOffset, sliceShape, strides);
1015 
1016   valueMapping[op] = newOp;
1017   return success();
1018 }
1019 
1020 static LogicalResult
1021 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1022                   llvm::DenseMap<Value, Value> &valueMapping) {
1023   OpBuilder::InsertionGuard g(rewriter);
1024   rewriter.setInsertionPoint(op);
1025 
1026   auto itA = valueMapping.find(op.getLhs());
1027   auto itB = valueMapping.find(op.getRhs());
1028   auto itC = valueMapping.find(op.getAcc());
1029   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1030       itC == valueMapping.end())
1031     return rewriter.notifyMatchFailure(op, "no mapping");
1032   Value opA = itA->second, opB = itB->second, opC = itC->second;
1033   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1034       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1035       /*b_transpose=*/UnitAttr());
1036   valueMapping[op.getResult()] = matmul;
1037   return success();
1038 }
1039 
1040 static LogicalResult
1041 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1042                            llvm::DenseMap<Value, Value> &valueMapping) {
1043   OpBuilder::InsertionGuard g(rewriter);
1044   rewriter.setInsertionPoint(op);
1045 
1046   auto itA = valueMapping.find(op.getLhs());
1047   auto itB = valueMapping.find(op.getRhs());
1048   auto itC = valueMapping.find(op.getAcc());
1049   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1050       itC == valueMapping.end())
1051     return rewriter.notifyMatchFailure(op, "no mapping");
1052   Value opA = itA->second, opB = itB->second, opC = itC->second;
1053   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1054   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1055   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1056   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1057       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1058   valueMapping[op.getResult()] = matmul;
1059   return success();
1060 }
1061 
1062 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1063 static LogicalResult
1064 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1065                   llvm::DenseMap<Value, Value> &valueMapping) {
1066   OpBuilder::InsertionGuard g(rewriter);
1067   rewriter.setInsertionPoint(op);
1068 
1069   assert(constantSupportsMMAMatrixType(op));
1070 
1071   auto splat =
1072       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1073   auto scalarConstant =
1074       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1075   const char *fragType = inferFragType(op);
1076   auto vecType = cast<VectorType>(op.getType());
1077   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1078       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1079   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1080       op.getLoc(), type, scalarConstant);
1081   valueMapping[op.getResult()] = matrix;
1082   return success();
1083 }
1084 
1085 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1086 static LogicalResult
1087 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1088                    llvm::DenseMap<Value, Value> &valueMapping) {
1089   OpBuilder::InsertionGuard g(rewriter);
1090   rewriter.setInsertionPoint(op);
1091 
1092   assert(broadcastSupportsMMAMatrixType(op));
1093 
1094   const char *fragType = inferFragType(op);
1095   auto vecType = op.getResultVectorType();
1096   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1097       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1098   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1099       op.getLoc(), type, op.getSource());
1100   valueMapping[op.getResult()] = matrix;
1101   return success();
1102 }
1103 
1104 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1105 // updated and needs to be updated separately for the loop to be correct.
1106 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1107                                                scf::ForOp loop,
1108                                                ValueRange newInitArgs) {
1109   OpBuilder::InsertionGuard g(rewriter);
1110   rewriter.setInsertionPoint(loop);
1111 
1112   // Create a new loop before the existing one, with the extra operands.
1113   rewriter.setInsertionPoint(loop);
1114   auto operands = llvm::to_vector<4>(loop.getInitArgs());
1115   llvm::append_range(operands, newInitArgs);
1116   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1117       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1118       operands);
1119   newLoop.getBody()->erase();
1120 
1121   newLoop.getRegion().getBlocks().splice(
1122       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1123   for (Value operand : newInitArgs)
1124     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1125 
1126   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1127                                                   loop.getNumResults())))
1128     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1129 
1130   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1131   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1132   LLVM_DEBUG(DBGS() << "erase: " << loop);
1133 
1134   rewriter.eraseOp(loop);
1135   return newLoop;
1136 }
1137 
1138 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1139                                   llvm::DenseMap<Value, Value> &valueMapping) {
1140   OpBuilder::InsertionGuard g(rewriter);
1141   rewriter.setInsertionPoint(op);
1142 
1143   SmallVector<Value> newOperands;
1144   SmallVector<std::pair<size_t, size_t>> argMapping;
1145   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1146     auto it = valueMapping.find(operand.value());
1147     if (it == valueMapping.end()) {
1148       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1149       continue;
1150     }
1151     argMapping.push_back(std::make_pair(
1152         operand.index(), op.getInitArgs().size() + newOperands.size()));
1153     newOperands.push_back(it->second);
1154   }
1155 
1156   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1157   Block &loopBody = *newForOp.getBody();
1158   for (auto mapping : argMapping) {
1159     valueMapping[newForOp.getResult(mapping.first)] =
1160         newForOp.getResult(mapping.second);
1161     valueMapping[loopBody.getArgument(mapping.first +
1162                                       newForOp.getNumInductionVars())] =
1163         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1164   }
1165 
1166   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1167   return success();
1168 }
1169 
1170 static LogicalResult
1171 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1172                llvm::DenseMap<Value, Value> &valueMapping) {
1173   OpBuilder::InsertionGuard g(rewriter);
1174   rewriter.setInsertionPoint(op);
1175 
1176   auto loop = cast<scf::ForOp>(op->getParentOp());
1177   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1178   for (const auto &operand : llvm::enumerate(op.getOperands())) {
1179     auto it = valueMapping.find(operand.value());
1180     if (it == valueMapping.end())
1181       continue;
1182     // Replace the yield of old value with the for op argument to make it easier
1183     // to remove the dead code.
1184     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1185     yieldOperands.push_back(it->second);
1186   }
1187   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1188 
1189   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1190   rewriter.eraseOp(op);
1191   return success();
1192 }
1193 
1194 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1195 static LogicalResult
1196 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1197                      gpu::MMAElementwiseOp opType,
1198                      llvm::DenseMap<Value, Value> &valueMapping) {
1199   OpBuilder::InsertionGuard g(rewriter);
1200   rewriter.setInsertionPoint(op);
1201 
1202   SmallVector<Value> matrixOperands;
1203   for (Value operand : op->getOperands()) {
1204     auto it = valueMapping.find(operand);
1205     if (it == valueMapping.end())
1206       return rewriter.notifyMatchFailure(op, "no mapping");
1207     matrixOperands.push_back(it->second);
1208   }
1209   auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
1210   if (opType == gpu::MMAElementwiseOp::EXTF) {
1211     // The floating point extension case has a different result type.
1212     auto vectorType = op->getResultTypes()[0].cast<VectorType>();
1213     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1214                                          vectorType.getElementType(),
1215                                          resultType.getOperand());
1216   }
1217 
1218   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1219       op->getLoc(), resultType, matrixOperands, opType);
1220   valueMapping[op->getResult(0)] = newOp;
1221   return success();
1222 }
1223 
1224 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1225                                               bool useNvGpu) {
1226   if (!useNvGpu) {
1227     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1228         patterns.getContext());
1229     return;
1230   }
1231   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1232   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1233 }
1234 
1235 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1236                                           Operation *rootOp) {
1237   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1238   llvm::DenseMap<Value, Value> valueMapping;
1239 
1240   auto globalRes = LogicalResult::success();
1241   for (Operation *op : ops) {
1242     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1243     // Apparently callers do not want to early exit on failure here.
1244     auto res = LogicalResult::success();
1245     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1246       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1247     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1248       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1249     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1250       res = convertContractOp(rewriter, contractOp, valueMapping);
1251     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1252       res = convertConstantOp(rewriter, constantOp, valueMapping);
1253     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1254       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1255     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1256       res = convertForOp(rewriter, forOp, valueMapping);
1257     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1258       res = convertYieldOp(rewriter, yieldOp, valueMapping);
1259     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1260       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1261     }
1262     if (failed(res))
1263       globalRes = failure();
1264   }
1265   return globalRes;
1266 }
1267 
1268 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1269                                                          Operation *rootOp) {
1270   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1271   llvm::DenseMap<Value, Value> valueMapping;
1272   for (Operation *op : ops) {
1273     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1274             .Case([&](vector::TransferReadOp transferReadOp) {
1275               return convertTransferReadToLoads(rewriter, transferReadOp,
1276                                                 valueMapping);
1277             })
1278             .Case([&](vector::TransferWriteOp transferWriteOp) {
1279               return convertTransferWriteToStores(rewriter, transferWriteOp,
1280                                                   valueMapping);
1281             })
1282             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1283               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1284                                                 valueMapping);
1285             })
1286             .Case([&](vector::ContractionOp contractionOp) {
1287               return convertContractOpToMmaSync(rewriter, contractionOp,
1288                                                 valueMapping);
1289             })
1290             .Case([&](scf::ForOp forOp) {
1291               return convertForOp(rewriter, forOp, valueMapping);
1292             })
1293             .Case([&](scf::YieldOp yieldOp) {
1294               return convertYieldOp(rewriter, yieldOp, valueMapping);
1295             })
1296             .Case([&](arith::ConstantOp constOp) {
1297               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1298             })
1299             .Default([&](Operation *op) {
1300               return op->emitError() << "unhandled vector to mma type: " << *op;
1301             })
1302             .failed()) {
1303       return op->emitOpError()
1304              << "failed to convert op during vector-to-nvgpu conversion";
1305     }
1306   }
1307   return success();
1308 }
1309 
1310 namespace {
1311 
1312 struct ConvertVectorToGPUPass
1313     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1314 
1315   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1316     useNvGpu.setValue(useNvGpu_);
1317   }
1318 
1319   void runOnOperation() override {
1320     RewritePatternSet patterns(&getContext());
1321     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1322     if (failed(
1323             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1324       return signalPassFailure();
1325 
1326     IRRewriter rewriter(&getContext());
1327     if (useNvGpu) {
1328       if (failed(
1329               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1330         return signalPassFailure();
1331       return;
1332     }
1333     (void)convertVectorToMMAOps(rewriter, getOperation());
1334   }
1335 };
1336 
1337 } // namespace
1338 
1339 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1340   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1341 }
1342