1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 15 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 22 namespace { 23 /// Simple pass that applies transform dialect ops directly contained in a 24 /// module. 25 26 template <typename Derived> 27 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; 28 29 class TestTransformDialectInterpreterPass 30 : public transform::TransformInterpreterPassBase< 31 TestTransformDialectInterpreterPass, OpPassWrapper> { 32 public: 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 34 TestTransformDialectInterpreterPass) 35 36 TestTransformDialectInterpreterPass() = default; 37 TestTransformDialectInterpreterPass( 38 const TestTransformDialectInterpreterPass &pass) 39 : TransformInterpreterPassBase(pass) {} 40 41 StringRef getArgument() const override { 42 return "test-transform-dialect-interpreter"; 43 } 44 45 StringRef getDescription() const override { 46 return "apply transform dialect operations one by one"; 47 } 48 49 void findOperationsByName(Operation *root, StringRef name, 50 SmallVectorImpl<Operation *> &operations) { 51 root->walk([&](Operation *op) { 52 if (op->getName().getStringRef() == name) { 53 operations.push_back(op); 54 } 55 }); 56 } 57 58 void createParameterMapping(MLIRContext &context, ArrayRef<int> values, 59 RaggedArray<transform::MappedValue> &result) { 60 SmallVector<transform::MappedValue> storage = 61 llvm::to_vector(llvm::map_range(values, [&](int v) { 62 Builder b(&context); 63 return transform::MappedValue(b.getI64IntegerAttr(v)); 64 })); 65 result.push_back(std::move(storage)); 66 } 67 68 void 69 createOpResultMapping(Operation *root, StringRef name, 70 RaggedArray<transform::MappedValue> &extraMapping) { 71 SmallVector<Operation *> operations; 72 findOperationsByName(root, name, operations); 73 SmallVector<Value> results; 74 for (Operation *op : operations) 75 llvm::append_range(results, op->getResults()); 76 extraMapping.push_back(results); 77 } 78 79 unsigned numberOfSetOptions(const Option<std::string> &ops, 80 const ListOption<int> ¶ms, 81 const Option<std::string> &values) { 82 unsigned numSetValues = 0; 83 numSetValues += !ops.empty(); 84 numSetValues += !params.empty(); 85 numSetValues += !values.empty(); 86 return numSetValues; 87 } 88 89 void runOnOperation() override { 90 unsigned firstSetOptions = 91 numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, 92 bindFirstExtraToResultsOfOps); 93 unsigned secondSetOptions = 94 numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, 95 bindSecondExtraToResultsOfOps); 96 auto loc = UnknownLoc::get(&getContext()); 97 if (firstSetOptions > 1) { 98 emitError(loc) << "cannot bind the first extra top-level argument to " 99 "multiple entities"; 100 return signalPassFailure(); 101 } 102 if (secondSetOptions > 1) { 103 emitError(loc) << "cannot bind the second extra top-level argument to " 104 "multiple entities"; 105 return signalPassFailure(); 106 } 107 if (firstSetOptions == 0 && secondSetOptions != 0) { 108 emitError(loc) << "cannot bind the second extra top-level argument " 109 "without bindings the first"; 110 } 111 112 RaggedArray<transform::MappedValue> extraMapping; 113 if (!bindFirstExtraToOps.empty()) { 114 SmallVector<Operation *> operations; 115 findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), 116 operations); 117 extraMapping.push_back(operations); 118 } else if (!bindFirstExtraToParams.empty()) { 119 createParameterMapping(getContext(), bindFirstExtraToParams, 120 extraMapping); 121 } else if (!bindFirstExtraToResultsOfOps.empty()) { 122 createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, 123 extraMapping); 124 } 125 126 if (!bindSecondExtraToOps.empty()) { 127 SmallVector<Operation *> operations; 128 findOperationsByName(getOperation(), bindSecondExtraToOps, operations); 129 extraMapping.push_back(operations); 130 } else if (!bindSecondExtraToParams.empty()) { 131 createParameterMapping(getContext(), bindSecondExtraToParams, 132 extraMapping); 133 } else if (!bindSecondExtraToResultsOfOps.empty()) { 134 createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, 135 extraMapping); 136 } 137 138 options = options.enableExpensiveChecks(enableExpensiveChecks); 139 if (failed(transform::detail::interpreterBaseRunOnOperationImpl( 140 getOperation(), getArgument(), getSharedTransformModule(), 141 getTransformLibraryModule(), extraMapping, options, 142 transformFileName, transformLibraryFileName, debugPayloadRootTag, 143 debugTransformRootTag, getBinaryName()))) 144 return signalPassFailure(); 145 } 146 147 Option<bool> enableExpensiveChecks{ 148 *this, "enable-expensive-checks", llvm::cl::init(false), 149 llvm::cl::desc("perform expensive checks to better report errors in the " 150 "transform IR")}; 151 152 Option<std::string> bindFirstExtraToOps{ 153 *this, "bind-first-extra-to-ops", 154 llvm::cl::desc("bind the first extra argument of the top-level op to " 155 "payload operations of the given kind")}; 156 ListOption<int> bindFirstExtraToParams{ 157 *this, "bind-first-extra-to-params", 158 llvm::cl::desc("bind the first extra argument of the top-level op to " 159 "the given integer parameters")}; 160 Option<std::string> bindFirstExtraToResultsOfOps{ 161 *this, "bind-first-extra-to-results-of-ops", 162 llvm::cl::desc("bind the first extra argument of the top-level op to " 163 "results of payload operations of the given kind")}; 164 165 Option<std::string> bindSecondExtraToOps{ 166 *this, "bind-second-extra-to-ops", 167 llvm::cl::desc("bind the second extra argument of the top-level op to " 168 "payload operations of the given kind")}; 169 ListOption<int> bindSecondExtraToParams{ 170 *this, "bind-second-extra-to-params", 171 llvm::cl::desc("bind the second extra argument of the top-level op to " 172 "the given integer parameters")}; 173 Option<std::string> bindSecondExtraToResultsOfOps{ 174 *this, "bind-second-extra-to-results-of-ops", 175 llvm::cl::desc("bind the second extra argument of the top-level op to " 176 "results of payload operations of the given kind")}; 177 178 Option<std::string> transformFileName{ 179 *this, "transform-file-name", llvm::cl::init(""), 180 llvm::cl::desc( 181 "Optional filename containing a transform dialect specification to " 182 "apply. If left empty, the IR is assumed to contain one top-level " 183 "transform dialect operation somewhere in the module.")}; 184 Option<std::string> debugPayloadRootTag{ 185 *this, "debug-payload-root-tag", llvm::cl::init(""), 186 llvm::cl::desc( 187 "Select the operation with 'transform.target_tag' attribute having " 188 "the given value as payload IR root. If empty select the pass anchor " 189 "operation as the payload IR root.")}; 190 Option<std::string> debugTransformRootTag{ 191 *this, "debug-transform-root-tag", llvm::cl::init(""), 192 llvm::cl::desc( 193 "Select the operation with 'transform.target_tag' attribute having " 194 "the given value as container IR for top-level transform ops. This " 195 "allows user control on what transformation to apply. If empty, " 196 "select the container of the top-level transform op.")}; 197 Option<std::string> transformLibraryFileName{ 198 *this, "transform-library-file-name", llvm::cl::init(""), 199 llvm::cl::desc( 200 "Optional name of the file containing transform dialect symbol " 201 "definitions to be injected into the transform module.")}; 202 }; 203 204 struct TestTransformDialectEraseSchedulePass 205 : public PassWrapper<TestTransformDialectEraseSchedulePass, 206 OperationPass<ModuleOp>> { 207 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 208 TestTransformDialectEraseSchedulePass) 209 210 StringRef getArgument() const final { 211 return "test-transform-dialect-erase-schedule"; 212 } 213 214 StringRef getDescription() const final { 215 return "erase transform dialect schedule from the IR"; 216 } 217 218 void runOnOperation() override { 219 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) { 220 if (isa<transform::TransformOpInterface>(nestedOp)) { 221 nestedOp->erase(); 222 return WalkResult::skip(); 223 } 224 return WalkResult::advance(); 225 }); 226 } 227 }; 228 } // namespace 229 230 namespace mlir { 231 namespace test { 232 /// Registers the test pass for erasing transform dialect ops. 233 void registerTestTransformDialectEraseSchedulePass() { 234 PassRegistration<TestTransformDialectEraseSchedulePass> reg; 235 } 236 /// Registers the test pass for applying transform dialect ops. 237 void registerTestTransformDialectInterpreterPass() { 238 PassRegistration<TestTransformDialectInterpreterPass> reg; 239 } 240 } // namespace test 241 } // namespace mlir 242