1 //===- MathToSPIRVPass.cpp - Math to SPIR-V Passes ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert standard dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" 14 15 #include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/Pass/Pass.h" 19 20 namespace mlir { 21 #define GEN_PASS_DEF_CONVERTMATHTOSPIRV 22 #include "mlir/Conversion/Passes.h.inc" 23 } // namespace mlir 24 25 using namespace mlir; 26 27 namespace { 28 /// A pass converting MLIR Math operations into the SPIR-V dialect. 29 class ConvertMathToSPIRVPass 30 : public impl::ConvertMathToSPIRVBase<ConvertMathToSPIRVPass> { 31 void runOnOperation() override; 32 }; 33 } // namespace 34 runOnOperation()35void ConvertMathToSPIRVPass::runOnOperation() { 36 MLIRContext *context = &getContext(); 37 Operation *op = getOperation(); 38 39 auto targetAttr = spirv::lookupTargetEnvOrDefault(op); 40 std::unique_ptr<ConversionTarget> target = 41 SPIRVConversionTarget::get(targetAttr); 42 43 SPIRVTypeConverter typeConverter(targetAttr); 44 45 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 46 // in patterns for other dialects. 47 target->addLegalOp<UnrealizedConversionCastOp>(); 48 49 RewritePatternSet patterns(context); 50 populateMathToSPIRVPatterns(typeConverter, patterns); 51 52 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 53 return signalPassFailure(); 54 } 55 createConvertMathToSPIRVPass()56std::unique_ptr<OperationPass<>> mlir::createConvertMathToSPIRVPass() { 57 return std::make_unique<ConvertMathToSPIRVPass>(); 58 } 59