1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 15 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 22 namespace { 23 /// Simple pass that applies transform dialect ops directly contained in a 24 /// module. 25 26 template <typename Derived> 27 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; 28 29 class TestTransformDialectInterpreterPass 30 : public transform::TransformInterpreterPassBase< 31 TestTransformDialectInterpreterPass, OpPassWrapper> { 32 public: 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 34 TestTransformDialectInterpreterPass) 35 36 TestTransformDialectInterpreterPass() = default; 37 TestTransformDialectInterpreterPass( 38 const TestTransformDialectInterpreterPass &pass) 39 : TransformInterpreterPassBase(pass) {} 40 41 StringRef getArgument() const override { 42 return "test-transform-dialect-interpreter"; 43 } 44 45 StringRef getDescription() const override { 46 return "apply transform dialect operations one by one"; 47 } 48 49 void findOperationsByName(Operation *root, StringRef name, 50 SmallVectorImpl<Operation *> &operations) { 51 root->walk([&](Operation *op) { 52 if (op->getName().getStringRef() == name) { 53 operations.push_back(op); 54 } 55 }); 56 } 57 58 void createParameterMapping(MLIRContext &context, ArrayRef<int> values, 59 RaggedArray<transform::MappedValue> &result) { 60 SmallVector<transform::MappedValue> storage = 61 llvm::to_vector(llvm::map_range(values, [&](int v) { 62 Builder b(&context); 63 return transform::MappedValue(b.getI64IntegerAttr(v)); 64 })); 65 result.push_back(std::move(storage)); 66 } 67 68 void 69 createOpResultMapping(Operation *root, StringRef name, 70 RaggedArray<transform::MappedValue> &extraMapping) { 71 SmallVector<Operation *> operations; 72 findOperationsByName(root, name, operations); 73 SmallVector<Value> results; 74 for (Operation *op : operations) 75 llvm::append_range(results, op->getResults()); 76 extraMapping.push_back(results); 77 } 78 79 unsigned numberOfSetOptions(const Option<std::string> &ops, 80 const ListOption<int> ¶ms, 81 const Option<std::string> &values) { 82 unsigned numSetValues = 0; 83 numSetValues += !ops.empty(); 84 numSetValues += !params.empty(); 85 numSetValues += !values.empty(); 86 return numSetValues; 87 } 88 89 void runOnOperation() override { 90 unsigned firstSetOptions = 91 numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, 92 bindFirstExtraToResultsOfOps); 93 unsigned secondSetOptions = 94 numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, 95 bindSecondExtraToResultsOfOps); 96 auto loc = UnknownLoc::get(&getContext()); 97 if (firstSetOptions > 1) { 98 emitError(loc) << "cannot bind the first extra top-level argument to " 99 "multiple entities"; 100 return signalPassFailure(); 101 } 102 if (secondSetOptions > 1) { 103 emitError(loc) << "cannot bind the second extra top-level argument to " 104 "multiple entities"; 105 return signalPassFailure(); 106 } 107 if (firstSetOptions == 0 && secondSetOptions != 0) { 108 emitError(loc) << "cannot bind the second extra top-level argument " 109 "without bindings the first"; 110 } 111 112 RaggedArray<transform::MappedValue> extraMapping; 113 if (!bindFirstExtraToOps.empty()) { 114 SmallVector<Operation *> operations; 115 findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), 116 operations); 117 extraMapping.push_back(operations); 118 } else if (!bindFirstExtraToParams.empty()) { 119 createParameterMapping(getContext(), bindFirstExtraToParams, 120 extraMapping); 121 } else if (!bindFirstExtraToResultsOfOps.empty()) { 122 createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, 123 extraMapping); 124 } 125 126 if (!bindSecondExtraToOps.empty()) { 127 SmallVector<Operation *> operations; 128 findOperationsByName(getOperation(), bindSecondExtraToOps, operations); 129 extraMapping.push_back(operations); 130 } else if (!bindSecondExtraToParams.empty()) { 131 createParameterMapping(getContext(), bindSecondExtraToParams, 132 extraMapping); 133 } else if (!bindSecondExtraToResultsOfOps.empty()) { 134 createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, 135 extraMapping); 136 } 137 138 options = options.enableExpensiveChecks(enableExpensiveChecks); 139 if (failed(transform::detail::interpreterBaseRunOnOperationImpl( 140 getOperation(), getArgument(), getSharedTransformModule(), 141 extraMapping, options, transformFileName, debugPayloadRootTag, 142 debugTransformRootTag, getBinaryName()))) 143 return signalPassFailure(); 144 } 145 146 Option<bool> enableExpensiveChecks{ 147 *this, "enable-expensive-checks", llvm::cl::init(false), 148 llvm::cl::desc("perform expensive checks to better report errors in the " 149 "transform IR")}; 150 151 Option<std::string> bindFirstExtraToOps{ 152 *this, "bind-first-extra-to-ops", 153 llvm::cl::desc("bind the first extra argument of the top-level op to " 154 "payload operations of the given kind")}; 155 ListOption<int> bindFirstExtraToParams{ 156 *this, "bind-first-extra-to-params", 157 llvm::cl::desc("bind the first extra argument of the top-level op to " 158 "the given integer parameters")}; 159 Option<std::string> bindFirstExtraToResultsOfOps{ 160 *this, "bind-first-extra-to-results-of-ops", 161 llvm::cl::desc("bind the first extra argument of the top-level op to " 162 "results of payload operations of the given kind")}; 163 164 Option<std::string> bindSecondExtraToOps{ 165 *this, "bind-second-extra-to-ops", 166 llvm::cl::desc("bind the second extra argument of the top-level op to " 167 "payload operations of the given kind")}; 168 ListOption<int> bindSecondExtraToParams{ 169 *this, "bind-second-extra-to-params", 170 llvm::cl::desc("bind the second extra argument of the top-level op to " 171 "the given integer parameters")}; 172 Option<std::string> bindSecondExtraToResultsOfOps{ 173 *this, "bind-second-extra-to-results-of-ops", 174 llvm::cl::desc("bind the second extra argument of the top-level op to " 175 "results of payload operations of the given kind")}; 176 177 Option<std::string> transformFileName{ 178 *this, "transform-file-name", llvm::cl::init(""), 179 llvm::cl::desc( 180 "Optional filename containing a transform dialect specification to " 181 "apply. If left empty, the IR is assumed to contain one top-level " 182 "transform dialect operation somewhere in the module.")}; 183 Option<std::string> debugPayloadRootTag{ 184 *this, "debug-payload-root-tag", llvm::cl::init(""), 185 llvm::cl::desc( 186 "Select the operation with 'transform.target_tag' attribute having " 187 "the given value as payload IR root. If empty select the pass anchor " 188 "operation as the payload IR root.")}; 189 Option<std::string> debugTransformRootTag{ 190 *this, "debug-transform-root-tag", llvm::cl::init(""), 191 llvm::cl::desc( 192 "Select the operation with 'transform.target_tag' attribute having " 193 "the given value as container IR for top-level transform ops. This " 194 "allows user control on what transformation to apply. If empty, " 195 "select the container of the top-level transform op.")}; 196 }; 197 198 struct TestTransformDialectEraseSchedulePass 199 : public PassWrapper<TestTransformDialectEraseSchedulePass, 200 OperationPass<ModuleOp>> { 201 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 202 TestTransformDialectEraseSchedulePass) 203 204 StringRef getArgument() const final { 205 return "test-transform-dialect-erase-schedule"; 206 } 207 208 StringRef getDescription() const final { 209 return "erase transform dialect schedule from the IR"; 210 } 211 212 void runOnOperation() override { 213 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) { 214 if (isa<transform::TransformOpInterface>(nestedOp)) { 215 nestedOp->erase(); 216 return WalkResult::skip(); 217 } 218 return WalkResult::advance(); 219 }); 220 } 221 }; 222 } // namespace 223 224 namespace mlir { 225 namespace test { 226 /// Registers the test pass for erasing transform dialect ops. 227 void registerTestTransformDialectEraseSchedulePass() { 228 PassRegistration<TestTransformDialectEraseSchedulePass> reg; 229 } 230 /// Registers the test pass for applying transform dialect ops. 231 void registerTestTransformDialectInterpreterPass() { 232 PassRegistration<TestTransformDialectInterpreterPass> reg; 233 } 234 } // namespace test 235 } // namespace mlir 236