1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14 15 #include <type_traits> 16 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Analysis/TopologicalSortUtils.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 29 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 30 #include "mlir/IR/Builders.h" 31 #include "mlir/IR/BuiltinOps.h" 32 #include "mlir/IR/Region.h" 33 #include "mlir/Pass/Pass.h" 34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/STLExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 39 #define DEBUG_TYPE "vector-to-gpu" 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 41 #define DBGSNL() (llvm::dbgs() << "\n") 42 43 namespace mlir { 44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 45 #include "mlir/Conversion/Passes.h.inc" 46 } // namespace mlir 47 48 using namespace mlir; 49 50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 51 /// AffineMap representing offsets to apply to indices, the function fills 52 /// `indices` with the original indices plus the offsets. The offsets are 53 /// applied by taking into account the permutation map of the transfer op. If 54 /// the `offsetMap` has dimension placeholders, those should be provided in 55 /// `dimValues`. 56 template <typename TransferOpType> 57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 58 AffineMap offsetMap, ArrayRef<Value> dimValues, 59 SmallVector<Value, 4> &indices) { 60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 61 Location loc = xferOp.getLoc(); 62 unsigned offsetsIdx = 0; 63 for (auto expr : xferOp.getPermutationMap().getResults()) { 64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 65 Value prevIdx = indices[dim.getPosition()]; 66 SmallVector<OpFoldResult, 3> dims(dimValues); 67 dims.push_back(prevIdx); 68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 69 indices[dim.getPosition()] = affine::makeComposedAffineApply( 70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 71 continue; 72 } 73 } 74 } 75 76 // Return true if the contract op can be convert to MMA matmul. 77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 78 bool useNvGpu) { 79 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80 auto infer = [&](MapList m) { 81 return AffineMap::inferFromExprList(m, contract.getContext()); 82 }; 83 AffineExpr m, n, k; 84 bindDims(contract.getContext(), m, n, k); 85 auto iteratorTypes = contract.getIteratorTypes().getValue(); 86 if (!(vector::isParallelIterator(iteratorTypes[0]) && 87 vector::isParallelIterator(iteratorTypes[1]) && 88 vector::isReductionIterator(iteratorTypes[2]))) 89 return false; 90 91 // The contract needs to represent a matmul to be able to convert to 92 // MMAMatrix matmul. 93 if (!useNvGpu && 94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 95 return false; 96 if (useNvGpu && 97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 98 return false; 99 100 return true; 101 } 102 103 // Return true if the given map represents a transposed matrix load, 104 // i.e. (d0, d1, ...) -> (dn-1, dn-2). 105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 106 MLIRContext *ctx = permutationMap.getContext(); 107 // Local OpBuilder is fine here, we just build attributes. 108 OpBuilder b(ctx); 109 auto nDim = permutationMap.getNumDims(); 110 AffineExpr zero = b.getAffineConstantExpr(0); 111 if (nDim < 2) { 112 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 113 AffineExpr dim0 = b.getAffineDimExpr(0); 114 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 115 } 116 117 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 118 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 119 // Support both transposed and transposed+broadcasted cases. 120 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 121 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 122 } 123 124 // Return the stide for the second-to-last dimension of |type| if it is a memref 125 // and has a constant stride. 126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { 127 auto memrefType = dyn_cast<MemRefType>(type); 128 if (!memrefType) 129 return false; 130 // If the memref is 0 or 1D the horizontal stride is 0. 131 if (memrefType.getRank() < 2) 132 return 0; 133 int64_t offset = 0; 134 SmallVector<int64_t, 2> strides; 135 if (failed(memrefType.getStridesAndOffset(strides, offset)) || 136 strides.back() != 1) 137 return std::nullopt; 138 int64_t stride = strides[strides.size() - 2]; 139 if (stride == ShapedType::kDynamic) 140 return std::nullopt; 141 return stride; 142 } 143 144 // Return true if the transfer op can be converted to a MMA matrix load. 145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 147 readOp.getVectorType().getRank() != 2) 148 return false; 149 if (!getStaticallyKnownRowStride(readOp.getShapedType())) 150 return false; 151 152 // Only allow integer types if the signedness can be inferred. 153 if (readOp.getVectorType().getElementType().isInteger(8)) 154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 155 !isa<arith::ExtUIOp>(*readOp->user_begin()))) 156 return false; 157 158 AffineMap map = readOp.getPermutationMap(); 159 MLIRContext *ctx = readOp.getContext(); 160 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 161 AffineExpr zero = getAffineConstantExpr(0, ctx); 162 auto broadcastInnerDim = 163 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 164 return map.isMinorIdentity() || map == broadcastInnerDim || 165 isTransposeMatrixLoadMap(map); 166 } 167 168 // Return true if the transfer op can be converted to a MMA matrix store. 169 static bool 170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 171 // TODO: support 0-d corner case. 172 if (writeOp.getTransferRank() == 0) 173 return false; 174 175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 176 writeOp.getVectorType().getRank() != 2) 177 return false; 178 if (!getStaticallyKnownRowStride(writeOp.getShapedType())) 179 return false; 180 // TODO: Support transpose once it is added to GPU dialect ops. 181 if (!writeOp.getPermutationMap().isMinorIdentity()) 182 return false; 183 return true; 184 } 185 186 /// Return true if the constant is a splat to a 2D vector so that it can be 187 /// converted to a MMA constant matrix op. 188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 189 auto vecType = dyn_cast<VectorType>(constantOp.getType()); 190 if (!vecType || vecType.getRank() != 2) 191 return false; 192 return isa<SplatElementsAttr>(constantOp.getValue()); 193 } 194 195 /// Return true if this is a broadcast from scalar to a 2D vector. 196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 197 return broadcastOp.getResultVectorType().getRank() == 2; 198 } 199 200 /// Return true if this integer extend op can be folded into a contract op. 201 template <typename ExtOpTy> 202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 203 auto transferReadOp = 204 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>(); 205 if (!transferReadOp) 206 return false; 207 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>); 208 } 209 210 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 211 212 /// Return the MMA elementwise enum associated with `op` if it is supported. 213 /// Return `std::nullopt` otherwise. 214 static std::optional<gpu::MMAElementwiseOp> 215 convertElementwiseOpToMMA(Operation *op) { 216 if (isa<arith::AddFOp>(op)) 217 return gpu::MMAElementwiseOp::ADDF; 218 if (isa<arith::MulFOp>(op)) 219 return gpu::MMAElementwiseOp::MULF; 220 if (isa<arith::SubFOp>(op)) 221 return gpu::MMAElementwiseOp::SUBF; 222 if (isa<arith::MaximumFOp>(op)) 223 return gpu::MMAElementwiseOp::MAXF; 224 if (isa<arith::MinimumFOp>(op)) 225 return gpu::MMAElementwiseOp::MINF; 226 if (isa<arith::DivFOp>(op)) 227 return gpu::MMAElementwiseOp::DIVF; 228 if (isa<arith::AddIOp>(op)) 229 return gpu::MMAElementwiseOp::ADDI; 230 if (isa<arith::MulIOp>(op)) 231 return gpu::MMAElementwiseOp::MULI; 232 if (isa<arith::SubIOp>(op)) 233 return gpu::MMAElementwiseOp::SUBI; 234 if (isa<arith::DivSIOp>(op)) 235 return gpu::MMAElementwiseOp::DIVS; 236 if (isa<arith::DivUIOp>(op)) 237 return gpu::MMAElementwiseOp::DIVU; 238 if (isa<arith::NegFOp>(op)) 239 return gpu::MMAElementwiseOp::NEGATEF; 240 if (isa<arith::ExtFOp>(op)) 241 return gpu::MMAElementwiseOp::EXTF; 242 return std::nullopt; 243 } 244 245 /// Return true if the op is supported as elementwise op on MMAMatrix type. 246 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 247 return convertElementwiseOpToMMA(op).has_value(); 248 } 249 250 /// Returns true if the extract strided slice op is supported with `mma.sync` 251 /// path. 252 static bool 253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 254 255 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 256 nvgpu::getWarpMatrixInfo(op); 257 if (failed(warpMatrixInfo)) 258 return false; 259 260 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 261 if (failed(contractOp)) 262 return false; 263 264 // Handle vector.extract_strided_slice on registers containing 265 // matrixB and matrixC operands. vector.extract_strided_slice op 266 // is not supported on registers containing matrixA operands. 267 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 268 return (cast<VectorType>(op->getResult(0).getType()) == 269 cast<VectorType>((*contractOp).getRhs().getType())); 270 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 271 return (cast<VectorType>(op->getResult(0).getType()) == 272 cast<VectorType>((*contractOp).getAcc().getType())); 273 274 return false; 275 } 276 277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 278 if (isa<scf::ForOp, scf::YieldOp>(op)) 279 return true; 280 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 281 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) 282 : transferReadSupportsMMAMatrixType(transferRead); 283 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 284 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) 285 : transferWriteSupportsMMAMatrixType(transferWrite); 286 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 287 return useNvGpu && 288 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 289 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 290 return contractSupportsMMAMatrixType(contract, useNvGpu); 291 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 292 return constantSupportsMMAMatrixType(constant); 293 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 294 return broadcastSupportsMMAMatrixType(broadcast); 295 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 296 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 297 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 298 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 299 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 300 return fpExtendSupportsMMAMatrixType(fpExtend); 301 return elementwiseSupportsMMAMatrixType(op); 302 } 303 304 /// Return an unsorted slice handling scf.for region differently than 305 /// `getSlice`. In scf.for we only want to include as part of the slice elements 306 /// that are part of the use/def chain. 307 static SetVector<Operation *> 308 getSliceContract(Operation *op, 309 const BackwardSliceOptions &backwardSliceOptions, 310 const ForwardSliceOptions &forwardSliceOptions) { 311 SetVector<Operation *> slice; 312 slice.insert(op); 313 unsigned currentIndex = 0; 314 SetVector<Operation *> backwardSlice; 315 SetVector<Operation *> forwardSlice; 316 while (currentIndex != slice.size()) { 317 auto *currentOp = (slice)[currentIndex]; 318 // Compute and insert the backwardSlice starting from currentOp. 319 backwardSlice.clear(); 320 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 321 slice.insert(backwardSlice.begin(), backwardSlice.end()); 322 323 // Compute and insert the forwardSlice starting from currentOp. 324 forwardSlice.clear(); 325 // Special case for ForOp, we don't want to include the whole region but 326 // only the value using the region arguments. 327 // TODO: We should refine this to only care about the region arguments being 328 // converted to matrix type. 329 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 330 for (Value forOpResult : forOp.getResults()) 331 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 332 for (BlockArgument &arg : forOp.getRegionIterArgs()) 333 getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 334 } else { 335 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 336 } 337 slice.insert(forwardSlice.begin(), forwardSlice.end()); 338 ++currentIndex; 339 } 340 return slice; 341 } 342 343 // Analyze slice of operations based on convert op to figure out if the whole 344 // slice can be converted to MMA operations. 345 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 346 bool useNvGpu) { 347 auto hasVectorDest = [](Operation *op) { 348 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>); 349 }; 350 BackwardSliceOptions backwardSliceOptions; 351 backwardSliceOptions.filter = hasVectorDest; 352 353 auto hasVectorSrc = [](Operation *op) { 354 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>); 355 }; 356 ForwardSliceOptions forwardSliceOptions; 357 forwardSliceOptions.filter = hasVectorSrc; 358 359 SetVector<Operation *> opToConvert; 360 op->walk([&](vector::ContractionOp contract) { 361 if (opToConvert.contains(contract.getOperation())) 362 return; 363 SetVector<Operation *> dependentOps = 364 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 365 // If any instruction cannot use MMA matrix type drop the whole 366 // chain. MMA matrix are stored in an opaque type so they cannot be used 367 // by all operations. 368 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 369 if (!supportsMMaMatrixType(op, useNvGpu)) { 370 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); 371 return true; 372 } 373 return false; 374 })) 375 return; 376 377 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 378 }); 379 // Sort the operations so that we can convert them in topological order. 380 return topologicalSort(opToConvert); 381 } 382 383 namespace { 384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 385 // to MMA matmul. 386 struct PrepareContractToGPUMMA 387 : public OpRewritePattern<vector::ContractionOp> { 388 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 389 390 LogicalResult matchAndRewrite(vector::ContractionOp op, 391 PatternRewriter &rewriter) const override { 392 Location loc = op.getLoc(); 393 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 394 395 // Set up the parallel/reduction structure in right form. 396 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 397 auto infer = [&](MapList m) { 398 return AffineMap::inferFromExprList(m, op.getContext()); 399 }; 400 AffineExpr m, n, k; 401 bindDims(rewriter.getContext(), m, n, k); 402 static constexpr std::array<int64_t, 2> perm = {1, 0}; 403 auto iteratorTypes = op.getIteratorTypes().getValue(); 404 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 405 if (!(vector::isParallelIterator(iteratorTypes[0]) && 406 vector::isParallelIterator(iteratorTypes[1]) && 407 vector::isReductionIterator(iteratorTypes[2]))) 408 return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 409 // 410 // Two outer parallel, one inner reduction (matmat flavor). 411 // 412 // This is the classical row-major matmul, nothing to do. 413 if (maps == infer({{m, k}, {k, n}, {m, n}})) 414 return rewriter.notifyMatchFailure(op, "contraction already prepared"); 415 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 416 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 417 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 418 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 419 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 420 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 421 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 422 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 423 std::swap(rhs, lhs); 424 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 425 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 426 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 427 std::swap(rhs, lhs); 428 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 429 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 430 std::swap(lhs, rhs); 431 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 432 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 433 std::swap(lhs, rhs); 434 } else { 435 // TODO: llvm_unreachable ? 436 return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 437 } 438 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 439 op, lhs, rhs, res, 440 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 441 op.getIteratorTypes()); 442 return success(); 443 } 444 }; 445 446 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports 447 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 448 // respectively. We can fold the transpose operation when loading the data from 449 // Shared Memory to registers. 450 struct CombineTransferReadOpTranspose final 451 : public OpRewritePattern<vector::TransposeOp> { 452 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 453 454 LogicalResult matchAndRewrite(vector::TransposeOp op, 455 PatternRewriter &rewriter) const override { 456 // Look through integer extend ops. 457 Value source = op.getVector(); 458 Type resultType = op.getType(); 459 Operation *extOp; 460 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 461 (extOp = source.getDefiningOp<arith::ExtUIOp>()) || 462 (extOp = source.getDefiningOp<arith::ExtFOp>())) { 463 source = extOp->getOperand(0); 464 resultType = 465 VectorType::get(cast<VectorType>(resultType).getShape(), 466 cast<VectorType>(source.getType()).getElementType()); 467 } 468 469 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 470 if (!transferReadOp) 471 return rewriter.notifyMatchFailure(op, "no transfer read"); 472 473 // TODO: support 0-d corner case. 474 if (transferReadOp.getTransferRank() == 0) 475 return rewriter.notifyMatchFailure(op, "0-D transfer read"); 476 477 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 478 return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 479 480 AffineMap permutationMap = 481 AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); 482 AffineMap newMap = 483 permutationMap.compose(transferReadOp.getPermutationMap()); 484 485 auto loc = op.getLoc(); 486 Value result = 487 rewriter 488 .create<vector::TransferReadOp>( 489 loc, resultType, transferReadOp.getSource(), 490 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 491 transferReadOp.getPadding(), transferReadOp.getMask(), 492 transferReadOp.getInBoundsAttr()) 493 .getResult(); 494 495 // Fuse through the integer extend op. 496 if (extOp) { 497 if (isa<arith::ExtSIOp>(extOp)) 498 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 499 .getResult(); 500 else if (isa<arith::ExtUIOp>(extOp)) 501 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 502 .getResult(); 503 else 504 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) 505 .getResult(); 506 } 507 508 rewriter.replaceOp(op, result); 509 return success(); 510 } 511 }; 512 513 } // namespace 514 515 // MMA types have different layout based on how they are used in matmul ops. 516 // Figure the right layout to use by looking at op uses. 517 // TODO: Change the GPU dialect to abstract the layout at the this level and 518 // only care about it during lowering to NVVM. 519 static const char *inferFragType(Operation *op) { 520 // We can have arith.ext ops before reaching contract ops. See through them 521 // and other kinds of elementwise ops. 522 if (op->hasOneUse()) { 523 Operation *userOp = *op->user_begin(); 524 if (userOp->hasTrait<OpTrait::Elementwise>()) 525 return inferFragType(userOp); 526 } 527 528 for (Operation *users : op->getUsers()) { 529 auto contract = dyn_cast<vector::ContractionOp>(users); 530 if (!contract) 531 continue; 532 assert(op->getNumResults() == 1); 533 if (contract.getLhs() == op->getResult(0)) 534 return "AOp"; 535 if (contract.getRhs() == op->getResult(0)) 536 return "BOp"; 537 } 538 return "COp"; 539 } 540 541 static LogicalResult 542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 543 llvm::DenseMap<Value, Value> &valueMapping) { 544 OpBuilder::InsertionGuard g(rewriter); 545 rewriter.setInsertionPoint(op); 546 547 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 548 assert(transferReadSupportsMMAMatrixType(op) && 549 "expected convertible operation"); 550 551 std::optional<int64_t> stride = 552 getStaticallyKnownRowStride(op.getShapedType()); 553 if (!stride.has_value()) { 554 LLVM_DEBUG(DBGS() << "no stride\n"); 555 return rewriter.notifyMatchFailure(op, "no stride"); 556 } 557 558 AffineMap map = op.getPermutationMap(); 559 bool isTranspose = isTransposeMatrixLoadMap(map); 560 561 // Handle broadcast by setting the stride to 0. 562 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { 563 assert(cstExpr.getValue() == 0); 564 stride = 0; 565 } 566 567 Value mappingResult = op.getResult(); 568 auto elType = op.getVectorType().getElementType(); 569 const char *fragType = inferFragType(op); 570 if (op->hasOneUse()) { 571 auto *user = *op->user_begin(); 572 // Infer the signedness of the mma type from the integer extend. 573 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) { 574 elType = IntegerType::get( 575 op.getContext(), cast<IntegerType>(elType).getWidth(), 576 isa<arith::ExtSIOp>(user) ? IntegerType::Signed 577 : IntegerType::Unsigned); 578 mappingResult = user->getResult(0); 579 } 580 } 581 gpu::MMAMatrixType type = 582 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 583 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 584 op.getLoc(), type, op.getSource(), op.getIndices(), 585 rewriter.getIndexAttr(*stride), 586 isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 587 valueMapping[mappingResult] = load; 588 589 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 590 return success(); 591 } 592 593 static LogicalResult 594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 595 llvm::DenseMap<Value, Value> &valueMapping) { 596 OpBuilder::InsertionGuard g(rewriter); 597 rewriter.setInsertionPoint(op); 598 599 assert(transferWriteSupportsMMAMatrixType(op)); 600 std::optional<int64_t> stride = 601 getStaticallyKnownRowStride(op.getShapedType()); 602 if (!stride.has_value()) { 603 LLVM_DEBUG(DBGS() << "no stride\n"); 604 return rewriter.notifyMatchFailure(op, "no stride"); 605 } 606 607 auto it = valueMapping.find(op.getVector()); 608 if (it == valueMapping.end()) { 609 LLVM_DEBUG(DBGS() << "no mapping\n"); 610 return rewriter.notifyMatchFailure(op, "no mapping"); 611 } 612 613 Value matrix = it->second; 614 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 615 op.getLoc(), matrix, op.getSource(), op.getIndices(), 616 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 617 (void)store; 618 619 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 620 621 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 622 rewriter.eraseOp(op); 623 return success(); 624 } 625 626 /// Returns the vector type which represents a matrix fragment. 627 static VectorType 628 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 629 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 630 regInfo.elementsPerRegister}; 631 Type elType = regInfo.registerLLVMType; 632 if (auto vecType = dyn_cast<VectorType>(elType)) 633 elType = vecType.getElementType(); 634 return VectorType::get(shape, elType); 635 } 636 637 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 638 static LogicalResult 639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 640 llvm::DenseMap<Value, Value> &valueMapping) { 641 OpBuilder::InsertionGuard g(rewriter); 642 rewriter.setInsertionPoint(op); 643 644 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 645 nvgpu::getWarpMatrixInfo(op); 646 if (failed(warpMatrixInfo)) { 647 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 648 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 649 } 650 651 FailureOr<nvgpu::FragmentElementInfo> regInfo = 652 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 653 if (failed(regInfo)) { 654 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 655 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 656 } 657 658 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 659 auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 660 if (!dense) { 661 LLVM_DEBUG(DBGS() << "not a splat\n"); 662 return rewriter.notifyMatchFailure(op, "not a splat"); 663 } 664 665 Value result = rewriter.create<arith::ConstantOp>( 666 op.getLoc(), vectorType, 667 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 668 valueMapping[op.getResult()] = result; 669 return success(); 670 } 671 672 /// Check if the loaded matrix operand requires transposed. 673 /// Transposed Map Example: 674 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 675 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 676 /// The code below checks if the output 2D is transposed using a generalized 677 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 678 /// Returns : true; if m > n, false o.w. 679 static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 680 mlir::AffineMap map = op.getPermutationMap(); 681 682 if (map.getNumResults() != 2) { 683 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 684 "is not a 2d operand\n"); 685 return failure(); 686 } 687 688 // Output 2D matrix dimensions in the order of d0, d1. 689 mlir::AffineExpr dM = map.getResult(0); 690 mlir::AffineExpr dN = map.getResult(1); 691 692 // Find the position of these expressions in the input. 693 auto exprM = dyn_cast<AffineDimExpr>(dM); 694 auto exprN = dyn_cast<AffineDimExpr>(dN); 695 696 if (!exprM || !exprN) { 697 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 698 "expressions, then transpose cannot be determined.\n"); 699 return failure(); 700 } 701 702 return exprM.getPosition() > exprN.getPosition(); 703 } 704 705 static LogicalResult 706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 707 llvm::DenseMap<Value, Value> &valueMapping) { 708 OpBuilder::InsertionGuard g(rewriter); 709 rewriter.setInsertionPoint(op); 710 Location loc = op->getLoc(); 711 712 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 713 nvgpu::getWarpMatrixInfo(op); 714 if (failed(warpMatrixInfo)) { 715 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 716 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 717 } 718 719 FailureOr<nvgpu::FragmentElementInfo> regInfo = 720 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 721 if (failed(regInfo)) { 722 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 723 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 724 } 725 726 FailureOr<bool> transpose = isTransposed(op); 727 if (failed(transpose)) { 728 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 729 return rewriter.notifyMatchFailure( 730 op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 731 } 732 733 FailureOr<nvgpu::LdMatrixParams> params = 734 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 735 736 if (failed(params)) { 737 LLVM_DEBUG( 738 DBGS() 739 << "failed to convert vector.transfer_read to ldmatrix. " 740 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 741 return rewriter.notifyMatchFailure( 742 op, "failed to convert vector.transfer_read to ldmatrix; this op " 743 "likely should not be converted to a nvgpu.ldmatrix call."); 744 } 745 746 // Adjust the load offset. 747 auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 748 FailureOr<AffineMap> offsets = 749 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 750 if (failed(offsets)) { 751 LLVM_DEBUG(DBGS() << "no offsets\n"); 752 return rewriter.notifyMatchFailure(op, "no offsets"); 753 } 754 755 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 756 757 SmallVector<Value, 4> indices; 758 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 759 indices); 760 761 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 762 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 763 valueMapping[op] = newOp->getResult(0); 764 return success(); 765 } 766 767 static LogicalResult 768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 769 llvm::DenseMap<Value, Value> &valueMapping) { 770 OpBuilder::InsertionGuard g(rewriter); 771 rewriter.setInsertionPoint(op); 772 773 Location loc = op.getLoc(); 774 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 775 nvgpu::getWarpMatrixInfo(op); 776 if (failed(warpMatrixInfo)) 777 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 778 FailureOr<nvgpu::FragmentElementInfo> regInfo = 779 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 780 if (failed(regInfo)) { 781 return rewriter.notifyMatchFailure( 782 op, "Failed to deduce register fragment type during " 783 "conversion to distributed non-ldmatrix compatible load"); 784 } 785 786 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 787 SmallVector<Value, 4> elements; 788 789 // This is the individual element type. 790 Type loadedElType = regInfo->registerLLVMType; 791 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 792 793 Value fill = rewriter.create<arith::ConstantOp>( 794 op.getLoc(), vectorType.getElementType(), 795 rewriter.getZeroAttr(vectorType.getElementType())); 796 Value result = 797 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 798 799 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 800 801 // If we are not transposing, then we can use vectorized loads. Otherwise, we 802 // must load each element individually. 803 if (!isTransposeLoad) { 804 if (!isa<VectorType>(loadedElType)) { 805 loadedElType = VectorType::get({1}, loadedElType); 806 } 807 808 for (int i = 0; i < vectorType.getShape()[0]; i++) { 809 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 810 rewriter, op.getLoc(), *warpMatrixInfo); 811 if (failed(coords)) 812 return rewriter.notifyMatchFailure(op, "no coords"); 813 814 Value logicalValueId = rewriter.create<arith::ConstantOp>( 815 loc, rewriter.getIndexType(), 816 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 817 SmallVector<Value, 4> newIndices; 818 getXferIndices<vector::TransferReadOp>( 819 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 820 821 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 822 op.getSource(), newIndices); 823 result = rewriter.create<vector::InsertOp>(loc, el, result, i); 824 } 825 } else { 826 if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 827 loadedElType = vecType.getElementType(); 828 } 829 for (int i = 0; i < vectorType.getShape()[0]; i++) { 830 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 831 innerIdx++) { 832 833 Value logicalValueId = rewriter.create<arith::ConstantOp>( 834 loc, rewriter.getIndexType(), 835 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 836 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 837 rewriter, op.getLoc(), *warpMatrixInfo); 838 if (failed(coords)) 839 return rewriter.notifyMatchFailure(op, "no coords"); 840 841 SmallVector<Value, 4> newIndices; 842 getXferIndices<vector::TransferReadOp>( 843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 844 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 845 op.getSource(), newIndices); 846 result = rewriter.create<vector::InsertOp>( 847 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 848 } 849 } 850 } 851 852 valueMapping[op.getResult()] = result; 853 return success(); 854 } 855 856 /// Return true if this is a shared memory memref type. 857 static bool isSharedMemory(MemRefType type) { 858 auto addressSpace = 859 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 860 return addressSpace && 861 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); 862 } 863 864 /// Converts a `vector.transfer_read` operation directly to either a 865 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 866 /// used when converting to `nvgpu.mma.sync` operations. 867 static LogicalResult 868 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 869 llvm::DenseMap<Value, Value> &valueMapping) { 870 OpBuilder::InsertionGuard g(rewriter); 871 rewriter.setInsertionPoint(op); 872 873 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 874 nvgpu::getWarpMatrixInfo(op); 875 if (failed(warpMatrixInfo)) 876 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 877 878 bool isLdMatrixCompatible = 879 isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 880 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 881 882 VectorType vecTy = op.getVectorType(); 883 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 884 885 // When we are transposing the B operand, ldmatrix will only work if we have 886 // at least 8 rows to read and the width to read for the transpose is 128 887 // bits. 888 if (!op.getPermutationMap().isMinorIdentity() && 889 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 890 vecTy.getDimSize(0) * bitWidth < 128)) 891 isLdMatrixCompatible = false; 892 893 if (!isLdMatrixCompatible) 894 return createNonLdMatrixLoads(rewriter, op, valueMapping); 895 896 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 897 } 898 899 static LogicalResult 900 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 901 llvm::DenseMap<Value, Value> &valueMapping) { 902 OpBuilder::InsertionGuard g(rewriter); 903 rewriter.setInsertionPoint(op); 904 905 Location loc = op->getLoc(); 906 auto it = valueMapping.find(op.getVector()); 907 if (it == valueMapping.end()) 908 return rewriter.notifyMatchFailure(op, "no mapping"); 909 Value matrix = it->second; 910 911 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 912 nvgpu::getWarpMatrixInfo(op); 913 if (failed(warpMatrixInfo)) 914 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 915 FailureOr<nvgpu::FragmentElementInfo> regInfo = 916 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 917 if (failed(regInfo)) 918 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 919 920 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 921 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 922 923 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 924 Value logicalValueId = rewriter.create<arith::ConstantOp>( 925 loc, rewriter.getIndexType(), 926 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 927 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 928 rewriter, op.getLoc(), *warpMatrixInfo); 929 if (failed(coords)) 930 return rewriter.notifyMatchFailure(op, "no coords"); 931 932 Value el = 933 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 934 SmallVector<Value, 4> newIndices; 935 getXferIndices<vector::TransferWriteOp>( 936 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 937 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 938 } 939 940 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 941 rewriter.eraseOp(op); 942 return success(); 943 } 944 945 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 946 SmallVectorImpl<int64_t> &results) { 947 for (auto attr : arrayAttr) 948 results.push_back(cast<IntegerAttr>(attr).getInt()); 949 } 950 951 static LogicalResult 952 convertExtractStridedSlice(RewriterBase &rewriter, 953 vector::ExtractStridedSliceOp op, 954 llvm::DenseMap<Value, Value> &valueMapping) { 955 OpBuilder::InsertionGuard g(rewriter); 956 rewriter.setInsertionPoint(op); 957 958 Location loc = op->getLoc(); 959 960 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 961 nvgpu::getWarpMatrixInfo(op); 962 if (failed(warpMatrixInfo)) 963 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 964 965 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 966 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 967 if (failed(mmaSyncFragmentInfo)) 968 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 969 970 // Find the vector.transer_read whose result vector is being sliced. 971 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 972 if (!transferReadOp) 973 return rewriter.notifyMatchFailure(op, "no transfer read"); 974 975 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 976 if (failed(warpMatrixInfo)) 977 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 978 979 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 980 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 981 if (failed(ldFragmentInfo)) 982 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 983 984 assert( 985 (mmaSyncFragmentInfo->elementsPerRegister == 986 ldFragmentInfo->elementsPerRegister) && 987 "Number of elements per register should be same for load and mma.sync"); 988 989 // Create vector.extract_strided_slice op for thread-owned fragments. 990 std::array<int64_t, 2> strides = {1, 991 1}; // stride for extract slice is always 1. 992 std::array<int64_t, 2> sliceShape = { 993 mmaSyncFragmentInfo->numRegistersPerFragment, 994 mmaSyncFragmentInfo->elementsPerRegister}; 995 auto it = valueMapping.find(transferReadOp); 996 if (it == valueMapping.end()) 997 return rewriter.notifyMatchFailure(op, "no mapping"); 998 auto sourceVector = it->second; 999 1000 // offset and sizes at warp-level of onwership. 1001 SmallVector<int64_t> offsets; 1002 populateFromInt64AttrArray(op.getOffsets(), offsets); 1003 1004 SmallVector<int64_t> sizes; 1005 populateFromInt64AttrArray(op.getSizes(), sizes); 1006 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 1007 1008 // Compute offset in vector registers. Note that the mma.sync vector registers 1009 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 1010 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 1011 std::array<int64_t, 2> sliceOffset = {0, 0}; 1012 1013 if (offsets[0] && offsets[1]) 1014 return op->emitError() << "Slicing fragments in 2D is not supported. "; 1015 if (offsets[0]) 1016 sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1017 else if (offsets[1]) 1018 sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1019 1020 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1021 loc, sourceVector, sliceOffset, sliceShape, strides); 1022 1023 valueMapping[op] = newOp; 1024 return success(); 1025 } 1026 1027 static LogicalResult 1028 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1029 llvm::DenseMap<Value, Value> &valueMapping) { 1030 OpBuilder::InsertionGuard g(rewriter); 1031 rewriter.setInsertionPoint(op); 1032 1033 auto itA = valueMapping.find(op.getLhs()); 1034 auto itB = valueMapping.find(op.getRhs()); 1035 auto itC = valueMapping.find(op.getAcc()); 1036 if (itA == valueMapping.end() || itB == valueMapping.end() || 1037 itC == valueMapping.end()) 1038 return rewriter.notifyMatchFailure(op, "no mapping"); 1039 Value opA = itA->second, opB = itB->second, opC = itC->second; 1040 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 1041 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 1042 /*b_transpose=*/UnitAttr()); 1043 valueMapping[op.getResult()] = matmul; 1044 return success(); 1045 } 1046 1047 static LogicalResult 1048 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 1049 llvm::DenseMap<Value, Value> &valueMapping) { 1050 OpBuilder::InsertionGuard g(rewriter); 1051 rewriter.setInsertionPoint(op); 1052 1053 auto itA = valueMapping.find(op.getLhs()); 1054 auto itB = valueMapping.find(op.getRhs()); 1055 auto itC = valueMapping.find(op.getAcc()); 1056 if (itA == valueMapping.end() || itB == valueMapping.end() || 1057 itC == valueMapping.end()) 1058 return rewriter.notifyMatchFailure(op, "no mapping"); 1059 Value opA = itA->second, opB = itB->second, opC = itC->second; 1060 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 1061 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 1062 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 1063 Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 1064 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 1065 valueMapping[op.getResult()] = matmul; 1066 return success(); 1067 } 1068 1069 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 1070 static LogicalResult 1071 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 1072 llvm::DenseMap<Value, Value> &valueMapping) { 1073 OpBuilder::InsertionGuard g(rewriter); 1074 rewriter.setInsertionPoint(op); 1075 1076 assert(constantSupportsMMAMatrixType(op)); 1077 1078 auto splat = 1079 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 1080 auto scalarConstant = 1081 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 1082 const char *fragType = inferFragType(op); 1083 auto vecType = cast<VectorType>(op.getType()); 1084 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1085 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1086 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1087 op.getLoc(), type, scalarConstant); 1088 valueMapping[op.getResult()] = matrix; 1089 return success(); 1090 } 1091 1092 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 1093 static LogicalResult 1094 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 1095 llvm::DenseMap<Value, Value> &valueMapping) { 1096 OpBuilder::InsertionGuard g(rewriter); 1097 rewriter.setInsertionPoint(op); 1098 1099 assert(broadcastSupportsMMAMatrixType(op)); 1100 1101 const char *fragType = inferFragType(op); 1102 auto vecType = op.getResultVectorType(); 1103 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1104 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1105 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1106 op.getLoc(), type, op.getSource()); 1107 valueMapping[op.getResult()] = matrix; 1108 return success(); 1109 } 1110 1111 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 1112 // updated and needs to be updated separately for the loop to be correct. 1113 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 1114 scf::ForOp loop, 1115 ValueRange newInitArgs) { 1116 OpBuilder::InsertionGuard g(rewriter); 1117 rewriter.setInsertionPoint(loop); 1118 1119 // Create a new loop before the existing one, with the extra operands. 1120 rewriter.setInsertionPoint(loop); 1121 auto operands = llvm::to_vector<4>(loop.getInitArgs()); 1122 llvm::append_range(operands, newInitArgs); 1123 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 1124 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 1125 operands); 1126 rewriter.eraseBlock(newLoop.getBody()); 1127 1128 newLoop.getRegion().getBlocks().splice( 1129 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); 1130 for (Value operand : newInitArgs) 1131 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 1132 1133 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 1134 loop.getNumResults()))) 1135 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 1136 1137 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 1138 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 1139 LLVM_DEBUG(DBGS() << "erase: " << loop); 1140 1141 rewriter.eraseOp(loop); 1142 return newLoop; 1143 } 1144 1145 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 1146 llvm::DenseMap<Value, Value> &valueMapping) { 1147 OpBuilder::InsertionGuard g(rewriter); 1148 rewriter.setInsertionPoint(op); 1149 1150 SmallVector<Value> newOperands; 1151 SmallVector<std::pair<size_t, size_t>> argMapping; 1152 for (const auto &operand : llvm::enumerate(op.getInitArgs())) { 1153 auto it = valueMapping.find(operand.value()); 1154 if (it == valueMapping.end()) { 1155 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 1156 continue; 1157 } 1158 argMapping.push_back(std::make_pair( 1159 operand.index(), op.getInitArgs().size() + newOperands.size())); 1160 newOperands.push_back(it->second); 1161 } 1162 1163 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 1164 Block &loopBody = *newForOp.getBody(); 1165 for (auto mapping : argMapping) { 1166 valueMapping[newForOp.getResult(mapping.first)] = 1167 newForOp.getResult(mapping.second); 1168 valueMapping[loopBody.getArgument(mapping.first + 1169 newForOp.getNumInductionVars())] = 1170 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 1171 } 1172 1173 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 1174 return success(); 1175 } 1176 1177 static LogicalResult 1178 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 1179 llvm::DenseMap<Value, Value> &valueMapping) { 1180 OpBuilder::InsertionGuard g(rewriter); 1181 rewriter.setInsertionPoint(op); 1182 1183 auto loop = cast<scf::ForOp>(op->getParentOp()); 1184 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1185 for (const auto &operand : llvm::enumerate(op.getOperands())) { 1186 auto it = valueMapping.find(operand.value()); 1187 if (it == valueMapping.end()) 1188 continue; 1189 // Replace the yield of old value with the for op argument to make it easier 1190 // to remove the dead code. 1191 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; 1192 yieldOperands.push_back(it->second); 1193 } 1194 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 1195 1196 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 1197 rewriter.eraseOp(op); 1198 return success(); 1199 } 1200 1201 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 1202 static LogicalResult 1203 convertElementwiseOp(RewriterBase &rewriter, Operation *op, 1204 gpu::MMAElementwiseOp opType, 1205 llvm::DenseMap<Value, Value> &valueMapping) { 1206 OpBuilder::InsertionGuard g(rewriter); 1207 rewriter.setInsertionPoint(op); 1208 1209 SmallVector<Value> matrixOperands; 1210 for (Value operand : op->getOperands()) { 1211 auto it = valueMapping.find(operand); 1212 if (it == valueMapping.end()) 1213 return rewriter.notifyMatchFailure(op, "no mapping"); 1214 matrixOperands.push_back(it->second); 1215 } 1216 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType()); 1217 if (opType == gpu::MMAElementwiseOp::EXTF) { 1218 // The floating point extension case has a different result type. 1219 auto vectorType = cast<VectorType>(op->getResultTypes()[0]); 1220 resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1221 vectorType.getElementType(), 1222 resultType.getOperand()); 1223 } 1224 1225 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1226 op->getLoc(), resultType, matrixOperands, opType); 1227 valueMapping[op->getResult(0)] = newOp; 1228 return success(); 1229 } 1230 1231 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 1232 bool useNvGpu) { 1233 if (!useNvGpu) { 1234 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1235 patterns.getContext()); 1236 return; 1237 } 1238 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1239 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1240 } 1241 1242 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 1243 Operation *rootOp) { 1244 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1245 llvm::DenseMap<Value, Value> valueMapping; 1246 1247 auto globalRes = LogicalResult::success(); 1248 for (Operation *op : ops) { 1249 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 1250 // Apparently callers do not want to early exit on failure here. 1251 auto res = LogicalResult::success(); 1252 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 1253 res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1254 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 1255 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1256 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 1257 res = convertContractOp(rewriter, contractOp, valueMapping); 1258 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 1259 res = convertConstantOp(rewriter, constantOp, valueMapping); 1260 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 1261 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 1262 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 1263 res = convertForOp(rewriter, forOp, valueMapping); 1264 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 1265 res = convertYieldOp(rewriter, yieldOp, valueMapping); 1266 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 1267 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1268 } 1269 if (failed(res)) 1270 globalRes = failure(); 1271 } 1272 return globalRes; 1273 } 1274 1275 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 1276 Operation *rootOp) { 1277 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 1278 llvm::DenseMap<Value, Value> valueMapping; 1279 for (Operation *op : ops) { 1280 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 1281 .Case([&](vector::TransferReadOp transferReadOp) { 1282 return convertTransferReadToLoads(rewriter, transferReadOp, 1283 valueMapping); 1284 }) 1285 .Case([&](vector::TransferWriteOp transferWriteOp) { 1286 return convertTransferWriteToStores(rewriter, transferWriteOp, 1287 valueMapping); 1288 }) 1289 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 1290 return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1291 valueMapping); 1292 }) 1293 .Case([&](vector::ContractionOp contractionOp) { 1294 return convertContractOpToMmaSync(rewriter, contractionOp, 1295 valueMapping); 1296 }) 1297 .Case([&](scf::ForOp forOp) { 1298 return convertForOp(rewriter, forOp, valueMapping); 1299 }) 1300 .Case([&](scf::YieldOp yieldOp) { 1301 return convertYieldOp(rewriter, yieldOp, valueMapping); 1302 }) 1303 .Case([&](arith::ConstantOp constOp) { 1304 return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 1305 }) 1306 .Default([&](Operation *op) { 1307 return op->emitError() << "unhandled vector to mma type: " << *op; 1308 }) 1309 .failed()) { 1310 return op->emitOpError() 1311 << "failed to convert op during vector-to-nvgpu conversion"; 1312 } 1313 } 1314 return success(); 1315 } 1316 1317 namespace { 1318 1319 struct ConvertVectorToGPUPass 1320 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 1321 1322 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 1323 useNvGpu.setValue(useNvGpu_); 1324 } 1325 1326 void runOnOperation() override { 1327 RewritePatternSet patterns(&getContext()); 1328 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 1329 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 1330 return signalPassFailure(); 1331 1332 IRRewriter rewriter(&getContext()); 1333 if (useNvGpu) { 1334 if (failed( 1335 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 1336 return signalPassFailure(); 1337 return; 1338 } 1339 (void)convertVectorToMMAOps(rewriter, getOperation()); 1340 } 1341 }; 1342 1343 } // namespace 1344 1345 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 1346 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1347 } 1348