xref: /llvm-project/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
26 #include "mlir/Dialect/Utils/IndexingUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44 
45 //===----------------------------------------------------------------------===//
46 // Apply...ConversionPatternsOp
47 //===----------------------------------------------------------------------===//
48 
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50     TypeConverter &typeConverter, RewritePatternSet &patterns) {
51   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52   /// device-side async tokens cannot be materialized in nvvm. We just
53   /// convert them to a dummy i32 type in order to easily drop them during
54   /// conversion.
55   populateGpuMemorySpaceAttributeConversions(
56       llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57         switch (space) {
58         case gpu::AddressSpace::Global:
59           return static_cast<unsigned>(
60               NVVM::NVVMMemorySpace::kGlobalMemorySpace);
61         case gpu::AddressSpace::Workgroup:
62           return static_cast<unsigned>(
63               NVVM::NVVMMemorySpace::kSharedMemorySpace);
64         case gpu::AddressSpace::Private:
65           return 0;
66         }
67         llvm_unreachable("unknown address space enum value");
68         return 0;
69       });
70   llvmTypeConverter.addConversion(
71       [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72         return llvmTypeConverter.convertType(
73             IntegerType::get(type.getContext(), 32));
74       });
75   llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76     return llvmTypeConverter.convertType(
77         IntegerType::get(type.getContext(), 64));
78   });
79   llvmTypeConverter.addConversion(
80       [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81         Type elemType = type.getFragmented().getElementType();
82         int64_t sizeM = type.getFragmented().getDimSize(0);
83         int64_t sizeN = type.getFragmented().getDimSize(1);
84 
85         unsigned numMembers;
86         if (elemType.isF32() || elemType.isInteger(32))
87           numMembers = sizeN / 2;
88         else if (elemType.isF16())
89           numMembers = sizeN / 4;
90         else
91           llvm_unreachable("unsupported type for warpgroup accumulator");
92 
93         SmallVector<Type> innerStructBody;
94         for (unsigned i = 0; i < numMembers; i++)
95           innerStructBody.push_back(elemType);
96         auto innerStructType = LLVM::LLVMStructType::getLiteral(
97             type.getContext(), innerStructBody);
98 
99         SmallVector<Type> structBody;
100         for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101           structBody.push_back(innerStructType);
102 
103         auto convertedType =
104             LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105         return llvmTypeConverter.convertType(convertedType);
106       });
107   llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108     return llvmTypeConverter.convertType(
109         getMBarrierMemrefType(type.getContext(), type));
110   });
111   llvmTypeConverter.addConversion(
112       [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113         return llvmTypeConverter.convertType(
114             IntegerType::get(type.getContext(), 64));
115       });
116   llvmTypeConverter.addConversion(
117       [&](nvgpu::TensorMapDescriptorType type) -> Type {
118         return LLVM::LLVMPointerType::get(type.getContext());
119       });
120   populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
121 }
122 
123 LogicalResult
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125     transform::TypeConverterBuilderOpInterface builder) {
126   if (builder.getTypeConverterType() != "LLVMTypeConverter")
127     return emitOpError("expected LLVMTypeConverter");
128   return success();
129 }
130 
131 //===---------------------------------------------------------------------===//
132 // CreateAsyncGroupsOp
133 //===---------------------------------------------------------------------===//
134 
135 void transform::CreateAsyncGroupsOp::getEffects(
136     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
137   transform::consumesHandle(getTargetMutable(), effects);
138   transform::producesHandle(getOperation()->getOpResults(), effects);
139   transform::modifiesPayload(effects);
140 }
141 
142 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143     TransformRewriter &rewriter, Operation *target,
144     ApplyToEachResultList &results, TransformState &state) {
145   nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146   results.push_back(target);
147   return DiagnosedSilenceableFailure::success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // PipelineSharedMemoryCopiesOp
152 //===----------------------------------------------------------------------===//
153 
154 /// Returns true if the given type has the default memory space.
155 static bool hasDefaultMemorySpace(BaseMemRefType type) {
156   return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157 }
158 
159 /// Returns true if the given type has the shared (workgroup) memory space.
160 static bool hasSharedMemorySpace(BaseMemRefType type) {
161   auto space =
162       dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163   return space &&
164          space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165 }
166 
167 /// Returns the value produced by a load from the default memory space. Returns
168 /// null if the operation is not such a load.
169 static Value getValueLoadedFromGlobal(Operation *op) {
170   // TODO: consider an interface or leveraging the memory effects interface.
171   auto load = dyn_cast<vector::TransferReadOp>(op);
172   if (!load)
173     return nullptr;
174 
175   auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
176   if (!loadType || !hasDefaultMemorySpace(loadType))
177     return nullptr;
178   return load;
179 }
180 
181 /// Returns true if the operation is storing the given value into shared memory.
182 static bool isStoreToShared(Operation *op, Value v) {
183   // TOD: consider an interface or leveraging the memory effects interface.
184   auto store = dyn_cast<vector::TransferWriteOp>(op);
185   if (!store || store.getVector() != v)
186     return false;
187 
188   auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
189   return storeType || hasSharedMemorySpace(storeType);
190 }
191 
192 /// Returns true if the operation is a load from the default memory space the
193 /// result of which is only stored into the shared memory space.
194 static bool isLoadFromGlobalStoredToShared(Operation *op) {
195   Value loaded = getValueLoadedFromGlobal(op);
196   if (!loaded || !loaded.hasOneUse())
197     return false;
198 
199   return isStoreToShared(*loaded.getUsers().begin(), loaded);
200 }
201 
202 /// Populate `ops` with the set of operations that belong to the stage 0 of the
203 /// pipelined version of the given loop when pipelining copies to shared memory.
204 /// Specifically, this collects:
205 ///
206 ///   1. all loads from global memory, both sync and async;
207 ///   2. the barriers for async loads.
208 ///
209 /// In particular, barriers are omitted if they do not dominate at least one
210 /// async load for which there is not yet a barrier.
211 static LogicalResult
212 collectStage0PipeliningOps(scf::ForOp forOp,
213                            llvm::SmallPtrSet<Operation *, 16> &ops) {
214 
215   llvm::SmallPtrSet<Operation *, 4> barriers;
216   for (Operation &op : *forOp.getBody()) {
217     // Bail on nested ops for now.
218     if (op.getNumRegions() > 0)
219       return failure();
220 
221     if (isa<gpu::BarrierOp>(op)) {
222       barriers.insert(&op);
223       continue;
224     }
225 
226     if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227       ops.insert(&op);
228       ops.insert(std::make_move_iterator(barriers.begin()),
229                  std::make_move_iterator(barriers.end()));
230       assert(barriers.empty() &&
231              "expected to have moved the barriers into another set");
232       continue;
233     }
234 
235     if (isLoadFromGlobalStoredToShared(&op)) {
236       ops.insert(&op);
237       continue;
238     }
239   }
240 
241   return success();
242 }
243 
244 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245 /// of async wait operations corresponding to pipelined shared memory copies.
246 // TODO: this currently assumes that there are no groups that could be in flight
247 // in the existing code.
248 static void
249 setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
250                            scf::PipeliningOption::PipelinerPart part,
251                            unsigned iteration, unsigned depth) {
252   // Based on the order of copies within the loop we need to set the number
253   // of copies in flight, unless it is already set.
254   auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255   if (!waitOp || waitOp.getNumGroups())
256     return;
257 
258   int numGroupInFlight = 0;
259   if (part == scf::PipeliningOption::PipelinerPart::Kernel ||
260       part == scf::PipeliningOption::PipelinerPart::Prologue) {
261     numGroupInFlight = depth - 1;
262   } else {
263     // By construction there should be no wait op in the prologue as all the
264     // wait should be in the last stage.
265     assert(part == scf::PipeliningOption::PipelinerPart::Epilogue);
266     // Based on the schedule we pick we know how many groups are in flight for
267     // each iteration of the epilogue.
268     numGroupInFlight = depth - 1 - iteration;
269   }
270   waitOp.setNumGroups(numGroupInFlight);
271 }
272 
273 /// Hook for the loop pipeliner that populates `ops` with the stage information
274 /// as follows:
275 ///
276 ///   - operations in `stage0Ops` (typically loads from global memory and
277 ///     related barriers) are at stage 0;
278 ///   - operations in the backward slice of any stage0Ops are all at stage 0;
279 ///   - other operations are at stage `depth`;
280 ///   - the internal order of the pipelined loop has ops at stage `depth` first,
281 ///   then those at stage 0, with relative order within each group preserved.
282 ///
283 static void getPipelineStages(
284     scf::ForOp forOp,
285     std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286     unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287   SetVector<Operation *> dependencies;
288   BackwardSliceOptions options([&](Operation *visited) {
289     return visited->getBlock() == forOp.getBody();
290   });
291   options.inclusive = true;
292   for (Operation &op : forOp.getBody()->getOperations()) {
293     if (stage0Ops.contains(&op))
294       getBackwardSlice(&op, &dependencies, options);
295   }
296 
297   for (Operation &op : forOp.getBody()->getOperations()) {
298     if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299       opsWithPipelineStages.emplace_back(&op, depth);
300   }
301   for (Operation &op : forOp.getBody()->getOperations()) {
302     if (dependencies.contains(&op))
303       opsWithPipelineStages.emplace_back(&op, 0);
304   }
305 }
306 
307 /// Hook for the loop pipeliner. Replaces op with a predicated version and
308 /// returns the resulting operation. Returns the original op if the predication
309 /// isn't necessary for the given op. Returns null if predication is needed but
310 /// not supported.
311 static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
312                                             Operation *op, Value predicate) {
313   // Some operations may be fine to execute "speculatively" more times than the
314   // original number of iterations, in particular side-effect free operations
315   // and barriers, even if they cannot be predicated.
316   if (isMemoryEffectFree(op) ||
317       isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318           nvgpu::DeviceAsyncWaitOp>(op)) {
319     return op;
320   }
321 
322   // Otherwise, only async copies can currently be predicated.
323   auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
324   if (!asyncCopyOp)
325     return nullptr;
326 
327   // Create srcElement Value based on `predicate`. The next lines generate
328   // the following code:
329   //
330   //   srcElement = (pred) ?  prevSrcElements : 0;
331   //
332   Location loc = asyncCopyOp->getLoc();
333   Value dstElements =
334       rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335   Value originalSrcElement =
336       asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337   Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
338   auto srcElements = rewriter.create<arith::SelectOp>(
339       loc, predicate, originalSrcElement, c0Index);
340   auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
341       loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
342       asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343       asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
344       UnitAttr());
345   rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346   return asyncCopyZeroFillOp;
347 }
348 
349 /// Applies loop pipelining with the given depth to the given loop so that
350 /// copies into the shared memory are pipelined. Doesn't affect other loops.
351 /// Returns a pair containing the error state and the pipelined op, the latter
352 /// being null in case of any failure. The error state contains a definite error
353 /// if the IR has been modified and a silenceable error otherwise.
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
355 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
356                         bool epiloguePeeling) {
357   llvm::SmallPtrSet<Operation *, 16> stage0Ops;
358   if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
359     return std::make_tuple(
360         emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
361         scf::ForOp());
362   }
363   if (stage0Ops.empty()) {
364     return std::make_tuple(
365         emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
366   }
367 
368   scf::PipeliningOption options;
369   unsigned maxDepth = depth;
370   auto setAnnotation = [&](Operation *op,
371                            scf::PipeliningOption::PipelinerPart part,
372                            unsigned iteration) {
373     return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
374   };
375   options.getScheduleFn =
376       [&](scf::ForOp schedulingFor,
377           std::vector<std::pair<Operation *, unsigned>> &ops) {
378         if (schedulingFor != forOp)
379           return;
380         return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
381       };
382   options.annotateFn = setAnnotation;
383   if (!epiloguePeeling) {
384     options.peelEpilogue = false;
385     options.predicateFn = replaceOpWithPredicatedOp;
386   }
387 
388   OpBuilder::InsertionGuard guard(rewriter);
389   rewriter.setInsertionPoint(forOp);
390   bool modifiedIR;
391   FailureOr<scf::ForOp> maybePipelined =
392       pipelineForLoop(rewriter, forOp, options, &modifiedIR);
393   if (succeeded(maybePipelined)) {
394     return std::make_tuple(DiagnosedSilenceableFailure::success(),
395                            *maybePipelined);
396   }
397   return std::make_tuple(
398       modifiedIR
399           ? DiagnosedSilenceableFailure::definiteFailure()
400           : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
401       scf::ForOp());
402 }
403 
404 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
405     TransformRewriter &rewriter, scf::ForOp forOp,
406     ApplyToEachResultList &results, TransformState &state) {
407   auto [diag, pipelined] = pipelineForSharedCopies(
408       rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
409   if (diag.succeeded()) {
410     results.push_back(pipelined);
411     return DiagnosedSilenceableFailure::success();
412   }
413   if (diag.isDefiniteFailure()) {
414     auto diag = emitDefiniteFailure("irreversible pipelining failure");
415     if (!getPeelEpilogue()) {
416       diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
417       diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
418     }
419     return diag;
420   }
421 
422   return std::move(diag);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RewriteMatmulAsMmaSyncOp
427 //===----------------------------------------------------------------------===//
428 
429 /// Helper struct to encode a pair of row/column indexings in the form of
430 /// affine expressions.
431 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
432   RowColIndexing(AffineExpr row, AffineExpr col)
433       : std::pair<AffineExpr, AffineExpr>(row, col) {}
434 
435   AffineExpr row() const { return first; };
436   AffineExpr col() const { return second; };
437 
438   void print(llvm::raw_ostream &os) const {
439     os << "- indexing: " << first << ", " << second;
440   }
441 };
442 
443 /// Helper struct to provide a simple mapping from matmul operations to the
444 /// corresponding mma.sync operation. This is constrained to the case where the
445 /// matmul matches the mma.sync operation 1-1.
446 struct MmaSyncBuilder {
447   MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
448       : b(b), loc(loc), laneId(laneId) {}
449 
450   using IndexCalculator =
451       std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
452 
453   /// Create the mma.sync operation corresponding to `linalgOp` along with all
454   /// the supporting load/store and vector operations.
455   FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
456 
457 private:
458   struct MmaSyncInfo {
459     std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
460     std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
461         vectorShapes;
462     SmallVector<int64_t> mmaShape;
463     bool tf32Enabled;
464   };
465 
466   /// Return the specific index calculator for the given `linalgOp` or failure
467   /// if the op is not supported. This is the toplevel switch that should just
468   /// be Tablegen'd in the future.
469   FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
470                                              TypeRange elementalTypes);
471 
472   //===--------------------------------------------------------------------===//
473   // Instruction-specific row, column indexing expression builders.
474   // These should all be declaratively specified via Tablegen in the future.
475   // The Tablegen specification should be as straightforward as possible to
476   // only model the existing size and type combinations.
477   //===--------------------------------------------------------------------===//
478   //
479   // TODO: Tablegen all this.
480   //===--------------------------------------------------------------------===//
481   // m16n8k4 tf32 case.
482   //===--------------------------------------------------------------------===//
483   /// From the NVIDIA doc:
484   /// groupID           = %laneid >> 2
485   /// threadIDInGroup = %laneid % 4
486   /// row =      groupID            for a0
487   ///            groupID + 8        for a1
488   /// col =  threadIDInGroup
489   static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
490     auto dim = getAffineDimExpr(0, ctx);
491     AffineExpr groupID = dim.floorDiv(4);
492     AffineExpr threadIDInGroup = dim % 4;
493     return {RowColIndexing{groupID, threadIDInGroup},
494             RowColIndexing{groupID + 8, threadIDInGroup}};
495   }
496 
497   /// From the NVIDIA doc:
498   /// groupID           = %laneid >> 2
499   /// threadIDInGroup = %laneid % 4
500   /// row =  threadIDInGroup
501   /// col =  groupID
502   static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
503     auto dim = getAffineDimExpr(0, ctx);
504     AffineExpr groupID = dim.floorDiv(4);
505     AffineExpr threadIDInGroup = dim % 4;
506     return {RowColIndexing{threadIDInGroup, groupID}};
507   }
508 
509   /// From the NVIDIA doc:
510   /// groupID          = %laneid >> 2
511   /// threadIDInGroup = %laneid % 4
512   /// row =      groupID                            for c0 and c1
513   ///          groupID + 8                          for c2 and c3
514   /// col =  (threadIDInGroup * 2) + (i & 0x1)    for ci   where i = {0,..,3}
515   static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
516     auto dim = getAffineDimExpr(0, ctx);
517     AffineExpr groupID = dim.floorDiv(4);
518     AffineExpr threadIDInGroup = dim % 4;
519     return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
520             RowColIndexing{groupID, threadIDInGroup * 2 + 1},
521             RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
522             RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
523   }
524 
525   //===--------------------------------------------------------------------===//
526   // m16n8k16 f16 case.
527   //===--------------------------------------------------------------------===//
528   /// From the NVIDIA doc:
529   /// groupID           = %laneid >> 2
530   /// threadIDInGroup = %laneid % 4
531   ///
532   /// row =      groupID            for ai where  0 <= i < 2 || 4 <= i < 6
533   ///           groupID + 8         Otherwise
534   ///
535   /// col =  (threadIDInGroup * 2) + (i & 0x1)          for ai where i <  4
536   ///        (threadIDInGroup * 2) + (i & 0x1) + 8      for ai where i >= 4
537   static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
538     auto dim = getAffineDimExpr(0, ctx);
539     AffineExpr groupID = dim.floorDiv(4);
540     AffineExpr threadIDInGroup = dim % 4;
541     // clang-format off
542     return {
543       RowColIndexing{groupID, threadIDInGroup * 2 + 0},         // i == 0
544       RowColIndexing{groupID, threadIDInGroup * 2 + 1},         // i == 1
545       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},     // i == 2
546       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},     // i == 3
547       RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},     // i == 4
548       RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},     // i == 5
549       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
550       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}  // i == 7
551     };
552     // clang-format on
553   }
554 
555   /// From the NVIDIA doc:
556   /// groupID           = %laneid >> 2
557   /// threadIDInGroup = %laneid % 4
558   ///
559   /// row =  (threadIDInGroup * 2) + (i & 0x1)           for bi where i <  2
560   ///        (threadIDInGroup * 2) + (i & 0x1) + 8       for bi where i >= 2
561   ///
562   /// col = groupID
563   static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
564     auto dim = getAffineDimExpr(0, ctx);
565     AffineExpr groupID = dim.floorDiv(4);
566     AffineExpr threadIDInGroup = dim % 4;
567     // clang-format off
568     return {
569       RowColIndexing{threadIDInGroup * 2 + 0, groupID},        // i == 0
570       RowColIndexing{threadIDInGroup * 2 + 1, groupID},        // i == 1
571       RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},    // i == 2
572       RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}     // i == 3
573     };
574     // clang-format on
575   }
576 
577   /// From the NVIDIA doc:
578   /// groupID           = %laneid >> 2
579   /// threadIDInGroup = %laneid % 4
580   ///
581   /// row =      groupID                               for ci where i <  2
582   ///          groupID + 8                             for ci where i >= 2
583   ///
584   /// col =  (threadIDInGroup * 2) + (i & 0x1)      for ci where i = {0,..,3}
585   static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
586     auto dim = getAffineDimExpr(0, ctx);
587     AffineExpr groupID = dim.floorDiv(4);
588     AffineExpr threadIDInGroup = dim % 4;
589     // clang-format off
590     return {
591       RowColIndexing{groupID, threadIDInGroup * 2 + 0},        // i == 0
592       RowColIndexing{groupID, threadIDInGroup * 2 + 1},        // i == 1
593       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},    // i == 2
594       RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}     // i == 3
595     };
596     // clang-format on
597   }
598 
599   //===--------------------------------------------------------------------===//
600   /// Helper functions to create customizable load and stores operations. The
601   /// specific shapes of each MMA instruction are passed via the
602   /// IndexCalculator callback.
603   //===--------------------------------------------------------------------===//
604   /// Build a list of memref.load operations indexed at `(row, col)` indices
605   /// that make sense for a particular MMA instruction and specified via the
606   /// IndexCalculator callback.
607   SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
608                                       OpFoldResult laneId, Value memref,
609                                       const IndexCalculator &indexFn);
610 
611   /// Perform a distributed load of a vector operand of `vectorShape` for a
612   /// particular MMA instruction whose `(row, col)` indices are specified via
613   /// the IndexCalculator callback. Each `laneId` loads the subportion of the
614   /// data that makes sense for the particular MMA operation.
615   /// The `vectorShape` matches existing NVGPU dialect op specification but
616   /// could also be flattened in the future if needed for simplification.
617   Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
618                                       OpFoldResult laneId, Value memref,
619                                       IndexCalculator indexFn,
620                                       ArrayRef<int64_t> vectorShape);
621 
622   /// Build a list of memref.store operations indexed at `(row, col)` indices
623   /// that make sense for a particular MMA instruction and specified via the
624   /// IndexCalculator callback.
625   SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
626                                              ValueRange toStore,
627                                              OpFoldResult laneId, Value memref,
628                                              const IndexCalculator &indexFn);
629 
630   /// Perform a distributed store of a vector operand of `vectorShape` for a
631   /// particular MMA instruction whose `(row, col)` indices are specified via
632   /// the IndexCalculator callback. Each `laneId` loads the subportion of the
633   /// data that makes sense for the particular MMA operation.
634   /// The `vectorShape` matches existing NVGPU dialect op specification but
635   /// could also be flattened in the future if needed for simplification.
636   SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
637       OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
638       Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
639 
640   OpBuilder &b;
641   Location loc;
642   OpFoldResult laneId;
643 };
644 
645 //===--------------------------------------------------------------------===//
646 /// Helper functions to create customizable load and stores operations. The
647 /// specific shapes of each MMA instruction are passed via the
648 /// IndexCalculator callback.
649 //===--------------------------------------------------------------------===//
650 
651 template <typename ApplyFn, typename ReduceFn>
652 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
653                                            ReduceFn reduceFn) {
654   VectorType vectorType = cast<VectorType>(vector.getType());
655   auto vectorShape = vectorType.getShape();
656   auto strides = computeStrides(vectorShape);
657   for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
658     auto indices = delinearize(idx, strides);
659     reduceFn(applyFn(vector, idx, indices), idx, indices);
660   }
661 }
662 
663 SmallVector<Value>
664 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
665                                  OpFoldResult laneId, Value memref,
666                                  const IndexCalculator &indexFn) {
667   auto aff = [&](AffineExpr e) {
668     return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
669   };
670   SmallVector<Value> res;
671   SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
672   for (auto indexing : indexings) {
673     Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
674     Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
675     auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
676     res.push_back(load);
677   }
678   return res;
679 }
680 
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682     OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683     IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
684   auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
685 
686   Type elementType = getElementTypeOrSelf(memref.getType());
687   auto vt = VectorType::get(vectorShape, elementType);
688   Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
689   foreachIndividualVectorElement(
690       res,
691       /*applyFn=*/
692       [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
693         return loads[linearIdx];
694       },
695       /*reduceFn=*/
696       [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
697         res = b.create<vector::InsertOp>(loc, v, res, indices);
698       });
699 
700   return res;
701 }
702 
703 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
704     OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705     Value memref, const IndexCalculator &indexFn) {
706   auto aff = [&](AffineExpr e) {
707     return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
708   };
709   SmallVector<Operation *> res;
710   for (auto [indexing, val] :
711        llvm::zip_equal(indexFn(b.getContext()), toStore)) {
712     Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
713     Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
714     Operation *store =
715         b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
716     res.push_back(store);
717   }
718   return res;
719 }
720 
721 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
722     OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
723     Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
724   SmallVector<Value> toStore;
725   toStore.reserve(32);
726   foreachIndividualVectorElement(
727       vectorToStore,
728       /*applyFn=*/
729       [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
730         return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
731       },
732       /*reduceFn=*/
733       [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
734         toStore.push_back(v);
735       });
736   return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
737 }
738 
739 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
740                   SmallVector<int64_t>>
741 makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
742                  ArrayRef<int64_t> res) {
743   SmallVector<int64_t> vlhs(lhs);
744   SmallVector<int64_t> vrhs(rhs);
745   SmallVector<int64_t> vres(res);
746   return std::make_tuple(vlhs, vrhs, vres);
747 }
748 
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
750 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
751                                     TypeRange elementalTypes) {
752   // TODO: Tablegen all this.
753   Type f16 = b.getF16Type();
754   Type f32 = b.getF32Type();
755   if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
756       elementalTypes == TypeRange{f32, f32, f32}) {
757     return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758                                        &MmaSyncBuilder::m16n8k4tf32Rhs,
759                                        &MmaSyncBuilder::m16n8k4tf32Res),
760                        makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
761                        SmallVector<int64_t>{opShape},
762                        /*tf32Enabled=*/true};
763   }
764   // This is the version with f16 accumulation.
765   // TODO: version with f32 accumulation.
766   if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
767       elementalTypes == TypeRange{f16, f16, f16}) {
768     return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769                                        &MmaSyncBuilder::m16n8k16f16Rhs,
770                                        &MmaSyncBuilder::m16n8k16f16Res),
771                        makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
772                        SmallVector<int64_t>{opShape},
773                        /*tf32Enabled=*/false};
774   }
775   return failure();
776 }
777 
778 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
779   Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780   Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781   Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782   assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
783          "expected lhs to be a 2D memref");
784   assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
785          "expected rhs to be a 2D memref");
786   assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
787          "expected res to be a 2D memref");
788 
789   int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
790   int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
791   int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
792   Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
793   Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
794   Type resType = getElementTypeOrSelf(resMemRef.getType());
795 
796   FailureOr<MmaSyncInfo> maybeInfo =
797       getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798   if (failed(maybeInfo))
799     return failure();
800 
801   MmaSyncInfo info = *maybeInfo;
802   auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803   auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804   Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805                                             lhsIndexFn, lhsShape);
806   Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807                                             rhsIndexFn, rhsShape);
808   Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809                                             resIndexFn, resShape);
810   res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
811                                    info.tf32Enabled);
812   buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
813                                  resShape);
814   return res.getDefiningOp();
815 }
816 
817 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
818     transform::TransformRewriter &rewriter, LinalgOp linalgOp,
819     transform::ApplyToEachResultList &results,
820     transform::TransformState &state) {
821   bool fail = true;
822   // TODO: more robust detection of matmulOp, with transposes etc.
823   if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824     // Check to not let go the matmul with extended semantic, through this
825     // transform.
826     if (linalgOp.hasUserDefinedMaps()) {
827       return emitSilenceableError()
828              << "only matmul ops with non-extended semantics are supported";
829     }
830     Location loc = linalgOp.getLoc();
831     // TODO: more robust computation of laneId, for now assume a single warp.
832     Value laneId = rewriter.create<gpu::ThreadIdOp>(
833         loc, rewriter.getIndexType(), gpu::Dimension::x);
834     if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
835       fail = false;
836   }
837 
838   if (fail) {
839     DiagnosedSilenceableFailure diag = emitSilenceableError()
840                                        << "unsupported target op: " << linalgOp;
841     diag.attachNote(linalgOp->getLoc()) << "target op";
842     return diag;
843   }
844 
845   rewriter.eraseOp(linalgOp);
846   return DiagnosedSilenceableFailure::success();
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // Hopper builders.
851 //===----------------------------------------------------------------------===//
852 
853 /// Helper to create the base Hopper-specific operations that are reused in
854 /// various other places.
855 struct HopperBuilder {
856   HopperBuilder(RewriterBase &rewriter, Location loc)
857       : rewriter(rewriter), loc(loc) {}
858 
859   TypedValue<nvgpu::MBarrierGroupType>
860   buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
861 
862   /// Create tma descriptor op to initiate transfer from global to shared
863   /// memory. This must be done before the launch op, on the host.
864   TypedValue<nvgpu::TensorMapDescriptorType>
865   buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
866                               gpu::LaunchOp launchOp);
867 
868   /// Build a tma load from global memory to shared memory using `barrier` to
869   /// synchronize. Return the number of bytes that will be transferred.
870   OpFoldResult
871   buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
872                     TypedValue<MemRefType> sharedMemref,
873                     TypedValue<nvgpu::MBarrierGroupType> barrier,
874                     SmallVectorImpl<Operation *> &loadOps);
875   void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
876                             ArrayRef<OpFoldResult> sizes);
877 
878   /// If threadIdx.x == 0 does TMA request + wait, else just wait.
879   /// Return the operation that performs the transfer on thread0.
880   // TODO: In the future, don't hardcode to thread 0 but elect a leader.
881   SmallVector<Operation *> buildPredicateLoadsOnThread0(
882       ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
883       ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
884       TypedValue<nvgpu::MBarrierGroupType> barrier);
885 
886   void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
887 
888   RewriterBase &rewriter;
889   Location loc;
890 };
891 
892 SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
893     ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
894     ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
895     TypedValue<nvgpu::MBarrierGroupType> barrier) {
896   SmallVector<Operation *> loadOps;
897   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
898   Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
899   Value cond =
900       rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
901   // clang-format off
902   rewriter.create<scf::IfOp>(
903     /*location=*/loc,
904     /*conditional=*/cond,
905     /*thenBuilder=*/
906     [&](OpBuilder &lb, Location loc) {
907       SmallVector<OpFoldResult> sizes;
908       sizes.reserve(globalDescriptors.size());
909       for (auto [desc, shmem] : llvm::zip_equal(
910               globalDescriptors, sharedMemBuffers)) {
911         OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
912         sizes.push_back(sz);
913       }
914       // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
915       // This may or may not have perf implications.
916       buildBarrierArriveTx(barrier, sizes);
917       rewriter.create<scf::YieldOp>(loc);
918     },
919     /*elseBuilder=*/
920     [&](OpBuilder &lb, Location loc) {
921       // TODO: is this for no-thread divergence?
922       // Should we just yield the size and hoist?
923       buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
924       rewriter.create<scf::YieldOp>(loc);
925     });
926   // clang-format on
927   return loadOps;
928 }
929 
930 static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
931   return gpu::AddressSpaceAttr::get(
932       b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
933   // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
934 }
935 
936 TypedValue<nvgpu::MBarrierGroupType>
937 HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
938   auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
939   Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
940       loc,
941       nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
942   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
943   rewriter.create<nvgpu::MBarrierInitOp>(
944       loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
945       zero, Value());
946   rewriter.create<gpu::BarrierOp>(loc);
947   return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
948 }
949 
950 TypedValue<nvgpu::TensorMapDescriptorType>
951 HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
952                                            gpu::LaunchOp launchOp) {
953   OpBuilder::InsertionGuard guard(rewriter);
954   rewriter.setInsertionPoint(launchOp);
955   Value unrankedMemRef = rewriter.create<memref::CastOp>(
956       loc,
957       UnrankedMemRefType::get(memref.getType().getElementType(),
958                               memref.getType().getMemorySpace()),
959       memref);
960   SmallVector<OpFoldResult> mixedSizes =
961       memref::getMixedSizes(rewriter, loc, memref);
962   SmallVector<Value> sizes =
963       getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
964 
965   auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
966   Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
967       loc,
968       nvgpu::TensorMapDescriptorType::get(
969           rewriter.getContext(),
970           MemRefType::Builder(memref.getType())
971               .setMemorySpace(sharedMemorySpace),
972           TensorMapSwizzleKind::SWIZZLE_NONE,
973           TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
974           TensorMapInterleaveKind::INTERLEAVE_NONE),
975       unrankedMemRef, sizes);
976   return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
977 }
978 
979 OpFoldResult HopperBuilder::buildTmaAsyncLoad(
980     TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
981     TypedValue<MemRefType> sharedMemref,
982     TypedValue<nvgpu::MBarrierGroupType> barrier,
983     SmallVectorImpl<Operation *> &loadOps) {
984   MLIRContext *ctx = rewriter.getContext();
985   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
986   Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
987       loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
988       Value(), Value());
989   loadOps.push_back(loadOp);
990   auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
991   SmallVector<AffineExpr> symbols(mixedSizes.size());
992   bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
993   AffineExpr prodExprInBytes =
994       computeProduct(ctx, symbols) *
995       (sharedMemref.getType().getElementTypeBitWidth() / 8);
996   auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
997                                                    prodExprInBytes, mixedSizes);
998   return res;
999 }
1000 
1001 void HopperBuilder::buildBarrierArriveTx(
1002     TypedValue<nvgpu::MBarrierGroupType> barrier,
1003     ArrayRef<OpFoldResult> mixedSizes) {
1004   assert(!mixedSizes.empty() && "expecte non-empty sizes");
1005   MLIRContext *ctx = rewriter.getContext();
1006   SmallVector<AffineExpr> symbols(mixedSizes.size());
1007   bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1008   AffineExpr sumExpr = computeSum(ctx, symbols);
1009   OpFoldResult size =
1010       affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1011   Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1012   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1013   rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1014                                                    Value());
1015 }
1016 
1017 void HopperBuilder::buildTryWaitParity(
1018     TypedValue<nvgpu::MBarrierGroupType> barrier) {
1019   Type i1 = rewriter.getI1Type();
1020   Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1021   // 10M is an arbitrary, not too small or too big number to specify the number
1022   // of ticks before retry.
1023   // TODO: hoist this in a default dialect constant.
1024   Value ticksBeforeRetry =
1025       rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
1026   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1027   rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1028                                                   ticksBeforeRetry, zero);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // RewriteCopyAsTmaOp
1033 //===----------------------------------------------------------------------===//
1034 
1035 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1036 struct CopyBuilder : public HopperBuilder {
1037   CopyBuilder(RewriterBase &rewriter, Location loc)
1038       : HopperBuilder(rewriter, loc) {}
1039 
1040   SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
1041 };
1042 
1043 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1044   MLIRContext *ctx = rewriter.getContext();
1045   if (copyOps.empty())
1046     return SmallVector<Operation *>();
1047 
1048   auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1049   assert(launchOp && "expected launch op");
1050 
1051   // 1. Init a barrier object in shared memory.
1052   OpBuilder::InsertionGuard g(rewriter);
1053   rewriter.setInsertionPoint(copyOps.front());
1054   AffineExpr bx, by, bz;
1055   bindSymbols(ctx, bx, by, bz);
1056   AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1057   OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
1058       rewriter, loc, prod,
1059       ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1060                              launchOp.getBlockSizeZ()});
1061 
1062   TypedValue<nvgpu::MBarrierGroupType> barrier =
1063       buildAndInitBarrierInSharedMemory(numThreads);
1064 
1065   SmallVector<TypedValue<MemRefType>> shmems;
1066   SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
1067   for (Operation *op : copyOps) {
1068     auto copyOp = cast<linalg::CopyOp>(op);
1069     auto inMemRef =
1070         cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1071     assert(inMemRef.getType().getRank() == 2 &&
1072            "expected in to be a 2D memref");
1073 
1074     // 2. Build global memory descriptor.
1075     TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
1076         buildGlobalMemRefDescriptor(inMemRef, launchOp);
1077     globalDescs.push_back(globalDesc);
1078 
1079     // 3. Shared memory and descriptor for the tmp array.
1080     auto shmem =
1081         cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1082     shmems.push_back(shmem);
1083   }
1084 
1085   // 4. Load in from global memory to shared memory using tma.
1086   OpBuilder::InsertionGuard g2(rewriter);
1087   rewriter.setInsertionPoint(copyOps.front());
1088   SmallVector<Operation *> results =
1089       buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1090 
1091   // 5. Spin-loop until data is ready.
1092   buildTryWaitParity(barrier);
1093 
1094   // 6. Erase the ops that have now been rewritten.
1095   for (Operation *op : copyOps)
1096     rewriter.eraseOp(op);
1097 
1098   return results;
1099 }
1100 
1101 DiagnosedSilenceableFailure
1102 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1103                                      transform::TransformResults &results,
1104                                      transform::TransformState &state) {
1105   auto payloadOps = state.getPayloadOps(getTarget());
1106   gpu::LaunchOp commonLaunchOp;
1107   Operation *firstOp, *failingOp;
1108   if (llvm::any_of(payloadOps, [&](Operation *op) {
1109         if (!commonLaunchOp) {
1110           commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1111           firstOp = op;
1112         }
1113         auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1114                     commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1115                     !isa<linalg::CopyOp>(op);
1116         if (fail)
1117           failingOp = op;
1118         return fail;
1119       })) {
1120     DiagnosedSilenceableFailure diag =
1121         emitSilenceableError()
1122         << "target ops must be linalg::CopyOp nested under a common "
1123            "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1124            "be created on the host.\nBut got: "
1125         << *firstOp << "\nand " << *failingOp;
1126     return diag;
1127   }
1128 
1129   // TODO: more robust detection of copy, with transposes etc.
1130   CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1131 
1132   return DiagnosedSilenceableFailure::success();
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // Transform op registration
1137 //===----------------------------------------------------------------------===//
1138 
1139 namespace {
1140 class NVGPUTransformDialectExtension
1141     : public transform::TransformDialectExtension<
1142           NVGPUTransformDialectExtension> {
1143 public:
1144   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1145 
1146   NVGPUTransformDialectExtension() {
1147     declareGeneratedDialect<arith::ArithDialect>();
1148     declareGeneratedDialect<affine::AffineDialect>();
1149     declareGeneratedDialect<nvgpu::NVGPUDialect>();
1150     declareGeneratedDialect<NVVM::NVVMDialect>();
1151     declareGeneratedDialect<vector::VectorDialect>();
1152     registerTransformOps<
1153 #define GET_OP_LIST
1154 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1155         >();
1156   }
1157 };
1158 } // namespace
1159 
1160 #define GET_OP_CLASSES
1161 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1162 
1163 void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry &registry) {
1164   registry.addExtensions<NVGPUTransformDialectExtension>();
1165 }
1166