1 //===- DialectExtension.cpp - Linalg transform dialect extension ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 13 #include "mlir/Dialect/Index/IR/IndexDialect.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" 16 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 21 #include "mlir/Dialect/Transform/IR/TransformOps.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 24 using namespace mlir; 25 26 namespace { 27 /// Registers new ops and declares PDL as dependent dialect since the 28 /// additional ops are using PDL types for operands and results. 29 class LinalgTransformDialectExtension 30 : public transform::TransformDialectExtension< 31 LinalgTransformDialectExtension> { 32 public: 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension) 34 35 using Base::Base; 36 37 void init() { 38 declareDependentDialect<linalg::LinalgDialect>(); 39 40 declareGeneratedDialect<affine::AffineDialect>(); 41 declareGeneratedDialect<arith::ArithDialect>(); 42 declareGeneratedDialect<index::IndexDialect>(); 43 declareGeneratedDialect<scf::SCFDialect>(); 44 declareGeneratedDialect<vector::VectorDialect>(); 45 declareGeneratedDialect<gpu::GPUDialect>(); 46 declareGeneratedDialect<tensor::TensorDialect>(); 47 48 registerTransformOps< 49 #define GET_OP_LIST 50 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 51 >(); 52 registerTransformOps< 53 #define GET_OP_LIST 54 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" 55 >(); 56 } 57 }; 58 } // namespace 59 60 void mlir::linalg::registerTransformDialectExtension( 61 DialectRegistry ®istry) { 62 registry.addExtensions<LinalgTransformDialectExtension>(); 63 } 64