xref: /llvm-project/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks  --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Bytecode/BytecodeReader.h"
12 #include "mlir/Bytecode/BytecodeWriter.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/Parser/Parser.h"
16 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/MemoryBufferRef.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include <list>
21 
22 using namespace mlir;
23 using namespace llvm;
24 
25 namespace {
26 class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
27 public:
TestDialectVersionParser(cl::Option & o)28   TestDialectVersionParser(cl::Option &o)
29       : cl::parser<test::TestDialectVersion>(o) {}
30 
parse(cl::Option & o,StringRef,StringRef arg,test::TestDialectVersion & v)31   bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
32              test::TestDialectVersion &v) {
33     long long major, minor;
34     if (getAsSignedInteger(arg.split(".").first, 10, major))
35       return o.error("Invalid argument '" + arg);
36     if (getAsSignedInteger(arg.split(".").second, 10, minor))
37       return o.error("Invalid argument '" + arg);
38     v = test::TestDialectVersion(major, minor);
39     // Returns true on error.
40     return false;
41   }
print(raw_ostream & os,const test::TestDialectVersion & v)42   static void print(raw_ostream &os, const test::TestDialectVersion &v) {
43     os << v.major_ << "." << v.minor_;
44   };
45 };
46 
47 /// This is a test pass which uses callbacks to encode attributes and types in a
48 /// custom fashion.
49 struct TestBytecodeRoundtripPass
50     : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon58778feb0111::TestBytecodeRoundtripPass51   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
52 
53   StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
getDescription__anon58778feb0111::TestBytecodeRoundtripPass54   StringRef getDescription() const final {
55     return "Test pass to implement bytecode roundtrip tests.";
56   }
getDependentDialects__anon58778feb0111::TestBytecodeRoundtripPass57   void getDependentDialects(DialectRegistry &registry) const override {
58     registry.insert<test::TestDialect>();
59   }
60   TestBytecodeRoundtripPass() = default;
TestBytecodeRoundtripPass__anon58778feb0111::TestBytecodeRoundtripPass61   TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
62 
initialize__anon58778feb0111::TestBytecodeRoundtripPass63   LogicalResult initialize(MLIRContext *context) override {
64     testDialect = context->getOrLoadDialect<test::TestDialect>();
65     return success();
66   }
67 
runOnOperation__anon58778feb0111::TestBytecodeRoundtripPass68   void runOnOperation() override {
69     switch (testKind) {
70       // Tests 0-5 implement a custom roundtrip with callbacks.
71     case (0):
72       return runTest0(getOperation());
73     case (1):
74       return runTest1(getOperation());
75     case (2):
76       return runTest2(getOperation());
77     case (3):
78       return runTest3(getOperation());
79     case (4):
80       return runTest4(getOperation());
81     case (5):
82       return runTest5(getOperation());
83     case (6):
84       // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
85       // `targetVersion`.
86       return runTest6(getOperation());
87     default:
88       llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
89     }
90   }
91 
92   mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
93       targetVersion{*this, "test-dialect-version",
94                     llvm::cl::desc(
95                         "Specifies the test dialect version to emit and parse"),
96                     cl::init(test::TestDialectVersion())};
97 
98   mlir::Pass::Option<int> testKind{
99       *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
100       cl::init(0)};
101 
102 private:
doRoundtripWithConfigs__anon58778feb0111::TestBytecodeRoundtripPass103   void doRoundtripWithConfigs(Operation *op,
104                               const BytecodeWriterConfig &writeConfig,
105                               const ParserConfig &parseConfig) {
106     std::string bytecode;
107     llvm::raw_string_ostream os(bytecode);
108     if (failed(writeBytecodeToFile(op, os, writeConfig))) {
109       op->emitError() << "failed to write bytecode\n";
110       signalPassFailure();
111       return;
112     }
113     auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
114     if (!newModuleOp.get()) {
115       op->emitError() << "failed to read bytecode\n";
116       signalPassFailure();
117       return;
118     }
119     // Print the module to the output stream, so that we can filecheck the
120     // result.
121     newModuleOp->print(llvm::outs());
122   }
123 
124   // Test0: let's assume that versions older than 2.0 were relying on a special
125   // integer attribute of a deprecated dialect called "funky". Assume that its
126   // encoding was made by two varInts, the first was the ID (999) and the second
127   // contained width and signedness info. We can emit it using a callback
128   // writing a custom encoding for the "funky" dialect group, and parse it back
129   // with a custom parser reading the same encoding in the same dialect group.
130   // Note that the ID 999 does not correspond to a valid integer type in the
131   // current encodings of builtin types.
runTest0__anon58778feb0111::TestBytecodeRoundtripPass132   void runTest0(Operation *op) {
133     auto newCtx = std::make_shared<MLIRContext>();
134     test::TestDialectVersion targetEmissionVersion = targetVersion;
135     BytecodeWriterConfig writeConfig;
136     // Set the emission version for the test dialect.
137     writeConfig.setDialectVersion<test::TestDialect>(
138         std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
139     writeConfig.attachTypeCallback(
140         [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
141             DialectBytecodeWriter &writer) -> LogicalResult {
142           // Do not override anything if version greater than 2.0.
143           auto versionOr = writer.getDialectVersion<test::TestDialect>();
144           assert(succeeded(versionOr) && "expected reader to be able to access "
145                                          "the version for test dialect");
146           const auto *version =
147               reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
148           if (version->major_ >= 2)
149             return failure();
150 
151           // For version less than 2.0, override the encoding of IntegerType.
152           if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
153             llvm::outs() << "Overriding IntegerType encoding...\n";
154             dialectGroupName = StringLiteral("funky");
155             writer.writeVarInt(/* IntegerType */ 999);
156             writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
157             return success();
158           }
159           return failure();
160         });
161     newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
162     newCtx->allowUnregisteredDialects();
163     ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
164     parseConfig.getBytecodeReaderConfig().attachTypeCallback(
165         [&](DialectBytecodeReader &reader, StringRef dialectName,
166             Type &entry) -> LogicalResult {
167           // Get test dialect version from the version map.
168           auto versionOr = reader.getDialectVersion<test::TestDialect>();
169           assert(succeeded(versionOr) && "expected reader to be able to access "
170                                          "the version for test dialect");
171           const auto *version =
172               reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
173           if (version->major_ >= 2)
174             return success();
175 
176           // `dialectName` is the name of the group we have the opportunity to
177           // override. In this case, override only the dialect group "funky",
178           // for which does not exist in memory.
179           if (dialectName != StringLiteral("funky"))
180             return success();
181 
182           uint64_t encoding;
183           if (failed(reader.readVarInt(encoding)) || encoding != 999)
184             return success();
185           llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
186           uint64_t widthAndSignedness, width;
187           IntegerType::SignednessSemantics signedness;
188           if (succeeded(reader.readVarInt(widthAndSignedness)) &&
189               ((width = widthAndSignedness >> 2), true) &&
190               ((signedness = static_cast<IntegerType::SignednessSemantics>(
191                     widthAndSignedness & 0x3)),
192                true))
193             entry = IntegerType::get(reader.getContext(), width, signedness);
194           // Return nullopt to fall through the rest of the parsing code path.
195           return success();
196         });
197     doRoundtripWithConfigs(op, writeConfig, parseConfig);
198   }
199 
200   // Test1: When writing bytecode, we override the encoding of TestI32Type with
201   // the encoding of builtin IntegerType. We can natively parse this without
202   // the use of a callback, relying on the existing builtin reader mechanism.
runTest1__anon58778feb0111::TestBytecodeRoundtripPass203   void runTest1(Operation *op) {
204     auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
205     BytecodeDialectInterface *iface =
206         builtin->getRegisteredInterface<BytecodeDialectInterface>();
207     BytecodeWriterConfig writeConfig;
208     writeConfig.attachTypeCallback(
209         [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
210             DialectBytecodeWriter &writer) -> LogicalResult {
211           // Emit TestIntegerType using the builtin dialect encoding.
212           if (llvm::isa<test::TestI32Type>(entryValue)) {
213             llvm::outs() << "Overriding TestI32Type encoding...\n";
214             auto builtinI32Type =
215                 IntegerType::get(op->getContext(), 32,
216                                  IntegerType::SignednessSemantics::Signless);
217             // Specify that this type will need to be written as part of the
218             // builtin group. This will override the default dialect group of
219             // the attribute (test).
220             dialectGroupName = StringLiteral("builtin");
221             if (succeeded(iface->writeType(builtinI32Type, writer)))
222               return success();
223           }
224           return failure();
225         });
226     // We natively parse the attribute as a builtin, so no callback needed.
227     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
228     doRoundtripWithConfigs(op, writeConfig, parseConfig);
229   }
230 
231   // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
232   // parsing, we use the encoding of IntegerType to intercept all i32. Then,
233   // instead of creating i32s, we assemble TestI32Type and return it.
runTest2__anon58778feb0111::TestBytecodeRoundtripPass234   void runTest2(Operation *op) {
235     auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
236     BytecodeDialectInterface *iface =
237         builtin->getRegisteredInterface<BytecodeDialectInterface>();
238     BytecodeWriterConfig writeConfig;
239     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
240     parseConfig.getBytecodeReaderConfig().attachTypeCallback(
241         [&](DialectBytecodeReader &reader, StringRef dialectName,
242             Type &entry) -> LogicalResult {
243           if (dialectName != StringLiteral("builtin"))
244             return success();
245           Type builtinAttr = iface->readType(reader);
246           if (auto integerType =
247                   llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
248             if (integerType.getWidth() == 32 && integerType.isSignless()) {
249               llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
250               entry = test::TestI32Type::get(reader.getContext());
251             }
252           }
253           return success();
254         });
255     doRoundtripWithConfigs(op, writeConfig, parseConfig);
256   }
257 
258   // Test3: When writing bytecode, we override the encoding of
259   // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
260   // can natively parse this without the use of a callback, relying on the
261   // existing builtin reader mechanism.
runTest3__anon58778feb0111::TestBytecodeRoundtripPass262   void runTest3(Operation *op) {
263     auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
264     BytecodeDialectInterface *iface =
265         builtin->getRegisteredInterface<BytecodeDialectInterface>();
266     auto i32Type = IntegerType::get(op->getContext(), 32,
267                                     IntegerType::SignednessSemantics::Signless);
268     BytecodeWriterConfig writeConfig;
269     writeConfig.attachAttributeCallback(
270         [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
271             DialectBytecodeWriter &writer) -> LogicalResult {
272           // Emit TestIntegerType using the builtin dialect encoding.
273           if (auto testParamAttrs =
274                   llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
275             llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
276             // Specify that this attribute will need to be written as part of
277             // the builtin group. This will override the default dialect group
278             // of the attribute (test).
279             dialectGroupName = StringLiteral("builtin");
280             auto denseAttr = DenseIntElementsAttr::get(
281                 RankedTensorType::get({2}, i32Type),
282                 {testParamAttrs.getV0(), testParamAttrs.getV1()});
283             if (succeeded(iface->writeAttribute(denseAttr, writer)))
284               return success();
285           }
286           return failure();
287         });
288     // We natively parse the attribute as a builtin, so no callback needed.
289     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
290     doRoundtripWithConfigs(op, writeConfig, parseConfig);
291   }
292 
293   // Test4: When writing bytecode, we write standard builtin
294   // DenseIntElementsAttr. At parsing, we use the encoding of
295   // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
296   // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
297   // TestAttrParamsAttr and return it.
runTest4__anon58778feb0111::TestBytecodeRoundtripPass298   void runTest4(Operation *op) {
299     auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
300     BytecodeDialectInterface *iface =
301         builtin->getRegisteredInterface<BytecodeDialectInterface>();
302     auto i32Type = IntegerType::get(op->getContext(), 32,
303                                     IntegerType::SignednessSemantics::Signless);
304     BytecodeWriterConfig writeConfig;
305     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
306     parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
307         [&](DialectBytecodeReader &reader, StringRef dialectName,
308             Attribute &entry) -> LogicalResult {
309           // Override only the case where the return type of the builtin reader
310           // is an i32 and fall through on all the other cases, since we want to
311           // still use TestDialect normal codepath to parse the other types.
312           Attribute builtinAttr = iface->readAttribute(reader);
313           if (auto denseAttr =
314                   llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
315             if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
316                 denseAttr.getElementType() == i32Type) {
317               llvm::outs()
318                   << "Overriding parsing of TestAttrParamsAttr encoding...\n";
319               int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
320               int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
321               entry =
322                   test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
323             }
324           }
325           return success();
326         });
327     doRoundtripWithConfigs(op, writeConfig, parseConfig);
328   }
329 
330   // Test5: When writing bytecode, we want TestDialect to use nothing else than
331   // the builtin types and attributes and take full control of the encoding,
332   // returning failure if any type or attribute is not part of builtin.
runTest5__anon58778feb0111::TestBytecodeRoundtripPass333   void runTest5(Operation *op) {
334     auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
335     BytecodeDialectInterface *iface =
336         builtin->getRegisteredInterface<BytecodeDialectInterface>();
337     BytecodeWriterConfig writeConfig;
338     writeConfig.attachAttributeCallback(
339         [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
340             DialectBytecodeWriter &writer) -> LogicalResult {
341           return iface->writeAttribute(attr, writer);
342         });
343     writeConfig.attachTypeCallback(
344         [&](Type type, std::optional<StringRef> &dialectGroupName,
345             DialectBytecodeWriter &writer) -> LogicalResult {
346           return iface->writeType(type, writer);
347         });
348     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
349     parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
350         [&](DialectBytecodeReader &reader, StringRef dialectName,
351             Attribute &entry) -> LogicalResult {
352           Attribute builtinAttr = iface->readAttribute(reader);
353           if (!builtinAttr)
354             return failure();
355           entry = builtinAttr;
356           return success();
357         });
358     parseConfig.getBytecodeReaderConfig().attachTypeCallback(
359         [&](DialectBytecodeReader &reader, StringRef dialectName,
360             Type &entry) -> LogicalResult {
361           Type builtinType = iface->readType(reader);
362           if (!builtinType) {
363             return failure();
364           }
365           entry = builtinType;
366           return success();
367         });
368     doRoundtripWithConfigs(op, writeConfig, parseConfig);
369   }
370 
downgradeToVersion__anon58778feb0111::TestBytecodeRoundtripPass371   LogicalResult downgradeToVersion(Operation *op,
372                                    const test::TestDialectVersion &version) {
373     if ((version.major_ == 2) && (version.minor_ == 0))
374       return success();
375     if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
376       return op->emitError() << "current test dialect version is 2.0, "
377                                 "can't downgrade to version: "
378                              << version.major_ << "." << version.minor_;
379     }
380     // Prior version 2.0, the old op supported only a single attribute called
381     // "dimensions". We need to check that the modifier is false, otherwise we
382     // can't do the downgrade.
383     auto status = op->walk([&](test::TestVersionedOpA op) {
384       auto &prop = op.getProperties();
385       if (prop.modifier.getValue()) {
386         op->emitOpError() << "cannot downgrade to version " << version.major_
387                           << "." << version.minor_
388                           << " since the modifier is not compatible";
389         return WalkResult::interrupt();
390       }
391       llvm::outs() << "downgrading op...\n";
392       return WalkResult::advance();
393     });
394     return failure(status.wasInterrupted());
395   }
396 
397   // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
398   // upgrade IR when back in memory. The module is expected to be unmodified at
399   // the end of the function.
runTest6__anon58778feb0111::TestBytecodeRoundtripPass400   void runTest6(Operation *op) {
401     test::TestDialectVersion targetEmissionVersion = targetVersion;
402 
403     // Downgrade IR constructs before writing the IR to bytecode.
404     auto status = downgradeToVersion(op, targetEmissionVersion);
405     assert(succeeded(status) && "expected the downgrade to succeed");
406     (void)status;
407 
408     BytecodeWriterConfig writeConfig;
409     writeConfig.setDialectVersion<test::TestDialect>(
410         std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
411     ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
412     doRoundtripWithConfigs(op, writeConfig, parseConfig);
413   }
414 
415   test::TestDialect *testDialect;
416 };
417 } // namespace
418 
419 namespace mlir {
registerTestBytecodeRoundtripPasses()420 void registerTestBytecodeRoundtripPasses() {
421   PassRegistration<TestBytecodeRoundtripPass>();
422 }
423 } // namespace mlir
424