1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14 15 #include <type_traits> 16 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Arith/IR/Arith.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/Region.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/STLExtras.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 39 #define DEBUG_TYPE "vector-to-gpu" 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 41 #define DBGSNL() (llvm::dbgs() << "\n") 42 43 namespace mlir { 44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 45 #include "mlir/Conversion/Passes.h.inc" 46 } // namespace mlir 47 48 using namespace mlir; 49 50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 51 /// AffineMap representing offsets to apply to indices, the function fills 52 /// `indices` with the original indices plus the offsets. The offsets are 53 /// applied by taking into account the permutation map of the transfer op. If 54 /// the `offsetMap` has dimension placeholders, those should be provided in 55 /// `dimValues`. 56 template <typename TransferOpType> 57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 58 AffineMap offsetMap, ArrayRef<Value> dimValues, 59 SmallVector<Value, 4> &indices) { 60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 61 Location loc = xferOp.getLoc(); 62 unsigned offsetsIdx = 0; 63 for (auto expr : xferOp.getPermutationMap().getResults()) { 64 if (auto dim = expr.template dyn_cast<AffineDimExpr>()) { 65 Value prevIdx = indices[dim.getPosition()]; 66 SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end()); 67 dims.push_back(prevIdx); 68 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 69 indices[dim.getPosition()] = affine::makeComposedAffineApply( 70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 71 continue; 72 } 73 } 74 } 75 76 // Return true if the contract op can be convert to MMA matmul. 77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 78 bool useNvGpu) { 79 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 81 AffineExpr m, n, k; 82 bindDims(contract.getContext(), m, n, k); 83 auto iteratorTypes = contract.getIteratorTypes().getValue(); 84 if (!(vector::isParallelIterator(iteratorTypes[0]) && 85 vector::isParallelIterator(iteratorTypes[1]) && 86 vector::isReductionIterator(iteratorTypes[2]))) 87 return false; 88 89 // The contract needs to represent a matmul to be able to convert to 90 // MMAMatrix matmul. 91 if (!useNvGpu && 92 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 93 return false; 94 if (useNvGpu && 95 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 96 return false; 97 98 return true; 99 } 100 101 // Return true if the given map represents a transposed matrix load, 102 // i.e. (d0, d1, ...) -> (dn-1, dn-2). 103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 104 MLIRContext *ctx = permutationMap.getContext(); 105 // Local OpBuilder is fine here, we just build attributes. 106 OpBuilder b(ctx); 107 auto nDim = permutationMap.getNumDims(); 108 AffineExpr zero = b.getAffineConstantExpr(0); 109 if (nDim < 2) { 110 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 111 AffineExpr dim0 = b.getAffineDimExpr(0); 112 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 113 } 114 115 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 116 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 117 // Support both transposed and transposed+broadcasted cases. 118 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 119 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 120 } 121 122 // Return the stide for the dimension 0 of |type| if it is a memref and has a 123 // constant stride. 124 static std::optional<int64_t> 125 getMemrefConstantHorizontalStride(ShapedType type) { 126 auto memrefType = dyn_cast<MemRefType>(type); 127 if (!memrefType) 128 return false; 129 // If the memref is 0 or 1D the horizontal stride is 0. 130 if (memrefType.getRank() < 2) 131 return 0; 132 int64_t offset = 0; 133 SmallVector<int64_t, 2> strides; 134 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 135 strides.back() != 1) 136 return std::nullopt; 137 int64_t stride = strides[strides.size() - 2]; 138 if (stride == ShapedType::kDynamic) 139 return std::nullopt; 140 return stride; 141 } 142 143 // Return true if the transfer op can be converted to a MMA matrix load. 144 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, 145 bool useNvGpu) { 146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 147 readOp.getVectorType().getRank() != 2) 148 return false; 149 if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 150 return false; 151 152 // Only allow integer types if the signedness can be inferred. 153 if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8)) 154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 155 !isa<arith::ExtUIOp>(*readOp->user_begin()))) 156 return false; 157 158 AffineMap map = readOp.getPermutationMap(); 159 160 MLIRContext *ctx = readOp.getContext(); 161 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 162 AffineExpr zero = getAffineConstantExpr(0, ctx); 163 auto broadcastInnerDim = 164 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 165 166 if (!useNvGpu) { 167 bool result = map.isMinorIdentity() || map == broadcastInnerDim || 168 isTransposeMatrixLoadMap(map); 169 return result; 170 } 171 172 return true; 173 } 174 175 // Return true if the transfer op can be converted to a MMA matrix store. 176 static bool 177 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 178 // TODO: support 0-d corner case. 179 if (writeOp.getTransferRank() == 0) 180 return false; 181 182 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 183 writeOp.getVectorType().getRank() != 2) 184 return false; 185 if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 186 return false; 187 // TODO: Support transpose once it is added to GPU dialect ops. 188 if (!writeOp.getPermutationMap().isMinorIdentity()) 189 return false; 190 return true; 191 } 192 193 /// Return true if the constant is a splat to a 2D vector so that it can be 194 /// converted to a MMA constant matrix op. 195 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 196 auto vecType = dyn_cast<VectorType>(constantOp.getType()); 197 if (!vecType || vecType.getRank() != 2) 198 return false; 199 return isa<SplatElementsAttr>(constantOp.getValue()); 200 } 201 202 /// Return true if this is a broadcast from scalar to a 2D vector. 203 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 204 return broadcastOp.getResultVectorType().getRank() == 2; 205 } 206 207 /// Return true if this integer extend op can be folded into a contract op. 208 template <typename ExtOpTy> 209 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 210 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) 211 return false; 212 return llvm::all_of(extOp->getUsers(), [](Operation *user) { 213 return isa<vector::ContractionOp>(user); 214 }); 215 } 216 217 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 218 219 /// Return the MMA elementwise enum associated with `op` if it is supported. 220 /// Return `std::nullopt` otherwise. 221 static std::optional<gpu::MMAElementwiseOp> 222 convertElementwiseOpToMMA(Operation *op) { 223 if (isa<arith::AddFOp>(op)) 224 return gpu::MMAElementwiseOp::ADDF; 225 if (isa<arith::MulFOp>(op)) 226 return gpu::MMAElementwiseOp::MULF; 227 if (isa<arith::SubFOp>(op)) 228 return gpu::MMAElementwiseOp::SUBF; 229 if (isa<arith::MaxFOp>(op)) 230 return gpu::MMAElementwiseOp::MAXF; 231 if (isa<arith::MinFOp>(op)) 232 return gpu::MMAElementwiseOp::MINF; 233 if (isa<arith::DivFOp>(op)) 234 return gpu::MMAElementwiseOp::DIVF; 235 if (isa<arith::AddIOp>(op)) 236 return gpu::MMAElementwiseOp::ADDI; 237 if (isa<arith::MulIOp>(op)) 238 return gpu::MMAElementwiseOp::MULI; 239 if (isa<arith::SubIOp>(op)) 240 return gpu::MMAElementwiseOp::SUBI; 241 if (isa<arith::DivSIOp>(op)) 242 return gpu::MMAElementwiseOp::DIVS; 243 if (isa<arith::DivUIOp>(op)) 244 return gpu::MMAElementwiseOp::DIVU; 245 if (isa<arith::NegFOp>(op)) 246 return gpu::MMAElementwiseOp::NEGATEF; 247 if (isa<arith::ExtFOp>(op)) 248 return gpu::MMAElementwiseOp::EXTF; 249 return std::nullopt; 250 } 251 252 /// Return true if the op is supported as elementwise op on MMAMatrix type. 253 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 254 return convertElementwiseOpToMMA(op).has_value(); 255 } 256 257 /// Returns true if the extract strided slice op is supported with `mma.sync` 258 /// path. 259 static bool 260 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 261 262 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 263 nvgpu::getWarpMatrixInfo(op); 264 if (failed(warpMatrixInfo)) 265 return false; 266 267 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 268 if (failed(contractOp)) 269 return false; 270 271 // Handle vector.extract_strided_slice on registers containing 272 // matrixB and matrixC operands. vector.extract_strided_slice op 273 // is not supported on registers containing matrixA operands. 274 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 275 return (cast<VectorType>(op->getResult(0).getType()) == 276 cast<VectorType>((*contractOp).getRhs().getType())); 277 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 278 return (cast<VectorType>(op->getResult(0).getType()) == 279 cast<VectorType>((*contractOp).getAcc().getType())); 280 281 return false; 282 } 283 284 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 285 if (isa<scf::ForOp, scf::YieldOp>(op)) 286 return true; 287 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 288 return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); 289 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 290 return transferWriteSupportsMMAMatrixType(transferWrite); 291 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 292 return useNvGpu && 293 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 294 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 295 return contractSupportsMMAMatrixType(contract, useNvGpu); 296 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 297 return constantSupportsMMAMatrixType(constant); 298 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 299 return broadcastSupportsMMAMatrixType(broadcast); 300 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 301 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 302 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 303 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 304 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 305 return fpExtendSupportsMMAMatrixType(fpExtend); 306 return elementwiseSupportsMMAMatrixType(op); 307 } 308 309 /// Return an unsorted slice handling scf.for region differently than 310 /// `getSlice`. In scf.for we only want to include as part of the slice elements 311 /// that are part of the use/def chain. 312 static SetVector<Operation *> 313 getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions, 314 ForwardSliceOptions forwardSliceOptions) { 315 SetVector<Operation *> slice; 316 slice.insert(op); 317 unsigned currentIndex = 0; 318 SetVector<Operation *> backwardSlice; 319 SetVector<Operation *> forwardSlice; 320 while (currentIndex != slice.size()) { 321 auto *currentOp = (slice)[currentIndex]; 322 // Compute and insert the backwardSlice starting from currentOp. 323 backwardSlice.clear(); 324 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 325 slice.insert(backwardSlice.begin(), backwardSlice.end()); 326 327 // Compute and insert the forwardSlice starting from currentOp. 328 forwardSlice.clear(); 329 // Special case for ForOp, we don't want to include the whole region but 330 // only the value using the region arguments. 331 // TODO: We should refine this to only care about the region arguments being 332 // converted to matrix type. 333 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 334 for (Value forOpResult : forOp.getResults()) 335 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 336 for (BlockArgument &arg : forOp.getRegionIterArgs()) 337 getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 338 } else { 339 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 340 } 341 slice.insert(forwardSlice.begin(), forwardSlice.end()); 342 ++currentIndex; 343 } 344 return slice; 345 } 346 347 // Analyze slice of operations based on convert op to figure out if the whole 348 // slice can be converted to MMA operations. 349 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 350 bool useNvGpu) { 351 auto hasVectorDest = [](Operation *op) { 352 return llvm::any_of(op->getResultTypes(), 353 [](Type t) { return isa<VectorType>(t); }); 354 }; 355 BackwardSliceOptions backwardSliceOptions; 356 backwardSliceOptions.filter = hasVectorDest; 357 358 auto hasVectorSrc = [](Operation *op) { 359 return llvm::any_of(op->getOperandTypes(), 360 [](Type t) { return isa<VectorType>(t); }); 361 }; 362 ForwardSliceOptions forwardSliceOptions; 363 forwardSliceOptions.filter = hasVectorSrc; 364 365 SetVector<Operation *> opToConvert; 366 op->walk([&](vector::ContractionOp contract) { 367 if (opToConvert.contains(contract.getOperation())) 368 return; 369 SetVector<Operation *> dependentOps = 370 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 371 // If any instruction cannot use MMA matrix type drop the whole 372 // chain. MMA matrix are stored in an opaque type so they cannot be used 373 // by all operations. 374 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 375 return !supportsMMaMatrixType(op, useNvGpu); 376 })) 377 return; 378 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 379 }); 380 // Sort the operations so that we can convert them in topological order. 381 return topologicalSort(opToConvert); 382 } 383 384 namespace { 385 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 386 // to MMA matmul. 387 struct PrepareContractToGPUMMA 388 : public OpRewritePattern<vector::ContractionOp> { 389 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 390 391 LogicalResult matchAndRewrite(vector::ContractionOp op, 392 PatternRewriter &rewriter) const override { 393 Location loc = op.getLoc(); 394 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 395 396 // Set up the parallel/reduction structure in right form. 397 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 398 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 399 AffineExpr m, n, k; 400 bindDims(rewriter.getContext(), m, n, k); 401 static constexpr std::array<int64_t, 2> perm = {1, 0}; 402 auto iteratorTypes = op.getIteratorTypes().getValue(); 403 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 404 if (!(vector::isParallelIterator(iteratorTypes[0]) && 405 vector::isParallelIterator(iteratorTypes[1]) && 406 vector::isReductionIterator(iteratorTypes[2]))) 407 return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 408 // 409 // Two outer parallel, one inner reduction (matmat flavor). 410 // 411 // This is the classical row-major matmul, nothing to do. 412 if (maps == infer({{m, k}, {k, n}, {m, n}})) 413 return rewriter.notifyMatchFailure(op, "contraction already prepared"); 414 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 415 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 416 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 417 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 418 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 419 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 420 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 421 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 422 std::swap(rhs, lhs); 423 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 424 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 425 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 426 std::swap(rhs, lhs); 427 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 428 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 429 std::swap(lhs, rhs); 430 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 431 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 432 std::swap(lhs, rhs); 433 } else { 434 // TODO: llvm_unreachable ? 435 return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 436 } 437 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 438 op, lhs, rhs, res, 439 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 440 op.getIteratorTypes()); 441 return success(); 442 } 443 }; 444 445 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports 446 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 447 // respectively. We can fold the transpose operation when loading the data from 448 // Shared Memory to registers. 449 struct CombineTransferReadOpTranspose final 450 : public OpRewritePattern<vector::TransposeOp> { 451 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 452 453 LogicalResult matchAndRewrite(vector::TransposeOp op, 454 PatternRewriter &rewriter) const override { 455 // Look through integer extend ops. 456 Value source = op.getVector(); 457 Type resultType = op.getType(); 458 Operation *extOp; 459 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 460 (extOp = source.getDefiningOp<arith::ExtUIOp>())) { 461 source = extOp->getOperand(0); 462 resultType = 463 VectorType::get(cast<VectorType>(resultType).getShape(), 464 cast<VectorType>(source.getType()).getElementType()); 465 } 466 467 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 468 if (!transferReadOp) 469 return rewriter.notifyMatchFailure(op, "no transfer read"); 470 471 // TODO: support 0-d corner case. 472 if (transferReadOp.getTransferRank() == 0) 473 return rewriter.notifyMatchFailure(op, "0-D transfer read"); 474 475 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 476 return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 477 478 SmallVector<int64_t, 2> perm; 479 op.getTransp(perm); 480 SmallVector<unsigned, 2> permU; 481 for (int64_t o : perm) 482 permU.push_back(unsigned(o)); 483 AffineMap permutationMap = 484 AffineMap::getPermutationMap(permU, op.getContext()); 485 AffineMap newMap = 486 permutationMap.compose(transferReadOp.getPermutationMap()); 487 488 auto loc = op.getLoc(); 489 Value result = 490 rewriter 491 .create<vector::TransferReadOp>( 492 loc, resultType, transferReadOp.getSource(), 493 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 494 transferReadOp.getPadding(), transferReadOp.getMask(), 495 transferReadOp.getInBoundsAttr()) 496 .getResult(); 497 498 // Fuse through the integer extend op. 499 if (extOp) { 500 if (isa<arith::ExtSIOp>(extOp)) 501 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 502 .getResult(); 503 else 504 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 505 .getResult(); 506 } 507 508 rewriter.replaceOp(op, result); 509 return success(); 510 } 511 }; 512 513 } // namespace 514 515 // MMA types have different layout based on how they are used in matmul ops. 516 // Figure the right layout to use by looking at op uses. 517 // TODO: Change the GPU dialect to abstract the layout at the this level and 518 // only care about it during lowering to NVVM. 519 static const char *inferFragType(Operation *op) { 520 for (Operation *users : op->getUsers()) { 521 auto contract = dyn_cast<vector::ContractionOp>(users); 522 if (!contract) 523 continue; 524 assert(op->getNumResults() == 1); 525 if (contract.getLhs() == op->getResult(0)) 526 return "AOp"; 527 if (contract.getRhs() == op->getResult(0)) 528 return "BOp"; 529 } 530 return "COp"; 531 } 532 533 static LogicalResult 534 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 535 llvm::DenseMap<Value, Value> &valueMapping) { 536 OpBuilder::InsertionGuard g(rewriter); 537 rewriter.setInsertionPoint(op); 538 539 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 540 assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); 541 542 std::optional<int64_t> stride = 543 getMemrefConstantHorizontalStride(op.getShapedType()); 544 if (!stride.has_value()) { 545 LLVM_DEBUG(DBGS() << "no stride\n"); 546 return rewriter.notifyMatchFailure(op, "no stride"); 547 } 548 549 AffineMap map = op.getPermutationMap(); 550 bool isTranspose = isTransposeMatrixLoadMap(map); 551 552 // Handle broadcast by setting the stride to 0. 553 if (auto cstExpr = 554 map.getResult(isTranspose).dyn_cast<AffineConstantExpr>()) { 555 assert(cstExpr.getValue() == 0); 556 stride = 0; 557 } 558 559 Value mappingResult = op.getResult(); 560 auto elType = op.getVectorType().getElementType(); 561 const char *fragType = inferFragType(op); 562 if (op->hasOneUse()) { 563 auto user = *op->user_begin(); 564 // Infer the signedness of the mma type from the integer extend. 565 bool isSignedExtend = isa<arith::ExtSIOp>(user); 566 if (isSignedExtend || isa<arith::ExtUIOp>(user)) { 567 elType = IntegerType::get( 568 op.getContext(), cast<IntegerType>(elType).getWidth(), 569 isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); 570 mappingResult = user->getResult(0); 571 fragType = inferFragType(user); 572 } 573 } 574 gpu::MMAMatrixType type = 575 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 576 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 577 op.getLoc(), type, op.getSource(), op.getIndices(), 578 rewriter.getIndexAttr(*stride), 579 isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 580 valueMapping[mappingResult] = load; 581 582 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 583 return success(); 584 } 585 586 static LogicalResult 587 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 588 llvm::DenseMap<Value, Value> &valueMapping) { 589 OpBuilder::InsertionGuard g(rewriter); 590 rewriter.setInsertionPoint(op); 591 592 assert(transferWriteSupportsMMAMatrixType(op)); 593 std::optional<int64_t> stride = 594 getMemrefConstantHorizontalStride(op.getShapedType()); 595 if (!stride.has_value()) { 596 LLVM_DEBUG(DBGS() << "no stride\n"); 597 return rewriter.notifyMatchFailure(op, "no stride"); 598 } 599 600 auto it = valueMapping.find(op.getVector()); 601 if (it == valueMapping.end()) { 602 LLVM_DEBUG(DBGS() << "no mapping\n"); 603 return rewriter.notifyMatchFailure(op, "no mapping"); 604 } 605 606 Value matrix = it->second; 607 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 608 op.getLoc(), matrix, op.getSource(), op.getIndices(), 609 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 610 (void)store; 611 612 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 613 614 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 615 rewriter.eraseOp(op); 616 return success(); 617 } 618 619 /// Returns the vector type which represents a matrix fragment. 620 static VectorType 621 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 622 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 623 regInfo.elementsPerRegister}; 624 Type elType = regInfo.registerLLVMType; 625 if (auto vecType = dyn_cast<VectorType>(elType)) 626 elType = vecType.getElementType(); 627 return VectorType::get(shape, elType); 628 } 629 630 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 631 static LogicalResult 632 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 633 llvm::DenseMap<Value, Value> &valueMapping) { 634 OpBuilder::InsertionGuard g(rewriter); 635 rewriter.setInsertionPoint(op); 636 637 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 638 nvgpu::getWarpMatrixInfo(op); 639 if (failed(warpMatrixInfo)) { 640 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 641 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 642 } 643 644 FailureOr<nvgpu::FragmentElementInfo> regInfo = 645 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 646 if (failed(regInfo)) { 647 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 648 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 649 } 650 651 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 652 auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 653 if (!dense) { 654 LLVM_DEBUG(DBGS() << "not a splat\n"); 655 return rewriter.notifyMatchFailure(op, "not a splat"); 656 } 657 658 Value result = rewriter.create<arith::ConstantOp>( 659 op.getLoc(), vectorType, 660 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 661 valueMapping[op.getResult()] = result; 662 return success(); 663 } 664 665 /// Check if the loaded matrix operand requires transposed. 666 /// Transposed Map Example: 667 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 668 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 669 /// The code below checks if the output 2D is transposed using a generalized 670 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 671 /// Returns : true; if m > n, false o.w. 672 static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 673 mlir::AffineMap map = op.getPermutationMap(); 674 675 if (map.getNumResults() != 2) { 676 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 677 "is not a 2d operand\n"); 678 return failure(); 679 } 680 681 // Output 2D matrix dimensions in the order of d0, d1. 682 mlir::AffineExpr dM = map.getResult(0); 683 mlir::AffineExpr dN = map.getResult(1); 684 685 // Find the position of these expressions in the input. 686 auto exprM = dM.dyn_cast<AffineDimExpr>(); 687 auto exprN = dN.dyn_cast<AffineDimExpr>(); 688 689 if (!exprM || !exprN) { 690 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 691 "expressions, then transpose cannot be determined.\n"); 692 return failure(); 693 } 694 695 return exprM.getPosition() > exprN.getPosition(); 696 } 697 698 static LogicalResult 699 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 700 llvm::DenseMap<Value, Value> &valueMapping) { 701 OpBuilder::InsertionGuard g(rewriter); 702 rewriter.setInsertionPoint(op); 703 Location loc = op->getLoc(); 704 705 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 706 nvgpu::getWarpMatrixInfo(op); 707 if (failed(warpMatrixInfo)) { 708 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 709 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 710 } 711 712 FailureOr<nvgpu::FragmentElementInfo> regInfo = 713 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 714 if (failed(regInfo)) { 715 LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 716 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 717 } 718 719 FailureOr<bool> transpose = isTransposed(op); 720 if (failed(transpose)) { 721 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 722 return rewriter.notifyMatchFailure( 723 op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 724 } 725 726 FailureOr<nvgpu::LdMatrixParams> params = 727 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 728 729 if (failed(params)) { 730 LLVM_DEBUG( 731 DBGS() 732 << "failed to convert vector.transfer_read to ldmatrix. " 733 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 734 return rewriter.notifyMatchFailure( 735 op, "failed to convert vector.transfer_read to ldmatrix; this op " 736 "likely should not be converted to a nvgpu.ldmatrix call."); 737 } 738 739 // Adjust the load offset. 740 auto laneId = rewriter.create<gpu::LaneIdOp>(loc); 741 FailureOr<AffineMap> offsets = 742 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 743 if (failed(offsets)) { 744 LLVM_DEBUG(DBGS() << "no offsets\n"); 745 return rewriter.notifyMatchFailure(op, "no offsets"); 746 } 747 748 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 749 750 SmallVector<Value, 4> indices; 751 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 752 indices); 753 754 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 755 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 756 valueMapping[op] = newOp->getResult(0); 757 return success(); 758 } 759 760 static LogicalResult 761 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 762 llvm::DenseMap<Value, Value> &valueMapping) { 763 OpBuilder::InsertionGuard g(rewriter); 764 rewriter.setInsertionPoint(op); 765 766 Location loc = op.getLoc(); 767 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 768 nvgpu::getWarpMatrixInfo(op); 769 if (failed(warpMatrixInfo)) 770 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 771 FailureOr<nvgpu::FragmentElementInfo> regInfo = 772 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 773 if (failed(regInfo)) { 774 return rewriter.notifyMatchFailure( 775 op, "Failed to deduce register fragment type during " 776 "conversion to distributed non-ldmatrix compatible load"); 777 } 778 779 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 780 SmallVector<Value, 4> elements; 781 782 // This is the individual element type. 783 Type loadedElType = regInfo->registerLLVMType; 784 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 785 786 Value fill = rewriter.create<arith::ConstantOp>( 787 op.getLoc(), vectorType.getElementType(), 788 rewriter.getZeroAttr(vectorType.getElementType())); 789 Value result = 790 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 791 792 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 793 794 // If we are not transposing, then we can use vectorized loads. Otherwise, we 795 // must load each element individually. 796 if (!isTransposeLoad) { 797 if (!isa<VectorType>(loadedElType)) { 798 loadedElType = VectorType::get({1}, loadedElType); 799 } 800 801 for (int i = 0; i < vectorType.getShape()[0]; i++) { 802 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 803 rewriter, op.getLoc(), *warpMatrixInfo); 804 if (failed(coords)) 805 return rewriter.notifyMatchFailure(op, "no coords"); 806 807 Value logicalValueId = rewriter.create<arith::ConstantOp>( 808 loc, rewriter.getIndexType(), 809 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 810 SmallVector<Value, 4> newIndices; 811 getXferIndices<vector::TransferReadOp>( 812 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 813 814 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 815 op.getSource(), newIndices); 816 result = rewriter.create<vector::InsertOp>(loc, el, result, i); 817 } 818 } else { 819 if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 820 loadedElType = vecType.getElementType(); 821 } 822 for (int i = 0; i < vectorType.getShape()[0]; i++) { 823 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 824 innerIdx++) { 825 826 Value logicalValueId = rewriter.create<arith::ConstantOp>( 827 loc, rewriter.getIndexType(), 828 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 829 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 830 rewriter, op.getLoc(), *warpMatrixInfo); 831 if (failed(coords)) 832 return rewriter.notifyMatchFailure(op, "no coords"); 833 834 SmallVector<Value, 4> newIndices; 835 getXferIndices<vector::TransferReadOp>( 836 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 837 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 838 op.getSource(), newIndices); 839 result = rewriter.create<vector::InsertOp>( 840 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 841 } 842 } 843 } 844 845 valueMapping[op.getResult()] = result; 846 return success(); 847 } 848 849 /// Return true if this is a shared memory memref type. 850 static bool isSharedMemory(MemRefType type) { 851 auto addressSpace = 852 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 853 if (addressSpace && 854 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) 855 return true; 856 return false; 857 } 858 859 /// Converts a `vector.transfer_read` operation directly to either a 860 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 861 /// used when converting to `nvgpu.mma.sync` operations. 862 static LogicalResult 863 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 864 llvm::DenseMap<Value, Value> &valueMapping) { 865 OpBuilder::InsertionGuard g(rewriter); 866 rewriter.setInsertionPoint(op); 867 868 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 869 nvgpu::getWarpMatrixInfo(op); 870 if (failed(warpMatrixInfo)) 871 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 872 873 bool isLdMatrixCompatible = 874 isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 875 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 876 877 VectorType vecTy = op.getVectorType(); 878 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 879 880 // When we are transposing the B operand, ldmatrix will only work if we have 881 // at least 8 rows to read and the width to read for the transpose is 128 882 // bits. 883 if (!op.getPermutationMap().isMinorIdentity() && 884 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 885 vecTy.getDimSize(0) * bitWidth < 128)) 886 isLdMatrixCompatible = false; 887 888 if (!isLdMatrixCompatible) 889 return createNonLdMatrixLoads(rewriter, op, valueMapping); 890 891 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 892 } 893 894 static LogicalResult 895 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 896 llvm::DenseMap<Value, Value> &valueMapping) { 897 OpBuilder::InsertionGuard g(rewriter); 898 rewriter.setInsertionPoint(op); 899 900 Location loc = op->getLoc(); 901 auto it = valueMapping.find(op.getVector()); 902 if (it == valueMapping.end()) 903 return rewriter.notifyMatchFailure(op, "no mapping"); 904 Value matrix = it->second; 905 906 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 907 nvgpu::getWarpMatrixInfo(op); 908 if (failed(warpMatrixInfo)) 909 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 910 FailureOr<nvgpu::FragmentElementInfo> regInfo = 911 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 912 if (failed(regInfo)) 913 return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 914 915 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 916 Value laneId = rewriter.create<gpu::LaneIdOp>(loc); 917 918 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 919 Value logicalValueId = rewriter.create<arith::ConstantOp>( 920 loc, rewriter.getIndexType(), 921 rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 922 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 923 rewriter, op.getLoc(), *warpMatrixInfo); 924 if (failed(coords)) 925 return rewriter.notifyMatchFailure(op, "no coords"); 926 927 Value el = 928 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 929 SmallVector<Value, 4> newIndices; 930 getXferIndices<vector::TransferWriteOp>( 931 rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 932 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 933 } 934 935 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 936 rewriter.eraseOp(op); 937 return success(); 938 } 939 940 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 941 SmallVectorImpl<int64_t> &results) { 942 for (auto attr : arrayAttr) 943 results.push_back(cast<IntegerAttr>(attr).getInt()); 944 } 945 946 static LogicalResult 947 convertExtractStridedSlice(RewriterBase &rewriter, 948 vector::ExtractStridedSliceOp op, 949 llvm::DenseMap<Value, Value> &valueMapping) { 950 OpBuilder::InsertionGuard g(rewriter); 951 rewriter.setInsertionPoint(op); 952 953 Location loc = op->getLoc(); 954 955 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 956 nvgpu::getWarpMatrixInfo(op); 957 if (failed(warpMatrixInfo)) 958 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 959 960 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 961 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 962 if (failed(mmaSyncFragmentInfo)) 963 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 964 965 // Find the vector.transer_read whose result vector is being sliced. 966 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 967 if (!transferReadOp) 968 return rewriter.notifyMatchFailure(op, "no transfer read"); 969 970 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 971 if (failed(warpMatrixInfo)) 972 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 973 974 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 975 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 976 if (failed(ldFragmentInfo)) 977 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 978 979 assert( 980 (mmaSyncFragmentInfo->elementsPerRegister == 981 ldFragmentInfo->elementsPerRegister) && 982 "Number of elements per register should be same for load and mma.sync"); 983 984 // Create vector.extract_strided_slice op for thread-owned fragments. 985 std::array<int64_t, 2> strides = {1, 986 1}; // stride for extract slice is always 1. 987 std::array<int64_t, 2> sliceShape = { 988 mmaSyncFragmentInfo->numRegistersPerFragment, 989 mmaSyncFragmentInfo->elementsPerRegister}; 990 auto it = valueMapping.find(transferReadOp); 991 if (it == valueMapping.end()) 992 return rewriter.notifyMatchFailure(op, "no mapping"); 993 auto sourceVector = it->second; 994 995 // offset and sizes at warp-level of onwership. 996 SmallVector<int64_t> offsets; 997 populateFromInt64AttrArray(op.getOffsets(), offsets); 998 999 SmallVector<int64_t> sizes; 1000 populateFromInt64AttrArray(op.getSizes(), sizes); 1001 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 1002 1003 // Compute offset in vector registers. Note that the mma.sync vector registers 1004 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 1005 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 1006 std::array<int64_t, 2> sliceOffset = {0, 0}; 1007 1008 if (offsets[0] && offsets[1]) 1009 return op->emitError() << "Slicing fragments in 2D is not supported. "; 1010 if (offsets[0]) 1011 sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1012 else if (offsets[1]) 1013 sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1014 1015 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1016 loc, sourceVector, sliceOffset, sliceShape, strides); 1017 1018 valueMapping[op] = newOp; 1019 return success(); 1020 } 1021 1022 static LogicalResult 1023 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1024 llvm::DenseMap<Value, Value> &valueMapping) { 1025 OpBuilder::InsertionGuard g(rewriter); 1026 rewriter.setInsertionPoint(op); 1027 1028 auto itA = valueMapping.find(op.getLhs()); 1029 auto itB = valueMapping.find(op.getRhs()); 1030 auto itC = valueMapping.find(op.getAcc()); 1031 if (itA == valueMapping.end() || itB == valueMapping.end() || 1032 itC == valueMapping.end()) 1033 return rewriter.notifyMatchFailure(op, "no mapping"); 1034 Value opA = itA->second, opB = itB->second, opC = itC->second; 1035 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 1036 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 1037 /*b_transpose=*/UnitAttr()); 1038 valueMapping[op.getResult()] = matmul; 1039 return success(); 1040 } 1041 1042 static LogicalResult 1043 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 1044 llvm::DenseMap<Value, Value> &valueMapping) { 1045 OpBuilder::InsertionGuard g(rewriter); 1046 rewriter.setInsertionPoint(op); 1047 1048 auto itA = valueMapping.find(op.getLhs()); 1049 auto itB = valueMapping.find(op.getRhs()); 1050 auto itC = valueMapping.find(op.getAcc()); 1051 if (itA == valueMapping.end() || itB == valueMapping.end() || 1052 itC == valueMapping.end()) 1053 return rewriter.notifyMatchFailure(op, "no mapping"); 1054 Value opA = itA->second, opB = itB->second, opC = itC->second; 1055 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 1056 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 1057 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 1058 Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 1059 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 1060 valueMapping[op.getResult()] = matmul; 1061 return success(); 1062 } 1063 1064 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 1065 static LogicalResult 1066 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 1067 llvm::DenseMap<Value, Value> &valueMapping) { 1068 OpBuilder::InsertionGuard g(rewriter); 1069 rewriter.setInsertionPoint(op); 1070 1071 assert(constantSupportsMMAMatrixType(op)); 1072 1073 auto splat = 1074 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 1075 auto scalarConstant = 1076 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 1077 const char *fragType = inferFragType(op); 1078 auto vecType = cast<VectorType>(op.getType()); 1079 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1080 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1081 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1082 op.getLoc(), type, scalarConstant); 1083 valueMapping[op.getResult()] = matrix; 1084 return success(); 1085 } 1086 1087 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 1088 static LogicalResult 1089 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 1090 llvm::DenseMap<Value, Value> &valueMapping) { 1091 OpBuilder::InsertionGuard g(rewriter); 1092 rewriter.setInsertionPoint(op); 1093 1094 assert(broadcastSupportsMMAMatrixType(op)); 1095 1096 const char *fragType = inferFragType(op); 1097 auto vecType = op.getResultVectorType(); 1098 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 1099 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 1100 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 1101 op.getLoc(), type, op.getSource()); 1102 valueMapping[op.getResult()] = matrix; 1103 return success(); 1104 } 1105 1106 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 1107 // updated and needs to be updated separatly for the loop to be correct. 1108 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 1109 scf::ForOp loop, 1110 ValueRange newIterOperands) { 1111 OpBuilder::InsertionGuard g(rewriter); 1112 rewriter.setInsertionPoint(loop); 1113 1114 // Create a new loop before the existing one, with the extra operands. 1115 rewriter.setInsertionPoint(loop); 1116 auto operands = llvm::to_vector<4>(loop.getIterOperands()); 1117 operands.append(newIterOperands.begin(), newIterOperands.end()); 1118 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 1119 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 1120 operands); 1121 newLoop.getBody()->erase(); 1122 1123 newLoop.getLoopBody().getBlocks().splice( 1124 newLoop.getLoopBody().getBlocks().begin(), 1125 loop.getLoopBody().getBlocks()); 1126 for (Value operand : newIterOperands) 1127 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 1128 1129 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 1130 loop.getNumResults()))) 1131 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 1132 1133 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 1134 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 1135 LLVM_DEBUG(DBGS() << "erase: " << loop); 1136 1137 rewriter.eraseOp(loop); 1138 return newLoop; 1139 } 1140 1141 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 1142 llvm::DenseMap<Value, Value> &valueMapping) { 1143 OpBuilder::InsertionGuard g(rewriter); 1144 rewriter.setInsertionPoint(op); 1145 1146 SmallVector<Value> newOperands; 1147 SmallVector<std::pair<size_t, size_t>> argMapping; 1148 for (const auto &operand : llvm::enumerate(op.getIterOperands())) { 1149 auto it = valueMapping.find(operand.value()); 1150 if (it == valueMapping.end()) { 1151 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 1152 continue; 1153 } 1154 argMapping.push_back(std::make_pair( 1155 operand.index(), op.getNumIterOperands() + newOperands.size())); 1156 newOperands.push_back(it->second); 1157 } 1158 1159 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 1160 Block &loopBody = *newForOp.getBody(); 1161 for (auto mapping : argMapping) { 1162 valueMapping[newForOp.getResult(mapping.first)] = 1163 newForOp.getResult(mapping.second); 1164 valueMapping[loopBody.getArgument(mapping.first + 1165 newForOp.getNumInductionVars())] = 1166 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 1167 } 1168 1169 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 1170 return success(); 1171 } 1172 1173 static LogicalResult 1174 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 1175 llvm::DenseMap<Value, Value> &valueMapping) { 1176 OpBuilder::InsertionGuard g(rewriter); 1177 rewriter.setInsertionPoint(op); 1178 1179 auto loop = cast<scf::ForOp>(op->getParentOp()); 1180 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1181 for (const auto &operand : llvm::enumerate(op.getOperands())) { 1182 auto it = valueMapping.find(operand.value()); 1183 if (it == valueMapping.end()) 1184 continue; 1185 // Replace the yield of old value with the for op argument to make it easier 1186 // to remove the dead code. 1187 yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 1188 yieldOperands.push_back(it->second); 1189 } 1190 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 1191 1192 LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 1193 rewriter.eraseOp(op); 1194 return success(); 1195 } 1196 1197 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 1198 static LogicalResult 1199 convertElementwiseOp(RewriterBase &rewriter, Operation *op, 1200 gpu::MMAElementwiseOp opType, 1201 llvm::DenseMap<Value, Value> &valueMapping) { 1202 OpBuilder::InsertionGuard g(rewriter); 1203 rewriter.setInsertionPoint(op); 1204 1205 SmallVector<Value> matrixOperands; 1206 for (Value operand : op->getOperands()) { 1207 auto it = valueMapping.find(operand); 1208 if (it == valueMapping.end()) 1209 return rewriter.notifyMatchFailure(op, "no mapping"); 1210 matrixOperands.push_back(it->second); 1211 } 1212 auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>(); 1213 if (opType == gpu::MMAElementwiseOp::EXTF) { 1214 // The floating point extension case has a different result type. 1215 auto vectorType = op->getResultTypes()[0].cast<VectorType>(); 1216 resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1217 vectorType.getElementType(), 1218 resultType.getOperand()); 1219 } 1220 1221 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1222 op->getLoc(), resultType, matrixOperands, opType); 1223 valueMapping[op->getResult(0)] = newOp; 1224 return success(); 1225 } 1226 1227 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 1228 bool useNvGpu) { 1229 if (!useNvGpu) { 1230 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1231 patterns.getContext()); 1232 return; 1233 } 1234 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1235 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1236 } 1237 1238 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 1239 Operation *rootOp) { 1240 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1241 llvm::DenseMap<Value, Value> valueMapping; 1242 1243 auto globalRes = LogicalResult::success(); 1244 for (Operation *op : ops) { 1245 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 1246 // Apparently callers do not want to early exit on failure here. 1247 auto res = LogicalResult::success(); 1248 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 1249 res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1250 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 1251 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1252 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 1253 res = convertContractOp(rewriter, contractOp, valueMapping); 1254 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 1255 res = convertConstantOp(rewriter, constantOp, valueMapping); 1256 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 1257 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 1258 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 1259 res = convertForOp(rewriter, forOp, valueMapping); 1260 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 1261 res = convertYieldOp(rewriter, yieldOp, valueMapping); 1262 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 1263 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1264 } 1265 if (failed(res)) 1266 globalRes = failure(); 1267 } 1268 return globalRes; 1269 } 1270 1271 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 1272 Operation *rootOp) { 1273 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 1274 llvm::DenseMap<Value, Value> valueMapping; 1275 for (Operation *op : ops) { 1276 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 1277 .Case([&](vector::TransferReadOp transferReadOp) { 1278 return convertTransferReadToLoads(rewriter, transferReadOp, 1279 valueMapping); 1280 }) 1281 .Case([&](vector::TransferWriteOp transferWriteOp) { 1282 return convertTransferWriteToStores(rewriter, transferWriteOp, 1283 valueMapping); 1284 }) 1285 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 1286 return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1287 valueMapping); 1288 }) 1289 .Case([&](vector::ContractionOp contractionOp) { 1290 return convertContractOpToMmaSync(rewriter, contractionOp, 1291 valueMapping); 1292 }) 1293 .Case([&](scf::ForOp forOp) { 1294 return convertForOp(rewriter, forOp, valueMapping); 1295 }) 1296 .Case([&](scf::YieldOp yieldOp) { 1297 return convertYieldOp(rewriter, yieldOp, valueMapping); 1298 }) 1299 .Case([&](arith::ConstantOp constOp) { 1300 return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 1301 }) 1302 .Default([&](Operation *op) { 1303 return op->emitError() << "unhandled vector to mma type: " << *op; 1304 }) 1305 .failed()) { 1306 return op->emitError() << "Failed to convert op " << *op; 1307 } 1308 } 1309 return success(); 1310 } 1311 1312 namespace { 1313 1314 struct ConvertVectorToGPUPass 1315 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 1316 1317 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 1318 useNvGpu.setValue(useNvGpu_); 1319 } 1320 1321 void runOnOperation() override { 1322 RewritePatternSet patterns(&getContext()); 1323 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 1324 if (failed( 1325 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 1326 return signalPassFailure(); 1327 1328 IRRewriter rewriter(&getContext()); 1329 if (useNvGpu.getValue()) { 1330 if (failed( 1331 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 1332 return signalPassFailure(); 1333 } 1334 (void)convertVectorToMMAOps(rewriter, getOperation()); 1335 } 1336 }; 1337 1338 } // namespace 1339 1340 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 1341 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1342 } 1343