1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <queue> 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/Shape/IR/Shape.h" 13 #include "mlir/IR/BuiltinDialect.h" 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "mlir/Pass/Pass.h" 16 17 using namespace mlir; 18 19 namespace { 20 /// This is a pass that reports shape functions associated with ops. 21 struct ReportShapeFnPass 22 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass) 24 25 void runOnOperation() override; getArgument__anon0fb73c7a0111::ReportShapeFnPass26 StringRef getArgument() const final { return "test-shape-function-report"; } getDescription__anon0fb73c7a0111::ReportShapeFnPass27 StringRef getDescription() const final { 28 return "Test pass to report associated shape functions"; 29 } 30 }; 31 } // namespace 32 runOnOperation()33void ReportShapeFnPass::runOnOperation() { 34 auto module = getOperation(); 35 36 // Report the shape function available to refine the op. 37 auto shapeFnId = StringAttr::get(&getContext(), "shape.function"); 38 auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { 39 if (op->hasTrait<OpTrait::IsTerminator>()) 40 return true; 41 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 42 op->emitRemark() << "implements InferType op interface"; 43 return true; 44 } 45 if (auto fn = shapeFnLib.getShapeFunction(op)) { 46 op->emitRemark() << "associated shape function: " << fn.getName(); 47 return true; 48 } 49 if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 50 auto fn = 51 cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 52 op->emitRemark() << "associated shape function: " << fn.getName(); 53 return true; 54 } 55 return false; 56 }; 57 58 // Lookup shape function library. 59 SmallVector<shape::FunctionLibraryOp, 4> libraries; 60 auto attr = module->getDiscardableAttr("shape.lib"); 61 if (attr) { 62 auto lookup = [&](Attribute attr) { 63 return cast<shape::FunctionLibraryOp>( 64 SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr))); 65 }; 66 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { 67 libraries.reserve(arrayAttr.size()); 68 for (auto attr : arrayAttr) 69 libraries.push_back(lookup(attr)); 70 } else { 71 libraries.reserve(1); 72 libraries.push_back(lookup(attr)); 73 } 74 } 75 76 module.getBodyRegion().walk([&](func::FuncOp func) { 77 // Skip ops in the shape function library. 78 if (isa<shape::FunctionLibraryOp>(func->getParentOp())) 79 return; 80 81 func.walk([&](Operation *op) { 82 bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { 83 return remarkShapeFn(lib, op); 84 }); 85 if (!found) 86 op->emitRemark() << "no associated way to refine shape"; 87 }); 88 }); 89 } 90 91 namespace mlir { registerShapeFunctionTestPasses()92void registerShapeFunctionTestPasses() { 93 PassRegistration<ReportShapeFnPass>(); 94 } 95 } // namespace mlir 96