xref: /llvm-project/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/DialectImplementation.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
30 #include "mlir/Interfaces/FunctionImplementation.h"
31 #include "mlir/Interfaces/SideEffectInterfaces.h"
32 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
33 #include "mlir/Transforms/InliningUtils.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
39 #include <cassert>
40 #include <numeric>
41 
42 using namespace mlir;
43 using namespace mlir::gpu;
44 
45 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // GPU Device Mapping Attributes
49 //===----------------------------------------------------------------------===//
50 
51 int64_t GPUBlockMappingAttr::getMappingId() const {
52   return static_cast<int64_t>(getBlock());
53 }
54 
55 bool GPUBlockMappingAttr::isLinearMapping() const {
56   return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
57 }
58 
59 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
60   return isLinearMapping()
61              ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
62              : getMappingId();
63 }
64 
65 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
66   return static_cast<int64_t>(getWarpgroup());
67 }
68 
69 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
70   return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
71 }
72 
73 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
74   return isLinearMapping()
75              ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
76              : getMappingId();
77 }
78 
79 int64_t GPUWarpMappingAttr::getMappingId() const {
80   return static_cast<int64_t>(getWarp());
81 }
82 
83 bool GPUWarpMappingAttr::isLinearMapping() const {
84   return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
85 }
86 
87 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
88   return isLinearMapping()
89              ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
90              : getMappingId();
91 }
92 
93 int64_t GPUThreadMappingAttr::getMappingId() const {
94   return static_cast<int64_t>(getThread());
95 }
96 
97 bool GPUThreadMappingAttr::isLinearMapping() const {
98   return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
99 }
100 
101 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
102   return isLinearMapping()
103              ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
104              : getMappingId();
105 }
106 
107 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
108   return static_cast<int64_t>(getAddressSpace());
109 }
110 
111 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
112   llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
113 }
114 
115 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
116   llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // MMAMatrixType
121 //===----------------------------------------------------------------------===//
122 
123 MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
124                                  StringRef operand) {
125   return Base::get(elementType.getContext(), shape, elementType, operand);
126 }
127 
128 MMAMatrixType
129 MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
130                           ArrayRef<int64_t> shape, Type elementType,
131                           StringRef operand) {
132   return Base::getChecked(emitError, elementType.getContext(), shape,
133                           elementType, operand);
134 }
135 
136 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
137 
138 ArrayRef<int64_t> MMAMatrixType::getShape() const {
139   return getImpl()->getShape();
140 }
141 
142 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
143 
144 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
145 
146 bool MMAMatrixType::isValidElementType(Type elementType) {
147   return elementType.isF16() || elementType.isF32() ||
148          elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
149          elementType.isInteger(32);
150 }
151 
152 LogicalResult
153 MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
154                                 ArrayRef<int64_t> shape, Type elementType,
155                                 StringRef operand) {
156   if (operand != "AOp" && operand != "BOp" && operand != "COp")
157     return emitError() << "operand expected to be one of AOp, BOp or COp";
158 
159   if (shape.size() != 2)
160     return emitError() << "MMAMatrixType must have exactly two dimensions";
161 
162   if (!MMAMatrixType::isValidElementType(elementType))
163     return emitError()
164            << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
165 
166   return success();
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // GPUDialect
171 //===----------------------------------------------------------------------===//
172 
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
174   if (!memorySpace)
175     return false;
176   if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177     return gpuAttr.getValue() == getWorkgroupAddressSpace();
178   return false;
179 }
180 
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182   Attribute memorySpace = type.getMemorySpace();
183   return isWorkgroupMemoryAddressSpace(memorySpace);
184 }
185 
186 bool GPUDialect::isKernel(Operation *op) {
187   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188   return static_cast<bool>(isKernelAttr);
189 }
190 
191 namespace {
192 /// This class defines the interface for handling inlining with gpu
193 /// operations.
194 struct GPUInlinerInterface : public DialectInlinerInterface {
195   using DialectInlinerInterface::DialectInlinerInterface;
196 
197   /// All gpu dialect ops can be inlined.
198   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
199     return true;
200   }
201 };
202 } // namespace
203 
204 void GPUDialect::initialize() {
205   addTypes<AsyncTokenType>();
206   addTypes<MMAMatrixType>();
207   addTypes<SparseDnTensorHandleType>();
208   addTypes<SparseSpMatHandleType>();
209   addTypes<SparseSpGEMMOpHandleType>();
210   addOperations<
211 #define GET_OP_LIST
212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
213       >();
214   addAttributes<
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
217       >();
218   addInterfaces<GPUInlinerInterface>();
219   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
220                            TerminatorOp>();
221   declarePromisedInterfaces<
222       ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
223       ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
224       SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
225 }
226 
227 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
228   switch (kind) {
229   case SparseHandleKind::DnTensor:
230     return "sparse.dntensor_handle";
231   case SparseHandleKind::SpMat:
232     return "sparse.spmat_handle";
233   case SparseHandleKind::SpGEMMOp:
234     return "sparse.spgemmop_handle";
235   }
236   llvm_unreachable("unknown sparse handle kind");
237   return "";
238 }
239 
240 Type GPUDialect::parseType(DialectAsmParser &parser) const {
241   // Parse the main keyword for the type.
242   StringRef keyword;
243   if (parser.parseKeyword(&keyword))
244     return Type();
245   MLIRContext *context = getContext();
246 
247   // Handle 'async token' types.
248   if (keyword == "async.token")
249     return AsyncTokenType::get(context);
250 
251   if (keyword == "mma_matrix") {
252     SMLoc beginLoc = parser.getNameLoc();
253 
254     // Parse '<'.
255     if (parser.parseLess())
256       return nullptr;
257 
258     // Parse the size and elementType.
259     SmallVector<int64_t> shape;
260     Type elementType;
261     if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
262         parser.parseType(elementType))
263       return nullptr;
264 
265     // Parse ','
266     if (parser.parseComma())
267       return nullptr;
268 
269     // Parse operand.
270     std::string operand;
271     if (failed(parser.parseOptionalString(&operand)))
272       return nullptr;
273 
274     // Parse '>'.
275     if (parser.parseGreater())
276       return nullptr;
277 
278     return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
279                                          parser.getEncodedSourceLoc(beginLoc)),
280                                      shape, elementType, operand);
281   }
282 
283   if (keyword == getSparseHandleKeyword(SparseHandleKind::DnTensor))
284     return SparseDnTensorHandleType::get(context);
285   if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat))
286     return SparseSpMatHandleType::get(context);
287   if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp))
288     return SparseSpGEMMOpHandleType::get(context);
289 
290   parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
291   return Type();
292 }
293 // TODO: print refined type here. Notice that should be corresponding to the
294 // parser
295 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
296   TypeSwitch<Type>(type)
297       .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
298       .Case<SparseDnTensorHandleType>([&](Type) {
299         os << getSparseHandleKeyword(SparseHandleKind::DnTensor);
300       })
301       .Case<SparseSpMatHandleType>(
302           [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); })
303       .Case<SparseSpGEMMOpHandleType>([&](Type) {
304         os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp);
305       })
306       .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
307         os << "mma_matrix<";
308         auto shape = fragTy.getShape();
309         for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
310           os << *dim << 'x';
311         os << shape.back() << 'x' << fragTy.getElementType();
312         os << ", \"" << fragTy.getOperand() << "\"" << '>';
313       })
314       .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
315 }
316 
317 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
318                                                NamedAttribute attr) {
319   auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
320   if (!array)
321     return op->emitOpError(Twine(attr.getName()) +
322                            " must be a dense i32 array");
323   if (array.size() != 3)
324     return op->emitOpError(Twine(attr.getName()) +
325                            " must contain exactly 3 elements");
326   return success();
327 }
328 
329 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
330                                                    NamedAttribute attr) {
331   if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
332     return verifyKnownLaunchSizeAttr(op, attr);
333   if (attr.getName() == getKnownGridSizeAttrHelper().getName())
334     return verifyKnownLaunchSizeAttr(op, attr);
335   if (!llvm::isa<UnitAttr>(attr.getValue()) ||
336       attr.getName() != getContainerModuleAttrName())
337     return success();
338 
339   auto module = dyn_cast<ModuleOp>(op);
340   if (!module)
341     return op->emitError("expected '")
342            << getContainerModuleAttrName() << "' attribute to be attached to '"
343            << ModuleOp::getOperationName() << '\'';
344 
345   auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
346     // Ignore launches that are nested more or less deep than functions in the
347     // module we are currently checking.
348     if (!launchOp->getParentOp() ||
349         launchOp->getParentOp()->getParentOp() != module)
350       return success();
351 
352     // Ignore launch ops with missing attributes here. The errors will be
353     // reported by the verifiers of those ops.
354     if (!launchOp->getAttrOfType<SymbolRefAttr>(
355             LaunchFuncOp::getKernelAttrName(launchOp->getName())))
356       return success();
357 
358     // Check that `launch_func` refers to a well-formed GPU kernel container.
359     StringAttr kernelContainerName = launchOp.getKernelModuleName();
360     Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
361     if (!kernelContainer)
362       return launchOp.emitOpError()
363              << "kernel container '" << kernelContainerName.getValue()
364              << "' is undefined";
365 
366     // If the container is a GPU binary op return success.
367     if (isa<BinaryOp>(kernelContainer))
368       return success();
369 
370     auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
371     if (!kernelModule)
372       return launchOp.emitOpError()
373              << "kernel module '" << kernelContainerName.getValue()
374              << "' is undefined";
375 
376     // Check that `launch_func` refers to a well-formed kernel function.
377     Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
378     if (!kernelFunc)
379       return launchOp.emitOpError("kernel function '")
380              << launchOp.getKernel() << "' is undefined";
381     auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
382     if (!kernelConvertedFunction) {
383       InFlightDiagnostic diag = launchOp.emitOpError()
384                                 << "referenced kernel '" << launchOp.getKernel()
385                                 << "' is not a function";
386       diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
387       return diag;
388     }
389 
390     if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
391             GPUDialect::getKernelFuncAttrName()))
392       return launchOp.emitOpError("kernel function is missing the '")
393              << GPUDialect::getKernelFuncAttrName() << "' attribute";
394 
395     // TODO: If the kernel isn't a GPU function (which happens during separate
396     // compilation), do not check type correspondence as it would require the
397     // verifier to be aware of the type conversion.
398     auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
399     if (!kernelGPUFunction)
400       return success();
401 
402     unsigned actualNumArguments = launchOp.getNumKernelOperands();
403     unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
404     if (expectedNumArguments != actualNumArguments)
405       return launchOp.emitOpError("got ")
406              << actualNumArguments << " kernel operands but expected "
407              << expectedNumArguments;
408 
409     auto functionType = kernelGPUFunction.getFunctionType();
410     for (unsigned i = 0; i < expectedNumArguments; ++i) {
411       if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
412         return launchOp.emitOpError("type of function argument ")
413                << i << " does not match";
414       }
415     }
416 
417     return success();
418   });
419 
420   return walkResult.wasInterrupted() ? failure() : success();
421 }
422 
423 /// Parses an optional list of async operands with an optional leading keyword.
424 /// (`async`)? (`[` ssa-id-list `]`)?
425 ///
426 /// This method is used by the tablegen assembly format for async ops as well.
427 static ParseResult parseAsyncDependencies(
428     OpAsmParser &parser, Type &asyncTokenType,
429     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
430   auto loc = parser.getCurrentLocation();
431   if (succeeded(parser.parseOptionalKeyword("async"))) {
432     if (parser.getNumResults() == 0)
433       return parser.emitError(loc, "needs to be named when marked 'async'");
434     asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
435   }
436   return parser.parseOperandList(asyncDependencies,
437                                  OpAsmParser::Delimiter::OptionalSquare);
438 }
439 
440 /// Prints optional async dependencies with its leading keyword.
441 ///   (`async`)? (`[` ssa-id-list `]`)?
442 // Used by the tablegen assembly format for several async ops.
443 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
444                                    Type asyncTokenType,
445                                    OperandRange asyncDependencies) {
446   if (asyncTokenType)
447     printer << "async";
448   if (asyncDependencies.empty())
449     return;
450   if (asyncTokenType)
451     printer << ' ';
452   printer << '[';
453   llvm::interleaveComma(asyncDependencies, printer);
454   printer << ']';
455 }
456 
457 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
458 /// Parses a GPU function memory attribution.
459 ///
460 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
461 ///                        (`private` `(` ssa-id-and-type-list `)`)?
462 ///
463 /// Note that this function parses only one of the two similar parts, with the
464 /// keyword provided as argument.
465 static ParseResult
466 parseAttributions(OpAsmParser &parser, StringRef keyword,
467                   SmallVectorImpl<OpAsmParser::Argument> &args) {
468   // If we could not parse the keyword, just assume empty list and succeed.
469   if (failed(parser.parseOptionalKeyword(keyword)))
470     return success();
471 
472   return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
473                                   /*allowType=*/true);
474 }
475 
476 /// Prints a GPU function memory attribution.
477 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
478                               ArrayRef<BlockArgument> values) {
479   if (values.empty())
480     return;
481 
482   p << ' ' << keyword << '(';
483   llvm::interleaveComma(
484       values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
485   p << ')';
486 }
487 
488 /// Verifies a GPU function memory attribution.
489 static LogicalResult verifyAttributions(Operation *op,
490                                         ArrayRef<BlockArgument> attributions,
491                                         gpu::AddressSpace memorySpace) {
492   for (Value v : attributions) {
493     auto type = llvm::dyn_cast<MemRefType>(v.getType());
494     if (!type)
495       return op->emitOpError() << "expected memref type in attribution";
496 
497     // We can only verify the address space if it hasn't already been lowered
498     // from the AddressSpaceAttr to a target-specific numeric value.
499     auto addressSpace =
500         llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
501     if (!addressSpace)
502       continue;
503     if (addressSpace.getValue() != memorySpace)
504       return op->emitOpError()
505              << "expected memory space " << stringifyAddressSpace(memorySpace)
506              << " in attribution";
507   }
508   return success();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // AllReduceOp
513 //===----------------------------------------------------------------------===//
514 
515 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
516                                            Type resType) {
517   using Kind = gpu::AllReduceOperation;
518   if (llvm::is_contained(
519           {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
520           opName)) {
521     if (!isa<FloatType>(resType))
522       return failure();
523   }
524 
525   if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
526                           Kind::AND, Kind::OR, Kind::XOR},
527                          opName)) {
528     if (!isa<IntegerType>(resType))
529       return failure();
530   }
531 
532   return success();
533 }
534 
535 LogicalResult gpu::AllReduceOp::verifyRegions() {
536   if (getBody().empty() != getOp().has_value())
537     return emitError("expected either an op attribute or a non-empty body");
538   if (!getBody().empty()) {
539     if (getBody().getNumArguments() != 2)
540       return emitError("expected two region arguments");
541     for (auto argument : getBody().getArguments()) {
542       if (argument.getType() != getType())
543         return emitError("incorrect region argument type");
544     }
545     unsigned yieldCount = 0;
546     for (Block &block : getBody()) {
547       if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
548         if (yield.getNumOperands() != 1)
549           return emitError("expected one gpu.yield operand");
550         if (yield.getOperand(0).getType() != getType())
551           return emitError("incorrect gpu.yield type");
552         ++yieldCount;
553       }
554     }
555     if (yieldCount == 0)
556       return emitError("expected gpu.yield op in region");
557   } else {
558     gpu::AllReduceOperation opName = *getOp();
559     if (failed(verifyReduceOpAndType(opName, getType()))) {
560       return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
561                          << "` reduction operation is not compatible with type "
562                          << getType();
563     }
564   }
565 
566   return success();
567 }
568 
569 static bool canMakeGroupOpUniform(Operation *op) {
570   auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
571   if (!launchOp)
572     return false;
573 
574   Region &body = launchOp.getBody();
575   assert(!body.empty() && "Invalid region");
576 
577   // Only convert ops in gpu::launch entry block for now.
578   return op->getBlock() == &body.front();
579 }
580 
581 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
582   if (!getUniform() && canMakeGroupOpUniform(*this)) {
583     setUniform(true);
584     return getResult();
585   }
586 
587   return nullptr;
588 }
589 
590 // TODO: Support optional custom attributes (without dialect prefix).
591 static ParseResult parseAllReduceOperation(AsmParser &parser,
592                                            AllReduceOperationAttr &attr) {
593   StringRef enumStr;
594   if (!parser.parseOptionalKeyword(&enumStr)) {
595     std::optional<AllReduceOperation> op =
596         gpu::symbolizeAllReduceOperation(enumStr);
597     if (!op)
598       return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
599     attr = AllReduceOperationAttr::get(parser.getContext(), *op);
600   }
601   return success();
602 }
603 
604 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
605                                     AllReduceOperationAttr attr) {
606   if (attr)
607     attr.print(printer);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // SubgroupReduceOp
612 //===----------------------------------------------------------------------===//
613 
614 LogicalResult gpu::SubgroupReduceOp::verify() {
615   Type elemType = getType();
616   if (auto vecTy = dyn_cast<VectorType>(elemType)) {
617     if (vecTy.isScalable())
618       return emitOpError() << "is not compatible with scalable vector types";
619 
620     elemType = vecTy.getElementType();
621   }
622 
623   gpu::AllReduceOperation opName = getOp();
624   if (failed(verifyReduceOpAndType(opName, elemType))) {
625     return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
626                        << "` reduction operation is not compatible with type "
627                        << getType();
628   }
629 
630   auto clusterSize = getClusterSize();
631   if (clusterSize) {
632     uint32_t size = *clusterSize;
633     if (!llvm::isPowerOf2_32(size)) {
634       return emitOpError() << "cluster size " << size
635                            << " is not a power of two";
636     }
637   }
638 
639   uint32_t stride = getClusterStride();
640   if (stride != 1 && !clusterSize) {
641     return emitOpError() << "cluster stride can only be specified if cluster "
642                             "size is specified";
643   }
644   if (!llvm::isPowerOf2_32(stride)) {
645     return emitOpError() << "cluster stride " << stride
646                          << " is not a power of two";
647   }
648 
649   return success();
650 }
651 
652 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
653   if (getClusterSize() == 1)
654     return getValue();
655 
656   if (!getUniform() && canMakeGroupOpUniform(*this)) {
657     setUniform(true);
658     return getResult();
659   }
660 
661   return nullptr;
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // AsyncOpInterface
666 //===----------------------------------------------------------------------===//
667 
668 void gpu::addAsyncDependency(Operation *op, Value token) {
669   op->insertOperands(0, {token});
670   if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
671     return;
672   auto attrName =
673       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
674   auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
675 
676   // Async dependencies is the only variadic operand.
677   if (!sizeAttr)
678     return;
679 
680   SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
681   ++sizes.front();
682   op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // LaunchOp
687 //===----------------------------------------------------------------------===//
688 
689 void LaunchOp::build(OpBuilder &builder, OperationState &result,
690                      Value gridSizeX, Value gridSizeY, Value gridSizeZ,
691                      Value getBlockSizeX, Value getBlockSizeY,
692                      Value getBlockSizeZ, Value dynamicSharedMemorySize,
693                      Type asyncTokenType, ValueRange asyncDependencies,
694                      TypeRange workgroupAttributions,
695                      TypeRange privateAttributions, Value clusterSizeX,
696                      Value clusterSizeY, Value clusterSizeZ) {
697   OpBuilder::InsertionGuard g(builder);
698 
699   // Add a WorkGroup attribution attribute. This attribute is required to
700   // identify private attributions in the list of block argguments.
701   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
702                       builder.getI64IntegerAttr(workgroupAttributions.size()));
703 
704   // Add Op operands.
705   result.addOperands(asyncDependencies);
706   if (asyncTokenType)
707     result.types.push_back(builder.getType<AsyncTokenType>());
708 
709   // Add grid and block sizes as op operands, followed by the data operands.
710   result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
711                       getBlockSizeY, getBlockSizeZ});
712   if (clusterSizeX)
713     result.addOperands(clusterSizeX);
714   if (clusterSizeY)
715     result.addOperands(clusterSizeY);
716   if (clusterSizeZ)
717     result.addOperands(clusterSizeZ);
718   if (dynamicSharedMemorySize)
719     result.addOperands(dynamicSharedMemorySize);
720 
721   // Create a kernel body region with kNumConfigRegionAttributes + N memory
722   // attributions, where the first kNumConfigRegionAttributes arguments have
723   // `index` type and the rest have the same types as the data operands.
724   Region *kernelRegion = result.addRegion();
725   Block *body = builder.createBlock(kernelRegion);
726   // TODO: Allow passing in proper locations here.
727   for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
728     body->addArgument(builder.getIndexType(), result.location);
729   // Add WorkGroup & Private attributions to the region arguments.
730   for (Type argTy : workgroupAttributions)
731     body->addArgument(argTy, result.location);
732   for (Type argTy : privateAttributions)
733     body->addArgument(argTy, result.location);
734   // Fill OperandSegmentSize Attribute.
735   SmallVector<int32_t, 11> segmentSizes(11, 1);
736   segmentSizes.front() = asyncDependencies.size();
737   segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
738   segmentSizes[7] = clusterSizeX ? 1 : 0;
739   segmentSizes[8] = clusterSizeY ? 1 : 0;
740   segmentSizes[9] = clusterSizeZ ? 1 : 0;
741   result.addAttribute(getOperandSegmentSizeAttr(),
742                       builder.getDenseI32ArrayAttr(segmentSizes));
743 }
744 
745 KernelDim3 LaunchOp::getBlockIds() {
746   assert(!getBody().empty() && "LaunchOp body must not be empty.");
747   auto args = getBody().getArguments();
748   return KernelDim3{args[0], args[1], args[2]};
749 }
750 
751 KernelDim3 LaunchOp::getThreadIds() {
752   assert(!getBody().empty() && "LaunchOp body must not be empty.");
753   auto args = getBody().getArguments();
754   return KernelDim3{args[3], args[4], args[5]};
755 }
756 
757 KernelDim3 LaunchOp::getGridSize() {
758   assert(!getBody().empty() && "LaunchOp body must not be empty.");
759   auto args = getBody().getArguments();
760   return KernelDim3{args[6], args[7], args[8]};
761 }
762 
763 KernelDim3 LaunchOp::getBlockSize() {
764   assert(!getBody().empty() && "LaunchOp body must not be empty.");
765   auto args = getBody().getArguments();
766   return KernelDim3{args[9], args[10], args[11]};
767 }
768 
769 std::optional<KernelDim3> LaunchOp::getClusterIds() {
770   assert(!getBody().empty() && "LaunchOp body must not be empty.");
771   if (!hasClusterSize())
772     return std::nullopt;
773   auto args = getBody().getArguments();
774   return KernelDim3{args[12], args[13], args[14]};
775 }
776 
777 std::optional<KernelDim3> LaunchOp::getClusterSize() {
778   assert(!getBody().empty() && "LaunchOp body must not be empty.");
779   if (!hasClusterSize())
780     return std::nullopt;
781   auto args = getBody().getArguments();
782   return KernelDim3{args[15], args[16], args[17]};
783 }
784 
785 KernelDim3 LaunchOp::getGridSizeOperandValues() {
786   auto operands = getOperands().drop_front(getAsyncDependencies().size());
787   return KernelDim3{operands[0], operands[1], operands[2]};
788 }
789 
790 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
791   auto operands = getOperands().drop_front(getAsyncDependencies().size());
792   return KernelDim3{operands[3], operands[4], operands[5]};
793 }
794 
795 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
796   auto operands = getOperands().drop_front(getAsyncDependencies().size());
797   if (!hasClusterSize())
798     return std::nullopt;
799   return KernelDim3{operands[6], operands[7], operands[8]};
800 }
801 
802 LogicalResult LaunchOp::verify() {
803   if (!(hasClusterSize()) &&
804       (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
805     return emitOpError() << "cluster size must be all present";
806   return success();
807 }
808 
809 LogicalResult LaunchOp::verifyRegions() {
810   // Kernel launch takes kNumConfigOperands leading operands for grid/block
811   // sizes and transforms them into kNumConfigRegionAttributes region arguments
812   // for block/thread identifiers and grid/block sizes.
813   if (!getBody().empty()) {
814     if (getBody().getNumArguments() <
815         kNumConfigRegionAttributes + getNumWorkgroupAttributions())
816       return emitOpError("unexpected number of region arguments");
817   }
818 
819   // Verify Attributions Address Spaces.
820   if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
821                                 GPUDialect::getWorkgroupAddressSpace())) ||
822       failed(verifyAttributions(getOperation(), getPrivateAttributions(),
823                                 GPUDialect::getPrivateAddressSpace())))
824     return failure();
825 
826   // Block terminators without successors are expected to exit the kernel region
827   // and must be `gpu.terminator`.
828   for (Block &block : getBody()) {
829     if (block.empty())
830       continue;
831     if (block.back().getNumSuccessors() != 0)
832       continue;
833     if (!isa<gpu::TerminatorOp>(&block.back())) {
834       return block.back()
835           .emitError()
836           .append("expected '", gpu::TerminatorOp::getOperationName(),
837                   "' or a terminator with successors")
838           .attachNote(getLoc())
839           .append("in '", LaunchOp::getOperationName(), "' body region");
840     }
841   }
842 
843   if (getNumResults() == 0 && getAsyncToken())
844     return emitOpError("needs to be named when async keyword is specified");
845 
846   return success();
847 }
848 
849 // Pretty-print the kernel grid/block size assignment as
850 //   (%iter-x, %iter-y, %iter-z) in
851 //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
852 // where %size-* and %iter-* will correspond to the body region arguments.
853 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
854                                 KernelDim3 operands, KernelDim3 ids) {
855   p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
856   p << size.x << " = " << operands.x << ", ";
857   p << size.y << " = " << operands.y << ", ";
858   p << size.z << " = " << operands.z << ')';
859 }
860 
861 void LaunchOp::print(OpAsmPrinter &p) {
862   if (getAsyncToken()) {
863     p << " async";
864     if (!getAsyncDependencies().empty())
865       p << " [" << getAsyncDependencies() << ']';
866   }
867   // Print the launch configuration.
868   if (hasClusterSize()) {
869     p << ' ' << getClustersKeyword();
870     printSizeAssignment(p, getClusterSize().value(),
871                         getClusterSizeOperandValues().value(),
872                         getClusterIds().value());
873   }
874   p << ' ' << getBlocksKeyword();
875   printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
876                       getBlockIds());
877   p << ' ' << getThreadsKeyword();
878   printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
879                       getThreadIds());
880   if (getDynamicSharedMemorySize())
881     p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
882       << getDynamicSharedMemorySize();
883 
884   printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
885   printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
886 
887   p << ' ';
888 
889   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
890   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
891                               LaunchOp::getOperandSegmentSizeAttr(),
892                               getNumWorkgroupAttributionsAttrName()});
893 }
894 
895 // Parse the size assignment blocks for blocks and threads.  These have the form
896 //   (%region_arg, %region_arg, %region_arg) in
897 //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
898 // where %region_arg are percent-identifiers for the region arguments to be
899 // introduced further (SSA defs), and %operand are percent-identifiers for the
900 // SSA value uses.
901 static ParseResult
902 parseSizeAssignment(OpAsmParser &parser,
903                     MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
904                     MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
905                     MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
906   assert(indices.size() == 3 && "space for three indices expected");
907   SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
908   if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
909                               /*allowResultNumber=*/false) ||
910       parser.parseKeyword("in") || parser.parseLParen())
911     return failure();
912   std::move(args.begin(), args.end(), indices.begin());
913 
914   for (int i = 0; i < 3; ++i) {
915     if (i != 0 && parser.parseComma())
916       return failure();
917     if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
918         parser.parseEqual() || parser.parseOperand(sizes[i]))
919       return failure();
920   }
921 
922   return parser.parseRParen();
923 }
924 
925 /// Parses a Launch operation.
926 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
927 ///       `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
928 ///       `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
929 ///       `threads` `(` ssa-id-list `)` `in` ssa-reassignment
930 ///       memory-attribution
931 ///       region attr-dict?
932 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
933 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
934   // Sizes of the grid and block.
935   SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
936       sizes(LaunchOp::kNumConfigOperands);
937 
938   // Actual (data) operands passed to the kernel.
939   SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands;
940 
941   // Region arguments to be created.
942   SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
943       LaunchOp::kNumConfigRegionAttributes);
944 
945   // Parse optional async dependencies.
946   SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
947   Type asyncTokenType;
948   if (failed(
949           parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
950       parser.resolveOperands(asyncDependencies, asyncTokenType,
951                              result.operands))
952     return failure();
953   if (parser.getNumResults() > 0)
954     result.types.push_back(asyncTokenType);
955 
956   bool hasCluster = false;
957   if (succeeded(
958           parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
959     hasCluster = true;
960     sizes.resize(9);
961     regionArgs.resize(18);
962   }
963   MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
964   MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
965 
966   // Last three segment assigns the cluster size. In the region argument
967   // list, this is last 6 arguments.
968   if (hasCluster) {
969     if (parseSizeAssignment(parser, sizesRef.drop_front(6),
970                             regionArgsRef.slice(15, 3),
971                             regionArgsRef.slice(12, 3)))
972       return failure();
973   }
974   // Parse the size assignment segments: the first segment assigns grid sizes
975   // and defines values for block identifiers; the second segment assigns block
976   // sizes and defines values for thread identifiers.  In the region argument
977   // list, identifiers precede sizes, and block-related values precede
978   // thread-related values.
979   if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
980       parseSizeAssignment(parser, sizesRef.take_front(3),
981                           regionArgsRef.slice(6, 3),
982                           regionArgsRef.slice(0, 3)) ||
983       parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
984       parseSizeAssignment(parser, sizesRef.drop_front(3),
985                           regionArgsRef.slice(9, 3),
986                           regionArgsRef.slice(3, 3)) ||
987       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
988                              result.operands))
989     return failure();
990 
991   OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
992   bool hasDynamicSharedMemorySize = false;
993   if (!parser.parseOptionalKeyword(
994           LaunchOp::getDynamicSharedMemorySizeKeyword())) {
995     hasDynamicSharedMemorySize = true;
996     if (parser.parseOperand(dynamicSharedMemorySize) ||
997         parser.resolveOperand(dynamicSharedMemorySize,
998                               parser.getBuilder().getI32Type(),
999                               result.operands))
1000       return failure();
1001   }
1002 
1003   // Create the region arguments, it has kNumConfigRegionAttributes arguments
1004   // that correspond to block/thread identifiers and grid/block sizes, all
1005   // having `index` type, a variadic number of WorkGroup Attributions and
1006   // a variadic number of Private Attributions. The number of WorkGroup
1007   // Attributions is stored in the attr with name:
1008   // LaunchOp::getNumWorkgroupAttributionsAttrName().
1009   Type index = parser.getBuilder().getIndexType();
1010   SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1011       LaunchOp::kNumConfigRegionAttributes + 6, index);
1012 
1013   SmallVector<OpAsmParser::Argument> regionArguments;
1014   for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1015     OpAsmParser::Argument arg;
1016     arg.ssaName = std::get<0>(ssaValueAndType);
1017     arg.type = std::get<1>(ssaValueAndType);
1018     regionArguments.push_back(arg);
1019   }
1020 
1021   Builder &builder = parser.getBuilder();
1022   // Parse workgroup memory attributions.
1023   if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1024                                regionArguments)))
1025     return failure();
1026 
1027   // Store the number of operands we just parsed as the number of workgroup
1028   // memory attributions.
1029   unsigned numWorkgroupAttrs = regionArguments.size() -
1030                                LaunchOp::kNumConfigRegionAttributes -
1031                                (hasCluster ? 6 : 0);
1032   result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1033                       builder.getI64IntegerAttr(numWorkgroupAttrs));
1034 
1035   // Parse private memory attributions.
1036   if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1037                                regionArguments)))
1038     return failure();
1039 
1040   // Introduce the body region and parse it. The region has
1041   // kNumConfigRegionAttributes arguments that correspond to
1042   // block/thread identifiers and grid/block sizes, all having `index` type.
1043   Region *body = result.addRegion();
1044   if (parser.parseRegion(*body, regionArguments) ||
1045       parser.parseOptionalAttrDict(result.attributes))
1046     return failure();
1047 
1048   SmallVector<int32_t, 11> segmentSizes(11, 1);
1049   segmentSizes.front() = asyncDependencies.size();
1050 
1051   if (!hasCluster) {
1052     segmentSizes[7] = 0;
1053     segmentSizes[8] = 0;
1054     segmentSizes[9] = 0;
1055   }
1056   segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1057   result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1058                       parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1059   return success();
1060 }
1061 
1062 /// Simplify the gpu.launch when the range of a thread or block ID is
1063 /// trivially known to be one.
1064 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1065   using OpRewritePattern<LaunchOp>::OpRewritePattern;
1066   LogicalResult matchAndRewrite(LaunchOp op,
1067                                 PatternRewriter &rewriter) const override {
1068     // If the range implies a single value for `id`, replace `id`'s uses by
1069     // zero.
1070     Value zero;
1071     bool simplified = false;
1072     auto constPropIdUses = [&](Value id, Value size) {
1073       // Check if size is trivially one.
1074       if (!matchPattern(size, m_One()))
1075         return;
1076       if (id.getUses().empty())
1077         return;
1078       if (!simplified) {
1079         // Create a zero value the first time.
1080         OpBuilder::InsertionGuard guard(rewriter);
1081         rewriter.setInsertionPointToStart(&op.getBody().front());
1082         zero =
1083             rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1084       }
1085       rewriter.replaceAllUsesWith(id, zero);
1086       simplified = true;
1087     };
1088     constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1089     constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1090     constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1091     constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1092     constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1093     constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1094 
1095     return success(simplified);
1096   }
1097 };
1098 
1099 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1100                                            MLIRContext *context) {
1101   rewrites.add<FoldLaunchArguments>(context);
1102 }
1103 
1104 /// Adds a new block argument that corresponds to buffers located in
1105 /// workgroup memory.
1106 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1107   auto attrName = getNumWorkgroupAttributionsAttrName();
1108   auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1109   (*this)->setAttr(attrName,
1110                    IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1111   return getBody().insertArgument(
1112       LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1113 }
1114 
1115 /// Adds a new block argument that corresponds to buffers located in
1116 /// private memory.
1117 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1118   // Buffers on the private memory always come after buffers on the workgroup
1119   // memory.
1120   return getBody().addArgument(type, loc);
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // LaunchFuncOp
1125 //===----------------------------------------------------------------------===//
1126 
1127 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1128                          SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1129                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1130                          ValueRange kernelOperands, Type asyncTokenType,
1131                          ValueRange asyncDependencies,
1132                          std::optional<KernelDim3> clusterSize) {
1133   assert(kernelSymbol.getNestedReferences().size() == 1 &&
1134          "expected a symbol reference with a single nested reference");
1135   result.addOperands(asyncDependencies);
1136   if (asyncTokenType)
1137     result.types.push_back(builder.getType<AsyncTokenType>());
1138 
1139   // Add grid and block sizes as op operands, followed by the data operands.
1140   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1141                       getBlockSize.y, getBlockSize.z});
1142   if (clusterSize.has_value())
1143     result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1144   if (dynamicSharedMemorySize)
1145     result.addOperands(dynamicSharedMemorySize);
1146   result.addOperands(kernelOperands);
1147 
1148   Properties &prop = result.getOrAddProperties<Properties>();
1149   prop.kernel = kernelSymbol;
1150   size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1151   // Initialize the segment sizes to 1.
1152   for (auto &sz : prop.operandSegmentSizes)
1153     sz = 1;
1154   prop.operandSegmentSizes[0] = asyncDependencies.size();
1155   if (!clusterSize.has_value()) {
1156     prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1157     prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1158     prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1159   }
1160   prop.operandSegmentSizes[segmentSizesLen - 3] =
1161       dynamicSharedMemorySize ? 1 : 0;
1162   prop.operandSegmentSizes[segmentSizesLen - 2] =
1163       static_cast<int32_t>(kernelOperands.size());
1164   prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1165 }
1166 
1167 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1168                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
1169                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1170                          ValueRange kernelOperands, Type asyncTokenType,
1171                          ValueRange asyncDependencies,
1172                          std::optional<KernelDim3> clusterSize) {
1173   auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1174   auto kernelSymbol =
1175       SymbolRefAttr::get(kernelModule.getNameAttr(),
1176                          {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1177   build(builder, result, kernelSymbol, gridSize, getBlockSize,
1178         dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1179         asyncDependencies, clusterSize);
1180 }
1181 
1182 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1183                          SymbolRefAttr kernel, KernelDim3 gridSize,
1184                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1185                          ValueRange kernelOperands, Value asyncObject,
1186                          std::optional<KernelDim3> clusterSize) {
1187   // Add grid and block sizes as op operands, followed by the data operands.
1188   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1189                       getBlockSize.y, getBlockSize.z});
1190   if (clusterSize.has_value())
1191     result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1192   if (dynamicSharedMemorySize)
1193     result.addOperands(dynamicSharedMemorySize);
1194   result.addOperands(kernelOperands);
1195   if (asyncObject)
1196     result.addOperands(asyncObject);
1197   Properties &prop = result.getOrAddProperties<Properties>();
1198   prop.kernel = kernel;
1199   size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1200   // Initialize the segment sizes to 1.
1201   for (auto &sz : prop.operandSegmentSizes)
1202     sz = 1;
1203   prop.operandSegmentSizes[0] = 0;
1204   if (!clusterSize.has_value()) {
1205     prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1206     prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1207     prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1208   }
1209   prop.operandSegmentSizes[segmentSizesLen - 3] =
1210       dynamicSharedMemorySize ? 1 : 0;
1211   prop.operandSegmentSizes[segmentSizesLen - 2] =
1212       static_cast<int32_t>(kernelOperands.size());
1213   prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1214 }
1215 
1216 StringAttr LaunchFuncOp::getKernelModuleName() {
1217   return getKernel().getRootReference();
1218 }
1219 
1220 StringAttr LaunchFuncOp::getKernelName() {
1221   return getKernel().getLeafReference();
1222 }
1223 
1224 unsigned LaunchFuncOp::getNumKernelOperands() {
1225   return getKernelOperands().size();
1226 }
1227 
1228 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1229   return getKernelOperands()[i];
1230 }
1231 
1232 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1233   auto operands = getOperands().drop_front(getAsyncDependencies().size());
1234   return KernelDim3{operands[0], operands[1], operands[2]};
1235 }
1236 
1237 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1238   auto operands = getOperands().drop_front(getAsyncDependencies().size());
1239   return KernelDim3{operands[3], operands[4], operands[5]};
1240 }
1241 
1242 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1243   assert(hasClusterSize() &&
1244          "cluster size is not set, check hasClusterSize() first");
1245   auto operands = getOperands().drop_front(getAsyncDependencies().size());
1246   return KernelDim3{operands[6], operands[7], operands[8]};
1247 }
1248 
1249 LogicalResult LaunchFuncOp::verify() {
1250   auto module = (*this)->getParentOfType<ModuleOp>();
1251   if (!module)
1252     return emitOpError("expected to belong to a module");
1253 
1254   if (!module->getAttrOfType<UnitAttr>(
1255           GPUDialect::getContainerModuleAttrName()))
1256     return emitOpError("expected the closest surrounding module to have the '" +
1257                        GPUDialect::getContainerModuleAttrName() +
1258                        "' attribute");
1259 
1260   if (hasClusterSize()) {
1261     if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1262         getClusterSizeZ().getType() != getClusterSizeX().getType())
1263       return emitOpError()
1264              << "expects types of the cluster dimensions must be the same";
1265   }
1266 
1267   return success();
1268 }
1269 
1270 static ParseResult
1271 parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
1272                    std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1273                    Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1274   if (succeeded(parser.parseOptionalColon())) {
1275     if (parser.parseType(dimTy))
1276       return failure();
1277   } else {
1278     dimTy = IndexType::get(parser.getContext());
1279   }
1280   if (clusterValue.has_value()) {
1281     clusterXTy = clusterYTy = clusterZTy = dimTy;
1282   }
1283   return success();
1284 }
1285 
1286 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1287                                Value clusterValue, Type clusterXTy,
1288                                Type clusterYTy, Type clusterZTy) {
1289   if (!dimTy.isIndex())
1290     printer << ": " << dimTy;
1291 }
1292 
1293 static ParseResult parseLaunchFuncOperands(
1294     OpAsmParser &parser,
1295     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
1296     SmallVectorImpl<Type> &argTypes) {
1297   if (parser.parseOptionalKeyword("args"))
1298     return success();
1299 
1300   auto parseElement = [&]() -> ParseResult {
1301     return failure(parser.parseOperand(argNames.emplace_back()) ||
1302                    parser.parseColonType(argTypes.emplace_back()));
1303   };
1304 
1305   return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1306                                         parseElement, " in argument list");
1307 }
1308 
1309 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
1310                                     OperandRange operands, TypeRange types) {
1311   if (operands.empty())
1312     return;
1313   printer << "args(";
1314   llvm::interleaveComma(llvm::zip(operands, types), printer,
1315                         [&](const auto &pair) {
1316                           printer.printOperand(std::get<0>(pair));
1317                           printer << " : ";
1318                           printer.printType(std::get<1>(pair));
1319                         });
1320   printer << ")";
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // ShuffleOp
1325 //===----------------------------------------------------------------------===//
1326 
1327 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1328                       int32_t offset, int32_t width, ShuffleMode mode) {
1329   build(builder, result, value,
1330         builder.create<arith::ConstantOp>(result.location,
1331                                           builder.getI32IntegerAttr(offset)),
1332         builder.create<arith::ConstantOp>(result.location,
1333                                           builder.getI32IntegerAttr(width)),
1334         mode);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // BarrierOp
1339 //===----------------------------------------------------------------------===//
1340 
1341 namespace {
1342 
1343 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1344 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1345                                           PatternRewriter &rewriter) {
1346   if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1347     rewriter.eraseOp(op);
1348     return success();
1349   }
1350   return failure();
1351 }
1352 
1353 } // end anonymous namespace
1354 
1355 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1356                                             MLIRContext *context) {
1357   results.add(eraseRedundantGpuBarrierOps);
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // GPUFuncOp
1362 //===----------------------------------------------------------------------===//
1363 
1364 /// Adds a new block argument that corresponds to buffers located in
1365 /// workgroup memory.
1366 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1367   auto attrName = getNumWorkgroupAttributionsAttrName();
1368   auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1369   (*this)->setAttr(attrName,
1370                    IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1371   return getBody().insertArgument(
1372       getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1373 }
1374 
1375 /// Adds a new block argument that corresponds to buffers located in
1376 /// private memory.
1377 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1378   // Buffers on the private memory always come after buffers on the workgroup
1379   // memory.
1380   return getBody().addArgument(type, loc);
1381 }
1382 
1383 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1384                       StringRef name, FunctionType type,
1385                       TypeRange workgroupAttributions,
1386                       TypeRange privateAttributions,
1387                       ArrayRef<NamedAttribute> attrs) {
1388   OpBuilder::InsertionGuard g(builder);
1389 
1390   result.addAttribute(SymbolTable::getSymbolAttrName(),
1391                       builder.getStringAttr(name));
1392   result.addAttribute(getFunctionTypeAttrName(result.name),
1393                       TypeAttr::get(type));
1394   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1395                       builder.getI64IntegerAttr(workgroupAttributions.size()));
1396   result.addAttributes(attrs);
1397   Region *body = result.addRegion();
1398   Block *entryBlock = builder.createBlock(body);
1399 
1400   // TODO: Allow passing in proper locations here.
1401   for (Type argTy : type.getInputs())
1402     entryBlock->addArgument(argTy, result.location);
1403   for (Type argTy : workgroupAttributions)
1404     entryBlock->addArgument(argTy, result.location);
1405   for (Type argTy : privateAttributions)
1406     entryBlock->addArgument(argTy, result.location);
1407 }
1408 
1409 /// Parses a GPU function memory attribution.
1410 ///
1411 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1412 ///                        (`private` `(` ssa-id-and-type-list `)`)?
1413 ///
1414 /// Note that this function parses only one of the two similar parts, with the
1415 /// keyword provided as argument.
1416 static ParseResult
1417 parseAttributions(OpAsmParser &parser, StringRef keyword,
1418                   SmallVectorImpl<OpAsmParser::Argument> &args,
1419                   Attribute &attributionAttrs) {
1420   // If we could not parse the keyword, just assume empty list and succeed.
1421   if (failed(parser.parseOptionalKeyword(keyword)))
1422     return success();
1423 
1424   size_t existingArgs = args.size();
1425   ParseResult result =
1426       parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
1427                                /*allowType=*/true, /*allowAttrs=*/true);
1428   if (failed(result))
1429     return result;
1430 
1431   bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1432                                [](const OpAsmParser::Argument &arg) -> bool {
1433                                  return arg.attrs && !arg.attrs.empty();
1434                                });
1435   if (!hadAttrs) {
1436     attributionAttrs = nullptr;
1437     return result;
1438   }
1439 
1440   Builder &builder = parser.getBuilder();
1441   SmallVector<Attribute> attributionAttrsVec;
1442   for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1443     if (!argument.attrs)
1444       attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1445     else
1446       attributionAttrsVec.push_back(argument.attrs);
1447   }
1448   attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1449   return result;
1450 }
1451 
1452 /// Parses a GPU function.
1453 ///
1454 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1455 ///                 (`->` function-result-list)? memory-attribution `kernel`?
1456 ///                 function-attributes? region
1457 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1458   SmallVector<OpAsmParser::Argument> entryArgs;
1459   SmallVector<DictionaryAttr> resultAttrs;
1460   SmallVector<Type> resultTypes;
1461   bool isVariadic;
1462 
1463   // Parse the function name.
1464   StringAttr nameAttr;
1465   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1466                              result.attributes))
1467     return failure();
1468 
1469   auto signatureLocation = parser.getCurrentLocation();
1470   if (failed(function_interface_impl::parseFunctionSignature(
1471           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1472           resultAttrs)))
1473     return failure();
1474 
1475   if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1476     return parser.emitError(signatureLocation)
1477            << "gpu.func requires named arguments";
1478 
1479   // Construct the function type. More types will be added to the region, but
1480   // not to the function type.
1481   Builder &builder = parser.getBuilder();
1482 
1483   SmallVector<Type> argTypes;
1484   for (auto &arg : entryArgs)
1485     argTypes.push_back(arg.type);
1486   auto type = builder.getFunctionType(argTypes, resultTypes);
1487   result.addAttribute(getFunctionTypeAttrName(result.name),
1488                       TypeAttr::get(type));
1489 
1490   function_interface_impl::addArgAndResultAttrs(
1491       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1492       getResAttrsAttrName(result.name));
1493 
1494   Attribute workgroupAttributionAttrs;
1495   // Parse workgroup memory attributions.
1496   if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1497                                entryArgs, workgroupAttributionAttrs)))
1498     return failure();
1499 
1500   // Store the number of operands we just parsed as the number of workgroup
1501   // memory attributions.
1502   unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1503   result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1504                       builder.getI64IntegerAttr(numWorkgroupAttrs));
1505   if (workgroupAttributionAttrs)
1506     result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1507                         workgroupAttributionAttrs);
1508 
1509   Attribute privateAttributionAttrs;
1510   // Parse private memory attributions.
1511   if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1512                                entryArgs, privateAttributionAttrs)))
1513     return failure();
1514   if (privateAttributionAttrs)
1515     result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1516                         privateAttributionAttrs);
1517 
1518   // Parse the kernel attribute if present.
1519   if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1520     result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1521                         builder.getUnitAttr());
1522 
1523   // Parse attributes.
1524   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1525     return failure();
1526 
1527   // Parse the region. If no argument names were provided, take all names
1528   // (including those of attributions) from the entry block.
1529   auto *body = result.addRegion();
1530   return parser.parseRegion(*body, entryArgs);
1531 }
1532 
1533 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1534                               ArrayRef<BlockArgument> values,
1535                               ArrayAttr attributes) {
1536   if (values.empty())
1537     return;
1538 
1539   p << ' ' << keyword << '(';
1540   llvm::interleaveComma(
1541       llvm::enumerate(values), p, [&p, attributes](auto pair) {
1542         BlockArgument v = pair.value();
1543         p << v << " : " << v.getType();
1544 
1545         size_t attributionIndex = pair.index();
1546         DictionaryAttr attrs;
1547         if (attributes && attributionIndex < attributes.size())
1548           attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1549         if (attrs)
1550           p.printOptionalAttrDict(attrs.getValue());
1551       });
1552   p << ')';
1553 }
1554 
1555 void GPUFuncOp::print(OpAsmPrinter &p) {
1556   p << ' ';
1557   p.printSymbolName(getName());
1558 
1559   FunctionType type = getFunctionType();
1560   function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1561                                                   /*isVariadic=*/false,
1562                                                   type.getResults());
1563 
1564   printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1565                     getWorkgroupAttribAttrs().value_or(nullptr));
1566   printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1567                     getPrivateAttribAttrs().value_or(nullptr));
1568   if (isKernel())
1569     p << ' ' << getKernelKeyword();
1570 
1571   function_interface_impl::printFunctionAttributes(
1572       p, *this,
1573       {getNumWorkgroupAttributionsAttrName(),
1574        GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1575        getArgAttrsAttrName(), getResAttrsAttrName(),
1576        getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1577   p << ' ';
1578   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1579 }
1580 
1581 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1582                                           StringAttr attrName) {
1583   auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1584   if (!allAttrs || index >= allAttrs.size())
1585     return DictionaryAttr();
1586   return llvm::cast<DictionaryAttr>(allAttrs[index]);
1587 }
1588 
1589 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1590   return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1591 }
1592 
1593 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1594   return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1595 }
1596 
1597 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1598                                 DictionaryAttr value, StringAttr attrName) {
1599   MLIRContext *ctx = op.getContext();
1600   auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1601   SmallVector<Attribute> elements;
1602   if (allAttrs)
1603     elements.append(allAttrs.begin(), allAttrs.end());
1604   while (elements.size() <= index)
1605     elements.push_back(DictionaryAttr::get(ctx));
1606   if (!value)
1607     elements[index] = DictionaryAttr::get(ctx);
1608   else
1609     elements[index] = value;
1610   ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1611   op->setAttr(attrName, newValue);
1612 }
1613 
1614 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1615                                              DictionaryAttr value) {
1616   setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1617 }
1618 
1619 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1620                                            DictionaryAttr value) {
1621   setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1622 }
1623 
1624 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1625                                     StringAttr name, StringAttr attrsName) {
1626   DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1627   if (!dict)
1628     return Attribute();
1629   return dict.get(name);
1630 }
1631 
1632 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1633                                                  StringAttr name) {
1634   assert(index < getNumWorkgroupAttributions() &&
1635          "index must map to a workgroup attribution");
1636   return getAttributionAttr(*this, index, name,
1637                             getWorkgroupAttribAttrsAttrName());
1638 }
1639 
1640 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1641                                                StringAttr name) {
1642   assert(index < getNumPrivateAttributions() &&
1643          "index must map to a private attribution");
1644   return getAttributionAttr(*this, index, name,
1645                             getPrivateAttribAttrsAttrName());
1646 }
1647 
1648 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1649                                Attribute value, StringAttr attrsName) {
1650   MLIRContext *ctx = op.getContext();
1651   SmallVector<NamedAttribute> elems;
1652   DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1653   if (oldDict)
1654     elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1655 
1656   bool found = false;
1657   bool mustSort = true;
1658   for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1659     if (elems[i].getName() == name) {
1660       found = true;
1661       if (!value) {
1662         std::swap(elems[i], elems[elems.size() - 1]);
1663         elems.pop_back();
1664       } else {
1665         mustSort = false;
1666         elems[i] = NamedAttribute(elems[i].getName(), value);
1667       }
1668       break;
1669     }
1670   }
1671   if (!found) {
1672     if (!value)
1673       return;
1674     elems.emplace_back(name, value);
1675   }
1676   if (mustSort) {
1677     DictionaryAttr::sortInPlace(elems);
1678   }
1679   auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1680   setAttributionAttrs(op, index, newDict, attrsName);
1681 }
1682 
1683 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1684                                             Attribute value) {
1685   assert(index < getNumWorkgroupAttributions() &&
1686          "index must map to a workgroup attribution");
1687   setAttributionAttr(*this, index, name, value,
1688                      getWorkgroupAttribAttrsAttrName());
1689 }
1690 
1691 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1692                                           Attribute value) {
1693   assert(index < getNumPrivateAttributions() &&
1694          "index must map to a private attribution");
1695   setAttributionAttr(*this, index, name, value,
1696                      getPrivateAttribAttrsAttrName());
1697 }
1698 
1699 LogicalResult GPUFuncOp::verifyType() {
1700   if (isKernel() && getFunctionType().getNumResults() != 0)
1701     return emitOpError() << "expected void return type for kernel function";
1702 
1703   return success();
1704 }
1705 
1706 /// Verifies the body of the function.
1707 LogicalResult GPUFuncOp::verifyBody() {
1708   if (empty())
1709     return emitOpError() << "expected body with at least one block";
1710   unsigned numFuncArguments = getNumArguments();
1711   unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1712   unsigned numBlockArguments = front().getNumArguments();
1713   if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1714     return emitOpError() << "expected at least "
1715                          << numFuncArguments + numWorkgroupAttributions
1716                          << " arguments to body region";
1717 
1718   ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1719   for (unsigned i = 0; i < numFuncArguments; ++i) {
1720     Type blockArgType = front().getArgument(i).getType();
1721     if (funcArgTypes[i] != blockArgType)
1722       return emitOpError() << "expected body region argument #" << i
1723                            << " to be of type " << funcArgTypes[i] << ", got "
1724                            << blockArgType;
1725   }
1726 
1727   if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1728                                 GPUDialect::getWorkgroupAddressSpace())) ||
1729       failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1730                                 GPUDialect::getPrivateAddressSpace())))
1731     return failure();
1732 
1733   return success();
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // ReturnOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 LogicalResult gpu::ReturnOp::verify() {
1741   GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1742 
1743   FunctionType funType = function.getFunctionType();
1744 
1745   if (funType.getNumResults() != getOperands().size())
1746     return emitOpError()
1747         .append("expected ", funType.getNumResults(), " result operands")
1748         .attachNote(function.getLoc())
1749         .append("return type declared here");
1750 
1751   for (const auto &pair : llvm::enumerate(
1752            llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1753     auto [type, operand] = pair.value();
1754     if (type != operand.getType())
1755       return emitOpError() << "unexpected type `" << operand.getType()
1756                            << "' for operand #" << pair.index();
1757   }
1758   return success();
1759 }
1760 
1761 //===----------------------------------------------------------------------===//
1762 // GPUModuleOp
1763 //===----------------------------------------------------------------------===//
1764 
1765 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1766                         StringRef name, ArrayAttr targets,
1767                         Attribute offloadingHandler) {
1768   result.addRegion()->emplaceBlock();
1769   Properties &props = result.getOrAddProperties<Properties>();
1770   if (targets)
1771     props.targets = targets;
1772   props.setSymName(builder.getStringAttr(name));
1773   props.offloadingHandler = offloadingHandler;
1774 }
1775 
1776 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1777                         StringRef name, ArrayRef<Attribute> targets,
1778                         Attribute offloadingHandler) {
1779   build(builder, result, name,
1780         targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1781         offloadingHandler);
1782 }
1783 
1784 bool GPUModuleOp::hasTarget(Attribute target) {
1785   if (ArrayAttr targets = getTargetsAttr())
1786     return llvm::count(targets.getValue(), target);
1787   return false;
1788 }
1789 
1790 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1791   ArrayAttr &targetsAttr = getProperties().targets;
1792   SmallVector<Attribute> targetsVector(targets);
1793   targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // GPUBinaryOp
1798 //===----------------------------------------------------------------------===//
1799 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1800                      Attribute offloadingHandler, ArrayAttr objects) {
1801   auto &properties = result.getOrAddProperties<Properties>();
1802   result.attributes.push_back(builder.getNamedAttr(
1803       SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1804   properties.objects = objects;
1805   if (offloadingHandler)
1806     properties.offloadingHandler = offloadingHandler;
1807   else
1808     properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1809 }
1810 
1811 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1812                      Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1813   build(builder, result, name, offloadingHandler,
1814         objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1815 }
1816 
1817 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1818                                           Attribute &offloadingHandler) {
1819   if (succeeded(parser.parseOptionalLess())) {
1820     if (parser.parseAttribute(offloadingHandler))
1821       return failure();
1822     if (parser.parseGreater())
1823       return failure();
1824   }
1825   if (!offloadingHandler)
1826     offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1827   return success();
1828 }
1829 
1830 static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op,
1831                                    Attribute offloadingHandler) {
1832   if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1833     printer << '<' << offloadingHandler << '>';
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // GPUMemcpyOp
1838 //===----------------------------------------------------------------------===//
1839 
1840 LogicalResult MemcpyOp::verify() {
1841   auto srcType = getSrc().getType();
1842   auto dstType = getDst().getType();
1843 
1844   if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1845     return emitOpError("arguments have incompatible element type");
1846 
1847   if (failed(verifyCompatibleShape(srcType, dstType)))
1848     return emitOpError("arguments have incompatible shape");
1849 
1850   return success();
1851 }
1852 
1853 namespace {
1854 
1855 /// Erases a common case of copy ops where a destination value is used only by
1856 /// the copy op, alloc and dealloc ops.
1857 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1858   using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1859 
1860   LogicalResult matchAndRewrite(MemcpyOp op,
1861                                 PatternRewriter &rewriter) const override {
1862     Value dest = op.getDst();
1863     Operation *destDefOp = dest.getDefiningOp();
1864     // `dest` must be defined by an op having Allocate memory effect in order to
1865     // perform the folding.
1866     if (!destDefOp ||
1867         !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1868       return failure();
1869     // We can erase `op` iff `dest` has no other use apart from its
1870     // use by `op` and dealloc ops.
1871     if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1872           return user != op &&
1873                  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1874         }))
1875       return failure();
1876     // We can perform the folding if and only if op has a single async
1877     // dependency and produces an async token as result, or if it does not have
1878     // any async dependency and does not produce any async token result.
1879     if (op.getAsyncDependencies().size() > 1 ||
1880         ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1881          (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1882       return failure();
1883     rewriter.replaceOp(op, op.getAsyncDependencies());
1884     return success();
1885   }
1886 };
1887 
1888 } // end anonymous namespace
1889 
1890 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1891                                            MLIRContext *context) {
1892   results.add<EraseTrivialCopyOp>(context);
1893 }
1894 
1895 //===----------------------------------------------------------------------===//
1896 // GPU_SubgroupMmaLoadMatrixOp
1897 //===----------------------------------------------------------------------===//
1898 
1899 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1900   auto srcType = getSrcMemref().getType();
1901   auto resType = getRes().getType();
1902   auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1903   auto operand = resMatrixType.getOperand();
1904   auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1905 
1906   if (!srcMemrefType.isLastDimUnitStride())
1907     return emitError(
1908         "expected source memref most minor dim must have unit stride");
1909 
1910   if (operand != "AOp" && operand != "BOp" && operand != "COp")
1911     return emitError("only AOp, BOp and COp can be loaded");
1912 
1913   return success();
1914 }
1915 
1916 //===----------------------------------------------------------------------===//
1917 // GPU_SubgroupMmaStoreMatrixOp
1918 //===----------------------------------------------------------------------===//
1919 
1920 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1921   auto srcType = getSrc().getType();
1922   auto dstType = getDstMemref().getType();
1923   auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1924   auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1925 
1926   if (!dstMemrefType.isLastDimUnitStride())
1927     return emitError(
1928         "expected destination memref most minor dim must have unit stride");
1929 
1930   if (srcMatrixType.getOperand() != "COp")
1931     return emitError(
1932         "expected the operand matrix being stored to have 'COp' operand type");
1933 
1934   return success();
1935 }
1936 
1937 //===----------------------------------------------------------------------===//
1938 // GPU_SubgroupMmaComputeOp
1939 //===----------------------------------------------------------------------===//
1940 
1941 LogicalResult SubgroupMmaComputeOp::verify() {
1942   enum OperandMap { A, B, C };
1943   SmallVector<MMAMatrixType, 3> opTypes;
1944   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1945   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1946   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1947 
1948   if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1949       opTypes[C].getOperand() != "COp")
1950     return emitError("operands must be in the order AOp, BOp, COp");
1951 
1952   ArrayRef<int64_t> aShape, bShape, cShape;
1953   aShape = opTypes[A].getShape();
1954   bShape = opTypes[B].getShape();
1955   cShape = opTypes[C].getShape();
1956 
1957   if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1958       bShape[1] != cShape[1])
1959     return emitError("operand shapes do not satisfy matmul constraints");
1960 
1961   return success();
1962 }
1963 
1964 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1965                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
1966   return memref::foldMemRefCast(*this);
1967 }
1968 
1969 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1970                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
1971   return memref::foldMemRefCast(*this);
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // GPU_WaitOp
1976 //===----------------------------------------------------------------------===//
1977 
1978 namespace {
1979 
1980 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1981 /// %t = gpu.wait async []       // No async dependencies.
1982 /// ...  gpu.wait ... [%t, ...]  // %t can be removed.
1983 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1984 public:
1985   using OpRewritePattern::OpRewritePattern;
1986 
1987   LogicalResult matchAndRewrite(WaitOp op,
1988                                 PatternRewriter &rewriter) const final {
1989     auto predicate = [](Value value) {
1990       auto waitOp = value.getDefiningOp<WaitOp>();
1991       return waitOp && waitOp->getNumOperands() == 0;
1992     };
1993     if (llvm::none_of(op.getAsyncDependencies(), predicate))
1994       return failure();
1995     SmallVector<Value> validOperands;
1996     for (Value operand : op->getOperands()) {
1997       if (predicate(operand))
1998         continue;
1999       validOperands.push_back(operand);
2000     }
2001     rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2002     return success();
2003   }
2004 };
2005 
2006 /// Simplify trivial gpu.wait ops for the following patterns.
2007 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2008 /// dependencies).
2009 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2010 /// %t0.
2011 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2012 /// dependencies nor return any token.
2013 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2014 public:
2015   using OpRewritePattern::OpRewritePattern;
2016 
2017   LogicalResult matchAndRewrite(WaitOp op,
2018                                 PatternRewriter &rewriter) const final {
2019     // Erase gpu.wait ops that neither have any async dependencies nor return
2020     // any async token.
2021     if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2022       rewriter.eraseOp(op);
2023       return success();
2024     }
2025     // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2026     if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2027         op.getAsyncToken()) {
2028       rewriter.replaceOp(op, op.getAsyncDependencies());
2029       return success();
2030     }
2031     // Erase %t = gpu.wait async ... ops, where %t has no uses.
2032     if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2033       rewriter.eraseOp(op);
2034       return success();
2035     }
2036     return failure();
2037   }
2038 };
2039 
2040 } // end anonymous namespace
2041 
2042 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2043                                          MLIRContext *context) {
2044   results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2045 }
2046 
2047 //===----------------------------------------------------------------------===//
2048 // GPU_AllocOp
2049 //===----------------------------------------------------------------------===//
2050 
2051 LogicalResult AllocOp::verify() {
2052   auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2053 
2054   if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2055     return emitOpError("dimension operand count does not equal memref "
2056                        "dynamic dimension count");
2057 
2058   unsigned numSymbols = 0;
2059   if (!memRefType.getLayout().isIdentity())
2060     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2061   if (getSymbolOperands().size() != numSymbols) {
2062     return emitOpError(
2063         "symbol operand count does not equal memref symbol count");
2064   }
2065 
2066   return success();
2067 }
2068 
2069 namespace {
2070 
2071 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2072 /// `memref::AllocOp`.
2073 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2074   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2075 
2076   LogicalResult matchAndRewrite(memref::DimOp dimOp,
2077                                 PatternRewriter &rewriter) const override {
2078     std::optional<int64_t> index = dimOp.getConstantIndex();
2079     if (!index)
2080       return failure();
2081 
2082     auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2083     if (!memrefType || index.value() >= memrefType.getRank() ||
2084         !memrefType.isDynamicDim(index.value()))
2085       return failure();
2086 
2087     auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2088     if (!alloc)
2089       return failure();
2090 
2091     Value substituteOp = *(alloc.getDynamicSizes().begin() +
2092                            memrefType.getDynamicDimIndex(index.value()));
2093     rewriter.replaceOp(dimOp, substituteOp);
2094     return success();
2095   }
2096 };
2097 
2098 } // namespace
2099 
2100 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2101                                           MLIRContext *context) {
2102   results.add<SimplifyDimOfAllocOp>(context);
2103 }
2104 
2105 //===----------------------------------------------------------------------===//
2106 // GPU object attribute
2107 //===----------------------------------------------------------------------===//
2108 
2109 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2110                                  Attribute target, CompilationTarget format,
2111                                  StringAttr object, DictionaryAttr properties,
2112                                  KernelTableAttr kernels) {
2113   if (!target)
2114     return emitError() << "the target attribute cannot be null";
2115   if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2116     return success();
2117   return emitError() << "the target attribute must implement or promise the "
2118                         "`gpu::TargetAttrInterface`";
2119 }
2120 
2121 namespace {
2122 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2123                         StringAttr &object) {
2124   std::optional<CompilationTarget> formatResult;
2125   StringRef enumKeyword;
2126   auto loc = odsParser.getCurrentLocation();
2127   if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2128     formatResult = CompilationTarget::Fatbin;
2129   if (!formatResult &&
2130       (formatResult =
2131            gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2132       odsParser.parseEqual())
2133     return odsParser.emitError(loc, "expected an equal sign");
2134   if (!formatResult)
2135     return odsParser.emitError(loc, "expected keyword for GPU object format");
2136   FailureOr<StringAttr> objectResult =
2137       FieldParser<StringAttr>::parse(odsParser);
2138   if (failed(objectResult))
2139     return odsParser.emitError(odsParser.getCurrentLocation(),
2140                                "failed to parse GPU_ObjectAttr parameter "
2141                                "'object' which is to be a `StringAttr`");
2142   format = *formatResult;
2143   object = *objectResult;
2144   return success();
2145 }
2146 
2147 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2148                  StringAttr object) {
2149   if (format != CompilationTarget::Fatbin)
2150     odsParser << stringifyEnum(format) << " = ";
2151   odsParser << object;
2152 }
2153 } // namespace
2154 
2155 //===----------------------------------------------------------------------===//
2156 // GPU select object attribute
2157 //===----------------------------------------------------------------------===//
2158 
2159 LogicalResult
2160 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2161                               Attribute target) {
2162   // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2163   if (target) {
2164     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2165       if (intAttr.getInt() < 0) {
2166         return emitError() << "the object index must be positive";
2167       }
2168     } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2169       return emitError()
2170              << "the target attribute must be a GPU Target attribute";
2171     }
2172   }
2173   return success();
2174 }
2175 
2176 //===----------------------------------------------------------------------===//
2177 // DynamicSharedMemoryOp
2178 //===----------------------------------------------------------------------===//
2179 
2180 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2181   if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2182     return emitOpError() << "must be inside an op with symbol table";
2183 
2184   MemRefType memrefType = getResultMemref().getType();
2185   // Check address space
2186   if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2187     return emitOpError() << "address space must be "
2188                          << gpu::AddressSpaceAttr::getMnemonic() << "<"
2189                          << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2190   }
2191   if (memrefType.hasStaticShape()) {
2192     return emitOpError() << "result memref type must be memref<?xi8, "
2193                             "#gpu.address_space<workgroup>>";
2194   }
2195   return success();
2196 }
2197 
2198 //===----------------------------------------------------------------------===//
2199 // GPU WarpExecuteOnLane0Op
2200 //===----------------------------------------------------------------------===//
2201 
2202 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2203   p << "(" << getLaneid() << ")";
2204 
2205   SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2206   auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2207   p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2208 
2209   if (!getArgs().empty())
2210     p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2211   if (!getResults().empty())
2212     p << " -> (" << getResults().getTypes() << ')';
2213   p << " ";
2214   p.printRegion(getRegion(),
2215                 /*printEntryBlockArgs=*/true,
2216                 /*printBlockTerminators=*/!getResults().empty());
2217   p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2218 }
2219 
2220 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2221                                         OperationState &result) {
2222   // Create the region.
2223   result.regions.reserve(1);
2224   Region *warpRegion = result.addRegion();
2225 
2226   auto &builder = parser.getBuilder();
2227   OpAsmParser::UnresolvedOperand laneId;
2228 
2229   // Parse predicate operand.
2230   if (parser.parseLParen() ||
2231       parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2232       parser.parseRParen())
2233     return failure();
2234 
2235   int64_t warpSize;
2236   if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2237       parser.parseRSquare())
2238     return failure();
2239   result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2240                                                         builder.getContext())),
2241                       builder.getI64IntegerAttr(warpSize));
2242 
2243   if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2244     return failure();
2245 
2246   llvm::SMLoc inputsOperandsLoc;
2247   SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2248   SmallVector<Type> inputTypes;
2249   if (succeeded(parser.parseOptionalKeyword("args"))) {
2250     if (parser.parseLParen())
2251       return failure();
2252 
2253     inputsOperandsLoc = parser.getCurrentLocation();
2254     if (parser.parseOperandList(inputsOperands) ||
2255         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2256       return failure();
2257   }
2258   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2259                              result.operands))
2260     return failure();
2261 
2262   // Parse optional results type list.
2263   if (parser.parseOptionalArrowTypeList(result.types))
2264     return failure();
2265   // Parse the region.
2266   if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2267                          /*argTypes=*/{}))
2268     return failure();
2269   WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2270 
2271   // Parse the optional attribute list.
2272   if (parser.parseOptionalAttrDict(result.attributes))
2273     return failure();
2274   return success();
2275 }
2276 
2277 void WarpExecuteOnLane0Op::getSuccessorRegions(
2278     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2279   if (!point.isParent()) {
2280     regions.push_back(RegionSuccessor(getResults()));
2281     return;
2282   }
2283 
2284   // The warp region is always executed
2285   regions.push_back(RegionSuccessor(&getWarpRegion()));
2286 }
2287 
2288 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2289                                  TypeRange resultTypes, Value laneId,
2290                                  int64_t warpSize) {
2291   build(builder, result, resultTypes, laneId, warpSize,
2292         /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2293 }
2294 
2295 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2296                                  TypeRange resultTypes, Value laneId,
2297                                  int64_t warpSize, ValueRange args,
2298                                  TypeRange blockArgTypes) {
2299   result.addOperands(laneId);
2300   result.addAttribute(getAttributeNames()[0],
2301                       builder.getI64IntegerAttr(warpSize));
2302   result.addTypes(resultTypes);
2303   result.addOperands(args);
2304   assert(args.size() == blockArgTypes.size());
2305   OpBuilder::InsertionGuard guard(builder);
2306   Region *warpRegion = result.addRegion();
2307   Block *block = builder.createBlock(warpRegion);
2308   for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2309     block->addArgument(type, arg.getLoc());
2310 }
2311 
2312 /// Helper check if the distributed vector type is consistent with the expanded
2313 /// type and distributed size.
2314 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2315                                            int64_t warpSize, Operation *op) {
2316   // If the types matches there is no distribution.
2317   if (expanded == distributed)
2318     return success();
2319   auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2320   auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2321   if (!expandedVecType || !distributedVecType)
2322     return op->emitOpError("expected vector type for distributed operands.");
2323   if (expandedVecType.getRank() != distributedVecType.getRank() ||
2324       expandedVecType.getElementType() != distributedVecType.getElementType())
2325     return op->emitOpError(
2326         "expected distributed vectors to have same rank and element type.");
2327 
2328   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2329   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2330     int64_t eDim = expandedVecType.getDimSize(i);
2331     int64_t dDim = distributedVecType.getDimSize(i);
2332     if (eDim == dDim)
2333       continue;
2334     if (eDim % dDim != 0)
2335       return op->emitOpError()
2336              << "expected expanded vector dimension #" << i << " (" << eDim
2337              << ") to be a multipler of the distributed vector dimension ("
2338              << dDim << ")";
2339     scales[i] = eDim / dDim;
2340   }
2341   if (std::accumulate(scales.begin(), scales.end(), 1,
2342                       std::multiplies<int64_t>()) != warpSize)
2343     return op->emitOpError()
2344            << "incompatible distribution dimensions from " << expandedVecType
2345            << " to " << distributedVecType << " with warp size = " << warpSize;
2346 
2347   return success();
2348 }
2349 
2350 LogicalResult WarpExecuteOnLane0Op::verify() {
2351   if (getArgs().size() != getWarpRegion().getNumArguments())
2352     return emitOpError(
2353         "expected same number op arguments and block arguments.");
2354   auto yield =
2355       cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2356   if (yield.getNumOperands() != getNumResults())
2357     return emitOpError(
2358         "expected same number of yield operands and return values.");
2359   int64_t warpSize = getWarpSize();
2360   for (auto [regionArg, arg] :
2361        llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2362     if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2363                                      warpSize, getOperation())))
2364       return failure();
2365   }
2366   for (auto [yieldOperand, result] :
2367        llvm::zip_equal(yield.getOperands(), getResults())) {
2368     if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2369                                      warpSize, getOperation())))
2370       return failure();
2371   }
2372   return success();
2373 }
2374 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2375   return succeeded(
2376       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2377 }
2378 
2379 //===----------------------------------------------------------------------===//
2380 // GPU KernelMetadataAttr
2381 //===----------------------------------------------------------------------===//
2382 
2383 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2384                                            DictionaryAttr metadata) {
2385   assert(kernel && "invalid kernel");
2386   return get(kernel.getNameAttr(), kernel.getFunctionType(),
2387              kernel.getAllArgAttrs(), metadata);
2388 }
2389 
2390 KernelMetadataAttr
2391 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2392                                FunctionOpInterface kernel,
2393                                DictionaryAttr metadata) {
2394   assert(kernel && "invalid kernel");
2395   return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2396                     kernel.getAllArgAttrs(), metadata);
2397 }
2398 
2399 KernelMetadataAttr
2400 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2401   if (attrs.empty())
2402     return *this;
2403   NamedAttrList attrList;
2404   if (DictionaryAttr dict = getMetadata())
2405     attrList.append(dict);
2406   attrList.append(attrs);
2407   return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2408                                  attrList.getDictionary(getContext()));
2409 }
2410 
2411 LogicalResult
2412 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2413                            StringAttr name, Type functionType,
2414                            ArrayAttr argAttrs, DictionaryAttr metadata) {
2415   if (name.empty())
2416     return emitError() << "the kernel name can't be empty";
2417   if (argAttrs) {
2418     if (llvm::any_of(argAttrs, [](Attribute attr) {
2419           return !llvm::isa<DictionaryAttr>(attr);
2420         }))
2421       return emitError()
2422              << "all attributes in the array must be a dictionary attribute";
2423   }
2424   return success();
2425 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // GPU KernelTableAttr
2429 //===----------------------------------------------------------------------===//
2430 
2431 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2432                                      ArrayRef<KernelMetadataAttr> kernels,
2433                                      bool isSorted) {
2434   // Note that `is_sorted` is always only invoked once even with assertions ON.
2435   assert((!isSorted || llvm::is_sorted(kernels)) &&
2436          "expected a sorted kernel array");
2437   // Immediately return the attribute if the array is sorted.
2438   if (isSorted || llvm::is_sorted(kernels))
2439     return Base::get(context, kernels);
2440   // Sort the array.
2441   SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2442   llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2443   return Base::get(context, kernelsTmp);
2444 }
2445 
2446 KernelTableAttr KernelTableAttr::getChecked(
2447     function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2448     ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2449   // Note that `is_sorted` is always only invoked once even with assertions ON.
2450   assert((!isSorted || llvm::is_sorted(kernels)) &&
2451          "expected a sorted kernel array");
2452   // Immediately return the attribute if the array is sorted.
2453   if (isSorted || llvm::is_sorted(kernels))
2454     return Base::getChecked(emitError, context, kernels);
2455   // Sort the array.
2456   SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2457   llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2458   return Base::getChecked(emitError, context, kernelsTmp);
2459 }
2460 
2461 LogicalResult
2462 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2463                         ArrayRef<KernelMetadataAttr> kernels) {
2464   if (kernels.size() < 2)
2465     return success();
2466   // Check that the kernels are uniquely named.
2467   if (std::adjacent_find(kernels.begin(), kernels.end(),
2468                          [](KernelMetadataAttr l, KernelMetadataAttr r) {
2469                            return l.getName() == r.getName();
2470                          }) != kernels.end()) {
2471     return emitError() << "expected all kernels to be uniquely named";
2472   }
2473   return success();
2474 }
2475 
2476 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2477   auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2478   return found ? *iterator : KernelMetadataAttr();
2479 }
2480 
2481 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2482   auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2483   return found ? *iterator : KernelMetadataAttr();
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // GPU target options
2488 //===----------------------------------------------------------------------===//
2489 
2490 TargetOptions::TargetOptions(
2491     StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2492     StringRef cmdOptions, StringRef elfSection,
2493     CompilationTarget compilationTarget,
2494     function_ref<SymbolTable *()> getSymbolTableCallback,
2495     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2496     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2497     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2498     function_ref<void(StringRef)> isaCallback)
2499     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2500                     cmdOptions, elfSection, compilationTarget,
2501                     getSymbolTableCallback, initialLlvmIRCallback,
2502                     linkedLlvmIRCallback, optimizedLlvmIRCallback,
2503                     isaCallback) {}
2504 
2505 TargetOptions::TargetOptions(
2506     TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2507     StringRef cmdOptions, StringRef elfSection,
2508     CompilationTarget compilationTarget,
2509     function_ref<SymbolTable *()> getSymbolTableCallback,
2510     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2511     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2512     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2513     function_ref<void(StringRef)> isaCallback)
2514     : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2515       cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2516       compilationTarget(compilationTarget),
2517       getSymbolTableCallback(getSymbolTableCallback),
2518       initialLlvmIRCallback(initialLlvmIRCallback),
2519       linkedLlvmIRCallback(linkedLlvmIRCallback),
2520       optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2521       isaCallback(isaCallback), typeID(typeID) {}
2522 
2523 TypeID TargetOptions::getTypeID() const { return typeID; }
2524 
2525 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2526 
2527 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2528   return librariesToLink;
2529 }
2530 
2531 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2532 
2533 StringRef TargetOptions::getELFSection() const { return elfSection; }
2534 
2535 SymbolTable *TargetOptions::getSymbolTable() const {
2536   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2537 }
2538 
2539 function_ref<void(llvm::Module &)>
2540 TargetOptions::getInitialLlvmIRCallback() const {
2541   return initialLlvmIRCallback;
2542 }
2543 
2544 function_ref<void(llvm::Module &)>
2545 TargetOptions::getLinkedLlvmIRCallback() const {
2546   return linkedLlvmIRCallback;
2547 }
2548 
2549 function_ref<void(llvm::Module &)>
2550 TargetOptions::getOptimizedLlvmIRCallback() const {
2551   return optimizedLlvmIRCallback;
2552 }
2553 
2554 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2555   return isaCallback;
2556 }
2557 
2558 CompilationTarget TargetOptions::getCompilationTarget() const {
2559   return compilationTarget;
2560 }
2561 
2562 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2563   return CompilationTarget::Fatbin;
2564 }
2565 
2566 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2567 TargetOptions::tokenizeCmdOptions() const {
2568   std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2569   llvm::StringSaver stringSaver(options.first);
2570   StringRef opts = cmdOptions;
2571   // For a correct tokenization of the command line options `opts` must be
2572   // unquoted, otherwise the tokenization function returns a single string: the
2573   // unquoted `cmdOptions` -which is not the desired behavior.
2574   // Remove any quotes if they are at the beginning and end of the string:
2575   if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2576     opts.consume_front("\""), opts.consume_back("\"");
2577   if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2578     opts.consume_front("'"), opts.consume_back("'");
2579 #ifdef _WIN32
2580   llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2581                                        /*MarkEOLs=*/false);
2582 #else
2583   llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2584                                    /*MarkEOLs=*/false);
2585 #endif // _WIN32
2586   return options;
2587 }
2588 
2589 MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions)
2590 
2591 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2592 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2593 
2594 #define GET_ATTRDEF_CLASSES
2595 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2596 
2597 #define GET_OP_CLASSES
2598 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2599 
2600 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
2601