1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14 15 #include <type_traits> 16 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Analysis/TopologicalSortUtils.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 29 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 30 #include "mlir/IR/Builders.h" 31 #include "mlir/IR/BuiltinOps.h" 32 #include "mlir/IR/Region.h" 33 #include "mlir/Pass/Pass.h" 34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/STLExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 39 #define DEBUG_TYPE "vector-to-gpu" 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 41 #define DBGSNL() (llvm::dbgs() << "\n") 42 43 namespace mlir { 44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 45 #include "mlir/Conversion/Passes.h.inc" 46 } // namespace mlir 47 48 using namespace mlir; 49 50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 51 /// AffineMap representing offsets to apply to indices, the function fills 52 /// `indices` with the original indices plus the offsets. The offsets are 53 /// applied by taking into account the permutation map of the transfer op. If 54 /// the `offsetMap` has dimension placeholders, those should be provided in 55 /// `dimValues`. 56 template <typename TransferOpType> 57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 58 AffineMap offsetMap, ArrayRef<Value> dimValues, 59 SmallVector<Value, 4> &indices) { 60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 61 Location loc = xferOp.getLoc(); 62 unsigned offsetsIdx = 0; 63 for (auto expr : xferOp.getPermutationMap().getResults()) { 64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 65 Value prevIdx = indices[dim.getPosition()]; 66 SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end()); 67 dims.push_back(prevIdx); 68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 69 indices[dim.getPosition()] = affine::makeComposedAffineApply( 70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 71 continue; 72 } 73 } 74 } 75 76 // Return true if the contract op can be convert to MMA matmul. 77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 78 bool useNvGpu) { 79 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80 auto infer = [&](MapList m) { 81 return AffineMap::inferFromExprList(m, contract.getContext()); 82 }; 83 AffineExpr m, n, k; 84 bindDims(contract.getContext(), m, n, k); 85 auto iteratorTypes = contract.getIteratorTypes().getValue(); 86 if (!(vector::isParallelIterator(iteratorTypes[0]) && 87 vector::isParallelIterator(iteratorTypes[1]) && 88 vector::isReductionIterator(iteratorTypes[2]))) 89 return false; 90 91 // The contract needs to represent a matmul to be able to convert to 92 // MMAMatrix matmul. 93 if (!useNvGpu && 94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 95 return false; 96 if (useNvGpu && 97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 98 return false; 99 100 return true; 101 } 102 103 // Return true if the given map represents a transposed matrix load, 104 // i.e. (d0, d1, ...) -> (dn-1, dn-2). 105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 106 MLIRContext *ctx = permutationMap.getContext(); 107 // Local OpBuilder is fine here, we just build attributes. 108 OpBuilder b(ctx); 109 auto nDim = permutationMap.getNumDims(); 110 AffineExpr zero = b.getAffineConstantExpr(0); 111 if (nDim < 2) { 112 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 113 AffineExpr dim0 = b.getAffineDimExpr(0); 114 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 115 } 116 117 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 118 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 119 // Support both transposed and transposed+broadcasted cases. 120 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 121 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 122 } 123 124 // Return the stide for the second-to-last dimension of |type| if it is a memref 125 // and has a constant stride. 126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { 127 auto memrefType = dyn_cast<MemRefType>(type); 128 if (!memrefType) 129 return false; 130 // If the memref is 0 or 1D the horizontal stride is 0. 131 if (memrefType.getRank() < 2) 132 return 0; 133 int64_t offset = 0; 134 SmallVector<int64_t, 2> strides; 135 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 136 strides.back() != 1) 137 return std::nullopt; 138 int64_t stride = strides[strides.size() - 2]; 139 if (stride == ShapedType::kDynamic) 140 return std::nullopt; 141 return stride; 142 } 143 144 // Return true if the transfer op can be converted to a MMA matrix load. 145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 147 readOp.getVectorType().getRank() != 2) 148 return false; 149 if (!getStaticallyKnownRowStride(readOp.getShapedType())) 150 return false; 151 152 // Only allow integer types if the signedness can be inferred. 153 if (readOp.getVectorType().getElementType().isInteger(8)) 154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 155 !isa<arith::ExtUIOp>(*readOp->user_begin()))) 156 return false; 157 158 AffineMap map = readOp.getPermutationMap(); 159 MLIRContext *ctx = readOp.getContext(); 160 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 161 AffineExpr zero = getAffineConstantExpr(0, ctx); 162 auto broadcastInnerDim = 163 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 164 return map.isMinorIdentity() || map == broadcastInnerDim || 165 isTransposeMatrixLoadMap(map); 166 } 167 168 // Return true if the transfer op can be converted to a MMA matrix store. 169 static bool 170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 171 // TODO: support 0-d corner case. 172 if (writeOp.getTransferRank() == 0) 173 return false; 174 175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 176 writeOp.getVectorType().getRank() != 2) 177 return false; 178 if (!getStaticallyKnownRowStride(writeOp.getShapedType())) 179 return false; 180 // TODO: Support transpose once it is added to GPU dialect ops. 181 if (!writeOp.getPermutationMap().isMinorIdentity()) 182 return false; 183 return true; 184 } 185 186 /// Return true if the constant is a splat to a 2D vector so that it can be 187 /// converted to a MMA constant matrix op. 188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 189 auto vecType = dyn_cast<VectorType>(constantOp.getType()); 190 if (!vecType || vecType.getRank() != 2) 191 return false; 192 return isa<SplatElementsAttr>(constantOp.getValue()); 193 } 194 195 /// Return true if this is a broadcast from scalar to a 2D vector. 196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 197 return broadcastOp.getResultVectorType().getRank() == 2; 198 } 199 200 /// Return true if this integer extend op can be folded into a contract op. 201 template <typename ExtOpTy> 202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 203 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) 204 return false; 205 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>); 206 } 207 208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 209 210 /// Return the MMA elementwise enum associated with `op` if it is supported. 211 /// Return `std::nullopt` otherwise. 212 static std::optional<gpu::MMAElementwiseOp> 213 convertElementwiseOpToMMA(Operation *op) { 214 if (isa<arith::AddFOp>(op)) 215 return gpu::MMAElementwiseOp::ADDF; 216 if (isa<arith::MulFOp>(op)) 217 return gpu::MMAElementwiseOp::MULF; 218 if (isa<arith::SubFOp>(op)) 219 return gpu::MMAElementwiseOp::SUBF; 220 if (isa<arith::MaximumFOp>(op)) 221 return gpu::MMAElementwiseOp::MAXF; 222 if (isa<arith::MinimumFOp>(op)) 223 return gpu::MMAElementwiseOp::MINF; 224 if (isa<arith::DivFOp>(op)) 225 return gpu::MMAElementwiseOp::DIVF; 226 if (isa<arith::AddIOp>(op)) 227 return gpu::MMAElementwiseOp::ADDI; 228 if (isa<arith::MulIOp>(op)) 229 return gpu::MMAElementwiseOp::MULI; 230 if (isa<arith::SubIOp>(op)) 231 return gpu::MMAElementwiseOp::SUBI; 232 if (isa<arith::DivSIOp>(op)) 233 return gpu::MMAElementwiseOp::DIVS; 234 if (isa<arith::DivUIOp>(op)) 235 return gpu::MMAElementwiseOp::DIVU; 236 if (isa<arith::NegFOp>(op)) 237 return gpu::MMAElementwiseOp::NEGATEF; 238 if (isa<arith::ExtFOp>(op)) 239 return gpu::MMAElementwiseOp::EXTF; 240 return std::nullopt; 241 } 242 243 /// Return true if the op is supported as elementwise op on MMAMatrix type. 244 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 245 return convertElementwiseOpToMMA(op).has_value(); 246 } 247 248 /// Returns true if the extract strided slice op is supported with `mma.sync` 249 /// path. 250 static bool 251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 252 253 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 254 nvgpu::getWarpMatrixInfo(op); 255 if (failed(warpMatrixInfo)) 256 return false; 257 258 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 259 if (failed(contractOp)) 260 return false; 261 262 // Handle vector.extract_strided_slice on registers containing 263 // matrixB and matrixC operands. vector.extract_strided_slice op 264 // is not supported on registers containing matrixA operands. 265 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 266 return (cast<VectorType>(op->getResult(0).getType()) == 267 cast<VectorType>((*contractOp).getRhs().getType())); 268 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 269 return (cast<VectorType>(op->getResult(0).getType()) == 270 cast<VectorType>((*contractOp).getAcc().getType())); 271 272 return false; 273 } 274 275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 276 if (isa<scf::ForOp, scf::YieldOp>(op)) 277 return true; 278 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 279 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) 280 : transferReadSupportsMMAMatrixType(transferRead); 281 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 282 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) 283 : transferWriteSupportsMMAMatrixType(transferWrite); 284 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 285 return useNvGpu && 286 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 287 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 288 return contractSupportsMMAMatrixType(contract, useNvGpu); 289 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 290 return constantSupportsMMAMatrixType(constant); 291 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 292 return broadcastSupportsMMAMatrixType(broadcast); 293 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 295 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 297 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 298 return fpExtendSupportsMMAMatrixType(fpExtend); 299 return elementwiseSupportsMMAMatrixType(op); 300 } 301 302 /// Return an unsorted slice handling scf.for region differently than 303 /// `getSlice`. In scf.for we only want to include as part of the slice elements 304 /// that are part of the use/def chain. 305 static SetVector<Operation *> 306 getSliceContract(Operation *op, 307 const BackwardSliceOptions &backwardSliceOptions, 308 const ForwardSliceOptions &forwardSliceOptions) { 309 SetVector<Operation *> slice; 310 slice.insert(op); 311 unsigned currentIndex = 0; 312 SetVector<Operation *> backwardSlice; 313 SetVector<Operation *> forwardSlice; 314 while (currentIndex != slice.size()) { 315 auto *currentOp = (slice)[currentIndex]; 316 // Compute and insert the backwardSlice starting from currentOp. 317 backwardSlice.clear(); 318 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 319 slice.insert(backwardSlice.begin(), backwardSlice.end()); 320 321 // Compute and insert the forwardSlice starting from currentOp. 322 forwardSlice.clear(); 323 // Special case for ForOp, we don't want to include the whole region but 324 // only the value using the region arguments. 325 // TODO: We should refine this to only care about the region arguments being 326 // converted to matrix type. 327 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 328 for (Value forOpResult : forOp.getResults()) 329 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 330 for (BlockArgument &arg : forOp.getRegionIterArgs()) 331 getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 332 } else { 333 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 334 } 335 slice.insert(forwardSlice.begin(), forwardSlice.end()); 336 ++currentIndex; 337 } 338 return slice; 339 } 340 341 // Analyze slice of operations based on convert op to figure out if the whole 342 // slice can be converted to MMA operations. 343 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 344 bool useNvGpu) { 345 auto hasVectorDest = [](Operation *op) { 346 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>); 347 }; 348 BackwardSliceOptions backwardSliceOptions; 349 backwardSliceOptions.filter = hasVectorDest; 350 351 auto hasVectorSrc = [](Operation *op) { 352 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>); 353 }; 354 ForwardSliceOptions forwardSliceOptions; 355 forwardSliceOptions.filter = hasVectorSrc; 356 357 SetVector<Operation *> opToConvert; 358 op->walk([&](vector::ContractionOp contract) { 359 if (opToConvert.contains(contract.getOperation())) 360 return; 361 SetVector<Operation *> dependentOps = 362 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 363 // If any instruction cannot use MMA matrix type drop the whole 364 // chain. MMA matrix are stored in an opaque type so they cannot be used 365 // by all operations. 366 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 367 if (!supportsMMaMatrixType(op, useNvGpu)) { 368 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); 369 return true; 370 } 371 return false; 372 })) 373 return; 374 375 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 376 }); 377 // Sort the operations so that we can convert them in topological order. 378 return topologicalSort(opToConvert); 379 } 380 381 namespace { 382 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 383 // to MMA matmul. 384 struct PrepareContractToGPUMMA 385 : public OpRewritePattern<vector::ContractionOp> { 386 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 387 388 LogicalResult matchAndRewrite(vector::ContractionOp op, 389 PatternRewriter &rewriter) const override { 390 Location loc = op.getLoc(); 391 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 392 393 // Set up the parallel/reduction structure in right form. 394 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 395 auto infer = [&](MapList m) { 396 return AffineMap::inferFromExprList(m, op.getContext()); 397 }; 398 AffineExpr m, n, k; 399 bindDims(rewriter.getContext(), m, n, k); 400 static constexpr std::array<int64_t, 2> perm = {1, 0}; 401 auto iteratorTypes = op.getIteratorTypes().getValue(); 402 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 403 if (!(vector::isParallelIterator(iteratorTypes[0]) && 404 vector::isParallelIterator(iteratorTypes[1]) && 405 vector::isReductionIterator(iteratorTypes[2]))) 406 return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 407 // 408 // Two outer parallel, one inner reduction (matmat flavor). 409 // 410 // This is the classical row-major matmul, nothing to do. 411 if (maps == infer({{m, k}, {k, n}, {m, n}})) 412 return rewriter.notifyMatchFailure(op, "contraction already prepared"); 413 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 414 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 415 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 416 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 417 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 418 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 419 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 420 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 421 std::swap(rhs, lhs); 422 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 423 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 424 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 425 std::swap(rhs, lhs); 426 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 427 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 428 std::swap(lhs, rhs); 429 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 430 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 431 std::swap(lhs, rhs); 432 } else { 433 // TODO: llvm_unreachable ? 434 return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 435 } 436 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 437 op, lhs, rhs, res, 438 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 439 op.getIteratorTypes()); 440 return success(); 441 } 442 }; 443 444 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports 445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 446 // respectively. We can fold the transpose operation when loading the data from 447 // Shared Memory to registers. 448 struct CombineTransferReadOpTranspose final 449 : public OpRewritePattern<vector::TransposeOp> { 450 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 451 452 LogicalResult matchAndRewrite(vector::TransposeOp op, 453 PatternRewriter &rewriter) const override { 454 // Look through integer extend ops. 455 Value source = op.getVector(); 456 Type resultType = op.getType(); 457 Operation *extOp; 458 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 459 (extOp = source.getDefiningOp<arith::ExtUIOp>()) || 460 (extOp = source.getDefiningOp<arith::ExtFOp>())) { 461 source = extOp->getOperand(0); 462 resultType = 463 VectorType::get(cast<VectorType>(resultType).getShape(), 464 cast<VectorType>(source.getType()).getElementType()); 465 } 466 467 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 468 if (!transferReadOp) 469 return rewriter.notifyMatchFailure(op, "no transfer read"); 470 471 // TODO: support 0-d corner case. 472 if (transferReadOp.getTransferRank() == 0) 473 return rewriter.notifyMatchFailure(op, "0-D transfer read"); 474 475 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 476 return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 477 478 AffineMap permutationMap = 479 AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); 480 AffineMap newMap = 481 permutationMap.compose(transferReadOp.getPermutationMap()); 482 483 auto loc = op.getLoc(); 484 Value result = 485 rewriter 486 .create<vector::TransferReadOp>( 487 loc, resultType, transferReadOp.getSource(), 488 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 489 transferReadOp.getPadding(), transferReadOp.getMask(), 490 transferReadOp.getInBoundsAttr()) 491 .getResult(); 492 493 // Fuse through the integer extend op. 494 if (extOp) { 495 if (isa<arith::ExtSIOp>(extOp)) 496 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 497 .getResult(); 498 else if (isa<arith::ExtUIOp>(extOp)) 499 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 500 .getResult(); 501 else 502 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) 503 .getResult(); 504 } 505 506 rewriter.replaceOp(op, result); 507 return success(); 508 } 509 }; 510 511 } // namespace 512 513 // MMA types have different layout based on how they are used in matmul ops. 514 // Figure the right layout to use by looking at op uses. 515 // TODO: Change the GPU dialect to abstract the layout at the this level and 516 // only care about it during lowering to NVVM. 517 static const char *inferFragType(Operation *op) { 518 // We can have arith.ext ops before reaching contract ops. See through them 519 // and other kinds of elementwise ops. 520 if (op->hasOneUse()) { 521 Operation *userOp = *op->user_begin(); 522 if (userOp->hasTrait<OpTrait::Elementwise>()) 523 return inferFragType(userOp); 524 } 525 526 for (Operation *users : op->getUsers()) { 527 auto contract = dyn_cast<vector::ContractionOp>(users); 528 if (!contract) 529 continue; 530 assert(op->getNumResults() == 1); 531 if (contract.getLhs() == op->getResult(0)) 532 return "AOp"; 533 if (contract.getRhs() == op->getResult(0)) 534 return "BOp"; 535 } 536 return "COp"; 537 } 538 539 static LogicalResult 540 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 541 llvm::DenseMap<Value, Value> &valueMapping) { 542 OpBuilder::InsertionGuard g(rewriter); 543 rewriter.setInsertionPoint(op); 544 545 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 546 assert(transferReadSupportsMMAMatrixType(op) && 547 "expected convertible operation"); 548 549 std::optional<int64_t> stride = 550 getStaticallyKnownRowStride(op.getShapedType()); 551 if (!stride.has_value()) { 552 LLVM_DEBUG(DBGS() << "no stride\n"); 553 return rewriter.notifyMatchFailure(op, "no stride"); 554 } 555 556 AffineMap map = op.getPermutationMap(); 557 bool isTranspose = isTransposeMatrixLoadMap(map); 558 559 // Handle broadcast by setting the stride to 0. 560 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { 561 assert(cstExpr.getValue() == 0); 562 stride = 0; 563 } 564 565 Value mappingResult = op.getResult(); 566 auto elType = op.getVectorType().getElementType(); 567 const char *fragType = inferFragType(op); 568 if (op->hasOneUse()) { 569 auto *user = *op->user_begin(); 570 // Infer the signedness of the mma type from the integer extend. 571 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) { 572 elType = IntegerType::get( 573 op.getContext(), cast<IntegerType>(elType).getWidth(), 574 isa<arith::ExtSIOp>(user) ? IntegerType::Signed 575 : IntegerType::Unsigned); 576 mappingResult = user->getResult(0); 577 } 578 } 579 gpu::MMAMatrixType type = 580 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 581 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 582 op.getLoc(), type, op.getSource(), op.getIndices(), 583 rewriter.getIndexAttr(*stride), 584 isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 585 valueMapping[mappingResult] = load; 586 587 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 588 return success(); 589 } 590 591 static LogicalResult 592 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 593 llvm::DenseMap<Value, Value> &valueMapping) { 594 OpBuilder::InsertionGuard g(rewriter); 595 rewriter.setInsertionPoint(op); 596 597 assert(transferWriteSupportsMMAMatrixType(op)); 598 std::optional<int64_t> stride = 599 getStaticallyKnownRowStride(op.getShapedType()); 600 if (!stride.has_value()) { 601 LLVM_DEBUG(DBGS() << "no stride\n"); 602 return rewriter.notifyMatchFailure(op, "no stride"); 603 } 604 605 auto it = valueMapping.find(op.getVector()); 606 if (it == valueMapping.end()) { 607 LLVM_DEBUG(DBGS() << "no mapping\n"); 608 return rewriter.notifyMatchFailure(op, "no mapping"); 609 } 610 611 Value matrix = it->second; 612 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 613 op.getLoc(), matrix, op.getSource(), op.getIndices(), 614 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 615 (void)store; 616 617 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 618 619 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 620 rewriter.eraseOp(op); 621 return success(); 622 } 623 624 /// Returns the vector type which represents a matrix fragment. 625 static VectorType 626 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 627 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 628 regInfo.elementsPerRegister}; 629 Type elType = regInfo.registerLLVMType; 630 if (auto vecType = dyn_cast<VectorType>(elType)) 631 elType = vecType.getElementType(); 632 return VectorType::get(shape, elType); 633 } 634 635 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 636 static LogicalResult 637 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 638 llvm::DenseMap<Value, Value> &valueMapping) { 639 OpBuilder::InsertionGuard g(rewriter); 640 rewriter.setInsertionPoint(op); 641 642 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 643 nvgpu::getWarpMatrixInfo(op); 644 if (failed(warpMatrixInfo)) { 645 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 646 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 647 } 648 649 FailureOr<nvgpu::FragmentElementInfo> regInfo = 650 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 651 if (failed(regInfo)) { 652 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 653 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 654 } 655 656 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 657 auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 658 if (!dense) { 659 LLVM_DEBUG(DBGS() << "not a splat\n"); 660 return rewriter.notifyMatchFailure(op, "not a splat"); 661 } 662 663 Value result = rewriter.create<arith::ConstantOp>( 664 op.getLoc(), vectorType, 665 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 666 valueMapping[op.getResult()] = result; 667 return success(); 668 } 669 670 /// Check if the loaded matrix operand requires transposed. 671 /// Transposed Map Example: 672 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 673 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 674 /// The code below checks if the output 2D is transposed using a generalized 675 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 676 /// Returns : true; if m > n, false o.w. 677 static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 678 mlir::AffineMap map = op.getPermutationMap(); 679 680 if (map.getNumResults() != 2) { 681 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 682 "is not a 2d operand\n"); 683 return failure(); 684 } 685 686 // Output 2D matrix dimensions in the order of d0, d1. 687 mlir::AffineExpr dM = map.getResult(0); 688 mlir::AffineExpr dN = map.getResult(1); 689 690 // Find the position of these expressions in the input. 691 auto exprM = dyn_cast<AffineDimExpr>(dM); 692 auto exprN = dyn_cast<AffineDimExpr>(dN); 693 694 if (!exprM || !exprN) { 695 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 696 "expressions, then transpose cannot be determined.\n"); 697 return failure(); 698 } 699 700 return exprM.getPosition() > exprN.getPosition(); 701 } 702 703 static LogicalResult 704 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 705 llvm::DenseMap<Value, Value> &valueMapping) { 706 OpBuilder::InsertionGuard g(rewriter); 707 rewriter.setInsertionPoint(op); 708 Location loc = op->getLoc(); 709 710 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 711 nvgpu::getWarpMatrixInfo(op); 712 if (failed(warpMatrixInfo)) { 713 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 714 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 715 } 716 717 FailureOr<nvgpu::FragmentElementInfo> regInfo = 718 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 719 if (failed(regInfo)) { 720 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 721 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 722 } 723 724 FailureOr<bool> transpose = isTransposed(op); 725 if (failed(transpose)) { 726 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 727 return rewriter.notifyMatchFailure( 728 op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 729 } 730 731 FailureOr<nvgpu::LdMatrixParams> params = 732 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 733 734 if (failed(params)) { 735 LLVM_DEBUG( 736 DBGS() 737 << "failed to convert vector.transfer_read to ldmatrix. " 738 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 739 return rewriter.notifyMatchFailure( 740 op, "failed to convert vector.transfer_read to ldmatrix; this op " 741 "likely should not be converted to a nvgpu.ldmatrix call."); 742 } 743 744 // Adjust the load offset. 745 auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 746 FailureOr<AffineMap> offsets = 747 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 748 if (failed(offsets)) { 749 LLVM_DEBUG(DBGS() << "no offsets\n"); 750 return rewriter.notifyMatchFailure(op, "no offsets"); 751 } 752 753 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 754 755 SmallVector<Value, 4> indices; 756 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 757 indices); 758 759 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 760 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 761 valueMapping[op] = newOp->getResult(0); 762 return success(); 763 } 764 765 static LogicalResult 766 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 767 llvm::DenseMap<Value, Value> &valueMapping) { 768 OpBuilder::InsertionGuard g(rewriter); 769 rewriter.setInsertionPoint(op); 770 771 Location loc = op.getLoc(); 772 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 773 nvgpu::getWarpMatrixInfo(op); 774 if (failed(warpMatrixInfo)) 775 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 776 FailureOr<nvgpu::FragmentElementInfo> regInfo = 777 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 778 if (failed(regInfo)) { 779 return rewriter.notifyMatchFailure( 780 op, "Failed to deduce register fragment type during " 781 "conversion to distributed non-ldmatrix compatible load"); 782 } 783 784 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 785 SmallVector<Value, 4> elements; 786 787 // This is the individual element type. 788 Type loadedElType = regInfo->registerLLVMType; 789 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 790 791 Value fill = rewriter.create<arith::ConstantOp>( 792 op.getLoc(), vectorType.getElementType(), 793 rewriter.getZeroAttr(vectorType.getElementType())); 794 Value result = 795 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 796 797 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 798 799 // If we are not transposing, then we can use vectorized loads. Otherwise, we 800 // must load each element individually. 801 if (!isTransposeLoad) { 802 if (!isa<VectorType>(loadedElType)) { 803 loadedElType = VectorType::get({1}, loadedElType); 804 } 805 806 for (int i = 0; i < vectorType.getShape()[0]; i++) { 807 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 808 rewriter, op.getLoc(), *warpMatrixInfo); 809 if (failed(coords)) 810 return rewriter.notifyMatchFailure(op, "no coords"); 811 812 Value logicalValueId = rewriter.create<arith::ConstantOp>( 813 loc, rewriter.getIndexType(), 814 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 815 SmallVector<Value, 4> newIndices; 816 getXferIndices<vector::TransferReadOp>( 817 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 818 819 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 820 op.getSource(), newIndices); 821 result = rewriter.create<vector::InsertOp>(loc, el, result, i); 822 } 823 } else { 824 if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 825 loadedElType = vecType.getElementType(); 826 } 827 for (int i = 0; i < vectorType.getShape()[0]; i++) { 828 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 829 innerIdx++) { 830 831 Value logicalValueId = rewriter.create<arith::ConstantOp>( 832 loc, rewriter.getIndexType(), 833 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 834 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 835 rewriter, op.getLoc(), *warpMatrixInfo); 836 if (failed(coords)) 837 return rewriter.notifyMatchFailure(op, "no coords"); 838 839 SmallVector<Value, 4> newIndices; 840 getXferIndices<vector::TransferReadOp>( 841 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 842 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 843 op.getSource(), newIndices); 844 result = rewriter.create<vector::InsertOp>( 845 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 846 } 847 } 848 } 849 850 valueMapping[op.getResult()] = result; 851 return success(); 852 } 853 854 /// Return true if this is a shared memory memref type. 855 static bool isSharedMemory(MemRefType type) { 856 auto addressSpace = 857 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 858 return addressSpace && 859 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); 860 } 861 862 /// Converts a `vector.transfer_read` operation directly to either a 863 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 864 /// used when converting to `nvgpu.mma.sync` operations. 865 static LogicalResult 866 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 867 llvm::DenseMap<Value, Value> &valueMapping) { 868 OpBuilder::InsertionGuard g(rewriter); 869 rewriter.setInsertionPoint(op); 870 871 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 872 nvgpu::getWarpMatrixInfo(op); 873 if (failed(warpMatrixInfo)) 874 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 875 876 bool isLdMatrixCompatible = 877 isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 878 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 879 880 VectorType vecTy = op.getVectorType(); 881 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 882 883 // When we are transposing the B operand, ldmatrix will only work if we have 884 // at least 8 rows to read and the width to read for the transpose is 128 885 // bits. 886 if (!op.getPermutationMap().isMinorIdentity() && 887 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 888 vecTy.getDimSize(0) * bitWidth < 128)) 889 isLdMatrixCompatible = false; 890 891 if (!isLdMatrixCompatible) 892 return createNonLdMatrixLoads(rewriter, op, valueMapping); 893 894 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 895 } 896 897 static LogicalResult 898 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 899 llvm::DenseMap<Value, Value> &valueMapping) { 900 OpBuilder::InsertionGuard g(rewriter); 901 rewriter.setInsertionPoint(op); 902 903 Location loc = op->getLoc(); 904 auto it = valueMapping.find(op.getVector()); 905 if (it == valueMapping.end()) 906 return rewriter.notifyMatchFailure(op, "no mapping"); 907 Value matrix = it->second; 908 909 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 910 nvgpu::getWarpMatrixInfo(op); 911 if (failed(warpMatrixInfo)) 912 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 913 FailureOr<nvgpu::FragmentElementInfo> regInfo = 914 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 915 if (failed(regInfo)) 916 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 917 918 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 919 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 920 921 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 922 Value logicalValueId = rewriter.create<arith::ConstantOp>( 923 loc, rewriter.getIndexType(), 924 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 925 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 926 rewriter, op.getLoc(), *warpMatrixInfo); 927 if (failed(coords)) 928 return rewriter.notifyMatchFailure(op, "no coords"); 929 930 Value el = 931 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 932 SmallVector<Value, 4> newIndices; 933 getXferIndices<vector::TransferWriteOp>( 934 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 935 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 936 } 937 938 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 939 rewriter.eraseOp(op); 940 return success(); 941 } 942 943 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 944 SmallVectorImpl<int64_t> &results) { 945 for (auto attr : arrayAttr) 946 results.push_back(cast<IntegerAttr>(attr).getInt()); 947 } 948 949 static LogicalResult 950 convertExtractStridedSlice(RewriterBase &rewriter, 951 vector::ExtractStridedSliceOp op, 952 llvm::DenseMap<Value, Value> &valueMapping) { 953 OpBuilder::InsertionGuard g(rewriter); 954 rewriter.setInsertionPoint(op); 955 956 Location loc = op->getLoc(); 957 958 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 959 nvgpu::getWarpMatrixInfo(op); 960 if (failed(warpMatrixInfo)) 961 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 962 963 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 964 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 965 if (failed(mmaSyncFragmentInfo)) 966 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 967 968 // Find the vector.transer_read whose result vector is being sliced. 969 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 970 if (!transferReadOp) 971 return rewriter.notifyMatchFailure(op, "no transfer read"); 972 973 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 974 if (failed(warpMatrixInfo)) 975 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 976 977 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 978 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 979 if (failed(ldFragmentInfo)) 980 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 981 982 assert( 983 (mmaSyncFragmentInfo->elementsPerRegister == 984 ldFragmentInfo->elementsPerRegister) && 985 "Number of elements per register should be same for load and mma.sync"); 986 987 // Create vector.extract_strided_slice op for thread-owned fragments. 988 std::array<int64_t, 2> strides = {1, 989 1}; // stride for extract slice is always 1. 990 std::array<int64_t, 2> sliceShape = { 991 mmaSyncFragmentInfo->numRegistersPerFragment, 992 mmaSyncFragmentInfo->elementsPerRegister}; 993 auto it = valueMapping.find(transferReadOp); 994 if (it == valueMapping.end()) 995 return rewriter.notifyMatchFailure(op, "no mapping"); 996 auto sourceVector = it->second; 997 998 // offset and sizes at warp-level of onwership. 999 SmallVector<int64_t> offsets; 1000 populateFromInt64AttrArray(op.getOffsets(), offsets); 1001 1002 SmallVector<int64_t> sizes; 1003 populateFromInt64AttrArray(op.getSizes(), sizes); 1004 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 1005 1006 // Compute offset in vector registers. Note that the mma.sync vector registers 1007 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 1008 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 1009 std::array<int64_t, 2> sliceOffset = {0, 0}; 1010 1011 if (offsets[0] && offsets[1]) 1012 return op->emitError() << "Slicing fragments in 2D is not supported. "; 1013 if (offsets[0]) 1014 sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1015 else if (offsets[1]) 1016 sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1017 1018 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1019 loc, sourceVector, sliceOffset, sliceShape, strides); 1020 1021 valueMapping[op] = newOp; 1022 return success(); 1023 } 1024 1025 static LogicalResult 1026 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1027 llvm::DenseMap<Value, Value> &valueMapping) { 1028 OpBuilder::InsertionGuard g(rewriter); 1029 rewriter.setInsertionPoint(op); 1030 1031 auto itA = valueMapping.find(op.getLhs()); 1032 auto itB = valueMapping.find(op.getRhs()); 1033 auto itC = valueMapping.find(op.getAcc()); 1034 if (itA == valueMapping.end() || itB == valueMapping.end() || 1035 itC == valueMapping.end()) 1036 return rewriter.notifyMatchFailure(op, "no mapping"); 1037 Value opA = itA->second, opB = itB->second, opC = itC->second; 1038 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 1039 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 1040 /*b_transpose=*/UnitAttr()); 1041 valueMapping[op.getResult()] = matmul; 1042 return success(); 1043 } 1044 1045 static LogicalResult 1046 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 1047 llvm::DenseMap<Value, Value> &valueMapping) { 1048 OpBuilder::InsertionGuard g(rewriter); 1049 rewriter.setInsertionPoint(op); 1050 1051 auto itA = valueMapping.find(op.getLhs()); 1052 auto itB = valueMapping.find(op.getRhs()); 1053 auto itC = valueMapping.find(op.getAcc()); 1054 if (itA == valueMapping.end() || itB == valueMapping.end() || 1055 itC == valueMapping.end()) 1056 return rewriter.notifyMatchFailure(op, "no mapping"); 1057 Value opA = itA->second, opB = itB->second, opC = itC->second; 1058 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 1059 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 1060 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 1061 Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 1062 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 1063 valueMapping[op.getResult()] = matmul; 1064 return success(); 1065 } 1066 1067 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 1068 static LogicalResult 1069 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 1070 llvm::DenseMap<Value, Value> &valueMapping) { 1071 OpBuilder::InsertionGuard g(rewriter); 1072 rewriter.setInsertionPoint(op); 1073 1074 assert(constantSupportsMMAMatrixType(op)); 1075 1076 auto splat = 1077 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 1078 auto scalarConstant = 1079 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 1080 const char *fragType = inferFragType(op); 1081 auto vecType = cast<VectorType>(op.getType()); 1082 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1083 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1084 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1085 op.getLoc(), type, scalarConstant); 1086 valueMapping[op.getResult()] = matrix; 1087 return success(); 1088 } 1089 1090 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 1091 static LogicalResult 1092 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 1093 llvm::DenseMap<Value, Value> &valueMapping) { 1094 OpBuilder::InsertionGuard g(rewriter); 1095 rewriter.setInsertionPoint(op); 1096 1097 assert(broadcastSupportsMMAMatrixType(op)); 1098 1099 const char *fragType = inferFragType(op); 1100 auto vecType = op.getResultVectorType(); 1101 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1102 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1103 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1104 op.getLoc(), type, op.getSource()); 1105 valueMapping[op.getResult()] = matrix; 1106 return success(); 1107 } 1108 1109 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 1110 // updated and needs to be updated separately for the loop to be correct. 1111 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 1112 scf::ForOp loop, 1113 ValueRange newInitArgs) { 1114 OpBuilder::InsertionGuard g(rewriter); 1115 rewriter.setInsertionPoint(loop); 1116 1117 // Create a new loop before the existing one, with the extra operands. 1118 rewriter.setInsertionPoint(loop); 1119 auto operands = llvm::to_vector<4>(loop.getInitArgs()); 1120 llvm::append_range(operands, newInitArgs); 1121 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 1122 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 1123 operands); 1124 rewriter.eraseBlock(newLoop.getBody()); 1125 1126 newLoop.getRegion().getBlocks().splice( 1127 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); 1128 for (Value operand : newInitArgs) 1129 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 1130 1131 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 1132 loop.getNumResults()))) 1133 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 1134 1135 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 1136 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 1137 LLVM_DEBUG(DBGS() << "erase: " << loop); 1138 1139 rewriter.eraseOp(loop); 1140 return newLoop; 1141 } 1142 1143 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 1144 llvm::DenseMap<Value, Value> &valueMapping) { 1145 OpBuilder::InsertionGuard g(rewriter); 1146 rewriter.setInsertionPoint(op); 1147 1148 SmallVector<Value> newOperands; 1149 SmallVector<std::pair<size_t, size_t>> argMapping; 1150 for (const auto &operand : llvm::enumerate(op.getInitArgs())) { 1151 auto it = valueMapping.find(operand.value()); 1152 if (it == valueMapping.end()) { 1153 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 1154 continue; 1155 } 1156 argMapping.push_back(std::make_pair( 1157 operand.index(), op.getInitArgs().size() + newOperands.size())); 1158 newOperands.push_back(it->second); 1159 } 1160 1161 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 1162 Block &loopBody = *newForOp.getBody(); 1163 for (auto mapping : argMapping) { 1164 valueMapping[newForOp.getResult(mapping.first)] = 1165 newForOp.getResult(mapping.second); 1166 valueMapping[loopBody.getArgument(mapping.first + 1167 newForOp.getNumInductionVars())] = 1168 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 1169 } 1170 1171 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 1172 return success(); 1173 } 1174 1175 static LogicalResult 1176 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 1177 llvm::DenseMap<Value, Value> &valueMapping) { 1178 OpBuilder::InsertionGuard g(rewriter); 1179 rewriter.setInsertionPoint(op); 1180 1181 auto loop = cast<scf::ForOp>(op->getParentOp()); 1182 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1183 for (const auto &operand : llvm::enumerate(op.getOperands())) { 1184 auto it = valueMapping.find(operand.value()); 1185 if (it == valueMapping.end()) 1186 continue; 1187 // Replace the yield of old value with the for op argument to make it easier 1188 // to remove the dead code. 1189 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; 1190 yieldOperands.push_back(it->second); 1191 } 1192 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 1193 1194 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 1195 rewriter.eraseOp(op); 1196 return success(); 1197 } 1198 1199 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 1200 static LogicalResult 1201 convertElementwiseOp(RewriterBase &rewriter, Operation *op, 1202 gpu::MMAElementwiseOp opType, 1203 llvm::DenseMap<Value, Value> &valueMapping) { 1204 OpBuilder::InsertionGuard g(rewriter); 1205 rewriter.setInsertionPoint(op); 1206 1207 SmallVector<Value> matrixOperands; 1208 for (Value operand : op->getOperands()) { 1209 auto it = valueMapping.find(operand); 1210 if (it == valueMapping.end()) 1211 return rewriter.notifyMatchFailure(op, "no mapping"); 1212 matrixOperands.push_back(it->second); 1213 } 1214 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType()); 1215 if (opType == gpu::MMAElementwiseOp::EXTF) { 1216 // The floating point extension case has a different result type. 1217 auto vectorType = cast<VectorType>(op->getResultTypes()[0]); 1218 resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1219 vectorType.getElementType(), 1220 resultType.getOperand()); 1221 } 1222 1223 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1224 op->getLoc(), resultType, matrixOperands, opType); 1225 valueMapping[op->getResult(0)] = newOp; 1226 return success(); 1227 } 1228 1229 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 1230 bool useNvGpu) { 1231 if (!useNvGpu) { 1232 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1233 patterns.getContext()); 1234 return; 1235 } 1236 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1237 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1238 } 1239 1240 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 1241 Operation *rootOp) { 1242 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1243 llvm::DenseMap<Value, Value> valueMapping; 1244 1245 auto globalRes = LogicalResult::success(); 1246 for (Operation *op : ops) { 1247 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 1248 // Apparently callers do not want to early exit on failure here. 1249 auto res = LogicalResult::success(); 1250 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 1251 res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1252 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 1253 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1254 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 1255 res = convertContractOp(rewriter, contractOp, valueMapping); 1256 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 1257 res = convertConstantOp(rewriter, constantOp, valueMapping); 1258 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 1259 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 1260 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 1261 res = convertForOp(rewriter, forOp, valueMapping); 1262 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 1263 res = convertYieldOp(rewriter, yieldOp, valueMapping); 1264 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 1265 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1266 } 1267 if (failed(res)) 1268 globalRes = failure(); 1269 } 1270 return globalRes; 1271 } 1272 1273 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 1274 Operation *rootOp) { 1275 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 1276 llvm::DenseMap<Value, Value> valueMapping; 1277 for (Operation *op : ops) { 1278 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 1279 .Case([&](vector::TransferReadOp transferReadOp) { 1280 return convertTransferReadToLoads(rewriter, transferReadOp, 1281 valueMapping); 1282 }) 1283 .Case([&](vector::TransferWriteOp transferWriteOp) { 1284 return convertTransferWriteToStores(rewriter, transferWriteOp, 1285 valueMapping); 1286 }) 1287 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 1288 return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1289 valueMapping); 1290 }) 1291 .Case([&](vector::ContractionOp contractionOp) { 1292 return convertContractOpToMmaSync(rewriter, contractionOp, 1293 valueMapping); 1294 }) 1295 .Case([&](scf::ForOp forOp) { 1296 return convertForOp(rewriter, forOp, valueMapping); 1297 }) 1298 .Case([&](scf::YieldOp yieldOp) { 1299 return convertYieldOp(rewriter, yieldOp, valueMapping); 1300 }) 1301 .Case([&](arith::ConstantOp constOp) { 1302 return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 1303 }) 1304 .Default([&](Operation *op) { 1305 return op->emitError() << "unhandled vector to mma type: " << *op; 1306 }) 1307 .failed()) { 1308 return op->emitOpError() 1309 << "failed to convert op during vector-to-nvgpu conversion"; 1310 } 1311 } 1312 return success(); 1313 } 1314 1315 namespace { 1316 1317 struct ConvertVectorToGPUPass 1318 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 1319 1320 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 1321 useNvGpu.setValue(useNvGpu_); 1322 } 1323 1324 void runOnOperation() override { 1325 RewritePatternSet patterns(&getContext()); 1326 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 1327 if (failed( 1328 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 1329 return signalPassFailure(); 1330 1331 IRRewriter rewriter(&getContext()); 1332 if (useNvGpu) { 1333 if (failed( 1334 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 1335 return signalPassFailure(); 1336 return; 1337 } 1338 (void)convertVectorToMMAOps(rewriter, getOperation()); 1339 } 1340 }; 1341 1342 } // namespace 1343 1344 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 1345 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1346 } 1347