1 //===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert Tensor dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" 14 15 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" 16 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 17 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_CONVERTTENSORTOSPIRV 23 #include "mlir/Conversion/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 28 namespace { 29 /// A pass converting MLIR Tensor operations into the SPIR-V dialect. 30 class ConvertTensorToSPIRVPass 31 : public impl::ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> { runOnOperation()32 void runOnOperation() override { 33 MLIRContext *context = &getContext(); 34 Operation *op = getOperation(); 35 36 auto targetAttr = spirv::lookupTargetEnvOrDefault(op); 37 std::unique_ptr<ConversionTarget> target = 38 SPIRVConversionTarget::get(targetAttr); 39 40 SPIRVConversionOptions options; 41 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; 42 SPIRVTypeConverter typeConverter(targetAttr, options); 43 44 RewritePatternSet patterns(context); 45 arith::populateArithToSPIRVPatterns(typeConverter, patterns); 46 populateFuncToSPIRVPatterns(typeConverter, patterns); 47 populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, 48 patterns); 49 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 50 51 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 52 return signalPassFailure(); 53 } 54 }; 55 } // namespace 56 createConvertTensorToSPIRVPass()57std::unique_ptr<OperationPass<>> mlir::createConvertTensorToSPIRVPass() { 58 return std::make_unique<ConvertTensorToSPIRVPass>(); 59 } 60