1 //===- IndependenceTransforms.cpp - Make ops independent of values --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Affine/Transforms/Transforms.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 16 17 using namespace mlir; 18 using namespace mlir::memref; 19 20 /// Make the given OpFoldResult independent of all independencies. 21 static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, 22 OpFoldResult ofr, 23 ValueRange independencies) { 24 if (isa<Attribute>(ofr)) 25 return ofr; 26 AffineMap boundMap; 27 ValueDimList mapOperands; 28 if (failed(ValueBoundsConstraintSet::computeIndependentBound( 29 boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, 30 /*closedUB=*/true))) 31 return failure(); 32 return affine::materializeComputedBound(b, loc, boundMap, mapOperands); 33 } 34 35 FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, 36 memref::AllocaOp allocaOp, 37 ValueRange independencies) { 38 OpBuilder::InsertionGuard g(b); 39 b.setInsertionPoint(allocaOp); 40 Location loc = allocaOp.getLoc(); 41 42 SmallVector<OpFoldResult> newSizes; 43 for (OpFoldResult ofr : allocaOp.getMixedSizes()) { 44 auto ub = makeIndependent(b, loc, ofr, independencies); 45 if (failed(ub)) 46 return failure(); 47 newSizes.push_back(*ub); 48 } 49 50 // Return existing memref::AllocaOp if nothing has changed. 51 if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) 52 return allocaOp.getResult(); 53 54 // Create a new memref::AllocaOp. 55 Value newAllocaOp = 56 b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType()); 57 58 // Create a memref::SubViewOp. 59 SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); 60 SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); 61 return b 62 .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), 63 strides) 64 .getResult(); 65 } 66 67 /// Push down an UnrealizedConversionCastOp past a SubViewOp. 68 static UnrealizedConversionCastOp 69 propagateSubViewOp(RewriterBase &rewriter, 70 UnrealizedConversionCastOp conversionOp, SubViewOp op) { 71 OpBuilder::InsertionGuard g(rewriter); 72 rewriter.setInsertionPoint(op); 73 auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType( 74 op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), 75 op.getMixedSizes(), op.getMixedStrides())); 76 Value newSubview = rewriter.create<SubViewOp>( 77 op.getLoc(), newResultType, conversionOp.getOperand(0), 78 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); 79 auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>( 80 op.getLoc(), op.getType(), newSubview); 81 rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); 82 return newConversionOp; 83 } 84 85 /// Given an original op and a new, modified op with the same number of results, 86 /// whose memref return types may differ, replace all uses of the original op 87 /// with the new op and propagate the new memref types through the IR. 88 /// 89 /// Example: 90 /// %from = memref.alloca(%sz) : memref<?xf32> 91 /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>> 92 /// memref.store %cst, %from[%c0] : memref<?xf32> 93 /// 94 /// In the above example, all uses of %from are replaced with %to. This can be 95 /// done directly for ops such as memref.store. For ops that have memref results 96 /// (e.g., memref.subview), the result type may depend on the operand type, so 97 /// we cannot just replace all uses. There is special handling for common memref 98 /// ops. For all other ops, unrealized_conversion_cast is inserted. 99 static void replaceAndPropagateMemRefType(RewriterBase &rewriter, 100 Operation *from, Operation *to) { 101 assert(from->getNumResults() == to->getNumResults() && 102 "expected same number of results"); 103 OpBuilder::InsertionGuard g(rewriter); 104 rewriter.setInsertionPointAfter(to); 105 106 // Wrap new results in unrealized_conversion_cast and replace all uses of the 107 // original op. 108 SmallVector<UnrealizedConversionCastOp> unrealizedConversions; 109 for (const auto &it : 110 llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { 111 unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>( 112 to->getLoc(), std::get<0>(it.value()).getType(), 113 std::get<1>(it.value()))); 114 rewriter.replaceAllUsesWith(from->getResult(it.index()), 115 unrealizedConversions.back()->getResult(0)); 116 } 117 118 // Push unrealized_conversion_cast ops further down in the IR. I.e., try to 119 // wrap results instead of operands in a cast. 120 for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) { 121 UnrealizedConversionCastOp conversion = unrealizedConversions[i]; 122 assert(conversion->getNumOperands() == 1 && 123 conversion->getNumResults() == 1 && 124 "expected single operand and single result"); 125 SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers()); 126 for (Operation *user : users) { 127 // Handle common memref dialect ops that produce new memrefs and must 128 // be recreated with the new result type. 129 if (auto subviewOp = dyn_cast<SubViewOp>(user)) { 130 unrealizedConversions.push_back( 131 propagateSubViewOp(rewriter, conversion, subviewOp)); 132 continue; 133 } 134 135 // TODO: Other memref ops such as memref.collapse_shape/expand_shape 136 // should also be handled here. 137 138 // Skip any ops that produce MemRef result or have MemRef region block 139 // arguments. These may need special handling (e.g., scf.for). 140 if (llvm::any_of(user->getResultTypes(), 141 [](Type t) { return isa<MemRefType>(t); })) 142 continue; 143 if (llvm::any_of(user->getRegions(), [](Region &r) { 144 return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { 145 return isa<MemRefType>(bbArg.getType()); 146 }); 147 })) 148 continue; 149 150 // For all other ops, we assume that we can directly replace the operand. 151 // This may have to be revised in the future; e.g., there may be ops that 152 // do not support non-identity layout maps. 153 for (OpOperand &operand : user->getOpOperands()) { 154 if ([[maybe_unused]] auto castOp = 155 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) { 156 rewriter.modifyOpInPlace( 157 user, [&]() { operand.set(conversion->getOperand(0)); }); 158 } 159 } 160 } 161 } 162 163 // Erase all unrealized_conversion_cast ops without uses. 164 for (auto op : unrealizedConversions) 165 if (op->getUses().empty()) 166 rewriter.eraseOp(op); 167 } 168 169 FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter, 170 memref::AllocaOp allocaOp, 171 ValueRange independencies) { 172 auto replacement = 173 memref::buildIndependentOp(rewriter, allocaOp, independencies); 174 if (failed(replacement)) 175 return failure(); 176 replaceAndPropagateMemRefType(rewriter, allocaOp, 177 replacement->getDefiningOp()); 178 return replacement; 179 } 180 181 memref::AllocaOp memref::allocToAlloca( 182 RewriterBase &rewriter, memref::AllocOp alloc, 183 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) { 184 memref::DeallocOp dealloc = nullptr; 185 for (Operation &candidate : 186 llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) { 187 dealloc = dyn_cast<memref::DeallocOp>(candidate); 188 if (dealloc && dealloc.getMemref() == alloc.getMemref() && 189 (!filter || filter(alloc, dealloc))) { 190 break; 191 } 192 } 193 194 if (!dealloc) 195 return nullptr; 196 197 OpBuilder::InsertionGuard guard(rewriter); 198 rewriter.setInsertionPoint(alloc); 199 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>( 200 alloc, alloc.getMemref().getType(), alloc.getOperands()); 201 rewriter.eraseOp(dealloc); 202 return alloca; 203 } 204