101178654SLei Zhang //===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
201178654SLei Zhang //
301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
601178654SLei Zhang //
701178654SLei Zhang //===----------------------------------------------------------------------===//
801178654SLei Zhang //
901178654SLei Zhang // This file implements Utilities used to get alignment and layout information
1001178654SLei Zhang // for types in SPIR-V dialect.
1101178654SLei Zhang //
1201178654SLei Zhang //===----------------------------------------------------------------------===//
1301178654SLei Zhang
1401178654SLei Zhang #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
1501178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1701178654SLei Zhang
1801178654SLei Zhang using namespace mlir;
1901178654SLei Zhang
2001178654SLei Zhang spirv::StructType
decorateType(spirv::StructType structType)2101178654SLei Zhang VulkanLayoutUtils::decorateType(spirv::StructType structType) {
2201178654SLei Zhang Size size = 0;
2301178654SLei Zhang Size alignment = 1;
2401178654SLei Zhang return decorateType(structType, size, alignment);
2501178654SLei Zhang }
2601178654SLei Zhang
2701178654SLei Zhang spirv::StructType
decorateType(spirv::StructType structType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)2801178654SLei Zhang VulkanLayoutUtils::decorateType(spirv::StructType structType,
2901178654SLei Zhang VulkanLayoutUtils::Size &size,
3001178654SLei Zhang VulkanLayoutUtils::Size &alignment) {
3101178654SLei Zhang if (structType.getNumElements() == 0) {
3201178654SLei Zhang return structType;
3301178654SLei Zhang }
3401178654SLei Zhang
3501178654SLei Zhang SmallVector<Type, 4> memberTypes;
3601178654SLei Zhang SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
3701178654SLei Zhang SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
3801178654SLei Zhang
3901178654SLei Zhang Size structMemberOffset = 0;
4001178654SLei Zhang Size maxMemberAlignment = 1;
4101178654SLei Zhang
4201178654SLei Zhang for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
4301178654SLei Zhang Size memberSize = 0;
4401178654SLei Zhang Size memberAlignment = 1;
4501178654SLei Zhang
4601178654SLei Zhang auto memberType =
4701178654SLei Zhang decorateType(structType.getElementType(i), memberSize, memberAlignment);
4801178654SLei Zhang structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
4901178654SLei Zhang memberTypes.push_back(memberType);
5001178654SLei Zhang offsetInfo.push_back(
5101178654SLei Zhang static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
5201178654SLei Zhang // If the member's size is the max value, it must be the last member and it
5301178654SLei Zhang // must be a runtime array.
5401178654SLei Zhang assert(memberSize != std::numeric_limits<Size>().max() ||
5501178654SLei Zhang (i + 1 == e &&
56*5550c821STres Popp isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
5701178654SLei Zhang // According to the Vulkan spec:
5801178654SLei Zhang // "A structure has a base alignment equal to the largest base alignment of
5901178654SLei Zhang // any of its members."
6001178654SLei Zhang structMemberOffset += memberSize;
6101178654SLei Zhang maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
6201178654SLei Zhang }
6301178654SLei Zhang
6401178654SLei Zhang // According to the Vulkan spec:
6501178654SLei Zhang // "The Offset decoration of a member must not place it between the end of a
6601178654SLei Zhang // structure or an array and the next multiple of the alignment of that
6701178654SLei Zhang // structure or array."
6801178654SLei Zhang size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
6901178654SLei Zhang alignment = maxMemberAlignment;
7001178654SLei Zhang structType.getMemberDecorations(memberDecorations);
7101178654SLei Zhang
7201178654SLei Zhang if (!structType.isIdentified())
7301178654SLei Zhang return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
7401178654SLei Zhang
7501178654SLei Zhang // Identified structs are uniqued by identifier so it is not possible
7601178654SLei Zhang // to create 2 structs with the same name but different decorations.
7701178654SLei Zhang return nullptr;
7801178654SLei Zhang }
7901178654SLei Zhang
decorateType(Type type,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)8001178654SLei Zhang Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
8101178654SLei Zhang VulkanLayoutUtils::Size &alignment) {
82*5550c821STres Popp if (isa<spirv::ScalarType>(type)) {
8301178654SLei Zhang alignment = getScalarTypeAlignment(type);
8401178654SLei Zhang // Vulkan spec does not specify any padding for a scalar type.
8501178654SLei Zhang size = alignment;
8601178654SLei Zhang return type;
8701178654SLei Zhang }
88*5550c821STres Popp if (auto structType = dyn_cast<spirv::StructType>(type))
8901178654SLei Zhang return decorateType(structType, size, alignment);
90*5550c821STres Popp if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
9101178654SLei Zhang return decorateType(arrayType, size, alignment);
92*5550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(type))
9301178654SLei Zhang return decorateType(vectorType, size, alignment);
94*5550c821STres Popp if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
9501178654SLei Zhang size = std::numeric_limits<Size>().max();
9601178654SLei Zhang return decorateType(arrayType, alignment);
9701178654SLei Zhang }
98*5550c821STres Popp if (isa<spirv::PointerType>(type)) {
9931c35285SJakub Kuderski // TODO: Add support for `PhysicalStorageBufferAddresses`.
10031c35285SJakub Kuderski return nullptr;
10131c35285SJakub Kuderski }
10201178654SLei Zhang llvm_unreachable("unhandled SPIR-V type");
10301178654SLei Zhang }
10401178654SLei Zhang
decorateType(VectorType vectorType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)10501178654SLei Zhang Type VulkanLayoutUtils::decorateType(VectorType vectorType,
10601178654SLei Zhang VulkanLayoutUtils::Size &size,
10701178654SLei Zhang VulkanLayoutUtils::Size &alignment) {
10801178654SLei Zhang const auto numElements = vectorType.getNumElements();
10901178654SLei Zhang auto elementType = vectorType.getElementType();
11001178654SLei Zhang Size elementSize = 0;
11101178654SLei Zhang Size elementAlignment = 1;
11201178654SLei Zhang
11301178654SLei Zhang auto memberType = decorateType(elementType, elementSize, elementAlignment);
11401178654SLei Zhang // According to the Vulkan spec:
11501178654SLei Zhang // 1. "A two-component vector has a base alignment equal to twice its scalar
11601178654SLei Zhang // alignment."
11701178654SLei Zhang // 2. "A three- or four-component vector has a base alignment equal to four
11801178654SLei Zhang // times its scalar alignment."
11901178654SLei Zhang size = elementSize * numElements;
12001178654SLei Zhang alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
12101178654SLei Zhang return VectorType::get(numElements, memberType);
12201178654SLei Zhang }
12301178654SLei Zhang
decorateType(spirv::ArrayType arrayType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)12401178654SLei Zhang Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
12501178654SLei Zhang VulkanLayoutUtils::Size &size,
12601178654SLei Zhang VulkanLayoutUtils::Size &alignment) {
12701178654SLei Zhang const auto numElements = arrayType.getNumElements();
12801178654SLei Zhang auto elementType = arrayType.getElementType();
12901178654SLei Zhang Size elementSize = 0;
13001178654SLei Zhang Size elementAlignment = 1;
13101178654SLei Zhang
13201178654SLei Zhang auto memberType = decorateType(elementType, elementSize, elementAlignment);
13301178654SLei Zhang // According to the Vulkan spec:
13401178654SLei Zhang // "An array has a base alignment equal to the base alignment of its element
13501178654SLei Zhang // type."
13601178654SLei Zhang size = elementSize * numElements;
13701178654SLei Zhang alignment = elementAlignment;
13801178654SLei Zhang return spirv::ArrayType::get(memberType, numElements, elementSize);
13901178654SLei Zhang }
14001178654SLei Zhang
decorateType(spirv::RuntimeArrayType arrayType,VulkanLayoutUtils::Size & alignment)14101178654SLei Zhang Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
14201178654SLei Zhang VulkanLayoutUtils::Size &alignment) {
14301178654SLei Zhang auto elementType = arrayType.getElementType();
14401178654SLei Zhang Size elementSize = 0;
14501178654SLei Zhang
14601178654SLei Zhang auto memberType = decorateType(elementType, elementSize, alignment);
14701178654SLei Zhang return spirv::RuntimeArrayType::get(memberType, elementSize);
14801178654SLei Zhang }
14901178654SLei Zhang
15001178654SLei Zhang VulkanLayoutUtils::Size
getScalarTypeAlignment(Type scalarType)15101178654SLei Zhang VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
15201178654SLei Zhang // According to the Vulkan spec:
15301178654SLei Zhang // 1. "A scalar of size N has a scalar alignment of N."
15401178654SLei Zhang // 2. "A scalar has a base alignment equal to its scalar alignment."
15501178654SLei Zhang // 3. "A scalar, vector or matrix type has an extended alignment equal to its
15601178654SLei Zhang // base alignment."
15701178654SLei Zhang auto bitWidth = scalarType.getIntOrFloatBitWidth();
15801178654SLei Zhang if (bitWidth == 1)
15901178654SLei Zhang return 1;
16001178654SLei Zhang return bitWidth / 8;
16101178654SLei Zhang }
16201178654SLei Zhang
isLegalType(Type type)16301178654SLei Zhang bool VulkanLayoutUtils::isLegalType(Type type) {
164*5550c821STres Popp auto ptrType = dyn_cast<spirv::PointerType>(type);
16501178654SLei Zhang if (!ptrType) {
16601178654SLei Zhang return true;
16701178654SLei Zhang }
16801178654SLei Zhang
16901178654SLei Zhang auto storageClass = ptrType.getStorageClass();
170*5550c821STres Popp auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
17101178654SLei Zhang if (!structType) {
17201178654SLei Zhang return true;
17301178654SLei Zhang }
17401178654SLei Zhang
17501178654SLei Zhang switch (storageClass) {
17601178654SLei Zhang case spirv::StorageClass::Uniform:
17701178654SLei Zhang case spirv::StorageClass::StorageBuffer:
17801178654SLei Zhang case spirv::StorageClass::PushConstant:
17901178654SLei Zhang case spirv::StorageClass::PhysicalStorageBuffer:
18001178654SLei Zhang return structType.hasOffset() || !structType.getNumElements();
18101178654SLei Zhang default:
18201178654SLei Zhang return true;
18301178654SLei Zhang }
18401178654SLei Zhang }
185