xref: /llvm-project/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14 #include "mlir/Dialect/NVGPU/Transforms/Utils.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 
19 using namespace mlir;
20 
21 /// Return "true" if the given vector transfer op is contiguous and suitable
22 /// for replacement with an async copy.
23 template <typename OpTy>
24 static bool isContiguousXferOp(OpTy op) {
25   return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
26          op.hasPureBufferSemantics() &&
27          cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())
28              .isLastDimUnitStride();
29 }
30 
31 /// Return "true" if the given op is a contiguous and suitable
32 /// vector.transfer_write or vector.store op.
33 static bool isContiguousStore(Operation *write) {
34   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
35     return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
36   // vector.store are always contiguous.
37   return isa<vector::StoreOp>(write);
38 }
39 
40 /// Return "true" if the given op is a contiguous and suitable
41 /// vector.transfer_read or vector.load op.
42 static bool isContiguousRead(Operation *read) {
43   if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
44     return isContiguousXferOp(transferRead);
45   // vector.load are always contiguous.
46   return isa<vector::LoadOp>(read);
47 }
48 
49 namespace {
50 /// A vector.create_mask op and extract position.
51 struct TransferMask {
52   vector::CreateMaskOp createMaskOp;
53   SmallVector<int64_t> extractPosition;
54 };
55 } // namespace
56 
57 /// If the given vector load op has a mask that is defined by
58 /// vector.create_mask, return that op.
59 static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
60   auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
61   if (!transferRead || !transferRead.getMask())
62     return TransferMask{{}, {}};
63   assert(transferRead.getMask().getType().getRank() == 1 &&
64          "expected 1-D mask");
65 
66   // Case 1: Mask is the result of a vector.create_mask.
67   if (auto maskOp =
68           transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
69     return TransferMask{maskOp, {}};
70 
71   // Case 2: Mask is the result of a vector.extract(vector.create_mask).
72   if (auto extractOp =
73           transferRead.getMask().getDefiningOp<vector::ExtractOp>())
74     if (auto maskOp =
75             extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
76       return TransferMask{maskOp,
77                           SmallVector<int64_t>(extractOp.getStaticPosition())};
78 
79   // All other cases: not supported.
80   return failure();
81 }
82 
83 /// Build an SSA value that represents the number of read elements.
84 static Value buildNumReadElements(OpBuilder &b, Location loc,
85                                   Operation *readOp) {
86   FailureOr<TransferMask> transferMask = getMaskOp(readOp);
87   assert(succeeded(transferMask) && "invalid transfer mask");
88 
89   // No mask => no num_read_elements.
90   if (!transferMask->createMaskOp)
91     return Value();
92 
93   // No extract: return size of "ones" segment in the mask.
94   if (transferMask->extractPosition.empty()) {
95     assert(transferMask->createMaskOp.getNumOperands() == 1 &&
96            "expected single operand");
97     return transferMask->createMaskOp.getOperand(0);
98   }
99 
100   // vector.extract(vector.create_mask).
101   // If extract_pos < num_ones, take number of elements from the least
102   // significant dimension. (Do this for all dimensions and bit-AND the
103   // conditions.)
104   assert(transferMask->createMaskOp.getVectorType().getRank() -
105                  transferMask->extractPosition.size() ==
106              1 &&
107          "expected N-D -> (N-1)-D extract");
108   Value cond;
109   // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
110   for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
111                                   transferMask->createMaskOp->getOperands())) {
112     Value cmp =
113         b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
114                                 b.create<arith::ConstantIndexOp>(loc, pos), sz);
115     if (!cond) {
116       cond = cmp;
117       continue;
118     }
119     cond = b.create<arith::AndIOp>(loc, cmp, cond);
120   }
121   return b.create<arith::SelectOp>(
122       loc, cond, transferMask->createMaskOp->getOperands().back(),
123       b.create<arith::ConstantIndexOp>(loc, 0));
124 }
125 
126 /// Return "true" if the conversion to async copy is supported by "async copy".
127 static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
128                                         VectorType vecType) {
129   assert(vecType.getRank() == 1 && "expected 1-D vector");
130   constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
131 
132   // Condition 1: the copy size must be supported.
133   bool supportedCopySize = false;
134   int64_t numElements = vecType.getNumElements();
135   Type elementType = vecType.getElementType();
136   for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
137     if (alignmentInBytes * 8 ==
138         numElements * elementType.getIntOrFloatBitWidth()) {
139       supportedCopySize = true;
140       break;
141     }
142   }
143   if (!supportedCopySize)
144     return false;
145 
146   // TODO: Condition 2: the alignments must be supported. For cp.async the
147   // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
148   // a multiple of the access size. If an address is not properly aligned, the
149   // resulting behavior is undefined.".
150   return true;
151 }
152 
153 void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
154                               bool bypassL1) {
155   llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
156 
157   // Look for all the copy that can be converted to async copy ops.
158   op->walk([&](Operation *writeOp) {
159     // Look for contiguous 1D vector store into shared memory.
160     if (!isContiguousStore(writeOp))
161       return;
162     Value vectorVal = nvgpu::getValueStored(writeOp);
163     if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
164       return;
165     Value storeBase = nvgpu::getMemrefOperand(writeOp);
166     if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
167             cast<MemRefType>(storeBase.getType())))
168       return;
169 
170     // The stored vector must originate from a contiguous 1D vector load.
171     Operation *readOp = vectorVal.getDefiningOp();
172     if (readOp == nullptr || !isContiguousRead(readOp))
173       return;
174     Value loadBase = nvgpu::getMemrefOperand(readOp);
175     // Should be reading from global memory (not shared memory).
176     if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
177             cast<MemRefType>(loadBase.getType())))
178       return;
179 
180     // Look for compatible mask and padding.
181     if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
182       if (Value mask = transferRead.getMask()) {
183         if (getConstantIntValue(transferRead.getPadding()) ==
184             static_cast<int64_t>(0))
185           return;
186         if (failed(getMaskOp(readOp)))
187           return;
188       }
189     }
190 
191     // Check whether both accesses are supported before we emit: this is
192     // necessary to ensure the correctness of DeviceAsyncCopyOp.
193     VectorType vecType = cast<VectorType>(vectorVal.getType());
194 
195     if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
196                                      vecType) ||
197         !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
198                                      vecType))
199       return;
200 
201     copyToSharedMem.insert(writeOp);
202     return;
203   });
204 
205   while (!copyToSharedMem.empty()) {
206     // Start a group with the first write.
207     SmallVector<Operation *> group;
208     Operation *writeOp = *copyToSharedMem.begin();
209     copyToSharedMem.remove(writeOp);
210     group.push_back(writeOp);
211     Operation *nextNode = writeOp;
212 
213     // Look in the next nodes for more copies to add to the same group.
214     while ((nextNode = nextNode->getNextNode())) {
215       // Ignore ops without side effects.
216       auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
217       if (memInterface && memInterface.hasNoEffect() &&
218           !nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
219         continue;
220       // Ignore read from a different address space.
221       if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
222         Operation *readOp = nextNode;
223         Value memrefOperand = nvgpu::getMemrefOperand(readOp);
224         if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
225                 cast<MemRefType>(memrefOperand.getType()))) {
226           continue;
227         }
228       }
229       if (copyToSharedMem.count(nextNode)) {
230         // Found another copy, add it to the group.
231         copyToSharedMem.remove(nextNode);
232         group.push_back(nextNode);
233         continue;
234       }
235       // If the op is something else stop the accumulating op in the group.
236       break;
237     }
238 
239     // Emit the group.
240     SmallVector<Value> tokens;
241     for (Operation *writeOp : group) {
242       rewriter.setInsertionPoint(writeOp);
243       Value vectorVal = nvgpu::getValueStored(writeOp);
244       auto vectorType = cast<VectorType>(vectorVal.getType());
245       int64_t numElements = vectorType.getNumElements();
246       Operation *readOp = vectorVal.getDefiningOp();
247       Value storeBase = nvgpu::getMemrefOperand(writeOp);
248       Value loadBase = nvgpu::getMemrefOperand(readOp);
249       Value numReadElements =
250           buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
251       auto dstMemref = cast<MemRefType>(storeBase.getType());
252       int64_t sizeInBytes =
253           (dstMemref.getElementTypeBitWidth() * numElements) / 8;
254       // bypass_l1 only possible with 16 byte transfer.
255       Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
256           writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
257           /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
258           /*src=*/loadBase,
259           /*srcIndices=*/nvgpu::getIndices(readOp),
260           /*dstElements=*/rewriter.getIndexAttr(numElements),
261           /*srcElements=*/numReadElements,
262           /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
263                                                      : UnitAttr());
264       tokens.push_back(token);
265     }
266 
267     // Create the group and wait for it right after.
268     Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
269         op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
270         tokens);
271     rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
272                                               nullptr);
273     // Clean up old stores.
274     for (Operation *writeOp : group)
275       rewriter.eraseOp(writeOp);
276   }
277 }
278