xref: /llvm-project/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (revision 1d7b5cd5bf8cfe2593109ace361ad37ec3b54a1f)
1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 
26 /// Trivial array storage for the custom data layout spec attribute, just a list
27 /// of entries.
28 class DataLayoutSpecStorage : public AttributeStorage {
29 public:
30   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
31 
32   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
33       : entries(entries) {}
34 
35   bool operator==(const KeyTy &key) const { return key == entries; }
36 
37   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
38                                           const KeyTy &key) {
39     return new (allocator.allocate<DataLayoutSpecStorage>())
40         DataLayoutSpecStorage(allocator.copyInto(key));
41   }
42 
43   ArrayRef<DataLayoutEntryInterface> entries;
44 };
45 
46 /// Simple data layout spec containing a list of entries that always verifies
47 /// as valid.
48 struct CustomDataLayoutSpec
49     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
50                                  DataLayoutSpecStorage,
51                                  DataLayoutSpecInterface::Trait> {
52   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
53 
54   using Base::Base;
55   static CustomDataLayoutSpec get(MLIRContext *ctx,
56                                   ArrayRef<DataLayoutEntryInterface> entries) {
57     return Base::get(ctx, entries);
58   }
59   CustomDataLayoutSpec
60   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
61     return *this;
62   }
63   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
64   LogicalResult verifySpec(Location loc) { return success(); }
65 };
66 
67 /// A type subject to data layout that exits the program if it is queried more
68 /// than once. Handy to check if the cache works.
69 struct SingleQueryType
70     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
71                             DataLayoutTypeInterface::Trait> {
72   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType)
73 
74   using Base::Base;
75 
76   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
77 
78   unsigned getTypeSizeInBits(const DataLayout &layout,
79                              DataLayoutEntryListRef params) const {
80     static bool executed = false;
81     if (executed)
82       llvm::report_fatal_error("repeated call");
83 
84     executed = true;
85     return 1;
86   }
87 
88   unsigned getABIAlignment(const DataLayout &layout,
89                            DataLayoutEntryListRef params) {
90     static bool executed = false;
91     if (executed)
92       llvm::report_fatal_error("repeated call");
93 
94     executed = true;
95     return 2;
96   }
97 
98   unsigned getPreferredAlignment(const DataLayout &layout,
99                                  DataLayoutEntryListRef params) {
100     static bool executed = false;
101     if (executed)
102       llvm::report_fatal_error("repeated call");
103 
104     executed = true;
105     return 4;
106   }
107 };
108 
109 /// A types that is not subject to data layout.
110 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
111   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout)
112 
113   using Base::Base;
114 
115   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
116 };
117 
118 /// An op that serves as scope for data layout queries with the relevant
119 /// attribute attached. This can handle data layout requests for the built-in
120 /// types itself.
121 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
122   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout)
123 
124   using Op::Op;
125   static ArrayRef<StringRef> getAttributeNames() { return {}; }
126 
127   static StringRef getOperationName() { return "dltest.op_with_layout"; }
128 
129   DataLayoutSpecInterface getDataLayoutSpec() {
130     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
131   }
132 
133   static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout,
134                                     DataLayoutEntryListRef params) {
135     // Make a recursive query.
136     if (type.isa<FloatType>())
137       return dataLayout.getTypeSizeInBits(
138           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
139 
140     // Handle built-in types that are not handled by the default process.
141     if (auto iType = type.dyn_cast<IntegerType>()) {
142       for (DataLayoutEntryInterface entry : params)
143         if (entry.getKey().dyn_cast<Type>() == type)
144           return 8 *
145                  entry.getValue().cast<IntegerAttr>().getValue().getZExtValue();
146       return 8 * iType.getIntOrFloatBitWidth();
147     }
148 
149     // Use the default process for everything else.
150     return detail::getDefaultTypeSize(type, dataLayout, params);
151   }
152 
153   static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout,
154                                       DataLayoutEntryListRef params) {
155     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
156   }
157 
158   static unsigned getTypePreferredAlignment(Type type,
159                                             const DataLayout &dataLayout,
160                                             DataLayoutEntryListRef params) {
161     return 2 * getTypeABIAlignment(type, dataLayout, params);
162   }
163 };
164 
165 struct OpWith7BitByte
166     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
167   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte)
168 
169   using Op::Op;
170   static ArrayRef<StringRef> getAttributeNames() { return {}; }
171 
172   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
173 
174   DataLayoutSpecInterface getDataLayoutSpec() {
175     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
176   }
177 
178   // Bytes are assumed to be 7-bit here.
179   static unsigned getTypeSize(Type type, const DataLayout &dataLayout,
180                               DataLayoutEntryListRef params) {
181     return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
182   }
183 };
184 
185 /// A dialect putting all the above together.
186 struct DLTestDialect : Dialect {
187   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect)
188 
189   explicit DLTestDialect(MLIRContext *ctx)
190       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
191     ctx->getOrLoadDialect<DLTIDialect>();
192     addAttributes<CustomDataLayoutSpec>();
193     addOperations<OpWithLayout, OpWith7BitByte>();
194     addTypes<SingleQueryType, TypeNoLayout>();
195   }
196   static StringRef getDialectNamespace() { return "dltest"; }
197 
198   void printAttribute(Attribute attr,
199                       DialectAsmPrinter &printer) const override {
200     printer << "spec<";
201     llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(),
202                           printer);
203     printer << ">";
204   }
205 
206   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
207     bool ok =
208         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
209     (void)ok;
210     assert(ok);
211     if (succeeded(parser.parseOptionalGreater()))
212       return CustomDataLayoutSpec::get(parser.getContext(), {});
213 
214     SmallVector<DataLayoutEntryInterface> entries;
215     ok = succeeded(parser.parseCommaSeparatedList([&]() {
216       entries.emplace_back();
217       ok = succeeded(parser.parseAttribute(entries.back()));
218       assert(ok);
219       return success();
220     }));
221     assert(ok);
222     ok = succeeded(parser.parseGreater());
223     assert(ok);
224     return CustomDataLayoutSpec::get(parser.getContext(), entries);
225   }
226 
227   void printType(Type type, DialectAsmPrinter &printer) const override {
228     if (type.isa<SingleQueryType>())
229       printer << "single_query";
230     else
231       printer << "no_layout";
232   }
233 
234   Type parseType(DialectAsmParser &parser) const override {
235     bool ok = succeeded(parser.parseKeyword("single_query"));
236     (void)ok;
237     assert(ok);
238     return SingleQueryType::get(parser.getContext());
239   }
240 };
241 
242 } // namespace
243 
244 TEST(DataLayout, FallbackDefault) {
245   const char *ir = R"MLIR(
246 module {}
247   )MLIR";
248 
249   DialectRegistry registry;
250   registry.insert<DLTIDialect, DLTestDialect>();
251   MLIRContext ctx(registry);
252 
253   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
254   DataLayout layout(module.get());
255   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
256   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
257   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
258   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
259   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
260   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
261   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
262   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
263 }
264 
265 TEST(DataLayout, NullSpec) {
266   const char *ir = R"MLIR(
267 "dltest.op_with_layout"() : () -> ()
268   )MLIR";
269 
270   DialectRegistry registry;
271   registry.insert<DLTIDialect, DLTestDialect>();
272   MLIRContext ctx(registry);
273 
274   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
275   auto op =
276       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
277   DataLayout layout(op);
278   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
279   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
280   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
281   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
282   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
283   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
284   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
285   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
286 }
287 
288 TEST(DataLayout, EmptySpec) {
289   const char *ir = R"MLIR(
290 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
291   )MLIR";
292 
293   DialectRegistry registry;
294   registry.insert<DLTIDialect, DLTestDialect>();
295   MLIRContext ctx(registry);
296 
297   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
298   auto op =
299       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
300   DataLayout layout(op);
301   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
302   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
303   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
304   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
305   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
306   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
307   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
308   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
309 }
310 
311 TEST(DataLayout, SpecWithEntries) {
312   const char *ir = R"MLIR(
313 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
314   #dlti.dl_entry<i42, 5>,
315   #dlti.dl_entry<i16, 6>
316 > } : () -> ()
317   )MLIR";
318 
319   DialectRegistry registry;
320   registry.insert<DLTIDialect, DLTestDialect>();
321   MLIRContext ctx(registry);
322 
323   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
324   auto op =
325       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
326   DataLayout layout(op);
327   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
328   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
329   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
330   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
331   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
332   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
333   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
334   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
335 
336   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
337   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
338   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
339   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
340   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
341   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
342   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
343   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
344 }
345 
346 TEST(DataLayout, Caching) {
347   const char *ir = R"MLIR(
348 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
349   )MLIR";
350 
351   DialectRegistry registry;
352   registry.insert<DLTIDialect, DLTestDialect>();
353   MLIRContext ctx(registry);
354 
355   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
356   auto op =
357       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
358   DataLayout layout(op);
359 
360   unsigned sum = 0;
361   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
362   // The second call should hit the cache. If it does not, the function in
363   // SingleQueryType will be called and will abort the process.
364   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
365   // Make sure the complier doesn't optimize away the query code.
366   EXPECT_EQ(sum, 2u);
367 
368   // A fresh data layout has a new cache, so the call to it should be dispatched
369   // down to the type and abort the proces.
370   DataLayout second(op);
371   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
372 }
373 
374 TEST(DataLayout, CacheInvalidation) {
375   const char *ir = R"MLIR(
376 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
377   #dlti.dl_entry<i42, 5>,
378   #dlti.dl_entry<i16, 6>
379 > } : () -> ()
380   )MLIR";
381 
382   DialectRegistry registry;
383   registry.insert<DLTIDialect, DLTestDialect>();
384   MLIRContext ctx(registry);
385 
386   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
387   auto op =
388       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
389   DataLayout layout(op);
390 
391   // Normal query is fine.
392   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
393 
394   // Replace the data layout spec with a new, empty spec.
395   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
396 
397   // Data layout is no longer valid and should trigger assertion when queried.
398 #ifndef NDEBUG
399   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
400 #endif
401 }
402 
403 TEST(DataLayout, UnimplementedTypeInterface) {
404   const char *ir = R"MLIR(
405 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
406   )MLIR";
407 
408   DialectRegistry registry;
409   registry.insert<DLTIDialect, DLTestDialect>();
410   MLIRContext ctx(registry);
411 
412   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
413   auto op =
414       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
415   DataLayout layout(op);
416 
417   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
418                "neither the scoping op nor the type class provide data layout "
419                "information");
420 }
421 
422 TEST(DataLayout, SevenBitByte) {
423   const char *ir = R"MLIR(
424 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
425   )MLIR";
426 
427   DialectRegistry registry;
428   registry.insert<DLTIDialect, DLTestDialect>();
429   MLIRContext ctx(registry);
430 
431   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
432   auto op =
433       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
434   DataLayout layout(op);
435 
436   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
437   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
438   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
439   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
440 }
441