1e534cee2SJacques Pienaar //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2e534cee2SJacques Pienaar // 3e534cee2SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e534cee2SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information. 5e534cee2SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e534cee2SJacques Pienaar // 7e534cee2SJacques Pienaar //===----------------------------------------------------------------------===// 8e534cee2SJacques Pienaar 9e534cee2SJacques Pienaar #include <queue> 10e534cee2SJacques Pienaar 119bae20b5SJacques Pienaar #include "mlir/Dialect/Func/IR/FuncOps.h" 12e534cee2SJacques Pienaar #include "mlir/Dialect/Shape/IR/Shape.h" 13e534cee2SJacques Pienaar #include "mlir/IR/BuiltinDialect.h" 14e534cee2SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h" 15e534cee2SJacques Pienaar #include "mlir/Pass/Pass.h" 16e534cee2SJacques Pienaar 17e534cee2SJacques Pienaar using namespace mlir; 18e534cee2SJacques Pienaar 19e534cee2SJacques Pienaar namespace { 20e534cee2SJacques Pienaar /// This is a pass that reports shape functions associated with ops. 21e534cee2SJacques Pienaar struct ReportShapeFnPass 22e534cee2SJacques Pienaar : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 235e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass) 245e50dd04SRiver Riddle 25e534cee2SJacques Pienaar void runOnOperation() override; getArgument__anon0fb73c7a0111::ReportShapeFnPass26b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-shape-function-report"; } getDescription__anon0fb73c7a0111::ReportShapeFnPass27b5e22e6dSMehdi Amini StringRef getDescription() const final { 28b5e22e6dSMehdi Amini return "Test pass to report associated shape functions"; 29b5e22e6dSMehdi Amini } 30e534cee2SJacques Pienaar }; 31be0a7e9fSMehdi Amini } // namespace 32e534cee2SJacques Pienaar runOnOperation()33e534cee2SJacques Pienaarvoid ReportShapeFnPass::runOnOperation() { 34e534cee2SJacques Pienaar auto module = getOperation(); 35e534cee2SJacques Pienaar 36e534cee2SJacques Pienaar // Report the shape function available to refine the op. 37195730a6SRiver Riddle auto shapeFnId = StringAttr::get(&getContext(), "shape.function"); 388d541a1fSJacques Pienaar auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { 39fe7c0d90SRiver Riddle if (op->hasTrait<OpTrait::IsTerminator>()) 408d541a1fSJacques Pienaar return true; 41e534cee2SJacques Pienaar if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 42e534cee2SJacques Pienaar op->emitRemark() << "implements InferType op interface"; 438d541a1fSJacques Pienaar return true; 448d541a1fSJacques Pienaar } 458d541a1fSJacques Pienaar if (auto fn = shapeFnLib.getShapeFunction(op)) { 46e534cee2SJacques Pienaar op->emitRemark() << "associated shape function: " << fn.getName(); 478d541a1fSJacques Pienaar return true; 488d541a1fSJacques Pienaar } 498d541a1fSJacques Pienaar if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 509bae20b5SJacques Pienaar auto fn = 519bae20b5SJacques Pienaar cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 52e534cee2SJacques Pienaar op->emitRemark() << "associated shape function: " << fn.getName(); 538d541a1fSJacques Pienaar return true; 54e534cee2SJacques Pienaar } 558d541a1fSJacques Pienaar return false; 56e534cee2SJacques Pienaar }; 57e534cee2SJacques Pienaar 588d541a1fSJacques Pienaar // Lookup shape function library. 598d541a1fSJacques Pienaar SmallVector<shape::FunctionLibraryOp, 4> libraries; 60*830b9b07SMehdi Amini auto attr = module->getDiscardableAttr("shape.lib"); 618d541a1fSJacques Pienaar if (attr) { 628d541a1fSJacques Pienaar auto lookup = [&](Attribute attr) { 638d541a1fSJacques Pienaar return cast<shape::FunctionLibraryOp>( 645550c821STres Popp SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr))); 658d541a1fSJacques Pienaar }; 665550c821STres Popp if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { 678d541a1fSJacques Pienaar libraries.reserve(arrayAttr.size()); 688d541a1fSJacques Pienaar for (auto attr : arrayAttr) 698d541a1fSJacques Pienaar libraries.push_back(lookup(attr)); 708d541a1fSJacques Pienaar } else { 718d541a1fSJacques Pienaar libraries.reserve(1); 728d541a1fSJacques Pienaar libraries.push_back(lookup(attr)); 738d541a1fSJacques Pienaar } 748d541a1fSJacques Pienaar } 758d541a1fSJacques Pienaar 7658ceae95SRiver Riddle module.getBodyRegion().walk([&](func::FuncOp func) { 77e534cee2SJacques Pienaar // Skip ops in the shape function library. 780bf4a82aSChristian Sigg if (isa<shape::FunctionLibraryOp>(func->getParentOp())) 79e534cee2SJacques Pienaar return; 80e534cee2SJacques Pienaar 818d541a1fSJacques Pienaar func.walk([&](Operation *op) { 828d541a1fSJacques Pienaar bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { 838d541a1fSJacques Pienaar return remarkShapeFn(lib, op); 848d541a1fSJacques Pienaar }); 858d541a1fSJacques Pienaar if (!found) 868d541a1fSJacques Pienaar op->emitRemark() << "no associated way to refine shape"; 878d541a1fSJacques Pienaar }); 88e534cee2SJacques Pienaar }); 89e534cee2SJacques Pienaar } 90e534cee2SJacques Pienaar 91e534cee2SJacques Pienaar namespace mlir { registerShapeFunctionTestPasses()92e534cee2SJacques Pienaarvoid registerShapeFunctionTestPasses() { 93b5e22e6dSMehdi Amini PassRegistration<ReportShapeFnPass>(); 94e534cee2SJacques Pienaar } 95e534cee2SJacques Pienaar } // namespace mlir 96