xref: /llvm-project/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (revision 10054ba4acbc5378d2e2aa869a5bccd88aa4b59e)
1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <type_traits>
11 
12 #include "mlir/Analysis/SliceAnalysis.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 using namespace mlir::vector;
38 
39 namespace {
40 
41 struct TestVectorToVectorLowering
42     : public PassWrapper<TestVectorToVectorLowering,
43                          OperationPass<func::FuncOp>> {
44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
45 
46   TestVectorToVectorLowering() = default;
47   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
48       : PassWrapper(pass) {}
49   StringRef getArgument() const final {
50     return "test-vector-to-vector-lowering";
51   }
52   StringRef getDescription() const final {
53     return "Test lowering patterns between ops in the vector dialect";
54   }
55 
56   void getDependentDialects(DialectRegistry &registry) const override {
57     registry.insert<affine::AffineDialect>();
58     registry.insert<vector::VectorDialect>();
59   }
60 
61   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
62                       llvm::cl::init(false)};
63 
64   void runOnOperation() override {
65     auto *ctx = &getContext();
66     RewritePatternSet patterns(ctx);
67     if (unroll) {
68       populateVectorUnrollPatterns(
69           patterns,
70           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
71               filter));
72     }
73     populateVectorToVectorCanonicalizationPatterns(patterns);
74     populateBubbleVectorBitCastOpPatterns(patterns);
75     populateCastAwayVectorLeadingOneDimPatterns(patterns);
76     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77   }
78 
79 private:
80   // Return the target shape based on op type.
81   static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
82     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
83       return SmallVector<int64_t>(2, 2);
84     if (isa<vector::ContractionOp>(op))
85       return SmallVector<int64_t>(3, 2);
86     // For transfer ops, just propagate the shape coming from
87     // InsertStridedSlices/ExtractStridedSlices.
88     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
89       VectorType dstVec;
90       for (Operation *users : readOp->getUsers()) {
91         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
92         if (!extract)
93           return std::nullopt;
94         auto vecType = cast<VectorType>(extract.getResult().getType());
95         if (dstVec && dstVec != vecType)
96           return std::nullopt;
97         dstVec = vecType;
98       }
99       return SmallVector<int64_t>(dstVec.getShape());
100     }
101     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
102       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
103       if (!insert)
104         return std::nullopt;
105       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
106       return SmallVector<int64_t>(shape);
107     }
108     return std::nullopt;
109   }
110 
111   static LogicalResult filter(Operation *op) {
112     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
113                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
114   }
115 };
116 
117 struct TestVectorContractionPrepareForMMTLowering
118     : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
119                          OperationPass<func::FuncOp>> {
120   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
121       TestVectorContractionPrepareForMMTLowering)
122 
123   StringRef getArgument() const final {
124     return "test-vector-contraction-prepare-for-mmt-lowering";
125   }
126   StringRef getDescription() const final {
127     return "Test vector.contraction matmul canonicalization for MMT lowering.";
128   }
129   TestVectorContractionPrepareForMMTLowering() = default;
130 
131   void getDependentDialects(DialectRegistry &registry) const override {
132     registry.insert<affine::AffineDialect, arith::ArithDialect,
133                     vector::VectorDialect>();
134   }
135 
136   void runOnOperation() override {
137     MLIRContext *ctx = &getContext();
138     RewritePatternSet patterns(ctx);
139     vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
140     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
141   }
142 };
143 
144 struct TestVectorUnrollingPatterns
145     : public PassWrapper<TestVectorUnrollingPatterns,
146                          OperationPass<func::FuncOp>> {
147   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
148 
149   StringRef getArgument() const final {
150     return "test-vector-unrolling-patterns";
151   }
152   StringRef getDescription() const final {
153     return "Test lowering patterns to unroll contract ops in the vector "
154            "dialect";
155   }
156   TestVectorUnrollingPatterns() = default;
157   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
158       : PassWrapper(pass) {}
159   void runOnOperation() override {
160     MLIRContext *ctx = &getContext();
161     RewritePatternSet patterns(ctx);
162     populateVectorUnrollPatterns(
163         patterns, UnrollVectorOptions()
164                       .setNativeShape(ArrayRef<int64_t>{2, 2})
165                       .setFilterConstraint([](Operation *op) {
166                         return success(isa<arith::AddFOp, vector::FMAOp,
167                                            vector::MultiDimReductionOp>(op));
168                       }));
169     populateVectorUnrollPatterns(
170         patterns, UnrollVectorOptions()
171                       .setNativeShape(ArrayRef<int64_t>{2})
172                       .setFilterConstraint([](Operation *op) {
173                         return success(isa<vector::ReductionOp>(op));
174                       }));
175     populateVectorUnrollPatterns(
176         patterns, UnrollVectorOptions()
177                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
178                       .setFilterConstraint([](Operation *op) {
179                         return success(isa<vector::TransposeOp>(op));
180                       }));
181 
182     if (unrollBasedOnType) {
183       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
184           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
185         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
186         SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
187                                          4);
188         Type lhsType = contractOp.getLhsType().getElementType();
189         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
190         return nativeShape;
191       };
192 
193       UnrollVectorOptions opts;
194       opts.setNativeShapeFn(nativeShapeFn)
195           .setFilterConstraint(
196               [](Operation *op) { return success(isa<ContractionOp>(op)); });
197 
198       if (!unrollOrder.empty()) {
199         opts.setUnrollTraversalOrderFn(
200             [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
201               vector::ContractionOp contractOp =
202                   cast<vector::ContractionOp>(op);
203               if (contractOp.getIteratorTypes().size() == unrollOrder.size())
204                 return SmallVector<int64_t>(unrollOrder.begin(),
205                                             unrollOrder.end());
206               return std::nullopt;
207             });
208       }
209       populateVectorUnrollPatterns(patterns, opts);
210     } else {
211       auto nativeShapeFn =
212           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
213         auto contractOp = dyn_cast<ContractionOp>(op);
214         if (!contractOp)
215           return std::nullopt;
216         return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
217       };
218       populateVectorUnrollPatterns(patterns,
219                                    UnrollVectorOptions()
220                                        .setNativeShapeFn(nativeShapeFn)
221                                        .setFilterConstraint([](Operation *op) {
222                                          return success(isa<ContractionOp>(op));
223                                        }));
224     }
225     populateVectorToVectorCanonicalizationPatterns(patterns);
226     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
227   }
228 
229   ListOption<int64_t> unrollOrder{*this, "unroll-order",
230                                   llvm::cl::desc("set the unroll order")};
231 
232   Option<bool> unrollBasedOnType{
233       *this, "unroll-based-on-type",
234       llvm::cl::desc("Set the unroll factor based on type of the operation"),
235       llvm::cl::init(false)};
236 };
237 
238 struct TestVectorTransferUnrollingPatterns
239     : public PassWrapper<TestVectorTransferUnrollingPatterns,
240                          OperationPass<func::FuncOp>> {
241   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242       TestVectorTransferUnrollingPatterns)
243 
244   TestVectorTransferUnrollingPatterns() = default;
245   TestVectorTransferUnrollingPatterns(
246       const TestVectorTransferUnrollingPatterns &pass)
247       : PassWrapper(pass) {}
248 
249   void getDependentDialects(DialectRegistry &registry) const override {
250     registry.insert<affine::AffineDialect>();
251   }
252   StringRef getArgument() const final {
253     return "test-vector-transfer-unrolling-patterns";
254   }
255   StringRef getDescription() const final {
256     return "Test lowering patterns to unroll transfer ops in the vector "
257            "dialect";
258   }
259   void runOnOperation() override {
260     MLIRContext *ctx = &getContext();
261     RewritePatternSet patterns(ctx);
262     UnrollVectorOptions opts;
263     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
264         .setFilterConstraint([](Operation *op) {
265           return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
266                              vector::GatherOp>(op));
267         });
268     if (reverseUnrollOrder.getValue()) {
269       opts.setUnrollTraversalOrderFn(
270           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
271             int64_t numLoops = 0;
272             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
273               numLoops = readOp.getVectorType().getRank();
274             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
275               numLoops = writeOp.getVectorType().getRank();
276             else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
277               numLoops = gatherOp.getVectorType().getRank();
278             else
279               return std::nullopt;
280             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
281             return llvm::to_vector(order);
282           });
283     }
284     populateVectorUnrollPatterns(patterns, opts);
285     populateVectorToVectorCanonicalizationPatterns(patterns);
286     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
287   }
288 
289   Option<bool> reverseUnrollOrder{
290       *this, "reverse-unroll-order",
291       llvm::cl::desc(
292           "reverse the order of unrolling of vector transfer operations"),
293       llvm::cl::init(false)};
294 };
295 
296 struct TestScalarVectorTransferLoweringPatterns
297     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
298                          OperationPass<func::FuncOp>> {
299   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
300       TestScalarVectorTransferLoweringPatterns)
301 
302   TestScalarVectorTransferLoweringPatterns() = default;
303   TestScalarVectorTransferLoweringPatterns(
304       const TestScalarVectorTransferLoweringPatterns &pass)
305       : PassWrapper(pass) {}
306 
307   StringRef getArgument() const final {
308     return "test-scalar-vector-transfer-lowering";
309   }
310   StringRef getDescription() const final {
311     return "Test lowering of scalar vector transfers to memref loads/stores.";
312   }
313 
314   void getDependentDialects(DialectRegistry &registry) const override {
315     registry.insert<affine::AffineDialect, memref::MemRefDialect,
316                     tensor::TensorDialect, vector::VectorDialect>();
317   }
318 
319   Option<bool> allowMultipleUses{
320       *this, "allow-multiple-uses",
321       llvm::cl::desc("Fold transfer operations with multiple uses"),
322       llvm::cl::init(false)};
323 
324   void runOnOperation() override {
325     MLIRContext *ctx = &getContext();
326     RewritePatternSet patterns(ctx);
327     vector::populateScalarVectorTransferLoweringPatterns(
328         patterns, /*benefit=*/1, allowMultipleUses.getValue());
329     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
330   }
331 };
332 
333 struct TestVectorTransferOpt
334     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
335   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
336 
337   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
338   StringRef getDescription() const final {
339     return "Test optimization transformations for transfer ops";
340   }
341   void runOnOperation() override {
342     IRRewriter rewriter(&getContext());
343     transferOpflowOpt(rewriter, getOperation());
344   }
345 };
346 
347 struct TestVectorTransferCollapseInnerMostContiguousDims
348     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
349                          OperationPass<func::FuncOp>> {
350   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
351       TestVectorTransferCollapseInnerMostContiguousDims)
352 
353   TestVectorTransferCollapseInnerMostContiguousDims() = default;
354   TestVectorTransferCollapseInnerMostContiguousDims(
355       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
356 
357   void getDependentDialects(DialectRegistry &registry) const override {
358     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
359   }
360 
361   StringRef getArgument() const final {
362     return "test-vector-transfer-collapse-inner-most-dims";
363   }
364 
365   StringRef getDescription() const final {
366     return "Test lowering patterns that reducedes the rank of the vector "
367            "transfer memory and vector operands.";
368   }
369 
370   void runOnOperation() override {
371     RewritePatternSet patterns(&getContext());
372     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
373     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
374   }
375 };
376 
377 struct TestVectorSinkPatterns
378     : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
379   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
380 
381   TestVectorSinkPatterns() = default;
382   TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
383 
384   void getDependentDialects(DialectRegistry &registry) const override {
385     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
386   }
387 
388   StringRef getArgument() const final { return "test-vector-sink-patterns"; }
389 
390   StringRef getDescription() const final {
391     return "Test lowering patterns that eliminate redundant brodacast "
392            "and transpose operations.";
393   }
394 
395   void runOnOperation() override {
396     RewritePatternSet patterns(&getContext());
397     populateSinkVectorOpsPatterns(patterns);
398     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
399   }
400 };
401 
402 struct TestVectorReduceToContractPatternsPatterns
403     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
404                          OperationPass<func::FuncOp>> {
405   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
406       TestVectorReduceToContractPatternsPatterns)
407 
408   StringRef getArgument() const final {
409     return "test-vector-reduction-to-contract-patterns";
410   }
411   StringRef getDescription() const final {
412     return "Test patterns to convert multireduce op to contract and combine "
413            "broadcast/transpose to contract";
414   }
415   void runOnOperation() override {
416     RewritePatternSet patterns(&getContext());
417     populateVectorReductionToContractPatterns(patterns);
418     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
419   }
420 };
421 
422 struct TestVectorChainedReductionFoldingPatterns
423     : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
424                          OperationPass<func::FuncOp>> {
425   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
426       TestVectorChainedReductionFoldingPatterns)
427 
428   StringRef getArgument() const final {
429     return "test-vector-chained-reduction-folding-patterns";
430   }
431   StringRef getDescription() const final {
432     return "Test patterns to fold chained vector reductions";
433   }
434   void runOnOperation() override {
435     RewritePatternSet patterns(&getContext());
436     populateChainedVectorReductionFoldingPatterns(patterns);
437     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
438   }
439 };
440 
441 struct TestVectorBreakDownReductionPatterns
442     : public PassWrapper<TestVectorBreakDownReductionPatterns,
443                          OperationPass<func::FuncOp>> {
444   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
445       TestVectorBreakDownReductionPatterns)
446 
447   StringRef getArgument() const final {
448     return "test-vector-break-down-reduction-patterns";
449   }
450   StringRef getDescription() const final {
451     return "Test patterns to break down vector reductions into arith "
452            "reductions";
453   }
454   void runOnOperation() override {
455     RewritePatternSet patterns(&getContext());
456     populateBreakDownVectorReductionPatterns(patterns,
457                                              /*maxNumElementsToExtract=*/2);
458     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
459   }
460 };
461 
462 struct TestFlattenVectorTransferPatterns
463     : public PassWrapper<TestFlattenVectorTransferPatterns,
464                          OperationPass<func::FuncOp>> {
465   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
466       TestFlattenVectorTransferPatterns)
467 
468   TestFlattenVectorTransferPatterns() = default;
469   TestFlattenVectorTransferPatterns(
470       const TestFlattenVectorTransferPatterns &pass)
471       : PassWrapper(pass) {}
472 
473   StringRef getArgument() const final {
474     return "test-vector-transfer-flatten-patterns";
475   }
476 
477   StringRef getDescription() const final {
478     return "Test patterns to rewrite contiguous row-major N-dimensional "
479            "vector.transfer_{read,write} ops into 1D transfers";
480   }
481 
482   void getDependentDialects(DialectRegistry &registry) const override {
483     registry.insert<memref::MemRefDialect>();
484     registry.insert<affine::AffineDialect>();
485     registry.insert<vector::VectorDialect>();
486   }
487 
488   Option<unsigned> targetVectorBitwidth{
489       *this, "target-vector-bitwidth",
490       llvm::cl::desc(
491           "Minimum vector bitwidth to enable the flattening transformation. "
492           "For scalable vectors this is the base size, i.e. the size "
493           "corresponding to vscale=1."),
494       llvm::cl::init(std::numeric_limits<unsigned>::max())};
495 
496   void runOnOperation() override {
497     RewritePatternSet patterns(&getContext());
498     populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
499     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
500   }
501 };
502 
503 struct TestVectorScanLowering
504     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
505   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
506 
507   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
508   StringRef getDescription() const final {
509     return "Test lowering patterns that lower the scan op in the vector "
510            "dialect";
511   }
512   void runOnOperation() override {
513     RewritePatternSet patterns(&getContext());
514     populateVectorScanLoweringPatterns(patterns);
515     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
516   }
517 };
518 
519 /// Allocate shared memory for a single warp to test lowering of
520 /// WarpExecuteOnLane0Op.
521 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
522                                         WarpExecuteOnLane0Op warpOp,
523                                         Type type) {
524   static constexpr int64_t kSharedMemorySpace = 3;
525   // Compute type of shared memory buffer.
526   MemRefType memrefType;
527   if (auto vectorType = dyn_cast<VectorType>(type)) {
528     memrefType =
529         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
530                         kSharedMemorySpace);
531   } else {
532     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
533   }
534 
535   // Get symbol table holding all shared memory globals.
536   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
537   SymbolTable symbolTable(moduleOp);
538 
539   // Create a pretty name.
540   SmallString<64> buf;
541   llvm::raw_svector_ostream os(buf);
542   interleave(memrefType.getShape(), os, "x");
543   os << "x" << memrefType.getElementType();
544   std::string symbolName = (Twine("__shared_") + os.str()).str();
545 
546   auto ip = builder.saveInsertionPoint();
547   builder.setInsertionPoint(moduleOp);
548   auto global = builder.create<memref::GlobalOp>(
549       loc,
550       /*sym_name=*/symbolName,
551       /*sym_visibility=*/builder.getStringAttr("private"),
552       /*type=*/memrefType,
553       /*initial_value=*/Attribute(),
554       /*constant=*/false,
555       /*alignment=*/IntegerAttr());
556   symbolTable.insert(global);
557   // The symbol table inserts at the end of the module, but globals are a bit
558   // nicer if they are at the beginning.
559   global->moveBefore(&moduleOp.front());
560 
561   builder.restoreInsertionPoint(ip);
562   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
563 }
564 
565 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
566                            CombiningKind kind, uint32_t size) {
567   // First reduce on a single thread to get per lane reduction value.
568   Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
569   // Parallel reduction using butterfly shuffles.
570   for (uint64_t i = 1; i < size; i <<= 1) {
571     Value shuffled = builder
572                          .create<gpu::ShuffleOp>(loc, laneVal, i,
573                                                  /*width=*/size,
574                                                  /*mode=*/gpu::ShuffleMode::XOR)
575                          .getShuffleResult();
576     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
577   }
578   return laneVal;
579 }
580 
581 struct TestVectorDistribution
582     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
583   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
584 
585   void getDependentDialects(DialectRegistry &registry) const override {
586     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
587                     affine::AffineDialect>();
588   }
589 
590   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
591   StringRef getDescription() const final {
592     return "Test vector warp distribute transformation and lowering patterns";
593   }
594   TestVectorDistribution() = default;
595   TestVectorDistribution(const TestVectorDistribution &pass)
596       : PassWrapper(pass) {}
597 
598   Option<bool> warpOpToSCF{
599       *this, "rewrite-warp-ops-to-scf-if",
600       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
601       llvm::cl::init(false)};
602 
603   Option<bool> distributeTransferWriteOps{
604       *this, "distribute-transfer-write",
605       llvm::cl::desc("Test distribution of transfer write"),
606       llvm::cl::init(false)};
607 
608   Option<unsigned> maxTransferWriteElements{
609       *this, "max-transfer-write-elements",
610       llvm::cl::desc("Maximum number of transfer write elements to distribute"),
611       llvm::cl::init(1)};
612 
613   Option<bool> hoistUniform{*this, "hoist-uniform",
614                             llvm::cl::desc("Test hoist uniform"),
615                             llvm::cl::init(false)};
616 
617   Option<bool> propagateDistribution{
618       *this, "propagate-distribution",
619       llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
620 
621   void runOnOperation() override {
622     RewritePatternSet patterns(&getContext());
623 
624     getOperation().walk([&](Operation *op) {
625       if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
626         if (hoistUniform) {
627           moveScalarUniformCode(warpOp);
628         }
629         WalkResult::interrupt();
630       }
631     });
632     MLIRContext *ctx = &getContext();
633     auto distributionFn = [](Value val) {
634       // Create an identity dim map of the same rank as the vector.
635       VectorType vecType = dyn_cast<VectorType>(val.getType());
636       int64_t vecRank = vecType ? vecType.getRank() : 0;
637       OpBuilder builder(val.getContext());
638       if (vecRank == 0)
639         return AffineMap::get(val.getContext());
640       return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
641     };
642     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
643                         Value srcIdx, int64_t warpSz) {
644       assert((val.getType().isF32() || val.getType().isInteger(32)) &&
645              "unsupported shuffle type");
646       Type i32Type = builder.getIntegerType(32);
647       Value srcIdxI32 =
648           builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
649       Value warpSzI32 = builder.create<arith::ConstantOp>(
650           loc, builder.getIntegerAttr(i32Type, warpSz));
651       Value result = builder
652                          .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
653                                                  gpu::ShuffleMode::IDX)
654                          .getResult(0);
655       return result;
656     };
657     if (distributeTransferWriteOps && propagateDistribution) {
658       RewritePatternSet patterns(ctx);
659       vector::populatePropagateWarpVectorDistributionPatterns(
660           patterns, distributionFn, shuffleFn, /*benefit=*/1,
661           /*readBenefit=*/0);
662       vector::populateDistributeReduction(patterns, warpReduction, 1);
663       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
664       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
665     } else if (distributeTransferWriteOps) {
666       RewritePatternSet patterns(ctx);
667       populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
668                                                 maxTransferWriteElements);
669       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
670     } else if (propagateDistribution) {
671       RewritePatternSet patterns(ctx);
672       vector::populatePropagateWarpVectorDistributionPatterns(
673           patterns, distributionFn, shuffleFn);
674       vector::populateDistributeReduction(patterns, warpReduction);
675       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
676     }
677     WarpExecuteOnLane0LoweringOptions options;
678     options.warpAllocationFn = allocateGlobalSharedMemory;
679     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
680                                       WarpExecuteOnLane0Op warpOp) {
681       builder.create<gpu::BarrierOp>(loc);
682     };
683     // Test on one pattern in isolation.
684     if (warpOpToSCF) {
685       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
686       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
687       return;
688     }
689   }
690 };
691 
692 struct TestVectorExtractStridedSliceLowering
693     : public PassWrapper<TestVectorExtractStridedSliceLowering,
694                          OperationPass<func::FuncOp>> {
695   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
696       TestVectorExtractStridedSliceLowering)
697 
698   StringRef getArgument() const final {
699     return "test-vector-extract-strided-slice-lowering";
700   }
701   StringRef getDescription() const final {
702     return "Test lowering patterns that converts vector.extract_strided_slice "
703            "into a chain of vector.extract and vector.insert ops";
704   }
705   void runOnOperation() override {
706     RewritePatternSet patterns(&getContext());
707     populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
708     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
709   }
710 };
711 
712 struct TestVectorContiguousExtractStridedSliceToExtract
713     : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
714                          OperationPass<func::FuncOp>> {
715   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
716       TestVectorExtractStridedSliceLowering)
717 
718   StringRef getArgument() const final {
719     return "test-vector-contiguous-extract-strided-slice-to-extract";
720   }
721   StringRef getDescription() const final {
722     return "Test lowering patterns that rewrite simple cases of N-D "
723            "extract_strided_slice, where the slice is contiguous, into extract "
724            "and shape_cast";
725   }
726   void runOnOperation() override {
727     RewritePatternSet patterns(&getContext());
728     populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
729     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
730   }
731 };
732 
733 struct TestVectorBreakDownBitCast
734     : public PassWrapper<TestVectorBreakDownBitCast,
735                          OperationPass<func::FuncOp>> {
736   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
737 
738   StringRef getArgument() const final {
739     return "test-vector-break-down-bitcast";
740   }
741   StringRef getDescription() const final {
742     return "Test pattern that breaks down vector.bitcast ops ";
743   }
744   void runOnOperation() override {
745     RewritePatternSet patterns(&getContext());
746     populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
747       return op.getSourceVectorType().getShape().back() > 4;
748     });
749     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
750   }
751 };
752 
753 struct TestCreateVectorBroadcast
754     : public PassWrapper<TestCreateVectorBroadcast,
755                          OperationPass<func::FuncOp>> {
756   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
757 
758   StringRef getArgument() const final { return "test-create-vector-broadcast"; }
759   StringRef getDescription() const final {
760     return "Test optimization transformations for transfer ops";
761   }
762   void getDependentDialects(DialectRegistry &registry) const override {
763     registry.insert<vector::VectorDialect>();
764   }
765 
766   void runOnOperation() override {
767     getOperation()->walk([](Operation *op) {
768       if (op->getName().getStringRef() != "test_create_broadcast")
769         return;
770       auto targetShape =
771           cast<VectorType>(op->getResult(0).getType()).getShape();
772       auto arrayAttr =
773           cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
774               .asArrayRef();
775       llvm::SetVector<int64_t> broadcastedDims;
776       broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
777       OpBuilder b(op);
778       Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
779           b, op->getOperand(0), targetShape, broadcastedDims);
780       op->getResult(0).replaceAllUsesWith(bcast);
781       op->erase();
782     });
783   }
784 };
785 
786 struct TestVectorGatherLowering
787     : public PassWrapper<TestVectorGatherLowering,
788                          OperationPass<func::FuncOp>> {
789   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
790 
791   StringRef getArgument() const final { return "test-vector-gather-lowering"; }
792   StringRef getDescription() const final {
793     return "Test patterns that lower the gather op in the vector conditional "
794            "loads";
795   }
796   void getDependentDialects(DialectRegistry &registry) const override {
797     registry.insert<arith::ArithDialect, func::FuncDialect,
798                     memref::MemRefDialect, scf::SCFDialect,
799                     tensor::TensorDialect, vector::VectorDialect>();
800   }
801 
802   void runOnOperation() override {
803     RewritePatternSet patterns(&getContext());
804     populateVectorGatherLoweringPatterns(patterns);
805     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
806   }
807 };
808 
809 struct TestFoldArithExtensionIntoVectorContractPatterns
810     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
811                          OperationPass<func::FuncOp>> {
812   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
813       TestFoldArithExtensionIntoVectorContractPatterns)
814 
815   StringRef getArgument() const final {
816     return "test-fold-arith-extf-into-vector-contract-patterns";
817   }
818   StringRef getDescription() const final {
819     return "Test patterns that fold arithmetic extension ops into vector "
820            "contract ops";
821   }
822 
823   void getDependentDialects(DialectRegistry &registry) const override {
824     registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
825                     memref::MemRefDialect, scf::SCFDialect,
826                     tensor::TensorDialect, vector::VectorDialect>();
827   }
828 
829   void runOnOperation() override {
830     RewritePatternSet patterns(&getContext());
831     populateFoldArithExtensionPatterns(patterns);
832     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
833   }
834 };
835 
836 struct TestVectorEmulateMaskedLoadStore final
837     : public PassWrapper<TestVectorEmulateMaskedLoadStore,
838                          OperationPass<func::FuncOp>> {
839   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
840 
841   StringRef getArgument() const override {
842     return "test-vector-emulate-masked-load-store";
843   }
844   StringRef getDescription() const override {
845     return "Test patterns that emulate the maskedload/maskedstore op by "
846            " memref.load/store and scf.if";
847   }
848   void getDependentDialects(DialectRegistry &registry) const override {
849     registry
850         .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
851                 scf::SCFDialect, vector::VectorDialect>();
852   }
853 
854   void runOnOperation() override {
855     RewritePatternSet patterns(&getContext());
856     populateVectorMaskedLoadStoreEmulationPatterns(patterns);
857     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
858   }
859 };
860 
861 struct TestVectorLinearize final
862     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
863   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
864 
865   TestVectorLinearize() = default;
866   TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
867 
868   StringRef getArgument() const override { return "test-vector-linearize"; }
869   StringRef getDescription() const override {
870     return "Linearizes ND vectors for N >= 2 into 1D vectors";
871   }
872   void getDependentDialects(DialectRegistry &registry) const override {
873     registry.insert<vector::VectorDialect>();
874   }
875 
876   Option<unsigned> targetVectorBitwidth{
877       *this, "target-vector-bitwidth",
878       llvm::cl::desc(
879           "Minimum vector bitwidth to enable the flattening transformation"),
880       llvm::cl::init(std::numeric_limits<unsigned>::max())};
881   void runOnOperation() override {
882     auto *context = &getContext();
883 
884     TypeConverter typeConverter;
885     RewritePatternSet patterns(context);
886     ConversionTarget target(*context);
887 
888     vector::populateVectorLinearizeTypeConversionsAndLegality(
889         typeConverter, patterns, target, targetVectorBitwidth);
890     vector::populateVectorLinearizeShuffleLikeOpsPatterns(
891         typeConverter, patterns, target, targetVectorBitwidth);
892     if (failed(applyPartialConversion(getOperation(), target,
893                                       std::move(patterns))))
894       return signalPassFailure();
895   }
896 };
897 
898 struct TestEliminateVectorMasks
899     : public PassWrapper<TestEliminateVectorMasks,
900                          OperationPass<func::FuncOp>> {
901   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
902 
903   TestEliminateVectorMasks() = default;
904   TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
905       : PassWrapper(pass) {}
906 
907   Option<unsigned> vscaleMin{
908       *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
909       llvm::cl::init(1)};
910   Option<unsigned> vscaleMax{
911       *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
912       llvm::cl::init(16)};
913 
914   StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
915   StringRef getDescription() const final {
916     return "Test eliminating vector masks";
917   }
918   void runOnOperation() override {
919     IRRewriter rewriter(&getContext());
920     eliminateVectorMasks(rewriter, getOperation(),
921                          VscaleRange{vscaleMin, vscaleMax});
922   }
923 };
924 } // namespace
925 
926 namespace mlir {
927 namespace test {
928 void registerTestVectorLowerings() {
929   PassRegistration<TestVectorToVectorLowering>();
930 
931   PassRegistration<TestVectorContractionPrepareForMMTLowering>();
932 
933   PassRegistration<TestVectorUnrollingPatterns>();
934 
935   PassRegistration<TestVectorTransferUnrollingPatterns>();
936 
937   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
938 
939   PassRegistration<TestVectorTransferOpt>();
940 
941   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
942 
943   PassRegistration<TestVectorSinkPatterns>();
944 
945   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
946 
947   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
948 
949   PassRegistration<TestVectorBreakDownReductionPatterns>();
950 
951   PassRegistration<TestFlattenVectorTransferPatterns>();
952 
953   PassRegistration<TestVectorScanLowering>();
954 
955   PassRegistration<TestVectorDistribution>();
956 
957   PassRegistration<TestVectorExtractStridedSliceLowering>();
958 
959   PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
960 
961   PassRegistration<TestVectorBreakDownBitCast>();
962 
963   PassRegistration<TestCreateVectorBroadcast>();
964 
965   PassRegistration<TestVectorGatherLowering>();
966 
967   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
968 
969   PassRegistration<TestVectorEmulateMaskedLoadStore>();
970 
971   PassRegistration<TestVectorLinearize>();
972 
973   PassRegistration<TestEliminateVectorMasks>();
974 }
975 } // namespace test
976 } // namespace mlir
977