xref: /llvm-project/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
101178654SLei Zhang //===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
201178654SLei Zhang //
301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
601178654SLei Zhang //
701178654SLei Zhang //===----------------------------------------------------------------------===//
801178654SLei Zhang //
901178654SLei Zhang // This file implements Utilities used to get alignment and layout information
1001178654SLei Zhang // for types in SPIR-V dialect.
1101178654SLei Zhang //
1201178654SLei Zhang //===----------------------------------------------------------------------===//
1301178654SLei Zhang 
1401178654SLei Zhang #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
1501178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1701178654SLei Zhang 
1801178654SLei Zhang using namespace mlir;
1901178654SLei Zhang 
2001178654SLei Zhang spirv::StructType
decorateType(spirv::StructType structType)2101178654SLei Zhang VulkanLayoutUtils::decorateType(spirv::StructType structType) {
2201178654SLei Zhang   Size size = 0;
2301178654SLei Zhang   Size alignment = 1;
2401178654SLei Zhang   return decorateType(structType, size, alignment);
2501178654SLei Zhang }
2601178654SLei Zhang 
2701178654SLei Zhang spirv::StructType
decorateType(spirv::StructType structType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)2801178654SLei Zhang VulkanLayoutUtils::decorateType(spirv::StructType structType,
2901178654SLei Zhang                                 VulkanLayoutUtils::Size &size,
3001178654SLei Zhang                                 VulkanLayoutUtils::Size &alignment) {
3101178654SLei Zhang   if (structType.getNumElements() == 0) {
3201178654SLei Zhang     return structType;
3301178654SLei Zhang   }
3401178654SLei Zhang 
3501178654SLei Zhang   SmallVector<Type, 4> memberTypes;
3601178654SLei Zhang   SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
3701178654SLei Zhang   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
3801178654SLei Zhang 
3901178654SLei Zhang   Size structMemberOffset = 0;
4001178654SLei Zhang   Size maxMemberAlignment = 1;
4101178654SLei Zhang 
4201178654SLei Zhang   for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
4301178654SLei Zhang     Size memberSize = 0;
4401178654SLei Zhang     Size memberAlignment = 1;
4501178654SLei Zhang 
4601178654SLei Zhang     auto memberType =
4701178654SLei Zhang         decorateType(structType.getElementType(i), memberSize, memberAlignment);
4801178654SLei Zhang     structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
4901178654SLei Zhang     memberTypes.push_back(memberType);
5001178654SLei Zhang     offsetInfo.push_back(
5101178654SLei Zhang         static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
5201178654SLei Zhang     // If the member's size is the max value, it must be the last member and it
5301178654SLei Zhang     // must be a runtime array.
5401178654SLei Zhang     assert(memberSize != std::numeric_limits<Size>().max() ||
5501178654SLei Zhang            (i + 1 == e &&
56*5550c821STres Popp             isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
5701178654SLei Zhang     // According to the Vulkan spec:
5801178654SLei Zhang     // "A structure has a base alignment equal to the largest base alignment of
5901178654SLei Zhang     // any of its members."
6001178654SLei Zhang     structMemberOffset += memberSize;
6101178654SLei Zhang     maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
6201178654SLei Zhang   }
6301178654SLei Zhang 
6401178654SLei Zhang   // According to the Vulkan spec:
6501178654SLei Zhang   // "The Offset decoration of a member must not place it between the end of a
6601178654SLei Zhang   // structure or an array and the next multiple of the alignment of that
6701178654SLei Zhang   // structure or array."
6801178654SLei Zhang   size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
6901178654SLei Zhang   alignment = maxMemberAlignment;
7001178654SLei Zhang   structType.getMemberDecorations(memberDecorations);
7101178654SLei Zhang 
7201178654SLei Zhang   if (!structType.isIdentified())
7301178654SLei Zhang     return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
7401178654SLei Zhang 
7501178654SLei Zhang   // Identified structs are uniqued by identifier so it is not possible
7601178654SLei Zhang   // to create 2 structs with the same name but different decorations.
7701178654SLei Zhang   return nullptr;
7801178654SLei Zhang }
7901178654SLei Zhang 
decorateType(Type type,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)8001178654SLei Zhang Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
8101178654SLei Zhang                                      VulkanLayoutUtils::Size &alignment) {
82*5550c821STres Popp   if (isa<spirv::ScalarType>(type)) {
8301178654SLei Zhang     alignment = getScalarTypeAlignment(type);
8401178654SLei Zhang     // Vulkan spec does not specify any padding for a scalar type.
8501178654SLei Zhang     size = alignment;
8601178654SLei Zhang     return type;
8701178654SLei Zhang   }
88*5550c821STres Popp   if (auto structType = dyn_cast<spirv::StructType>(type))
8901178654SLei Zhang     return decorateType(structType, size, alignment);
90*5550c821STres Popp   if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
9101178654SLei Zhang     return decorateType(arrayType, size, alignment);
92*5550c821STres Popp   if (auto vectorType = dyn_cast<VectorType>(type))
9301178654SLei Zhang     return decorateType(vectorType, size, alignment);
94*5550c821STres Popp   if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
9501178654SLei Zhang     size = std::numeric_limits<Size>().max();
9601178654SLei Zhang     return decorateType(arrayType, alignment);
9701178654SLei Zhang   }
98*5550c821STres Popp   if (isa<spirv::PointerType>(type)) {
9931c35285SJakub Kuderski     // TODO: Add support for `PhysicalStorageBufferAddresses`.
10031c35285SJakub Kuderski     return nullptr;
10131c35285SJakub Kuderski   }
10201178654SLei Zhang   llvm_unreachable("unhandled SPIR-V type");
10301178654SLei Zhang }
10401178654SLei Zhang 
decorateType(VectorType vectorType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)10501178654SLei Zhang Type VulkanLayoutUtils::decorateType(VectorType vectorType,
10601178654SLei Zhang                                      VulkanLayoutUtils::Size &size,
10701178654SLei Zhang                                      VulkanLayoutUtils::Size &alignment) {
10801178654SLei Zhang   const auto numElements = vectorType.getNumElements();
10901178654SLei Zhang   auto elementType = vectorType.getElementType();
11001178654SLei Zhang   Size elementSize = 0;
11101178654SLei Zhang   Size elementAlignment = 1;
11201178654SLei Zhang 
11301178654SLei Zhang   auto memberType = decorateType(elementType, elementSize, elementAlignment);
11401178654SLei Zhang   // According to the Vulkan spec:
11501178654SLei Zhang   // 1. "A two-component vector has a base alignment equal to twice its scalar
11601178654SLei Zhang   // alignment."
11701178654SLei Zhang   // 2. "A three- or four-component vector has a base alignment equal to four
11801178654SLei Zhang   // times its scalar alignment."
11901178654SLei Zhang   size = elementSize * numElements;
12001178654SLei Zhang   alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
12101178654SLei Zhang   return VectorType::get(numElements, memberType);
12201178654SLei Zhang }
12301178654SLei Zhang 
decorateType(spirv::ArrayType arrayType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)12401178654SLei Zhang Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
12501178654SLei Zhang                                      VulkanLayoutUtils::Size &size,
12601178654SLei Zhang                                      VulkanLayoutUtils::Size &alignment) {
12701178654SLei Zhang   const auto numElements = arrayType.getNumElements();
12801178654SLei Zhang   auto elementType = arrayType.getElementType();
12901178654SLei Zhang   Size elementSize = 0;
13001178654SLei Zhang   Size elementAlignment = 1;
13101178654SLei Zhang 
13201178654SLei Zhang   auto memberType = decorateType(elementType, elementSize, elementAlignment);
13301178654SLei Zhang   // According to the Vulkan spec:
13401178654SLei Zhang   // "An array has a base alignment equal to the base alignment of its element
13501178654SLei Zhang   // type."
13601178654SLei Zhang   size = elementSize * numElements;
13701178654SLei Zhang   alignment = elementAlignment;
13801178654SLei Zhang   return spirv::ArrayType::get(memberType, numElements, elementSize);
13901178654SLei Zhang }
14001178654SLei Zhang 
decorateType(spirv::RuntimeArrayType arrayType,VulkanLayoutUtils::Size & alignment)14101178654SLei Zhang Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
14201178654SLei Zhang                                      VulkanLayoutUtils::Size &alignment) {
14301178654SLei Zhang   auto elementType = arrayType.getElementType();
14401178654SLei Zhang   Size elementSize = 0;
14501178654SLei Zhang 
14601178654SLei Zhang   auto memberType = decorateType(elementType, elementSize, alignment);
14701178654SLei Zhang   return spirv::RuntimeArrayType::get(memberType, elementSize);
14801178654SLei Zhang }
14901178654SLei Zhang 
15001178654SLei Zhang VulkanLayoutUtils::Size
getScalarTypeAlignment(Type scalarType)15101178654SLei Zhang VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
15201178654SLei Zhang   // According to the Vulkan spec:
15301178654SLei Zhang   // 1. "A scalar of size N has a scalar alignment of N."
15401178654SLei Zhang   // 2. "A scalar has a base alignment equal to its scalar alignment."
15501178654SLei Zhang   // 3. "A scalar, vector or matrix type has an extended alignment equal to its
15601178654SLei Zhang   // base alignment."
15701178654SLei Zhang   auto bitWidth = scalarType.getIntOrFloatBitWidth();
15801178654SLei Zhang   if (bitWidth == 1)
15901178654SLei Zhang     return 1;
16001178654SLei Zhang   return bitWidth / 8;
16101178654SLei Zhang }
16201178654SLei Zhang 
isLegalType(Type type)16301178654SLei Zhang bool VulkanLayoutUtils::isLegalType(Type type) {
164*5550c821STres Popp   auto ptrType = dyn_cast<spirv::PointerType>(type);
16501178654SLei Zhang   if (!ptrType) {
16601178654SLei Zhang     return true;
16701178654SLei Zhang   }
16801178654SLei Zhang 
16901178654SLei Zhang   auto storageClass = ptrType.getStorageClass();
170*5550c821STres Popp   auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
17101178654SLei Zhang   if (!structType) {
17201178654SLei Zhang     return true;
17301178654SLei Zhang   }
17401178654SLei Zhang 
17501178654SLei Zhang   switch (storageClass) {
17601178654SLei Zhang   case spirv::StorageClass::Uniform:
17701178654SLei Zhang   case spirv::StorageClass::StorageBuffer:
17801178654SLei Zhang   case spirv::StorageClass::PushConstant:
17901178654SLei Zhang   case spirv::StorageClass::PhysicalStorageBuffer:
18001178654SLei Zhang     return structType.hasOffset() || !structType.getNumElements();
18101178654SLei Zhang   default:
18201178654SLei Zhang     return true;
18301178654SLei Zhang   }
18401178654SLei Zhang }
185