1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 26 #include "mlir/Dialect/Utils/IndexingUtils.h" 27 #include "mlir/Dialect/Utils/StaticValueUtils.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/Value.h" 32 #include "llvm/ADT/ArrayRef.h" 33 34 using namespace mlir; 35 using namespace mlir::linalg; 36 using namespace mlir::nvgpu; 37 using namespace mlir::NVVM; 38 using namespace mlir::transform; 39 40 #define DEBUG_TYPE "nvgpu-transforms" 41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 42 #define DBGSNL() (llvm::dbgs() << "\n") 43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") 44 45 //===----------------------------------------------------------------------===// 46 // Apply...ConversionPatternsOp 47 //===----------------------------------------------------------------------===// 48 49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( 50 TypeConverter &typeConverter, RewritePatternSet &patterns) { 51 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 52 /// device-side async tokens cannot be materialized in nvvm. We just 53 /// convert them to a dummy i32 type in order to easily drop them during 54 /// conversion. 55 populateGpuMemorySpaceAttributeConversions( 56 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned { 57 switch (space) { 58 case gpu::AddressSpace::Global: 59 return static_cast<unsigned>( 60 NVVM::NVVMMemorySpace::kGlobalMemorySpace); 61 case gpu::AddressSpace::Workgroup: 62 return static_cast<unsigned>( 63 NVVM::NVVMMemorySpace::kSharedMemorySpace); 64 case gpu::AddressSpace::Private: 65 return 0; 66 } 67 llvm_unreachable("unknown address space enum value"); 68 return 0; 69 }); 70 llvmTypeConverter.addConversion( 71 [&](nvgpu::DeviceAsyncTokenType type) -> Type { 72 return llvmTypeConverter.convertType( 73 IntegerType::get(type.getContext(), 32)); 74 }); 75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { 76 return llvmTypeConverter.convertType( 77 IntegerType::get(type.getContext(), 64)); 78 }); 79 llvmTypeConverter.addConversion( 80 [&](nvgpu::WarpgroupAccumulatorType type) -> Type { 81 Type elemType = type.getFragmented().getElementType(); 82 int64_t sizeM = type.getFragmented().getDimSize(0); 83 int64_t sizeN = type.getFragmented().getDimSize(1); 84 85 unsigned numMembers; 86 if (elemType.isF32() || elemType.isInteger(32)) 87 numMembers = sizeN / 2; 88 else if (elemType.isF16()) 89 numMembers = sizeN / 4; 90 else 91 llvm_unreachable("unsupported type for warpgroup accumulator"); 92 93 SmallVector<Type> innerStructBody; 94 for (unsigned i = 0; i < numMembers; i++) 95 innerStructBody.push_back(elemType); 96 auto innerStructType = LLVM::LLVMStructType::getLiteral( 97 type.getContext(), innerStructBody); 98 99 SmallVector<Type> structBody; 100 for (int i = 0; i < sizeM; i += kWgmmaSizeM) 101 structBody.push_back(innerStructType); 102 103 auto convertedType = 104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); 105 return llvmTypeConverter.convertType(convertedType); 106 }); 107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { 108 return llvmTypeConverter.convertType( 109 getMBarrierMemrefType(type.getContext(), type)); 110 }); 111 llvmTypeConverter.addConversion( 112 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { 113 return llvmTypeConverter.convertType( 114 IntegerType::get(type.getContext(), 64)); 115 }); 116 llvmTypeConverter.addConversion( 117 [&](nvgpu::TensorMapDescriptorType type) -> Type { 118 return LLVM::LLVMPointerType::get(type.getContext()); 119 }); 120 populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); 121 } 122 123 LogicalResult 124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( 125 transform::TypeConverterBuilderOpInterface builder) { 126 if (builder.getTypeConverterType() != "LLVMTypeConverter") 127 return emitOpError("expected LLVMTypeConverter"); 128 return success(); 129 } 130 131 //===---------------------------------------------------------------------===// 132 // CreateAsyncGroupsOp 133 //===---------------------------------------------------------------------===// 134 135 void transform::CreateAsyncGroupsOp::getEffects( 136 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 137 transform::consumesHandle(getTargetMutable(), effects); 138 transform::producesHandle(getOperation()->getOpResults(), effects); 139 transform::modifiesPayload(effects); 140 } 141 142 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( 143 TransformRewriter &rewriter, Operation *target, 144 ApplyToEachResultList &results, TransformState &state) { 145 nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); 146 results.push_back(target); 147 return DiagnosedSilenceableFailure::success(); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // PipelineSharedMemoryCopiesOp 152 //===----------------------------------------------------------------------===// 153 154 /// Returns true if the given type has the default memory space. 155 static bool hasDefaultMemorySpace(BaseMemRefType type) { 156 return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; 157 } 158 159 /// Returns true if the given type has the shared (workgroup) memory space. 160 static bool hasSharedMemorySpace(BaseMemRefType type) { 161 auto space = 162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace()); 163 return space && 164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); 165 } 166 167 /// Returns the value produced by a load from the default memory space. Returns 168 /// null if the operation is not such a load. 169 static Value getValueLoadedFromGlobal(Operation *op) { 170 // TODO: consider an interface or leveraging the memory effects interface. 171 auto load = dyn_cast<vector::TransferReadOp>(op); 172 if (!load) 173 return nullptr; 174 175 auto loadType = dyn_cast<MemRefType>(load.getSource().getType()); 176 if (!loadType || !hasDefaultMemorySpace(loadType)) 177 return nullptr; 178 return load; 179 } 180 181 /// Returns true if the operation is storing the given value into shared memory. 182 static bool isStoreToShared(Operation *op, Value v) { 183 // TOD: consider an interface or leveraging the memory effects interface. 184 auto store = dyn_cast<vector::TransferWriteOp>(op); 185 if (!store || store.getVector() != v) 186 return false; 187 188 auto storeType = dyn_cast<MemRefType>(store.getSource().getType()); 189 return storeType || hasSharedMemorySpace(storeType); 190 } 191 192 /// Returns true if the operation is a load from the default memory space the 193 /// result of which is only stored into the shared memory space. 194 static bool isLoadFromGlobalStoredToShared(Operation *op) { 195 Value loaded = getValueLoadedFromGlobal(op); 196 if (!loaded || !loaded.hasOneUse()) 197 return false; 198 199 return isStoreToShared(*loaded.getUsers().begin(), loaded); 200 } 201 202 /// Populate `ops` with the set of operations that belong to the stage 0 of the 203 /// pipelined version of the given loop when pipelining copies to shared memory. 204 /// Specifically, this collects: 205 /// 206 /// 1. all loads from global memory, both sync and async; 207 /// 2. the barriers for async loads. 208 /// 209 /// In particular, barriers are omitted if they do not dominate at least one 210 /// async load for which there is not yet a barrier. 211 static LogicalResult 212 collectStage0PipeliningOps(scf::ForOp forOp, 213 llvm::SmallPtrSet<Operation *, 16> &ops) { 214 215 llvm::SmallPtrSet<Operation *, 4> barriers; 216 for (Operation &op : *forOp.getBody()) { 217 // Bail on nested ops for now. 218 if (op.getNumRegions() > 0) 219 return failure(); 220 221 if (isa<gpu::BarrierOp>(op)) { 222 barriers.insert(&op); 223 continue; 224 } 225 226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { 227 ops.insert(&op); 228 ops.insert(std::make_move_iterator(barriers.begin()), 229 std::make_move_iterator(barriers.end())); 230 assert(barriers.empty() && 231 "expected to have moved the barriers into another set"); 232 continue; 233 } 234 235 if (isLoadFromGlobalStoredToShared(&op)) { 236 ops.insert(&op); 237 continue; 238 } 239 } 240 241 return success(); 242 } 243 244 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute 245 /// of async wait operations corresponding to pipelined shared memory copies. 246 // TODO: this currently assumes that there are no groups that could be in flight 247 // in the existing code. 248 static void 249 setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, 250 scf::PipeliningOption::PipelinerPart part, 251 unsigned iteration, unsigned depth) { 252 // Based on the order of copies within the loop we need to set the number 253 // of copies in flight, unless it is already set. 254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); 255 if (!waitOp || waitOp.getNumGroups()) 256 return; 257 258 int numGroupInFlight = 0; 259 if (part == scf::PipeliningOption::PipelinerPart::Kernel || 260 part == scf::PipeliningOption::PipelinerPart::Prologue) { 261 numGroupInFlight = depth - 1; 262 } else { 263 // By construction there should be no wait op in the prologue as all the 264 // wait should be in the last stage. 265 assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); 266 // Based on the schedule we pick we know how many groups are in flight for 267 // each iteration of the epilogue. 268 numGroupInFlight = depth - 1 - iteration; 269 } 270 waitOp.setNumGroups(numGroupInFlight); 271 } 272 273 /// Hook for the loop pipeliner that populates `ops` with the stage information 274 /// as follows: 275 /// 276 /// - operations in `stage0Ops` (typically loads from global memory and 277 /// related barriers) are at stage 0; 278 /// - operations in the backward slice of any stage0Ops are all at stage 0; 279 /// - other operations are at stage `depth`; 280 /// - the internal order of the pipelined loop has ops at stage `depth` first, 281 /// then those at stage 0, with relative order within each group preserved. 282 /// 283 static void getPipelineStages( 284 scf::ForOp forOp, 285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages, 286 unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { 287 SetVector<Operation *> dependencies; 288 BackwardSliceOptions options([&](Operation *visited) { 289 return visited->getBlock() == forOp.getBody(); 290 }); 291 options.inclusive = true; 292 for (Operation &op : forOp.getBody()->getOperations()) { 293 if (stage0Ops.contains(&op)) 294 getBackwardSlice(&op, &dependencies, options); 295 } 296 297 for (Operation &op : forOp.getBody()->getOperations()) { 298 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op)) 299 opsWithPipelineStages.emplace_back(&op, depth); 300 } 301 for (Operation &op : forOp.getBody()->getOperations()) { 302 if (dependencies.contains(&op)) 303 opsWithPipelineStages.emplace_back(&op, 0); 304 } 305 } 306 307 /// Hook for the loop pipeliner. Replaces op with a predicated version and 308 /// returns the resulting operation. Returns the original op if the predication 309 /// isn't necessary for the given op. Returns null if predication is needed but 310 /// not supported. 311 static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, 312 Operation *op, Value predicate) { 313 // Some operations may be fine to execute "speculatively" more times than the 314 // original number of iterations, in particular side-effect free operations 315 // and barriers, even if they cannot be predicated. 316 if (isMemoryEffectFree(op) || 317 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, 318 nvgpu::DeviceAsyncWaitOp>(op)) { 319 return op; 320 } 321 322 // Otherwise, only async copies can currently be predicated. 323 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); 324 if (!asyncCopyOp) 325 return nullptr; 326 327 // Create srcElement Value based on `predicate`. The next lines generate 328 // the following code: 329 // 330 // srcElement = (pred) ? prevSrcElements : 0; 331 // 332 Location loc = asyncCopyOp->getLoc(); 333 Value dstElements = 334 rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr()); 335 Value originalSrcElement = 336 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; 337 Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0); 338 auto srcElements = rewriter.create<arith::SelectOp>( 339 loc, predicate, originalSrcElement, c0Index); 340 auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>( 341 loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), 342 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), 343 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, 344 UnitAttr()); 345 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); 346 return asyncCopyZeroFillOp; 347 } 348 349 /// Applies loop pipelining with the given depth to the given loop so that 350 /// copies into the shared memory are pipelined. Doesn't affect other loops. 351 /// Returns a pair containing the error state and the pipelined op, the latter 352 /// being null in case of any failure. The error state contains a definite error 353 /// if the IR has been modified and a silenceable error otherwise. 354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp> 355 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, 356 bool epiloguePeeling) { 357 llvm::SmallPtrSet<Operation *, 16> stage0Ops; 358 if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { 359 return std::make_tuple( 360 emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"), 361 scf::ForOp()); 362 } 363 if (stage0Ops.empty()) { 364 return std::make_tuple( 365 emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp()); 366 } 367 368 scf::PipeliningOption options; 369 unsigned maxDepth = depth; 370 auto setAnnotation = [&](Operation *op, 371 scf::PipeliningOption::PipelinerPart part, 372 unsigned iteration) { 373 return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth); 374 }; 375 options.getScheduleFn = 376 [&](scf::ForOp schedulingFor, 377 std::vector<std::pair<Operation *, unsigned>> &ops) { 378 if (schedulingFor != forOp) 379 return; 380 return getPipelineStages(forOp, ops, maxDepth, stage0Ops); 381 }; 382 options.annotateFn = setAnnotation; 383 if (!epiloguePeeling) { 384 options.peelEpilogue = false; 385 options.predicateFn = replaceOpWithPredicatedOp; 386 } 387 388 OpBuilder::InsertionGuard guard(rewriter); 389 rewriter.setInsertionPoint(forOp); 390 bool modifiedIR; 391 FailureOr<scf::ForOp> maybePipelined = 392 pipelineForLoop(rewriter, forOp, options, &modifiedIR); 393 if (succeeded(maybePipelined)) { 394 return std::make_tuple(DiagnosedSilenceableFailure::success(), 395 *maybePipelined); 396 } 397 return std::make_tuple( 398 modifiedIR 399 ? DiagnosedSilenceableFailure::definiteFailure() 400 : emitSilenceableFailure(forOp, "pipelining preconditions failed"), 401 scf::ForOp()); 402 } 403 404 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( 405 TransformRewriter &rewriter, scf::ForOp forOp, 406 ApplyToEachResultList &results, TransformState &state) { 407 auto [diag, pipelined] = pipelineForSharedCopies( 408 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue()); 409 if (diag.succeeded()) { 410 results.push_back(pipelined); 411 return DiagnosedSilenceableFailure::success(); 412 } 413 if (diag.isDefiniteFailure()) { 414 auto diag = emitDefiniteFailure("irreversible pipelining failure"); 415 if (!getPeelEpilogue()) { 416 diag.attachNote(forOp->getLoc()) << "couldn't predicate?"; 417 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName(); 418 } 419 return diag; 420 } 421 422 return std::move(diag); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // RewriteMatmulAsMmaSyncOp 427 //===----------------------------------------------------------------------===// 428 429 /// Helper struct to encode a pair of row/column indexings in the form of 430 /// affine expressions. 431 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { 432 RowColIndexing(AffineExpr row, AffineExpr col) 433 : std::pair<AffineExpr, AffineExpr>(row, col) {} 434 435 AffineExpr row() const { return first; }; 436 AffineExpr col() const { return second; }; 437 438 void print(llvm::raw_ostream &os) const { 439 os << "- indexing: " << first << ", " << second; 440 } 441 }; 442 443 /// Helper struct to provide a simple mapping from matmul operations to the 444 /// corresponding mma.sync operation. This is constrained to the case where the 445 /// matmul matches the mma.sync operation 1-1. 446 struct MmaSyncBuilder { 447 MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) 448 : b(b), loc(loc), laneId(laneId) {} 449 450 using IndexCalculator = 451 std::function<SmallVector<RowColIndexing>(MLIRContext *)>; 452 453 /// Create the mma.sync operation corresponding to `linalgOp` along with all 454 /// the supporting load/store and vector operations. 455 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp); 456 457 private: 458 struct MmaSyncInfo { 459 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns; 460 std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>> 461 vectorShapes; 462 SmallVector<int64_t> mmaShape; 463 bool tf32Enabled; 464 }; 465 466 /// Return the specific index calculator for the given `linalgOp` or failure 467 /// if the op is not supported. This is the toplevel switch that should just 468 /// be Tablegen'd in the future. 469 FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape, 470 TypeRange elementalTypes); 471 472 //===--------------------------------------------------------------------===// 473 // Instruction-specific row, column indexing expression builders. 474 // These should all be declaratively specified via Tablegen in the future. 475 // The Tablegen specification should be as straightforward as possible to 476 // only model the existing size and type combinations. 477 //===--------------------------------------------------------------------===// 478 // 479 // TODO: Tablegen all this. 480 //===--------------------------------------------------------------------===// 481 // m16n8k4 tf32 case. 482 //===--------------------------------------------------------------------===// 483 /// From the NVIDIA doc: 484 /// groupID = %laneid >> 2 485 /// threadIDInGroup = %laneid % 4 486 /// row = groupID for a0 487 /// groupID + 8 for a1 488 /// col = threadIDInGroup 489 static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) { 490 auto dim = getAffineDimExpr(0, ctx); 491 AffineExpr groupID = dim.floorDiv(4); 492 AffineExpr threadIDInGroup = dim % 4; 493 return {RowColIndexing{groupID, threadIDInGroup}, 494 RowColIndexing{groupID + 8, threadIDInGroup}}; 495 } 496 497 /// From the NVIDIA doc: 498 /// groupID = %laneid >> 2 499 /// threadIDInGroup = %laneid % 4 500 /// row = threadIDInGroup 501 /// col = groupID 502 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) { 503 auto dim = getAffineDimExpr(0, ctx); 504 AffineExpr groupID = dim.floorDiv(4); 505 AffineExpr threadIDInGroup = dim % 4; 506 return {RowColIndexing{threadIDInGroup, groupID}}; 507 } 508 509 /// From the NVIDIA doc: 510 /// groupID = %laneid >> 2 511 /// threadIDInGroup = %laneid % 4 512 /// row = groupID for c0 and c1 513 /// groupID + 8 for c2 and c3 514 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} 515 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) { 516 auto dim = getAffineDimExpr(0, ctx); 517 AffineExpr groupID = dim.floorDiv(4); 518 AffineExpr threadIDInGroup = dim % 4; 519 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, 520 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, 521 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, 522 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; 523 } 524 525 //===--------------------------------------------------------------------===// 526 // m16n8k16 f16 case. 527 //===--------------------------------------------------------------------===// 528 /// From the NVIDIA doc: 529 /// groupID = %laneid >> 2 530 /// threadIDInGroup = %laneid % 4 531 /// 532 /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 533 /// groupID + 8 Otherwise 534 /// 535 /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 536 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 537 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) { 538 auto dim = getAffineDimExpr(0, ctx); 539 AffineExpr groupID = dim.floorDiv(4); 540 AffineExpr threadIDInGroup = dim % 4; 541 // clang-format off 542 return { 543 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 544 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 545 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 546 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 547 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 548 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 549 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 550 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 551 }; 552 // clang-format on 553 } 554 555 /// From the NVIDIA doc: 556 /// groupID = %laneid >> 2 557 /// threadIDInGroup = %laneid % 4 558 /// 559 /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 560 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 561 /// 562 /// col = groupID 563 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) { 564 auto dim = getAffineDimExpr(0, ctx); 565 AffineExpr groupID = dim.floorDiv(4); 566 AffineExpr threadIDInGroup = dim % 4; 567 // clang-format off 568 return { 569 RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 570 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 571 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 572 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 573 }; 574 // clang-format on 575 } 576 577 /// From the NVIDIA doc: 578 /// groupID = %laneid >> 2 579 /// threadIDInGroup = %laneid % 4 580 /// 581 /// row = groupID for ci where i < 2 582 /// groupID + 8 for ci where i >= 2 583 /// 584 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} 585 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) { 586 auto dim = getAffineDimExpr(0, ctx); 587 AffineExpr groupID = dim.floorDiv(4); 588 AffineExpr threadIDInGroup = dim % 4; 589 // clang-format off 590 return { 591 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 592 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 593 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 594 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 595 }; 596 // clang-format on 597 } 598 599 //===--------------------------------------------------------------------===// 600 /// Helper functions to create customizable load and stores operations. The 601 /// specific shapes of each MMA instruction are passed via the 602 /// IndexCalculator callback. 603 //===--------------------------------------------------------------------===// 604 /// Build a list of memref.load operations indexed at `(row, col)` indices 605 /// that make sense for a particular MMA instruction and specified via the 606 /// IndexCalculator callback. 607 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc, 608 OpFoldResult laneId, Value memref, 609 const IndexCalculator &indexFn); 610 611 /// Perform a distributed load of a vector operand of `vectorShape` for a 612 /// particular MMA instruction whose `(row, col)` indices are specified via 613 /// the IndexCalculator callback. Each `laneId` loads the subportion of the 614 /// data that makes sense for the particular MMA operation. 615 /// The `vectorShape` matches existing NVGPU dialect op specification but 616 /// could also be flattened in the future if needed for simplification. 617 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, 618 OpFoldResult laneId, Value memref, 619 IndexCalculator indexFn, 620 ArrayRef<int64_t> vectorShape); 621 622 /// Build a list of memref.store operations indexed at `(row, col)` indices 623 /// that make sense for a particular MMA instruction and specified via the 624 /// IndexCalculator callback. 625 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc, 626 ValueRange toStore, 627 OpFoldResult laneId, Value memref, 628 const IndexCalculator &indexFn); 629 630 /// Perform a distributed store of a vector operand of `vectorShape` for a 631 /// particular MMA instruction whose `(row, col)` indices are specified via 632 /// the IndexCalculator callback. Each `laneId` loads the subportion of the 633 /// data that makes sense for the particular MMA operation. 634 /// The `vectorShape` matches existing NVGPU dialect op specification but 635 /// could also be flattened in the future if needed for simplification. 636 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand( 637 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, 638 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape); 639 640 OpBuilder &b; 641 Location loc; 642 OpFoldResult laneId; 643 }; 644 645 //===--------------------------------------------------------------------===// 646 /// Helper functions to create customizable load and stores operations. The 647 /// specific shapes of each MMA instruction are passed via the 648 /// IndexCalculator callback. 649 //===--------------------------------------------------------------------===// 650 651 template <typename ApplyFn, typename ReduceFn> 652 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, 653 ReduceFn reduceFn) { 654 VectorType vectorType = cast<VectorType>(vector.getType()); 655 auto vectorShape = vectorType.getShape(); 656 auto strides = computeStrides(vectorShape); 657 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { 658 auto indices = delinearize(idx, strides); 659 reduceFn(applyFn(vector, idx, indices), idx, indices); 660 } 661 } 662 663 SmallVector<Value> 664 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, 665 OpFoldResult laneId, Value memref, 666 const IndexCalculator &indexFn) { 667 auto aff = [&](AffineExpr e) { 668 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); 669 }; 670 SmallVector<Value> res; 671 SmallVector<RowColIndexing> indexings = indexFn(b.getContext()); 672 for (auto indexing : indexings) { 673 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); 674 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); 675 auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col}); 676 res.push_back(load); 677 } 678 return res; 679 } 680 681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( 682 OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, 683 IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { 684 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn)); 685 686 Type elementType = getElementTypeOrSelf(memref.getType()); 687 auto vt = VectorType::get(vectorShape, elementType); 688 Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); 689 foreachIndividualVectorElement( 690 res, 691 /*applyFn=*/ 692 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { 693 return loads[linearIdx]; 694 }, 695 /*reduceFn=*/ 696 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { 697 res = b.create<vector::InsertOp>(loc, v, res, indices); 698 }); 699 700 return res; 701 } 702 703 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores( 704 OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, 705 Value memref, const IndexCalculator &indexFn) { 706 auto aff = [&](AffineExpr e) { 707 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); 708 }; 709 SmallVector<Operation *> res; 710 for (auto [indexing, val] : 711 llvm::zip_equal(indexFn(b.getContext()), toStore)) { 712 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); 713 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); 714 Operation *store = 715 b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col}); 716 res.push_back(store); 717 } 718 return res; 719 } 720 721 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( 722 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, 723 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { 724 SmallVector<Value> toStore; 725 toStore.reserve(32); 726 foreachIndividualVectorElement( 727 vectorToStore, 728 /*applyFn=*/ 729 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { 730 return b.create<vector::ExtractOp>(loc, vectorToStore, indices); 731 }, 732 /*reduceFn=*/ 733 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { 734 toStore.push_back(v); 735 }); 736 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn)); 737 } 738 739 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, 740 SmallVector<int64_t>> 741 makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, 742 ArrayRef<int64_t> res) { 743 SmallVector<int64_t> vlhs(lhs); 744 SmallVector<int64_t> vrhs(rhs); 745 SmallVector<int64_t> vres(res); 746 return std::make_tuple(vlhs, vrhs, vres); 747 } 748 749 FailureOr<MmaSyncBuilder::MmaSyncInfo> 750 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape, 751 TypeRange elementalTypes) { 752 // TODO: Tablegen all this. 753 Type f16 = b.getF16Type(); 754 Type f32 = b.getF32Type(); 755 if (opShape == ArrayRef<int64_t>{16, 8, 4} && 756 elementalTypes == TypeRange{f32, f32, f32}) { 757 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs, 758 &MmaSyncBuilder::m16n8k4tf32Rhs, 759 &MmaSyncBuilder::m16n8k4tf32Res), 760 makeVectorShapes({2, 1}, {1, 1}, {2, 2}), 761 SmallVector<int64_t>{opShape}, 762 /*tf32Enabled=*/true}; 763 } 764 // This is the version with f16 accumulation. 765 // TODO: version with f32 accumulation. 766 if (opShape == ArrayRef<int64_t>{16, 8, 16} && 767 elementalTypes == TypeRange{f16, f16, f16}) { 768 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs, 769 &MmaSyncBuilder::m16n8k16f16Rhs, 770 &MmaSyncBuilder::m16n8k16f16Res), 771 makeVectorShapes({4, 2}, {2, 2}, {2, 2}), 772 SmallVector<int64_t>{opShape}, 773 /*tf32Enabled=*/false}; 774 } 775 return failure(); 776 } 777 778 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { 779 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); 780 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); 781 Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); 782 assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 && 783 "expected lhs to be a 2D memref"); 784 assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 && 785 "expected rhs to be a 2D memref"); 786 assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 && 787 "expected res to be a 2D memref"); 788 789 int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0]; 790 int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1]; 791 int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1]; 792 Type lhsType = getElementTypeOrSelf(lhsMemRef.getType()); 793 Type rhsType = getElementTypeOrSelf(rhsMemRef.getType()); 794 Type resType = getElementTypeOrSelf(resMemRef.getType()); 795 796 FailureOr<MmaSyncInfo> maybeInfo = 797 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); 798 if (failed(maybeInfo)) 799 return failure(); 800 801 MmaSyncInfo info = *maybeInfo; 802 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; 803 auto [lhsShape, rhsShape, resShape] = info.vectorShapes; 804 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, 805 lhsIndexFn, lhsShape); 806 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef, 807 rhsIndexFn, rhsShape); 808 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, 809 resIndexFn, resShape); 810 res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape, 811 info.tf32Enabled); 812 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, 813 resShape); 814 return res.getDefiningOp(); 815 } 816 817 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( 818 transform::TransformRewriter &rewriter, LinalgOp linalgOp, 819 transform::ApplyToEachResultList &results, 820 transform::TransformState &state) { 821 bool fail = true; 822 // TODO: more robust detection of matmulOp, with transposes etc. 823 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { 824 // Check to not let go the matmul with extended semantic, through this 825 // transform. 826 if (linalgOp.hasUserDefinedMaps()) { 827 return emitSilenceableError() 828 << "only matmul ops with non-extended semantics are supported"; 829 } 830 Location loc = linalgOp.getLoc(); 831 // TODO: more robust computation of laneId, for now assume a single warp. 832 Value laneId = rewriter.create<gpu::ThreadIdOp>( 833 loc, rewriter.getIndexType(), gpu::Dimension::x); 834 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) 835 fail = false; 836 } 837 838 if (fail) { 839 DiagnosedSilenceableFailure diag = emitSilenceableError() 840 << "unsupported target op: " << linalgOp; 841 diag.attachNote(linalgOp->getLoc()) << "target op"; 842 return diag; 843 } 844 845 rewriter.eraseOp(linalgOp); 846 return DiagnosedSilenceableFailure::success(); 847 } 848 849 //===----------------------------------------------------------------------===// 850 // Hopper builders. 851 //===----------------------------------------------------------------------===// 852 853 /// Helper to create the base Hopper-specific operations that are reused in 854 /// various other places. 855 struct HopperBuilder { 856 HopperBuilder(RewriterBase &rewriter, Location loc) 857 : rewriter(rewriter), loc(loc) {} 858 859 TypedValue<nvgpu::MBarrierGroupType> 860 buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); 861 862 /// Create tma descriptor op to initiate transfer from global to shared 863 /// memory. This must be done before the launch op, on the host. 864 TypedValue<nvgpu::TensorMapDescriptorType> 865 buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, 866 gpu::LaunchOp launchOp); 867 868 /// Build a tma load from global memory to shared memory using `barrier` to 869 /// synchronize. Return the number of bytes that will be transferred. 870 OpFoldResult 871 buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, 872 TypedValue<MemRefType> sharedMemref, 873 TypedValue<nvgpu::MBarrierGroupType> barrier, 874 SmallVectorImpl<Operation *> &loadOps); 875 void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, 876 ArrayRef<OpFoldResult> sizes); 877 878 /// If threadIdx.x == 0 does TMA request + wait, else just wait. 879 /// Return the operation that performs the transfer on thread0. 880 // TODO: In the future, don't hardcode to thread 0 but elect a leader. 881 SmallVector<Operation *> buildPredicateLoadsOnThread0( 882 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, 883 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, 884 TypedValue<nvgpu::MBarrierGroupType> barrier); 885 886 void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); 887 888 RewriterBase &rewriter; 889 Location loc; 890 }; 891 892 SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( 893 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, 894 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, 895 TypedValue<nvgpu::MBarrierGroupType> barrier) { 896 SmallVector<Operation *> loadOps; 897 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 898 Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 899 Value cond = 900 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero); 901 // clang-format off 902 rewriter.create<scf::IfOp>( 903 /*location=*/loc, 904 /*conditional=*/cond, 905 /*thenBuilder=*/ 906 [&](OpBuilder &lb, Location loc) { 907 SmallVector<OpFoldResult> sizes; 908 sizes.reserve(globalDescriptors.size()); 909 for (auto [desc, shmem] : llvm::zip_equal( 910 globalDescriptors, sharedMemBuffers)) { 911 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); 912 sizes.push_back(sz); 913 } 914 // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. 915 // This may or may not have perf implications. 916 buildBarrierArriveTx(barrier, sizes); 917 rewriter.create<scf::YieldOp>(loc); 918 }, 919 /*elseBuilder=*/ 920 [&](OpBuilder &lb, Location loc) { 921 // TODO: is this for no-thread divergence? 922 // Should we just yield the size and hoist? 923 buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0)); 924 rewriter.create<scf::YieldOp>(loc); 925 }); 926 // clang-format on 927 return loadOps; 928 } 929 930 static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { 931 return gpu::AddressSpaceAttr::get( 932 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); 933 // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); 934 } 935 936 TypedValue<nvgpu::MBarrierGroupType> 937 HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { 938 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); 939 Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>( 940 loc, 941 nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); 942 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 943 rewriter.create<nvgpu::MBarrierInitOp>( 944 loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), 945 zero, Value()); 946 rewriter.create<gpu::BarrierOp>(loc); 947 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier); 948 } 949 950 TypedValue<nvgpu::TensorMapDescriptorType> 951 HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, 952 gpu::LaunchOp launchOp) { 953 OpBuilder::InsertionGuard guard(rewriter); 954 rewriter.setInsertionPoint(launchOp); 955 Value unrankedMemRef = rewriter.create<memref::CastOp>( 956 loc, 957 UnrankedMemRefType::get(memref.getType().getElementType(), 958 memref.getType().getMemorySpace()), 959 memref); 960 SmallVector<OpFoldResult> mixedSizes = 961 memref::getMixedSizes(rewriter, loc, memref); 962 SmallVector<Value> sizes = 963 getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); 964 965 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); 966 Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>( 967 loc, 968 nvgpu::TensorMapDescriptorType::get( 969 rewriter.getContext(), 970 MemRefType::Builder(memref.getType()) 971 .setMemorySpace(sharedMemorySpace), 972 TensorMapSwizzleKind::SWIZZLE_NONE, 973 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, 974 TensorMapInterleaveKind::INTERLEAVE_NONE), 975 unrankedMemRef, sizes); 976 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); 977 } 978 979 OpFoldResult HopperBuilder::buildTmaAsyncLoad( 980 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, 981 TypedValue<MemRefType> sharedMemref, 982 TypedValue<nvgpu::MBarrierGroupType> barrier, 983 SmallVectorImpl<Operation *> &loadOps) { 984 MLIRContext *ctx = rewriter.getContext(); 985 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 986 Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>( 987 loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, 988 Value(), Value()); 989 loadOps.push_back(loadOp); 990 auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); 991 SmallVector<AffineExpr> symbols(mixedSizes.size()); 992 bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); 993 AffineExpr prodExprInBytes = 994 computeProduct(ctx, symbols) * 995 (sharedMemref.getType().getElementTypeBitWidth() / 8); 996 auto res = affine::makeComposedFoldedAffineApply(rewriter, loc, 997 prodExprInBytes, mixedSizes); 998 return res; 999 } 1000 1001 void HopperBuilder::buildBarrierArriveTx( 1002 TypedValue<nvgpu::MBarrierGroupType> barrier, 1003 ArrayRef<OpFoldResult> mixedSizes) { 1004 assert(!mixedSizes.empty() && "expecte non-empty sizes"); 1005 MLIRContext *ctx = rewriter.getContext(); 1006 SmallVector<AffineExpr> symbols(mixedSizes.size()); 1007 bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); 1008 AffineExpr sumExpr = computeSum(ctx, symbols); 1009 OpFoldResult size = 1010 affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes); 1011 Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size); 1012 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1013 rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero, 1014 Value()); 1015 } 1016 1017 void HopperBuilder::buildTryWaitParity( 1018 TypedValue<nvgpu::MBarrierGroupType> barrier) { 1019 Type i1 = rewriter.getI1Type(); 1020 Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0); 1021 // 10M is an arbitrary, not too small or too big number to specify the number 1022 // of ticks before retry. 1023 // TODO: hoist this in a default dialect constant. 1024 Value ticksBeforeRetry = 1025 rewriter.create<arith::ConstantIndexOp>(loc, 10000000); 1026 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1027 rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity, 1028 ticksBeforeRetry, zero); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // RewriteCopyAsTmaOp 1033 //===----------------------------------------------------------------------===// 1034 1035 /// Helper to create the tma operations corresponding to `linalg::CopyOp`. 1036 struct CopyBuilder : public HopperBuilder { 1037 CopyBuilder(RewriterBase &rewriter, Location loc) 1038 : HopperBuilder(rewriter, loc) {} 1039 1040 SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps); 1041 }; 1042 1043 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { 1044 MLIRContext *ctx = rewriter.getContext(); 1045 if (copyOps.empty()) 1046 return SmallVector<Operation *>(); 1047 1048 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>(); 1049 assert(launchOp && "expected launch op"); 1050 1051 // 1. Init a barrier object in shared memory. 1052 OpBuilder::InsertionGuard g(rewriter); 1053 rewriter.setInsertionPoint(copyOps.front()); 1054 AffineExpr bx, by, bz; 1055 bindSymbols(ctx, bx, by, bz); 1056 AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz}); 1057 OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( 1058 rewriter, loc, prod, 1059 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), 1060 launchOp.getBlockSizeZ()}); 1061 1062 TypedValue<nvgpu::MBarrierGroupType> barrier = 1063 buildAndInitBarrierInSharedMemory(numThreads); 1064 1065 SmallVector<TypedValue<MemRefType>> shmems; 1066 SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; 1067 for (Operation *op : copyOps) { 1068 auto copyOp = cast<linalg::CopyOp>(op); 1069 auto inMemRef = 1070 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get()); 1071 assert(inMemRef.getType().getRank() == 2 && 1072 "expected in to be a 2D memref"); 1073 1074 // 2. Build global memory descriptor. 1075 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = 1076 buildGlobalMemRefDescriptor(inMemRef, launchOp); 1077 globalDescs.push_back(globalDesc); 1078 1079 // 3. Shared memory and descriptor for the tmp array. 1080 auto shmem = 1081 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get()); 1082 shmems.push_back(shmem); 1083 } 1084 1085 // 4. Load in from global memory to shared memory using tma. 1086 OpBuilder::InsertionGuard g2(rewriter); 1087 rewriter.setInsertionPoint(copyOps.front()); 1088 SmallVector<Operation *> results = 1089 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); 1090 1091 // 5. Spin-loop until data is ready. 1092 buildTryWaitParity(barrier); 1093 1094 // 6. Erase the ops that have now been rewritten. 1095 for (Operation *op : copyOps) 1096 rewriter.eraseOp(op); 1097 1098 return results; 1099 } 1100 1101 DiagnosedSilenceableFailure 1102 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, 1103 transform::TransformResults &results, 1104 transform::TransformState &state) { 1105 auto payloadOps = state.getPayloadOps(getTarget()); 1106 gpu::LaunchOp commonLaunchOp; 1107 Operation *firstOp, *failingOp; 1108 if (llvm::any_of(payloadOps, [&](Operation *op) { 1109 if (!commonLaunchOp) { 1110 commonLaunchOp = op->getParentOfType<gpu::LaunchOp>(); 1111 firstOp = op; 1112 } 1113 auto fail = !op->getParentOfType<gpu::LaunchOp>() || 1114 commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() || 1115 !isa<linalg::CopyOp>(op); 1116 if (fail) 1117 failingOp = op; 1118 return fail; 1119 })) { 1120 DiagnosedSilenceableFailure diag = 1121 emitSilenceableError() 1122 << "target ops must be linalg::CopyOp nested under a common " 1123 "gpu.LaunchOp to be rewritten because the tma descriptors need to " 1124 "be created on the host.\nBut got: " 1125 << *firstOp << "\nand " << *failingOp; 1126 return diag; 1127 } 1128 1129 // TODO: more robust detection of copy, with transposes etc. 1130 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); 1131 1132 return DiagnosedSilenceableFailure::success(); 1133 } 1134 1135 //===----------------------------------------------------------------------===// 1136 // Transform op registration 1137 //===----------------------------------------------------------------------===// 1138 1139 namespace { 1140 class NVGPUTransformDialectExtension 1141 : public transform::TransformDialectExtension< 1142 NVGPUTransformDialectExtension> { 1143 public: 1144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) 1145 1146 NVGPUTransformDialectExtension() { 1147 declareGeneratedDialect<arith::ArithDialect>(); 1148 declareGeneratedDialect<affine::AffineDialect>(); 1149 declareGeneratedDialect<nvgpu::NVGPUDialect>(); 1150 declareGeneratedDialect<NVVM::NVVMDialect>(); 1151 declareGeneratedDialect<vector::VectorDialect>(); 1152 registerTransformOps< 1153 #define GET_OP_LIST 1154 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" 1155 >(); 1156 } 1157 }; 1158 } // namespace 1159 1160 #define GET_OP_CLASSES 1161 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" 1162 1163 void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { 1164 registry.addExtensions<NVGPUTransformDialectExtension>(); 1165 } 1166