xref: /llvm-project/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
1 //===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Ptr dialect types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
14 #include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::ptr;
19 
20 //===----------------------------------------------------------------------===//
21 // Pointer type
22 //===----------------------------------------------------------------------===//
23 
24 constexpr const static unsigned kDefaultPointerSizeBits = 64;
25 constexpr const static unsigned kBitsInByte = 8;
26 constexpr const static unsigned kDefaultPointerAlignment = 8;
27 
28 static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; }
29 
30 /// Searches the data layout for the pointer spec, returns nullptr if it is not
31 /// found.
32 static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) {
33   for (DataLayoutEntryInterface entry : params) {
34     if (!entry.isTypeEntry())
35       continue;
36     if (cast<PtrType>(cast<Type>(entry.getKey())).getMemorySpace() ==
37         type.getMemorySpace()) {
38       if (auto spec = dyn_cast<SpecAttr>(entry.getValue()))
39         return spec;
40     }
41   }
42   // If not found, and this is the pointer to the default memory space, assume
43   // 64-bit pointers.
44   if (type.getMemorySpace() == getDefaultMemorySpace(type))
45     return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits,
46                          kDefaultPointerAlignment, kDefaultPointerAlignment,
47                          kDefaultPointerSizeBits);
48   return nullptr;
49 }
50 
51 bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
52                             DataLayoutEntryListRef newLayout) const {
53   for (DataLayoutEntryInterface newEntry : newLayout) {
54     if (!newEntry.isTypeEntry())
55       continue;
56     uint32_t size = kDefaultPointerSizeBits;
57     uint32_t abi = kDefaultPointerAlignment;
58     auto newType = llvm::cast<PtrType>(llvm::cast<Type>(newEntry.getKey()));
59     const auto *it =
60         llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
61           if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
62             return llvm::cast<PtrType>(type).getMemorySpace() ==
63                    newType.getMemorySpace();
64           }
65           return false;
66         });
67     if (it == oldLayout.end()) {
68       it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
69         if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
70           auto ptrTy = llvm::cast<PtrType>(type);
71           return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy);
72         }
73         return false;
74       });
75     }
76     if (it != oldLayout.end()) {
77       auto spec = llvm::cast<SpecAttr>(*it);
78       size = spec.getSize();
79       abi = spec.getAbi();
80     }
81 
82     auto newSpec = llvm::cast<SpecAttr>(newEntry.getValue());
83     uint32_t newSize = newSpec.getSize();
84     uint32_t newAbi = newSpec.getAbi();
85     if (size != newSize || abi < newAbi || abi % newAbi != 0)
86       return false;
87   }
88   return true;
89 }
90 
91 uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
92                                   DataLayoutEntryListRef params) const {
93   if (SpecAttr spec = getPointerSpec(params, *this))
94     return spec.getAbi() / kBitsInByte;
95 
96   return dataLayout.getTypeABIAlignment(
97       get(getContext(), getDefaultMemorySpace(*this)));
98 }
99 
100 std::optional<uint64_t>
101 PtrType::getIndexBitwidth(const DataLayout &dataLayout,
102                           DataLayoutEntryListRef params) const {
103   if (SpecAttr spec = getPointerSpec(params, *this)) {
104     return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize()
105                                                            : spec.getIndex();
106   }
107 
108   return dataLayout.getTypeIndexBitwidth(
109       get(getContext(), getDefaultMemorySpace(*this)));
110 }
111 
112 llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
113                                           DataLayoutEntryListRef params) const {
114   if (SpecAttr spec = getPointerSpec(params, *this))
115     return llvm::TypeSize::getFixed(spec.getSize());
116 
117   // For other memory spaces, use the size of the pointer to the default memory
118   // space.
119   return dataLayout.getTypeSizeInBits(
120       get(getContext(), getDefaultMemorySpace(*this)));
121 }
122 
123 uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
124                                         DataLayoutEntryListRef params) const {
125   if (SpecAttr spec = getPointerSpec(params, *this))
126     return spec.getPreferred() / kBitsInByte;
127 
128   return dataLayout.getTypePreferredAlignment(
129       get(getContext(), getDefaultMemorySpace(*this)));
130 }
131 
132 LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
133                                      Location loc) const {
134   for (DataLayoutEntryInterface entry : entries) {
135     if (!entry.isTypeEntry())
136       continue;
137     auto key = llvm::cast<Type>(entry.getKey());
138     if (!llvm::isa<SpecAttr>(entry.getValue())) {
139       return emitError(loc) << "expected layout attribute for " << key
140                             << " to be a #ptr.spec attribute";
141     }
142   }
143   return success();
144 }
145