1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Transforms/RegionUtils.h" 18 #include "llvm/ADT/SetVector.h" 19 #include <numeric> 20 #include <utility> 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 /// Currently the distribution map is implicit based on the vector shape. In the 26 /// future it will be part of the op. 27 /// Example: 28 /// ``` 29 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 30 /// ... 31 /// vector.yield %3 : vector<32x16x64xf32> 32 /// } 33 /// ``` 34 /// Would have an implicit map of: 35 /// `(d0, d1, d2) -> (d0, d2)` 36 static AffineMap calculateImplicitMap(VectorType sequentialType, 37 VectorType distributedType) { 38 SmallVector<AffineExpr> perm; 39 perm.reserve(1); 40 // Check which dimensions of the sequential type are different than the 41 // dimensions of the distributed type to know the distributed dimensions. Then 42 // associate each distributed dimension to an ID in order. 43 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { 44 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) 45 perm.push_back(getAffineDimExpr(i, distributedType.getContext())); 46 } 47 auto map = AffineMap::get(sequentialType.getRank(), 0, perm, 48 distributedType.getContext()); 49 return map; 50 } 51 52 namespace { 53 54 /// Helper struct to create the load / store operations that permit transit 55 /// through the parallel / sequential and the sequential / parallel boundaries 56 /// when performing `rewriteWarpOpToScfFor`. 57 /// 58 /// The vector distribution dimension is inferred from the vector types. 59 struct DistributedLoadStoreHelper { 60 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, 61 Value laneId, Value zero) 62 : sequentialVal(sequentialVal), distributedVal(distributedVal), 63 laneId(laneId), zero(zero) { 64 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); 65 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); 66 if (sequentialVectorType && distributedVectorType) 67 distributionMap = 68 calculateImplicitMap(sequentialVectorType, distributedVectorType); 69 } 70 71 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { 72 int64_t distributedSize = distributedVectorType.getDimSize(index); 73 AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); 74 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, 75 ArrayRef<Value>{laneId}); 76 } 77 78 /// Create a store during the process of distributing the 79 /// `vector.warp_execute_on_thread_0` op. 80 /// Vector distribution assumes the following convention regarding the 81 /// temporary buffers that are created to transition values. This **must** 82 /// be properly specified in the `options.warpAllocationFn`: 83 /// 1. scalars of type T transit through a memref<1xT>. 84 /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 85 Operation *buildStore(RewriterBase &b, Location loc, Value val, 86 Value buffer) { 87 assert((val == distributedVal || val == sequentialVal) && 88 "Must store either the preregistered distributed or the " 89 "preregistered sequential value."); 90 // Scalar case can directly use memref.store. 91 if (!isa<VectorType>(val.getType())) 92 return b.create<memref::StoreOp>(loc, val, buffer, zero); 93 94 // Vector case must use vector::TransferWriteOp which will later lower to 95 // vector.store of memref.store depending on further lowerings. 96 int64_t rank = sequentialVectorType.getRank(); 97 SmallVector<Value> indices(rank, zero); 98 if (val == distributedVal) { 99 for (auto dimExpr : distributionMap.getResults()) { 100 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 101 indices[index] = buildDistributedOffset(b, loc, index); 102 } 103 } 104 SmallVector<bool> inBounds(indices.size(), true); 105 return b.create<vector::TransferWriteOp>( 106 loc, val, buffer, indices, 107 ArrayRef<bool>(inBounds.begin(), inBounds.end())); 108 } 109 110 /// Create a load during the process of distributing the 111 /// `vector.warp_execute_on_thread_0` op. 112 /// Vector distribution assumes the following convention regarding the 113 /// temporary buffers that are created to transition values. This **must** 114 /// be properly specified in the `options.warpAllocationFn`: 115 /// 1. scalars of type T transit through a memref<1xT>. 116 /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 117 /// 118 /// When broadcastMode is true, the load is not distributed to account for 119 /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op. 120 /// 121 /// Example: 122 /// 123 /// ``` 124 /// %r = vector.warp_execute_on_lane_0(...) -> (f32) { 125 /// vector.yield %cst : f32 126 /// } 127 /// // Both types are f32. The constant %cst is broadcasted to all lanes. 128 /// ``` 129 /// This behavior described in more detail in the documentation of the op. 130 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { 131 132 // Scalar case can directly use memref.store. 133 if (!isa<VectorType>(type)) 134 return b.create<memref::LoadOp>(loc, buffer, zero); 135 136 // Other cases must be vector atm. 137 // Vector case must use vector::TransferReadOp which will later lower to 138 // vector.read of memref.read depending on further lowerings. 139 assert((type == distributedVectorType || type == sequentialVectorType) && 140 "Must store either the preregistered distributed or the " 141 "preregistered sequential type."); 142 SmallVector<Value> indices(sequentialVectorType.getRank(), zero); 143 if (type == distributedVectorType) { 144 for (auto dimExpr : distributionMap.getResults()) { 145 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 146 indices[index] = buildDistributedOffset(b, loc, index); 147 } 148 } 149 SmallVector<bool> inBounds(indices.size(), true); 150 return b.create<vector::TransferReadOp>( 151 loc, cast<VectorType>(type), buffer, indices, 152 ArrayRef<bool>(inBounds.begin(), inBounds.end())); 153 } 154 155 Value sequentialVal, distributedVal, laneId, zero; 156 VectorType sequentialVectorType, distributedVectorType; 157 AffineMap distributionMap; 158 }; 159 160 } // namespace 161 162 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 163 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 164 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 165 ValueRange newYieldedValues, TypeRange newReturnTypes) { 166 // Create a new op before the existing one, with the extra operands. 167 OpBuilder::InsertionGuard g(rewriter); 168 rewriter.setInsertionPoint(warpOp); 169 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 170 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 171 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 172 173 Region &opBody = warpOp.getBodyRegion(); 174 Region &newOpBody = newWarpOp.getBodyRegion(); 175 Block &newOpFirstBlock = newOpBody.front(); 176 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 177 rewriter.eraseBlock(&newOpFirstBlock); 178 assert(newWarpOp.getWarpRegion().hasOneBlock() && 179 "expected WarpOp with single block"); 180 181 auto yield = 182 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 183 184 rewriter.updateRootInPlace( 185 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); 186 return newWarpOp; 187 } 188 189 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 190 /// `indices` return the index of each new output. 191 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 192 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 193 ValueRange newYieldedValues, TypeRange newReturnTypes, 194 llvm::SmallVector<size_t> &indices) { 195 SmallVector<Type> types(warpOp.getResultTypes().begin(), 196 warpOp.getResultTypes().end()); 197 auto yield = cast<vector::YieldOp>( 198 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 199 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), 200 yield.getOperands().end()); 201 for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { 202 if (yieldValues.insert(std::get<0>(newRet))) { 203 types.push_back(std::get<1>(newRet)); 204 indices.push_back(yieldValues.size() - 1); 205 } else { 206 // If the value already exit the region don't create a new output. 207 for (auto [idx, yieldOperand] : 208 llvm::enumerate(yieldValues.getArrayRef())) { 209 if (yieldOperand == std::get<0>(newRet)) { 210 indices.push_back(idx); 211 break; 212 } 213 } 214 } 215 } 216 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); 217 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 218 rewriter, warpOp, yieldValues.getArrayRef(), types); 219 rewriter.replaceOp(warpOp, 220 newWarpOp.getResults().take_front(warpOp.getNumResults())); 221 return newWarpOp; 222 } 223 224 /// Helper to know if an op can be hoisted out of the region. 225 static bool canBeHoisted(Operation *op, 226 function_ref<bool(Value)> definedOutside) { 227 return llvm::all_of(op->getOperands(), definedOutside) && 228 isMemoryEffectFree(op) && op->getNumRegions() == 0; 229 } 230 231 /// Return a value yielded by `warpOp` which statifies the filter lamdba 232 /// condition and is not dead. 233 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 234 const std::function<bool(Operation *)> &fn) { 235 auto yield = cast<vector::YieldOp>( 236 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 237 for (OpOperand &yieldOperand : yield->getOpOperands()) { 238 Value yieldValues = yieldOperand.get(); 239 Operation *definedOp = yieldValues.getDefiningOp(); 240 if (definedOp && fn(definedOp)) { 241 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 242 return &yieldOperand; 243 } 244 } 245 return {}; 246 } 247 248 // Clones `op` into a new operation that takes `operands` and returns 249 // `resultTypes`. 250 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 251 Location loc, Operation *op, 252 ArrayRef<Value> operands, 253 ArrayRef<Type> resultTypes) { 254 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 255 op->getAttrs()); 256 return rewriter.create(res); 257 } 258 259 namespace { 260 261 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single 262 /// thread `laneId` executes the entirety of the computation. 263 /// 264 /// After the transformation: 265 /// - the IR within the scf.if op can be thought of as executing sequentially 266 /// (from the point of view of threads along `laneId`). 267 /// - the IR outside of the scf.if op can be thought of as executing in 268 /// parallel (from the point of view of threads along `laneId`). 269 /// 270 /// Values that need to transit through the parallel / sequential and the 271 /// sequential / parallel boundaries do so via reads and writes to a temporary 272 /// memory location. 273 /// 274 /// The transformation proceeds in multiple steps: 275 /// 1. Create the scf.if op. 276 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads 277 /// within the scf.if to transit the values captured from above. 278 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are 279 /// consistent within the scf.if. 280 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. 281 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to 282 /// transit the values returned by the op. 283 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are 284 /// consistent after the scf.if. 285 /// 7. Perform late cleanups. 286 /// 287 /// All this assumes the vector distribution occurs along the most minor 288 /// distributed vector dimension. 289 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 290 WarpOpToScfIfPattern(MLIRContext *context, 291 const WarpExecuteOnLane0LoweringOptions &options, 292 PatternBenefit benefit = 1) 293 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 294 options(options) {} 295 296 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 297 PatternRewriter &rewriter) const override { 298 assert(warpOp.getBodyRegion().hasOneBlock() && 299 "expected WarpOp with single block"); 300 Block *warpOpBody = &warpOp.getBodyRegion().front(); 301 Location loc = warpOp.getLoc(); 302 303 // Passed all checks. Start rewriting. 304 OpBuilder::InsertionGuard g(rewriter); 305 rewriter.setInsertionPoint(warpOp); 306 307 // Step 1: Create scf.if op. 308 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 309 Value isLane0 = rewriter.create<arith::CmpIOp>( 310 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); 311 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 312 /*withElseRegion=*/false); 313 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 314 315 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and 316 // reads within the scf.if to transit the values captured from above. 317 SmallVector<Value> bbArgReplacements; 318 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 319 Value sequentialVal = warpOpBody->getArgument(it.index()); 320 Value distributedVal = it.value(); 321 DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 322 warpOp.getLaneid(), c0); 323 324 // Create buffer before the ifOp. 325 rewriter.setInsertionPoint(ifOp); 326 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 327 sequentialVal.getType()); 328 // Store distributed vector into buffer, before the ifOp. 329 helper.buildStore(rewriter, loc, distributedVal, buffer); 330 // Load sequential vector from buffer, inside the ifOp. 331 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 332 bbArgReplacements.push_back( 333 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); 334 } 335 336 // Step 3. Insert sync after all the stores and before all the loads. 337 if (!warpOp.getArgs().empty()) { 338 rewriter.setInsertionPoint(ifOp); 339 options.warpSyncronizationFn(loc, rewriter, warpOp); 340 } 341 342 // Step 4. Move body of warpOp to ifOp. 343 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 344 345 // Step 5. Insert appropriate writes within scf.if and reads after the 346 // scf.if to transit the values returned by the op. 347 // TODO: at this point, we can reuse the shared memory from previous 348 // buffers. 349 SmallVector<Value> replacements; 350 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 351 Location yieldLoc = yieldOp.getLoc(); 352 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 353 Value sequentialVal = it.value(); 354 Value distributedVal = warpOp->getResult(it.index()); 355 DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 356 warpOp.getLaneid(), c0); 357 358 // Create buffer before the ifOp. 359 rewriter.setInsertionPoint(ifOp); 360 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 361 sequentialVal.getType()); 362 363 // Store yielded value into buffer, inside the ifOp, before the 364 // terminator. 365 rewriter.setInsertionPoint(yieldOp); 366 helper.buildStore(rewriter, loc, sequentialVal, buffer); 367 368 // Load distributed value from buffer, after the warpOp. 369 rewriter.setInsertionPointAfter(ifOp); 370 // Result type and yielded value type are the same. This is a broadcast. 371 // E.g.: 372 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 373 // vector.yield %cst : f32 374 // } 375 // Both types are f32. The constant %cst is broadcasted to all lanes. 376 // This is described in more detail in the documentation of the op. 377 replacements.push_back( 378 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); 379 } 380 381 // Step 6. Insert sync after all the stores and before all the loads. 382 if (!yieldOp.getOperands().empty()) { 383 rewriter.setInsertionPointAfter(ifOp); 384 options.warpSyncronizationFn(loc, rewriter, warpOp); 385 } 386 387 // Step 7. Delete terminator and add empty scf.yield. 388 rewriter.eraseOp(yieldOp); 389 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 390 rewriter.create<scf::YieldOp>(yieldLoc); 391 392 // Compute replacements for WarpOp results. 393 rewriter.replaceOp(warpOp, replacements); 394 395 return success(); 396 } 397 398 private: 399 const WarpExecuteOnLane0LoweringOptions &options; 400 }; 401 402 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute 403 /// op with the proper return type. 404 /// The new write op is updated to write the result of the new warp execute op. 405 /// The old `writeOp` is deleted. 406 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, 407 WarpExecuteOnLane0Op warpOp, 408 vector::TransferWriteOp writeOp, 409 VectorType targetType, 410 VectorType maybeMaskType) { 411 assert(writeOp->getParentOp() == warpOp && 412 "write must be nested immediately under warp"); 413 OpBuilder::InsertionGuard g(rewriter); 414 SmallVector<size_t> newRetIndices; 415 WarpExecuteOnLane0Op newWarpOp; 416 if (maybeMaskType) { 417 newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 418 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, 419 TypeRange{targetType, maybeMaskType}, newRetIndices); 420 } else { 421 newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 422 rewriter, warpOp, ValueRange{{writeOp.getVector()}}, 423 TypeRange{targetType}, newRetIndices); 424 } 425 rewriter.setInsertionPointAfter(newWarpOp); 426 auto newWriteOp = 427 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 428 rewriter.eraseOp(writeOp); 429 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 430 if (maybeMaskType) 431 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); 432 return newWriteOp; 433 } 434 435 /// Return the distributed vector type based on the original type and the 436 /// distribution map. The map is expected to have a dimension equal to the 437 /// original type rank and should be a projection where the results are the 438 /// distributed dimensions. The number of results should be equal to the number 439 /// of warp sizes which is currently limited to 1. 440 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) 441 /// and a warp size of 16 would distribute the second dimension (associated to 442 /// d1) and return vector<16x2x64> 443 static VectorType getDistributedType(VectorType originalType, AffineMap map, 444 int64_t warpSize) { 445 if (map.getNumResults() != 1) 446 return VectorType(); 447 SmallVector<int64_t> targetShape(originalType.getShape().begin(), 448 originalType.getShape().end()); 449 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 450 unsigned position = map.getDimPosition(i); 451 if (targetShape[position] % warpSize != 0) 452 return VectorType(); 453 targetShape[position] = targetShape[position] / warpSize; 454 } 455 VectorType targetType = 456 VectorType::get(targetShape, originalType.getElementType()); 457 return targetType; 458 } 459 460 /// Distribute transfer_write ops based on the affine map returned by 461 /// `distributionMapFn`. 462 /// Example: 463 /// ``` 464 /// %0 = vector.warp_execute_on_lane_0(%id){ 465 /// ... 466 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 467 /// vector.yield 468 /// } 469 /// ``` 470 /// To 471 /// ``` 472 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 473 /// ... 474 /// vector.yield %v : vector<32xf32> 475 /// } 476 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 477 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> { 478 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 479 PatternBenefit b = 1) 480 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), 481 distributionMapFn(std::move(fn)) {} 482 483 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 484 /// are multiples of the distribution ratio are supported at the moment. 485 LogicalResult tryDistributeOp(RewriterBase &rewriter, 486 vector::TransferWriteOp writeOp, 487 WarpExecuteOnLane0Op warpOp) const { 488 VectorType writtenVectorType = writeOp.getVectorType(); 489 490 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op 491 // to separate it from the rest. 492 if (writtenVectorType.getRank() == 0) 493 return failure(); 494 495 // 2. Compute the distributed type. 496 AffineMap map = distributionMapFn(writeOp.getVector()); 497 VectorType targetType = 498 getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); 499 if (!targetType) 500 return failure(); 501 502 // 2.5 Compute the distributed type for the new mask; 503 VectorType maskType; 504 if (writeOp.getMask()) { 505 // TODO: Distribution of masked writes with non-trivial permutation maps 506 // requires the distribution of the mask to elementwise match the 507 // distribution of the permuted written vector. Currently the details 508 // of which lane is responsible for which element is captured strictly 509 // by shape information on the warp op, and thus requires materializing 510 // the permutation in IR. 511 if (!writeOp.getPermutationMap().isMinorIdentity()) 512 return failure(); 513 maskType = 514 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); 515 } 516 517 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from 518 // the rest. 519 vector::TransferWriteOp newWriteOp = 520 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); 521 522 // 4. Reindex the write using the distribution map. 523 auto newWarpOp = 524 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); 525 rewriter.setInsertionPoint(newWriteOp); 526 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 527 Location loc = newWriteOp.getLoc(); 528 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 529 newWriteOp.getIndices().end()); 530 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 531 AffineExpr d0, d1; 532 bindDims(newWarpOp.getContext(), d0, d1); 533 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 534 if (!indexExpr) 535 continue; 536 unsigned indexPos = indexExpr.getPosition(); 537 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 538 auto scale = 539 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); 540 indices[indexPos] = affine::makeComposedAffineApply( 541 rewriter, loc, d0 + scale * d1, 542 {indices[indexPos], newWarpOp.getLaneid()}); 543 } 544 newWriteOp.getIndicesMutable().assign(indices); 545 546 return success(); 547 } 548 549 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 550 LogicalResult tryExtractOp(RewriterBase &rewriter, 551 vector::TransferWriteOp writeOp, 552 WarpExecuteOnLane0Op warpOp) const { 553 Location loc = writeOp.getLoc(); 554 VectorType vecType = writeOp.getVectorType(); 555 556 // Only sink out vector of 1 element for now to not serialize large vector 557 // store. This can later be controlled by user. 558 if (vecType.getNumElements() != 1) 559 return failure(); 560 561 // Do not process warp ops that contain only TransferWriteOps. 562 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 563 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 564 })) 565 return failure(); 566 567 SmallVector<Value> yieldValues = {writeOp.getVector()}; 568 SmallVector<Type> retTypes = {vecType}; 569 SmallVector<size_t> newRetIndices; 570 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 571 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 572 rewriter.setInsertionPointAfter(newWarpOp); 573 574 // Create a second warp op that contains only writeOp. 575 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 576 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 577 Block &body = secondWarpOp.getBodyRegion().front(); 578 rewriter.setInsertionPointToStart(&body); 579 auto newWriteOp = 580 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 581 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 582 rewriter.eraseOp(writeOp); 583 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 584 return success(); 585 } 586 587 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 588 PatternRewriter &rewriter) const override { 589 auto yield = cast<vector::YieldOp>( 590 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 591 Operation *lastNode = yield->getPrevNode(); 592 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); 593 if (!writeOp) 594 return failure(); 595 596 Value maybeMask = writeOp.getMask(); 597 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 598 return writeOp.getVector() == value || 599 (maybeMask && maybeMask == value) || 600 warpOp.isDefinedOutsideOfRegion(value); 601 })) 602 return failure(); 603 604 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 605 return success(); 606 607 // Masked writes not supported for extraction. 608 if (writeOp.getMask()) 609 return failure(); 610 611 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 612 return success(); 613 614 return failure(); 615 } 616 617 private: 618 DistributionMapFn distributionMapFn; 619 }; 620 621 /// Sink out elementwise op feeding into a warp op yield. 622 /// ``` 623 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 624 /// ... 625 /// %3 = arith.addf %1, %2 : vector<32xf32> 626 /// vector.yield %3 : vector<32xf32> 627 /// } 628 /// ``` 629 /// To 630 /// ``` 631 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 632 /// vector<1xf32>, vector<1xf32>) { 633 /// ... 634 /// %4 = arith.addf %2, %3 : vector<32xf32> 635 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 636 /// vector<32xf32> 637 /// } 638 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 639 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 640 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 641 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 642 PatternRewriter &rewriter) const override { 643 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 644 return OpTrait::hasElementwiseMappableTraits(op); 645 }); 646 if (!yieldOperand) 647 return failure(); 648 649 Operation *elementWise = yieldOperand->get().getDefiningOp(); 650 unsigned operandIndex = yieldOperand->getOperandNumber(); 651 Value distributedVal = warpOp.getResult(operandIndex); 652 SmallVector<Value> yieldValues; 653 SmallVector<Type> retTypes; 654 Location loc = warpOp.getLoc(); 655 for (OpOperand &operand : elementWise->getOpOperands()) { 656 Type targetType; 657 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { 658 // If the result type is a vector, the operands must also be vectors. 659 auto operandType = cast<VectorType>(operand.get().getType()); 660 targetType = 661 VectorType::get(vecType.getShape(), operandType.getElementType()); 662 } else { 663 auto operandType = operand.get().getType(); 664 assert(!isa<VectorType>(operandType) && 665 "unexpected yield of vector from op with scalar result type"); 666 targetType = operandType; 667 } 668 retTypes.push_back(targetType); 669 yieldValues.push_back(operand.get()); 670 } 671 SmallVector<size_t> newRetIndices; 672 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 673 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 674 rewriter.setInsertionPointAfter(newWarpOp); 675 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 676 elementWise->getOperands().end()); 677 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 678 newOperands[i] = newWarpOp.getResult(newRetIndices[i]); 679 } 680 OpBuilder::InsertionGuard g(rewriter); 681 rewriter.setInsertionPointAfter(newWarpOp); 682 Operation *newOp = cloneOpWithOperandsAndTypes( 683 rewriter, loc, elementWise, newOperands, 684 {newWarpOp.getResult(operandIndex).getType()}); 685 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), 686 newOp->getResult(0)); 687 return success(); 688 } 689 }; 690 691 /// Sink out splat constant op feeding into a warp op yield. 692 /// ``` 693 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 694 /// ... 695 /// %cst = arith.constant dense<2.0> : vector<32xf32> 696 /// vector.yield %cst : vector<32xf32> 697 /// } 698 /// ``` 699 /// To 700 /// ``` 701 /// vector.warp_execute_on_lane_0(%arg0 { 702 /// ... 703 /// } 704 /// %0 = arith.constant dense<2.0> : vector<1xf32> 705 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> { 706 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 707 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 708 PatternRewriter &rewriter) const override { 709 OpOperand *yieldOperand = getWarpResult( 710 warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); }); 711 if (!yieldOperand) 712 return failure(); 713 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); 714 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); 715 if (!dense) 716 return failure(); 717 // Notify the rewriter that the warp op is changing (see the comment on 718 // the WarpOpTransferRead pattern). 719 rewriter.startRootUpdate(warpOp); 720 unsigned operandIndex = yieldOperand->getOperandNumber(); 721 Attribute scalarAttr = dense.getSplatValue<Attribute>(); 722 auto newAttr = DenseElementsAttr::get( 723 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); 724 Location loc = warpOp.getLoc(); 725 rewriter.setInsertionPointAfter(warpOp); 726 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); 727 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); 728 rewriter.finalizeRootUpdate(warpOp); 729 return success(); 730 } 731 }; 732 733 /// Delinearize the given `laneId` into multiple dimensions, where each 734 /// dimension's size is determined by `originalShape` and `distributedShape` 735 /// together. This function expects the total numbers of threads needed for 736 /// distribution is equal to `warpSize`. Returns true and updates 737 /// `delinearizedIds` if so. 738 bool delinearizeLaneId(OpBuilder &builder, Location loc, 739 ArrayRef<int64_t> originalShape, 740 ArrayRef<int64_t> distributedShape, int64_t warpSize, 741 Value laneId, SmallVectorImpl<Value> &delinearizedIds) { 742 // If the original shape and the distributed shape is the same, we don't 743 // distribute at all--every thread is handling the whole. For such case, we 744 // should not rely on lane IDs later. So just return an empty lane ID vector. 745 if (originalShape == distributedShape) { 746 delinearizedIds.clear(); 747 return true; 748 } 749 750 SmallVector<int64_t> sizes; 751 for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { 752 if (large % small != 0) 753 return false; 754 sizes.push_back(large / small); 755 } 756 if (std::accumulate(sizes.begin(), sizes.end(), 1, 757 std::multiplies<int64_t>()) != warpSize) 758 return false; 759 760 AffineExpr s0, s1; 761 bindSymbols(builder.getContext(), s0, s1); 762 763 int64_t usedThreads = 1; 764 765 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 766 delinearizedIds.assign(sizes.size(), zero); 767 768 for (int i = sizes.size() - 1; i >= 0; --i) { 769 usedThreads *= sizes[i]; 770 if (usedThreads == warpSize) { 771 // We've used up all available threads. Don't need to perform modulo 772 // anymore. And we can stop the calculation for further dimensions. 773 delinearizedIds[i] = laneId; 774 break; 775 } 776 delinearizedIds[i] = 777 affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); 778 laneId = affine::makeComposedAffineApply( 779 builder, loc, s0.floorDiv(usedThreads), {laneId}); 780 } 781 return true; 782 } 783 784 /// Sink out transfer_read op feeding into a warp op yield. 785 /// ``` 786 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 787 /// ... 788 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 789 // vector<32xf32> 790 /// vector.yield %2 : vector<32xf32> 791 /// } 792 /// ``` 793 /// To 794 /// ``` 795 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 796 /// vector<1xf32>, vector<1xf32>) { 797 /// ... 798 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 799 /// vector<32xf32> vector.yield %2 : vector<32xf32> 800 /// } 801 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 802 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 803 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 804 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 805 PatternRewriter &rewriter) const override { 806 // Try to find a distributable yielded read. Note that this pattern can 807 // still fail at the end after distribution, in which case this might have 808 // missed another distributable read. 809 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { 810 // Don't duplicate transfer_read ops when distributing. 811 return isa<vector::TransferReadOp>(op) && op->hasOneUse(); 812 }); 813 if (!operand) 814 return failure(); 815 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 816 817 unsigned operandIndex = operand->getOperandNumber(); 818 Value distributedVal = warpOp.getResult(operandIndex); 819 820 SmallVector<Value, 4> indices(read.getIndices().begin(), 821 read.getIndices().end()); 822 auto sequentialType = cast<VectorType>(read.getResult().getType()); 823 auto distributedType = cast<VectorType>(distributedVal.getType()); 824 AffineMap map = calculateImplicitMap(sequentialType, distributedType); 825 AffineMap indexMap = map.compose(read.getPermutationMap()); 826 827 // Distribute the mask if present. 828 OpBuilder::InsertionGuard g(rewriter); 829 WarpExecuteOnLane0Op newWarpOp = warpOp; 830 Value newMask = read.getMask(); 831 bool hasMask = false; 832 if (read.getMask()) { 833 hasMask = true; 834 // TODO: Distribution of masked reads with non-trivial permutation maps 835 // requires the distribution of the mask to elementwise match the 836 // distribution of the permuted written vector. Currently the details 837 // of which lane is responsible for which element is captured strictly 838 // by shape information on the warp op, and thus requires materializing 839 // the permutation in IR. 840 if (!read.getPermutationMap().isMinorIdentity()) 841 return failure(); 842 VectorType maskType = 843 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); 844 SmallVector<size_t> newRetIndices; 845 newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 846 rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType}, 847 newRetIndices); 848 newMask = newWarpOp.getResult(newRetIndices[0]); 849 distributedVal = newWarpOp.getResult(operandIndex); 850 } else { 851 // This pattern does not actually change the warp op directly. Instead it 852 // just rewrites a new transfer read (when not masked) outside of the warp 853 // op and replaces the correponding result. There are then follow up 854 // patterns to erase now dead results of the warp op. This erasure allows 855 // propagation to continue, but this pattern on its own never actually 856 // tells the pattern rewriter that the warp op "changed." Notify the 857 // rewriter here that the warp op is changing. Similar situations are 858 // noted in following patterns. 859 rewriter.startRootUpdate(warpOp); 860 } 861 862 rewriter.setInsertionPointAfter(newWarpOp); 863 864 // Try to delinearize the lane ID to match the rank expected for 865 // distribution. 866 SmallVector<Value> delinearizedIds; 867 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), 868 distributedType.getShape(), newWarpOp.getWarpSize(), 869 newWarpOp.getLaneid(), delinearizedIds)) { 870 if (!hasMask) 871 rewriter.cancelRootUpdate(warpOp); 872 return rewriter.notifyMatchFailure( 873 read, "cannot delinearize lane ID for distribution"); 874 } 875 assert(!delinearizedIds.empty() || map.getNumResults() == 0); 876 877 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { 878 AffineExpr d0, d1; 879 bindDims(read.getContext(), d0, d1); 880 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 881 if (!indexExpr) 882 continue; 883 unsigned indexPos = indexExpr.getPosition(); 884 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 885 int64_t scale = distributedType.getDimSize(vectorPos); 886 indices[indexPos] = affine::makeComposedAffineApply( 887 rewriter, read.getLoc(), d0 + scale * d1, 888 {indices[indexPos], delinearizedIds[vectorPos]}); 889 } 890 auto newRead = rewriter.create<vector::TransferReadOp>( 891 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 892 read.getPermutationMapAttr(), read.getPadding(), newMask, 893 read.getInBoundsAttr()); 894 895 // Check that the produced operation is legal. 896 // The transfer op may be reading from values that are defined within 897 // warpOp's body, which is illegal. 898 // We do the check late because incdices may be changed by 899 // makeComposeAffineApply. This rewrite may remove dependencies from 900 // warpOp's body. 901 // E.g., warpop { 902 // %idx = affine.apply...[%outsideDef] 903 // ... = transfer_read ...[%idx] 904 // } 905 // will be rewritten in: 906 // warpop { 907 // } 908 // %new_idx = affine.apply...[%outsideDef] 909 // ... = transfer_read ...[%new_idx] 910 if (!llvm::all_of(newRead->getOperands(), [&](Value value) { 911 return (newRead.getMask() && value == newRead.getMask()) || 912 newWarpOp.isDefinedOutsideOfRegion(value); 913 })) { 914 if (!hasMask) 915 rewriter.cancelRootUpdate(warpOp); 916 return failure(); 917 } 918 919 rewriter.replaceAllUsesWith(distributedVal, newRead); 920 if (!hasMask) 921 rewriter.finalizeRootUpdate(warpOp); 922 return success(); 923 } 924 }; 925 926 /// Remove any result that has no use along with the matching yieldOp operand. 927 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 928 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 929 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 930 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 931 PatternRewriter &rewriter) const override { 932 SmallVector<Type> newResultTypes; 933 newResultTypes.reserve(warpOp->getNumResults()); 934 SmallVector<Value> newYieldValues; 935 newYieldValues.reserve(warpOp->getNumResults()); 936 DenseMap<Value, int64_t> dedupYieldOperandPositionMap; 937 DenseMap<OpResult, int64_t> dedupResultPositionMap; 938 auto yield = cast<vector::YieldOp>( 939 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 940 941 // Some values may be yielded multiple times and correspond to multiple 942 // results. Deduplicating occurs by taking each result with its matching 943 // yielded value, and: 944 // 1. recording the unique first position at which the value is yielded. 945 // 2. recording for the result, the first position at which the dedup'ed 946 // value is yielded. 947 // 3. skipping from the new result types / new yielded values any result 948 // that has no use or whose yielded value has already been seen. 949 for (OpResult result : warpOp.getResults()) { 950 Value yieldOperand = yield.getOperand(result.getResultNumber()); 951 auto it = dedupYieldOperandPositionMap.insert( 952 std::make_pair(yieldOperand, newResultTypes.size())); 953 dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); 954 if (result.use_empty() || !it.second) 955 continue; 956 newResultTypes.push_back(result.getType()); 957 newYieldValues.push_back(yieldOperand); 958 } 959 // No modification, exit early. 960 if (yield.getNumOperands() == newYieldValues.size()) 961 return failure(); 962 // Move the body of the old warpOp to a new warpOp. 963 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 964 rewriter, warpOp, newYieldValues, newResultTypes); 965 966 // Simplify the new warp op after dropping dead results. 967 newWarpOp.getBody()->walk([&](Operation *op) { 968 if (isOpTriviallyDead(op)) 969 rewriter.eraseOp(op); 970 }); 971 972 // Replace results of the old warpOp by the new, deduplicated results. 973 SmallVector<Value> newValues; 974 newValues.reserve(warpOp->getNumResults()); 975 for (OpResult result : warpOp.getResults()) { 976 if (result.use_empty()) 977 newValues.push_back(Value()); 978 else 979 newValues.push_back( 980 newWarpOp.getResult(dedupResultPositionMap.lookup(result))); 981 } 982 rewriter.replaceOp(warpOp, newValues); 983 return success(); 984 } 985 }; 986 987 // If an operand is directly yielded out of the region we can forward it 988 // directly and it doesn't need to go through the region. 989 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 990 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 991 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 992 PatternRewriter &rewriter) const override { 993 SmallVector<Type> resultTypes; 994 SmallVector<Value> yieldValues; 995 auto yield = cast<vector::YieldOp>( 996 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 997 Value valForwarded; 998 unsigned resultIndex; 999 for (OpOperand &operand : yield->getOpOperands()) { 1000 Value result = warpOp.getResult(operand.getOperandNumber()); 1001 if (result.use_empty()) 1002 continue; 1003 1004 // Assume all the values coming from above are uniform. 1005 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 1006 if (result.getType() != operand.get().getType()) 1007 continue; 1008 valForwarded = operand.get(); 1009 resultIndex = operand.getOperandNumber(); 1010 break; 1011 } 1012 auto arg = dyn_cast<BlockArgument>(operand.get()); 1013 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 1014 continue; 1015 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 1016 if (result.getType() != warpOperand.getType()) 1017 continue; 1018 valForwarded = warpOperand; 1019 resultIndex = operand.getOperandNumber(); 1020 break; 1021 } 1022 if (!valForwarded) 1023 return failure(); 1024 // Notify the rewriter that the warp op is changing (see the comment on 1025 // the WarpOpTransferRead pattern). 1026 rewriter.startRootUpdate(warpOp); 1027 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); 1028 rewriter.finalizeRootUpdate(warpOp); 1029 return success(); 1030 } 1031 }; 1032 1033 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 1034 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1035 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1036 PatternRewriter &rewriter) const override { 1037 OpOperand *operand = getWarpResult( 1038 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 1039 if (!operand) 1040 return failure(); 1041 unsigned int operandNumber = operand->getOperandNumber(); 1042 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 1043 Location loc = broadcastOp.getLoc(); 1044 auto destVecType = 1045 cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 1046 Value broadcastSrc = broadcastOp.getSource(); 1047 Type broadcastSrcType = broadcastSrc.getType(); 1048 1049 // Check that the broadcast actually spans a set of values uniformly across 1050 // all threads. In other words, check that each thread can reconstruct 1051 // their own broadcast. 1052 // For that we simply check that the broadcast we want to build makes sense. 1053 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != 1054 vector::BroadcastableToResult::Success) 1055 return failure(); 1056 SmallVector<size_t> newRetIndices; 1057 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1058 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); 1059 rewriter.setInsertionPointAfter(newWarpOp); 1060 Value broadcasted = rewriter.create<vector::BroadcastOp>( 1061 loc, destVecType, newWarpOp->getResult(newRetIndices[0])); 1062 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1063 broadcasted); 1064 return success(); 1065 } 1066 }; 1067 1068 /// Pattern to move shape cast out of the warp op. shape cast is basically a 1069 /// no-op for warp distribution; we need to handle the shape though. 1070 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> { 1071 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1072 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1073 PatternRewriter &rewriter) const override { 1074 OpOperand *operand = getWarpResult( 1075 warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); }); 1076 if (!operand) 1077 return failure(); 1078 1079 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); 1080 1081 unsigned int operandNumber = operand->getOperandNumber(); 1082 auto castDistributedType = 1083 cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 1084 VectorType castOriginalType = oldCastOp.getSourceVectorType(); 1085 VectorType castResultType = castDistributedType; 1086 1087 // We expect the distributed type to have a smaller rank than the original 1088 // type. Prepend with size-one dimensions to make them the same. 1089 unsigned castDistributedRank = castDistributedType.getRank(); 1090 unsigned castOriginalRank = castOriginalType.getRank(); 1091 if (castDistributedRank < castOriginalRank) { 1092 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); 1093 llvm::append_range(shape, castDistributedType.getShape()); 1094 castDistributedType = 1095 VectorType::get(shape, castDistributedType.getElementType()); 1096 } 1097 1098 SmallVector<size_t> newRetIndices; 1099 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1100 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, 1101 newRetIndices); 1102 rewriter.setInsertionPointAfter(newWarpOp); 1103 Value newCast = rewriter.create<vector::ShapeCastOp>( 1104 oldCastOp.getLoc(), castResultType, 1105 newWarpOp->getResult(newRetIndices[0])); 1106 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); 1107 return success(); 1108 } 1109 }; 1110 1111 /// Sink out vector.create_mask op feeding into a warp op yield. 1112 /// ``` 1113 /// %0 = ... 1114 /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 1115 /// ... 1116 /// %mask = vector.create_mask %0 : vector<32xi1> 1117 /// vector.yield %mask : vector<32xi1> 1118 /// } 1119 /// ``` 1120 /// To 1121 /// ``` 1122 /// %0 = ... 1123 /// vector.warp_execute_on_lane_0(%arg0) { 1124 /// ... 1125 /// } 1126 /// %cmp = arith.cmpi ult, %laneid, %0 1127 /// %ub = arith.select %cmp, %c0, %c1 1128 /// %1 = vector.create_mask %ub : vector<1xi1> 1129 struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> { 1130 using OpRewritePattern::OpRewritePattern; 1131 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1132 PatternRewriter &rewriter) const override { 1133 OpOperand *yieldOperand = getWarpResult( 1134 warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); }); 1135 if (!yieldOperand) 1136 return failure(); 1137 1138 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); 1139 1140 // Early exit if any values needed for calculating the new mask indices 1141 // are defined inside the warp op. 1142 if (!llvm::all_of(mask->getOperands(), [&](Value value) { 1143 return warpOp.isDefinedOutsideOfRegion(value); 1144 })) 1145 return failure(); 1146 1147 Location loc = mask.getLoc(); 1148 unsigned operandIndex = yieldOperand->getOperandNumber(); 1149 1150 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); 1151 VectorType seqType = mask.getVectorType(); 1152 ArrayRef<int64_t> seqShape = seqType.getShape(); 1153 ArrayRef<int64_t> distShape = distType.getShape(); 1154 1155 rewriter.setInsertionPointAfter(warpOp); 1156 1157 // Delinearize the lane ID for constructing the distributed mask sizes. 1158 SmallVector<Value> delinearizedIds; 1159 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape, 1160 warpOp.getWarpSize(), warpOp.getLaneid(), 1161 delinearizedIds)) 1162 return rewriter.notifyMatchFailure( 1163 mask, "cannot delinearize lane ID for distribution"); 1164 assert(!delinearizedIds.empty()); 1165 1166 // Notify the rewriter that the warp op is changing (see the comment on 1167 // the WarpOpTransferRead pattern). 1168 rewriter.startRootUpdate(warpOp); 1169 1170 AffineExpr s0, s1; 1171 bindSymbols(rewriter.getContext(), s0, s1); 1172 SmallVector<Value> newOperands; 1173 for (int i = 0, e = distShape.size(); i < e; ++i) { 1174 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to 1175 // find the distance from the largest mask index owned by this lane to the 1176 // original mask size. `vector.create_mask` implicitly clamps mask 1177 // operands to the range [0, mask_vector_size[i]], or in other words, the 1178 // mask sizes are always in the range [0, mask_vector_size[i]). 1179 Value maskDimIdx = affine::makeComposedAffineApply( 1180 rewriter, loc, s1 - s0 * distShape[i], 1181 {delinearizedIds[i], mask.getOperand(i)}); 1182 newOperands.push_back(maskDimIdx); 1183 } 1184 1185 auto newMask = 1186 rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); 1187 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); 1188 rewriter.finalizeRootUpdate(warpOp); 1189 return success(); 1190 } 1191 }; 1192 1193 /// Pattern to move out vector.extract of single element vector. Those don't 1194 /// need to be distributed and can just be propagated outside of the region. 1195 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> { 1196 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1197 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1198 PatternRewriter &rewriter) const override { 1199 OpOperand *operand = getWarpResult( 1200 warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); }); 1201 if (!operand) 1202 return failure(); 1203 unsigned int operandNumber = operand->getOperandNumber(); 1204 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 1205 VectorType extractSrcType = extractOp.getSourceVectorType(); 1206 Location loc = extractOp.getLoc(); 1207 1208 // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op. 1209 assert(extractSrcType.getRank() > 0 && 1210 "vector.extract does not support rank 0 sources"); 1211 1212 // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be 1213 // canonicalized to %v. 1214 if (extractOp.getNumIndices() == 0) 1215 return failure(); 1216 1217 // Rewrite vector.extract with 1d source to vector.extractelement. 1218 if (extractSrcType.getRank() == 1) { 1219 if (extractOp.hasDynamicPosition()) 1220 // TODO: Dinamic position not supported yet. 1221 return failure(); 1222 1223 assert(extractOp.getNumIndices() == 1 && "expected 1 index"); 1224 int64_t pos = extractOp.getStaticPosition()[0]; 1225 rewriter.setInsertionPoint(extractOp); 1226 rewriter.replaceOpWithNewOp<vector::ExtractElementOp>( 1227 extractOp, extractOp.getVector(), 1228 rewriter.create<arith::ConstantIndexOp>(loc, pos)); 1229 return success(); 1230 } 1231 1232 // All following cases are 2d or higher dimensional source vectors. 1233 1234 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 1235 // There is no distribution, this is a broadcast. Simply move the extract 1236 // out of the warp op. 1237 // TODO: This could be optimized. E.g., in case of a scalar result, let 1238 // one lane extract and shuffle the result to all other lanes (same as 1239 // the 1d case). 1240 SmallVector<size_t> newRetIndices; 1241 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1242 rewriter, warpOp, {extractOp.getVector()}, 1243 {extractOp.getSourceVectorType()}, newRetIndices); 1244 rewriter.setInsertionPointAfter(newWarpOp); 1245 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1246 // Extract from distributed vector. 1247 Value newExtract = rewriter.create<vector::ExtractOp>( 1248 loc, distributedVec, extractOp.getMixedPosition()); 1249 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1250 newExtract); 1251 return success(); 1252 } 1253 1254 // Find the distributed dimension. There should be exactly one. 1255 auto distributedType = 1256 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1257 auto yieldedType = cast<VectorType>(operand->get().getType()); 1258 int64_t distributedDim = -1; 1259 for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 1260 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { 1261 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 1262 // support distributing multiple dimensions in the future. 1263 assert(distributedDim == -1 && "found multiple distributed dims"); 1264 distributedDim = i; 1265 } 1266 } 1267 assert(distributedDim != -1 && "could not find distributed dimension"); 1268 (void)distributedDim; 1269 1270 // Yield source vector from warp op. 1271 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(), 1272 extractSrcType.getShape().end()); 1273 for (int i = 0; i < distributedType.getRank(); ++i) 1274 newDistributedShape[i + extractOp.getNumIndices()] = 1275 distributedType.getDimSize(i); 1276 auto newDistributedType = 1277 VectorType::get(newDistributedShape, distributedType.getElementType()); 1278 SmallVector<size_t> newRetIndices; 1279 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1280 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, 1281 newRetIndices); 1282 rewriter.setInsertionPointAfter(newWarpOp); 1283 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1284 // Extract from distributed vector. 1285 Value newExtract = rewriter.create<vector::ExtractOp>( 1286 loc, distributedVec, extractOp.getMixedPosition()); 1287 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1288 newExtract); 1289 return success(); 1290 } 1291 }; 1292 1293 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't 1294 /// need to be distributed and can just be propagated outside of the region. 1295 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> { 1296 WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn, 1297 PatternBenefit b = 1) 1298 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), 1299 warpShuffleFromIdxFn(std::move(fn)) {} 1300 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1301 PatternRewriter &rewriter) const override { 1302 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { 1303 return isa<vector::ExtractElementOp>(op); 1304 }); 1305 if (!operand) 1306 return failure(); 1307 unsigned int operandNumber = operand->getOperandNumber(); 1308 auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>(); 1309 VectorType extractSrcType = extractOp.getSourceVectorType(); 1310 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; 1311 Type elType = extractSrcType.getElementType(); 1312 VectorType distributedVecType; 1313 if (!is0dOrVec1Extract) { 1314 assert(extractSrcType.getRank() == 1 && 1315 "expected that extractelement src rank is 0 or 1"); 1316 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) 1317 return failure(); 1318 int64_t elementsPerLane = 1319 extractSrcType.getShape()[0] / warpOp.getWarpSize(); 1320 distributedVecType = VectorType::get({elementsPerLane}, elType); 1321 } else { 1322 distributedVecType = extractSrcType; 1323 } 1324 // Yield source vector from warp op. 1325 Location loc = extractOp.getLoc(); 1326 SmallVector<size_t> newRetIndices; 1327 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1328 rewriter, warpOp, {extractOp.getVector()}, {distributedVecType}, 1329 newRetIndices); 1330 rewriter.setInsertionPointAfter(newWarpOp); 1331 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1332 1333 // 0d extract: The new warp op broadcasts the source vector to all lanes. 1334 // All lanes extract the scalar. 1335 if (is0dOrVec1Extract) { 1336 Value newExtract; 1337 if (extractSrcType.getRank() == 1) { 1338 newExtract = rewriter.create<vector::ExtractElementOp>( 1339 loc, distributedVec, 1340 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 1341 1342 } else { 1343 newExtract = 1344 rewriter.create<vector::ExtractElementOp>(loc, distributedVec); 1345 } 1346 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1347 newExtract); 1348 return success(); 1349 } 1350 1351 // 1d extract: Distribute the source vector. One lane extracts and shuffles 1352 // the value to all other lanes. 1353 int64_t elementsPerLane = distributedVecType.getShape()[0]; 1354 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 1355 // tid of extracting thread: pos / elementsPerLane 1356 Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>( 1357 loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition()); 1358 // Extract at position: pos % elementsPerLane 1359 Value pos = 1360 elementsPerLane == 1 1361 ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult() 1362 : rewriter 1363 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane, 1364 extractOp.getPosition()) 1365 .getResult(); 1366 Value extracted = 1367 rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos); 1368 1369 // Shuffle the extracted value to all lanes. 1370 Value shuffled = warpShuffleFromIdxFn( 1371 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); 1372 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); 1373 return success(); 1374 } 1375 1376 private: 1377 WarpShuffleFromIdxFn warpShuffleFromIdxFn; 1378 }; 1379 1380 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> { 1381 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1382 1383 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1384 PatternRewriter &rewriter) const override { 1385 OpOperand *operand = getWarpResult( 1386 warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); }); 1387 if (!operand) 1388 return failure(); 1389 unsigned int operandNumber = operand->getOperandNumber(); 1390 auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); 1391 VectorType vecType = insertOp.getDestVectorType(); 1392 VectorType distrType = 1393 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1394 bool hasPos = static_cast<bool>(insertOp.getPosition()); 1395 1396 // Yield destination vector, source scalar and position from warp op. 1397 SmallVector<Value> additionalResults{insertOp.getDest(), 1398 insertOp.getSource()}; 1399 SmallVector<Type> additionalResultTypes{distrType, 1400 insertOp.getSource().getType()}; 1401 if (hasPos) { 1402 additionalResults.push_back(insertOp.getPosition()); 1403 additionalResultTypes.push_back(insertOp.getPosition().getType()); 1404 } 1405 Location loc = insertOp.getLoc(); 1406 SmallVector<size_t> newRetIndices; 1407 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1408 rewriter, warpOp, additionalResults, additionalResultTypes, 1409 newRetIndices); 1410 rewriter.setInsertionPointAfter(newWarpOp); 1411 Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 1412 Value newSource = newWarpOp->getResult(newRetIndices[1]); 1413 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value(); 1414 rewriter.setInsertionPointAfter(newWarpOp); 1415 1416 if (vecType == distrType) { 1417 // Broadcast: Simply move the vector.inserelement op out. 1418 Value newInsert = rewriter.create<vector::InsertElementOp>( 1419 loc, newSource, distributedVec, newPos); 1420 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1421 newInsert); 1422 return success(); 1423 } 1424 1425 // This is a distribution. Only one lane should insert. 1426 int64_t elementsPerLane = distrType.getShape()[0]; 1427 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 1428 // tid of extracting thread: pos / elementsPerLane 1429 Value insertingLane = rewriter.create<affine::AffineApplyOp>( 1430 loc, sym0.ceilDiv(elementsPerLane), newPos); 1431 // Insert position: pos % elementsPerLane 1432 Value pos = 1433 elementsPerLane == 1 1434 ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult() 1435 : rewriter 1436 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane, 1437 newPos) 1438 .getResult(); 1439 Value isInsertingLane = rewriter.create<arith::CmpIOp>( 1440 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 1441 Value newResult = 1442 rewriter 1443 .create<scf::IfOp>( 1444 loc, isInsertingLane, 1445 /*thenBuilder=*/ 1446 [&](OpBuilder &builder, Location loc) { 1447 Value newInsert = builder.create<vector::InsertElementOp>( 1448 loc, newSource, distributedVec, pos); 1449 builder.create<scf::YieldOp>(loc, newInsert); 1450 }, 1451 /*elseBuilder=*/ 1452 [&](OpBuilder &builder, Location loc) { 1453 builder.create<scf::YieldOp>(loc, distributedVec); 1454 }) 1455 .getResult(0); 1456 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 1457 return success(); 1458 } 1459 }; 1460 1461 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> { 1462 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1463 1464 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1465 PatternRewriter &rewriter) const override { 1466 OpOperand *operand = getWarpResult( 1467 warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); }); 1468 if (!operand) 1469 return failure(); 1470 unsigned int operandNumber = operand->getOperandNumber(); 1471 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); 1472 Location loc = insertOp.getLoc(); 1473 1474 // "vector.insert %v, %v[] : ..." can be canonicalized to %v. 1475 if (insertOp.getNumIndices() == 0) 1476 return failure(); 1477 1478 // Rewrite vector.insert with 1d dest to vector.insertelement. 1479 if (insertOp.getDestVectorType().getRank() == 1) { 1480 if (insertOp.hasDynamicPosition()) 1481 // TODO: Dinamic position not supported yet. 1482 return failure(); 1483 1484 assert(insertOp.getNumIndices() == 1 && "expected 1 index"); 1485 int64_t pos = insertOp.getStaticPosition()[0]; 1486 rewriter.setInsertionPoint(insertOp); 1487 rewriter.replaceOpWithNewOp<vector::InsertElementOp>( 1488 insertOp, insertOp.getSource(), insertOp.getDest(), 1489 rewriter.create<arith::ConstantIndexOp>(loc, pos)); 1490 return success(); 1491 } 1492 1493 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 1494 // There is no distribution, this is a broadcast. Simply move the insert 1495 // out of the warp op. 1496 SmallVector<size_t> newRetIndices; 1497 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1498 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 1499 {insertOp.getSourceType(), insertOp.getDestVectorType()}, 1500 newRetIndices); 1501 rewriter.setInsertionPointAfter(newWarpOp); 1502 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 1503 Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 1504 Value newResult = rewriter.create<vector::InsertOp>( 1505 loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 1506 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 1507 newResult); 1508 return success(); 1509 } 1510 1511 // Find the distributed dimension. There should be exactly one. 1512 auto distrDestType = 1513 cast<VectorType>(warpOp.getResult(operandNumber).getType()); 1514 auto yieldedType = cast<VectorType>(operand->get().getType()); 1515 int64_t distrDestDim = -1; 1516 for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 1517 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { 1518 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 1519 // support distributing multiple dimensions in the future. 1520 assert(distrDestDim == -1 && "found multiple distributed dims"); 1521 distrDestDim = i; 1522 } 1523 } 1524 assert(distrDestDim != -1 && "could not find distributed dimension"); 1525 1526 // Compute the distributed source vector type. 1527 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType()); 1528 SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(), 1529 srcVecType.getShape().end()); 1530 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> 1531 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will 1532 // insert a smaller vector<3xf32>. 1533 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that 1534 // case, one lane will insert the source vector<96xf32>. The other 1535 // lanes will not do anything. 1536 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); 1537 if (distrSrcDim >= 0) 1538 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); 1539 auto distrSrcType = 1540 VectorType::get(distrSrcShape, distrDestType.getElementType()); 1541 1542 // Yield source and dest vectors from warp op. 1543 SmallVector<size_t> newRetIndices; 1544 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1545 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 1546 {distrSrcType, distrDestType}, newRetIndices); 1547 rewriter.setInsertionPointAfter(newWarpOp); 1548 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 1549 Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 1550 1551 // Insert into the distributed vector. 1552 Value newResult; 1553 if (distrSrcDim >= 0) { 1554 // Every lane inserts a small piece. 1555 newResult = rewriter.create<vector::InsertOp>( 1556 loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 1557 } else { 1558 // One lane inserts the entire source vector. 1559 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); 1560 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); 1561 SmallVector<int64_t> newPos = getAsIntegers(pos); 1562 // tid of inserting lane: pos / elementsPerLane 1563 Value insertingLane = rewriter.create<arith::ConstantIndexOp>( 1564 loc, newPos[distrDestDim] / elementsPerLane); 1565 Value isInsertingLane = rewriter.create<arith::CmpIOp>( 1566 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 1567 // Insert position: pos % elementsPerLane 1568 newPos[distrDestDim] %= elementsPerLane; 1569 auto insertingBuilder = [&](OpBuilder &builder, Location loc) { 1570 Value newInsert = builder.create<vector::InsertOp>( 1571 loc, distributedSrc, distributedDest, newPos); 1572 builder.create<scf::YieldOp>(loc, newInsert); 1573 }; 1574 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { 1575 builder.create<scf::YieldOp>(loc, distributedDest); 1576 }; 1577 newResult = rewriter 1578 .create<scf::IfOp>(loc, isInsertingLane, 1579 /*thenBuilder=*/insertingBuilder, 1580 /*elseBuilder=*/nonInsertingBuilder) 1581 .getResult(0); 1582 } 1583 1584 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 1585 return success(); 1586 } 1587 }; 1588 1589 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 1590 /// the scf.ForOp is the last operation in the region so that it doesn't change 1591 /// the order of execution. This creates a new scf.for region after the 1592 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 1593 /// WarpExecuteOnLane0Op region. Example: 1594 /// ``` 1595 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 1596 /// ... 1597 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 1598 /// -> (vector<128xf32>) { 1599 /// ... 1600 /// scf.yield %r : vector<128xf32> 1601 /// } 1602 /// vector.yield %v1 : vector<128xf32> 1603 /// } 1604 /// ``` 1605 /// To: 1606 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 1607 /// ... 1608 /// vector.yield %v : vector<128xf32> 1609 /// } 1610 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 1611 /// -> (vector<4xf32>) { 1612 /// %iw = vector.warp_execute_on_lane_0(%laneid) 1613 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 1614 /// ^bb0(%arg: vector<128xf32>): 1615 /// ... 1616 /// vector.yield %ir : vector<128xf32> 1617 /// } 1618 /// scf.yield %iw : vector<4xf32> 1619 /// } 1620 /// ``` 1621 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 1622 1623 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) 1624 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), 1625 distributionMapFn(std::move(fn)) {} 1626 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 1627 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1628 PatternRewriter &rewriter) const override { 1629 auto yield = cast<vector::YieldOp>( 1630 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 1631 // Only pick up forOp if it is the last op in the region. 1632 Operation *lastNode = yield->getPrevNode(); 1633 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 1634 if (!forOp) 1635 return failure(); 1636 // Collect Values that come from the warp op but are outside the forOp. 1637 // Those Value needs to be returned by the original warpOp and passed to the 1638 // new op. 1639 llvm::SmallSetVector<Value, 32> escapingValues; 1640 SmallVector<Type> inputTypes; 1641 SmallVector<Type> distTypes; 1642 mlir::visitUsedValuesDefinedAbove( 1643 forOp.getBodyRegion(), [&](OpOperand *operand) { 1644 Operation *parent = operand->get().getParentRegion()->getParentOp(); 1645 if (warpOp->isAncestor(parent)) { 1646 if (!escapingValues.insert(operand->get())) 1647 return; 1648 Type distType = operand->get().getType(); 1649 if (auto vecType = dyn_cast<VectorType>(distType)) { 1650 AffineMap map = distributionMapFn(operand->get()); 1651 distType = getDistributedType(vecType, map, warpOp.getWarpSize()); 1652 } 1653 inputTypes.push_back(operand->get().getType()); 1654 distTypes.push_back(distType); 1655 } 1656 }); 1657 1658 SmallVector<size_t> newRetIndices; 1659 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1660 rewriter, warpOp, escapingValues.getArrayRef(), distTypes, 1661 newRetIndices); 1662 yield = cast<vector::YieldOp>( 1663 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 1664 1665 SmallVector<Value> newOperands; 1666 SmallVector<unsigned> resultIdx; 1667 // Collect all the outputs coming from the forOp. 1668 for (OpOperand &yieldOperand : yield->getOpOperands()) { 1669 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 1670 continue; 1671 auto forResult = cast<OpResult>(yieldOperand.get()); 1672 newOperands.push_back( 1673 newWarpOp.getResult(yieldOperand.getOperandNumber())); 1674 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); 1675 resultIdx.push_back(yieldOperand.getOperandNumber()); 1676 } 1677 1678 OpBuilder::InsertionGuard g(rewriter); 1679 rewriter.setInsertionPointAfter(newWarpOp); 1680 1681 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 1682 // inside. 1683 auto newForOp = rewriter.create<scf::ForOp>( 1684 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 1685 forOp.getStep(), newOperands); 1686 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 1687 1688 SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), 1689 newForOp.getRegionIterArgs().end()); 1690 SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), 1691 forOp.getResultTypes().end()); 1692 llvm::SmallDenseMap<Value, int64_t> argIndexMapping; 1693 for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { 1694 warpInput.push_back(newWarpOp.getResult(retIdx)); 1695 argIndexMapping[escapingValues[i]] = warpInputType.size(); 1696 warpInputType.push_back(inputTypes[i]); 1697 } 1698 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 1699 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), 1700 newWarpOp.getWarpSize(), warpInput, warpInputType); 1701 1702 SmallVector<Value> argMapping; 1703 argMapping.push_back(newForOp.getInductionVar()); 1704 for (Value args : innerWarp.getBody()->getArguments()) { 1705 argMapping.push_back(args); 1706 } 1707 argMapping.resize(forOp.getBody()->getNumArguments()); 1708 SmallVector<Value> yieldOperands; 1709 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 1710 yieldOperands.push_back(operand); 1711 rewriter.eraseOp(forOp.getBody()->getTerminator()); 1712 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 1713 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 1714 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 1715 rewriter.setInsertionPointAfter(innerWarp); 1716 if (!innerWarp.getResults().empty()) 1717 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 1718 rewriter.eraseOp(forOp); 1719 // Replace the warpOp result coming from the original ForOp. 1720 for (const auto &res : llvm::enumerate(resultIdx)) { 1721 rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), 1722 newForOp.getResult(res.index())); 1723 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); 1724 } 1725 newForOp.walk([&](Operation *op) { 1726 for (OpOperand &operand : op->getOpOperands()) { 1727 auto it = argIndexMapping.find(operand.get()); 1728 if (it == argIndexMapping.end()) 1729 continue; 1730 operand.set(innerWarp.getBodyRegion().getArgument(it->second)); 1731 } 1732 }); 1733 1734 // Finally, hoist out any now uniform code from the inner warp op. 1735 mlir::vector::moveScalarUniformCode(innerWarp); 1736 return success(); 1737 } 1738 1739 private: 1740 DistributionMapFn distributionMapFn; 1741 }; 1742 1743 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 1744 /// The vector is reduced in parallel. Currently limited to vector size matching 1745 /// the warpOp size. E.g.: 1746 /// ``` 1747 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 1748 /// %0 = "some_def"() : () -> (vector<32xf32>) 1749 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 1750 /// vector_ext.yield %1 : f32 1751 /// } 1752 /// ``` 1753 /// is lowered to: 1754 /// ``` 1755 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 1756 /// %1 = "some_def"() : () -> (vector<32xf32>) 1757 /// vector_ext.yield %1 : vector<32xf32> 1758 /// } 1759 /// %a = vector.extract %0[0] : f32 from vector<1xf32> 1760 /// %r = ("warp.reduction %a") 1761 /// ``` 1762 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 1763 WarpOpReduction(MLIRContext *context, 1764 DistributedReductionFn distributedReductionFn, 1765 PatternBenefit benefit = 1) 1766 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 1767 distributedReductionFn(std::move(distributedReductionFn)) {} 1768 1769 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1770 PatternRewriter &rewriter) const override { 1771 OpOperand *yieldOperand = getWarpResult( 1772 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 1773 if (!yieldOperand) 1774 return failure(); 1775 1776 auto reductionOp = 1777 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 1778 auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); 1779 // Only rank 1 vectors supported. 1780 if (vectorType.getRank() != 1) 1781 return rewriter.notifyMatchFailure( 1782 warpOp, "Only rank 1 reductions can be distributed."); 1783 // Only warp_size-sized vectors supported. 1784 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) 1785 return rewriter.notifyMatchFailure( 1786 warpOp, "Reduction vector dimension must match was size."); 1787 if (!reductionOp.getType().isIntOrFloat()) 1788 return rewriter.notifyMatchFailure( 1789 warpOp, "Reduction distribution currently only supports floats and " 1790 "integer types."); 1791 1792 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); 1793 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 1794 unsigned operandIndex = yieldOperand->getOperandNumber(); 1795 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 1796 SmallVector<Type> retTypes = { 1797 VectorType::get({numElements}, reductionOp.getType())}; 1798 if (reductionOp.getAcc()) { 1799 yieldValues.push_back(reductionOp.getAcc()); 1800 retTypes.push_back(reductionOp.getAcc().getType()); 1801 } 1802 SmallVector<size_t> newRetIndices; 1803 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1804 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 1805 rewriter.setInsertionPointAfter(newWarpOp); 1806 1807 // Obtain data to reduce for a single lane. 1808 Value laneValVec = newWarpOp.getResult(newRetIndices[0]); 1809 // Distribute and reduce across threads. 1810 Value fullReduce = 1811 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, 1812 reductionOp.getKind(), newWarpOp.getWarpSize()); 1813 if (reductionOp.getAcc()) { 1814 fullReduce = vector::makeArithReduction( 1815 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, 1816 newWarpOp.getResult(newRetIndices[1])); 1817 } 1818 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); 1819 return success(); 1820 } 1821 1822 private: 1823 DistributedReductionFn distributedReductionFn; 1824 }; 1825 1826 } // namespace 1827 1828 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 1829 RewritePatternSet &patterns, 1830 const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { 1831 patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit); 1832 } 1833 1834 void mlir::vector::populateDistributeTransferWriteOpPatterns( 1835 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 1836 PatternBenefit benefit) { 1837 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn, 1838 benefit); 1839 } 1840 1841 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 1842 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 1843 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, 1844 PatternBenefit readBenefit) { 1845 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); 1846 patterns 1847 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, 1848 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, 1849 WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>( 1850 patterns.getContext(), benefit); 1851 patterns.add<WarpOpExtractElement>(patterns.getContext(), 1852 warpShuffleFromIdxFn, benefit); 1853 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, 1854 benefit); 1855 } 1856 1857 void mlir::vector::populateDistributeReduction( 1858 RewritePatternSet &patterns, 1859 const DistributedReductionFn &distributedReductionFn, 1860 PatternBenefit benefit) { 1861 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn, 1862 benefit); 1863 } 1864 1865 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 1866 Block *body = warpOp.getBody(); 1867 1868 // Keep track of the ops we want to hoist. 1869 llvm::SmallSetVector<Operation *, 8> opsToMove; 1870 1871 // Helper to check if a value is or will be defined outside of the region. 1872 auto isDefinedOutsideOfBody = [&](Value value) { 1873 auto *definingOp = value.getDefiningOp(); 1874 return (definingOp && opsToMove.count(definingOp)) || 1875 warpOp.isDefinedOutsideOfRegion(value); 1876 }; 1877 1878 // Do not use walk here, as we do not want to go into nested regions and hoist 1879 // operations from there. 1880 for (auto &op : body->without_terminator()) { 1881 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 1882 return isa<VectorType>(result.getType()); 1883 }); 1884 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 1885 opsToMove.insert(&op); 1886 } 1887 1888 // Move all the ops marked as uniform outside of the region. 1889 for (Operation *op : opsToMove) 1890 op->moveBefore(warpOp); 1891 } 1892