xref: /llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp (revision 1079fc4f543c42bb09a33d2d79d90edd9c0bac91)
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
15 
16 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Visitors.h"
26 #include "mlir/Interfaces/FunctionInterfaces.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
38 
39 using namespace mlir;
40 
41 //===----------------------------------------------------------------------===//
42 // Mappings
43 //===----------------------------------------------------------------------===//
44 
45 /// Mapping between SPIR-V storage classes to memref memory spaces.
46 ///
47 /// Note: memref does not have a defined semantics for each memory space; it
48 /// depends on the context where it is used. There are no particular reasons
49 /// behind the number assignments; we try to follow NVVM conventions and largely
50 /// give common storage classes a smaller number.
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
52   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
53   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
54   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
55   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
56   MAP_FN(spirv::StorageClass::Private, 5)                                      \
57   MAP_FN(spirv::StorageClass::Function, 6)                                     \
58   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
59   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
60   MAP_FN(spirv::StorageClass::Input, 9)                                        \
61   MAP_FN(spirv::StorageClass::Output, 10)                                      \
62   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
63 
64 std::optional<spirv::StorageClass>
mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr)65 spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
66   // Handle null memory space attribute specially.
67   if (!memorySpaceAttr)
68     return spirv::StorageClass::StorageBuffer;
69 
70   // Unknown dialect custom attributes are not supported by default.
71   // Downstream callers should plug in more specialized ones.
72   auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
73   if (!intAttr)
74     return std::nullopt;
75   unsigned memorySpace = intAttr.getInt();
76 
77 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
78   case space:                                                                  \
79     return storage;
80 
81   switch (memorySpace) {
82     VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
83   default:
84     break;
85   }
86   return std::nullopt;
87 
88 #undef STORAGE_SPACE_MAP_FN
89 }
90 
91 std::optional<unsigned>
mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass)92 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
93 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
94   case storage:                                                                \
95     return space;
96 
97   switch (storageClass) {
98     VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
99   default:
100     break;
101   }
102   return std::nullopt;
103 
104 #undef STORAGE_SPACE_MAP_FN
105 }
106 
107 #undef VULKAN_STORAGE_SPACE_MAP_LIST
108 
109 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
110   MAP_FN(spirv::StorageClass::CrossWorkgroup, 0)                               \
111   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
112   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
113   MAP_FN(spirv::StorageClass::UniformConstant, 4)                              \
114   MAP_FN(spirv::StorageClass::Private, 5)                                      \
115   MAP_FN(spirv::StorageClass::Function, 6)                                     \
116   MAP_FN(spirv::StorageClass::Image, 7)
117 
118 std::optional<spirv::StorageClass>
mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr)119 spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
120   // Handle null memory space attribute specially.
121   if (!memorySpaceAttr)
122     return spirv::StorageClass::CrossWorkgroup;
123 
124   // Unknown dialect custom attributes are not supported by default.
125   // Downstream callers should plug in more specialized ones.
126   auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
127   if (!intAttr)
128     return std::nullopt;
129   unsigned memorySpace = intAttr.getInt();
130 
131 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
132   case space:                                                                  \
133     return storage;
134 
135   switch (memorySpace) {
136     OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
137   default:
138     break;
139   }
140   return std::nullopt;
141 
142 #undef STORAGE_SPACE_MAP_FN
143 }
144 
145 std::optional<unsigned>
mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass)146 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
147 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
148   case storage:                                                                \
149     return space;
150 
151   switch (storageClass) {
152     OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
153   default:
154     break;
155   }
156   return std::nullopt;
157 
158 #undef STORAGE_SPACE_MAP_FN
159 }
160 
161 #undef OPENCL_STORAGE_SPACE_MAP_LIST
162 
163 //===----------------------------------------------------------------------===//
164 // Type Converter
165 //===----------------------------------------------------------------------===//
166 
MemorySpaceToStorageClassConverter(const spirv::MemorySpaceToStorageClassMap & memorySpaceMap)167 spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
168     const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
169     : memorySpaceMap(memorySpaceMap) {
170   // Pass through for all other types.
171   addConversion([](Type type) { return type; });
172 
173   addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
174     std::optional<spirv::StorageClass> storage =
175         this->memorySpaceMap(memRefType.getMemorySpace());
176     if (!storage) {
177       LLVM_DEBUG(llvm::dbgs()
178                  << "cannot convert " << memRefType
179                  << " due to being unable to find memory space in map\n");
180       return std::nullopt;
181     }
182 
183     auto storageAttr =
184         spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
185     if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
186       return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
187                              rankedType.getLayout(), storageAttr);
188     }
189     return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
190   });
191 
192   addConversion([this](FunctionType type) {
193     auto inputs = llvm::map_to_vector(
194         type.getInputs(), [this](Type ty) { return convertType(ty); });
195     auto results = llvm::map_to_vector(
196         type.getResults(), [this](Type ty) { return convertType(ty); });
197     return FunctionType::get(type.getContext(), inputs, results);
198   });
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // Conversion Target
203 //===----------------------------------------------------------------------===//
204 
205 /// Returns true if the given `type` is considered as legal for SPIR-V
206 /// conversion.
isLegalType(Type type)207 static bool isLegalType(Type type) {
208   if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
209     Attribute spaceAttr = memRefType.getMemorySpace();
210     return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
211   }
212   return true;
213 }
214 
215 /// Returns true if the given `attr` is considered as legal for SPIR-V
216 /// conversion.
isLegalAttr(Attribute attr)217 static bool isLegalAttr(Attribute attr) {
218   if (auto typeAttr = dyn_cast<TypeAttr>(attr))
219     return isLegalType(typeAttr.getValue());
220   return true;
221 }
222 
223 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
isLegalOp(Operation * op)224 static bool isLegalOp(Operation *op) {
225   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
226     return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
227            llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
228            llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
229                         isLegalType);
230   }
231 
232   auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
233     return attr.getValue();
234   });
235 
236   return llvm::all_of(op->getOperandTypes(), isLegalType) &&
237          llvm::all_of(op->getResultTypes(), isLegalType) &&
238          llvm::all_of(attrs, isLegalAttr);
239 }
240 
241 std::unique_ptr<ConversionTarget>
getMemorySpaceToStorageClassTarget(MLIRContext & context)242 spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
243   auto target = std::make_unique<ConversionTarget>(context);
244   target->markUnknownOpDynamicallyLegal(isLegalOp);
245   return target;
246 }
247 
convertMemRefTypesAndAttrs(Operation * op,MemorySpaceToStorageClassConverter & typeConverter)248 void spirv::convertMemRefTypesAndAttrs(
249     Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
250   AttrTypeReplacer replacer;
251   replacer.addReplacement([&typeConverter](BaseMemRefType origType)
252                               -> std::optional<BaseMemRefType> {
253     return typeConverter.convertType<BaseMemRefType>(origType);
254   });
255 
256   replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
257                                         /*replaceLocs=*/false,
258                                         /*replaceTypes=*/true);
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Conversion Pass
263 //===----------------------------------------------------------------------===//
264 
265 namespace {
266 class MapMemRefStorageClassPass final
267     : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
268 public:
269   MapMemRefStorageClassPass() = default;
270 
MapMemRefStorageClassPass(const spirv::MemorySpaceToStorageClassMap & memorySpaceMap)271   explicit MapMemRefStorageClassPass(
272       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
273       : memorySpaceMap(memorySpaceMap) {}
274 
initializeOptions(StringRef options,function_ref<LogicalResult (const Twine &)> errorHandler)275   LogicalResult initializeOptions(
276       StringRef options,
277       function_ref<LogicalResult(const Twine &)> errorHandler) override {
278     if (failed(Pass::initializeOptions(options, errorHandler)))
279       return failure();
280 
281     if (clientAPI == "opencl")
282       memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
283     else if (clientAPI != "vulkan")
284       return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
285 
286     return success();
287   }
288 
runOnOperation()289   void runOnOperation() override {
290     MLIRContext *context = &getContext();
291     Operation *op = getOperation();
292 
293     spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
294     if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
295       spirv::TargetEnv targetEnv(attr);
296       if (targetEnv.allows(spirv::Capability::Kernel)) {
297         spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
298       } else if (targetEnv.allows(spirv::Capability::Shader)) {
299         spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
300       }
301     }
302 
303     spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
304     // Perform the replacement.
305     spirv::convertMemRefTypesAndAttrs(op, converter);
306 
307     // Check if there are any illegal ops remaining.
308     std::unique_ptr<ConversionTarget> target =
309         spirv::getMemorySpaceToStorageClassTarget(*context);
310     op->walk([&target, this](Operation *childOp) {
311       if (target->isIllegal(childOp)) {
312         childOp->emitOpError("failed to legalize memory space");
313         signalPassFailure();
314         return WalkResult::interrupt();
315       }
316       return WalkResult::advance();
317     });
318   }
319 
320 private:
321   spirv::MemorySpaceToStorageClassMap memorySpaceMap =
322       spirv::mapMemorySpaceToVulkanStorageClass;
323 };
324 } // namespace
325 
createMapMemRefStorageClassPass()326 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
327   return std::make_unique<MapMemRefStorageClassPass>();
328 }
329