xref: /llvm-project/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
9 
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 
16 using namespace mlir;
17 using namespace mlir::nvgpu;
18 
19 /// There are always 4 threads per [128|256|512] bit row.
20 static constexpr int64_t kThreadsPerRow = 4;
21 static constexpr int64_t kNumRowsPerTile = 8;
22 
23 static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
24   return operandType == MatMulOperandRole::C;
25 }
26 
27 /// Returns the number of registers which compose a matrix fragment held by a
28 /// single thread.
29 static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
30   int64_t lineSize = inferTileWidthInBits(type);
31   auto shape = type.vectorType.getShape();
32   return (shape[0] / kNumRowsPerTile) *
33          (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34          lineSize;
35 }
36 
37 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
38 /// operand shape.
39 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
40                                            Type elementType,
41                                            int64_t lineSizeBits) {
42   // For each 8x128bit square, a thread is responsible for one 32bit register.
43   return {operandShape[0] / kNumRowsPerTile,
44           (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
45               lineSizeBits};
46 }
47 
48 /// Returns the first user of the `op` that is vector.contract. If no
49 /// vector.contract user exists, return failure.
50 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
51   for (Operation *user : op->getUsers()) {
52     if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
53       return contractOp;
54   }
55   return failure();
56 }
57 
58 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
59   WarpMatrixInfo info;
60 
61   // Determine the vector type at warp-level.
62   if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
63     info.vectorType = writeOp.getVectorType();
64   } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65                  vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
66     info.vectorType = cast<VectorType>(op->getResult(0).getType());
67   } else {
68     return op->emitError()
69            << "unhandled operation type in nvgpu.mma.sync conversion path";
70   }
71 
72   // Determine the operand role. We assume it is an accumulator/result unless it
73   // is directly consumed by a `vector.contract` op.
74   info.operandRole = MatMulOperandRole::C;
75   FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
76   if (failed(contractOp))
77     return info;
78 
79   if ((*contractOp).getLhs() == op->getResult(0))
80     info.operandRole = MatMulOperandRole::A;
81   else if ((*contractOp).getRhs() == op->getResult(0))
82     info.operandRole = MatMulOperandRole::B;
83 
84   return info;
85 }
86 
87 int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
88   bool isAcc = isAccumulatorOrResult(type.operandRole);
89   Type elType = type.vectorType.getElementType();
90   if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91     return 256;
92   }
93   if (elType.getIntOrFloatBitWidth() == 64) {
94     return isAcc ? 512 : 256;
95   }
96   return 128;
97 }
98 
99 FailureOr<FragmentElementInfo>
100 nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101   MLIRContext *ctx = type.vectorType.getContext();
102   const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104   Type elType = type.vectorType.getElementType();
105   if (elType.isF16()) {
106     return FragmentElementInfo{
107         LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108         inferNumRegistersPerMatrixFragment(type)};
109   }
110 
111   // f64 operand
112   Type f64Ty = Float64Type::get(ctx);
113   if (elType.isF64()) {
114     return isAccum
115                ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
116                                      inferNumRegistersPerMatrixFragment(type)}
117                : FragmentElementInfo{f64Ty, 1, 64,
118                                      inferNumRegistersPerMatrixFragment(type)};
119   }
120 
121   // int8 operand
122   if (elType.isInteger(8)) {
123     return FragmentElementInfo{
124         LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125         inferNumRegistersPerMatrixFragment(type)};
126   }
127 
128   // int4 operand
129   if (elType.isInteger(4)) {
130     return FragmentElementInfo{
131         LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132         inferNumRegistersPerMatrixFragment(type)};
133   }
134 
135   // Integer 32bit acc operands
136   if (elType.isInteger(32)) {
137     return FragmentElementInfo{
138         LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
139         inferNumRegistersPerMatrixFragment(type)};
140   }
141 
142   // Floating point 32bit operands
143   if (elType.isF32()) {
144     Type f32Ty = Float32Type::get(ctx);
145     return isAccum
146                ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
147                                      inferNumRegistersPerMatrixFragment(type)}
148                : FragmentElementInfo{f32Ty, 1, 32,
149                                      inferNumRegistersPerMatrixFragment(type)};
150   }
151   return failure();
152 }
153 
154 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
155                                                  Type elementType,
156                                                  ArrayRef<int64_t> operandShape,
157                                                  bool isAccumulator,
158                                                  int64_t elementsPerRegister,
159                                                  AffineExpr logicalValueId) {
160   const int64_t elementsPerLine =
161       lineSize / elementType.getIntOrFloatBitWidth();
162   const std::array<int64_t, 2> num8x128bTiles =
163       getTileShape(operandShape, elementType, lineSize);
164   AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
165   return AffineMap::get(
166       2, 0,
167       {(registerIdx % num8x128bTiles[0]) * 8,
168        (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
169       elementType.getContext());
170 }
171 
172 FailureOr<AffineMap>
173 nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
174                                          const WarpMatrixInfo &fragmentType) {
175   Type elementType = fragmentType.vectorType.getElementType();
176   ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
177   FailureOr<nvgpu::FragmentElementInfo> regInfo =
178       getMmaSyncRegisterType(fragmentType);
179   if (failed(regInfo))
180     return failure();
181 
182   const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183   const int64_t elementsPerRegister =
184       regInfo->registerWidthBits / elementBitWidth;
185   const int64_t lineSize = inferTileWidthInBits(fragmentType);
186 
187   AffineExpr laneId, logicalValueIdDim;
188   bindDims(builder.getContext(), laneId, logicalValueIdDim);
189 
190   // Determine what register logicalValueId corresponds to. Use that as a
191   // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192   AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193       lineSize, elementType, operandShape,
194       isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
195       logicalValueIdDim);
196 
197   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198     return AffineMap::get(2, 0, dimExprs, builder.getContext());
199   };
200 
201   auto tileRow = registerIndexToTileCoord.getResult(0);
202   auto tileCol = registerIndexToTileCoord.getResult(1);
203   return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
204                   tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205                       (logicalValueIdDim % elementsPerRegister)});
206 }
207 
208 FailureOr<nvgpu::LdMatrixParams>
209 nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
210   LdMatrixParams params;
211   Type elType = type.vectorType.getElementType();
212   params.fragmentType = type.vectorType;
213   if (type.operandRole == MatMulOperandRole::A ||
214       type.operandRole == MatMulOperandRole::C) {
215     params.targetLayout = NVVM::MMALayout::row;
216   } else {
217     params.targetLayout = NVVM::MMALayout::col;
218   }
219   ArrayRef<int64_t> shape = type.vectorType.getShape();
220   params.contiguousDimType = transpose ? vector::IteratorType::parallel
221                                        : vector::IteratorType::reduction;
222 
223   if (params.contiguousDimType == vector::IteratorType::reduction) {
224     params.numTiles = (shape[0] / kNumRowsPerTile) *
225                       ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226   } else {
227     params.numTiles = (shape[1] / kNumRowsPerTile) *
228                       ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229   }
230 
231   if (params.numTiles == 0)
232     return failure();
233 
234   return params;
235 }
236 
237 FailureOr<AffineMap>
238 nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
239                                       const LdMatrixParams &params) {
240   // One thread per 128b row.
241   const int bitsPerElement = static_cast<int>(
242       params.fragmentType.getElementType().getIntOrFloatBitWidth());
243   const int kElementsPer128b = (128 / bitsPerElement);
244   ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
245   AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
246 
247   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
248     return AffineMap::get(1, 0, dimExprs, builder.getContext());
249   };
250 
251   // Index `idx` in vectorType `operandShape` maps to the strided dimension of
252   // the `srcMemref` memory of the LdMatrixOp.
253   int idx =
254       (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
255 
256   // Affine expr in strided and contiguous dimension encodes the coordinate
257   // mapping for the element a thread points to for warp-wide LdMatrixOp.
258   AffineExpr strided = d0 % (operandShape[idx]);
259   AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
260 
261   // This case corresponds to row-major matrixA or col-major matrixB or
262   // row-major matrixC. This is when the memory layout in `srcMemref`
263   // match mma.sync hardware vector register operand layout.
264   if (params.contiguousDimType == vector::IteratorType::reduction)
265     return makeMap({strided, contiguous});
266 
267   // This case corresponds to col-major matrixA or row-major matrixB or
268   // col-major matrixC. This is when the memory layout in `srcMemref` does not
269   // match mma.sync hardware vector register operand layout.
270   if (params.contiguousDimType == vector::IteratorType::parallel)
271     return makeMap({contiguous, strided});
272 
273   return failure();
274 }
275 
276 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
277   if (op.getMask() || op.hasOutOfBoundsDim())
278     return false;
279   VectorType type = op.getType();
280   // The result type should be 2D. Note that it is possible to expand support so
281   // that we are robust to extra unit dimensions that failed to fold, but that
282   // would significantly increase downstream code complexity in the conversion
283   // step. For now, we rely on other patterns to ensure canonical 2D form is
284   // used when targeting the `nvgpu.mma.sync` lowering path.
285   if (!type.hasStaticShape() || type.getRank() != 2)
286     return false;
287 
288   // Currently we can't support reads on tensor types because we need stride
289   // information to ensure correctness of downstream assumptions. It is possible
290   // to enable this if caller can assert that tensor will be lowered in a
291   // particular manner.
292   auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
293   if (!sourceType)
294     return false;
295 
296   // Check that the last dimension of the read is contiguous. Note that it is
297   // possible to expand support for this by scalarizing all the loads during
298   // conversion.
299   auto [strides, offset] = sourceType.getStridesAndOffset();
300   return strides.back() == 1;
301 }
302 
303 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
304   if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
305     return false;
306   VectorType type = op.getVectorType();
307   if (!type.hasStaticShape() || type.getRank() != 2)
308     return false;
309   // TODO: Currently we rely on lowering to a `vector.store` operation. We could
310   // support the transposed write case by lowering to scalarized `memref.store`
311   // operations.
312   if (!op.getPermutationMap().isMinorIdentity())
313     return false;
314   // Currently we can't support reads on tensor types because we need stride
315   // information to ensure correctness of downstream assumptions.
316   auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
317   if (!sourceType)
318     return false;
319 
320   // Check that the last dimension of the target memref is contiguous. Note that
321   // it is possible to expand support for this by scalarizing all the stores
322   // during conversion.
323   auto [strides, offset] = sourceType.getStridesAndOffset();
324   return strides.back() == 1;
325 }
326