1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 12 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/Interfaces/SideEffectInterfaces.h" 19 #include "mlir/Transforms/RegionUtils.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/Support/FormatVariadic.h" 22 #include <utility> 23 24 using namespace mlir; 25 using namespace mlir::vector; 26 using namespace mlir::gpu; 27 28 /// Currently the distribution map is implicit based on the vector shape. In the 29 /// future it will be part of the op. 30 /// Example: 31 /// ``` 32 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 33 /// ... 34 /// gpu.yield %3 : vector<32x16x64xf32> 35 /// } 36 /// ``` 37 /// Would have an implicit map of: 38 /// `(d0, d1, d2) -> (d0, d2)` 39 static AffineMap calculateImplicitMap(VectorType sequentialType, 40 VectorType distributedType) { 41 SmallVector<AffineExpr> perm; 42 perm.reserve(1); 43 // Check which dimensions of the sequential type are different than the 44 // dimensions of the distributed type to know the distributed dimensions. Then 45 // associate each distributed dimension to an ID in order. 46 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { 47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) 48 perm.push_back(getAffineDimExpr(i, distributedType.getContext())); 49 } 50 auto map = AffineMap::get(sequentialType.getRank(), 0, perm, 51 distributedType.getContext()); 52 return map; 53 } 54 55 namespace { 56 57 /// Helper struct to create the load / store operations that permit transit 58 /// through the parallel / sequential and the sequential / parallel boundaries 59 /// when performing `rewriteWarpOpToScfFor`. 60 /// 61 /// The vector distribution dimension is inferred from the vector types. 62 struct DistributedLoadStoreHelper { 63 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, 64 Value laneId, Value zero) 65 : sequentialVal(sequentialVal), distributedVal(distributedVal), 66 laneId(laneId), zero(zero) { 67 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); 68 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); 69 if (sequentialVectorType && distributedVectorType) 70 distributionMap = 71 calculateImplicitMap(sequentialVectorType, distributedVectorType); 72 } 73 74 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { 75 int64_t distributedSize = distributedVectorType.getDimSize(index); 76 AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); 77 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, 78 ArrayRef<Value>{laneId}); 79 } 80 81 /// Create a store during the process of distributing the 82 /// `vector.warp_execute_on_thread_0` op. 83 /// Vector distribution assumes the following convention regarding the 84 /// temporary buffers that are created to transition values. This **must** 85 /// be properly specified in the `options.warpAllocationFn`: 86 /// 1. scalars of type T transit through a memref<1xT>. 87 /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 88 Operation *buildStore(RewriterBase &b, Location loc, Value val, 89 Value buffer) { 90 assert((val == distributedVal || val == sequentialVal) && 91 "Must store either the preregistered distributed or the " 92 "preregistered sequential value."); 93 // Scalar case can directly use memref.store. 94 if (!isa<VectorType>(val.getType())) 95 return b.create<memref::StoreOp>(loc, val, buffer, zero); 96 97 // Vector case must use vector::TransferWriteOp which will later lower to 98 // vector.store of memref.store depending on further lowerings. 99 int64_t rank = sequentialVectorType.getRank(); 100 SmallVector<Value> indices(rank, zero); 101 if (val == distributedVal) { 102 for (auto dimExpr : distributionMap.getResults()) { 103 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 104 indices[index] = buildDistributedOffset(b, loc, index); 105 } 106 } 107 SmallVector<bool> inBounds(indices.size(), true); 108 return b.create<vector::TransferWriteOp>( 109 loc, val, buffer, indices, 110 ArrayRef<bool>(inBounds.begin(), inBounds.end())); 111 } 112 113 /// Create a load during the process of distributing the 114 /// `vector.warp_execute_on_thread_0` op. 115 /// Vector distribution assumes the following convention regarding the 116 /// temporary buffers that are created to transition values. This **must** 117 /// be properly specified in the `options.warpAllocationFn`: 118 /// 1. scalars of type T transit through a memref<1xT>. 119 /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 120 /// 121 /// When broadcastMode is true, the load is not distributed to account for 122 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op. 123 /// 124 /// Example: 125 /// 126 /// ``` 127 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) { 128 /// gpu.yield %cst : f32 129 /// } 130 /// // Both types are f32. The constant %cst is broadcasted to all lanes. 131 /// ``` 132 /// This behavior described in more detail in the documentation of the op. 133 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { 134 135 // Scalar case can directly use memref.store. 136 if (!isa<VectorType>(type)) 137 return b.create<memref::LoadOp>(loc, buffer, zero); 138 139 // Other cases must be vector atm. 140 // Vector case must use vector::TransferReadOp which will later lower to 141 // vector.read of memref.read depending on further lowerings. 142 assert((type == distributedVectorType || type == sequentialVectorType) && 143 "Must store either the preregistered distributed or the " 144 "preregistered sequential type."); 145 SmallVector<Value> indices(sequentialVectorType.getRank(), zero); 146 if (type == distributedVectorType) { 147 for (auto dimExpr : distributionMap.getResults()) { 148 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 149 indices[index] = buildDistributedOffset(b, loc, index); 150 } 151 } 152 SmallVector<bool> inBounds(indices.size(), true); 153 return b.create<vector::TransferReadOp>( 154 loc, cast<VectorType>(type), buffer, indices, 155 ArrayRef<bool>(inBounds.begin(), inBounds.end())); 156 } 157 158 Value sequentialVal, distributedVal, laneId, zero; 159 VectorType sequentialVectorType, distributedVectorType; 160 AffineMap distributionMap; 161 }; 162 163 } // namespace 164 165 // Clones `op` into a new operation that takes `operands` and returns 166 // `resultTypes`. 167 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 168 Location loc, Operation *op, 169 ArrayRef<Value> operands, 170 ArrayRef<Type> resultTypes) { 171 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 172 op->getAttrs()); 173 return rewriter.create(res); 174 } 175 176 namespace { 177 178 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single 179 /// thread `laneId` executes the entirety of the computation. 180 /// 181 /// After the transformation: 182 /// - the IR within the scf.if op can be thought of as executing sequentially 183 /// (from the point of view of threads along `laneId`). 184 /// - the IR outside of the scf.if op can be thought of as executing in 185 /// parallel (from the point of view of threads along `laneId`). 186 /// 187 /// Values that need to transit through the parallel / sequential and the 188 /// sequential / parallel boundaries do so via reads and writes to a temporary 189 /// memory location. 190 /// 191 /// The transformation proceeds in multiple steps: 192 /// 1. Create the scf.if op. 193 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads 194 /// within the scf.if to transit the values captured from above. 195 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are 196 /// consistent within the scf.if. 197 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. 198 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to 199 /// transit the values returned by the op. 200 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are 201 /// consistent after the scf.if. 202 /// 7. Perform late cleanups. 203 /// 204 /// All this assumes the vector distribution occurs along the most minor 205 /// distributed vector dimension. 206 struct WarpOpToScfIfPattern : public WarpDistributionPattern { 207 WarpOpToScfIfPattern(MLIRContext *context, 208 const WarpExecuteOnLane0LoweringOptions &options, 209 PatternBenefit benefit = 1) 210 : WarpDistributionPattern(context, benefit), options(options) {} 211 212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 213 PatternRewriter &rewriter) const override { 214 assert(warpOp.getBodyRegion().hasOneBlock() && 215 "expected WarpOp with single block"); 216 Block *warpOpBody = &warpOp.getBodyRegion().front(); 217 Location loc = warpOp.getLoc(); 218 219 // Passed all checks. Start rewriting. 220 OpBuilder::InsertionGuard g(rewriter); 221 rewriter.setInsertionPoint(warpOp); 222 223 // Step 1: Create scf.if op. 224 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 225 Value isLane0 = rewriter.create<arith::CmpIOp>( 226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); 227 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 228 /*withElseRegion=*/false); 229 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 230 231 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and 232 // reads within the scf.if to transit the values captured from above. 233 SmallVector<Value> bbArgReplacements; 234 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 235 Value sequentialVal = warpOpBody->getArgument(it.index()); 236 Value distributedVal = it.value(); 237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 238 warpOp.getLaneid(), c0); 239 240 // Create buffer before the ifOp. 241 rewriter.setInsertionPoint(ifOp); 242 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 243 sequentialVal.getType()); 244 // Store distributed vector into buffer, before the ifOp. 245 helper.buildStore(rewriter, loc, distributedVal, buffer); 246 // Load sequential vector from buffer, inside the ifOp. 247 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 248 bbArgReplacements.push_back( 249 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); 250 } 251 252 // Step 3. Insert sync after all the stores and before all the loads. 253 if (!warpOp.getArgs().empty()) { 254 rewriter.setInsertionPoint(ifOp); 255 options.warpSyncronizationFn(loc, rewriter, warpOp); 256 } 257 258 // Step 4. Move body of warpOp to ifOp. 259 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 260 261 // Step 5. Insert appropriate writes within scf.if and reads after the 262 // scf.if to transit the values returned by the op. 263 // TODO: at this point, we can reuse the shared memory from previous 264 // buffers. 265 SmallVector<Value> replacements; 266 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator()); 267 Location yieldLoc = yieldOp.getLoc(); 268 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 269 Value sequentialVal = it.value(); 270 Value distributedVal = warpOp->getResult(it.index()); 271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 272 warpOp.getLaneid(), c0); 273 274 // Create buffer before the ifOp. 275 rewriter.setInsertionPoint(ifOp); 276 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 277 sequentialVal.getType()); 278 279 // Store yielded value into buffer, inside the ifOp, before the 280 // terminator. 281 rewriter.setInsertionPoint(yieldOp); 282 helper.buildStore(rewriter, loc, sequentialVal, buffer); 283 284 // Load distributed value from buffer, after the warpOp. 285 rewriter.setInsertionPointAfter(ifOp); 286 // Result type and yielded value type are the same. This is a broadcast. 287 // E.g.: 288 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) { 289 // gpu.yield %cst : f32 290 // } 291 // Both types are f32. The constant %cst is broadcasted to all lanes. 292 // This is described in more detail in the documentation of the op. 293 replacements.push_back( 294 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); 295 } 296 297 // Step 6. Insert sync after all the stores and before all the loads. 298 if (!yieldOp.getOperands().empty()) { 299 rewriter.setInsertionPointAfter(ifOp); 300 options.warpSyncronizationFn(loc, rewriter, warpOp); 301 } 302 303 // Step 7. Delete terminator and add empty scf.yield. 304 rewriter.eraseOp(yieldOp); 305 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 306 rewriter.create<scf::YieldOp>(yieldLoc); 307 308 // Compute replacements for WarpOp results. 309 rewriter.replaceOp(warpOp, replacements); 310 311 return success(); 312 } 313 314 private: 315 const WarpExecuteOnLane0LoweringOptions &options; 316 }; 317 318 /// Return the distributed vector type based on the original type and the 319 /// distribution map. The map is expected to have a dimension equal to the 320 /// original type rank and should be a projection where the results are the 321 /// distributed dimensions. The number of results should be equal to the number 322 /// of warp sizes which is currently limited to 1. 323 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) 324 /// and a warp size of 16 would distribute the second dimension (associated to 325 /// d1) and return vector<16x2x64> 326 static VectorType getDistributedType(VectorType originalType, AffineMap map, 327 int64_t warpSize) { 328 SmallVector<int64_t> targetShape(originalType.getShape()); 329 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 330 unsigned position = map.getDimPosition(i); 331 if (targetShape[position] % warpSize != 0) { 332 if (warpSize % targetShape[position] != 0) { 333 return VectorType(); 334 } 335 warpSize /= targetShape[position]; 336 targetShape[position] = 1; 337 continue; 338 } 339 targetShape[position] = targetShape[position] / warpSize; 340 warpSize = 1; 341 break; 342 } 343 if (warpSize != 1) { 344 return VectorType(); 345 } 346 VectorType targetType = 347 VectorType::get(targetShape, originalType.getElementType()); 348 return targetType; 349 } 350 351 /// Distribute transfer_write ops based on the affine map returned by 352 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` 353 /// will not be distributed (it should be less than the warp size). 354 /// 355 /// Example: 356 /// ``` 357 /// %0 = gpu.warp_execute_on_lane_0(%id){ 358 /// ... 359 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 360 /// gpu.yield 361 /// } 362 /// ``` 363 /// To 364 /// ``` 365 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 366 /// ... 367 /// gpu.yield %v : vector<32xf32> 368 /// } 369 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 370 struct WarpOpTransferWrite : public WarpDistributionPattern { 371 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 372 unsigned maxNumElementsToExtract, PatternBenefit b = 1) 373 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)), 374 maxNumElementsToExtract(maxNumElementsToExtract) {} 375 376 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 377 /// are multiples of the distribution ratio are supported at the moment. 378 LogicalResult tryDistributeOp(RewriterBase &rewriter, 379 vector::TransferWriteOp writeOp, 380 WarpExecuteOnLane0Op warpOp) const { 381 VectorType writtenVectorType = writeOp.getVectorType(); 382 383 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op 384 // to separate it from the rest. 385 if (writtenVectorType.getRank() == 0) 386 return failure(); 387 388 // 2. Compute the distributed type. 389 AffineMap map = distributionMapFn(writeOp.getVector()); 390 VectorType targetType = 391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); 392 if (!targetType) 393 return failure(); 394 395 // 2.5 Compute the distributed type for the new mask; 396 VectorType maskType; 397 if (writeOp.getMask()) { 398 // TODO: Distribution of masked writes with non-trivial permutation maps 399 // requires the distribution of the mask to elementwise match the 400 // distribution of the permuted written vector. Currently the details 401 // of which lane is responsible for which element is captured strictly 402 // by shape information on the warp op, and thus requires materializing 403 // the permutation in IR. 404 if (!writeOp.getPermutationMap().isMinorIdentity()) 405 return failure(); 406 maskType = 407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); 408 } 409 410 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from 411 // the rest. 412 vector::TransferWriteOp newWriteOp = 413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); 414 415 // 4. Reindex the write using the distribution map. 416 auto newWarpOp = 417 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); 418 419 // Delinearize the lane id based on the way threads are divided across the 420 // vector. To get the number of threads per vector dimension, divide the 421 // sequential size by the distributed size along each dim. 422 rewriter.setInsertionPoint(newWriteOp); 423 SmallVector<OpFoldResult> delinearizedIdSizes; 424 for (auto [seqSize, distSize] : 425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) { 426 assert(seqSize % distSize == 0 && "Invalid distributed vector shape"); 427 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize)); 428 } 429 SmallVector<Value> delinearized; 430 if (map.getNumResults() > 1) { 431 delinearized = rewriter 432 .create<mlir::affine::AffineDelinearizeIndexOp>( 433 newWarpOp.getLoc(), newWarpOp.getLaneid(), 434 delinearizedIdSizes) 435 .getResults(); 436 } else { 437 // If there is only one map result, we can elide the delinearization 438 // op and use the lane id directly. 439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid()); 440 } 441 442 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 443 Location loc = newWriteOp.getLoc(); 444 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 445 newWriteOp.getIndices().end()); 446 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 447 AffineExpr d0, d1; 448 bindDims(newWarpOp.getContext(), d0, d1); 449 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 450 if (!indexExpr) 451 continue; 452 unsigned indexPos = indexExpr.getPosition(); 453 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 454 Value laneId = delinearized[vectorPos]; 455 auto scale = 456 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); 457 indices[indexPos] = affine::makeComposedAffineApply( 458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId}); 459 } 460 newWriteOp.getIndicesMutable().assign(indices); 461 462 return success(); 463 } 464 465 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 466 LogicalResult tryExtractOp(RewriterBase &rewriter, 467 vector::TransferWriteOp writeOp, 468 WarpExecuteOnLane0Op warpOp) const { 469 Location loc = writeOp.getLoc(); 470 VectorType vecType = writeOp.getVectorType(); 471 472 if (vecType.getNumElements() > maxNumElementsToExtract) { 473 return rewriter.notifyMatchFailure( 474 warpOp, 475 llvm::formatv( 476 "writes more elements ({0}) than allowed to extract ({1})", 477 vecType.getNumElements(), maxNumElementsToExtract)); 478 } 479 480 // Do not process warp ops that contain only TransferWriteOps. 481 if (llvm::all_of(warpOp.getOps(), 482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>)) 483 return failure(); 484 485 SmallVector<Value> yieldValues = {writeOp.getVector()}; 486 SmallVector<Type> retTypes = {vecType}; 487 SmallVector<size_t> newRetIndices; 488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 489 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 490 rewriter.setInsertionPointAfter(newWarpOp); 491 492 // Create a second warp op that contains only writeOp. 493 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 494 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 495 Block &body = secondWarpOp.getBodyRegion().front(); 496 rewriter.setInsertionPointToStart(&body); 497 auto newWriteOp = 498 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 499 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 500 rewriter.eraseOp(writeOp); 501 rewriter.create<gpu::YieldOp>(newWarpOp.getLoc()); 502 return success(); 503 } 504 505 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 506 PatternRewriter &rewriter) const override { 507 auto yield = cast<gpu::YieldOp>( 508 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 509 Operation *lastNode = yield->getPrevNode(); 510 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); 511 if (!writeOp) 512 return failure(); 513 514 Value maybeMask = writeOp.getMask(); 515 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 516 return writeOp.getVector() == value || 517 (maybeMask && maybeMask == value) || 518 warpOp.isDefinedOutsideOfRegion(value); 519 })) 520 return failure(); 521 522 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 523 return success(); 524 525 // Masked writes not supported for extraction. 526 if (writeOp.getMask()) 527 return failure(); 528 529 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 530 return success(); 531 532 return failure(); 533 } 534 535 private: 536 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp 537 /// execute op with the proper return type. The new write op is updated to 538 /// write the result of the new warp execute op. The old `writeOp` is deleted. 539 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, 540 WarpExecuteOnLane0Op warpOp, 541 vector::TransferWriteOp writeOp, 542 VectorType targetType, 543 VectorType maybeMaskType) const { 544 assert(writeOp->getParentOp() == warpOp && 545 "write must be nested immediately under warp"); 546 OpBuilder::InsertionGuard g(rewriter); 547 SmallVector<size_t> newRetIndices; 548 WarpExecuteOnLane0Op newWarpOp; 549 if (maybeMaskType) { 550 newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 551 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, 552 TypeRange{targetType, maybeMaskType}, newRetIndices); 553 } else { 554 newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 555 rewriter, warpOp, ValueRange{{writeOp.getVector()}}, 556 TypeRange{targetType}, newRetIndices); 557 } 558 rewriter.setInsertionPointAfter(newWarpOp); 559 auto newWriteOp = 560 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 561 rewriter.eraseOp(writeOp); 562 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 563 if (maybeMaskType) 564 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); 565 return newWriteOp; 566 } 567 568 DistributionMapFn distributionMapFn; 569 unsigned maxNumElementsToExtract = 1; 570 }; 571 572 /// Sink out elementwise op feeding into a warp op yield. 573 /// ``` 574 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 575 /// ... 576 /// %3 = arith.addf %1, %2 : vector<32xf32> 577 /// gpu.yield %3 : vector<32xf32> 578 /// } 579 /// ``` 580 /// To 581 /// ``` 582 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 583 /// vector<1xf32>, vector<1xf32>) { 584 /// ... 585 /// %4 = arith.addf %2, %3 : vector<32xf32> 586 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 587 /// vector<32xf32> 588 /// } 589 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 590 struct WarpOpElementwise : public WarpDistributionPattern { 591 using Base::Base; 592 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 593 PatternRewriter &rewriter) const override { 594 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 595 return OpTrait::hasElementwiseMappableTraits(op); 596 }); 597 if (!yieldOperand) 598 return failure(); 599 600 Operation *elementWise = yieldOperand->get().getDefiningOp(); 601 unsigned operandIndex = yieldOperand->getOperandNumber(); 602 Value distributedVal = warpOp.getResult(operandIndex); 603 SmallVector<Value> yieldValues; 604 SmallVector<Type> retTypes; 605 Location loc = warpOp.getLoc(); 606 for (OpOperand &operand : elementWise->getOpOperands()) { 607 Type targetType; 608 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { 609 // If the result type is a vector, the operands must also be vectors. 610 auto operandType = cast<VectorType>(operand.get().getType()); 611 targetType = 612 VectorType::get(vecType.getShape(), operandType.getElementType()); 613 } else { 614 auto operandType = operand.get().getType(); 615 assert(!isa<VectorType>(operandType) && 616 "unexpected yield of vector from op with scalar result type"); 617 targetType = operandType; 618 } 619 retTypes.push_back(targetType); 620 yieldValues.push_back(operand.get()); 621 } 622 SmallVector<size_t> newRetIndices; 623 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 624 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 625 rewriter.setInsertionPointAfter(newWarpOp); 626 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 627 elementWise->getOperands().end()); 628 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 629 newOperands[i] = newWarpOp.getResult(newRetIndices[i]); 630 } 631 OpBuilder::InsertionGuard g(rewriter); 632 rewriter.setInsertionPointAfter(newWarpOp); 633 Operation *newOp = cloneOpWithOperandsAndTypes( 634 rewriter, loc, elementWise, newOperands, 635 {newWarpOp.getResult(operandIndex).getType()}); 636 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), 637 newOp->getResult(0)); 638 return success(); 639 } 640 }; 641 642 /// Sink out splat constant op feeding into a warp op yield. 643 /// ``` 644 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 645 /// ... 646 /// %cst = arith.constant dense<2.0> : vector<32xf32> 647 /// gpu.yield %cst : vector<32xf32> 648 /// } 649 /// ``` 650 /// To 651 /// ``` 652 /// gpu.warp_execute_on_lane_0(%arg0 { 653 /// ... 654 /// } 655 /// %0 = arith.constant dense<2.0> : vector<1xf32> 656 struct WarpOpConstant : public WarpDistributionPattern { 657 using Base::Base; 658 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 659 PatternRewriter &rewriter) const override { 660 OpOperand *yieldOperand = 661 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); 662 if (!yieldOperand) 663 return failure(); 664 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); 665 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); 666 if (!dense) 667 return failure(); 668 // Notify the rewriter that the warp op is changing (see the comment on 669 // the WarpOpTransferRead pattern). 670 rewriter.startOpModification(warpOp); 671 unsigned operandIndex = yieldOperand->getOperandNumber(); 672 Attribute scalarAttr = dense.getSplatValue<Attribute>(); 673 auto newAttr = DenseElementsAttr::get( 674 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); 675 Location loc = warpOp.getLoc(); 676 rewriter.setInsertionPointAfter(warpOp); 677 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); 678 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); 679 rewriter.finalizeOpModification(warpOp); 680 return success(); 681 } 682 }; 683 684 /// Sink out transfer_read op feeding into a warp op yield. 685 /// ``` 686 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 687 /// ... 688 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 689 // vector<32xf32> 690 /// gpu.yield %2 : vector<32xf32> 691 /// } 692 /// ``` 693 /// To 694 /// ``` 695 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 696 /// vector<1xf32>, vector<1xf32>) { 697 /// ... 698 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 699 /// vector<32xf32> gpu.yield %2 : vector<32xf32> 700 /// } 701 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 702 struct WarpOpTransferRead : public WarpDistributionPattern { 703 using Base::Base; 704 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 705 PatternRewriter &rewriter) const override { 706 // Try to find a distributable yielded read. Note that this pattern can 707 // still fail at the end after distribution, in which case this might have 708 // missed another distributable read. 709 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { 710 // Don't duplicate transfer_read ops when distributing. 711 return isa<vector::TransferReadOp>(op) && op->hasOneUse(); 712 }); 713 if (!operand) 714 return rewriter.notifyMatchFailure( 715 warpOp, "warp result is not a vector.transfer_read op"); 716 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 717 718 // Source must be defined outside of the region. 719 if (!warpOp.isDefinedOutsideOfRegion(read.getSource())) 720 return rewriter.notifyMatchFailure( 721 read, "source must be defined outside of the region"); 722 723 unsigned operandIndex = operand->getOperandNumber(); 724 Value distributedVal = warpOp.getResult(operandIndex); 725 726 SmallVector<Value, 4> indices(read.getIndices().begin(), 727 read.getIndices().end()); 728 auto sequentialType = cast<VectorType>(read.getResult().getType()); 729 auto distributedType = cast<VectorType>(distributedVal.getType()); 730 AffineMap map = calculateImplicitMap(sequentialType, distributedType); 731 AffineMap indexMap = map.compose(read.getPermutationMap()); 732 733 // Try to delinearize the lane ID to match the rank expected for 734 // distribution. 735 SmallVector<Value> delinearizedIds; 736 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), 737 distributedType.getShape(), warpOp.getWarpSize(), 738 warpOp.getLaneid(), delinearizedIds)) { 739 return rewriter.notifyMatchFailure( 740 read, "cannot delinearize lane ID for distribution"); 741 } 742 assert(!delinearizedIds.empty() || map.getNumResults() == 0); 743 744 // Distribute indices and the mask (if present). 745 OpBuilder::InsertionGuard g(rewriter); 746 SmallVector<Value> additionalResults(indices.begin(), indices.end()); 747 SmallVector<Type> additionalResultTypes(indices.size(), 748 rewriter.getIndexType()); 749 additionalResults.push_back(read.getPadding()); 750 additionalResultTypes.push_back(read.getPadding().getType()); 751 752 bool hasMask = false; 753 if (read.getMask()) { 754 hasMask = true; 755 // TODO: Distribution of masked reads with non-trivial permutation maps 756 // requires the distribution of the mask to elementwise match the 757 // distribution of the permuted written vector. Currently the details 758 // of which lane is responsible for which element is captured strictly 759 // by shape information on the warp op, and thus requires materializing 760 // the permutation in IR. 761 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity()) 762 return rewriter.notifyMatchFailure( 763 read, "non-trivial permutation maps not supported"); 764 VectorType maskType = 765 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); 766 additionalResults.push_back(read.getMask()); 767 additionalResultTypes.push_back(maskType); 768 } 769 770 SmallVector<size_t> newRetIndices; 771 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 772 rewriter, warpOp, additionalResults, additionalResultTypes, 773 newRetIndices); 774 distributedVal = newWarpOp.getResult(operandIndex); 775 776 // Distributed indices were appended first. 777 SmallVector<Value> newIndices; 778 for (int64_t i = 0, e = indices.size(); i < e; ++i) 779 newIndices.push_back(newWarpOp.getResult(newRetIndices[i])); 780 781 rewriter.setInsertionPointAfter(newWarpOp); 782 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { 783 AffineExpr d0, d1; 784 bindDims(read.getContext(), d0, d1); 785 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 786 if (!indexExpr) 787 continue; 788 unsigned indexPos = indexExpr.getPosition(); 789 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 790 int64_t scale = distributedType.getDimSize(vectorPos); 791 newIndices[indexPos] = affine::makeComposedAffineApply( 792 rewriter, read.getLoc(), d0 + scale * d1, 793 {newIndices[indexPos], delinearizedIds[vectorPos]}); 794 } 795 796 // Distributed padding value was appended right after the indices. 797 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]); 798 // Distributed mask value was added at the end (if the op has a mask). 799 Value newMask = 800 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) 801 : Value(); 802 auto newRead = rewriter.create<vector::TransferReadOp>( 803 read.getLoc(), distributedVal.getType(), read.getSource(), newIndices, 804 read.getPermutationMapAttr(), newPadding, newMask, 805 read.getInBoundsAttr()); 806 807 rewriter.replaceAllUsesWith(distributedVal, newRead); 808 return success(); 809 } 810 }; 811 812 /// Remove any result that has no use along with the matching yieldOp operand. 813 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 814 struct WarpOpDeadResult : public WarpDistributionPattern { 815 using Base::Base; 816 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 817 PatternRewriter &rewriter) const override { 818 SmallVector<Type> newResultTypes; 819 newResultTypes.reserve(warpOp->getNumResults()); 820 SmallVector<Value> newYieldValues; 821 newYieldValues.reserve(warpOp->getNumResults()); 822 DenseMap<Value, int64_t> dedupYieldOperandPositionMap; 823 DenseMap<OpResult, int64_t> dedupResultPositionMap; 824 auto yield = cast<gpu::YieldOp>( 825 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 826 827 // Some values may be yielded multiple times and correspond to multiple 828 // results. Deduplicating occurs by taking each result with its matching 829 // yielded value, and: 830 // 1. recording the unique first position at which the value is yielded. 831 // 2. recording for the result, the first position at which the dedup'ed 832 // value is yielded. 833 // 3. skipping from the new result types / new yielded values any result 834 // that has no use or whose yielded value has already been seen. 835 for (OpResult result : warpOp.getResults()) { 836 Value yieldOperand = yield.getOperand(result.getResultNumber()); 837 auto it = dedupYieldOperandPositionMap.insert( 838 std::make_pair(yieldOperand, newResultTypes.size())); 839 dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); 840 if (result.use_empty() || !it.second) 841 continue; 842 newResultTypes.push_back(result.getType()); 843 newYieldValues.push_back(yieldOperand); 844 } 845 // No modification, exit early. 846 if (yield.getNumOperands() == newYieldValues.size()) 847 return failure(); 848 // Move the body of the old warpOp to a new warpOp. 849 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 850 rewriter, warpOp, newYieldValues, newResultTypes); 851 852 // Simplify the new warp op after dropping dead results. 853 newWarpOp.getBody()->walk([&](Operation *op) { 854 if (isOpTriviallyDead(op)) 855 rewriter.eraseOp(op); 856 }); 857 858 // Replace results of the old warpOp by the new, deduplicated results. 859 SmallVector<Value> newValues; 860 newValues.reserve(warpOp->getNumResults()); 861 for (OpResult result : warpOp.getResults()) { 862 if (result.use_empty()) 863 newValues.push_back(Value()); 864 else 865 newValues.push_back( 866 newWarpOp.getResult(dedupResultPositionMap.lookup(result))); 867 } 868 rewriter.replaceOp(warpOp, newValues); 869 return success(); 870 } 871 }; 872 873 // If an operand is directly yielded out of the region we can forward it 874 // directly and it doesn't need to go through the region. 875 struct WarpOpForwardOperand : public WarpDistributionPattern { 876 using Base::Base; 877 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 878 PatternRewriter &rewriter) const override { 879 SmallVector<Type> resultTypes; 880 SmallVector<Value> yieldValues; 881 auto yield = cast<gpu::YieldOp>( 882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 883 Value valForwarded; 884 unsigned resultIndex; 885 for (OpOperand &operand : yield->getOpOperands()) { 886 Value result = warpOp.getResult(operand.getOperandNumber()); 887 if (result.use_empty()) 888 continue; 889 890 // Assume all the values coming from above are uniform. 891 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 892 if (result.getType() != operand.get().getType()) 893 continue; 894 valForwarded = operand.get(); 895 resultIndex = operand.getOperandNumber(); 896 break; 897 } 898 auto arg = dyn_cast<BlockArgument>(operand.get()); 899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 900 continue; 901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 902 if (result.getType() != warpOperand.getType()) 903 continue; 904 valForwarded = warpOperand; 905 resultIndex = operand.getOperandNumber(); 906 break; 907 } 908 if (!valForwarded) 909 return failure(); 910 // Notify the rewriter that the warp op is changing (see the comment on 911 // the WarpOpTransferRead pattern). 912 rewriter.startOpModification(warpOp); 913 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); 914 rewriter.finalizeOpModification(warpOp); 915 return success(); 916 } 917 }; 918 919 struct WarpOpBroadcast : public WarpDistributionPattern { 920 using Base::Base; 921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 922 PatternRewriter &rewriter) const override { 923 OpOperand *operand = 924 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); 925 if (!operand) 926 return failure(); 927 unsigned int operandNumber = operand->getOperandNumber(); 928 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 929 Location loc = broadcastOp.getLoc(); 930 auto destVecType = 931 cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 932 Value broadcastSrc = broadcastOp.getSource(); 933 Type broadcastSrcType = broadcastSrc.getType(); 934 935 // Check that the broadcast actually spans a set of values uniformly across 936 // all threads. In other words, check that each thread can reconstruct 937 // their own broadcast. 938 // For that we simply check that the broadcast we want to build makes sense. 939 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != 940 vector::BroadcastableToResult::Success) 941 return failure(); 942 SmallVector<size_t> newRetIndices; 943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); 945 rewriter.setInsertionPointAfter(newWarpOp); 946 Value broadcasted = rewriter.create<vector::BroadcastOp>( 947 loc, destVecType, newWarpOp->getResult(newRetIndices[0])); 948 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 949 broadcasted); 950 return success(); 951 } 952 }; 953 954 /// Pattern to move shape cast out of the warp op. shape cast is basically a 955 /// no-op for warp distribution; we need to handle the shape though. 956 struct WarpOpShapeCast : public WarpDistributionPattern { 957 using Base::Base; 958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 959 PatternRewriter &rewriter) const override { 960 OpOperand *operand = 961 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); 962 if (!operand) 963 return failure(); 964 965 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); 966 967 unsigned int operandNumber = operand->getOperandNumber(); 968 auto castDistributedType = 969 cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 970 VectorType castOriginalType = oldCastOp.getSourceVectorType(); 971 VectorType castResultType = castDistributedType; 972 973 // We expect the distributed type to have a smaller rank than the original 974 // type. Prepend with size-one dimensions to make them the same. 975 unsigned castDistributedRank = castDistributedType.getRank(); 976 unsigned castOriginalRank = castOriginalType.getRank(); 977 if (castDistributedRank < castOriginalRank) { 978 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); 979 llvm::append_range(shape, castDistributedType.getShape()); 980 castDistributedType = 981 VectorType::get(shape, castDistributedType.getElementType()); 982 } 983 984 SmallVector<size_t> newRetIndices; 985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, 987 newRetIndices); 988 rewriter.setInsertionPointAfter(newWarpOp); 989 Value newCast = rewriter.create<vector::ShapeCastOp>( 990 oldCastOp.getLoc(), castResultType, 991 newWarpOp->getResult(newRetIndices[0])); 992 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); 993 return success(); 994 } 995 }; 996 997 /// Sink out vector.create_mask op feeding into a warp op yield. 998 /// ``` 999 /// %0 = ... 1000 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 1001 /// ... 1002 /// %mask = vector.create_mask %0 : vector<32xi1> 1003 /// gpu.yield %mask : vector<32xi1> 1004 /// } 1005 /// ``` 1006 /// To 1007 /// ``` 1008 /// %0 = ... 1009 /// gpu.warp_execute_on_lane_0(%arg0) { 1010 /// ... 1011 /// } 1012 /// %cmp = arith.cmpi ult, %laneid, %0 1013 /// %ub = arith.select %cmp, %c0, %c1 1014 /// %1 = vector.create_mask %ub : vector<1xi1> 1015 struct WarpOpCreateMask : public WarpDistributionPattern { 1016 using Base::Base; 1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1018 PatternRewriter &rewriter) const override { 1019 OpOperand *yieldOperand = 1020 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); 1021 if (!yieldOperand) 1022 return failure(); 1023 1024 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); 1025 1026 // Early exit if any values needed for calculating the new mask indices 1027 // are defined inside the warp op. 1028 if (!llvm::all_of(mask->getOperands(), [&](Value value) { 1029 return warpOp.isDefinedOutsideOfRegion(value); 1030 })) 1031 return failure(); 1032 1033 Location loc = mask.getLoc(); 1034 unsigned operandIndex = yieldOperand->getOperandNumber(); 1035 1036 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); 1037 VectorType seqType = mask.getVectorType(); 1038 ArrayRef<int64_t> seqShape = seqType.getShape(); 1039 ArrayRef<int64_t> distShape = distType.getShape(); 1040 1041 rewriter.setInsertionPointAfter(warpOp); 1042 1043 // Delinearize the lane ID for constructing the distributed mask sizes. 1044 SmallVector<Value> delinearizedIds; 1045 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape, 1046 warpOp.getWarpSize(), warpOp.getLaneid(), 1047 delinearizedIds)) 1048 return rewriter.notifyMatchFailure( 1049 mask, "cannot delinearize lane ID for distribution"); 1050 assert(!delinearizedIds.empty()); 1051 1052 // Notify the rewriter that the warp op is changing (see the comment on 1053 // the WarpOpTransferRead pattern). 1054 rewriter.startOpModification(warpOp); 1055 1056 AffineExpr s0, s1; 1057 bindSymbols(rewriter.getContext(), s0, s1); 1058 SmallVector<Value> newOperands; 1059 for (int i = 0, e = distShape.size(); i < e; ++i) { 1060 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to 1061 // find the distance from the largest mask index owned by this lane to the 1062 // original mask size. `vector.create_mask` implicitly clamps mask 1063 // operands to the range [0, mask_vector_size[i]], or in other words, the 1064 // mask sizes are always in the range [0, mask_vector_size[i]). 1065 Value maskDimIdx = affine::makeComposedAffineApply( 1066 rewriter, loc, s1 - s0 * distShape[i], 1067 {delinearizedIds[i], mask.getOperand(i)}); 1068 newOperands.push_back(maskDimIdx); 1069 } 1070 1071 auto newMask = 1072 rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); 1073 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); 1074 rewriter.finalizeOpModification(warpOp); 1075 return success(); 1076 } 1077 }; 1078 1079 /// Pattern to move out vector.extract of single element vector. Those don't 1080 /// need to be distributed and can just be propagated outside of the region. 1081 struct WarpOpExtract : public WarpDistributionPattern { 1082 using Base::Base; 1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1084 PatternRewriter &rewriter) const override { 1085 OpOperand *operand = 1086 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); 1087 if (!operand) 1088 return failure(); 1089 unsigned int operandNumber = operand->getOperandNumber(); 1090 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 1091 VectorType extractSrcType = extractOp.getSourceVectorType(); 1092 Location loc = extractOp.getLoc(); 1093 1094 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern. 1095 if (extractSrcType.getRank() <= 1) { 1096 return failure(); 1097 } 1098 1099 // All following cases are 2d or higher dimensional source vectors. 1100 1101 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 1102 // There is no distribution, this is a broadcast. Simply move the extract 1103 // out of the warp op. 1104 // TODO: This could be optimized. E.g., in case of a scalar result, let 1105 // one lane extract and shuffle the result to all other lanes (same as 1106 // the 1d case). 1107 SmallVector<size_t> newRetIndices; 1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1109 rewriter, warpOp, {extractOp.getVector()}, 1110 {extractOp.getSourceVectorType()}, newRetIndices); 1111 rewriter.setInsertionPointAfter(newWarpOp); 1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1113 // Extract from distributed vector. 1114 Value newExtract = rewriter.create<vector::ExtractOp>( 1115 loc, distributedVec, extractOp.getMixedPosition()); 1116 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1117 newExtract); 1118 return success(); 1119 } 1120 1121 // Find the distributed dimension. There should be exactly one. 1122 auto distributedType = 1123 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1124 auto yieldedType = cast<VectorType>(operand->get().getType()); 1125 int64_t distributedDim = -1; 1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { 1128 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 1129 // support distributing multiple dimensions in the future. 1130 assert(distributedDim == -1 && "found multiple distributed dims"); 1131 distributedDim = i; 1132 } 1133 } 1134 assert(distributedDim != -1 && "could not find distributed dimension"); 1135 (void)distributedDim; 1136 1137 // Yield source vector from warp op. 1138 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape()); 1139 for (int i = 0; i < distributedType.getRank(); ++i) 1140 newDistributedShape[i + extractOp.getNumIndices()] = 1141 distributedType.getDimSize(i); 1142 auto newDistributedType = 1143 VectorType::get(newDistributedShape, distributedType.getElementType()); 1144 SmallVector<size_t> newRetIndices; 1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, 1147 newRetIndices); 1148 rewriter.setInsertionPointAfter(newWarpOp); 1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1150 // Extract from distributed vector. 1151 Value newExtract = rewriter.create<vector::ExtractOp>( 1152 loc, distributedVec, extractOp.getMixedPosition()); 1153 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1154 newExtract); 1155 return success(); 1156 } 1157 }; 1158 1159 /// Pattern to move out vector.extract with a scalar result. 1160 /// Only supports 1-D and 0-D sources for now. 1161 struct WarpOpExtractScalar : public WarpDistributionPattern { 1162 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, 1163 PatternBenefit b = 1) 1164 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} 1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1166 PatternRewriter &rewriter) const override { 1167 OpOperand *operand = 1168 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); 1169 if (!operand) 1170 return failure(); 1171 unsigned int operandNumber = operand->getOperandNumber(); 1172 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 1173 VectorType extractSrcType = extractOp.getSourceVectorType(); 1174 // Only supports 1-D or 0-D sources for now. 1175 if (extractSrcType.getRank() > 1) { 1176 return rewriter.notifyMatchFailure( 1177 extractOp, "only 0-D or 1-D source supported for now"); 1178 } 1179 // TODO: Supported shuffle types should be parameterizable, similar to 1180 // `WarpShuffleFromIdxFn`. 1181 if (!extractSrcType.getElementType().isF32() && 1182 !extractSrcType.getElementType().isInteger(32)) 1183 return rewriter.notifyMatchFailure( 1184 extractOp, "only f32/i32 element types are supported"); 1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; 1186 Type elType = extractSrcType.getElementType(); 1187 VectorType distributedVecType; 1188 if (!is0dOrVec1Extract) { 1189 assert(extractSrcType.getRank() == 1 && 1190 "expected that extract src rank is 0 or 1"); 1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) 1192 return failure(); 1193 int64_t elementsPerLane = 1194 extractSrcType.getShape()[0] / warpOp.getWarpSize(); 1195 distributedVecType = VectorType::get({elementsPerLane}, elType); 1196 } else { 1197 distributedVecType = extractSrcType; 1198 } 1199 // Yield source vector and position (if present) from warp op. 1200 SmallVector<Value> additionalResults{extractOp.getVector()}; 1201 SmallVector<Type> additionalResultTypes{distributedVecType}; 1202 additionalResults.append( 1203 SmallVector<Value>(extractOp.getDynamicPosition())); 1204 additionalResultTypes.append( 1205 SmallVector<Type>(extractOp.getDynamicPosition().getTypes())); 1206 1207 Location loc = extractOp.getLoc(); 1208 SmallVector<size_t> newRetIndices; 1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1210 rewriter, warpOp, additionalResults, additionalResultTypes, 1211 newRetIndices); 1212 rewriter.setInsertionPointAfter(newWarpOp); 1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1214 1215 // 0d extract: The new warp op broadcasts the source vector to all lanes. 1216 // All lanes extract the scalar. 1217 if (is0dOrVec1Extract) { 1218 Value newExtract; 1219 SmallVector<int64_t> indices(extractSrcType.getRank(), 0); 1220 newExtract = 1221 rewriter.create<vector::ExtractOp>(loc, distributedVec, indices); 1222 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1223 newExtract); 1224 return success(); 1225 } 1226 1227 int64_t staticPos = extractOp.getStaticPosition()[0]; 1228 OpFoldResult pos = ShapedType::isDynamic(staticPos) 1229 ? (newWarpOp->getResult(newRetIndices[1])) 1230 : OpFoldResult(rewriter.getIndexAttr(staticPos)); 1231 // 1d extract: Distribute the source vector. One lane extracts and shuffles 1232 // the value to all other lanes. 1233 int64_t elementsPerLane = distributedVecType.getShape()[0]; 1234 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 1235 // tid of extracting thread: pos / elementsPerLane 1236 Value broadcastFromTid = affine::makeComposedAffineApply( 1237 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); 1238 // Extract at position: pos % elementsPerLane 1239 Value newPos = 1240 elementsPerLane == 1 1241 ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult() 1242 : affine::makeComposedAffineApply(rewriter, loc, 1243 sym0 % elementsPerLane, pos); 1244 Value extracted = 1245 rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos); 1246 1247 // Shuffle the extracted value to all lanes. 1248 Value shuffled = warpShuffleFromIdxFn( 1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); 1250 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); 1251 return success(); 1252 } 1253 1254 private: 1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn; 1256 }; 1257 1258 /// Pattern to convert vector.extractelement to vector.extract. 1259 struct WarpOpExtractElement : public WarpDistributionPattern { 1260 using Base::Base; 1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1262 PatternRewriter &rewriter) const override { 1263 OpOperand *operand = 1264 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); 1265 if (!operand) 1266 return failure(); 1267 auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>(); 1268 SmallVector<OpFoldResult> indices; 1269 if (auto pos = extractOp.getPosition()) { 1270 indices.push_back(pos); 1271 } 1272 rewriter.setInsertionPoint(extractOp); 1273 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1274 extractOp, extractOp.getVector(), indices); 1275 return success(); 1276 } 1277 }; 1278 1279 /// Pattern to move out vector.insert with a scalar input. 1280 /// Only supports 1-D and 0-D destinations for now. 1281 struct WarpOpInsertScalar : public WarpDistributionPattern { 1282 using Base::Base; 1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1284 PatternRewriter &rewriter) const override { 1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); 1286 if (!operand) 1287 return failure(); 1288 unsigned int operandNumber = operand->getOperandNumber(); 1289 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); 1290 VectorType vecType = insertOp.getDestVectorType(); 1291 VectorType distrType = 1292 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1293 1294 // Only supports 1-D or 0-D destinations for now. 1295 if (vecType.getRank() > 1) { 1296 return rewriter.notifyMatchFailure( 1297 insertOp, "only 0-D or 1-D source supported for now"); 1298 } 1299 1300 // Yield destination vector, source scalar and position from warp op. 1301 SmallVector<Value> additionalResults{insertOp.getDest(), 1302 insertOp.getSource()}; 1303 SmallVector<Type> additionalResultTypes{distrType, 1304 insertOp.getSource().getType()}; 1305 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition())); 1306 additionalResultTypes.append( 1307 SmallVector<Type>(insertOp.getDynamicPosition().getTypes())); 1308 1309 Location loc = insertOp.getLoc(); 1310 SmallVector<size_t> newRetIndices; 1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1312 rewriter, warpOp, additionalResults, additionalResultTypes, 1313 newRetIndices); 1314 rewriter.setInsertionPointAfter(newWarpOp); 1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1316 Value newSource = newWarpOp->getResult(newRetIndices[1]); 1317 rewriter.setInsertionPointAfter(newWarpOp); 1318 1319 OpFoldResult pos; 1320 if (vecType.getRank() != 0) { 1321 int64_t staticPos = insertOp.getStaticPosition()[0]; 1322 pos = ShapedType::isDynamic(staticPos) 1323 ? (newWarpOp->getResult(newRetIndices[2])) 1324 : OpFoldResult(rewriter.getIndexAttr(staticPos)); 1325 } 1326 1327 // This condition is always true for 0-d vectors. 1328 if (vecType == distrType) { 1329 Value newInsert; 1330 SmallVector<OpFoldResult> indices; 1331 if (pos) { 1332 indices.push_back(pos); 1333 } 1334 newInsert = rewriter.create<vector::InsertOp>(loc, newSource, 1335 distributedVec, indices); 1336 // Broadcast: Simply move the vector.insert op out. 1337 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1338 newInsert); 1339 return success(); 1340 } 1341 1342 // This is a distribution. Only one lane should insert. 1343 int64_t elementsPerLane = distrType.getShape()[0]; 1344 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 1345 // tid of extracting thread: pos / elementsPerLane 1346 Value insertingLane = affine::makeComposedAffineApply( 1347 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); 1348 // Insert position: pos % elementsPerLane 1349 OpFoldResult newPos = affine::makeComposedFoldedAffineApply( 1350 rewriter, loc, sym0 % elementsPerLane, pos); 1351 Value isInsertingLane = rewriter.create<arith::CmpIOp>( 1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 1353 Value newResult = 1354 rewriter 1355 .create<scf::IfOp>( 1356 loc, isInsertingLane, 1357 /*thenBuilder=*/ 1358 [&](OpBuilder &builder, Location loc) { 1359 Value newInsert = builder.create<vector::InsertOp>( 1360 loc, newSource, distributedVec, newPos); 1361 builder.create<scf::YieldOp>(loc, newInsert); 1362 }, 1363 /*elseBuilder=*/ 1364 [&](OpBuilder &builder, Location loc) { 1365 builder.create<scf::YieldOp>(loc, distributedVec); 1366 }) 1367 .getResult(0); 1368 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 1369 return success(); 1370 } 1371 }; 1372 1373 struct WarpOpInsert : public WarpDistributionPattern { 1374 using Base::Base; 1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1376 PatternRewriter &rewriter) const override { 1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); 1378 if (!operand) 1379 return failure(); 1380 unsigned int operandNumber = operand->getOperandNumber(); 1381 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); 1382 Location loc = insertOp.getLoc(); 1383 1384 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern. 1385 if (insertOp.getDestVectorType().getRank() <= 1) { 1386 return failure(); 1387 } 1388 1389 // All following cases are 2d or higher dimensional source vectors. 1390 1391 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 1392 // There is no distribution, this is a broadcast. Simply move the insert 1393 // out of the warp op. 1394 SmallVector<size_t> newRetIndices; 1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1396 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 1397 {insertOp.getSourceType(), insertOp.getDestVectorType()}, 1398 newRetIndices); 1399 rewriter.setInsertionPointAfter(newWarpOp); 1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 1402 Value newResult = rewriter.create<vector::InsertOp>( 1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 1404 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1405 newResult); 1406 return success(); 1407 } 1408 1409 // Find the distributed dimension. There should be exactly one. 1410 auto distrDestType = 1411 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1412 auto yieldedType = cast<VectorType>(operand->get().getType()); 1413 int64_t distrDestDim = -1; 1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { 1416 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 1417 // support distributing multiple dimensions in the future. 1418 assert(distrDestDim == -1 && "found multiple distributed dims"); 1419 distrDestDim = i; 1420 } 1421 } 1422 assert(distrDestDim != -1 && "could not find distributed dimension"); 1423 1424 // Compute the distributed source vector type. 1425 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType()); 1426 SmallVector<int64_t> distrSrcShape(srcVecType.getShape()); 1427 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> 1428 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will 1429 // insert a smaller vector<3xf32>. 1430 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that 1431 // case, one lane will insert the source vector<96xf32>. The other 1432 // lanes will not do anything. 1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); 1434 if (distrSrcDim >= 0) 1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); 1436 auto distrSrcType = 1437 VectorType::get(distrSrcShape, distrDestType.getElementType()); 1438 1439 // Yield source and dest vectors from warp op. 1440 SmallVector<size_t> newRetIndices; 1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1442 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 1443 {distrSrcType, distrDestType}, newRetIndices); 1444 rewriter.setInsertionPointAfter(newWarpOp); 1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 1447 1448 // Insert into the distributed vector. 1449 Value newResult; 1450 if (distrSrcDim >= 0) { 1451 // Every lane inserts a small piece. 1452 newResult = rewriter.create<vector::InsertOp>( 1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 1454 } else { 1455 // One lane inserts the entire source vector. 1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); 1457 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); 1458 SmallVector<int64_t> newPos = getAsIntegers(pos); 1459 // tid of inserting lane: pos / elementsPerLane 1460 Value insertingLane = rewriter.create<arith::ConstantIndexOp>( 1461 loc, newPos[distrDestDim] / elementsPerLane); 1462 Value isInsertingLane = rewriter.create<arith::CmpIOp>( 1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 1464 // Insert position: pos % elementsPerLane 1465 newPos[distrDestDim] %= elementsPerLane; 1466 auto insertingBuilder = [&](OpBuilder &builder, Location loc) { 1467 Value newInsert = builder.create<vector::InsertOp>( 1468 loc, distributedSrc, distributedDest, newPos); 1469 builder.create<scf::YieldOp>(loc, newInsert); 1470 }; 1471 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { 1472 builder.create<scf::YieldOp>(loc, distributedDest); 1473 }; 1474 newResult = rewriter 1475 .create<scf::IfOp>(loc, isInsertingLane, 1476 /*thenBuilder=*/insertingBuilder, 1477 /*elseBuilder=*/nonInsertingBuilder) 1478 .getResult(0); 1479 } 1480 1481 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 1482 return success(); 1483 } 1484 }; 1485 1486 struct WarpOpInsertElement : public WarpDistributionPattern { 1487 using Base::Base; 1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1489 PatternRewriter &rewriter) const override { 1490 OpOperand *operand = 1491 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); 1492 if (!operand) 1493 return failure(); 1494 auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); 1495 SmallVector<OpFoldResult> indices; 1496 if (auto pos = insertOp.getPosition()) { 1497 indices.push_back(pos); 1498 } 1499 rewriter.setInsertionPoint(insertOp); 1500 rewriter.replaceOpWithNewOp<vector::InsertOp>( 1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices); 1502 return success(); 1503 } 1504 }; 1505 1506 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 1507 /// the scf.ForOp is the last operation in the region so that it doesn't 1508 /// change the order of execution. This creates a new scf.for region after the 1509 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 1510 /// WarpExecuteOnLane0Op region. Example: 1511 /// ``` 1512 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 1513 /// ... 1514 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 1515 /// -> (vector<128xf32>) { 1516 /// ... 1517 /// scf.yield %r : vector<128xf32> 1518 /// } 1519 /// gpu.yield %v1 : vector<128xf32> 1520 /// } 1521 /// ``` 1522 /// To: 1523 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 1524 /// ... 1525 /// gpu.yield %v : vector<128xf32> 1526 /// } 1527 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 1528 /// -> (vector<4xf32>) { 1529 /// %iw = gpu.warp_execute_on_lane_0(%laneid) 1530 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 1531 /// ^bb0(%arg: vector<128xf32>): 1532 /// ... 1533 /// gpu.yield %ir : vector<128xf32> 1534 /// } 1535 /// scf.yield %iw : vector<4xf32> 1536 /// } 1537 /// ``` 1538 struct WarpOpScfForOp : public WarpDistributionPattern { 1539 1540 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) 1541 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} 1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1543 PatternRewriter &rewriter) const override { 1544 auto yield = cast<gpu::YieldOp>( 1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 1546 // Only pick up forOp if it is the last op in the region. 1547 Operation *lastNode = yield->getPrevNode(); 1548 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 1549 if (!forOp) 1550 return failure(); 1551 // Collect Values that come from the warp op but are outside the forOp. 1552 // Those Value needs to be returned by the original warpOp and passed to 1553 // the new op. 1554 llvm::SmallSetVector<Value, 32> escapingValues; 1555 SmallVector<Type> inputTypes; 1556 SmallVector<Type> distTypes; 1557 mlir::visitUsedValuesDefinedAbove( 1558 forOp.getBodyRegion(), [&](OpOperand *operand) { 1559 Operation *parent = operand->get().getParentRegion()->getParentOp(); 1560 if (warpOp->isAncestor(parent)) { 1561 if (!escapingValues.insert(operand->get())) 1562 return; 1563 Type distType = operand->get().getType(); 1564 if (auto vecType = dyn_cast<VectorType>(distType)) { 1565 AffineMap map = distributionMapFn(operand->get()); 1566 distType = getDistributedType(vecType, map, warpOp.getWarpSize()); 1567 } 1568 inputTypes.push_back(operand->get().getType()); 1569 distTypes.push_back(distType); 1570 } 1571 }); 1572 1573 if (llvm::is_contained(distTypes, Type{})) 1574 return failure(); 1575 1576 SmallVector<size_t> newRetIndices; 1577 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1578 rewriter, warpOp, escapingValues.getArrayRef(), distTypes, 1579 newRetIndices); 1580 yield = cast<gpu::YieldOp>( 1581 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 1582 1583 SmallVector<Value> newOperands; 1584 SmallVector<unsigned> resultIdx; 1585 // Collect all the outputs coming from the forOp. 1586 for (OpOperand &yieldOperand : yield->getOpOperands()) { 1587 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 1588 continue; 1589 auto forResult = cast<OpResult>(yieldOperand.get()); 1590 newOperands.push_back( 1591 newWarpOp.getResult(yieldOperand.getOperandNumber())); 1592 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); 1593 resultIdx.push_back(yieldOperand.getOperandNumber()); 1594 } 1595 1596 OpBuilder::InsertionGuard g(rewriter); 1597 rewriter.setInsertionPointAfter(newWarpOp); 1598 1599 // Create a new for op outside the region with a WarpExecuteOnLane0Op 1600 // region inside. 1601 auto newForOp = rewriter.create<scf::ForOp>( 1602 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 1603 forOp.getStep(), newOperands); 1604 rewriter.setInsertionPointToStart(newForOp.getBody()); 1605 1606 SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), 1607 newForOp.getRegionIterArgs().end()); 1608 SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), 1609 forOp.getResultTypes().end()); 1610 llvm::SmallDenseMap<Value, int64_t> argIndexMapping; 1611 for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { 1612 warpInput.push_back(newWarpOp.getResult(retIdx)); 1613 argIndexMapping[escapingValues[i]] = warpInputType.size(); 1614 warpInputType.push_back(inputTypes[i]); 1615 } 1616 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 1617 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), 1618 newWarpOp.getWarpSize(), warpInput, warpInputType); 1619 1620 SmallVector<Value> argMapping; 1621 argMapping.push_back(newForOp.getInductionVar()); 1622 for (Value args : innerWarp.getBody()->getArguments()) { 1623 argMapping.push_back(args); 1624 } 1625 argMapping.resize(forOp.getBody()->getNumArguments()); 1626 SmallVector<Value> yieldOperands; 1627 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 1628 yieldOperands.push_back(operand); 1629 rewriter.eraseOp(forOp.getBody()->getTerminator()); 1630 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 1631 rewriter.setInsertionPointToEnd(innerWarp.getBody()); 1632 rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); 1633 rewriter.setInsertionPointAfter(innerWarp); 1634 if (!innerWarp.getResults().empty()) 1635 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 1636 rewriter.eraseOp(forOp); 1637 // Replace the warpOp result coming from the original ForOp. 1638 for (const auto &res : llvm::enumerate(resultIdx)) { 1639 rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), 1640 newForOp.getResult(res.index())); 1641 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); 1642 } 1643 newForOp.walk([&](Operation *op) { 1644 for (OpOperand &operand : op->getOpOperands()) { 1645 auto it = argIndexMapping.find(operand.get()); 1646 if (it == argIndexMapping.end()) 1647 continue; 1648 operand.set(innerWarp.getBodyRegion().getArgument(it->second)); 1649 } 1650 }); 1651 1652 // Finally, hoist out any now uniform code from the inner warp op. 1653 mlir::vector::moveScalarUniformCode(innerWarp); 1654 return success(); 1655 } 1656 1657 private: 1658 DistributionMapFn distributionMapFn; 1659 }; 1660 1661 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 1662 /// The vector is reduced in parallel. Currently limited to vector size 1663 /// matching the warpOp size. E.g.: 1664 /// ``` 1665 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 1666 /// %0 = "some_def"() : () -> (vector<32xf32>) 1667 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 1668 /// gpu.yield %1 : f32 1669 /// } 1670 /// ``` 1671 /// is lowered to: 1672 /// ``` 1673 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 1674 /// %1 = "some_def"() : () -> (vector<32xf32>) 1675 /// gpu.yield %1 : vector<32xf32> 1676 /// } 1677 /// %a = vector.extract %0[0] : f32 from vector<1xf32> 1678 /// %r = ("warp.reduction %a") 1679 /// ``` 1680 struct WarpOpReduction : public WarpDistributionPattern { 1681 WarpOpReduction(MLIRContext *context, 1682 DistributedReductionFn distributedReductionFn, 1683 PatternBenefit benefit = 1) 1684 : WarpDistributionPattern(context, benefit), 1685 distributedReductionFn(std::move(distributedReductionFn)) {} 1686 1687 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1688 PatternRewriter &rewriter) const override { 1689 OpOperand *yieldOperand = 1690 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); 1691 if (!yieldOperand) 1692 return failure(); 1693 1694 auto reductionOp = 1695 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 1696 auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); 1697 // Only rank 1 vectors supported. 1698 if (vectorType.getRank() != 1) 1699 return rewriter.notifyMatchFailure( 1700 warpOp, "Only rank 1 reductions can be distributed."); 1701 // Only warp_size-sized vectors supported. 1702 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) 1703 return rewriter.notifyMatchFailure( 1704 warpOp, "Reduction vector dimension must match was size."); 1705 if (!reductionOp.getType().isIntOrFloat()) 1706 return rewriter.notifyMatchFailure( 1707 warpOp, "Reduction distribution currently only supports floats and " 1708 "integer types."); 1709 1710 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); 1711 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 1712 unsigned operandIndex = yieldOperand->getOperandNumber(); 1713 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 1714 SmallVector<Type> retTypes = { 1715 VectorType::get({numElements}, reductionOp.getType())}; 1716 if (reductionOp.getAcc()) { 1717 yieldValues.push_back(reductionOp.getAcc()); 1718 retTypes.push_back(reductionOp.getAcc().getType()); 1719 } 1720 SmallVector<size_t> newRetIndices; 1721 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1722 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 1723 rewriter.setInsertionPointAfter(newWarpOp); 1724 1725 // Obtain data to reduce for a single lane. 1726 Value laneValVec = newWarpOp.getResult(newRetIndices[0]); 1727 // Distribute and reduce across threads. 1728 Value fullReduce = 1729 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, 1730 reductionOp.getKind(), newWarpOp.getWarpSize()); 1731 if (reductionOp.getAcc()) { 1732 fullReduce = vector::makeArithReduction( 1733 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, 1734 newWarpOp.getResult(newRetIndices[1])); 1735 } 1736 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); 1737 return success(); 1738 } 1739 1740 private: 1741 DistributedReductionFn distributedReductionFn; 1742 }; 1743 1744 } // namespace 1745 1746 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 1747 RewritePatternSet &patterns, 1748 const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { 1749 patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit); 1750 } 1751 1752 void mlir::vector::populateDistributeTransferWriteOpPatterns( 1753 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 1754 unsigned maxNumElementsToExtract, PatternBenefit benefit) { 1755 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn, 1756 maxNumElementsToExtract, benefit); 1757 } 1758 1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 1760 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 1761 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, 1762 PatternBenefit readBenefit) { 1763 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); 1764 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, 1765 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, 1766 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement, 1767 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>( 1768 patterns.getContext(), benefit); 1769 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, 1770 benefit); 1771 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, 1772 benefit); 1773 } 1774 1775 void mlir::vector::populateDistributeReduction( 1776 RewritePatternSet &patterns, 1777 const DistributedReductionFn &distributedReductionFn, 1778 PatternBenefit benefit) { 1779 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn, 1780 benefit); 1781 } 1782 1783 /// Helper to know if an op can be hoisted out of the region. 1784 static bool canBeHoisted(Operation *op, 1785 function_ref<bool(Value)> definedOutside) { 1786 return llvm::all_of(op->getOperands(), definedOutside) && 1787 isMemoryEffectFree(op) && op->getNumRegions() == 0; 1788 } 1789 1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 1791 Block *body = warpOp.getBody(); 1792 1793 // Keep track of the ops we want to hoist. 1794 llvm::SmallSetVector<Operation *, 8> opsToMove; 1795 1796 // Helper to check if a value is or will be defined outside of the region. 1797 auto isDefinedOutsideOfBody = [&](Value value) { 1798 auto *definingOp = value.getDefiningOp(); 1799 return (definingOp && opsToMove.count(definingOp)) || 1800 warpOp.isDefinedOutsideOfRegion(value); 1801 }; 1802 1803 // Do not use walk here, as we do not want to go into nested regions and hoist 1804 // operations from there. 1805 for (auto &op : body->without_terminator()) { 1806 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 1807 return isa<VectorType>(result.getType()); 1808 }); 1809 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 1810 opsToMove.insert(&op); 1811 } 1812 1813 // Move all the ops marked as uniform outside of the region. 1814 for (Operation *op : opsToMove) 1815 op->moveBefore(warpOp); 1816 } 1817