xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements conversions between named ops that can be seens as
10 // canonicalizations of named ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Passes.h"
15 
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
32   return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
33 }
34 
35 static LogicalResult
36 matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
37                              Value iZp, Value kZp, Value init, Attribute stride,
38                              Attribute dilation, PatternRewriter &rewriter) {
39   Location loc = operation->getLoc();
40   auto linalgOp = dyn_cast<LinalgOp>(operation);
41   // Exit out on the memref version of this operation.
42   if (!linalgOp || !linalgOp.hasPureTensorSemantics())
43     return failure();
44 
45   auto result = operation->getResult(0);
46 
47   auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
48   auto initTy = dyn_cast<RankedTensorType>(init.getType());
49   auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50   if (!kernelTy || !initTy || !resultTy)
51     return failure();
52 
53   if (kernelTy.getDimSize(3) != 1)
54     return failure();
55 
56   // Collapse kernel dims.
57   SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
58       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
59   auto newKernelTy = RankedTensorType::get(
60       {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61       kernelTy.getElementType());
62   auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
63       loc, newKernelTy, kernel, collapsedKernelDims);
64 
65   // Collapse init dims.
66   SmallVector<ReassociationIndices, 4> collapsedInitDims = {
67       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
68       getIndicesVector(3, 5)};
69   auto newInitTy =
70       RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
71                              initTy.getDimSize(2), initTy.getDimSize(3)},
72                             initTy.getElementType());
73   auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
74       loc, newInitTy, init, collapsedInitDims);
75 
76   SmallVector<NamedAttribute> preservedAttrs;
77   Operation *newConv =
78       TypeSwitch<Operation *, Operation *>(operation)
79           .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
80             preservedAttrs = getPrunedAttributeList(op);
81             return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
82                 loc, newInitTy, ValueRange{input, collapsedKernel},
83                 ValueRange{collapsedInit}, stride, dilation);
84           })
85           .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
86             preservedAttrs = getPrunedAttributeList(op);
87             return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
88                 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
89                 ValueRange{collapsedInit}, stride, dilation);
90           })
91           .Default([](Operation *op) { return nullptr; });
92   if (!newConv)
93     return failure();
94   for (auto attr : preservedAttrs)
95     newConv->setAttr(attr.getName(), attr.getValue());
96 
97   // Expand dimensions back out to
98   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
99       operation, resultTy, newConv->getResult(0), collapsedInitDims);
100   return success();
101 }
102 
103 namespace {
104 struct SimplifyDepthwiseConvOp
105     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
106   using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
107 
108   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
109                                 PatternRewriter &rewriter) const override {
110     Operation *operation = op.getOperation();
111     Value input = op.getDpsInputOperand(0)->get();
112     Value kernel = op.getDpsInputOperand(1)->get();
113     Value init = op.getDpsInitOperand(0)->get();
114 
115     auto stride = op.getStrides();
116     auto dilation = op.getDilations();
117 
118     return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
119                                         nullptr, init, stride, dilation,
120                                         rewriter);
121   }
122 };
123 
124 struct SimplifyDepthwiseConvQOp
125     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
126   using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
127 
128   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
129                                 PatternRewriter &rewriter) const override {
130     Operation *operation = op.getOperation();
131     Value input = op.getDpsInputOperand(0)->get();
132     Value kernel = op.getDpsInputOperand(1)->get();
133     Value iZp = op.getDpsInputOperand(2)->get();
134     Value kZp = op.getDpsInputOperand(3)->get();
135     Value init = op.getDpsInitOperand(0)->get();
136 
137     auto stride = op.getStrides();
138     auto dilation = op.getDilations();
139 
140     return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
141                                         init, stride, dilation, rewriter);
142   }
143 };
144 
145 struct LinalgNamedOpConversionPass
146     : public impl::LinalgNamedOpConversionPassBase<
147           LinalgNamedOpConversionPass> {
148   using impl::LinalgNamedOpConversionPassBase<
149       LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
150 
151   void runOnOperation() override {
152     Operation *op = getOperation();
153     RewritePatternSet patterns(op->getContext());
154     populateLinalgNamedOpConversionPatterns(patterns);
155     if (failed(applyPatternsGreedily(op, std::move(patterns))))
156       return signalPassFailure();
157   }
158 };
159 } // namespace
160 
161 void mlir::linalg::populateLinalgNamedOpConversionPatterns(
162     RewritePatternSet &patterns) {
163   patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
164       patterns.getContext());
165 }
166