1 //===- CreateAsyncGroups.cpp - Create async device copies -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 14 #include "mlir/Dialect/NVGPU/Transforms/Utils.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 19 using namespace mlir; 20 21 /// Return "true" if the given vector transfer op is contiguous and suitable 22 /// for replacement with an async copy. 23 template <typename OpTy> 24 static bool isContiguousXferOp(OpTy op) { 25 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) && 26 op.hasPureBufferSemantics() && 27 cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()) 28 .isLastDimUnitStride(); 29 } 30 31 /// Return "true" if the given op is a contiguous and suitable 32 /// vector.transfer_write or vector.store op. 33 static bool isContiguousStore(Operation *write) { 34 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write)) 35 return isContiguousXferOp(transferWrite) && !transferWrite.getMask(); 36 // vector.store are always contiguous. 37 return isa<vector::StoreOp>(write); 38 } 39 40 /// Return "true" if the given op is a contiguous and suitable 41 /// vector.transfer_read or vector.load op. 42 static bool isContiguousRead(Operation *read) { 43 if (auto transferRead = dyn_cast<vector::TransferReadOp>(read)) 44 return isContiguousXferOp(transferRead); 45 // vector.load are always contiguous. 46 return isa<vector::LoadOp>(read); 47 } 48 49 namespace { 50 /// A vector.create_mask op and extract position. 51 struct TransferMask { 52 vector::CreateMaskOp createMaskOp; 53 SmallVector<int64_t> extractPosition; 54 }; 55 } // namespace 56 57 /// If the given vector load op has a mask that is defined by 58 /// vector.create_mask, return that op. 59 static FailureOr<TransferMask> getMaskOp(Operation *loadOp) { 60 auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp); 61 if (!transferRead || !transferRead.getMask()) 62 return TransferMask{{}, {}}; 63 assert(transferRead.getMask().getType().getRank() == 1 && 64 "expected 1-D mask"); 65 66 // Case 1: Mask is the result of a vector.create_mask. 67 if (auto maskOp = 68 transferRead.getMask().getDefiningOp<vector::CreateMaskOp>()) 69 return TransferMask{maskOp, {}}; 70 71 // Case 2: Mask is the result of a vector.extract(vector.create_mask). 72 if (auto extractOp = 73 transferRead.getMask().getDefiningOp<vector::ExtractOp>()) 74 if (auto maskOp = 75 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>()) 76 return TransferMask{maskOp, 77 SmallVector<int64_t>(extractOp.getStaticPosition())}; 78 79 // All other cases: not supported. 80 return failure(); 81 } 82 83 /// Build an SSA value that represents the number of read elements. 84 static Value buildNumReadElements(OpBuilder &b, Location loc, 85 Operation *readOp) { 86 FailureOr<TransferMask> transferMask = getMaskOp(readOp); 87 assert(succeeded(transferMask) && "invalid transfer mask"); 88 89 // No mask => no num_read_elements. 90 if (!transferMask->createMaskOp) 91 return Value(); 92 93 // No extract: return size of "ones" segment in the mask. 94 if (transferMask->extractPosition.empty()) { 95 assert(transferMask->createMaskOp.getNumOperands() == 1 && 96 "expected single operand"); 97 return transferMask->createMaskOp.getOperand(0); 98 } 99 100 // vector.extract(vector.create_mask). 101 // If extract_pos < num_ones, take number of elements from the least 102 // significant dimension. (Do this for all dimensions and bit-AND the 103 // conditions.) 104 assert(transferMask->createMaskOp.getVectorType().getRank() - 105 transferMask->extractPosition.size() == 106 1 && 107 "expected N-D -> (N-1)-D extract"); 108 Value cond; 109 // Note: There is one more `sz` than `pos`. The loop end with the last `pos`. 110 for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, 111 transferMask->createMaskOp->getOperands())) { 112 Value cmp = 113 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 114 b.create<arith::ConstantIndexOp>(loc, pos), sz); 115 if (!cond) { 116 cond = cmp; 117 continue; 118 } 119 cond = b.create<arith::AndIOp>(loc, cmp, cond); 120 } 121 return b.create<arith::SelectOp>( 122 loc, cond, transferMask->createMaskOp->getOperands().back(), 123 b.create<arith::ConstantIndexOp>(loc, 0)); 124 } 125 126 /// Return "true" if the conversion to async copy is supported by "async copy". 127 static bool resultsInSupportedAsyncCopy(MemRefType memrefType, 128 VectorType vecType) { 129 assert(vecType.getRank() == 1 && "expected 1-D vector"); 130 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16}; 131 132 // Condition 1: the copy size must be supported. 133 bool supportedCopySize = false; 134 int64_t numElements = vecType.getNumElements(); 135 Type elementType = vecType.getElementType(); 136 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) { 137 if (alignmentInBytes * 8 == 138 numElements * elementType.getIntOrFloatBitWidth()) { 139 supportedCopySize = true; 140 break; 141 } 142 } 143 if (!supportedCopySize) 144 return false; 145 146 // TODO: Condition 2: the alignments must be supported. For cp.async the 147 // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to 148 // a multiple of the access size. If an address is not properly aligned, the 149 // resulting behavior is undefined.". 150 return true; 151 } 152 153 void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, 154 bool bypassL1) { 155 llvm::SmallSetVector<Operation *, 16> copyToSharedMem; 156 157 // Look for all the copy that can be converted to async copy ops. 158 op->walk([&](Operation *writeOp) { 159 // Look for contiguous 1D vector store into shared memory. 160 if (!isContiguousStore(writeOp)) 161 return; 162 Value vectorVal = nvgpu::getValueStored(writeOp); 163 if (cast<VectorType>(vectorVal.getType()).getRank() != 1) 164 return; 165 Value storeBase = nvgpu::getMemrefOperand(writeOp); 166 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( 167 cast<MemRefType>(storeBase.getType()))) 168 return; 169 170 // The stored vector must originate from a contiguous 1D vector load. 171 Operation *readOp = vectorVal.getDefiningOp(); 172 if (readOp == nullptr || !isContiguousRead(readOp)) 173 return; 174 Value loadBase = nvgpu::getMemrefOperand(readOp); 175 // Should be reading from global memory (not shared memory). 176 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( 177 cast<MemRefType>(loadBase.getType()))) 178 return; 179 180 // Look for compatible mask and padding. 181 if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) { 182 if (Value mask = transferRead.getMask()) { 183 if (getConstantIntValue(transferRead.getPadding()) == 184 static_cast<int64_t>(0)) 185 return; 186 if (failed(getMaskOp(readOp))) 187 return; 188 } 189 } 190 191 // Check whether both accesses are supported before we emit: this is 192 // necessary to ensure the correctness of DeviceAsyncCopyOp. 193 VectorType vecType = cast<VectorType>(vectorVal.getType()); 194 195 if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()), 196 vecType) || 197 !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()), 198 vecType)) 199 return; 200 201 copyToSharedMem.insert(writeOp); 202 return; 203 }); 204 205 while (!copyToSharedMem.empty()) { 206 // Start a group with the first write. 207 SmallVector<Operation *> group; 208 Operation *writeOp = *copyToSharedMem.begin(); 209 copyToSharedMem.remove(writeOp); 210 group.push_back(writeOp); 211 Operation *nextNode = writeOp; 212 213 // Look in the next nodes for more copies to add to the same group. 214 while ((nextNode = nextNode->getNextNode())) { 215 // Ignore ops without side effects. 216 auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode); 217 if (memInterface && memInterface.hasNoEffect() && 218 !nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) 219 continue; 220 // Ignore read from a different address space. 221 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) { 222 Operation *readOp = nextNode; 223 Value memrefOperand = nvgpu::getMemrefOperand(readOp); 224 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( 225 cast<MemRefType>(memrefOperand.getType()))) { 226 continue; 227 } 228 } 229 if (copyToSharedMem.count(nextNode)) { 230 // Found another copy, add it to the group. 231 copyToSharedMem.remove(nextNode); 232 group.push_back(nextNode); 233 continue; 234 } 235 // If the op is something else stop the accumulating op in the group. 236 break; 237 } 238 239 // Emit the group. 240 SmallVector<Value> tokens; 241 for (Operation *writeOp : group) { 242 rewriter.setInsertionPoint(writeOp); 243 Value vectorVal = nvgpu::getValueStored(writeOp); 244 auto vectorType = cast<VectorType>(vectorVal.getType()); 245 int64_t numElements = vectorType.getNumElements(); 246 Operation *readOp = vectorVal.getDefiningOp(); 247 Value storeBase = nvgpu::getMemrefOperand(writeOp); 248 Value loadBase = nvgpu::getMemrefOperand(readOp); 249 Value numReadElements = 250 buildNumReadElements(rewriter, writeOp->getLoc(), readOp); 251 auto dstMemref = cast<MemRefType>(storeBase.getType()); 252 int64_t sizeInBytes = 253 (dstMemref.getElementTypeBitWidth() * numElements) / 8; 254 // bypass_l1 only possible with 16 byte transfer. 255 Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>( 256 writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), 257 /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), 258 /*src=*/loadBase, 259 /*srcIndices=*/nvgpu::getIndices(readOp), 260 /*dstElements=*/rewriter.getIndexAttr(numElements), 261 /*srcElements=*/numReadElements, 262 /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr() 263 : UnitAttr()); 264 tokens.push_back(token); 265 } 266 267 // Create the group and wait for it right after. 268 Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>( 269 op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), 270 tokens); 271 rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken, 272 nullptr); 273 // Clean up old stores. 274 for (Operation *writeOp : group) 275 rewriter.eraseOp(writeOp); 276 } 277 } 278