xref: /llvm-project/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (revision 21646789497346a1a8dabb4b369e12db482b4daa)
1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 constexpr static llvm::StringLiteral kAllocaKeyName =
26     "dltest.alloca_memory_space";
27 constexpr static llvm::StringLiteral kStackAlignmentKeyName =
28     "dltest.stack_alignment";
29 
30 /// Trivial array storage for the custom data layout spec attribute, just a list
31 /// of entries.
32 class DataLayoutSpecStorage : public AttributeStorage {
33 public:
34   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
35 
36   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
37       : entries(entries) {}
38 
39   bool operator==(const KeyTy &key) const { return key == entries; }
40 
41   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
42                                           const KeyTy &key) {
43     return new (allocator.allocate<DataLayoutSpecStorage>())
44         DataLayoutSpecStorage(allocator.copyInto(key));
45   }
46 
47   ArrayRef<DataLayoutEntryInterface> entries;
48 };
49 
50 /// Simple data layout spec containing a list of entries that always verifies
51 /// as valid.
52 struct CustomDataLayoutSpec
53     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
54                                  DataLayoutSpecStorage,
55                                  DataLayoutSpecInterface::Trait> {
56   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
57 
58   using Base::Base;
59   static CustomDataLayoutSpec get(MLIRContext *ctx,
60                                   ArrayRef<DataLayoutEntryInterface> entries) {
61     return Base::get(ctx, entries);
62   }
63   CustomDataLayoutSpec
64   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
65     return *this;
66   }
67   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
68   LogicalResult verifySpec(Location loc) { return success(); }
69   StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
70     return Builder(context).getStringAttr(kAllocaKeyName);
71   }
72   StringAttr getStackAlignmentIdentifier(MLIRContext *context) const {
73     return Builder(context).getStringAttr(kStackAlignmentKeyName);
74   }
75 };
76 
77 /// A type subject to data layout that exits the program if it is queried more
78 /// than once. Handy to check if the cache works.
79 struct SingleQueryType
80     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
81                             DataLayoutTypeInterface::Trait> {
82   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType)
83 
84   using Base::Base;
85 
86   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
87 
88   llvm::TypeSize getTypeSizeInBits(const DataLayout &layout,
89                                    DataLayoutEntryListRef params) const {
90     static bool executed = false;
91     if (executed)
92       llvm::report_fatal_error("repeated call");
93 
94     executed = true;
95     return llvm::TypeSize::getFixed(1);
96   }
97 
98   uint64_t getABIAlignment(const DataLayout &layout,
99                            DataLayoutEntryListRef params) {
100     static bool executed = false;
101     if (executed)
102       llvm::report_fatal_error("repeated call");
103 
104     executed = true;
105     return 2;
106   }
107 
108   uint64_t getPreferredAlignment(const DataLayout &layout,
109                                  DataLayoutEntryListRef params) {
110     static bool executed = false;
111     if (executed)
112       llvm::report_fatal_error("repeated call");
113 
114     executed = true;
115     return 4;
116   }
117 
118   Attribute getAllocaMemorySpace(DataLayoutEntryInterface entry) {
119     static bool executed = false;
120     if (executed)
121       llvm::report_fatal_error("repeated call");
122 
123     executed = true;
124     return Attribute();
125   }
126 };
127 
128 /// A types that is not subject to data layout.
129 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
130   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout)
131 
132   using Base::Base;
133 
134   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
135 };
136 
137 /// An op that serves as scope for data layout queries with the relevant
138 /// attribute attached. This can handle data layout requests for the built-in
139 /// types itself.
140 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
141   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout)
142 
143   using Op::Op;
144   static ArrayRef<StringRef> getAttributeNames() { return {}; }
145 
146   static StringRef getOperationName() { return "dltest.op_with_layout"; }
147 
148   DataLayoutSpecInterface getDataLayoutSpec() {
149     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
150   }
151 
152   static llvm::TypeSize getTypeSizeInBits(Type type,
153                                           const DataLayout &dataLayout,
154                                           DataLayoutEntryListRef params) {
155     // Make a recursive query.
156     if (isa<FloatType>(type))
157       return dataLayout.getTypeSizeInBits(
158           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
159 
160     // Handle built-in types that are not handled by the default process.
161     if (auto iType = dyn_cast<IntegerType>(type)) {
162       for (DataLayoutEntryInterface entry : params)
163         if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
164           return llvm::TypeSize::getFixed(
165               8 *
166               cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue());
167       return llvm::TypeSize::getFixed(8 * iType.getIntOrFloatBitWidth());
168     }
169 
170     // Use the default process for everything else.
171     return detail::getDefaultTypeSize(type, dataLayout, params);
172   }
173 
174   static uint64_t getTypeABIAlignment(Type type, const DataLayout &dataLayout,
175                                       DataLayoutEntryListRef params) {
176     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
177   }
178 
179   static uint64_t getTypePreferredAlignment(Type type,
180                                             const DataLayout &dataLayout,
181                                             DataLayoutEntryListRef params) {
182     return 2 * getTypeABIAlignment(type, dataLayout, params);
183   }
184 };
185 
186 struct OpWith7BitByte
187     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
188   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte)
189 
190   using Op::Op;
191   static ArrayRef<StringRef> getAttributeNames() { return {}; }
192 
193   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
194 
195   DataLayoutSpecInterface getDataLayoutSpec() {
196     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
197   }
198 
199   // Bytes are assumed to be 7-bit here.
200   static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
201                                     DataLayoutEntryListRef params) {
202     return mlir::detail::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
203   }
204 };
205 
206 /// A dialect putting all the above together.
207 struct DLTestDialect : Dialect {
208   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect)
209 
210   explicit DLTestDialect(MLIRContext *ctx)
211       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
212     ctx->getOrLoadDialect<DLTIDialect>();
213     addAttributes<CustomDataLayoutSpec>();
214     addOperations<OpWithLayout, OpWith7BitByte>();
215     addTypes<SingleQueryType, TypeNoLayout>();
216   }
217   static StringRef getDialectNamespace() { return "dltest"; }
218 
219   void printAttribute(Attribute attr,
220                       DialectAsmPrinter &printer) const override {
221     printer << "spec<";
222     llvm::interleaveComma(cast<CustomDataLayoutSpec>(attr).getEntries(),
223                           printer);
224     printer << ">";
225   }
226 
227   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
228     bool ok =
229         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
230     (void)ok;
231     assert(ok);
232     if (succeeded(parser.parseOptionalGreater()))
233       return CustomDataLayoutSpec::get(parser.getContext(), {});
234 
235     SmallVector<DataLayoutEntryInterface> entries;
236     ok = succeeded(parser.parseCommaSeparatedList([&]() {
237       entries.emplace_back();
238       ok = succeeded(parser.parseAttribute(entries.back()));
239       assert(ok);
240       return success();
241     }));
242     assert(ok);
243     ok = succeeded(parser.parseGreater());
244     assert(ok);
245     return CustomDataLayoutSpec::get(parser.getContext(), entries);
246   }
247 
248   void printType(Type type, DialectAsmPrinter &printer) const override {
249     if (isa<SingleQueryType>(type))
250       printer << "single_query";
251     else
252       printer << "no_layout";
253   }
254 
255   Type parseType(DialectAsmParser &parser) const override {
256     bool ok = succeeded(parser.parseKeyword("single_query"));
257     (void)ok;
258     assert(ok);
259     return SingleQueryType::get(parser.getContext());
260   }
261 };
262 
263 } // namespace
264 
265 TEST(DataLayout, FallbackDefault) {
266   const char *ir = R"MLIR(
267 module {}
268   )MLIR";
269 
270   DialectRegistry registry;
271   registry.insert<DLTIDialect, DLTestDialect>();
272   MLIRContext ctx(registry);
273 
274   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
275   DataLayout layout(module.get());
276   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
277   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
278   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
279   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
280   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
281   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
282   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
283   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
284 
285   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
286   EXPECT_EQ(layout.getStackAlignment(), 0u);
287 }
288 
289 TEST(DataLayout, NullSpec) {
290   const char *ir = R"MLIR(
291 "dltest.op_with_layout"() : () -> ()
292   )MLIR";
293 
294   DialectRegistry registry;
295   registry.insert<DLTIDialect, DLTestDialect>();
296   MLIRContext ctx(registry);
297 
298   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
299   auto op =
300       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
301   DataLayout layout(op);
302 
303   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
304   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
305   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
306   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
307   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
308   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
309   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
310   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
311 
312   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
313   EXPECT_EQ(layout.getStackAlignment(), 0u);
314 }
315 
316 TEST(DataLayout, EmptySpec) {
317   const char *ir = R"MLIR(
318 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
319   )MLIR";
320 
321   DialectRegistry registry;
322   registry.insert<DLTIDialect, DLTestDialect>();
323   MLIRContext ctx(registry);
324 
325   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
326   auto op =
327       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
328   DataLayout layout(op);
329   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
330   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
331   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
332   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
333   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
334   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
335   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
336   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
337 
338   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
339   EXPECT_EQ(layout.getStackAlignment(), 0u);
340 }
341 
342 TEST(DataLayout, SpecWithEntries) {
343   const char *ir = R"MLIR(
344 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
345   #dlti.dl_entry<i42, 5>,
346   #dlti.dl_entry<i16, 6>,
347   #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>,
348   #dlti.dl_entry<"dltest.stack_alignment", 128 : i32>
349 > } : () -> ()
350   )MLIR";
351 
352   DialectRegistry registry;
353   registry.insert<DLTIDialect, DLTestDialect>();
354   MLIRContext ctx(registry);
355 
356   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
357   auto op =
358       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
359   DataLayout layout(op);
360   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
361   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
362   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
363   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
364   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
365   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
366   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
367   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
368 
369   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
370   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
371   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
372   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
373   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
374   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
375   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
376   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
377 
378   EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5));
379   EXPECT_EQ(layout.getStackAlignment(), 128u);
380 }
381 
382 TEST(DataLayout, Caching) {
383   const char *ir = R"MLIR(
384 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
385   )MLIR";
386 
387   DialectRegistry registry;
388   registry.insert<DLTIDialect, DLTestDialect>();
389   MLIRContext ctx(registry);
390 
391   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
392   auto op =
393       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
394   DataLayout layout(op);
395 
396   unsigned sum = 0;
397   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
398   // The second call should hit the cache. If it does not, the function in
399   // SingleQueryType will be called and will abort the process.
400   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
401   // Make sure the complier doesn't optimize away the query code.
402   EXPECT_EQ(sum, 2u);
403 
404   // A fresh data layout has a new cache, so the call to it should be dispatched
405   // down to the type and abort the process.
406   DataLayout second(op);
407   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
408 }
409 
410 TEST(DataLayout, CacheInvalidation) {
411   const char *ir = R"MLIR(
412 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
413   #dlti.dl_entry<i42, 5>,
414   #dlti.dl_entry<i16, 6>
415 > } : () -> ()
416   )MLIR";
417 
418   DialectRegistry registry;
419   registry.insert<DLTIDialect, DLTestDialect>();
420   MLIRContext ctx(registry);
421 
422   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
423   auto op =
424       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
425   DataLayout layout(op);
426 
427   // Normal query is fine.
428   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
429 
430   // Replace the data layout spec with a new, empty spec.
431   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
432 
433   // Data layout is no longer valid and should trigger assertion when queried.
434 #ifndef NDEBUG
435   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
436 #endif
437 }
438 
439 TEST(DataLayout, UnimplementedTypeInterface) {
440   const char *ir = R"MLIR(
441 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
442   )MLIR";
443 
444   DialectRegistry registry;
445   registry.insert<DLTIDialect, DLTestDialect>();
446   MLIRContext ctx(registry);
447 
448   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
449   auto op =
450       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
451   DataLayout layout(op);
452 
453   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
454                "neither the scoping op nor the type class provide data layout "
455                "information");
456 }
457 
458 TEST(DataLayout, SevenBitByte) {
459   const char *ir = R"MLIR(
460 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
461   )MLIR";
462 
463   DialectRegistry registry;
464   registry.insert<DLTIDialect, DLTestDialect>();
465   MLIRContext ctx(registry);
466 
467   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
468   auto op =
469       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
470   DataLayout layout(op);
471 
472   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
473   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
474   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
475   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
476 }
477