177124386SMatthias Springer //===- IndependenceTransforms.cpp - Make ops independent of values --------===// 277124386SMatthias Springer // 377124386SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 477124386SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 577124386SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 677124386SMatthias Springer // 777124386SMatthias Springer //===----------------------------------------------------------------------===// 877124386SMatthias Springer 977124386SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 1077124386SMatthias Springer 1177124386SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 1277124386SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h" 1377124386SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1477124386SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 1577124386SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 1677124386SMatthias Springer 1777124386SMatthias Springer using namespace mlir; 1877124386SMatthias Springer using namespace mlir::tensor; 1977124386SMatthias Springer 2077124386SMatthias Springer /// Make the given OpFoldResult independent of all independencies. 2177124386SMatthias Springer static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, 2277124386SMatthias Springer OpFoldResult ofr, 2377124386SMatthias Springer ValueRange independencies) { 24*fecf1397SKazu Hirata if (isa<Attribute>(ofr)) 2577124386SMatthias Springer return ofr; 26*fecf1397SKazu Hirata Value value = cast<Value>(ofr); 2777124386SMatthias Springer AffineMap boundMap; 2877124386SMatthias Springer ValueDimList mapOperands; 2977124386SMatthias Springer if (failed(ValueBoundsConstraintSet::computeIndependentBound( 3077124386SMatthias Springer boundMap, mapOperands, presburger::BoundType::UB, value, 3140dd3aa9SMatthias Springer independencies, 3240dd3aa9SMatthias Springer /*closedUB=*/true))) 3377124386SMatthias Springer return failure(); 3477124386SMatthias Springer return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands); 3577124386SMatthias Springer } 3677124386SMatthias Springer 3777124386SMatthias Springer FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, 3877124386SMatthias Springer ValueRange independencies) { 3977124386SMatthias Springer OpBuilder::InsertionGuard g(b); 4077124386SMatthias Springer b.setInsertionPoint(padOp); 4177124386SMatthias Springer Location loc = padOp.getLoc(); 4277124386SMatthias Springer 4377124386SMatthias Springer // Non-constant padding not supported. 4477124386SMatthias Springer Value constantPadding = padOp.getConstantPaddingValue(); 4577124386SMatthias Springer if (!constantPadding) 4677124386SMatthias Springer return failure(); 4777124386SMatthias Springer 4877124386SMatthias Springer SmallVector<OpFoldResult> newMixedLow, newMixedHigh; 4977124386SMatthias Springer for (OpFoldResult ofr : padOp.getMixedLowPad()) { 5077124386SMatthias Springer auto ub = makeIndependent(b, loc, ofr, independencies); 5177124386SMatthias Springer if (failed(ub)) 5277124386SMatthias Springer return failure(); 5377124386SMatthias Springer newMixedLow.push_back(*ub); 5477124386SMatthias Springer } 5577124386SMatthias Springer for (OpFoldResult ofr : padOp.getMixedHighPad()) { 5677124386SMatthias Springer auto ub = makeIndependent(b, loc, ofr, independencies); 5777124386SMatthias Springer if (failed(ub)) 5877124386SMatthias Springer return failure(); 5977124386SMatthias Springer newMixedHigh.push_back(*ub); 6077124386SMatthias Springer } 6177124386SMatthias Springer 6277124386SMatthias Springer // Return existing tensor::PadOp if nothing has changed. 6377124386SMatthias Springer if (llvm::equal(padOp.getMixedLowPad(), newMixedLow) && 6477124386SMatthias Springer llvm::equal(padOp.getMixedHighPad(), newMixedHigh)) 6577124386SMatthias Springer return padOp.getResult(); 6677124386SMatthias Springer 6777124386SMatthias Springer // Create a new tensor::PadOp. 6877124386SMatthias Springer auto newPadOp = b.create<PadOp>( 6977124386SMatthias Springer loc, padOp.getResultType(), padOp.getSource(), newMixedLow, newMixedHigh, 7077124386SMatthias Springer constantPadding, padOp.getNofold(), /*attrs=*/ArrayRef<NamedAttribute>{}); 7177124386SMatthias Springer 7277124386SMatthias Springer // Create a tensor::ExtractSliceOp. 7377124386SMatthias Springer // Reify the result sizes of the old tensor::PadOp. 7477124386SMatthias Springer ReifiedRankedShapedTypeDims reifiedSizes; 7577124386SMatthias Springer ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = 7677124386SMatthias Springer dyn_cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation()); 7777124386SMatthias Springer if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedSizes))) 7877124386SMatthias Springer return failure(); 7977124386SMatthias Springer SmallVector<OpFoldResult> offsets, sizes, strides; 8077124386SMatthias Springer for (int64_t i = 0, e = padOp.getResultType().getRank(); i < e; ++i) { 8177124386SMatthias Springer // offset = ub(low_padding) - low_padding 8277124386SMatthias Springer OpFoldResult prevLow = padOp.getMixedLowPad()[i]; 83*fecf1397SKazu Hirata if (isa<Attribute>(prevLow)) { 8477124386SMatthias Springer offsets.push_back(b.getIndexAttr(0)); 8577124386SMatthias Springer } else { 8677124386SMatthias Springer offsets.push_back( 8777124386SMatthias Springer b.create<affine::AffineApplyOp>( 8877124386SMatthias Springer loc, b.getAffineDimExpr(0) - b.getAffineDimExpr(1), 89*fecf1397SKazu Hirata std::initializer_list<Value>{cast<Value>(newMixedLow[i]), 90*fecf1397SKazu Hirata cast<Value>(prevLow)}) 9177124386SMatthias Springer .getResult()); 9277124386SMatthias Springer } 9377124386SMatthias Springer // size = reified result size 9477124386SMatthias Springer if (!padOp.getResultType().isDynamicDim(i)) { 9577124386SMatthias Springer sizes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(i))); 9677124386SMatthias Springer } else { 9777124386SMatthias Springer sizes.push_back(reifiedSizes[0][i]); 9877124386SMatthias Springer } 9977124386SMatthias Springer // stride = 1 10077124386SMatthias Springer strides.push_back(b.getIndexAttr(1)); 10177124386SMatthias Springer } 10277124386SMatthias Springer 10377124386SMatthias Springer return b.create<ExtractSliceOp>(loc, newPadOp, offsets, sizes, strides) 10477124386SMatthias Springer .getResult(); 10577124386SMatthias Springer } 10677124386SMatthias Springer 10777124386SMatthias Springer FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b, 10877124386SMatthias Springer tensor::EmptyOp emptyOp, 10977124386SMatthias Springer ValueRange independencies) { 11077124386SMatthias Springer OpBuilder::InsertionGuard g(b); 11177124386SMatthias Springer b.setInsertionPoint(emptyOp); 11277124386SMatthias Springer Location loc = emptyOp.getLoc(); 11377124386SMatthias Springer 11477124386SMatthias Springer SmallVector<OpFoldResult> newSizes; 11577124386SMatthias Springer for (OpFoldResult ofr : emptyOp.getMixedSizes()) { 11677124386SMatthias Springer auto ub = makeIndependent(b, loc, ofr, independencies); 11777124386SMatthias Springer if (failed(ub)) 11877124386SMatthias Springer return failure(); 11977124386SMatthias Springer newSizes.push_back(*ub); 12077124386SMatthias Springer } 12177124386SMatthias Springer 12277124386SMatthias Springer // Return existing tensor::EmptyOp if nothing has changed. 12377124386SMatthias Springer if (llvm::equal(emptyOp.getMixedSizes(), newSizes)) 12477124386SMatthias Springer return emptyOp.getResult(); 12577124386SMatthias Springer 12677124386SMatthias Springer // Create a new tensor::EmptyOp. 12777124386SMatthias Springer Value newEmptyOp = 12877124386SMatthias Springer b.create<EmptyOp>(loc, newSizes, emptyOp.getType().getElementType()); 12977124386SMatthias Springer 13077124386SMatthias Springer // Create a tensor::ExtractSliceOp. 13177124386SMatthias Springer SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); 13277124386SMatthias Springer SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); 13377124386SMatthias Springer return b 13477124386SMatthias Springer .create<ExtractSliceOp>(loc, newEmptyOp, offsets, emptyOp.getMixedSizes(), 13577124386SMatthias Springer strides) 13677124386SMatthias Springer .getResult(); 13777124386SMatthias Springer } 138