1d02f10d9SThomas Raoux //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2d02f10d9SThomas Raoux // 3d02f10d9SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d02f10d9SThomas Raoux // See https://llvm.org/LICENSE.txt for license information. 5d02f10d9SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d02f10d9SThomas Raoux // 7d02f10d9SThomas Raoux //===----------------------------------------------------------------------===// 8d02f10d9SThomas Raoux 9ed0288f7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h" 10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 11ecaf2c33SPetr Kurapov #include "mlir/Dialect/GPU/IR/GPUDialect.h" 12*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" 13d02f10d9SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h" 148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 15fa8a10a1SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 16d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 17fa8a10a1SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 18fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 1991f62f0eSThomas Raoux #include "mlir/Transforms/RegionUtils.h" 20d7d6443dSThomas Raoux #include "llvm/ADT/SetVector.h" 2180636227SJakub Kuderski #include "llvm/Support/FormatVariadic.h" 2208d651d7SMehdi Amini #include <utility> 2308d651d7SMehdi Amini 24d02f10d9SThomas Raoux using namespace mlir; 25d02f10d9SThomas Raoux using namespace mlir::vector; 26ecaf2c33SPetr Kurapov using namespace mlir::gpu; 27d02f10d9SThomas Raoux 284abb9e5dSThomas Raoux /// Currently the distribution map is implicit based on the vector shape. In the 294abb9e5dSThomas Raoux /// future it will be part of the op. 304abb9e5dSThomas Raoux /// Example: 314abb9e5dSThomas Raoux /// ``` 32ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 334abb9e5dSThomas Raoux /// ... 34ecaf2c33SPetr Kurapov /// gpu.yield %3 : vector<32x16x64xf32> 354abb9e5dSThomas Raoux /// } 364abb9e5dSThomas Raoux /// ``` 374abb9e5dSThomas Raoux /// Would have an implicit map of: 384abb9e5dSThomas Raoux /// `(d0, d1, d2) -> (d0, d2)` 394abb9e5dSThomas Raoux static AffineMap calculateImplicitMap(VectorType sequentialType, 404abb9e5dSThomas Raoux VectorType distributedType) { 414abb9e5dSThomas Raoux SmallVector<AffineExpr> perm; 424abb9e5dSThomas Raoux perm.reserve(1); 434abb9e5dSThomas Raoux // Check which dimensions of the sequential type are different than the 444abb9e5dSThomas Raoux // dimensions of the distributed type to know the distributed dimensions. Then 454abb9e5dSThomas Raoux // associate each distributed dimension to an ID in order. 464abb9e5dSThomas Raoux for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { 474abb9e5dSThomas Raoux if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) 484abb9e5dSThomas Raoux perm.push_back(getAffineDimExpr(i, distributedType.getContext())); 49d02f10d9SThomas Raoux } 504abb9e5dSThomas Raoux auto map = AffineMap::get(sequentialType.getRank(), 0, perm, 514abb9e5dSThomas Raoux distributedType.getContext()); 524abb9e5dSThomas Raoux return map; 53d02f10d9SThomas Raoux } 54d02f10d9SThomas Raoux 55fa8a10a1SNicolas Vasilache namespace { 56d02f10d9SThomas Raoux 57fa8a10a1SNicolas Vasilache /// Helper struct to create the load / store operations that permit transit 58fa8a10a1SNicolas Vasilache /// through the parallel / sequential and the sequential / parallel boundaries 59fa8a10a1SNicolas Vasilache /// when performing `rewriteWarpOpToScfFor`. 60fa8a10a1SNicolas Vasilache /// 614abb9e5dSThomas Raoux /// The vector distribution dimension is inferred from the vector types. 62fa8a10a1SNicolas Vasilache struct DistributedLoadStoreHelper { 63fa8a10a1SNicolas Vasilache DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, 64fa8a10a1SNicolas Vasilache Value laneId, Value zero) 65fa8a10a1SNicolas Vasilache : sequentialVal(sequentialVal), distributedVal(distributedVal), 66fa8a10a1SNicolas Vasilache laneId(laneId), zero(zero) { 675550c821STres Popp sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); 685550c821STres Popp distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); 694abb9e5dSThomas Raoux if (sequentialVectorType && distributedVectorType) 704abb9e5dSThomas Raoux distributionMap = 714abb9e5dSThomas Raoux calculateImplicitMap(sequentialVectorType, distributedVectorType); 72fa8a10a1SNicolas Vasilache } 73d02f10d9SThomas Raoux 744abb9e5dSThomas Raoux Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { 754abb9e5dSThomas Raoux int64_t distributedSize = distributedVectorType.getDimSize(index); 76fa8a10a1SNicolas Vasilache AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); 774c48f016SMatthias Springer return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, 78fa8a10a1SNicolas Vasilache ArrayRef<Value>{laneId}); 79fa8a10a1SNicolas Vasilache } 80d02f10d9SThomas Raoux 81845dc178SNicolas Vasilache /// Create a store during the process of distributing the 82845dc178SNicolas Vasilache /// `vector.warp_execute_on_thread_0` op. 83845dc178SNicolas Vasilache /// Vector distribution assumes the following convention regarding the 84845dc178SNicolas Vasilache /// temporary buffers that are created to transition values. This **must** 85845dc178SNicolas Vasilache /// be properly specified in the `options.warpAllocationFn`: 86845dc178SNicolas Vasilache /// 1. scalars of type T transit through a memref<1xT>. 87845dc178SNicolas Vasilache /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 88fa8a10a1SNicolas Vasilache Operation *buildStore(RewriterBase &b, Location loc, Value val, 89fa8a10a1SNicolas Vasilache Value buffer) { 90fa8a10a1SNicolas Vasilache assert((val == distributedVal || val == sequentialVal) && 91fa8a10a1SNicolas Vasilache "Must store either the preregistered distributed or the " 92fa8a10a1SNicolas Vasilache "preregistered sequential value."); 934abb9e5dSThomas Raoux // Scalar case can directly use memref.store. 945550c821STres Popp if (!isa<VectorType>(val.getType())) 954abb9e5dSThomas Raoux return b.create<memref::StoreOp>(loc, val, buffer, zero); 964abb9e5dSThomas Raoux 97fa8a10a1SNicolas Vasilache // Vector case must use vector::TransferWriteOp which will later lower to 98fa8a10a1SNicolas Vasilache // vector.store of memref.store depending on further lowerings. 99845dc178SNicolas Vasilache int64_t rank = sequentialVectorType.getRank(); 100845dc178SNicolas Vasilache SmallVector<Value> indices(rank, zero); 1014abb9e5dSThomas Raoux if (val == distributedVal) { 1024abb9e5dSThomas Raoux for (auto dimExpr : distributionMap.getResults()) { 1031609f1c2Slong.chen int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 1044abb9e5dSThomas Raoux indices[index] = buildDistributedOffset(b, loc, index); 1054abb9e5dSThomas Raoux } 1064abb9e5dSThomas Raoux } 107fa8a10a1SNicolas Vasilache SmallVector<bool> inBounds(indices.size(), true); 108fa8a10a1SNicolas Vasilache return b.create<vector::TransferWriteOp>( 109fa8a10a1SNicolas Vasilache loc, val, buffer, indices, 110fa8a10a1SNicolas Vasilache ArrayRef<bool>(inBounds.begin(), inBounds.end())); 111fa8a10a1SNicolas Vasilache } 112fa8a10a1SNicolas Vasilache 113845dc178SNicolas Vasilache /// Create a load during the process of distributing the 114845dc178SNicolas Vasilache /// `vector.warp_execute_on_thread_0` op. 115845dc178SNicolas Vasilache /// Vector distribution assumes the following convention regarding the 116845dc178SNicolas Vasilache /// temporary buffers that are created to transition values. This **must** 117845dc178SNicolas Vasilache /// be properly specified in the `options.warpAllocationFn`: 118845dc178SNicolas Vasilache /// 1. scalars of type T transit through a memref<1xT>. 119845dc178SNicolas Vasilache /// 2. vectors of type V<shapexT> transit through a memref<shapexT> 120845dc178SNicolas Vasilache /// 121845dc178SNicolas Vasilache /// When broadcastMode is true, the load is not distributed to account for 122ecaf2c33SPetr Kurapov /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op. 123845dc178SNicolas Vasilache /// 124845dc178SNicolas Vasilache /// Example: 125845dc178SNicolas Vasilache /// 126845dc178SNicolas Vasilache /// ``` 127ecaf2c33SPetr Kurapov /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) { 128ecaf2c33SPetr Kurapov /// gpu.yield %cst : f32 129845dc178SNicolas Vasilache /// } 130845dc178SNicolas Vasilache /// // Both types are f32. The constant %cst is broadcasted to all lanes. 131845dc178SNicolas Vasilache /// ``` 132845dc178SNicolas Vasilache /// This behavior described in more detail in the documentation of the op. 1334abb9e5dSThomas Raoux Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { 1344abb9e5dSThomas Raoux 1354abb9e5dSThomas Raoux // Scalar case can directly use memref.store. 1365550c821STres Popp if (!isa<VectorType>(type)) 137fa8a10a1SNicolas Vasilache return b.create<memref::LoadOp>(loc, buffer, zero); 138fa8a10a1SNicolas Vasilache 139fa8a10a1SNicolas Vasilache // Other cases must be vector atm. 140fa8a10a1SNicolas Vasilache // Vector case must use vector::TransferReadOp which will later lower to 141fa8a10a1SNicolas Vasilache // vector.read of memref.read depending on further lowerings. 142fa8a10a1SNicolas Vasilache assert((type == distributedVectorType || type == sequentialVectorType) && 143fa8a10a1SNicolas Vasilache "Must store either the preregistered distributed or the " 144fa8a10a1SNicolas Vasilache "preregistered sequential type."); 145fa8a10a1SNicolas Vasilache SmallVector<Value> indices(sequentialVectorType.getRank(), zero); 146fa8a10a1SNicolas Vasilache if (type == distributedVectorType) { 1474abb9e5dSThomas Raoux for (auto dimExpr : distributionMap.getResults()) { 1481609f1c2Slong.chen int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); 1494abb9e5dSThomas Raoux indices[index] = buildDistributedOffset(b, loc, index); 1504abb9e5dSThomas Raoux } 151d02f10d9SThomas Raoux } 152fa8a10a1SNicolas Vasilache SmallVector<bool> inBounds(indices.size(), true); 153fa8a10a1SNicolas Vasilache return b.create<vector::TransferReadOp>( 1545550c821STres Popp loc, cast<VectorType>(type), buffer, indices, 155fa8a10a1SNicolas Vasilache ArrayRef<bool>(inBounds.begin(), inBounds.end())); 156d02f10d9SThomas Raoux } 157d02f10d9SThomas Raoux 158fa8a10a1SNicolas Vasilache Value sequentialVal, distributedVal, laneId, zero; 159fa8a10a1SNicolas Vasilache VectorType sequentialVectorType, distributedVectorType; 1604abb9e5dSThomas Raoux AffineMap distributionMap; 161fa8a10a1SNicolas Vasilache }; 162d02f10d9SThomas Raoux 163fa8a10a1SNicolas Vasilache } // namespace 164d02f10d9SThomas Raoux 16576cf33daSThomas Raoux // Clones `op` into a new operation that takes `operands` and returns 16676cf33daSThomas Raoux // `resultTypes`. 16776cf33daSThomas Raoux static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 16876cf33daSThomas Raoux Location loc, Operation *op, 16976cf33daSThomas Raoux ArrayRef<Value> operands, 17076cf33daSThomas Raoux ArrayRef<Type> resultTypes) { 17176cf33daSThomas Raoux OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 17276cf33daSThomas Raoux op->getAttrs()); 17376cf33daSThomas Raoux return rewriter.create(res); 17476cf33daSThomas Raoux } 17576cf33daSThomas Raoux 176d02f10d9SThomas Raoux namespace { 177d02f10d9SThomas Raoux 178fa8a10a1SNicolas Vasilache /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single 179fa8a10a1SNicolas Vasilache /// thread `laneId` executes the entirety of the computation. 180fa8a10a1SNicolas Vasilache /// 181fa8a10a1SNicolas Vasilache /// After the transformation: 182fa8a10a1SNicolas Vasilache /// - the IR within the scf.if op can be thought of as executing sequentially 183fa8a10a1SNicolas Vasilache /// (from the point of view of threads along `laneId`). 184fa8a10a1SNicolas Vasilache /// - the IR outside of the scf.if op can be thought of as executing in 185fa8a10a1SNicolas Vasilache /// parallel (from the point of view of threads along `laneId`). 186fa8a10a1SNicolas Vasilache /// 187fa8a10a1SNicolas Vasilache /// Values that need to transit through the parallel / sequential and the 188fa8a10a1SNicolas Vasilache /// sequential / parallel boundaries do so via reads and writes to a temporary 189fa8a10a1SNicolas Vasilache /// memory location. 190fa8a10a1SNicolas Vasilache /// 191fa8a10a1SNicolas Vasilache /// The transformation proceeds in multiple steps: 192fa8a10a1SNicolas Vasilache /// 1. Create the scf.if op. 193fa8a10a1SNicolas Vasilache /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads 194fa8a10a1SNicolas Vasilache /// within the scf.if to transit the values captured from above. 195fa8a10a1SNicolas Vasilache /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are 196fa8a10a1SNicolas Vasilache /// consistent within the scf.if. 197fa8a10a1SNicolas Vasilache /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. 198fa8a10a1SNicolas Vasilache /// 5. Insert appropriate writes within scf.if and reads after the scf.if to 199fa8a10a1SNicolas Vasilache /// transit the values returned by the op. 200fa8a10a1SNicolas Vasilache /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are 201fa8a10a1SNicolas Vasilache /// consistent after the scf.if. 202fa8a10a1SNicolas Vasilache /// 7. Perform late cleanups. 203fa8a10a1SNicolas Vasilache /// 204fa8a10a1SNicolas Vasilache /// All this assumes the vector distribution occurs along the most minor 205fa8a10a1SNicolas Vasilache /// distributed vector dimension. 206*bc29fc93SPetr Kurapov struct WarpOpToScfIfPattern : public WarpDistributionPattern { 2074abb9e5dSThomas Raoux WarpOpToScfIfPattern(MLIRContext *context, 208d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options, 209d02f10d9SThomas Raoux PatternBenefit benefit = 1) 210*bc29fc93SPetr Kurapov : WarpDistributionPattern(context, benefit), options(options) {} 211d02f10d9SThomas Raoux 212d02f10d9SThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 213d02f10d9SThomas Raoux PatternRewriter &rewriter) const override { 214fa8a10a1SNicolas Vasilache assert(warpOp.getBodyRegion().hasOneBlock() && 215fa8a10a1SNicolas Vasilache "expected WarpOp with single block"); 216fa8a10a1SNicolas Vasilache Block *warpOpBody = &warpOp.getBodyRegion().front(); 217fa8a10a1SNicolas Vasilache Location loc = warpOp.getLoc(); 218fa8a10a1SNicolas Vasilache 219fa8a10a1SNicolas Vasilache // Passed all checks. Start rewriting. 220fa8a10a1SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 221fa8a10a1SNicolas Vasilache rewriter.setInsertionPoint(warpOp); 222fa8a10a1SNicolas Vasilache 223fa8a10a1SNicolas Vasilache // Step 1: Create scf.if op. 224fa8a10a1SNicolas Vasilache Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 225fa8a10a1SNicolas Vasilache Value isLane0 = rewriter.create<arith::CmpIOp>( 226fa8a10a1SNicolas Vasilache loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); 227fa8a10a1SNicolas Vasilache auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 228fa8a10a1SNicolas Vasilache /*withElseRegion=*/false); 229fa8a10a1SNicolas Vasilache rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 230fa8a10a1SNicolas Vasilache 231fa8a10a1SNicolas Vasilache // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and 232fa8a10a1SNicolas Vasilache // reads within the scf.if to transit the values captured from above. 233fa8a10a1SNicolas Vasilache SmallVector<Value> bbArgReplacements; 234fa8a10a1SNicolas Vasilache for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 235fa8a10a1SNicolas Vasilache Value sequentialVal = warpOpBody->getArgument(it.index()); 236fa8a10a1SNicolas Vasilache Value distributedVal = it.value(); 237fa8a10a1SNicolas Vasilache DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 238fa8a10a1SNicolas Vasilache warpOp.getLaneid(), c0); 239fa8a10a1SNicolas Vasilache 240fa8a10a1SNicolas Vasilache // Create buffer before the ifOp. 241fa8a10a1SNicolas Vasilache rewriter.setInsertionPoint(ifOp); 242fa8a10a1SNicolas Vasilache Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 243fa8a10a1SNicolas Vasilache sequentialVal.getType()); 244fa8a10a1SNicolas Vasilache // Store distributed vector into buffer, before the ifOp. 245fa8a10a1SNicolas Vasilache helper.buildStore(rewriter, loc, distributedVal, buffer); 246fa8a10a1SNicolas Vasilache // Load sequential vector from buffer, inside the ifOp. 247fa8a10a1SNicolas Vasilache rewriter.setInsertionPointToStart(ifOp.thenBlock()); 2484abb9e5dSThomas Raoux bbArgReplacements.push_back( 2494abb9e5dSThomas Raoux helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); 250fa8a10a1SNicolas Vasilache } 251fa8a10a1SNicolas Vasilache 252fa8a10a1SNicolas Vasilache // Step 3. Insert sync after all the stores and before all the loads. 253fa8a10a1SNicolas Vasilache if (!warpOp.getArgs().empty()) { 254fa8a10a1SNicolas Vasilache rewriter.setInsertionPoint(ifOp); 255fa8a10a1SNicolas Vasilache options.warpSyncronizationFn(loc, rewriter, warpOp); 256fa8a10a1SNicolas Vasilache } 257fa8a10a1SNicolas Vasilache 258fa8a10a1SNicolas Vasilache // Step 4. Move body of warpOp to ifOp. 259fa8a10a1SNicolas Vasilache rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 260fa8a10a1SNicolas Vasilache 261fa8a10a1SNicolas Vasilache // Step 5. Insert appropriate writes within scf.if and reads after the 262fa8a10a1SNicolas Vasilache // scf.if to transit the values returned by the op. 263fa8a10a1SNicolas Vasilache // TODO: at this point, we can reuse the shared memory from previous 264fa8a10a1SNicolas Vasilache // buffers. 265fa8a10a1SNicolas Vasilache SmallVector<Value> replacements; 266ecaf2c33SPetr Kurapov auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator()); 267fa8a10a1SNicolas Vasilache Location yieldLoc = yieldOp.getLoc(); 268b74192b7SRiver Riddle for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 269fa8a10a1SNicolas Vasilache Value sequentialVal = it.value(); 270fa8a10a1SNicolas Vasilache Value distributedVal = warpOp->getResult(it.index()); 271fa8a10a1SNicolas Vasilache DistributedLoadStoreHelper helper(sequentialVal, distributedVal, 272fa8a10a1SNicolas Vasilache warpOp.getLaneid(), c0); 273fa8a10a1SNicolas Vasilache 274fa8a10a1SNicolas Vasilache // Create buffer before the ifOp. 275fa8a10a1SNicolas Vasilache rewriter.setInsertionPoint(ifOp); 276fa8a10a1SNicolas Vasilache Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, 277fa8a10a1SNicolas Vasilache sequentialVal.getType()); 278fa8a10a1SNicolas Vasilache 279fa8a10a1SNicolas Vasilache // Store yielded value into buffer, inside the ifOp, before the 280fa8a10a1SNicolas Vasilache // terminator. 281fa8a10a1SNicolas Vasilache rewriter.setInsertionPoint(yieldOp); 282fa8a10a1SNicolas Vasilache helper.buildStore(rewriter, loc, sequentialVal, buffer); 283fa8a10a1SNicolas Vasilache 284fa8a10a1SNicolas Vasilache // Load distributed value from buffer, after the warpOp. 285fa8a10a1SNicolas Vasilache rewriter.setInsertionPointAfter(ifOp); 286fa8a10a1SNicolas Vasilache // Result type and yielded value type are the same. This is a broadcast. 287fa8a10a1SNicolas Vasilache // E.g.: 288ecaf2c33SPetr Kurapov // %r = gpu.warp_execute_on_lane_0(...) -> (f32) { 289ecaf2c33SPetr Kurapov // gpu.yield %cst : f32 290fa8a10a1SNicolas Vasilache // } 291fa8a10a1SNicolas Vasilache // Both types are f32. The constant %cst is broadcasted to all lanes. 292fa8a10a1SNicolas Vasilache // This is described in more detail in the documentation of the op. 2934abb9e5dSThomas Raoux replacements.push_back( 2944abb9e5dSThomas Raoux helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); 295fa8a10a1SNicolas Vasilache } 296fa8a10a1SNicolas Vasilache 297fa8a10a1SNicolas Vasilache // Step 6. Insert sync after all the stores and before all the loads. 298b74192b7SRiver Riddle if (!yieldOp.getOperands().empty()) { 299fa8a10a1SNicolas Vasilache rewriter.setInsertionPointAfter(ifOp); 300fa8a10a1SNicolas Vasilache options.warpSyncronizationFn(loc, rewriter, warpOp); 301fa8a10a1SNicolas Vasilache } 302fa8a10a1SNicolas Vasilache 303fa8a10a1SNicolas Vasilache // Step 7. Delete terminator and add empty scf.yield. 304fa8a10a1SNicolas Vasilache rewriter.eraseOp(yieldOp); 305fa8a10a1SNicolas Vasilache rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 306fa8a10a1SNicolas Vasilache rewriter.create<scf::YieldOp>(yieldLoc); 307fa8a10a1SNicolas Vasilache 308fa8a10a1SNicolas Vasilache // Compute replacements for WarpOp results. 309fa8a10a1SNicolas Vasilache rewriter.replaceOp(warpOp, replacements); 310fa8a10a1SNicolas Vasilache 311fa8a10a1SNicolas Vasilache return success(); 312d02f10d9SThomas Raoux } 313d02f10d9SThomas Raoux 314d02f10d9SThomas Raoux private: 315d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options; 316d02f10d9SThomas Raoux }; 317d02f10d9SThomas Raoux 31891f62f0eSThomas Raoux /// Return the distributed vector type based on the original type and the 31991f62f0eSThomas Raoux /// distribution map. The map is expected to have a dimension equal to the 32091f62f0eSThomas Raoux /// original type rank and should be a projection where the results are the 32191f62f0eSThomas Raoux /// distributed dimensions. The number of results should be equal to the number 32291f62f0eSThomas Raoux /// of warp sizes which is currently limited to 1. 32391f62f0eSThomas Raoux /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) 32491f62f0eSThomas Raoux /// and a warp size of 16 would distribute the second dimension (associated to 32591f62f0eSThomas Raoux /// d1) and return vector<16x2x64> 32691f62f0eSThomas Raoux static VectorType getDistributedType(VectorType originalType, AffineMap map, 32791f62f0eSThomas Raoux int64_t warpSize) { 3285262865aSKazu Hirata SmallVector<int64_t> targetShape(originalType.getShape()); 32991f62f0eSThomas Raoux for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 33091f62f0eSThomas Raoux unsigned position = map.getDimPosition(i); 331c2b95292SQuinn Dawkins if (targetShape[position] % warpSize != 0) { 332c2b95292SQuinn Dawkins if (warpSize % targetShape[position] != 0) { 33391f62f0eSThomas Raoux return VectorType(); 334c2b95292SQuinn Dawkins } 335c2b95292SQuinn Dawkins warpSize /= targetShape[position]; 336c2b95292SQuinn Dawkins targetShape[position] = 1; 337c2b95292SQuinn Dawkins continue; 338c2b95292SQuinn Dawkins } 33991f62f0eSThomas Raoux targetShape[position] = targetShape[position] / warpSize; 340c2b95292SQuinn Dawkins warpSize = 1; 341c2b95292SQuinn Dawkins break; 342c2b95292SQuinn Dawkins } 343c2b95292SQuinn Dawkins if (warpSize != 1) { 344c2b95292SQuinn Dawkins return VectorType(); 34591f62f0eSThomas Raoux } 34691f62f0eSThomas Raoux VectorType targetType = 34791f62f0eSThomas Raoux VectorType::get(targetShape, originalType.getElementType()); 34891f62f0eSThomas Raoux return targetType; 34991f62f0eSThomas Raoux } 35091f62f0eSThomas Raoux 351ed0288f7SThomas Raoux /// Distribute transfer_write ops based on the affine map returned by 35280636227SJakub Kuderski /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` 35380636227SJakub Kuderski /// will not be distributed (it should be less than the warp size). 35480636227SJakub Kuderski /// 355ed0288f7SThomas Raoux /// Example: 356ed0288f7SThomas Raoux /// ``` 357ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%id){ 358ed0288f7SThomas Raoux /// ... 359ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 360ecaf2c33SPetr Kurapov /// gpu.yield 361ed0288f7SThomas Raoux /// } 362ed0288f7SThomas Raoux /// ``` 363ed0288f7SThomas Raoux /// To 364ed0288f7SThomas Raoux /// ``` 365ecaf2c33SPetr Kurapov /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 366ed0288f7SThomas Raoux /// ... 367ecaf2c33SPetr Kurapov /// gpu.yield %v : vector<32xf32> 368ed0288f7SThomas Raoux /// } 369ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 370*bc29fc93SPetr Kurapov struct WarpOpTransferWrite : public WarpDistributionPattern { 371ed0288f7SThomas Raoux WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 37280636227SJakub Kuderski unsigned maxNumElementsToExtract, PatternBenefit b = 1) 373*bc29fc93SPetr Kurapov : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)), 37480636227SJakub Kuderski maxNumElementsToExtract(maxNumElementsToExtract) {} 375ed0288f7SThomas Raoux 376ed0288f7SThomas Raoux /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 377ed0288f7SThomas Raoux /// are multiples of the distribution ratio are supported at the moment. 378ed0288f7SThomas Raoux LogicalResult tryDistributeOp(RewriterBase &rewriter, 379ed0288f7SThomas Raoux vector::TransferWriteOp writeOp, 380ed0288f7SThomas Raoux WarpExecuteOnLane0Op warpOp) const { 3816a57d8fbSNicolas Vasilache VectorType writtenVectorType = writeOp.getVectorType(); 3826a57d8fbSNicolas Vasilache 3836a57d8fbSNicolas Vasilache // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op 3846a57d8fbSNicolas Vasilache // to separate it from the rest. 3856a57d8fbSNicolas Vasilache if (writtenVectorType.getRank() == 0) 3866a57d8fbSNicolas Vasilache return failure(); 3876a57d8fbSNicolas Vasilache 38891f62f0eSThomas Raoux // 2. Compute the distributed type. 38991f62f0eSThomas Raoux AffineMap map = distributionMapFn(writeOp.getVector()); 390ed0288f7SThomas Raoux VectorType targetType = 39191f62f0eSThomas Raoux getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); 39291f62f0eSThomas Raoux if (!targetType) 39391f62f0eSThomas Raoux return failure(); 394ed0288f7SThomas Raoux 39525ec1fa9SQuinn Dawkins // 2.5 Compute the distributed type for the new mask; 39625ec1fa9SQuinn Dawkins VectorType maskType; 39725ec1fa9SQuinn Dawkins if (writeOp.getMask()) { 39825ec1fa9SQuinn Dawkins // TODO: Distribution of masked writes with non-trivial permutation maps 39925ec1fa9SQuinn Dawkins // requires the distribution of the mask to elementwise match the 40025ec1fa9SQuinn Dawkins // distribution of the permuted written vector. Currently the details 40125ec1fa9SQuinn Dawkins // of which lane is responsible for which element is captured strictly 40225ec1fa9SQuinn Dawkins // by shape information on the warp op, and thus requires materializing 40325ec1fa9SQuinn Dawkins // the permutation in IR. 40425ec1fa9SQuinn Dawkins if (!writeOp.getPermutationMap().isMinorIdentity()) 40525ec1fa9SQuinn Dawkins return failure(); 40625ec1fa9SQuinn Dawkins maskType = 40725ec1fa9SQuinn Dawkins getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); 40825ec1fa9SQuinn Dawkins } 40925ec1fa9SQuinn Dawkins 41091f62f0eSThomas Raoux // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from 4116a57d8fbSNicolas Vasilache // the rest. 4126a57d8fbSNicolas Vasilache vector::TransferWriteOp newWriteOp = 41325ec1fa9SQuinn Dawkins cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); 414ed0288f7SThomas Raoux 41591f62f0eSThomas Raoux // 4. Reindex the write using the distribution map. 4166a57d8fbSNicolas Vasilache auto newWarpOp = 4176a57d8fbSNicolas Vasilache newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); 418c2b95292SQuinn Dawkins 419c2b95292SQuinn Dawkins // Delinearize the lane id based on the way threads are divided across the 420c2b95292SQuinn Dawkins // vector. To get the number of threads per vector dimension, divide the 421c2b95292SQuinn Dawkins // sequential size by the distributed size along each dim. 422ed0288f7SThomas Raoux rewriter.setInsertionPoint(newWriteOp); 423c2b95292SQuinn Dawkins SmallVector<OpFoldResult> delinearizedIdSizes; 424c2b95292SQuinn Dawkins for (auto [seqSize, distSize] : 425c2b95292SQuinn Dawkins llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) { 426c2b95292SQuinn Dawkins assert(seqSize % distSize == 0 && "Invalid distributed vector shape"); 427c2b95292SQuinn Dawkins delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize)); 428c2b95292SQuinn Dawkins } 429c2b95292SQuinn Dawkins SmallVector<Value> delinearized; 430c2b95292SQuinn Dawkins if (map.getNumResults() > 1) { 431c2b95292SQuinn Dawkins delinearized = rewriter 432c2b95292SQuinn Dawkins .create<mlir::affine::AffineDelinearizeIndexOp>( 433c2b95292SQuinn Dawkins newWarpOp.getLoc(), newWarpOp.getLaneid(), 434c2b95292SQuinn Dawkins delinearizedIdSizes) 435c2b95292SQuinn Dawkins .getResults(); 436c2b95292SQuinn Dawkins } else { 437c2b95292SQuinn Dawkins // If there is only one map result, we can elide the delinearization 438c2b95292SQuinn Dawkins // op and use the lane id directly. 439c2b95292SQuinn Dawkins delinearized.append(targetType.getRank(), newWarpOp.getLaneid()); 440c2b95292SQuinn Dawkins } 441c2b95292SQuinn Dawkins 442ed0288f7SThomas Raoux AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 443ed0288f7SThomas Raoux Location loc = newWriteOp.getLoc(); 444ed0288f7SThomas Raoux SmallVector<Value> indices(newWriteOp.getIndices().begin(), 445ed0288f7SThomas Raoux newWriteOp.getIndices().end()); 446ed0288f7SThomas Raoux for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 447ed0288f7SThomas Raoux AffineExpr d0, d1; 448ed0288f7SThomas Raoux bindDims(newWarpOp.getContext(), d0, d1); 4491609f1c2Slong.chen auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 450ed0288f7SThomas Raoux if (!indexExpr) 451ed0288f7SThomas Raoux continue; 452ed0288f7SThomas Raoux unsigned indexPos = indexExpr.getPosition(); 4531609f1c2Slong.chen unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 454c2b95292SQuinn Dawkins Value laneId = delinearized[vectorPos]; 45591f62f0eSThomas Raoux auto scale = 45691f62f0eSThomas Raoux rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); 4574c48f016SMatthias Springer indices[indexPos] = affine::makeComposedAffineApply( 458c2b95292SQuinn Dawkins rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId}); 459ed0288f7SThomas Raoux } 460ed0288f7SThomas Raoux newWriteOp.getIndicesMutable().assign(indices); 461ed0288f7SThomas Raoux 462ed0288f7SThomas Raoux return success(); 463ed0288f7SThomas Raoux } 464ed0288f7SThomas Raoux 465ed0288f7SThomas Raoux /// Extract TransferWriteOps of vector<1x> into a separate warp op. 466ed0288f7SThomas Raoux LogicalResult tryExtractOp(RewriterBase &rewriter, 467ed0288f7SThomas Raoux vector::TransferWriteOp writeOp, 468ed0288f7SThomas Raoux WarpExecuteOnLane0Op warpOp) const { 469ed0288f7SThomas Raoux Location loc = writeOp.getLoc(); 470ed0288f7SThomas Raoux VectorType vecType = writeOp.getVectorType(); 471ed0288f7SThomas Raoux 47280636227SJakub Kuderski if (vecType.getNumElements() > maxNumElementsToExtract) { 47380636227SJakub Kuderski return rewriter.notifyMatchFailure( 47480636227SJakub Kuderski warpOp, 47580636227SJakub Kuderski llvm::formatv( 47680636227SJakub Kuderski "writes more elements ({0}) than allowed to extract ({1})", 47780636227SJakub Kuderski vecType.getNumElements(), maxNumElementsToExtract)); 47880636227SJakub Kuderski } 479ed0288f7SThomas Raoux 480ed0288f7SThomas Raoux // Do not process warp ops that contain only TransferWriteOps. 481971b8525SJakub Kuderski if (llvm::all_of(warpOp.getOps(), 482ecaf2c33SPetr Kurapov llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>)) 483ed0288f7SThomas Raoux return failure(); 484ed0288f7SThomas Raoux 485ed0288f7SThomas Raoux SmallVector<Value> yieldValues = {writeOp.getVector()}; 486ed0288f7SThomas Raoux SmallVector<Type> retTypes = {vecType}; 487d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices; 488ed0288f7SThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 489d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices); 490ed0288f7SThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 491ed0288f7SThomas Raoux 492ed0288f7SThomas Raoux // Create a second warp op that contains only writeOp. 493ed0288f7SThomas Raoux auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 494ed0288f7SThomas Raoux loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 495ed0288f7SThomas Raoux Block &body = secondWarpOp.getBodyRegion().front(); 496ed0288f7SThomas Raoux rewriter.setInsertionPointToStart(&body); 497ed0288f7SThomas Raoux auto newWriteOp = 498ed0288f7SThomas Raoux cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 499d7d6443dSThomas Raoux newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 500ed0288f7SThomas Raoux rewriter.eraseOp(writeOp); 501ecaf2c33SPetr Kurapov rewriter.create<gpu::YieldOp>(newWarpOp.getLoc()); 502ed0288f7SThomas Raoux return success(); 503ed0288f7SThomas Raoux } 504ed0288f7SThomas Raoux 505df49a97aSQuinn Dawkins LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 506ed0288f7SThomas Raoux PatternRewriter &rewriter) const override { 507ecaf2c33SPetr Kurapov auto yield = cast<gpu::YieldOp>( 508df49a97aSQuinn Dawkins warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 509df49a97aSQuinn Dawkins Operation *lastNode = yield->getPrevNode(); 510df49a97aSQuinn Dawkins auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); 511df49a97aSQuinn Dawkins if (!writeOp) 512ed0288f7SThomas Raoux return failure(); 513ed0288f7SThomas Raoux 51425ec1fa9SQuinn Dawkins Value maybeMask = writeOp.getMask(); 515ed0288f7SThomas Raoux if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 516ed0288f7SThomas Raoux return writeOp.getVector() == value || 51725ec1fa9SQuinn Dawkins (maybeMask && maybeMask == value) || 518ed0288f7SThomas Raoux warpOp.isDefinedOutsideOfRegion(value); 519ed0288f7SThomas Raoux })) 520ed0288f7SThomas Raoux return failure(); 521ed0288f7SThomas Raoux 522ed0288f7SThomas Raoux if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 523ed0288f7SThomas Raoux return success(); 524ed0288f7SThomas Raoux 52525ec1fa9SQuinn Dawkins // Masked writes not supported for extraction. 52625ec1fa9SQuinn Dawkins if (writeOp.getMask()) 52725ec1fa9SQuinn Dawkins return failure(); 52825ec1fa9SQuinn Dawkins 529ed0288f7SThomas Raoux if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 530ed0288f7SThomas Raoux return success(); 531ed0288f7SThomas Raoux 532ed0288f7SThomas Raoux return failure(); 533ed0288f7SThomas Raoux } 534ed0288f7SThomas Raoux 535ed0288f7SThomas Raoux private: 536*bc29fc93SPetr Kurapov /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp 537*bc29fc93SPetr Kurapov /// execute op with the proper return type. The new write op is updated to 538*bc29fc93SPetr Kurapov /// write the result of the new warp execute op. The old `writeOp` is deleted. 539*bc29fc93SPetr Kurapov vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, 540*bc29fc93SPetr Kurapov WarpExecuteOnLane0Op warpOp, 541*bc29fc93SPetr Kurapov vector::TransferWriteOp writeOp, 542*bc29fc93SPetr Kurapov VectorType targetType, 543*bc29fc93SPetr Kurapov VectorType maybeMaskType) const { 544*bc29fc93SPetr Kurapov assert(writeOp->getParentOp() == warpOp && 545*bc29fc93SPetr Kurapov "write must be nested immediately under warp"); 546*bc29fc93SPetr Kurapov OpBuilder::InsertionGuard g(rewriter); 547*bc29fc93SPetr Kurapov SmallVector<size_t> newRetIndices; 548*bc29fc93SPetr Kurapov WarpExecuteOnLane0Op newWarpOp; 549*bc29fc93SPetr Kurapov if (maybeMaskType) { 550*bc29fc93SPetr Kurapov newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 551*bc29fc93SPetr Kurapov rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, 552*bc29fc93SPetr Kurapov TypeRange{targetType, maybeMaskType}, newRetIndices); 553*bc29fc93SPetr Kurapov } else { 554*bc29fc93SPetr Kurapov newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 555*bc29fc93SPetr Kurapov rewriter, warpOp, ValueRange{{writeOp.getVector()}}, 556*bc29fc93SPetr Kurapov TypeRange{targetType}, newRetIndices); 557*bc29fc93SPetr Kurapov } 558*bc29fc93SPetr Kurapov rewriter.setInsertionPointAfter(newWarpOp); 559*bc29fc93SPetr Kurapov auto newWriteOp = 560*bc29fc93SPetr Kurapov cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 561*bc29fc93SPetr Kurapov rewriter.eraseOp(writeOp); 562*bc29fc93SPetr Kurapov newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 563*bc29fc93SPetr Kurapov if (maybeMaskType) 564*bc29fc93SPetr Kurapov newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); 565*bc29fc93SPetr Kurapov return newWriteOp; 566*bc29fc93SPetr Kurapov } 567*bc29fc93SPetr Kurapov 568ed0288f7SThomas Raoux DistributionMapFn distributionMapFn; 56980636227SJakub Kuderski unsigned maxNumElementsToExtract = 1; 570ed0288f7SThomas Raoux }; 571ed0288f7SThomas Raoux 57276cf33daSThomas Raoux /// Sink out elementwise op feeding into a warp op yield. 57376cf33daSThomas Raoux /// ``` 574ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 57576cf33daSThomas Raoux /// ... 57676cf33daSThomas Raoux /// %3 = arith.addf %1, %2 : vector<32xf32> 577ecaf2c33SPetr Kurapov /// gpu.yield %3 : vector<32xf32> 57876cf33daSThomas Raoux /// } 57976cf33daSThomas Raoux /// ``` 58076cf33daSThomas Raoux /// To 58176cf33daSThomas Raoux /// ``` 582ecaf2c33SPetr Kurapov /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 58376cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) { 58476cf33daSThomas Raoux /// ... 58576cf33daSThomas Raoux /// %4 = arith.addf %2, %3 : vector<32xf32> 586ecaf2c33SPetr Kurapov /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 58776cf33daSThomas Raoux /// vector<32xf32> 58876cf33daSThomas Raoux /// } 58976cf33daSThomas Raoux /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 590*bc29fc93SPetr Kurapov struct WarpOpElementwise : public WarpDistributionPattern { 591*bc29fc93SPetr Kurapov using Base::Base; 59276cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 59376cf33daSThomas Raoux PatternRewriter &rewriter) const override { 59476cf33daSThomas Raoux OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 59576cf33daSThomas Raoux return OpTrait::hasElementwiseMappableTraits(op); 59676cf33daSThomas Raoux }); 59776cf33daSThomas Raoux if (!yieldOperand) 59876cf33daSThomas Raoux return failure(); 599aa2376a0SQuinn Dawkins 60076cf33daSThomas Raoux Operation *elementWise = yieldOperand->get().getDefiningOp(); 60176cf33daSThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber(); 60276cf33daSThomas Raoux Value distributedVal = warpOp.getResult(operandIndex); 60376cf33daSThomas Raoux SmallVector<Value> yieldValues; 60476cf33daSThomas Raoux SmallVector<Type> retTypes; 60576cf33daSThomas Raoux Location loc = warpOp.getLoc(); 60676cf33daSThomas Raoux for (OpOperand &operand : elementWise->getOpOperands()) { 60776cf33daSThomas Raoux Type targetType; 6085550c821STres Popp if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { 60976cf33daSThomas Raoux // If the result type is a vector, the operands must also be vectors. 6105550c821STres Popp auto operandType = cast<VectorType>(operand.get().getType()); 61176cf33daSThomas Raoux targetType = 61276cf33daSThomas Raoux VectorType::get(vecType.getShape(), operandType.getElementType()); 61376cf33daSThomas Raoux } else { 61476cf33daSThomas Raoux auto operandType = operand.get().getType(); 6155550c821STres Popp assert(!isa<VectorType>(operandType) && 61676cf33daSThomas Raoux "unexpected yield of vector from op with scalar result type"); 61776cf33daSThomas Raoux targetType = operandType; 61876cf33daSThomas Raoux } 61976cf33daSThomas Raoux retTypes.push_back(targetType); 62076cf33daSThomas Raoux yieldValues.push_back(operand.get()); 62176cf33daSThomas Raoux } 622d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices; 62376cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 624d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices); 62576cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 62676cf33daSThomas Raoux SmallVector<Value> newOperands(elementWise->getOperands().begin(), 62776cf33daSThomas Raoux elementWise->getOperands().end()); 62876cf33daSThomas Raoux for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 629d7d6443dSThomas Raoux newOperands[i] = newWarpOp.getResult(newRetIndices[i]); 63076cf33daSThomas Raoux } 63176cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter); 63276cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 63376cf33daSThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes( 63476cf33daSThomas Raoux rewriter, loc, elementWise, newOperands, 63576cf33daSThomas Raoux {newWarpOp.getResult(operandIndex).getType()}); 6367ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), 6377ecc921dSMatthias Springer newOp->getResult(0)); 63876cf33daSThomas Raoux return success(); 63976cf33daSThomas Raoux } 64076cf33daSThomas Raoux }; 64176cf33daSThomas Raoux 6420af26805SThomas Raoux /// Sink out splat constant op feeding into a warp op yield. 6430af26805SThomas Raoux /// ``` 644ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 6450af26805SThomas Raoux /// ... 6460af26805SThomas Raoux /// %cst = arith.constant dense<2.0> : vector<32xf32> 647ecaf2c33SPetr Kurapov /// gpu.yield %cst : vector<32xf32> 6480af26805SThomas Raoux /// } 6490af26805SThomas Raoux /// ``` 6500af26805SThomas Raoux /// To 6510af26805SThomas Raoux /// ``` 652ecaf2c33SPetr Kurapov /// gpu.warp_execute_on_lane_0(%arg0 { 6530af26805SThomas Raoux /// ... 6540af26805SThomas Raoux /// } 6550af26805SThomas Raoux /// %0 = arith.constant dense<2.0> : vector<1xf32> 656*bc29fc93SPetr Kurapov struct WarpOpConstant : public WarpDistributionPattern { 657*bc29fc93SPetr Kurapov using Base::Base; 6580af26805SThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 6590af26805SThomas Raoux PatternRewriter &rewriter) const override { 660971b8525SJakub Kuderski OpOperand *yieldOperand = 661971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); 6620af26805SThomas Raoux if (!yieldOperand) 6630af26805SThomas Raoux return failure(); 6640af26805SThomas Raoux auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); 6655550c821STres Popp auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); 6660af26805SThomas Raoux if (!dense) 6670af26805SThomas Raoux return failure(); 668aa2376a0SQuinn Dawkins // Notify the rewriter that the warp op is changing (see the comment on 669aa2376a0SQuinn Dawkins // the WarpOpTransferRead pattern). 6705fcf907bSMatthias Springer rewriter.startOpModification(warpOp); 6710af26805SThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber(); 6720af26805SThomas Raoux Attribute scalarAttr = dense.getSplatValue<Attribute>(); 6736089d612SRahul Kayaith auto newAttr = DenseElementsAttr::get( 6746089d612SRahul Kayaith cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); 6750af26805SThomas Raoux Location loc = warpOp.getLoc(); 6760af26805SThomas Raoux rewriter.setInsertionPointAfter(warpOp); 6770af26805SThomas Raoux Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); 6787ecc921dSMatthias Springer rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); 6795fcf907bSMatthias Springer rewriter.finalizeOpModification(warpOp); 6800af26805SThomas Raoux return success(); 6810af26805SThomas Raoux } 6820af26805SThomas Raoux }; 6830af26805SThomas Raoux 68476cf33daSThomas Raoux /// Sink out transfer_read op feeding into a warp op yield. 68576cf33daSThomas Raoux /// ``` 686ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 68776cf33daSThomas Raoux /// ... 68876cf33daSThomas Raoux // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 68976cf33daSThomas Raoux // vector<32xf32> 690ecaf2c33SPetr Kurapov /// gpu.yield %2 : vector<32xf32> 69176cf33daSThomas Raoux /// } 69276cf33daSThomas Raoux /// ``` 69376cf33daSThomas Raoux /// To 69476cf33daSThomas Raoux /// ``` 695ecaf2c33SPetr Kurapov /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 69676cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) { 69776cf33daSThomas Raoux /// ... 69876cf33daSThomas Raoux /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 699ecaf2c33SPetr Kurapov /// vector<32xf32> gpu.yield %2 : vector<32xf32> 70076cf33daSThomas Raoux /// } 70176cf33daSThomas Raoux /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 702*bc29fc93SPetr Kurapov struct WarpOpTransferRead : public WarpDistributionPattern { 703*bc29fc93SPetr Kurapov using Base::Base; 70476cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 70576cf33daSThomas Raoux PatternRewriter &rewriter) const override { 7067360d5d3SQuinn Dawkins // Try to find a distributable yielded read. Note that this pattern can 7077360d5d3SQuinn Dawkins // still fail at the end after distribution, in which case this might have 7087360d5d3SQuinn Dawkins // missed another distributable read. 7097360d5d3SQuinn Dawkins OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { 7107360d5d3SQuinn Dawkins // Don't duplicate transfer_read ops when distributing. 7117360d5d3SQuinn Dawkins return isa<vector::TransferReadOp>(op) && op->hasOneUse(); 7127360d5d3SQuinn Dawkins }); 71376cf33daSThomas Raoux if (!operand) 71435c19fddSMatthias Springer return rewriter.notifyMatchFailure( 71535c19fddSMatthias Springer warpOp, "warp result is not a vector.transfer_read op"); 71676cf33daSThomas Raoux auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 7177360d5d3SQuinn Dawkins 71835c19fddSMatthias Springer // Source must be defined outside of the region. 71935c19fddSMatthias Springer if (!warpOp.isDefinedOutsideOfRegion(read.getSource())) 72035c19fddSMatthias Springer return rewriter.notifyMatchFailure( 72135c19fddSMatthias Springer read, "source must be defined outside of the region"); 72235c19fddSMatthias Springer 72376cf33daSThomas Raoux unsigned operandIndex = operand->getOperandNumber(); 72476cf33daSThomas Raoux Value distributedVal = warpOp.getResult(operandIndex); 72576cf33daSThomas Raoux 72676cf33daSThomas Raoux SmallVector<Value, 4> indices(read.getIndices().begin(), 72776cf33daSThomas Raoux read.getIndices().end()); 7285550c821STres Popp auto sequentialType = cast<VectorType>(read.getResult().getType()); 7295550c821STres Popp auto distributedType = cast<VectorType>(distributedVal.getType()); 7304abb9e5dSThomas Raoux AffineMap map = calculateImplicitMap(sequentialType, distributedType); 73176cf33daSThomas Raoux AffineMap indexMap = map.compose(read.getPermutationMap()); 732771f5759SQuinn Dawkins 73335c19fddSMatthias Springer // Try to delinearize the lane ID to match the rank expected for 73435c19fddSMatthias Springer // distribution. 73535c19fddSMatthias Springer SmallVector<Value> delinearizedIds; 73635c19fddSMatthias Springer if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), 73735c19fddSMatthias Springer distributedType.getShape(), warpOp.getWarpSize(), 73835c19fddSMatthias Springer warpOp.getLaneid(), delinearizedIds)) { 73935c19fddSMatthias Springer return rewriter.notifyMatchFailure( 74035c19fddSMatthias Springer read, "cannot delinearize lane ID for distribution"); 74135c19fddSMatthias Springer } 74235c19fddSMatthias Springer assert(!delinearizedIds.empty() || map.getNumResults() == 0); 74335c19fddSMatthias Springer 74435c19fddSMatthias Springer // Distribute indices and the mask (if present). 74576cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter); 74635c19fddSMatthias Springer SmallVector<Value> additionalResults(indices.begin(), indices.end()); 74735c19fddSMatthias Springer SmallVector<Type> additionalResultTypes(indices.size(), 74835c19fddSMatthias Springer rewriter.getIndexType()); 74935c19fddSMatthias Springer additionalResults.push_back(read.getPadding()); 75035c19fddSMatthias Springer additionalResultTypes.push_back(read.getPadding().getType()); 75135c19fddSMatthias Springer 752aa2376a0SQuinn Dawkins bool hasMask = false; 753771f5759SQuinn Dawkins if (read.getMask()) { 754aa2376a0SQuinn Dawkins hasMask = true; 755771f5759SQuinn Dawkins // TODO: Distribution of masked reads with non-trivial permutation maps 756771f5759SQuinn Dawkins // requires the distribution of the mask to elementwise match the 757771f5759SQuinn Dawkins // distribution of the permuted written vector. Currently the details 758771f5759SQuinn Dawkins // of which lane is responsible for which element is captured strictly 759771f5759SQuinn Dawkins // by shape information on the warp op, and thus requires materializing 760771f5759SQuinn Dawkins // the permutation in IR. 761f385f6c9SQuinn Dawkins if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity()) 76235c19fddSMatthias Springer return rewriter.notifyMatchFailure( 76335c19fddSMatthias Springer read, "non-trivial permutation maps not supported"); 764771f5759SQuinn Dawkins VectorType maskType = 765771f5759SQuinn Dawkins getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); 76635c19fddSMatthias Springer additionalResults.push_back(read.getMask()); 76735c19fddSMatthias Springer additionalResultTypes.push_back(maskType); 768771f5759SQuinn Dawkins } 769771f5759SQuinn Dawkins 77035c19fddSMatthias Springer SmallVector<size_t> newRetIndices; 77135c19fddSMatthias Springer WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 77235c19fddSMatthias Springer rewriter, warpOp, additionalResults, additionalResultTypes, 77335c19fddSMatthias Springer newRetIndices); 77435c19fddSMatthias Springer distributedVal = newWarpOp.getResult(operandIndex); 77535c19fddSMatthias Springer 77635c19fddSMatthias Springer // Distributed indices were appended first. 77735c19fddSMatthias Springer SmallVector<Value> newIndices; 77835c19fddSMatthias Springer for (int64_t i = 0, e = indices.size(); i < e; ++i) 77935c19fddSMatthias Springer newIndices.push_back(newWarpOp.getResult(newRetIndices[i])); 78035c19fddSMatthias Springer 781771f5759SQuinn Dawkins rewriter.setInsertionPointAfter(newWarpOp); 782199442eaSLei Zhang for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { 78376cf33daSThomas Raoux AffineExpr d0, d1; 78476cf33daSThomas Raoux bindDims(read.getContext(), d0, d1); 7851609f1c2Slong.chen auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); 78676cf33daSThomas Raoux if (!indexExpr) 78776cf33daSThomas Raoux continue; 78876cf33daSThomas Raoux unsigned indexPos = indexExpr.getPosition(); 7891609f1c2Slong.chen unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); 79073ddc447SLei Zhang int64_t scale = distributedType.getDimSize(vectorPos); 79135c19fddSMatthias Springer newIndices[indexPos] = affine::makeComposedAffineApply( 7924c48f016SMatthias Springer rewriter, read.getLoc(), d0 + scale * d1, 79335c19fddSMatthias Springer {newIndices[indexPos], delinearizedIds[vectorPos]}); 79476cf33daSThomas Raoux } 79535c19fddSMatthias Springer 79635c19fddSMatthias Springer // Distributed padding value was appended right after the indices. 79735c19fddSMatthias Springer Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]); 79835c19fddSMatthias Springer // Distributed mask value was added at the end (if the op has a mask). 79935c19fddSMatthias Springer Value newMask = 80035c19fddSMatthias Springer hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) 80135c19fddSMatthias Springer : Value(); 802018d8ac9SQuentin Colombet auto newRead = rewriter.create<vector::TransferReadOp>( 80335c19fddSMatthias Springer read.getLoc(), distributedVal.getType(), read.getSource(), newIndices, 80435c19fddSMatthias Springer read.getPermutationMapAttr(), newPadding, newMask, 80576cf33daSThomas Raoux read.getInBoundsAttr()); 806018d8ac9SQuentin Colombet 8077ecc921dSMatthias Springer rewriter.replaceAllUsesWith(distributedVal, newRead); 80876cf33daSThomas Raoux return success(); 80976cf33daSThomas Raoux } 81076cf33daSThomas Raoux }; 81176cf33daSThomas Raoux 81276cf33daSThomas Raoux /// Remove any result that has no use along with the matching yieldOp operand. 81376cf33daSThomas Raoux // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 814*bc29fc93SPetr Kurapov struct WarpOpDeadResult : public WarpDistributionPattern { 815*bc29fc93SPetr Kurapov using Base::Base; 81676cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 81776cf33daSThomas Raoux PatternRewriter &rewriter) const override { 81820df17fdSNicolas Vasilache SmallVector<Type> newResultTypes; 81920df17fdSNicolas Vasilache newResultTypes.reserve(warpOp->getNumResults()); 82020df17fdSNicolas Vasilache SmallVector<Value> newYieldValues; 82120df17fdSNicolas Vasilache newYieldValues.reserve(warpOp->getNumResults()); 82220df17fdSNicolas Vasilache DenseMap<Value, int64_t> dedupYieldOperandPositionMap; 82320df17fdSNicolas Vasilache DenseMap<OpResult, int64_t> dedupResultPositionMap; 824ecaf2c33SPetr Kurapov auto yield = cast<gpu::YieldOp>( 82576cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 82620df17fdSNicolas Vasilache 82720df17fdSNicolas Vasilache // Some values may be yielded multiple times and correspond to multiple 82820df17fdSNicolas Vasilache // results. Deduplicating occurs by taking each result with its matching 82920df17fdSNicolas Vasilache // yielded value, and: 83020df17fdSNicolas Vasilache // 1. recording the unique first position at which the value is yielded. 83120df17fdSNicolas Vasilache // 2. recording for the result, the first position at which the dedup'ed 83220df17fdSNicolas Vasilache // value is yielded. 83320df17fdSNicolas Vasilache // 3. skipping from the new result types / new yielded values any result 83420df17fdSNicolas Vasilache // that has no use or whose yielded value has already been seen. 83576cf33daSThomas Raoux for (OpResult result : warpOp.getResults()) { 83620df17fdSNicolas Vasilache Value yieldOperand = yield.getOperand(result.getResultNumber()); 83720df17fdSNicolas Vasilache auto it = dedupYieldOperandPositionMap.insert( 83820df17fdSNicolas Vasilache std::make_pair(yieldOperand, newResultTypes.size())); 83920df17fdSNicolas Vasilache dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); 84020df17fdSNicolas Vasilache if (result.use_empty() || !it.second) 84176cf33daSThomas Raoux continue; 84220df17fdSNicolas Vasilache newResultTypes.push_back(result.getType()); 84320df17fdSNicolas Vasilache newYieldValues.push_back(yieldOperand); 84476cf33daSThomas Raoux } 84520df17fdSNicolas Vasilache // No modification, exit early. 84620df17fdSNicolas Vasilache if (yield.getNumOperands() == newYieldValues.size()) 84776cf33daSThomas Raoux return failure(); 84820df17fdSNicolas Vasilache // Move the body of the old warpOp to a new warpOp. 84976cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 85020df17fdSNicolas Vasilache rewriter, warpOp, newYieldValues, newResultTypes); 8517360d5d3SQuinn Dawkins 8527360d5d3SQuinn Dawkins // Simplify the new warp op after dropping dead results. 8537360d5d3SQuinn Dawkins newWarpOp.getBody()->walk([&](Operation *op) { 8547360d5d3SQuinn Dawkins if (isOpTriviallyDead(op)) 8557360d5d3SQuinn Dawkins rewriter.eraseOp(op); 8567360d5d3SQuinn Dawkins }); 8577360d5d3SQuinn Dawkins 85820df17fdSNicolas Vasilache // Replace results of the old warpOp by the new, deduplicated results. 85920df17fdSNicolas Vasilache SmallVector<Value> newValues; 86020df17fdSNicolas Vasilache newValues.reserve(warpOp->getNumResults()); 86176cf33daSThomas Raoux for (OpResult result : warpOp.getResults()) { 86276cf33daSThomas Raoux if (result.use_empty()) 86320df17fdSNicolas Vasilache newValues.push_back(Value()); 86420df17fdSNicolas Vasilache else 86520df17fdSNicolas Vasilache newValues.push_back( 86620df17fdSNicolas Vasilache newWarpOp.getResult(dedupResultPositionMap.lookup(result))); 86776cf33daSThomas Raoux } 86820df17fdSNicolas Vasilache rewriter.replaceOp(warpOp, newValues); 86976cf33daSThomas Raoux return success(); 87076cf33daSThomas Raoux } 87176cf33daSThomas Raoux }; 87276cf33daSThomas Raoux 87376cf33daSThomas Raoux // If an operand is directly yielded out of the region we can forward it 87476cf33daSThomas Raoux // directly and it doesn't need to go through the region. 875*bc29fc93SPetr Kurapov struct WarpOpForwardOperand : public WarpDistributionPattern { 876*bc29fc93SPetr Kurapov using Base::Base; 87776cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 87876cf33daSThomas Raoux PatternRewriter &rewriter) const override { 87976cf33daSThomas Raoux SmallVector<Type> resultTypes; 88076cf33daSThomas Raoux SmallVector<Value> yieldValues; 881ecaf2c33SPetr Kurapov auto yield = cast<gpu::YieldOp>( 88276cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 88376cf33daSThomas Raoux Value valForwarded; 88476cf33daSThomas Raoux unsigned resultIndex; 88576cf33daSThomas Raoux for (OpOperand &operand : yield->getOpOperands()) { 88676cf33daSThomas Raoux Value result = warpOp.getResult(operand.getOperandNumber()); 88776cf33daSThomas Raoux if (result.use_empty()) 88876cf33daSThomas Raoux continue; 88976cf33daSThomas Raoux 89076cf33daSThomas Raoux // Assume all the values coming from above are uniform. 89176cf33daSThomas Raoux if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 89276cf33daSThomas Raoux if (result.getType() != operand.get().getType()) 89376cf33daSThomas Raoux continue; 89476cf33daSThomas Raoux valForwarded = operand.get(); 89576cf33daSThomas Raoux resultIndex = operand.getOperandNumber(); 89676cf33daSThomas Raoux break; 89776cf33daSThomas Raoux } 8985550c821STres Popp auto arg = dyn_cast<BlockArgument>(operand.get()); 89976cf33daSThomas Raoux if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 90076cf33daSThomas Raoux continue; 90176cf33daSThomas Raoux Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 90276cf33daSThomas Raoux if (result.getType() != warpOperand.getType()) 90376cf33daSThomas Raoux continue; 90476cf33daSThomas Raoux valForwarded = warpOperand; 90576cf33daSThomas Raoux resultIndex = operand.getOperandNumber(); 90676cf33daSThomas Raoux break; 90776cf33daSThomas Raoux } 90876cf33daSThomas Raoux if (!valForwarded) 90976cf33daSThomas Raoux return failure(); 910aa2376a0SQuinn Dawkins // Notify the rewriter that the warp op is changing (see the comment on 911aa2376a0SQuinn Dawkins // the WarpOpTransferRead pattern). 9125fcf907bSMatthias Springer rewriter.startOpModification(warpOp); 9137ecc921dSMatthias Springer rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); 9145fcf907bSMatthias Springer rewriter.finalizeOpModification(warpOp); 91576cf33daSThomas Raoux return success(); 91676cf33daSThomas Raoux } 91776cf33daSThomas Raoux }; 91876cf33daSThomas Raoux 919*bc29fc93SPetr Kurapov struct WarpOpBroadcast : public WarpDistributionPattern { 920*bc29fc93SPetr Kurapov using Base::Base; 92176cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 92276cf33daSThomas Raoux PatternRewriter &rewriter) const override { 923971b8525SJakub Kuderski OpOperand *operand = 924971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); 92576cf33daSThomas Raoux if (!operand) 92676cf33daSThomas Raoux return failure(); 92776cf33daSThomas Raoux unsigned int operandNumber = operand->getOperandNumber(); 92876cf33daSThomas Raoux auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 92976cf33daSThomas Raoux Location loc = broadcastOp.getLoc(); 93076cf33daSThomas Raoux auto destVecType = 9315550c821STres Popp cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 9321dd00d39SQuentin Colombet Value broadcastSrc = broadcastOp.getSource(); 9331dd00d39SQuentin Colombet Type broadcastSrcType = broadcastSrc.getType(); 9341dd00d39SQuentin Colombet 9351dd00d39SQuentin Colombet // Check that the broadcast actually spans a set of values uniformly across 9361dd00d39SQuentin Colombet // all threads. In other words, check that each thread can reconstruct 9371dd00d39SQuentin Colombet // their own broadcast. 9381dd00d39SQuentin Colombet // For that we simply check that the broadcast we want to build makes sense. 9391dd00d39SQuentin Colombet if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != 9401dd00d39SQuentin Colombet vector::BroadcastableToResult::Success) 9411dd00d39SQuentin Colombet return failure(); 942d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices; 94376cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 9441dd00d39SQuentin Colombet rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); 94576cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 94676cf33daSThomas Raoux Value broadcasted = rewriter.create<vector::BroadcastOp>( 947d7d6443dSThomas Raoux loc, destVecType, newWarpOp->getResult(newRetIndices[0])); 9487ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 9497ecc921dSMatthias Springer broadcasted); 95076cf33daSThomas Raoux return success(); 95176cf33daSThomas Raoux } 95276cf33daSThomas Raoux }; 95376cf33daSThomas Raoux 95473ddc447SLei Zhang /// Pattern to move shape cast out of the warp op. shape cast is basically a 95573ddc447SLei Zhang /// no-op for warp distribution; we need to handle the shape though. 956*bc29fc93SPetr Kurapov struct WarpOpShapeCast : public WarpDistributionPattern { 957*bc29fc93SPetr Kurapov using Base::Base; 95873ddc447SLei Zhang LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 95973ddc447SLei Zhang PatternRewriter &rewriter) const override { 960971b8525SJakub Kuderski OpOperand *operand = 961971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); 96273ddc447SLei Zhang if (!operand) 96373ddc447SLei Zhang return failure(); 964aa2376a0SQuinn Dawkins 96573ddc447SLei Zhang auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); 96673ddc447SLei Zhang 96773ddc447SLei Zhang unsigned int operandNumber = operand->getOperandNumber(); 96873ddc447SLei Zhang auto castDistributedType = 96973ddc447SLei Zhang cast<VectorType>(warpOp->getResultTypes()[operandNumber]); 97073ddc447SLei Zhang VectorType castOriginalType = oldCastOp.getSourceVectorType(); 97173ddc447SLei Zhang VectorType castResultType = castDistributedType; 97273ddc447SLei Zhang 97373ddc447SLei Zhang // We expect the distributed type to have a smaller rank than the original 97473ddc447SLei Zhang // type. Prepend with size-one dimensions to make them the same. 97573ddc447SLei Zhang unsigned castDistributedRank = castDistributedType.getRank(); 97673ddc447SLei Zhang unsigned castOriginalRank = castOriginalType.getRank(); 97773ddc447SLei Zhang if (castDistributedRank < castOriginalRank) { 97873ddc447SLei Zhang SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); 97973ddc447SLei Zhang llvm::append_range(shape, castDistributedType.getShape()); 98073ddc447SLei Zhang castDistributedType = 98173ddc447SLei Zhang VectorType::get(shape, castDistributedType.getElementType()); 98273ddc447SLei Zhang } 98373ddc447SLei Zhang 98473ddc447SLei Zhang SmallVector<size_t> newRetIndices; 98573ddc447SLei Zhang WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 98673ddc447SLei Zhang rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, 98773ddc447SLei Zhang newRetIndices); 98873ddc447SLei Zhang rewriter.setInsertionPointAfter(newWarpOp); 98973ddc447SLei Zhang Value newCast = rewriter.create<vector::ShapeCastOp>( 99073ddc447SLei Zhang oldCastOp.getLoc(), castResultType, 99173ddc447SLei Zhang newWarpOp->getResult(newRetIndices[0])); 99273ddc447SLei Zhang rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); 99373ddc447SLei Zhang return success(); 99473ddc447SLei Zhang } 99573ddc447SLei Zhang }; 99673ddc447SLei Zhang 997d4d28914SQuinn Dawkins /// Sink out vector.create_mask op feeding into a warp op yield. 998d4d28914SQuinn Dawkins /// ``` 999d4d28914SQuinn Dawkins /// %0 = ... 1000ecaf2c33SPetr Kurapov /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 1001d4d28914SQuinn Dawkins /// ... 1002d4d28914SQuinn Dawkins /// %mask = vector.create_mask %0 : vector<32xi1> 1003ecaf2c33SPetr Kurapov /// gpu.yield %mask : vector<32xi1> 1004d4d28914SQuinn Dawkins /// } 1005d4d28914SQuinn Dawkins /// ``` 1006d4d28914SQuinn Dawkins /// To 1007d4d28914SQuinn Dawkins /// ``` 1008d4d28914SQuinn Dawkins /// %0 = ... 1009ecaf2c33SPetr Kurapov /// gpu.warp_execute_on_lane_0(%arg0) { 1010d4d28914SQuinn Dawkins /// ... 1011d4d28914SQuinn Dawkins /// } 1012d4d28914SQuinn Dawkins /// %cmp = arith.cmpi ult, %laneid, %0 1013d4d28914SQuinn Dawkins /// %ub = arith.select %cmp, %c0, %c1 1014d4d28914SQuinn Dawkins /// %1 = vector.create_mask %ub : vector<1xi1> 1015*bc29fc93SPetr Kurapov struct WarpOpCreateMask : public WarpDistributionPattern { 1016*bc29fc93SPetr Kurapov using Base::Base; 1017d4d28914SQuinn Dawkins LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1018d4d28914SQuinn Dawkins PatternRewriter &rewriter) const override { 1019971b8525SJakub Kuderski OpOperand *yieldOperand = 1020971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); 1021d4d28914SQuinn Dawkins if (!yieldOperand) 1022d4d28914SQuinn Dawkins return failure(); 1023d4d28914SQuinn Dawkins 1024d4d28914SQuinn Dawkins auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); 1025d4d28914SQuinn Dawkins 1026d4d28914SQuinn Dawkins // Early exit if any values needed for calculating the new mask indices 1027d4d28914SQuinn Dawkins // are defined inside the warp op. 1028d4d28914SQuinn Dawkins if (!llvm::all_of(mask->getOperands(), [&](Value value) { 1029d4d28914SQuinn Dawkins return warpOp.isDefinedOutsideOfRegion(value); 1030d4d28914SQuinn Dawkins })) 1031d4d28914SQuinn Dawkins return failure(); 1032d4d28914SQuinn Dawkins 1033d4d28914SQuinn Dawkins Location loc = mask.getLoc(); 1034d4d28914SQuinn Dawkins unsigned operandIndex = yieldOperand->getOperandNumber(); 1035d4d28914SQuinn Dawkins 1036d4d28914SQuinn Dawkins auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); 1037d4d28914SQuinn Dawkins VectorType seqType = mask.getVectorType(); 1038d4d28914SQuinn Dawkins ArrayRef<int64_t> seqShape = seqType.getShape(); 1039d4d28914SQuinn Dawkins ArrayRef<int64_t> distShape = distType.getShape(); 1040d4d28914SQuinn Dawkins 1041d4d28914SQuinn Dawkins rewriter.setInsertionPointAfter(warpOp); 1042d4d28914SQuinn Dawkins 1043d4d28914SQuinn Dawkins // Delinearize the lane ID for constructing the distributed mask sizes. 1044d4d28914SQuinn Dawkins SmallVector<Value> delinearizedIds; 1045d4d28914SQuinn Dawkins if (!delinearizeLaneId(rewriter, loc, seqShape, distShape, 1046d4d28914SQuinn Dawkins warpOp.getWarpSize(), warpOp.getLaneid(), 1047d4d28914SQuinn Dawkins delinearizedIds)) 1048d4d28914SQuinn Dawkins return rewriter.notifyMatchFailure( 1049d4d28914SQuinn Dawkins mask, "cannot delinearize lane ID for distribution"); 1050d4d28914SQuinn Dawkins assert(!delinearizedIds.empty()); 1051d4d28914SQuinn Dawkins 1052aa2376a0SQuinn Dawkins // Notify the rewriter that the warp op is changing (see the comment on 1053aa2376a0SQuinn Dawkins // the WarpOpTransferRead pattern). 10545fcf907bSMatthias Springer rewriter.startOpModification(warpOp); 1055aa2376a0SQuinn Dawkins 1056d4d28914SQuinn Dawkins AffineExpr s0, s1; 1057d4d28914SQuinn Dawkins bindSymbols(rewriter.getContext(), s0, s1); 1058d4d28914SQuinn Dawkins SmallVector<Value> newOperands; 1059d4d28914SQuinn Dawkins for (int i = 0, e = distShape.size(); i < e; ++i) { 1060d4d28914SQuinn Dawkins // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to 1061d4d28914SQuinn Dawkins // find the distance from the largest mask index owned by this lane to the 1062d4d28914SQuinn Dawkins // original mask size. `vector.create_mask` implicitly clamps mask 1063d4d28914SQuinn Dawkins // operands to the range [0, mask_vector_size[i]], or in other words, the 1064d4d28914SQuinn Dawkins // mask sizes are always in the range [0, mask_vector_size[i]). 1065d4d28914SQuinn Dawkins Value maskDimIdx = affine::makeComposedAffineApply( 1066d4d28914SQuinn Dawkins rewriter, loc, s1 - s0 * distShape[i], 1067d4d28914SQuinn Dawkins {delinearizedIds[i], mask.getOperand(i)}); 1068d4d28914SQuinn Dawkins newOperands.push_back(maskDimIdx); 1069d4d28914SQuinn Dawkins } 1070d4d28914SQuinn Dawkins 1071d4d28914SQuinn Dawkins auto newMask = 1072d4d28914SQuinn Dawkins rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); 1073d4d28914SQuinn Dawkins rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); 10745fcf907bSMatthias Springer rewriter.finalizeOpModification(warpOp); 1075d4d28914SQuinn Dawkins return success(); 1076d4d28914SQuinn Dawkins } 1077d4d28914SQuinn Dawkins }; 1078d4d28914SQuinn Dawkins 1079f48ce52cSThomas Raoux /// Pattern to move out vector.extract of single element vector. Those don't 1080f48ce52cSThomas Raoux /// need to be distributed and can just be propagated outside of the region. 1081*bc29fc93SPetr Kurapov struct WarpOpExtract : public WarpDistributionPattern { 1082*bc29fc93SPetr Kurapov using Base::Base; 1083f48ce52cSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1084f48ce52cSThomas Raoux PatternRewriter &rewriter) const override { 1085971b8525SJakub Kuderski OpOperand *operand = 1086971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); 1087f48ce52cSThomas Raoux if (!operand) 1088f48ce52cSThomas Raoux return failure(); 1089f48ce52cSThomas Raoux unsigned int operandNumber = operand->getOperandNumber(); 1090f48ce52cSThomas Raoux auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 1091a1aad28dSLei Zhang VectorType extractSrcType = extractOp.getSourceVectorType(); 1092f48ce52cSThomas Raoux Location loc = extractOp.getLoc(); 10939085f00bSMatthias Springer 10942f925d75SKunwar Grover // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern. 10952f925d75SKunwar Grover if (extractSrcType.getRank() <= 1) { 10969085f00bSMatthias Springer return failure(); 10979085f00bSMatthias Springer } 10989085f00bSMatthias Springer 10999085f00bSMatthias Springer // All following cases are 2d or higher dimensional source vectors. 11009085f00bSMatthias Springer 11019085f00bSMatthias Springer if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 11029085f00bSMatthias Springer // There is no distribution, this is a broadcast. Simply move the extract 11039085f00bSMatthias Springer // out of the warp op. 11049085f00bSMatthias Springer // TODO: This could be optimized. E.g., in case of a scalar result, let 11059085f00bSMatthias Springer // one lane extract and shuffle the result to all other lanes (same as 11069085f00bSMatthias Springer // the 1d case). 1107f48ce52cSThomas Raoux SmallVector<size_t> newRetIndices; 1108f48ce52cSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 11099085f00bSMatthias Springer rewriter, warpOp, {extractOp.getVector()}, 1110a1aad28dSLei Zhang {extractOp.getSourceVectorType()}, newRetIndices); 11119085f00bSMatthias Springer rewriter.setInsertionPointAfter(newWarpOp); 11129085f00bSMatthias Springer Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 11139085f00bSMatthias Springer // Extract from distributed vector. 11149085f00bSMatthias Springer Value newExtract = rewriter.create<vector::ExtractOp>( 111598f6289aSDiego Caballero loc, distributedVec, extractOp.getMixedPosition()); 11167ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 11177ecc921dSMatthias Springer newExtract); 11189085f00bSMatthias Springer return success(); 11199085f00bSMatthias Springer } 11209085f00bSMatthias Springer 11219085f00bSMatthias Springer // Find the distributed dimension. There should be exactly one. 11229085f00bSMatthias Springer auto distributedType = 11235550c821STres Popp cast<VectorType>(warpOp.getResult(operandNumber).getType()); 11245550c821STres Popp auto yieldedType = cast<VectorType>(operand->get().getType()); 11259085f00bSMatthias Springer int64_t distributedDim = -1; 11269085f00bSMatthias Springer for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 11279085f00bSMatthias Springer if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { 11289085f00bSMatthias Springer // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 11299085f00bSMatthias Springer // support distributing multiple dimensions in the future. 11309085f00bSMatthias Springer assert(distributedDim == -1 && "found multiple distributed dims"); 11319085f00bSMatthias Springer distributedDim = i; 11329085f00bSMatthias Springer } 11339085f00bSMatthias Springer } 11349085f00bSMatthias Springer assert(distributedDim != -1 && "could not find distributed dimension"); 113551ddfd76SKazu Hirata (void)distributedDim; 11369085f00bSMatthias Springer 11379085f00bSMatthias Springer // Yield source vector from warp op. 11385262865aSKazu Hirata SmallVector<int64_t> newDistributedShape(extractSrcType.getShape()); 11399085f00bSMatthias Springer for (int i = 0; i < distributedType.getRank(); ++i) 114098f6289aSDiego Caballero newDistributedShape[i + extractOp.getNumIndices()] = 11419085f00bSMatthias Springer distributedType.getDimSize(i); 11429085f00bSMatthias Springer auto newDistributedType = 11439085f00bSMatthias Springer VectorType::get(newDistributedShape, distributedType.getElementType()); 11449085f00bSMatthias Springer SmallVector<size_t> newRetIndices; 11459085f00bSMatthias Springer WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 11469085f00bSMatthias Springer rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, 1147f48ce52cSThomas Raoux newRetIndices); 1148f48ce52cSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 11499085f00bSMatthias Springer Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 11509085f00bSMatthias Springer // Extract from distributed vector. 1151f48ce52cSThomas Raoux Value newExtract = rewriter.create<vector::ExtractOp>( 115298f6289aSDiego Caballero loc, distributedVec, extractOp.getMixedPosition()); 11537ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 11547ecc921dSMatthias Springer newExtract); 1155f48ce52cSThomas Raoux return success(); 1156f48ce52cSThomas Raoux } 1157f48ce52cSThomas Raoux }; 1158f48ce52cSThomas Raoux 11592f925d75SKunwar Grover /// Pattern to move out vector.extract with a scalar result. 11602f925d75SKunwar Grover /// Only supports 1-D and 0-D sources for now. 1161*bc29fc93SPetr Kurapov struct WarpOpExtractScalar : public WarpDistributionPattern { 11622f925d75SKunwar Grover WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, 11639d51b4e4SMatthias Springer PatternBenefit b = 1) 1164*bc29fc93SPetr Kurapov : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} 11651757164eSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 11661757164eSThomas Raoux PatternRewriter &rewriter) const override { 1167971b8525SJakub Kuderski OpOperand *operand = 11682f925d75SKunwar Grover getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); 11691757164eSThomas Raoux if (!operand) 11701757164eSThomas Raoux return failure(); 11711757164eSThomas Raoux unsigned int operandNumber = operand->getOperandNumber(); 11722f925d75SKunwar Grover auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 1173a1aad28dSLei Zhang VectorType extractSrcType = extractOp.getSourceVectorType(); 11742f925d75SKunwar Grover // Only supports 1-D or 0-D sources for now. 11752f925d75SKunwar Grover if (extractSrcType.getRank() > 1) { 11762f925d75SKunwar Grover return rewriter.notifyMatchFailure( 11772f925d75SKunwar Grover extractOp, "only 0-D or 1-D source supported for now"); 11782f925d75SKunwar Grover } 117935c19fddSMatthias Springer // TODO: Supported shuffle types should be parameterizable, similar to 118035c19fddSMatthias Springer // `WarpShuffleFromIdxFn`. 118135c19fddSMatthias Springer if (!extractSrcType.getElementType().isF32() && 118235c19fddSMatthias Springer !extractSrcType.getElementType().isInteger(32)) 118335c19fddSMatthias Springer return rewriter.notifyMatchFailure( 118435c19fddSMatthias Springer extractOp, "only f32/i32 element types are supported"); 1185069d7d7eSThomas Raoux bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; 11869d51b4e4SMatthias Springer Type elType = extractSrcType.getElementType(); 11879d51b4e4SMatthias Springer VectorType distributedVecType; 1188069d7d7eSThomas Raoux if (!is0dOrVec1Extract) { 11899d51b4e4SMatthias Springer assert(extractSrcType.getRank() == 1 && 11902f925d75SKunwar Grover "expected that extract src rank is 0 or 1"); 1191069d7d7eSThomas Raoux if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) 1192069d7d7eSThomas Raoux return failure(); 11939d51b4e4SMatthias Springer int64_t elementsPerLane = 11949d51b4e4SMatthias Springer extractSrcType.getShape()[0] / warpOp.getWarpSize(); 11959d51b4e4SMatthias Springer distributedVecType = VectorType::get({elementsPerLane}, elType); 11969d51b4e4SMatthias Springer } else { 11979d51b4e4SMatthias Springer distributedVecType = extractSrcType; 11989d51b4e4SMatthias Springer } 1199ad100b36SMatthias Springer // Yield source vector and position (if present) from warp op. 1200ad100b36SMatthias Springer SmallVector<Value> additionalResults{extractOp.getVector()}; 1201ad100b36SMatthias Springer SmallVector<Type> additionalResultTypes{distributedVecType}; 12022f925d75SKunwar Grover additionalResults.append( 12032f925d75SKunwar Grover SmallVector<Value>(extractOp.getDynamicPosition())); 12042f925d75SKunwar Grover additionalResultTypes.append( 12052f925d75SKunwar Grover SmallVector<Type>(extractOp.getDynamicPosition().getTypes())); 12062f925d75SKunwar Grover 12071757164eSThomas Raoux Location loc = extractOp.getLoc(); 12081757164eSThomas Raoux SmallVector<size_t> newRetIndices; 12091757164eSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1210ad100b36SMatthias Springer rewriter, warpOp, additionalResults, additionalResultTypes, 12111757164eSThomas Raoux newRetIndices); 12121757164eSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 12139d51b4e4SMatthias Springer Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 12149d51b4e4SMatthias Springer 12159d51b4e4SMatthias Springer // 0d extract: The new warp op broadcasts the source vector to all lanes. 12169d51b4e4SMatthias Springer // All lanes extract the scalar. 1217069d7d7eSThomas Raoux if (is0dOrVec1Extract) { 1218069d7d7eSThomas Raoux Value newExtract; 12192f925d75SKunwar Grover SmallVector<int64_t> indices(extractSrcType.getRank(), 0); 1220069d7d7eSThomas Raoux newExtract = 12212f925d75SKunwar Grover rewriter.create<vector::ExtractOp>(loc, distributedVec, indices); 12227ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 12237ecc921dSMatthias Springer newExtract); 12241757164eSThomas Raoux return success(); 12251757164eSThomas Raoux } 12269d51b4e4SMatthias Springer 12272f925d75SKunwar Grover int64_t staticPos = extractOp.getStaticPosition()[0]; 12282f925d75SKunwar Grover OpFoldResult pos = ShapedType::isDynamic(staticPos) 12292f925d75SKunwar Grover ? (newWarpOp->getResult(newRetIndices[1])) 12302f925d75SKunwar Grover : OpFoldResult(rewriter.getIndexAttr(staticPos)); 12319d51b4e4SMatthias Springer // 1d extract: Distribute the source vector. One lane extracts and shuffles 12329d51b4e4SMatthias Springer // the value to all other lanes. 12339d51b4e4SMatthias Springer int64_t elementsPerLane = distributedVecType.getShape()[0]; 12349d51b4e4SMatthias Springer AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 12359d51b4e4SMatthias Springer // tid of extracting thread: pos / elementsPerLane 12362f925d75SKunwar Grover Value broadcastFromTid = affine::makeComposedAffineApply( 12372f925d75SKunwar Grover rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); 12389d51b4e4SMatthias Springer // Extract at position: pos % elementsPerLane 12392f925d75SKunwar Grover Value newPos = 124073ce971cSMatthias Springer elementsPerLane == 1 124173ce971cSMatthias Springer ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult() 12422f925d75SKunwar Grover : affine::makeComposedAffineApply(rewriter, loc, 12432f925d75SKunwar Grover sym0 % elementsPerLane, pos); 12449d51b4e4SMatthias Springer Value extracted = 12452f925d75SKunwar Grover rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos); 12469d51b4e4SMatthias Springer 12479d51b4e4SMatthias Springer // Shuffle the extracted value to all lanes. 12489d51b4e4SMatthias Springer Value shuffled = warpShuffleFromIdxFn( 12499d51b4e4SMatthias Springer loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); 12507ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); 12519d51b4e4SMatthias Springer return success(); 12529d51b4e4SMatthias Springer } 12539d51b4e4SMatthias Springer 12549d51b4e4SMatthias Springer private: 12559d51b4e4SMatthias Springer WarpShuffleFromIdxFn warpShuffleFromIdxFn; 12561757164eSThomas Raoux }; 12571757164eSThomas Raoux 12582f925d75SKunwar Grover /// Pattern to convert vector.extractelement to vector.extract. 1259*bc29fc93SPetr Kurapov struct WarpOpExtractElement : public WarpDistributionPattern { 1260*bc29fc93SPetr Kurapov using Base::Base; 12612f925d75SKunwar Grover LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 12622f925d75SKunwar Grover PatternRewriter &rewriter) const override { 12632f925d75SKunwar Grover OpOperand *operand = 12642f925d75SKunwar Grover getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); 12652f925d75SKunwar Grover if (!operand) 12662f925d75SKunwar Grover return failure(); 12672f925d75SKunwar Grover auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>(); 12682f925d75SKunwar Grover SmallVector<OpFoldResult> indices; 12692f925d75SKunwar Grover if (auto pos = extractOp.getPosition()) { 12702f925d75SKunwar Grover indices.push_back(pos); 12712f925d75SKunwar Grover } 12722f925d75SKunwar Grover rewriter.setInsertionPoint(extractOp); 12732f925d75SKunwar Grover rewriter.replaceOpWithNewOp<vector::ExtractOp>( 12742f925d75SKunwar Grover extractOp, extractOp.getVector(), indices); 12752f925d75SKunwar Grover return success(); 12762f925d75SKunwar Grover } 12772f925d75SKunwar Grover }; 12782f925d75SKunwar Grover 12792f925d75SKunwar Grover /// Pattern to move out vector.insert with a scalar input. 12802f925d75SKunwar Grover /// Only supports 1-D and 0-D destinations for now. 1281*bc29fc93SPetr Kurapov struct WarpOpInsertScalar : public WarpDistributionPattern { 1282*bc29fc93SPetr Kurapov using Base::Base; 128373ce971cSMatthias Springer LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 128473ce971cSMatthias Springer PatternRewriter &rewriter) const override { 12852f925d75SKunwar Grover OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); 128673ce971cSMatthias Springer if (!operand) 128773ce971cSMatthias Springer return failure(); 128873ce971cSMatthias Springer unsigned int operandNumber = operand->getOperandNumber(); 12892f925d75SKunwar Grover auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); 129073ce971cSMatthias Springer VectorType vecType = insertOp.getDestVectorType(); 129173ce971cSMatthias Springer VectorType distrType = 12925550c821STres Popp cast<VectorType>(warpOp.getResult(operandNumber).getType()); 12932f925d75SKunwar Grover 12942f925d75SKunwar Grover // Only supports 1-D or 0-D destinations for now. 12952f925d75SKunwar Grover if (vecType.getRank() > 1) { 12962f925d75SKunwar Grover return rewriter.notifyMatchFailure( 12972f925d75SKunwar Grover insertOp, "only 0-D or 1-D source supported for now"); 12982f925d75SKunwar Grover } 129973ce971cSMatthias Springer 130073ce971cSMatthias Springer // Yield destination vector, source scalar and position from warp op. 130173ce971cSMatthias Springer SmallVector<Value> additionalResults{insertOp.getDest(), 130273ce971cSMatthias Springer insertOp.getSource()}; 130373ce971cSMatthias Springer SmallVector<Type> additionalResultTypes{distrType, 130473ce971cSMatthias Springer insertOp.getSource().getType()}; 13052f925d75SKunwar Grover additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition())); 13062f925d75SKunwar Grover additionalResultTypes.append( 13072f925d75SKunwar Grover SmallVector<Type>(insertOp.getDynamicPosition().getTypes())); 13082f925d75SKunwar Grover 130973ce971cSMatthias Springer Location loc = insertOp.getLoc(); 131073ce971cSMatthias Springer SmallVector<size_t> newRetIndices; 131173ce971cSMatthias Springer WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 131273ce971cSMatthias Springer rewriter, warpOp, additionalResults, additionalResultTypes, 131373ce971cSMatthias Springer newRetIndices); 131473ce971cSMatthias Springer rewriter.setInsertionPointAfter(newWarpOp); 131573ce971cSMatthias Springer Value distributedVec = newWarpOp->getResult(newRetIndices[0]); 131673ce971cSMatthias Springer Value newSource = newWarpOp->getResult(newRetIndices[1]); 131773ce971cSMatthias Springer rewriter.setInsertionPointAfter(newWarpOp); 131873ce971cSMatthias Springer 13192f925d75SKunwar Grover OpFoldResult pos; 13202f925d75SKunwar Grover if (vecType.getRank() != 0) { 13212f925d75SKunwar Grover int64_t staticPos = insertOp.getStaticPosition()[0]; 13222f925d75SKunwar Grover pos = ShapedType::isDynamic(staticPos) 13232f925d75SKunwar Grover ? (newWarpOp->getResult(newRetIndices[2])) 13242f925d75SKunwar Grover : OpFoldResult(rewriter.getIndexAttr(staticPos)); 13252f925d75SKunwar Grover } 13262f925d75SKunwar Grover 13272f925d75SKunwar Grover // This condition is always true for 0-d vectors. 132873ce971cSMatthias Springer if (vecType == distrType) { 13292f925d75SKunwar Grover Value newInsert; 13302f925d75SKunwar Grover SmallVector<OpFoldResult> indices; 13312f925d75SKunwar Grover if (pos) { 13322f925d75SKunwar Grover indices.push_back(pos); 13332f925d75SKunwar Grover } 13342f925d75SKunwar Grover newInsert = rewriter.create<vector::InsertOp>(loc, newSource, 13352f925d75SKunwar Grover distributedVec, indices); 13362f925d75SKunwar Grover // Broadcast: Simply move the vector.insert op out. 13377ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 13387ecc921dSMatthias Springer newInsert); 133973ce971cSMatthias Springer return success(); 134073ce971cSMatthias Springer } 134173ce971cSMatthias Springer 134273ce971cSMatthias Springer // This is a distribution. Only one lane should insert. 134373ce971cSMatthias Springer int64_t elementsPerLane = distrType.getShape()[0]; 134473ce971cSMatthias Springer AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); 134573ce971cSMatthias Springer // tid of extracting thread: pos / elementsPerLane 13462f925d75SKunwar Grover Value insertingLane = affine::makeComposedAffineApply( 13472f925d75SKunwar Grover rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); 134873ce971cSMatthias Springer // Insert position: pos % elementsPerLane 13492f925d75SKunwar Grover OpFoldResult newPos = affine::makeComposedFoldedAffineApply( 13502f925d75SKunwar Grover rewriter, loc, sym0 % elementsPerLane, pos); 135173ce971cSMatthias Springer Value isInsertingLane = rewriter.create<arith::CmpIOp>( 135273ce971cSMatthias Springer loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 135373ce971cSMatthias Springer Value newResult = 135473ce971cSMatthias Springer rewriter 135573ce971cSMatthias Springer .create<scf::IfOp>( 13561125c5c0SFrederik Gossen loc, isInsertingLane, 135773ce971cSMatthias Springer /*thenBuilder=*/ 135873ce971cSMatthias Springer [&](OpBuilder &builder, Location loc) { 13592f925d75SKunwar Grover Value newInsert = builder.create<vector::InsertOp>( 13602f925d75SKunwar Grover loc, newSource, distributedVec, newPos); 136173ce971cSMatthias Springer builder.create<scf::YieldOp>(loc, newInsert); 136273ce971cSMatthias Springer }, 136373ce971cSMatthias Springer /*elseBuilder=*/ 136473ce971cSMatthias Springer [&](OpBuilder &builder, Location loc) { 136573ce971cSMatthias Springer builder.create<scf::YieldOp>(loc, distributedVec); 136673ce971cSMatthias Springer }) 136773ce971cSMatthias Springer .getResult(0); 13687ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 136973ce971cSMatthias Springer return success(); 137073ce971cSMatthias Springer } 137173ce971cSMatthias Springer }; 137273ce971cSMatthias Springer 1373*bc29fc93SPetr Kurapov struct WarpOpInsert : public WarpDistributionPattern { 1374*bc29fc93SPetr Kurapov using Base::Base; 13751523b729SMatthias Springer LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 13761523b729SMatthias Springer PatternRewriter &rewriter) const override { 1377971b8525SJakub Kuderski OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); 13781523b729SMatthias Springer if (!operand) 13791523b729SMatthias Springer return failure(); 13801523b729SMatthias Springer unsigned int operandNumber = operand->getOperandNumber(); 13811523b729SMatthias Springer auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); 13821523b729SMatthias Springer Location loc = insertOp.getLoc(); 13831523b729SMatthias Springer 13842f925d75SKunwar Grover // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern. 13852f925d75SKunwar Grover if (insertOp.getDestVectorType().getRank() <= 1) { 13861523b729SMatthias Springer return failure(); 13871523b729SMatthias Springer } 13881523b729SMatthias Springer 13892f925d75SKunwar Grover // All following cases are 2d or higher dimensional source vectors. 13902f925d75SKunwar Grover 13911523b729SMatthias Springer if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { 13921523b729SMatthias Springer // There is no distribution, this is a broadcast. Simply move the insert 13931523b729SMatthias Springer // out of the warp op. 13941523b729SMatthias Springer SmallVector<size_t> newRetIndices; 13951523b729SMatthias Springer WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 13961523b729SMatthias Springer rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 13971523b729SMatthias Springer {insertOp.getSourceType(), insertOp.getDestVectorType()}, 13981523b729SMatthias Springer newRetIndices); 13991523b729SMatthias Springer rewriter.setInsertionPointAfter(newWarpOp); 14001523b729SMatthias Springer Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 14011523b729SMatthias Springer Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 14021523b729SMatthias Springer Value newResult = rewriter.create<vector::InsertOp>( 140398f6289aSDiego Caballero loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 14047ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), 14057ecc921dSMatthias Springer newResult); 14061523b729SMatthias Springer return success(); 14071523b729SMatthias Springer } 14081523b729SMatthias Springer 14091523b729SMatthias Springer // Find the distributed dimension. There should be exactly one. 14101523b729SMatthias Springer auto distrDestType = 14115550c821STres Popp cast<VectorType>(warpOp.getResult(operandNumber).getType()); 14125550c821STres Popp auto yieldedType = cast<VectorType>(operand->get().getType()); 14131523b729SMatthias Springer int64_t distrDestDim = -1; 14141523b729SMatthias Springer for (int64_t i = 0; i < yieldedType.getRank(); ++i) { 14151523b729SMatthias Springer if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { 14161523b729SMatthias Springer // Keep this assert here in case WarpExecuteOnLane0Op gets extended to 14171523b729SMatthias Springer // support distributing multiple dimensions in the future. 14181523b729SMatthias Springer assert(distrDestDim == -1 && "found multiple distributed dims"); 14191523b729SMatthias Springer distrDestDim = i; 14201523b729SMatthias Springer } 14211523b729SMatthias Springer } 14221523b729SMatthias Springer assert(distrDestDim != -1 && "could not find distributed dimension"); 14231523b729SMatthias Springer 14241523b729SMatthias Springer // Compute the distributed source vector type. 14255550c821STres Popp VectorType srcVecType = cast<VectorType>(insertOp.getSourceType()); 14265262865aSKazu Hirata SmallVector<int64_t> distrSrcShape(srcVecType.getShape()); 14271523b729SMatthias Springer // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> 14281523b729SMatthias Springer // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will 14291523b729SMatthias Springer // insert a smaller vector<3xf32>. 14301523b729SMatthias Springer // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that 14311523b729SMatthias Springer // case, one lane will insert the source vector<96xf32>. The other 14321523b729SMatthias Springer // lanes will not do anything. 143398f6289aSDiego Caballero int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); 14341523b729SMatthias Springer if (distrSrcDim >= 0) 14351523b729SMatthias Springer distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); 14361523b729SMatthias Springer auto distrSrcType = 14371523b729SMatthias Springer VectorType::get(distrSrcShape, distrDestType.getElementType()); 14381523b729SMatthias Springer 14391523b729SMatthias Springer // Yield source and dest vectors from warp op. 14401523b729SMatthias Springer SmallVector<size_t> newRetIndices; 14411523b729SMatthias Springer WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 14421523b729SMatthias Springer rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, 14431523b729SMatthias Springer {distrSrcType, distrDestType}, newRetIndices); 14441523b729SMatthias Springer rewriter.setInsertionPointAfter(newWarpOp); 14451523b729SMatthias Springer Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); 14461523b729SMatthias Springer Value distributedDest = newWarpOp->getResult(newRetIndices[1]); 14471523b729SMatthias Springer 14481523b729SMatthias Springer // Insert into the distributed vector. 14491523b729SMatthias Springer Value newResult; 14501523b729SMatthias Springer if (distrSrcDim >= 0) { 14511523b729SMatthias Springer // Every lane inserts a small piece. 14521523b729SMatthias Springer newResult = rewriter.create<vector::InsertOp>( 145398f6289aSDiego Caballero loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); 14541523b729SMatthias Springer } else { 14551523b729SMatthias Springer // One lane inserts the entire source vector. 14561523b729SMatthias Springer int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); 145798f6289aSDiego Caballero SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); 145898f6289aSDiego Caballero SmallVector<int64_t> newPos = getAsIntegers(pos); 14591523b729SMatthias Springer // tid of inserting lane: pos / elementsPerLane 14601523b729SMatthias Springer Value insertingLane = rewriter.create<arith::ConstantIndexOp>( 14611523b729SMatthias Springer loc, newPos[distrDestDim] / elementsPerLane); 14621523b729SMatthias Springer Value isInsertingLane = rewriter.create<arith::CmpIOp>( 14631523b729SMatthias Springer loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); 14641523b729SMatthias Springer // Insert position: pos % elementsPerLane 14651523b729SMatthias Springer newPos[distrDestDim] %= elementsPerLane; 14661523b729SMatthias Springer auto insertingBuilder = [&](OpBuilder &builder, Location loc) { 14671523b729SMatthias Springer Value newInsert = builder.create<vector::InsertOp>( 14681523b729SMatthias Springer loc, distributedSrc, distributedDest, newPos); 14691523b729SMatthias Springer builder.create<scf::YieldOp>(loc, newInsert); 14701523b729SMatthias Springer }; 14711523b729SMatthias Springer auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { 14721523b729SMatthias Springer builder.create<scf::YieldOp>(loc, distributedDest); 14731523b729SMatthias Springer }; 14741523b729SMatthias Springer newResult = rewriter 14751125c5c0SFrederik Gossen .create<scf::IfOp>(loc, isInsertingLane, 14761523b729SMatthias Springer /*thenBuilder=*/insertingBuilder, 14771523b729SMatthias Springer /*elseBuilder=*/nonInsertingBuilder) 14781523b729SMatthias Springer .getResult(0); 14791523b729SMatthias Springer } 14801523b729SMatthias Springer 14817ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); 14821523b729SMatthias Springer return success(); 14831523b729SMatthias Springer } 14841523b729SMatthias Springer }; 14851523b729SMatthias Springer 1486*bc29fc93SPetr Kurapov struct WarpOpInsertElement : public WarpDistributionPattern { 1487*bc29fc93SPetr Kurapov using Base::Base; 14882f925d75SKunwar Grover LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 14892f925d75SKunwar Grover PatternRewriter &rewriter) const override { 14902f925d75SKunwar Grover OpOperand *operand = 14912f925d75SKunwar Grover getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); 14922f925d75SKunwar Grover if (!operand) 14932f925d75SKunwar Grover return failure(); 14942f925d75SKunwar Grover auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); 14952f925d75SKunwar Grover SmallVector<OpFoldResult> indices; 14962f925d75SKunwar Grover if (auto pos = insertOp.getPosition()) { 14972f925d75SKunwar Grover indices.push_back(pos); 14982f925d75SKunwar Grover } 14992f925d75SKunwar Grover rewriter.setInsertionPoint(insertOp); 15002f925d75SKunwar Grover rewriter.replaceOpWithNewOp<vector::InsertOp>( 15012f925d75SKunwar Grover insertOp, insertOp.getSource(), insertOp.getDest(), indices); 15022f925d75SKunwar Grover return success(); 15032f925d75SKunwar Grover } 15042f925d75SKunwar Grover }; 15052f925d75SKunwar Grover 150676cf33daSThomas Raoux /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 15072f925d75SKunwar Grover /// the scf.ForOp is the last operation in the region so that it doesn't 15082f925d75SKunwar Grover /// change the order of execution. This creates a new scf.for region after the 150976cf33daSThomas Raoux /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 151076cf33daSThomas Raoux /// WarpExecuteOnLane0Op region. Example: 151176cf33daSThomas Raoux /// ``` 1512ecaf2c33SPetr Kurapov /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 151376cf33daSThomas Raoux /// ... 151476cf33daSThomas Raoux /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 151576cf33daSThomas Raoux /// -> (vector<128xf32>) { 151676cf33daSThomas Raoux /// ... 151776cf33daSThomas Raoux /// scf.yield %r : vector<128xf32> 151876cf33daSThomas Raoux /// } 1519ecaf2c33SPetr Kurapov /// gpu.yield %v1 : vector<128xf32> 152076cf33daSThomas Raoux /// } 152176cf33daSThomas Raoux /// ``` 152276cf33daSThomas Raoux /// To: 1523ecaf2c33SPetr Kurapov /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 152476cf33daSThomas Raoux /// ... 1525ecaf2c33SPetr Kurapov /// gpu.yield %v : vector<128xf32> 152676cf33daSThomas Raoux /// } 152776cf33daSThomas Raoux /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 152876cf33daSThomas Raoux /// -> (vector<4xf32>) { 1529ecaf2c33SPetr Kurapov /// %iw = gpu.warp_execute_on_lane_0(%laneid) 153076cf33daSThomas Raoux /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 153176cf33daSThomas Raoux /// ^bb0(%arg: vector<128xf32>): 153276cf33daSThomas Raoux /// ... 1533ecaf2c33SPetr Kurapov /// gpu.yield %ir : vector<128xf32> 153476cf33daSThomas Raoux /// } 153576cf33daSThomas Raoux /// scf.yield %iw : vector<4xf32> 153676cf33daSThomas Raoux /// } 153776cf33daSThomas Raoux /// ``` 1538*bc29fc93SPetr Kurapov struct WarpOpScfForOp : public WarpDistributionPattern { 153991f62f0eSThomas Raoux 154091f62f0eSThomas Raoux WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) 1541*bc29fc93SPetr Kurapov : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} 154276cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 154376cf33daSThomas Raoux PatternRewriter &rewriter) const override { 1544ecaf2c33SPetr Kurapov auto yield = cast<gpu::YieldOp>( 154576cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 154676cf33daSThomas Raoux // Only pick up forOp if it is the last op in the region. 154776cf33daSThomas Raoux Operation *lastNode = yield->getPrevNode(); 154876cf33daSThomas Raoux auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 154976cf33daSThomas Raoux if (!forOp) 155076cf33daSThomas Raoux return failure(); 155191f62f0eSThomas Raoux // Collect Values that come from the warp op but are outside the forOp. 15522f925d75SKunwar Grover // Those Value needs to be returned by the original warpOp and passed to 15532f925d75SKunwar Grover // the new op. 155491f62f0eSThomas Raoux llvm::SmallSetVector<Value, 32> escapingValues; 155591f62f0eSThomas Raoux SmallVector<Type> inputTypes; 155691f62f0eSThomas Raoux SmallVector<Type> distTypes; 155791f62f0eSThomas Raoux mlir::visitUsedValuesDefinedAbove( 155891f62f0eSThomas Raoux forOp.getBodyRegion(), [&](OpOperand *operand) { 155991f62f0eSThomas Raoux Operation *parent = operand->get().getParentRegion()->getParentOp(); 156091f62f0eSThomas Raoux if (warpOp->isAncestor(parent)) { 156191f62f0eSThomas Raoux if (!escapingValues.insert(operand->get())) 156291f62f0eSThomas Raoux return; 156391f62f0eSThomas Raoux Type distType = operand->get().getType(); 1564d2433787SLei Zhang if (auto vecType = dyn_cast<VectorType>(distType)) { 156591f62f0eSThomas Raoux AffineMap map = distributionMapFn(operand->get()); 156691f62f0eSThomas Raoux distType = getDistributedType(vecType, map, warpOp.getWarpSize()); 156791f62f0eSThomas Raoux } 156891f62f0eSThomas Raoux inputTypes.push_back(operand->get().getType()); 156991f62f0eSThomas Raoux distTypes.push_back(distType); 157091f62f0eSThomas Raoux } 157191f62f0eSThomas Raoux }); 157291f62f0eSThomas Raoux 1573b5e47d2eSBangtian Liu if (llvm::is_contained(distTypes, Type{})) 1574b5e47d2eSBangtian Liu return failure(); 1575b5e47d2eSBangtian Liu 157691f62f0eSThomas Raoux SmallVector<size_t> newRetIndices; 157791f62f0eSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 157891f62f0eSThomas Raoux rewriter, warpOp, escapingValues.getArrayRef(), distTypes, 157991f62f0eSThomas Raoux newRetIndices); 1580ecaf2c33SPetr Kurapov yield = cast<gpu::YieldOp>( 158191f62f0eSThomas Raoux newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 158291f62f0eSThomas Raoux 158376cf33daSThomas Raoux SmallVector<Value> newOperands; 158476cf33daSThomas Raoux SmallVector<unsigned> resultIdx; 158576cf33daSThomas Raoux // Collect all the outputs coming from the forOp. 158676cf33daSThomas Raoux for (OpOperand &yieldOperand : yield->getOpOperands()) { 158776cf33daSThomas Raoux if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 158876cf33daSThomas Raoux continue; 15895550c821STres Popp auto forResult = cast<OpResult>(yieldOperand.get()); 159091f62f0eSThomas Raoux newOperands.push_back( 159191f62f0eSThomas Raoux newWarpOp.getResult(yieldOperand.getOperandNumber())); 15925cf714bbSMatthias Springer yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); 159376cf33daSThomas Raoux resultIdx.push_back(yieldOperand.getOperandNumber()); 159476cf33daSThomas Raoux } 159591f62f0eSThomas Raoux 159676cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter); 159791f62f0eSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 159891f62f0eSThomas Raoux 15992f925d75SKunwar Grover // Create a new for op outside the region with a WarpExecuteOnLane0Op 16002f925d75SKunwar Grover // region inside. 160176cf33daSThomas Raoux auto newForOp = rewriter.create<scf::ForOp>( 160276cf33daSThomas Raoux forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 160376cf33daSThomas Raoux forOp.getStep(), newOperands); 1604b613a540SMatthias Springer rewriter.setInsertionPointToStart(newForOp.getBody()); 160591f62f0eSThomas Raoux 160691f62f0eSThomas Raoux SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), 160791f62f0eSThomas Raoux newForOp.getRegionIterArgs().end()); 160891f62f0eSThomas Raoux SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), 160991f62f0eSThomas Raoux forOp.getResultTypes().end()); 161091f62f0eSThomas Raoux llvm::SmallDenseMap<Value, int64_t> argIndexMapping; 161191f62f0eSThomas Raoux for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { 161291f62f0eSThomas Raoux warpInput.push_back(newWarpOp.getResult(retIdx)); 161391f62f0eSThomas Raoux argIndexMapping[escapingValues[i]] = warpInputType.size(); 161491f62f0eSThomas Raoux warpInputType.push_back(inputTypes[i]); 161591f62f0eSThomas Raoux } 161676cf33daSThomas Raoux auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 161791f62f0eSThomas Raoux newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), 161891f62f0eSThomas Raoux newWarpOp.getWarpSize(), warpInput, warpInputType); 161976cf33daSThomas Raoux 162076cf33daSThomas Raoux SmallVector<Value> argMapping; 162176cf33daSThomas Raoux argMapping.push_back(newForOp.getInductionVar()); 162276cf33daSThomas Raoux for (Value args : innerWarp.getBody()->getArguments()) { 162376cf33daSThomas Raoux argMapping.push_back(args); 162476cf33daSThomas Raoux } 162591f62f0eSThomas Raoux argMapping.resize(forOp.getBody()->getNumArguments()); 162676cf33daSThomas Raoux SmallVector<Value> yieldOperands; 162776cf33daSThomas Raoux for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 162876cf33daSThomas Raoux yieldOperands.push_back(operand); 162976cf33daSThomas Raoux rewriter.eraseOp(forOp.getBody()->getTerminator()); 163076cf33daSThomas Raoux rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 1631b613a540SMatthias Springer rewriter.setInsertionPointToEnd(innerWarp.getBody()); 1632ecaf2c33SPetr Kurapov rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); 163376cf33daSThomas Raoux rewriter.setInsertionPointAfter(innerWarp); 1634d343cdd5SThomas Raoux if (!innerWarp.getResults().empty()) 163576cf33daSThomas Raoux rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 163676cf33daSThomas Raoux rewriter.eraseOp(forOp); 163776cf33daSThomas Raoux // Replace the warpOp result coming from the original ForOp. 163876cf33daSThomas Raoux for (const auto &res : llvm::enumerate(resultIdx)) { 16397ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), 16407ecc921dSMatthias Springer newForOp.getResult(res.index())); 164191f62f0eSThomas Raoux newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); 164276cf33daSThomas Raoux } 164391f62f0eSThomas Raoux newForOp.walk([&](Operation *op) { 164491f62f0eSThomas Raoux for (OpOperand &operand : op->getOpOperands()) { 164591f62f0eSThomas Raoux auto it = argIndexMapping.find(operand.get()); 164691f62f0eSThomas Raoux if (it == argIndexMapping.end()) 164791f62f0eSThomas Raoux continue; 164891f62f0eSThomas Raoux operand.set(innerWarp.getBodyRegion().getArgument(it->second)); 164991f62f0eSThomas Raoux } 165091f62f0eSThomas Raoux }); 165198dcd98aSQuinn Dawkins 165298dcd98aSQuinn Dawkins // Finally, hoist out any now uniform code from the inner warp op. 165398dcd98aSQuinn Dawkins mlir::vector::moveScalarUniformCode(innerWarp); 165476cf33daSThomas Raoux return success(); 165576cf33daSThomas Raoux } 165691f62f0eSThomas Raoux 165791f62f0eSThomas Raoux private: 165891f62f0eSThomas Raoux DistributionMapFn distributionMapFn; 165976cf33daSThomas Raoux }; 166076cf33daSThomas Raoux 1661087aba4fSThomas Raoux /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 16622f925d75SKunwar Grover /// The vector is reduced in parallel. Currently limited to vector size 16632f925d75SKunwar Grover /// matching the warpOp size. E.g.: 1664087aba4fSThomas Raoux /// ``` 1665ecaf2c33SPetr Kurapov /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 1666087aba4fSThomas Raoux /// %0 = "some_def"() : () -> (vector<32xf32>) 1667087aba4fSThomas Raoux /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 1668ecaf2c33SPetr Kurapov /// gpu.yield %1 : f32 1669087aba4fSThomas Raoux /// } 1670087aba4fSThomas Raoux /// ``` 1671087aba4fSThomas Raoux /// is lowered to: 1672087aba4fSThomas Raoux /// ``` 1673ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 1674087aba4fSThomas Raoux /// %1 = "some_def"() : () -> (vector<32xf32>) 1675ecaf2c33SPetr Kurapov /// gpu.yield %1 : vector<32xf32> 1676087aba4fSThomas Raoux /// } 16779816edc9SCullen Rhodes /// %a = vector.extract %0[0] : f32 from vector<1xf32> 16786834803cSThomas Raoux /// %r = ("warp.reduction %a") 1679087aba4fSThomas Raoux /// ``` 1680*bc29fc93SPetr Kurapov struct WarpOpReduction : public WarpDistributionPattern { 16816834803cSThomas Raoux WarpOpReduction(MLIRContext *context, 16826834803cSThomas Raoux DistributedReductionFn distributedReductionFn, 16836834803cSThomas Raoux PatternBenefit benefit = 1) 1684*bc29fc93SPetr Kurapov : WarpDistributionPattern(context, benefit), 168561f06774SMehdi Amini distributedReductionFn(std::move(distributedReductionFn)) {} 1686087aba4fSThomas Raoux 1687087aba4fSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 1688087aba4fSThomas Raoux PatternRewriter &rewriter) const override { 1689971b8525SJakub Kuderski OpOperand *yieldOperand = 1690971b8525SJakub Kuderski getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); 1691087aba4fSThomas Raoux if (!yieldOperand) 1692087aba4fSThomas Raoux return failure(); 1693087aba4fSThomas Raoux 1694087aba4fSThomas Raoux auto reductionOp = 1695087aba4fSThomas Raoux cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 16965550c821STres Popp auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); 1697087aba4fSThomas Raoux // Only rank 1 vectors supported. 1698087aba4fSThomas Raoux if (vectorType.getRank() != 1) 1699087aba4fSThomas Raoux return rewriter.notifyMatchFailure( 1700087aba4fSThomas Raoux warpOp, "Only rank 1 reductions can be distributed."); 1701087aba4fSThomas Raoux // Only warp_size-sized vectors supported. 17020660f3c5SThomas Raoux if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) 1703087aba4fSThomas Raoux return rewriter.notifyMatchFailure( 1704087aba4fSThomas Raoux warpOp, "Reduction vector dimension must match was size."); 1705f41abcdaSThomas Raoux if (!reductionOp.getType().isIntOrFloat()) 1706087aba4fSThomas Raoux return rewriter.notifyMatchFailure( 1707f41abcdaSThomas Raoux warpOp, "Reduction distribution currently only supports floats and " 1708f41abcdaSThomas Raoux "integer types."); 1709087aba4fSThomas Raoux 17100660f3c5SThomas Raoux int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); 1711087aba4fSThomas Raoux // Return vector that will be reduced from the WarpExecuteOnLane0Op. 1712087aba4fSThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber(); 1713087aba4fSThomas Raoux SmallVector<Value> yieldValues = {reductionOp.getVector()}; 17140660f3c5SThomas Raoux SmallVector<Type> retTypes = { 17150660f3c5SThomas Raoux VectorType::get({numElements}, reductionOp.getType())}; 1716ffa7384fSThomas Raoux if (reductionOp.getAcc()) { 1717ffa7384fSThomas Raoux yieldValues.push_back(reductionOp.getAcc()); 1718ffa7384fSThomas Raoux retTypes.push_back(reductionOp.getAcc().getType()); 1719ffa7384fSThomas Raoux } 1720d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices; 1721087aba4fSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 1722d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices); 1723087aba4fSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp); 1724087aba4fSThomas Raoux 1725d2061530Sstanley-nod // Obtain data to reduce for a single lane. 1726d7d6443dSThomas Raoux Value laneValVec = newWarpOp.getResult(newRetIndices[0]); 1727d2061530Sstanley-nod // Distribute and reduce across threads. 17280660f3c5SThomas Raoux Value fullReduce = 1729d2061530Sstanley-nod distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, 17306834803cSThomas Raoux reductionOp.getKind(), newWarpOp.getWarpSize()); 1731ffa7384fSThomas Raoux if (reductionOp.getAcc()) { 1732ffa7384fSThomas Raoux fullReduce = vector::makeArithReduction( 1733ffa7384fSThomas Raoux rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, 1734ffa7384fSThomas Raoux newWarpOp.getResult(newRetIndices[1])); 1735ffa7384fSThomas Raoux } 17367ecc921dSMatthias Springer rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); 1737087aba4fSThomas Raoux return success(); 1738087aba4fSThomas Raoux } 17396834803cSThomas Raoux 17406834803cSThomas Raoux private: 17416834803cSThomas Raoux DistributedReductionFn distributedReductionFn; 1742087aba4fSThomas Raoux }; 1743087aba4fSThomas Raoux 1744d02f10d9SThomas Raoux } // namespace 1745d02f10d9SThomas Raoux 1746d02f10d9SThomas Raoux void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 1747d02f10d9SThomas Raoux RewritePatternSet &patterns, 174827cc31b6SNicolas Vasilache const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { 17494abb9e5dSThomas Raoux patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit); 1750d02f10d9SThomas Raoux } 1751ed0288f7SThomas Raoux 1752ed0288f7SThomas Raoux void mlir::vector::populateDistributeTransferWriteOpPatterns( 175327cc31b6SNicolas Vasilache RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 175480636227SJakub Kuderski unsigned maxNumElementsToExtract, PatternBenefit benefit) { 175527cc31b6SNicolas Vasilache patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn, 175680636227SJakub Kuderski maxNumElementsToExtract, benefit); 1757ed0288f7SThomas Raoux } 1758ed0288f7SThomas Raoux 175976cf33daSThomas Raoux void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 176091f62f0eSThomas Raoux RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, 1761df49a97aSQuinn Dawkins const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, 1762df49a97aSQuinn Dawkins PatternBenefit readBenefit) { 1763df49a97aSQuinn Dawkins patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); 17642f925d75SKunwar Grover patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, 17652f925d75SKunwar Grover WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, 17662f925d75SKunwar Grover WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement, 17672f925d75SKunwar Grover WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>( 1768df49a97aSQuinn Dawkins patterns.getContext(), benefit); 17692f925d75SKunwar Grover patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, 17702f925d75SKunwar Grover benefit); 177191f62f0eSThomas Raoux patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, 177291f62f0eSThomas Raoux benefit); 177376cf33daSThomas Raoux } 177476cf33daSThomas Raoux 17756834803cSThomas Raoux void mlir::vector::populateDistributeReduction( 17766834803cSThomas Raoux RewritePatternSet &patterns, 177727cc31b6SNicolas Vasilache const DistributedReductionFn &distributedReductionFn, 177827cc31b6SNicolas Vasilache PatternBenefit benefit) { 177927cc31b6SNicolas Vasilache patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn, 178027cc31b6SNicolas Vasilache benefit); 1781087aba4fSThomas Raoux } 1782087aba4fSThomas Raoux 1783*bc29fc93SPetr Kurapov /// Helper to know if an op can be hoisted out of the region. 1784*bc29fc93SPetr Kurapov static bool canBeHoisted(Operation *op, 1785*bc29fc93SPetr Kurapov function_ref<bool(Value)> definedOutside) { 1786*bc29fc93SPetr Kurapov return llvm::all_of(op->getOperands(), definedOutside) && 1787*bc29fc93SPetr Kurapov isMemoryEffectFree(op) && op->getNumRegions() == 0; 1788*bc29fc93SPetr Kurapov } 1789*bc29fc93SPetr Kurapov 1790ed0288f7SThomas Raoux void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 1791ed0288f7SThomas Raoux Block *body = warpOp.getBody(); 1792ed0288f7SThomas Raoux 1793ed0288f7SThomas Raoux // Keep track of the ops we want to hoist. 1794ed0288f7SThomas Raoux llvm::SmallSetVector<Operation *, 8> opsToMove; 1795ed0288f7SThomas Raoux 1796ed0288f7SThomas Raoux // Helper to check if a value is or will be defined outside of the region. 1797ed0288f7SThomas Raoux auto isDefinedOutsideOfBody = [&](Value value) { 1798ed0288f7SThomas Raoux auto *definingOp = value.getDefiningOp(); 1799ed0288f7SThomas Raoux return (definingOp && opsToMove.count(definingOp)) || 1800ed0288f7SThomas Raoux warpOp.isDefinedOutsideOfRegion(value); 1801ed0288f7SThomas Raoux }; 1802ed0288f7SThomas Raoux 1803ed0288f7SThomas Raoux // Do not use walk here, as we do not want to go into nested regions and hoist 1804ed0288f7SThomas Raoux // operations from there. 1805ed0288f7SThomas Raoux for (auto &op : body->without_terminator()) { 1806ed0288f7SThomas Raoux bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 18075550c821STres Popp return isa<VectorType>(result.getType()); 1808ed0288f7SThomas Raoux }); 1809ed0288f7SThomas Raoux if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 1810ed0288f7SThomas Raoux opsToMove.insert(&op); 1811ed0288f7SThomas Raoux } 1812ed0288f7SThomas Raoux 1813ed0288f7SThomas Raoux // Move all the ops marked as uniform outside of the region. 1814ed0288f7SThomas Raoux for (Operation *op : opsToMove) 1815ed0288f7SThomas Raoux op->moveBefore(warpOp); 1816ed0288f7SThomas Raoux } 1817