17ad9e9dcSMatteo Franciolini //===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
27ad9e9dcSMatteo Franciolini //
37ad9e9dcSMatteo Franciolini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ad9e9dcSMatteo Franciolini // See https://llvm.org/LICENSE.txt for license information.
57ad9e9dcSMatteo Franciolini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ad9e9dcSMatteo Franciolini //
77ad9e9dcSMatteo Franciolini //===----------------------------------------------------------------------===//
87ad9e9dcSMatteo Franciolini
97ad9e9dcSMatteo Franciolini #include "TestDialect.h"
10*e95e94adSJeff Niu #include "TestOps.h"
117ad9e9dcSMatteo Franciolini #include "mlir/Bytecode/BytecodeReader.h"
127ad9e9dcSMatteo Franciolini #include "mlir/Bytecode/BytecodeWriter.h"
137ad9e9dcSMatteo Franciolini #include "mlir/IR/BuiltinOps.h"
147ad9e9dcSMatteo Franciolini #include "mlir/IR/OperationSupport.h"
157ad9e9dcSMatteo Franciolini #include "mlir/Parser/Parser.h"
167ad9e9dcSMatteo Franciolini #include "mlir/Pass/Pass.h"
177ad9e9dcSMatteo Franciolini #include "llvm/Support/CommandLine.h"
187ad9e9dcSMatteo Franciolini #include "llvm/Support/MemoryBufferRef.h"
197ad9e9dcSMatteo Franciolini #include "llvm/Support/raw_ostream.h"
207ad9e9dcSMatteo Franciolini #include <list>
217ad9e9dcSMatteo Franciolini
227ad9e9dcSMatteo Franciolini using namespace mlir;
237ad9e9dcSMatteo Franciolini using namespace llvm;
247ad9e9dcSMatteo Franciolini
257ad9e9dcSMatteo Franciolini namespace {
267ad9e9dcSMatteo Franciolini class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
277ad9e9dcSMatteo Franciolini public:
TestDialectVersionParser(cl::Option & o)281cf4cc00SMehdi Amini TestDialectVersionParser(cl::Option &o)
291cf4cc00SMehdi Amini : cl::parser<test::TestDialectVersion>(o) {}
307ad9e9dcSMatteo Franciolini
parse(cl::Option & o,StringRef,StringRef arg,test::TestDialectVersion & v)311cf4cc00SMehdi Amini bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
327ad9e9dcSMatteo Franciolini test::TestDialectVersion &v) {
331cf4cc00SMehdi Amini long long major, minor;
341cf4cc00SMehdi Amini if (getAsSignedInteger(arg.split(".").first, 10, major))
351cf4cc00SMehdi Amini return o.error("Invalid argument '" + arg);
361cf4cc00SMehdi Amini if (getAsSignedInteger(arg.split(".").second, 10, minor))
371cf4cc00SMehdi Amini return o.error("Invalid argument '" + arg);
381cf4cc00SMehdi Amini v = test::TestDialectVersion(major, minor);
397ad9e9dcSMatteo Franciolini // Returns true on error.
407ad9e9dcSMatteo Franciolini return false;
417ad9e9dcSMatteo Franciolini }
print(raw_ostream & os,const test::TestDialectVersion & v)427ad9e9dcSMatteo Franciolini static void print(raw_ostream &os, const test::TestDialectVersion &v) {
437ad9e9dcSMatteo Franciolini os << v.major_ << "." << v.minor_;
447ad9e9dcSMatteo Franciolini };
457ad9e9dcSMatteo Franciolini };
467ad9e9dcSMatteo Franciolini
477ad9e9dcSMatteo Franciolini /// This is a test pass which uses callbacks to encode attributes and types in a
487ad9e9dcSMatteo Franciolini /// custom fashion.
497ad9e9dcSMatteo Franciolini struct TestBytecodeRoundtripPass
507ad9e9dcSMatteo Franciolini : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon58778feb0111::TestBytecodeRoundtripPass517ad9e9dcSMatteo Franciolini MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
527ad9e9dcSMatteo Franciolini
537ad9e9dcSMatteo Franciolini StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
getDescription__anon58778feb0111::TestBytecodeRoundtripPass547ad9e9dcSMatteo Franciolini StringRef getDescription() const final {
557ad9e9dcSMatteo Franciolini return "Test pass to implement bytecode roundtrip tests.";
567ad9e9dcSMatteo Franciolini }
getDependentDialects__anon58778feb0111::TestBytecodeRoundtripPass577ad9e9dcSMatteo Franciolini void getDependentDialects(DialectRegistry ®istry) const override {
587ad9e9dcSMatteo Franciolini registry.insert<test::TestDialect>();
597ad9e9dcSMatteo Franciolini }
607ad9e9dcSMatteo Franciolini TestBytecodeRoundtripPass() = default;
TestBytecodeRoundtripPass__anon58778feb0111::TestBytecodeRoundtripPass617ad9e9dcSMatteo Franciolini TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
627ad9e9dcSMatteo Franciolini
initialize__anon58778feb0111::TestBytecodeRoundtripPass637ad9e9dcSMatteo Franciolini LogicalResult initialize(MLIRContext *context) override {
647ad9e9dcSMatteo Franciolini testDialect = context->getOrLoadDialect<test::TestDialect>();
657ad9e9dcSMatteo Franciolini return success();
667ad9e9dcSMatteo Franciolini }
677ad9e9dcSMatteo Franciolini
runOnOperation__anon58778feb0111::TestBytecodeRoundtripPass687ad9e9dcSMatteo Franciolini void runOnOperation() override {
697ad9e9dcSMatteo Franciolini switch (testKind) {
707ad9e9dcSMatteo Franciolini // Tests 0-5 implement a custom roundtrip with callbacks.
717ad9e9dcSMatteo Franciolini case (0):
727ad9e9dcSMatteo Franciolini return runTest0(getOperation());
737ad9e9dcSMatteo Franciolini case (1):
747ad9e9dcSMatteo Franciolini return runTest1(getOperation());
757ad9e9dcSMatteo Franciolini case (2):
767ad9e9dcSMatteo Franciolini return runTest2(getOperation());
777ad9e9dcSMatteo Franciolini case (3):
787ad9e9dcSMatteo Franciolini return runTest3(getOperation());
797ad9e9dcSMatteo Franciolini case (4):
807ad9e9dcSMatteo Franciolini return runTest4(getOperation());
817ad9e9dcSMatteo Franciolini case (5):
827ad9e9dcSMatteo Franciolini return runTest5(getOperation());
837ad9e9dcSMatteo Franciolini case (6):
847ad9e9dcSMatteo Franciolini // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
857ad9e9dcSMatteo Franciolini // `targetVersion`.
867ad9e9dcSMatteo Franciolini return runTest6(getOperation());
877ad9e9dcSMatteo Franciolini default:
887ad9e9dcSMatteo Franciolini llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
897ad9e9dcSMatteo Franciolini }
907ad9e9dcSMatteo Franciolini }
917ad9e9dcSMatteo Franciolini
927ad9e9dcSMatteo Franciolini mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
937ad9e9dcSMatteo Franciolini targetVersion{*this, "test-dialect-version",
947ad9e9dcSMatteo Franciolini llvm::cl::desc(
957ad9e9dcSMatteo Franciolini "Specifies the test dialect version to emit and parse"),
967ad9e9dcSMatteo Franciolini cl::init(test::TestDialectVersion())};
977ad9e9dcSMatteo Franciolini
987ad9e9dcSMatteo Franciolini mlir::Pass::Option<int> testKind{
997ad9e9dcSMatteo Franciolini *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
1007ad9e9dcSMatteo Franciolini cl::init(0)};
1017ad9e9dcSMatteo Franciolini
1027ad9e9dcSMatteo Franciolini private:
doRoundtripWithConfigs__anon58778feb0111::TestBytecodeRoundtripPass1037ad9e9dcSMatteo Franciolini void doRoundtripWithConfigs(Operation *op,
1047ad9e9dcSMatteo Franciolini const BytecodeWriterConfig &writeConfig,
1057ad9e9dcSMatteo Franciolini const ParserConfig &parseConfig) {
1067ad9e9dcSMatteo Franciolini std::string bytecode;
1077ad9e9dcSMatteo Franciolini llvm::raw_string_ostream os(bytecode);
1087ad9e9dcSMatteo Franciolini if (failed(writeBytecodeToFile(op, os, writeConfig))) {
1097ad9e9dcSMatteo Franciolini op->emitError() << "failed to write bytecode\n";
1107ad9e9dcSMatteo Franciolini signalPassFailure();
1117ad9e9dcSMatteo Franciolini return;
1127ad9e9dcSMatteo Franciolini }
1137ad9e9dcSMatteo Franciolini auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
1147ad9e9dcSMatteo Franciolini if (!newModuleOp.get()) {
1157ad9e9dcSMatteo Franciolini op->emitError() << "failed to read bytecode\n";
1167ad9e9dcSMatteo Franciolini signalPassFailure();
1177ad9e9dcSMatteo Franciolini return;
1187ad9e9dcSMatteo Franciolini }
1197ad9e9dcSMatteo Franciolini // Print the module to the output stream, so that we can filecheck the
1207ad9e9dcSMatteo Franciolini // result.
1217ad9e9dcSMatteo Franciolini newModuleOp->print(llvm::outs());
1227ad9e9dcSMatteo Franciolini }
1237ad9e9dcSMatteo Franciolini
1247ad9e9dcSMatteo Franciolini // Test0: let's assume that versions older than 2.0 were relying on a special
1257ad9e9dcSMatteo Franciolini // integer attribute of a deprecated dialect called "funky". Assume that its
1267ad9e9dcSMatteo Franciolini // encoding was made by two varInts, the first was the ID (999) and the second
1277ad9e9dcSMatteo Franciolini // contained width and signedness info. We can emit it using a callback
1287ad9e9dcSMatteo Franciolini // writing a custom encoding for the "funky" dialect group, and parse it back
1297ad9e9dcSMatteo Franciolini // with a custom parser reading the same encoding in the same dialect group.
1307ad9e9dcSMatteo Franciolini // Note that the ID 999 does not correspond to a valid integer type in the
1317ad9e9dcSMatteo Franciolini // current encodings of builtin types.
runTest0__anon58778feb0111::TestBytecodeRoundtripPass1327ad9e9dcSMatteo Franciolini void runTest0(Operation *op) {
1337ad9e9dcSMatteo Franciolini auto newCtx = std::make_shared<MLIRContext>();
1347ad9e9dcSMatteo Franciolini test::TestDialectVersion targetEmissionVersion = targetVersion;
1357ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
1367ad9e9dcSMatteo Franciolini // Set the emission version for the test dialect.
1377ad9e9dcSMatteo Franciolini writeConfig.setDialectVersion<test::TestDialect>(
1387ad9e9dcSMatteo Franciolini std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
1397ad9e9dcSMatteo Franciolini writeConfig.attachTypeCallback(
1407ad9e9dcSMatteo Franciolini [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
1417ad9e9dcSMatteo Franciolini DialectBytecodeWriter &writer) -> LogicalResult {
1427ad9e9dcSMatteo Franciolini // Do not override anything if version greater than 2.0.
1437ad9e9dcSMatteo Franciolini auto versionOr = writer.getDialectVersion<test::TestDialect>();
1447ad9e9dcSMatteo Franciolini assert(succeeded(versionOr) && "expected reader to be able to access "
1457ad9e9dcSMatteo Franciolini "the version for test dialect");
1467ad9e9dcSMatteo Franciolini const auto *version =
1477ad9e9dcSMatteo Franciolini reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
1487ad9e9dcSMatteo Franciolini if (version->major_ >= 2)
1497ad9e9dcSMatteo Franciolini return failure();
1507ad9e9dcSMatteo Franciolini
1517ad9e9dcSMatteo Franciolini // For version less than 2.0, override the encoding of IntegerType.
1527ad9e9dcSMatteo Franciolini if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
1537ad9e9dcSMatteo Franciolini llvm::outs() << "Overriding IntegerType encoding...\n";
1547ad9e9dcSMatteo Franciolini dialectGroupName = StringLiteral("funky");
1557ad9e9dcSMatteo Franciolini writer.writeVarInt(/* IntegerType */ 999);
1567ad9e9dcSMatteo Franciolini writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
1577ad9e9dcSMatteo Franciolini return success();
1587ad9e9dcSMatteo Franciolini }
1597ad9e9dcSMatteo Franciolini return failure();
1607ad9e9dcSMatteo Franciolini });
1617ad9e9dcSMatteo Franciolini newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
1627ad9e9dcSMatteo Franciolini newCtx->allowUnregisteredDialects();
1637ad9e9dcSMatteo Franciolini ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
1647ad9e9dcSMatteo Franciolini parseConfig.getBytecodeReaderConfig().attachTypeCallback(
1657ad9e9dcSMatteo Franciolini [&](DialectBytecodeReader &reader, StringRef dialectName,
1667ad9e9dcSMatteo Franciolini Type &entry) -> LogicalResult {
1677ad9e9dcSMatteo Franciolini // Get test dialect version from the version map.
1687ad9e9dcSMatteo Franciolini auto versionOr = reader.getDialectVersion<test::TestDialect>();
1697ad9e9dcSMatteo Franciolini assert(succeeded(versionOr) && "expected reader to be able to access "
1707ad9e9dcSMatteo Franciolini "the version for test dialect");
1717ad9e9dcSMatteo Franciolini const auto *version =
1727ad9e9dcSMatteo Franciolini reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
1737ad9e9dcSMatteo Franciolini if (version->major_ >= 2)
1747ad9e9dcSMatteo Franciolini return success();
1757ad9e9dcSMatteo Franciolini
1767ad9e9dcSMatteo Franciolini // `dialectName` is the name of the group we have the opportunity to
1777ad9e9dcSMatteo Franciolini // override. In this case, override only the dialect group "funky",
1787ad9e9dcSMatteo Franciolini // for which does not exist in memory.
1797ad9e9dcSMatteo Franciolini if (dialectName != StringLiteral("funky"))
1807ad9e9dcSMatteo Franciolini return success();
1817ad9e9dcSMatteo Franciolini
1827ad9e9dcSMatteo Franciolini uint64_t encoding;
1837ad9e9dcSMatteo Franciolini if (failed(reader.readVarInt(encoding)) || encoding != 999)
1847ad9e9dcSMatteo Franciolini return success();
1857ad9e9dcSMatteo Franciolini llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
1861cf4cc00SMehdi Amini uint64_t widthAndSignedness, width;
1877ad9e9dcSMatteo Franciolini IntegerType::SignednessSemantics signedness;
1881cf4cc00SMehdi Amini if (succeeded(reader.readVarInt(widthAndSignedness)) &&
1891cf4cc00SMehdi Amini ((width = widthAndSignedness >> 2), true) &&
1907ad9e9dcSMatteo Franciolini ((signedness = static_cast<IntegerType::SignednessSemantics>(
1911cf4cc00SMehdi Amini widthAndSignedness & 0x3)),
1927ad9e9dcSMatteo Franciolini true))
1937ad9e9dcSMatteo Franciolini entry = IntegerType::get(reader.getContext(), width, signedness);
1947ad9e9dcSMatteo Franciolini // Return nullopt to fall through the rest of the parsing code path.
1957ad9e9dcSMatteo Franciolini return success();
1967ad9e9dcSMatteo Franciolini });
1977ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
1987ad9e9dcSMatteo Franciolini }
1997ad9e9dcSMatteo Franciolini
2007ad9e9dcSMatteo Franciolini // Test1: When writing bytecode, we override the encoding of TestI32Type with
2017ad9e9dcSMatteo Franciolini // the encoding of builtin IntegerType. We can natively parse this without
2027ad9e9dcSMatteo Franciolini // the use of a callback, relying on the existing builtin reader mechanism.
runTest1__anon58778feb0111::TestBytecodeRoundtripPass2037ad9e9dcSMatteo Franciolini void runTest1(Operation *op) {
2043029acb8SMehdi Amini auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
2057ad9e9dcSMatteo Franciolini BytecodeDialectInterface *iface =
2067ad9e9dcSMatteo Franciolini builtin->getRegisteredInterface<BytecodeDialectInterface>();
2077ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
2087ad9e9dcSMatteo Franciolini writeConfig.attachTypeCallback(
2097ad9e9dcSMatteo Franciolini [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
2107ad9e9dcSMatteo Franciolini DialectBytecodeWriter &writer) -> LogicalResult {
2117ad9e9dcSMatteo Franciolini // Emit TestIntegerType using the builtin dialect encoding.
2127ad9e9dcSMatteo Franciolini if (llvm::isa<test::TestI32Type>(entryValue)) {
2137ad9e9dcSMatteo Franciolini llvm::outs() << "Overriding TestI32Type encoding...\n";
2147ad9e9dcSMatteo Franciolini auto builtinI32Type =
2157ad9e9dcSMatteo Franciolini IntegerType::get(op->getContext(), 32,
2167ad9e9dcSMatteo Franciolini IntegerType::SignednessSemantics::Signless);
2177ad9e9dcSMatteo Franciolini // Specify that this type will need to be written as part of the
2187ad9e9dcSMatteo Franciolini // builtin group. This will override the default dialect group of
2197ad9e9dcSMatteo Franciolini // the attribute (test).
2207ad9e9dcSMatteo Franciolini dialectGroupName = StringLiteral("builtin");
2217ad9e9dcSMatteo Franciolini if (succeeded(iface->writeType(builtinI32Type, writer)))
2227ad9e9dcSMatteo Franciolini return success();
2237ad9e9dcSMatteo Franciolini }
2247ad9e9dcSMatteo Franciolini return failure();
2257ad9e9dcSMatteo Franciolini });
2267ad9e9dcSMatteo Franciolini // We natively parse the attribute as a builtin, so no callback needed.
2277ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
2287ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
2297ad9e9dcSMatteo Franciolini }
2307ad9e9dcSMatteo Franciolini
2317ad9e9dcSMatteo Franciolini // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
2327ad9e9dcSMatteo Franciolini // parsing, we use the encoding of IntegerType to intercept all i32. Then,
2337ad9e9dcSMatteo Franciolini // instead of creating i32s, we assemble TestI32Type and return it.
runTest2__anon58778feb0111::TestBytecodeRoundtripPass2347ad9e9dcSMatteo Franciolini void runTest2(Operation *op) {
2353029acb8SMehdi Amini auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
2367ad9e9dcSMatteo Franciolini BytecodeDialectInterface *iface =
2377ad9e9dcSMatteo Franciolini builtin->getRegisteredInterface<BytecodeDialectInterface>();
2387ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
2397ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
2407ad9e9dcSMatteo Franciolini parseConfig.getBytecodeReaderConfig().attachTypeCallback(
2417ad9e9dcSMatteo Franciolini [&](DialectBytecodeReader &reader, StringRef dialectName,
2427ad9e9dcSMatteo Franciolini Type &entry) -> LogicalResult {
2437ad9e9dcSMatteo Franciolini if (dialectName != StringLiteral("builtin"))
2447ad9e9dcSMatteo Franciolini return success();
2457ad9e9dcSMatteo Franciolini Type builtinAttr = iface->readType(reader);
2467ad9e9dcSMatteo Franciolini if (auto integerType =
2477ad9e9dcSMatteo Franciolini llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
2487ad9e9dcSMatteo Franciolini if (integerType.getWidth() == 32 && integerType.isSignless()) {
2497ad9e9dcSMatteo Franciolini llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
2507ad9e9dcSMatteo Franciolini entry = test::TestI32Type::get(reader.getContext());
2517ad9e9dcSMatteo Franciolini }
2527ad9e9dcSMatteo Franciolini }
2537ad9e9dcSMatteo Franciolini return success();
2547ad9e9dcSMatteo Franciolini });
2557ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
2567ad9e9dcSMatteo Franciolini }
2577ad9e9dcSMatteo Franciolini
2587ad9e9dcSMatteo Franciolini // Test3: When writing bytecode, we override the encoding of
2597ad9e9dcSMatteo Franciolini // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
2607ad9e9dcSMatteo Franciolini // can natively parse this without the use of a callback, relying on the
2617ad9e9dcSMatteo Franciolini // existing builtin reader mechanism.
runTest3__anon58778feb0111::TestBytecodeRoundtripPass2627ad9e9dcSMatteo Franciolini void runTest3(Operation *op) {
2633029acb8SMehdi Amini auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
2647ad9e9dcSMatteo Franciolini BytecodeDialectInterface *iface =
2657ad9e9dcSMatteo Franciolini builtin->getRegisteredInterface<BytecodeDialectInterface>();
2667ad9e9dcSMatteo Franciolini auto i32Type = IntegerType::get(op->getContext(), 32,
2677ad9e9dcSMatteo Franciolini IntegerType::SignednessSemantics::Signless);
2687ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
2697ad9e9dcSMatteo Franciolini writeConfig.attachAttributeCallback(
2707ad9e9dcSMatteo Franciolini [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
2717ad9e9dcSMatteo Franciolini DialectBytecodeWriter &writer) -> LogicalResult {
2727ad9e9dcSMatteo Franciolini // Emit TestIntegerType using the builtin dialect encoding.
2737ad9e9dcSMatteo Franciolini if (auto testParamAttrs =
2747ad9e9dcSMatteo Franciolini llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
2757ad9e9dcSMatteo Franciolini llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
2767ad9e9dcSMatteo Franciolini // Specify that this attribute will need to be written as part of
2777ad9e9dcSMatteo Franciolini // the builtin group. This will override the default dialect group
2787ad9e9dcSMatteo Franciolini // of the attribute (test).
2797ad9e9dcSMatteo Franciolini dialectGroupName = StringLiteral("builtin");
2807ad9e9dcSMatteo Franciolini auto denseAttr = DenseIntElementsAttr::get(
2817ad9e9dcSMatteo Franciolini RankedTensorType::get({2}, i32Type),
2827ad9e9dcSMatteo Franciolini {testParamAttrs.getV0(), testParamAttrs.getV1()});
2837ad9e9dcSMatteo Franciolini if (succeeded(iface->writeAttribute(denseAttr, writer)))
2847ad9e9dcSMatteo Franciolini return success();
2857ad9e9dcSMatteo Franciolini }
2867ad9e9dcSMatteo Franciolini return failure();
2877ad9e9dcSMatteo Franciolini });
2887ad9e9dcSMatteo Franciolini // We natively parse the attribute as a builtin, so no callback needed.
2897ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
2907ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
2917ad9e9dcSMatteo Franciolini }
2927ad9e9dcSMatteo Franciolini
2937ad9e9dcSMatteo Franciolini // Test4: When writing bytecode, we write standard builtin
2947ad9e9dcSMatteo Franciolini // DenseIntElementsAttr. At parsing, we use the encoding of
2957ad9e9dcSMatteo Franciolini // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
2967ad9e9dcSMatteo Franciolini // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
2977ad9e9dcSMatteo Franciolini // TestAttrParamsAttr and return it.
runTest4__anon58778feb0111::TestBytecodeRoundtripPass2987ad9e9dcSMatteo Franciolini void runTest4(Operation *op) {
2993029acb8SMehdi Amini auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
3007ad9e9dcSMatteo Franciolini BytecodeDialectInterface *iface =
3017ad9e9dcSMatteo Franciolini builtin->getRegisteredInterface<BytecodeDialectInterface>();
3027ad9e9dcSMatteo Franciolini auto i32Type = IntegerType::get(op->getContext(), 32,
3037ad9e9dcSMatteo Franciolini IntegerType::SignednessSemantics::Signless);
3047ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
3057ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
3067ad9e9dcSMatteo Franciolini parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
3077ad9e9dcSMatteo Franciolini [&](DialectBytecodeReader &reader, StringRef dialectName,
3087ad9e9dcSMatteo Franciolini Attribute &entry) -> LogicalResult {
3097ad9e9dcSMatteo Franciolini // Override only the case where the return type of the builtin reader
3107ad9e9dcSMatteo Franciolini // is an i32 and fall through on all the other cases, since we want to
3117ad9e9dcSMatteo Franciolini // still use TestDialect normal codepath to parse the other types.
3127ad9e9dcSMatteo Franciolini Attribute builtinAttr = iface->readAttribute(reader);
3137ad9e9dcSMatteo Franciolini if (auto denseAttr =
3147ad9e9dcSMatteo Franciolini llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
3157ad9e9dcSMatteo Franciolini if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
3167ad9e9dcSMatteo Franciolini denseAttr.getElementType() == i32Type) {
3177ad9e9dcSMatteo Franciolini llvm::outs()
3187ad9e9dcSMatteo Franciolini << "Overriding parsing of TestAttrParamsAttr encoding...\n";
3197ad9e9dcSMatteo Franciolini int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
3207ad9e9dcSMatteo Franciolini int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
3217ad9e9dcSMatteo Franciolini entry =
3227ad9e9dcSMatteo Franciolini test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
3237ad9e9dcSMatteo Franciolini }
3247ad9e9dcSMatteo Franciolini }
3257ad9e9dcSMatteo Franciolini return success();
3267ad9e9dcSMatteo Franciolini });
3277ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
3287ad9e9dcSMatteo Franciolini }
3297ad9e9dcSMatteo Franciolini
3307ad9e9dcSMatteo Franciolini // Test5: When writing bytecode, we want TestDialect to use nothing else than
3317ad9e9dcSMatteo Franciolini // the builtin types and attributes and take full control of the encoding,
3327ad9e9dcSMatteo Franciolini // returning failure if any type or attribute is not part of builtin.
runTest5__anon58778feb0111::TestBytecodeRoundtripPass3337ad9e9dcSMatteo Franciolini void runTest5(Operation *op) {
3343029acb8SMehdi Amini auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
3357ad9e9dcSMatteo Franciolini BytecodeDialectInterface *iface =
3367ad9e9dcSMatteo Franciolini builtin->getRegisteredInterface<BytecodeDialectInterface>();
3377ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
3387ad9e9dcSMatteo Franciolini writeConfig.attachAttributeCallback(
3397ad9e9dcSMatteo Franciolini [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
3407ad9e9dcSMatteo Franciolini DialectBytecodeWriter &writer) -> LogicalResult {
3417ad9e9dcSMatteo Franciolini return iface->writeAttribute(attr, writer);
3427ad9e9dcSMatteo Franciolini });
3437ad9e9dcSMatteo Franciolini writeConfig.attachTypeCallback(
3447ad9e9dcSMatteo Franciolini [&](Type type, std::optional<StringRef> &dialectGroupName,
3457ad9e9dcSMatteo Franciolini DialectBytecodeWriter &writer) -> LogicalResult {
3467ad9e9dcSMatteo Franciolini return iface->writeType(type, writer);
3477ad9e9dcSMatteo Franciolini });
3487ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
3497ad9e9dcSMatteo Franciolini parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
3507ad9e9dcSMatteo Franciolini [&](DialectBytecodeReader &reader, StringRef dialectName,
3517ad9e9dcSMatteo Franciolini Attribute &entry) -> LogicalResult {
3527ad9e9dcSMatteo Franciolini Attribute builtinAttr = iface->readAttribute(reader);
3537ad9e9dcSMatteo Franciolini if (!builtinAttr)
3547ad9e9dcSMatteo Franciolini return failure();
3557ad9e9dcSMatteo Franciolini entry = builtinAttr;
3567ad9e9dcSMatteo Franciolini return success();
3577ad9e9dcSMatteo Franciolini });
3587ad9e9dcSMatteo Franciolini parseConfig.getBytecodeReaderConfig().attachTypeCallback(
3597ad9e9dcSMatteo Franciolini [&](DialectBytecodeReader &reader, StringRef dialectName,
3607ad9e9dcSMatteo Franciolini Type &entry) -> LogicalResult {
3617ad9e9dcSMatteo Franciolini Type builtinType = iface->readType(reader);
3627ad9e9dcSMatteo Franciolini if (!builtinType) {
3637ad9e9dcSMatteo Franciolini return failure();
3647ad9e9dcSMatteo Franciolini }
3657ad9e9dcSMatteo Franciolini entry = builtinType;
3667ad9e9dcSMatteo Franciolini return success();
3677ad9e9dcSMatteo Franciolini });
3687ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
3697ad9e9dcSMatteo Franciolini }
3707ad9e9dcSMatteo Franciolini
downgradeToVersion__anon58778feb0111::TestBytecodeRoundtripPass3717ad9e9dcSMatteo Franciolini LogicalResult downgradeToVersion(Operation *op,
3727ad9e9dcSMatteo Franciolini const test::TestDialectVersion &version) {
3737ad9e9dcSMatteo Franciolini if ((version.major_ == 2) && (version.minor_ == 0))
3747ad9e9dcSMatteo Franciolini return success();
3757ad9e9dcSMatteo Franciolini if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
3767ad9e9dcSMatteo Franciolini return op->emitError() << "current test dialect version is 2.0, "
3777ad9e9dcSMatteo Franciolini "can't downgrade to version: "
3787ad9e9dcSMatteo Franciolini << version.major_ << "." << version.minor_;
3797ad9e9dcSMatteo Franciolini }
3807ad9e9dcSMatteo Franciolini // Prior version 2.0, the old op supported only a single attribute called
3817ad9e9dcSMatteo Franciolini // "dimensions". We need to check that the modifier is false, otherwise we
3827ad9e9dcSMatteo Franciolini // can't do the downgrade.
3837ad9e9dcSMatteo Franciolini auto status = op->walk([&](test::TestVersionedOpA op) {
3847ad9e9dcSMatteo Franciolini auto &prop = op.getProperties();
3857ad9e9dcSMatteo Franciolini if (prop.modifier.getValue()) {
3867ad9e9dcSMatteo Franciolini op->emitOpError() << "cannot downgrade to version " << version.major_
3877ad9e9dcSMatteo Franciolini << "." << version.minor_
3887ad9e9dcSMatteo Franciolini << " since the modifier is not compatible";
3897ad9e9dcSMatteo Franciolini return WalkResult::interrupt();
3907ad9e9dcSMatteo Franciolini }
3917ad9e9dcSMatteo Franciolini llvm::outs() << "downgrading op...\n";
3927ad9e9dcSMatteo Franciolini return WalkResult::advance();
3937ad9e9dcSMatteo Franciolini });
3947ad9e9dcSMatteo Franciolini return failure(status.wasInterrupted());
3957ad9e9dcSMatteo Franciolini }
3967ad9e9dcSMatteo Franciolini
3977ad9e9dcSMatteo Franciolini // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
3987ad9e9dcSMatteo Franciolini // upgrade IR when back in memory. The module is expected to be unmodified at
3997ad9e9dcSMatteo Franciolini // the end of the function.
runTest6__anon58778feb0111::TestBytecodeRoundtripPass4007ad9e9dcSMatteo Franciolini void runTest6(Operation *op) {
4017ad9e9dcSMatteo Franciolini test::TestDialectVersion targetEmissionVersion = targetVersion;
4027ad9e9dcSMatteo Franciolini
4037ad9e9dcSMatteo Franciolini // Downgrade IR constructs before writing the IR to bytecode.
4047ad9e9dcSMatteo Franciolini auto status = downgradeToVersion(op, targetEmissionVersion);
4057ad9e9dcSMatteo Franciolini assert(succeeded(status) && "expected the downgrade to succeed");
4065bfa16cbSKazu Hirata (void)status;
4077ad9e9dcSMatteo Franciolini
4087ad9e9dcSMatteo Franciolini BytecodeWriterConfig writeConfig;
4097ad9e9dcSMatteo Franciolini writeConfig.setDialectVersion<test::TestDialect>(
4107ad9e9dcSMatteo Franciolini std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
4117ad9e9dcSMatteo Franciolini ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
4127ad9e9dcSMatteo Franciolini doRoundtripWithConfigs(op, writeConfig, parseConfig);
4137ad9e9dcSMatteo Franciolini }
4147ad9e9dcSMatteo Franciolini
4157ad9e9dcSMatteo Franciolini test::TestDialect *testDialect;
4167ad9e9dcSMatteo Franciolini };
4177ad9e9dcSMatteo Franciolini } // namespace
4187ad9e9dcSMatteo Franciolini
4197ad9e9dcSMatteo Franciolini namespace mlir {
registerTestBytecodeRoundtripPasses()4207ad9e9dcSMatteo Franciolini void registerTestBytecodeRoundtripPasses() {
4217ad9e9dcSMatteo Franciolini PassRegistration<TestBytecodeRoundtripPass>();
4227ad9e9dcSMatteo Franciolini }
4237ad9e9dcSMatteo Franciolini } // namespace mlir
424