1 //===- ComposeSubView.cpp - Combining composed subview ops ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains patterns for combining composed subview ops (i.e. subview 10 // of a subview becomes a single subview). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 // Replaces a subview of a subview with a single subview(both static and dynamic 28 // offsets are supported). 29 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { 30 using OpRewritePattern::OpRewritePattern; 31 32 LogicalResult matchAndRewrite(memref::SubViewOp op, 33 PatternRewriter &rewriter) const override { 34 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that 35 // produces the input of the op we're rewriting (for 'SubViewOp' the input 36 // is called the "source" value). We can only combine them if both 'op' and 37 // 'sourceOp' are 'SubViewOp'. 38 auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>(); 39 if (!sourceOp) 40 return failure(); 41 42 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the 43 // output memref that are statically known to be equal to 1. We do not 44 // allow 'sourceOp' to be a rank-reducing subview because then our two 45 // 'SubViewOp's would have different numbers of offset/size/stride 46 // parameters (just difficult to deal with, not impossible if we end up 47 // needing it). 48 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) { 49 return failure(); 50 } 51 52 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'. 53 SmallVector<OpFoldResult> offsets, sizes, strides, 54 opStrides = op.getMixedStrides(), 55 sourceStrides = sourceOp.getMixedStrides(); 56 57 // The output stride in each dimension is equal to the product of the 58 // dimensions corresponding to source and op. 59 int64_t sourceStrideValue; 60 for (auto &&[opStride, sourceStride] : 61 llvm::zip(opStrides, sourceStrides)) { 62 Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride); 63 Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride); 64 if (!opStrideAttr || !sourceStrideAttr) 65 return failure(); 66 sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt(); 67 strides.push_back(rewriter.getI64IntegerAttr( 68 cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue)); 69 } 70 71 // The rules for calculating the new offsets and sizes are: 72 // * Multiple subview offsets for a given dimension compose additively. 73 // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by 74 // m + n * k") 75 // * Multiple sizes for a given dimension compose by taking the size of the 76 // final subview and ignoring the rest. ("Take m values" followed by "Take 77 // n values" == "Take n values") This size must also be the smallest one 78 // by definition (a subview needs to be the same size as or smaller than 79 // its source along each dimension; presumably subviews that are larger 80 // than their sources are disallowed by validation). 81 for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : 82 llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), 83 sourceOp.getMixedStrides(), op.getMixedSizes())) { 84 // We only support static sizes. 85 if (isa<Value>(opSize)) { 86 return failure(); 87 } 88 sizes.push_back(opSize); 89 Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), 90 sourceOffsetAttr = 91 llvm::dyn_cast_if_present<Attribute>(sourceOffset), 92 sourceStrideAttr = 93 llvm::dyn_cast_if_present<Attribute>(sourceStride); 94 if (opOffsetAttr && sourceOffsetAttr) { 95 96 // If both offsets are static we can simply calculate the combined 97 // offset statically. 98 offsets.push_back(rewriter.getI64IntegerAttr( 99 cast<IntegerAttr>(opOffsetAttr).getInt() * 100 cast<IntegerAttr>(sourceStrideAttr).getInt() + 101 cast<IntegerAttr>(sourceOffsetAttr).getInt())); 102 } else { 103 AffineExpr expr; 104 SmallVector<Value> affineApplyOperands; 105 106 // Make 'expr' add 'sourceOffset'. 107 if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) { 108 expr = 109 rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()); 110 } else { 111 expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size()); 112 affineApplyOperands.push_back(cast<Value>(sourceOffset)); 113 } 114 115 // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the 116 // result. 117 if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) { 118 expr = expr + cast<IntegerAttr>(attr).getInt() * 119 cast<IntegerAttr>(sourceStrideAttr).getInt(); 120 } else { 121 expr = 122 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) * 123 cast<IntegerAttr>(sourceStrideAttr).getInt(); 124 affineApplyOperands.push_back(cast<Value>(opOffset)); 125 } 126 127 AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); 128 Value result = rewriter.create<affine::AffineApplyOp>( 129 op.getLoc(), map, affineApplyOperands); 130 offsets.push_back(result); 131 } 132 } 133 134 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any 135 // uses it can be removed by a (separate) dead code elimination pass. 136 rewriter.replaceOpWithNewOp<memref::SubViewOp>( 137 op, op.getType(), sourceOp.getSource(), offsets, sizes, strides); 138 return success(); 139 } 140 }; 141 142 } // namespace 143 144 void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns, 145 MLIRContext *context) { 146 patterns.add<ComposeSubViewOpPattern>(context); 147 } 148