xref: /llvm-project/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp (revision 46e41c8631bd6c1a6c91d6cc4a5e4f1671078ccd)
1 //===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Interfaces/FoldInterfaces.h"
12 #include "mlir/Reducer/ReductionPatternInterface.h"
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 using namespace mlir;
16 using namespace test;
17 
18 //===----------------------------------------------------------------------===//
19 // TestDialect Interfaces
20 //===----------------------------------------------------------------------===//
21 
22 namespace {
23 
24 /// Testing the correctness of some traits.
25 static_assert(
26     llvm::is_detected<OpTrait::has_implicit_terminator_t,
27                       SingleBlockImplicitTerminatorOp>::value,
28     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
29 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
30                   SingleBlockImplicitTerminatorOp>::value,
31               "hasSingleBlockImplicitTerminator does not match "
32               "SingleBlockImplicitTerminatorOp");
33 
34 struct TestResourceBlobManagerInterface
35     : public ResourceBlobManagerDialectInterfaceBase<
36           TestDialectResourceBlobHandle> {
37   using ResourceBlobManagerDialectInterfaceBase<
38       TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
39 };
40 
41 namespace {
42 enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
43 } // namespace
44 
45 // Test support for interacting with the Bytecode reader/writer.
46 struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
47   using BytecodeDialectInterface::BytecodeDialectInterface;
TestBytecodeDialectInterface__anon2bddc4b90111::TestBytecodeDialectInterface48   TestBytecodeDialectInterface(Dialect *dialect)
49       : BytecodeDialectInterface(dialect) {}
50 
writeType__anon2bddc4b90111::TestBytecodeDialectInterface51   LogicalResult writeType(Type type,
52                           DialectBytecodeWriter &writer) const final {
53     if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
54       writer.writeVarInt(test_encoding::k_test_i32);
55       return success();
56     }
57     return failure();
58   }
59 
readType__anon2bddc4b90111::TestBytecodeDialectInterface60   Type readType(DialectBytecodeReader &reader) const final {
61     uint64_t encoding;
62     if (failed(reader.readVarInt(encoding)))
63       return Type();
64     if (encoding == test_encoding::k_test_i32)
65       return TestI32Type::get(getContext());
66     return Type();
67   }
68 
writeAttribute__anon2bddc4b90111::TestBytecodeDialectInterface69   LogicalResult writeAttribute(Attribute attr,
70                                DialectBytecodeWriter &writer) const final {
71     if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
72       writer.writeVarInt(test_encoding::k_attr_params);
73       writer.writeVarInt(concreteAttr.getV0());
74       writer.writeVarInt(concreteAttr.getV1());
75       return success();
76     }
77     return failure();
78   }
79 
readAttribute__anon2bddc4b90111::TestBytecodeDialectInterface80   Attribute readAttribute(DialectBytecodeReader &reader) const final {
81     auto versionOr = reader.getDialectVersion<test::TestDialect>();
82     // Assume current version if not available through the reader.
83     const auto version =
84         (succeeded(versionOr))
85             ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
86             : TestDialectVersion();
87     if (version.major_ < 2)
88       return readAttrOldEncoding(reader);
89     if (version.major_ == 2 && version.minor_ == 0)
90       return readAttrNewEncoding(reader);
91     // Forbid reading future versions by returning nullptr.
92     return Attribute();
93   }
94 
95   // Emit a specific version of the dialect.
writeVersion__anon2bddc4b90111::TestBytecodeDialectInterface96   void writeVersion(DialectBytecodeWriter &writer) const final {
97     // Construct the current dialect version.
98     test::TestDialectVersion versionToEmit;
99 
100     // Check if a target version to emit was specified on the writer configs.
101     auto versionOr = writer.getDialectVersion<test::TestDialect>();
102     if (succeeded(versionOr))
103       versionToEmit =
104           *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
105     writer.writeVarInt(versionToEmit.major_); // major
106     writer.writeVarInt(versionToEmit.minor_); // minor
107   }
108 
109   std::unique_ptr<DialectVersion>
readVersion__anon2bddc4b90111::TestBytecodeDialectInterface110   readVersion(DialectBytecodeReader &reader) const final {
111     uint64_t major_, minor_;
112     if (failed(reader.readVarInt(major_)) || failed(reader.readVarInt(minor_)))
113       return nullptr;
114     auto version = std::make_unique<TestDialectVersion>();
115     version->major_ = major_;
116     version->minor_ = minor_;
117     return version;
118   }
119 
upgradeFromVersion__anon2bddc4b90111::TestBytecodeDialectInterface120   LogicalResult upgradeFromVersion(Operation *topLevelOp,
121                                    const DialectVersion &version_) const final {
122     const auto &version = static_cast<const TestDialectVersion &>(version_);
123     if ((version.major_ == 2) && (version.minor_ == 0))
124       return success();
125     if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
126       return topLevelOp->emitError()
127              << "current test dialect version is 2.0, can't parse version: "
128              << version.major_ << "." << version.minor_;
129     }
130     // Prior version 2.0, the old op supported only a single attribute called
131     // "dimensions". We can perform the upgrade.
132     topLevelOp->walk([](TestVersionedOpA op) {
133       // Prior version 2.0, `readProperties` did not process the modifier
134       // attribute. Handle that according to the version here.
135       auto &prop = op.getProperties();
136       prop.modifier = BoolAttr::get(op->getContext(), false);
137     });
138     return success();
139   }
140 
141 private:
readAttrNewEncoding__anon2bddc4b90111::TestBytecodeDialectInterface142   Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
143     uint64_t encoding;
144     if (failed(reader.readVarInt(encoding)) ||
145         encoding != test_encoding::k_attr_params)
146       return Attribute();
147     // The new encoding has v0 first, v1 second.
148     uint64_t v0, v1;
149     if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
150       return Attribute();
151     return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
152                                    static_cast<int>(v1));
153   }
154 
readAttrOldEncoding__anon2bddc4b90111::TestBytecodeDialectInterface155   Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
156     uint64_t encoding;
157     if (failed(reader.readVarInt(encoding)) ||
158         encoding != test_encoding::k_attr_params)
159       return Attribute();
160     // The old encoding has v1 first, v0 second.
161     uint64_t v0, v1;
162     if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
163       return Attribute();
164     return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
165                                    static_cast<int>(v1));
166   }
167 };
168 
169 // Test support for interacting with the AsmPrinter.
170 struct TestOpAsmInterface : public OpAsmDialectInterface {
171   using OpAsmDialectInterface::OpAsmDialectInterface;
TestOpAsmInterface__anon2bddc4b90111::TestOpAsmInterface172   TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
173       : OpAsmDialectInterface(dialect), blobManager(mgr) {}
174 
175   //===------------------------------------------------------------------===//
176   // Aliases
177   //===------------------------------------------------------------------===//
178 
getAlias__anon2bddc4b90111::TestOpAsmInterface179   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
180     StringAttr strAttr = dyn_cast<StringAttr>(attr);
181     if (!strAttr)
182       return AliasResult::NoAlias;
183 
184     // Check the contents of the string attribute to see what the test alias
185     // should be named.
186     std::optional<StringRef> aliasName =
187         StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188             .Case("alias_test:dot_in_name", StringRef("test.alias"))
189             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
190             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191             .Case("alias_test:prefixed_symbol", StringRef("%test"))
192             .Case("alias_test:sanitize_conflict_a",
193                   StringRef("test_alias_conflict0"))
194             .Case("alias_test:sanitize_conflict_b",
195                   StringRef("test_alias_conflict0_"))
196             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197             .Default(std::nullopt);
198     if (!aliasName)
199       return AliasResult::NoAlias;
200 
201     os << *aliasName;
202     return AliasResult::FinalAlias;
203   }
204 
getAlias__anon2bddc4b90111::TestOpAsmInterface205   AliasResult getAlias(Type type, raw_ostream &os) const final {
206     if (auto tupleType = dyn_cast<TupleType>(type)) {
207       if (tupleType.size() > 0 &&
208           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
209             return isa<SimpleAType>(elemType);
210           })) {
211         os << "test_tuple";
212         return AliasResult::FinalAlias;
213       }
214     }
215     if (auto intType = dyn_cast<TestIntegerType>(type)) {
216       if (intType.getSignedness() ==
217               TestIntegerType::SignednessSemantics::Unsigned &&
218           intType.getWidth() == 8) {
219         os << "test_ui8";
220         return AliasResult::FinalAlias;
221       }
222     }
223     if (auto recType = dyn_cast<TestRecursiveType>(type)) {
224       if (recType.getName() == "type_to_alias") {
225         // We only make alias for a specific recursive type.
226         os << "testrec";
227         return AliasResult::FinalAlias;
228       }
229     }
230     if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
231       os << recAliasType.getName();
232       return AliasResult::FinalAlias;
233     }
234     return AliasResult::NoAlias;
235   }
236 
237   //===------------------------------------------------------------------===//
238   // Resources
239   //===------------------------------------------------------------------===//
240 
241   std::string
getResourceKey__anon2bddc4b90111::TestOpAsmInterface242   getResourceKey(const AsmDialectResourceHandle &handle) const override {
243     return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
244   }
245 
246   FailureOr<AsmDialectResourceHandle>
declareResource__anon2bddc4b90111::TestOpAsmInterface247   declareResource(StringRef key) const final {
248     return blobManager.insert(key);
249   }
250 
parseResource__anon2bddc4b90111::TestOpAsmInterface251   LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
252     FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
253     if (failed(blob))
254       return failure();
255 
256     // Update the blob for this entry.
257     blobManager.update(entry.getKey(), std::move(*blob));
258     return success();
259   }
260 
261   void
buildResources__anon2bddc4b90111::TestOpAsmInterface262   buildResources(Operation *op,
263                  const SetVector<AsmDialectResourceHandle> &referencedResources,
264                  AsmResourceBuilder &provider) const final {
265     blobManager.buildResources(provider, referencedResources.getArrayRef());
266   }
267 
268 private:
269   /// The blob manager for the dialect.
270   TestResourceBlobManagerInterface &blobManager;
271 };
272 
273 struct TestDialectFoldInterface : public DialectFoldInterface {
274   using DialectFoldInterface::DialectFoldInterface;
275 
276   /// Registered hook to check if the given region, which is attached to an
277   /// operation that is *not* isolated from above, should be used when
278   /// materializing constants.
shouldMaterializeInto__anon2bddc4b90111::TestDialectFoldInterface279   bool shouldMaterializeInto(Region *region) const final {
280     // If this is a one region operation, then insert into it.
281     return isa<OneRegionOp>(region->getParentOp());
282   }
283 };
284 
285 /// This class defines the interface for handling inlining with standard
286 /// operations.
287 struct TestInlinerInterface : public DialectInlinerInterface {
288   using DialectInlinerInterface::DialectInlinerInterface;
289 
290   //===--------------------------------------------------------------------===//
291   // Analysis Hooks
292   //===--------------------------------------------------------------------===//
293 
isLegalToInline__anon2bddc4b90111::TestInlinerInterface294   bool isLegalToInline(Operation *call, Operation *callable,
295                        bool wouldBeCloned) const final {
296     // Don't allow inlining calls that are marked `noinline`.
297     return !call->hasAttr("noinline");
298   }
isLegalToInline__anon2bddc4b90111::TestInlinerInterface299   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
300     // Inlining into test dialect regions is legal.
301     return true;
302   }
isLegalToInline__anon2bddc4b90111::TestInlinerInterface303   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
304     return true;
305   }
306 
shouldAnalyzeRecursively__anon2bddc4b90111::TestInlinerInterface307   bool shouldAnalyzeRecursively(Operation *op) const final {
308     // Analyze recursively if this is not a functional region operation, it
309     // froms a separate functional scope.
310     return !isa<FunctionalRegionOp>(op);
311   }
312 
313   //===--------------------------------------------------------------------===//
314   // Transformation Hooks
315   //===--------------------------------------------------------------------===//
316 
317   /// Handle the given inlined terminator by replacing it with a new operation
318   /// as necessary.
handleTerminator__anon2bddc4b90111::TestInlinerInterface319   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
320     // Only handle "test.return" here.
321     auto returnOp = dyn_cast<TestReturnOp>(op);
322     if (!returnOp)
323       return;
324 
325     // Replace the values directly with the return operands.
326     assert(returnOp.getNumOperands() == valuesToRepl.size());
327     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
328       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
329   }
330 
331   /// Attempt to materialize a conversion for a type mismatch between a call
332   /// from this dialect, and a callable region. This method should generate an
333   /// operation that takes 'input' as the only operand, and produces a single
334   /// result of 'resultType'. If a conversion can not be generated, nullptr
335   /// should be returned.
materializeCallConversion__anon2bddc4b90111::TestInlinerInterface336   Operation *materializeCallConversion(OpBuilder &builder, Value input,
337                                        Type resultType,
338                                        Location conversionLoc) const final {
339     // Only allow conversion for i16/i32 types.
340     if (!(resultType.isSignlessInteger(16) ||
341           resultType.isSignlessInteger(32)) ||
342         !(input.getType().isSignlessInteger(16) ||
343           input.getType().isSignlessInteger(32)))
344       return nullptr;
345     return builder.create<TestCastOp>(conversionLoc, resultType, input);
346   }
347 
handleArgument__anon2bddc4b90111::TestInlinerInterface348   Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
349                        Value argument,
350                        DictionaryAttr argumentAttrs) const final {
351     if (!argumentAttrs.contains("test.handle_argument"))
352       return argument;
353     return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
354                                              argument);
355   }
356 
handleResult__anon2bddc4b90111::TestInlinerInterface357   Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
358                      Value result, DictionaryAttr resultAttrs) const final {
359     if (!resultAttrs.contains("test.handle_result"))
360       return result;
361     return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
362                                              result);
363   }
364 
processInlinedCallBlocks__anon2bddc4b90111::TestInlinerInterface365   void processInlinedCallBlocks(
366       Operation *call,
367       iterator_range<Region::iterator> inlinedBlocks) const final {
368     if (!isa<ConversionCallOp>(call))
369       return;
370 
371     // Set attributed on all ops in the inlined blocks.
372     for (Block &block : inlinedBlocks) {
373       block.walk([&](Operation *op) {
374         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
375       });
376     }
377   }
378 };
379 
380 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
381 public:
TestReductionPatternInterface__anon2bddc4b90111::TestReductionPatternInterface382   TestReductionPatternInterface(Dialect *dialect)
383       : DialectReductionPatternInterface(dialect) {}
384 
populateReductionPatterns__anon2bddc4b90111::TestReductionPatternInterface385   void populateReductionPatterns(RewritePatternSet &patterns) const final {
386     populateTestReductionPatterns(patterns);
387   }
388 };
389 
390 } // namespace
391 
registerInterfaces()392 void TestDialect::registerInterfaces() {
393   auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
394   addInterface<TestOpAsmInterface>(blobInterface);
395 
396   addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
397                 TestReductionPatternInterface, TestBytecodeDialectInterface>();
398 }
399