xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp (revision 30916b6942371fc314f3ce1bfa4042cae3e6ff28)
1 //===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for combining composed subview ops (i.e. subview
10 // of a subview becomes a single subview).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 // Replaces a subview of a subview with a single subview(both static and dynamic
28 // offsets are supported).
29 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
30   using OpRewritePattern::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(memref::SubViewOp op,
33                                 PatternRewriter &rewriter) const override {
34     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
35     // produces the input of the op we're rewriting (for 'SubViewOp' the input
36     // is called the "source" value). We can only combine them if both 'op' and
37     // 'sourceOp' are 'SubViewOp'.
38     auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
39     if (!sourceOp)
40       return failure();
41 
42     // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
43     // output memref that are statically known to be equal to 1. We do not
44     // allow 'sourceOp' to be a rank-reducing subview because then our two
45     // 'SubViewOp's would have different numbers of offset/size/stride
46     // parameters (just difficult to deal with, not impossible if we end up
47     // needing it).
48     if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
49       return failure();
50     }
51 
52     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
53     SmallVector<OpFoldResult> offsets, sizes, strides,
54         opStrides = op.getMixedStrides(),
55         sourceStrides = sourceOp.getMixedStrides();
56 
57     // The output stride in each dimension is equal to the product of the
58     // dimensions corresponding to source and op.
59     int64_t sourceStrideValue;
60     for (auto &&[opStride, sourceStride] :
61          llvm::zip(opStrides, sourceStrides)) {
62       Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
63       Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
64       if (!opStrideAttr || !sourceStrideAttr)
65         return failure();
66       sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
67       strides.push_back(rewriter.getI64IntegerAttr(
68           cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
69     }
70 
71     // The rules for calculating the new offsets and sizes are:
72     // * Multiple subview offsets for a given dimension compose additively.
73     //   ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
74     //   m + n * k")
75     // * Multiple sizes for a given dimension compose by taking the size of the
76     //   final subview and ignoring the rest. ("Take m values" followed by "Take
77     //   n values" == "Take n values") This size must also be the smallest one
78     //   by definition (a subview needs to be the same size as or smaller than
79     //   its source along each dimension; presumably subviews that are larger
80     //   than their sources are disallowed by validation).
81     for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
82          llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
83                    sourceOp.getMixedStrides(), op.getMixedSizes())) {
84       // We only support static sizes.
85       if (isa<Value>(opSize)) {
86         return failure();
87       }
88       sizes.push_back(opSize);
89       Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
90                 sourceOffsetAttr =
91                     llvm::dyn_cast_if_present<Attribute>(sourceOffset),
92                 sourceStrideAttr =
93                     llvm::dyn_cast_if_present<Attribute>(sourceStride);
94       if (opOffsetAttr && sourceOffsetAttr) {
95 
96         // If both offsets are static we can simply calculate the combined
97         // offset statically.
98         offsets.push_back(rewriter.getI64IntegerAttr(
99             cast<IntegerAttr>(opOffsetAttr).getInt() *
100                 cast<IntegerAttr>(sourceStrideAttr).getInt() +
101             cast<IntegerAttr>(sourceOffsetAttr).getInt()));
102       } else {
103         AffineExpr expr;
104         SmallVector<Value> affineApplyOperands;
105 
106         // Make 'expr' add 'sourceOffset'.
107         if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
108           expr =
109               rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
110         } else {
111           expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
112           affineApplyOperands.push_back(cast<Value>(sourceOffset));
113         }
114 
115         // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
116         // result.
117         if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
118           expr = expr + cast<IntegerAttr>(attr).getInt() *
119                             cast<IntegerAttr>(sourceStrideAttr).getInt();
120         } else {
121           expr =
122               expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
123                          cast<IntegerAttr>(sourceStrideAttr).getInt();
124           affineApplyOperands.push_back(cast<Value>(opOffset));
125         }
126 
127         AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
128         Value result = rewriter.create<affine::AffineApplyOp>(
129             op.getLoc(), map, affineApplyOperands);
130         offsets.push_back(result);
131       }
132     }
133 
134     // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
135     // uses it can be removed by a (separate) dead code elimination pass.
136     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
137         op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
138     return success();
139   }
140 };
141 
142 } // namespace
143 
144 void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
145                                                   MLIRContext *context) {
146   patterns.add<ComposeSubViewOpPattern>(context);
147 }
148