1713f85d5SLei Zhang //===- MapMemRefStorageCLassPass.cpp --------------------------------------===// 2713f85d5SLei Zhang // 3713f85d5SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4713f85d5SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5713f85d5SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6713f85d5SLei Zhang // 7713f85d5SLei Zhang //===----------------------------------------------------------------------===// 8713f85d5SLei Zhang // 9713f85d5SLei Zhang // This file implements a pass to map numeric MemRef memory spaces to 10713f85d5SLei Zhang // symbolic ones defined in the SPIR-V specification. 11713f85d5SLei Zhang // 12713f85d5SLei Zhang //===----------------------------------------------------------------------===// 13713f85d5SLei Zhang 14*039b969bSMichele Scuttari #include "../PassDetail.h" 152be8af8fSMichele Scuttari #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" 16*039b969bSMichele Scuttari #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" 17a29fffc4SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 18713f85d5SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 1915135553SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 20713f85d5SLei Zhang #include "mlir/IR/BuiltinTypes.h" 21b83d0d46SLei Zhang #include "mlir/IR/FunctionInterfaces.h" 22713f85d5SLei Zhang #include "mlir/Transforms/DialectConversion.h" 23713f85d5SLei Zhang #include "llvm/ADT/StringExtras.h" 24713f85d5SLei Zhang #include "llvm/Support/Debug.h" 25713f85d5SLei Zhang 26713f85d5SLei Zhang #define DEBUG_TYPE "mlir-map-memref-storage-class" 27713f85d5SLei Zhang 28713f85d5SLei Zhang using namespace mlir; 29713f85d5SLei Zhang 30713f85d5SLei Zhang //===----------------------------------------------------------------------===// 311f7544a6SLei Zhang // Mappings 32713f85d5SLei Zhang //===----------------------------------------------------------------------===// 33713f85d5SLei Zhang 341f7544a6SLei Zhang /// Mapping between SPIR-V storage classes to memref memory spaces. 351f7544a6SLei Zhang /// 361f7544a6SLei Zhang /// Note: memref does not have a defined semantics for each memory space; it 371f7544a6SLei Zhang /// depends on the context where it is used. There are no particular reasons 381f7544a6SLei Zhang /// behind the number assignments; we try to follow NVVM conventions and largely 391f7544a6SLei Zhang /// give common storage classes a smaller number. 4035a56e5dSStanley Winata #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \ 411f7544a6SLei Zhang MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 421f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Generic, 1) \ 431f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Workgroup, 3) \ 441f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Uniform, 4) \ 451f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Private, 5) \ 461f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Function, 6) \ 471f7544a6SLei Zhang MAP_FN(spirv::StorageClass::PushConstant, 7) \ 481f7544a6SLei Zhang MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 491f7544a6SLei Zhang MAP_FN(spirv::StorageClass::Input, 9) \ 5015135553SLei Zhang MAP_FN(spirv::StorageClass::Output, 10) 511f7544a6SLei Zhang 5215135553SLei Zhang Optional<spirv::StorageClass> 5315135553SLei Zhang spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { 5415135553SLei Zhang #define STORAGE_SPACE_MAP_FN(storage, space) \ 5515135553SLei Zhang case space: \ 5615135553SLei Zhang return storage; 571f7544a6SLei Zhang 5815135553SLei Zhang switch (memorySpace) { 5935a56e5dSStanley Winata VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 6015135553SLei Zhang default: 6115135553SLei Zhang break; 6215135553SLei Zhang } 6315135553SLei Zhang return llvm::None; 641f7544a6SLei Zhang 651f7544a6SLei Zhang #undef STORAGE_SPACE_MAP_FN 66713f85d5SLei Zhang } 67713f85d5SLei Zhang 6815135553SLei Zhang Optional<unsigned> 6915135553SLei Zhang spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) { 7015135553SLei Zhang #define STORAGE_SPACE_MAP_FN(storage, space) \ 7115135553SLei Zhang case storage: \ 7215135553SLei Zhang return space; 7315135553SLei Zhang 7415135553SLei Zhang switch (storageClass) { 7535a56e5dSStanley Winata VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 7615135553SLei Zhang default: 7715135553SLei Zhang break; 7815135553SLei Zhang } 7915135553SLei Zhang return llvm::None; 8015135553SLei Zhang 8115135553SLei Zhang #undef STORAGE_SPACE_MAP_FN 8215135553SLei Zhang } 8315135553SLei Zhang 8435a56e5dSStanley Winata #undef VULKAN_STORAGE_SPACE_MAP_LIST 8535a56e5dSStanley Winata 8635a56e5dSStanley Winata #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \ 8735a56e5dSStanley Winata MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \ 8835a56e5dSStanley Winata MAP_FN(spirv::StorageClass::Generic, 1) \ 8935a56e5dSStanley Winata MAP_FN(spirv::StorageClass::Workgroup, 3) \ 9035a56e5dSStanley Winata MAP_FN(spirv::StorageClass::UniformConstant, 4) \ 9135a56e5dSStanley Winata MAP_FN(spirv::StorageClass::Private, 5) \ 9235a56e5dSStanley Winata MAP_FN(spirv::StorageClass::Function, 6) \ 9335a56e5dSStanley Winata MAP_FN(spirv::StorageClass::Image, 7) 9435a56e5dSStanley Winata 9535a56e5dSStanley Winata Optional<spirv::StorageClass> 9635a56e5dSStanley Winata spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) { 9735a56e5dSStanley Winata #define STORAGE_SPACE_MAP_FN(storage, space) \ 9835a56e5dSStanley Winata case space: \ 9935a56e5dSStanley Winata return storage; 10035a56e5dSStanley Winata 10135a56e5dSStanley Winata switch (memorySpace) { 10235a56e5dSStanley Winata OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 10335a56e5dSStanley Winata default: 10435a56e5dSStanley Winata break; 10535a56e5dSStanley Winata } 10635a56e5dSStanley Winata return llvm::None; 10735a56e5dSStanley Winata 10835a56e5dSStanley Winata #undef STORAGE_SPACE_MAP_FN 10935a56e5dSStanley Winata } 11035a56e5dSStanley Winata 11135a56e5dSStanley Winata Optional<unsigned> 11235a56e5dSStanley Winata spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) { 11335a56e5dSStanley Winata #define STORAGE_SPACE_MAP_FN(storage, space) \ 11435a56e5dSStanley Winata case storage: \ 11535a56e5dSStanley Winata return space; 11635a56e5dSStanley Winata 11735a56e5dSStanley Winata switch (storageClass) { 11835a56e5dSStanley Winata OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 11935a56e5dSStanley Winata default: 12035a56e5dSStanley Winata break; 12135a56e5dSStanley Winata } 12235a56e5dSStanley Winata return llvm::None; 12335a56e5dSStanley Winata 12435a56e5dSStanley Winata #undef STORAGE_SPACE_MAP_FN 12535a56e5dSStanley Winata } 12635a56e5dSStanley Winata 12735a56e5dSStanley Winata #undef OPENCL_STORAGE_SPACE_MAP_LIST 12815135553SLei Zhang 129713f85d5SLei Zhang //===----------------------------------------------------------------------===// 130713f85d5SLei Zhang // Type Converter 131713f85d5SLei Zhang //===----------------------------------------------------------------------===// 132713f85d5SLei Zhang 133713f85d5SLei Zhang spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( 134713f85d5SLei Zhang const spirv::MemorySpaceToStorageClassMap &memorySpaceMap) 135713f85d5SLei Zhang : memorySpaceMap(memorySpaceMap) { 136713f85d5SLei Zhang // Pass through for all other types. 137713f85d5SLei Zhang addConversion([](Type type) { return type; }); 138713f85d5SLei Zhang 139713f85d5SLei Zhang addConversion([this](BaseMemRefType memRefType) -> Optional<Type> { 140713f85d5SLei Zhang // Expect IntegerAttr memory spaces. The attribute can be missing for the 141713f85d5SLei Zhang // case of memory space == 0. 142713f85d5SLei Zhang Attribute spaceAttr = memRefType.getMemorySpace(); 143713f85d5SLei Zhang if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) { 144713f85d5SLei Zhang LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType 145b83d0d46SLei Zhang << " due to non-IntegerAttr memory space\n"); 146713f85d5SLei Zhang return llvm::None; 147713f85d5SLei Zhang } 148713f85d5SLei Zhang 149713f85d5SLei Zhang unsigned space = memRefType.getMemorySpaceAsInt(); 15015135553SLei Zhang auto storage = this->memorySpaceMap(space); 15115135553SLei Zhang if (!storage) { 152b83d0d46SLei Zhang LLVM_DEBUG(llvm::dbgs() 153b83d0d46SLei Zhang << "cannot convert " << memRefType 154b83d0d46SLei Zhang << " due to being unable to find memory space in map\n"); 155713f85d5SLei Zhang return llvm::None; 156713f85d5SLei Zhang } 157713f85d5SLei Zhang 158713f85d5SLei Zhang auto storageAttr = 15915135553SLei Zhang spirv::StorageClassAttr::get(memRefType.getContext(), *storage); 160713f85d5SLei Zhang if (auto rankedType = memRefType.dyn_cast<MemRefType>()) { 161713f85d5SLei Zhang return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), 162713f85d5SLei Zhang rankedType.getLayout(), storageAttr); 163713f85d5SLei Zhang } 164713f85d5SLei Zhang return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr); 165713f85d5SLei Zhang }); 166713f85d5SLei Zhang 167713f85d5SLei Zhang addConversion([this](FunctionType type) { 168713f85d5SLei Zhang SmallVector<Type> inputs, results; 169713f85d5SLei Zhang inputs.reserve(type.getNumInputs()); 170713f85d5SLei Zhang results.reserve(type.getNumResults()); 171713f85d5SLei Zhang for (Type input : type.getInputs()) 172713f85d5SLei Zhang inputs.push_back(convertType(input)); 173713f85d5SLei Zhang for (Type result : type.getResults()) 174713f85d5SLei Zhang results.push_back(convertType(result)); 175713f85d5SLei Zhang return FunctionType::get(type.getContext(), inputs, results); 176713f85d5SLei Zhang }); 177713f85d5SLei Zhang } 178713f85d5SLei Zhang 179713f85d5SLei Zhang //===----------------------------------------------------------------------===// 180713f85d5SLei Zhang // Conversion Target 181713f85d5SLei Zhang //===----------------------------------------------------------------------===// 182713f85d5SLei Zhang 183713f85d5SLei Zhang /// Returns true if the given `type` is considered as legal for SPIR-V 184713f85d5SLei Zhang /// conversion. 185713f85d5SLei Zhang static bool isLegalType(Type type) { 186713f85d5SLei Zhang if (auto memRefType = type.dyn_cast<BaseMemRefType>()) { 187713f85d5SLei Zhang Attribute spaceAttr = memRefType.getMemorySpace(); 188713f85d5SLei Zhang return spaceAttr && spaceAttr.isa<spirv::StorageClassAttr>(); 189713f85d5SLei Zhang } 190713f85d5SLei Zhang return true; 191713f85d5SLei Zhang } 192713f85d5SLei Zhang 193713f85d5SLei Zhang /// Returns true if the given `attr` is considered as legal for SPIR-V 194713f85d5SLei Zhang /// conversion. 195713f85d5SLei Zhang static bool isLegalAttr(Attribute attr) { 196713f85d5SLei Zhang if (auto typeAttr = attr.dyn_cast<TypeAttr>()) 197713f85d5SLei Zhang return isLegalType(typeAttr.getValue()); 198713f85d5SLei Zhang return true; 199713f85d5SLei Zhang } 200713f85d5SLei Zhang 201713f85d5SLei Zhang /// Returns true if the given `op` is considered as legal for SPIR-V conversion. 202713f85d5SLei Zhang static bool isLegalOp(Operation *op) { 203b83d0d46SLei Zhang if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 204b83d0d46SLei Zhang return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) && 205b83d0d46SLei Zhang llvm::all_of(funcOp.getResultTypes(), isLegalType); 206713f85d5SLei Zhang } 207713f85d5SLei Zhang 208713f85d5SLei Zhang auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) { 209713f85d5SLei Zhang return attr.getValue(); 210713f85d5SLei Zhang }); 211713f85d5SLei Zhang 212713f85d5SLei Zhang return llvm::all_of(op->getOperandTypes(), isLegalType) && 213713f85d5SLei Zhang llvm::all_of(op->getResultTypes(), isLegalType) && 214713f85d5SLei Zhang llvm::all_of(attrs, isLegalAttr); 215713f85d5SLei Zhang } 216713f85d5SLei Zhang 217713f85d5SLei Zhang std::unique_ptr<ConversionTarget> 218713f85d5SLei Zhang spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) { 219713f85d5SLei Zhang auto target = std::make_unique<ConversionTarget>(context); 220713f85d5SLei Zhang target->markUnknownOpDynamicallyLegal(isLegalOp); 221713f85d5SLei Zhang return target; 222713f85d5SLei Zhang } 223713f85d5SLei Zhang 224713f85d5SLei Zhang //===----------------------------------------------------------------------===// 225713f85d5SLei Zhang // Conversion Pattern 226713f85d5SLei Zhang //===----------------------------------------------------------------------===// 227713f85d5SLei Zhang 228713f85d5SLei Zhang namespace { 229713f85d5SLei Zhang /// Converts any op that has operands/results/attributes with numeric MemRef 230713f85d5SLei Zhang /// memory spaces. 231713f85d5SLei Zhang struct MapMemRefStoragePattern final : public ConversionPattern { 232713f85d5SLei Zhang MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter) 233713f85d5SLei Zhang : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} 234713f85d5SLei Zhang 235713f85d5SLei Zhang LogicalResult 236713f85d5SLei Zhang matchAndRewrite(Operation *op, ArrayRef<Value> operands, 237713f85d5SLei Zhang ConversionPatternRewriter &rewriter) const override; 238713f85d5SLei Zhang }; 239713f85d5SLei Zhang } // namespace 240713f85d5SLei Zhang 241713f85d5SLei Zhang LogicalResult MapMemRefStoragePattern::matchAndRewrite( 242713f85d5SLei Zhang Operation *op, ArrayRef<Value> operands, 243713f85d5SLei Zhang ConversionPatternRewriter &rewriter) const { 244713f85d5SLei Zhang llvm::SmallVector<NamedAttribute, 4> newAttrs; 245713f85d5SLei Zhang newAttrs.reserve(op->getAttrs().size()); 246713f85d5SLei Zhang for (auto attr : op->getAttrs()) { 247713f85d5SLei Zhang if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) { 248713f85d5SLei Zhang auto newAttr = getTypeConverter()->convertType(typeAttr.getValue()); 249713f85d5SLei Zhang newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); 250713f85d5SLei Zhang } else { 251713f85d5SLei Zhang newAttrs.push_back(attr); 252713f85d5SLei Zhang } 253713f85d5SLei Zhang } 254713f85d5SLei Zhang 255713f85d5SLei Zhang llvm::SmallVector<Type, 4> newResults; 256713f85d5SLei Zhang (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults); 257713f85d5SLei Zhang 258713f85d5SLei Zhang OperationState state(op->getLoc(), op->getName().getStringRef(), operands, 259713f85d5SLei Zhang newResults, newAttrs, op->getSuccessors()); 260713f85d5SLei Zhang 261713f85d5SLei Zhang for (Region ®ion : op->getRegions()) { 262713f85d5SLei Zhang Region *newRegion = state.addRegion(); 263713f85d5SLei Zhang rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); 264713f85d5SLei Zhang TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 265713f85d5SLei Zhang (void)getTypeConverter()->convertSignatureArgs( 266713f85d5SLei Zhang newRegion->getArgumentTypes(), result); 267713f85d5SLei Zhang rewriter.applySignatureConversion(newRegion, result); 268713f85d5SLei Zhang } 269713f85d5SLei Zhang 270713f85d5SLei Zhang Operation *newOp = rewriter.create(state); 271713f85d5SLei Zhang rewriter.replaceOp(op, newOp->getResults()); 272713f85d5SLei Zhang return success(); 273713f85d5SLei Zhang } 274713f85d5SLei Zhang 275713f85d5SLei Zhang void spirv::populateMemorySpaceToStorageClassPatterns( 276713f85d5SLei Zhang spirv::MemorySpaceToStorageClassConverter &typeConverter, 277713f85d5SLei Zhang RewritePatternSet &patterns) { 278713f85d5SLei Zhang patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter); 279713f85d5SLei Zhang } 280713f85d5SLei Zhang 281713f85d5SLei Zhang //===----------------------------------------------------------------------===// 282713f85d5SLei Zhang // Conversion Pass 283713f85d5SLei Zhang //===----------------------------------------------------------------------===// 284713f85d5SLei Zhang 285713f85d5SLei Zhang namespace { 286713f85d5SLei Zhang class MapMemRefStorageClassPass final 287*039b969bSMichele Scuttari : public MapMemRefStorageClassBase<MapMemRefStorageClassPass> { 288713f85d5SLei Zhang public: 289b83d0d46SLei Zhang explicit MapMemRefStorageClassPass() { 29015135553SLei Zhang memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass; 291b83d0d46SLei Zhang } 292713f85d5SLei Zhang explicit MapMemRefStorageClassPass( 293713f85d5SLei Zhang const spirv::MemorySpaceToStorageClassMap &memorySpaceMap) 294713f85d5SLei Zhang : memorySpaceMap(memorySpaceMap) {} 295713f85d5SLei Zhang 296713f85d5SLei Zhang LogicalResult initializeOptions(StringRef options) override; 297713f85d5SLei Zhang 298713f85d5SLei Zhang void runOnOperation() override; 299713f85d5SLei Zhang 300713f85d5SLei Zhang private: 301713f85d5SLei Zhang spirv::MemorySpaceToStorageClassMap memorySpaceMap; 302713f85d5SLei Zhang }; 303713f85d5SLei Zhang } // namespace 304713f85d5SLei Zhang 305713f85d5SLei Zhang LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) { 306713f85d5SLei Zhang if (failed(Pass::initializeOptions(options))) 307713f85d5SLei Zhang return failure(); 308713f85d5SLei Zhang 30935a56e5dSStanley Winata if (clientAPI == "opencl") { 31035a56e5dSStanley Winata memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass; 31135a56e5dSStanley Winata } 31235a56e5dSStanley Winata 31335a56e5dSStanley Winata if (clientAPI != "vulkan" && clientAPI != "opencl") 314713f85d5SLei Zhang return failure(); 315713f85d5SLei Zhang 316713f85d5SLei Zhang return success(); 317713f85d5SLei Zhang } 318713f85d5SLei Zhang 319713f85d5SLei Zhang void MapMemRefStorageClassPass::runOnOperation() { 320713f85d5SLei Zhang MLIRContext *context = &getContext(); 321b83d0d46SLei Zhang Operation *op = getOperation(); 322713f85d5SLei Zhang 323713f85d5SLei Zhang auto target = spirv::getMemorySpaceToStorageClassTarget(*context); 324713f85d5SLei Zhang spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); 325713f85d5SLei Zhang 326713f85d5SLei Zhang RewritePatternSet patterns(context); 327713f85d5SLei Zhang spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns); 328713f85d5SLei Zhang 329b83d0d46SLei Zhang if (failed(applyFullConversion(op, *target, std::move(patterns)))) 330713f85d5SLei Zhang return signalPassFailure(); 331713f85d5SLei Zhang } 332713f85d5SLei Zhang 333b83d0d46SLei Zhang std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() { 334713f85d5SLei Zhang return std::make_unique<MapMemRefStorageClassPass>(); 335713f85d5SLei Zhang } 336