xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- DialectExtension.cpp - Linalg transform dialect extension ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/Index/IR/IndexDialect.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
16 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
21 #include "mlir/Dialect/Transform/IR/TransformOps.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Registers new ops and declares PDL as dependent dialect since the
28 /// additional ops are using PDL types for operands and results.
29 class LinalgTransformDialectExtension
30     : public transform::TransformDialectExtension<
31           LinalgTransformDialectExtension> {
32 public:
33   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension)
34 
35   using Base::Base;
36 
37   void init() {
38     declareDependentDialect<linalg::LinalgDialect>();
39 
40     declareGeneratedDialect<affine::AffineDialect>();
41     declareGeneratedDialect<arith::ArithDialect>();
42     declareGeneratedDialect<index::IndexDialect>();
43     declareGeneratedDialect<scf::SCFDialect>();
44     declareGeneratedDialect<vector::VectorDialect>();
45     declareGeneratedDialect<gpu::GPUDialect>();
46     declareGeneratedDialect<tensor::TensorDialect>();
47 
48     registerTransformOps<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
51         >();
52     registerTransformOps<
53 #define GET_OP_LIST
54 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
55         >();
56   }
57 };
58 } // namespace
59 
60 void mlir::linalg::registerTransformDialectExtension(
61     DialectRegistry &registry) {
62   registry.addExtensions<LinalgTransformDialectExtension>();
63 }
64