xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (revision 1609f1c2a5ecc0e0e28f433ec9205122478ddaa3)
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Transforms/RegionUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include <numeric>
20 #include <utility>
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 /// Currently the distribution map is implicit based on the vector shape. In the
26 /// future it will be part of the op.
27 /// Example:
28 /// ```
29 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
30 ///   ...
31 ///   vector.yield %3 : vector<32x16x64xf32>
32 /// }
33 /// ```
34 /// Would have an implicit map of:
35 /// `(d0, d1, d2) -> (d0, d2)`
36 static AffineMap calculateImplicitMap(VectorType sequentialType,
37                                       VectorType distributedType) {
38   SmallVector<AffineExpr> perm;
39   perm.reserve(1);
40   // Check which dimensions of the sequential type are different than the
41   // dimensions of the distributed type to know the distributed dimensions. Then
42   // associate each distributed dimension to an ID in order.
43   for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
44     if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
45       perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
46   }
47   auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
48                             distributedType.getContext());
49   return map;
50 }
51 
52 namespace {
53 
54 /// Helper struct to create the load / store operations that permit transit
55 /// through the parallel / sequential and the sequential / parallel boundaries
56 /// when performing `rewriteWarpOpToScfFor`.
57 ///
58 /// The vector distribution dimension is inferred from the vector types.
59 struct DistributedLoadStoreHelper {
60   DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
61                              Value laneId, Value zero)
62       : sequentialVal(sequentialVal), distributedVal(distributedVal),
63         laneId(laneId), zero(zero) {
64     sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
65     distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
66     if (sequentialVectorType && distributedVectorType)
67       distributionMap =
68           calculateImplicitMap(sequentialVectorType, distributedVectorType);
69   }
70 
71   Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
72     int64_t distributedSize = distributedVectorType.getDimSize(index);
73     AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
74     return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
75                                                  ArrayRef<Value>{laneId});
76   }
77 
78   /// Create a store during the process of distributing the
79   /// `vector.warp_execute_on_thread_0` op.
80   /// Vector distribution assumes the following convention regarding the
81   /// temporary buffers that are created to transition values. This **must**
82   /// be properly specified in the `options.warpAllocationFn`:
83   ///   1. scalars of type T transit through a memref<1xT>.
84   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
85   Operation *buildStore(RewriterBase &b, Location loc, Value val,
86                         Value buffer) {
87     assert((val == distributedVal || val == sequentialVal) &&
88            "Must store either the preregistered distributed or the "
89            "preregistered sequential value.");
90     // Scalar case can directly use memref.store.
91     if (!isa<VectorType>(val.getType()))
92       return b.create<memref::StoreOp>(loc, val, buffer, zero);
93 
94     // Vector case must use vector::TransferWriteOp which will later lower to
95     //   vector.store of memref.store depending on further lowerings.
96     int64_t rank = sequentialVectorType.getRank();
97     SmallVector<Value> indices(rank, zero);
98     if (val == distributedVal) {
99       for (auto dimExpr : distributionMap.getResults()) {
100         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
101         indices[index] = buildDistributedOffset(b, loc, index);
102       }
103     }
104     SmallVector<bool> inBounds(indices.size(), true);
105     return b.create<vector::TransferWriteOp>(
106         loc, val, buffer, indices,
107         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
108   }
109 
110   /// Create a load during the process of distributing the
111   /// `vector.warp_execute_on_thread_0` op.
112   /// Vector distribution assumes the following convention regarding the
113   /// temporary buffers that are created to transition values. This **must**
114   /// be properly specified in the `options.warpAllocationFn`:
115   ///   1. scalars of type T transit through a memref<1xT>.
116   ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
117   ///
118   /// When broadcastMode is true, the load is not distributed to account for
119   /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
120   ///
121   /// Example:
122   ///
123   /// ```
124   ///   %r = vector.warp_execute_on_lane_0(...) -> (f32) {
125   ///     vector.yield %cst : f32
126   ///   }
127   ///   // Both types are f32. The constant %cst is broadcasted to all lanes.
128   /// ```
129   /// This behavior described in more detail in the documentation of the op.
130   Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
131 
132     // Scalar case can directly use memref.store.
133     if (!isa<VectorType>(type))
134       return b.create<memref::LoadOp>(loc, buffer, zero);
135 
136     // Other cases must be vector atm.
137     // Vector case must use vector::TransferReadOp which will later lower to
138     //   vector.read of memref.read depending on further lowerings.
139     assert((type == distributedVectorType || type == sequentialVectorType) &&
140            "Must store either the preregistered distributed or the "
141            "preregistered sequential type.");
142     SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
143     if (type == distributedVectorType) {
144       for (auto dimExpr : distributionMap.getResults()) {
145         int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
146         indices[index] = buildDistributedOffset(b, loc, index);
147       }
148     }
149     SmallVector<bool> inBounds(indices.size(), true);
150     return b.create<vector::TransferReadOp>(
151         loc, cast<VectorType>(type), buffer, indices,
152         ArrayRef<bool>(inBounds.begin(), inBounds.end()));
153   }
154 
155   Value sequentialVal, distributedVal, laneId, zero;
156   VectorType sequentialVectorType, distributedVectorType;
157   AffineMap distributionMap;
158 };
159 
160 } // namespace
161 
162 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
163 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
164     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
165     ValueRange newYieldedValues, TypeRange newReturnTypes) {
166   // Create a new op before the existing one, with the extra operands.
167   OpBuilder::InsertionGuard g(rewriter);
168   rewriter.setInsertionPoint(warpOp);
169   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
170       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
171       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
172 
173   Region &opBody = warpOp.getBodyRegion();
174   Region &newOpBody = newWarpOp.getBodyRegion();
175   Block &newOpFirstBlock = newOpBody.front();
176   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
177   rewriter.eraseBlock(&newOpFirstBlock);
178   assert(newWarpOp.getWarpRegion().hasOneBlock() &&
179          "expected WarpOp with single block");
180 
181   auto yield =
182       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
183 
184   rewriter.updateRootInPlace(
185       yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
186   return newWarpOp;
187 }
188 
189 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
190 /// `indices` return the index of each new output.
191 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
192     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
193     ValueRange newYieldedValues, TypeRange newReturnTypes,
194     llvm::SmallVector<size_t> &indices) {
195   SmallVector<Type> types(warpOp.getResultTypes().begin(),
196                           warpOp.getResultTypes().end());
197   auto yield = cast<vector::YieldOp>(
198       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
199   llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
200                                               yield.getOperands().end());
201   for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
202     if (yieldValues.insert(std::get<0>(newRet))) {
203       types.push_back(std::get<1>(newRet));
204       indices.push_back(yieldValues.size() - 1);
205     } else {
206       // If the value already exit the region don't create a new output.
207       for (auto [idx, yieldOperand] :
208            llvm::enumerate(yieldValues.getArrayRef())) {
209         if (yieldOperand == std::get<0>(newRet)) {
210           indices.push_back(idx);
211           break;
212         }
213       }
214     }
215   }
216   yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
217   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
218       rewriter, warpOp, yieldValues.getArrayRef(), types);
219   rewriter.replaceOp(warpOp,
220                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
221   return newWarpOp;
222 }
223 
224 /// Helper to know if an op can be hoisted out of the region.
225 static bool canBeHoisted(Operation *op,
226                          function_ref<bool(Value)> definedOutside) {
227   return llvm::all_of(op->getOperands(), definedOutside) &&
228          isMemoryEffectFree(op) && op->getNumRegions() == 0;
229 }
230 
231 /// Return a value yielded by `warpOp` which statifies the filter lamdba
232 /// condition and is not dead.
233 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
234                                 const std::function<bool(Operation *)> &fn) {
235   auto yield = cast<vector::YieldOp>(
236       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
237   for (OpOperand &yieldOperand : yield->getOpOperands()) {
238     Value yieldValues = yieldOperand.get();
239     Operation *definedOp = yieldValues.getDefiningOp();
240     if (definedOp && fn(definedOp)) {
241       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
242         return &yieldOperand;
243     }
244   }
245   return {};
246 }
247 
248 // Clones `op` into a new operation that takes `operands` and returns
249 // `resultTypes`.
250 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
251                                               Location loc, Operation *op,
252                                               ArrayRef<Value> operands,
253                                               ArrayRef<Type> resultTypes) {
254   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
255                      op->getAttrs());
256   return rewriter.create(res);
257 }
258 
259 namespace {
260 
261 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
262 /// thread `laneId` executes the entirety of the computation.
263 ///
264 /// After the transformation:
265 ///   - the IR within the scf.if op can be thought of as executing sequentially
266 ///     (from the point of view of threads along `laneId`).
267 ///   - the IR outside of the scf.if op can be thought of as executing in
268 ///     parallel (from the point of view of threads along `laneId`).
269 ///
270 /// Values that need to transit through the parallel / sequential and the
271 /// sequential / parallel boundaries do so via reads and writes to a temporary
272 /// memory location.
273 ///
274 /// The transformation proceeds in multiple steps:
275 ///   1. Create the scf.if op.
276 ///   2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
277 ///      within the scf.if to transit the values captured from above.
278 ///   3. Synchronize before the scf.if to ensure all writes inserted in 2. are
279 ///      consistent within the scf.if.
280 ///   4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
281 ///   5. Insert appropriate writes within scf.if and reads after the scf.if to
282 ///      transit the values returned by the op.
283 ///   6. Synchronize after the scf.if to ensure all writes inserted in 5. are
284 ///      consistent after the scf.if.
285 ///   7. Perform late cleanups.
286 ///
287 /// All this assumes the vector distribution occurs along the most minor
288 /// distributed vector dimension.
289 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
290   WarpOpToScfIfPattern(MLIRContext *context,
291                        const WarpExecuteOnLane0LoweringOptions &options,
292                        PatternBenefit benefit = 1)
293       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
294         options(options) {}
295 
296   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
297                                 PatternRewriter &rewriter) const override {
298     assert(warpOp.getBodyRegion().hasOneBlock() &&
299            "expected WarpOp with single block");
300     Block *warpOpBody = &warpOp.getBodyRegion().front();
301     Location loc = warpOp.getLoc();
302 
303     // Passed all checks. Start rewriting.
304     OpBuilder::InsertionGuard g(rewriter);
305     rewriter.setInsertionPoint(warpOp);
306 
307     // Step 1: Create scf.if op.
308     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
309     Value isLane0 = rewriter.create<arith::CmpIOp>(
310         loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
311     auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
312                                            /*withElseRegion=*/false);
313     rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
314 
315     // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
316     // reads within the scf.if to transit the values captured from above.
317     SmallVector<Value> bbArgReplacements;
318     for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
319       Value sequentialVal = warpOpBody->getArgument(it.index());
320       Value distributedVal = it.value();
321       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
322                                         warpOp.getLaneid(), c0);
323 
324       // Create buffer before the ifOp.
325       rewriter.setInsertionPoint(ifOp);
326       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
327                                               sequentialVal.getType());
328       // Store distributed vector into buffer, before the ifOp.
329       helper.buildStore(rewriter, loc, distributedVal, buffer);
330       // Load sequential vector from buffer, inside the ifOp.
331       rewriter.setInsertionPointToStart(ifOp.thenBlock());
332       bbArgReplacements.push_back(
333           helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
334     }
335 
336     // Step 3. Insert sync after all the stores and before all the loads.
337     if (!warpOp.getArgs().empty()) {
338       rewriter.setInsertionPoint(ifOp);
339       options.warpSyncronizationFn(loc, rewriter, warpOp);
340     }
341 
342     // Step 4. Move body of warpOp to ifOp.
343     rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
344 
345     // Step 5. Insert appropriate writes within scf.if and reads after the
346     // scf.if to transit the values returned by the op.
347     // TODO: at this point, we can reuse the shared memory from previous
348     // buffers.
349     SmallVector<Value> replacements;
350     auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
351     Location yieldLoc = yieldOp.getLoc();
352     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
353       Value sequentialVal = it.value();
354       Value distributedVal = warpOp->getResult(it.index());
355       DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
356                                         warpOp.getLaneid(), c0);
357 
358       // Create buffer before the ifOp.
359       rewriter.setInsertionPoint(ifOp);
360       Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
361                                               sequentialVal.getType());
362 
363       // Store yielded value into buffer, inside the ifOp, before the
364       // terminator.
365       rewriter.setInsertionPoint(yieldOp);
366       helper.buildStore(rewriter, loc, sequentialVal, buffer);
367 
368       // Load distributed value from buffer, after  the warpOp.
369       rewriter.setInsertionPointAfter(ifOp);
370       // Result type and yielded value type are the same. This is a broadcast.
371       // E.g.:
372       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
373       //   vector.yield %cst : f32
374       // }
375       // Both types are f32. The constant %cst is broadcasted to all lanes.
376       // This is described in more detail in the documentation of the op.
377       replacements.push_back(
378           helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
379     }
380 
381     // Step 6. Insert sync after all the stores and before all the loads.
382     if (!yieldOp.getOperands().empty()) {
383       rewriter.setInsertionPointAfter(ifOp);
384       options.warpSyncronizationFn(loc, rewriter, warpOp);
385     }
386 
387     // Step 7. Delete terminator and add empty scf.yield.
388     rewriter.eraseOp(yieldOp);
389     rewriter.setInsertionPointToEnd(ifOp.thenBlock());
390     rewriter.create<scf::YieldOp>(yieldLoc);
391 
392     // Compute replacements for WarpOp results.
393     rewriter.replaceOp(warpOp, replacements);
394 
395     return success();
396   }
397 
398 private:
399   const WarpExecuteOnLane0LoweringOptions &options;
400 };
401 
402 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
403 /// op with the proper return type.
404 /// The new write op is updated to write the result of the new warp execute op.
405 /// The old `writeOp` is deleted.
406 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
407                                             WarpExecuteOnLane0Op warpOp,
408                                             vector::TransferWriteOp writeOp,
409                                             VectorType targetType,
410                                             VectorType maybeMaskType) {
411   assert(writeOp->getParentOp() == warpOp &&
412          "write must be nested immediately under warp");
413   OpBuilder::InsertionGuard g(rewriter);
414   SmallVector<size_t> newRetIndices;
415   WarpExecuteOnLane0Op newWarpOp;
416   if (maybeMaskType) {
417     newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
418         rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
419         TypeRange{targetType, maybeMaskType}, newRetIndices);
420   } else {
421     newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
422         rewriter, warpOp, ValueRange{{writeOp.getVector()}},
423         TypeRange{targetType}, newRetIndices);
424   }
425   rewriter.setInsertionPointAfter(newWarpOp);
426   auto newWriteOp =
427       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
428   rewriter.eraseOp(writeOp);
429   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
430   if (maybeMaskType)
431     newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
432   return newWriteOp;
433 }
434 
435 /// Return the distributed vector type based on the original type and the
436 /// distribution map. The map is expected to have a dimension equal to the
437 /// original type rank and should be a projection where the results are the
438 /// distributed dimensions. The number of results should be equal to the number
439 /// of warp sizes which is currently limited to 1.
440 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
441 /// and a warp size of 16 would distribute the second dimension (associated to
442 /// d1) and return vector<16x2x64>
443 static VectorType getDistributedType(VectorType originalType, AffineMap map,
444                                      int64_t warpSize) {
445   if (map.getNumResults() != 1)
446     return VectorType();
447   SmallVector<int64_t> targetShape(originalType.getShape().begin(),
448                                    originalType.getShape().end());
449   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
450     unsigned position = map.getDimPosition(i);
451     if (targetShape[position] % warpSize != 0)
452       return VectorType();
453     targetShape[position] = targetShape[position] / warpSize;
454   }
455   VectorType targetType =
456       VectorType::get(targetShape, originalType.getElementType());
457   return targetType;
458 }
459 
460 /// Distribute transfer_write ops based on the affine map returned by
461 /// `distributionMapFn`.
462 /// Example:
463 /// ```
464 /// %0 = vector.warp_execute_on_lane_0(%id){
465 ///   ...
466 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
467 ///   vector.yield
468 /// }
469 /// ```
470 /// To
471 /// ```
472 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
473 ///   ...
474 ///   vector.yield %v : vector<32xf32>
475 /// }
476 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
477 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
478   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
479                       PatternBenefit b = 1)
480       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
481         distributionMapFn(std::move(fn)) {}
482 
483   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
484   /// are multiples of the distribution ratio are supported at the moment.
485   LogicalResult tryDistributeOp(RewriterBase &rewriter,
486                                 vector::TransferWriteOp writeOp,
487                                 WarpExecuteOnLane0Op warpOp) const {
488     VectorType writtenVectorType = writeOp.getVectorType();
489 
490     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
491     // to separate it from the rest.
492     if (writtenVectorType.getRank() == 0)
493       return failure();
494 
495     // 2. Compute the distributed type.
496     AffineMap map = distributionMapFn(writeOp.getVector());
497     VectorType targetType =
498         getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
499     if (!targetType)
500       return failure();
501 
502     // 2.5 Compute the distributed type for the new mask;
503     VectorType maskType;
504     if (writeOp.getMask()) {
505       // TODO: Distribution of masked writes with non-trivial permutation maps
506       // requires the distribution of the mask to elementwise match the
507       // distribution of the permuted written vector. Currently the details
508       // of which lane is responsible for which element is captured strictly
509       // by shape information on the warp op, and thus requires materializing
510       // the permutation in IR.
511       if (!writeOp.getPermutationMap().isMinorIdentity())
512         return failure();
513       maskType =
514           getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
515     }
516 
517     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
518     // the rest.
519     vector::TransferWriteOp newWriteOp =
520         cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
521 
522     // 4. Reindex the write using the distribution map.
523     auto newWarpOp =
524         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
525     rewriter.setInsertionPoint(newWriteOp);
526     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
527     Location loc = newWriteOp.getLoc();
528     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
529                                newWriteOp.getIndices().end());
530     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
531       AffineExpr d0, d1;
532       bindDims(newWarpOp.getContext(), d0, d1);
533       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
534       if (!indexExpr)
535         continue;
536       unsigned indexPos = indexExpr.getPosition();
537       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
538       auto scale =
539           rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
540       indices[indexPos] = affine::makeComposedAffineApply(
541           rewriter, loc, d0 + scale * d1,
542           {indices[indexPos], newWarpOp.getLaneid()});
543     }
544     newWriteOp.getIndicesMutable().assign(indices);
545 
546     return success();
547   }
548 
549   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
550   LogicalResult tryExtractOp(RewriterBase &rewriter,
551                              vector::TransferWriteOp writeOp,
552                              WarpExecuteOnLane0Op warpOp) const {
553     Location loc = writeOp.getLoc();
554     VectorType vecType = writeOp.getVectorType();
555 
556     // Only sink out vector of 1 element for now to not serialize large vector
557     // store. This can later be controlled by user.
558     if (vecType.getNumElements() != 1)
559       return failure();
560 
561     // Do not process warp ops that contain only TransferWriteOps.
562     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
563           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
564         }))
565       return failure();
566 
567     SmallVector<Value> yieldValues = {writeOp.getVector()};
568     SmallVector<Type> retTypes = {vecType};
569     SmallVector<size_t> newRetIndices;
570     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
571         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
572     rewriter.setInsertionPointAfter(newWarpOp);
573 
574     // Create a second warp op that contains only writeOp.
575     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
576         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
577     Block &body = secondWarpOp.getBodyRegion().front();
578     rewriter.setInsertionPointToStart(&body);
579     auto newWriteOp =
580         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
581     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
582     rewriter.eraseOp(writeOp);
583     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
584     return success();
585   }
586 
587   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
588                                 PatternRewriter &rewriter) const override {
589     auto yield = cast<vector::YieldOp>(
590         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
591     Operation *lastNode = yield->getPrevNode();
592     auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
593     if (!writeOp)
594       return failure();
595 
596     Value maybeMask = writeOp.getMask();
597     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
598           return writeOp.getVector() == value ||
599                  (maybeMask && maybeMask == value) ||
600                  warpOp.isDefinedOutsideOfRegion(value);
601         }))
602       return failure();
603 
604     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
605       return success();
606 
607     // Masked writes not supported for extraction.
608     if (writeOp.getMask())
609       return failure();
610 
611     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
612       return success();
613 
614     return failure();
615   }
616 
617 private:
618   DistributionMapFn distributionMapFn;
619 };
620 
621 /// Sink out elementwise op feeding into a warp op yield.
622 /// ```
623 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
624 ///   ...
625 ///   %3 = arith.addf %1, %2 : vector<32xf32>
626 ///   vector.yield %3 : vector<32xf32>
627 /// }
628 /// ```
629 /// To
630 /// ```
631 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
632 /// vector<1xf32>, vector<1xf32>) {
633 ///   ...
634 ///   %4 = arith.addf %2, %3 : vector<32xf32>
635 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
636 ///   vector<32xf32>
637 /// }
638 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
639 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
640   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
641   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
642                                 PatternRewriter &rewriter) const override {
643     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
644       return OpTrait::hasElementwiseMappableTraits(op);
645     });
646     if (!yieldOperand)
647       return failure();
648 
649     Operation *elementWise = yieldOperand->get().getDefiningOp();
650     unsigned operandIndex = yieldOperand->getOperandNumber();
651     Value distributedVal = warpOp.getResult(operandIndex);
652     SmallVector<Value> yieldValues;
653     SmallVector<Type> retTypes;
654     Location loc = warpOp.getLoc();
655     for (OpOperand &operand : elementWise->getOpOperands()) {
656       Type targetType;
657       if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
658         // If the result type is a vector, the operands must also be vectors.
659         auto operandType = cast<VectorType>(operand.get().getType());
660         targetType =
661             VectorType::get(vecType.getShape(), operandType.getElementType());
662       } else {
663         auto operandType = operand.get().getType();
664         assert(!isa<VectorType>(operandType) &&
665                "unexpected yield of vector from op with scalar result type");
666         targetType = operandType;
667       }
668       retTypes.push_back(targetType);
669       yieldValues.push_back(operand.get());
670     }
671     SmallVector<size_t> newRetIndices;
672     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
673         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
674     rewriter.setInsertionPointAfter(newWarpOp);
675     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
676                                    elementWise->getOperands().end());
677     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
678       newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
679     }
680     OpBuilder::InsertionGuard g(rewriter);
681     rewriter.setInsertionPointAfter(newWarpOp);
682     Operation *newOp = cloneOpWithOperandsAndTypes(
683         rewriter, loc, elementWise, newOperands,
684         {newWarpOp.getResult(operandIndex).getType()});
685     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
686                                 newOp->getResult(0));
687     return success();
688   }
689 };
690 
691 /// Sink out splat constant op feeding into a warp op yield.
692 /// ```
693 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
694 ///   ...
695 ///   %cst = arith.constant dense<2.0> : vector<32xf32>
696 ///   vector.yield %cst : vector<32xf32>
697 /// }
698 /// ```
699 /// To
700 /// ```
701 /// vector.warp_execute_on_lane_0(%arg0 {
702 ///   ...
703 /// }
704 /// %0 = arith.constant dense<2.0> : vector<1xf32>
705 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
706   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
707   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
708                                 PatternRewriter &rewriter) const override {
709     OpOperand *yieldOperand = getWarpResult(
710         warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
711     if (!yieldOperand)
712       return failure();
713     auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
714     auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
715     if (!dense)
716       return failure();
717     // Notify the rewriter that the warp op is changing (see the comment on
718     // the WarpOpTransferRead pattern).
719     rewriter.startRootUpdate(warpOp);
720     unsigned operandIndex = yieldOperand->getOperandNumber();
721     Attribute scalarAttr = dense.getSplatValue<Attribute>();
722     auto newAttr = DenseElementsAttr::get(
723         cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
724     Location loc = warpOp.getLoc();
725     rewriter.setInsertionPointAfter(warpOp);
726     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
727     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
728     rewriter.finalizeRootUpdate(warpOp);
729     return success();
730   }
731 };
732 
733 /// Delinearize the given `laneId` into multiple dimensions, where each
734 /// dimension's size is determined by `originalShape` and `distributedShape`
735 /// together. This function expects the total numbers of threads needed for
736 /// distribution is equal to `warpSize`. Returns true and updates
737 /// `delinearizedIds` if so.
738 bool delinearizeLaneId(OpBuilder &builder, Location loc,
739                        ArrayRef<int64_t> originalShape,
740                        ArrayRef<int64_t> distributedShape, int64_t warpSize,
741                        Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
742   // If the original shape and the distributed shape is the same, we don't
743   // distribute at all--every thread is handling the whole. For such case, we
744   // should not rely on lane IDs later. So just return an empty lane ID vector.
745   if (originalShape == distributedShape) {
746     delinearizedIds.clear();
747     return true;
748   }
749 
750   SmallVector<int64_t> sizes;
751   for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
752     if (large % small != 0)
753       return false;
754     sizes.push_back(large / small);
755   }
756   if (std::accumulate(sizes.begin(), sizes.end(), 1,
757                       std::multiplies<int64_t>()) != warpSize)
758     return false;
759 
760   AffineExpr s0, s1;
761   bindSymbols(builder.getContext(), s0, s1);
762 
763   int64_t usedThreads = 1;
764 
765   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
766   delinearizedIds.assign(sizes.size(), zero);
767 
768   for (int i = sizes.size() - 1; i >= 0; --i) {
769     usedThreads *= sizes[i];
770     if (usedThreads == warpSize) {
771       // We've used up all available threads. Don't need to perform modulo
772       // anymore. And we can stop the calculation for further dimensions.
773       delinearizedIds[i] = laneId;
774       break;
775     }
776     delinearizedIds[i] =
777         affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
778     laneId = affine::makeComposedAffineApply(
779         builder, loc, s0.floorDiv(usedThreads), {laneId});
780   }
781   return true;
782 }
783 
784 /// Sink out transfer_read op feeding into a warp op yield.
785 /// ```
786 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
787 ///   ...
788 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
789 //    vector<32xf32>
790 ///   vector.yield %2 : vector<32xf32>
791 /// }
792 /// ```
793 /// To
794 /// ```
795 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
796 /// vector<1xf32>, vector<1xf32>) {
797 ///   ...
798 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
799 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
800 /// }
801 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
802 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
803   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
804   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
805                                 PatternRewriter &rewriter) const override {
806     // Try to find a distributable yielded read. Note that this pattern can
807     // still fail at the end after distribution, in which case this might have
808     // missed another distributable read.
809     OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
810       // Don't duplicate transfer_read ops when distributing.
811       return isa<vector::TransferReadOp>(op) && op->hasOneUse();
812     });
813     if (!operand)
814       return failure();
815     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
816 
817     unsigned operandIndex = operand->getOperandNumber();
818     Value distributedVal = warpOp.getResult(operandIndex);
819 
820     SmallVector<Value, 4> indices(read.getIndices().begin(),
821                                   read.getIndices().end());
822     auto sequentialType = cast<VectorType>(read.getResult().getType());
823     auto distributedType = cast<VectorType>(distributedVal.getType());
824     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
825     AffineMap indexMap = map.compose(read.getPermutationMap());
826 
827     // Distribute the mask if present.
828     OpBuilder::InsertionGuard g(rewriter);
829     WarpExecuteOnLane0Op newWarpOp = warpOp;
830     Value newMask = read.getMask();
831     bool hasMask = false;
832     if (read.getMask()) {
833       hasMask = true;
834       // TODO: Distribution of masked reads with non-trivial permutation maps
835       // requires the distribution of the mask to elementwise match the
836       // distribution of the permuted written vector. Currently the details
837       // of which lane is responsible for which element is captured strictly
838       // by shape information on the warp op, and thus requires materializing
839       // the permutation in IR.
840       if (!read.getPermutationMap().isMinorIdentity())
841         return failure();
842       VectorType maskType =
843           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
844       SmallVector<size_t> newRetIndices;
845       newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
846           rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
847           newRetIndices);
848       newMask = newWarpOp.getResult(newRetIndices[0]);
849       distributedVal = newWarpOp.getResult(operandIndex);
850     } else {
851       // This pattern does not actually change the warp op directly. Instead it
852       // just rewrites a new transfer read (when not masked) outside of the warp
853       // op and replaces the correponding result. There are then follow up
854       // patterns to erase now dead results of the warp op. This erasure allows
855       // propagation to continue, but this pattern on its own never actually
856       // tells the pattern rewriter that the warp op "changed." Notify the
857       // rewriter here that the warp op is changing. Similar situations are
858       // noted in following patterns.
859       rewriter.startRootUpdate(warpOp);
860     }
861 
862     rewriter.setInsertionPointAfter(newWarpOp);
863 
864     // Try to delinearize the lane ID to match the rank expected for
865     // distribution.
866     SmallVector<Value> delinearizedIds;
867     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
868                            distributedType.getShape(), newWarpOp.getWarpSize(),
869                            newWarpOp.getLaneid(), delinearizedIds)) {
870       if (!hasMask)
871         rewriter.cancelRootUpdate(warpOp);
872       return rewriter.notifyMatchFailure(
873           read, "cannot delinearize lane ID for distribution");
874     }
875     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
876 
877     for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
878       AffineExpr d0, d1;
879       bindDims(read.getContext(), d0, d1);
880       auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
881       if (!indexExpr)
882         continue;
883       unsigned indexPos = indexExpr.getPosition();
884       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
885       int64_t scale = distributedType.getDimSize(vectorPos);
886       indices[indexPos] = affine::makeComposedAffineApply(
887           rewriter, read.getLoc(), d0 + scale * d1,
888           {indices[indexPos], delinearizedIds[vectorPos]});
889     }
890     auto newRead = rewriter.create<vector::TransferReadOp>(
891         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
892         read.getPermutationMapAttr(), read.getPadding(), newMask,
893         read.getInBoundsAttr());
894 
895     // Check that the produced operation is legal.
896     // The transfer op may be reading from values that are defined within
897     // warpOp's body, which is illegal.
898     // We do the check late because incdices may be changed by
899     // makeComposeAffineApply. This rewrite may remove dependencies from
900     // warpOp's body.
901     // E.g., warpop {
902     //   %idx = affine.apply...[%outsideDef]
903     //   ... = transfer_read ...[%idx]
904     // }
905     // will be rewritten in:
906     // warpop {
907     // }
908     //  %new_idx = affine.apply...[%outsideDef]
909     //   ... = transfer_read ...[%new_idx]
910     if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
911           return (newRead.getMask() && value == newRead.getMask()) ||
912                  newWarpOp.isDefinedOutsideOfRegion(value);
913         })) {
914       if (!hasMask)
915         rewriter.cancelRootUpdate(warpOp);
916       return failure();
917     }
918 
919     rewriter.replaceAllUsesWith(distributedVal, newRead);
920     if (!hasMask)
921       rewriter.finalizeRootUpdate(warpOp);
922     return success();
923   }
924 };
925 
926 /// Remove any result that has no use along with the matching yieldOp operand.
927 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
928 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
929   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
930   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
931                                 PatternRewriter &rewriter) const override {
932     SmallVector<Type> newResultTypes;
933     newResultTypes.reserve(warpOp->getNumResults());
934     SmallVector<Value> newYieldValues;
935     newYieldValues.reserve(warpOp->getNumResults());
936     DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
937     DenseMap<OpResult, int64_t> dedupResultPositionMap;
938     auto yield = cast<vector::YieldOp>(
939         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
940 
941     // Some values may be yielded multiple times and correspond to multiple
942     // results. Deduplicating occurs by taking each result with its matching
943     // yielded value, and:
944     //   1. recording the unique first position at which the value is yielded.
945     //   2. recording for the result, the first position at which the dedup'ed
946     //      value is yielded.
947     //   3. skipping from the new result types / new yielded values any result
948     //      that has no use or whose yielded value has already been seen.
949     for (OpResult result : warpOp.getResults()) {
950       Value yieldOperand = yield.getOperand(result.getResultNumber());
951       auto it = dedupYieldOperandPositionMap.insert(
952           std::make_pair(yieldOperand, newResultTypes.size()));
953       dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
954       if (result.use_empty() || !it.second)
955         continue;
956       newResultTypes.push_back(result.getType());
957       newYieldValues.push_back(yieldOperand);
958     }
959     // No modification, exit early.
960     if (yield.getNumOperands() == newYieldValues.size())
961       return failure();
962     // Move the body of the old warpOp to a new warpOp.
963     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
964         rewriter, warpOp, newYieldValues, newResultTypes);
965 
966     // Simplify the new warp op after dropping dead results.
967     newWarpOp.getBody()->walk([&](Operation *op) {
968       if (isOpTriviallyDead(op))
969         rewriter.eraseOp(op);
970     });
971 
972     // Replace results of the old warpOp by the new, deduplicated results.
973     SmallVector<Value> newValues;
974     newValues.reserve(warpOp->getNumResults());
975     for (OpResult result : warpOp.getResults()) {
976       if (result.use_empty())
977         newValues.push_back(Value());
978       else
979         newValues.push_back(
980             newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
981     }
982     rewriter.replaceOp(warpOp, newValues);
983     return success();
984   }
985 };
986 
987 // If an operand is directly yielded out of the region we can forward it
988 // directly and it doesn't need to go through the region.
989 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
990   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
991   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
992                                 PatternRewriter &rewriter) const override {
993     SmallVector<Type> resultTypes;
994     SmallVector<Value> yieldValues;
995     auto yield = cast<vector::YieldOp>(
996         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
997     Value valForwarded;
998     unsigned resultIndex;
999     for (OpOperand &operand : yield->getOpOperands()) {
1000       Value result = warpOp.getResult(operand.getOperandNumber());
1001       if (result.use_empty())
1002         continue;
1003 
1004       // Assume all the values coming from above are uniform.
1005       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1006         if (result.getType() != operand.get().getType())
1007           continue;
1008         valForwarded = operand.get();
1009         resultIndex = operand.getOperandNumber();
1010         break;
1011       }
1012       auto arg = dyn_cast<BlockArgument>(operand.get());
1013       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1014         continue;
1015       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1016       if (result.getType() != warpOperand.getType())
1017         continue;
1018       valForwarded = warpOperand;
1019       resultIndex = operand.getOperandNumber();
1020       break;
1021     }
1022     if (!valForwarded)
1023       return failure();
1024     // Notify the rewriter that the warp op is changing (see the comment on
1025     // the WarpOpTransferRead pattern).
1026     rewriter.startRootUpdate(warpOp);
1027     rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1028     rewriter.finalizeRootUpdate(warpOp);
1029     return success();
1030   }
1031 };
1032 
1033 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1034   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1035   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1036                                 PatternRewriter &rewriter) const override {
1037     OpOperand *operand = getWarpResult(
1038         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
1039     if (!operand)
1040       return failure();
1041     unsigned int operandNumber = operand->getOperandNumber();
1042     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1043     Location loc = broadcastOp.getLoc();
1044     auto destVecType =
1045         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1046     Value broadcastSrc = broadcastOp.getSource();
1047     Type broadcastSrcType = broadcastSrc.getType();
1048 
1049     // Check that the broadcast actually spans a set of values uniformly across
1050     // all threads. In other words, check that each thread can reconstruct
1051     // their own broadcast.
1052     // For that we simply check that the broadcast we want to build makes sense.
1053     if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1054         vector::BroadcastableToResult::Success)
1055       return failure();
1056     SmallVector<size_t> newRetIndices;
1057     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1058         rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1059     rewriter.setInsertionPointAfter(newWarpOp);
1060     Value broadcasted = rewriter.create<vector::BroadcastOp>(
1061         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1062     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1063                                 broadcasted);
1064     return success();
1065   }
1066 };
1067 
1068 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1069 /// no-op for warp distribution; we need to handle the shape though.
1070 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1071   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1072   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1073                                 PatternRewriter &rewriter) const override {
1074     OpOperand *operand = getWarpResult(
1075         warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
1076     if (!operand)
1077       return failure();
1078 
1079     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1080 
1081     unsigned int operandNumber = operand->getOperandNumber();
1082     auto castDistributedType =
1083         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1084     VectorType castOriginalType = oldCastOp.getSourceVectorType();
1085     VectorType castResultType = castDistributedType;
1086 
1087     // We expect the distributed type to have a smaller rank than the original
1088     // type. Prepend with size-one dimensions to make them the same.
1089     unsigned castDistributedRank = castDistributedType.getRank();
1090     unsigned castOriginalRank = castOriginalType.getRank();
1091     if (castDistributedRank < castOriginalRank) {
1092       SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1093       llvm::append_range(shape, castDistributedType.getShape());
1094       castDistributedType =
1095           VectorType::get(shape, castDistributedType.getElementType());
1096     }
1097 
1098     SmallVector<size_t> newRetIndices;
1099     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1100         rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1101         newRetIndices);
1102     rewriter.setInsertionPointAfter(newWarpOp);
1103     Value newCast = rewriter.create<vector::ShapeCastOp>(
1104         oldCastOp.getLoc(), castResultType,
1105         newWarpOp->getResult(newRetIndices[0]));
1106     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1107     return success();
1108   }
1109 };
1110 
1111 /// Sink out vector.create_mask op feeding into a warp op yield.
1112 /// ```
1113 /// %0 = ...
1114 /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1115 ///   ...
1116 ///   %mask = vector.create_mask %0 : vector<32xi1>
1117 ///   vector.yield %mask : vector<32xi1>
1118 /// }
1119 /// ```
1120 /// To
1121 /// ```
1122 /// %0 = ...
1123 /// vector.warp_execute_on_lane_0(%arg0) {
1124 ///   ...
1125 /// }
1126 /// %cmp = arith.cmpi ult, %laneid, %0
1127 /// %ub = arith.select %cmp, %c0, %c1
1128 /// %1 = vector.create_mask %ub : vector<1xi1>
1129 struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1130   using OpRewritePattern::OpRewritePattern;
1131   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1132                                 PatternRewriter &rewriter) const override {
1133     OpOperand *yieldOperand = getWarpResult(
1134         warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
1135     if (!yieldOperand)
1136       return failure();
1137 
1138     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1139 
1140     // Early exit if any values needed for calculating the new mask indices
1141     // are defined inside the warp op.
1142     if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1143           return warpOp.isDefinedOutsideOfRegion(value);
1144         }))
1145       return failure();
1146 
1147     Location loc = mask.getLoc();
1148     unsigned operandIndex = yieldOperand->getOperandNumber();
1149 
1150     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1151     VectorType seqType = mask.getVectorType();
1152     ArrayRef<int64_t> seqShape = seqType.getShape();
1153     ArrayRef<int64_t> distShape = distType.getShape();
1154 
1155     rewriter.setInsertionPointAfter(warpOp);
1156 
1157     // Delinearize the lane ID for constructing the distributed mask sizes.
1158     SmallVector<Value> delinearizedIds;
1159     if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1160                            warpOp.getWarpSize(), warpOp.getLaneid(),
1161                            delinearizedIds))
1162       return rewriter.notifyMatchFailure(
1163           mask, "cannot delinearize lane ID for distribution");
1164     assert(!delinearizedIds.empty());
1165 
1166     // Notify the rewriter that the warp op is changing (see the comment on
1167     // the WarpOpTransferRead pattern).
1168     rewriter.startRootUpdate(warpOp);
1169 
1170     AffineExpr s0, s1;
1171     bindSymbols(rewriter.getContext(), s0, s1);
1172     SmallVector<Value> newOperands;
1173     for (int i = 0, e = distShape.size(); i < e; ++i) {
1174       // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1175       // find the distance from the largest mask index owned by this lane to the
1176       // original mask size. `vector.create_mask` implicitly clamps mask
1177       // operands to the range [0, mask_vector_size[i]], or in other words, the
1178       // mask sizes are always in the range [0, mask_vector_size[i]).
1179       Value maskDimIdx = affine::makeComposedAffineApply(
1180           rewriter, loc, s1 - s0 * distShape[i],
1181           {delinearizedIds[i], mask.getOperand(i)});
1182       newOperands.push_back(maskDimIdx);
1183     }
1184 
1185     auto newMask =
1186         rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1187     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1188     rewriter.finalizeRootUpdate(warpOp);
1189     return success();
1190   }
1191 };
1192 
1193 /// Pattern to move out vector.extract of single element vector. Those don't
1194 /// need to be distributed and can just be propagated outside of the region.
1195 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1196   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1197   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1198                                 PatternRewriter &rewriter) const override {
1199     OpOperand *operand = getWarpResult(
1200         warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
1201     if (!operand)
1202       return failure();
1203     unsigned int operandNumber = operand->getOperandNumber();
1204     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1205     VectorType extractSrcType = extractOp.getSourceVectorType();
1206     Location loc = extractOp.getLoc();
1207 
1208     // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1209     assert(extractSrcType.getRank() > 0 &&
1210            "vector.extract does not support rank 0 sources");
1211 
1212     // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1213     // canonicalized to %v.
1214     if (extractOp.getNumIndices() == 0)
1215       return failure();
1216 
1217     // Rewrite vector.extract with 1d source to vector.extractelement.
1218     if (extractSrcType.getRank() == 1) {
1219       if (extractOp.hasDynamicPosition())
1220         // TODO: Dinamic position not supported yet.
1221         return failure();
1222 
1223       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1224       int64_t pos = extractOp.getStaticPosition()[0];
1225       rewriter.setInsertionPoint(extractOp);
1226       rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1227           extractOp, extractOp.getVector(),
1228           rewriter.create<arith::ConstantIndexOp>(loc, pos));
1229       return success();
1230     }
1231 
1232     // All following cases are 2d or higher dimensional source vectors.
1233 
1234     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1235       // There is no distribution, this is a broadcast. Simply move the extract
1236       // out of the warp op.
1237       // TODO: This could be optimized. E.g., in case of a scalar result, let
1238       // one lane extract and shuffle the result to all other lanes (same as
1239       // the 1d case).
1240       SmallVector<size_t> newRetIndices;
1241       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1242           rewriter, warpOp, {extractOp.getVector()},
1243           {extractOp.getSourceVectorType()}, newRetIndices);
1244       rewriter.setInsertionPointAfter(newWarpOp);
1245       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1246       // Extract from distributed vector.
1247       Value newExtract = rewriter.create<vector::ExtractOp>(
1248           loc, distributedVec, extractOp.getMixedPosition());
1249       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1250                                   newExtract);
1251       return success();
1252     }
1253 
1254     // Find the distributed dimension. There should be exactly one.
1255     auto distributedType =
1256         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1257     auto yieldedType = cast<VectorType>(operand->get().getType());
1258     int64_t distributedDim = -1;
1259     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1260       if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1261         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1262         // support distributing multiple dimensions in the future.
1263         assert(distributedDim == -1 && "found multiple distributed dims");
1264         distributedDim = i;
1265       }
1266     }
1267     assert(distributedDim != -1 && "could not find distributed dimension");
1268     (void)distributedDim;
1269 
1270     // Yield source vector from warp op.
1271     SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
1272                                              extractSrcType.getShape().end());
1273     for (int i = 0; i < distributedType.getRank(); ++i)
1274       newDistributedShape[i + extractOp.getNumIndices()] =
1275           distributedType.getDimSize(i);
1276     auto newDistributedType =
1277         VectorType::get(newDistributedShape, distributedType.getElementType());
1278     SmallVector<size_t> newRetIndices;
1279     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1280         rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1281         newRetIndices);
1282     rewriter.setInsertionPointAfter(newWarpOp);
1283     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1284     // Extract from distributed vector.
1285     Value newExtract = rewriter.create<vector::ExtractOp>(
1286         loc, distributedVec, extractOp.getMixedPosition());
1287     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1288                                 newExtract);
1289     return success();
1290   }
1291 };
1292 
1293 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1294 /// need to be distributed and can just be propagated outside of the region.
1295 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1296   WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1297                        PatternBenefit b = 1)
1298       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1299         warpShuffleFromIdxFn(std::move(fn)) {}
1300   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1301                                 PatternRewriter &rewriter) const override {
1302     OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
1303       return isa<vector::ExtractElementOp>(op);
1304     });
1305     if (!operand)
1306       return failure();
1307     unsigned int operandNumber = operand->getOperandNumber();
1308     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1309     VectorType extractSrcType = extractOp.getSourceVectorType();
1310     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1311     Type elType = extractSrcType.getElementType();
1312     VectorType distributedVecType;
1313     if (!is0dOrVec1Extract) {
1314       assert(extractSrcType.getRank() == 1 &&
1315              "expected that extractelement src rank is 0 or 1");
1316       if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1317         return failure();
1318       int64_t elementsPerLane =
1319           extractSrcType.getShape()[0] / warpOp.getWarpSize();
1320       distributedVecType = VectorType::get({elementsPerLane}, elType);
1321     } else {
1322       distributedVecType = extractSrcType;
1323     }
1324     // Yield source vector from warp op.
1325     Location loc = extractOp.getLoc();
1326     SmallVector<size_t> newRetIndices;
1327     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1328         rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1329         newRetIndices);
1330     rewriter.setInsertionPointAfter(newWarpOp);
1331     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1332 
1333     // 0d extract: The new warp op broadcasts the source vector to all lanes.
1334     // All lanes extract the scalar.
1335     if (is0dOrVec1Extract) {
1336       Value newExtract;
1337       if (extractSrcType.getRank() == 1) {
1338         newExtract = rewriter.create<vector::ExtractElementOp>(
1339             loc, distributedVec,
1340             rewriter.create<arith::ConstantIndexOp>(loc, 0));
1341 
1342       } else {
1343         newExtract =
1344             rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1345       }
1346       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1347                                   newExtract);
1348       return success();
1349     }
1350 
1351     // 1d extract: Distribute the source vector. One lane extracts and shuffles
1352     // the value to all other lanes.
1353     int64_t elementsPerLane = distributedVecType.getShape()[0];
1354     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1355     // tid of extracting thread: pos / elementsPerLane
1356     Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1357         loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
1358     // Extract at position: pos % elementsPerLane
1359     Value pos =
1360         elementsPerLane == 1
1361             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1362             : rewriter
1363                   .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1364                                                  extractOp.getPosition())
1365                   .getResult();
1366     Value extracted =
1367         rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1368 
1369     // Shuffle the extracted value to all lanes.
1370     Value shuffled = warpShuffleFromIdxFn(
1371         loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1372     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1373     return success();
1374   }
1375 
1376 private:
1377   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1378 };
1379 
1380 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1381   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1382 
1383   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1384                                 PatternRewriter &rewriter) const override {
1385     OpOperand *operand = getWarpResult(
1386         warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
1387     if (!operand)
1388       return failure();
1389     unsigned int operandNumber = operand->getOperandNumber();
1390     auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1391     VectorType vecType = insertOp.getDestVectorType();
1392     VectorType distrType =
1393         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1394     bool hasPos = static_cast<bool>(insertOp.getPosition());
1395 
1396     // Yield destination vector, source scalar and position from warp op.
1397     SmallVector<Value> additionalResults{insertOp.getDest(),
1398                                          insertOp.getSource()};
1399     SmallVector<Type> additionalResultTypes{distrType,
1400                                             insertOp.getSource().getType()};
1401     if (hasPos) {
1402       additionalResults.push_back(insertOp.getPosition());
1403       additionalResultTypes.push_back(insertOp.getPosition().getType());
1404     }
1405     Location loc = insertOp.getLoc();
1406     SmallVector<size_t> newRetIndices;
1407     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1408         rewriter, warpOp, additionalResults, additionalResultTypes,
1409         newRetIndices);
1410     rewriter.setInsertionPointAfter(newWarpOp);
1411     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1412     Value newSource = newWarpOp->getResult(newRetIndices[1]);
1413     Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1414     rewriter.setInsertionPointAfter(newWarpOp);
1415 
1416     if (vecType == distrType) {
1417       // Broadcast: Simply move the vector.inserelement op out.
1418       Value newInsert = rewriter.create<vector::InsertElementOp>(
1419           loc, newSource, distributedVec, newPos);
1420       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1421                                   newInsert);
1422       return success();
1423     }
1424 
1425     // This is a distribution. Only one lane should insert.
1426     int64_t elementsPerLane = distrType.getShape()[0];
1427     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1428     // tid of extracting thread: pos / elementsPerLane
1429     Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1430         loc, sym0.ceilDiv(elementsPerLane), newPos);
1431     // Insert position: pos % elementsPerLane
1432     Value pos =
1433         elementsPerLane == 1
1434             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1435             : rewriter
1436                   .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1437                                                  newPos)
1438                   .getResult();
1439     Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1440         loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1441     Value newResult =
1442         rewriter
1443             .create<scf::IfOp>(
1444                 loc, isInsertingLane,
1445                 /*thenBuilder=*/
1446                 [&](OpBuilder &builder, Location loc) {
1447                   Value newInsert = builder.create<vector::InsertElementOp>(
1448                       loc, newSource, distributedVec, pos);
1449                   builder.create<scf::YieldOp>(loc, newInsert);
1450                 },
1451                 /*elseBuilder=*/
1452                 [&](OpBuilder &builder, Location loc) {
1453                   builder.create<scf::YieldOp>(loc, distributedVec);
1454                 })
1455             .getResult(0);
1456     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1457     return success();
1458   }
1459 };
1460 
1461 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1462   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1463 
1464   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1465                                 PatternRewriter &rewriter) const override {
1466     OpOperand *operand = getWarpResult(
1467         warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
1468     if (!operand)
1469       return failure();
1470     unsigned int operandNumber = operand->getOperandNumber();
1471     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1472     Location loc = insertOp.getLoc();
1473 
1474     // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1475     if (insertOp.getNumIndices() == 0)
1476       return failure();
1477 
1478     // Rewrite vector.insert with 1d dest to vector.insertelement.
1479     if (insertOp.getDestVectorType().getRank() == 1) {
1480       if (insertOp.hasDynamicPosition())
1481         // TODO: Dinamic position not supported yet.
1482         return failure();
1483 
1484       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1485       int64_t pos = insertOp.getStaticPosition()[0];
1486       rewriter.setInsertionPoint(insertOp);
1487       rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1488           insertOp, insertOp.getSource(), insertOp.getDest(),
1489           rewriter.create<arith::ConstantIndexOp>(loc, pos));
1490       return success();
1491     }
1492 
1493     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1494       // There is no distribution, this is a broadcast. Simply move the insert
1495       // out of the warp op.
1496       SmallVector<size_t> newRetIndices;
1497       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1498           rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1499           {insertOp.getSourceType(), insertOp.getDestVectorType()},
1500           newRetIndices);
1501       rewriter.setInsertionPointAfter(newWarpOp);
1502       Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1503       Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1504       Value newResult = rewriter.create<vector::InsertOp>(
1505           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1506       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1507                                   newResult);
1508       return success();
1509     }
1510 
1511     // Find the distributed dimension. There should be exactly one.
1512     auto distrDestType =
1513         cast<VectorType>(warpOp.getResult(operandNumber).getType());
1514     auto yieldedType = cast<VectorType>(operand->get().getType());
1515     int64_t distrDestDim = -1;
1516     for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1517       if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1518         // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1519         // support distributing multiple dimensions in the future.
1520         assert(distrDestDim == -1 && "found multiple distributed dims");
1521         distrDestDim = i;
1522       }
1523     }
1524     assert(distrDestDim != -1 && "could not find distributed dimension");
1525 
1526     // Compute the distributed source vector type.
1527     VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1528     SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
1529                                        srcVecType.getShape().end());
1530     // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1531     // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1532     //         insert a smaller vector<3xf32>.
1533     // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1534     //         case, one lane will insert the source vector<96xf32>. The other
1535     //         lanes will not do anything.
1536     int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1537     if (distrSrcDim >= 0)
1538       distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1539     auto distrSrcType =
1540         VectorType::get(distrSrcShape, distrDestType.getElementType());
1541 
1542     // Yield source and dest vectors from warp op.
1543     SmallVector<size_t> newRetIndices;
1544     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1545         rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1546         {distrSrcType, distrDestType}, newRetIndices);
1547     rewriter.setInsertionPointAfter(newWarpOp);
1548     Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1549     Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1550 
1551     // Insert into the distributed vector.
1552     Value newResult;
1553     if (distrSrcDim >= 0) {
1554       // Every lane inserts a small piece.
1555       newResult = rewriter.create<vector::InsertOp>(
1556           loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1557     } else {
1558       // One lane inserts the entire source vector.
1559       int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1560       SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1561       SmallVector<int64_t> newPos = getAsIntegers(pos);
1562       // tid of inserting lane: pos / elementsPerLane
1563       Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1564           loc, newPos[distrDestDim] / elementsPerLane);
1565       Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1566           loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1567       // Insert position: pos % elementsPerLane
1568       newPos[distrDestDim] %= elementsPerLane;
1569       auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1570         Value newInsert = builder.create<vector::InsertOp>(
1571             loc, distributedSrc, distributedDest, newPos);
1572         builder.create<scf::YieldOp>(loc, newInsert);
1573       };
1574       auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1575         builder.create<scf::YieldOp>(loc, distributedDest);
1576       };
1577       newResult = rewriter
1578                       .create<scf::IfOp>(loc, isInsertingLane,
1579                                          /*thenBuilder=*/insertingBuilder,
1580                                          /*elseBuilder=*/nonInsertingBuilder)
1581                       .getResult(0);
1582     }
1583 
1584     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1585     return success();
1586   }
1587 };
1588 
1589 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1590 /// the scf.ForOp is the last operation in the region so that it doesn't change
1591 /// the order of execution. This creates a new scf.for region after the
1592 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1593 /// WarpExecuteOnLane0Op region. Example:
1594 /// ```
1595 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1596 ///   ...
1597 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1598 ///   -> (vector<128xf32>) {
1599 ///     ...
1600 ///     scf.yield %r : vector<128xf32>
1601 ///   }
1602 ///   vector.yield %v1 : vector<128xf32>
1603 /// }
1604 /// ```
1605 /// To:
1606 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1607 ///   ...
1608 ///   vector.yield %v : vector<128xf32>
1609 /// }
1610 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1611 ///   -> (vector<4xf32>) {
1612 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
1613 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1614 ///     ^bb0(%arg: vector<128xf32>):
1615 ///       ...
1616 ///       vector.yield %ir : vector<128xf32>
1617 ///     }
1618 ///     scf.yield %iw : vector<4xf32>
1619 ///  }
1620 /// ```
1621 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1622 
1623   WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1624       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1625         distributionMapFn(std::move(fn)) {}
1626   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1627   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1628                                 PatternRewriter &rewriter) const override {
1629     auto yield = cast<vector::YieldOp>(
1630         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1631     // Only pick up forOp if it is the last op in the region.
1632     Operation *lastNode = yield->getPrevNode();
1633     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1634     if (!forOp)
1635       return failure();
1636     // Collect Values that come from the warp op but are outside the forOp.
1637     // Those Value needs to be returned by the original warpOp and passed to the
1638     // new op.
1639     llvm::SmallSetVector<Value, 32> escapingValues;
1640     SmallVector<Type> inputTypes;
1641     SmallVector<Type> distTypes;
1642     mlir::visitUsedValuesDefinedAbove(
1643         forOp.getBodyRegion(), [&](OpOperand *operand) {
1644           Operation *parent = operand->get().getParentRegion()->getParentOp();
1645           if (warpOp->isAncestor(parent)) {
1646             if (!escapingValues.insert(operand->get()))
1647               return;
1648             Type distType = operand->get().getType();
1649             if (auto vecType = dyn_cast<VectorType>(distType)) {
1650               AffineMap map = distributionMapFn(operand->get());
1651               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1652             }
1653             inputTypes.push_back(operand->get().getType());
1654             distTypes.push_back(distType);
1655           }
1656         });
1657 
1658     SmallVector<size_t> newRetIndices;
1659     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1660         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1661         newRetIndices);
1662     yield = cast<vector::YieldOp>(
1663         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1664 
1665     SmallVector<Value> newOperands;
1666     SmallVector<unsigned> resultIdx;
1667     // Collect all the outputs coming from the forOp.
1668     for (OpOperand &yieldOperand : yield->getOpOperands()) {
1669       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1670         continue;
1671       auto forResult = cast<OpResult>(yieldOperand.get());
1672       newOperands.push_back(
1673           newWarpOp.getResult(yieldOperand.getOperandNumber()));
1674       yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1675       resultIdx.push_back(yieldOperand.getOperandNumber());
1676     }
1677 
1678     OpBuilder::InsertionGuard g(rewriter);
1679     rewriter.setInsertionPointAfter(newWarpOp);
1680 
1681     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1682     // inside.
1683     auto newForOp = rewriter.create<scf::ForOp>(
1684         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1685         forOp.getStep(), newOperands);
1686     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1687 
1688     SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1689                                  newForOp.getRegionIterArgs().end());
1690     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1691                                     forOp.getResultTypes().end());
1692     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1693     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1694       warpInput.push_back(newWarpOp.getResult(retIdx));
1695       argIndexMapping[escapingValues[i]] = warpInputType.size();
1696       warpInputType.push_back(inputTypes[i]);
1697     }
1698     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1699         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1700         newWarpOp.getWarpSize(), warpInput, warpInputType);
1701 
1702     SmallVector<Value> argMapping;
1703     argMapping.push_back(newForOp.getInductionVar());
1704     for (Value args : innerWarp.getBody()->getArguments()) {
1705       argMapping.push_back(args);
1706     }
1707     argMapping.resize(forOp.getBody()->getNumArguments());
1708     SmallVector<Value> yieldOperands;
1709     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1710       yieldOperands.push_back(operand);
1711     rewriter.eraseOp(forOp.getBody()->getTerminator());
1712     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1713     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1714     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1715     rewriter.setInsertionPointAfter(innerWarp);
1716     if (!innerWarp.getResults().empty())
1717       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1718     rewriter.eraseOp(forOp);
1719     // Replace the warpOp result coming from the original ForOp.
1720     for (const auto &res : llvm::enumerate(resultIdx)) {
1721       rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1722                                   newForOp.getResult(res.index()));
1723       newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1724     }
1725     newForOp.walk([&](Operation *op) {
1726       for (OpOperand &operand : op->getOpOperands()) {
1727         auto it = argIndexMapping.find(operand.get());
1728         if (it == argIndexMapping.end())
1729           continue;
1730         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1731       }
1732     });
1733 
1734     // Finally, hoist out any now uniform code from the inner warp op.
1735     mlir::vector::moveScalarUniformCode(innerWarp);
1736     return success();
1737   }
1738 
1739 private:
1740   DistributionMapFn distributionMapFn;
1741 };
1742 
1743 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1744 /// The vector is reduced in parallel. Currently limited to vector size matching
1745 /// the warpOp size. E.g.:
1746 /// ```
1747 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1748 ///   %0 = "some_def"() : () -> (vector<32xf32>)
1749 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1750 ///   vector_ext.yield %1 : f32
1751 /// }
1752 /// ```
1753 /// is lowered to:
1754 /// ```
1755 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1756 ///   %1 = "some_def"() : () -> (vector<32xf32>)
1757 ///   vector_ext.yield %1 : vector<32xf32>
1758 /// }
1759 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1760 /// %r = ("warp.reduction %a")
1761 /// ```
1762 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1763   WarpOpReduction(MLIRContext *context,
1764                   DistributedReductionFn distributedReductionFn,
1765                   PatternBenefit benefit = 1)
1766       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1767         distributedReductionFn(std::move(distributedReductionFn)) {}
1768 
1769   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1770                                 PatternRewriter &rewriter) const override {
1771     OpOperand *yieldOperand = getWarpResult(
1772         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
1773     if (!yieldOperand)
1774       return failure();
1775 
1776     auto reductionOp =
1777         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1778     auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1779     // Only rank 1 vectors supported.
1780     if (vectorType.getRank() != 1)
1781       return rewriter.notifyMatchFailure(
1782           warpOp, "Only rank 1 reductions can be distributed.");
1783     // Only warp_size-sized vectors supported.
1784     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1785       return rewriter.notifyMatchFailure(
1786           warpOp, "Reduction vector dimension must match was size.");
1787     if (!reductionOp.getType().isIntOrFloat())
1788       return rewriter.notifyMatchFailure(
1789           warpOp, "Reduction distribution currently only supports floats and "
1790                   "integer types.");
1791 
1792     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1793     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1794     unsigned operandIndex = yieldOperand->getOperandNumber();
1795     SmallVector<Value> yieldValues = {reductionOp.getVector()};
1796     SmallVector<Type> retTypes = {
1797         VectorType::get({numElements}, reductionOp.getType())};
1798     if (reductionOp.getAcc()) {
1799       yieldValues.push_back(reductionOp.getAcc());
1800       retTypes.push_back(reductionOp.getAcc().getType());
1801     }
1802     SmallVector<size_t> newRetIndices;
1803     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1804         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1805     rewriter.setInsertionPointAfter(newWarpOp);
1806 
1807     // Obtain data to reduce for a single lane.
1808     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1809     // Distribute and reduce across threads.
1810     Value fullReduce =
1811         distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1812                                reductionOp.getKind(), newWarpOp.getWarpSize());
1813     if (reductionOp.getAcc()) {
1814       fullReduce = vector::makeArithReduction(
1815           rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1816           newWarpOp.getResult(newRetIndices[1]));
1817     }
1818     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1819     return success();
1820   }
1821 
1822 private:
1823   DistributedReductionFn distributedReductionFn;
1824 };
1825 
1826 } // namespace
1827 
1828 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1829     RewritePatternSet &patterns,
1830     const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
1831   patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1832 }
1833 
1834 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1835     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1836     PatternBenefit benefit) {
1837   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1838                                     benefit);
1839 }
1840 
1841 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1842     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1843     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1844     PatternBenefit readBenefit) {
1845   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1846   patterns
1847       .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1848            WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1849            WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1850           patterns.getContext(), benefit);
1851   patterns.add<WarpOpExtractElement>(patterns.getContext(),
1852                                      warpShuffleFromIdxFn, benefit);
1853   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1854                                benefit);
1855 }
1856 
1857 void mlir::vector::populateDistributeReduction(
1858     RewritePatternSet &patterns,
1859     const DistributedReductionFn &distributedReductionFn,
1860     PatternBenefit benefit) {
1861   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1862                                 benefit);
1863 }
1864 
1865 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1866   Block *body = warpOp.getBody();
1867 
1868   // Keep track of the ops we want to hoist.
1869   llvm::SmallSetVector<Operation *, 8> opsToMove;
1870 
1871   // Helper to check if a value is or will be defined outside of the region.
1872   auto isDefinedOutsideOfBody = [&](Value value) {
1873     auto *definingOp = value.getDefiningOp();
1874     return (definingOp && opsToMove.count(definingOp)) ||
1875            warpOp.isDefinedOutsideOfRegion(value);
1876   };
1877 
1878   // Do not use walk here, as we do not want to go into nested regions and hoist
1879   // operations from there.
1880   for (auto &op : body->without_terminator()) {
1881     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1882       return isa<VectorType>(result.getType());
1883     });
1884     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1885       opsToMove.insert(&op);
1886   }
1887 
1888   // Move all the ops marked as uniform outside of the region.
1889   for (Operation *op : opsToMove)
1890     op->moveBefore(warpOp);
1891 }
1892