xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp (revision 30916b6942371fc314f3ce1bfa4042cae3e6ff28)
17d0426ddSRiver Riddle //===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
27d0426ddSRiver Riddle //
37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67d0426ddSRiver Riddle //
77d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
87d0426ddSRiver Riddle //
97d0426ddSRiver Riddle // This file contains patterns for combining composed subview ops (i.e. subview
107d0426ddSRiver Riddle // of a subview becomes a single subview).
117d0426ddSRiver Riddle //
127d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
137d0426ddSRiver Riddle 
147d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
157d0426ddSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
167d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
177d0426ddSRiver Riddle #include "mlir/IR/BuiltinAttributes.h"
187d0426ddSRiver Riddle #include "mlir/IR/OpDefinition.h"
197d0426ddSRiver Riddle #include "mlir/IR/PatternMatch.h"
207d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h"
217d0426ddSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
227d0426ddSRiver Riddle 
237d0426ddSRiver Riddle using namespace mlir;
247d0426ddSRiver Riddle 
257d0426ddSRiver Riddle namespace {
267d0426ddSRiver Riddle 
272ecf6088Slonely eagle // Replaces a subview of a subview with a single subview(both static and dynamic
287d0426ddSRiver Riddle // offsets are supported).
297d0426ddSRiver Riddle struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
307d0426ddSRiver Riddle   using OpRewritePattern::OpRewritePattern;
317d0426ddSRiver Riddle 
327d0426ddSRiver Riddle   LogicalResult matchAndRewrite(memref::SubViewOp op,
337d0426ddSRiver Riddle                                 PatternRewriter &rewriter) const override {
347d0426ddSRiver Riddle     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
357d0426ddSRiver Riddle     // produces the input of the op we're rewriting (for 'SubViewOp' the input
367d0426ddSRiver Riddle     // is called the "source" value). We can only combine them if both 'op' and
377d0426ddSRiver Riddle     // 'sourceOp' are 'SubViewOp'.
38136d746eSJacques Pienaar     auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
397d0426ddSRiver Riddle     if (!sourceOp)
407d0426ddSRiver Riddle       return failure();
417d0426ddSRiver Riddle 
427d0426ddSRiver Riddle     // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
437d0426ddSRiver Riddle     // output memref that are statically known to be equal to 1. We do not
447d0426ddSRiver Riddle     // allow 'sourceOp' to be a rank-reducing subview because then our two
457d0426ddSRiver Riddle     // 'SubViewOp's would have different numbers of offset/size/stride
467d0426ddSRiver Riddle     // parameters (just difficult to deal with, not impossible if we end up
477d0426ddSRiver Riddle     // needing it).
487d0426ddSRiver Riddle     if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
497d0426ddSRiver Riddle       return failure();
507d0426ddSRiver Riddle     }
517d0426ddSRiver Riddle 
527d0426ddSRiver Riddle     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
532ecf6088Slonely eagle     SmallVector<OpFoldResult> offsets, sizes, strides,
542ecf6088Slonely eagle         opStrides = op.getMixedStrides(),
552ecf6088Slonely eagle         sourceStrides = sourceOp.getMixedStrides();
567d0426ddSRiver Riddle 
572ecf6088Slonely eagle     // The output stride in each dimension is equal to the product of the
582ecf6088Slonely eagle     // dimensions corresponding to source and op.
592ecf6088Slonely eagle     int64_t sourceStrideValue;
602ecf6088Slonely eagle     for (auto &&[opStride, sourceStride] :
612ecf6088Slonely eagle          llvm::zip(opStrides, sourceStrides)) {
622ecf6088Slonely eagle       Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
632ecf6088Slonely eagle       Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
642ecf6088Slonely eagle       if (!opStrideAttr || !sourceStrideAttr)
657d0426ddSRiver Riddle         return failure();
662ecf6088Slonely eagle       sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
672ecf6088Slonely eagle       strides.push_back(rewriter.getI64IntegerAttr(
682ecf6088Slonely eagle           cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
697d0426ddSRiver Riddle     }
707d0426ddSRiver Riddle 
717d0426ddSRiver Riddle     // The rules for calculating the new offsets and sizes are:
727d0426ddSRiver Riddle     // * Multiple subview offsets for a given dimension compose additively.
732ecf6088Slonely eagle     //   ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
742ecf6088Slonely eagle     //   m + n * k")
757d0426ddSRiver Riddle     // * Multiple sizes for a given dimension compose by taking the size of the
767d0426ddSRiver Riddle     //   final subview and ignoring the rest. ("Take m values" followed by "Take
777d0426ddSRiver Riddle     //   n values" == "Take n values") This size must also be the smallest one
787d0426ddSRiver Riddle     //   by definition (a subview needs to be the same size as or smaller than
797d0426ddSRiver Riddle     //   its source along each dimension; presumably subviews that are larger
807d0426ddSRiver Riddle     //   than their sources are disallowed by validation).
812ecf6088Slonely eagle     for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
822ecf6088Slonely eagle          llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
832ecf6088Slonely eagle                    sourceOp.getMixedStrides(), op.getMixedSizes())) {
847d0426ddSRiver Riddle       // We only support static sizes.
85*30916b69SKazu Hirata       if (isa<Value>(opSize)) {
867d0426ddSRiver Riddle         return failure();
877d0426ddSRiver Riddle       }
887d0426ddSRiver Riddle       sizes.push_back(opSize);
8968f58812STres Popp       Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
9068f58812STres Popp                 sourceOffsetAttr =
912ecf6088Slonely eagle                     llvm::dyn_cast_if_present<Attribute>(sourceOffset),
922ecf6088Slonely eagle                 sourceStrideAttr =
932ecf6088Slonely eagle                     llvm::dyn_cast_if_present<Attribute>(sourceStride);
947d0426ddSRiver Riddle       if (opOffsetAttr && sourceOffsetAttr) {
952ecf6088Slonely eagle 
967d0426ddSRiver Riddle         // If both offsets are static we can simply calculate the combined
977d0426ddSRiver Riddle         // offset statically.
987d0426ddSRiver Riddle         offsets.push_back(rewriter.getI64IntegerAttr(
992ecf6088Slonely eagle             cast<IntegerAttr>(opOffsetAttr).getInt() *
1002ecf6088Slonely eagle                 cast<IntegerAttr>(sourceStrideAttr).getInt() +
1015550c821STres Popp             cast<IntegerAttr>(sourceOffsetAttr).getInt()));
1027d0426ddSRiver Riddle       } else {
1032ecf6088Slonely eagle         AffineExpr expr;
1047d0426ddSRiver Riddle         SmallVector<Value> affineApplyOperands;
1052ecf6088Slonely eagle 
1062ecf6088Slonely eagle         // Make 'expr' add 'sourceOffset'.
1072ecf6088Slonely eagle         if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
1082ecf6088Slonely eagle           expr =
1092ecf6088Slonely eagle               rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
1102ecf6088Slonely eagle         } else {
1112ecf6088Slonely eagle           expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
112*30916b69SKazu Hirata           affineApplyOperands.push_back(cast<Value>(sourceOffset));
1132ecf6088Slonely eagle         }
1142ecf6088Slonely eagle 
1152ecf6088Slonely eagle         // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
1162ecf6088Slonely eagle         // result.
1172ecf6088Slonely eagle         if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
1182ecf6088Slonely eagle           expr = expr + cast<IntegerAttr>(attr).getInt() *
1192ecf6088Slonely eagle                             cast<IntegerAttr>(sourceStrideAttr).getInt();
1207d0426ddSRiver Riddle         } else {
1217d0426ddSRiver Riddle           expr =
1222ecf6088Slonely eagle               expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
1232ecf6088Slonely eagle                          cast<IntegerAttr>(sourceStrideAttr).getInt();
124*30916b69SKazu Hirata           affineApplyOperands.push_back(cast<Value>(opOffset));
1257d0426ddSRiver Riddle         }
1267d0426ddSRiver Riddle 
1277d0426ddSRiver Riddle         AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
1284c48f016SMatthias Springer         Value result = rewriter.create<affine::AffineApplyOp>(
1294c48f016SMatthias Springer             op.getLoc(), map, affineApplyOperands);
1307d0426ddSRiver Riddle         offsets.push_back(result);
1317d0426ddSRiver Riddle       }
1327d0426ddSRiver Riddle     }
1337d0426ddSRiver Riddle 
1347d0426ddSRiver Riddle     // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
1357d0426ddSRiver Riddle     // uses it can be removed by a (separate) dead code elimination pass.
1362ecf6088Slonely eagle     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
1372ecf6088Slonely eagle         op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
1387d0426ddSRiver Riddle     return success();
1397d0426ddSRiver Riddle   }
1407d0426ddSRiver Riddle };
1417d0426ddSRiver Riddle 
1427d0426ddSRiver Riddle } // namespace
1437d0426ddSRiver Riddle 
144b7f93c28SJeff Niu void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
145b7f93c28SJeff Niu                                                   MLIRContext *context) {
146b4e0507cSTres Popp   patterns.add<ComposeSubViewOpPattern>(context);
1477d0426ddSRiver Riddle }
148