1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14 15 #include <type_traits> 16 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Arith/IR/Arith.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/Region.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/STLExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 39 #define DEBUG_TYPE "vector-to-gpu" 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 41 #define DBGSNL() (llvm::dbgs() << "\n") 42 43 namespace mlir { 44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 45 #include "mlir/Conversion/Passes.h.inc" 46 } // namespace mlir 47 48 using namespace mlir; 49 50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 51 /// AffineMap representing offsets to apply to indices, the function fills 52 /// `indices` with the original indices plus the offsets. The offsets are 53 /// applied by taking into account the permutation map of the transfer op. If 54 /// the `offsetMap` has dimension placeholders, those should be provided in 55 /// `dimValues`. 56 template <typename TransferOpType> 57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 58 AffineMap offsetMap, ArrayRef<Value> dimValues, 59 SmallVector<Value, 4> &indices) { 60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 61 Location loc = xferOp.getLoc(); 62 unsigned offsetsIdx = 0; 63 for (auto expr : xferOp.getPermutationMap().getResults()) { 64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 65 Value prevIdx = indices[dim.getPosition()]; 66 SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end()); 67 dims.push_back(prevIdx); 68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 69 indices[dim.getPosition()] = affine::makeComposedAffineApply( 70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 71 continue; 72 } 73 } 74 } 75 76 // Return true if the contract op can be convert to MMA matmul. 77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 78 bool useNvGpu) { 79 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 81 AffineExpr m, n, k; 82 bindDims(contract.getContext(), m, n, k); 83 auto iteratorTypes = contract.getIteratorTypes().getValue(); 84 if (!(vector::isParallelIterator(iteratorTypes[0]) && 85 vector::isParallelIterator(iteratorTypes[1]) && 86 vector::isReductionIterator(iteratorTypes[2]))) 87 return false; 88 89 // The contract needs to represent a matmul to be able to convert to 90 // MMAMatrix matmul. 91 if (!useNvGpu && 92 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 93 return false; 94 if (useNvGpu && 95 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 96 return false; 97 98 return true; 99 } 100 101 // Return true if the given map represents a transposed matrix load, 102 // i.e. (d0, d1, ...) -> (dn-1, dn-2). 103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 104 MLIRContext *ctx = permutationMap.getContext(); 105 // Local OpBuilder is fine here, we just build attributes. 106 OpBuilder b(ctx); 107 auto nDim = permutationMap.getNumDims(); 108 AffineExpr zero = b.getAffineConstantExpr(0); 109 if (nDim < 2) { 110 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 111 AffineExpr dim0 = b.getAffineDimExpr(0); 112 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 113 } 114 115 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 116 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 117 // Support both transposed and transposed+broadcasted cases. 118 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 119 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 120 } 121 122 // Return the stide for the second-to-last dimension of |type| if it is a memref 123 // and has a constant stride. 124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { 125 auto memrefType = dyn_cast<MemRefType>(type); 126 if (!memrefType) 127 return false; 128 // If the memref is 0 or 1D the horizontal stride is 0. 129 if (memrefType.getRank() < 2) 130 return 0; 131 int64_t offset = 0; 132 SmallVector<int64_t, 2> strides; 133 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 134 strides.back() != 1) 135 return std::nullopt; 136 int64_t stride = strides[strides.size() - 2]; 137 if (stride == ShapedType::kDynamic) 138 return std::nullopt; 139 return stride; 140 } 141 142 // Return true if the transfer op can be converted to a MMA matrix load. 143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 144 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 145 readOp.getVectorType().getRank() != 2) 146 return false; 147 if (!getStaticallyKnownRowStride(readOp.getShapedType())) 148 return false; 149 150 // Only allow integer types if the signedness can be inferred. 151 if (readOp.getVectorType().getElementType().isInteger(8)) 152 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 153 !isa<arith::ExtUIOp>(*readOp->user_begin()))) 154 return false; 155 156 AffineMap map = readOp.getPermutationMap(); 157 MLIRContext *ctx = readOp.getContext(); 158 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 159 AffineExpr zero = getAffineConstantExpr(0, ctx); 160 auto broadcastInnerDim = 161 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 162 return map.isMinorIdentity() || map == broadcastInnerDim || 163 isTransposeMatrixLoadMap(map); 164 } 165 166 // Return true if the transfer op can be converted to a MMA matrix store. 167 static bool 168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 169 // TODO: support 0-d corner case. 170 if (writeOp.getTransferRank() == 0) 171 return false; 172 173 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 174 writeOp.getVectorType().getRank() != 2) 175 return false; 176 if (!getStaticallyKnownRowStride(writeOp.getShapedType())) 177 return false; 178 // TODO: Support transpose once it is added to GPU dialect ops. 179 if (!writeOp.getPermutationMap().isMinorIdentity()) 180 return false; 181 return true; 182 } 183 184 /// Return true if the constant is a splat to a 2D vector so that it can be 185 /// converted to a MMA constant matrix op. 186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 187 auto vecType = dyn_cast<VectorType>(constantOp.getType()); 188 if (!vecType || vecType.getRank() != 2) 189 return false; 190 return isa<SplatElementsAttr>(constantOp.getValue()); 191 } 192 193 /// Return true if this is a broadcast from scalar to a 2D vector. 194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 195 return broadcastOp.getResultVectorType().getRank() == 2; 196 } 197 198 /// Return true if this integer extend op can be folded into a contract op. 199 template <typename ExtOpTy> 200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 201 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) 202 return false; 203 return llvm::all_of(extOp->getUsers(), [](Operation *user) { 204 return isa<vector::ContractionOp>(user); 205 }); 206 } 207 208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 209 210 /// Return the MMA elementwise enum associated with `op` if it is supported. 211 /// Return `std::nullopt` otherwise. 212 static std::optional<gpu::MMAElementwiseOp> 213 convertElementwiseOpToMMA(Operation *op) { 214 if (isa<arith::AddFOp>(op)) 215 return gpu::MMAElementwiseOp::ADDF; 216 if (isa<arith::MulFOp>(op)) 217 return gpu::MMAElementwiseOp::MULF; 218 if (isa<arith::SubFOp>(op)) 219 return gpu::MMAElementwiseOp::SUBF; 220 if (isa<arith::MaximumFOp>(op)) 221 return gpu::MMAElementwiseOp::MAXF; 222 if (isa<arith::MinimumFOp>(op)) 223 return gpu::MMAElementwiseOp::MINF; 224 if (isa<arith::DivFOp>(op)) 225 return gpu::MMAElementwiseOp::DIVF; 226 if (isa<arith::AddIOp>(op)) 227 return gpu::MMAElementwiseOp::ADDI; 228 if (isa<arith::MulIOp>(op)) 229 return gpu::MMAElementwiseOp::MULI; 230 if (isa<arith::SubIOp>(op)) 231 return gpu::MMAElementwiseOp::SUBI; 232 if (isa<arith::DivSIOp>(op)) 233 return gpu::MMAElementwiseOp::DIVS; 234 if (isa<arith::DivUIOp>(op)) 235 return gpu::MMAElementwiseOp::DIVU; 236 if (isa<arith::NegFOp>(op)) 237 return gpu::MMAElementwiseOp::NEGATEF; 238 if (isa<arith::ExtFOp>(op)) 239 return gpu::MMAElementwiseOp::EXTF; 240 return std::nullopt; 241 } 242 243 /// Return true if the op is supported as elementwise op on MMAMatrix type. 244 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 245 return convertElementwiseOpToMMA(op).has_value(); 246 } 247 248 /// Returns true if the extract strided slice op is supported with `mma.sync` 249 /// path. 250 static bool 251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 252 253 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 254 nvgpu::getWarpMatrixInfo(op); 255 if (failed(warpMatrixInfo)) 256 return false; 257 258 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 259 if (failed(contractOp)) 260 return false; 261 262 // Handle vector.extract_strided_slice on registers containing 263 // matrixB and matrixC operands. vector.extract_strided_slice op 264 // is not supported on registers containing matrixA operands. 265 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 266 return (cast<VectorType>(op->getResult(0).getType()) == 267 cast<VectorType>((*contractOp).getRhs().getType())); 268 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 269 return (cast<VectorType>(op->getResult(0).getType()) == 270 cast<VectorType>((*contractOp).getAcc().getType())); 271 272 return false; 273 } 274 275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 276 if (isa<scf::ForOp, scf::YieldOp>(op)) 277 return true; 278 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 279 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) 280 : transferReadSupportsMMAMatrixType(transferRead); 281 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 282 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) 283 : transferWriteSupportsMMAMatrixType(transferWrite); 284 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 285 return useNvGpu && 286 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 287 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 288 return contractSupportsMMAMatrixType(contract, useNvGpu); 289 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 290 return constantSupportsMMAMatrixType(constant); 291 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 292 return broadcastSupportsMMAMatrixType(broadcast); 293 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 295 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 297 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 298 return fpExtendSupportsMMAMatrixType(fpExtend); 299 return elementwiseSupportsMMAMatrixType(op); 300 } 301 302 /// Return an unsorted slice handling scf.for region differently than 303 /// `getSlice`. In scf.for we only want to include as part of the slice elements 304 /// that are part of the use/def chain. 305 static SetVector<Operation *> 306 getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions, 307 ForwardSliceOptions forwardSliceOptions) { 308 SetVector<Operation *> slice; 309 slice.insert(op); 310 unsigned currentIndex = 0; 311 SetVector<Operation *> backwardSlice; 312 SetVector<Operation *> forwardSlice; 313 while (currentIndex != slice.size()) { 314 auto *currentOp = (slice)[currentIndex]; 315 // Compute and insert the backwardSlice starting from currentOp. 316 backwardSlice.clear(); 317 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 318 slice.insert(backwardSlice.begin(), backwardSlice.end()); 319 320 // Compute and insert the forwardSlice starting from currentOp. 321 forwardSlice.clear(); 322 // Special case for ForOp, we don't want to include the whole region but 323 // only the value using the region arguments. 324 // TODO: We should refine this to only care about the region arguments being 325 // converted to matrix type. 326 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 327 for (Value forOpResult : forOp.getResults()) 328 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 329 for (BlockArgument &arg : forOp.getRegionIterArgs()) 330 getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 331 } else { 332 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 333 } 334 slice.insert(forwardSlice.begin(), forwardSlice.end()); 335 ++currentIndex; 336 } 337 return slice; 338 } 339 340 // Analyze slice of operations based on convert op to figure out if the whole 341 // slice can be converted to MMA operations. 342 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 343 bool useNvGpu) { 344 auto hasVectorDest = [](Operation *op) { 345 return llvm::any_of(op->getResultTypes(), 346 [](Type t) { return isa<VectorType>(t); }); 347 }; 348 BackwardSliceOptions backwardSliceOptions; 349 backwardSliceOptions.filter = hasVectorDest; 350 351 auto hasVectorSrc = [](Operation *op) { 352 return llvm::any_of(op->getOperandTypes(), 353 [](Type t) { return isa<VectorType>(t); }); 354 }; 355 ForwardSliceOptions forwardSliceOptions; 356 forwardSliceOptions.filter = hasVectorSrc; 357 358 SetVector<Operation *> opToConvert; 359 op->walk([&](vector::ContractionOp contract) { 360 if (opToConvert.contains(contract.getOperation())) 361 return; 362 SetVector<Operation *> dependentOps = 363 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 364 // If any instruction cannot use MMA matrix type drop the whole 365 // chain. MMA matrix are stored in an opaque type so they cannot be used 366 // by all operations. 367 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 368 if (!supportsMMaMatrixType(op, useNvGpu)) { 369 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); 370 return true; 371 } 372 return false; 373 })) 374 return; 375 376 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 377 }); 378 // Sort the operations so that we can convert them in topological order. 379 return topologicalSort(opToConvert); 380 } 381 382 namespace { 383 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 384 // to MMA matmul. 385 struct PrepareContractToGPUMMA 386 : public OpRewritePattern<vector::ContractionOp> { 387 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 388 389 LogicalResult matchAndRewrite(vector::ContractionOp op, 390 PatternRewriter &rewriter) const override { 391 Location loc = op.getLoc(); 392 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 393 394 // Set up the parallel/reduction structure in right form. 395 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 396 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 397 AffineExpr m, n, k; 398 bindDims(rewriter.getContext(), m, n, k); 399 static constexpr std::array<int64_t, 2> perm = {1, 0}; 400 auto iteratorTypes = op.getIteratorTypes().getValue(); 401 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 402 if (!(vector::isParallelIterator(iteratorTypes[0]) && 403 vector::isParallelIterator(iteratorTypes[1]) && 404 vector::isReductionIterator(iteratorTypes[2]))) 405 return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 406 // 407 // Two outer parallel, one inner reduction (matmat flavor). 408 // 409 // This is the classical row-major matmul, nothing to do. 410 if (maps == infer({{m, k}, {k, n}, {m, n}})) 411 return rewriter.notifyMatchFailure(op, "contraction already prepared"); 412 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 413 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 414 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 415 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 416 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 417 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 418 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 419 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 420 std::swap(rhs, lhs); 421 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 422 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 423 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 424 std::swap(rhs, lhs); 425 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 426 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 427 std::swap(lhs, rhs); 428 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 429 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 430 std::swap(lhs, rhs); 431 } else { 432 // TODO: llvm_unreachable ? 433 return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 434 } 435 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 436 op, lhs, rhs, res, 437 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 438 op.getIteratorTypes()); 439 return success(); 440 } 441 }; 442 443 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports 444 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 445 // respectively. We can fold the transpose operation when loading the data from 446 // Shared Memory to registers. 447 struct CombineTransferReadOpTranspose final 448 : public OpRewritePattern<vector::TransposeOp> { 449 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 450 451 LogicalResult matchAndRewrite(vector::TransposeOp op, 452 PatternRewriter &rewriter) const override { 453 // Look through integer extend ops. 454 Value source = op.getVector(); 455 Type resultType = op.getType(); 456 Operation *extOp; 457 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 458 (extOp = source.getDefiningOp<arith::ExtUIOp>())) { 459 source = extOp->getOperand(0); 460 resultType = 461 VectorType::get(cast<VectorType>(resultType).getShape(), 462 cast<VectorType>(source.getType()).getElementType()); 463 } 464 465 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 466 if (!transferReadOp) 467 return rewriter.notifyMatchFailure(op, "no transfer read"); 468 469 // TODO: support 0-d corner case. 470 if (transferReadOp.getTransferRank() == 0) 471 return rewriter.notifyMatchFailure(op, "0-D transfer read"); 472 473 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 474 return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 475 476 AffineMap permutationMap = 477 AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); 478 AffineMap newMap = 479 permutationMap.compose(transferReadOp.getPermutationMap()); 480 481 auto loc = op.getLoc(); 482 Value result = 483 rewriter 484 .create<vector::TransferReadOp>( 485 loc, resultType, transferReadOp.getSource(), 486 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 487 transferReadOp.getPadding(), transferReadOp.getMask(), 488 transferReadOp.getInBoundsAttr()) 489 .getResult(); 490 491 // Fuse through the integer extend op. 492 if (extOp) { 493 if (isa<arith::ExtSIOp>(extOp)) 494 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 495 .getResult(); 496 else 497 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 498 .getResult(); 499 } 500 501 rewriter.replaceOp(op, result); 502 return success(); 503 } 504 }; 505 506 } // namespace 507 508 // MMA types have different layout based on how they are used in matmul ops. 509 // Figure the right layout to use by looking at op uses. 510 // TODO: Change the GPU dialect to abstract the layout at the this level and 511 // only care about it during lowering to NVVM. 512 static const char *inferFragType(Operation *op) { 513 for (Operation *users : op->getUsers()) { 514 auto contract = dyn_cast<vector::ContractionOp>(users); 515 if (!contract) 516 continue; 517 assert(op->getNumResults() == 1); 518 if (contract.getLhs() == op->getResult(0)) 519 return "AOp"; 520 if (contract.getRhs() == op->getResult(0)) 521 return "BOp"; 522 } 523 return "COp"; 524 } 525 526 static LogicalResult 527 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 528 llvm::DenseMap<Value, Value> &valueMapping) { 529 OpBuilder::InsertionGuard g(rewriter); 530 rewriter.setInsertionPoint(op); 531 532 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 533 assert(transferReadSupportsMMAMatrixType(op) && 534 "expected convertible operation"); 535 536 std::optional<int64_t> stride = 537 getStaticallyKnownRowStride(op.getShapedType()); 538 if (!stride.has_value()) { 539 LLVM_DEBUG(DBGS() << "no stride\n"); 540 return rewriter.notifyMatchFailure(op, "no stride"); 541 } 542 543 AffineMap map = op.getPermutationMap(); 544 bool isTranspose = isTransposeMatrixLoadMap(map); 545 546 // Handle broadcast by setting the stride to 0. 547 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { 548 assert(cstExpr.getValue() == 0); 549 stride = 0; 550 } 551 552 Value mappingResult = op.getResult(); 553 auto elType = op.getVectorType().getElementType(); 554 const char *fragType = inferFragType(op); 555 if (op->hasOneUse()) { 556 auto user = *op->user_begin(); 557 // Infer the signedness of the mma type from the integer extend. 558 bool isSignedExtend = isa<arith::ExtSIOp>(user); 559 if (isSignedExtend || isa<arith::ExtUIOp>(user)) { 560 elType = IntegerType::get( 561 op.getContext(), cast<IntegerType>(elType).getWidth(), 562 isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); 563 mappingResult = user->getResult(0); 564 fragType = inferFragType(user); 565 } 566 } 567 gpu::MMAMatrixType type = 568 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 569 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 570 op.getLoc(), type, op.getSource(), op.getIndices(), 571 rewriter.getIndexAttr(*stride), 572 isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 573 valueMapping[mappingResult] = load; 574 575 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 576 return success(); 577 } 578 579 static LogicalResult 580 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 581 llvm::DenseMap<Value, Value> &valueMapping) { 582 OpBuilder::InsertionGuard g(rewriter); 583 rewriter.setInsertionPoint(op); 584 585 assert(transferWriteSupportsMMAMatrixType(op)); 586 std::optional<int64_t> stride = 587 getStaticallyKnownRowStride(op.getShapedType()); 588 if (!stride.has_value()) { 589 LLVM_DEBUG(DBGS() << "no stride\n"); 590 return rewriter.notifyMatchFailure(op, "no stride"); 591 } 592 593 auto it = valueMapping.find(op.getVector()); 594 if (it == valueMapping.end()) { 595 LLVM_DEBUG(DBGS() << "no mapping\n"); 596 return rewriter.notifyMatchFailure(op, "no mapping"); 597 } 598 599 Value matrix = it->second; 600 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 601 op.getLoc(), matrix, op.getSource(), op.getIndices(), 602 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 603 (void)store; 604 605 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 606 607 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 608 rewriter.eraseOp(op); 609 return success(); 610 } 611 612 /// Returns the vector type which represents a matrix fragment. 613 static VectorType 614 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 615 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 616 regInfo.elementsPerRegister}; 617 Type elType = regInfo.registerLLVMType; 618 if (auto vecType = dyn_cast<VectorType>(elType)) 619 elType = vecType.getElementType(); 620 return VectorType::get(shape, elType); 621 } 622 623 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 624 static LogicalResult 625 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 626 llvm::DenseMap<Value, Value> &valueMapping) { 627 OpBuilder::InsertionGuard g(rewriter); 628 rewriter.setInsertionPoint(op); 629 630 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 631 nvgpu::getWarpMatrixInfo(op); 632 if (failed(warpMatrixInfo)) { 633 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 634 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 635 } 636 637 FailureOr<nvgpu::FragmentElementInfo> regInfo = 638 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 639 if (failed(regInfo)) { 640 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 641 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 642 } 643 644 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 645 auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 646 if (!dense) { 647 LLVM_DEBUG(DBGS() << "not a splat\n"); 648 return rewriter.notifyMatchFailure(op, "not a splat"); 649 } 650 651 Value result = rewriter.create<arith::ConstantOp>( 652 op.getLoc(), vectorType, 653 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 654 valueMapping[op.getResult()] = result; 655 return success(); 656 } 657 658 /// Check if the loaded matrix operand requires transposed. 659 /// Transposed Map Example: 660 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 661 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 662 /// The code below checks if the output 2D is transposed using a generalized 663 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 664 /// Returns : true; if m > n, false o.w. 665 static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 666 mlir::AffineMap map = op.getPermutationMap(); 667 668 if (map.getNumResults() != 2) { 669 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 670 "is not a 2d operand\n"); 671 return failure(); 672 } 673 674 // Output 2D matrix dimensions in the order of d0, d1. 675 mlir::AffineExpr dM = map.getResult(0); 676 mlir::AffineExpr dN = map.getResult(1); 677 678 // Find the position of these expressions in the input. 679 auto exprM = dyn_cast<AffineDimExpr>(dM); 680 auto exprN = dyn_cast<AffineDimExpr>(dN); 681 682 if (!exprM || !exprN) { 683 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 684 "expressions, then transpose cannot be determined.\n"); 685 return failure(); 686 } 687 688 return exprM.getPosition() > exprN.getPosition(); 689 } 690 691 static LogicalResult 692 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 693 llvm::DenseMap<Value, Value> &valueMapping) { 694 OpBuilder::InsertionGuard g(rewriter); 695 rewriter.setInsertionPoint(op); 696 Location loc = op->getLoc(); 697 698 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 699 nvgpu::getWarpMatrixInfo(op); 700 if (failed(warpMatrixInfo)) { 701 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 702 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 703 } 704 705 FailureOr<nvgpu::FragmentElementInfo> regInfo = 706 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 707 if (failed(regInfo)) { 708 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 709 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 710 } 711 712 FailureOr<bool> transpose = isTransposed(op); 713 if (failed(transpose)) { 714 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 715 return rewriter.notifyMatchFailure( 716 op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 717 } 718 719 FailureOr<nvgpu::LdMatrixParams> params = 720 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 721 722 if (failed(params)) { 723 LLVM_DEBUG( 724 DBGS() 725 << "failed to convert vector.transfer_read to ldmatrix. " 726 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 727 return rewriter.notifyMatchFailure( 728 op, "failed to convert vector.transfer_read to ldmatrix; this op " 729 "likely should not be converted to a nvgpu.ldmatrix call."); 730 } 731 732 // Adjust the load offset. 733 auto laneId = rewriter.create<gpu::LaneIdOp>(loc); 734 FailureOr<AffineMap> offsets = 735 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 736 if (failed(offsets)) { 737 LLVM_DEBUG(DBGS() << "no offsets\n"); 738 return rewriter.notifyMatchFailure(op, "no offsets"); 739 } 740 741 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 742 743 SmallVector<Value, 4> indices; 744 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 745 indices); 746 747 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 748 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 749 valueMapping[op] = newOp->getResult(0); 750 return success(); 751 } 752 753 static LogicalResult 754 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 755 llvm::DenseMap<Value, Value> &valueMapping) { 756 OpBuilder::InsertionGuard g(rewriter); 757 rewriter.setInsertionPoint(op); 758 759 Location loc = op.getLoc(); 760 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 761 nvgpu::getWarpMatrixInfo(op); 762 if (failed(warpMatrixInfo)) 763 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 764 FailureOr<nvgpu::FragmentElementInfo> regInfo = 765 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 766 if (failed(regInfo)) { 767 return rewriter.notifyMatchFailure( 768 op, "Failed to deduce register fragment type during " 769 "conversion to distributed non-ldmatrix compatible load"); 770 } 771 772 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 773 SmallVector<Value, 4> elements; 774 775 // This is the individual element type. 776 Type loadedElType = regInfo->registerLLVMType; 777 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 778 779 Value fill = rewriter.create<arith::ConstantOp>( 780 op.getLoc(), vectorType.getElementType(), 781 rewriter.getZeroAttr(vectorType.getElementType())); 782 Value result = 783 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 784 785 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 786 787 // If we are not transposing, then we can use vectorized loads. Otherwise, we 788 // must load each element individually. 789 if (!isTransposeLoad) { 790 if (!isa<VectorType>(loadedElType)) { 791 loadedElType = VectorType::get({1}, loadedElType); 792 } 793 794 for (int i = 0; i < vectorType.getShape()[0]; i++) { 795 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 796 rewriter, op.getLoc(), *warpMatrixInfo); 797 if (failed(coords)) 798 return rewriter.notifyMatchFailure(op, "no coords"); 799 800 Value logicalValueId = rewriter.create<arith::ConstantOp>( 801 loc, rewriter.getIndexType(), 802 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 803 SmallVector<Value, 4> newIndices; 804 getXferIndices<vector::TransferReadOp>( 805 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 806 807 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 808 op.getSource(), newIndices); 809 result = rewriter.create<vector::InsertOp>(loc, el, result, i); 810 } 811 } else { 812 if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 813 loadedElType = vecType.getElementType(); 814 } 815 for (int i = 0; i < vectorType.getShape()[0]; i++) { 816 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 817 innerIdx++) { 818 819 Value logicalValueId = rewriter.create<arith::ConstantOp>( 820 loc, rewriter.getIndexType(), 821 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 822 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 823 rewriter, op.getLoc(), *warpMatrixInfo); 824 if (failed(coords)) 825 return rewriter.notifyMatchFailure(op, "no coords"); 826 827 SmallVector<Value, 4> newIndices; 828 getXferIndices<vector::TransferReadOp>( 829 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 830 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 831 op.getSource(), newIndices); 832 result = rewriter.create<vector::InsertOp>( 833 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 834 } 835 } 836 } 837 838 valueMapping[op.getResult()] = result; 839 return success(); 840 } 841 842 /// Return true if this is a shared memory memref type. 843 static bool isSharedMemory(MemRefType type) { 844 auto addressSpace = 845 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 846 if (addressSpace && 847 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) 848 return true; 849 return false; 850 } 851 852 /// Converts a `vector.transfer_read` operation directly to either a 853 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 854 /// used when converting to `nvgpu.mma.sync` operations. 855 static LogicalResult 856 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 857 llvm::DenseMap<Value, Value> &valueMapping) { 858 OpBuilder::InsertionGuard g(rewriter); 859 rewriter.setInsertionPoint(op); 860 861 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 862 nvgpu::getWarpMatrixInfo(op); 863 if (failed(warpMatrixInfo)) 864 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 865 866 bool isLdMatrixCompatible = 867 isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 868 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 869 870 VectorType vecTy = op.getVectorType(); 871 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 872 873 // When we are transposing the B operand, ldmatrix will only work if we have 874 // at least 8 rows to read and the width to read for the transpose is 128 875 // bits. 876 if (!op.getPermutationMap().isMinorIdentity() && 877 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 878 vecTy.getDimSize(0) * bitWidth < 128)) 879 isLdMatrixCompatible = false; 880 881 if (!isLdMatrixCompatible) 882 return createNonLdMatrixLoads(rewriter, op, valueMapping); 883 884 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 885 } 886 887 static LogicalResult 888 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 889 llvm::DenseMap<Value, Value> &valueMapping) { 890 OpBuilder::InsertionGuard g(rewriter); 891 rewriter.setInsertionPoint(op); 892 893 Location loc = op->getLoc(); 894 auto it = valueMapping.find(op.getVector()); 895 if (it == valueMapping.end()) 896 return rewriter.notifyMatchFailure(op, "no mapping"); 897 Value matrix = it->second; 898 899 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 900 nvgpu::getWarpMatrixInfo(op); 901 if (failed(warpMatrixInfo)) 902 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 903 FailureOr<nvgpu::FragmentElementInfo> regInfo = 904 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 905 if (failed(regInfo)) 906 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 907 908 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 909 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 910 911 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 912 Value logicalValueId = rewriter.create<arith::ConstantOp>( 913 loc, rewriter.getIndexType(), 914 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 915 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 916 rewriter, op.getLoc(), *warpMatrixInfo); 917 if (failed(coords)) 918 return rewriter.notifyMatchFailure(op, "no coords"); 919 920 Value el = 921 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 922 SmallVector<Value, 4> newIndices; 923 getXferIndices<vector::TransferWriteOp>( 924 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 925 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 926 } 927 928 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 929 rewriter.eraseOp(op); 930 return success(); 931 } 932 933 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 934 SmallVectorImpl<int64_t> &results) { 935 for (auto attr : arrayAttr) 936 results.push_back(cast<IntegerAttr>(attr).getInt()); 937 } 938 939 static LogicalResult 940 convertExtractStridedSlice(RewriterBase &rewriter, 941 vector::ExtractStridedSliceOp op, 942 llvm::DenseMap<Value, Value> &valueMapping) { 943 OpBuilder::InsertionGuard g(rewriter); 944 rewriter.setInsertionPoint(op); 945 946 Location loc = op->getLoc(); 947 948 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 949 nvgpu::getWarpMatrixInfo(op); 950 if (failed(warpMatrixInfo)) 951 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 952 953 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 954 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 955 if (failed(mmaSyncFragmentInfo)) 956 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 957 958 // Find the vector.transer_read whose result vector is being sliced. 959 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 960 if (!transferReadOp) 961 return rewriter.notifyMatchFailure(op, "no transfer read"); 962 963 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 964 if (failed(warpMatrixInfo)) 965 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 966 967 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 968 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 969 if (failed(ldFragmentInfo)) 970 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 971 972 assert( 973 (mmaSyncFragmentInfo->elementsPerRegister == 974 ldFragmentInfo->elementsPerRegister) && 975 "Number of elements per register should be same for load and mma.sync"); 976 977 // Create vector.extract_strided_slice op for thread-owned fragments. 978 std::array<int64_t, 2> strides = {1, 979 1}; // stride for extract slice is always 1. 980 std::array<int64_t, 2> sliceShape = { 981 mmaSyncFragmentInfo->numRegistersPerFragment, 982 mmaSyncFragmentInfo->elementsPerRegister}; 983 auto it = valueMapping.find(transferReadOp); 984 if (it == valueMapping.end()) 985 return rewriter.notifyMatchFailure(op, "no mapping"); 986 auto sourceVector = it->second; 987 988 // offset and sizes at warp-level of onwership. 989 SmallVector<int64_t> offsets; 990 populateFromInt64AttrArray(op.getOffsets(), offsets); 991 992 SmallVector<int64_t> sizes; 993 populateFromInt64AttrArray(op.getSizes(), sizes); 994 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 995 996 // Compute offset in vector registers. Note that the mma.sync vector registers 997 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 998 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 999 std::array<int64_t, 2> sliceOffset = {0, 0}; 1000 1001 if (offsets[0] && offsets[1]) 1002 return op->emitError() << "Slicing fragments in 2D is not supported. "; 1003 if (offsets[0]) 1004 sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1005 else if (offsets[1]) 1006 sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1007 1008 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1009 loc, sourceVector, sliceOffset, sliceShape, strides); 1010 1011 valueMapping[op] = newOp; 1012 return success(); 1013 } 1014 1015 static LogicalResult 1016 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1017 llvm::DenseMap<Value, Value> &valueMapping) { 1018 OpBuilder::InsertionGuard g(rewriter); 1019 rewriter.setInsertionPoint(op); 1020 1021 auto itA = valueMapping.find(op.getLhs()); 1022 auto itB = valueMapping.find(op.getRhs()); 1023 auto itC = valueMapping.find(op.getAcc()); 1024 if (itA == valueMapping.end() || itB == valueMapping.end() || 1025 itC == valueMapping.end()) 1026 return rewriter.notifyMatchFailure(op, "no mapping"); 1027 Value opA = itA->second, opB = itB->second, opC = itC->second; 1028 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 1029 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 1030 /*b_transpose=*/UnitAttr()); 1031 valueMapping[op.getResult()] = matmul; 1032 return success(); 1033 } 1034 1035 static LogicalResult 1036 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 1037 llvm::DenseMap<Value, Value> &valueMapping) { 1038 OpBuilder::InsertionGuard g(rewriter); 1039 rewriter.setInsertionPoint(op); 1040 1041 auto itA = valueMapping.find(op.getLhs()); 1042 auto itB = valueMapping.find(op.getRhs()); 1043 auto itC = valueMapping.find(op.getAcc()); 1044 if (itA == valueMapping.end() || itB == valueMapping.end() || 1045 itC == valueMapping.end()) 1046 return rewriter.notifyMatchFailure(op, "no mapping"); 1047 Value opA = itA->second, opB = itB->second, opC = itC->second; 1048 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 1049 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 1050 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 1051 Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 1052 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 1053 valueMapping[op.getResult()] = matmul; 1054 return success(); 1055 } 1056 1057 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 1058 static LogicalResult 1059 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 1060 llvm::DenseMap<Value, Value> &valueMapping) { 1061 OpBuilder::InsertionGuard g(rewriter); 1062 rewriter.setInsertionPoint(op); 1063 1064 assert(constantSupportsMMAMatrixType(op)); 1065 1066 auto splat = 1067 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 1068 auto scalarConstant = 1069 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 1070 const char *fragType = inferFragType(op); 1071 auto vecType = cast<VectorType>(op.getType()); 1072 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1073 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1074 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1075 op.getLoc(), type, scalarConstant); 1076 valueMapping[op.getResult()] = matrix; 1077 return success(); 1078 } 1079 1080 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 1081 static LogicalResult 1082 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 1083 llvm::DenseMap<Value, Value> &valueMapping) { 1084 OpBuilder::InsertionGuard g(rewriter); 1085 rewriter.setInsertionPoint(op); 1086 1087 assert(broadcastSupportsMMAMatrixType(op)); 1088 1089 const char *fragType = inferFragType(op); 1090 auto vecType = op.getResultVectorType(); 1091 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1092 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1093 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1094 op.getLoc(), type, op.getSource()); 1095 valueMapping[op.getResult()] = matrix; 1096 return success(); 1097 } 1098 1099 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 1100 // updated and needs to be updated separately for the loop to be correct. 1101 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 1102 scf::ForOp loop, 1103 ValueRange newInitArgs) { 1104 OpBuilder::InsertionGuard g(rewriter); 1105 rewriter.setInsertionPoint(loop); 1106 1107 // Create a new loop before the existing one, with the extra operands. 1108 rewriter.setInsertionPoint(loop); 1109 auto operands = llvm::to_vector<4>(loop.getInitArgs()); 1110 llvm::append_range(operands, newInitArgs); 1111 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 1112 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 1113 operands); 1114 newLoop.getBody()->erase(); 1115 1116 newLoop.getRegion().getBlocks().splice( 1117 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); 1118 for (Value operand : newInitArgs) 1119 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 1120 1121 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 1122 loop.getNumResults()))) 1123 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 1124 1125 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 1126 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 1127 LLVM_DEBUG(DBGS() << "erase: " << loop); 1128 1129 rewriter.eraseOp(loop); 1130 return newLoop; 1131 } 1132 1133 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 1134 llvm::DenseMap<Value, Value> &valueMapping) { 1135 OpBuilder::InsertionGuard g(rewriter); 1136 rewriter.setInsertionPoint(op); 1137 1138 SmallVector<Value> newOperands; 1139 SmallVector<std::pair<size_t, size_t>> argMapping; 1140 for (const auto &operand : llvm::enumerate(op.getInitArgs())) { 1141 auto it = valueMapping.find(operand.value()); 1142 if (it == valueMapping.end()) { 1143 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 1144 continue; 1145 } 1146 argMapping.push_back(std::make_pair( 1147 operand.index(), op.getInitArgs().size() + newOperands.size())); 1148 newOperands.push_back(it->second); 1149 } 1150 1151 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 1152 Block &loopBody = *newForOp.getBody(); 1153 for (auto mapping : argMapping) { 1154 valueMapping[newForOp.getResult(mapping.first)] = 1155 newForOp.getResult(mapping.second); 1156 valueMapping[loopBody.getArgument(mapping.first + 1157 newForOp.getNumInductionVars())] = 1158 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 1159 } 1160 1161 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 1162 return success(); 1163 } 1164 1165 static LogicalResult 1166 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 1167 llvm::DenseMap<Value, Value> &valueMapping) { 1168 OpBuilder::InsertionGuard g(rewriter); 1169 rewriter.setInsertionPoint(op); 1170 1171 auto loop = cast<scf::ForOp>(op->getParentOp()); 1172 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1173 for (const auto &operand : llvm::enumerate(op.getOperands())) { 1174 auto it = valueMapping.find(operand.value()); 1175 if (it == valueMapping.end()) 1176 continue; 1177 // Replace the yield of old value with the for op argument to make it easier 1178 // to remove the dead code. 1179 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; 1180 yieldOperands.push_back(it->second); 1181 } 1182 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 1183 1184 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 1185 rewriter.eraseOp(op); 1186 return success(); 1187 } 1188 1189 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 1190 static LogicalResult 1191 convertElementwiseOp(RewriterBase &rewriter, Operation *op, 1192 gpu::MMAElementwiseOp opType, 1193 llvm::DenseMap<Value, Value> &valueMapping) { 1194 OpBuilder::InsertionGuard g(rewriter); 1195 rewriter.setInsertionPoint(op); 1196 1197 SmallVector<Value> matrixOperands; 1198 for (Value operand : op->getOperands()) { 1199 auto it = valueMapping.find(operand); 1200 if (it == valueMapping.end()) 1201 return rewriter.notifyMatchFailure(op, "no mapping"); 1202 matrixOperands.push_back(it->second); 1203 } 1204 auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>(); 1205 if (opType == gpu::MMAElementwiseOp::EXTF) { 1206 // The floating point extension case has a different result type. 1207 auto vectorType = op->getResultTypes()[0].cast<VectorType>(); 1208 resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1209 vectorType.getElementType(), 1210 resultType.getOperand()); 1211 } 1212 1213 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1214 op->getLoc(), resultType, matrixOperands, opType); 1215 valueMapping[op->getResult(0)] = newOp; 1216 return success(); 1217 } 1218 1219 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 1220 bool useNvGpu) { 1221 if (!useNvGpu) { 1222 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1223 patterns.getContext()); 1224 return; 1225 } 1226 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1227 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1228 } 1229 1230 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 1231 Operation *rootOp) { 1232 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1233 llvm::DenseMap<Value, Value> valueMapping; 1234 1235 auto globalRes = LogicalResult::success(); 1236 for (Operation *op : ops) { 1237 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 1238 // Apparently callers do not want to early exit on failure here. 1239 auto res = LogicalResult::success(); 1240 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 1241 res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1242 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 1243 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1244 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 1245 res = convertContractOp(rewriter, contractOp, valueMapping); 1246 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 1247 res = convertConstantOp(rewriter, constantOp, valueMapping); 1248 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 1249 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 1250 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 1251 res = convertForOp(rewriter, forOp, valueMapping); 1252 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 1253 res = convertYieldOp(rewriter, yieldOp, valueMapping); 1254 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 1255 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1256 } 1257 if (failed(res)) 1258 globalRes = failure(); 1259 } 1260 return globalRes; 1261 } 1262 1263 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 1264 Operation *rootOp) { 1265 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 1266 llvm::DenseMap<Value, Value> valueMapping; 1267 for (Operation *op : ops) { 1268 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 1269 .Case([&](vector::TransferReadOp transferReadOp) { 1270 return convertTransferReadToLoads(rewriter, transferReadOp, 1271 valueMapping); 1272 }) 1273 .Case([&](vector::TransferWriteOp transferWriteOp) { 1274 return convertTransferWriteToStores(rewriter, transferWriteOp, 1275 valueMapping); 1276 }) 1277 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 1278 return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1279 valueMapping); 1280 }) 1281 .Case([&](vector::ContractionOp contractionOp) { 1282 return convertContractOpToMmaSync(rewriter, contractionOp, 1283 valueMapping); 1284 }) 1285 .Case([&](scf::ForOp forOp) { 1286 return convertForOp(rewriter, forOp, valueMapping); 1287 }) 1288 .Case([&](scf::YieldOp yieldOp) { 1289 return convertYieldOp(rewriter, yieldOp, valueMapping); 1290 }) 1291 .Case([&](arith::ConstantOp constOp) { 1292 return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 1293 }) 1294 .Default([&](Operation *op) { 1295 return op->emitError() << "unhandled vector to mma type: " << *op; 1296 }) 1297 .failed()) { 1298 return op->emitOpError() 1299 << "failed to convert op during vector-to-nvgpu conversion"; 1300 } 1301 } 1302 return success(); 1303 } 1304 1305 namespace { 1306 1307 struct ConvertVectorToGPUPass 1308 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 1309 1310 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 1311 useNvGpu.setValue(useNvGpu_); 1312 } 1313 1314 void runOnOperation() override { 1315 RewritePatternSet patterns(&getContext()); 1316 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 1317 if (failed( 1318 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 1319 return signalPassFailure(); 1320 1321 IRRewriter rewriter(&getContext()); 1322 if (useNvGpu) { 1323 if (failed( 1324 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 1325 return signalPassFailure(); 1326 return; 1327 } 1328 (void)convertVectorToMMAOps(rewriter, getOperation()); 1329 } 1330 }; 1331 1332 } // namespace 1333 1334 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 1335 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1336 } 1337