xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (revision fecf1397e32a89a928ebeeab07bfc7e38a318827)
177124386SMatthias Springer //===- IndependenceTransforms.cpp - Make ops independent of values --------===//
277124386SMatthias Springer //
377124386SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
477124386SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
577124386SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
677124386SMatthias Springer //
777124386SMatthias Springer //===----------------------------------------------------------------------===//
877124386SMatthias Springer 
977124386SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1077124386SMatthias Springer 
1177124386SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1277124386SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h"
1377124386SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1477124386SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
1577124386SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
1677124386SMatthias Springer 
1777124386SMatthias Springer using namespace mlir;
1877124386SMatthias Springer using namespace mlir::tensor;
1977124386SMatthias Springer 
2077124386SMatthias Springer /// Make the given OpFoldResult independent of all independencies.
2177124386SMatthias Springer static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
2277124386SMatthias Springer                                                OpFoldResult ofr,
2377124386SMatthias Springer                                                ValueRange independencies) {
24*fecf1397SKazu Hirata   if (isa<Attribute>(ofr))
2577124386SMatthias Springer     return ofr;
26*fecf1397SKazu Hirata   Value value = cast<Value>(ofr);
2777124386SMatthias Springer   AffineMap boundMap;
2877124386SMatthias Springer   ValueDimList mapOperands;
2977124386SMatthias Springer   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
3077124386SMatthias Springer           boundMap, mapOperands, presburger::BoundType::UB, value,
3140dd3aa9SMatthias Springer           independencies,
3240dd3aa9SMatthias Springer           /*closedUB=*/true)))
3377124386SMatthias Springer     return failure();
3477124386SMatthias Springer   return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
3577124386SMatthias Springer }
3677124386SMatthias Springer 
3777124386SMatthias Springer FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
3877124386SMatthias Springer                                             ValueRange independencies) {
3977124386SMatthias Springer   OpBuilder::InsertionGuard g(b);
4077124386SMatthias Springer   b.setInsertionPoint(padOp);
4177124386SMatthias Springer   Location loc = padOp.getLoc();
4277124386SMatthias Springer 
4377124386SMatthias Springer   // Non-constant padding not supported.
4477124386SMatthias Springer   Value constantPadding = padOp.getConstantPaddingValue();
4577124386SMatthias Springer   if (!constantPadding)
4677124386SMatthias Springer     return failure();
4777124386SMatthias Springer 
4877124386SMatthias Springer   SmallVector<OpFoldResult> newMixedLow, newMixedHigh;
4977124386SMatthias Springer   for (OpFoldResult ofr : padOp.getMixedLowPad()) {
5077124386SMatthias Springer     auto ub = makeIndependent(b, loc, ofr, independencies);
5177124386SMatthias Springer     if (failed(ub))
5277124386SMatthias Springer       return failure();
5377124386SMatthias Springer     newMixedLow.push_back(*ub);
5477124386SMatthias Springer   }
5577124386SMatthias Springer   for (OpFoldResult ofr : padOp.getMixedHighPad()) {
5677124386SMatthias Springer     auto ub = makeIndependent(b, loc, ofr, independencies);
5777124386SMatthias Springer     if (failed(ub))
5877124386SMatthias Springer       return failure();
5977124386SMatthias Springer     newMixedHigh.push_back(*ub);
6077124386SMatthias Springer   }
6177124386SMatthias Springer 
6277124386SMatthias Springer   // Return existing tensor::PadOp if nothing has changed.
6377124386SMatthias Springer   if (llvm::equal(padOp.getMixedLowPad(), newMixedLow) &&
6477124386SMatthias Springer       llvm::equal(padOp.getMixedHighPad(), newMixedHigh))
6577124386SMatthias Springer     return padOp.getResult();
6677124386SMatthias Springer 
6777124386SMatthias Springer   // Create a new tensor::PadOp.
6877124386SMatthias Springer   auto newPadOp = b.create<PadOp>(
6977124386SMatthias Springer       loc, padOp.getResultType(), padOp.getSource(), newMixedLow, newMixedHigh,
7077124386SMatthias Springer       constantPadding, padOp.getNofold(), /*attrs=*/ArrayRef<NamedAttribute>{});
7177124386SMatthias Springer 
7277124386SMatthias Springer   // Create a tensor::ExtractSliceOp.
7377124386SMatthias Springer   // Reify the result sizes of the old tensor::PadOp.
7477124386SMatthias Springer   ReifiedRankedShapedTypeDims reifiedSizes;
7577124386SMatthias Springer   ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
7677124386SMatthias Springer       dyn_cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
7777124386SMatthias Springer   if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedSizes)))
7877124386SMatthias Springer     return failure();
7977124386SMatthias Springer   SmallVector<OpFoldResult> offsets, sizes, strides;
8077124386SMatthias Springer   for (int64_t i = 0, e = padOp.getResultType().getRank(); i < e; ++i) {
8177124386SMatthias Springer     // offset = ub(low_padding) - low_padding
8277124386SMatthias Springer     OpFoldResult prevLow = padOp.getMixedLowPad()[i];
83*fecf1397SKazu Hirata     if (isa<Attribute>(prevLow)) {
8477124386SMatthias Springer       offsets.push_back(b.getIndexAttr(0));
8577124386SMatthias Springer     } else {
8677124386SMatthias Springer       offsets.push_back(
8777124386SMatthias Springer           b.create<affine::AffineApplyOp>(
8877124386SMatthias Springer                loc, b.getAffineDimExpr(0) - b.getAffineDimExpr(1),
89*fecf1397SKazu Hirata                std::initializer_list<Value>{cast<Value>(newMixedLow[i]),
90*fecf1397SKazu Hirata                                             cast<Value>(prevLow)})
9177124386SMatthias Springer               .getResult());
9277124386SMatthias Springer     }
9377124386SMatthias Springer     // size = reified result size
9477124386SMatthias Springer     if (!padOp.getResultType().isDynamicDim(i)) {
9577124386SMatthias Springer       sizes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(i)));
9677124386SMatthias Springer     } else {
9777124386SMatthias Springer       sizes.push_back(reifiedSizes[0][i]);
9877124386SMatthias Springer     }
9977124386SMatthias Springer     // stride = 1
10077124386SMatthias Springer     strides.push_back(b.getIndexAttr(1));
10177124386SMatthias Springer   }
10277124386SMatthias Springer 
10377124386SMatthias Springer   return b.create<ExtractSliceOp>(loc, newPadOp, offsets, sizes, strides)
10477124386SMatthias Springer       .getResult();
10577124386SMatthias Springer }
10677124386SMatthias Springer 
10777124386SMatthias Springer FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b,
10877124386SMatthias Springer                                             tensor::EmptyOp emptyOp,
10977124386SMatthias Springer                                             ValueRange independencies) {
11077124386SMatthias Springer   OpBuilder::InsertionGuard g(b);
11177124386SMatthias Springer   b.setInsertionPoint(emptyOp);
11277124386SMatthias Springer   Location loc = emptyOp.getLoc();
11377124386SMatthias Springer 
11477124386SMatthias Springer   SmallVector<OpFoldResult> newSizes;
11577124386SMatthias Springer   for (OpFoldResult ofr : emptyOp.getMixedSizes()) {
11677124386SMatthias Springer     auto ub = makeIndependent(b, loc, ofr, independencies);
11777124386SMatthias Springer     if (failed(ub))
11877124386SMatthias Springer       return failure();
11977124386SMatthias Springer     newSizes.push_back(*ub);
12077124386SMatthias Springer   }
12177124386SMatthias Springer 
12277124386SMatthias Springer   // Return existing tensor::EmptyOp if nothing has changed.
12377124386SMatthias Springer   if (llvm::equal(emptyOp.getMixedSizes(), newSizes))
12477124386SMatthias Springer     return emptyOp.getResult();
12577124386SMatthias Springer 
12677124386SMatthias Springer   // Create a new tensor::EmptyOp.
12777124386SMatthias Springer   Value newEmptyOp =
12877124386SMatthias Springer       b.create<EmptyOp>(loc, newSizes, emptyOp.getType().getElementType());
12977124386SMatthias Springer 
13077124386SMatthias Springer   // Create a tensor::ExtractSliceOp.
13177124386SMatthias Springer   SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
13277124386SMatthias Springer   SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
13377124386SMatthias Springer   return b
13477124386SMatthias Springer       .create<ExtractSliceOp>(loc, newEmptyOp, offsets, emptyOp.getMixedSizes(),
13577124386SMatthias Springer                               strides)
13677124386SMatthias Springer       .getResult();
13777124386SMatthias Springer }
138