xref: /llvm-project/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1e534cee2SJacques Pienaar //===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
2e534cee2SJacques Pienaar //
3e534cee2SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e534cee2SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
5e534cee2SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e534cee2SJacques Pienaar //
7e534cee2SJacques Pienaar //===----------------------------------------------------------------------===//
8e534cee2SJacques Pienaar 
9e534cee2SJacques Pienaar #include <queue>
10e534cee2SJacques Pienaar 
119bae20b5SJacques Pienaar #include "mlir/Dialect/Func/IR/FuncOps.h"
12e534cee2SJacques Pienaar #include "mlir/Dialect/Shape/IR/Shape.h"
13e534cee2SJacques Pienaar #include "mlir/IR/BuiltinDialect.h"
14e534cee2SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h"
15e534cee2SJacques Pienaar #include "mlir/Pass/Pass.h"
16e534cee2SJacques Pienaar 
17e534cee2SJacques Pienaar using namespace mlir;
18e534cee2SJacques Pienaar 
19e534cee2SJacques Pienaar namespace {
20e534cee2SJacques Pienaar /// This is a pass that reports shape functions associated with ops.
21e534cee2SJacques Pienaar struct ReportShapeFnPass
22e534cee2SJacques Pienaar     : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
235e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass)
245e50dd04SRiver Riddle 
25e534cee2SJacques Pienaar   void runOnOperation() override;
getArgument__anon0fb73c7a0111::ReportShapeFnPass26b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-shape-function-report"; }
getDescription__anon0fb73c7a0111::ReportShapeFnPass27b5e22e6dSMehdi Amini   StringRef getDescription() const final {
28b5e22e6dSMehdi Amini     return "Test pass to report associated shape functions";
29b5e22e6dSMehdi Amini   }
30e534cee2SJacques Pienaar };
31be0a7e9fSMehdi Amini } // namespace
32e534cee2SJacques Pienaar 
runOnOperation()33e534cee2SJacques Pienaar void ReportShapeFnPass::runOnOperation() {
34e534cee2SJacques Pienaar   auto module = getOperation();
35e534cee2SJacques Pienaar 
36e534cee2SJacques Pienaar   // Report the shape function available to refine the op.
37195730a6SRiver Riddle   auto shapeFnId = StringAttr::get(&getContext(), "shape.function");
388d541a1fSJacques Pienaar   auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
39fe7c0d90SRiver Riddle     if (op->hasTrait<OpTrait::IsTerminator>())
408d541a1fSJacques Pienaar       return true;
41e534cee2SJacques Pienaar     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
42e534cee2SJacques Pienaar       op->emitRemark() << "implements InferType op interface";
438d541a1fSJacques Pienaar       return true;
448d541a1fSJacques Pienaar     }
458d541a1fSJacques Pienaar     if (auto fn = shapeFnLib.getShapeFunction(op)) {
46e534cee2SJacques Pienaar       op->emitRemark() << "associated shape function: " << fn.getName();
478d541a1fSJacques Pienaar       return true;
488d541a1fSJacques Pienaar     }
498d541a1fSJacques Pienaar     if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
509bae20b5SJacques Pienaar       auto fn =
519bae20b5SJacques Pienaar           cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
52e534cee2SJacques Pienaar       op->emitRemark() << "associated shape function: " << fn.getName();
538d541a1fSJacques Pienaar       return true;
54e534cee2SJacques Pienaar     }
558d541a1fSJacques Pienaar     return false;
56e534cee2SJacques Pienaar   };
57e534cee2SJacques Pienaar 
588d541a1fSJacques Pienaar   // Lookup shape function library.
598d541a1fSJacques Pienaar   SmallVector<shape::FunctionLibraryOp, 4> libraries;
60*830b9b07SMehdi Amini   auto attr = module->getDiscardableAttr("shape.lib");
618d541a1fSJacques Pienaar   if (attr) {
628d541a1fSJacques Pienaar     auto lookup = [&](Attribute attr) {
638d541a1fSJacques Pienaar       return cast<shape::FunctionLibraryOp>(
645550c821STres Popp           SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr)));
658d541a1fSJacques Pienaar     };
665550c821STres Popp     if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
678d541a1fSJacques Pienaar       libraries.reserve(arrayAttr.size());
688d541a1fSJacques Pienaar       for (auto attr : arrayAttr)
698d541a1fSJacques Pienaar         libraries.push_back(lookup(attr));
708d541a1fSJacques Pienaar     } else {
718d541a1fSJacques Pienaar       libraries.reserve(1);
728d541a1fSJacques Pienaar       libraries.push_back(lookup(attr));
738d541a1fSJacques Pienaar     }
748d541a1fSJacques Pienaar   }
758d541a1fSJacques Pienaar 
7658ceae95SRiver Riddle   module.getBodyRegion().walk([&](func::FuncOp func) {
77e534cee2SJacques Pienaar     // Skip ops in the shape function library.
780bf4a82aSChristian Sigg     if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
79e534cee2SJacques Pienaar       return;
80e534cee2SJacques Pienaar 
818d541a1fSJacques Pienaar     func.walk([&](Operation *op) {
828d541a1fSJacques Pienaar       bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
838d541a1fSJacques Pienaar         return remarkShapeFn(lib, op);
848d541a1fSJacques Pienaar       });
858d541a1fSJacques Pienaar       if (!found)
868d541a1fSJacques Pienaar         op->emitRemark() << "no associated way to refine shape";
878d541a1fSJacques Pienaar     });
88e534cee2SJacques Pienaar   });
89e534cee2SJacques Pienaar }
90e534cee2SJacques Pienaar 
91e534cee2SJacques Pienaar namespace mlir {
registerShapeFunctionTestPasses()92e534cee2SJacques Pienaar void registerShapeFunctionTestPasses() {
93b5e22e6dSMehdi Amini   PassRegistration<ReportShapeFnPass>();
94e534cee2SJacques Pienaar }
95e534cee2SJacques Pienaar } // namespace mlir
96