xref: /llvm-project/mlir/unittests/IR/InterfaceAttachmentTest.cpp (revision d0e6fd99aa95ff61372ea328e9f89da2ee39c49c)
19b2a1bcfSAlex Zinenko //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
29b2a1bcfSAlex Zinenko //
39b2a1bcfSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49b2a1bcfSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
59b2a1bcfSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69b2a1bcfSAlex Zinenko //
79b2a1bcfSAlex Zinenko //===----------------------------------------------------------------------===//
89b2a1bcfSAlex Zinenko //
99b2a1bcfSAlex Zinenko // This implements the tests for attaching interfaces to attributes and types
109b2a1bcfSAlex Zinenko // without having to specify them on the attribute or type class directly.
119b2a1bcfSAlex Zinenko //
129b2a1bcfSAlex Zinenko //===----------------------------------------------------------------------===//
139b2a1bcfSAlex Zinenko 
149b2a1bcfSAlex Zinenko #include "mlir/IR/BuiltinAttributes.h"
15d7e89121SAlex Zinenko #include "mlir/IR/BuiltinDialect.h"
1623cdf7b6SAlex Zinenko #include "mlir/IR/BuiltinOps.h"
179b2a1bcfSAlex Zinenko #include "mlir/IR/BuiltinTypes.h"
189b2a1bcfSAlex Zinenko #include "gtest/gtest.h"
199b2a1bcfSAlex Zinenko 
209b2a1bcfSAlex Zinenko #include "../../test/lib/Dialect/Test/TestAttributes.h"
2123cdf7b6SAlex Zinenko #include "../../test/lib/Dialect/Test/TestDialect.h"
229b2a1bcfSAlex Zinenko #include "../../test/lib/Dialect/Test/TestTypes.h"
23db79f4a2SMehdi Amini #include "mlir/IR/OwningOpRef.h"
249b2a1bcfSAlex Zinenko 
259b2a1bcfSAlex Zinenko using namespace mlir;
267776b19eSStephen Neuendorffer using namespace test;
279b2a1bcfSAlex Zinenko 
289b2a1bcfSAlex Zinenko namespace {
299b2a1bcfSAlex Zinenko 
309b2a1bcfSAlex Zinenko /// External interface model for the integer type. Only provides non-default
319b2a1bcfSAlex Zinenko /// methods.
329b2a1bcfSAlex Zinenko struct Model
339b2a1bcfSAlex Zinenko     : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
349b2a1bcfSAlex Zinenko   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
359b2a1bcfSAlex Zinenko     return type.getIntOrFloatBitWidth() + arg;
369b2a1bcfSAlex Zinenko   }
379b2a1bcfSAlex Zinenko 
389b2a1bcfSAlex Zinenko   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
399b2a1bcfSAlex Zinenko };
409b2a1bcfSAlex Zinenko 
419b2a1bcfSAlex Zinenko /// External interface model for the float type. Provides non-deafult and
429b2a1bcfSAlex Zinenko /// overrides default methods.
439b2a1bcfSAlex Zinenko struct OverridingModel
449b2a1bcfSAlex Zinenko     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
459b2a1bcfSAlex Zinenko                                                       FloatType> {
469b2a1bcfSAlex Zinenko   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
479b2a1bcfSAlex Zinenko     return type.getIntOrFloatBitWidth() + arg;
489b2a1bcfSAlex Zinenko   }
499b2a1bcfSAlex Zinenko 
509b2a1bcfSAlex Zinenko   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
519b2a1bcfSAlex Zinenko 
529b2a1bcfSAlex Zinenko   unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
539b2a1bcfSAlex Zinenko     return 128;
549b2a1bcfSAlex Zinenko   }
559b2a1bcfSAlex Zinenko 
569b2a1bcfSAlex Zinenko   static unsigned staticGetArgument(unsigned arg) { return 420; }
579b2a1bcfSAlex Zinenko };
589b2a1bcfSAlex Zinenko 
599b2a1bcfSAlex Zinenko TEST(InterfaceAttachment, Type) {
609b2a1bcfSAlex Zinenko   MLIRContext context;
619b2a1bcfSAlex Zinenko 
629b2a1bcfSAlex Zinenko   // Check that the type has no interface.
639b2a1bcfSAlex Zinenko   IntegerType i8 = IntegerType::get(&context, 8);
645550c821STres Popp   ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
659b2a1bcfSAlex Zinenko 
669b2a1bcfSAlex Zinenko   // Attach an interface and check that the type now has the interface.
679b2a1bcfSAlex Zinenko   IntegerType::attachInterface<Model>(context);
685550c821STres Popp   TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
699b2a1bcfSAlex Zinenko   ASSERT_TRUE(iface != nullptr);
709b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
719b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
729b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
739b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.staticGetArgument(17), 17u);
749b2a1bcfSAlex Zinenko 
759b2a1bcfSAlex Zinenko   // Same, but with the default implementation overridden.
769b2a1bcfSAlex Zinenko   FloatType flt = Float32Type::get(&context);
775550c821STres Popp   ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
789b2a1bcfSAlex Zinenko   Float32Type::attachInterface<OverridingModel>(context);
795550c821STres Popp   iface = dyn_cast<TestExternalTypeInterface>(flt);
809b2a1bcfSAlex Zinenko   ASSERT_TRUE(iface != nullptr);
819b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
829b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
839b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
849b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.staticGetArgument(17), 420u);
859b2a1bcfSAlex Zinenko 
869b2a1bcfSAlex Zinenko   // Other contexts shouldn't have the attribute attached.
879b2a1bcfSAlex Zinenko   MLIRContext other;
889b2a1bcfSAlex Zinenko   IntegerType i8other = IntegerType::get(&other, 8);
895550c821STres Popp   EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
909b2a1bcfSAlex Zinenko }
919b2a1bcfSAlex Zinenko 
92d7e89121SAlex Zinenko /// External interface model for the test type from the test dialect.
93d7e89121SAlex Zinenko struct TestTypeModel
94d7e89121SAlex Zinenko     : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
95d7e89121SAlex Zinenko                                                       test::TestType> {
96d7e89121SAlex Zinenko   unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
97d7e89121SAlex Zinenko 
98d7e89121SAlex Zinenko   static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
99d7e89121SAlex Zinenko };
100d7e89121SAlex Zinenko 
101d7e89121SAlex Zinenko TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
102d7e89121SAlex Zinenko   // Put the interface in the registry.
103d7e89121SAlex Zinenko   DialectRegistry registry;
104d7e89121SAlex Zinenko   registry.insert<test::TestDialect>();
10577eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
10677eee579SRiver Riddle     test::TestType::attachInterface<TestTypeModel>(*ctx);
10777eee579SRiver Riddle   });
108d7e89121SAlex Zinenko 
109d7e89121SAlex Zinenko   // Check that when a context is constructed with the given registry, the type
110d7e89121SAlex Zinenko   // interface gets registered.
111d7e89121SAlex Zinenko   MLIRContext context(registry);
112d7e89121SAlex Zinenko   context.loadDialect<test::TestDialect>();
113d7e89121SAlex Zinenko   test::TestType testType = test::TestType::get(&context);
1145550c821STres Popp   auto iface = dyn_cast<TestExternalTypeInterface>(testType);
115d7e89121SAlex Zinenko   ASSERT_TRUE(iface != nullptr);
116d7e89121SAlex Zinenko   EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
117d7e89121SAlex Zinenko   EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
118d7e89121SAlex Zinenko }
119d7e89121SAlex Zinenko 
120d7e89121SAlex Zinenko TEST(InterfaceAttachment, TypeDelayedContextAppend) {
121d7e89121SAlex Zinenko   // Put the interface in the registry.
122d7e89121SAlex Zinenko   DialectRegistry registry;
123d7e89121SAlex Zinenko   registry.insert<test::TestDialect>();
12477eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
12577eee579SRiver Riddle     test::TestType::attachInterface<TestTypeModel>(*ctx);
12677eee579SRiver Riddle   });
127d7e89121SAlex Zinenko 
128d7e89121SAlex Zinenko   // Check that when the registry gets appended to the context, the interface
129d7e89121SAlex Zinenko   // becomes available for objects in loaded dialects.
130d7e89121SAlex Zinenko   MLIRContext context;
131d7e89121SAlex Zinenko   context.loadDialect<test::TestDialect>();
132d7e89121SAlex Zinenko   test::TestType testType = test::TestType::get(&context);
1335550c821STres Popp   EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
134d7e89121SAlex Zinenko   context.appendDialectRegistry(registry);
1355550c821STres Popp   EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
136d7e89121SAlex Zinenko }
137d7e89121SAlex Zinenko 
138d7e89121SAlex Zinenko TEST(InterfaceAttachment, RepeatedRegistration) {
139d7e89121SAlex Zinenko   DialectRegistry registry;
14077eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
14177eee579SRiver Riddle     IntegerType::attachInterface<Model>(*ctx);
14277eee579SRiver Riddle   });
143d7e89121SAlex Zinenko   MLIRContext context(registry);
144d7e89121SAlex Zinenko 
145d7e89121SAlex Zinenko   // Should't fail on repeated registration through the dialect registry.
146d7e89121SAlex Zinenko   context.appendDialectRegistry(registry);
147d7e89121SAlex Zinenko }
148d7e89121SAlex Zinenko 
149d7e89121SAlex Zinenko TEST(InterfaceAttachment, TypeBuiltinDelayed) {
150d7e89121SAlex Zinenko   // Builtin dialect needs to registration or loading, but delayed interface
151d7e89121SAlex Zinenko   // registration must still work.
152d7e89121SAlex Zinenko   DialectRegistry registry;
15377eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
15477eee579SRiver Riddle     IntegerType::attachInterface<Model>(*ctx);
15577eee579SRiver Riddle   });
156d7e89121SAlex Zinenko 
157d7e89121SAlex Zinenko   MLIRContext context(registry);
158d7e89121SAlex Zinenko   IntegerType i16 = IntegerType::get(&context, 16);
1595550c821STres Popp   EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
160d7e89121SAlex Zinenko 
161d7e89121SAlex Zinenko   MLIRContext initiallyEmpty;
162d7e89121SAlex Zinenko   IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
1635550c821STres Popp   EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
164d7e89121SAlex Zinenko   initiallyEmpty.appendDialectRegistry(registry);
1655550c821STres Popp   EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
166d7e89121SAlex Zinenko }
167d7e89121SAlex Zinenko 
1689b2a1bcfSAlex Zinenko /// The interface provides a default implementation that expects
1699b2a1bcfSAlex Zinenko /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
1709b2a1bcfSAlex Zinenko /// just derives from the ExternalModel.
1719b2a1bcfSAlex Zinenko struct TestExternalFallbackTypeIntegerModel
1729b2a1bcfSAlex Zinenko     : public TestExternalFallbackTypeInterface::ExternalModel<
1739b2a1bcfSAlex Zinenko           TestExternalFallbackTypeIntegerModel, IntegerType> {};
1749b2a1bcfSAlex Zinenko 
1759b2a1bcfSAlex Zinenko /// The interface provides a default implementation that expects
1769b2a1bcfSAlex Zinenko /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
1779b2a1bcfSAlex Zinenko /// FallbackModel instead to override this and make sure the code still compiles
1789b2a1bcfSAlex Zinenko /// because we never instantiate the ExternalModel class template with a
1799b2a1bcfSAlex Zinenko /// template argument that would have led to compilation failures.
1809b2a1bcfSAlex Zinenko struct TestExternalFallbackTypeVectorModel
1819b2a1bcfSAlex Zinenko     : public TestExternalFallbackTypeInterface::FallbackModel<
1829b2a1bcfSAlex Zinenko           TestExternalFallbackTypeVectorModel> {
1839b2a1bcfSAlex Zinenko   unsigned getBitwidth(Type type) const {
1845550c821STres Popp     IntegerType elementType =
1855550c821STres Popp         dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
1869b2a1bcfSAlex Zinenko     return elementType ? elementType.getWidth() : 0;
1879b2a1bcfSAlex Zinenko   }
1889b2a1bcfSAlex Zinenko };
1899b2a1bcfSAlex Zinenko 
1909b2a1bcfSAlex Zinenko TEST(InterfaceAttachment, Fallback) {
1919b2a1bcfSAlex Zinenko   MLIRContext context;
1929b2a1bcfSAlex Zinenko 
1939b2a1bcfSAlex Zinenko   // Just check that we can attach the interface.
1949b2a1bcfSAlex Zinenko   IntegerType i8 = IntegerType::get(&context, 8);
1955550c821STres Popp   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
1969b2a1bcfSAlex Zinenko   IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
1975550c821STres Popp   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
1989b2a1bcfSAlex Zinenko 
1999b2a1bcfSAlex Zinenko   // Call the method so it is guaranteed not to be instantiated.
2009b2a1bcfSAlex Zinenko   VectorType vec = VectorType::get({42}, i8);
2015550c821STres Popp   ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
2029b2a1bcfSAlex Zinenko   VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
2035550c821STres Popp   ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
2045550c821STres Popp   EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
2059b2a1bcfSAlex Zinenko }
2069b2a1bcfSAlex Zinenko 
2079b2a1bcfSAlex Zinenko /// External model for attribute interfaces.
208d7e89121SAlex Zinenko struct TestExternalIntegerAttrModel
2099b2a1bcfSAlex Zinenko     : public TestExternalAttrInterface::ExternalModel<
210d7e89121SAlex Zinenko           TestExternalIntegerAttrModel, IntegerAttr> {
2119b2a1bcfSAlex Zinenko   const Dialect *getDialectPtr(Attribute attr) const {
2125550c821STres Popp     return &cast<IntegerAttr>(attr).getDialect();
2139b2a1bcfSAlex Zinenko   }
2149b2a1bcfSAlex Zinenko 
2159b2a1bcfSAlex Zinenko   static int getSomeNumber() { return 42; }
2169b2a1bcfSAlex Zinenko };
2179b2a1bcfSAlex Zinenko 
2189b2a1bcfSAlex Zinenko TEST(InterfaceAttachment, Attribute) {
2199b2a1bcfSAlex Zinenko   MLIRContext context;
2209b2a1bcfSAlex Zinenko 
2219b2a1bcfSAlex Zinenko   // Attribute interfaces use the exact same mechanism as types, so just check
2229b2a1bcfSAlex Zinenko   // that the basics work for attributes.
2239b2a1bcfSAlex Zinenko   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
2245550c821STres Popp   ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
225d7e89121SAlex Zinenko   IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
2265550c821STres Popp   auto iface = dyn_cast<TestExternalAttrInterface>(attr);
2279b2a1bcfSAlex Zinenko   ASSERT_TRUE(iface != nullptr);
2289b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
2299b2a1bcfSAlex Zinenko   EXPECT_EQ(iface.getSomeNumber(), 42);
2309b2a1bcfSAlex Zinenko }
2319b2a1bcfSAlex Zinenko 
232d7e89121SAlex Zinenko /// External model for an interface attachable to a non-builtin attribute.
233d7e89121SAlex Zinenko struct TestExternalSimpleAAttrModel
234d7e89121SAlex Zinenko     : public TestExternalAttrInterface::ExternalModel<
235d7e89121SAlex Zinenko           TestExternalSimpleAAttrModel, test::SimpleAAttr> {
236d7e89121SAlex Zinenko   const Dialect *getDialectPtr(Attribute attr) const {
237d7e89121SAlex Zinenko     return &attr.getDialect();
238d7e89121SAlex Zinenko   }
239d7e89121SAlex Zinenko 
240d7e89121SAlex Zinenko   static int getSomeNumber() { return 21; }
241d7e89121SAlex Zinenko };
242d7e89121SAlex Zinenko 
243d7e89121SAlex Zinenko TEST(InterfaceAttachmentTest, AttributeDelayed) {
244d7e89121SAlex Zinenko   // Attribute interfaces use the exact same mechanism as types, so just check
245d7e89121SAlex Zinenko   // that the delayed registration work for attributes.
246d7e89121SAlex Zinenko   DialectRegistry registry;
247d7e89121SAlex Zinenko   registry.insert<test::TestDialect>();
24877eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
24977eee579SRiver Riddle     test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
25077eee579SRiver Riddle   });
251d7e89121SAlex Zinenko 
252d7e89121SAlex Zinenko   MLIRContext context(registry);
253d7e89121SAlex Zinenko   context.loadDialect<test::TestDialect>();
254d7e89121SAlex Zinenko   auto attr = test::SimpleAAttr::get(&context);
2555550c821STres Popp   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
256d7e89121SAlex Zinenko 
257d7e89121SAlex Zinenko   MLIRContext initiallyEmpty;
258d7e89121SAlex Zinenko   initiallyEmpty.loadDialect<test::TestDialect>();
259d7e89121SAlex Zinenko   attr = test::SimpleAAttr::get(&initiallyEmpty);
2605550c821STres Popp   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
261d7e89121SAlex Zinenko   initiallyEmpty.appendDialectRegistry(registry);
2625550c821STres Popp   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
263d7e89121SAlex Zinenko }
264d7e89121SAlex Zinenko 
26523cdf7b6SAlex Zinenko /// External interface model for the module operation. Only provides non-default
26623cdf7b6SAlex Zinenko /// methods.
26723cdf7b6SAlex Zinenko struct TestExternalOpModel
26823cdf7b6SAlex Zinenko     : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
26923cdf7b6SAlex Zinenko                                                     ModuleOp> {
27023cdf7b6SAlex Zinenko   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
27123cdf7b6SAlex Zinenko     return op->getName().getStringRef().size() + arg;
27223cdf7b6SAlex Zinenko   }
27323cdf7b6SAlex Zinenko 
27423cdf7b6SAlex Zinenko   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
27523cdf7b6SAlex Zinenko     return ModuleOp::getOperationName().size() + 2 * arg;
27623cdf7b6SAlex Zinenko   }
27723cdf7b6SAlex Zinenko };
27823cdf7b6SAlex Zinenko 
27923cdf7b6SAlex Zinenko /// External interface model for the func operation. Provides non-deafult and
28023cdf7b6SAlex Zinenko /// overrides default methods.
28123cdf7b6SAlex Zinenko struct TestExternalOpOverridingModel
28223cdf7b6SAlex Zinenko     : public TestExternalOpInterface::FallbackModel<
28323cdf7b6SAlex Zinenko           TestExternalOpOverridingModel> {
28423cdf7b6SAlex Zinenko   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
28523cdf7b6SAlex Zinenko     return op->getName().getStringRef().size() + arg;
28623cdf7b6SAlex Zinenko   }
28723cdf7b6SAlex Zinenko 
28823cdf7b6SAlex Zinenko   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
28936550692SRiver Riddle     return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
29023cdf7b6SAlex Zinenko   }
29123cdf7b6SAlex Zinenko 
29223cdf7b6SAlex Zinenko   unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
29323cdf7b6SAlex Zinenko     return 42;
29423cdf7b6SAlex Zinenko   }
29523cdf7b6SAlex Zinenko 
29623cdf7b6SAlex Zinenko   static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
29723cdf7b6SAlex Zinenko };
29823cdf7b6SAlex Zinenko 
29923cdf7b6SAlex Zinenko TEST(InterfaceAttachment, Operation) {
30023cdf7b6SAlex Zinenko   MLIRContext context;
30136550692SRiver Riddle   OpBuilder builder(&context);
30223cdf7b6SAlex Zinenko 
30323cdf7b6SAlex Zinenko   // Initially, the operation doesn't have the interface.
30436550692SRiver Riddle   OwningOpRef<ModuleOp> moduleOp =
30536550692SRiver Riddle       builder.create<ModuleOp>(UnknownLoc::get(&context));
306db79f4a2SMehdi Amini   ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
30723cdf7b6SAlex Zinenko 
30823cdf7b6SAlex Zinenko   // We can attach an external interface and now the operaiton has it.
30923cdf7b6SAlex Zinenko   ModuleOp::attachInterface<TestExternalOpModel>(context);
310db79f4a2SMehdi Amini   auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
31123cdf7b6SAlex Zinenko   ASSERT_TRUE(iface != nullptr);
312f8479d9dSRiver Riddle   EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
313f8479d9dSRiver Riddle   EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
314f8479d9dSRiver Riddle   EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
315f8479d9dSRiver Riddle   EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
31623cdf7b6SAlex Zinenko 
31723cdf7b6SAlex Zinenko   // Default implementation can be overridden.
31836550692SRiver Riddle   OwningOpRef<UnrealizedConversionCastOp> castOp =
31936550692SRiver Riddle       builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
32036550692SRiver Riddle                                                  TypeRange(), ValueRange());
32136550692SRiver Riddle   ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
32236550692SRiver Riddle   UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
32336550692SRiver Riddle       context);
32436550692SRiver Riddle   iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
32523cdf7b6SAlex Zinenko   ASSERT_TRUE(iface != nullptr);
32636550692SRiver Riddle   EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
32723cdf7b6SAlex Zinenko   EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
32836550692SRiver Riddle   EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
32923cdf7b6SAlex Zinenko   EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
33023cdf7b6SAlex Zinenko 
33123cdf7b6SAlex Zinenko   // Another context doesn't have the interfaces registered.
33223cdf7b6SAlex Zinenko   MLIRContext other;
333db79f4a2SMehdi Amini   OwningOpRef<ModuleOp> otherModuleOp =
334db79f4a2SMehdi Amini       ModuleOp::create(UnknownLoc::get(&other));
335db79f4a2SMehdi Amini   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
33623cdf7b6SAlex Zinenko }
33723cdf7b6SAlex Zinenko 
3389b50844fSVladislav Vinogradov template <class ConcreteOp>
339d7e89121SAlex Zinenko struct TestExternalTestOpModel
3409b50844fSVladislav Vinogradov     : public TestExternalOpInterface::ExternalModel<
3419b50844fSVladislav Vinogradov           TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
342d7e89121SAlex Zinenko   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
343d7e89121SAlex Zinenko     return op->getName().getStringRef().size() + arg;
344d7e89121SAlex Zinenko   }
345d7e89121SAlex Zinenko 
346d7e89121SAlex Zinenko   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
3479b50844fSVladislav Vinogradov     return ConcreteOp::getOperationName().size() + 2 * arg;
348d7e89121SAlex Zinenko   }
349d7e89121SAlex Zinenko };
350d7e89121SAlex Zinenko 
351d7e89121SAlex Zinenko TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
352d7e89121SAlex Zinenko   DialectRegistry registry;
353d7e89121SAlex Zinenko   registry.insert<test::TestDialect>();
35477eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
35577eee579SRiver Riddle     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
35677eee579SRiver Riddle   });
35777eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
35877eee579SRiver Riddle     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
35977eee579SRiver Riddle     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
36077eee579SRiver Riddle   });
361d7e89121SAlex Zinenko 
36277eee579SRiver Riddle   // Construct the context directly from a registry. The interfaces are
36377eee579SRiver Riddle   // expected to be readily available on operations.
364d7e89121SAlex Zinenko   MLIRContext context(registry);
365d7e89121SAlex Zinenko   context.loadDialect<test::TestDialect>();
3669b50844fSVladislav Vinogradov 
367db79f4a2SMehdi Amini   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
368db79f4a2SMehdi Amini   OpBuilder builder(module->getBody(), module->getBody()->begin());
3699b50844fSVladislav Vinogradov   auto opJ =
370d7e89121SAlex Zinenko       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
3719b50844fSVladislav Vinogradov   auto opH =
3729b50844fSVladislav Vinogradov       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
3739b50844fSVladislav Vinogradov   auto opI =
3749b50844fSVladislav Vinogradov       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
3759b50844fSVladislav Vinogradov 
376db79f4a2SMehdi Amini   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
3779b50844fSVladislav Vinogradov   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
3789b50844fSVladislav Vinogradov   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
3799b50844fSVladislav Vinogradov   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
380d7e89121SAlex Zinenko }
381d7e89121SAlex Zinenko 
382d7e89121SAlex Zinenko TEST(InterfaceAttachment, OperationDelayedContextAppend) {
383d7e89121SAlex Zinenko   DialectRegistry registry;
384d7e89121SAlex Zinenko   registry.insert<test::TestDialect>();
38577eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
38677eee579SRiver Riddle     ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
38777eee579SRiver Riddle   });
38877eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
38977eee579SRiver Riddle     test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
39077eee579SRiver Riddle     test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
39177eee579SRiver Riddle   });
392d7e89121SAlex Zinenko 
393d7e89121SAlex Zinenko   // Construct the context, create ops, and only then append the registry. The
394d7e89121SAlex Zinenko   // interfaces are expected to be available after appending the registry.
395d7e89121SAlex Zinenko   MLIRContext context;
396d7e89121SAlex Zinenko   context.loadDialect<test::TestDialect>();
3979b50844fSVladislav Vinogradov 
398db79f4a2SMehdi Amini   OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
399db79f4a2SMehdi Amini   OpBuilder builder(module->getBody(), module->getBody()->begin());
4009b50844fSVladislav Vinogradov   auto opJ =
401d7e89121SAlex Zinenko       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
4029b50844fSVladislav Vinogradov   auto opH =
4039b50844fSVladislav Vinogradov       builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
4049b50844fSVladislav Vinogradov   auto opI =
4059b50844fSVladislav Vinogradov       builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
4069b50844fSVladislav Vinogradov 
407db79f4a2SMehdi Amini   EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
4089b50844fSVladislav Vinogradov   EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
4099b50844fSVladislav Vinogradov   EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
4109b50844fSVladislav Vinogradov   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
4119b50844fSVladislav Vinogradov 
412d7e89121SAlex Zinenko   context.appendDialectRegistry(registry);
4139b50844fSVladislav Vinogradov 
414db79f4a2SMehdi Amini   EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
4159b50844fSVladislav Vinogradov   EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
4169b50844fSVladislav Vinogradov   EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
4179b50844fSVladislav Vinogradov   EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
418d7e89121SAlex Zinenko }
419d7e89121SAlex Zinenko 
420*d0e6fd99SFabian Mora TEST(InterfaceAttachmentTest, PromisedInterfaces) {
421*d0e6fd99SFabian Mora   // Attribute interfaces use the exact same mechanism as types, so just check
422*d0e6fd99SFabian Mora   // that the promise mechanism works for attributes.
423*d0e6fd99SFabian Mora   MLIRContext context;
424*d0e6fd99SFabian Mora   auto testDialect = context.getOrLoadDialect<test::TestDialect>();
425*d0e6fd99SFabian Mora   auto attr = test::SimpleAAttr::get(&context);
426*d0e6fd99SFabian Mora 
427*d0e6fd99SFabian Mora   // `SimpleAAttr` doesn't implement nor promises the
428*d0e6fd99SFabian Mora   // `TestExternalAttrInterface` interface.
429*d0e6fd99SFabian Mora   EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
430*d0e6fd99SFabian Mora   EXPECT_FALSE(
431*d0e6fd99SFabian Mora       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
432*d0e6fd99SFabian Mora 
433*d0e6fd99SFabian Mora   // Add a promise `TestExternalAttrInterface`.
434*d0e6fd99SFabian Mora   testDialect->declarePromisedInterface<test::SimpleAAttr,
435*d0e6fd99SFabian Mora                                         TestExternalAttrInterface>();
436*d0e6fd99SFabian Mora   EXPECT_TRUE(
437*d0e6fd99SFabian Mora       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
438*d0e6fd99SFabian Mora 
439*d0e6fd99SFabian Mora   // Attach the interface.
440*d0e6fd99SFabian Mora   test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
441*d0e6fd99SFabian Mora   EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
442*d0e6fd99SFabian Mora   EXPECT_TRUE(
443*d0e6fd99SFabian Mora       attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
444*d0e6fd99SFabian Mora }
445*d0e6fd99SFabian Mora 
446be0a7e9fSMehdi Amini } // namespace
447