1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" 9 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 using namespace mlir; 17 using namespace mlir::nvgpu; 18 19 /// There are always 4 threads per [128|256|512] bit row. 20 static constexpr int64_t kThreadsPerRow = 4; 21 static constexpr int64_t kNumRowsPerTile = 8; 22 23 static bool isAccumulatorOrResult(MatMulOperandRole operandType) { 24 return operandType == MatMulOperandRole::C; 25 } 26 27 /// Returns the number of registers which compose a matrix fragment held by a 28 /// single thread. 29 static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { 30 int64_t lineSize = inferTileWidthInBits(type); 31 auto shape = type.vectorType.getShape(); 32 return (shape[0] / kNumRowsPerTile) * 33 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / 34 lineSize; 35 } 36 37 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given 38 /// operand shape. 39 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape, 40 Type elementType, 41 int64_t lineSizeBits) { 42 // For each 8x128bit square, a thread is responsible for one 32bit register. 43 return {operandShape[0] / kNumRowsPerTile, 44 (operandShape[1] * elementType.getIntOrFloatBitWidth()) / 45 lineSizeBits}; 46 } 47 48 /// Returns the first user of the `op` that is vector.contract. If no 49 /// vector.contract user exists, return failure. 50 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) { 51 for (Operation *user : op->getUsers()) { 52 if (auto contractOp = dyn_cast<vector::ContractionOp>(user)) 53 return contractOp; 54 } 55 return failure(); 56 } 57 58 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) { 59 WarpMatrixInfo info; 60 61 // Determine the vector type at warp-level. 62 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 63 info.vectorType = writeOp.getVectorType(); 64 } else if (isa<vector::TransferReadOp, vector::ContractionOp, 65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) { 66 info.vectorType = cast<VectorType>(op->getResult(0).getType()); 67 } else { 68 return op->emitError() 69 << "unhandled operation type in nvgpu.mma.sync conversion path"; 70 } 71 72 // Determine the operand role. We assume it is an accumulator/result unless it 73 // is directly consumed by a `vector.contract` op. 74 info.operandRole = MatMulOperandRole::C; 75 FailureOr<vector::ContractionOp> contractOp = getUserContract(op); 76 if (failed(contractOp)) 77 return info; 78 79 if ((*contractOp).getLhs() == op->getResult(0)) 80 info.operandRole = MatMulOperandRole::A; 81 else if ((*contractOp).getRhs() == op->getResult(0)) 82 info.operandRole = MatMulOperandRole::B; 83 84 return info; 85 } 86 87 int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) { 88 bool isAcc = isAccumulatorOrResult(type.operandRole); 89 Type elType = type.vectorType.getElementType(); 90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) { 91 return 256; 92 } 93 if (elType.getIntOrFloatBitWidth() == 64) { 94 return isAcc ? 512 : 256; 95 } 96 return 128; 97 } 98 99 FailureOr<FragmentElementInfo> 100 nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { 101 MLIRContext *ctx = type.vectorType.getContext(); 102 const bool isAccum = isAccumulatorOrResult(type.operandRole); 103 104 Type elType = type.vectorType.getElementType(); 105 if (elType.isF16()) { 106 return FragmentElementInfo{ 107 LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, 108 inferNumRegistersPerMatrixFragment(type)}; 109 } 110 111 // f64 operand 112 Type f64Ty = Float64Type::get(ctx); 113 if (elType.isF64()) { 114 return isAccum 115 ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, 116 inferNumRegistersPerMatrixFragment(type)} 117 : FragmentElementInfo{f64Ty, 1, 64, 118 inferNumRegistersPerMatrixFragment(type)}; 119 } 120 121 // int8 operand 122 if (elType.isInteger(8)) { 123 return FragmentElementInfo{ 124 LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, 125 inferNumRegistersPerMatrixFragment(type)}; 126 } 127 128 // int4 operand 129 if (elType.isInteger(4)) { 130 return FragmentElementInfo{ 131 LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32, 132 inferNumRegistersPerMatrixFragment(type)}; 133 } 134 135 // Integer 32bit acc operands 136 if (elType.isInteger(32)) { 137 return FragmentElementInfo{ 138 LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, 139 inferNumRegistersPerMatrixFragment(type)}; 140 } 141 142 // Floating point 32bit operands 143 if (elType.isF32()) { 144 Type f32Ty = Float32Type::get(ctx); 145 return isAccum 146 ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, 147 inferNumRegistersPerMatrixFragment(type)} 148 : FragmentElementInfo{f32Ty, 1, 32, 149 inferNumRegistersPerMatrixFragment(type)}; 150 } 151 return failure(); 152 } 153 154 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, 155 Type elementType, 156 ArrayRef<int64_t> operandShape, 157 bool isAccumulator, 158 int64_t elementsPerRegister, 159 AffineExpr logicalValueId) { 160 const int64_t elementsPerLine = 161 lineSize / elementType.getIntOrFloatBitWidth(); 162 const std::array<int64_t, 2> num8x128bTiles = 163 getTileShape(operandShape, elementType, lineSize); 164 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); 165 return AffineMap::get( 166 2, 0, 167 {(registerIdx % num8x128bTiles[0]) * 8, 168 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, 169 elementType.getContext()); 170 } 171 172 FailureOr<AffineMap> 173 nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, 174 const WarpMatrixInfo &fragmentType) { 175 Type elementType = fragmentType.vectorType.getElementType(); 176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); 177 FailureOr<nvgpu::FragmentElementInfo> regInfo = 178 getMmaSyncRegisterType(fragmentType); 179 if (failed(regInfo)) 180 return failure(); 181 182 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); 183 const int64_t elementsPerRegister = 184 regInfo->registerWidthBits / elementBitWidth; 185 const int64_t lineSize = inferTileWidthInBits(fragmentType); 186 187 AffineExpr laneId, logicalValueIdDim; 188 bindDims(builder.getContext(), laneId, logicalValueIdDim); 189 190 // Determine what register logicalValueId corresponds to. Use that as a 191 // linear index into the coordinate mapping `index -> (tile row, tile col)`. 192 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( 193 lineSize, elementType, operandShape, 194 isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, 195 logicalValueIdDim); 196 197 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 198 return AffineMap::get(2, 0, dimExprs, builder.getContext()); 199 }; 200 201 auto tileRow = registerIndexToTileCoord.getResult(0); 202 auto tileCol = registerIndexToTileCoord.getResult(1); 203 return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), 204 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + 205 (logicalValueIdDim % elementsPerRegister)}); 206 } 207 208 FailureOr<nvgpu::LdMatrixParams> 209 nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { 210 LdMatrixParams params; 211 Type elType = type.vectorType.getElementType(); 212 params.fragmentType = type.vectorType; 213 if (type.operandRole == MatMulOperandRole::A || 214 type.operandRole == MatMulOperandRole::C) { 215 params.targetLayout = NVVM::MMALayout::row; 216 } else { 217 params.targetLayout = NVVM::MMALayout::col; 218 } 219 ArrayRef<int64_t> shape = type.vectorType.getShape(); 220 params.contiguousDimType = transpose ? vector::IteratorType::parallel 221 : vector::IteratorType::reduction; 222 223 if (params.contiguousDimType == vector::IteratorType::reduction) { 224 params.numTiles = (shape[0] / kNumRowsPerTile) * 225 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); 226 } else { 227 params.numTiles = (shape[1] / kNumRowsPerTile) * 228 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); 229 } 230 231 if (params.numTiles == 0) 232 return failure(); 233 234 return params; 235 } 236 237 FailureOr<AffineMap> 238 nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, 239 const LdMatrixParams ¶ms) { 240 // One thread per 128b row. 241 const int bitsPerElement = static_cast<int>( 242 params.fragmentType.getElementType().getIntOrFloatBitWidth()); 243 const int kElementsPer128b = (128 / bitsPerElement); 244 ArrayRef<int64_t> operandShape = params.fragmentType.getShape(); 245 AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); 246 247 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 248 return AffineMap::get(1, 0, dimExprs, builder.getContext()); 249 }; 250 251 // Index `idx` in vectorType `operandShape` maps to the strided dimension of 252 // the `srcMemref` memory of the LdMatrixOp. 253 int idx = 254 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1; 255 256 // Affine expr in strided and contiguous dimension encodes the coordinate 257 // mapping for the element a thread points to for warp-wide LdMatrixOp. 258 AffineExpr strided = d0 % (operandShape[idx]); 259 AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b); 260 261 // This case corresponds to row-major matrixA or col-major matrixB or 262 // row-major matrixC. This is when the memory layout in `srcMemref` 263 // match mma.sync hardware vector register operand layout. 264 if (params.contiguousDimType == vector::IteratorType::reduction) 265 return makeMap({strided, contiguous}); 266 267 // This case corresponds to col-major matrixA or row-major matrixB or 268 // col-major matrixC. This is when the memory layout in `srcMemref` does not 269 // match mma.sync hardware vector register operand layout. 270 if (params.contiguousDimType == vector::IteratorType::parallel) 271 return makeMap({contiguous, strided}); 272 273 return failure(); 274 } 275 276 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) { 277 if (op.getMask() || op.hasOutOfBoundsDim()) 278 return false; 279 VectorType type = op.getType(); 280 // The result type should be 2D. Note that it is possible to expand support so 281 // that we are robust to extra unit dimensions that failed to fold, but that 282 // would significantly increase downstream code complexity in the conversion 283 // step. For now, we rely on other patterns to ensure canonical 2D form is 284 // used when targeting the `nvgpu.mma.sync` lowering path. 285 if (!type.hasStaticShape() || type.getRank() != 2) 286 return false; 287 288 // Currently we can't support reads on tensor types because we need stride 289 // information to ensure correctness of downstream assumptions. It is possible 290 // to enable this if caller can assert that tensor will be lowered in a 291 // particular manner. 292 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType()); 293 if (!sourceType) 294 return false; 295 296 // Check that the last dimension of the read is contiguous. Note that it is 297 // possible to expand support for this by scalarizing all the loads during 298 // conversion. 299 auto [strides, offset] = sourceType.getStridesAndOffset(); 300 return strides.back() == 1; 301 } 302 303 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) { 304 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0) 305 return false; 306 VectorType type = op.getVectorType(); 307 if (!type.hasStaticShape() || type.getRank() != 2) 308 return false; 309 // TODO: Currently we rely on lowering to a `vector.store` operation. We could 310 // support the transposed write case by lowering to scalarized `memref.store` 311 // operations. 312 if (!op.getPermutationMap().isMinorIdentity()) 313 return false; 314 // Currently we can't support reads on tensor types because we need stride 315 // information to ensure correctness of downstream assumptions. 316 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType()); 317 if (!sourceType) 318 return false; 319 320 // Check that the last dimension of the target memref is contiguous. Note that 321 // it is possible to expand support for this by scalarizing all the stores 322 // during conversion. 323 auto [strides, offset] = sourceType.getStridesAndOffset(); 324 return strides.back() == 1; 325 } 326