xref: /llvm-project/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (revision 35ef3994bf738318b59ce640910fb1ccd3bb7dcb)
1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <type_traits>
11 
12 #include "mlir/Analysis/SliceAnalysis.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 using namespace mlir::vector;
38 
39 namespace {
40 
41 struct TestVectorToVectorLowering
42     : public PassWrapper<TestVectorToVectorLowering,
43                          OperationPass<func::FuncOp>> {
44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
45 
46   TestVectorToVectorLowering() = default;
47   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
48       : PassWrapper(pass) {}
49   StringRef getArgument() const final {
50     return "test-vector-to-vector-lowering";
51   }
52   StringRef getDescription() const final {
53     return "Test lowering patterns between ops in the vector dialect";
54   }
55 
56   void getDependentDialects(DialectRegistry &registry) const override {
57     registry.insert<affine::AffineDialect>();
58     registry.insert<vector::VectorDialect>();
59   }
60 
61   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
62                       llvm::cl::init(false)};
63 
64   void runOnOperation() override {
65     auto *ctx = &getContext();
66     RewritePatternSet patterns(ctx);
67     if (unroll) {
68       populateVectorUnrollPatterns(
69           patterns,
70           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
71               filter));
72     }
73     populateVectorToVectorCanonicalizationPatterns(patterns);
74     populateBubbleVectorBitCastOpPatterns(patterns);
75     populateCastAwayVectorLeadingOneDimPatterns(patterns);
76     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77   }
78 
79 private:
80   // Return the target shape based on op type.
81   static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
82     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
83       return SmallVector<int64_t>(2, 2);
84     if (isa<vector::ContractionOp>(op))
85       return SmallVector<int64_t>(3, 2);
86     // For transfer ops, just propagate the shape coming from
87     // InsertStridedSlices/ExtractStridedSlices.
88     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
89       VectorType dstVec;
90       for (Operation *users : readOp->getUsers()) {
91         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
92         if (!extract)
93           return std::nullopt;
94         auto vecType = cast<VectorType>(extract.getResult().getType());
95         if (dstVec && dstVec != vecType)
96           return std::nullopt;
97         dstVec = vecType;
98       }
99       return SmallVector<int64_t>(dstVec.getShape().begin(),
100                                   dstVec.getShape().end());
101     }
102     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
103       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
104       if (!insert)
105         return std::nullopt;
106       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
107       return SmallVector<int64_t>(shape.begin(), shape.end());
108     }
109     return std::nullopt;
110   }
111 
112   static LogicalResult filter(Operation *op) {
113     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
114                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
115   }
116 };
117 
118 struct TestVectorContractionPrepareForMMTLowering
119     : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
120                          OperationPass<func::FuncOp>> {
121   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
122       TestVectorContractionPrepareForMMTLowering)
123 
124   StringRef getArgument() const final {
125     return "test-vector-contraction-prepare-for-mmt-lowering";
126   }
127   StringRef getDescription() const final {
128     return "Test vector.contraction matmul canonicalization for MMT lowering.";
129   }
130   TestVectorContractionPrepareForMMTLowering() = default;
131 
132   void getDependentDialects(DialectRegistry &registry) const override {
133     registry.insert<affine::AffineDialect, arith::ArithDialect,
134                     vector::VectorDialect>();
135   }
136 
137   void runOnOperation() override {
138     MLIRContext *ctx = &getContext();
139     RewritePatternSet patterns(ctx);
140     vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
141     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
142   }
143 };
144 
145 struct TestVectorUnrollingPatterns
146     : public PassWrapper<TestVectorUnrollingPatterns,
147                          OperationPass<func::FuncOp>> {
148   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
149 
150   StringRef getArgument() const final {
151     return "test-vector-unrolling-patterns";
152   }
153   StringRef getDescription() const final {
154     return "Test lowering patterns to unroll contract ops in the vector "
155            "dialect";
156   }
157   TestVectorUnrollingPatterns() = default;
158   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
159       : PassWrapper(pass) {}
160   void runOnOperation() override {
161     MLIRContext *ctx = &getContext();
162     RewritePatternSet patterns(ctx);
163     populateVectorUnrollPatterns(
164         patterns, UnrollVectorOptions()
165                       .setNativeShape(ArrayRef<int64_t>{2, 2})
166                       .setFilterConstraint([](Operation *op) {
167                         return success(isa<arith::AddFOp, vector::FMAOp,
168                                            vector::MultiDimReductionOp>(op));
169                       }));
170     populateVectorUnrollPatterns(
171         patterns, UnrollVectorOptions()
172                       .setNativeShape(ArrayRef<int64_t>{2})
173                       .setFilterConstraint([](Operation *op) {
174                         return success(isa<vector::ReductionOp>(op));
175                       }));
176     populateVectorUnrollPatterns(
177         patterns, UnrollVectorOptions()
178                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
179                       .setFilterConstraint([](Operation *op) {
180                         return success(isa<vector::TransposeOp>(op));
181                       }));
182 
183     if (unrollBasedOnType) {
184       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
185           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
186         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
187         SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
188                                          4);
189         Type lhsType = contractOp.getLhsType().getElementType();
190         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
191         return nativeShape;
192       };
193 
194       UnrollVectorOptions opts;
195       opts.setNativeShapeFn(nativeShapeFn)
196           .setFilterConstraint(
197               [](Operation *op) { return success(isa<ContractionOp>(op)); });
198 
199       if (!unrollOrder.empty()) {
200         opts.setUnrollTraversalOrderFn(
201             [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
202               vector::ContractionOp contractOp =
203                   cast<vector::ContractionOp>(op);
204               if (contractOp.getIteratorTypes().size() == unrollOrder.size())
205                 return SmallVector<int64_t>(unrollOrder.begin(),
206                                             unrollOrder.end());
207               return std::nullopt;
208             });
209       }
210       populateVectorUnrollPatterns(patterns, opts);
211     } else {
212       auto nativeShapeFn =
213           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
214         auto contractOp = dyn_cast<ContractionOp>(op);
215         if (!contractOp)
216           return std::nullopt;
217         return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
218       };
219       populateVectorUnrollPatterns(patterns,
220                                    UnrollVectorOptions()
221                                        .setNativeShapeFn(nativeShapeFn)
222                                        .setFilterConstraint([](Operation *op) {
223                                          return success(isa<ContractionOp>(op));
224                                        }));
225     }
226     populateVectorToVectorCanonicalizationPatterns(patterns);
227     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
228   }
229 
230   ListOption<int64_t> unrollOrder{*this, "unroll-order",
231                                   llvm::cl::desc("set the unroll order")};
232 
233   Option<bool> unrollBasedOnType{
234       *this, "unroll-based-on-type",
235       llvm::cl::desc("Set the unroll factor based on type of the operation"),
236       llvm::cl::init(false)};
237 };
238 
239 struct TestVectorTransferUnrollingPatterns
240     : public PassWrapper<TestVectorTransferUnrollingPatterns,
241                          OperationPass<func::FuncOp>> {
242   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
243       TestVectorTransferUnrollingPatterns)
244 
245   TestVectorTransferUnrollingPatterns() = default;
246   TestVectorTransferUnrollingPatterns(
247       const TestVectorTransferUnrollingPatterns &pass)
248       : PassWrapper(pass) {}
249 
250   void getDependentDialects(DialectRegistry &registry) const override {
251     registry.insert<affine::AffineDialect>();
252   }
253   StringRef getArgument() const final {
254     return "test-vector-transfer-unrolling-patterns";
255   }
256   StringRef getDescription() const final {
257     return "Test lowering patterns to unroll transfer ops in the vector "
258            "dialect";
259   }
260   void runOnOperation() override {
261     MLIRContext *ctx = &getContext();
262     RewritePatternSet patterns(ctx);
263     UnrollVectorOptions opts;
264     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
265         .setFilterConstraint([](Operation *op) {
266           return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
267                              vector::GatherOp>(op));
268         });
269     if (reverseUnrollOrder.getValue()) {
270       opts.setUnrollTraversalOrderFn(
271           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
272             int64_t numLoops = 0;
273             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
274               numLoops = readOp.getVectorType().getRank();
275             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
276               numLoops = writeOp.getVectorType().getRank();
277             else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
278               numLoops = gatherOp.getVectorType().getRank();
279             else
280               return std::nullopt;
281             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
282             return llvm::to_vector(order);
283           });
284     }
285     populateVectorUnrollPatterns(patterns, opts);
286     populateVectorToVectorCanonicalizationPatterns(patterns);
287     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
288   }
289 
290   Option<bool> reverseUnrollOrder{
291       *this, "reverse-unroll-order",
292       llvm::cl::desc(
293           "reverse the order of unrolling of vector transfer operations"),
294       llvm::cl::init(false)};
295 };
296 
297 struct TestScalarVectorTransferLoweringPatterns
298     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
299                          OperationPass<func::FuncOp>> {
300   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
301       TestScalarVectorTransferLoweringPatterns)
302 
303   TestScalarVectorTransferLoweringPatterns() = default;
304   TestScalarVectorTransferLoweringPatterns(
305       const TestScalarVectorTransferLoweringPatterns &pass)
306       : PassWrapper(pass) {}
307 
308   StringRef getArgument() const final {
309     return "test-scalar-vector-transfer-lowering";
310   }
311   StringRef getDescription() const final {
312     return "Test lowering of scalar vector transfers to memref loads/stores.";
313   }
314 
315   void getDependentDialects(DialectRegistry &registry) const override {
316     registry.insert<affine::AffineDialect, memref::MemRefDialect,
317                     tensor::TensorDialect, vector::VectorDialect>();
318   }
319 
320   Option<bool> allowMultipleUses{
321       *this, "allow-multiple-uses",
322       llvm::cl::desc("Fold transfer operations with multiple uses"),
323       llvm::cl::init(false)};
324 
325   void runOnOperation() override {
326     MLIRContext *ctx = &getContext();
327     RewritePatternSet patterns(ctx);
328     vector::populateScalarVectorTransferLoweringPatterns(
329         patterns, /*benefit=*/1, allowMultipleUses.getValue());
330     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
331   }
332 };
333 
334 struct TestVectorTransferOpt
335     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
336   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
337 
338   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
339   StringRef getDescription() const final {
340     return "Test optimization transformations for transfer ops";
341   }
342   void runOnOperation() override {
343     IRRewriter rewriter(&getContext());
344     transferOpflowOpt(rewriter, getOperation());
345   }
346 };
347 
348 struct TestVectorTransferCollapseInnerMostContiguousDims
349     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
350                          OperationPass<func::FuncOp>> {
351   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
352       TestVectorTransferCollapseInnerMostContiguousDims)
353 
354   TestVectorTransferCollapseInnerMostContiguousDims() = default;
355   TestVectorTransferCollapseInnerMostContiguousDims(
356       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
357 
358   void getDependentDialects(DialectRegistry &registry) const override {
359     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
360   }
361 
362   StringRef getArgument() const final {
363     return "test-vector-transfer-collapse-inner-most-dims";
364   }
365 
366   StringRef getDescription() const final {
367     return "Test lowering patterns that reducedes the rank of the vector "
368            "transfer memory and vector operands.";
369   }
370 
371   void runOnOperation() override {
372     RewritePatternSet patterns(&getContext());
373     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
374     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
375   }
376 };
377 
378 struct TestSinkVectorBroadcast
379     : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
380   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
381 
382   TestSinkVectorBroadcast() = default;
383   TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
384 
385   void getDependentDialects(DialectRegistry &registry) const override {
386     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
387   }
388 
389   StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
390 
391   StringRef getDescription() const final {
392     return "Test lowering patterns that eliminate redundant brodacast "
393            "operations.";
394   }
395 
396   void runOnOperation() override {
397     RewritePatternSet patterns(&getContext());
398     populateSinkVectorBroadcastPatterns(patterns);
399     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
400   }
401 };
402 
403 struct TestVectorReduceToContractPatternsPatterns
404     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
405                          OperationPass<func::FuncOp>> {
406   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
407       TestVectorReduceToContractPatternsPatterns)
408 
409   StringRef getArgument() const final {
410     return "test-vector-reduction-to-contract-patterns";
411   }
412   StringRef getDescription() const final {
413     return "Test patterns to convert multireduce op to contract and combine "
414            "broadcast/transpose to contract";
415   }
416   void runOnOperation() override {
417     RewritePatternSet patterns(&getContext());
418     populateVectorReductionToContractPatterns(patterns);
419     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
420   }
421 };
422 
423 struct TestVectorChainedReductionFoldingPatterns
424     : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
425                          OperationPass<func::FuncOp>> {
426   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
427       TestVectorChainedReductionFoldingPatterns)
428 
429   StringRef getArgument() const final {
430     return "test-vector-chained-reduction-folding-patterns";
431   }
432   StringRef getDescription() const final {
433     return "Test patterns to fold chained vector reductions";
434   }
435   void runOnOperation() override {
436     RewritePatternSet patterns(&getContext());
437     populateChainedVectorReductionFoldingPatterns(patterns);
438     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
439   }
440 };
441 
442 struct TestVectorBreakDownReductionPatterns
443     : public PassWrapper<TestVectorBreakDownReductionPatterns,
444                          OperationPass<func::FuncOp>> {
445   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446       TestVectorBreakDownReductionPatterns)
447 
448   StringRef getArgument() const final {
449     return "test-vector-break-down-reduction-patterns";
450   }
451   StringRef getDescription() const final {
452     return "Test patterns to break down vector reductions into arith "
453            "reductions";
454   }
455   void runOnOperation() override {
456     RewritePatternSet patterns(&getContext());
457     populateBreakDownVectorReductionPatterns(patterns,
458                                              /*maxNumElementsToExtract=*/2);
459     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
460   }
461 };
462 
463 struct TestFlattenVectorTransferPatterns
464     : public PassWrapper<TestFlattenVectorTransferPatterns,
465                          OperationPass<func::FuncOp>> {
466   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
467       TestFlattenVectorTransferPatterns)
468 
469   StringRef getArgument() const final {
470     return "test-vector-transfer-flatten-patterns";
471   }
472   StringRef getDescription() const final {
473     return "Test patterns to rewrite contiguous row-major N-dimensional "
474            "vector.transfer_{read,write} ops into 1D transfers";
475   }
476   void getDependentDialects(DialectRegistry &registry) const override {
477     registry.insert<memref::MemRefDialect>();
478     registry.insert<affine::AffineDialect>();
479     registry.insert<vector::VectorDialect>();
480   }
481   void runOnOperation() override {
482     RewritePatternSet patterns(&getContext());
483     populateFlattenVectorTransferPatterns(patterns);
484     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
485   }
486 };
487 
488 struct TestVectorScanLowering
489     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
490   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
491 
492   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
493   StringRef getDescription() const final {
494     return "Test lowering patterns that lower the scan op in the vector "
495            "dialect";
496   }
497   void runOnOperation() override {
498     RewritePatternSet patterns(&getContext());
499     populateVectorScanLoweringPatterns(patterns);
500     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
501   }
502 };
503 
504 /// Allocate shared memory for a single warp to test lowering of
505 /// WarpExecuteOnLane0Op.
506 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
507                                         WarpExecuteOnLane0Op warpOp,
508                                         Type type) {
509   static constexpr int64_t kSharedMemorySpace = 3;
510   // Compute type of shared memory buffer.
511   MemRefType memrefType;
512   if (auto vectorType = dyn_cast<VectorType>(type)) {
513     memrefType =
514         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
515                         kSharedMemorySpace);
516   } else {
517     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
518   }
519 
520   // Get symbol table holding all shared memory globals.
521   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
522   SymbolTable symbolTable(moduleOp);
523 
524   // Create a pretty name.
525   SmallString<64> buf;
526   llvm::raw_svector_ostream os(buf);
527   interleave(memrefType.getShape(), os, "x");
528   os << "x" << memrefType.getElementType();
529   std::string symbolName = (Twine("__shared_") + os.str()).str();
530 
531   auto ip = builder.saveInsertionPoint();
532   builder.setInsertionPoint(moduleOp);
533   auto global = builder.create<memref::GlobalOp>(
534       loc,
535       /*sym_name=*/symbolName,
536       /*sym_visibility=*/builder.getStringAttr("private"),
537       /*type=*/memrefType,
538       /*initial_value=*/Attribute(),
539       /*constant=*/false,
540       /*alignment=*/IntegerAttr());
541   symbolTable.insert(global);
542   // The symbol table inserts at the end of the module, but globals are a bit
543   // nicer if they are at the beginning.
544   global->moveBefore(&moduleOp.front());
545 
546   builder.restoreInsertionPoint(ip);
547   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
548 }
549 
550 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
551                            CombiningKind kind, uint32_t size) {
552   // First reduce on a single thread to get per lane reduction value.
553   Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
554   // Parallel reduction using butterfly shuffles.
555   for (uint64_t i = 1; i < size; i <<= 1) {
556     Value shuffled = builder
557                          .create<gpu::ShuffleOp>(loc, laneVal, i,
558                                                  /*width=*/size,
559                                                  /*mode=*/gpu::ShuffleMode::XOR)
560                          .getShuffleResult();
561     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
562   }
563   return laneVal;
564 }
565 
566 struct TestVectorDistribution
567     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
568   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
569 
570   void getDependentDialects(DialectRegistry &registry) const override {
571     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
572                     affine::AffineDialect>();
573   }
574 
575   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
576   StringRef getDescription() const final {
577     return "Test vector warp distribute transformation and lowering patterns";
578   }
579   TestVectorDistribution() = default;
580   TestVectorDistribution(const TestVectorDistribution &pass)
581       : PassWrapper(pass) {}
582 
583   Option<bool> warpOpToSCF{
584       *this, "rewrite-warp-ops-to-scf-if",
585       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
586       llvm::cl::init(false)};
587 
588   Option<bool> distributeTransferWriteOps{
589       *this, "distribute-transfer-write",
590       llvm::cl::desc("Test distribution of transfer write"),
591       llvm::cl::init(false)};
592 
593   Option<unsigned> maxTransferWriteElements{
594       *this, "max-transfer-write-elements",
595       llvm::cl::desc("Maximum number of transfer write elements to distribute"),
596       llvm::cl::init(1)};
597 
598   Option<bool> hoistUniform{*this, "hoist-uniform",
599                             llvm::cl::desc("Test hoist uniform"),
600                             llvm::cl::init(false)};
601 
602   Option<bool> propagateDistribution{
603       *this, "propagate-distribution",
604       llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
605 
606   void runOnOperation() override {
607     RewritePatternSet patterns(&getContext());
608 
609     getOperation().walk([&](Operation *op) {
610       if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
611         if (hoistUniform) {
612           moveScalarUniformCode(warpOp);
613         }
614         WalkResult::interrupt();
615       }
616     });
617     MLIRContext *ctx = &getContext();
618     auto distributionFn = [](Value val) {
619       // Create a map (d0, d1) -> (d1) to distribute along the inner
620       // dimension. Once we support n-d distribution we can add more
621       // complex cases.
622       VectorType vecType = dyn_cast<VectorType>(val.getType());
623       int64_t vecRank = vecType ? vecType.getRank() : 0;
624       OpBuilder builder(val.getContext());
625       if (vecRank == 0)
626         return AffineMap::get(val.getContext());
627       return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
628     };
629     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
630                         Value srcIdx, int64_t warpSz) {
631       assert((val.getType().isF32() || val.getType().isInteger(32)) &&
632              "unsupported shuffle type");
633       Type i32Type = builder.getIntegerType(32);
634       Value srcIdxI32 =
635           builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
636       Value warpSzI32 = builder.create<arith::ConstantOp>(
637           loc, builder.getIntegerAttr(i32Type, warpSz));
638       Value result = builder
639                          .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
640                                                  gpu::ShuffleMode::IDX)
641                          .getResult(0);
642       return result;
643     };
644     if (distributeTransferWriteOps && propagateDistribution) {
645       RewritePatternSet patterns(ctx);
646       vector::populatePropagateWarpVectorDistributionPatterns(
647           patterns, distributionFn, shuffleFn, /*benefit=*/1,
648           /*readBenefit=*/0);
649       vector::populateDistributeReduction(patterns, warpReduction, 1);
650       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
651       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
652     } else if (distributeTransferWriteOps) {
653       RewritePatternSet patterns(ctx);
654       populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
655                                                 maxTransferWriteElements);
656       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
657     } else if (propagateDistribution) {
658       RewritePatternSet patterns(ctx);
659       vector::populatePropagateWarpVectorDistributionPatterns(
660           patterns, distributionFn, shuffleFn);
661       vector::populateDistributeReduction(patterns, warpReduction);
662       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
663     }
664     WarpExecuteOnLane0LoweringOptions options;
665     options.warpAllocationFn = allocateGlobalSharedMemory;
666     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
667                                       WarpExecuteOnLane0Op warpOp) {
668       builder.create<gpu::BarrierOp>(loc);
669     };
670     // Test on one pattern in isolation.
671     if (warpOpToSCF) {
672       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
673       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
674       return;
675     }
676   }
677 };
678 
679 struct TestVectorExtractStridedSliceLowering
680     : public PassWrapper<TestVectorExtractStridedSliceLowering,
681                          OperationPass<func::FuncOp>> {
682   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
683       TestVectorExtractStridedSliceLowering)
684 
685   StringRef getArgument() const final {
686     return "test-vector-extract-strided-slice-lowering";
687   }
688   StringRef getDescription() const final {
689     return "Test lowering patterns that converts vector.extract_strided_slice "
690            "into a chain of vector.extract and vector.insert ops";
691   }
692   void runOnOperation() override {
693     RewritePatternSet patterns(&getContext());
694     populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
695     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
696   }
697 };
698 
699 struct TestVectorBreakDownBitCast
700     : public PassWrapper<TestVectorBreakDownBitCast,
701                          OperationPass<func::FuncOp>> {
702   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
703 
704   StringRef getArgument() const final {
705     return "test-vector-break-down-bitcast";
706   }
707   StringRef getDescription() const final {
708     return "Test pattern that breaks down vector.bitcast ops ";
709   }
710   void runOnOperation() override {
711     RewritePatternSet patterns(&getContext());
712     populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
713       return op.getSourceVectorType().getShape().back() > 4;
714     });
715     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
716   }
717 };
718 
719 struct TestCreateVectorBroadcast
720     : public PassWrapper<TestCreateVectorBroadcast,
721                          OperationPass<func::FuncOp>> {
722   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
723 
724   StringRef getArgument() const final { return "test-create-vector-broadcast"; }
725   StringRef getDescription() const final {
726     return "Test optimization transformations for transfer ops";
727   }
728   void getDependentDialects(DialectRegistry &registry) const override {
729     registry.insert<vector::VectorDialect>();
730   }
731 
732   void runOnOperation() override {
733     getOperation()->walk([](Operation *op) {
734       if (op->getName().getStringRef() != "test_create_broadcast")
735         return;
736       auto targetShape =
737           cast<VectorType>(op->getResult(0).getType()).getShape();
738       auto arrayAttr =
739           cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
740               .asArrayRef();
741       llvm::SetVector<int64_t> broadcastedDims;
742       broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
743       OpBuilder b(op);
744       Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
745           b, op->getOperand(0), targetShape, broadcastedDims);
746       op->getResult(0).replaceAllUsesWith(bcast);
747       op->erase();
748     });
749   }
750 };
751 
752 struct TestVectorGatherLowering
753     : public PassWrapper<TestVectorGatherLowering,
754                          OperationPass<func::FuncOp>> {
755   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
756 
757   StringRef getArgument() const final { return "test-vector-gather-lowering"; }
758   StringRef getDescription() const final {
759     return "Test patterns that lower the gather op in the vector conditional "
760            "loads";
761   }
762   void getDependentDialects(DialectRegistry &registry) const override {
763     registry.insert<arith::ArithDialect, func::FuncDialect,
764                     memref::MemRefDialect, scf::SCFDialect,
765                     tensor::TensorDialect, vector::VectorDialect>();
766   }
767 
768   void runOnOperation() override {
769     RewritePatternSet patterns(&getContext());
770     populateVectorGatherLoweringPatterns(patterns);
771     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
772   }
773 };
774 
775 struct TestFoldArithExtensionIntoVectorContractPatterns
776     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
777                          OperationPass<func::FuncOp>> {
778   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
779       TestFoldArithExtensionIntoVectorContractPatterns)
780 
781   StringRef getArgument() const final {
782     return "test-fold-arith-extf-into-vector-contract-patterns";
783   }
784   StringRef getDescription() const final {
785     return "Test patterns that fold arithmetic extension ops into vector "
786            "contract ops";
787   }
788 
789   void getDependentDialects(DialectRegistry &registry) const override {
790     registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
791                     memref::MemRefDialect, scf::SCFDialect,
792                     tensor::TensorDialect, vector::VectorDialect>();
793   }
794 
795   void runOnOperation() override {
796     RewritePatternSet patterns(&getContext());
797     populateFoldArithExtensionPatterns(patterns);
798     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
799   }
800 };
801 
802 struct TestVectorEmulateMaskedLoadStore final
803     : public PassWrapper<TestVectorEmulateMaskedLoadStore,
804                          OperationPass<func::FuncOp>> {
805   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
806 
807   StringRef getArgument() const override {
808     return "test-vector-emulate-masked-load-store";
809   }
810   StringRef getDescription() const override {
811     return "Test patterns that emulate the maskedload/maskedstore op by "
812            " memref.load/store and scf.if";
813   }
814   void getDependentDialects(DialectRegistry &registry) const override {
815     registry
816         .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
817                 scf::SCFDialect, vector::VectorDialect>();
818   }
819 
820   void runOnOperation() override {
821     RewritePatternSet patterns(&getContext());
822     populateVectorMaskedLoadStoreEmulationPatterns(patterns);
823     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
824   }
825 };
826 
827 struct TestVectorLinearize final
828     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
829   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
830 
831   StringRef getArgument() const override { return "test-vector-linearize"; }
832   StringRef getDescription() const override {
833     return "Linearizes ND vectors for N >= 2 into 1D vectors";
834   }
835   void getDependentDialects(DialectRegistry &registry) const override {
836     registry.insert<vector::VectorDialect>();
837   }
838 
839   void runOnOperation() override {
840     auto *context = &getContext();
841 
842     TypeConverter typeConverter;
843     RewritePatternSet patterns(context);
844     ConversionTarget target(*context);
845 
846     vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
847                                                               patterns, target);
848     if (failed(applyPartialConversion(getOperation(), target,
849                                       std::move(patterns))))
850       return signalPassFailure();
851   }
852 };
853 } // namespace
854 
855 namespace mlir {
856 namespace test {
857 void registerTestVectorLowerings() {
858   PassRegistration<TestVectorToVectorLowering>();
859 
860   PassRegistration<TestVectorContractionPrepareForMMTLowering>();
861 
862   PassRegistration<TestVectorUnrollingPatterns>();
863 
864   PassRegistration<TestVectorTransferUnrollingPatterns>();
865 
866   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
867 
868   PassRegistration<TestVectorTransferOpt>();
869 
870   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
871 
872   PassRegistration<TestSinkVectorBroadcast>();
873 
874   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
875 
876   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
877 
878   PassRegistration<TestVectorBreakDownReductionPatterns>();
879 
880   PassRegistration<TestFlattenVectorTransferPatterns>();
881 
882   PassRegistration<TestVectorScanLowering>();
883 
884   PassRegistration<TestVectorDistribution>();
885 
886   PassRegistration<TestVectorExtractStridedSliceLowering>();
887 
888   PassRegistration<TestVectorBreakDownBitCast>();
889 
890   PassRegistration<TestCreateVectorBroadcast>();
891 
892   PassRegistration<TestVectorGatherLowering>();
893 
894   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
895 
896   PassRegistration<TestVectorEmulateMaskedLoadStore>();
897 
898   PassRegistration<TestVectorLinearize>();
899 }
900 } // namespace test
901 } // namespace mlir
902