1 //===- MemRefToSPIRVPass.cpp - MemRef to SPIR-V Passes ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert standard dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" 14 15 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 19 namespace mlir { 20 #define GEN_PASS_DEF_CONVERTMEMREFTOSPIRV 21 #include "mlir/Conversion/Passes.h.inc" 22 } // namespace mlir 23 24 using namespace mlir; 25 26 namespace { 27 /// A pass converting MLIR MemRef operations into the SPIR-V dialect. 28 class ConvertMemRefToSPIRVPass 29 : public impl::ConvertMemRefToSPIRVBase<ConvertMemRefToSPIRVPass> { 30 void runOnOperation() override; 31 }; 32 } // namespace 33 runOnOperation()34void ConvertMemRefToSPIRVPass::runOnOperation() { 35 MLIRContext *context = &getContext(); 36 Operation *op = getOperation(); 37 38 auto targetAttr = spirv::lookupTargetEnvOrDefault(op); 39 std::unique_ptr<ConversionTarget> target = 40 SPIRVConversionTarget::get(targetAttr); 41 42 SPIRVConversionOptions options; 43 options.boolNumBits = this->boolNumBits; 44 options.use64bitIndex = this->use64bitIndex; 45 SPIRVTypeConverter typeConverter(targetAttr, options); 46 47 // Use UnrealizedConversionCast as the bridge so that we don't need to pull in 48 // patterns for other dialects. 49 target->addLegalOp<UnrealizedConversionCastOp>(); 50 51 RewritePatternSet patterns(context); 52 populateMemRefToSPIRVPatterns(typeConverter, patterns); 53 54 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 55 return signalPassFailure(); 56 } 57 createConvertMemRefToSPIRVPass()58std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToSPIRVPass() { 59 return std::make_unique<ConvertMemRefToSPIRVPass>(); 60 } 61