1 //===- VectorToSPIRVPass.cpp - Vector to SPIR-V Passes --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" 14 15 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/Dialect/UB/IR/UBOps.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 namespace mlir { 23 #define GEN_PASS_DEF_CONVERTVECTORTOSPIRV 24 #include "mlir/Conversion/Passes.h.inc" 25 } // namespace mlir 26 27 using namespace mlir; 28 29 namespace { 30 struct ConvertVectorToSPIRVPass 31 : public impl::ConvertVectorToSPIRVBase<ConvertVectorToSPIRVPass> { 32 void runOnOperation() override; 33 }; 34 } // namespace 35 36 void ConvertVectorToSPIRVPass::runOnOperation() { 37 MLIRContext *context = &getContext(); 38 Operation *op = getOperation(); 39 40 auto targetAttr = spirv::lookupTargetEnvOrDefault(op); 41 std::unique_ptr<ConversionTarget> target = 42 SPIRVConversionTarget::get(targetAttr); 43 44 SPIRVTypeConverter typeConverter(targetAttr); 45 46 // Use UnrealizedConversionCast as the bridge so that we don't need to pull in 47 // patterns for other dialects. 48 target->addLegalOp<UnrealizedConversionCastOp>(); 49 50 RewritePatternSet patterns(context); 51 populateVectorToSPIRVPatterns(typeConverter, patterns); 52 53 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 54 return signalPassFailure(); 55 } 56 57 std::unique_ptr<OperationPass<>> mlir::createConvertVectorToSPIRVPass() { 58 return std::make_unique<ConvertVectorToSPIRVPass>(); 59 } 60