xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Transforms/RegionUtils.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::gpu;
27 
28 /// Currently the distribution map is implicit based on the vector shape. In the
29 /// future it will be part of the op.
30 /// Example:
31 /// ```
32 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
33 ///   ...
34 ///   gpu.yield %3 : vector<32x16x64xf32>
35 /// }
36 /// ```
37 /// Would have an implicit map of:
38 /// `(d0, d1, d2) -> (d0, d2)`
39 static AffineMap calculateImplicitMap(VectorType sequentialType,
40                                       VectorType distributedType) {
41   SmallVector<AffineExpr> perm;
42   perm.reserve(1);
43   // Check which dimensions of the sequential type are different than the
44   // dimensions of the distributed type to know the distributed dimensions. Then
45   // associate each distributed dimension to an ID in order.
46   for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47     if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48       perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49   }
50   auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51                             distributedType.getContext());
52   return map;
53 }
54 
55 namespace {
56 
57 /// Helper struct to create the load / store operations that permit transit
58 /// through the parallel / sequential and the sequential / parallel boundaries
59 /// when performing `rewriteWarpOpToScfFor`.
60 ///
61 /// The vector distribution dimension is inferred from the vector types.
62 struct DistributedLoadStoreHelper {
63   DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64                              Value laneId, Value zero)
65       : sequentialVal(sequentialVal), distributedVal(distributedVal),
66         laneId(laneId), zero(zero) {
67     sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
68     distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
69     if (sequentialVectorType && distributedVectorType)
70       distributionMap =
71           calculateImplicitMap(sequentialVectorType, distributedVectorType);
72   }
73 
74   Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
75     int64_t distributedSize = distributedVectorType.getDimSize(index);
76     AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
77     return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78                                                  ArrayRef<Value>{laneId});
79   }
80 
81   /// Create a store during the process of distributing the
82   /// `vector.warp_execute_on_thread_0` op.
83   /// Vector distribution assumes the following convention regarding the
84   /// temporary buffers that are created to transition values. This **must**
85   /// be properly specified in the `options.warpAllocationFn`:
86   ///   1. scalars of type T transit through a memref<1xT>.
87   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
88   Operation *buildStore(RewriterBase &b, Location loc, Value val,
89                         Value buffer) {
90     assert((val == distributedVal || val == sequentialVal) &&
91            "Must store either the preregistered distributed or the "
92            "preregistered sequential value.");
93     // Scalar case can directly use memref.store.
94     if (!isa<VectorType>(val.getType()))
95       return b.create<memref::StoreOp>(loc, val, buffer, zero);
96 
97     // Vector case must use vector::TransferWriteOp which will later lower to
98     //   vector.store of memref.store depending on further lowerings.
99     int64_t rank = sequentialVectorType.getRank();
100     SmallVector<Value> indices(rank, zero);
101     if (val == distributedVal) {
102       for (auto dimExpr : distributionMap.getResults()) {
103         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104         indices[index] = buildDistributedOffset(b, loc, index);
105       }
106     }
107     SmallVector<bool> inBounds(indices.size(), true);
108     return b.create<vector::TransferWriteOp>(
109         loc, val, buffer, indices,
110         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111   }
112 
113   /// Create a load during the process of distributing the
114   /// `vector.warp_execute_on_thread_0` op.
115   /// Vector distribution assumes the following convention regarding the
116   /// temporary buffers that are created to transition values. This **must**
117   /// be properly specified in the `options.warpAllocationFn`:
118   ///   1. scalars of type T transit through a memref<1xT>.
119   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
120   ///
121   /// When broadcastMode is true, the load is not distributed to account for
122   /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123   ///
124   /// Example:
125   ///
126   /// ```
127   ///   %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128   ///     gpu.yield %cst : f32
129   ///   }
130   ///   // Both types are f32. The constant %cst is broadcasted to all lanes.
131   /// ```
132   /// This behavior described in more detail in the documentation of the op.
133   Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
134 
135     // Scalar case can directly use memref.store.
136     if (!isa<VectorType>(type))
137       return b.create<memref::LoadOp>(loc, buffer, zero);
138 
139     // Other cases must be vector atm.
140     // Vector case must use vector::TransferReadOp which will later lower to
141     //   vector.read of memref.read depending on further lowerings.
142     assert((type == distributedVectorType || type == sequentialVectorType) &&
143            "Must store either the preregistered distributed or the "
144            "preregistered sequential type.");
145     SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146     if (type == distributedVectorType) {
147       for (auto dimExpr : distributionMap.getResults()) {
148         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149         indices[index] = buildDistributedOffset(b, loc, index);
150       }
151     }
152     SmallVector<bool> inBounds(indices.size(), true);
153     return b.create<vector::TransferReadOp>(
154         loc, cast<VectorType>(type), buffer, indices,
155         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156   }
157 
158   Value sequentialVal, distributedVal, laneId, zero;
159   VectorType sequentialVectorType, distributedVectorType;
160   AffineMap distributionMap;
161 };
162 
163 } // namespace
164 
165 // Clones `op` into a new operation that takes `operands` and returns
166 // `resultTypes`.
167 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
168                                               Location loc, Operation *op,
169                                               ArrayRef<Value> operands,
170                                               ArrayRef<Type> resultTypes) {
171   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
172                      op->getAttrs());
173   return rewriter.create(res);
174 }
175 
176 namespace {
177 
178 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179 /// thread `laneId` executes the entirety of the computation.
180 ///
181 /// After the transformation:
182 ///   - the IR within the scf.if op can be thought of as executing sequentially
183 ///     (from the point of view of threads along `laneId`).
184 ///   - the IR outside of the scf.if op can be thought of as executing in
185 ///     parallel (from the point of view of threads along `laneId`).
186 ///
187 /// Values that need to transit through the parallel / sequential and the
188 /// sequential / parallel boundaries do so via reads and writes to a temporary
189 /// memory location.
190 ///
191 /// The transformation proceeds in multiple steps:
192 ///   1. Create the scf.if op.
193 ///   2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194 ///      within the scf.if to transit the values captured from above.
195 ///   3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196 ///      consistent within the scf.if.
197 ///   4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198 ///   5. Insert appropriate writes within scf.if and reads after the scf.if to
199 ///      transit the values returned by the op.
200 ///   6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201 ///      consistent after the scf.if.
202 ///   7. Perform late cleanups.
203 ///
204 /// All this assumes the vector distribution occurs along the most minor
205 /// distributed vector dimension.
206 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
207   WarpOpToScfIfPattern(MLIRContext *context,
208                        const WarpExecuteOnLane0LoweringOptions &options,
209                        PatternBenefit benefit = 1)
210       : WarpDistributionPattern(context, benefit), options(options) {}
211 
212   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213                                 PatternRewriter &rewriter) const override {
214     assert(warpOp.getBodyRegion().hasOneBlock() &&
215            "expected WarpOp with single block");
216     Block *warpOpBody = &warpOp.getBodyRegion().front();
217     Location loc = warpOp.getLoc();
218 
219     // Passed all checks. Start rewriting.
220     OpBuilder::InsertionGuard g(rewriter);
221     rewriter.setInsertionPoint(warpOp);
222 
223     // Step 1: Create scf.if op.
224     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
225     Value isLane0 = rewriter.create<arith::CmpIOp>(
226         loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227     auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228                                            /*withElseRegion=*/false);
229     rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230 
231     // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232     // reads within the scf.if to transit the values captured from above.
233     SmallVector<Value> bbArgReplacements;
234     for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235       Value sequentialVal = warpOpBody->getArgument(it.index());
236       Value distributedVal = it.value();
237       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238                                         warpOp.getLaneid(), c0);
239 
240       // Create buffer before the ifOp.
241       rewriter.setInsertionPoint(ifOp);
242       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243                                               sequentialVal.getType());
244       // Store distributed vector into buffer, before the ifOp.
245       helper.buildStore(rewriter, loc, distributedVal, buffer);
246       // Load sequential vector from buffer, inside the ifOp.
247       rewriter.setInsertionPointToStart(ifOp.thenBlock());
248       bbArgReplacements.push_back(
249           helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250     }
251 
252     // Step 3. Insert sync after all the stores and before all the loads.
253     if (!warpOp.getArgs().empty()) {
254       rewriter.setInsertionPoint(ifOp);
255       options.warpSyncronizationFn(loc, rewriter, warpOp);
256     }
257 
258     // Step 4. Move body of warpOp to ifOp.
259     rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260 
261     // Step 5. Insert appropriate writes within scf.if and reads after the
262     // scf.if to transit the values returned by the op.
263     // TODO: at this point, we can reuse the shared memory from previous
264     // buffers.
265     SmallVector<Value> replacements;
266     auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267     Location yieldLoc = yieldOp.getLoc();
268     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269       Value sequentialVal = it.value();
270       Value distributedVal = warpOp->getResult(it.index());
271       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272                                         warpOp.getLaneid(), c0);
273 
274       // Create buffer before the ifOp.
275       rewriter.setInsertionPoint(ifOp);
276       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277                                               sequentialVal.getType());
278 
279       // Store yielded value into buffer, inside the ifOp, before the
280       // terminator.
281       rewriter.setInsertionPoint(yieldOp);
282       helper.buildStore(rewriter, loc, sequentialVal, buffer);
283 
284       // Load distributed value from buffer, after  the warpOp.
285       rewriter.setInsertionPointAfter(ifOp);
286       // Result type and yielded value type are the same. This is a broadcast.
287       // E.g.:
288       // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289       //   gpu.yield %cst : f32
290       // }
291       // Both types are f32. The constant %cst is broadcasted to all lanes.
292       // This is described in more detail in the documentation of the op.
293       replacements.push_back(
294           helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295     }
296 
297     // Step 6. Insert sync after all the stores and before all the loads.
298     if (!yieldOp.getOperands().empty()) {
299       rewriter.setInsertionPointAfter(ifOp);
300       options.warpSyncronizationFn(loc, rewriter, warpOp);
301     }
302 
303     // Step 7. Delete terminator and add empty scf.yield.
304     rewriter.eraseOp(yieldOp);
305     rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306     rewriter.create<scf::YieldOp>(yieldLoc);
307 
308     // Compute replacements for WarpOp results.
309     rewriter.replaceOp(warpOp, replacements);
310 
311     return success();
312   }
313 
314 private:
315   const WarpExecuteOnLane0LoweringOptions &options;
316 };
317 
318 /// Return the distributed vector type based on the original type and the
319 /// distribution map. The map is expected to have a dimension equal to the
320 /// original type rank and should be a projection where the results are the
321 /// distributed dimensions. The number of results should be equal to the number
322 /// of warp sizes which is currently limited to 1.
323 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
324 /// and a warp size of 16 would distribute the second dimension (associated to
325 /// d1) and return vector<16x2x64>
326 static VectorType getDistributedType(VectorType originalType, AffineMap map,
327                                      int64_t warpSize) {
328   SmallVector<int64_t> targetShape(originalType.getShape());
329   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
330     unsigned position = map.getDimPosition(i);
331     if (targetShape[position] % warpSize != 0) {
332       if (warpSize % targetShape[position] != 0) {
333         return VectorType();
334       }
335       warpSize /= targetShape[position];
336       targetShape[position] = 1;
337       continue;
338     }
339     targetShape[position] = targetShape[position] / warpSize;
340     warpSize = 1;
341     break;
342   }
343   if (warpSize != 1) {
344     return VectorType();
345   }
346   VectorType targetType =
347       VectorType::get(targetShape, originalType.getElementType());
348   return targetType;
349 }
350 
351 /// Distribute transfer_write ops based on the affine map returned by
352 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
353 /// will not be distributed (it should be less than the warp size).
354 ///
355 /// Example:
356 /// ```
357 /// %0 = gpu.warp_execute_on_lane_0(%id){
358 ///   ...
359 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360 ///   gpu.yield
361 /// }
362 /// ```
363 /// To
364 /// ```
365 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366 ///   ...
367 ///   gpu.yield %v : vector<32xf32>
368 /// }
369 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370 struct WarpOpTransferWrite : public WarpDistributionPattern {
371   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
372                       unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
374         maxNumElementsToExtract(maxNumElementsToExtract) {}
375 
376   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377   /// are multiples of the distribution ratio are supported at the moment.
378   LogicalResult tryDistributeOp(RewriterBase &rewriter,
379                                 vector::TransferWriteOp writeOp,
380                                 WarpExecuteOnLane0Op warpOp) const {
381     VectorType writtenVectorType = writeOp.getVectorType();
382 
383     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
384     // to separate it from the rest.
385     if (writtenVectorType.getRank() == 0)
386       return failure();
387 
388     // 2. Compute the distributed type.
389     AffineMap map = distributionMapFn(writeOp.getVector());
390     VectorType targetType =
391         getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392     if (!targetType)
393       return failure();
394 
395     // 2.5 Compute the distributed type for the new mask;
396     VectorType maskType;
397     if (writeOp.getMask()) {
398       // TODO: Distribution of masked writes with non-trivial permutation maps
399       // requires the distribution of the mask to elementwise match the
400       // distribution of the permuted written vector. Currently the details
401       // of which lane is responsible for which element is captured strictly
402       // by shape information on the warp op, and thus requires materializing
403       // the permutation in IR.
404       if (!writeOp.getPermutationMap().isMinorIdentity())
405         return failure();
406       maskType =
407           getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408     }
409 
410     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
411     // the rest.
412     vector::TransferWriteOp newWriteOp =
413         cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414 
415     // 4. Reindex the write using the distribution map.
416     auto newWarpOp =
417         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418 
419     // Delinearize the lane id based on the way threads are divided across the
420     // vector. To get the number of threads per vector dimension, divide the
421     // sequential size by the distributed size along each dim.
422     rewriter.setInsertionPoint(newWriteOp);
423     SmallVector<OpFoldResult> delinearizedIdSizes;
424     for (auto [seqSize, distSize] :
425          llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426       assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427       delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428     }
429     SmallVector<Value> delinearized;
430     if (map.getNumResults() > 1) {
431       delinearized = rewriter
432                          .create<mlir::affine::AffineDelinearizeIndexOp>(
433                              newWarpOp.getLoc(), newWarpOp.getLaneid(),
434                              delinearizedIdSizes)
435                          .getResults();
436     } else {
437       // If there is only one map result, we can elide the delinearization
438       // op and use the lane id directly.
439       delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440     }
441 
442     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443     Location loc = newWriteOp.getLoc();
444     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445                                newWriteOp.getIndices().end());
446     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447       AffineExpr d0, d1;
448       bindDims(newWarpOp.getContext(), d0, d1);
449       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450       if (!indexExpr)
451         continue;
452       unsigned indexPos = indexExpr.getPosition();
453       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454       Value laneId = delinearized[vectorPos];
455       auto scale =
456           rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
457       indices[indexPos] = affine::makeComposedAffineApply(
458           rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459     }
460     newWriteOp.getIndicesMutable().assign(indices);
461 
462     return success();
463   }
464 
465   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466   LogicalResult tryExtractOp(RewriterBase &rewriter,
467                              vector::TransferWriteOp writeOp,
468                              WarpExecuteOnLane0Op warpOp) const {
469     Location loc = writeOp.getLoc();
470     VectorType vecType = writeOp.getVectorType();
471 
472     if (vecType.getNumElements() > maxNumElementsToExtract) {
473       return rewriter.notifyMatchFailure(
474           warpOp,
475           llvm::formatv(
476               "writes more elements ({0}) than allowed to extract ({1})",
477               vecType.getNumElements(), maxNumElementsToExtract));
478     }
479 
480     // Do not process warp ops that contain only TransferWriteOps.
481     if (llvm::all_of(warpOp.getOps(),
482                      llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483       return failure();
484 
485     SmallVector<Value> yieldValues = {writeOp.getVector()};
486     SmallVector<Type> retTypes = {vecType};
487     SmallVector<size_t> newRetIndices;
488     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490     rewriter.setInsertionPointAfter(newWarpOp);
491 
492     // Create a second warp op that contains only writeOp.
493     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495     Block &body = secondWarpOp.getBodyRegion().front();
496     rewriter.setInsertionPointToStart(&body);
497     auto newWriteOp =
498         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
500     rewriter.eraseOp(writeOp);
501     rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
502     return success();
503   }
504 
505   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
506                                 PatternRewriter &rewriter) const override {
507     auto yield = cast<gpu::YieldOp>(
508         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
509     Operation *lastNode = yield->getPrevNode();
510     auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
511     if (!writeOp)
512       return failure();
513 
514     Value maybeMask = writeOp.getMask();
515     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
516           return writeOp.getVector() == value ||
517                  (maybeMask && maybeMask == value) ||
518                  warpOp.isDefinedOutsideOfRegion(value);
519         }))
520       return failure();
521 
522     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
523       return success();
524 
525     // Masked writes not supported for extraction.
526     if (writeOp.getMask())
527       return failure();
528 
529     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
530       return success();
531 
532     return failure();
533   }
534 
535 private:
536   /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
537   /// execute op with the proper return type. The new write op is updated to
538   /// write the result of the new warp execute op. The old `writeOp` is deleted.
539   vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
540                                        WarpExecuteOnLane0Op warpOp,
541                                        vector::TransferWriteOp writeOp,
542                                        VectorType targetType,
543                                        VectorType maybeMaskType) const {
544     assert(writeOp->getParentOp() == warpOp &&
545            "write must be nested immediately under warp");
546     OpBuilder::InsertionGuard g(rewriter);
547     SmallVector<size_t> newRetIndices;
548     WarpExecuteOnLane0Op newWarpOp;
549     if (maybeMaskType) {
550       newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
551           rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
552           TypeRange{targetType, maybeMaskType}, newRetIndices);
553     } else {
554       newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
555           rewriter, warpOp, ValueRange{{writeOp.getVector()}},
556           TypeRange{targetType}, newRetIndices);
557     }
558     rewriter.setInsertionPointAfter(newWarpOp);
559     auto newWriteOp =
560         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
561     rewriter.eraseOp(writeOp);
562     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
563     if (maybeMaskType)
564       newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
565     return newWriteOp;
566   }
567 
568   DistributionMapFn distributionMapFn;
569   unsigned maxNumElementsToExtract = 1;
570 };
571 
572 /// Sink out elementwise op feeding into a warp op yield.
573 /// ```
574 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
575 ///   ...
576 ///   %3 = arith.addf %1, %2 : vector<32xf32>
577 ///   gpu.yield %3 : vector<32xf32>
578 /// }
579 /// ```
580 /// To
581 /// ```
582 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
583 /// vector<1xf32>, vector<1xf32>) {
584 ///   ...
585 ///   %4 = arith.addf %2, %3 : vector<32xf32>
586 ///   gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
587 ///   vector<32xf32>
588 /// }
589 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
590 struct WarpOpElementwise : public WarpDistributionPattern {
591   using Base::Base;
592   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
593                                 PatternRewriter &rewriter) const override {
594     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
595       return OpTrait::hasElementwiseMappableTraits(op);
596     });
597     if (!yieldOperand)
598       return failure();
599 
600     Operation *elementWise = yieldOperand->get().getDefiningOp();
601     unsigned operandIndex = yieldOperand->getOperandNumber();
602     Value distributedVal = warpOp.getResult(operandIndex);
603     SmallVector<Value> yieldValues;
604     SmallVector<Type> retTypes;
605     Location loc = warpOp.getLoc();
606     for (OpOperand &operand : elementWise->getOpOperands()) {
607       Type targetType;
608       if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
609         // If the result type is a vector, the operands must also be vectors.
610         auto operandType = cast<VectorType>(operand.get().getType());
611         targetType =
612             VectorType::get(vecType.getShape(), operandType.getElementType());
613       } else {
614         auto operandType = operand.get().getType();
615         assert(!isa<VectorType>(operandType) &&
616                "unexpected yield of vector from op with scalar result type");
617         targetType = operandType;
618       }
619       retTypes.push_back(targetType);
620       yieldValues.push_back(operand.get());
621     }
622     SmallVector<size_t> newRetIndices;
623     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
624         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
625     rewriter.setInsertionPointAfter(newWarpOp);
626     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
627                                    elementWise->getOperands().end());
628     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
629       newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
630     }
631     OpBuilder::InsertionGuard g(rewriter);
632     rewriter.setInsertionPointAfter(newWarpOp);
633     Operation *newOp = cloneOpWithOperandsAndTypes(
634         rewriter, loc, elementWise, newOperands,
635         {newWarpOp.getResult(operandIndex).getType()});
636     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
637                                 newOp->getResult(0));
638     return success();
639   }
640 };
641 
642 /// Sink out splat constant op feeding into a warp op yield.
643 /// ```
644 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
645 ///   ...
646 ///   %cst = arith.constant dense<2.0> : vector<32xf32>
647 ///   gpu.yield %cst : vector<32xf32>
648 /// }
649 /// ```
650 /// To
651 /// ```
652 /// gpu.warp_execute_on_lane_0(%arg0 {
653 ///   ...
654 /// }
655 /// %0 = arith.constant dense<2.0> : vector<1xf32>
656 struct WarpOpConstant : public WarpDistributionPattern {
657   using Base::Base;
658   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
659                                 PatternRewriter &rewriter) const override {
660     OpOperand *yieldOperand =
661         getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
662     if (!yieldOperand)
663       return failure();
664     auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
665     auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
666     if (!dense)
667       return failure();
668     // Notify the rewriter that the warp op is changing (see the comment on
669     // the WarpOpTransferRead pattern).
670     rewriter.startOpModification(warpOp);
671     unsigned operandIndex = yieldOperand->getOperandNumber();
672     Attribute scalarAttr = dense.getSplatValue<Attribute>();
673     auto newAttr = DenseElementsAttr::get(
674         cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
675     Location loc = warpOp.getLoc();
676     rewriter.setInsertionPointAfter(warpOp);
677     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
678     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
679     rewriter.finalizeOpModification(warpOp);
680     return success();
681   }
682 };
683 
684 /// Sink out transfer_read op feeding into a warp op yield.
685 /// ```
686 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
687 ///   ...
688 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
689 //    vector<32xf32>
690 ///   gpu.yield %2 : vector<32xf32>
691 /// }
692 /// ```
693 /// To
694 /// ```
695 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
696 /// vector<1xf32>, vector<1xf32>) {
697 ///   ...
698 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
699 ///   vector<32xf32> gpu.yield %2 : vector<32xf32>
700 /// }
701 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
702 struct WarpOpTransferRead : public WarpDistributionPattern {
703   using Base::Base;
704   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
705                                 PatternRewriter &rewriter) const override {
706     // Try to find a distributable yielded read. Note that this pattern can
707     // still fail at the end after distribution, in which case this might have
708     // missed another distributable read.
709     OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
710       // Don't duplicate transfer_read ops when distributing.
711       return isa<vector::TransferReadOp>(op) && op->hasOneUse();
712     });
713     if (!operand)
714       return rewriter.notifyMatchFailure(
715           warpOp, "warp result is not a vector.transfer_read op");
716     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
717 
718     // Source must be defined outside of the region.
719     if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
720       return rewriter.notifyMatchFailure(
721           read, "source must be defined outside of the region");
722 
723     unsigned operandIndex = operand->getOperandNumber();
724     Value distributedVal = warpOp.getResult(operandIndex);
725 
726     SmallVector<Value, 4> indices(read.getIndices().begin(),
727                                   read.getIndices().end());
728     auto sequentialType = cast<VectorType>(read.getResult().getType());
729     auto distributedType = cast<VectorType>(distributedVal.getType());
730     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
731     AffineMap indexMap = map.compose(read.getPermutationMap());
732 
733     // Try to delinearize the lane ID to match the rank expected for
734     // distribution.
735     SmallVector<Value> delinearizedIds;
736     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
737                            distributedType.getShape(), warpOp.getWarpSize(),
738                            warpOp.getLaneid(), delinearizedIds)) {
739       return rewriter.notifyMatchFailure(
740           read, "cannot delinearize lane ID for distribution");
741     }
742     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
743 
744     // Distribute indices and the mask (if present).
745     OpBuilder::InsertionGuard g(rewriter);
746     SmallVector<Value> additionalResults(indices.begin(), indices.end());
747     SmallVector<Type> additionalResultTypes(indices.size(),
748                                             rewriter.getIndexType());
749     additionalResults.push_back(read.getPadding());
750     additionalResultTypes.push_back(read.getPadding().getType());
751 
752     bool hasMask = false;
753     if (read.getMask()) {
754       hasMask = true;
755       // TODO: Distribution of masked reads with non-trivial permutation maps
756       // requires the distribution of the mask to elementwise match the
757       // distribution of the permuted written vector. Currently the details
758       // of which lane is responsible for which element is captured strictly
759       // by shape information on the warp op, and thus requires materializing
760       // the permutation in IR.
761       if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
762         return rewriter.notifyMatchFailure(
763             read, "non-trivial permutation maps not supported");
764       VectorType maskType =
765           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
766       additionalResults.push_back(read.getMask());
767       additionalResultTypes.push_back(maskType);
768     }
769 
770     SmallVector<size_t> newRetIndices;
771     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
772         rewriter, warpOp, additionalResults, additionalResultTypes,
773         newRetIndices);
774     distributedVal = newWarpOp.getResult(operandIndex);
775 
776     // Distributed indices were appended first.
777     SmallVector<Value> newIndices;
778     for (int64_t i = 0, e = indices.size(); i < e; ++i)
779       newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
780 
781     rewriter.setInsertionPointAfter(newWarpOp);
782     for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
783       AffineExpr d0, d1;
784       bindDims(read.getContext(), d0, d1);
785       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
786       if (!indexExpr)
787         continue;
788       unsigned indexPos = indexExpr.getPosition();
789       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
790       int64_t scale = distributedType.getDimSize(vectorPos);
791       newIndices[indexPos] = affine::makeComposedAffineApply(
792           rewriter, read.getLoc(), d0 + scale * d1,
793           {newIndices[indexPos], delinearizedIds[vectorPos]});
794     }
795 
796     // Distributed padding value was appended right after the indices.
797     Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
798     // Distributed mask value was added at the end (if the op has a mask).
799     Value newMask =
800         hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
801                 : Value();
802     auto newRead = rewriter.create<vector::TransferReadOp>(
803         read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
804         read.getPermutationMapAttr(), newPadding, newMask,
805         read.getInBoundsAttr());
806 
807     rewriter.replaceAllUsesWith(distributedVal, newRead);
808     return success();
809   }
810 };
811 
812 /// Remove any result that has no use along with the matching yieldOp operand.
813 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
814 struct WarpOpDeadResult : public WarpDistributionPattern {
815   using Base::Base;
816   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
817                                 PatternRewriter &rewriter) const override {
818     SmallVector<Type> newResultTypes;
819     newResultTypes.reserve(warpOp->getNumResults());
820     SmallVector<Value> newYieldValues;
821     newYieldValues.reserve(warpOp->getNumResults());
822     DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
823     DenseMap<OpResult, int64_t> dedupResultPositionMap;
824     auto yield = cast<gpu::YieldOp>(
825         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
826 
827     // Some values may be yielded multiple times and correspond to multiple
828     // results. Deduplicating occurs by taking each result with its matching
829     // yielded value, and:
830     //   1. recording the unique first position at which the value is yielded.
831     //   2. recording for the result, the first position at which the dedup'ed
832     //      value is yielded.
833     //   3. skipping from the new result types / new yielded values any result
834     //      that has no use or whose yielded value has already been seen.
835     for (OpResult result : warpOp.getResults()) {
836       Value yieldOperand = yield.getOperand(result.getResultNumber());
837       auto it = dedupYieldOperandPositionMap.insert(
838           std::make_pair(yieldOperand, newResultTypes.size()));
839       dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
840       if (result.use_empty() || !it.second)
841         continue;
842       newResultTypes.push_back(result.getType());
843       newYieldValues.push_back(yieldOperand);
844     }
845     // No modification, exit early.
846     if (yield.getNumOperands() == newYieldValues.size())
847       return failure();
848     // Move the body of the old warpOp to a new warpOp.
849     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
850         rewriter, warpOp, newYieldValues, newResultTypes);
851 
852     // Simplify the new warp op after dropping dead results.
853     newWarpOp.getBody()->walk([&](Operation *op) {
854       if (isOpTriviallyDead(op))
855         rewriter.eraseOp(op);
856     });
857 
858     // Replace results of the old warpOp by the new, deduplicated results.
859     SmallVector<Value> newValues;
860     newValues.reserve(warpOp->getNumResults());
861     for (OpResult result : warpOp.getResults()) {
862       if (result.use_empty())
863         newValues.push_back(Value());
864       else
865         newValues.push_back(
866             newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
867     }
868     rewriter.replaceOp(warpOp, newValues);
869     return success();
870   }
871 };
872 
873 // If an operand is directly yielded out of the region we can forward it
874 // directly and it doesn't need to go through the region.
875 struct WarpOpForwardOperand : public WarpDistributionPattern {
876   using Base::Base;
877   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
878                                 PatternRewriter &rewriter) const override {
879     SmallVector<Type> resultTypes;
880     SmallVector<Value> yieldValues;
881     auto yield = cast<gpu::YieldOp>(
882         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
883     Value valForwarded;
884     unsigned resultIndex;
885     for (OpOperand &operand : yield->getOpOperands()) {
886       Value result = warpOp.getResult(operand.getOperandNumber());
887       if (result.use_empty())
888         continue;
889 
890       // Assume all the values coming from above are uniform.
891       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
892         if (result.getType() != operand.get().getType())
893           continue;
894         valForwarded = operand.get();
895         resultIndex = operand.getOperandNumber();
896         break;
897       }
898       auto arg = dyn_cast<BlockArgument>(operand.get());
899       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
900         continue;
901       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
902       if (result.getType() != warpOperand.getType())
903         continue;
904       valForwarded = warpOperand;
905       resultIndex = operand.getOperandNumber();
906       break;
907     }
908     if (!valForwarded)
909       return failure();
910     // Notify the rewriter that the warp op is changing (see the comment on
911     // the WarpOpTransferRead pattern).
912     rewriter.startOpModification(warpOp);
913     rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
914     rewriter.finalizeOpModification(warpOp);
915     return success();
916   }
917 };
918 
919 struct WarpOpBroadcast : public WarpDistributionPattern {
920   using Base::Base;
921   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
922                                 PatternRewriter &rewriter) const override {
923     OpOperand *operand =
924         getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
925     if (!operand)
926       return failure();
927     unsigned int operandNumber = operand->getOperandNumber();
928     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
929     Location loc = broadcastOp.getLoc();
930     auto destVecType =
931         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932     Value broadcastSrc = broadcastOp.getSource();
933     Type broadcastSrcType = broadcastSrc.getType();
934 
935     // Check that the broadcast actually spans a set of values uniformly across
936     // all threads. In other words, check that each thread can reconstruct
937     // their own broadcast.
938     // For that we simply check that the broadcast we want to build makes sense.
939     if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
940         vector::BroadcastableToResult::Success)
941       return failure();
942     SmallVector<size_t> newRetIndices;
943     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944         rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
945     rewriter.setInsertionPointAfter(newWarpOp);
946     Value broadcasted = rewriter.create<vector::BroadcastOp>(
947         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
948     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
949                                 broadcasted);
950     return success();
951   }
952 };
953 
954 /// Pattern to move shape cast out of the warp op. shape cast is basically a
955 /// no-op for warp distribution; we need to handle the shape though.
956 struct WarpOpShapeCast : public WarpDistributionPattern {
957   using Base::Base;
958   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
959                                 PatternRewriter &rewriter) const override {
960     OpOperand *operand =
961         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
962     if (!operand)
963       return failure();
964 
965     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
966 
967     unsigned int operandNumber = operand->getOperandNumber();
968     auto castDistributedType =
969         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970     VectorType castOriginalType = oldCastOp.getSourceVectorType();
971     VectorType castResultType = castDistributedType;
972 
973     // We expect the distributed type to have a smaller rank than the original
974     // type. Prepend with size-one dimensions to make them the same.
975     unsigned castDistributedRank = castDistributedType.getRank();
976     unsigned castOriginalRank = castOriginalType.getRank();
977     if (castDistributedRank < castOriginalRank) {
978       SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
979       llvm::append_range(shape, castDistributedType.getShape());
980       castDistributedType =
981           VectorType::get(shape, castDistributedType.getElementType());
982     }
983 
984     SmallVector<size_t> newRetIndices;
985     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986         rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
987         newRetIndices);
988     rewriter.setInsertionPointAfter(newWarpOp);
989     Value newCast = rewriter.create<vector::ShapeCastOp>(
990         oldCastOp.getLoc(), castResultType,
991         newWarpOp->getResult(newRetIndices[0]));
992     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
993     return success();
994   }
995 };
996 
997 /// Sink out vector.create_mask op feeding into a warp op yield.
998 /// ```
999 /// %0 = ...
1000 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1001 ///   ...
1002 ///   %mask = vector.create_mask %0 : vector<32xi1>
1003 ///   gpu.yield %mask : vector<32xi1>
1004 /// }
1005 /// ```
1006 /// To
1007 /// ```
1008 /// %0 = ...
1009 /// gpu.warp_execute_on_lane_0(%arg0) {
1010 ///   ...
1011 /// }
1012 /// %cmp = arith.cmpi ult, %laneid, %0
1013 /// %ub = arith.select %cmp, %c0, %c1
1014 /// %1 = vector.create_mask %ub : vector<1xi1>
1015 struct WarpOpCreateMask : public WarpDistributionPattern {
1016   using Base::Base;
1017   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1018                                 PatternRewriter &rewriter) const override {
1019     OpOperand *yieldOperand =
1020         getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1021     if (!yieldOperand)
1022       return failure();
1023 
1024     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1025 
1026     // Early exit if any values needed for calculating the new mask indices
1027     // are defined inside the warp op.
1028     if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029           return warpOp.isDefinedOutsideOfRegion(value);
1030         }))
1031       return failure();
1032 
1033     Location loc = mask.getLoc();
1034     unsigned operandIndex = yieldOperand->getOperandNumber();
1035 
1036     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037     VectorType seqType = mask.getVectorType();
1038     ArrayRef<int64_t> seqShape = seqType.getShape();
1039     ArrayRef<int64_t> distShape = distType.getShape();
1040 
1041     rewriter.setInsertionPointAfter(warpOp);
1042 
1043     // Delinearize the lane ID for constructing the distributed mask sizes.
1044     SmallVector<Value> delinearizedIds;
1045     if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046                            warpOp.getWarpSize(), warpOp.getLaneid(),
1047                            delinearizedIds))
1048       return rewriter.notifyMatchFailure(
1049           mask, "cannot delinearize lane ID for distribution");
1050     assert(!delinearizedIds.empty());
1051 
1052     // Notify the rewriter that the warp op is changing (see the comment on
1053     // the WarpOpTransferRead pattern).
1054     rewriter.startOpModification(warpOp);
1055 
1056     AffineExpr s0, s1;
1057     bindSymbols(rewriter.getContext(), s0, s1);
1058     SmallVector<Value> newOperands;
1059     for (int i = 0, e = distShape.size(); i < e; ++i) {
1060       // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1061       // find the distance from the largest mask index owned by this lane to the
1062       // original mask size. `vector.create_mask` implicitly clamps mask
1063       // operands to the range [0, mask_vector_size[i]], or in other words, the
1064       // mask sizes are always in the range [0, mask_vector_size[i]).
1065       Value maskDimIdx = affine::makeComposedAffineApply(
1066           rewriter, loc, s1 - s0 * distShape[i],
1067           {delinearizedIds[i], mask.getOperand(i)});
1068       newOperands.push_back(maskDimIdx);
1069     }
1070 
1071     auto newMask =
1072         rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1073     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1074     rewriter.finalizeOpModification(warpOp);
1075     return success();
1076   }
1077 };
1078 
1079 /// Pattern to move out vector.extract of single element vector. Those don't
1080 /// need to be distributed and can just be propagated outside of the region.
1081 struct WarpOpExtract : public WarpDistributionPattern {
1082   using Base::Base;
1083   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084                                 PatternRewriter &rewriter) const override {
1085     OpOperand *operand =
1086         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1087     if (!operand)
1088       return failure();
1089     unsigned int operandNumber = operand->getOperandNumber();
1090     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1091     VectorType extractSrcType = extractOp.getSourceVectorType();
1092     Location loc = extractOp.getLoc();
1093 
1094     // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1095     if (extractSrcType.getRank() <= 1) {
1096       return failure();
1097     }
1098 
1099     // All following cases are 2d or higher dimensional source vectors.
1100 
1101     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1102       // There is no distribution, this is a broadcast. Simply move the extract
1103       // out of the warp op.
1104       // TODO: This could be optimized. E.g., in case of a scalar result, let
1105       // one lane extract and shuffle the result to all other lanes (same as
1106       // the 1d case).
1107       SmallVector<size_t> newRetIndices;
1108       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109           rewriter, warpOp, {extractOp.getVector()},
1110           {extractOp.getSourceVectorType()}, newRetIndices);
1111       rewriter.setInsertionPointAfter(newWarpOp);
1112       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1113       // Extract from distributed vector.
1114       Value newExtract = rewriter.create<vector::ExtractOp>(
1115           loc, distributedVec, extractOp.getMixedPosition());
1116       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1117                                   newExtract);
1118       return success();
1119     }
1120 
1121     // Find the distributed dimension. There should be exactly one.
1122     auto distributedType =
1123         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124     auto yieldedType = cast<VectorType>(operand->get().getType());
1125     int64_t distributedDim = -1;
1126     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127       if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1129         // support distributing multiple dimensions in the future.
1130         assert(distributedDim == -1 && "found multiple distributed dims");
1131         distributedDim = i;
1132       }
1133     }
1134     assert(distributedDim != -1 && "could not find distributed dimension");
1135     (void)distributedDim;
1136 
1137     // Yield source vector from warp op.
1138     SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1139     for (int i = 0; i < distributedType.getRank(); ++i)
1140       newDistributedShape[i + extractOp.getNumIndices()] =
1141           distributedType.getDimSize(i);
1142     auto newDistributedType =
1143         VectorType::get(newDistributedShape, distributedType.getElementType());
1144     SmallVector<size_t> newRetIndices;
1145     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146         rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147         newRetIndices);
1148     rewriter.setInsertionPointAfter(newWarpOp);
1149     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1150     // Extract from distributed vector.
1151     Value newExtract = rewriter.create<vector::ExtractOp>(
1152         loc, distributedVec, extractOp.getMixedPosition());
1153     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1154                                 newExtract);
1155     return success();
1156   }
1157 };
1158 
1159 /// Pattern to move out vector.extract with a scalar result.
1160 /// Only supports 1-D and 0-D sources for now.
1161 struct WarpOpExtractScalar : public WarpDistributionPattern {
1162   WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1163                       PatternBenefit b = 1)
1164       : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1165   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1166                                 PatternRewriter &rewriter) const override {
1167     OpOperand *operand =
1168         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1169     if (!operand)
1170       return failure();
1171     unsigned int operandNumber = operand->getOperandNumber();
1172     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1173     VectorType extractSrcType = extractOp.getSourceVectorType();
1174     // Only supports 1-D or 0-D sources for now.
1175     if (extractSrcType.getRank() > 1) {
1176       return rewriter.notifyMatchFailure(
1177           extractOp, "only 0-D or 1-D source supported for now");
1178     }
1179     // TODO: Supported shuffle types should be parameterizable, similar to
1180     // `WarpShuffleFromIdxFn`.
1181     if (!extractSrcType.getElementType().isF32() &&
1182         !extractSrcType.getElementType().isInteger(32))
1183       return rewriter.notifyMatchFailure(
1184           extractOp, "only f32/i32 element types are supported");
1185     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186     Type elType = extractSrcType.getElementType();
1187     VectorType distributedVecType;
1188     if (!is0dOrVec1Extract) {
1189       assert(extractSrcType.getRank() == 1 &&
1190              "expected that extract src rank is 0 or 1");
1191       if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192         return failure();
1193       int64_t elementsPerLane =
1194           extractSrcType.getShape()[0] / warpOp.getWarpSize();
1195       distributedVecType = VectorType::get({elementsPerLane}, elType);
1196     } else {
1197       distributedVecType = extractSrcType;
1198     }
1199     // Yield source vector and position (if present) from warp op.
1200     SmallVector<Value> additionalResults{extractOp.getVector()};
1201     SmallVector<Type> additionalResultTypes{distributedVecType};
1202     additionalResults.append(
1203         SmallVector<Value>(extractOp.getDynamicPosition()));
1204     additionalResultTypes.append(
1205         SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1206 
1207     Location loc = extractOp.getLoc();
1208     SmallVector<size_t> newRetIndices;
1209     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210         rewriter, warpOp, additionalResults, additionalResultTypes,
1211         newRetIndices);
1212     rewriter.setInsertionPointAfter(newWarpOp);
1213     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1214 
1215     // 0d extract: The new warp op broadcasts the source vector to all lanes.
1216     // All lanes extract the scalar.
1217     if (is0dOrVec1Extract) {
1218       Value newExtract;
1219       SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1220       newExtract =
1221           rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1222       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1223                                   newExtract);
1224       return success();
1225     }
1226 
1227     int64_t staticPos = extractOp.getStaticPosition()[0];
1228     OpFoldResult pos = ShapedType::isDynamic(staticPos)
1229                            ? (newWarpOp->getResult(newRetIndices[1]))
1230                            : OpFoldResult(rewriter.getIndexAttr(staticPos));
1231     // 1d extract: Distribute the source vector. One lane extracts and shuffles
1232     // the value to all other lanes.
1233     int64_t elementsPerLane = distributedVecType.getShape()[0];
1234     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1235     // tid of extracting thread: pos / elementsPerLane
1236     Value broadcastFromTid = affine::makeComposedAffineApply(
1237         rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1238     // Extract at position: pos % elementsPerLane
1239     Value newPos =
1240         elementsPerLane == 1
1241             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1242             : affine::makeComposedAffineApply(rewriter, loc,
1243                                               sym0 % elementsPerLane, pos);
1244     Value extracted =
1245         rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1246 
1247     // Shuffle the extracted value to all lanes.
1248     Value shuffled = warpShuffleFromIdxFn(
1249         loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1250     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1251     return success();
1252   }
1253 
1254 private:
1255   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1256 };
1257 
1258 /// Pattern to convert vector.extractelement to vector.extract.
1259 struct WarpOpExtractElement : public WarpDistributionPattern {
1260   using Base::Base;
1261   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1262                                 PatternRewriter &rewriter) const override {
1263     OpOperand *operand =
1264         getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1265     if (!operand)
1266       return failure();
1267     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1268     SmallVector<OpFoldResult> indices;
1269     if (auto pos = extractOp.getPosition()) {
1270       indices.push_back(pos);
1271     }
1272     rewriter.setInsertionPoint(extractOp);
1273     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1274         extractOp, extractOp.getVector(), indices);
1275     return success();
1276   }
1277 };
1278 
1279 /// Pattern to move out vector.insert with a scalar input.
1280 /// Only supports 1-D and 0-D destinations for now.
1281 struct WarpOpInsertScalar : public WarpDistributionPattern {
1282   using Base::Base;
1283   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1284                                 PatternRewriter &rewriter) const override {
1285     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1286     if (!operand)
1287       return failure();
1288     unsigned int operandNumber = operand->getOperandNumber();
1289     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1290     VectorType vecType = insertOp.getDestVectorType();
1291     VectorType distrType =
1292         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1293 
1294     // Only supports 1-D or 0-D destinations for now.
1295     if (vecType.getRank() > 1) {
1296       return rewriter.notifyMatchFailure(
1297           insertOp, "only 0-D or 1-D source supported for now");
1298     }
1299 
1300     // Yield destination vector, source scalar and position from warp op.
1301     SmallVector<Value> additionalResults{insertOp.getDest(),
1302                                          insertOp.getSource()};
1303     SmallVector<Type> additionalResultTypes{distrType,
1304                                             insertOp.getSource().getType()};
1305     additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1306     additionalResultTypes.append(
1307         SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1308 
1309     Location loc = insertOp.getLoc();
1310     SmallVector<size_t> newRetIndices;
1311     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312         rewriter, warpOp, additionalResults, additionalResultTypes,
1313         newRetIndices);
1314     rewriter.setInsertionPointAfter(newWarpOp);
1315     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316     Value newSource = newWarpOp->getResult(newRetIndices[1]);
1317     rewriter.setInsertionPointAfter(newWarpOp);
1318 
1319     OpFoldResult pos;
1320     if (vecType.getRank() != 0) {
1321       int64_t staticPos = insertOp.getStaticPosition()[0];
1322       pos = ShapedType::isDynamic(staticPos)
1323                 ? (newWarpOp->getResult(newRetIndices[2]))
1324                 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1325     }
1326 
1327     // This condition is always true for 0-d vectors.
1328     if (vecType == distrType) {
1329       Value newInsert;
1330       SmallVector<OpFoldResult> indices;
1331       if (pos) {
1332         indices.push_back(pos);
1333       }
1334       newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1335                                                     distributedVec, indices);
1336       // Broadcast: Simply move the vector.insert op out.
1337       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1338                                   newInsert);
1339       return success();
1340     }
1341 
1342     // This is a distribution. Only one lane should insert.
1343     int64_t elementsPerLane = distrType.getShape()[0];
1344     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1345     // tid of extracting thread: pos / elementsPerLane
1346     Value insertingLane = affine::makeComposedAffineApply(
1347         rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1348     // Insert position: pos % elementsPerLane
1349     OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1350         rewriter, loc, sym0 % elementsPerLane, pos);
1351     Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1352         loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1353     Value newResult =
1354         rewriter
1355             .create<scf::IfOp>(
1356                 loc, isInsertingLane,
1357                 /*thenBuilder=*/
1358                 [&](OpBuilder &builder, Location loc) {
1359                   Value newInsert = builder.create<vector::InsertOp>(
1360                       loc, newSource, distributedVec, newPos);
1361                   builder.create<scf::YieldOp>(loc, newInsert);
1362                 },
1363                 /*elseBuilder=*/
1364                 [&](OpBuilder &builder, Location loc) {
1365                   builder.create<scf::YieldOp>(loc, distributedVec);
1366                 })
1367             .getResult(0);
1368     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1369     return success();
1370   }
1371 };
1372 
1373 struct WarpOpInsert : public WarpDistributionPattern {
1374   using Base::Base;
1375   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1376                                 PatternRewriter &rewriter) const override {
1377     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1378     if (!operand)
1379       return failure();
1380     unsigned int operandNumber = operand->getOperandNumber();
1381     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1382     Location loc = insertOp.getLoc();
1383 
1384     // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1385     if (insertOp.getDestVectorType().getRank() <= 1) {
1386       return failure();
1387     }
1388 
1389     // All following cases are 2d or higher dimensional source vectors.
1390 
1391     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1392       // There is no distribution, this is a broadcast. Simply move the insert
1393       // out of the warp op.
1394       SmallVector<size_t> newRetIndices;
1395       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396           rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1397           {insertOp.getSourceType(), insertOp.getDestVectorType()},
1398           newRetIndices);
1399       rewriter.setInsertionPointAfter(newWarpOp);
1400       Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401       Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402       Value newResult = rewriter.create<vector::InsertOp>(
1403           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1404       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1405                                   newResult);
1406       return success();
1407     }
1408 
1409     // Find the distributed dimension. There should be exactly one.
1410     auto distrDestType =
1411         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412     auto yieldedType = cast<VectorType>(operand->get().getType());
1413     int64_t distrDestDim = -1;
1414     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415       if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1416         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1417         // support distributing multiple dimensions in the future.
1418         assert(distrDestDim == -1 && "found multiple distributed dims");
1419         distrDestDim = i;
1420       }
1421     }
1422     assert(distrDestDim != -1 && "could not find distributed dimension");
1423 
1424     // Compute the distributed source vector type.
1425     VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1426     SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1427     // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1428     // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1429     //         insert a smaller vector<3xf32>.
1430     // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1431     //         case, one lane will insert the source vector<96xf32>. The other
1432     //         lanes will not do anything.
1433     int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434     if (distrSrcDim >= 0)
1435       distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1436     auto distrSrcType =
1437         VectorType::get(distrSrcShape, distrDestType.getElementType());
1438 
1439     // Yield source and dest vectors from warp op.
1440     SmallVector<size_t> newRetIndices;
1441     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442         rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1443         {distrSrcType, distrDestType}, newRetIndices);
1444     rewriter.setInsertionPointAfter(newWarpOp);
1445     Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446     Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1447 
1448     // Insert into the distributed vector.
1449     Value newResult;
1450     if (distrSrcDim >= 0) {
1451       // Every lane inserts a small piece.
1452       newResult = rewriter.create<vector::InsertOp>(
1453           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1454     } else {
1455       // One lane inserts the entire source vector.
1456       int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1457       SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1458       SmallVector<int64_t> newPos = getAsIntegers(pos);
1459       // tid of inserting lane: pos / elementsPerLane
1460       Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1461           loc, newPos[distrDestDim] / elementsPerLane);
1462       Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1463           loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1464       // Insert position: pos % elementsPerLane
1465       newPos[distrDestDim] %= elementsPerLane;
1466       auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1467         Value newInsert = builder.create<vector::InsertOp>(
1468             loc, distributedSrc, distributedDest, newPos);
1469         builder.create<scf::YieldOp>(loc, newInsert);
1470       };
1471       auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1472         builder.create<scf::YieldOp>(loc, distributedDest);
1473       };
1474       newResult = rewriter
1475                       .create<scf::IfOp>(loc, isInsertingLane,
1476                                          /*thenBuilder=*/insertingBuilder,
1477                                          /*elseBuilder=*/nonInsertingBuilder)
1478                       .getResult(0);
1479     }
1480 
1481     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1482     return success();
1483   }
1484 };
1485 
1486 struct WarpOpInsertElement : public WarpDistributionPattern {
1487   using Base::Base;
1488   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1489                                 PatternRewriter &rewriter) const override {
1490     OpOperand *operand =
1491         getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1492     if (!operand)
1493       return failure();
1494     auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1495     SmallVector<OpFoldResult> indices;
1496     if (auto pos = insertOp.getPosition()) {
1497       indices.push_back(pos);
1498     }
1499     rewriter.setInsertionPoint(insertOp);
1500     rewriter.replaceOpWithNewOp<vector::InsertOp>(
1501         insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1502     return success();
1503   }
1504 };
1505 
1506 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1507 /// the scf.ForOp is the last operation in the region so that it doesn't
1508 /// change the order of execution. This creates a new scf.for region after the
1509 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1510 /// WarpExecuteOnLane0Op region. Example:
1511 /// ```
1512 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1513 ///   ...
1514 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1515 ///   -> (vector<128xf32>) {
1516 ///     ...
1517 ///     scf.yield %r : vector<128xf32>
1518 ///   }
1519 ///   gpu.yield %v1 : vector<128xf32>
1520 /// }
1521 /// ```
1522 /// To:
1523 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1524 ///   ...
1525 ///   gpu.yield %v : vector<128xf32>
1526 /// }
1527 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1528 ///   -> (vector<4xf32>) {
1529 ///     %iw = gpu.warp_execute_on_lane_0(%laneid)
1530 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1531 ///     ^bb0(%arg: vector<128xf32>):
1532 ///       ...
1533 ///       gpu.yield %ir : vector<128xf32>
1534 ///     }
1535 ///     scf.yield %iw : vector<4xf32>
1536 ///  }
1537 /// ```
1538 struct WarpOpScfForOp : public WarpDistributionPattern {
1539 
1540   WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1541       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1542   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1543                                 PatternRewriter &rewriter) const override {
1544     auto yield = cast<gpu::YieldOp>(
1545         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1546     // Only pick up forOp if it is the last op in the region.
1547     Operation *lastNode = yield->getPrevNode();
1548     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1549     if (!forOp)
1550       return failure();
1551     // Collect Values that come from the warp op but are outside the forOp.
1552     // Those Value needs to be returned by the original warpOp and passed to
1553     // the new op.
1554     llvm::SmallSetVector<Value, 32> escapingValues;
1555     SmallVector<Type> inputTypes;
1556     SmallVector<Type> distTypes;
1557     mlir::visitUsedValuesDefinedAbove(
1558         forOp.getBodyRegion(), [&](OpOperand *operand) {
1559           Operation *parent = operand->get().getParentRegion()->getParentOp();
1560           if (warpOp->isAncestor(parent)) {
1561             if (!escapingValues.insert(operand->get()))
1562               return;
1563             Type distType = operand->get().getType();
1564             if (auto vecType = dyn_cast<VectorType>(distType)) {
1565               AffineMap map = distributionMapFn(operand->get());
1566               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1567             }
1568             inputTypes.push_back(operand->get().getType());
1569             distTypes.push_back(distType);
1570           }
1571         });
1572 
1573     if (llvm::is_contained(distTypes, Type{}))
1574       return failure();
1575 
1576     SmallVector<size_t> newRetIndices;
1577     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1578         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1579         newRetIndices);
1580     yield = cast<gpu::YieldOp>(
1581         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1582 
1583     SmallVector<Value> newOperands;
1584     SmallVector<unsigned> resultIdx;
1585     // Collect all the outputs coming from the forOp.
1586     for (OpOperand &yieldOperand : yield->getOpOperands()) {
1587       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1588         continue;
1589       auto forResult = cast<OpResult>(yieldOperand.get());
1590       newOperands.push_back(
1591           newWarpOp.getResult(yieldOperand.getOperandNumber()));
1592       yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1593       resultIdx.push_back(yieldOperand.getOperandNumber());
1594     }
1595 
1596     OpBuilder::InsertionGuard g(rewriter);
1597     rewriter.setInsertionPointAfter(newWarpOp);
1598 
1599     // Create a new for op outside the region with a WarpExecuteOnLane0Op
1600     // region inside.
1601     auto newForOp = rewriter.create<scf::ForOp>(
1602         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1603         forOp.getStep(), newOperands);
1604     rewriter.setInsertionPointToStart(newForOp.getBody());
1605 
1606     SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1607                                  newForOp.getRegionIterArgs().end());
1608     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1609                                     forOp.getResultTypes().end());
1610     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1611     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1612       warpInput.push_back(newWarpOp.getResult(retIdx));
1613       argIndexMapping[escapingValues[i]] = warpInputType.size();
1614       warpInputType.push_back(inputTypes[i]);
1615     }
1616     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1617         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1618         newWarpOp.getWarpSize(), warpInput, warpInputType);
1619 
1620     SmallVector<Value> argMapping;
1621     argMapping.push_back(newForOp.getInductionVar());
1622     for (Value args : innerWarp.getBody()->getArguments()) {
1623       argMapping.push_back(args);
1624     }
1625     argMapping.resize(forOp.getBody()->getNumArguments());
1626     SmallVector<Value> yieldOperands;
1627     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1628       yieldOperands.push_back(operand);
1629     rewriter.eraseOp(forOp.getBody()->getTerminator());
1630     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1631     rewriter.setInsertionPointToEnd(innerWarp.getBody());
1632     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1633     rewriter.setInsertionPointAfter(innerWarp);
1634     if (!innerWarp.getResults().empty())
1635       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1636     rewriter.eraseOp(forOp);
1637     // Replace the warpOp result coming from the original ForOp.
1638     for (const auto &res : llvm::enumerate(resultIdx)) {
1639       rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1640                                   newForOp.getResult(res.index()));
1641       newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1642     }
1643     newForOp.walk([&](Operation *op) {
1644       for (OpOperand &operand : op->getOpOperands()) {
1645         auto it = argIndexMapping.find(operand.get());
1646         if (it == argIndexMapping.end())
1647           continue;
1648         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1649       }
1650     });
1651 
1652     // Finally, hoist out any now uniform code from the inner warp op.
1653     mlir::vector::moveScalarUniformCode(innerWarp);
1654     return success();
1655   }
1656 
1657 private:
1658   DistributionMapFn distributionMapFn;
1659 };
1660 
1661 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1662 /// The vector is reduced in parallel. Currently limited to vector size
1663 /// matching the warpOp size. E.g.:
1664 /// ```
1665 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1666 ///   %0 = "some_def"() : () -> (vector<32xf32>)
1667 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1668 ///   gpu.yield %1 : f32
1669 /// }
1670 /// ```
1671 /// is lowered to:
1672 /// ```
1673 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1674 ///   %1 = "some_def"() : () -> (vector<32xf32>)
1675 ///   gpu.yield %1 : vector<32xf32>
1676 /// }
1677 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1678 /// %r = ("warp.reduction %a")
1679 /// ```
1680 struct WarpOpReduction : public WarpDistributionPattern {
1681   WarpOpReduction(MLIRContext *context,
1682                   DistributedReductionFn distributedReductionFn,
1683                   PatternBenefit benefit = 1)
1684       : WarpDistributionPattern(context, benefit),
1685         distributedReductionFn(std::move(distributedReductionFn)) {}
1686 
1687   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1688                                 PatternRewriter &rewriter) const override {
1689     OpOperand *yieldOperand =
1690         getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1691     if (!yieldOperand)
1692       return failure();
1693 
1694     auto reductionOp =
1695         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1696     auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1697     // Only rank 1 vectors supported.
1698     if (vectorType.getRank() != 1)
1699       return rewriter.notifyMatchFailure(
1700           warpOp, "Only rank 1 reductions can be distributed.");
1701     // Only warp_size-sized vectors supported.
1702     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1703       return rewriter.notifyMatchFailure(
1704           warpOp, "Reduction vector dimension must match was size.");
1705     if (!reductionOp.getType().isIntOrFloat())
1706       return rewriter.notifyMatchFailure(
1707           warpOp, "Reduction distribution currently only supports floats and "
1708                   "integer types.");
1709 
1710     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1711     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1712     unsigned operandIndex = yieldOperand->getOperandNumber();
1713     SmallVector<Value> yieldValues = {reductionOp.getVector()};
1714     SmallVector<Type> retTypes = {
1715         VectorType::get({numElements}, reductionOp.getType())};
1716     if (reductionOp.getAcc()) {
1717       yieldValues.push_back(reductionOp.getAcc());
1718       retTypes.push_back(reductionOp.getAcc().getType());
1719     }
1720     SmallVector<size_t> newRetIndices;
1721     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1722         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1723     rewriter.setInsertionPointAfter(newWarpOp);
1724 
1725     // Obtain data to reduce for a single lane.
1726     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1727     // Distribute and reduce across threads.
1728     Value fullReduce =
1729         distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1730                                reductionOp.getKind(), newWarpOp.getWarpSize());
1731     if (reductionOp.getAcc()) {
1732       fullReduce = vector::makeArithReduction(
1733           rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1734           newWarpOp.getResult(newRetIndices[1]));
1735     }
1736     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1737     return success();
1738   }
1739 
1740 private:
1741   DistributedReductionFn distributedReductionFn;
1742 };
1743 
1744 } // namespace
1745 
1746 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1747     RewritePatternSet &patterns,
1748     const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
1749   patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1750 }
1751 
1752 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1753     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1754     unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1755   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1756                                     maxNumElementsToExtract, benefit);
1757 }
1758 
1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1760     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1761     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1762     PatternBenefit readBenefit) {
1763   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1764   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1765                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1766                WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1767                WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1768       patterns.getContext(), benefit);
1769   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1770                                     benefit);
1771   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1772                                benefit);
1773 }
1774 
1775 void mlir::vector::populateDistributeReduction(
1776     RewritePatternSet &patterns,
1777     const DistributedReductionFn &distributedReductionFn,
1778     PatternBenefit benefit) {
1779   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1780                                 benefit);
1781 }
1782 
1783 /// Helper to know if an op can be hoisted out of the region.
1784 static bool canBeHoisted(Operation *op,
1785                          function_ref<bool(Value)> definedOutside) {
1786   return llvm::all_of(op->getOperands(), definedOutside) &&
1787          isMemoryEffectFree(op) && op->getNumRegions() == 0;
1788 }
1789 
1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1791   Block *body = warpOp.getBody();
1792 
1793   // Keep track of the ops we want to hoist.
1794   llvm::SmallSetVector<Operation *, 8> opsToMove;
1795 
1796   // Helper to check if a value is or will be defined outside of the region.
1797   auto isDefinedOutsideOfBody = [&](Value value) {
1798     auto *definingOp = value.getDefiningOp();
1799     return (definingOp && opsToMove.count(definingOp)) ||
1800            warpOp.isDefinedOutsideOfRegion(value);
1801   };
1802 
1803   // Do not use walk here, as we do not want to go into nested regions and hoist
1804   // operations from there.
1805   for (auto &op : body->without_terminator()) {
1806     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1807       return isa<VectorType>(result.getType());
1808     });
1809     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1810       opsToMove.insert(&op);
1811   }
1812 
1813   // Move all the ops marked as uniform outside of the region.
1814   for (Operation *op : opsToMove)
1815     op->moveBefore(warpOp);
1816 }
1817