xref: /llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1713f85d5SLei Zhang //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2713f85d5SLei Zhang //
3713f85d5SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4713f85d5SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5713f85d5SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6713f85d5SLei Zhang //
7713f85d5SLei Zhang //===----------------------------------------------------------------------===//
8713f85d5SLei Zhang //
9713f85d5SLei Zhang // This file implements a pass to map numeric MemRef memory spaces to
10713f85d5SLei Zhang // symbolic ones defined in the SPIR-V specification.
11713f85d5SLei Zhang //
12713f85d5SLei Zhang //===----------------------------------------------------------------------===//
13713f85d5SLei Zhang 
14*039b969bSMichele Scuttari #include "../PassDetail.h"
152be8af8fSMichele Scuttari #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
16*039b969bSMichele Scuttari #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
17a29fffc4SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18713f85d5SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1915135553SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20713f85d5SLei Zhang #include "mlir/IR/BuiltinTypes.h"
21b83d0d46SLei Zhang #include "mlir/IR/FunctionInterfaces.h"
22713f85d5SLei Zhang #include "mlir/Transforms/DialectConversion.h"
23713f85d5SLei Zhang #include "llvm/ADT/StringExtras.h"
24713f85d5SLei Zhang #include "llvm/Support/Debug.h"
25713f85d5SLei Zhang 
26713f85d5SLei Zhang #define DEBUG_TYPE "mlir-map-memref-storage-class"
27713f85d5SLei Zhang 
28713f85d5SLei Zhang using namespace mlir;
29713f85d5SLei Zhang 
30713f85d5SLei Zhang //===----------------------------------------------------------------------===//
311f7544a6SLei Zhang // Mappings
32713f85d5SLei Zhang //===----------------------------------------------------------------------===//
33713f85d5SLei Zhang 
341f7544a6SLei Zhang /// Mapping between SPIR-V storage classes to memref memory spaces.
351f7544a6SLei Zhang ///
361f7544a6SLei Zhang /// Note: memref does not have a defined semantics for each memory space; it
371f7544a6SLei Zhang /// depends on the context where it is used. There are no particular reasons
381f7544a6SLei Zhang /// behind the number assignments; we try to follow NVVM conventions and largely
391f7544a6SLei Zhang /// give common storage classes a smaller number.
4035a56e5dSStanley Winata #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
411f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
421f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
431f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
441f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
451f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Private, 5)                                      \
461f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Function, 6)                                     \
471f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
481f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
491f7544a6SLei Zhang   MAP_FN(spirv::StorageClass::Input, 9)                                        \
5015135553SLei Zhang   MAP_FN(spirv::StorageClass::Output, 10)
511f7544a6SLei Zhang 
5215135553SLei Zhang Optional<spirv::StorageClass>
5315135553SLei Zhang spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
5415135553SLei Zhang #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
5515135553SLei Zhang   case space:                                                                  \
5615135553SLei Zhang     return storage;
571f7544a6SLei Zhang 
5815135553SLei Zhang   switch (memorySpace) {
5935a56e5dSStanley Winata     VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
6015135553SLei Zhang   default:
6115135553SLei Zhang     break;
6215135553SLei Zhang   }
6315135553SLei Zhang   return llvm::None;
641f7544a6SLei Zhang 
651f7544a6SLei Zhang #undef STORAGE_SPACE_MAP_FN
66713f85d5SLei Zhang }
67713f85d5SLei Zhang 
6815135553SLei Zhang Optional<unsigned>
6915135553SLei Zhang spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
7015135553SLei Zhang #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
7115135553SLei Zhang   case storage:                                                                \
7215135553SLei Zhang     return space;
7315135553SLei Zhang 
7415135553SLei Zhang   switch (storageClass) {
7535a56e5dSStanley Winata     VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
7615135553SLei Zhang   default:
7715135553SLei Zhang     break;
7815135553SLei Zhang   }
7915135553SLei Zhang   return llvm::None;
8015135553SLei Zhang 
8115135553SLei Zhang #undef STORAGE_SPACE_MAP_FN
8215135553SLei Zhang }
8315135553SLei Zhang 
8435a56e5dSStanley Winata #undef VULKAN_STORAGE_SPACE_MAP_LIST
8535a56e5dSStanley Winata 
8635a56e5dSStanley Winata #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
8735a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::CrossWorkgroup, 0)                               \
8835a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
8935a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
9035a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::UniformConstant, 4)                              \
9135a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::Private, 5)                                      \
9235a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::Function, 6)                                     \
9335a56e5dSStanley Winata   MAP_FN(spirv::StorageClass::Image, 7)
9435a56e5dSStanley Winata 
9535a56e5dSStanley Winata Optional<spirv::StorageClass>
9635a56e5dSStanley Winata spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) {
9735a56e5dSStanley Winata #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
9835a56e5dSStanley Winata   case space:                                                                  \
9935a56e5dSStanley Winata     return storage;
10035a56e5dSStanley Winata 
10135a56e5dSStanley Winata   switch (memorySpace) {
10235a56e5dSStanley Winata     OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
10335a56e5dSStanley Winata   default:
10435a56e5dSStanley Winata     break;
10535a56e5dSStanley Winata   }
10635a56e5dSStanley Winata   return llvm::None;
10735a56e5dSStanley Winata 
10835a56e5dSStanley Winata #undef STORAGE_SPACE_MAP_FN
10935a56e5dSStanley Winata }
11035a56e5dSStanley Winata 
11135a56e5dSStanley Winata Optional<unsigned>
11235a56e5dSStanley Winata spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
11335a56e5dSStanley Winata #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
11435a56e5dSStanley Winata   case storage:                                                                \
11535a56e5dSStanley Winata     return space;
11635a56e5dSStanley Winata 
11735a56e5dSStanley Winata   switch (storageClass) {
11835a56e5dSStanley Winata     OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
11935a56e5dSStanley Winata   default:
12035a56e5dSStanley Winata     break;
12135a56e5dSStanley Winata   }
12235a56e5dSStanley Winata   return llvm::None;
12335a56e5dSStanley Winata 
12435a56e5dSStanley Winata #undef STORAGE_SPACE_MAP_FN
12535a56e5dSStanley Winata }
12635a56e5dSStanley Winata 
12735a56e5dSStanley Winata #undef OPENCL_STORAGE_SPACE_MAP_LIST
12815135553SLei Zhang 
129713f85d5SLei Zhang //===----------------------------------------------------------------------===//
130713f85d5SLei Zhang // Type Converter
131713f85d5SLei Zhang //===----------------------------------------------------------------------===//
132713f85d5SLei Zhang 
133713f85d5SLei Zhang spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
134713f85d5SLei Zhang     const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
135713f85d5SLei Zhang     : memorySpaceMap(memorySpaceMap) {
136713f85d5SLei Zhang   // Pass through for all other types.
137713f85d5SLei Zhang   addConversion([](Type type) { return type; });
138713f85d5SLei Zhang 
139713f85d5SLei Zhang   addConversion([this](BaseMemRefType memRefType) -> Optional<Type> {
140713f85d5SLei Zhang     // Expect IntegerAttr memory spaces. The attribute can be missing for the
141713f85d5SLei Zhang     // case of memory space == 0.
142713f85d5SLei Zhang     Attribute spaceAttr = memRefType.getMemorySpace();
143713f85d5SLei Zhang     if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) {
144713f85d5SLei Zhang       LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
145b83d0d46SLei Zhang                               << " due to non-IntegerAttr memory space\n");
146713f85d5SLei Zhang       return llvm::None;
147713f85d5SLei Zhang     }
148713f85d5SLei Zhang 
149713f85d5SLei Zhang     unsigned space = memRefType.getMemorySpaceAsInt();
15015135553SLei Zhang     auto storage = this->memorySpaceMap(space);
15115135553SLei Zhang     if (!storage) {
152b83d0d46SLei Zhang       LLVM_DEBUG(llvm::dbgs()
153b83d0d46SLei Zhang                  << "cannot convert " << memRefType
154b83d0d46SLei Zhang                  << " due to being unable to find memory space in map\n");
155713f85d5SLei Zhang       return llvm::None;
156713f85d5SLei Zhang     }
157713f85d5SLei Zhang 
158713f85d5SLei Zhang     auto storageAttr =
15915135553SLei Zhang         spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
160713f85d5SLei Zhang     if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
161713f85d5SLei Zhang       return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
162713f85d5SLei Zhang                              rankedType.getLayout(), storageAttr);
163713f85d5SLei Zhang     }
164713f85d5SLei Zhang     return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
165713f85d5SLei Zhang   });
166713f85d5SLei Zhang 
167713f85d5SLei Zhang   addConversion([this](FunctionType type) {
168713f85d5SLei Zhang     SmallVector<Type> inputs, results;
169713f85d5SLei Zhang     inputs.reserve(type.getNumInputs());
170713f85d5SLei Zhang     results.reserve(type.getNumResults());
171713f85d5SLei Zhang     for (Type input : type.getInputs())
172713f85d5SLei Zhang       inputs.push_back(convertType(input));
173713f85d5SLei Zhang     for (Type result : type.getResults())
174713f85d5SLei Zhang       results.push_back(convertType(result));
175713f85d5SLei Zhang     return FunctionType::get(type.getContext(), inputs, results);
176713f85d5SLei Zhang   });
177713f85d5SLei Zhang }
178713f85d5SLei Zhang 
179713f85d5SLei Zhang //===----------------------------------------------------------------------===//
180713f85d5SLei Zhang // Conversion Target
181713f85d5SLei Zhang //===----------------------------------------------------------------------===//
182713f85d5SLei Zhang 
183713f85d5SLei Zhang /// Returns true if the given `type` is considered as legal for SPIR-V
184713f85d5SLei Zhang /// conversion.
185713f85d5SLei Zhang static bool isLegalType(Type type) {
186713f85d5SLei Zhang   if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
187713f85d5SLei Zhang     Attribute spaceAttr = memRefType.getMemorySpace();
188713f85d5SLei Zhang     return spaceAttr && spaceAttr.isa<spirv::StorageClassAttr>();
189713f85d5SLei Zhang   }
190713f85d5SLei Zhang   return true;
191713f85d5SLei Zhang }
192713f85d5SLei Zhang 
193713f85d5SLei Zhang /// Returns true if the given `attr` is considered as legal for SPIR-V
194713f85d5SLei Zhang /// conversion.
195713f85d5SLei Zhang static bool isLegalAttr(Attribute attr) {
196713f85d5SLei Zhang   if (auto typeAttr = attr.dyn_cast<TypeAttr>())
197713f85d5SLei Zhang     return isLegalType(typeAttr.getValue());
198713f85d5SLei Zhang   return true;
199713f85d5SLei Zhang }
200713f85d5SLei Zhang 
201713f85d5SLei Zhang /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
202713f85d5SLei Zhang static bool isLegalOp(Operation *op) {
203b83d0d46SLei Zhang   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
204b83d0d46SLei Zhang     return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
205b83d0d46SLei Zhang            llvm::all_of(funcOp.getResultTypes(), isLegalType);
206713f85d5SLei Zhang   }
207713f85d5SLei Zhang 
208713f85d5SLei Zhang   auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
209713f85d5SLei Zhang     return attr.getValue();
210713f85d5SLei Zhang   });
211713f85d5SLei Zhang 
212713f85d5SLei Zhang   return llvm::all_of(op->getOperandTypes(), isLegalType) &&
213713f85d5SLei Zhang          llvm::all_of(op->getResultTypes(), isLegalType) &&
214713f85d5SLei Zhang          llvm::all_of(attrs, isLegalAttr);
215713f85d5SLei Zhang }
216713f85d5SLei Zhang 
217713f85d5SLei Zhang std::unique_ptr<ConversionTarget>
218713f85d5SLei Zhang spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
219713f85d5SLei Zhang   auto target = std::make_unique<ConversionTarget>(context);
220713f85d5SLei Zhang   target->markUnknownOpDynamicallyLegal(isLegalOp);
221713f85d5SLei Zhang   return target;
222713f85d5SLei Zhang }
223713f85d5SLei Zhang 
224713f85d5SLei Zhang //===----------------------------------------------------------------------===//
225713f85d5SLei Zhang // Conversion Pattern
226713f85d5SLei Zhang //===----------------------------------------------------------------------===//
227713f85d5SLei Zhang 
228713f85d5SLei Zhang namespace {
229713f85d5SLei Zhang /// Converts any op that has operands/results/attributes with numeric MemRef
230713f85d5SLei Zhang /// memory spaces.
231713f85d5SLei Zhang struct MapMemRefStoragePattern final : public ConversionPattern {
232713f85d5SLei Zhang   MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
233713f85d5SLei Zhang       : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
234713f85d5SLei Zhang 
235713f85d5SLei Zhang   LogicalResult
236713f85d5SLei Zhang   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
237713f85d5SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
238713f85d5SLei Zhang };
239713f85d5SLei Zhang } // namespace
240713f85d5SLei Zhang 
241713f85d5SLei Zhang LogicalResult MapMemRefStoragePattern::matchAndRewrite(
242713f85d5SLei Zhang     Operation *op, ArrayRef<Value> operands,
243713f85d5SLei Zhang     ConversionPatternRewriter &rewriter) const {
244713f85d5SLei Zhang   llvm::SmallVector<NamedAttribute, 4> newAttrs;
245713f85d5SLei Zhang   newAttrs.reserve(op->getAttrs().size());
246713f85d5SLei Zhang   for (auto attr : op->getAttrs()) {
247713f85d5SLei Zhang     if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
248713f85d5SLei Zhang       auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
249713f85d5SLei Zhang       newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
250713f85d5SLei Zhang     } else {
251713f85d5SLei Zhang       newAttrs.push_back(attr);
252713f85d5SLei Zhang     }
253713f85d5SLei Zhang   }
254713f85d5SLei Zhang 
255713f85d5SLei Zhang   llvm::SmallVector<Type, 4> newResults;
256713f85d5SLei Zhang   (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
257713f85d5SLei Zhang 
258713f85d5SLei Zhang   OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
259713f85d5SLei Zhang                        newResults, newAttrs, op->getSuccessors());
260713f85d5SLei Zhang 
261713f85d5SLei Zhang   for (Region &region : op->getRegions()) {
262713f85d5SLei Zhang     Region *newRegion = state.addRegion();
263713f85d5SLei Zhang     rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
264713f85d5SLei Zhang     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
265713f85d5SLei Zhang     (void)getTypeConverter()->convertSignatureArgs(
266713f85d5SLei Zhang         newRegion->getArgumentTypes(), result);
267713f85d5SLei Zhang     rewriter.applySignatureConversion(newRegion, result);
268713f85d5SLei Zhang   }
269713f85d5SLei Zhang 
270713f85d5SLei Zhang   Operation *newOp = rewriter.create(state);
271713f85d5SLei Zhang   rewriter.replaceOp(op, newOp->getResults());
272713f85d5SLei Zhang   return success();
273713f85d5SLei Zhang }
274713f85d5SLei Zhang 
275713f85d5SLei Zhang void spirv::populateMemorySpaceToStorageClassPatterns(
276713f85d5SLei Zhang     spirv::MemorySpaceToStorageClassConverter &typeConverter,
277713f85d5SLei Zhang     RewritePatternSet &patterns) {
278713f85d5SLei Zhang   patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
279713f85d5SLei Zhang }
280713f85d5SLei Zhang 
281713f85d5SLei Zhang //===----------------------------------------------------------------------===//
282713f85d5SLei Zhang // Conversion Pass
283713f85d5SLei Zhang //===----------------------------------------------------------------------===//
284713f85d5SLei Zhang 
285713f85d5SLei Zhang namespace {
286713f85d5SLei Zhang class MapMemRefStorageClassPass final
287*039b969bSMichele Scuttari     : public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
288713f85d5SLei Zhang public:
289b83d0d46SLei Zhang   explicit MapMemRefStorageClassPass() {
29015135553SLei Zhang     memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
291b83d0d46SLei Zhang   }
292713f85d5SLei Zhang   explicit MapMemRefStorageClassPass(
293713f85d5SLei Zhang       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
294713f85d5SLei Zhang       : memorySpaceMap(memorySpaceMap) {}
295713f85d5SLei Zhang 
296713f85d5SLei Zhang   LogicalResult initializeOptions(StringRef options) override;
297713f85d5SLei Zhang 
298713f85d5SLei Zhang   void runOnOperation() override;
299713f85d5SLei Zhang 
300713f85d5SLei Zhang private:
301713f85d5SLei Zhang   spirv::MemorySpaceToStorageClassMap memorySpaceMap;
302713f85d5SLei Zhang };
303713f85d5SLei Zhang } // namespace
304713f85d5SLei Zhang 
305713f85d5SLei Zhang LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
306713f85d5SLei Zhang   if (failed(Pass::initializeOptions(options)))
307713f85d5SLei Zhang     return failure();
308713f85d5SLei Zhang 
30935a56e5dSStanley Winata   if (clientAPI == "opencl") {
31035a56e5dSStanley Winata     memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
31135a56e5dSStanley Winata   }
31235a56e5dSStanley Winata 
31335a56e5dSStanley Winata   if (clientAPI != "vulkan" && clientAPI != "opencl")
314713f85d5SLei Zhang     return failure();
315713f85d5SLei Zhang 
316713f85d5SLei Zhang   return success();
317713f85d5SLei Zhang }
318713f85d5SLei Zhang 
319713f85d5SLei Zhang void MapMemRefStorageClassPass::runOnOperation() {
320713f85d5SLei Zhang   MLIRContext *context = &getContext();
321b83d0d46SLei Zhang   Operation *op = getOperation();
322713f85d5SLei Zhang 
323713f85d5SLei Zhang   auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
324713f85d5SLei Zhang   spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
325713f85d5SLei Zhang 
326713f85d5SLei Zhang   RewritePatternSet patterns(context);
327713f85d5SLei Zhang   spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
328713f85d5SLei Zhang 
329b83d0d46SLei Zhang   if (failed(applyFullConversion(op, *target, std::move(patterns))))
330713f85d5SLei Zhang     return signalPassFailure();
331713f85d5SLei Zhang }
332713f85d5SLei Zhang 
333b83d0d46SLei Zhang std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
334713f85d5SLei Zhang   return std::make_unique<MapMemRefStorageClassPass>();
335713f85d5SLei Zhang }
336