11dce51b8STres Popp //===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// 21dce51b8STres Popp // 31dce51b8STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41dce51b8STres Popp // See https://llvm.org/LICENSE.txt for license information. 51dce51b8STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61dce51b8STres Popp // 71dce51b8STres Popp //===----------------------------------------------------------------------===// 81dce51b8STres Popp // 91dce51b8STres Popp // This file implements a pass to convert Tensor dialect to Linalg dialect. 101dce51b8STres Popp // 111dce51b8STres Popp //===----------------------------------------------------------------------===// 121dce51b8STres Popp 131dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" 1467d0d7acSMichele Scuttari 151dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" 16*abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 171dce51b8STres Popp #include "mlir/Dialect/Linalg/IR/Linalg.h" 181dce51b8STres Popp #include "mlir/Dialect/Tensor/IR/Tensor.h" 191dce51b8STres Popp 2067d0d7acSMichele Scuttari namespace mlir { 2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTTENSORTOLINALG 2267d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2367d0d7acSMichele Scuttari } // namespace mlir 2467d0d7acSMichele Scuttari 251dce51b8STres Popp using namespace mlir; 261dce51b8STres Popp 271dce51b8STres Popp namespace { 281dce51b8STres Popp /// A pass converting MLIR Tensor operations into the Linalg dialect. 291dce51b8STres Popp class ConvertTensorToLinalgPass 3067d0d7acSMichele Scuttari : public impl::ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> { runOnOperation()311dce51b8STres Popp void runOnOperation() override { 321dce51b8STres Popp auto &context = getContext(); 331dce51b8STres Popp ConversionTarget target(context); 34*abc362a1SJakub Kuderski target 35*abc362a1SJakub Kuderski .addLegalDialect<mlir::arith::ArithDialect, mlir::linalg::LinalgDialect, 361dce51b8STres Popp mlir::tensor::TensorDialect>(); 371dce51b8STres Popp target.addIllegalOp<mlir::tensor::PadOp>(); 381dce51b8STres Popp 391dce51b8STres Popp RewritePatternSet patterns(&context); 401dce51b8STres Popp populateTensorToLinalgPatterns(patterns); 411dce51b8STres Popp 421dce51b8STres Popp if (failed(applyPartialConversion(getOperation(), target, 431dce51b8STres Popp std::move(patterns)))) 441dce51b8STres Popp return signalPassFailure(); 451dce51b8STres Popp } 461dce51b8STres Popp }; 471dce51b8STres Popp } // namespace 481dce51b8STres Popp 491dce51b8STres Popp std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass()501dce51b8STres Poppmlir::createConvertTensorToLinalgPass() { 511dce51b8STres Popp return std::make_unique<ConvertTensorToLinalgPass>(); 521dce51b8STres Popp } 53