xref: /llvm-project/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Utilities used to get alignment and layout information
10 // for types in SPIR-V dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 
18 using namespace mlir;
19 
20 spirv::StructType
decorateType(spirv::StructType structType)21 VulkanLayoutUtils::decorateType(spirv::StructType structType) {
22   Size size = 0;
23   Size alignment = 1;
24   return decorateType(structType, size, alignment);
25 }
26 
27 spirv::StructType
decorateType(spirv::StructType structType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)28 VulkanLayoutUtils::decorateType(spirv::StructType structType,
29                                 VulkanLayoutUtils::Size &size,
30                                 VulkanLayoutUtils::Size &alignment) {
31   if (structType.getNumElements() == 0) {
32     return structType;
33   }
34 
35   SmallVector<Type, 4> memberTypes;
36   SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
37   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
38 
39   Size structMemberOffset = 0;
40   Size maxMemberAlignment = 1;
41 
42   for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
43     Size memberSize = 0;
44     Size memberAlignment = 1;
45 
46     auto memberType =
47         decorateType(structType.getElementType(i), memberSize, memberAlignment);
48     structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
49     memberTypes.push_back(memberType);
50     offsetInfo.push_back(
51         static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
52     // If the member's size is the max value, it must be the last member and it
53     // must be a runtime array.
54     assert(memberSize != std::numeric_limits<Size>().max() ||
55            (i + 1 == e &&
56             isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
57     // According to the Vulkan spec:
58     // "A structure has a base alignment equal to the largest base alignment of
59     // any of its members."
60     structMemberOffset += memberSize;
61     maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
62   }
63 
64   // According to the Vulkan spec:
65   // "The Offset decoration of a member must not place it between the end of a
66   // structure or an array and the next multiple of the alignment of that
67   // structure or array."
68   size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
69   alignment = maxMemberAlignment;
70   structType.getMemberDecorations(memberDecorations);
71 
72   if (!structType.isIdentified())
73     return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
74 
75   // Identified structs are uniqued by identifier so it is not possible
76   // to create 2 structs with the same name but different decorations.
77   return nullptr;
78 }
79 
decorateType(Type type,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)80 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
81                                      VulkanLayoutUtils::Size &alignment) {
82   if (isa<spirv::ScalarType>(type)) {
83     alignment = getScalarTypeAlignment(type);
84     // Vulkan spec does not specify any padding for a scalar type.
85     size = alignment;
86     return type;
87   }
88   if (auto structType = dyn_cast<spirv::StructType>(type))
89     return decorateType(structType, size, alignment);
90   if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
91     return decorateType(arrayType, size, alignment);
92   if (auto vectorType = dyn_cast<VectorType>(type))
93     return decorateType(vectorType, size, alignment);
94   if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
95     size = std::numeric_limits<Size>().max();
96     return decorateType(arrayType, alignment);
97   }
98   if (isa<spirv::PointerType>(type)) {
99     // TODO: Add support for `PhysicalStorageBufferAddresses`.
100     return nullptr;
101   }
102   llvm_unreachable("unhandled SPIR-V type");
103 }
104 
decorateType(VectorType vectorType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)105 Type VulkanLayoutUtils::decorateType(VectorType vectorType,
106                                      VulkanLayoutUtils::Size &size,
107                                      VulkanLayoutUtils::Size &alignment) {
108   const auto numElements = vectorType.getNumElements();
109   auto elementType = vectorType.getElementType();
110   Size elementSize = 0;
111   Size elementAlignment = 1;
112 
113   auto memberType = decorateType(elementType, elementSize, elementAlignment);
114   // According to the Vulkan spec:
115   // 1. "A two-component vector has a base alignment equal to twice its scalar
116   // alignment."
117   // 2. "A three- or four-component vector has a base alignment equal to four
118   // times its scalar alignment."
119   size = elementSize * numElements;
120   alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
121   return VectorType::get(numElements, memberType);
122 }
123 
decorateType(spirv::ArrayType arrayType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)124 Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
125                                      VulkanLayoutUtils::Size &size,
126                                      VulkanLayoutUtils::Size &alignment) {
127   const auto numElements = arrayType.getNumElements();
128   auto elementType = arrayType.getElementType();
129   Size elementSize = 0;
130   Size elementAlignment = 1;
131 
132   auto memberType = decorateType(elementType, elementSize, elementAlignment);
133   // According to the Vulkan spec:
134   // "An array has a base alignment equal to the base alignment of its element
135   // type."
136   size = elementSize * numElements;
137   alignment = elementAlignment;
138   return spirv::ArrayType::get(memberType, numElements, elementSize);
139 }
140 
decorateType(spirv::RuntimeArrayType arrayType,VulkanLayoutUtils::Size & alignment)141 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
142                                      VulkanLayoutUtils::Size &alignment) {
143   auto elementType = arrayType.getElementType();
144   Size elementSize = 0;
145 
146   auto memberType = decorateType(elementType, elementSize, alignment);
147   return spirv::RuntimeArrayType::get(memberType, elementSize);
148 }
149 
150 VulkanLayoutUtils::Size
getScalarTypeAlignment(Type scalarType)151 VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
152   // According to the Vulkan spec:
153   // 1. "A scalar of size N has a scalar alignment of N."
154   // 2. "A scalar has a base alignment equal to its scalar alignment."
155   // 3. "A scalar, vector or matrix type has an extended alignment equal to its
156   // base alignment."
157   auto bitWidth = scalarType.getIntOrFloatBitWidth();
158   if (bitWidth == 1)
159     return 1;
160   return bitWidth / 8;
161 }
162 
isLegalType(Type type)163 bool VulkanLayoutUtils::isLegalType(Type type) {
164   auto ptrType = dyn_cast<spirv::PointerType>(type);
165   if (!ptrType) {
166     return true;
167   }
168 
169   auto storageClass = ptrType.getStorageClass();
170   auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
171   if (!structType) {
172     return true;
173   }
174 
175   switch (storageClass) {
176   case spirv::StorageClass::Uniform:
177   case spirv::StorageClass::StorageBuffer:
178   case spirv::StorageClass::PushConstant:
179   case spirv::StorageClass::PhysicalStorageBuffer:
180     return structType.hasOffset() || !structType.getNumElements();
181   default:
182     return true;
183   }
184 }
185