//===- TestTransformDialectInterpreter.cpp --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines a test pass that interprets Transform dialect operations in // the module. // //===----------------------------------------------------------------------===// #include "TestTransformDialectExtension.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; namespace { /// Simple pass that applies transform dialect ops directly contained in a /// module. template class OpPassWrapper : public PassWrapper> {}; class TestTransformDialectInterpreterPass : public transform::TransformInterpreterPassBase< TestTransformDialectInterpreterPass, OpPassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectInterpreterPass) TestTransformDialectInterpreterPass() = default; TestTransformDialectInterpreterPass( const TestTransformDialectInterpreterPass &pass) : TransformInterpreterPassBase(pass) {} StringRef getArgument() const override { return "test-transform-dialect-interpreter"; } StringRef getDescription() const override { return "apply transform dialect operations one by one"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void findOperationsByName(Operation *root, StringRef name, SmallVectorImpl &operations) { root->walk([&](Operation *op) { if (op->getName().getStringRef() == name) { operations.push_back(op); } }); } void createParameterMapping(MLIRContext &context, ArrayRef values, RaggedArray &result) { SmallVector storage = llvm::to_vector(llvm::map_range(values, [&](int v) { Builder b(&context); return transform::MappedValue(b.getI64IntegerAttr(v)); })); result.push_back(std::move(storage)); } void createOpResultMapping(Operation *root, StringRef name, RaggedArray &extraMapping) { SmallVector operations; findOperationsByName(root, name, operations); SmallVector results; for (Operation *op : operations) llvm::append_range(results, op->getResults()); extraMapping.push_back(results); } unsigned numberOfSetOptions(const Option &ops, const ListOption ¶ms, const Option &values) { unsigned numSetValues = 0; numSetValues += !ops.empty(); numSetValues += !params.empty(); numSetValues += !values.empty(); return numSetValues; } std::optional constructTransformModule(OpBuilder &builder, Location loc) { if (!testModuleGeneration) return std::nullopt; builder.create( loc, TypeRange(), transform::FailurePropagationMode::Propagate, builder.getType(), [](OpBuilder &b, Location nested, Value rootH) { b.create( nested, rootH, "remark from generated"); b.create(nested, ValueRange()); }); return success(); } void runOnOperation() override { unsigned firstSetOptions = numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, bindFirstExtraToResultsOfOps); unsigned secondSetOptions = numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, bindSecondExtraToResultsOfOps); auto loc = UnknownLoc::get(&getContext()); if (firstSetOptions > 1) { emitError(loc) << "cannot bind the first extra top-level argument to " "multiple entities"; return signalPassFailure(); } if (secondSetOptions > 1) { emitError(loc) << "cannot bind the second extra top-level argument to " "multiple entities"; return signalPassFailure(); } if (firstSetOptions == 0 && secondSetOptions != 0) { emitError(loc) << "cannot bind the second extra top-level argument " "without bindings the first"; } RaggedArray extraMapping; if (!bindFirstExtraToOps.empty()) { SmallVector operations; findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), operations); extraMapping.push_back(operations); } else if (!bindFirstExtraToParams.empty()) { createParameterMapping(getContext(), bindFirstExtraToParams, extraMapping); } else if (!bindFirstExtraToResultsOfOps.empty()) { createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, extraMapping); } if (!bindSecondExtraToOps.empty()) { SmallVector operations; findOperationsByName(getOperation(), bindSecondExtraToOps, operations); extraMapping.push_back(operations); } else if (!bindSecondExtraToParams.empty()) { createParameterMapping(getContext(), bindSecondExtraToParams, extraMapping); } else if (!bindSecondExtraToResultsOfOps.empty()) { createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, extraMapping); } options = options.enableExpensiveChecks(enableExpensiveChecks); options = options.enableEnforceSingleToplevelTransformOp( enforceSingleToplevelTransformOp); if (failed(transform::detail::interpreterBaseRunOnOperationImpl( getOperation(), getArgument(), getSharedTransformModule(), getTransformLibraryModule(), extraMapping, options, transformFileName, transformLibraryPaths, debugPayloadRootTag, debugTransformRootTag, getBinaryName()))) return signalPassFailure(); } Option enableExpensiveChecks{ *this, "enable-expensive-checks", llvm::cl::init(false), llvm::cl::desc("perform expensive checks to better report errors in the " "transform IR")}; Option enforceSingleToplevelTransformOp{ *this, "enforce-single-top-level-transform-op", llvm::cl::init(true), llvm::cl::desc("Ensure that only a single top-level transform op is " "present in the IR.")}; Option bindFirstExtraToOps{ *this, "bind-first-extra-to-ops", llvm::cl::desc("bind the first extra argument of the top-level op to " "payload operations of the given kind")}; ListOption bindFirstExtraToParams{ *this, "bind-first-extra-to-params", llvm::cl::desc("bind the first extra argument of the top-level op to " "the given integer parameters")}; Option bindFirstExtraToResultsOfOps{ *this, "bind-first-extra-to-results-of-ops", llvm::cl::desc("bind the first extra argument of the top-level op to " "results of payload operations of the given kind")}; Option bindSecondExtraToOps{ *this, "bind-second-extra-to-ops", llvm::cl::desc("bind the second extra argument of the top-level op to " "payload operations of the given kind")}; ListOption bindSecondExtraToParams{ *this, "bind-second-extra-to-params", llvm::cl::desc("bind the second extra argument of the top-level op to " "the given integer parameters")}; Option bindSecondExtraToResultsOfOps{ *this, "bind-second-extra-to-results-of-ops", llvm::cl::desc("bind the second extra argument of the top-level op to " "results of payload operations of the given kind")}; Option transformFileName{ *this, "transform-file-name", llvm::cl::init(""), llvm::cl::desc( "Optional filename containing a transform dialect specification to " "apply. If left empty, the IR is assumed to contain one top-level " "transform dialect operation somewhere in the module.")}; Option debugPayloadRootTag{ *this, "debug-payload-root-tag", llvm::cl::init(""), llvm::cl::desc( "Select the operation with 'transform.target_tag' attribute having " "the given value as payload IR root. If empty select the pass anchor " "operation as the payload IR root.")}; Option debugTransformRootTag{ *this, "debug-transform-root-tag", llvm::cl::init(""), llvm::cl::desc( "Select the operation with 'transform.target_tag' attribute having " "the given value as container IR for top-level transform ops. This " "allows user control on what transformation to apply. If empty, " "select the container of the top-level transform op.")}; ListOption transformLibraryPaths{ *this, "transform-library-paths", llvm::cl::ZeroOrMore, llvm::cl::desc("Optional paths to files with modules that should be " "merged into the transform module to provide the " "definitions of external named sequences.")}; Option testModuleGeneration{ *this, "test-module-generation", llvm::cl::init(false), llvm::cl::desc("test the generation of the transform module during pass " "initialization, overridden by parsing")}; }; struct TestTransformDialectEraseSchedulePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectEraseSchedulePass) StringRef getArgument() const final { return "test-transform-dialect-erase-schedule"; } StringRef getDescription() const final { return "erase transform dialect schedule from the IR"; } void runOnOperation() override { getOperation()->walk([&](Operation *nestedOp) { if (isa(nestedOp)) { nestedOp->erase(); return WalkResult::skip(); } return WalkResult::advance(); }); } }; } // namespace namespace mlir { namespace test { /// Registers the test pass for erasing transform dialect ops. void registerTestTransformDialectEraseSchedulePass() { PassRegistration reg; } /// Registers the test pass for applying transform dialect ops. void registerTestTransformDialectInterpreterPass() { PassRegistration reg; } } // namespace test } // namespace mlir