1 //===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing the transformation to drop unit 10 // extent dimensions from `linalg.generic` operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 LogicalResult dropOutermostUnitDims(RewriterBase &rewriter, 25 linalg::GenericOp genericOp) { 26 linalg::ControlDropUnitDims options; 27 options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; }; 28 FailureOr<linalg::DropUnitDimsResult> result = 29 linalg::dropUnitDims(rewriter, genericOp, options); 30 if (failed(result)) { 31 return failure(); 32 } 33 rewriter.replaceOp(genericOp, result->replacements); 34 return success(); 35 } 36 37 struct TestLinalgDropUnitDims 38 : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> { 39 40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims) 41 42 TestLinalgDropUnitDims() = default; 43 TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default; 44 45 void getDependentDialects(DialectRegistry ®istry) const override { 46 registry.insert<linalg::LinalgDialect>(); 47 } 48 49 StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; } 50 51 StringRef getDescriptions() const { 52 return "Test transformation to drop unit-extent dims from Linalg " 53 "operations"; 54 } 55 56 void runOnOperation() override { 57 MLIRContext *context = &this->getContext(); 58 func::FuncOp funcOp = this->getOperation(); 59 IRRewriter rewriter(context); 60 SmallVector<linalg::GenericOp> genericOps; 61 funcOp.walk( 62 [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); }); 63 64 for (auto genericOp : genericOps) { 65 rewriter.setInsertionPoint(genericOp); 66 (void)dropOutermostUnitDims(rewriter, genericOp); 67 } 68 } 69 }; 70 } // namespace 71 72 namespace mlir { 73 namespace test { 74 void registerTestLinalgDropUnitDims() { 75 PassRegistration<TestLinalgDropUnitDims>(); 76 } 77 } // namespace test 78 } // namespace mlir 79