xref: /llvm-project/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (revision 5c1752e368585e55c0335a7d7651fe43d42af282)
1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 constexpr static llvm::StringLiteral kEndiannesKeyName = "dltest.endianness";
26 constexpr static llvm::StringLiteral kAllocaKeyName =
27     "dltest.alloca_memory_space";
28 constexpr static llvm::StringLiteral kProgramKeyName =
29     "dltest.program_memory_space";
30 constexpr static llvm::StringLiteral kGlobalKeyName =
31     "dltest.global_memory_space";
32 constexpr static llvm::StringLiteral kStackAlignmentKeyName =
33     "dltest.stack_alignment";
34 
35 constexpr static llvm::StringLiteral kTargetSystemDescAttrName =
36     "dl_target_sys_desc_test.target_system_spec";
37 
38 /// Trivial array storage for the custom data layout spec attribute, just a list
39 /// of entries.
40 class DataLayoutSpecStorage : public AttributeStorage {
41 public:
42   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
43 
44   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
45       : entries(entries) {}
46 
47   bool operator==(const KeyTy &key) const { return key == entries; }
48 
49   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
50                                           const KeyTy &key) {
51     return new (allocator.allocate<DataLayoutSpecStorage>())
52         DataLayoutSpecStorage(allocator.copyInto(key));
53   }
54 
55   ArrayRef<DataLayoutEntryInterface> entries;
56 };
57 
58 /// Simple data layout spec containing a list of entries that always verifies
59 /// as valid.
60 struct CustomDataLayoutSpec
61     : public Attribute::AttrBase<
62           CustomDataLayoutSpec, Attribute, DataLayoutSpecStorage,
63           DLTIQueryInterface::Trait, DataLayoutSpecInterface::Trait> {
64   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
65 
66   using Base::Base;
67 
68   static constexpr StringLiteral name = "test.custom_data_layout_spec";
69 
70   static CustomDataLayoutSpec get(MLIRContext *ctx,
71                                   ArrayRef<DataLayoutEntryInterface> entries) {
72     return Base::get(ctx, entries);
73   }
74   CustomDataLayoutSpec
75   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
76     return *this;
77   }
78   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
79   LogicalResult verifySpec(Location loc) { return success(); }
80   StringAttr getEndiannessIdentifier(MLIRContext *context) const {
81     return Builder(context).getStringAttr(kEndiannesKeyName);
82   }
83   StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
84     return Builder(context).getStringAttr(kAllocaKeyName);
85   }
86   StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const {
87     return Builder(context).getStringAttr(kProgramKeyName);
88   }
89   StringAttr getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
90     return Builder(context).getStringAttr(kGlobalKeyName);
91   }
92   StringAttr getStackAlignmentIdentifier(MLIRContext *context) const {
93     return Builder(context).getStringAttr(kStackAlignmentKeyName);
94   }
95   FailureOr<Attribute> query(DataLayoutEntryKey key) const {
96     return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
97   }
98 };
99 
100 class TargetSystemSpecStorage : public AttributeStorage {
101 public:
102   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
103 
104   TargetSystemSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
105       : entries(entries) {}
106 
107   bool operator==(const KeyTy &key) const { return key == entries; }
108 
109   static TargetSystemSpecStorage *
110   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
111     return new (allocator.allocate<TargetSystemSpecStorage>())
112         TargetSystemSpecStorage(allocator.copyInto(key));
113   }
114 
115   ArrayRef<DataLayoutEntryInterface> entries;
116 };
117 
118 struct CustomTargetSystemSpec
119     : public Attribute::AttrBase<
120           CustomTargetSystemSpec, Attribute, TargetSystemSpecStorage,
121           DLTIQueryInterface::Trait, TargetSystemSpecInterface::Trait> {
122   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
123 
124   using Base::Base;
125 
126   static constexpr StringLiteral name = "test.custom_target_system_spec";
127 
128   static CustomTargetSystemSpec
129   get(MLIRContext *ctx, ArrayRef<DataLayoutEntryInterface> entries) {
130     return Base::get(ctx, entries);
131   }
132   ArrayRef<DataLayoutEntryInterface> getEntries() const {
133     return getImpl()->entries;
134   }
135   LogicalResult verifySpec(Location loc) { return success(); }
136   std::optional<TargetDeviceSpecInterface>
137   getDeviceSpecForDeviceID(TargetSystemSpecInterface::DeviceID deviceID) {
138     for (const auto &entry : getEntries()) {
139       if (entry.getKey() == DataLayoutEntryKey(deviceID))
140         if (auto deviceSpec =
141                 llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
142           return deviceSpec;
143     }
144     return std::nullopt;
145   }
146   FailureOr<Attribute> query(DataLayoutEntryKey key) const {
147     return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
148   }
149 };
150 
151 /// A type subject to data layout that exits the program if it is queried more
152 /// than once. Handy to check if the cache works.
153 struct SingleQueryType
154     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
155                             DataLayoutTypeInterface::Trait> {
156   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType)
157 
158   using Base::Base;
159 
160   static constexpr StringLiteral name = "test.single_query";
161 
162   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
163 
164   llvm::TypeSize getTypeSizeInBits(const DataLayout &layout,
165                                    DataLayoutEntryListRef params) const {
166     static bool executed = false;
167     if (executed)
168       llvm::report_fatal_error("repeated call");
169 
170     executed = true;
171     return llvm::TypeSize::getFixed(1);
172   }
173 
174   uint64_t getABIAlignment(const DataLayout &layout,
175                            DataLayoutEntryListRef params) {
176     static bool executed = false;
177     if (executed)
178       llvm::report_fatal_error("repeated call");
179 
180     executed = true;
181     return 2;
182   }
183 
184   uint64_t getPreferredAlignment(const DataLayout &layout,
185                                  DataLayoutEntryListRef params) {
186     static bool executed = false;
187     if (executed)
188       llvm::report_fatal_error("repeated call");
189 
190     executed = true;
191     return 4;
192   }
193 
194   Attribute getEndianness(DataLayoutEntryInterface entry) {
195     static bool executed = false;
196     if (executed)
197       llvm::report_fatal_error("repeated call");
198 
199     executed = true;
200     return Attribute();
201   }
202 
203   Attribute getAllocaMemorySpace(DataLayoutEntryInterface entry) {
204     static bool executed = false;
205     if (executed)
206       llvm::report_fatal_error("repeated call");
207 
208     executed = true;
209     return Attribute();
210   }
211 
212   Attribute getProgramMemorySpace(DataLayoutEntryInterface entry) {
213     static bool executed = false;
214     if (executed)
215       llvm::report_fatal_error("repeated call");
216 
217     executed = true;
218     return Attribute();
219   }
220 
221   Attribute getGlobalMemorySpace(DataLayoutEntryInterface entry) {
222     static bool executed = false;
223     if (executed)
224       llvm::report_fatal_error("repeated call");
225 
226     executed = true;
227     return Attribute();
228   }
229 };
230 
231 /// A types that is not subject to data layout.
232 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
233   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout)
234 
235   using Base::Base;
236 
237   static constexpr StringLiteral name = "test.no_layout";
238 
239   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
240 };
241 
242 /// An op that serves as scope for data layout queries with the relevant
243 /// attribute attached. This can handle data layout requests for the built-in
244 /// types itself.
245 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
246   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout)
247 
248   using Op::Op;
249   static ArrayRef<StringRef> getAttributeNames() { return {}; }
250 
251   static StringRef getOperationName() { return "dltest.op_with_layout"; }
252 
253   DataLayoutSpecInterface getDataLayoutSpec() {
254     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
255   }
256 
257   TargetSystemSpecInterface getTargetSystemSpec() {
258     return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
259         kTargetSystemDescAttrName);
260   }
261 
262   static llvm::TypeSize getTypeSizeInBits(Type type,
263                                           const DataLayout &dataLayout,
264                                           DataLayoutEntryListRef params) {
265     // Make a recursive query.
266     if (isa<FloatType>(type))
267       return dataLayout.getTypeSizeInBits(
268           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
269 
270     // Handle built-in types that are not handled by the default process.
271     if (auto iType = dyn_cast<IntegerType>(type)) {
272       for (DataLayoutEntryInterface entry : params)
273         if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
274           return llvm::TypeSize::getFixed(
275               8 *
276               cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue());
277       return llvm::TypeSize::getFixed(8 * iType.getIntOrFloatBitWidth());
278     }
279 
280     // Use the default process for everything else.
281     return detail::getDefaultTypeSize(type, dataLayout, params);
282   }
283 
284   static uint64_t getTypeABIAlignment(Type type, const DataLayout &dataLayout,
285                                       DataLayoutEntryListRef params) {
286     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
287   }
288 
289   static uint64_t getTypePreferredAlignment(Type type,
290                                             const DataLayout &dataLayout,
291                                             DataLayoutEntryListRef params) {
292     return 2 * getTypeABIAlignment(type, dataLayout, params);
293   }
294 };
295 
296 struct OpWith7BitByte
297     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
298   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte)
299 
300   using Op::Op;
301   static ArrayRef<StringRef> getAttributeNames() { return {}; }
302 
303   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
304 
305   DataLayoutSpecInterface getDataLayoutSpec() {
306     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
307   }
308 
309   TargetSystemSpecInterface getTargetSystemSpec() {
310     return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
311         kTargetSystemDescAttrName);
312   }
313 
314   // Bytes are assumed to be 7-bit here.
315   static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
316                                     DataLayoutEntryListRef params) {
317     return mlir::detail::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
318   }
319 };
320 
321 /// A dialect putting all the above together.
322 struct DLTestDialect : Dialect {
323   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect)
324 
325   explicit DLTestDialect(MLIRContext *ctx)
326       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
327     ctx->getOrLoadDialect<DLTIDialect>();
328     addAttributes<CustomDataLayoutSpec>();
329     addOperations<OpWithLayout, OpWith7BitByte>();
330     addTypes<SingleQueryType, TypeNoLayout>();
331   }
332   static StringRef getDialectNamespace() { return "dltest"; }
333 
334   void printAttribute(Attribute attr,
335                       DialectAsmPrinter &printer) const override {
336     printer << "spec<";
337     llvm::interleaveComma(cast<CustomDataLayoutSpec>(attr).getEntries(),
338                           printer);
339     printer << ">";
340   }
341 
342   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
343     bool ok =
344         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
345     (void)ok;
346     assert(ok);
347     if (succeeded(parser.parseOptionalGreater()))
348       return CustomDataLayoutSpec::get(parser.getContext(), {});
349 
350     SmallVector<DataLayoutEntryInterface> entries;
351     ok = succeeded(parser.parseCommaSeparatedList([&]() {
352       entries.emplace_back();
353       ok = succeeded(parser.parseAttribute(entries.back()));
354       assert(ok);
355       return success();
356     }));
357     assert(ok);
358     ok = succeeded(parser.parseGreater());
359     assert(ok);
360     return CustomDataLayoutSpec::get(parser.getContext(), entries);
361   }
362 
363   void printType(Type type, DialectAsmPrinter &printer) const override {
364     if (isa<SingleQueryType>(type))
365       printer << "single_query";
366     else
367       printer << "no_layout";
368   }
369 
370   Type parseType(DialectAsmParser &parser) const override {
371     bool ok = succeeded(parser.parseKeyword("single_query"));
372     (void)ok;
373     assert(ok);
374     return SingleQueryType::get(parser.getContext());
375   }
376 };
377 
378 /// A dialect to test DLTI's target system spec and related attributes
379 struct DLTargetSystemDescTestDialect : public Dialect {
380   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
381 
382   explicit DLTargetSystemDescTestDialect(MLIRContext *ctx)
383       : Dialect(getDialectNamespace(), ctx,
384                 TypeID::get<DLTargetSystemDescTestDialect>()) {
385     ctx->getOrLoadDialect<DLTIDialect>();
386     addAttributes<CustomTargetSystemSpec>();
387   }
388   static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; }
389 
390   void printAttribute(Attribute attr,
391                       DialectAsmPrinter &printer) const override {
392     printer << "target_system_spec<";
393     llvm::interleaveComma(cast<CustomTargetSystemSpec>(attr).getEntries(),
394                           printer, [&](const auto &it) {
395                             printer << dyn_cast<StringAttr>(it.getKey()) << ":"
396                                     << it.getValue();
397                           });
398     printer << ">";
399   }
400 
401   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
402     bool ok = succeeded(parser.parseKeyword("target_system_spec")) &&
403               succeeded(parser.parseLess());
404     (void)ok;
405     assert(ok);
406     if (succeeded(parser.parseOptionalGreater()))
407       return CustomTargetSystemSpec::get(parser.getContext(), {});
408 
409     auto parseTargetDeviceSpecEntry =
410         [&](AsmParser &parser) -> FailureOr<TargetDeviceSpecEntry> {
411       std::string deviceID;
412       if (failed(parser.parseString(&deviceID))) {
413         parser.emitError(parser.getCurrentLocation())
414             << "DeviceID is missing, or is not of string type";
415         return failure();
416       }
417       if (failed(parser.parseColon())) {
418         parser.emitError(parser.getCurrentLocation()) << "Missing colon";
419         return failure();
420       }
421 
422       TargetDeviceSpecInterface targetDeviceSpec;
423       if (failed(parser.parseAttribute(targetDeviceSpec))) {
424         parser.emitError(parser.getCurrentLocation())
425             << "Error in parsing target device spec";
426         return failure();
427       }
428       return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
429                             targetDeviceSpec);
430     };
431 
432     SmallVector<DataLayoutEntryInterface> entries;
433     ok = succeeded(parser.parseCommaSeparatedList([&]() {
434       auto deviceIDAndTargetDeviceSpecPair = parseTargetDeviceSpecEntry(parser);
435       ok = succeeded(deviceIDAndTargetDeviceSpecPair);
436       assert(ok);
437       auto entry =
438           DataLayoutEntryAttr::get(deviceIDAndTargetDeviceSpecPair->first,
439                                    deviceIDAndTargetDeviceSpecPair->second);
440       entries.push_back(entry);
441       return success();
442     }));
443     assert(ok);
444     ok = succeeded(parser.parseGreater());
445     assert(ok);
446     return CustomTargetSystemSpec::get(parser.getContext(), entries);
447   }
448 };
449 
450 } // namespace
451 
452 TEST(DataLayout, FallbackDefault) {
453   const char *ir = R"MLIR(
454 module {}
455   )MLIR";
456 
457   DialectRegistry registry;
458   registry.insert<DLTIDialect, DLTestDialect>();
459   MLIRContext ctx(registry);
460 
461   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
462   DataLayout layout(module.get());
463   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
464   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
465   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
466   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
467   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
468   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
469   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
470   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
471 
472   EXPECT_EQ(layout.getEndianness(), Attribute());
473   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
474   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
475   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
476   EXPECT_EQ(layout.getStackAlignment(), 0u);
477 }
478 
479 TEST(DataLayout, NullSpec) {
480   const char *ir = R"MLIR(
481 "dltest.op_with_layout"() : () -> ()
482   )MLIR";
483 
484   DialectRegistry registry;
485   registry.insert<DLTIDialect, DLTestDialect>();
486   MLIRContext ctx(registry);
487 
488   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
489   auto op =
490       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
491   DataLayout layout(op);
492 
493   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
494   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
495   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
496   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
497   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
498   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
499   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
500   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
501   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
502   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
503 
504   EXPECT_EQ(layout.getEndianness(), Attribute());
505   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
506   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
507   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
508   EXPECT_EQ(layout.getStackAlignment(), 0u);
509 
510   EXPECT_EQ(layout.getDevicePropertyValue(
511                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
512                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
513             std::nullopt);
514   EXPECT_EQ(layout.getDevicePropertyValue(
515                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
516                 Builder(&ctx).getStringAttr("max_vector_width")),
517             std::nullopt);
518 }
519 
520 TEST(DataLayout, EmptySpec) {
521   const char *ir = R"MLIR(
522 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
523   )MLIR";
524 
525   DialectRegistry registry;
526   registry.insert<DLTIDialect, DLTestDialect>();
527   MLIRContext ctx(registry);
528 
529   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
530   auto op =
531       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
532   DataLayout layout(op);
533   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
534   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
535   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
536   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
537   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
538   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
539   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
540   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
541   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
542   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
543 
544   EXPECT_EQ(layout.getEndianness(), Attribute());
545   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
546   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
547   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
548   EXPECT_EQ(layout.getStackAlignment(), 0u);
549 
550   EXPECT_EQ(layout.getDevicePropertyValue(
551                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
552                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
553             std::nullopt);
554   EXPECT_EQ(layout.getDevicePropertyValue(
555                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
556                 Builder(&ctx).getStringAttr("max_vector_width")),
557             std::nullopt);
558 }
559 
560 TEST(DataLayout, SpecWithEntries) {
561   const char *ir = R"MLIR(
562 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
563   #dlti.dl_entry<i42, 5>,
564   #dlti.dl_entry<i16, 6>,
565   #dlti.dl_entry<index, 42>,
566   #dlti.dl_entry<"dltest.endianness", "little">,
567   #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>,
568   #dlti.dl_entry<"dltest.program_memory_space", 3 : i32>,
569   #dlti.dl_entry<"dltest.global_memory_space", 2 : i32>,
570   #dlti.dl_entry<"dltest.stack_alignment", 128 : i32>
571 > } : () -> ()
572   )MLIR";
573 
574   DialectRegistry registry;
575   registry.insert<DLTIDialect, DLTestDialect>();
576   MLIRContext ctx(registry);
577 
578   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
579   auto op =
580       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
581   DataLayout layout(op);
582   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
583   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
584   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
585   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
586   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
587   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
588   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
589   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
590   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
591   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 42u);
592 
593   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
594   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
595   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
596   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
597   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
598   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
599   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
600   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
601 
602   EXPECT_EQ(layout.getEndianness(), Builder(&ctx).getStringAttr("little"));
603   EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5));
604   EXPECT_EQ(layout.getProgramMemorySpace(), Builder(&ctx).getI32IntegerAttr(3));
605   EXPECT_EQ(layout.getGlobalMemorySpace(), Builder(&ctx).getI32IntegerAttr(2));
606   EXPECT_EQ(layout.getStackAlignment(), 128u);
607 }
608 
609 TEST(DataLayout, SpecWithTargetSystemDescEntries) {
610   const char *ir = R"MLIR(
611   module attributes { dl_target_sys_desc_test.target_system_spec =
612     #dl_target_sys_desc_test.target_system_spec<
613       "CPU": #dlti.target_device_spec<
614               #dlti.dl_entry<"L1_cache_size_in_bytes", "4096">,
615               #dlti.dl_entry<"max_vector_op_width", "128">>
616     > } {}
617   )MLIR";
618 
619   DialectRegistry registry;
620   registry.insert<DLTIDialect, DLTargetSystemDescTestDialect>();
621   MLIRContext ctx(registry);
622 
623   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
624   DataLayout layout(*module);
625   EXPECT_EQ(layout.getDevicePropertyValue(
626                 Builder(&ctx).getStringAttr("CPU") /* device ID*/,
627                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
628             std::optional<Attribute>(Builder(&ctx).getStringAttr("4096")));
629   EXPECT_EQ(layout.getDevicePropertyValue(
630                 Builder(&ctx).getStringAttr("CPU") /* device ID*/,
631                 Builder(&ctx).getStringAttr("max_vector_op_width")),
632             std::optional<Attribute>(Builder(&ctx).getStringAttr("128")));
633 }
634 
635 TEST(DataLayout, Caching) {
636   const char *ir = R"MLIR(
637 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
638   )MLIR";
639 
640   DialectRegistry registry;
641   registry.insert<DLTIDialect, DLTestDialect>();
642   MLIRContext ctx(registry);
643 
644   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
645   auto op =
646       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
647   DataLayout layout(op);
648 
649   unsigned sum = 0;
650   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
651   // The second call should hit the cache. If it does not, the function in
652   // SingleQueryType will be called and will abort the process.
653   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
654   // Make sure the complier doesn't optimize away the query code.
655   EXPECT_EQ(sum, 2u);
656 
657   // A fresh data layout has a new cache, so the call to it should be dispatched
658   // down to the type and abort the process.
659   DataLayout second(op);
660   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
661 }
662 
663 TEST(DataLayout, CacheInvalidation) {
664   const char *ir = R"MLIR(
665 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
666   #dlti.dl_entry<i42, 5>,
667   #dlti.dl_entry<i16, 6>
668 > } : () -> ()
669   )MLIR";
670 
671   DialectRegistry registry;
672   registry.insert<DLTIDialect, DLTestDialect>();
673   MLIRContext ctx(registry);
674 
675   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
676   auto op =
677       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
678   DataLayout layout(op);
679 
680   // Normal query is fine.
681   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
682 
683   // Replace the data layout spec with a new, empty spec.
684   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
685 
686   // Data layout is no longer valid and should trigger assertion when queried.
687 #ifndef NDEBUG
688   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
689 #endif
690 }
691 
692 TEST(DataLayout, UnimplementedTypeInterface) {
693   const char *ir = R"MLIR(
694 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
695   )MLIR";
696 
697   DialectRegistry registry;
698   registry.insert<DLTIDialect, DLTestDialect>();
699   MLIRContext ctx(registry);
700 
701   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
702   auto op =
703       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
704   DataLayout layout(op);
705 
706   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
707                "neither the scoping op nor the type class provide data layout "
708                "information");
709 }
710 
711 TEST(DataLayout, SevenBitByte) {
712   const char *ir = R"MLIR(
713 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
714   )MLIR";
715 
716   DialectRegistry registry;
717   registry.insert<DLTIDialect, DLTestDialect>();
718   MLIRContext ctx(registry);
719 
720   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
721   auto op =
722       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
723   DataLayout layout(op);
724 
725   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
726   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
727   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
728   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
729 }
730