xref: /llvm-project/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (revision abd95342f0b94e140b36ac954b8f8c29b1393861)
1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 constexpr static llvm::StringLiteral kEndiannesKeyName = "dltest.endianness";
26 constexpr static llvm::StringLiteral kAllocaKeyName =
27     "dltest.alloca_memory_space";
28 constexpr static llvm::StringLiteral kProgramKeyName =
29     "dltest.program_memory_space";
30 constexpr static llvm::StringLiteral kGlobalKeyName =
31     "dltest.global_memory_space";
32 constexpr static llvm::StringLiteral kStackAlignmentKeyName =
33     "dltest.stack_alignment";
34 
35 constexpr static llvm::StringLiteral kTargetSystemDescAttrName =
36     "dl_target_sys_desc_test.target_system_spec";
37 
38 /// Trivial array storage for the custom data layout spec attribute, just a list
39 /// of entries.
40 class DataLayoutSpecStorage : public AttributeStorage {
41 public:
42   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
43 
44   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
45       : entries(entries) {}
46 
47   bool operator==(const KeyTy &key) const { return key == entries; }
48 
49   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
50                                           const KeyTy &key) {
51     return new (allocator.allocate<DataLayoutSpecStorage>())
52         DataLayoutSpecStorage(allocator.copyInto(key));
53   }
54 
55   ArrayRef<DataLayoutEntryInterface> entries;
56 };
57 
58 /// Simple data layout spec containing a list of entries that always verifies
59 /// as valid.
60 struct CustomDataLayoutSpec
61     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
62                                  DataLayoutSpecStorage,
63                                  DataLayoutSpecInterface::Trait> {
64   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
65 
66   using Base::Base;
67 
68   static constexpr StringLiteral name = "test.custom_data_layout_spec";
69 
70   static CustomDataLayoutSpec get(MLIRContext *ctx,
71                                   ArrayRef<DataLayoutEntryInterface> entries) {
72     return Base::get(ctx, entries);
73   }
74   CustomDataLayoutSpec
75   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
76     return *this;
77   }
78   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
79   LogicalResult verifySpec(Location loc) { return success(); }
80   StringAttr getEndiannessIdentifier(MLIRContext *context) const {
81     return Builder(context).getStringAttr(kEndiannesKeyName);
82   }
83   StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
84     return Builder(context).getStringAttr(kAllocaKeyName);
85   }
86   StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const {
87     return Builder(context).getStringAttr(kProgramKeyName);
88   }
89   StringAttr getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
90     return Builder(context).getStringAttr(kGlobalKeyName);
91   }
92   StringAttr getStackAlignmentIdentifier(MLIRContext *context) const {
93     return Builder(context).getStringAttr(kStackAlignmentKeyName);
94   }
95 };
96 
97 class TargetSystemSpecStorage : public AttributeStorage {
98 public:
99   using KeyTy = ArrayRef<DeviceIDTargetDeviceSpecPair>;
100 
101   TargetSystemSpecStorage(ArrayRef<DeviceIDTargetDeviceSpecPair> entries)
102       : entries(entries) {}
103 
104   bool operator==(const KeyTy &key) const { return key == entries; }
105 
106   static TargetSystemSpecStorage *
107   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
108     return new (allocator.allocate<TargetSystemSpecStorage>())
109         TargetSystemSpecStorage(allocator.copyInto(key));
110   }
111 
112   ArrayRef<DeviceIDTargetDeviceSpecPair> entries;
113 };
114 
115 struct CustomTargetSystemSpec
116     : public Attribute::AttrBase<CustomTargetSystemSpec, Attribute,
117                                  TargetSystemSpecStorage,
118                                  TargetSystemSpecInterface::Trait> {
119   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
120 
121   using Base::Base;
122 
123   static constexpr StringLiteral name = "test.custom_target_system_spec";
124 
125   static CustomTargetSystemSpec
126   get(MLIRContext *ctx, ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
127     return Base::get(ctx, entries);
128   }
129   DeviceIDTargetDeviceSpecPairListRef getEntries() const {
130     return getImpl()->entries;
131   }
132   LogicalResult verifySpec(Location loc) { return success(); }
133   std::optional<TargetDeviceSpecInterface>
134   getDeviceSpecForDeviceID(TargetSystemSpecInterface::DeviceID deviceID) {
135     for (const auto &entry : getEntries()) {
136       if (entry.first == deviceID)
137         return entry.second;
138     }
139     return std::nullopt;
140   }
141 };
142 
143 /// A type subject to data layout that exits the program if it is queried more
144 /// than once. Handy to check if the cache works.
145 struct SingleQueryType
146     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
147                             DataLayoutTypeInterface::Trait> {
148   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SingleQueryType)
149 
150   using Base::Base;
151 
152   static constexpr StringLiteral name = "test.single_query";
153 
154   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
155 
156   llvm::TypeSize getTypeSizeInBits(const DataLayout &layout,
157                                    DataLayoutEntryListRef params) const {
158     static bool executed = false;
159     if (executed)
160       llvm::report_fatal_error("repeated call");
161 
162     executed = true;
163     return llvm::TypeSize::getFixed(1);
164   }
165 
166   uint64_t getABIAlignment(const DataLayout &layout,
167                            DataLayoutEntryListRef params) {
168     static bool executed = false;
169     if (executed)
170       llvm::report_fatal_error("repeated call");
171 
172     executed = true;
173     return 2;
174   }
175 
176   uint64_t getPreferredAlignment(const DataLayout &layout,
177                                  DataLayoutEntryListRef params) {
178     static bool executed = false;
179     if (executed)
180       llvm::report_fatal_error("repeated call");
181 
182     executed = true;
183     return 4;
184   }
185 
186   Attribute getEndianness(DataLayoutEntryInterface entry) {
187     static bool executed = false;
188     if (executed)
189       llvm::report_fatal_error("repeated call");
190 
191     executed = true;
192     return Attribute();
193   }
194 
195   Attribute getAllocaMemorySpace(DataLayoutEntryInterface entry) {
196     static bool executed = false;
197     if (executed)
198       llvm::report_fatal_error("repeated call");
199 
200     executed = true;
201     return Attribute();
202   }
203 
204   Attribute getProgramMemorySpace(DataLayoutEntryInterface entry) {
205     static bool executed = false;
206     if (executed)
207       llvm::report_fatal_error("repeated call");
208 
209     executed = true;
210     return Attribute();
211   }
212 
213   Attribute getGlobalMemorySpace(DataLayoutEntryInterface entry) {
214     static bool executed = false;
215     if (executed)
216       llvm::report_fatal_error("repeated call");
217 
218     executed = true;
219     return Attribute();
220   }
221 };
222 
223 /// A types that is not subject to data layout.
224 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
225   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeNoLayout)
226 
227   using Base::Base;
228 
229   static constexpr StringLiteral name = "test.no_layout";
230 
231   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
232 };
233 
234 /// An op that serves as scope for data layout queries with the relevant
235 /// attribute attached. This can handle data layout requests for the built-in
236 /// types itself.
237 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
238   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithLayout)
239 
240   using Op::Op;
241   static ArrayRef<StringRef> getAttributeNames() { return {}; }
242 
243   static StringRef getOperationName() { return "dltest.op_with_layout"; }
244 
245   DataLayoutSpecInterface getDataLayoutSpec() {
246     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
247   }
248 
249   TargetSystemSpecInterface getTargetSystemSpec() {
250     return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
251         kTargetSystemDescAttrName);
252   }
253 
254   static llvm::TypeSize getTypeSizeInBits(Type type,
255                                           const DataLayout &dataLayout,
256                                           DataLayoutEntryListRef params) {
257     // Make a recursive query.
258     if (isa<FloatType>(type))
259       return dataLayout.getTypeSizeInBits(
260           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
261 
262     // Handle built-in types that are not handled by the default process.
263     if (auto iType = dyn_cast<IntegerType>(type)) {
264       for (DataLayoutEntryInterface entry : params)
265         if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
266           return llvm::TypeSize::getFixed(
267               8 *
268               cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue());
269       return llvm::TypeSize::getFixed(8 * iType.getIntOrFloatBitWidth());
270     }
271 
272     // Use the default process for everything else.
273     return detail::getDefaultTypeSize(type, dataLayout, params);
274   }
275 
276   static uint64_t getTypeABIAlignment(Type type, const DataLayout &dataLayout,
277                                       DataLayoutEntryListRef params) {
278     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
279   }
280 
281   static uint64_t getTypePreferredAlignment(Type type,
282                                             const DataLayout &dataLayout,
283                                             DataLayoutEntryListRef params) {
284     return 2 * getTypeABIAlignment(type, dataLayout, params);
285   }
286 };
287 
288 struct OpWith7BitByte
289     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
290   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWith7BitByte)
291 
292   using Op::Op;
293   static ArrayRef<StringRef> getAttributeNames() { return {}; }
294 
295   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
296 
297   DataLayoutSpecInterface getDataLayoutSpec() {
298     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
299   }
300 
301   TargetSystemSpecInterface getTargetSystemSpec() {
302     return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
303         kTargetSystemDescAttrName);
304   }
305 
306   // Bytes are assumed to be 7-bit here.
307   static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
308                                     DataLayoutEntryListRef params) {
309     return mlir::detail::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
310   }
311 };
312 
313 /// A dialect putting all the above together.
314 struct DLTestDialect : Dialect {
315   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTestDialect)
316 
317   explicit DLTestDialect(MLIRContext *ctx)
318       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
319     ctx->getOrLoadDialect<DLTIDialect>();
320     addAttributes<CustomDataLayoutSpec>();
321     addOperations<OpWithLayout, OpWith7BitByte>();
322     addTypes<SingleQueryType, TypeNoLayout>();
323   }
324   static StringRef getDialectNamespace() { return "dltest"; }
325 
326   void printAttribute(Attribute attr,
327                       DialectAsmPrinter &printer) const override {
328     printer << "spec<";
329     llvm::interleaveComma(cast<CustomDataLayoutSpec>(attr).getEntries(),
330                           printer);
331     printer << ">";
332   }
333 
334   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
335     bool ok =
336         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
337     (void)ok;
338     assert(ok);
339     if (succeeded(parser.parseOptionalGreater()))
340       return CustomDataLayoutSpec::get(parser.getContext(), {});
341 
342     SmallVector<DataLayoutEntryInterface> entries;
343     ok = succeeded(parser.parseCommaSeparatedList([&]() {
344       entries.emplace_back();
345       ok = succeeded(parser.parseAttribute(entries.back()));
346       assert(ok);
347       return success();
348     }));
349     assert(ok);
350     ok = succeeded(parser.parseGreater());
351     assert(ok);
352     return CustomDataLayoutSpec::get(parser.getContext(), entries);
353   }
354 
355   void printType(Type type, DialectAsmPrinter &printer) const override {
356     if (isa<SingleQueryType>(type))
357       printer << "single_query";
358     else
359       printer << "no_layout";
360   }
361 
362   Type parseType(DialectAsmParser &parser) const override {
363     bool ok = succeeded(parser.parseKeyword("single_query"));
364     (void)ok;
365     assert(ok);
366     return SingleQueryType::get(parser.getContext());
367   }
368 };
369 
370 /// A dialect to test DLTI's target system spec and related attributes
371 struct DLTargetSystemDescTestDialect : public Dialect {
372   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
373 
374   explicit DLTargetSystemDescTestDialect(MLIRContext *ctx)
375       : Dialect(getDialectNamespace(), ctx,
376                 TypeID::get<DLTargetSystemDescTestDialect>()) {
377     ctx->getOrLoadDialect<DLTIDialect>();
378     addAttributes<CustomTargetSystemSpec>();
379   }
380   static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; }
381 
382   void printAttribute(Attribute attr,
383                       DialectAsmPrinter &printer) const override {
384     printer << "target_system_spec<";
385     llvm::interleaveComma(
386         cast<CustomTargetSystemSpec>(attr).getEntries(), printer,
387         [&](const auto &it) { printer << it.first << ":" << it.second; });
388     printer << ">";
389   }
390 
391   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
392     bool ok = succeeded(parser.parseKeyword("target_system_spec")) &&
393               succeeded(parser.parseLess());
394     (void)ok;
395     assert(ok);
396     if (succeeded(parser.parseOptionalGreater()))
397       return CustomTargetSystemSpec::get(parser.getContext(), {});
398 
399     auto parseDeviceIDTargetDeviceSpecPair =
400         [&](AsmParser &parser) -> FailureOr<DeviceIDTargetDeviceSpecPair> {
401       std::string deviceID;
402       if (failed(parser.parseString(&deviceID))) {
403         parser.emitError(parser.getCurrentLocation())
404             << "DeviceID is missing, or is not of string type";
405         return failure();
406       }
407       if (failed(parser.parseColon())) {
408         parser.emitError(parser.getCurrentLocation()) << "Missing colon";
409         return failure();
410       }
411 
412       TargetDeviceSpecInterface targetDeviceSpec;
413       if (failed(parser.parseAttribute(targetDeviceSpec))) {
414         parser.emitError(parser.getCurrentLocation())
415             << "Error in parsing target device spec";
416         return failure();
417       }
418       return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
419                             targetDeviceSpec);
420     };
421 
422     SmallVector<DeviceIDTargetDeviceSpecPair> entries;
423     ok = succeeded(parser.parseCommaSeparatedList([&]() {
424       auto deviceIDAndTargetDeviceSpecPair =
425           parseDeviceIDTargetDeviceSpecPair(parser);
426       ok = succeeded(deviceIDAndTargetDeviceSpecPair);
427       assert(ok);
428       entries.push_back(*deviceIDAndTargetDeviceSpecPair);
429       return success();
430     }));
431     assert(ok);
432     ok = succeeded(parser.parseGreater());
433     assert(ok);
434     return CustomTargetSystemSpec::get(parser.getContext(), entries);
435   }
436 };
437 
438 } // namespace
439 
440 TEST(DataLayout, FallbackDefault) {
441   const char *ir = R"MLIR(
442 module {}
443   )MLIR";
444 
445   DialectRegistry registry;
446   registry.insert<DLTIDialect, DLTestDialect>();
447   MLIRContext ctx(registry);
448 
449   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
450   DataLayout layout(module.get());
451   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
452   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
453   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
454   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
455   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
456   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
457   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
458   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
459 
460   EXPECT_EQ(layout.getEndianness(), Attribute());
461   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
462   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
463   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
464   EXPECT_EQ(layout.getStackAlignment(), 0u);
465 }
466 
467 TEST(DataLayout, NullSpec) {
468   const char *ir = R"MLIR(
469 "dltest.op_with_layout"() : () -> ()
470   )MLIR";
471 
472   DialectRegistry registry;
473   registry.insert<DLTIDialect, DLTestDialect>();
474   MLIRContext ctx(registry);
475 
476   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
477   auto op =
478       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
479   DataLayout layout(op);
480 
481   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
482   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
483   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
484   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
485   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
486   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
487   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
488   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
489   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
490   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
491 
492   EXPECT_EQ(layout.getEndianness(), Attribute());
493   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
494   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
495   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
496   EXPECT_EQ(layout.getStackAlignment(), 0u);
497 
498   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
499                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
500                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
501             std::nullopt);
502   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
503                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
504                 Builder(&ctx).getStringAttr("max_vector_width")),
505             std::nullopt);
506 }
507 
508 TEST(DataLayout, EmptySpec) {
509   const char *ir = R"MLIR(
510 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
511   )MLIR";
512 
513   DialectRegistry registry;
514   registry.insert<DLTIDialect, DLTestDialect>();
515   MLIRContext ctx(registry);
516 
517   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
518   auto op =
519       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
520   DataLayout layout(op);
521   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
522   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
523   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
524   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
525   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
526   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
527   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
528   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
529   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
530   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
531 
532   EXPECT_EQ(layout.getEndianness(), Attribute());
533   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
534   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
535   EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
536   EXPECT_EQ(layout.getStackAlignment(), 0u);
537 
538   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
539                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
540                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
541             std::nullopt);
542   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
543                 Builder(&ctx).getStringAttr("CPU" /* device ID*/),
544                 Builder(&ctx).getStringAttr("max_vector_width")),
545             std::nullopt);
546 }
547 
548 TEST(DataLayout, SpecWithEntries) {
549   const char *ir = R"MLIR(
550 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
551   #dlti.dl_entry<i42, 5>,
552   #dlti.dl_entry<i16, 6>,
553   #dlti.dl_entry<index, 42>,
554   #dlti.dl_entry<"dltest.endianness", "little">,
555   #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>,
556   #dlti.dl_entry<"dltest.program_memory_space", 3 : i32>,
557   #dlti.dl_entry<"dltest.global_memory_space", 2 : i32>,
558   #dlti.dl_entry<"dltest.stack_alignment", 128 : i32>
559 > } : () -> ()
560   )MLIR";
561 
562   DialectRegistry registry;
563   registry.insert<DLTIDialect, DLTestDialect>();
564   MLIRContext ctx(registry);
565 
566   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
567   auto op =
568       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
569   DataLayout layout(op);
570   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
571   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
572   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
573   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
574   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
575   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
576   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
577   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
578   EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
579   EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 42u);
580 
581   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
582   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
583   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
584   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
585   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
586   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
587   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
588   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
589 
590   EXPECT_EQ(layout.getEndianness(), Builder(&ctx).getStringAttr("little"));
591   EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5));
592   EXPECT_EQ(layout.getProgramMemorySpace(), Builder(&ctx).getI32IntegerAttr(3));
593   EXPECT_EQ(layout.getGlobalMemorySpace(), Builder(&ctx).getI32IntegerAttr(2));
594   EXPECT_EQ(layout.getStackAlignment(), 128u);
595 }
596 
597 TEST(DataLayout, SpecWithTargetSystemDescEntries) {
598   const char *ir = R"MLIR(
599   module attributes { dl_target_sys_desc_test.target_system_spec =
600     #dl_target_sys_desc_test.target_system_spec<
601       "CPU": #dlti.target_device_spec<
602               #dlti.dl_entry<"L1_cache_size_in_bytes", 4096 : ui32>,
603               #dlti.dl_entry<"max_vector_op_width", 128 : ui32>>
604     > } {}
605   )MLIR";
606 
607   DialectRegistry registry;
608   registry.insert<DLTIDialect, DLTargetSystemDescTestDialect>();
609   MLIRContext ctx(registry);
610 
611   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
612   DataLayout layout(*module);
613   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
614                 Builder(&ctx).getStringAttr("CPU") /* device ID*/,
615                 Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
616             std::optional<int64_t>(4096));
617   EXPECT_EQ(layout.getDevicePropertyValueAsInt(
618                 Builder(&ctx).getStringAttr("CPU") /* device ID*/,
619                 Builder(&ctx).getStringAttr("max_vector_op_width")),
620             std::optional<int64_t>(128));
621 }
622 
623 TEST(DataLayout, Caching) {
624   const char *ir = R"MLIR(
625 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
626   )MLIR";
627 
628   DialectRegistry registry;
629   registry.insert<DLTIDialect, DLTestDialect>();
630   MLIRContext ctx(registry);
631 
632   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
633   auto op =
634       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
635   DataLayout layout(op);
636 
637   unsigned sum = 0;
638   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
639   // The second call should hit the cache. If it does not, the function in
640   // SingleQueryType will be called and will abort the process.
641   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
642   // Make sure the complier doesn't optimize away the query code.
643   EXPECT_EQ(sum, 2u);
644 
645   // A fresh data layout has a new cache, so the call to it should be dispatched
646   // down to the type and abort the process.
647   DataLayout second(op);
648   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
649 }
650 
651 TEST(DataLayout, CacheInvalidation) {
652   const char *ir = R"MLIR(
653 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
654   #dlti.dl_entry<i42, 5>,
655   #dlti.dl_entry<i16, 6>
656 > } : () -> ()
657   )MLIR";
658 
659   DialectRegistry registry;
660   registry.insert<DLTIDialect, DLTestDialect>();
661   MLIRContext ctx(registry);
662 
663   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
664   auto op =
665       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
666   DataLayout layout(op);
667 
668   // Normal query is fine.
669   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
670 
671   // Replace the data layout spec with a new, empty spec.
672   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
673 
674   // Data layout is no longer valid and should trigger assertion when queried.
675 #ifndef NDEBUG
676   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
677 #endif
678 }
679 
680 TEST(DataLayout, UnimplementedTypeInterface) {
681   const char *ir = R"MLIR(
682 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
683   )MLIR";
684 
685   DialectRegistry registry;
686   registry.insert<DLTIDialect, DLTestDialect>();
687   MLIRContext ctx(registry);
688 
689   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
690   auto op =
691       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
692   DataLayout layout(op);
693 
694   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
695                "neither the scoping op nor the type class provide data layout "
696                "information");
697 }
698 
699 TEST(DataLayout, SevenBitByte) {
700   const char *ir = R"MLIR(
701 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
702   )MLIR";
703 
704   DialectRegistry registry;
705   registry.insert<DLTIDialect, DLTestDialect>();
706   MLIRContext ctx(registry);
707 
708   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
709   auto op =
710       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
711   DataLayout layout(op);
712 
713   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
714   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
715   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
716   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
717 }
718