17d0426ddSRiver Riddle //===- ComposeSubView.cpp - Combining composed subview ops ----------------===// 27d0426ddSRiver Riddle // 37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67d0426ddSRiver Riddle // 77d0426ddSRiver Riddle //===----------------------------------------------------------------------===// 87d0426ddSRiver Riddle // 97d0426ddSRiver Riddle // This file contains patterns for combining composed subview ops (i.e. subview 107d0426ddSRiver Riddle // of a subview becomes a single subview). 117d0426ddSRiver Riddle // 127d0426ddSRiver Riddle //===----------------------------------------------------------------------===// 137d0426ddSRiver Riddle 147d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" 157d0426ddSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 167d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 177d0426ddSRiver Riddle #include "mlir/IR/BuiltinAttributes.h" 187d0426ddSRiver Riddle #include "mlir/IR/OpDefinition.h" 197d0426ddSRiver Riddle #include "mlir/IR/PatternMatch.h" 207d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h" 217d0426ddSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 227d0426ddSRiver Riddle 237d0426ddSRiver Riddle using namespace mlir; 247d0426ddSRiver Riddle 257d0426ddSRiver Riddle namespace { 267d0426ddSRiver Riddle 272ecf6088Slonely eagle // Replaces a subview of a subview with a single subview(both static and dynamic 287d0426ddSRiver Riddle // offsets are supported). 297d0426ddSRiver Riddle struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { 307d0426ddSRiver Riddle using OpRewritePattern::OpRewritePattern; 317d0426ddSRiver Riddle 327d0426ddSRiver Riddle LogicalResult matchAndRewrite(memref::SubViewOp op, 337d0426ddSRiver Riddle PatternRewriter &rewriter) const override { 347d0426ddSRiver Riddle // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that 357d0426ddSRiver Riddle // produces the input of the op we're rewriting (for 'SubViewOp' the input 367d0426ddSRiver Riddle // is called the "source" value). We can only combine them if both 'op' and 377d0426ddSRiver Riddle // 'sourceOp' are 'SubViewOp'. 38136d746eSJacques Pienaar auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>(); 397d0426ddSRiver Riddle if (!sourceOp) 407d0426ddSRiver Riddle return failure(); 417d0426ddSRiver Riddle 427d0426ddSRiver Riddle // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the 437d0426ddSRiver Riddle // output memref that are statically known to be equal to 1. We do not 447d0426ddSRiver Riddle // allow 'sourceOp' to be a rank-reducing subview because then our two 457d0426ddSRiver Riddle // 'SubViewOp's would have different numbers of offset/size/stride 467d0426ddSRiver Riddle // parameters (just difficult to deal with, not impossible if we end up 477d0426ddSRiver Riddle // needing it). 487d0426ddSRiver Riddle if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) { 497d0426ddSRiver Riddle return failure(); 507d0426ddSRiver Riddle } 517d0426ddSRiver Riddle 527d0426ddSRiver Riddle // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'. 532ecf6088Slonely eagle SmallVector<OpFoldResult> offsets, sizes, strides, 542ecf6088Slonely eagle opStrides = op.getMixedStrides(), 552ecf6088Slonely eagle sourceStrides = sourceOp.getMixedStrides(); 567d0426ddSRiver Riddle 572ecf6088Slonely eagle // The output stride in each dimension is equal to the product of the 582ecf6088Slonely eagle // dimensions corresponding to source and op. 592ecf6088Slonely eagle int64_t sourceStrideValue; 602ecf6088Slonely eagle for (auto &&[opStride, sourceStride] : 612ecf6088Slonely eagle llvm::zip(opStrides, sourceStrides)) { 622ecf6088Slonely eagle Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride); 632ecf6088Slonely eagle Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride); 642ecf6088Slonely eagle if (!opStrideAttr || !sourceStrideAttr) 657d0426ddSRiver Riddle return failure(); 662ecf6088Slonely eagle sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt(); 672ecf6088Slonely eagle strides.push_back(rewriter.getI64IntegerAttr( 682ecf6088Slonely eagle cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue)); 697d0426ddSRiver Riddle } 707d0426ddSRiver Riddle 717d0426ddSRiver Riddle // The rules for calculating the new offsets and sizes are: 727d0426ddSRiver Riddle // * Multiple subview offsets for a given dimension compose additively. 732ecf6088Slonely eagle // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by 742ecf6088Slonely eagle // m + n * k") 757d0426ddSRiver Riddle // * Multiple sizes for a given dimension compose by taking the size of the 767d0426ddSRiver Riddle // final subview and ignoring the rest. ("Take m values" followed by "Take 777d0426ddSRiver Riddle // n values" == "Take n values") This size must also be the smallest one 787d0426ddSRiver Riddle // by definition (a subview needs to be the same size as or smaller than 797d0426ddSRiver Riddle // its source along each dimension; presumably subviews that are larger 807d0426ddSRiver Riddle // than their sources are disallowed by validation). 812ecf6088Slonely eagle for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : 822ecf6088Slonely eagle llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), 832ecf6088Slonely eagle sourceOp.getMixedStrides(), op.getMixedSizes())) { 847d0426ddSRiver Riddle // We only support static sizes. 85*30916b69SKazu Hirata if (isa<Value>(opSize)) { 867d0426ddSRiver Riddle return failure(); 877d0426ddSRiver Riddle } 887d0426ddSRiver Riddle sizes.push_back(opSize); 8968f58812STres Popp Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), 9068f58812STres Popp sourceOffsetAttr = 912ecf6088Slonely eagle llvm::dyn_cast_if_present<Attribute>(sourceOffset), 922ecf6088Slonely eagle sourceStrideAttr = 932ecf6088Slonely eagle llvm::dyn_cast_if_present<Attribute>(sourceStride); 947d0426ddSRiver Riddle if (opOffsetAttr && sourceOffsetAttr) { 952ecf6088Slonely eagle 967d0426ddSRiver Riddle // If both offsets are static we can simply calculate the combined 977d0426ddSRiver Riddle // offset statically. 987d0426ddSRiver Riddle offsets.push_back(rewriter.getI64IntegerAttr( 992ecf6088Slonely eagle cast<IntegerAttr>(opOffsetAttr).getInt() * 1002ecf6088Slonely eagle cast<IntegerAttr>(sourceStrideAttr).getInt() + 1015550c821STres Popp cast<IntegerAttr>(sourceOffsetAttr).getInt())); 1027d0426ddSRiver Riddle } else { 1032ecf6088Slonely eagle AffineExpr expr; 1047d0426ddSRiver Riddle SmallVector<Value> affineApplyOperands; 1052ecf6088Slonely eagle 1062ecf6088Slonely eagle // Make 'expr' add 'sourceOffset'. 1072ecf6088Slonely eagle if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) { 1082ecf6088Slonely eagle expr = 1092ecf6088Slonely eagle rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()); 1102ecf6088Slonely eagle } else { 1112ecf6088Slonely eagle expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size()); 112*30916b69SKazu Hirata affineApplyOperands.push_back(cast<Value>(sourceOffset)); 1132ecf6088Slonely eagle } 1142ecf6088Slonely eagle 1152ecf6088Slonely eagle // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the 1162ecf6088Slonely eagle // result. 1172ecf6088Slonely eagle if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) { 1182ecf6088Slonely eagle expr = expr + cast<IntegerAttr>(attr).getInt() * 1192ecf6088Slonely eagle cast<IntegerAttr>(sourceStrideAttr).getInt(); 1207d0426ddSRiver Riddle } else { 1217d0426ddSRiver Riddle expr = 1222ecf6088Slonely eagle expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) * 1232ecf6088Slonely eagle cast<IntegerAttr>(sourceStrideAttr).getInt(); 124*30916b69SKazu Hirata affineApplyOperands.push_back(cast<Value>(opOffset)); 1257d0426ddSRiver Riddle } 1267d0426ddSRiver Riddle 1277d0426ddSRiver Riddle AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); 1284c48f016SMatthias Springer Value result = rewriter.create<affine::AffineApplyOp>( 1294c48f016SMatthias Springer op.getLoc(), map, affineApplyOperands); 1307d0426ddSRiver Riddle offsets.push_back(result); 1317d0426ddSRiver Riddle } 1327d0426ddSRiver Riddle } 1337d0426ddSRiver Riddle 1347d0426ddSRiver Riddle // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any 1357d0426ddSRiver Riddle // uses it can be removed by a (separate) dead code elimination pass. 1362ecf6088Slonely eagle rewriter.replaceOpWithNewOp<memref::SubViewOp>( 1372ecf6088Slonely eagle op, op.getType(), sourceOp.getSource(), offsets, sizes, strides); 1387d0426ddSRiver Riddle return success(); 1397d0426ddSRiver Riddle } 1407d0426ddSRiver Riddle }; 1417d0426ddSRiver Riddle 1427d0426ddSRiver Riddle } // namespace 1437d0426ddSRiver Riddle 144b7f93c28SJeff Niu void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns, 145b7f93c28SJeff Niu MLIRContext *context) { 146b4e0507cSTres Popp patterns.add<ComposeSubViewOpPattern>(context); 1477d0426ddSRiver Riddle } 148