xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
1d02f10d9SThomas Raoux //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2d02f10d9SThomas Raoux //
3d02f10d9SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d02f10d9SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5d02f10d9SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d02f10d9SThomas Raoux //
7d02f10d9SThomas Raoux //===----------------------------------------------------------------------===//
8d02f10d9SThomas Raoux 
9ed0288f7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
11ecaf2c33SPetr Kurapov #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
13d02f10d9SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
15fa8a10a1SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
16d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
17fa8a10a1SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
18fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
1991f62f0eSThomas Raoux #include "mlir/Transforms/RegionUtils.h"
20d7d6443dSThomas Raoux #include "llvm/ADT/SetVector.h"
2180636227SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
2208d651d7SMehdi Amini #include <utility>
2308d651d7SMehdi Amini 
24d02f10d9SThomas Raoux using namespace mlir;
25d02f10d9SThomas Raoux using namespace mlir::vector;
26ecaf2c33SPetr Kurapov using namespace mlir::gpu;
27d02f10d9SThomas Raoux 
284abb9e5dSThomas Raoux /// Currently the distribution map is implicit based on the vector shape. In the
294abb9e5dSThomas Raoux /// future it will be part of the op.
304abb9e5dSThomas Raoux /// Example:
314abb9e5dSThomas Raoux /// ```
32ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
334abb9e5dSThomas Raoux ///   ...
34ecaf2c33SPetr Kurapov ///   gpu.yield %3 : vector<32x16x64xf32>
354abb9e5dSThomas Raoux /// }
364abb9e5dSThomas Raoux /// ```
374abb9e5dSThomas Raoux /// Would have an implicit map of:
384abb9e5dSThomas Raoux /// `(d0, d1, d2) -> (d0, d2)`
394abb9e5dSThomas Raoux static AffineMap calculateImplicitMap(VectorType sequentialType,
404abb9e5dSThomas Raoux                                       VectorType distributedType) {
414abb9e5dSThomas Raoux   SmallVector<AffineExpr> perm;
424abb9e5dSThomas Raoux   perm.reserve(1);
434abb9e5dSThomas Raoux   // Check which dimensions of the sequential type are different than the
444abb9e5dSThomas Raoux   // dimensions of the distributed type to know the distributed dimensions. Then
454abb9e5dSThomas Raoux   // associate each distributed dimension to an ID in order.
464abb9e5dSThomas Raoux   for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
474abb9e5dSThomas Raoux     if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
484abb9e5dSThomas Raoux       perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49d02f10d9SThomas Raoux   }
504abb9e5dSThomas Raoux   auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
514abb9e5dSThomas Raoux                             distributedType.getContext());
524abb9e5dSThomas Raoux   return map;
53d02f10d9SThomas Raoux }
54d02f10d9SThomas Raoux 
55fa8a10a1SNicolas Vasilache namespace {
56d02f10d9SThomas Raoux 
57fa8a10a1SNicolas Vasilache /// Helper struct to create the load / store operations that permit transit
58fa8a10a1SNicolas Vasilache /// through the parallel / sequential and the sequential / parallel boundaries
59fa8a10a1SNicolas Vasilache /// when performing `rewriteWarpOpToScfFor`.
60fa8a10a1SNicolas Vasilache ///
614abb9e5dSThomas Raoux /// The vector distribution dimension is inferred from the vector types.
62fa8a10a1SNicolas Vasilache struct DistributedLoadStoreHelper {
63fa8a10a1SNicolas Vasilache   DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64fa8a10a1SNicolas Vasilache                              Value laneId, Value zero)
65fa8a10a1SNicolas Vasilache       : sequentialVal(sequentialVal), distributedVal(distributedVal),
66fa8a10a1SNicolas Vasilache         laneId(laneId), zero(zero) {
675550c821STres Popp     sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
685550c821STres Popp     distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
694abb9e5dSThomas Raoux     if (sequentialVectorType && distributedVectorType)
704abb9e5dSThomas Raoux       distributionMap =
714abb9e5dSThomas Raoux           calculateImplicitMap(sequentialVectorType, distributedVectorType);
72fa8a10a1SNicolas Vasilache   }
73d02f10d9SThomas Raoux 
744abb9e5dSThomas Raoux   Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
754abb9e5dSThomas Raoux     int64_t distributedSize = distributedVectorType.getDimSize(index);
76fa8a10a1SNicolas Vasilache     AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
774c48f016SMatthias Springer     return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78fa8a10a1SNicolas Vasilache                                                  ArrayRef<Value>{laneId});
79fa8a10a1SNicolas Vasilache   }
80d02f10d9SThomas Raoux 
81845dc178SNicolas Vasilache   /// Create a store during the process of distributing the
82845dc178SNicolas Vasilache   /// `vector.warp_execute_on_thread_0` op.
83845dc178SNicolas Vasilache   /// Vector distribution assumes the following convention regarding the
84845dc178SNicolas Vasilache   /// temporary buffers that are created to transition values. This **must**
85845dc178SNicolas Vasilache   /// be properly specified in the `options.warpAllocationFn`:
86845dc178SNicolas Vasilache   ///   1. scalars of type T transit through a memref<1xT>.
87845dc178SNicolas Vasilache   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
88fa8a10a1SNicolas Vasilache   Operation *buildStore(RewriterBase &b, Location loc, Value val,
89fa8a10a1SNicolas Vasilache                         Value buffer) {
90fa8a10a1SNicolas Vasilache     assert((val == distributedVal || val == sequentialVal) &&
91fa8a10a1SNicolas Vasilache            "Must store either the preregistered distributed or the "
92fa8a10a1SNicolas Vasilache            "preregistered sequential value.");
934abb9e5dSThomas Raoux     // Scalar case can directly use memref.store.
945550c821STres Popp     if (!isa<VectorType>(val.getType()))
954abb9e5dSThomas Raoux       return b.create<memref::StoreOp>(loc, val, buffer, zero);
964abb9e5dSThomas Raoux 
97fa8a10a1SNicolas Vasilache     // Vector case must use vector::TransferWriteOp which will later lower to
98fa8a10a1SNicolas Vasilache     //   vector.store of memref.store depending on further lowerings.
99845dc178SNicolas Vasilache     int64_t rank = sequentialVectorType.getRank();
100845dc178SNicolas Vasilache     SmallVector<Value> indices(rank, zero);
1014abb9e5dSThomas Raoux     if (val == distributedVal) {
1024abb9e5dSThomas Raoux       for (auto dimExpr : distributionMap.getResults()) {
1031609f1c2Slong.chen         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
1044abb9e5dSThomas Raoux         indices[index] = buildDistributedOffset(b, loc, index);
1054abb9e5dSThomas Raoux       }
1064abb9e5dSThomas Raoux     }
107fa8a10a1SNicolas Vasilache     SmallVector<bool> inBounds(indices.size(), true);
108fa8a10a1SNicolas Vasilache     return b.create<vector::TransferWriteOp>(
109fa8a10a1SNicolas Vasilache         loc, val, buffer, indices,
110fa8a10a1SNicolas Vasilache         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111fa8a10a1SNicolas Vasilache   }
112fa8a10a1SNicolas Vasilache 
113845dc178SNicolas Vasilache   /// Create a load during the process of distributing the
114845dc178SNicolas Vasilache   /// `vector.warp_execute_on_thread_0` op.
115845dc178SNicolas Vasilache   /// Vector distribution assumes the following convention regarding the
116845dc178SNicolas Vasilache   /// temporary buffers that are created to transition values. This **must**
117845dc178SNicolas Vasilache   /// be properly specified in the `options.warpAllocationFn`:
118845dc178SNicolas Vasilache   ///   1. scalars of type T transit through a memref<1xT>.
119845dc178SNicolas Vasilache   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
120845dc178SNicolas Vasilache   ///
121845dc178SNicolas Vasilache   /// When broadcastMode is true, the load is not distributed to account for
122ecaf2c33SPetr Kurapov   /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123845dc178SNicolas Vasilache   ///
124845dc178SNicolas Vasilache   /// Example:
125845dc178SNicolas Vasilache   ///
126845dc178SNicolas Vasilache   /// ```
127ecaf2c33SPetr Kurapov   ///   %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128ecaf2c33SPetr Kurapov   ///     gpu.yield %cst : f32
129845dc178SNicolas Vasilache   ///   }
130845dc178SNicolas Vasilache   ///   // Both types are f32. The constant %cst is broadcasted to all lanes.
131845dc178SNicolas Vasilache   /// ```
132845dc178SNicolas Vasilache   /// This behavior described in more detail in the documentation of the op.
1334abb9e5dSThomas Raoux   Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
1344abb9e5dSThomas Raoux 
1354abb9e5dSThomas Raoux     // Scalar case can directly use memref.store.
1365550c821STres Popp     if (!isa<VectorType>(type))
137fa8a10a1SNicolas Vasilache       return b.create<memref::LoadOp>(loc, buffer, zero);
138fa8a10a1SNicolas Vasilache 
139fa8a10a1SNicolas Vasilache     // Other cases must be vector atm.
140fa8a10a1SNicolas Vasilache     // Vector case must use vector::TransferReadOp which will later lower to
141fa8a10a1SNicolas Vasilache     //   vector.read of memref.read depending on further lowerings.
142fa8a10a1SNicolas Vasilache     assert((type == distributedVectorType || type == sequentialVectorType) &&
143fa8a10a1SNicolas Vasilache            "Must store either the preregistered distributed or the "
144fa8a10a1SNicolas Vasilache            "preregistered sequential type.");
145fa8a10a1SNicolas Vasilache     SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146fa8a10a1SNicolas Vasilache     if (type == distributedVectorType) {
1474abb9e5dSThomas Raoux       for (auto dimExpr : distributionMap.getResults()) {
1481609f1c2Slong.chen         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
1494abb9e5dSThomas Raoux         indices[index] = buildDistributedOffset(b, loc, index);
1504abb9e5dSThomas Raoux       }
151d02f10d9SThomas Raoux     }
152fa8a10a1SNicolas Vasilache     SmallVector<bool> inBounds(indices.size(), true);
153fa8a10a1SNicolas Vasilache     return b.create<vector::TransferReadOp>(
1545550c821STres Popp         loc, cast<VectorType>(type), buffer, indices,
155fa8a10a1SNicolas Vasilache         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156d02f10d9SThomas Raoux   }
157d02f10d9SThomas Raoux 
158fa8a10a1SNicolas Vasilache   Value sequentialVal, distributedVal, laneId, zero;
159fa8a10a1SNicolas Vasilache   VectorType sequentialVectorType, distributedVectorType;
1604abb9e5dSThomas Raoux   AffineMap distributionMap;
161fa8a10a1SNicolas Vasilache };
162d02f10d9SThomas Raoux 
163fa8a10a1SNicolas Vasilache } // namespace
164d02f10d9SThomas Raoux 
16576cf33daSThomas Raoux // Clones `op` into a new operation that takes `operands` and returns
16676cf33daSThomas Raoux // `resultTypes`.
16776cf33daSThomas Raoux static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
16876cf33daSThomas Raoux                                               Location loc, Operation *op,
16976cf33daSThomas Raoux                                               ArrayRef<Value> operands,
17076cf33daSThomas Raoux                                               ArrayRef<Type> resultTypes) {
17176cf33daSThomas Raoux   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
17276cf33daSThomas Raoux                      op->getAttrs());
17376cf33daSThomas Raoux   return rewriter.create(res);
17476cf33daSThomas Raoux }
17576cf33daSThomas Raoux 
176d02f10d9SThomas Raoux namespace {
177d02f10d9SThomas Raoux 
178fa8a10a1SNicolas Vasilache /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179fa8a10a1SNicolas Vasilache /// thread `laneId` executes the entirety of the computation.
180fa8a10a1SNicolas Vasilache ///
181fa8a10a1SNicolas Vasilache /// After the transformation:
182fa8a10a1SNicolas Vasilache ///   - the IR within the scf.if op can be thought of as executing sequentially
183fa8a10a1SNicolas Vasilache ///     (from the point of view of threads along `laneId`).
184fa8a10a1SNicolas Vasilache ///   - the IR outside of the scf.if op can be thought of as executing in
185fa8a10a1SNicolas Vasilache ///     parallel (from the point of view of threads along `laneId`).
186fa8a10a1SNicolas Vasilache ///
187fa8a10a1SNicolas Vasilache /// Values that need to transit through the parallel / sequential and the
188fa8a10a1SNicolas Vasilache /// sequential / parallel boundaries do so via reads and writes to a temporary
189fa8a10a1SNicolas Vasilache /// memory location.
190fa8a10a1SNicolas Vasilache ///
191fa8a10a1SNicolas Vasilache /// The transformation proceeds in multiple steps:
192fa8a10a1SNicolas Vasilache ///   1. Create the scf.if op.
193fa8a10a1SNicolas Vasilache ///   2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194fa8a10a1SNicolas Vasilache ///      within the scf.if to transit the values captured from above.
195fa8a10a1SNicolas Vasilache ///   3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196fa8a10a1SNicolas Vasilache ///      consistent within the scf.if.
197fa8a10a1SNicolas Vasilache ///   4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198fa8a10a1SNicolas Vasilache ///   5. Insert appropriate writes within scf.if and reads after the scf.if to
199fa8a10a1SNicolas Vasilache ///      transit the values returned by the op.
200fa8a10a1SNicolas Vasilache ///   6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201fa8a10a1SNicolas Vasilache ///      consistent after the scf.if.
202fa8a10a1SNicolas Vasilache ///   7. Perform late cleanups.
203fa8a10a1SNicolas Vasilache ///
204fa8a10a1SNicolas Vasilache /// All this assumes the vector distribution occurs along the most minor
205fa8a10a1SNicolas Vasilache /// distributed vector dimension.
206*bc29fc93SPetr Kurapov struct WarpOpToScfIfPattern : public WarpDistributionPattern {
2074abb9e5dSThomas Raoux   WarpOpToScfIfPattern(MLIRContext *context,
208d02f10d9SThomas Raoux                        const WarpExecuteOnLane0LoweringOptions &options,
209d02f10d9SThomas Raoux                        PatternBenefit benefit = 1)
210*bc29fc93SPetr Kurapov       : WarpDistributionPattern(context, benefit), options(options) {}
211d02f10d9SThomas Raoux 
212d02f10d9SThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213d02f10d9SThomas Raoux                                 PatternRewriter &rewriter) const override {
214fa8a10a1SNicolas Vasilache     assert(warpOp.getBodyRegion().hasOneBlock() &&
215fa8a10a1SNicolas Vasilache            "expected WarpOp with single block");
216fa8a10a1SNicolas Vasilache     Block *warpOpBody = &warpOp.getBodyRegion().front();
217fa8a10a1SNicolas Vasilache     Location loc = warpOp.getLoc();
218fa8a10a1SNicolas Vasilache 
219fa8a10a1SNicolas Vasilache     // Passed all checks. Start rewriting.
220fa8a10a1SNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
221fa8a10a1SNicolas Vasilache     rewriter.setInsertionPoint(warpOp);
222fa8a10a1SNicolas Vasilache 
223fa8a10a1SNicolas Vasilache     // Step 1: Create scf.if op.
224fa8a10a1SNicolas Vasilache     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
225fa8a10a1SNicolas Vasilache     Value isLane0 = rewriter.create<arith::CmpIOp>(
226fa8a10a1SNicolas Vasilache         loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227fa8a10a1SNicolas Vasilache     auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228fa8a10a1SNicolas Vasilache                                            /*withElseRegion=*/false);
229fa8a10a1SNicolas Vasilache     rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230fa8a10a1SNicolas Vasilache 
231fa8a10a1SNicolas Vasilache     // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232fa8a10a1SNicolas Vasilache     // reads within the scf.if to transit the values captured from above.
233fa8a10a1SNicolas Vasilache     SmallVector<Value> bbArgReplacements;
234fa8a10a1SNicolas Vasilache     for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235fa8a10a1SNicolas Vasilache       Value sequentialVal = warpOpBody->getArgument(it.index());
236fa8a10a1SNicolas Vasilache       Value distributedVal = it.value();
237fa8a10a1SNicolas Vasilache       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238fa8a10a1SNicolas Vasilache                                         warpOp.getLaneid(), c0);
239fa8a10a1SNicolas Vasilache 
240fa8a10a1SNicolas Vasilache       // Create buffer before the ifOp.
241fa8a10a1SNicolas Vasilache       rewriter.setInsertionPoint(ifOp);
242fa8a10a1SNicolas Vasilache       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243fa8a10a1SNicolas Vasilache                                               sequentialVal.getType());
244fa8a10a1SNicolas Vasilache       // Store distributed vector into buffer, before the ifOp.
245fa8a10a1SNicolas Vasilache       helper.buildStore(rewriter, loc, distributedVal, buffer);
246fa8a10a1SNicolas Vasilache       // Load sequential vector from buffer, inside the ifOp.
247fa8a10a1SNicolas Vasilache       rewriter.setInsertionPointToStart(ifOp.thenBlock());
2484abb9e5dSThomas Raoux       bbArgReplacements.push_back(
2494abb9e5dSThomas Raoux           helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250fa8a10a1SNicolas Vasilache     }
251fa8a10a1SNicolas Vasilache 
252fa8a10a1SNicolas Vasilache     // Step 3. Insert sync after all the stores and before all the loads.
253fa8a10a1SNicolas Vasilache     if (!warpOp.getArgs().empty()) {
254fa8a10a1SNicolas Vasilache       rewriter.setInsertionPoint(ifOp);
255fa8a10a1SNicolas Vasilache       options.warpSyncronizationFn(loc, rewriter, warpOp);
256fa8a10a1SNicolas Vasilache     }
257fa8a10a1SNicolas Vasilache 
258fa8a10a1SNicolas Vasilache     // Step 4. Move body of warpOp to ifOp.
259fa8a10a1SNicolas Vasilache     rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260fa8a10a1SNicolas Vasilache 
261fa8a10a1SNicolas Vasilache     // Step 5. Insert appropriate writes within scf.if and reads after the
262fa8a10a1SNicolas Vasilache     // scf.if to transit the values returned by the op.
263fa8a10a1SNicolas Vasilache     // TODO: at this point, we can reuse the shared memory from previous
264fa8a10a1SNicolas Vasilache     // buffers.
265fa8a10a1SNicolas Vasilache     SmallVector<Value> replacements;
266ecaf2c33SPetr Kurapov     auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267fa8a10a1SNicolas Vasilache     Location yieldLoc = yieldOp.getLoc();
268b74192b7SRiver Riddle     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269fa8a10a1SNicolas Vasilache       Value sequentialVal = it.value();
270fa8a10a1SNicolas Vasilache       Value distributedVal = warpOp->getResult(it.index());
271fa8a10a1SNicolas Vasilache       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272fa8a10a1SNicolas Vasilache                                         warpOp.getLaneid(), c0);
273fa8a10a1SNicolas Vasilache 
274fa8a10a1SNicolas Vasilache       // Create buffer before the ifOp.
275fa8a10a1SNicolas Vasilache       rewriter.setInsertionPoint(ifOp);
276fa8a10a1SNicolas Vasilache       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277fa8a10a1SNicolas Vasilache                                               sequentialVal.getType());
278fa8a10a1SNicolas Vasilache 
279fa8a10a1SNicolas Vasilache       // Store yielded value into buffer, inside the ifOp, before the
280fa8a10a1SNicolas Vasilache       // terminator.
281fa8a10a1SNicolas Vasilache       rewriter.setInsertionPoint(yieldOp);
282fa8a10a1SNicolas Vasilache       helper.buildStore(rewriter, loc, sequentialVal, buffer);
283fa8a10a1SNicolas Vasilache 
284fa8a10a1SNicolas Vasilache       // Load distributed value from buffer, after  the warpOp.
285fa8a10a1SNicolas Vasilache       rewriter.setInsertionPointAfter(ifOp);
286fa8a10a1SNicolas Vasilache       // Result type and yielded value type are the same. This is a broadcast.
287fa8a10a1SNicolas Vasilache       // E.g.:
288ecaf2c33SPetr Kurapov       // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289ecaf2c33SPetr Kurapov       //   gpu.yield %cst : f32
290fa8a10a1SNicolas Vasilache       // }
291fa8a10a1SNicolas Vasilache       // Both types are f32. The constant %cst is broadcasted to all lanes.
292fa8a10a1SNicolas Vasilache       // This is described in more detail in the documentation of the op.
2934abb9e5dSThomas Raoux       replacements.push_back(
2944abb9e5dSThomas Raoux           helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295fa8a10a1SNicolas Vasilache     }
296fa8a10a1SNicolas Vasilache 
297fa8a10a1SNicolas Vasilache     // Step 6. Insert sync after all the stores and before all the loads.
298b74192b7SRiver Riddle     if (!yieldOp.getOperands().empty()) {
299fa8a10a1SNicolas Vasilache       rewriter.setInsertionPointAfter(ifOp);
300fa8a10a1SNicolas Vasilache       options.warpSyncronizationFn(loc, rewriter, warpOp);
301fa8a10a1SNicolas Vasilache     }
302fa8a10a1SNicolas Vasilache 
303fa8a10a1SNicolas Vasilache     // Step 7. Delete terminator and add empty scf.yield.
304fa8a10a1SNicolas Vasilache     rewriter.eraseOp(yieldOp);
305fa8a10a1SNicolas Vasilache     rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306fa8a10a1SNicolas Vasilache     rewriter.create<scf::YieldOp>(yieldLoc);
307fa8a10a1SNicolas Vasilache 
308fa8a10a1SNicolas Vasilache     // Compute replacements for WarpOp results.
309fa8a10a1SNicolas Vasilache     rewriter.replaceOp(warpOp, replacements);
310fa8a10a1SNicolas Vasilache 
311fa8a10a1SNicolas Vasilache     return success();
312d02f10d9SThomas Raoux   }
313d02f10d9SThomas Raoux 
314d02f10d9SThomas Raoux private:
315d02f10d9SThomas Raoux   const WarpExecuteOnLane0LoweringOptions &options;
316d02f10d9SThomas Raoux };
317d02f10d9SThomas Raoux 
31891f62f0eSThomas Raoux /// Return the distributed vector type based on the original type and the
31991f62f0eSThomas Raoux /// distribution map. The map is expected to have a dimension equal to the
32091f62f0eSThomas Raoux /// original type rank and should be a projection where the results are the
32191f62f0eSThomas Raoux /// distributed dimensions. The number of results should be equal to the number
32291f62f0eSThomas Raoux /// of warp sizes which is currently limited to 1.
32391f62f0eSThomas Raoux /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
32491f62f0eSThomas Raoux /// and a warp size of 16 would distribute the second dimension (associated to
32591f62f0eSThomas Raoux /// d1) and return vector<16x2x64>
32691f62f0eSThomas Raoux static VectorType getDistributedType(VectorType originalType, AffineMap map,
32791f62f0eSThomas Raoux                                      int64_t warpSize) {
3285262865aSKazu Hirata   SmallVector<int64_t> targetShape(originalType.getShape());
32991f62f0eSThomas Raoux   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
33091f62f0eSThomas Raoux     unsigned position = map.getDimPosition(i);
331c2b95292SQuinn Dawkins     if (targetShape[position] % warpSize != 0) {
332c2b95292SQuinn Dawkins       if (warpSize % targetShape[position] != 0) {
33391f62f0eSThomas Raoux         return VectorType();
334c2b95292SQuinn Dawkins       }
335c2b95292SQuinn Dawkins       warpSize /= targetShape[position];
336c2b95292SQuinn Dawkins       targetShape[position] = 1;
337c2b95292SQuinn Dawkins       continue;
338c2b95292SQuinn Dawkins     }
33991f62f0eSThomas Raoux     targetShape[position] = targetShape[position] / warpSize;
340c2b95292SQuinn Dawkins     warpSize = 1;
341c2b95292SQuinn Dawkins     break;
342c2b95292SQuinn Dawkins   }
343c2b95292SQuinn Dawkins   if (warpSize != 1) {
344c2b95292SQuinn Dawkins     return VectorType();
34591f62f0eSThomas Raoux   }
34691f62f0eSThomas Raoux   VectorType targetType =
34791f62f0eSThomas Raoux       VectorType::get(targetShape, originalType.getElementType());
34891f62f0eSThomas Raoux   return targetType;
34991f62f0eSThomas Raoux }
35091f62f0eSThomas Raoux 
351ed0288f7SThomas Raoux /// Distribute transfer_write ops based on the affine map returned by
35280636227SJakub Kuderski /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
35380636227SJakub Kuderski /// will not be distributed (it should be less than the warp size).
35480636227SJakub Kuderski ///
355ed0288f7SThomas Raoux /// Example:
356ed0288f7SThomas Raoux /// ```
357ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%id){
358ed0288f7SThomas Raoux ///   ...
359ed0288f7SThomas Raoux ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360ecaf2c33SPetr Kurapov ///   gpu.yield
361ed0288f7SThomas Raoux /// }
362ed0288f7SThomas Raoux /// ```
363ed0288f7SThomas Raoux /// To
364ed0288f7SThomas Raoux /// ```
365ecaf2c33SPetr Kurapov /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366ed0288f7SThomas Raoux ///   ...
367ecaf2c33SPetr Kurapov ///   gpu.yield %v : vector<32xf32>
368ed0288f7SThomas Raoux /// }
369ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370*bc29fc93SPetr Kurapov struct WarpOpTransferWrite : public WarpDistributionPattern {
371ed0288f7SThomas Raoux   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
37280636227SJakub Kuderski                       unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373*bc29fc93SPetr Kurapov       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
37480636227SJakub Kuderski         maxNumElementsToExtract(maxNumElementsToExtract) {}
375ed0288f7SThomas Raoux 
376ed0288f7SThomas Raoux   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377ed0288f7SThomas Raoux   /// are multiples of the distribution ratio are supported at the moment.
378ed0288f7SThomas Raoux   LogicalResult tryDistributeOp(RewriterBase &rewriter,
379ed0288f7SThomas Raoux                                 vector::TransferWriteOp writeOp,
380ed0288f7SThomas Raoux                                 WarpExecuteOnLane0Op warpOp) const {
3816a57d8fbSNicolas Vasilache     VectorType writtenVectorType = writeOp.getVectorType();
3826a57d8fbSNicolas Vasilache 
3836a57d8fbSNicolas Vasilache     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
3846a57d8fbSNicolas Vasilache     // to separate it from the rest.
3856a57d8fbSNicolas Vasilache     if (writtenVectorType.getRank() == 0)
3866a57d8fbSNicolas Vasilache       return failure();
3876a57d8fbSNicolas Vasilache 
38891f62f0eSThomas Raoux     // 2. Compute the distributed type.
38991f62f0eSThomas Raoux     AffineMap map = distributionMapFn(writeOp.getVector());
390ed0288f7SThomas Raoux     VectorType targetType =
39191f62f0eSThomas Raoux         getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
39291f62f0eSThomas Raoux     if (!targetType)
39391f62f0eSThomas Raoux       return failure();
394ed0288f7SThomas Raoux 
39525ec1fa9SQuinn Dawkins     // 2.5 Compute the distributed type for the new mask;
39625ec1fa9SQuinn Dawkins     VectorType maskType;
39725ec1fa9SQuinn Dawkins     if (writeOp.getMask()) {
39825ec1fa9SQuinn Dawkins       // TODO: Distribution of masked writes with non-trivial permutation maps
39925ec1fa9SQuinn Dawkins       // requires the distribution of the mask to elementwise match the
40025ec1fa9SQuinn Dawkins       // distribution of the permuted written vector. Currently the details
40125ec1fa9SQuinn Dawkins       // of which lane is responsible for which element is captured strictly
40225ec1fa9SQuinn Dawkins       // by shape information on the warp op, and thus requires materializing
40325ec1fa9SQuinn Dawkins       // the permutation in IR.
40425ec1fa9SQuinn Dawkins       if (!writeOp.getPermutationMap().isMinorIdentity())
40525ec1fa9SQuinn Dawkins         return failure();
40625ec1fa9SQuinn Dawkins       maskType =
40725ec1fa9SQuinn Dawkins           getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
40825ec1fa9SQuinn Dawkins     }
40925ec1fa9SQuinn Dawkins 
41091f62f0eSThomas Raoux     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
4116a57d8fbSNicolas Vasilache     // the rest.
4126a57d8fbSNicolas Vasilache     vector::TransferWriteOp newWriteOp =
41325ec1fa9SQuinn Dawkins         cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414ed0288f7SThomas Raoux 
41591f62f0eSThomas Raoux     // 4. Reindex the write using the distribution map.
4166a57d8fbSNicolas Vasilache     auto newWarpOp =
4176a57d8fbSNicolas Vasilache         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418c2b95292SQuinn Dawkins 
419c2b95292SQuinn Dawkins     // Delinearize the lane id based on the way threads are divided across the
420c2b95292SQuinn Dawkins     // vector. To get the number of threads per vector dimension, divide the
421c2b95292SQuinn Dawkins     // sequential size by the distributed size along each dim.
422ed0288f7SThomas Raoux     rewriter.setInsertionPoint(newWriteOp);
423c2b95292SQuinn Dawkins     SmallVector<OpFoldResult> delinearizedIdSizes;
424c2b95292SQuinn Dawkins     for (auto [seqSize, distSize] :
425c2b95292SQuinn Dawkins          llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426c2b95292SQuinn Dawkins       assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427c2b95292SQuinn Dawkins       delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428c2b95292SQuinn Dawkins     }
429c2b95292SQuinn Dawkins     SmallVector<Value> delinearized;
430c2b95292SQuinn Dawkins     if (map.getNumResults() > 1) {
431c2b95292SQuinn Dawkins       delinearized = rewriter
432c2b95292SQuinn Dawkins                          .create<mlir::affine::AffineDelinearizeIndexOp>(
433c2b95292SQuinn Dawkins                              newWarpOp.getLoc(), newWarpOp.getLaneid(),
434c2b95292SQuinn Dawkins                              delinearizedIdSizes)
435c2b95292SQuinn Dawkins                          .getResults();
436c2b95292SQuinn Dawkins     } else {
437c2b95292SQuinn Dawkins       // If there is only one map result, we can elide the delinearization
438c2b95292SQuinn Dawkins       // op and use the lane id directly.
439c2b95292SQuinn Dawkins       delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440c2b95292SQuinn Dawkins     }
441c2b95292SQuinn Dawkins 
442ed0288f7SThomas Raoux     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443ed0288f7SThomas Raoux     Location loc = newWriteOp.getLoc();
444ed0288f7SThomas Raoux     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445ed0288f7SThomas Raoux                                newWriteOp.getIndices().end());
446ed0288f7SThomas Raoux     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447ed0288f7SThomas Raoux       AffineExpr d0, d1;
448ed0288f7SThomas Raoux       bindDims(newWarpOp.getContext(), d0, d1);
4491609f1c2Slong.chen       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450ed0288f7SThomas Raoux       if (!indexExpr)
451ed0288f7SThomas Raoux         continue;
452ed0288f7SThomas Raoux       unsigned indexPos = indexExpr.getPosition();
4531609f1c2Slong.chen       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454c2b95292SQuinn Dawkins       Value laneId = delinearized[vectorPos];
45591f62f0eSThomas Raoux       auto scale =
45691f62f0eSThomas Raoux           rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
4574c48f016SMatthias Springer       indices[indexPos] = affine::makeComposedAffineApply(
458c2b95292SQuinn Dawkins           rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459ed0288f7SThomas Raoux     }
460ed0288f7SThomas Raoux     newWriteOp.getIndicesMutable().assign(indices);
461ed0288f7SThomas Raoux 
462ed0288f7SThomas Raoux     return success();
463ed0288f7SThomas Raoux   }
464ed0288f7SThomas Raoux 
465ed0288f7SThomas Raoux   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466ed0288f7SThomas Raoux   LogicalResult tryExtractOp(RewriterBase &rewriter,
467ed0288f7SThomas Raoux                              vector::TransferWriteOp writeOp,
468ed0288f7SThomas Raoux                              WarpExecuteOnLane0Op warpOp) const {
469ed0288f7SThomas Raoux     Location loc = writeOp.getLoc();
470ed0288f7SThomas Raoux     VectorType vecType = writeOp.getVectorType();
471ed0288f7SThomas Raoux 
47280636227SJakub Kuderski     if (vecType.getNumElements() > maxNumElementsToExtract) {
47380636227SJakub Kuderski       return rewriter.notifyMatchFailure(
47480636227SJakub Kuderski           warpOp,
47580636227SJakub Kuderski           llvm::formatv(
47680636227SJakub Kuderski               "writes more elements ({0}) than allowed to extract ({1})",
47780636227SJakub Kuderski               vecType.getNumElements(), maxNumElementsToExtract));
47880636227SJakub Kuderski     }
479ed0288f7SThomas Raoux 
480ed0288f7SThomas Raoux     // Do not process warp ops that contain only TransferWriteOps.
481971b8525SJakub Kuderski     if (llvm::all_of(warpOp.getOps(),
482ecaf2c33SPetr Kurapov                      llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483ed0288f7SThomas Raoux       return failure();
484ed0288f7SThomas Raoux 
485ed0288f7SThomas Raoux     SmallVector<Value> yieldValues = {writeOp.getVector()};
486ed0288f7SThomas Raoux     SmallVector<Type> retTypes = {vecType};
487d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
488ed0288f7SThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490ed0288f7SThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
491ed0288f7SThomas Raoux 
492ed0288f7SThomas Raoux     // Create a second warp op that contains only writeOp.
493ed0288f7SThomas Raoux     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494ed0288f7SThomas Raoux         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495ed0288f7SThomas Raoux     Block &body = secondWarpOp.getBodyRegion().front();
496ed0288f7SThomas Raoux     rewriter.setInsertionPointToStart(&body);
497ed0288f7SThomas Raoux     auto newWriteOp =
498ed0288f7SThomas Raoux         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499d7d6443dSThomas Raoux     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
500ed0288f7SThomas Raoux     rewriter.eraseOp(writeOp);
501ecaf2c33SPetr Kurapov     rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
502ed0288f7SThomas Raoux     return success();
503ed0288f7SThomas Raoux   }
504ed0288f7SThomas Raoux 
505df49a97aSQuinn Dawkins   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
506ed0288f7SThomas Raoux                                 PatternRewriter &rewriter) const override {
507ecaf2c33SPetr Kurapov     auto yield = cast<gpu::YieldOp>(
508df49a97aSQuinn Dawkins         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
509df49a97aSQuinn Dawkins     Operation *lastNode = yield->getPrevNode();
510df49a97aSQuinn Dawkins     auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
511df49a97aSQuinn Dawkins     if (!writeOp)
512ed0288f7SThomas Raoux       return failure();
513ed0288f7SThomas Raoux 
51425ec1fa9SQuinn Dawkins     Value maybeMask = writeOp.getMask();
515ed0288f7SThomas Raoux     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
516ed0288f7SThomas Raoux           return writeOp.getVector() == value ||
51725ec1fa9SQuinn Dawkins                  (maybeMask && maybeMask == value) ||
518ed0288f7SThomas Raoux                  warpOp.isDefinedOutsideOfRegion(value);
519ed0288f7SThomas Raoux         }))
520ed0288f7SThomas Raoux       return failure();
521ed0288f7SThomas Raoux 
522ed0288f7SThomas Raoux     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
523ed0288f7SThomas Raoux       return success();
524ed0288f7SThomas Raoux 
52525ec1fa9SQuinn Dawkins     // Masked writes not supported for extraction.
52625ec1fa9SQuinn Dawkins     if (writeOp.getMask())
52725ec1fa9SQuinn Dawkins       return failure();
52825ec1fa9SQuinn Dawkins 
529ed0288f7SThomas Raoux     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
530ed0288f7SThomas Raoux       return success();
531ed0288f7SThomas Raoux 
532ed0288f7SThomas Raoux     return failure();
533ed0288f7SThomas Raoux   }
534ed0288f7SThomas Raoux 
535ed0288f7SThomas Raoux private:
536*bc29fc93SPetr Kurapov   /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
537*bc29fc93SPetr Kurapov   /// execute op with the proper return type. The new write op is updated to
538*bc29fc93SPetr Kurapov   /// write the result of the new warp execute op. The old `writeOp` is deleted.
539*bc29fc93SPetr Kurapov   vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
540*bc29fc93SPetr Kurapov                                        WarpExecuteOnLane0Op warpOp,
541*bc29fc93SPetr Kurapov                                        vector::TransferWriteOp writeOp,
542*bc29fc93SPetr Kurapov                                        VectorType targetType,
543*bc29fc93SPetr Kurapov                                        VectorType maybeMaskType) const {
544*bc29fc93SPetr Kurapov     assert(writeOp->getParentOp() == warpOp &&
545*bc29fc93SPetr Kurapov            "write must be nested immediately under warp");
546*bc29fc93SPetr Kurapov     OpBuilder::InsertionGuard g(rewriter);
547*bc29fc93SPetr Kurapov     SmallVector<size_t> newRetIndices;
548*bc29fc93SPetr Kurapov     WarpExecuteOnLane0Op newWarpOp;
549*bc29fc93SPetr Kurapov     if (maybeMaskType) {
550*bc29fc93SPetr Kurapov       newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
551*bc29fc93SPetr Kurapov           rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
552*bc29fc93SPetr Kurapov           TypeRange{targetType, maybeMaskType}, newRetIndices);
553*bc29fc93SPetr Kurapov     } else {
554*bc29fc93SPetr Kurapov       newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
555*bc29fc93SPetr Kurapov           rewriter, warpOp, ValueRange{{writeOp.getVector()}},
556*bc29fc93SPetr Kurapov           TypeRange{targetType}, newRetIndices);
557*bc29fc93SPetr Kurapov     }
558*bc29fc93SPetr Kurapov     rewriter.setInsertionPointAfter(newWarpOp);
559*bc29fc93SPetr Kurapov     auto newWriteOp =
560*bc29fc93SPetr Kurapov         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
561*bc29fc93SPetr Kurapov     rewriter.eraseOp(writeOp);
562*bc29fc93SPetr Kurapov     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
563*bc29fc93SPetr Kurapov     if (maybeMaskType)
564*bc29fc93SPetr Kurapov       newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
565*bc29fc93SPetr Kurapov     return newWriteOp;
566*bc29fc93SPetr Kurapov   }
567*bc29fc93SPetr Kurapov 
568ed0288f7SThomas Raoux   DistributionMapFn distributionMapFn;
56980636227SJakub Kuderski   unsigned maxNumElementsToExtract = 1;
570ed0288f7SThomas Raoux };
571ed0288f7SThomas Raoux 
57276cf33daSThomas Raoux /// Sink out elementwise op feeding into a warp op yield.
57376cf33daSThomas Raoux /// ```
574ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
57576cf33daSThomas Raoux ///   ...
57676cf33daSThomas Raoux ///   %3 = arith.addf %1, %2 : vector<32xf32>
577ecaf2c33SPetr Kurapov ///   gpu.yield %3 : vector<32xf32>
57876cf33daSThomas Raoux /// }
57976cf33daSThomas Raoux /// ```
58076cf33daSThomas Raoux /// To
58176cf33daSThomas Raoux /// ```
582ecaf2c33SPetr Kurapov /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
58376cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
58476cf33daSThomas Raoux ///   ...
58576cf33daSThomas Raoux ///   %4 = arith.addf %2, %3 : vector<32xf32>
586ecaf2c33SPetr Kurapov ///   gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
58776cf33daSThomas Raoux ///   vector<32xf32>
58876cf33daSThomas Raoux /// }
58976cf33daSThomas Raoux /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
590*bc29fc93SPetr Kurapov struct WarpOpElementwise : public WarpDistributionPattern {
591*bc29fc93SPetr Kurapov   using Base::Base;
59276cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
59376cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
59476cf33daSThomas Raoux     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
59576cf33daSThomas Raoux       return OpTrait::hasElementwiseMappableTraits(op);
59676cf33daSThomas Raoux     });
59776cf33daSThomas Raoux     if (!yieldOperand)
59876cf33daSThomas Raoux       return failure();
599aa2376a0SQuinn Dawkins 
60076cf33daSThomas Raoux     Operation *elementWise = yieldOperand->get().getDefiningOp();
60176cf33daSThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
60276cf33daSThomas Raoux     Value distributedVal = warpOp.getResult(operandIndex);
60376cf33daSThomas Raoux     SmallVector<Value> yieldValues;
60476cf33daSThomas Raoux     SmallVector<Type> retTypes;
60576cf33daSThomas Raoux     Location loc = warpOp.getLoc();
60676cf33daSThomas Raoux     for (OpOperand &operand : elementWise->getOpOperands()) {
60776cf33daSThomas Raoux       Type targetType;
6085550c821STres Popp       if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
60976cf33daSThomas Raoux         // If the result type is a vector, the operands must also be vectors.
6105550c821STres Popp         auto operandType = cast<VectorType>(operand.get().getType());
61176cf33daSThomas Raoux         targetType =
61276cf33daSThomas Raoux             VectorType::get(vecType.getShape(), operandType.getElementType());
61376cf33daSThomas Raoux       } else {
61476cf33daSThomas Raoux         auto operandType = operand.get().getType();
6155550c821STres Popp         assert(!isa<VectorType>(operandType) &&
61676cf33daSThomas Raoux                "unexpected yield of vector from op with scalar result type");
61776cf33daSThomas Raoux         targetType = operandType;
61876cf33daSThomas Raoux       }
61976cf33daSThomas Raoux       retTypes.push_back(targetType);
62076cf33daSThomas Raoux       yieldValues.push_back(operand.get());
62176cf33daSThomas Raoux     }
622d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
62376cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
624d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
62576cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
62676cf33daSThomas Raoux     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
62776cf33daSThomas Raoux                                    elementWise->getOperands().end());
62876cf33daSThomas Raoux     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
629d7d6443dSThomas Raoux       newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
63076cf33daSThomas Raoux     }
63176cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
63276cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
63376cf33daSThomas Raoux     Operation *newOp = cloneOpWithOperandsAndTypes(
63476cf33daSThomas Raoux         rewriter, loc, elementWise, newOperands,
63576cf33daSThomas Raoux         {newWarpOp.getResult(operandIndex).getType()});
6367ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
6377ecc921dSMatthias Springer                                 newOp->getResult(0));
63876cf33daSThomas Raoux     return success();
63976cf33daSThomas Raoux   }
64076cf33daSThomas Raoux };
64176cf33daSThomas Raoux 
6420af26805SThomas Raoux /// Sink out splat constant op feeding into a warp op yield.
6430af26805SThomas Raoux /// ```
644ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
6450af26805SThomas Raoux ///   ...
6460af26805SThomas Raoux ///   %cst = arith.constant dense<2.0> : vector<32xf32>
647ecaf2c33SPetr Kurapov ///   gpu.yield %cst : vector<32xf32>
6480af26805SThomas Raoux /// }
6490af26805SThomas Raoux /// ```
6500af26805SThomas Raoux /// To
6510af26805SThomas Raoux /// ```
652ecaf2c33SPetr Kurapov /// gpu.warp_execute_on_lane_0(%arg0 {
6530af26805SThomas Raoux ///   ...
6540af26805SThomas Raoux /// }
6550af26805SThomas Raoux /// %0 = arith.constant dense<2.0> : vector<1xf32>
656*bc29fc93SPetr Kurapov struct WarpOpConstant : public WarpDistributionPattern {
657*bc29fc93SPetr Kurapov   using Base::Base;
6580af26805SThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
6590af26805SThomas Raoux                                 PatternRewriter &rewriter) const override {
660971b8525SJakub Kuderski     OpOperand *yieldOperand =
661971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
6620af26805SThomas Raoux     if (!yieldOperand)
6630af26805SThomas Raoux       return failure();
6640af26805SThomas Raoux     auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
6655550c821STres Popp     auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
6660af26805SThomas Raoux     if (!dense)
6670af26805SThomas Raoux       return failure();
668aa2376a0SQuinn Dawkins     // Notify the rewriter that the warp op is changing (see the comment on
669aa2376a0SQuinn Dawkins     // the WarpOpTransferRead pattern).
6705fcf907bSMatthias Springer     rewriter.startOpModification(warpOp);
6710af26805SThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
6720af26805SThomas Raoux     Attribute scalarAttr = dense.getSplatValue<Attribute>();
6736089d612SRahul Kayaith     auto newAttr = DenseElementsAttr::get(
6746089d612SRahul Kayaith         cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
6750af26805SThomas Raoux     Location loc = warpOp.getLoc();
6760af26805SThomas Raoux     rewriter.setInsertionPointAfter(warpOp);
6770af26805SThomas Raoux     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
6787ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
6795fcf907bSMatthias Springer     rewriter.finalizeOpModification(warpOp);
6800af26805SThomas Raoux     return success();
6810af26805SThomas Raoux   }
6820af26805SThomas Raoux };
6830af26805SThomas Raoux 
68476cf33daSThomas Raoux /// Sink out transfer_read op feeding into a warp op yield.
68576cf33daSThomas Raoux /// ```
686ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
68776cf33daSThomas Raoux ///   ...
68876cf33daSThomas Raoux //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
68976cf33daSThomas Raoux //    vector<32xf32>
690ecaf2c33SPetr Kurapov ///   gpu.yield %2 : vector<32xf32>
69176cf33daSThomas Raoux /// }
69276cf33daSThomas Raoux /// ```
69376cf33daSThomas Raoux /// To
69476cf33daSThomas Raoux /// ```
695ecaf2c33SPetr Kurapov /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
69676cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
69776cf33daSThomas Raoux ///   ...
69876cf33daSThomas Raoux ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
699ecaf2c33SPetr Kurapov ///   vector<32xf32> gpu.yield %2 : vector<32xf32>
70076cf33daSThomas Raoux /// }
70176cf33daSThomas Raoux /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
702*bc29fc93SPetr Kurapov struct WarpOpTransferRead : public WarpDistributionPattern {
703*bc29fc93SPetr Kurapov   using Base::Base;
70476cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
70576cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
7067360d5d3SQuinn Dawkins     // Try to find a distributable yielded read. Note that this pattern can
7077360d5d3SQuinn Dawkins     // still fail at the end after distribution, in which case this might have
7087360d5d3SQuinn Dawkins     // missed another distributable read.
7097360d5d3SQuinn Dawkins     OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
7107360d5d3SQuinn Dawkins       // Don't duplicate transfer_read ops when distributing.
7117360d5d3SQuinn Dawkins       return isa<vector::TransferReadOp>(op) && op->hasOneUse();
7127360d5d3SQuinn Dawkins     });
71376cf33daSThomas Raoux     if (!operand)
71435c19fddSMatthias Springer       return rewriter.notifyMatchFailure(
71535c19fddSMatthias Springer           warpOp, "warp result is not a vector.transfer_read op");
71676cf33daSThomas Raoux     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
7177360d5d3SQuinn Dawkins 
71835c19fddSMatthias Springer     // Source must be defined outside of the region.
71935c19fddSMatthias Springer     if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
72035c19fddSMatthias Springer       return rewriter.notifyMatchFailure(
72135c19fddSMatthias Springer           read, "source must be defined outside of the region");
72235c19fddSMatthias Springer 
72376cf33daSThomas Raoux     unsigned operandIndex = operand->getOperandNumber();
72476cf33daSThomas Raoux     Value distributedVal = warpOp.getResult(operandIndex);
72576cf33daSThomas Raoux 
72676cf33daSThomas Raoux     SmallVector<Value, 4> indices(read.getIndices().begin(),
72776cf33daSThomas Raoux                                   read.getIndices().end());
7285550c821STres Popp     auto sequentialType = cast<VectorType>(read.getResult().getType());
7295550c821STres Popp     auto distributedType = cast<VectorType>(distributedVal.getType());
7304abb9e5dSThomas Raoux     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
73176cf33daSThomas Raoux     AffineMap indexMap = map.compose(read.getPermutationMap());
732771f5759SQuinn Dawkins 
73335c19fddSMatthias Springer     // Try to delinearize the lane ID to match the rank expected for
73435c19fddSMatthias Springer     // distribution.
73535c19fddSMatthias Springer     SmallVector<Value> delinearizedIds;
73635c19fddSMatthias Springer     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
73735c19fddSMatthias Springer                            distributedType.getShape(), warpOp.getWarpSize(),
73835c19fddSMatthias Springer                            warpOp.getLaneid(), delinearizedIds)) {
73935c19fddSMatthias Springer       return rewriter.notifyMatchFailure(
74035c19fddSMatthias Springer           read, "cannot delinearize lane ID for distribution");
74135c19fddSMatthias Springer     }
74235c19fddSMatthias Springer     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
74335c19fddSMatthias Springer 
74435c19fddSMatthias Springer     // Distribute indices and the mask (if present).
74576cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
74635c19fddSMatthias Springer     SmallVector<Value> additionalResults(indices.begin(), indices.end());
74735c19fddSMatthias Springer     SmallVector<Type> additionalResultTypes(indices.size(),
74835c19fddSMatthias Springer                                             rewriter.getIndexType());
74935c19fddSMatthias Springer     additionalResults.push_back(read.getPadding());
75035c19fddSMatthias Springer     additionalResultTypes.push_back(read.getPadding().getType());
75135c19fddSMatthias Springer 
752aa2376a0SQuinn Dawkins     bool hasMask = false;
753771f5759SQuinn Dawkins     if (read.getMask()) {
754aa2376a0SQuinn Dawkins       hasMask = true;
755771f5759SQuinn Dawkins       // TODO: Distribution of masked reads with non-trivial permutation maps
756771f5759SQuinn Dawkins       // requires the distribution of the mask to elementwise match the
757771f5759SQuinn Dawkins       // distribution of the permuted written vector. Currently the details
758771f5759SQuinn Dawkins       // of which lane is responsible for which element is captured strictly
759771f5759SQuinn Dawkins       // by shape information on the warp op, and thus requires materializing
760771f5759SQuinn Dawkins       // the permutation in IR.
761f385f6c9SQuinn Dawkins       if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
76235c19fddSMatthias Springer         return rewriter.notifyMatchFailure(
76335c19fddSMatthias Springer             read, "non-trivial permutation maps not supported");
764771f5759SQuinn Dawkins       VectorType maskType =
765771f5759SQuinn Dawkins           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
76635c19fddSMatthias Springer       additionalResults.push_back(read.getMask());
76735c19fddSMatthias Springer       additionalResultTypes.push_back(maskType);
768771f5759SQuinn Dawkins     }
769771f5759SQuinn Dawkins 
77035c19fddSMatthias Springer     SmallVector<size_t> newRetIndices;
77135c19fddSMatthias Springer     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
77235c19fddSMatthias Springer         rewriter, warpOp, additionalResults, additionalResultTypes,
77335c19fddSMatthias Springer         newRetIndices);
77435c19fddSMatthias Springer     distributedVal = newWarpOp.getResult(operandIndex);
77535c19fddSMatthias Springer 
77635c19fddSMatthias Springer     // Distributed indices were appended first.
77735c19fddSMatthias Springer     SmallVector<Value> newIndices;
77835c19fddSMatthias Springer     for (int64_t i = 0, e = indices.size(); i < e; ++i)
77935c19fddSMatthias Springer       newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
78035c19fddSMatthias Springer 
781771f5759SQuinn Dawkins     rewriter.setInsertionPointAfter(newWarpOp);
782199442eaSLei Zhang     for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
78376cf33daSThomas Raoux       AffineExpr d0, d1;
78476cf33daSThomas Raoux       bindDims(read.getContext(), d0, d1);
7851609f1c2Slong.chen       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
78676cf33daSThomas Raoux       if (!indexExpr)
78776cf33daSThomas Raoux         continue;
78876cf33daSThomas Raoux       unsigned indexPos = indexExpr.getPosition();
7891609f1c2Slong.chen       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
79073ddc447SLei Zhang       int64_t scale = distributedType.getDimSize(vectorPos);
79135c19fddSMatthias Springer       newIndices[indexPos] = affine::makeComposedAffineApply(
7924c48f016SMatthias Springer           rewriter, read.getLoc(), d0 + scale * d1,
79335c19fddSMatthias Springer           {newIndices[indexPos], delinearizedIds[vectorPos]});
79476cf33daSThomas Raoux     }
79535c19fddSMatthias Springer 
79635c19fddSMatthias Springer     // Distributed padding value was appended right after the indices.
79735c19fddSMatthias Springer     Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
79835c19fddSMatthias Springer     // Distributed mask value was added at the end (if the op has a mask).
79935c19fddSMatthias Springer     Value newMask =
80035c19fddSMatthias Springer         hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
80135c19fddSMatthias Springer                 : Value();
802018d8ac9SQuentin Colombet     auto newRead = rewriter.create<vector::TransferReadOp>(
80335c19fddSMatthias Springer         read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
80435c19fddSMatthias Springer         read.getPermutationMapAttr(), newPadding, newMask,
80576cf33daSThomas Raoux         read.getInBoundsAttr());
806018d8ac9SQuentin Colombet 
8077ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(distributedVal, newRead);
80876cf33daSThomas Raoux     return success();
80976cf33daSThomas Raoux   }
81076cf33daSThomas Raoux };
81176cf33daSThomas Raoux 
81276cf33daSThomas Raoux /// Remove any result that has no use along with the matching yieldOp operand.
81376cf33daSThomas Raoux // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
814*bc29fc93SPetr Kurapov struct WarpOpDeadResult : public WarpDistributionPattern {
815*bc29fc93SPetr Kurapov   using Base::Base;
81676cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
81776cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
81820df17fdSNicolas Vasilache     SmallVector<Type> newResultTypes;
81920df17fdSNicolas Vasilache     newResultTypes.reserve(warpOp->getNumResults());
82020df17fdSNicolas Vasilache     SmallVector<Value> newYieldValues;
82120df17fdSNicolas Vasilache     newYieldValues.reserve(warpOp->getNumResults());
82220df17fdSNicolas Vasilache     DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
82320df17fdSNicolas Vasilache     DenseMap<OpResult, int64_t> dedupResultPositionMap;
824ecaf2c33SPetr Kurapov     auto yield = cast<gpu::YieldOp>(
82576cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
82620df17fdSNicolas Vasilache 
82720df17fdSNicolas Vasilache     // Some values may be yielded multiple times and correspond to multiple
82820df17fdSNicolas Vasilache     // results. Deduplicating occurs by taking each result with its matching
82920df17fdSNicolas Vasilache     // yielded value, and:
83020df17fdSNicolas Vasilache     //   1. recording the unique first position at which the value is yielded.
83120df17fdSNicolas Vasilache     //   2. recording for the result, the first position at which the dedup'ed
83220df17fdSNicolas Vasilache     //      value is yielded.
83320df17fdSNicolas Vasilache     //   3. skipping from the new result types / new yielded values any result
83420df17fdSNicolas Vasilache     //      that has no use or whose yielded value has already been seen.
83576cf33daSThomas Raoux     for (OpResult result : warpOp.getResults()) {
83620df17fdSNicolas Vasilache       Value yieldOperand = yield.getOperand(result.getResultNumber());
83720df17fdSNicolas Vasilache       auto it = dedupYieldOperandPositionMap.insert(
83820df17fdSNicolas Vasilache           std::make_pair(yieldOperand, newResultTypes.size()));
83920df17fdSNicolas Vasilache       dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
84020df17fdSNicolas Vasilache       if (result.use_empty() || !it.second)
84176cf33daSThomas Raoux         continue;
84220df17fdSNicolas Vasilache       newResultTypes.push_back(result.getType());
84320df17fdSNicolas Vasilache       newYieldValues.push_back(yieldOperand);
84476cf33daSThomas Raoux     }
84520df17fdSNicolas Vasilache     // No modification, exit early.
84620df17fdSNicolas Vasilache     if (yield.getNumOperands() == newYieldValues.size())
84776cf33daSThomas Raoux       return failure();
84820df17fdSNicolas Vasilache     // Move the body of the old warpOp to a new warpOp.
84976cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
85020df17fdSNicolas Vasilache         rewriter, warpOp, newYieldValues, newResultTypes);
8517360d5d3SQuinn Dawkins 
8527360d5d3SQuinn Dawkins     // Simplify the new warp op after dropping dead results.
8537360d5d3SQuinn Dawkins     newWarpOp.getBody()->walk([&](Operation *op) {
8547360d5d3SQuinn Dawkins       if (isOpTriviallyDead(op))
8557360d5d3SQuinn Dawkins         rewriter.eraseOp(op);
8567360d5d3SQuinn Dawkins     });
8577360d5d3SQuinn Dawkins 
85820df17fdSNicolas Vasilache     // Replace results of the old warpOp by the new, deduplicated results.
85920df17fdSNicolas Vasilache     SmallVector<Value> newValues;
86020df17fdSNicolas Vasilache     newValues.reserve(warpOp->getNumResults());
86176cf33daSThomas Raoux     for (OpResult result : warpOp.getResults()) {
86276cf33daSThomas Raoux       if (result.use_empty())
86320df17fdSNicolas Vasilache         newValues.push_back(Value());
86420df17fdSNicolas Vasilache       else
86520df17fdSNicolas Vasilache         newValues.push_back(
86620df17fdSNicolas Vasilache             newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
86776cf33daSThomas Raoux     }
86820df17fdSNicolas Vasilache     rewriter.replaceOp(warpOp, newValues);
86976cf33daSThomas Raoux     return success();
87076cf33daSThomas Raoux   }
87176cf33daSThomas Raoux };
87276cf33daSThomas Raoux 
87376cf33daSThomas Raoux // If an operand is directly yielded out of the region we can forward it
87476cf33daSThomas Raoux // directly and it doesn't need to go through the region.
875*bc29fc93SPetr Kurapov struct WarpOpForwardOperand : public WarpDistributionPattern {
876*bc29fc93SPetr Kurapov   using Base::Base;
87776cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
87876cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
87976cf33daSThomas Raoux     SmallVector<Type> resultTypes;
88076cf33daSThomas Raoux     SmallVector<Value> yieldValues;
881ecaf2c33SPetr Kurapov     auto yield = cast<gpu::YieldOp>(
88276cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
88376cf33daSThomas Raoux     Value valForwarded;
88476cf33daSThomas Raoux     unsigned resultIndex;
88576cf33daSThomas Raoux     for (OpOperand &operand : yield->getOpOperands()) {
88676cf33daSThomas Raoux       Value result = warpOp.getResult(operand.getOperandNumber());
88776cf33daSThomas Raoux       if (result.use_empty())
88876cf33daSThomas Raoux         continue;
88976cf33daSThomas Raoux 
89076cf33daSThomas Raoux       // Assume all the values coming from above are uniform.
89176cf33daSThomas Raoux       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
89276cf33daSThomas Raoux         if (result.getType() != operand.get().getType())
89376cf33daSThomas Raoux           continue;
89476cf33daSThomas Raoux         valForwarded = operand.get();
89576cf33daSThomas Raoux         resultIndex = operand.getOperandNumber();
89676cf33daSThomas Raoux         break;
89776cf33daSThomas Raoux       }
8985550c821STres Popp       auto arg = dyn_cast<BlockArgument>(operand.get());
89976cf33daSThomas Raoux       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
90076cf33daSThomas Raoux         continue;
90176cf33daSThomas Raoux       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
90276cf33daSThomas Raoux       if (result.getType() != warpOperand.getType())
90376cf33daSThomas Raoux         continue;
90476cf33daSThomas Raoux       valForwarded = warpOperand;
90576cf33daSThomas Raoux       resultIndex = operand.getOperandNumber();
90676cf33daSThomas Raoux       break;
90776cf33daSThomas Raoux     }
90876cf33daSThomas Raoux     if (!valForwarded)
90976cf33daSThomas Raoux       return failure();
910aa2376a0SQuinn Dawkins     // Notify the rewriter that the warp op is changing (see the comment on
911aa2376a0SQuinn Dawkins     // the WarpOpTransferRead pattern).
9125fcf907bSMatthias Springer     rewriter.startOpModification(warpOp);
9137ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
9145fcf907bSMatthias Springer     rewriter.finalizeOpModification(warpOp);
91576cf33daSThomas Raoux     return success();
91676cf33daSThomas Raoux   }
91776cf33daSThomas Raoux };
91876cf33daSThomas Raoux 
919*bc29fc93SPetr Kurapov struct WarpOpBroadcast : public WarpDistributionPattern {
920*bc29fc93SPetr Kurapov   using Base::Base;
92176cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
92276cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
923971b8525SJakub Kuderski     OpOperand *operand =
924971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
92576cf33daSThomas Raoux     if (!operand)
92676cf33daSThomas Raoux       return failure();
92776cf33daSThomas Raoux     unsigned int operandNumber = operand->getOperandNumber();
92876cf33daSThomas Raoux     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
92976cf33daSThomas Raoux     Location loc = broadcastOp.getLoc();
93076cf33daSThomas Raoux     auto destVecType =
9315550c821STres Popp         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
9321dd00d39SQuentin Colombet     Value broadcastSrc = broadcastOp.getSource();
9331dd00d39SQuentin Colombet     Type broadcastSrcType = broadcastSrc.getType();
9341dd00d39SQuentin Colombet 
9351dd00d39SQuentin Colombet     // Check that the broadcast actually spans a set of values uniformly across
9361dd00d39SQuentin Colombet     // all threads. In other words, check that each thread can reconstruct
9371dd00d39SQuentin Colombet     // their own broadcast.
9381dd00d39SQuentin Colombet     // For that we simply check that the broadcast we want to build makes sense.
9391dd00d39SQuentin Colombet     if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
9401dd00d39SQuentin Colombet         vector::BroadcastableToResult::Success)
9411dd00d39SQuentin Colombet       return failure();
942d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
94376cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
9441dd00d39SQuentin Colombet         rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
94576cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
94676cf33daSThomas Raoux     Value broadcasted = rewriter.create<vector::BroadcastOp>(
947d7d6443dSThomas Raoux         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
9487ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
9497ecc921dSMatthias Springer                                 broadcasted);
95076cf33daSThomas Raoux     return success();
95176cf33daSThomas Raoux   }
95276cf33daSThomas Raoux };
95376cf33daSThomas Raoux 
95473ddc447SLei Zhang /// Pattern to move shape cast out of the warp op. shape cast is basically a
95573ddc447SLei Zhang /// no-op for warp distribution; we need to handle the shape though.
956*bc29fc93SPetr Kurapov struct WarpOpShapeCast : public WarpDistributionPattern {
957*bc29fc93SPetr Kurapov   using Base::Base;
95873ddc447SLei Zhang   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
95973ddc447SLei Zhang                                 PatternRewriter &rewriter) const override {
960971b8525SJakub Kuderski     OpOperand *operand =
961971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
96273ddc447SLei Zhang     if (!operand)
96373ddc447SLei Zhang       return failure();
964aa2376a0SQuinn Dawkins 
96573ddc447SLei Zhang     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
96673ddc447SLei Zhang 
96773ddc447SLei Zhang     unsigned int operandNumber = operand->getOperandNumber();
96873ddc447SLei Zhang     auto castDistributedType =
96973ddc447SLei Zhang         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
97073ddc447SLei Zhang     VectorType castOriginalType = oldCastOp.getSourceVectorType();
97173ddc447SLei Zhang     VectorType castResultType = castDistributedType;
97273ddc447SLei Zhang 
97373ddc447SLei Zhang     // We expect the distributed type to have a smaller rank than the original
97473ddc447SLei Zhang     // type. Prepend with size-one dimensions to make them the same.
97573ddc447SLei Zhang     unsigned castDistributedRank = castDistributedType.getRank();
97673ddc447SLei Zhang     unsigned castOriginalRank = castOriginalType.getRank();
97773ddc447SLei Zhang     if (castDistributedRank < castOriginalRank) {
97873ddc447SLei Zhang       SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
97973ddc447SLei Zhang       llvm::append_range(shape, castDistributedType.getShape());
98073ddc447SLei Zhang       castDistributedType =
98173ddc447SLei Zhang           VectorType::get(shape, castDistributedType.getElementType());
98273ddc447SLei Zhang     }
98373ddc447SLei Zhang 
98473ddc447SLei Zhang     SmallVector<size_t> newRetIndices;
98573ddc447SLei Zhang     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
98673ddc447SLei Zhang         rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
98773ddc447SLei Zhang         newRetIndices);
98873ddc447SLei Zhang     rewriter.setInsertionPointAfter(newWarpOp);
98973ddc447SLei Zhang     Value newCast = rewriter.create<vector::ShapeCastOp>(
99073ddc447SLei Zhang         oldCastOp.getLoc(), castResultType,
99173ddc447SLei Zhang         newWarpOp->getResult(newRetIndices[0]));
99273ddc447SLei Zhang     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
99373ddc447SLei Zhang     return success();
99473ddc447SLei Zhang   }
99573ddc447SLei Zhang };
99673ddc447SLei Zhang 
997d4d28914SQuinn Dawkins /// Sink out vector.create_mask op feeding into a warp op yield.
998d4d28914SQuinn Dawkins /// ```
999d4d28914SQuinn Dawkins /// %0 = ...
1000ecaf2c33SPetr Kurapov /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1001d4d28914SQuinn Dawkins ///   ...
1002d4d28914SQuinn Dawkins ///   %mask = vector.create_mask %0 : vector<32xi1>
1003ecaf2c33SPetr Kurapov ///   gpu.yield %mask : vector<32xi1>
1004d4d28914SQuinn Dawkins /// }
1005d4d28914SQuinn Dawkins /// ```
1006d4d28914SQuinn Dawkins /// To
1007d4d28914SQuinn Dawkins /// ```
1008d4d28914SQuinn Dawkins /// %0 = ...
1009ecaf2c33SPetr Kurapov /// gpu.warp_execute_on_lane_0(%arg0) {
1010d4d28914SQuinn Dawkins ///   ...
1011d4d28914SQuinn Dawkins /// }
1012d4d28914SQuinn Dawkins /// %cmp = arith.cmpi ult, %laneid, %0
1013d4d28914SQuinn Dawkins /// %ub = arith.select %cmp, %c0, %c1
1014d4d28914SQuinn Dawkins /// %1 = vector.create_mask %ub : vector<1xi1>
1015*bc29fc93SPetr Kurapov struct WarpOpCreateMask : public WarpDistributionPattern {
1016*bc29fc93SPetr Kurapov   using Base::Base;
1017d4d28914SQuinn Dawkins   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1018d4d28914SQuinn Dawkins                                 PatternRewriter &rewriter) const override {
1019971b8525SJakub Kuderski     OpOperand *yieldOperand =
1020971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1021d4d28914SQuinn Dawkins     if (!yieldOperand)
1022d4d28914SQuinn Dawkins       return failure();
1023d4d28914SQuinn Dawkins 
1024d4d28914SQuinn Dawkins     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1025d4d28914SQuinn Dawkins 
1026d4d28914SQuinn Dawkins     // Early exit if any values needed for calculating the new mask indices
1027d4d28914SQuinn Dawkins     // are defined inside the warp op.
1028d4d28914SQuinn Dawkins     if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029d4d28914SQuinn Dawkins           return warpOp.isDefinedOutsideOfRegion(value);
1030d4d28914SQuinn Dawkins         }))
1031d4d28914SQuinn Dawkins       return failure();
1032d4d28914SQuinn Dawkins 
1033d4d28914SQuinn Dawkins     Location loc = mask.getLoc();
1034d4d28914SQuinn Dawkins     unsigned operandIndex = yieldOperand->getOperandNumber();
1035d4d28914SQuinn Dawkins 
1036d4d28914SQuinn Dawkins     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037d4d28914SQuinn Dawkins     VectorType seqType = mask.getVectorType();
1038d4d28914SQuinn Dawkins     ArrayRef<int64_t> seqShape = seqType.getShape();
1039d4d28914SQuinn Dawkins     ArrayRef<int64_t> distShape = distType.getShape();
1040d4d28914SQuinn Dawkins 
1041d4d28914SQuinn Dawkins     rewriter.setInsertionPointAfter(warpOp);
1042d4d28914SQuinn Dawkins 
1043d4d28914SQuinn Dawkins     // Delinearize the lane ID for constructing the distributed mask sizes.
1044d4d28914SQuinn Dawkins     SmallVector<Value> delinearizedIds;
1045d4d28914SQuinn Dawkins     if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046d4d28914SQuinn Dawkins                            warpOp.getWarpSize(), warpOp.getLaneid(),
1047d4d28914SQuinn Dawkins                            delinearizedIds))
1048d4d28914SQuinn Dawkins       return rewriter.notifyMatchFailure(
1049d4d28914SQuinn Dawkins           mask, "cannot delinearize lane ID for distribution");
1050d4d28914SQuinn Dawkins     assert(!delinearizedIds.empty());
1051d4d28914SQuinn Dawkins 
1052aa2376a0SQuinn Dawkins     // Notify the rewriter that the warp op is changing (see the comment on
1053aa2376a0SQuinn Dawkins     // the WarpOpTransferRead pattern).
10545fcf907bSMatthias Springer     rewriter.startOpModification(warpOp);
1055aa2376a0SQuinn Dawkins 
1056d4d28914SQuinn Dawkins     AffineExpr s0, s1;
1057d4d28914SQuinn Dawkins     bindSymbols(rewriter.getContext(), s0, s1);
1058d4d28914SQuinn Dawkins     SmallVector<Value> newOperands;
1059d4d28914SQuinn Dawkins     for (int i = 0, e = distShape.size(); i < e; ++i) {
1060d4d28914SQuinn Dawkins       // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1061d4d28914SQuinn Dawkins       // find the distance from the largest mask index owned by this lane to the
1062d4d28914SQuinn Dawkins       // original mask size. `vector.create_mask` implicitly clamps mask
1063d4d28914SQuinn Dawkins       // operands to the range [0, mask_vector_size[i]], or in other words, the
1064d4d28914SQuinn Dawkins       // mask sizes are always in the range [0, mask_vector_size[i]).
1065d4d28914SQuinn Dawkins       Value maskDimIdx = affine::makeComposedAffineApply(
1066d4d28914SQuinn Dawkins           rewriter, loc, s1 - s0 * distShape[i],
1067d4d28914SQuinn Dawkins           {delinearizedIds[i], mask.getOperand(i)});
1068d4d28914SQuinn Dawkins       newOperands.push_back(maskDimIdx);
1069d4d28914SQuinn Dawkins     }
1070d4d28914SQuinn Dawkins 
1071d4d28914SQuinn Dawkins     auto newMask =
1072d4d28914SQuinn Dawkins         rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1073d4d28914SQuinn Dawkins     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
10745fcf907bSMatthias Springer     rewriter.finalizeOpModification(warpOp);
1075d4d28914SQuinn Dawkins     return success();
1076d4d28914SQuinn Dawkins   }
1077d4d28914SQuinn Dawkins };
1078d4d28914SQuinn Dawkins 
1079f48ce52cSThomas Raoux /// Pattern to move out vector.extract of single element vector. Those don't
1080f48ce52cSThomas Raoux /// need to be distributed and can just be propagated outside of the region.
1081*bc29fc93SPetr Kurapov struct WarpOpExtract : public WarpDistributionPattern {
1082*bc29fc93SPetr Kurapov   using Base::Base;
1083f48ce52cSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084f48ce52cSThomas Raoux                                 PatternRewriter &rewriter) const override {
1085971b8525SJakub Kuderski     OpOperand *operand =
1086971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1087f48ce52cSThomas Raoux     if (!operand)
1088f48ce52cSThomas Raoux       return failure();
1089f48ce52cSThomas Raoux     unsigned int operandNumber = operand->getOperandNumber();
1090f48ce52cSThomas Raoux     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1091a1aad28dSLei Zhang     VectorType extractSrcType = extractOp.getSourceVectorType();
1092f48ce52cSThomas Raoux     Location loc = extractOp.getLoc();
10939085f00bSMatthias Springer 
10942f925d75SKunwar Grover     // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
10952f925d75SKunwar Grover     if (extractSrcType.getRank() <= 1) {
10969085f00bSMatthias Springer       return failure();
10979085f00bSMatthias Springer     }
10989085f00bSMatthias Springer 
10999085f00bSMatthias Springer     // All following cases are 2d or higher dimensional source vectors.
11009085f00bSMatthias Springer 
11019085f00bSMatthias Springer     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
11029085f00bSMatthias Springer       // There is no distribution, this is a broadcast. Simply move the extract
11039085f00bSMatthias Springer       // out of the warp op.
11049085f00bSMatthias Springer       // TODO: This could be optimized. E.g., in case of a scalar result, let
11059085f00bSMatthias Springer       // one lane extract and shuffle the result to all other lanes (same as
11069085f00bSMatthias Springer       // the 1d case).
1107f48ce52cSThomas Raoux       SmallVector<size_t> newRetIndices;
1108f48ce52cSThomas Raoux       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
11099085f00bSMatthias Springer           rewriter, warpOp, {extractOp.getVector()},
1110a1aad28dSLei Zhang           {extractOp.getSourceVectorType()}, newRetIndices);
11119085f00bSMatthias Springer       rewriter.setInsertionPointAfter(newWarpOp);
11129085f00bSMatthias Springer       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
11139085f00bSMatthias Springer       // Extract from distributed vector.
11149085f00bSMatthias Springer       Value newExtract = rewriter.create<vector::ExtractOp>(
111598f6289aSDiego Caballero           loc, distributedVec, extractOp.getMixedPosition());
11167ecc921dSMatthias Springer       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
11177ecc921dSMatthias Springer                                   newExtract);
11189085f00bSMatthias Springer       return success();
11199085f00bSMatthias Springer     }
11209085f00bSMatthias Springer 
11219085f00bSMatthias Springer     // Find the distributed dimension. There should be exactly one.
11229085f00bSMatthias Springer     auto distributedType =
11235550c821STres Popp         cast<VectorType>(warpOp.getResult(operandNumber).getType());
11245550c821STres Popp     auto yieldedType = cast<VectorType>(operand->get().getType());
11259085f00bSMatthias Springer     int64_t distributedDim = -1;
11269085f00bSMatthias Springer     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
11279085f00bSMatthias Springer       if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
11289085f00bSMatthias Springer         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
11299085f00bSMatthias Springer         // support distributing multiple dimensions in the future.
11309085f00bSMatthias Springer         assert(distributedDim == -1 && "found multiple distributed dims");
11319085f00bSMatthias Springer         distributedDim = i;
11329085f00bSMatthias Springer       }
11339085f00bSMatthias Springer     }
11349085f00bSMatthias Springer     assert(distributedDim != -1 && "could not find distributed dimension");
113551ddfd76SKazu Hirata     (void)distributedDim;
11369085f00bSMatthias Springer 
11379085f00bSMatthias Springer     // Yield source vector from warp op.
11385262865aSKazu Hirata     SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
11399085f00bSMatthias Springer     for (int i = 0; i < distributedType.getRank(); ++i)
114098f6289aSDiego Caballero       newDistributedShape[i + extractOp.getNumIndices()] =
11419085f00bSMatthias Springer           distributedType.getDimSize(i);
11429085f00bSMatthias Springer     auto newDistributedType =
11439085f00bSMatthias Springer         VectorType::get(newDistributedShape, distributedType.getElementType());
11449085f00bSMatthias Springer     SmallVector<size_t> newRetIndices;
11459085f00bSMatthias Springer     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
11469085f00bSMatthias Springer         rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147f48ce52cSThomas Raoux         newRetIndices);
1148f48ce52cSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
11499085f00bSMatthias Springer     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
11509085f00bSMatthias Springer     // Extract from distributed vector.
1151f48ce52cSThomas Raoux     Value newExtract = rewriter.create<vector::ExtractOp>(
115298f6289aSDiego Caballero         loc, distributedVec, extractOp.getMixedPosition());
11537ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
11547ecc921dSMatthias Springer                                 newExtract);
1155f48ce52cSThomas Raoux     return success();
1156f48ce52cSThomas Raoux   }
1157f48ce52cSThomas Raoux };
1158f48ce52cSThomas Raoux 
11592f925d75SKunwar Grover /// Pattern to move out vector.extract with a scalar result.
11602f925d75SKunwar Grover /// Only supports 1-D and 0-D sources for now.
1161*bc29fc93SPetr Kurapov struct WarpOpExtractScalar : public WarpDistributionPattern {
11622f925d75SKunwar Grover   WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
11639d51b4e4SMatthias Springer                       PatternBenefit b = 1)
1164*bc29fc93SPetr Kurapov       : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
11651757164eSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
11661757164eSThomas Raoux                                 PatternRewriter &rewriter) const override {
1167971b8525SJakub Kuderski     OpOperand *operand =
11682f925d75SKunwar Grover         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
11691757164eSThomas Raoux     if (!operand)
11701757164eSThomas Raoux       return failure();
11711757164eSThomas Raoux     unsigned int operandNumber = operand->getOperandNumber();
11722f925d75SKunwar Grover     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1173a1aad28dSLei Zhang     VectorType extractSrcType = extractOp.getSourceVectorType();
11742f925d75SKunwar Grover     // Only supports 1-D or 0-D sources for now.
11752f925d75SKunwar Grover     if (extractSrcType.getRank() > 1) {
11762f925d75SKunwar Grover       return rewriter.notifyMatchFailure(
11772f925d75SKunwar Grover           extractOp, "only 0-D or 1-D source supported for now");
11782f925d75SKunwar Grover     }
117935c19fddSMatthias Springer     // TODO: Supported shuffle types should be parameterizable, similar to
118035c19fddSMatthias Springer     // `WarpShuffleFromIdxFn`.
118135c19fddSMatthias Springer     if (!extractSrcType.getElementType().isF32() &&
118235c19fddSMatthias Springer         !extractSrcType.getElementType().isInteger(32))
118335c19fddSMatthias Springer       return rewriter.notifyMatchFailure(
118435c19fddSMatthias Springer           extractOp, "only f32/i32 element types are supported");
1185069d7d7eSThomas Raoux     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
11869d51b4e4SMatthias Springer     Type elType = extractSrcType.getElementType();
11879d51b4e4SMatthias Springer     VectorType distributedVecType;
1188069d7d7eSThomas Raoux     if (!is0dOrVec1Extract) {
11899d51b4e4SMatthias Springer       assert(extractSrcType.getRank() == 1 &&
11902f925d75SKunwar Grover              "expected that extract src rank is 0 or 1");
1191069d7d7eSThomas Raoux       if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192069d7d7eSThomas Raoux         return failure();
11939d51b4e4SMatthias Springer       int64_t elementsPerLane =
11949d51b4e4SMatthias Springer           extractSrcType.getShape()[0] / warpOp.getWarpSize();
11959d51b4e4SMatthias Springer       distributedVecType = VectorType::get({elementsPerLane}, elType);
11969d51b4e4SMatthias Springer     } else {
11979d51b4e4SMatthias Springer       distributedVecType = extractSrcType;
11989d51b4e4SMatthias Springer     }
1199ad100b36SMatthias Springer     // Yield source vector and position (if present) from warp op.
1200ad100b36SMatthias Springer     SmallVector<Value> additionalResults{extractOp.getVector()};
1201ad100b36SMatthias Springer     SmallVector<Type> additionalResultTypes{distributedVecType};
12022f925d75SKunwar Grover     additionalResults.append(
12032f925d75SKunwar Grover         SmallVector<Value>(extractOp.getDynamicPosition()));
12042f925d75SKunwar Grover     additionalResultTypes.append(
12052f925d75SKunwar Grover         SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
12062f925d75SKunwar Grover 
12071757164eSThomas Raoux     Location loc = extractOp.getLoc();
12081757164eSThomas Raoux     SmallVector<size_t> newRetIndices;
12091757164eSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210ad100b36SMatthias Springer         rewriter, warpOp, additionalResults, additionalResultTypes,
12111757164eSThomas Raoux         newRetIndices);
12121757164eSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
12139d51b4e4SMatthias Springer     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
12149d51b4e4SMatthias Springer 
12159d51b4e4SMatthias Springer     // 0d extract: The new warp op broadcasts the source vector to all lanes.
12169d51b4e4SMatthias Springer     // All lanes extract the scalar.
1217069d7d7eSThomas Raoux     if (is0dOrVec1Extract) {
1218069d7d7eSThomas Raoux       Value newExtract;
12192f925d75SKunwar Grover       SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1220069d7d7eSThomas Raoux       newExtract =
12212f925d75SKunwar Grover           rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
12227ecc921dSMatthias Springer       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
12237ecc921dSMatthias Springer                                   newExtract);
12241757164eSThomas Raoux       return success();
12251757164eSThomas Raoux     }
12269d51b4e4SMatthias Springer 
12272f925d75SKunwar Grover     int64_t staticPos = extractOp.getStaticPosition()[0];
12282f925d75SKunwar Grover     OpFoldResult pos = ShapedType::isDynamic(staticPos)
12292f925d75SKunwar Grover                            ? (newWarpOp->getResult(newRetIndices[1]))
12302f925d75SKunwar Grover                            : OpFoldResult(rewriter.getIndexAttr(staticPos));
12319d51b4e4SMatthias Springer     // 1d extract: Distribute the source vector. One lane extracts and shuffles
12329d51b4e4SMatthias Springer     // the value to all other lanes.
12339d51b4e4SMatthias Springer     int64_t elementsPerLane = distributedVecType.getShape()[0];
12349d51b4e4SMatthias Springer     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
12359d51b4e4SMatthias Springer     // tid of extracting thread: pos / elementsPerLane
12362f925d75SKunwar Grover     Value broadcastFromTid = affine::makeComposedAffineApply(
12372f925d75SKunwar Grover         rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
12389d51b4e4SMatthias Springer     // Extract at position: pos % elementsPerLane
12392f925d75SKunwar Grover     Value newPos =
124073ce971cSMatthias Springer         elementsPerLane == 1
124173ce971cSMatthias Springer             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
12422f925d75SKunwar Grover             : affine::makeComposedAffineApply(rewriter, loc,
12432f925d75SKunwar Grover                                               sym0 % elementsPerLane, pos);
12449d51b4e4SMatthias Springer     Value extracted =
12452f925d75SKunwar Grover         rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
12469d51b4e4SMatthias Springer 
12479d51b4e4SMatthias Springer     // Shuffle the extracted value to all lanes.
12489d51b4e4SMatthias Springer     Value shuffled = warpShuffleFromIdxFn(
12499d51b4e4SMatthias Springer         loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
12507ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
12519d51b4e4SMatthias Springer     return success();
12529d51b4e4SMatthias Springer   }
12539d51b4e4SMatthias Springer 
12549d51b4e4SMatthias Springer private:
12559d51b4e4SMatthias Springer   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
12561757164eSThomas Raoux };
12571757164eSThomas Raoux 
12582f925d75SKunwar Grover /// Pattern to convert vector.extractelement to vector.extract.
1259*bc29fc93SPetr Kurapov struct WarpOpExtractElement : public WarpDistributionPattern {
1260*bc29fc93SPetr Kurapov   using Base::Base;
12612f925d75SKunwar Grover   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
12622f925d75SKunwar Grover                                 PatternRewriter &rewriter) const override {
12632f925d75SKunwar Grover     OpOperand *operand =
12642f925d75SKunwar Grover         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
12652f925d75SKunwar Grover     if (!operand)
12662f925d75SKunwar Grover       return failure();
12672f925d75SKunwar Grover     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
12682f925d75SKunwar Grover     SmallVector<OpFoldResult> indices;
12692f925d75SKunwar Grover     if (auto pos = extractOp.getPosition()) {
12702f925d75SKunwar Grover       indices.push_back(pos);
12712f925d75SKunwar Grover     }
12722f925d75SKunwar Grover     rewriter.setInsertionPoint(extractOp);
12732f925d75SKunwar Grover     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
12742f925d75SKunwar Grover         extractOp, extractOp.getVector(), indices);
12752f925d75SKunwar Grover     return success();
12762f925d75SKunwar Grover   }
12772f925d75SKunwar Grover };
12782f925d75SKunwar Grover 
12792f925d75SKunwar Grover /// Pattern to move out vector.insert with a scalar input.
12802f925d75SKunwar Grover /// Only supports 1-D and 0-D destinations for now.
1281*bc29fc93SPetr Kurapov struct WarpOpInsertScalar : public WarpDistributionPattern {
1282*bc29fc93SPetr Kurapov   using Base::Base;
128373ce971cSMatthias Springer   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
128473ce971cSMatthias Springer                                 PatternRewriter &rewriter) const override {
12852f925d75SKunwar Grover     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
128673ce971cSMatthias Springer     if (!operand)
128773ce971cSMatthias Springer       return failure();
128873ce971cSMatthias Springer     unsigned int operandNumber = operand->getOperandNumber();
12892f925d75SKunwar Grover     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
129073ce971cSMatthias Springer     VectorType vecType = insertOp.getDestVectorType();
129173ce971cSMatthias Springer     VectorType distrType =
12925550c821STres Popp         cast<VectorType>(warpOp.getResult(operandNumber).getType());
12932f925d75SKunwar Grover 
12942f925d75SKunwar Grover     // Only supports 1-D or 0-D destinations for now.
12952f925d75SKunwar Grover     if (vecType.getRank() > 1) {
12962f925d75SKunwar Grover       return rewriter.notifyMatchFailure(
12972f925d75SKunwar Grover           insertOp, "only 0-D or 1-D source supported for now");
12982f925d75SKunwar Grover     }
129973ce971cSMatthias Springer 
130073ce971cSMatthias Springer     // Yield destination vector, source scalar and position from warp op.
130173ce971cSMatthias Springer     SmallVector<Value> additionalResults{insertOp.getDest(),
130273ce971cSMatthias Springer                                          insertOp.getSource()};
130373ce971cSMatthias Springer     SmallVector<Type> additionalResultTypes{distrType,
130473ce971cSMatthias Springer                                             insertOp.getSource().getType()};
13052f925d75SKunwar Grover     additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
13062f925d75SKunwar Grover     additionalResultTypes.append(
13072f925d75SKunwar Grover         SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
13082f925d75SKunwar Grover 
130973ce971cSMatthias Springer     Location loc = insertOp.getLoc();
131073ce971cSMatthias Springer     SmallVector<size_t> newRetIndices;
131173ce971cSMatthias Springer     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
131273ce971cSMatthias Springer         rewriter, warpOp, additionalResults, additionalResultTypes,
131373ce971cSMatthias Springer         newRetIndices);
131473ce971cSMatthias Springer     rewriter.setInsertionPointAfter(newWarpOp);
131573ce971cSMatthias Springer     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
131673ce971cSMatthias Springer     Value newSource = newWarpOp->getResult(newRetIndices[1]);
131773ce971cSMatthias Springer     rewriter.setInsertionPointAfter(newWarpOp);
131873ce971cSMatthias Springer 
13192f925d75SKunwar Grover     OpFoldResult pos;
13202f925d75SKunwar Grover     if (vecType.getRank() != 0) {
13212f925d75SKunwar Grover       int64_t staticPos = insertOp.getStaticPosition()[0];
13222f925d75SKunwar Grover       pos = ShapedType::isDynamic(staticPos)
13232f925d75SKunwar Grover                 ? (newWarpOp->getResult(newRetIndices[2]))
13242f925d75SKunwar Grover                 : OpFoldResult(rewriter.getIndexAttr(staticPos));
13252f925d75SKunwar Grover     }
13262f925d75SKunwar Grover 
13272f925d75SKunwar Grover     // This condition is always true for 0-d vectors.
132873ce971cSMatthias Springer     if (vecType == distrType) {
13292f925d75SKunwar Grover       Value newInsert;
13302f925d75SKunwar Grover       SmallVector<OpFoldResult> indices;
13312f925d75SKunwar Grover       if (pos) {
13322f925d75SKunwar Grover         indices.push_back(pos);
13332f925d75SKunwar Grover       }
13342f925d75SKunwar Grover       newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
13352f925d75SKunwar Grover                                                     distributedVec, indices);
13362f925d75SKunwar Grover       // Broadcast: Simply move the vector.insert op out.
13377ecc921dSMatthias Springer       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
13387ecc921dSMatthias Springer                                   newInsert);
133973ce971cSMatthias Springer       return success();
134073ce971cSMatthias Springer     }
134173ce971cSMatthias Springer 
134273ce971cSMatthias Springer     // This is a distribution. Only one lane should insert.
134373ce971cSMatthias Springer     int64_t elementsPerLane = distrType.getShape()[0];
134473ce971cSMatthias Springer     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
134573ce971cSMatthias Springer     // tid of extracting thread: pos / elementsPerLane
13462f925d75SKunwar Grover     Value insertingLane = affine::makeComposedAffineApply(
13472f925d75SKunwar Grover         rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
134873ce971cSMatthias Springer     // Insert position: pos % elementsPerLane
13492f925d75SKunwar Grover     OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
13502f925d75SKunwar Grover         rewriter, loc, sym0 % elementsPerLane, pos);
135173ce971cSMatthias Springer     Value isInsertingLane = rewriter.create<arith::CmpIOp>(
135273ce971cSMatthias Springer         loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
135373ce971cSMatthias Springer     Value newResult =
135473ce971cSMatthias Springer         rewriter
135573ce971cSMatthias Springer             .create<scf::IfOp>(
13561125c5c0SFrederik Gossen                 loc, isInsertingLane,
135773ce971cSMatthias Springer                 /*thenBuilder=*/
135873ce971cSMatthias Springer                 [&](OpBuilder &builder, Location loc) {
13592f925d75SKunwar Grover                   Value newInsert = builder.create<vector::InsertOp>(
13602f925d75SKunwar Grover                       loc, newSource, distributedVec, newPos);
136173ce971cSMatthias Springer                   builder.create<scf::YieldOp>(loc, newInsert);
136273ce971cSMatthias Springer                 },
136373ce971cSMatthias Springer                 /*elseBuilder=*/
136473ce971cSMatthias Springer                 [&](OpBuilder &builder, Location loc) {
136573ce971cSMatthias Springer                   builder.create<scf::YieldOp>(loc, distributedVec);
136673ce971cSMatthias Springer                 })
136773ce971cSMatthias Springer             .getResult(0);
13687ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
136973ce971cSMatthias Springer     return success();
137073ce971cSMatthias Springer   }
137173ce971cSMatthias Springer };
137273ce971cSMatthias Springer 
1373*bc29fc93SPetr Kurapov struct WarpOpInsert : public WarpDistributionPattern {
1374*bc29fc93SPetr Kurapov   using Base::Base;
13751523b729SMatthias Springer   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
13761523b729SMatthias Springer                                 PatternRewriter &rewriter) const override {
1377971b8525SJakub Kuderski     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
13781523b729SMatthias Springer     if (!operand)
13791523b729SMatthias Springer       return failure();
13801523b729SMatthias Springer     unsigned int operandNumber = operand->getOperandNumber();
13811523b729SMatthias Springer     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
13821523b729SMatthias Springer     Location loc = insertOp.getLoc();
13831523b729SMatthias Springer 
13842f925d75SKunwar Grover     // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
13852f925d75SKunwar Grover     if (insertOp.getDestVectorType().getRank() <= 1) {
13861523b729SMatthias Springer       return failure();
13871523b729SMatthias Springer     }
13881523b729SMatthias Springer 
13892f925d75SKunwar Grover     // All following cases are 2d or higher dimensional source vectors.
13902f925d75SKunwar Grover 
13911523b729SMatthias Springer     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
13921523b729SMatthias Springer       // There is no distribution, this is a broadcast. Simply move the insert
13931523b729SMatthias Springer       // out of the warp op.
13941523b729SMatthias Springer       SmallVector<size_t> newRetIndices;
13951523b729SMatthias Springer       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
13961523b729SMatthias Springer           rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
13971523b729SMatthias Springer           {insertOp.getSourceType(), insertOp.getDestVectorType()},
13981523b729SMatthias Springer           newRetIndices);
13991523b729SMatthias Springer       rewriter.setInsertionPointAfter(newWarpOp);
14001523b729SMatthias Springer       Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
14011523b729SMatthias Springer       Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
14021523b729SMatthias Springer       Value newResult = rewriter.create<vector::InsertOp>(
140398f6289aSDiego Caballero           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
14047ecc921dSMatthias Springer       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
14057ecc921dSMatthias Springer                                   newResult);
14061523b729SMatthias Springer       return success();
14071523b729SMatthias Springer     }
14081523b729SMatthias Springer 
14091523b729SMatthias Springer     // Find the distributed dimension. There should be exactly one.
14101523b729SMatthias Springer     auto distrDestType =
14115550c821STres Popp         cast<VectorType>(warpOp.getResult(operandNumber).getType());
14125550c821STres Popp     auto yieldedType = cast<VectorType>(operand->get().getType());
14131523b729SMatthias Springer     int64_t distrDestDim = -1;
14141523b729SMatthias Springer     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
14151523b729SMatthias Springer       if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
14161523b729SMatthias Springer         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
14171523b729SMatthias Springer         // support distributing multiple dimensions in the future.
14181523b729SMatthias Springer         assert(distrDestDim == -1 && "found multiple distributed dims");
14191523b729SMatthias Springer         distrDestDim = i;
14201523b729SMatthias Springer       }
14211523b729SMatthias Springer     }
14221523b729SMatthias Springer     assert(distrDestDim != -1 && "could not find distributed dimension");
14231523b729SMatthias Springer 
14241523b729SMatthias Springer     // Compute the distributed source vector type.
14255550c821STres Popp     VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
14265262865aSKazu Hirata     SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
14271523b729SMatthias Springer     // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
14281523b729SMatthias Springer     // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
14291523b729SMatthias Springer     //         insert a smaller vector<3xf32>.
14301523b729SMatthias Springer     // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
14311523b729SMatthias Springer     //         case, one lane will insert the source vector<96xf32>. The other
14321523b729SMatthias Springer     //         lanes will not do anything.
143398f6289aSDiego Caballero     int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
14341523b729SMatthias Springer     if (distrSrcDim >= 0)
14351523b729SMatthias Springer       distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
14361523b729SMatthias Springer     auto distrSrcType =
14371523b729SMatthias Springer         VectorType::get(distrSrcShape, distrDestType.getElementType());
14381523b729SMatthias Springer 
14391523b729SMatthias Springer     // Yield source and dest vectors from warp op.
14401523b729SMatthias Springer     SmallVector<size_t> newRetIndices;
14411523b729SMatthias Springer     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
14421523b729SMatthias Springer         rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
14431523b729SMatthias Springer         {distrSrcType, distrDestType}, newRetIndices);
14441523b729SMatthias Springer     rewriter.setInsertionPointAfter(newWarpOp);
14451523b729SMatthias Springer     Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
14461523b729SMatthias Springer     Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
14471523b729SMatthias Springer 
14481523b729SMatthias Springer     // Insert into the distributed vector.
14491523b729SMatthias Springer     Value newResult;
14501523b729SMatthias Springer     if (distrSrcDim >= 0) {
14511523b729SMatthias Springer       // Every lane inserts a small piece.
14521523b729SMatthias Springer       newResult = rewriter.create<vector::InsertOp>(
145398f6289aSDiego Caballero           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
14541523b729SMatthias Springer     } else {
14551523b729SMatthias Springer       // One lane inserts the entire source vector.
14561523b729SMatthias Springer       int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
145798f6289aSDiego Caballero       SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
145898f6289aSDiego Caballero       SmallVector<int64_t> newPos = getAsIntegers(pos);
14591523b729SMatthias Springer       // tid of inserting lane: pos / elementsPerLane
14601523b729SMatthias Springer       Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
14611523b729SMatthias Springer           loc, newPos[distrDestDim] / elementsPerLane);
14621523b729SMatthias Springer       Value isInsertingLane = rewriter.create<arith::CmpIOp>(
14631523b729SMatthias Springer           loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
14641523b729SMatthias Springer       // Insert position: pos % elementsPerLane
14651523b729SMatthias Springer       newPos[distrDestDim] %= elementsPerLane;
14661523b729SMatthias Springer       auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
14671523b729SMatthias Springer         Value newInsert = builder.create<vector::InsertOp>(
14681523b729SMatthias Springer             loc, distributedSrc, distributedDest, newPos);
14691523b729SMatthias Springer         builder.create<scf::YieldOp>(loc, newInsert);
14701523b729SMatthias Springer       };
14711523b729SMatthias Springer       auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
14721523b729SMatthias Springer         builder.create<scf::YieldOp>(loc, distributedDest);
14731523b729SMatthias Springer       };
14741523b729SMatthias Springer       newResult = rewriter
14751125c5c0SFrederik Gossen                       .create<scf::IfOp>(loc, isInsertingLane,
14761523b729SMatthias Springer                                          /*thenBuilder=*/insertingBuilder,
14771523b729SMatthias Springer                                          /*elseBuilder=*/nonInsertingBuilder)
14781523b729SMatthias Springer                       .getResult(0);
14791523b729SMatthias Springer     }
14801523b729SMatthias Springer 
14817ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
14821523b729SMatthias Springer     return success();
14831523b729SMatthias Springer   }
14841523b729SMatthias Springer };
14851523b729SMatthias Springer 
1486*bc29fc93SPetr Kurapov struct WarpOpInsertElement : public WarpDistributionPattern {
1487*bc29fc93SPetr Kurapov   using Base::Base;
14882f925d75SKunwar Grover   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
14892f925d75SKunwar Grover                                 PatternRewriter &rewriter) const override {
14902f925d75SKunwar Grover     OpOperand *operand =
14912f925d75SKunwar Grover         getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
14922f925d75SKunwar Grover     if (!operand)
14932f925d75SKunwar Grover       return failure();
14942f925d75SKunwar Grover     auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
14952f925d75SKunwar Grover     SmallVector<OpFoldResult> indices;
14962f925d75SKunwar Grover     if (auto pos = insertOp.getPosition()) {
14972f925d75SKunwar Grover       indices.push_back(pos);
14982f925d75SKunwar Grover     }
14992f925d75SKunwar Grover     rewriter.setInsertionPoint(insertOp);
15002f925d75SKunwar Grover     rewriter.replaceOpWithNewOp<vector::InsertOp>(
15012f925d75SKunwar Grover         insertOp, insertOp.getSource(), insertOp.getDest(), indices);
15022f925d75SKunwar Grover     return success();
15032f925d75SKunwar Grover   }
15042f925d75SKunwar Grover };
15052f925d75SKunwar Grover 
150676cf33daSThomas Raoux /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
15072f925d75SKunwar Grover /// the scf.ForOp is the last operation in the region so that it doesn't
15082f925d75SKunwar Grover /// change the order of execution. This creates a new scf.for region after the
150976cf33daSThomas Raoux /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
151076cf33daSThomas Raoux /// WarpExecuteOnLane0Op region. Example:
151176cf33daSThomas Raoux /// ```
1512ecaf2c33SPetr Kurapov /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
151376cf33daSThomas Raoux ///   ...
151476cf33daSThomas Raoux ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
151576cf33daSThomas Raoux ///   -> (vector<128xf32>) {
151676cf33daSThomas Raoux ///     ...
151776cf33daSThomas Raoux ///     scf.yield %r : vector<128xf32>
151876cf33daSThomas Raoux ///   }
1519ecaf2c33SPetr Kurapov ///   gpu.yield %v1 : vector<128xf32>
152076cf33daSThomas Raoux /// }
152176cf33daSThomas Raoux /// ```
152276cf33daSThomas Raoux /// To:
1523ecaf2c33SPetr Kurapov /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
152476cf33daSThomas Raoux ///   ...
1525ecaf2c33SPetr Kurapov ///   gpu.yield %v : vector<128xf32>
152676cf33daSThomas Raoux /// }
152776cf33daSThomas Raoux /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
152876cf33daSThomas Raoux ///   -> (vector<4xf32>) {
1529ecaf2c33SPetr Kurapov ///     %iw = gpu.warp_execute_on_lane_0(%laneid)
153076cf33daSThomas Raoux ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
153176cf33daSThomas Raoux ///     ^bb0(%arg: vector<128xf32>):
153276cf33daSThomas Raoux ///       ...
1533ecaf2c33SPetr Kurapov ///       gpu.yield %ir : vector<128xf32>
153476cf33daSThomas Raoux ///     }
153576cf33daSThomas Raoux ///     scf.yield %iw : vector<4xf32>
153676cf33daSThomas Raoux ///  }
153776cf33daSThomas Raoux /// ```
1538*bc29fc93SPetr Kurapov struct WarpOpScfForOp : public WarpDistributionPattern {
153991f62f0eSThomas Raoux 
154091f62f0eSThomas Raoux   WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1541*bc29fc93SPetr Kurapov       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
154276cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
154376cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
1544ecaf2c33SPetr Kurapov     auto yield = cast<gpu::YieldOp>(
154576cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
154676cf33daSThomas Raoux     // Only pick up forOp if it is the last op in the region.
154776cf33daSThomas Raoux     Operation *lastNode = yield->getPrevNode();
154876cf33daSThomas Raoux     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
154976cf33daSThomas Raoux     if (!forOp)
155076cf33daSThomas Raoux       return failure();
155191f62f0eSThomas Raoux     // Collect Values that come from the warp op but are outside the forOp.
15522f925d75SKunwar Grover     // Those Value needs to be returned by the original warpOp and passed to
15532f925d75SKunwar Grover     // the new op.
155491f62f0eSThomas Raoux     llvm::SmallSetVector<Value, 32> escapingValues;
155591f62f0eSThomas Raoux     SmallVector<Type> inputTypes;
155691f62f0eSThomas Raoux     SmallVector<Type> distTypes;
155791f62f0eSThomas Raoux     mlir::visitUsedValuesDefinedAbove(
155891f62f0eSThomas Raoux         forOp.getBodyRegion(), [&](OpOperand *operand) {
155991f62f0eSThomas Raoux           Operation *parent = operand->get().getParentRegion()->getParentOp();
156091f62f0eSThomas Raoux           if (warpOp->isAncestor(parent)) {
156191f62f0eSThomas Raoux             if (!escapingValues.insert(operand->get()))
156291f62f0eSThomas Raoux               return;
156391f62f0eSThomas Raoux             Type distType = operand->get().getType();
1564d2433787SLei Zhang             if (auto vecType = dyn_cast<VectorType>(distType)) {
156591f62f0eSThomas Raoux               AffineMap map = distributionMapFn(operand->get());
156691f62f0eSThomas Raoux               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
156791f62f0eSThomas Raoux             }
156891f62f0eSThomas Raoux             inputTypes.push_back(operand->get().getType());
156991f62f0eSThomas Raoux             distTypes.push_back(distType);
157091f62f0eSThomas Raoux           }
157191f62f0eSThomas Raoux         });
157291f62f0eSThomas Raoux 
1573b5e47d2eSBangtian Liu     if (llvm::is_contained(distTypes, Type{}))
1574b5e47d2eSBangtian Liu       return failure();
1575b5e47d2eSBangtian Liu 
157691f62f0eSThomas Raoux     SmallVector<size_t> newRetIndices;
157791f62f0eSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
157891f62f0eSThomas Raoux         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
157991f62f0eSThomas Raoux         newRetIndices);
1580ecaf2c33SPetr Kurapov     yield = cast<gpu::YieldOp>(
158191f62f0eSThomas Raoux         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
158291f62f0eSThomas Raoux 
158376cf33daSThomas Raoux     SmallVector<Value> newOperands;
158476cf33daSThomas Raoux     SmallVector<unsigned> resultIdx;
158576cf33daSThomas Raoux     // Collect all the outputs coming from the forOp.
158676cf33daSThomas Raoux     for (OpOperand &yieldOperand : yield->getOpOperands()) {
158776cf33daSThomas Raoux       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
158876cf33daSThomas Raoux         continue;
15895550c821STres Popp       auto forResult = cast<OpResult>(yieldOperand.get());
159091f62f0eSThomas Raoux       newOperands.push_back(
159191f62f0eSThomas Raoux           newWarpOp.getResult(yieldOperand.getOperandNumber()));
15925cf714bbSMatthias Springer       yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
159376cf33daSThomas Raoux       resultIdx.push_back(yieldOperand.getOperandNumber());
159476cf33daSThomas Raoux     }
159591f62f0eSThomas Raoux 
159676cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
159791f62f0eSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
159891f62f0eSThomas Raoux 
15992f925d75SKunwar Grover     // Create a new for op outside the region with a WarpExecuteOnLane0Op
16002f925d75SKunwar Grover     // region inside.
160176cf33daSThomas Raoux     auto newForOp = rewriter.create<scf::ForOp>(
160276cf33daSThomas Raoux         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
160376cf33daSThomas Raoux         forOp.getStep(), newOperands);
1604b613a540SMatthias Springer     rewriter.setInsertionPointToStart(newForOp.getBody());
160591f62f0eSThomas Raoux 
160691f62f0eSThomas Raoux     SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
160791f62f0eSThomas Raoux                                  newForOp.getRegionIterArgs().end());
160891f62f0eSThomas Raoux     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
160991f62f0eSThomas Raoux                                     forOp.getResultTypes().end());
161091f62f0eSThomas Raoux     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
161191f62f0eSThomas Raoux     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
161291f62f0eSThomas Raoux       warpInput.push_back(newWarpOp.getResult(retIdx));
161391f62f0eSThomas Raoux       argIndexMapping[escapingValues[i]] = warpInputType.size();
161491f62f0eSThomas Raoux       warpInputType.push_back(inputTypes[i]);
161591f62f0eSThomas Raoux     }
161676cf33daSThomas Raoux     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
161791f62f0eSThomas Raoux         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
161891f62f0eSThomas Raoux         newWarpOp.getWarpSize(), warpInput, warpInputType);
161976cf33daSThomas Raoux 
162076cf33daSThomas Raoux     SmallVector<Value> argMapping;
162176cf33daSThomas Raoux     argMapping.push_back(newForOp.getInductionVar());
162276cf33daSThomas Raoux     for (Value args : innerWarp.getBody()->getArguments()) {
162376cf33daSThomas Raoux       argMapping.push_back(args);
162476cf33daSThomas Raoux     }
162591f62f0eSThomas Raoux     argMapping.resize(forOp.getBody()->getNumArguments());
162676cf33daSThomas Raoux     SmallVector<Value> yieldOperands;
162776cf33daSThomas Raoux     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
162876cf33daSThomas Raoux       yieldOperands.push_back(operand);
162976cf33daSThomas Raoux     rewriter.eraseOp(forOp.getBody()->getTerminator());
163076cf33daSThomas Raoux     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1631b613a540SMatthias Springer     rewriter.setInsertionPointToEnd(innerWarp.getBody());
1632ecaf2c33SPetr Kurapov     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
163376cf33daSThomas Raoux     rewriter.setInsertionPointAfter(innerWarp);
1634d343cdd5SThomas Raoux     if (!innerWarp.getResults().empty())
163576cf33daSThomas Raoux       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
163676cf33daSThomas Raoux     rewriter.eraseOp(forOp);
163776cf33daSThomas Raoux     // Replace the warpOp result coming from the original ForOp.
163876cf33daSThomas Raoux     for (const auto &res : llvm::enumerate(resultIdx)) {
16397ecc921dSMatthias Springer       rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
16407ecc921dSMatthias Springer                                   newForOp.getResult(res.index()));
164191f62f0eSThomas Raoux       newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
164276cf33daSThomas Raoux     }
164391f62f0eSThomas Raoux     newForOp.walk([&](Operation *op) {
164491f62f0eSThomas Raoux       for (OpOperand &operand : op->getOpOperands()) {
164591f62f0eSThomas Raoux         auto it = argIndexMapping.find(operand.get());
164691f62f0eSThomas Raoux         if (it == argIndexMapping.end())
164791f62f0eSThomas Raoux           continue;
164891f62f0eSThomas Raoux         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
164991f62f0eSThomas Raoux       }
165091f62f0eSThomas Raoux     });
165198dcd98aSQuinn Dawkins 
165298dcd98aSQuinn Dawkins     // Finally, hoist out any now uniform code from the inner warp op.
165398dcd98aSQuinn Dawkins     mlir::vector::moveScalarUniformCode(innerWarp);
165476cf33daSThomas Raoux     return success();
165576cf33daSThomas Raoux   }
165691f62f0eSThomas Raoux 
165791f62f0eSThomas Raoux private:
165891f62f0eSThomas Raoux   DistributionMapFn distributionMapFn;
165976cf33daSThomas Raoux };
166076cf33daSThomas Raoux 
1661087aba4fSThomas Raoux /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
16622f925d75SKunwar Grover /// The vector is reduced in parallel. Currently limited to vector size
16632f925d75SKunwar Grover /// matching the warpOp size. E.g.:
1664087aba4fSThomas Raoux /// ```
1665ecaf2c33SPetr Kurapov /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1666087aba4fSThomas Raoux ///   %0 = "some_def"() : () -> (vector<32xf32>)
1667087aba4fSThomas Raoux ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1668ecaf2c33SPetr Kurapov ///   gpu.yield %1 : f32
1669087aba4fSThomas Raoux /// }
1670087aba4fSThomas Raoux /// ```
1671087aba4fSThomas Raoux /// is lowered to:
1672087aba4fSThomas Raoux /// ```
1673ecaf2c33SPetr Kurapov /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1674087aba4fSThomas Raoux ///   %1 = "some_def"() : () -> (vector<32xf32>)
1675ecaf2c33SPetr Kurapov ///   gpu.yield %1 : vector<32xf32>
1676087aba4fSThomas Raoux /// }
16779816edc9SCullen Rhodes /// %a = vector.extract %0[0] : f32 from vector<1xf32>
16786834803cSThomas Raoux /// %r = ("warp.reduction %a")
1679087aba4fSThomas Raoux /// ```
1680*bc29fc93SPetr Kurapov struct WarpOpReduction : public WarpDistributionPattern {
16816834803cSThomas Raoux   WarpOpReduction(MLIRContext *context,
16826834803cSThomas Raoux                   DistributedReductionFn distributedReductionFn,
16836834803cSThomas Raoux                   PatternBenefit benefit = 1)
1684*bc29fc93SPetr Kurapov       : WarpDistributionPattern(context, benefit),
168561f06774SMehdi Amini         distributedReductionFn(std::move(distributedReductionFn)) {}
1686087aba4fSThomas Raoux 
1687087aba4fSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1688087aba4fSThomas Raoux                                 PatternRewriter &rewriter) const override {
1689971b8525SJakub Kuderski     OpOperand *yieldOperand =
1690971b8525SJakub Kuderski         getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1691087aba4fSThomas Raoux     if (!yieldOperand)
1692087aba4fSThomas Raoux       return failure();
1693087aba4fSThomas Raoux 
1694087aba4fSThomas Raoux     auto reductionOp =
1695087aba4fSThomas Raoux         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
16965550c821STres Popp     auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1697087aba4fSThomas Raoux     // Only rank 1 vectors supported.
1698087aba4fSThomas Raoux     if (vectorType.getRank() != 1)
1699087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
1700087aba4fSThomas Raoux           warpOp, "Only rank 1 reductions can be distributed.");
1701087aba4fSThomas Raoux     // Only warp_size-sized vectors supported.
17020660f3c5SThomas Raoux     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1703087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
1704087aba4fSThomas Raoux           warpOp, "Reduction vector dimension must match was size.");
1705f41abcdaSThomas Raoux     if (!reductionOp.getType().isIntOrFloat())
1706087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
1707f41abcdaSThomas Raoux           warpOp, "Reduction distribution currently only supports floats and "
1708f41abcdaSThomas Raoux                   "integer types.");
1709087aba4fSThomas Raoux 
17100660f3c5SThomas Raoux     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1711087aba4fSThomas Raoux     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1712087aba4fSThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
1713087aba4fSThomas Raoux     SmallVector<Value> yieldValues = {reductionOp.getVector()};
17140660f3c5SThomas Raoux     SmallVector<Type> retTypes = {
17150660f3c5SThomas Raoux         VectorType::get({numElements}, reductionOp.getType())};
1716ffa7384fSThomas Raoux     if (reductionOp.getAcc()) {
1717ffa7384fSThomas Raoux       yieldValues.push_back(reductionOp.getAcc());
1718ffa7384fSThomas Raoux       retTypes.push_back(reductionOp.getAcc().getType());
1719ffa7384fSThomas Raoux     }
1720d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
1721087aba4fSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1722d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1723087aba4fSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
1724087aba4fSThomas Raoux 
1725d2061530Sstanley-nod     // Obtain data to reduce for a single lane.
1726d7d6443dSThomas Raoux     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1727d2061530Sstanley-nod     // Distribute and reduce across threads.
17280660f3c5SThomas Raoux     Value fullReduce =
1729d2061530Sstanley-nod         distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
17306834803cSThomas Raoux                                reductionOp.getKind(), newWarpOp.getWarpSize());
1731ffa7384fSThomas Raoux     if (reductionOp.getAcc()) {
1732ffa7384fSThomas Raoux       fullReduce = vector::makeArithReduction(
1733ffa7384fSThomas Raoux           rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1734ffa7384fSThomas Raoux           newWarpOp.getResult(newRetIndices[1]));
1735ffa7384fSThomas Raoux     }
17367ecc921dSMatthias Springer     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1737087aba4fSThomas Raoux     return success();
1738087aba4fSThomas Raoux   }
17396834803cSThomas Raoux 
17406834803cSThomas Raoux private:
17416834803cSThomas Raoux   DistributedReductionFn distributedReductionFn;
1742087aba4fSThomas Raoux };
1743087aba4fSThomas Raoux 
1744d02f10d9SThomas Raoux } // namespace
1745d02f10d9SThomas Raoux 
1746d02f10d9SThomas Raoux void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1747d02f10d9SThomas Raoux     RewritePatternSet &patterns,
174827cc31b6SNicolas Vasilache     const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
17494abb9e5dSThomas Raoux   patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1750d02f10d9SThomas Raoux }
1751ed0288f7SThomas Raoux 
1752ed0288f7SThomas Raoux void mlir::vector::populateDistributeTransferWriteOpPatterns(
175327cc31b6SNicolas Vasilache     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
175480636227SJakub Kuderski     unsigned maxNumElementsToExtract, PatternBenefit benefit) {
175527cc31b6SNicolas Vasilache   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
175680636227SJakub Kuderski                                     maxNumElementsToExtract, benefit);
1757ed0288f7SThomas Raoux }
1758ed0288f7SThomas Raoux 
175976cf33daSThomas Raoux void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
176091f62f0eSThomas Raoux     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1761df49a97aSQuinn Dawkins     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1762df49a97aSQuinn Dawkins     PatternBenefit readBenefit) {
1763df49a97aSQuinn Dawkins   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
17642f925d75SKunwar Grover   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
17652f925d75SKunwar Grover                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
17662f925d75SKunwar Grover                WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
17672f925d75SKunwar Grover                WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1768df49a97aSQuinn Dawkins       patterns.getContext(), benefit);
17692f925d75SKunwar Grover   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
17702f925d75SKunwar Grover                                     benefit);
177191f62f0eSThomas Raoux   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
177291f62f0eSThomas Raoux                                benefit);
177376cf33daSThomas Raoux }
177476cf33daSThomas Raoux 
17756834803cSThomas Raoux void mlir::vector::populateDistributeReduction(
17766834803cSThomas Raoux     RewritePatternSet &patterns,
177727cc31b6SNicolas Vasilache     const DistributedReductionFn &distributedReductionFn,
177827cc31b6SNicolas Vasilache     PatternBenefit benefit) {
177927cc31b6SNicolas Vasilache   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
178027cc31b6SNicolas Vasilache                                 benefit);
1781087aba4fSThomas Raoux }
1782087aba4fSThomas Raoux 
1783*bc29fc93SPetr Kurapov /// Helper to know if an op can be hoisted out of the region.
1784*bc29fc93SPetr Kurapov static bool canBeHoisted(Operation *op,
1785*bc29fc93SPetr Kurapov                          function_ref<bool(Value)> definedOutside) {
1786*bc29fc93SPetr Kurapov   return llvm::all_of(op->getOperands(), definedOutside) &&
1787*bc29fc93SPetr Kurapov          isMemoryEffectFree(op) && op->getNumRegions() == 0;
1788*bc29fc93SPetr Kurapov }
1789*bc29fc93SPetr Kurapov 
1790ed0288f7SThomas Raoux void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1791ed0288f7SThomas Raoux   Block *body = warpOp.getBody();
1792ed0288f7SThomas Raoux 
1793ed0288f7SThomas Raoux   // Keep track of the ops we want to hoist.
1794ed0288f7SThomas Raoux   llvm::SmallSetVector<Operation *, 8> opsToMove;
1795ed0288f7SThomas Raoux 
1796ed0288f7SThomas Raoux   // Helper to check if a value is or will be defined outside of the region.
1797ed0288f7SThomas Raoux   auto isDefinedOutsideOfBody = [&](Value value) {
1798ed0288f7SThomas Raoux     auto *definingOp = value.getDefiningOp();
1799ed0288f7SThomas Raoux     return (definingOp && opsToMove.count(definingOp)) ||
1800ed0288f7SThomas Raoux            warpOp.isDefinedOutsideOfRegion(value);
1801ed0288f7SThomas Raoux   };
1802ed0288f7SThomas Raoux 
1803ed0288f7SThomas Raoux   // Do not use walk here, as we do not want to go into nested regions and hoist
1804ed0288f7SThomas Raoux   // operations from there.
1805ed0288f7SThomas Raoux   for (auto &op : body->without_terminator()) {
1806ed0288f7SThomas Raoux     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
18075550c821STres Popp       return isa<VectorType>(result.getType());
1808ed0288f7SThomas Raoux     });
1809ed0288f7SThomas Raoux     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1810ed0288f7SThomas Raoux       opsToMove.insert(&op);
1811ed0288f7SThomas Raoux   }
1812ed0288f7SThomas Raoux 
1813ed0288f7SThomas Raoux   // Move all the ops marked as uniform outside of the region.
1814ed0288f7SThomas Raoux   for (Operation *op : opsToMove)
1815ed0288f7SThomas Raoux     op->moveBefore(warpOp);
1816ed0288f7SThomas Raoux }
1817