xref: /llvm-project/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Analysis/TopologicalSortUtils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58                            AffineMap offsetMap, ArrayRef<Value> dimValues,
59                            SmallVector<Value, 4> &indices) {
60   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61   Location loc = xferOp.getLoc();
62   unsigned offsetsIdx = 0;
63   for (auto expr : xferOp.getPermutationMap().getResults()) {
64     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65       Value prevIdx = indices[dim.getPosition()];
66       SmallVector<OpFoldResult, 3> dims(dimValues);
67       dims.push_back(prevIdx);
68       AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69       indices[dim.getPosition()] = affine::makeComposedAffineApply(
70           rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71       continue;
72     }
73   }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78                                           bool useNvGpu) {
79   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80   auto infer = [&](MapList m) {
81     return AffineMap::inferFromExprList(m, contract.getContext());
82   };
83   AffineExpr m, n, k;
84   bindDims(contract.getContext(), m, n, k);
85   auto iteratorTypes = contract.getIteratorTypes().getValue();
86   if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87         vector::isParallelIterator(iteratorTypes[1]) &&
88         vector::isReductionIterator(iteratorTypes[2])))
89     return false;
90 
91   // The contract needs to represent a matmul to be able to convert to
92   // MMAMatrix matmul.
93   if (!useNvGpu &&
94       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95     return false;
96   if (useNvGpu &&
97       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98     return false;
99 
100   return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106   MLIRContext *ctx = permutationMap.getContext();
107   // Local OpBuilder is fine here, we just build attributes.
108   OpBuilder b(ctx);
109   auto nDim = permutationMap.getNumDims();
110   AffineExpr zero = b.getAffineConstantExpr(0);
111   if (nDim < 2) {
112     // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113     AffineExpr dim0 = b.getAffineDimExpr(0);
114     return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115   }
116 
117   AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118   AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119   // Support both transposed and transposed+broadcasted cases.
120   return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127   auto memrefType = dyn_cast<MemRefType>(type);
128   if (!memrefType)
129     return false;
130   // If the memref is 0 or 1D the horizontal stride is 0.
131   if (memrefType.getRank() < 2)
132     return 0;
133   int64_t offset = 0;
134   SmallVector<int64_t, 2> strides;
135   if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136       strides.back() != 1)
137     return std::nullopt;
138   int64_t stride = strides[strides.size() - 2];
139   if (stride == ShapedType::kDynamic)
140     return std::nullopt;
141   return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147       readOp.getVectorType().getRank() != 2)
148     return false;
149   if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150     return false;
151 
152   // Only allow integer types if the signedness can be inferred.
153   if (readOp.getVectorType().getElementType().isInteger(8))
154     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156       return false;
157 
158   AffineMap map = readOp.getPermutationMap();
159   MLIRContext *ctx = readOp.getContext();
160   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161   AffineExpr zero = getAffineConstantExpr(0, ctx);
162   auto broadcastInnerDim =
163       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164   return map.isMinorIdentity() || map == broadcastInnerDim ||
165          isTransposeMatrixLoadMap(map);
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171   // TODO: support 0-d corner case.
172   if (writeOp.getTransferRank() == 0)
173     return false;
174 
175   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176       writeOp.getVectorType().getRank() != 2)
177     return false;
178   if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179     return false;
180   // TODO: Support transpose once it is added to GPU dialect ops.
181   if (!writeOp.getPermutationMap().isMinorIdentity())
182     return false;
183   return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189   auto vecType = dyn_cast<VectorType>(constantOp.getType());
190   if (!vecType || vecType.getRank() != 2)
191     return false;
192   return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197   return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203   auto transferReadOp =
204       extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205   if (!transferReadOp)
206     return false;
207   return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208 }
209 
210 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211 
212 /// Return the MMA elementwise enum associated with `op` if it is supported.
213 /// Return `std::nullopt` otherwise.
214 static std::optional<gpu::MMAElementwiseOp>
215 convertElementwiseOpToMMA(Operation *op) {
216   if (isa<arith::AddFOp>(op))
217     return gpu::MMAElementwiseOp::ADDF;
218   if (isa<arith::MulFOp>(op))
219     return gpu::MMAElementwiseOp::MULF;
220   if (isa<arith::SubFOp>(op))
221     return gpu::MMAElementwiseOp::SUBF;
222   if (isa<arith::MaximumFOp>(op))
223     return gpu::MMAElementwiseOp::MAXF;
224   if (isa<arith::MinimumFOp>(op))
225     return gpu::MMAElementwiseOp::MINF;
226   if (isa<arith::DivFOp>(op))
227     return gpu::MMAElementwiseOp::DIVF;
228   if (isa<arith::AddIOp>(op))
229     return gpu::MMAElementwiseOp::ADDI;
230   if (isa<arith::MulIOp>(op))
231     return gpu::MMAElementwiseOp::MULI;
232   if (isa<arith::SubIOp>(op))
233     return gpu::MMAElementwiseOp::SUBI;
234   if (isa<arith::DivSIOp>(op))
235     return gpu::MMAElementwiseOp::DIVS;
236   if (isa<arith::DivUIOp>(op))
237     return gpu::MMAElementwiseOp::DIVU;
238   if (isa<arith::NegFOp>(op))
239     return gpu::MMAElementwiseOp::NEGATEF;
240   if (isa<arith::ExtFOp>(op))
241     return gpu::MMAElementwiseOp::EXTF;
242   return std::nullopt;
243 }
244 
245 /// Return true if the op is supported as elementwise op on MMAMatrix type.
246 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
247   return convertElementwiseOpToMMA(op).has_value();
248 }
249 
250 /// Returns true if the extract strided slice op is supported with `mma.sync`
251 /// path.
252 static bool
253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254 
255   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
256       nvgpu::getWarpMatrixInfo(op);
257   if (failed(warpMatrixInfo))
258     return false;
259 
260   FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261   if (failed(contractOp))
262     return false;
263 
264   // Handle vector.extract_strided_slice on registers containing
265   // matrixB and matrixC operands. vector.extract_strided_slice op
266   // is not supported on registers containing matrixA operands.
267   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
268     return (cast<VectorType>(op->getResult(0).getType()) ==
269             cast<VectorType>((*contractOp).getRhs().getType()));
270   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
271     return (cast<VectorType>(op->getResult(0).getType()) ==
272             cast<VectorType>((*contractOp).getAcc().getType()));
273 
274   return false;
275 }
276 
277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278   if (isa<scf::ForOp, scf::YieldOp>(op))
279     return true;
280   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282                     : transferReadSupportsMMAMatrixType(transferRead);
283   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284     return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285                     : transferWriteSupportsMMAMatrixType(transferWrite);
286   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287     return useNvGpu &&
288            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289   if (auto contract = dyn_cast<vector::ContractionOp>(op))
290     return contractSupportsMMAMatrixType(contract, useNvGpu);
291   if (auto constant = dyn_cast<arith::ConstantOp>(op))
292     return constantSupportsMMAMatrixType(constant);
293   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
294     return broadcastSupportsMMAMatrixType(broadcast);
295   if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300     return fpExtendSupportsMMAMatrixType(fpExtend);
301   return elementwiseSupportsMMAMatrixType(op);
302 }
303 
304 /// Return an unsorted slice handling scf.for region differently than
305 /// `getSlice`. In scf.for we only want to include as part of the slice elements
306 /// that are part of the use/def chain.
307 static SetVector<Operation *>
308 getSliceContract(Operation *op,
309                  const BackwardSliceOptions &backwardSliceOptions,
310                  const ForwardSliceOptions &forwardSliceOptions) {
311   SetVector<Operation *> slice;
312   slice.insert(op);
313   unsigned currentIndex = 0;
314   SetVector<Operation *> backwardSlice;
315   SetVector<Operation *> forwardSlice;
316   while (currentIndex != slice.size()) {
317     auto *currentOp = (slice)[currentIndex];
318     // Compute and insert the backwardSlice starting from currentOp.
319     backwardSlice.clear();
320     getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
321     slice.insert(backwardSlice.begin(), backwardSlice.end());
322 
323     // Compute and insert the forwardSlice starting from currentOp.
324     forwardSlice.clear();
325     // Special case for ForOp, we don't want to include the whole region but
326     // only the value using the region arguments.
327     // TODO: We should refine this to only care about the region arguments being
328     // converted to matrix type.
329     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
330       for (Value forOpResult : forOp.getResults())
331         getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
332       for (BlockArgument &arg : forOp.getRegionIterArgs())
333         getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
334     } else {
335       getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
336     }
337     slice.insert(forwardSlice.begin(), forwardSlice.end());
338     ++currentIndex;
339   }
340   return slice;
341 }
342 
343 // Analyze slice of operations based on convert op to figure out if the whole
344 // slice can be converted to MMA operations.
345 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
346                                              bool useNvGpu) {
347   auto hasVectorDest = [](Operation *op) {
348     return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
349   };
350   BackwardSliceOptions backwardSliceOptions;
351   backwardSliceOptions.filter = hasVectorDest;
352 
353   auto hasVectorSrc = [](Operation *op) {
354     return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
355   };
356   ForwardSliceOptions forwardSliceOptions;
357   forwardSliceOptions.filter = hasVectorSrc;
358 
359   SetVector<Operation *> opToConvert;
360   op->walk([&](vector::ContractionOp contract) {
361     if (opToConvert.contains(contract.getOperation()))
362       return;
363     SetVector<Operation *> dependentOps =
364         getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365     // If any instruction cannot use MMA matrix type drop the whole
366     // chain. MMA matrix are stored in an opaque type so they cannot be used
367     // by all operations.
368     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369           if (!supportsMMaMatrixType(op, useNvGpu)) {
370             LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371             return true;
372           }
373           return false;
374         }))
375       return;
376 
377     opToConvert.insert(dependentOps.begin(), dependentOps.end());
378   });
379   // Sort the operations so that we can convert them in topological order.
380   return topologicalSort(opToConvert);
381 }
382 
383 namespace {
384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385 // to MMA matmul.
386 struct PrepareContractToGPUMMA
387     : public OpRewritePattern<vector::ContractionOp> {
388   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
389 
390   LogicalResult matchAndRewrite(vector::ContractionOp op,
391                                 PatternRewriter &rewriter) const override {
392     Location loc = op.getLoc();
393     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394 
395     // Set up the parallel/reduction structure in right form.
396     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397     auto infer = [&](MapList m) {
398       return AffineMap::inferFromExprList(m, op.getContext());
399     };
400     AffineExpr m, n, k;
401     bindDims(rewriter.getContext(), m, n, k);
402     static constexpr std::array<int64_t, 2> perm = {1, 0};
403     auto iteratorTypes = op.getIteratorTypes().getValue();
404     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
405     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
406           vector::isParallelIterator(iteratorTypes[1]) &&
407           vector::isReductionIterator(iteratorTypes[2])))
408       return rewriter.notifyMatchFailure(op, "not a gemm contraction");
409     //
410     // Two outer parallel, one inner reduction (matmat flavor).
411     //
412     // This is the classical row-major matmul, nothing to do.
413     if (maps == infer({{m, k}, {k, n}, {m, n}}))
414       return rewriter.notifyMatchFailure(op, "contraction already prepared");
415     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
416       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
417     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
418       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
420       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
421       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
423       std::swap(rhs, lhs);
424       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
425       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
426     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
427       std::swap(rhs, lhs);
428       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
429     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
430       std::swap(lhs, rhs);
431       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
432     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
433       std::swap(lhs, rhs);
434     } else {
435       // TODO: llvm_unreachable ?
436       return rewriter.notifyMatchFailure(op, "unexpected contraction case");
437     }
438     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
439         op, lhs, rhs, res,
440         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
441         op.getIteratorTypes());
442     return success();
443   }
444 };
445 
446 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
447 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
448 // respectively. We can fold the transpose operation when loading the data from
449 // Shared Memory to registers.
450 struct CombineTransferReadOpTranspose final
451     : public OpRewritePattern<vector::TransposeOp> {
452   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
453 
454   LogicalResult matchAndRewrite(vector::TransposeOp op,
455                                 PatternRewriter &rewriter) const override {
456     // Look through integer extend ops.
457     Value source = op.getVector();
458     Type resultType = op.getType();
459     Operation *extOp;
460     if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
461         (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
462         (extOp = source.getDefiningOp<arith::ExtFOp>())) {
463       source = extOp->getOperand(0);
464       resultType =
465           VectorType::get(cast<VectorType>(resultType).getShape(),
466                           cast<VectorType>(source.getType()).getElementType());
467     }
468 
469     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
470     if (!transferReadOp)
471       return rewriter.notifyMatchFailure(op, "no transfer read");
472 
473     // TODO: support 0-d corner case.
474     if (transferReadOp.getTransferRank() == 0)
475       return rewriter.notifyMatchFailure(op, "0-D transfer read");
476 
477     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
478       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
479 
480     AffineMap permutationMap =
481         AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
482     AffineMap newMap =
483         permutationMap.compose(transferReadOp.getPermutationMap());
484 
485     auto loc = op.getLoc();
486     Value result =
487         rewriter
488             .create<vector::TransferReadOp>(
489                 loc, resultType, transferReadOp.getSource(),
490                 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
491                 transferReadOp.getPadding(), transferReadOp.getMask(),
492                 transferReadOp.getInBoundsAttr())
493             .getResult();
494 
495     // Fuse through the integer extend op.
496     if (extOp) {
497       if (isa<arith::ExtSIOp>(extOp))
498         result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
499                      .getResult();
500       else if (isa<arith::ExtUIOp>(extOp))
501         result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
502                      .getResult();
503       else
504         result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
505                      .getResult();
506     }
507 
508     rewriter.replaceOp(op, result);
509     return success();
510   }
511 };
512 
513 } // namespace
514 
515 // MMA types have different layout based on how they are used in matmul ops.
516 // Figure the right layout to use by looking at op uses.
517 // TODO: Change the GPU dialect to abstract the layout at the this level and
518 // only care about it during lowering to NVVM.
519 static const char *inferFragType(Operation *op) {
520   // We can have arith.ext ops before reaching contract ops. See through them
521   // and other kinds of elementwise ops.
522   if (op->hasOneUse()) {
523     Operation *userOp = *op->user_begin();
524     if (userOp->hasTrait<OpTrait::Elementwise>())
525       return inferFragType(userOp);
526   }
527 
528   for (Operation *users : op->getUsers()) {
529     auto contract = dyn_cast<vector::ContractionOp>(users);
530     if (!contract)
531       continue;
532     assert(op->getNumResults() == 1);
533     if (contract.getLhs() == op->getResult(0))
534       return "AOp";
535     if (contract.getRhs() == op->getResult(0))
536       return "BOp";
537   }
538   return "COp";
539 }
540 
541 static LogicalResult
542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
543                       llvm::DenseMap<Value, Value> &valueMapping) {
544   OpBuilder::InsertionGuard g(rewriter);
545   rewriter.setInsertionPoint(op);
546 
547   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548   assert(transferReadSupportsMMAMatrixType(op) &&
549          "expected convertible operation");
550 
551   std::optional<int64_t> stride =
552       getStaticallyKnownRowStride(op.getShapedType());
553   if (!stride.has_value()) {
554     LLVM_DEBUG(DBGS() << "no stride\n");
555     return rewriter.notifyMatchFailure(op, "no stride");
556   }
557 
558   AffineMap map = op.getPermutationMap();
559   bool isTranspose = isTransposeMatrixLoadMap(map);
560 
561   // Handle broadcast by setting the stride to 0.
562   if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
563     assert(cstExpr.getValue() == 0);
564     stride = 0;
565   }
566 
567   Value mappingResult = op.getResult();
568   auto elType = op.getVectorType().getElementType();
569   const char *fragType = inferFragType(op);
570   if (op->hasOneUse()) {
571     auto *user = *op->user_begin();
572     // Infer the signedness of the mma type from the integer extend.
573     if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
574       elType = IntegerType::get(
575           op.getContext(), cast<IntegerType>(elType).getWidth(),
576           isa<arith::ExtSIOp>(user) ? IntegerType::Signed
577                                     : IntegerType::Unsigned);
578       mappingResult = user->getResult(0);
579     }
580   }
581   gpu::MMAMatrixType type =
582       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
583   Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
584       op.getLoc(), type, op.getSource(), op.getIndices(),
585       rewriter.getIndexAttr(*stride),
586       isTranspose ? rewriter.getUnitAttr() : UnitAttr());
587   valueMapping[mappingResult] = load;
588 
589   LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
590   return success();
591 }
592 
593 static LogicalResult
594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
595                        llvm::DenseMap<Value, Value> &valueMapping) {
596   OpBuilder::InsertionGuard g(rewriter);
597   rewriter.setInsertionPoint(op);
598 
599   assert(transferWriteSupportsMMAMatrixType(op));
600   std::optional<int64_t> stride =
601       getStaticallyKnownRowStride(op.getShapedType());
602   if (!stride.has_value()) {
603     LLVM_DEBUG(DBGS() << "no stride\n");
604     return rewriter.notifyMatchFailure(op, "no stride");
605   }
606 
607   auto it = valueMapping.find(op.getVector());
608   if (it == valueMapping.end()) {
609     LLVM_DEBUG(DBGS() << "no mapping\n");
610     return rewriter.notifyMatchFailure(op, "no mapping");
611   }
612 
613   Value matrix = it->second;
614   auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
615       op.getLoc(), matrix, op.getSource(), op.getIndices(),
616       rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
617   (void)store;
618 
619   LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
620 
621   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
622   rewriter.eraseOp(op);
623   return success();
624 }
625 
626 /// Returns the vector type which represents a matrix fragment.
627 static VectorType
628 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
629   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
630                              regInfo.elementsPerRegister};
631   Type elType = regInfo.registerLLVMType;
632   if (auto vecType = dyn_cast<VectorType>(elType))
633     elType = vecType.getElementType();
634   return VectorType::get(shape, elType);
635 }
636 
637 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
638 static LogicalResult
639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
640                          llvm::DenseMap<Value, Value> &valueMapping) {
641   OpBuilder::InsertionGuard g(rewriter);
642   rewriter.setInsertionPoint(op);
643 
644   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
645       nvgpu::getWarpMatrixInfo(op);
646   if (failed(warpMatrixInfo)) {
647     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
648     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
649   }
650 
651   FailureOr<nvgpu::FragmentElementInfo> regInfo =
652       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
653   if (failed(regInfo)) {
654     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
655     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
656   }
657 
658   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
659   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
660   if (!dense) {
661     LLVM_DEBUG(DBGS() << "not a splat\n");
662     return rewriter.notifyMatchFailure(op, "not a splat");
663   }
664 
665   Value result = rewriter.create<arith::ConstantOp>(
666       op.getLoc(), vectorType,
667       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
668   valueMapping[op.getResult()] = result;
669   return success();
670 }
671 
672 /// Check if the loaded matrix operand requires transposed.
673 /// Transposed Map Example:
674 /// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
675 /// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
676 /// The code below checks if the output 2D is transposed using a generalized
677 /// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
678 /// Returns     : true; if m > n, false o.w.
679 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
680   mlir::AffineMap map = op.getPermutationMap();
681 
682   if (map.getNumResults() != 2) {
683     LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
684                          "is not a 2d operand\n");
685     return failure();
686   }
687 
688   // Output 2D matrix dimensions in the order of d0, d1.
689   mlir::AffineExpr dM = map.getResult(0);
690   mlir::AffineExpr dN = map.getResult(1);
691 
692   //  Find the position of these expressions in the input.
693   auto exprM = dyn_cast<AffineDimExpr>(dM);
694   auto exprN = dyn_cast<AffineDimExpr>(dN);
695 
696   if (!exprM || !exprN) {
697     LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
698                          "expressions, then transpose cannot be determined.\n");
699     return failure();
700   }
701 
702   return exprM.getPosition() > exprN.getPosition();
703 }
704 
705 static LogicalResult
706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
707                              llvm::DenseMap<Value, Value> &valueMapping) {
708   OpBuilder::InsertionGuard g(rewriter);
709   rewriter.setInsertionPoint(op);
710   Location loc = op->getLoc();
711 
712   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
713       nvgpu::getWarpMatrixInfo(op);
714   if (failed(warpMatrixInfo)) {
715     LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
716     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
717   }
718 
719   FailureOr<nvgpu::FragmentElementInfo> regInfo =
720       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
721   if (failed(regInfo)) {
722     LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
723     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
724   }
725 
726   FailureOr<bool> transpose = isTransposed(op);
727   if (failed(transpose)) {
728     LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
729     return rewriter.notifyMatchFailure(
730         op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
731   }
732 
733   FailureOr<nvgpu::LdMatrixParams> params =
734       nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
735 
736   if (failed(params)) {
737     LLVM_DEBUG(
738         DBGS()
739         << "failed to convert vector.transfer_read to ldmatrix. "
740         << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
741     return rewriter.notifyMatchFailure(
742         op, "failed to convert vector.transfer_read to ldmatrix; this op "
743             "likely should not be converted to a nvgpu.ldmatrix call.");
744   }
745 
746   // Adjust the load offset.
747   auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
748   FailureOr<AffineMap> offsets =
749       nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
750   if (failed(offsets)) {
751     LLVM_DEBUG(DBGS() << "no offsets\n");
752     return rewriter.notifyMatchFailure(op, "no offsets");
753   }
754 
755   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
756 
757   SmallVector<Value, 4> indices;
758   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
759                                          indices);
760 
761   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
762       loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
763   valueMapping[op] = newOp->getResult(0);
764   return success();
765 }
766 
767 static LogicalResult
768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
769                        llvm::DenseMap<Value, Value> &valueMapping) {
770   OpBuilder::InsertionGuard g(rewriter);
771   rewriter.setInsertionPoint(op);
772 
773   Location loc = op.getLoc();
774   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
775       nvgpu::getWarpMatrixInfo(op);
776   if (failed(warpMatrixInfo))
777     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
778   FailureOr<nvgpu::FragmentElementInfo> regInfo =
779       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
780   if (failed(regInfo)) {
781     return rewriter.notifyMatchFailure(
782         op, "Failed to deduce register fragment type during "
783             "conversion to distributed non-ldmatrix compatible load");
784   }
785 
786   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
787   SmallVector<Value, 4> elements;
788 
789   // This is the individual element type.
790   Type loadedElType = regInfo->registerLLVMType;
791   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
792 
793   Value fill = rewriter.create<arith::ConstantOp>(
794       op.getLoc(), vectorType.getElementType(),
795       rewriter.getZeroAttr(vectorType.getElementType()));
796   Value result =
797       rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
798 
799   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
800 
801   // If we are not transposing, then we can use vectorized loads. Otherwise, we
802   // must load each element individually.
803   if (!isTransposeLoad) {
804     if (!isa<VectorType>(loadedElType)) {
805       loadedElType = VectorType::get({1}, loadedElType);
806     }
807 
808     for (int i = 0; i < vectorType.getShape()[0]; i++) {
809       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
810           rewriter, op.getLoc(), *warpMatrixInfo);
811       if (failed(coords))
812         return rewriter.notifyMatchFailure(op, "no coords");
813 
814       Value logicalValueId = rewriter.create<arith::ConstantOp>(
815           loc, rewriter.getIndexType(),
816           rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
817       SmallVector<Value, 4> newIndices;
818       getXferIndices<vector::TransferReadOp>(
819           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
820 
821       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
822                                                  op.getSource(), newIndices);
823       result = rewriter.create<vector::InsertOp>(loc, el, result, i);
824     }
825   } else {
826     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
827       loadedElType = vecType.getElementType();
828     }
829     for (int i = 0; i < vectorType.getShape()[0]; i++) {
830       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
831            innerIdx++) {
832 
833         Value logicalValueId = rewriter.create<arith::ConstantOp>(
834             loc, rewriter.getIndexType(),
835             rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
836         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
837             rewriter, op.getLoc(), *warpMatrixInfo);
838         if (failed(coords))
839           return rewriter.notifyMatchFailure(op, "no coords");
840 
841         SmallVector<Value, 4> newIndices;
842         getXferIndices<vector::TransferReadOp>(
843             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
845                                                    op.getSource(), newIndices);
846         result = rewriter.create<vector::InsertOp>(
847             op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
848       }
849     }
850   }
851 
852   valueMapping[op.getResult()] = result;
853   return success();
854 }
855 
856 /// Return true if this is a shared memory memref type.
857 static bool isSharedMemory(MemRefType type) {
858   auto addressSpace =
859       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
860   return addressSpace &&
861          addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
862 }
863 
864 /// Converts a `vector.transfer_read` operation directly to either a
865 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
866 /// used when converting to `nvgpu.mma.sync` operations.
867 static LogicalResult
868 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
869                            llvm::DenseMap<Value, Value> &valueMapping) {
870   OpBuilder::InsertionGuard g(rewriter);
871   rewriter.setInsertionPoint(op);
872 
873   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
874       nvgpu::getWarpMatrixInfo(op);
875   if (failed(warpMatrixInfo))
876     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
877 
878   bool isLdMatrixCompatible =
879       isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
880       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
881 
882   VectorType vecTy = op.getVectorType();
883   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
884 
885   // When we are transposing the B operand, ldmatrix will only work if we have
886   // at least 8 rows to read and the width to read for the transpose is 128
887   // bits.
888   if (!op.getPermutationMap().isMinorIdentity() &&
889       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
890        vecTy.getDimSize(0) * bitWidth < 128))
891     isLdMatrixCompatible = false;
892 
893   if (!isLdMatrixCompatible)
894     return createNonLdMatrixLoads(rewriter, op, valueMapping);
895 
896   return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
897 }
898 
899 static LogicalResult
900 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
901                              llvm::DenseMap<Value, Value> &valueMapping) {
902   OpBuilder::InsertionGuard g(rewriter);
903   rewriter.setInsertionPoint(op);
904 
905   Location loc = op->getLoc();
906   auto it = valueMapping.find(op.getVector());
907   if (it == valueMapping.end())
908     return rewriter.notifyMatchFailure(op, "no mapping");
909   Value matrix = it->second;
910 
911   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
912       nvgpu::getWarpMatrixInfo(op);
913   if (failed(warpMatrixInfo))
914     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
915   FailureOr<nvgpu::FragmentElementInfo> regInfo =
916       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
917   if (failed(regInfo))
918     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
919 
920   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
921   Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
922 
923   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
924     Value logicalValueId = rewriter.create<arith::ConstantOp>(
925         loc, rewriter.getIndexType(),
926         rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
927     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
928         rewriter, op.getLoc(), *warpMatrixInfo);
929     if (failed(coords))
930       return rewriter.notifyMatchFailure(op, "no coords");
931 
932     Value el =
933         rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
934     SmallVector<Value, 4> newIndices;
935     getXferIndices<vector::TransferWriteOp>(
936         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
937     rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
938   }
939 
940   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
941   rewriter.eraseOp(op);
942   return success();
943 }
944 
945 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
946                                        SmallVectorImpl<int64_t> &results) {
947   for (auto attr : arrayAttr)
948     results.push_back(cast<IntegerAttr>(attr).getInt());
949 }
950 
951 static LogicalResult
952 convertExtractStridedSlice(RewriterBase &rewriter,
953                            vector::ExtractStridedSliceOp op,
954                            llvm::DenseMap<Value, Value> &valueMapping) {
955   OpBuilder::InsertionGuard g(rewriter);
956   rewriter.setInsertionPoint(op);
957 
958   Location loc = op->getLoc();
959 
960   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
961       nvgpu::getWarpMatrixInfo(op);
962   if (failed(warpMatrixInfo))
963     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
964 
965   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
966       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
967   if (failed(mmaSyncFragmentInfo))
968     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
969 
970   // Find the vector.transer_read whose result vector is being sliced.
971   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
972   if (!transferReadOp)
973     return rewriter.notifyMatchFailure(op, "no transfer read");
974 
975   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
976   if (failed(warpMatrixInfo))
977     return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
978 
979   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
980       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
981   if (failed(ldFragmentInfo))
982     return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
983 
984   assert(
985       (mmaSyncFragmentInfo->elementsPerRegister ==
986        ldFragmentInfo->elementsPerRegister) &&
987       "Number of elements per register should be same for load and mma.sync");
988 
989   // Create vector.extract_strided_slice op for thread-owned fragments.
990   std::array<int64_t, 2> strides = {1,
991                                     1}; // stride for extract slice is always 1.
992   std::array<int64_t, 2> sliceShape = {
993       mmaSyncFragmentInfo->numRegistersPerFragment,
994       mmaSyncFragmentInfo->elementsPerRegister};
995   auto it = valueMapping.find(transferReadOp);
996   if (it == valueMapping.end())
997     return rewriter.notifyMatchFailure(op, "no mapping");
998   auto sourceVector = it->second;
999 
1000   // offset and sizes at warp-level of onwership.
1001   SmallVector<int64_t> offsets;
1002   populateFromInt64AttrArray(op.getOffsets(), offsets);
1003 
1004   SmallVector<int64_t> sizes;
1005   populateFromInt64AttrArray(op.getSizes(), sizes);
1006   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1007 
1008   // Compute offset in vector registers. Note that the mma.sync vector registers
1009   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1010   // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1011   std::array<int64_t, 2> sliceOffset = {0, 0};
1012 
1013   if (offsets[0] && offsets[1])
1014     return op->emitError() << "Slicing fragments in 2D is not supported. ";
1015   if (offsets[0])
1016     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1017   else if (offsets[1])
1018     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1019 
1020   Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1021       loc, sourceVector, sliceOffset, sliceShape, strides);
1022 
1023   valueMapping[op] = newOp;
1024   return success();
1025 }
1026 
1027 static LogicalResult
1028 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1029                   llvm::DenseMap<Value, Value> &valueMapping) {
1030   OpBuilder::InsertionGuard g(rewriter);
1031   rewriter.setInsertionPoint(op);
1032 
1033   auto itA = valueMapping.find(op.getLhs());
1034   auto itB = valueMapping.find(op.getRhs());
1035   auto itC = valueMapping.find(op.getAcc());
1036   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1037       itC == valueMapping.end())
1038     return rewriter.notifyMatchFailure(op, "no mapping");
1039   Value opA = itA->second, opB = itB->second, opC = itC->second;
1040   Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1041       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1042       /*b_transpose=*/UnitAttr());
1043   valueMapping[op.getResult()] = matmul;
1044   return success();
1045 }
1046 
1047 static LogicalResult
1048 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1049                            llvm::DenseMap<Value, Value> &valueMapping) {
1050   OpBuilder::InsertionGuard g(rewriter);
1051   rewriter.setInsertionPoint(op);
1052 
1053   auto itA = valueMapping.find(op.getLhs());
1054   auto itB = valueMapping.find(op.getRhs());
1055   auto itC = valueMapping.find(op.getAcc());
1056   if (itA == valueMapping.end() || itB == valueMapping.end() ||
1057       itC == valueMapping.end())
1058     return rewriter.notifyMatchFailure(op, "no mapping");
1059   Value opA = itA->second, opB = itB->second, opC = itC->second;
1060   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1061   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1062   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1063   Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1064       op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1065   valueMapping[op.getResult()] = matmul;
1066   return success();
1067 }
1068 
1069 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1070 static LogicalResult
1071 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1072                   llvm::DenseMap<Value, Value> &valueMapping) {
1073   OpBuilder::InsertionGuard g(rewriter);
1074   rewriter.setInsertionPoint(op);
1075 
1076   assert(constantSupportsMMAMatrixType(op));
1077 
1078   auto splat =
1079       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1080   auto scalarConstant =
1081       rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1082   const char *fragType = inferFragType(op);
1083   auto vecType = cast<VectorType>(op.getType());
1084   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1085       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1086   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1087       op.getLoc(), type, scalarConstant);
1088   valueMapping[op.getResult()] = matrix;
1089   return success();
1090 }
1091 
1092 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1093 static LogicalResult
1094 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1095                    llvm::DenseMap<Value, Value> &valueMapping) {
1096   OpBuilder::InsertionGuard g(rewriter);
1097   rewriter.setInsertionPoint(op);
1098 
1099   assert(broadcastSupportsMMAMatrixType(op));
1100 
1101   const char *fragType = inferFragType(op);
1102   auto vecType = op.getResultVectorType();
1103   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1104       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1105   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1106       op.getLoc(), type, op.getSource());
1107   valueMapping[op.getResult()] = matrix;
1108   return success();
1109 }
1110 
1111 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1112 // updated and needs to be updated separately for the loop to be correct.
1113 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1114                                                scf::ForOp loop,
1115                                                ValueRange newInitArgs) {
1116   OpBuilder::InsertionGuard g(rewriter);
1117   rewriter.setInsertionPoint(loop);
1118 
1119   // Create a new loop before the existing one, with the extra operands.
1120   rewriter.setInsertionPoint(loop);
1121   auto operands = llvm::to_vector<4>(loop.getInitArgs());
1122   llvm::append_range(operands, newInitArgs);
1123   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1124       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1125       operands);
1126   rewriter.eraseBlock(newLoop.getBody());
1127 
1128   newLoop.getRegion().getBlocks().splice(
1129       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1130   for (Value operand : newInitArgs)
1131     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1132 
1133   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1134                                                   loop.getNumResults())))
1135     rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1136 
1137   LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1138   LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1139   LLVM_DEBUG(DBGS() << "erase: " << loop);
1140 
1141   rewriter.eraseOp(loop);
1142   return newLoop;
1143 }
1144 
1145 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1146                                   llvm::DenseMap<Value, Value> &valueMapping) {
1147   OpBuilder::InsertionGuard g(rewriter);
1148   rewriter.setInsertionPoint(op);
1149 
1150   SmallVector<Value> newOperands;
1151   SmallVector<std::pair<size_t, size_t>> argMapping;
1152   for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1153     auto it = valueMapping.find(operand.value());
1154     if (it == valueMapping.end()) {
1155       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1156       continue;
1157     }
1158     argMapping.push_back(std::make_pair(
1159         operand.index(), op.getInitArgs().size() + newOperands.size()));
1160     newOperands.push_back(it->second);
1161   }
1162 
1163   scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1164   Block &loopBody = *newForOp.getBody();
1165   for (auto mapping : argMapping) {
1166     valueMapping[newForOp.getResult(mapping.first)] =
1167         newForOp.getResult(mapping.second);
1168     valueMapping[loopBody.getArgument(mapping.first +
1169                                       newForOp.getNumInductionVars())] =
1170         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1171   }
1172 
1173   LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1174   return success();
1175 }
1176 
1177 static LogicalResult
1178 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1179                llvm::DenseMap<Value, Value> &valueMapping) {
1180   OpBuilder::InsertionGuard g(rewriter);
1181   rewriter.setInsertionPoint(op);
1182 
1183   auto loop = cast<scf::ForOp>(op->getParentOp());
1184   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1185   for (const auto &operand : llvm::enumerate(op.getOperands())) {
1186     auto it = valueMapping.find(operand.value());
1187     if (it == valueMapping.end())
1188       continue;
1189     // Replace the yield of old value with the for op argument to make it easier
1190     // to remove the dead code.
1191     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1192     yieldOperands.push_back(it->second);
1193   }
1194   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1195 
1196   LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1197   rewriter.eraseOp(op);
1198   return success();
1199 }
1200 
1201 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1202 static LogicalResult
1203 convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1204                      gpu::MMAElementwiseOp opType,
1205                      llvm::DenseMap<Value, Value> &valueMapping) {
1206   OpBuilder::InsertionGuard g(rewriter);
1207   rewriter.setInsertionPoint(op);
1208 
1209   SmallVector<Value> matrixOperands;
1210   for (Value operand : op->getOperands()) {
1211     auto it = valueMapping.find(operand);
1212     if (it == valueMapping.end())
1213       return rewriter.notifyMatchFailure(op, "no mapping");
1214     matrixOperands.push_back(it->second);
1215   }
1216   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1217   if (opType == gpu::MMAElementwiseOp::EXTF) {
1218     // The floating point extension case has a different result type.
1219     auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1220     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1221                                          vectorType.getElementType(),
1222                                          resultType.getOperand());
1223   }
1224 
1225   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1226       op->getLoc(), resultType, matrixOperands, opType);
1227   valueMapping[op->getResult(0)] = newOp;
1228   return success();
1229 }
1230 
1231 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1232                                               bool useNvGpu) {
1233   if (!useNvGpu) {
1234     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1235         patterns.getContext());
1236     return;
1237   }
1238   vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1239   patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1240 }
1241 
1242 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1243                                           Operation *rootOp) {
1244   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1245   llvm::DenseMap<Value, Value> valueMapping;
1246 
1247   auto globalRes = LogicalResult::success();
1248   for (Operation *op : ops) {
1249     LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1250     // Apparently callers do not want to early exit on failure here.
1251     auto res = LogicalResult::success();
1252     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1253       res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1254     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1255       res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1256     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1257       res = convertContractOp(rewriter, contractOp, valueMapping);
1258     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1259       res = convertConstantOp(rewriter, constantOp, valueMapping);
1260     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1261       res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1262     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1263       res = convertForOp(rewriter, forOp, valueMapping);
1264     } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1265       res = convertYieldOp(rewriter, yieldOp, valueMapping);
1266     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1267       res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1268     }
1269     if (failed(res))
1270       globalRes = failure();
1271   }
1272   return globalRes;
1273 }
1274 
1275 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1276                                                          Operation *rootOp) {
1277   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1278   llvm::DenseMap<Value, Value> valueMapping;
1279   for (Operation *op : ops) {
1280     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1281             .Case([&](vector::TransferReadOp transferReadOp) {
1282               return convertTransferReadToLoads(rewriter, transferReadOp,
1283                                                 valueMapping);
1284             })
1285             .Case([&](vector::TransferWriteOp transferWriteOp) {
1286               return convertTransferWriteToStores(rewriter, transferWriteOp,
1287                                                   valueMapping);
1288             })
1289             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1290               return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1291                                                 valueMapping);
1292             })
1293             .Case([&](vector::ContractionOp contractionOp) {
1294               return convertContractOpToMmaSync(rewriter, contractionOp,
1295                                                 valueMapping);
1296             })
1297             .Case([&](scf::ForOp forOp) {
1298               return convertForOp(rewriter, forOp, valueMapping);
1299             })
1300             .Case([&](scf::YieldOp yieldOp) {
1301               return convertYieldOp(rewriter, yieldOp, valueMapping);
1302             })
1303             .Case([&](arith::ConstantOp constOp) {
1304               return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1305             })
1306             .Default([&](Operation *op) {
1307               return op->emitError() << "unhandled vector to mma type: " << *op;
1308             })
1309             .failed()) {
1310       return op->emitOpError()
1311              << "failed to convert op during vector-to-nvgpu conversion";
1312     }
1313   }
1314   return success();
1315 }
1316 
1317 namespace {
1318 
1319 struct ConvertVectorToGPUPass
1320     : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1321 
1322   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1323     useNvGpu.setValue(useNvGpu_);
1324   }
1325 
1326   void runOnOperation() override {
1327     RewritePatternSet patterns(&getContext());
1328     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1329     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1330       return signalPassFailure();
1331 
1332     IRRewriter rewriter(&getContext());
1333     if (useNvGpu) {
1334       if (failed(
1335               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1336         return signalPassFailure();
1337       return;
1338     }
1339     (void)convertVectorToMMAOps(rewriter, getOperation());
1340   }
1341 };
1342 
1343 } // namespace
1344 
1345 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1346   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1347 }
1348