1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h" 16 #include "mlir/Dialect/Transform/IR/TransformOps.h" 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace mlir; 24 25 namespace { 26 /// Simple pass that applies transform dialect ops directly contained in a 27 /// module. 28 29 template <typename Derived> 30 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; 31 32 class TestTransformDialectInterpreterPass 33 : public transform::TransformInterpreterPassBase< 34 TestTransformDialectInterpreterPass, OpPassWrapper> { 35 public: 36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 37 TestTransformDialectInterpreterPass) 38 39 TestTransformDialectInterpreterPass() = default; 40 TestTransformDialectInterpreterPass( 41 const TestTransformDialectInterpreterPass &pass) 42 : TransformInterpreterPassBase(pass) {} 43 44 StringRef getArgument() const override { 45 return "test-transform-dialect-interpreter"; 46 } 47 48 StringRef getDescription() const override { 49 return "apply transform dialect operations one by one"; 50 } 51 52 void getDependentDialects(DialectRegistry ®istry) const override { 53 registry.insert<transform::TransformDialect>(); 54 } 55 56 void findOperationsByName(Operation *root, StringRef name, 57 SmallVectorImpl<Operation *> &operations) { 58 root->walk([&](Operation *op) { 59 if (op->getName().getStringRef() == name) { 60 operations.push_back(op); 61 } 62 }); 63 } 64 65 void createParameterMapping(MLIRContext &context, ArrayRef<int> values, 66 RaggedArray<transform::MappedValue> &result) { 67 SmallVector<transform::MappedValue> storage = 68 llvm::to_vector(llvm::map_range(values, [&](int v) { 69 Builder b(&context); 70 return transform::MappedValue(b.getI64IntegerAttr(v)); 71 })); 72 result.push_back(std::move(storage)); 73 } 74 75 void 76 createOpResultMapping(Operation *root, StringRef name, 77 RaggedArray<transform::MappedValue> &extraMapping) { 78 SmallVector<Operation *> operations; 79 findOperationsByName(root, name, operations); 80 SmallVector<Value> results; 81 for (Operation *op : operations) 82 llvm::append_range(results, op->getResults()); 83 extraMapping.push_back(results); 84 } 85 86 unsigned numberOfSetOptions(const Option<std::string> &ops, 87 const ListOption<int> ¶ms, 88 const Option<std::string> &values) { 89 unsigned numSetValues = 0; 90 numSetValues += !ops.empty(); 91 numSetValues += !params.empty(); 92 numSetValues += !values.empty(); 93 return numSetValues; 94 } 95 96 std::optional<LogicalResult> constructTransformModule(OpBuilder &builder, 97 Location loc) { 98 if (!testModuleGeneration) 99 return std::nullopt; 100 101 builder.create<transform::SequenceOp>( 102 loc, TypeRange(), transform::FailurePropagationMode::Propagate, 103 builder.getType<transform::AnyOpType>(), 104 [](OpBuilder &b, Location nested, Value rootH) { 105 b.create<transform::DebugEmitRemarkAtOp>(nested, rootH, 106 "remark from generated"); 107 b.create<transform::YieldOp>(nested, ValueRange()); 108 }); 109 return success(); 110 } 111 112 void runOnOperation() override { 113 unsigned firstSetOptions = 114 numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, 115 bindFirstExtraToResultsOfOps); 116 unsigned secondSetOptions = 117 numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, 118 bindSecondExtraToResultsOfOps); 119 auto loc = UnknownLoc::get(&getContext()); 120 if (firstSetOptions > 1) { 121 emitError(loc) << "cannot bind the first extra top-level argument to " 122 "multiple entities"; 123 return signalPassFailure(); 124 } 125 if (secondSetOptions > 1) { 126 emitError(loc) << "cannot bind the second extra top-level argument to " 127 "multiple entities"; 128 return signalPassFailure(); 129 } 130 if (firstSetOptions == 0 && secondSetOptions != 0) { 131 emitError(loc) << "cannot bind the second extra top-level argument " 132 "without bindings the first"; 133 } 134 135 RaggedArray<transform::MappedValue> extraMapping; 136 if (!bindFirstExtraToOps.empty()) { 137 SmallVector<Operation *> operations; 138 findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), 139 operations); 140 extraMapping.push_back(operations); 141 } else if (!bindFirstExtraToParams.empty()) { 142 createParameterMapping(getContext(), bindFirstExtraToParams, 143 extraMapping); 144 } else if (!bindFirstExtraToResultsOfOps.empty()) { 145 createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, 146 extraMapping); 147 } 148 149 if (!bindSecondExtraToOps.empty()) { 150 SmallVector<Operation *> operations; 151 findOperationsByName(getOperation(), bindSecondExtraToOps, operations); 152 extraMapping.push_back(operations); 153 } else if (!bindSecondExtraToParams.empty()) { 154 createParameterMapping(getContext(), bindSecondExtraToParams, 155 extraMapping); 156 } else if (!bindSecondExtraToResultsOfOps.empty()) { 157 createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, 158 extraMapping); 159 } 160 161 options = options.enableExpensiveChecks(enableExpensiveChecks); 162 options = options.enableEnforceSingleToplevelTransformOp( 163 enforceSingleToplevelTransformOp); 164 if (failed(transform::detail::interpreterBaseRunOnOperationImpl( 165 getOperation(), getArgument(), getSharedTransformModule(), 166 getTransformLibraryModule(), extraMapping, options, 167 transformFileName, transformLibraryPaths, debugPayloadRootTag, 168 debugTransformRootTag, getBinaryName()))) 169 return signalPassFailure(); 170 } 171 172 Option<bool> enableExpensiveChecks{ 173 *this, "enable-expensive-checks", llvm::cl::init(false), 174 llvm::cl::desc("perform expensive checks to better report errors in the " 175 "transform IR")}; 176 Option<bool> enforceSingleToplevelTransformOp{ 177 *this, "enforce-single-top-level-transform-op", llvm::cl::init(true), 178 llvm::cl::desc("Ensure that only a single top-level transform op is " 179 "present in the IR.")}; 180 181 Option<std::string> bindFirstExtraToOps{ 182 *this, "bind-first-extra-to-ops", 183 llvm::cl::desc("bind the first extra argument of the top-level op to " 184 "payload operations of the given kind")}; 185 ListOption<int> bindFirstExtraToParams{ 186 *this, "bind-first-extra-to-params", 187 llvm::cl::desc("bind the first extra argument of the top-level op to " 188 "the given integer parameters")}; 189 Option<std::string> bindFirstExtraToResultsOfOps{ 190 *this, "bind-first-extra-to-results-of-ops", 191 llvm::cl::desc("bind the first extra argument of the top-level op to " 192 "results of payload operations of the given kind")}; 193 194 Option<std::string> bindSecondExtraToOps{ 195 *this, "bind-second-extra-to-ops", 196 llvm::cl::desc("bind the second extra argument of the top-level op to " 197 "payload operations of the given kind")}; 198 ListOption<int> bindSecondExtraToParams{ 199 *this, "bind-second-extra-to-params", 200 llvm::cl::desc("bind the second extra argument of the top-level op to " 201 "the given integer parameters")}; 202 Option<std::string> bindSecondExtraToResultsOfOps{ 203 *this, "bind-second-extra-to-results-of-ops", 204 llvm::cl::desc("bind the second extra argument of the top-level op to " 205 "results of payload operations of the given kind")}; 206 207 Option<std::string> transformFileName{ 208 *this, "transform-file-name", llvm::cl::init(""), 209 llvm::cl::desc( 210 "Optional filename containing a transform dialect specification to " 211 "apply. If left empty, the IR is assumed to contain one top-level " 212 "transform dialect operation somewhere in the module.")}; 213 Option<std::string> debugPayloadRootTag{ 214 *this, "debug-payload-root-tag", llvm::cl::init(""), 215 llvm::cl::desc( 216 "Select the operation with 'transform.target_tag' attribute having " 217 "the given value as payload IR root. If empty select the pass anchor " 218 "operation as the payload IR root.")}; 219 Option<std::string> debugTransformRootTag{ 220 *this, "debug-transform-root-tag", llvm::cl::init(""), 221 llvm::cl::desc( 222 "Select the operation with 'transform.target_tag' attribute having " 223 "the given value as container IR for top-level transform ops. This " 224 "allows user control on what transformation to apply. If empty, " 225 "select the container of the top-level transform op.")}; 226 ListOption<std::string> transformLibraryPaths{ 227 *this, "transform-library-paths", llvm::cl::ZeroOrMore, 228 llvm::cl::desc("Optional paths to files with modules that should be " 229 "merged into the transform module to provide the " 230 "definitions of external named sequences.")}; 231 232 Option<bool> testModuleGeneration{ 233 *this, "test-module-generation", llvm::cl::init(false), 234 llvm::cl::desc("test the generation of the transform module during pass " 235 "initialization, overridden by parsing")}; 236 }; 237 238 struct TestTransformDialectEraseSchedulePass 239 : public PassWrapper<TestTransformDialectEraseSchedulePass, 240 OperationPass<ModuleOp>> { 241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 242 TestTransformDialectEraseSchedulePass) 243 244 StringRef getArgument() const final { 245 return "test-transform-dialect-erase-schedule"; 246 } 247 248 StringRef getDescription() const final { 249 return "erase transform dialect schedule from the IR"; 250 } 251 252 void runOnOperation() override { 253 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) { 254 if (isa<transform::TransformOpInterface>(nestedOp)) { 255 nestedOp->erase(); 256 return WalkResult::skip(); 257 } 258 return WalkResult::advance(); 259 }); 260 } 261 }; 262 } // namespace 263 264 namespace mlir { 265 namespace test { 266 /// Registers the test pass for erasing transform dialect ops. 267 void registerTestTransformDialectEraseSchedulePass() { 268 PassRegistration<TestTransformDialectEraseSchedulePass> reg; 269 } 270 /// Registers the test pass for applying transform dialect ops. 271 void registerTestTransformDialectInterpreterPass() { 272 PassRegistration<TestTransformDialectInterpreterPass> reg; 273 } 274 } // namespace test 275 } // namespace mlir 276