1edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2edd9515bSthomasraoux // 3edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6edd9515bSthomasraoux // 7edd9515bSthomasraoux //===----------------------------------------------------------------------===// 8edd9515bSthomasraoux // 9edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops. 10edd9515bSthomasraoux // 11edd9515bSthomasraoux //===----------------------------------------------------------------------===// 12edd9515bSthomasraoux 13edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 14edd9515bSthomasraoux 15ea2ed80eSChristopher Bate #include <type_traits> 16ea2ed80eSChristopher Bate 17edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h" 18b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h" 19ea2ed80eSChristopher Bate #include "mlir/Dialect/Affine/IR/AffineOps.h" 20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 21d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 2266f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 2351b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24ea2ed80eSChristopher Bate #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 26edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 28fb7ef637SJakub Kuderski #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 2999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 30edd9515bSthomasraoux #include "mlir/IR/Builders.h" 315ef7ceaeSNicolas Vasilache #include "mlir/IR/BuiltinOps.h" 325ef7ceaeSNicolas Vasilache #include "mlir/IR/Region.h" 33edd9515bSthomasraoux #include "mlir/Pass/Pass.h" 34edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 35edd9515bSthomasraoux #include "mlir/Transforms/Passes.h" 365ef7ceaeSNicolas Vasilache #include "llvm/ADT/STLExtras.h" 371ca772edSChristopher Bate #include "llvm/ADT/TypeSwitch.h" 38edd9515bSthomasraoux 395ef7ceaeSNicolas Vasilache #define DEBUG_TYPE "vector-to-gpu" 405ef7ceaeSNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 415ef7ceaeSNicolas Vasilache #define DBGSNL() (llvm::dbgs() << "\n") 425ef7ceaeSNicolas Vasilache 4367d0d7acSMichele Scuttari namespace mlir { 4467d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTVECTORTOGPU 4567d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 4667d0d7acSMichele Scuttari } // namespace mlir 4767d0d7acSMichele Scuttari 48edd9515bSthomasraoux using namespace mlir; 49edd9515bSthomasraoux 501ca772edSChristopher Bate /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 511ca772edSChristopher Bate /// AffineMap representing offsets to apply to indices, the function fills 521ca772edSChristopher Bate /// `indices` with the original indices plus the offsets. The offsets are 531ca772edSChristopher Bate /// applied by taking into account the permutation map of the transfer op. If 541ca772edSChristopher Bate /// the `offsetMap` has dimension placeholders, those should be provided in 551ca772edSChristopher Bate /// `dimValues`. 561ca772edSChristopher Bate template <typename TransferOpType> 575ef7ceaeSNicolas Vasilache static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, 581ca772edSChristopher Bate AffineMap offsetMap, ArrayRef<Value> dimValues, 591ca772edSChristopher Bate SmallVector<Value, 4> &indices) { 601ca772edSChristopher Bate indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 611ca772edSChristopher Bate Location loc = xferOp.getLoc(); 621ca772edSChristopher Bate unsigned offsetsIdx = 0; 631ca772edSChristopher Bate for (auto expr : xferOp.getPermutationMap().getResults()) { 641609f1c2Slong.chen if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 651ca772edSChristopher Bate Value prevIdx = indices[dim.getPosition()]; 665262865aSKazu Hirata SmallVector<OpFoldResult, 3> dims(dimValues); 671ca772edSChristopher Bate dims.push_back(prevIdx); 685ef7ceaeSNicolas Vasilache AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); 694c48f016SMatthias Springer indices[dim.getPosition()] = affine::makeComposedAffineApply( 705ef7ceaeSNicolas Vasilache rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 711ca772edSChristopher Bate continue; 721ca772edSChristopher Bate } 731ca772edSChristopher Bate } 741ca772edSChristopher Bate } 751ca772edSChristopher Bate 76edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul. 771ca772edSChristopher Bate static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 781ca772edSChristopher Bate bool useNvGpu) { 79edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 80fe8a62c4SUday Bondhugula auto infer = [&](MapList m) { 81fe8a62c4SUday Bondhugula return AffineMap::inferFromExprList(m, contract.getContext()); 82fe8a62c4SUday Bondhugula }; 83edd9515bSthomasraoux AffineExpr m, n, k; 84edd9515bSthomasraoux bindDims(contract.getContext(), m, n, k); 857c38fd60SJacques Pienaar auto iteratorTypes = contract.getIteratorTypes().getValue(); 864758e916SOleg Shyshkov if (!(vector::isParallelIterator(iteratorTypes[0]) && 874758e916SOleg Shyshkov vector::isParallelIterator(iteratorTypes[1]) && 884758e916SOleg Shyshkov vector::isReductionIterator(iteratorTypes[2]))) 89edd9515bSthomasraoux return false; 90edd9515bSthomasraoux 91edd9515bSthomasraoux // The contract needs to represent a matmul to be able to convert to 92edd9515bSthomasraoux // MMAMatrix matmul. 931ca772edSChristopher Bate if (!useNvGpu && 94d2c0572bSJacques Pienaar contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) 951ca772edSChristopher Bate return false; 96d2c0572bSJacques Pienaar if (useNvGpu && 97d2c0572bSJacques Pienaar contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) 98edd9515bSthomasraoux return false; 99edd9515bSthomasraoux 100edd9515bSthomasraoux return true; 101edd9515bSthomasraoux } 102edd9515bSthomasraoux 103c0321edcSQuinn Dawkins // Return true if the given map represents a transposed matrix load, 104c0321edcSQuinn Dawkins // i.e. (d0, d1, ...) -> (dn-1, dn-2). 1055ef7ceaeSNicolas Vasilache static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { 1065ef7ceaeSNicolas Vasilache MLIRContext *ctx = permutationMap.getContext(); 1075ef7ceaeSNicolas Vasilache // Local OpBuilder is fine here, we just build attributes. 1085ef7ceaeSNicolas Vasilache OpBuilder b(ctx); 109c0321edcSQuinn Dawkins auto nDim = permutationMap.getNumDims(); 110f1db4aecSLei Zhang AffineExpr zero = b.getAffineConstantExpr(0); 111dbddd4f6SLei Zhang if (nDim < 2) { 112dbddd4f6SLei Zhang // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. 113dbddd4f6SLei Zhang AffineExpr dim0 = b.getAffineDimExpr(0); 114f1db4aecSLei Zhang return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); 115dbddd4f6SLei Zhang } 116c0321edcSQuinn Dawkins 117c0321edcSQuinn Dawkins AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); 118c0321edcSQuinn Dawkins AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); 119f1db4aecSLei Zhang // Support both transposed and transposed+broadcasted cases. 120f1db4aecSLei Zhang return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || 121f1db4aecSLei Zhang permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); 122c0321edcSQuinn Dawkins } 123c0321edcSQuinn Dawkins 124cafb6284SChristopher Bate // Return the stide for the second-to-last dimension of |type| if it is a memref 125cafb6284SChristopher Bate // and has a constant stride. 126cafb6284SChristopher Bate static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { 1275550c821STres Popp auto memrefType = dyn_cast<MemRefType>(type); 128edd9515bSthomasraoux if (!memrefType) 129edd9515bSthomasraoux return false; 130a57ccad5SThomas Raoux // If the memref is 0 or 1D the horizontal stride is 0. 131a57ccad5SThomas Raoux if (memrefType.getRank() < 2) 132a57ccad5SThomas Raoux return 0; 133edd9515bSthomasraoux int64_t offset = 0; 134edd9515bSthomasraoux SmallVector<int64_t, 2> strides; 135*6aaa8f25SMatthias Springer if (failed(memrefType.getStridesAndOffset(strides, offset)) || 136d77f4836SThomas Raoux strides.back() != 1) 1371a36588eSKazu Hirata return std::nullopt; 138a57ccad5SThomas Raoux int64_t stride = strides[strides.size() - 2]; 139399638f9SAliia Khasanova if (stride == ShapedType::kDynamic) 1401a36588eSKazu Hirata return std::nullopt; 141a57ccad5SThomas Raoux return stride; 142edd9515bSthomasraoux } 143edd9515bSthomasraoux 144edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load. 145cafb6284SChristopher Bate static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 1467c38fd60SJacques Pienaar if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 147edd9515bSthomasraoux readOp.getVectorType().getRank() != 2) 148edd9515bSthomasraoux return false; 149cafb6284SChristopher Bate if (!getStaticallyKnownRowStride(readOp.getShapedType())) 150edd9515bSthomasraoux return false; 151985f7ff6SQuinn Dawkins 152985f7ff6SQuinn Dawkins // Only allow integer types if the signedness can be inferred. 153cafb6284SChristopher Bate if (readOp.getVectorType().getElementType().isInteger(8)) 1545205c712SQuinn Dawkins if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && 1555205c712SQuinn Dawkins !isa<arith::ExtUIOp>(*readOp->user_begin()))) 156985f7ff6SQuinn Dawkins return false; 157985f7ff6SQuinn Dawkins 1587c38fd60SJacques Pienaar AffineMap map = readOp.getPermutationMap(); 1595ef7ceaeSNicolas Vasilache MLIRContext *ctx = readOp.getContext(); 1605ef7ceaeSNicolas Vasilache AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); 1615ef7ceaeSNicolas Vasilache AffineExpr zero = getAffineConstantExpr(0, ctx); 1625ef7ceaeSNicolas Vasilache auto broadcastInnerDim = 1635ef7ceaeSNicolas Vasilache AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); 164cafb6284SChristopher Bate return map.isMinorIdentity() || map == broadcastInnerDim || 1655ef7ceaeSNicolas Vasilache isTransposeMatrixLoadMap(map); 166edd9515bSthomasraoux } 167edd9515bSthomasraoux 168edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store. 169edd9515bSthomasraoux static bool 170edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 171c537a943SNicolas Vasilache // TODO: support 0-d corner case. 172c537a943SNicolas Vasilache if (writeOp.getTransferRank() == 0) 173c537a943SNicolas Vasilache return false; 174c537a943SNicolas Vasilache 1757c38fd60SJacques Pienaar if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 176edd9515bSthomasraoux writeOp.getVectorType().getRank() != 2) 177edd9515bSthomasraoux return false; 178cafb6284SChristopher Bate if (!getStaticallyKnownRowStride(writeOp.getShapedType())) 179edd9515bSthomasraoux return false; 180edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 1817c38fd60SJacques Pienaar if (!writeOp.getPermutationMap().isMinorIdentity()) 182edd9515bSthomasraoux return false; 183edd9515bSthomasraoux return true; 184edd9515bSthomasraoux } 185edd9515bSthomasraoux 1866413226dSthomasraoux /// Return true if the constant is a splat to a 2D vector so that it can be 1876413226dSthomasraoux /// converted to a MMA constant matrix op. 188a54f4eaeSMogball static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 1895550c821STres Popp auto vecType = dyn_cast<VectorType>(constantOp.getType()); 1906413226dSthomasraoux if (!vecType || vecType.getRank() != 2) 1916413226dSthomasraoux return false; 1925550c821STres Popp return isa<SplatElementsAttr>(constantOp.getValue()); 1936413226dSthomasraoux } 1946413226dSthomasraoux 19543928419Sthomasraoux /// Return true if this is a broadcast from scalar to a 2D vector. 19643928419Sthomasraoux static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 197a1aad28dSLei Zhang return broadcastOp.getResultVectorType().getRank() == 2; 198985f7ff6SQuinn Dawkins } 199985f7ff6SQuinn Dawkins 2005205c712SQuinn Dawkins /// Return true if this integer extend op can be folded into a contract op. 2015205c712SQuinn Dawkins template <typename ExtOpTy> 2025205c712SQuinn Dawkins static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { 203927559d2SLongsheng Mou auto transferReadOp = 204927559d2SLongsheng Mou extOp.getOperand().template getDefiningOp<vector::TransferReadOp>(); 205927559d2SLongsheng Mou if (!transferReadOp) 206985f7ff6SQuinn Dawkins return false; 207971b8525SJakub Kuderski return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>); 20843928419Sthomasraoux } 20943928419Sthomasraoux 210a0119437SLei Zhang static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } 211a0119437SLei Zhang 2127fbb0678Sthomasraoux /// Return the MMA elementwise enum associated with `op` if it is supported. 213192d9dd7SKazu Hirata /// Return `std::nullopt` otherwise. 214d32ec523SRamkumar Ramachandra static std::optional<gpu::MMAElementwiseOp> 2157fbb0678Sthomasraoux convertElementwiseOpToMMA(Operation *op) { 2167fbb0678Sthomasraoux if (isa<arith::AddFOp>(op)) 2177fbb0678Sthomasraoux return gpu::MMAElementwiseOp::ADDF; 2187fbb0678Sthomasraoux if (isa<arith::MulFOp>(op)) 2197fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MULF; 22050882b4dSLei Zhang if (isa<arith::SubFOp>(op)) 22150882b4dSLei Zhang return gpu::MMAElementwiseOp::SUBF; 2228a6e54c9SDaniil Dudkin if (isa<arith::MaximumFOp>(op)) 2237fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MAXF; 2248a6e54c9SDaniil Dudkin if (isa<arith::MinimumFOp>(op)) 2257fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MINF; 226e7969240SThomas Raoux if (isa<arith::DivFOp>(op)) 227e7969240SThomas Raoux return gpu::MMAElementwiseOp::DIVF; 22850882b4dSLei Zhang if (isa<arith::AddIOp>(op)) 22950882b4dSLei Zhang return gpu::MMAElementwiseOp::ADDI; 23050882b4dSLei Zhang if (isa<arith::MulIOp>(op)) 23150882b4dSLei Zhang return gpu::MMAElementwiseOp::MULI; 23250882b4dSLei Zhang if (isa<arith::SubIOp>(op)) 23350882b4dSLei Zhang return gpu::MMAElementwiseOp::SUBI; 23450882b4dSLei Zhang if (isa<arith::DivSIOp>(op)) 23550882b4dSLei Zhang return gpu::MMAElementwiseOp::DIVS; 23650882b4dSLei Zhang if (isa<arith::DivUIOp>(op)) 23750882b4dSLei Zhang return gpu::MMAElementwiseOp::DIVU; 23850882b4dSLei Zhang if (isa<arith::NegFOp>(op)) 23950882b4dSLei Zhang return gpu::MMAElementwiseOp::NEGATEF; 240a0119437SLei Zhang if (isa<arith::ExtFOp>(op)) 241a0119437SLei Zhang return gpu::MMAElementwiseOp::EXTF; 2421a36588eSKazu Hirata return std::nullopt; 2437fbb0678Sthomasraoux } 2447fbb0678Sthomasraoux 2457fbb0678Sthomasraoux /// Return true if the op is supported as elementwise op on MMAMatrix type. 2467fbb0678Sthomasraoux static bool elementwiseSupportsMMAMatrixType(Operation *op) { 247064a08cdSKazu Hirata return convertElementwiseOpToMMA(op).has_value(); 2487fbb0678Sthomasraoux } 2497fbb0678Sthomasraoux 250114ba722SManish Gupta /// Returns true if the extract strided slice op is supported with `mma.sync` 251114ba722SManish Gupta /// path. 252114ba722SManish Gupta static bool 253114ba722SManish Gupta extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { 254114ba722SManish Gupta 255114ba722SManish Gupta FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 256114ba722SManish Gupta nvgpu::getWarpMatrixInfo(op); 257114ba722SManish Gupta if (failed(warpMatrixInfo)) 258114ba722SManish Gupta return false; 259114ba722SManish Gupta 260114ba722SManish Gupta FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); 261114ba722SManish Gupta if (failed(contractOp)) 262114ba722SManish Gupta return false; 263114ba722SManish Gupta 2643af64383SNicolas Vasilache // Handle vector.extract_strided_slice on registers containing 2653af64383SNicolas Vasilache // matrixB and matrixC operands. vector.extract_strided_slice op 2663af64383SNicolas Vasilache // is not supported on registers containing matrixA operands. 267114ba722SManish Gupta if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) 2685550c821STres Popp return (cast<VectorType>(op->getResult(0).getType()) == 2695550c821STres Popp cast<VectorType>((*contractOp).getRhs().getType())); 2706a7a1188SMehdi Amini if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) 2715550c821STres Popp return (cast<VectorType>(op->getResult(0).getType()) == 2725550c821STres Popp cast<VectorType>((*contractOp).getAcc().getType())); 273114ba722SManish Gupta 274114ba722SManish Gupta return false; 275114ba722SManish Gupta } 276114ba722SManish Gupta 2771ca772edSChristopher Bate static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 2781a865592Sthomasraoux if (isa<scf::ForOp, scf::YieldOp>(op)) 2791a865592Sthomasraoux return true; 280edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 281cafb6284SChristopher Bate return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) 282cafb6284SChristopher Bate : transferReadSupportsMMAMatrixType(transferRead); 283edd9515bSthomasraoux if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 284cafb6284SChristopher Bate return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) 285cafb6284SChristopher Bate : transferWriteSupportsMMAMatrixType(transferWrite); 286114ba722SManish Gupta if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) 287114ba722SManish Gupta return useNvGpu && 288114ba722SManish Gupta extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); 289edd9515bSthomasraoux if (auto contract = dyn_cast<vector::ContractionOp>(op)) 2901ca772edSChristopher Bate return contractSupportsMMAMatrixType(contract, useNvGpu); 291a54f4eaeSMogball if (auto constant = dyn_cast<arith::ConstantOp>(op)) 2926413226dSthomasraoux return constantSupportsMMAMatrixType(constant); 2933af64383SNicolas Vasilache if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 2943af64383SNicolas Vasilache return broadcastSupportsMMAMatrixType(broadcast); 2955205c712SQuinn Dawkins if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) 2965205c712SQuinn Dawkins return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); 2975205c712SQuinn Dawkins if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) 2985205c712SQuinn Dawkins return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); 299a0119437SLei Zhang if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) 300a0119437SLei Zhang return fpExtendSupportsMMAMatrixType(fpExtend); 3017fbb0678Sthomasraoux return elementwiseSupportsMMAMatrixType(op); 302edd9515bSthomasraoux } 303edd9515bSthomasraoux 304e7969240SThomas Raoux /// Return an unsorted slice handling scf.for region differently than 305e7969240SThomas Raoux /// `getSlice`. In scf.for we only want to include as part of the slice elements 306e7969240SThomas Raoux /// that are part of the use/def chain. 307641b12e9SMahesh Ravishankar static SetVector<Operation *> 308847d8457SMehdi Amini getSliceContract(Operation *op, 309847d8457SMehdi Amini const BackwardSliceOptions &backwardSliceOptions, 310847d8457SMehdi Amini const ForwardSliceOptions &forwardSliceOptions) { 311e7969240SThomas Raoux SetVector<Operation *> slice; 312e7969240SThomas Raoux slice.insert(op); 313e7969240SThomas Raoux unsigned currentIndex = 0; 314e7969240SThomas Raoux SetVector<Operation *> backwardSlice; 315e7969240SThomas Raoux SetVector<Operation *> forwardSlice; 316e7969240SThomas Raoux while (currentIndex != slice.size()) { 317e7969240SThomas Raoux auto *currentOp = (slice)[currentIndex]; 3183af64383SNicolas Vasilache // Compute and insert the backwardSlice starting from currentOp. 319e7969240SThomas Raoux backwardSlice.clear(); 320641b12e9SMahesh Ravishankar getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); 321e7969240SThomas Raoux slice.insert(backwardSlice.begin(), backwardSlice.end()); 322e7969240SThomas Raoux 3233af64383SNicolas Vasilache // Compute and insert the forwardSlice starting from currentOp. 324e7969240SThomas Raoux forwardSlice.clear(); 3253af64383SNicolas Vasilache // Special case for ForOp, we don't want to include the whole region but 3263af64383SNicolas Vasilache // only the value using the region arguments. 3273af64383SNicolas Vasilache // TODO: We should refine this to only care about the region arguments being 3283af64383SNicolas Vasilache // converted to matrix type. 329e7969240SThomas Raoux if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 330e7969240SThomas Raoux for (Value forOpResult : forOp.getResults()) 331641b12e9SMahesh Ravishankar getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); 332e7969240SThomas Raoux for (BlockArgument &arg : forOp.getRegionIterArgs()) 333641b12e9SMahesh Ravishankar getForwardSlice(arg, &forwardSlice, forwardSliceOptions); 334e7969240SThomas Raoux } else { 335641b12e9SMahesh Ravishankar getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); 336e7969240SThomas Raoux } 337e7969240SThomas Raoux slice.insert(forwardSlice.begin(), forwardSlice.end()); 338e7969240SThomas Raoux ++currentIndex; 339e7969240SThomas Raoux } 340e7969240SThomas Raoux return slice; 341e7969240SThomas Raoux } 342e7969240SThomas Raoux 343edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole 344edd9515bSthomasraoux // slice can be converted to MMA operations. 3451ca772edSChristopher Bate static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 3461ca772edSChristopher Bate bool useNvGpu) { 347edd9515bSthomasraoux auto hasVectorDest = [](Operation *op) { 348971b8525SJakub Kuderski return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>); 34943928419Sthomasraoux }; 350641b12e9SMahesh Ravishankar BackwardSliceOptions backwardSliceOptions; 351641b12e9SMahesh Ravishankar backwardSliceOptions.filter = hasVectorDest; 352641b12e9SMahesh Ravishankar 35343928419Sthomasraoux auto hasVectorSrc = [](Operation *op) { 354971b8525SJakub Kuderski return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>); 355edd9515bSthomasraoux }; 356641b12e9SMahesh Ravishankar ForwardSliceOptions forwardSliceOptions; 357641b12e9SMahesh Ravishankar forwardSliceOptions.filter = hasVectorSrc; 358641b12e9SMahesh Ravishankar 359edd9515bSthomasraoux SetVector<Operation *> opToConvert; 360edd9515bSthomasraoux op->walk([&](vector::ContractionOp contract) { 361edd9515bSthomasraoux if (opToConvert.contains(contract.getOperation())) 362edd9515bSthomasraoux return; 363edd9515bSthomasraoux SetVector<Operation *> dependentOps = 364641b12e9SMahesh Ravishankar getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); 3653af64383SNicolas Vasilache // If any instruction cannot use MMA matrix type drop the whole 3663af64383SNicolas Vasilache // chain. MMA matrix are stored in an opaque type so they cannot be used 3673af64383SNicolas Vasilache // by all operations. 3681ca772edSChristopher Bate if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 369cafb6284SChristopher Bate if (!supportsMMaMatrixType(op, useNvGpu)) { 370cafb6284SChristopher Bate LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); 371cafb6284SChristopher Bate return true; 372cafb6284SChristopher Bate } 373cafb6284SChristopher Bate return false; 3741ca772edSChristopher Bate })) 375edd9515bSthomasraoux return; 376cafb6284SChristopher Bate 377edd9515bSthomasraoux opToConvert.insert(dependentOps.begin(), dependentOps.end()); 378edd9515bSthomasraoux }); 3793af64383SNicolas Vasilache // Sort the operations so that we can convert them in topological order. 380e7969240SThomas Raoux return topologicalSort(opToConvert); 381edd9515bSthomasraoux } 382edd9515bSthomasraoux 383edd9515bSthomasraoux namespace { 384edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 385edd9515bSthomasraoux // to MMA matmul. 386edd9515bSthomasraoux struct PrepareContractToGPUMMA 387edd9515bSthomasraoux : public OpRewritePattern<vector::ContractionOp> { 388edd9515bSthomasraoux using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 389edd9515bSthomasraoux 390edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::ContractionOp op, 391edd9515bSthomasraoux PatternRewriter &rewriter) const override { 392edd9515bSthomasraoux Location loc = op.getLoc(); 3937c38fd60SJacques Pienaar Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 394edd9515bSthomasraoux 395edd9515bSthomasraoux // Set up the parallel/reduction structure in right form. 396edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 397fe8a62c4SUday Bondhugula auto infer = [&](MapList m) { 398fe8a62c4SUday Bondhugula return AffineMap::inferFromExprList(m, op.getContext()); 399fe8a62c4SUday Bondhugula }; 400edd9515bSthomasraoux AffineExpr m, n, k; 401edd9515bSthomasraoux bindDims(rewriter.getContext(), m, n, k); 402edd9515bSthomasraoux static constexpr std::array<int64_t, 2> perm = {1, 0}; 4037c38fd60SJacques Pienaar auto iteratorTypes = op.getIteratorTypes().getValue(); 404d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 4054758e916SOleg Shyshkov if (!(vector::isParallelIterator(iteratorTypes[0]) && 4064758e916SOleg Shyshkov vector::isParallelIterator(iteratorTypes[1]) && 4074758e916SOleg Shyshkov vector::isReductionIterator(iteratorTypes[2]))) 4085ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not a gemm contraction"); 409edd9515bSthomasraoux // 410edd9515bSthomasraoux // Two outer parallel, one inner reduction (matmat flavor). 411edd9515bSthomasraoux // 412edd9515bSthomasraoux // This is the classical row-major matmul, nothing to do. 4135ef7ceaeSNicolas Vasilache if (maps == infer({{m, k}, {k, n}, {m, n}})) 4145ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "contraction already prepared"); 415edd9515bSthomasraoux if (maps == infer({{m, k}, {n, k}, {m, n}})) { 416edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 417edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 418edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 419edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 420edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 421edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 422edd9515bSthomasraoux } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 423edd9515bSthomasraoux std::swap(rhs, lhs); 424edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 425edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 426edd9515bSthomasraoux } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 427edd9515bSthomasraoux std::swap(rhs, lhs); 428edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 429edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 430edd9515bSthomasraoux std::swap(lhs, rhs); 431edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 432edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 433edd9515bSthomasraoux std::swap(lhs, rhs); 434edd9515bSthomasraoux } else { 4355ef7ceaeSNicolas Vasilache // TODO: llvm_unreachable ? 4365ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "unexpected contraction case"); 437edd9515bSthomasraoux } 438edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::ContractionOp>( 439edd9515bSthomasraoux op, lhs, rhs, res, 440edd9515bSthomasraoux rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 4417c38fd60SJacques Pienaar op.getIteratorTypes()); 442edd9515bSthomasraoux return success(); 443edd9515bSthomasraoux } 444edd9515bSthomasraoux }; 445edd9515bSthomasraoux 446baa5beecStyb0807 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports 447114ba722SManish Gupta // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, 448114ba722SManish Gupta // respectively. We can fold the transpose operation when loading the data from 449114ba722SManish Gupta // Shared Memory to registers. 450edd9515bSthomasraoux struct CombineTransferReadOpTranspose final 451edd9515bSthomasraoux : public OpRewritePattern<vector::TransposeOp> { 452edd9515bSthomasraoux using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 453edd9515bSthomasraoux 454edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::TransposeOp op, 455edd9515bSthomasraoux PatternRewriter &rewriter) const override { 456985f7ff6SQuinn Dawkins // Look through integer extend ops. 457985f7ff6SQuinn Dawkins Value source = op.getVector(); 4583cf7f224SThomas Raoux Type resultType = op.getType(); 4595205c712SQuinn Dawkins Operation *extOp; 4605205c712SQuinn Dawkins if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || 46142bba97fSharsh-nod (extOp = source.getDefiningOp<arith::ExtUIOp>()) || 46242bba97fSharsh-nod (extOp = source.getDefiningOp<arith::ExtFOp>())) { 4635205c712SQuinn Dawkins source = extOp->getOperand(0); 464985f7ff6SQuinn Dawkins resultType = 4655550c821STres Popp VectorType::get(cast<VectorType>(resultType).getShape(), 4665550c821STres Popp cast<VectorType>(source.getType()).getElementType()); 467985f7ff6SQuinn Dawkins } 468985f7ff6SQuinn Dawkins 469985f7ff6SQuinn Dawkins auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); 470edd9515bSthomasraoux if (!transferReadOp) 4715ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no transfer read"); 472c537a943SNicolas Vasilache 473c537a943SNicolas Vasilache // TODO: support 0-d corner case. 474c537a943SNicolas Vasilache if (transferReadOp.getTransferRank() == 0) 4755ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "0-D transfer read"); 476c537a943SNicolas Vasilache 4777c38fd60SJacques Pienaar if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 4785ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); 4795ef7ceaeSNicolas Vasilache 480edd9515bSthomasraoux AffineMap permutationMap = 48132c3decbSMatthias Springer AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); 4827c38fd60SJacques Pienaar AffineMap newMap = 4837c38fd60SJacques Pienaar permutationMap.compose(transferReadOp.getPermutationMap()); 484985f7ff6SQuinn Dawkins 485985f7ff6SQuinn Dawkins auto loc = op.getLoc(); 486985f7ff6SQuinn Dawkins Value result = 487985f7ff6SQuinn Dawkins rewriter 488985f7ff6SQuinn Dawkins .create<vector::TransferReadOp>( 489985f7ff6SQuinn Dawkins loc, resultType, transferReadOp.getSource(), 4907c38fd60SJacques Pienaar transferReadOp.getIndices(), AffineMapAttr::get(newMap), 4917c38fd60SJacques Pienaar transferReadOp.getPadding(), transferReadOp.getMask(), 492985f7ff6SQuinn Dawkins transferReadOp.getInBoundsAttr()) 493985f7ff6SQuinn Dawkins .getResult(); 494985f7ff6SQuinn Dawkins 495985f7ff6SQuinn Dawkins // Fuse through the integer extend op. 4965205c712SQuinn Dawkins if (extOp) { 4975205c712SQuinn Dawkins if (isa<arith::ExtSIOp>(extOp)) 498985f7ff6SQuinn Dawkins result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) 499985f7ff6SQuinn Dawkins .getResult(); 50042bba97fSharsh-nod else if (isa<arith::ExtUIOp>(extOp)) 5015205c712SQuinn Dawkins result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) 5025205c712SQuinn Dawkins .getResult(); 50342bba97fSharsh-nod else 50442bba97fSharsh-nod result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) 50542bba97fSharsh-nod .getResult(); 5065205c712SQuinn Dawkins } 507985f7ff6SQuinn Dawkins 508985f7ff6SQuinn Dawkins rewriter.replaceOp(op, result); 509edd9515bSthomasraoux return success(); 510edd9515bSthomasraoux } 511edd9515bSthomasraoux }; 512edd9515bSthomasraoux 513edd9515bSthomasraoux } // namespace 514edd9515bSthomasraoux 515edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops. 5166413226dSthomasraoux // Figure the right layout to use by looking at op uses. 517edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and 518edd9515bSthomasraoux // only care about it during lowering to NVVM. 5195205c712SQuinn Dawkins static const char *inferFragType(Operation *op) { 520a037d889SLei Zhang // We can have arith.ext ops before reaching contract ops. See through them 521a037d889SLei Zhang // and other kinds of elementwise ops. 522a037d889SLei Zhang if (op->hasOneUse()) { 523a037d889SLei Zhang Operation *userOp = *op->user_begin(); 524a037d889SLei Zhang if (userOp->hasTrait<OpTrait::Elementwise>()) 525a037d889SLei Zhang return inferFragType(userOp); 526a037d889SLei Zhang } 527a037d889SLei Zhang 528edd9515bSthomasraoux for (Operation *users : op->getUsers()) { 529edd9515bSthomasraoux auto contract = dyn_cast<vector::ContractionOp>(users); 530edd9515bSthomasraoux if (!contract) 531edd9515bSthomasraoux continue; 5325205c712SQuinn Dawkins assert(op->getNumResults() == 1); 5335205c712SQuinn Dawkins if (contract.getLhs() == op->getResult(0)) 534edd9515bSthomasraoux return "AOp"; 5355205c712SQuinn Dawkins if (contract.getRhs() == op->getResult(0)) 536edd9515bSthomasraoux return "BOp"; 537edd9515bSthomasraoux } 538edd9515bSthomasraoux return "COp"; 539edd9515bSthomasraoux } 540edd9515bSthomasraoux 5415ef7ceaeSNicolas Vasilache static LogicalResult 5425ef7ceaeSNicolas Vasilache convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, 543edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 5445ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 5455ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 5465ef7ceaeSNicolas Vasilache 547c537a943SNicolas Vasilache assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 548cafb6284SChristopher Bate assert(transferReadSupportsMMAMatrixType(op) && 549cafb6284SChristopher Bate "expected convertible operation"); 550dbddd4f6SLei Zhang 551d32ec523SRamkumar Ramachandra std::optional<int64_t> stride = 552cafb6284SChristopher Bate getStaticallyKnownRowStride(op.getShapedType()); 5535ef7ceaeSNicolas Vasilache if (!stride.has_value()) { 5545ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no stride\n"); 5555ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no stride"); 5565ef7ceaeSNicolas Vasilache } 557dbddd4f6SLei Zhang 5587c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap(); 5595ef7ceaeSNicolas Vasilache bool isTranspose = isTransposeMatrixLoadMap(map); 560dbddd4f6SLei Zhang 561e7969240SThomas Raoux // Handle broadcast by setting the stride to 0. 5621609f1c2Slong.chen if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { 563dbddd4f6SLei Zhang assert(cstExpr.getValue() == 0); 564e7969240SThomas Raoux stride = 0; 565e7969240SThomas Raoux } 5665ef7ceaeSNicolas Vasilache 567985f7ff6SQuinn Dawkins Value mappingResult = op.getResult(); 568985f7ff6SQuinn Dawkins auto elType = op.getVectorType().getElementType(); 569edd9515bSthomasraoux const char *fragType = inferFragType(op); 570985f7ff6SQuinn Dawkins if (op->hasOneUse()) { 571b8a3f0fdSMehdi Amini auto *user = *op->user_begin(); 5725205c712SQuinn Dawkins // Infer the signedness of the mma type from the integer extend. 573a037d889SLei Zhang if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) { 5745205c712SQuinn Dawkins elType = IntegerType::get( 5755550c821STres Popp op.getContext(), cast<IntegerType>(elType).getWidth(), 576a037d889SLei Zhang isa<arith::ExtSIOp>(user) ? IntegerType::Signed 577a037d889SLei Zhang : IntegerType::Unsigned); 5785205c712SQuinn Dawkins mappingResult = user->getResult(0); 579985f7ff6SQuinn Dawkins } 580985f7ff6SQuinn Dawkins } 581edd9515bSthomasraoux gpu::MMAMatrixType type = 582985f7ff6SQuinn Dawkins gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); 5835ef7ceaeSNicolas Vasilache Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( 5847c38fd60SJacques Pienaar op.getLoc(), type, op.getSource(), op.getIndices(), 5855ef7ceaeSNicolas Vasilache rewriter.getIndexAttr(*stride), 5865ef7ceaeSNicolas Vasilache isTranspose ? rewriter.getUnitAttr() : UnitAttr()); 587985f7ff6SQuinn Dawkins valueMapping[mappingResult] = load; 5885ef7ceaeSNicolas Vasilache 5895ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); 5905ef7ceaeSNicolas Vasilache return success(); 591edd9515bSthomasraoux } 592edd9515bSthomasraoux 5935ef7ceaeSNicolas Vasilache static LogicalResult 5945ef7ceaeSNicolas Vasilache convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, 595edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 5965ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 5975ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 5985ef7ceaeSNicolas Vasilache 599edd9515bSthomasraoux assert(transferWriteSupportsMMAMatrixType(op)); 600d32ec523SRamkumar Ramachandra std::optional<int64_t> stride = 601cafb6284SChristopher Bate getStaticallyKnownRowStride(op.getShapedType()); 6025ef7ceaeSNicolas Vasilache if (!stride.has_value()) { 6035ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no stride\n"); 6045ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no stride"); 6055ef7ceaeSNicolas Vasilache } 6065ef7ceaeSNicolas Vasilache 6075ef7ceaeSNicolas Vasilache auto it = valueMapping.find(op.getVector()); 6085ef7ceaeSNicolas Vasilache if (it == valueMapping.end()) { 6095ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no mapping\n"); 6105ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 6115ef7ceaeSNicolas Vasilache } 6125ef7ceaeSNicolas Vasilache 6135ef7ceaeSNicolas Vasilache Value matrix = it->second; 6145ef7ceaeSNicolas Vasilache auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( 6153d35546cSNavdeep Katel op.getLoc(), matrix, op.getSource(), op.getIndices(), 6165ef7ceaeSNicolas Vasilache rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); 6175ef7ceaeSNicolas Vasilache (void)store; 6185ef7ceaeSNicolas Vasilache 6195ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); 6205ef7ceaeSNicolas Vasilache 6215ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 6225ef7ceaeSNicolas Vasilache rewriter.eraseOp(op); 6235ef7ceaeSNicolas Vasilache return success(); 624edd9515bSthomasraoux } 625edd9515bSthomasraoux 6261ca772edSChristopher Bate /// Returns the vector type which represents a matrix fragment. 6271ca772edSChristopher Bate static VectorType 6281ca772edSChristopher Bate getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 6291ca772edSChristopher Bate SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 6301ca772edSChristopher Bate regInfo.elementsPerRegister}; 6311ca772edSChristopher Bate Type elType = regInfo.registerLLVMType; 6325550c821STres Popp if (auto vecType = dyn_cast<VectorType>(elType)) 6331ca772edSChristopher Bate elType = vecType.getElementType(); 6341ca772edSChristopher Bate return VectorType::get(shape, elType); 6351ca772edSChristopher Bate } 6361ca772edSChristopher Bate 6371ca772edSChristopher Bate /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 6381ca772edSChristopher Bate static LogicalResult 6395ef7ceaeSNicolas Vasilache convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, 6401ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 6415ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 6425ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 6435ef7ceaeSNicolas Vasilache 6441ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 6451ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 6465ef7ceaeSNicolas Vasilache if (failed(warpMatrixInfo)) { 6475ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 6485ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 6495ef7ceaeSNicolas Vasilache } 6501ca772edSChristopher Bate 6511ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 6521ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 6535ef7ceaeSNicolas Vasilache if (failed(regInfo)) { 6545ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 6555ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 6565ef7ceaeSNicolas Vasilache } 6571ca772edSChristopher Bate 6581ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 6595550c821STres Popp auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); 6605ef7ceaeSNicolas Vasilache if (!dense) { 6615ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "not a splat\n"); 6625ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not a splat"); 6635ef7ceaeSNicolas Vasilache } 6645ef7ceaeSNicolas Vasilache 6655ef7ceaeSNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 6661ca772edSChristopher Bate op.getLoc(), vectorType, 6671ca772edSChristopher Bate DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 6681ca772edSChristopher Bate valueMapping[op.getResult()] = result; 6691ca772edSChristopher Bate return success(); 6701ca772edSChristopher Bate } 6711ca772edSChristopher Bate 6726b82fc77SManish Gupta /// Check if the loaded matrix operand requires transposed. 6736b82fc77SManish Gupta /// Transposed Map Example: 6746b82fc77SManish Gupta /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) 6756b82fc77SManish Gupta /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) 6766b82fc77SManish Gupta /// The code below checks if the output 2D is transposed using a generalized 6776b82fc77SManish Gupta /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) 6786b82fc77SManish Gupta /// Returns : true; if m > n, false o.w. 67984eed784SManish Gupta static FailureOr<bool> isTransposed(vector::TransferReadOp op) { 6806b82fc77SManish Gupta mlir::AffineMap map = op.getPermutationMap(); 68184eed784SManish Gupta 6826b82fc77SManish Gupta if (map.getNumResults() != 2) { 68384eed784SManish Gupta LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " 68484eed784SManish Gupta "is not a 2d operand\n"); 68584eed784SManish Gupta return failure(); 6866b82fc77SManish Gupta } 6876b82fc77SManish Gupta 6886b82fc77SManish Gupta // Output 2D matrix dimensions in the order of d0, d1. 68984eed784SManish Gupta mlir::AffineExpr dM = map.getResult(0); 69084eed784SManish Gupta mlir::AffineExpr dN = map.getResult(1); 6916b82fc77SManish Gupta 6926b82fc77SManish Gupta // Find the position of these expressions in the input. 6931609f1c2Slong.chen auto exprM = dyn_cast<AffineDimExpr>(dM); 6941609f1c2Slong.chen auto exprN = dyn_cast<AffineDimExpr>(dN); 69584eed784SManish Gupta 6966b82fc77SManish Gupta if (!exprM || !exprN) { 69784eed784SManish Gupta LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " 69884eed784SManish Gupta "expressions, then transpose cannot be determined.\n"); 69984eed784SManish Gupta return failure(); 7006b82fc77SManish Gupta } 70184eed784SManish Gupta 7026b82fc77SManish Gupta return exprM.getPosition() > exprN.getPosition(); 7036b82fc77SManish Gupta } 7046b82fc77SManish Gupta 7051ca772edSChristopher Bate static LogicalResult 7065ef7ceaeSNicolas Vasilache creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, 7071ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 7085ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 7095ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 7101ca772edSChristopher Bate Location loc = op->getLoc(); 7111ca772edSChristopher Bate 7121ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 7131ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 7145ef7ceaeSNicolas Vasilache if (failed(warpMatrixInfo)) { 7155ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); 7165ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 7175ef7ceaeSNicolas Vasilache } 7181ca772edSChristopher Bate 7191ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 7201ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 7215ef7ceaeSNicolas Vasilache if (failed(regInfo)) { 7225ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); 7235ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 7245ef7ceaeSNicolas Vasilache } 7251ca772edSChristopher Bate 72684eed784SManish Gupta FailureOr<bool> transpose = isTransposed(op); 72784eed784SManish Gupta if (failed(transpose)) { 72884eed784SManish Gupta LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); 72984eed784SManish Gupta return rewriter.notifyMatchFailure( 73084eed784SManish Gupta op, "Op should likely not be converted to a nvgpu.ldmatrix call."); 73184eed784SManish Gupta } 73284eed784SManish Gupta 7336b82fc77SManish Gupta FailureOr<nvgpu::LdMatrixParams> params = 73484eed784SManish Gupta nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); 7356b82fc77SManish Gupta 7361ca772edSChristopher Bate if (failed(params)) { 7375ef7ceaeSNicolas Vasilache LLVM_DEBUG( 7385ef7ceaeSNicolas Vasilache DBGS() 7395ef7ceaeSNicolas Vasilache << "failed to convert vector.transfer_read to ldmatrix. " 7405ef7ceaeSNicolas Vasilache << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); 7415ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure( 7425ef7ceaeSNicolas Vasilache op, "failed to convert vector.transfer_read to ldmatrix; this op " 7435ef7ceaeSNicolas Vasilache "likely should not be converted to a nvgpu.ldmatrix call."); 7441ca772edSChristopher Bate } 7451ca772edSChristopher Bate 7461ca772edSChristopher Bate // Adjust the load offset. 74743fd4c49SKrzysztof Drewniak auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 7481ca772edSChristopher Bate FailureOr<AffineMap> offsets = 7495ef7ceaeSNicolas Vasilache nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); 7505ef7ceaeSNicolas Vasilache if (failed(offsets)) { 7515ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no offsets\n"); 7525ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no offsets"); 7535ef7ceaeSNicolas Vasilache } 7541ca772edSChristopher Bate 7551ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 7561ca772edSChristopher Bate 7571ca772edSChristopher Bate SmallVector<Value, 4> indices; 7585ef7ceaeSNicolas Vasilache getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, 7591ca772edSChristopher Bate indices); 76084eed784SManish Gupta 7615ef7ceaeSNicolas Vasilache nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( 76284eed784SManish Gupta loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); 7631ca772edSChristopher Bate valueMapping[op] = newOp->getResult(0); 7641ca772edSChristopher Bate return success(); 7651ca772edSChristopher Bate } 7661ca772edSChristopher Bate 7671ca772edSChristopher Bate static LogicalResult 7685ef7ceaeSNicolas Vasilache createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, 7691ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 7705ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 7715ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 7725ef7ceaeSNicolas Vasilache 7731ca772edSChristopher Bate Location loc = op.getLoc(); 7741ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 7751ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 7761ca772edSChristopher Bate if (failed(warpMatrixInfo)) 7775ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 7781ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 7791ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 7801ca772edSChristopher Bate if (failed(regInfo)) { 781c0f504dfSJie Fu return rewriter.notifyMatchFailure( 7825ef7ceaeSNicolas Vasilache op, "Failed to deduce register fragment type during " 7835ef7ceaeSNicolas Vasilache "conversion to distributed non-ldmatrix compatible load"); 7841ca772edSChristopher Bate } 7851ca772edSChristopher Bate 78643fd4c49SKrzysztof Drewniak Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 7871ca772edSChristopher Bate SmallVector<Value, 4> elements; 7881ca772edSChristopher Bate 7891ca772edSChristopher Bate // This is the individual element type. 7901ca772edSChristopher Bate Type loadedElType = regInfo->registerLLVMType; 7911ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 7921ca772edSChristopher Bate 7935ef7ceaeSNicolas Vasilache Value fill = rewriter.create<arith::ConstantOp>( 7941ca772edSChristopher Bate op.getLoc(), vectorType.getElementType(), 7955ef7ceaeSNicolas Vasilache rewriter.getZeroAttr(vectorType.getElementType())); 7965ef7ceaeSNicolas Vasilache Value result = 7975ef7ceaeSNicolas Vasilache rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 7981ca772edSChristopher Bate 7991ca772edSChristopher Bate bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 8001ca772edSChristopher Bate 8013af64383SNicolas Vasilache // If we are not transposing, then we can use vectorized loads. Otherwise, we 8023af64383SNicolas Vasilache // must load each element individually. 803670eee08SChristopher Bate if (!isTransposeLoad) { 8045550c821STres Popp if (!isa<VectorType>(loadedElType)) { 8051ca772edSChristopher Bate loadedElType = VectorType::get({1}, loadedElType); 8061ca772edSChristopher Bate } 8071ca772edSChristopher Bate 8081ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 8091ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 8105ef7ceaeSNicolas Vasilache rewriter, op.getLoc(), *warpMatrixInfo); 8111ca772edSChristopher Bate if (failed(coords)) 8125ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no coords"); 8135ef7ceaeSNicolas Vasilache 8145ef7ceaeSNicolas Vasilache Value logicalValueId = rewriter.create<arith::ConstantOp>( 8155ef7ceaeSNicolas Vasilache loc, rewriter.getIndexType(), 8165ef7ceaeSNicolas Vasilache rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 8171ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 8181ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 8195ef7ceaeSNicolas Vasilache rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 8201ca772edSChristopher Bate 8215ef7ceaeSNicolas Vasilache Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, 8221ca772edSChristopher Bate op.getSource(), newIndices); 82316b75cd2SMatthias Springer result = rewriter.create<vector::InsertOp>(loc, el, result, i); 8241ca772edSChristopher Bate } 825670eee08SChristopher Bate } else { 8265550c821STres Popp if (auto vecType = dyn_cast<VectorType>(loadedElType)) { 8271ca772edSChristopher Bate loadedElType = vecType.getElementType(); 8281ca772edSChristopher Bate } 8291ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 8301ca772edSChristopher Bate for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 8311ca772edSChristopher Bate innerIdx++) { 8321ca772edSChristopher Bate 8335ef7ceaeSNicolas Vasilache Value logicalValueId = rewriter.create<arith::ConstantOp>( 8345ef7ceaeSNicolas Vasilache loc, rewriter.getIndexType(), 8355ef7ceaeSNicolas Vasilache rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 8361ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 8375ef7ceaeSNicolas Vasilache rewriter, op.getLoc(), *warpMatrixInfo); 8381ca772edSChristopher Bate if (failed(coords)) 8395ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no coords"); 8401ca772edSChristopher Bate 8411ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 8421ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 8435ef7ceaeSNicolas Vasilache rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 8445ef7ceaeSNicolas Vasilache Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, 8451ca772edSChristopher Bate op.getSource(), newIndices); 8465ef7ceaeSNicolas Vasilache result = rewriter.create<vector::InsertOp>( 84716b75cd2SMatthias Springer op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); 8481ca772edSChristopher Bate } 8491ca772edSChristopher Bate } 8501ca772edSChristopher Bate } 8511ca772edSChristopher Bate 8521ca772edSChristopher Bate valueMapping[op.getResult()] = result; 8531ca772edSChristopher Bate return success(); 8541ca772edSChristopher Bate } 8551ca772edSChristopher Bate 856066b4fcbSThomas Raoux /// Return true if this is a shared memory memref type. 857066b4fcbSThomas Raoux static bool isSharedMemory(MemRefType type) { 858066b4fcbSThomas Raoux auto addressSpace = 8595550c821STres Popp dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 860b984045dSMehdi Amini return addressSpace && 861b984045dSMehdi Amini addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); 862066b4fcbSThomas Raoux } 863066b4fcbSThomas Raoux 8641ca772edSChristopher Bate /// Converts a `vector.transfer_read` operation directly to either a 8651ca772edSChristopher Bate /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 8661ca772edSChristopher Bate /// used when converting to `nvgpu.mma.sync` operations. 8671ca772edSChristopher Bate static LogicalResult 8685ef7ceaeSNicolas Vasilache convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, 8691ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 8705ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 8715ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 8721ca772edSChristopher Bate 8731ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 8741ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 8751ca772edSChristopher Bate if (failed(warpMatrixInfo)) 8765ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 8771ca772edSChristopher Bate 8781ca772edSChristopher Bate bool isLdMatrixCompatible = 8795550c821STres Popp isSharedMemory(cast<MemRefType>(op.getSource().getType())) && 8801ca772edSChristopher Bate nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 8811ca772edSChristopher Bate 8821ca772edSChristopher Bate VectorType vecTy = op.getVectorType(); 8831ca772edSChristopher Bate int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 8841ca772edSChristopher Bate 8853af64383SNicolas Vasilache // When we are transposing the B operand, ldmatrix will only work if we have 8863af64383SNicolas Vasilache // at least 8 rows to read and the width to read for the transpose is 128 8873af64383SNicolas Vasilache // bits. 8881ca772edSChristopher Bate if (!op.getPermutationMap().isMinorIdentity() && 889271a48e0SThomas Raoux (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 890271a48e0SThomas Raoux vecTy.getDimSize(0) * bitWidth < 128)) 8911ca772edSChristopher Bate isLdMatrixCompatible = false; 8921ca772edSChristopher Bate 8931ca772edSChristopher Bate if (!isLdMatrixCompatible) 8945ef7ceaeSNicolas Vasilache return createNonLdMatrixLoads(rewriter, op, valueMapping); 8951ca772edSChristopher Bate 8965ef7ceaeSNicolas Vasilache return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); 8971ca772edSChristopher Bate } 8981ca772edSChristopher Bate 8991ca772edSChristopher Bate static LogicalResult 9005ef7ceaeSNicolas Vasilache convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, 9011ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 9025ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 9035ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 9045ef7ceaeSNicolas Vasilache 9051ca772edSChristopher Bate Location loc = op->getLoc(); 9065ef7ceaeSNicolas Vasilache auto it = valueMapping.find(op.getVector()); 9075ef7ceaeSNicolas Vasilache if (it == valueMapping.end()) 9085ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 9095ef7ceaeSNicolas Vasilache Value matrix = it->second; 9101ca772edSChristopher Bate 9111ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 9121ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 9131ca772edSChristopher Bate if (failed(warpMatrixInfo)) 9145ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 9151ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 9161ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 9171ca772edSChristopher Bate if (failed(regInfo)) 9185ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "not mma sync reg info"); 9191ca772edSChristopher Bate 9201ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 92143fd4c49SKrzysztof Drewniak Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); 9221ca772edSChristopher Bate 9231ca772edSChristopher Bate for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 9245ef7ceaeSNicolas Vasilache Value logicalValueId = rewriter.create<arith::ConstantOp>( 9255ef7ceaeSNicolas Vasilache loc, rewriter.getIndexType(), 9265ef7ceaeSNicolas Vasilache rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); 9271ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 9285ef7ceaeSNicolas Vasilache rewriter, op.getLoc(), *warpMatrixInfo); 9291ca772edSChristopher Bate if (failed(coords)) 9305ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no coords"); 9311ca772edSChristopher Bate 9325ef7ceaeSNicolas Vasilache Value el = 9335ef7ceaeSNicolas Vasilache rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 9341ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 9351ca772edSChristopher Bate getXferIndices<vector::TransferWriteOp>( 9365ef7ceaeSNicolas Vasilache rewriter, op, *coords, {laneId, logicalValueId}, newIndices); 9375ef7ceaeSNicolas Vasilache rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 9381ca772edSChristopher Bate } 9395ef7ceaeSNicolas Vasilache 9405ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 9415ef7ceaeSNicolas Vasilache rewriter.eraseOp(op); 9421ca772edSChristopher Bate return success(); 9431ca772edSChristopher Bate } 9441ca772edSChristopher Bate 945114ba722SManish Gupta static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 946114ba722SManish Gupta SmallVectorImpl<int64_t> &results) { 947114ba722SManish Gupta for (auto attr : arrayAttr) 9485550c821STres Popp results.push_back(cast<IntegerAttr>(attr).getInt()); 949114ba722SManish Gupta } 950114ba722SManish Gupta 951114ba722SManish Gupta static LogicalResult 9525ef7ceaeSNicolas Vasilache convertExtractStridedSlice(RewriterBase &rewriter, 9535ef7ceaeSNicolas Vasilache vector::ExtractStridedSliceOp op, 954114ba722SManish Gupta llvm::DenseMap<Value, Value> &valueMapping) { 9555ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 9565ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 957114ba722SManish Gupta 958114ba722SManish Gupta Location loc = op->getLoc(); 959114ba722SManish Gupta 960114ba722SManish Gupta FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 961114ba722SManish Gupta nvgpu::getWarpMatrixInfo(op); 962114ba722SManish Gupta if (failed(warpMatrixInfo)) 9635ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 964114ba722SManish Gupta 965114ba722SManish Gupta FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = 966114ba722SManish Gupta nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 967114ba722SManish Gupta if (failed(mmaSyncFragmentInfo)) 9685ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); 969114ba722SManish Gupta 9703af64383SNicolas Vasilache // Find the vector.transer_read whose result vector is being sliced. 971114ba722SManish Gupta auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); 972114ba722SManish Gupta if (!transferReadOp) 9735ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no transfer read"); 974114ba722SManish Gupta 975114ba722SManish Gupta warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); 976114ba722SManish Gupta if (failed(warpMatrixInfo)) 9775ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); 978114ba722SManish Gupta 979114ba722SManish Gupta FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = 980114ba722SManish Gupta nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 981114ba722SManish Gupta if (failed(ldFragmentInfo)) 9825ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); 983114ba722SManish Gupta 9843af64383SNicolas Vasilache assert( 9853af64383SNicolas Vasilache (mmaSyncFragmentInfo->elementsPerRegister == 986114ba722SManish Gupta ldFragmentInfo->elementsPerRegister) && 9873af64383SNicolas Vasilache "Number of elements per register should be same for load and mma.sync"); 988114ba722SManish Gupta 9893af64383SNicolas Vasilache // Create vector.extract_strided_slice op for thread-owned fragments. 990114ba722SManish Gupta std::array<int64_t, 2> strides = {1, 991114ba722SManish Gupta 1}; // stride for extract slice is always 1. 992114ba722SManish Gupta std::array<int64_t, 2> sliceShape = { 993114ba722SManish Gupta mmaSyncFragmentInfo->numRegistersPerFragment, 994114ba722SManish Gupta mmaSyncFragmentInfo->elementsPerRegister}; 9955ef7ceaeSNicolas Vasilache auto it = valueMapping.find(transferReadOp); 9965ef7ceaeSNicolas Vasilache if (it == valueMapping.end()) 9975ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 9985ef7ceaeSNicolas Vasilache auto sourceVector = it->second; 999114ba722SManish Gupta 1000114ba722SManish Gupta // offset and sizes at warp-level of onwership. 1001114ba722SManish Gupta SmallVector<int64_t> offsets; 1002114ba722SManish Gupta populateFromInt64AttrArray(op.getOffsets(), offsets); 1003114ba722SManish Gupta 1004114ba722SManish Gupta SmallVector<int64_t> sizes; 1005114ba722SManish Gupta populateFromInt64AttrArray(op.getSizes(), sizes); 1006a1aad28dSLei Zhang ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); 1007114ba722SManish Gupta 10083af64383SNicolas Vasilache // Compute offset in vector registers. Note that the mma.sync vector registers 10093af64383SNicolas Vasilache // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector 10103af64383SNicolas Vasilache // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. 1011114ba722SManish Gupta std::array<int64_t, 2> sliceOffset = {0, 0}; 1012114ba722SManish Gupta 1013114ba722SManish Gupta if (offsets[0] && offsets[1]) 1014114ba722SManish Gupta return op->emitError() << "Slicing fragments in 2D is not supported. "; 10156a7a1188SMehdi Amini if (offsets[0]) 1016114ba722SManish Gupta sliceOffset[0] = (warpVectorShape[0] / offsets[0]); 1017114ba722SManish Gupta else if (offsets[1]) 1018114ba722SManish Gupta sliceOffset[0] = (warpVectorShape[1] / offsets[1]); 1019114ba722SManish Gupta 10205ef7ceaeSNicolas Vasilache Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( 1021114ba722SManish Gupta loc, sourceVector, sliceOffset, sliceShape, strides); 1022114ba722SManish Gupta 1023114ba722SManish Gupta valueMapping[op] = newOp; 1024114ba722SManish Gupta return success(); 1025114ba722SManish Gupta } 1026114ba722SManish Gupta 10275ef7ceaeSNicolas Vasilache static LogicalResult 10285ef7ceaeSNicolas Vasilache convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, 1029edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 10305ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 10315ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 10325ef7ceaeSNicolas Vasilache 10335ef7ceaeSNicolas Vasilache auto itA = valueMapping.find(op.getLhs()); 10345ef7ceaeSNicolas Vasilache auto itB = valueMapping.find(op.getRhs()); 10355ef7ceaeSNicolas Vasilache auto itC = valueMapping.find(op.getAcc()); 10365ef7ceaeSNicolas Vasilache if (itA == valueMapping.end() || itB == valueMapping.end() || 10375ef7ceaeSNicolas Vasilache itC == valueMapping.end()) 10385ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 10395ef7ceaeSNicolas Vasilache Value opA = itA->second, opB = itB->second, opC = itC->second; 10405ef7ceaeSNicolas Vasilache Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( 10413d35546cSNavdeep Katel op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), 10423d35546cSNavdeep Katel /*b_transpose=*/UnitAttr()); 1043edd9515bSthomasraoux valueMapping[op.getResult()] = matmul; 10445ef7ceaeSNicolas Vasilache return success(); 1045edd9515bSthomasraoux } 1046edd9515bSthomasraoux 10471ca772edSChristopher Bate static LogicalResult 10485ef7ceaeSNicolas Vasilache convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, 10491ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 10505ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 10515ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 10525ef7ceaeSNicolas Vasilache 10535ef7ceaeSNicolas Vasilache auto itA = valueMapping.find(op.getLhs()); 10545ef7ceaeSNicolas Vasilache auto itB = valueMapping.find(op.getRhs()); 10555ef7ceaeSNicolas Vasilache auto itC = valueMapping.find(op.getAcc()); 10565ef7ceaeSNicolas Vasilache if (itA == valueMapping.end() || itB == valueMapping.end() || 10575ef7ceaeSNicolas Vasilache itC == valueMapping.end()) 10585ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 10595ef7ceaeSNicolas Vasilache Value opA = itA->second, opB = itB->second, opC = itC->second; 10605550c821STres Popp int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; 10615550c821STres Popp int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; 10625550c821STres Popp int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; 10635ef7ceaeSNicolas Vasilache Value matmul = rewriter.create<nvgpu::MmaSyncOp>( 10645ef7ceaeSNicolas Vasilache op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); 10651ca772edSChristopher Bate valueMapping[op.getResult()] = matmul; 10661ca772edSChristopher Bate return success(); 10671ca772edSChristopher Bate } 10681ca772edSChristopher Bate 10696413226dSthomasraoux /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 10705ef7ceaeSNicolas Vasilache static LogicalResult 10715ef7ceaeSNicolas Vasilache convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, 10726413226dSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 10735ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 10745ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 10755ef7ceaeSNicolas Vasilache 10766413226dSthomasraoux assert(constantSupportsMMAMatrixType(op)); 10775ef7ceaeSNicolas Vasilache 1078e1795322SJeff Niu auto splat = 10795550c821STres Popp cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); 10806413226dSthomasraoux auto scalarConstant = 10815ef7ceaeSNicolas Vasilache rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 10826413226dSthomasraoux const char *fragType = inferFragType(op); 10835550c821STres Popp auto vecType = cast<VectorType>(op.getType()); 10846413226dSthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 10856413226dSthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 10865ef7ceaeSNicolas Vasilache auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 10875ef7ceaeSNicolas Vasilache op.getLoc(), type, scalarConstant); 10886413226dSthomasraoux valueMapping[op.getResult()] = matrix; 10895ef7ceaeSNicolas Vasilache return success(); 10906413226dSthomasraoux } 10916413226dSthomasraoux 109243928419Sthomasraoux /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 10935ef7ceaeSNicolas Vasilache static LogicalResult 10945ef7ceaeSNicolas Vasilache convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, 109543928419Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 10965ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 10975ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 10985ef7ceaeSNicolas Vasilache 10993af64383SNicolas Vasilache assert(broadcastSupportsMMAMatrixType(op)); 11005ef7ceaeSNicolas Vasilache 110143928419Sthomasraoux const char *fragType = inferFragType(op); 1102a1aad28dSLei Zhang auto vecType = op.getResultVectorType(); 110343928419Sthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 110443928419Sthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 11055ef7ceaeSNicolas Vasilache auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( 11065ef7ceaeSNicolas Vasilache op.getLoc(), type, op.getSource()); 110743928419Sthomasraoux valueMapping[op.getResult()] = matrix; 11085ef7ceaeSNicolas Vasilache return success(); 110943928419Sthomasraoux } 111043928419Sthomasraoux 11111a865592Sthomasraoux // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 11129b5ef2beSMatthias Springer // updated and needs to be updated separately for the loop to be correct. 11135ef7ceaeSNicolas Vasilache static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, 11145ef7ceaeSNicolas Vasilache scf::ForOp loop, 11155cf714bbSMatthias Springer ValueRange newInitArgs) { 11165ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 11175ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(loop); 11185ef7ceaeSNicolas Vasilache 11193af64383SNicolas Vasilache // Create a new loop before the existing one, with the extra operands. 11205ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(loop); 11215cf714bbSMatthias Springer auto operands = llvm::to_vector<4>(loop.getInitArgs()); 11225cf714bbSMatthias Springer llvm::append_range(operands, newInitArgs); 11235ef7ceaeSNicolas Vasilache scf::ForOp newLoop = rewriter.create<scf::ForOp>( 11245ef7ceaeSNicolas Vasilache loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 11255ef7ceaeSNicolas Vasilache operands); 112662bf7710SMatthias Springer rewriter.eraseBlock(newLoop.getBody()); 11275ef7ceaeSNicolas Vasilache 11289b5ef2beSMatthias Springer newLoop.getRegion().getBlocks().splice( 11299b5ef2beSMatthias Springer newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); 11305cf714bbSMatthias Springer for (Value operand : newInitArgs) 1131e084679fSRiver Riddle newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 11321a865592Sthomasraoux 11331a865592Sthomasraoux for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 11341a865592Sthomasraoux loop.getNumResults()))) 11355ef7ceaeSNicolas Vasilache rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); 11365ef7ceaeSNicolas Vasilache 11375ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); 11385ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); 11395ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "erase: " << loop); 11405ef7ceaeSNicolas Vasilache 11415ef7ceaeSNicolas Vasilache rewriter.eraseOp(loop); 11421a865592Sthomasraoux return newLoop; 11431a865592Sthomasraoux } 11441a865592Sthomasraoux 11455ef7ceaeSNicolas Vasilache static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, 11461a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 11475ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 11485ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 11495ef7ceaeSNicolas Vasilache 11501a865592Sthomasraoux SmallVector<Value> newOperands; 11511a865592Sthomasraoux SmallVector<std::pair<size_t, size_t>> argMapping; 11525cf714bbSMatthias Springer for (const auto &operand : llvm::enumerate(op.getInitArgs())) { 11531a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 11545ef7ceaeSNicolas Vasilache if (it == valueMapping.end()) { 11555ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); 11561a865592Sthomasraoux continue; 11575ef7ceaeSNicolas Vasilache } 11581a865592Sthomasraoux argMapping.push_back(std::make_pair( 11595cf714bbSMatthias Springer operand.index(), op.getInitArgs().size() + newOperands.size())); 11601a865592Sthomasraoux newOperands.push_back(it->second); 11611a865592Sthomasraoux } 11625ef7ceaeSNicolas Vasilache 11635ef7ceaeSNicolas Vasilache scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); 11641a865592Sthomasraoux Block &loopBody = *newForOp.getBody(); 11651a865592Sthomasraoux for (auto mapping : argMapping) { 11661a865592Sthomasraoux valueMapping[newForOp.getResult(mapping.first)] = 11671a865592Sthomasraoux newForOp.getResult(mapping.second); 11681a865592Sthomasraoux valueMapping[loopBody.getArgument(mapping.first + 11691a865592Sthomasraoux newForOp.getNumInductionVars())] = 11701a865592Sthomasraoux loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 11711a865592Sthomasraoux } 11725ef7ceaeSNicolas Vasilache 11735ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); 11745ef7ceaeSNicolas Vasilache return success(); 11751a865592Sthomasraoux } 11761a865592Sthomasraoux 11775ef7ceaeSNicolas Vasilache static LogicalResult 11785ef7ceaeSNicolas Vasilache convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, 11791a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 11805ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 11815ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 11825ef7ceaeSNicolas Vasilache 11831a865592Sthomasraoux auto loop = cast<scf::ForOp>(op->getParentOp()); 11841a865592Sthomasraoux auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 1185e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(op.getOperands())) { 11861a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 11871a865592Sthomasraoux if (it == valueMapping.end()) 11881a865592Sthomasraoux continue; 11893af64383SNicolas Vasilache // Replace the yield of old value with the for op argument to make it easier 11903af64383SNicolas Vasilache // to remove the dead code. 11915cf714bbSMatthias Springer yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; 11921a865592Sthomasraoux yieldOperands.push_back(it->second); 11931a865592Sthomasraoux } 11945ef7ceaeSNicolas Vasilache rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); 11955ef7ceaeSNicolas Vasilache 11965ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); 11975ef7ceaeSNicolas Vasilache rewriter.eraseOp(op); 11985ef7ceaeSNicolas Vasilache return success(); 11991a865592Sthomasraoux } 12001a865592Sthomasraoux 12017fbb0678Sthomasraoux /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 12025ef7ceaeSNicolas Vasilache static LogicalResult 12035ef7ceaeSNicolas Vasilache convertElementwiseOp(RewriterBase &rewriter, Operation *op, 12045ef7ceaeSNicolas Vasilache gpu::MMAElementwiseOp opType, 12057fbb0678Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 12065ef7ceaeSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 12075ef7ceaeSNicolas Vasilache rewriter.setInsertionPoint(op); 12085ef7ceaeSNicolas Vasilache 12097fbb0678Sthomasraoux SmallVector<Value> matrixOperands; 12105ef7ceaeSNicolas Vasilache for (Value operand : op->getOperands()) { 12115ef7ceaeSNicolas Vasilache auto it = valueMapping.find(operand); 12125ef7ceaeSNicolas Vasilache if (it == valueMapping.end()) 12135ef7ceaeSNicolas Vasilache return rewriter.notifyMatchFailure(op, "no mapping"); 12145ef7ceaeSNicolas Vasilache matrixOperands.push_back(it->second); 12155ef7ceaeSNicolas Vasilache } 1216a5757c5bSChristian Sigg auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType()); 1217a0119437SLei Zhang if (opType == gpu::MMAElementwiseOp::EXTF) { 1218a0119437SLei Zhang // The floating point extension case has a different result type. 1219a5757c5bSChristian Sigg auto vectorType = cast<VectorType>(op->getResultTypes()[0]); 1220a0119437SLei Zhang resultType = gpu::MMAMatrixType::get(resultType.getShape(), 1221a0119437SLei Zhang vectorType.getElementType(), 1222a0119437SLei Zhang resultType.getOperand()); 1223a0119437SLei Zhang } 1224a0119437SLei Zhang 12255ef7ceaeSNicolas Vasilache Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( 1226a0119437SLei Zhang op->getLoc(), resultType, matrixOperands, opType); 12277fbb0678Sthomasraoux valueMapping[op->getResult(0)] = newOp; 12285ef7ceaeSNicolas Vasilache return success(); 12297fbb0678Sthomasraoux } 12307fbb0678Sthomasraoux 12311ca772edSChristopher Bate void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 12321ca772edSChristopher Bate bool useNvGpu) { 12331ca772edSChristopher Bate if (!useNvGpu) { 1234edd9515bSthomasraoux patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 1235edd9515bSthomasraoux patterns.getContext()); 12361ca772edSChristopher Bate return; 12371ca772edSChristopher Bate } 1238fb7ef637SJakub Kuderski vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 1239fb7ef637SJakub Kuderski patterns.add<CombineTransferReadOpTranspose>(patterns.getContext()); 1240edd9515bSthomasraoux } 1241edd9515bSthomasraoux 12425ef7ceaeSNicolas Vasilache LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, 12435ef7ceaeSNicolas Vasilache Operation *rootOp) { 12441ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 1245edd9515bSthomasraoux llvm::DenseMap<Value, Value> valueMapping; 12465ef7ceaeSNicolas Vasilache 12475ef7ceaeSNicolas Vasilache auto globalRes = LogicalResult::success(); 1248edd9515bSthomasraoux for (Operation *op : ops) { 12495ef7ceaeSNicolas Vasilache LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); 12505ef7ceaeSNicolas Vasilache // Apparently callers do not want to early exit on failure here. 12515ef7ceaeSNicolas Vasilache auto res = LogicalResult::success(); 1252edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 12535ef7ceaeSNicolas Vasilache res = convertTransferReadOp(rewriter, transferRead, valueMapping); 1254edd9515bSthomasraoux } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 12555ef7ceaeSNicolas Vasilache res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); 1256edd9515bSthomasraoux } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 12575ef7ceaeSNicolas Vasilache res = convertContractOp(rewriter, contractOp, valueMapping); 1258a54f4eaeSMogball } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 12595ef7ceaeSNicolas Vasilache res = convertConstantOp(rewriter, constantOp, valueMapping); 126043928419Sthomasraoux } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 12615ef7ceaeSNicolas Vasilache res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); 12621a865592Sthomasraoux } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 12635ef7ceaeSNicolas Vasilache res = convertForOp(rewriter, forOp, valueMapping); 12645ef7ceaeSNicolas Vasilache } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { 12655ef7ceaeSNicolas Vasilache res = convertYieldOp(rewriter, yieldOp, valueMapping); 12667fbb0678Sthomasraoux } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 12675ef7ceaeSNicolas Vasilache res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); 1268edd9515bSthomasraoux } 12695ef7ceaeSNicolas Vasilache if (failed(res)) 12705ef7ceaeSNicolas Vasilache globalRes = failure(); 1271edd9515bSthomasraoux } 12725ef7ceaeSNicolas Vasilache return globalRes; 1273edd9515bSthomasraoux } 1274edd9515bSthomasraoux 12755ef7ceaeSNicolas Vasilache LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, 12765ef7ceaeSNicolas Vasilache Operation *rootOp) { 12771ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 12781ca772edSChristopher Bate llvm::DenseMap<Value, Value> valueMapping; 12791ca772edSChristopher Bate for (Operation *op : ops) { 12801ca772edSChristopher Bate if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 12811ca772edSChristopher Bate .Case([&](vector::TransferReadOp transferReadOp) { 12825ef7ceaeSNicolas Vasilache return convertTransferReadToLoads(rewriter, transferReadOp, 12835ef7ceaeSNicolas Vasilache valueMapping); 12841ca772edSChristopher Bate }) 12851ca772edSChristopher Bate .Case([&](vector::TransferWriteOp transferWriteOp) { 12865ef7ceaeSNicolas Vasilache return convertTransferWriteToStores(rewriter, transferWriteOp, 12871ca772edSChristopher Bate valueMapping); 12881ca772edSChristopher Bate }) 1289114ba722SManish Gupta .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { 12905ef7ceaeSNicolas Vasilache return convertExtractStridedSlice(rewriter, extractStridedSliceOp, 1291114ba722SManish Gupta valueMapping); 1292114ba722SManish Gupta }) 12931ca772edSChristopher Bate .Case([&](vector::ContractionOp contractionOp) { 12945ef7ceaeSNicolas Vasilache return convertContractOpToMmaSync(rewriter, contractionOp, 12955ef7ceaeSNicolas Vasilache valueMapping); 12961ca772edSChristopher Bate }) 12971ca772edSChristopher Bate .Case([&](scf::ForOp forOp) { 12985ef7ceaeSNicolas Vasilache return convertForOp(rewriter, forOp, valueMapping); 12991ca772edSChristopher Bate }) 13001ca772edSChristopher Bate .Case([&](scf::YieldOp yieldOp) { 13015ef7ceaeSNicolas Vasilache return convertYieldOp(rewriter, yieldOp, valueMapping); 13021ca772edSChristopher Bate }) 13031ca772edSChristopher Bate .Case([&](arith::ConstantOp constOp) { 13045ef7ceaeSNicolas Vasilache return convertConstantOpMmaSync(rewriter, constOp, valueMapping); 13051ca772edSChristopher Bate }) 13061ca772edSChristopher Bate .Default([&](Operation *op) { 13075ef7ceaeSNicolas Vasilache return op->emitError() << "unhandled vector to mma type: " << *op; 13081ca772edSChristopher Bate }) 13091ca772edSChristopher Bate .failed()) { 1310cafb6284SChristopher Bate return op->emitOpError() 1311cafb6284SChristopher Bate << "failed to convert op during vector-to-nvgpu conversion"; 13121ca772edSChristopher Bate } 13131ca772edSChristopher Bate } 13141ca772edSChristopher Bate return success(); 13151ca772edSChristopher Bate } 13161ca772edSChristopher Bate 1317edd9515bSthomasraoux namespace { 1318edd9515bSthomasraoux 1319edd9515bSthomasraoux struct ConvertVectorToGPUPass 132067d0d7acSMichele Scuttari : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 13211ca772edSChristopher Bate 13221ca772edSChristopher Bate explicit ConvertVectorToGPUPass(bool useNvGpu_) { 13231ca772edSChristopher Bate useNvGpu.setValue(useNvGpu_); 13241ca772edSChristopher Bate } 13251ca772edSChristopher Bate 132641574554SRiver Riddle void runOnOperation() override { 132747f175b0SRiver Riddle RewritePatternSet patterns(&getContext()); 13281ca772edSChristopher Bate populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 132909dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 13301ca772edSChristopher Bate return signalPassFailure(); 1331edd9515bSthomasraoux 13325ef7ceaeSNicolas Vasilache IRRewriter rewriter(&getContext()); 1333cafb6284SChristopher Bate if (useNvGpu) { 13345ef7ceaeSNicolas Vasilache if (failed( 13355ef7ceaeSNicolas Vasilache convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) 13361ca772edSChristopher Bate return signalPassFailure(); 1337cafb6284SChristopher Bate return; 13381ca772edSChristopher Bate } 13395ef7ceaeSNicolas Vasilache (void)convertVectorToMMAOps(rewriter, getOperation()); 1340edd9515bSthomasraoux } 1341edd9515bSthomasraoux }; 1342edd9515bSthomasraoux 1343edd9515bSthomasraoux } // namespace 1344edd9515bSthomasraoux 13451ca772edSChristopher Bate std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 13461ca772edSChristopher Bate return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 1347edd9515bSthomasraoux } 1348