1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 16 #include "mlir/Dialect/Transform/IR/TransformOps.h" 17 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/Pass/Pass.h" 21 22 using namespace mlir; 23 24 namespace { 25 /// Simple pass that applies transform dialect ops directly contained in a 26 /// module. 27 28 template <typename Derived> 29 class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; 30 31 class TestTransformDialectInterpreterPass 32 : public transform::TransformInterpreterPassBase< 33 TestTransformDialectInterpreterPass, OpPassWrapper> { 34 public: 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 36 TestTransformDialectInterpreterPass) 37 38 TestTransformDialectInterpreterPass() = default; 39 TestTransformDialectInterpreterPass( 40 const TestTransformDialectInterpreterPass &pass) 41 : TransformInterpreterPassBase(pass) {} 42 43 StringRef getArgument() const override { 44 return "test-transform-dialect-interpreter"; 45 } 46 47 StringRef getDescription() const override { 48 return "apply transform dialect operations one by one"; 49 } 50 51 void getDependentDialects(DialectRegistry ®istry) const override { 52 registry.insert<transform::TransformDialect>(); 53 } 54 55 void findOperationsByName(Operation *root, StringRef name, 56 SmallVectorImpl<Operation *> &operations) { 57 root->walk([&](Operation *op) { 58 if (op->getName().getStringRef() == name) { 59 operations.push_back(op); 60 } 61 }); 62 } 63 64 void createParameterMapping(MLIRContext &context, ArrayRef<int> values, 65 RaggedArray<transform::MappedValue> &result) { 66 SmallVector<transform::MappedValue> storage = 67 llvm::to_vector(llvm::map_range(values, [&](int v) { 68 Builder b(&context); 69 return transform::MappedValue(b.getI64IntegerAttr(v)); 70 })); 71 result.push_back(std::move(storage)); 72 } 73 74 void 75 createOpResultMapping(Operation *root, StringRef name, 76 RaggedArray<transform::MappedValue> &extraMapping) { 77 SmallVector<Operation *> operations; 78 findOperationsByName(root, name, operations); 79 SmallVector<Value> results; 80 for (Operation *op : operations) 81 llvm::append_range(results, op->getResults()); 82 extraMapping.push_back(results); 83 } 84 85 unsigned numberOfSetOptions(const Option<std::string> &ops, 86 const ListOption<int> ¶ms, 87 const Option<std::string> &values) { 88 unsigned numSetValues = 0; 89 numSetValues += !ops.empty(); 90 numSetValues += !params.empty(); 91 numSetValues += !values.empty(); 92 return numSetValues; 93 } 94 95 std::optional<LogicalResult> constructTransformModule(OpBuilder &builder, 96 Location loc) { 97 if (!testModuleGeneration) 98 return std::nullopt; 99 100 builder.create<transform::SequenceOp>( 101 loc, TypeRange(), transform::FailurePropagationMode::Propagate, 102 builder.getType<transform::AnyOpType>(), 103 [](OpBuilder &b, Location nested, Value rootH) { 104 b.create<mlir::test::TestPrintRemarkAtOperandOp>( 105 nested, rootH, "remark from generated"); 106 b.create<transform::YieldOp>(nested, ValueRange()); 107 }); 108 return success(); 109 } 110 111 void runOnOperation() override { 112 unsigned firstSetOptions = 113 numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, 114 bindFirstExtraToResultsOfOps); 115 unsigned secondSetOptions = 116 numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, 117 bindSecondExtraToResultsOfOps); 118 auto loc = UnknownLoc::get(&getContext()); 119 if (firstSetOptions > 1) { 120 emitError(loc) << "cannot bind the first extra top-level argument to " 121 "multiple entities"; 122 return signalPassFailure(); 123 } 124 if (secondSetOptions > 1) { 125 emitError(loc) << "cannot bind the second extra top-level argument to " 126 "multiple entities"; 127 return signalPassFailure(); 128 } 129 if (firstSetOptions == 0 && secondSetOptions != 0) { 130 emitError(loc) << "cannot bind the second extra top-level argument " 131 "without bindings the first"; 132 } 133 134 RaggedArray<transform::MappedValue> extraMapping; 135 if (!bindFirstExtraToOps.empty()) { 136 SmallVector<Operation *> operations; 137 findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), 138 operations); 139 extraMapping.push_back(operations); 140 } else if (!bindFirstExtraToParams.empty()) { 141 createParameterMapping(getContext(), bindFirstExtraToParams, 142 extraMapping); 143 } else if (!bindFirstExtraToResultsOfOps.empty()) { 144 createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, 145 extraMapping); 146 } 147 148 if (!bindSecondExtraToOps.empty()) { 149 SmallVector<Operation *> operations; 150 findOperationsByName(getOperation(), bindSecondExtraToOps, operations); 151 extraMapping.push_back(operations); 152 } else if (!bindSecondExtraToParams.empty()) { 153 createParameterMapping(getContext(), bindSecondExtraToParams, 154 extraMapping); 155 } else if (!bindSecondExtraToResultsOfOps.empty()) { 156 createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, 157 extraMapping); 158 } 159 160 options = options.enableExpensiveChecks(enableExpensiveChecks); 161 options = options.enableEnforceSingleToplevelTransformOp( 162 enforceSingleToplevelTransformOp); 163 if (failed(transform::detail::interpreterBaseRunOnOperationImpl( 164 getOperation(), getArgument(), getSharedTransformModule(), 165 getTransformLibraryModule(), extraMapping, options, 166 transformFileName, transformLibraryPaths, debugPayloadRootTag, 167 debugTransformRootTag, getBinaryName()))) 168 return signalPassFailure(); 169 } 170 171 Option<bool> enableExpensiveChecks{ 172 *this, "enable-expensive-checks", llvm::cl::init(false), 173 llvm::cl::desc("perform expensive checks to better report errors in the " 174 "transform IR")}; 175 Option<bool> enforceSingleToplevelTransformOp{ 176 *this, "enforce-single-top-level-transform-op", llvm::cl::init(true), 177 llvm::cl::desc("Ensure that only a single top-level transform op is " 178 "present in the IR.")}; 179 180 Option<std::string> bindFirstExtraToOps{ 181 *this, "bind-first-extra-to-ops", 182 llvm::cl::desc("bind the first extra argument of the top-level op to " 183 "payload operations of the given kind")}; 184 ListOption<int> bindFirstExtraToParams{ 185 *this, "bind-first-extra-to-params", 186 llvm::cl::desc("bind the first extra argument of the top-level op to " 187 "the given integer parameters")}; 188 Option<std::string> bindFirstExtraToResultsOfOps{ 189 *this, "bind-first-extra-to-results-of-ops", 190 llvm::cl::desc("bind the first extra argument of the top-level op to " 191 "results of payload operations of the given kind")}; 192 193 Option<std::string> bindSecondExtraToOps{ 194 *this, "bind-second-extra-to-ops", 195 llvm::cl::desc("bind the second extra argument of the top-level op to " 196 "payload operations of the given kind")}; 197 ListOption<int> bindSecondExtraToParams{ 198 *this, "bind-second-extra-to-params", 199 llvm::cl::desc("bind the second extra argument of the top-level op to " 200 "the given integer parameters")}; 201 Option<std::string> bindSecondExtraToResultsOfOps{ 202 *this, "bind-second-extra-to-results-of-ops", 203 llvm::cl::desc("bind the second extra argument of the top-level op to " 204 "results of payload operations of the given kind")}; 205 206 Option<std::string> transformFileName{ 207 *this, "transform-file-name", llvm::cl::init(""), 208 llvm::cl::desc( 209 "Optional filename containing a transform dialect specification to " 210 "apply. If left empty, the IR is assumed to contain one top-level " 211 "transform dialect operation somewhere in the module.")}; 212 Option<std::string> debugPayloadRootTag{ 213 *this, "debug-payload-root-tag", llvm::cl::init(""), 214 llvm::cl::desc( 215 "Select the operation with 'transform.target_tag' attribute having " 216 "the given value as payload IR root. If empty select the pass anchor " 217 "operation as the payload IR root.")}; 218 Option<std::string> debugTransformRootTag{ 219 *this, "debug-transform-root-tag", llvm::cl::init(""), 220 llvm::cl::desc( 221 "Select the operation with 'transform.target_tag' attribute having " 222 "the given value as container IR for top-level transform ops. This " 223 "allows user control on what transformation to apply. If empty, " 224 "select the container of the top-level transform op.")}; 225 ListOption<std::string> transformLibraryPaths{ 226 *this, "transform-library-paths", llvm::cl::ZeroOrMore, 227 llvm::cl::desc("Optional paths to files with modules that should be " 228 "merged into the transform module to provide the " 229 "definitions of external named sequences.")}; 230 231 Option<bool> testModuleGeneration{ 232 *this, "test-module-generation", llvm::cl::init(false), 233 llvm::cl::desc("test the generation of the transform module during pass " 234 "initialization, overridden by parsing")}; 235 }; 236 237 struct TestTransformDialectEraseSchedulePass 238 : public PassWrapper<TestTransformDialectEraseSchedulePass, 239 OperationPass<ModuleOp>> { 240 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 241 TestTransformDialectEraseSchedulePass) 242 243 StringRef getArgument() const final { 244 return "test-transform-dialect-erase-schedule"; 245 } 246 247 StringRef getDescription() const final { 248 return "erase transform dialect schedule from the IR"; 249 } 250 251 void runOnOperation() override { 252 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) { 253 if (isa<transform::TransformOpInterface>(nestedOp)) { 254 nestedOp->erase(); 255 return WalkResult::skip(); 256 } 257 return WalkResult::advance(); 258 }); 259 } 260 }; 261 } // namespace 262 263 namespace mlir { 264 namespace test { 265 /// Registers the test pass for erasing transform dialect ops. 266 void registerTestTransformDialectEraseSchedulePass() { 267 PassRegistration<TestTransformDialectEraseSchedulePass> reg; 268 } 269 /// Registers the test pass for applying transform dialect ops. 270 void registerTestTransformDialectInterpreterPass() { 271 PassRegistration<TestTransformDialectInterpreterPass> reg; 272 } 273 } // namespace test 274 } // namespace mlir 275