1 //===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert Tensor dialect to Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" 14 15 #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 20 namespace mlir { 21 #define GEN_PASS_DEF_CONVERTTENSORTOLINALG 22 #include "mlir/Conversion/Passes.h.inc" 23 } // namespace mlir 24 25 using namespace mlir; 26 27 namespace { 28 /// A pass converting MLIR Tensor operations into the Linalg dialect. 29 class ConvertTensorToLinalgPass 30 : public impl::ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> { runOnOperation()31 void runOnOperation() override { 32 auto &context = getContext(); 33 ConversionTarget target(context); 34 target 35 .addLegalDialect<mlir::arith::ArithDialect, mlir::linalg::LinalgDialect, 36 mlir::tensor::TensorDialect>(); 37 target.addIllegalOp<mlir::tensor::PadOp>(); 38 39 RewritePatternSet patterns(&context); 40 populateTensorToLinalgPatterns(patterns); 41 42 if (failed(applyPartialConversion(getOperation(), target, 43 std::move(patterns)))) 44 return signalPassFailure(); 45 } 46 }; 47 } // namespace 48 49 std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass()50mlir::createConvertTensorToLinalgPass() { 51 return std::make_unique<ConvertTensorToLinalgPass>(); 52 } 53