1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14 15 #include <type_traits> 16 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Arith/IR/Arith.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/Region.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/STLExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 39 #define DEBUG_TYPE "vector-to-gpu" 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 41 #define DBGSNL() (llvm::dbgs() << "\n") 42 43 namespace mlir { 44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 45 #include "mlir/Conversion/Passes.h.inc" 46 } // namespace mlir 47 48 using namespace mlir; 49 50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 51 /// AffineMap representing offsets to apply to indices, the function fills 52 /// `indices` with the original indices plus the offsets. The offsets are 53 /// applied by taking into account the permutation map of the transfer op. If 54 /// the `offsetMap` has dimension placeholders, those should be provided in 55 /// `dimValues`. 56 template <typename TransferOpType> 57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 58 AffineMap offsetMap, ArrayRef<Value> dimValues, 59 SmallVector<Value, 4> &indices) { 60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 61 Location loc = xferOp.getLoc(); 62 unsigned offsetsIdx = 0; 63 for (auto expr : xferOp.getPermutationMap().getResults()) { 64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 65 Value prevIdx = indices[dim.getPosition()]; 66 SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end()); 67 dims.push_back(prevIdx); 68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 69 indices[dim.getPosition()] = affine::makeComposedAffineApply( 70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 71 continue; 72 } 73 } 74 } 75 76 // Return true if the contract op can be convert to MMA matmul. 77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 78 bool useNvGpu) { 79 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 81 AffineExpr m, n, k; 82 bindDims(contract.getContext(), m, n, k); 83 auto iteratorTypes = contract.getIteratorTypes().getValue(); 84 if (!(vector::isParallelIterator(iteratorTypes[0]) && 85 vector::isParallelIterator(iteratorTypes[1]) && 86 vector::isReductionIterator(iteratorTypes[2]))) 87 return false; 88 89 // The contract needs to represent a matmul to be able to convert to 90 // MMAMatrix matmul. 91 if (!useNvGpu && 92 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 93 return false; 94 if (useNvGpu && 95 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 96 return false; 97 98 return true; 99 } 100 101 // Return true if the given map represents a transposed matrix load, 102 // i.e. (d0, d1, ...) -> (dn-1, dn-2). 103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 104 MLIRContext *ctx = permutationMap.getContext(); 105 // Local OpBuilder is fine here, we just build attributes. 106 OpBuilder b(ctx); 107 auto nDim = permutationMap.getNumDims(); 108 AffineExpr zero = b.getAffineConstantExpr(0); 109 if (nDim < 2) { 110 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 111 AffineExpr dim0 = b.getAffineDimExpr(0); 112 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 113 } 114 115 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 116 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 117 // Support both transposed and transposed+broadcasted cases. 118 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 119 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 120 } 121 122 // Return the stide for the second-to-last dimension of |type| if it is a memref 123 // and has a constant stride. 124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { 125 auto memrefType = dyn_cast<MemRefType>(type); 126 if (!memrefType) 127 return false; 128 // If the memref is 0 or 1D the horizontal stride is 0. 129 if (memrefType.getRank() < 2) 130 return 0; 131 int64_t offset = 0; 132 SmallVector<int64_t, 2> strides; 133 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 134 strides.back() != 1) 135 return std::nullopt; 136 int64_t stride = strides[strides.size() - 2]; 137 if (stride == ShapedType::kDynamic) 138 return std::nullopt; 139 return stride; 140 } 141 142 // Return true if the transfer op can be converted to a MMA matrix load. 143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 144 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 145 readOp.getVectorType().getRank() != 2) 146 return false; 147 if (!getStaticallyKnownRowStride(readOp.getShapedType())) 148 return false; 149 150 // Only allow integer types if the signedness can be inferred. 151 if (readOp.getVectorType().getElementType().isInteger(8)) 152 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 153 !isa<arith::ExtUIOp>(*readOp->user_begin()))) 154 return false; 155 156 AffineMap map = readOp.getPermutationMap(); 157 MLIRContext *ctx = readOp.getContext(); 158 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 159 AffineExpr zero = getAffineConstantExpr(0, ctx); 160 auto broadcastInnerDim = 161 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 162 return map.isMinorIdentity() || map == broadcastInnerDim || 163 isTransposeMatrixLoadMap(map); 164 } 165 166 // Return true if the transfer op can be converted to a MMA matrix store. 167 static bool 168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 169 // TODO: support 0-d corner case. 170 if (writeOp.getTransferRank() == 0) 171 return false; 172 173 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 174 writeOp.getVectorType().getRank() != 2) 175 return false; 176 if (!getStaticallyKnownRowStride(writeOp.getShapedType())) 177 return false; 178 // TODO: Support transpose once it is added to GPU dialect ops. 179 if (!writeOp.getPermutationMap().isMinorIdentity()) 180 return false; 181 return true; 182 } 183 184 /// Return true if the constant is a splat to a 2D vector so that it can be 185 /// converted to a MMA constant matrix op. 186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 187 auto vecType = dyn_cast<VectorType>(constantOp.getType()); 188 if (!vecType || vecType.getRank() != 2) 189 return false; 190 return isa<SplatElementsAttr>(constantOp.getValue()); 191 } 192 193 /// Return true if this is a broadcast from scalar to a 2D vector. 194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 195 return broadcastOp.getResultVectorType().getRank() == 2; 196 } 197 198 /// Return true if this integer extend op can be folded into a contract op. 199 template <typename ExtOpTy> 200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 201 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) 202 return false; 203 return llvm::all_of(extOp->getUsers(), [](Operation *user) { 204 return isa<vector::ContractionOp>(user); 205 }); 206 } 207 208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 209 210 /// Return the MMA elementwise enum associated with `op` if it is supported. 211 /// Return `std::nullopt` otherwise. 212 static std::optional<gpu::MMAElementwiseOp> 213 convertElementwiseOpToMMA(Operation *op) { 214 if (isa<arith::AddFOp>(op)) 215 return gpu::MMAElementwiseOp::ADDF; 216 if (isa<arith::MulFOp>(op)) 217 return gpu::MMAElementwiseOp::MULF; 218 if (isa<arith::SubFOp>(op)) 219 return gpu::MMAElementwiseOp::SUBF; 220 if (isa<arith::MaximumFOp>(op)) 221 return gpu::MMAElementwiseOp::MAXF; 222 if (isa<arith::MinimumFOp>(op)) 223 return gpu::MMAElementwiseOp::MINF; 224 if (isa<arith::DivFOp>(op)) 225 return gpu::MMAElementwiseOp::DIVF; 226 if (isa<arith::AddIOp>(op)) 227 return gpu::MMAElementwiseOp::ADDI; 228 if (isa<arith::MulIOp>(op)) 229 return gpu::MMAElementwiseOp::MULI; 230 if (isa<arith::SubIOp>(op)) 231 return gpu::MMAElementwiseOp::SUBI; 232 if (isa<arith::DivSIOp>(op)) 233 return gpu::MMAElementwiseOp::DIVS; 234 if (isa<arith::DivUIOp>(op)) 235 return gpu::MMAElementwiseOp::DIVU; 236 if (isa<arith::NegFOp>(op)) 237 return gpu::MMAElementwiseOp::NEGATEF; 238 if (isa<arith::ExtFOp>(op)) 239 return gpu::MMAElementwiseOp::EXTF; 240 return std::nullopt; 241 } 242 243 /// Return true if the op is supported as elementwise op on MMAMatrix type. 244 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 245 return convertElementwiseOpToMMA(op).has_value(); 246 } 247 248 /// Returns true if the extract strided slice op is supported with `mma.sync` 249 /// path. 250 static bool 251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 252 253 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 254 nvgpu::getWarpMatrixInfo(op); 255 if (failed(warpMatrixInfo)) 256 return false; 257 258 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 259 if (failed(contractOp)) 260 return false; 261 262 // Handle vector.extract_strided_slice on registers containing 263 // matrixB and matrixC operands. vector.extract_strided_slice op 264 // is not supported on registers containing matrixA operands. 265 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 266 return (cast<VectorType>(op->getResult(0).getType()) == 267 cast<VectorType>((*contractOp).getRhs().getType())); 268 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 269 return (cast<VectorType>(op->getResult(0).getType()) == 270 cast<VectorType>((*contractOp).getAcc().getType())); 271 272 return false; 273 } 274 275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 276 if (isa<scf::ForOp, scf::YieldOp>(op)) 277 return true; 278 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 279 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) 280 : transferReadSupportsMMAMatrixType(transferRead); 281 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 282 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) 283 : transferWriteSupportsMMAMatrixType(transferWrite); 284 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 285 return useNvGpu && 286 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 287 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 288 return contractSupportsMMAMatrixType(contract, useNvGpu); 289 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 290 return constantSupportsMMAMatrixType(constant); 291 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 292 return broadcastSupportsMMAMatrixType(broadcast); 293 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 295 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 297 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 298 return fpExtendSupportsMMAMatrixType(fpExtend); 299 return elementwiseSupportsMMAMatrixType(op); 300 } 301 302 /// Return an unsorted slice handling scf.for region differently than 303 /// `getSlice`. In scf.for we only want to include as part of the slice elements 304 /// that are part of the use/def chain. 305 static SetVector<Operation *> 306 getSliceContract(Operation *op, 307 const BackwardSliceOptions &backwardSliceOptions, 308 const ForwardSliceOptions &forwardSliceOptions) { 309 SetVector<Operation *> slice; 310 slice.insert(op); 311 unsigned currentIndex = 0; 312 SetVector<Operation *> backwardSlice; 313 SetVector<Operation *> forwardSlice; 314 while (currentIndex != slice.size()) { 315 auto *currentOp = (slice)[currentIndex]; 316 // Compute and insert the backwardSlice starting from currentOp. 317 backwardSlice.clear(); 318 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 319 slice.insert(backwardSlice.begin(), backwardSlice.end()); 320 321 // Compute and insert the forwardSlice starting from currentOp. 322 forwardSlice.clear(); 323 // Special case for ForOp, we don't want to include the whole region but 324 // only the value using the region arguments. 325 // TODO: We should refine this to only care about the region arguments being 326 // converted to matrix type. 327 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 328 for (Value forOpResult : forOp.getResults()) 329 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 330 for (BlockArgument &arg : forOp.getRegionIterArgs()) 331 getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 332 } else { 333 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 334 } 335 slice.insert(forwardSlice.begin(), forwardSlice.end()); 336 ++currentIndex; 337 } 338 return slice; 339 } 340 341 // Analyze slice of operations based on convert op to figure out if the whole 342 // slice can be converted to MMA operations. 343 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 344 bool useNvGpu) { 345 auto hasVectorDest = [](Operation *op) { 346 return llvm::any_of(op->getResultTypes(), 347 [](Type t) { return isa<VectorType>(t); }); 348 }; 349 BackwardSliceOptions backwardSliceOptions; 350 backwardSliceOptions.filter = hasVectorDest; 351 352 auto hasVectorSrc = [](Operation *op) { 353 return llvm::any_of(op->getOperandTypes(), 354 [](Type t) { return isa<VectorType>(t); }); 355 }; 356 ForwardSliceOptions forwardSliceOptions; 357 forwardSliceOptions.filter = hasVectorSrc; 358 359 SetVector<Operation *> opToConvert; 360 op->walk([&](vector::ContractionOp contract) { 361 if (opToConvert.contains(contract.getOperation())) 362 return; 363 SetVector<Operation *> dependentOps = 364 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 365 // If any instruction cannot use MMA matrix type drop the whole 366 // chain. MMA matrix are stored in an opaque type so they cannot be used 367 // by all operations. 368 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 369 if (!supportsMMaMatrixType(op, useNvGpu)) { 370 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); 371 return true; 372 } 373 return false; 374 })) 375 return; 376 377 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 378 }); 379 // Sort the operations so that we can convert them in topological order. 380 return topologicalSort(opToConvert); 381 } 382 383 namespace { 384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 385 // to MMA matmul. 386 struct PrepareContractToGPUMMA 387 : public OpRewritePattern<vector::ContractionOp> { 388 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 389 390 LogicalResult matchAndRewrite(vector::ContractionOp op, 391 PatternRewriter &rewriter) const override { 392 Location loc = op.getLoc(); 393 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 394 395 // Set up the parallel/reduction structure in right form. 396 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 397 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 398 AffineExpr m, n, k; 399 bindDims(rewriter.getContext(), m, n, k); 400 static constexpr std::array<int64_t, 2> perm = {1, 0}; 401 auto iteratorTypes = op.getIteratorTypes().getValue(); 402 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 403 if (!(vector::isParallelIterator(iteratorTypes[0]) && 404 vector::isParallelIterator(iteratorTypes[1]) && 405 vector::isReductionIterator(iteratorTypes[2]))) 406 return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 407 // 408 // Two outer parallel, one inner reduction (matmat flavor). 409 // 410 // This is the classical row-major matmul, nothing to do. 411 if (maps == infer({{m, k}, {k, n}, {m, n}})) 412 return rewriter.notifyMatchFailure(op, "contraction already prepared"); 413 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 414 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 415 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 416 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 417 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 418 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 419 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 420 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 421 std::swap(rhs, lhs); 422 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 423 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 424 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 425 std::swap(rhs, lhs); 426 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 427 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 428 std::swap(lhs, rhs); 429 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 430 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 431 std::swap(lhs, rhs); 432 } else { 433 // TODO: llvm_unreachable ? 434 return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 435 } 436 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 437 op, lhs, rhs, res, 438 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 439 op.getIteratorTypes()); 440 return success(); 441 } 442 }; 443 444 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports 445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 446 // respectively. We can fold the transpose operation when loading the data from 447 // Shared Memory to registers. 448 struct CombineTransferReadOpTranspose final 449 : public OpRewritePattern<vector::TransposeOp> { 450 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 451 452 LogicalResult matchAndRewrite(vector::TransposeOp op, 453 PatternRewriter &rewriter) const override { 454 // Look through integer extend ops. 455 Value source = op.getVector(); 456 Type resultType = op.getType(); 457 Operation *extOp; 458 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 459 (extOp = source.getDefiningOp<arith::ExtUIOp>()) || 460 (extOp = source.getDefiningOp<arith::ExtFOp>())) { 461 source = extOp->getOperand(0); 462 resultType = 463 VectorType::get(cast<VectorType>(resultType).getShape(), 464 cast<VectorType>(source.getType()).getElementType()); 465 } 466 467 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 468 if (!transferReadOp) 469 return rewriter.notifyMatchFailure(op, "no transfer read"); 470 471 // TODO: support 0-d corner case. 472 if (transferReadOp.getTransferRank() == 0) 473 return rewriter.notifyMatchFailure(op, "0-D transfer read"); 474 475 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 476 return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 477 478 AffineMap permutationMap = 479 AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); 480 AffineMap newMap = 481 permutationMap.compose(transferReadOp.getPermutationMap()); 482 483 auto loc = op.getLoc(); 484 Value result = 485 rewriter 486 .create<vector::TransferReadOp>( 487 loc, resultType, transferReadOp.getSource(), 488 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 489 transferReadOp.getPadding(), transferReadOp.getMask(), 490 transferReadOp.getInBoundsAttr()) 491 .getResult(); 492 493 // Fuse through the integer extend op. 494 if (extOp) { 495 if (isa<arith::ExtSIOp>(extOp)) 496 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 497 .getResult(); 498 else if (isa<arith::ExtUIOp>(extOp)) 499 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 500 .getResult(); 501 else 502 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) 503 .getResult(); 504 } 505 506 rewriter.replaceOp(op, result); 507 return success(); 508 } 509 }; 510 511 } // namespace 512 513 // MMA types have different layout based on how they are used in matmul ops. 514 // Figure the right layout to use by looking at op uses. 515 // TODO: Change the GPU dialect to abstract the layout at the this level and 516 // only care about it during lowering to NVVM. 517 static const char *inferFragType(Operation *op) { 518 for (Operation *users : op->getUsers()) { 519 auto contract = dyn_cast<vector::ContractionOp>(users); 520 if (!contract) 521 continue; 522 assert(op->getNumResults() == 1); 523 if (contract.getLhs() == op->getResult(0)) 524 return "AOp"; 525 if (contract.getRhs() == op->getResult(0)) 526 return "BOp"; 527 } 528 return "COp"; 529 } 530 531 static LogicalResult 532 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 533 llvm::DenseMap<Value, Value> &valueMapping) { 534 OpBuilder::InsertionGuard g(rewriter); 535 rewriter.setInsertionPoint(op); 536 537 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 538 assert(transferReadSupportsMMAMatrixType(op) && 539 "expected convertible operation"); 540 541 std::optional<int64_t> stride = 542 getStaticallyKnownRowStride(op.getShapedType()); 543 if (!stride.has_value()) { 544 LLVM_DEBUG(DBGS() << "no stride\n"); 545 return rewriter.notifyMatchFailure(op, "no stride"); 546 } 547 548 AffineMap map = op.getPermutationMap(); 549 bool isTranspose = isTransposeMatrixLoadMap(map); 550 551 // Handle broadcast by setting the stride to 0. 552 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { 553 assert(cstExpr.getValue() == 0); 554 stride = 0; 555 } 556 557 Value mappingResult = op.getResult(); 558 auto elType = op.getVectorType().getElementType(); 559 const char *fragType = inferFragType(op); 560 if (op->hasOneUse()) { 561 auto *user = *op->user_begin(); 562 // Infer the signedness of the mma type from the integer extend. 563 bool isSignedExtend = isa<arith::ExtSIOp>(user); 564 if (isSignedExtend || isa<arith::ExtUIOp>(user)) { 565 elType = IntegerType::get( 566 op.getContext(), cast<IntegerType>(elType).getWidth(), 567 isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); 568 mappingResult = user->getResult(0); 569 fragType = inferFragType(user); 570 } 571 } 572 gpu::MMAMatrixType type = 573 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 574 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 575 op.getLoc(), type, op.getSource(), op.getIndices(), 576 rewriter.getIndexAttr(*stride), 577 isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 578 valueMapping[mappingResult] = load; 579 580 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 581 return success(); 582 } 583 584 static LogicalResult 585 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 586 llvm::DenseMap<Value, Value> &valueMapping) { 587 OpBuilder::InsertionGuard g(rewriter); 588 rewriter.setInsertionPoint(op); 589 590 assert(transferWriteSupportsMMAMatrixType(op)); 591 std::optional<int64_t> stride = 592 getStaticallyKnownRowStride(op.getShapedType()); 593 if (!stride.has_value()) { 594 LLVM_DEBUG(DBGS() << "no stride\n"); 595 return rewriter.notifyMatchFailure(op, "no stride"); 596 } 597 598 auto it = valueMapping.find(op.getVector()); 599 if (it == valueMapping.end()) { 600 LLVM_DEBUG(DBGS() << "no mapping\n"); 601 return rewriter.notifyMatchFailure(op, "no mapping"); 602 } 603 604 Value matrix = it->second; 605 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 606 op.getLoc(), matrix, op.getSource(), op.getIndices(), 607 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 608 (void)store; 609 610 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 611 612 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 613 rewriter.eraseOp(op); 614 return success(); 615 } 616 617 /// Returns the vector type which represents a matrix fragment. 618 static VectorType 619 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 620 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 621 regInfo.elementsPerRegister}; 622 Type elType = regInfo.registerLLVMType; 623 if (auto vecType = dyn_cast<VectorType>(elType)) 624 elType = vecType.getElementType(); 625 return VectorType::get(shape, elType); 626 } 627 628 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 629 static LogicalResult 630 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 631 llvm::DenseMap<Value, Value> &valueMapping) { 632 OpBuilder::InsertionGuard g(rewriter); 633 rewriter.setInsertionPoint(op); 634 635 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 636 nvgpu::getWarpMatrixInfo(op); 637 if (failed(warpMatrixInfo)) { 638 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 639 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 640 } 641 642 FailureOr<nvgpu::FragmentElementInfo> regInfo = 643 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 644 if (failed(regInfo)) { 645 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 646 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 647 } 648 649 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 650 auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 651 if (!dense) { 652 LLVM_DEBUG(DBGS() << "not a splat\n"); 653 return rewriter.notifyMatchFailure(op, "not a splat"); 654 } 655 656 Value result = rewriter.create<arith::ConstantOp>( 657 op.getLoc(), vectorType, 658 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 659 valueMapping[op.getResult()] = result; 660 return success(); 661 } 662 663 /// Check if the loaded matrix operand requires transposed. 664 /// Transposed Map Example: 665 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 666 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 667 /// The code below checks if the output 2D is transposed using a generalized 668 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 669 /// Returns : true; if m > n, false o.w. 670 static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 671 mlir::AffineMap map = op.getPermutationMap(); 672 673 if (map.getNumResults() != 2) { 674 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 675 "is not a 2d operand\n"); 676 return failure(); 677 } 678 679 // Output 2D matrix dimensions in the order of d0, d1. 680 mlir::AffineExpr dM = map.getResult(0); 681 mlir::AffineExpr dN = map.getResult(1); 682 683 // Find the position of these expressions in the input. 684 auto exprM = dyn_cast<AffineDimExpr>(dM); 685 auto exprN = dyn_cast<AffineDimExpr>(dN); 686 687 if (!exprM || !exprN) { 688 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 689 "expressions, then transpose cannot be determined.\n"); 690 return failure(); 691 } 692 693 return exprM.getPosition() > exprN.getPosition(); 694 } 695 696 static LogicalResult 697 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 698 llvm::DenseMap<Value, Value> &valueMapping) { 699 OpBuilder::InsertionGuard g(rewriter); 700 rewriter.setInsertionPoint(op); 701 Location loc = op->getLoc(); 702 703 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 704 nvgpu::getWarpMatrixInfo(op); 705 if (failed(warpMatrixInfo)) { 706 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 707 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 708 } 709 710 FailureOr<nvgpu::FragmentElementInfo> regInfo = 711 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 712 if (failed(regInfo)) { 713 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 714 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 715 } 716 717 FailureOr<bool> transpose = isTransposed(op); 718 if (failed(transpose)) { 719 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 720 return rewriter.notifyMatchFailure( 721 op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 722 } 723 724 FailureOr<nvgpu::LdMatrixParams> params = 725 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 726 727 if (failed(params)) { 728 LLVM_DEBUG( 729 DBGS() 730 << "failed to convert vector.transfer_read to ldmatrix. " 731 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 732 return rewriter.notifyMatchFailure( 733 op, "failed to convert vector.transfer_read to ldmatrix; this op " 734 "likely should not be converted to a nvgpu.ldmatrix call."); 735 } 736 737 // Adjust the load offset. 738 auto laneId = rewriter.create<gpu::LaneIdOp>(loc); 739 FailureOr<AffineMap> offsets = 740 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 741 if (failed(offsets)) { 742 LLVM_DEBUG(DBGS() << "no offsets\n"); 743 return rewriter.notifyMatchFailure(op, "no offsets"); 744 } 745 746 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 747 748 SmallVector<Value, 4> indices; 749 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 750 indices); 751 752 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 753 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 754 valueMapping[op] = newOp->getResult(0); 755 return success(); 756 } 757 758 static LogicalResult 759 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 760 llvm::DenseMap<Value, Value> &valueMapping) { 761 OpBuilder::InsertionGuard g(rewriter); 762 rewriter.setInsertionPoint(op); 763 764 Location loc = op.getLoc(); 765 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 766 nvgpu::getWarpMatrixInfo(op); 767 if (failed(warpMatrixInfo)) 768 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 769 FailureOr<nvgpu::FragmentElementInfo> regInfo = 770 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 771 if (failed(regInfo)) { 772 return rewriter.notifyMatchFailure( 773 op, "Failed to deduce register fragment type during " 774 "conversion to distributed non-ldmatrix compatible load"); 775 } 776 777 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 778 SmallVector<Value, 4> elements; 779 780 // This is the individual element type. 781 Type loadedElType = regInfo->registerLLVMType; 782 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 783 784 Value fill = rewriter.create<arith::ConstantOp>( 785 op.getLoc(), vectorType.getElementType(), 786 rewriter.getZeroAttr(vectorType.getElementType())); 787 Value result = 788 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 789 790 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 791 792 // If we are not transposing, then we can use vectorized loads. Otherwise, we 793 // must load each element individually. 794 if (!isTransposeLoad) { 795 if (!isa<VectorType>(loadedElType)) { 796 loadedElType = VectorType::get({1}, loadedElType); 797 } 798 799 for (int i = 0; i < vectorType.getShape()[0]; i++) { 800 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 801 rewriter, op.getLoc(), *warpMatrixInfo); 802 if (failed(coords)) 803 return rewriter.notifyMatchFailure(op, "no coords"); 804 805 Value logicalValueId = rewriter.create<arith::ConstantOp>( 806 loc, rewriter.getIndexType(), 807 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 808 SmallVector<Value, 4> newIndices; 809 getXferIndices<vector::TransferReadOp>( 810 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 811 812 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 813 op.getSource(), newIndices); 814 result = rewriter.create<vector::InsertOp>(loc, el, result, i); 815 } 816 } else { 817 if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 818 loadedElType = vecType.getElementType(); 819 } 820 for (int i = 0; i < vectorType.getShape()[0]; i++) { 821 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 822 innerIdx++) { 823 824 Value logicalValueId = rewriter.create<arith::ConstantOp>( 825 loc, rewriter.getIndexType(), 826 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 827 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 828 rewriter, op.getLoc(), *warpMatrixInfo); 829 if (failed(coords)) 830 return rewriter.notifyMatchFailure(op, "no coords"); 831 832 SmallVector<Value, 4> newIndices; 833 getXferIndices<vector::TransferReadOp>( 834 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 835 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 836 op.getSource(), newIndices); 837 result = rewriter.create<vector::InsertOp>( 838 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 839 } 840 } 841 } 842 843 valueMapping[op.getResult()] = result; 844 return success(); 845 } 846 847 /// Return true if this is a shared memory memref type. 848 static bool isSharedMemory(MemRefType type) { 849 auto addressSpace = 850 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 851 if (addressSpace && 852 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) 853 return true; 854 return false; 855 } 856 857 /// Converts a `vector.transfer_read` operation directly to either a 858 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 859 /// used when converting to `nvgpu.mma.sync` operations. 860 static LogicalResult 861 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 862 llvm::DenseMap<Value, Value> &valueMapping) { 863 OpBuilder::InsertionGuard g(rewriter); 864 rewriter.setInsertionPoint(op); 865 866 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 867 nvgpu::getWarpMatrixInfo(op); 868 if (failed(warpMatrixInfo)) 869 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 870 871 bool isLdMatrixCompatible = 872 isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 873 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 874 875 VectorType vecTy = op.getVectorType(); 876 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 877 878 // When we are transposing the B operand, ldmatrix will only work if we have 879 // at least 8 rows to read and the width to read for the transpose is 128 880 // bits. 881 if (!op.getPermutationMap().isMinorIdentity() && 882 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 883 vecTy.getDimSize(0) * bitWidth < 128)) 884 isLdMatrixCompatible = false; 885 886 if (!isLdMatrixCompatible) 887 return createNonLdMatrixLoads(rewriter, op, valueMapping); 888 889 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 890 } 891 892 static LogicalResult 893 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 894 llvm::DenseMap<Value, Value> &valueMapping) { 895 OpBuilder::InsertionGuard g(rewriter); 896 rewriter.setInsertionPoint(op); 897 898 Location loc = op->getLoc(); 899 auto it = valueMapping.find(op.getVector()); 900 if (it == valueMapping.end()) 901 return rewriter.notifyMatchFailure(op, "no mapping"); 902 Value matrix = it->second; 903 904 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 905 nvgpu::getWarpMatrixInfo(op); 906 if (failed(warpMatrixInfo)) 907 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 908 FailureOr<nvgpu::FragmentElementInfo> regInfo = 909 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 910 if (failed(regInfo)) 911 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 912 913 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 914 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 915 916 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 917 Value logicalValueId = rewriter.create<arith::ConstantOp>( 918 loc, rewriter.getIndexType(), 919 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 920 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 921 rewriter, op.getLoc(), *warpMatrixInfo); 922 if (failed(coords)) 923 return rewriter.notifyMatchFailure(op, "no coords"); 924 925 Value el = 926 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 927 SmallVector<Value, 4> newIndices; 928 getXferIndices<vector::TransferWriteOp>( 929 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 930 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 931 } 932 933 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 934 rewriter.eraseOp(op); 935 return success(); 936 } 937 938 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 939 SmallVectorImpl<int64_t> &results) { 940 for (auto attr : arrayAttr) 941 results.push_back(cast<IntegerAttr>(attr).getInt()); 942 } 943 944 static LogicalResult 945 convertExtractStridedSlice(RewriterBase &rewriter, 946 vector::ExtractStridedSliceOp op, 947 llvm::DenseMap<Value, Value> &valueMapping) { 948 OpBuilder::InsertionGuard g(rewriter); 949 rewriter.setInsertionPoint(op); 950 951 Location loc = op->getLoc(); 952 953 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 954 nvgpu::getWarpMatrixInfo(op); 955 if (failed(warpMatrixInfo)) 956 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 957 958 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 959 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 960 if (failed(mmaSyncFragmentInfo)) 961 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 962 963 // Find the vector.transer_read whose result vector is being sliced. 964 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 965 if (!transferReadOp) 966 return rewriter.notifyMatchFailure(op, "no transfer read"); 967 968 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 969 if (failed(warpMatrixInfo)) 970 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 971 972 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 973 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 974 if (failed(ldFragmentInfo)) 975 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 976 977 assert( 978 (mmaSyncFragmentInfo->elementsPerRegister == 979 ldFragmentInfo->elementsPerRegister) && 980 "Number of elements per register should be same for load and mma.sync"); 981 982 // Create vector.extract_strided_slice op for thread-owned fragments. 983 std::array<int64_t, 2> strides = {1, 984 1}; // stride for extract slice is always 1. 985 std::array<int64_t, 2> sliceShape = { 986 mmaSyncFragmentInfo->numRegistersPerFragment, 987 mmaSyncFragmentInfo->elementsPerRegister}; 988 auto it = valueMapping.find(transferReadOp); 989 if (it == valueMapping.end()) 990 return rewriter.notifyMatchFailure(op, "no mapping"); 991 auto sourceVector = it->second; 992 993 // offset and sizes at warp-level of onwership. 994 SmallVector<int64_t> offsets; 995 populateFromInt64AttrArray(op.getOffsets(), offsets); 996 997 SmallVector<int64_t> sizes; 998 populateFromInt64AttrArray(op.getSizes(), sizes); 999 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 1000 1001 // Compute offset in vector registers. Note that the mma.sync vector registers 1002 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 1003 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 1004 std::array<int64_t, 2> sliceOffset = {0, 0}; 1005 1006 if (offsets[0] && offsets[1]) 1007 return op->emitError() << "Slicing fragments in 2D is not supported. "; 1008 if (offsets[0]) 1009 sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1010 else if (offsets[1]) 1011 sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1012 1013 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1014 loc, sourceVector, sliceOffset, sliceShape, strides); 1015 1016 valueMapping[op] = newOp; 1017 return success(); 1018 } 1019 1020 static LogicalResult 1021 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1022 llvm::DenseMap<Value, Value> &valueMapping) { 1023 OpBuilder::InsertionGuard g(rewriter); 1024 rewriter.setInsertionPoint(op); 1025 1026 auto itA = valueMapping.find(op.getLhs()); 1027 auto itB = valueMapping.find(op.getRhs()); 1028 auto itC = valueMapping.find(op.getAcc()); 1029 if (itA == valueMapping.end() || itB == valueMapping.end() || 1030 itC == valueMapping.end()) 1031 return rewriter.notifyMatchFailure(op, "no mapping"); 1032 Value opA = itA->second, opB = itB->second, opC = itC->second; 1033 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 1034 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 1035 /*b_transpose=*/UnitAttr()); 1036 valueMapping[op.getResult()] = matmul; 1037 return success(); 1038 } 1039 1040 static LogicalResult 1041 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 1042 llvm::DenseMap<Value, Value> &valueMapping) { 1043 OpBuilder::InsertionGuard g(rewriter); 1044 rewriter.setInsertionPoint(op); 1045 1046 auto itA = valueMapping.find(op.getLhs()); 1047 auto itB = valueMapping.find(op.getRhs()); 1048 auto itC = valueMapping.find(op.getAcc()); 1049 if (itA == valueMapping.end() || itB == valueMapping.end() || 1050 itC == valueMapping.end()) 1051 return rewriter.notifyMatchFailure(op, "no mapping"); 1052 Value opA = itA->second, opB = itB->second, opC = itC->second; 1053 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 1054 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 1055 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 1056 Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 1057 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 1058 valueMapping[op.getResult()] = matmul; 1059 return success(); 1060 } 1061 1062 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 1063 static LogicalResult 1064 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 1065 llvm::DenseMap<Value, Value> &valueMapping) { 1066 OpBuilder::InsertionGuard g(rewriter); 1067 rewriter.setInsertionPoint(op); 1068 1069 assert(constantSupportsMMAMatrixType(op)); 1070 1071 auto splat = 1072 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 1073 auto scalarConstant = 1074 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 1075 const char *fragType = inferFragType(op); 1076 auto vecType = cast<VectorType>(op.getType()); 1077 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1078 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1079 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1080 op.getLoc(), type, scalarConstant); 1081 valueMapping[op.getResult()] = matrix; 1082 return success(); 1083 } 1084 1085 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 1086 static LogicalResult 1087 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 1088 llvm::DenseMap<Value, Value> &valueMapping) { 1089 OpBuilder::InsertionGuard g(rewriter); 1090 rewriter.setInsertionPoint(op); 1091 1092 assert(broadcastSupportsMMAMatrixType(op)); 1093 1094 const char *fragType = inferFragType(op); 1095 auto vecType = op.getResultVectorType(); 1096 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1097 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1098 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1099 op.getLoc(), type, op.getSource()); 1100 valueMapping[op.getResult()] = matrix; 1101 return success(); 1102 } 1103 1104 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 1105 // updated and needs to be updated separately for the loop to be correct. 1106 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 1107 scf::ForOp loop, 1108 ValueRange newInitArgs) { 1109 OpBuilder::InsertionGuard g(rewriter); 1110 rewriter.setInsertionPoint(loop); 1111 1112 // Create a new loop before the existing one, with the extra operands. 1113 rewriter.setInsertionPoint(loop); 1114 auto operands = llvm::to_vector<4>(loop.getInitArgs()); 1115 llvm::append_range(operands, newInitArgs); 1116 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 1117 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 1118 operands); 1119 newLoop.getBody()->erase(); 1120 1121 newLoop.getRegion().getBlocks().splice( 1122 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); 1123 for (Value operand : newInitArgs) 1124 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 1125 1126 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 1127 loop.getNumResults()))) 1128 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 1129 1130 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 1131 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 1132 LLVM_DEBUG(DBGS() << "erase: " << loop); 1133 1134 rewriter.eraseOp(loop); 1135 return newLoop; 1136 } 1137 1138 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 1139 llvm::DenseMap<Value, Value> &valueMapping) { 1140 OpBuilder::InsertionGuard g(rewriter); 1141 rewriter.setInsertionPoint(op); 1142 1143 SmallVector<Value> newOperands; 1144 SmallVector<std::pair<size_t, size_t>> argMapping; 1145 for (const auto &operand : llvm::enumerate(op.getInitArgs())) { 1146 auto it = valueMapping.find(operand.value()); 1147 if (it == valueMapping.end()) { 1148 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 1149 continue; 1150 } 1151 argMapping.push_back(std::make_pair( 1152 operand.index(), op.getInitArgs().size() + newOperands.size())); 1153 newOperands.push_back(it->second); 1154 } 1155 1156 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 1157 Block &loopBody = *newForOp.getBody(); 1158 for (auto mapping : argMapping) { 1159 valueMapping[newForOp.getResult(mapping.first)] = 1160 newForOp.getResult(mapping.second); 1161 valueMapping[loopBody.getArgument(mapping.first + 1162 newForOp.getNumInductionVars())] = 1163 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 1164 } 1165 1166 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 1167 return success(); 1168 } 1169 1170 static LogicalResult 1171 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 1172 llvm::DenseMap<Value, Value> &valueMapping) { 1173 OpBuilder::InsertionGuard g(rewriter); 1174 rewriter.setInsertionPoint(op); 1175 1176 auto loop = cast<scf::ForOp>(op->getParentOp()); 1177 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1178 for (const auto &operand : llvm::enumerate(op.getOperands())) { 1179 auto it = valueMapping.find(operand.value()); 1180 if (it == valueMapping.end()) 1181 continue; 1182 // Replace the yield of old value with the for op argument to make it easier 1183 // to remove the dead code. 1184 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; 1185 yieldOperands.push_back(it->second); 1186 } 1187 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 1188 1189 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 1190 rewriter.eraseOp(op); 1191 return success(); 1192 } 1193 1194 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 1195 static LogicalResult 1196 convertElementwiseOp(RewriterBase &rewriter, Operation *op, 1197 gpu::MMAElementwiseOp opType, 1198 llvm::DenseMap<Value, Value> &valueMapping) { 1199 OpBuilder::InsertionGuard g(rewriter); 1200 rewriter.setInsertionPoint(op); 1201 1202 SmallVector<Value> matrixOperands; 1203 for (Value operand : op->getOperands()) { 1204 auto it = valueMapping.find(operand); 1205 if (it == valueMapping.end()) 1206 return rewriter.notifyMatchFailure(op, "no mapping"); 1207 matrixOperands.push_back(it->second); 1208 } 1209 auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>(); 1210 if (opType == gpu::MMAElementwiseOp::EXTF) { 1211 // The floating point extension case has a different result type. 1212 auto vectorType = op->getResultTypes()[0].cast<VectorType>(); 1213 resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1214 vectorType.getElementType(), 1215 resultType.getOperand()); 1216 } 1217 1218 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1219 op->getLoc(), resultType, matrixOperands, opType); 1220 valueMapping[op->getResult(0)] = newOp; 1221 return success(); 1222 } 1223 1224 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 1225 bool useNvGpu) { 1226 if (!useNvGpu) { 1227 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1228 patterns.getContext()); 1229 return; 1230 } 1231 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1232 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1233 } 1234 1235 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 1236 Operation *rootOp) { 1237 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1238 llvm::DenseMap<Value, Value> valueMapping; 1239 1240 auto globalRes = LogicalResult::success(); 1241 for (Operation *op : ops) { 1242 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 1243 // Apparently callers do not want to early exit on failure here. 1244 auto res = LogicalResult::success(); 1245 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 1246 res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1247 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 1248 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1249 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 1250 res = convertContractOp(rewriter, contractOp, valueMapping); 1251 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 1252 res = convertConstantOp(rewriter, constantOp, valueMapping); 1253 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 1254 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 1255 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 1256 res = convertForOp(rewriter, forOp, valueMapping); 1257 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 1258 res = convertYieldOp(rewriter, yieldOp, valueMapping); 1259 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 1260 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1261 } 1262 if (failed(res)) 1263 globalRes = failure(); 1264 } 1265 return globalRes; 1266 } 1267 1268 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 1269 Operation *rootOp) { 1270 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 1271 llvm::DenseMap<Value, Value> valueMapping; 1272 for (Operation *op : ops) { 1273 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 1274 .Case([&](vector::TransferReadOp transferReadOp) { 1275 return convertTransferReadToLoads(rewriter, transferReadOp, 1276 valueMapping); 1277 }) 1278 .Case([&](vector::TransferWriteOp transferWriteOp) { 1279 return convertTransferWriteToStores(rewriter, transferWriteOp, 1280 valueMapping); 1281 }) 1282 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 1283 return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1284 valueMapping); 1285 }) 1286 .Case([&](vector::ContractionOp contractionOp) { 1287 return convertContractOpToMmaSync(rewriter, contractionOp, 1288 valueMapping); 1289 }) 1290 .Case([&](scf::ForOp forOp) { 1291 return convertForOp(rewriter, forOp, valueMapping); 1292 }) 1293 .Case([&](scf::YieldOp yieldOp) { 1294 return convertYieldOp(rewriter, yieldOp, valueMapping); 1295 }) 1296 .Case([&](arith::ConstantOp constOp) { 1297 return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 1298 }) 1299 .Default([&](Operation *op) { 1300 return op->emitError() << "unhandled vector to mma type: " << *op; 1301 }) 1302 .failed()) { 1303 return op->emitOpError() 1304 << "failed to convert op during vector-to-nvgpu conversion"; 1305 } 1306 } 1307 return success(); 1308 } 1309 1310 namespace { 1311 1312 struct ConvertVectorToGPUPass 1313 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 1314 1315 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 1316 useNvGpu.setValue(useNvGpu_); 1317 } 1318 1319 void runOnOperation() override { 1320 RewritePatternSet patterns(&getContext()); 1321 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 1322 if (failed( 1323 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 1324 return signalPassFailure(); 1325 1326 IRRewriter rewriter(&getContext()); 1327 if (useNvGpu) { 1328 if (failed( 1329 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 1330 return signalPassFailure(); 1331 return; 1332 } 1333 (void)convertVectorToMMAOps(rewriter, getOperation()); 1334 } 1335 }; 1336 1337 } // namespace 1338 1339 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 1340 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1341 } 1342