xref: /llvm-project/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (revision 9d69bca18787e0861c1bc6db8e82d4f21b77ae75)
1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 constexpr static llvm::StringLiteral kAllocaKeyName =
26     "dltest.alloca_memory_space";
27 constexpr static llvm::StringLiteral kStackAlignmentKeyName =
28     "dltest.stack_alignment";
29 
30 /// Trivial array storage for the custom data layout spec attribute, just a list
31 /// of entries.
32 class DataLayoutSpecStorage : public AttributeStorage {
33 public:
34   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
35 
36   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
37       : entries(entries) {}
38 
39   bool operator==(const KeyTy &key) const { return key == entries; }
40 
41   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
42                                           const KeyTy &key) {
43     return new (allocator.allocate<DataLayoutSpecStorage>())
44         DataLayoutSpecStorage(allocator.copyInto(key));
45   }
46 
47   ArrayRef<DataLayoutEntryInterface> entries;
48 };
49 
50 /// Simple data layout spec containing a list of entries that always verifies
51 /// as valid.
52 struct CustomDataLayoutSpec
53     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
54                                  DataLayoutSpecStorage,
55                                  DataLayoutSpecInterface::Trait> {
56   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
57 
58   using Base::Base;
59   static CustomDataLayoutSpec get(MLIRContext *ctx,
60                                   ArrayRef<DataLayoutEntryInterface> entries) {
61     return Base::get(ctx, entries);
62   }
63   CustomDataLayoutSpec
64   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
65     return *this;
66   }
67   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
68   LogicalResult verifySpec(Location loc) { return success(); }
69   StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
70     return Builder(context).getStringAttr(kAllocaKeyName);
71   }
72   StringAttr getStackAlignmentIdentifier(MLIRContext *context) const {
73     return Builder(context).getStringAttr(kStackAlignmentKeyName);
74   }
75 };
76 
77 /// A type subject to data layout that exits the program if it is queried more
78 /// than once. Handy to check if the cache works.
79 struct SingleQueryType
80     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
81                             DataLayoutTypeInterface::Trait> {
82   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType)
83 
84   using Base::Base;
85 
86   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
87 
88   unsigned getTypeSizeInBits(const DataLayout &layout,
89                              DataLayoutEntryListRef params) const {
90     static bool executed = false;
91     if (executed)
92       llvm::report_fatal_error("repeated call");
93 
94     executed = true;
95     return 1;
96   }
97 
98   unsigned getABIAlignment(const DataLayout &layout,
99                            DataLayoutEntryListRef params) {
100     static bool executed = false;
101     if (executed)
102       llvm::report_fatal_error("repeated call");
103 
104     executed = true;
105     return 2;
106   }
107 
108   unsigned getPreferredAlignment(const DataLayout &layout,
109                                  DataLayoutEntryListRef params) {
110     static bool executed = false;
111     if (executed)
112       llvm::report_fatal_error("repeated call");
113 
114     executed = true;
115     return 4;
116   }
117 
118   Attribute getAllocaMemorySpace(DataLayoutEntryInterface entry) {
119     static bool executed = false;
120     if (executed)
121       llvm::report_fatal_error("repeated call");
122 
123     executed = true;
124     return Attribute();
125   }
126 };
127 
128 /// A types that is not subject to data layout.
129 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
130   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout)
131 
132   using Base::Base;
133 
134   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
135 };
136 
137 /// An op that serves as scope for data layout queries with the relevant
138 /// attribute attached. This can handle data layout requests for the built-in
139 /// types itself.
140 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
141   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout)
142 
143   using Op::Op;
144   static ArrayRef<StringRef> getAttributeNames() { return {}; }
145 
146   static StringRef getOperationName() { return "dltest.op_with_layout"; }
147 
148   DataLayoutSpecInterface getDataLayoutSpec() {
149     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
150   }
151 
152   static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout,
153                                     DataLayoutEntryListRef params) {
154     // Make a recursive query.
155     if (type.isa<FloatType>())
156       return dataLayout.getTypeSizeInBits(
157           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
158 
159     // Handle built-in types that are not handled by the default process.
160     if (auto iType = type.dyn_cast<IntegerType>()) {
161       for (DataLayoutEntryInterface entry : params)
162         if (entry.getKey().dyn_cast<Type>() == type)
163           return 8 *
164                  entry.getValue().cast<IntegerAttr>().getValue().getZExtValue();
165       return 8 * iType.getIntOrFloatBitWidth();
166     }
167 
168     // Use the default process for everything else.
169     return detail::getDefaultTypeSize(type, dataLayout, params);
170   }
171 
172   static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout,
173                                       DataLayoutEntryListRef params) {
174     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
175   }
176 
177   static unsigned getTypePreferredAlignment(Type type,
178                                             const DataLayout &dataLayout,
179                                             DataLayoutEntryListRef params) {
180     return 2 * getTypeABIAlignment(type, dataLayout, params);
181   }
182 };
183 
184 struct OpWith7BitByte
185     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
186   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte)
187 
188   using Op::Op;
189   static ArrayRef<StringRef> getAttributeNames() { return {}; }
190 
191   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
192 
193   DataLayoutSpecInterface getDataLayoutSpec() {
194     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
195   }
196 
197   // Bytes are assumed to be 7-bit here.
198   static unsigned getTypeSize(Type type, const DataLayout &dataLayout,
199                               DataLayoutEntryListRef params) {
200     return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
201   }
202 };
203 
204 /// A dialect putting all the above together.
205 struct DLTestDialect : Dialect {
206   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect)
207 
208   explicit DLTestDialect(MLIRContext *ctx)
209       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
210     ctx->getOrLoadDialect<DLTIDialect>();
211     addAttributes<CustomDataLayoutSpec>();
212     addOperations<OpWithLayout, OpWith7BitByte>();
213     addTypes<SingleQueryType, TypeNoLayout>();
214   }
215   static StringRef getDialectNamespace() { return "dltest"; }
216 
217   void printAttribute(Attribute attr,
218                       DialectAsmPrinter &printer) const override {
219     printer << "spec<";
220     llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(),
221                           printer);
222     printer << ">";
223   }
224 
225   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
226     bool ok =
227         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
228     (void)ok;
229     assert(ok);
230     if (succeeded(parser.parseOptionalGreater()))
231       return CustomDataLayoutSpec::get(parser.getContext(), {});
232 
233     SmallVector<DataLayoutEntryInterface> entries;
234     ok = succeeded(parser.parseCommaSeparatedList([&]() {
235       entries.emplace_back();
236       ok = succeeded(parser.parseAttribute(entries.back()));
237       assert(ok);
238       return success();
239     }));
240     assert(ok);
241     ok = succeeded(parser.parseGreater());
242     assert(ok);
243     return CustomDataLayoutSpec::get(parser.getContext(), entries);
244   }
245 
246   void printType(Type type, DialectAsmPrinter &printer) const override {
247     if (type.isa<SingleQueryType>())
248       printer << "single_query";
249     else
250       printer << "no_layout";
251   }
252 
253   Type parseType(DialectAsmParser &parser) const override {
254     bool ok = succeeded(parser.parseKeyword("single_query"));
255     (void)ok;
256     assert(ok);
257     return SingleQueryType::get(parser.getContext());
258   }
259 };
260 
261 } // namespace
262 
263 TEST(DataLayout, FallbackDefault) {
264   const char *ir = R"MLIR(
265 module {}
266   )MLIR";
267 
268   DialectRegistry registry;
269   registry.insert<DLTIDialect, DLTestDialect>();
270   MLIRContext ctx(registry);
271 
272   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
273   DataLayout layout(module.get());
274   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
275   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
276   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
277   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
278   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
279   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
280   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
281   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
282 
283   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
284   EXPECT_EQ(layout.getStackAlignment(), 0u);
285 }
286 
287 TEST(DataLayout, NullSpec) {
288   const char *ir = R"MLIR(
289 "dltest.op_with_layout"() : () -> ()
290   )MLIR";
291 
292   DialectRegistry registry;
293   registry.insert<DLTIDialect, DLTestDialect>();
294   MLIRContext ctx(registry);
295 
296   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
297   auto op =
298       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
299   DataLayout layout(op);
300 
301   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
302   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
303   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
304   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
305   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
306   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
307   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
308   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
309 
310   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
311   EXPECT_EQ(layout.getStackAlignment(), 0u);
312 }
313 
314 TEST(DataLayout, EmptySpec) {
315   const char *ir = R"MLIR(
316 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
317   )MLIR";
318 
319   DialectRegistry registry;
320   registry.insert<DLTIDialect, DLTestDialect>();
321   MLIRContext ctx(registry);
322 
323   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
324   auto op =
325       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
326   DataLayout layout(op);
327   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
328   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
329   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
330   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
331   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
332   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
333   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
334   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
335 
336   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
337   EXPECT_EQ(layout.getStackAlignment(), 0u);
338 }
339 
340 TEST(DataLayout, SpecWithEntries) {
341   const char *ir = R"MLIR(
342 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
343   #dlti.dl_entry<i42, 5>,
344   #dlti.dl_entry<i16, 6>,
345   #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>,
346   #dlti.dl_entry<"dltest.stack_alignment", 128 : i32>
347 > } : () -> ()
348   )MLIR";
349 
350   DialectRegistry registry;
351   registry.insert<DLTIDialect, DLTestDialect>();
352   MLIRContext ctx(registry);
353 
354   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
355   auto op =
356       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
357   DataLayout layout(op);
358   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
359   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
360   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
361   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
362   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
363   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
364   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
365   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
366 
367   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
368   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
369   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
370   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
371   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
372   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
373   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
374   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
375 
376   EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5));
377   EXPECT_EQ(layout.getStackAlignment(), 128u);
378 }
379 
380 TEST(DataLayout, Caching) {
381   const char *ir = R"MLIR(
382 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
383   )MLIR";
384 
385   DialectRegistry registry;
386   registry.insert<DLTIDialect, DLTestDialect>();
387   MLIRContext ctx(registry);
388 
389   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
390   auto op =
391       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
392   DataLayout layout(op);
393 
394   unsigned sum = 0;
395   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
396   // The second call should hit the cache. If it does not, the function in
397   // SingleQueryType will be called and will abort the process.
398   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
399   // Make sure the complier doesn't optimize away the query code.
400   EXPECT_EQ(sum, 2u);
401 
402   // A fresh data layout has a new cache, so the call to it should be dispatched
403   // down to the type and abort the proces.
404   DataLayout second(op);
405   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
406 }
407 
408 TEST(DataLayout, CacheInvalidation) {
409   const char *ir = R"MLIR(
410 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
411   #dlti.dl_entry<i42, 5>,
412   #dlti.dl_entry<i16, 6>
413 > } : () -> ()
414   )MLIR";
415 
416   DialectRegistry registry;
417   registry.insert<DLTIDialect, DLTestDialect>();
418   MLIRContext ctx(registry);
419 
420   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
421   auto op =
422       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
423   DataLayout layout(op);
424 
425   // Normal query is fine.
426   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
427 
428   // Replace the data layout spec with a new, empty spec.
429   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
430 
431   // Data layout is no longer valid and should trigger assertion when queried.
432 #ifndef NDEBUG
433   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
434 #endif
435 }
436 
437 TEST(DataLayout, UnimplementedTypeInterface) {
438   const char *ir = R"MLIR(
439 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
440   )MLIR";
441 
442   DialectRegistry registry;
443   registry.insert<DLTIDialect, DLTestDialect>();
444   MLIRContext ctx(registry);
445 
446   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
447   auto op =
448       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
449   DataLayout layout(op);
450 
451   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
452                "neither the scoping op nor the type class provide data layout "
453                "information");
454 }
455 
456 TEST(DataLayout, SevenBitByte) {
457   const char *ir = R"MLIR(
458 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
459   )MLIR";
460 
461   DialectRegistry registry;
462   registry.insert<DLTIDialect, DLTestDialect>();
463   MLIRContext ctx(registry);
464 
465   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
466   auto op =
467       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
468   DataLayout layout(op);
469 
470   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
471   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
472   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
473   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
474 }
475