xref: /llvm-project/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1 //===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <queue>
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Shape/IR/Shape.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 /// This is a pass that reports shape functions associated with ops.
21 struct ReportShapeFnPass
22     : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
23   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass)
24 
25   void runOnOperation() override;
getArgument__anon0fb73c7a0111::ReportShapeFnPass26   StringRef getArgument() const final { return "test-shape-function-report"; }
getDescription__anon0fb73c7a0111::ReportShapeFnPass27   StringRef getDescription() const final {
28     return "Test pass to report associated shape functions";
29   }
30 };
31 } // namespace
32 
runOnOperation()33 void ReportShapeFnPass::runOnOperation() {
34   auto module = getOperation();
35 
36   // Report the shape function available to refine the op.
37   auto shapeFnId = StringAttr::get(&getContext(), "shape.function");
38   auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
39     if (op->hasTrait<OpTrait::IsTerminator>())
40       return true;
41     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
42       op->emitRemark() << "implements InferType op interface";
43       return true;
44     }
45     if (auto fn = shapeFnLib.getShapeFunction(op)) {
46       op->emitRemark() << "associated shape function: " << fn.getName();
47       return true;
48     }
49     if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
50       auto fn =
51           cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
52       op->emitRemark() << "associated shape function: " << fn.getName();
53       return true;
54     }
55     return false;
56   };
57 
58   // Lookup shape function library.
59   SmallVector<shape::FunctionLibraryOp, 4> libraries;
60   auto attr = module->getDiscardableAttr("shape.lib");
61   if (attr) {
62     auto lookup = [&](Attribute attr) {
63       return cast<shape::FunctionLibraryOp>(
64           SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr)));
65     };
66     if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
67       libraries.reserve(arrayAttr.size());
68       for (auto attr : arrayAttr)
69         libraries.push_back(lookup(attr));
70     } else {
71       libraries.reserve(1);
72       libraries.push_back(lookup(attr));
73     }
74   }
75 
76   module.getBodyRegion().walk([&](func::FuncOp func) {
77     // Skip ops in the shape function library.
78     if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
79       return;
80 
81     func.walk([&](Operation *op) {
82       bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
83         return remarkShapeFn(lib, op);
84       });
85       if (!found)
86         op->emitRemark() << "no associated way to refine shape";
87     });
88   });
89 }
90 
91 namespace mlir {
registerShapeFunctionTestPasses()92 void registerShapeFunctionTestPasses() {
93   PassRegistration<ReportShapeFnPass>();
94 }
95 } // namespace mlir
96