xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp (revision 051612c0180e4e5a9ba750a994a91d2c1b05b00c)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace tensor {
18 namespace {
19 
20 struct CastOpInterface
21     : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23                                        ValueBoundsConstraintSet &cstr) const {
24     auto castOp = cast<CastOp>(op);
25     assert(value == castOp.getResult() && "invalid value");
26 
27     if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28         llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29       cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30     }
31   }
32 };
33 
34 struct DimOpInterface
35     : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
36   void populateBoundsForIndexValue(Operation *op, Value value,
37                                    ValueBoundsConstraintSet &cstr) const {
38     auto dimOp = cast<DimOp>(op);
39     assert(value == dimOp.getResult() && "invalid value");
40 
41     cstr.bound(value) >= 0;
42     auto constIndex = dimOp.getConstantIndex();
43     if (!constIndex.has_value())
44       return;
45     cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
46   }
47 };
48 
49 struct EmptyOpInterface
50     : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
51   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
52                                        ValueBoundsConstraintSet &cstr) const {
53     auto emptyOp = cast<EmptyOp>(op);
54     assert(value == emptyOp.getResult() && "invalid value");
55 
56     cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
57   }
58 };
59 
60 struct ExtractSliceOpInterface
61     : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
62                                                    ExtractSliceOp> {
63   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
64                                        ValueBoundsConstraintSet &cstr) const {
65     auto extractSliceOp = cast<ExtractSliceOp>(op);
66     assert(value == extractSliceOp.getResult() && "invalid value");
67 
68     llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
69     int64_t ctr = -1;
70     for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
71       // Skip over rank-reduced dimensions.
72       if (!dropped.test(i))
73         ++ctr;
74       if (ctr == dim) {
75         cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
76         return;
77       }
78     }
79     llvm_unreachable("could not find non-rank-reduced dim");
80   }
81 };
82 
83 struct PadOpInterface
84     : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
85   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
86                                        ValueBoundsConstraintSet &cstr) const {
87     auto padOp = cast<PadOp>(op);
88     assert(value == padOp.getResult() && "invalid value");
89 
90     AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
91     AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
92     AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
93     cstr.bound(value)[dim] == srcSize + lowPad + highPad;
94   }
95 };
96 
97 struct RankOpInterface
98     : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
99   void populateBoundsForIndexValue(Operation *op, Value value,
100                                    ValueBoundsConstraintSet &cstr) const {
101     auto rankOp = cast<RankOp>(op);
102     assert(value == rankOp.getResult() && "invalid value");
103 
104     auto tensorType =
105         llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
106     if (!tensorType)
107       return;
108     cstr.bound(value) == tensorType.getRank();
109   }
110 };
111 
112 } // namespace
113 } // namespace tensor
114 } // namespace mlir
115 
116 void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
117     DialectRegistry &registry) {
118   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
119     tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
120     tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
121     tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
122     tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
123         *ctx);
124     tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
125     tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
126     // Note: ValueBoundsOpInterface implementation is not required for ops that
127     // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
128   });
129 }
130