xref: /llvm-project/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp (revision 051612c0180e4e5a9ba750a994a91d2c1b05b00c)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace memref {
18 namespace {
19 
20 template <typename OpTy>
21 struct AllocOpInterface
22     : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23                                                    OpTy> {
24   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25                                        ValueBoundsConstraintSet &cstr) const {
26     auto allocOp = cast<OpTy>(op);
27     assert(value == allocOp.getResult() && "invalid value");
28 
29     cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30   }
31 };
32 
33 struct CastOpInterface
34     : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36                                        ValueBoundsConstraintSet &cstr) const {
37     auto castOp = cast<CastOp>(op);
38     assert(value == castOp.getResult() && "invalid value");
39 
40     if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41         llvm::isa<MemRefType>(castOp.getSource().getType())) {
42       cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43     }
44   }
45 };
46 
47 struct DimOpInterface
48     : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49   void populateBoundsForIndexValue(Operation *op, Value value,
50                                    ValueBoundsConstraintSet &cstr) const {
51     auto dimOp = cast<DimOp>(op);
52     assert(value == dimOp.getResult() && "invalid value");
53 
54     cstr.bound(value) >= 0;
55     auto constIndex = dimOp.getConstantIndex();
56     if (!constIndex.has_value())
57       return;
58     cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
59   }
60 };
61 
62 struct GetGlobalOpInterface
63     : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
64                                                    GetGlobalOp> {
65   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66                                        ValueBoundsConstraintSet &cstr) const {
67     auto getGlobalOp = cast<GetGlobalOp>(op);
68     assert(value == getGlobalOp.getResult() && "invalid value");
69 
70     auto type = getGlobalOp.getType();
71     assert(!type.isDynamicDim(dim) && "expected static dim");
72     cstr.bound(value)[dim] == type.getDimSize(dim);
73   }
74 };
75 
76 struct RankOpInterface
77     : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
78   void populateBoundsForIndexValue(Operation *op, Value value,
79                                    ValueBoundsConstraintSet &cstr) const {
80     auto rankOp = cast<RankOp>(op);
81     assert(value == rankOp.getResult() && "invalid value");
82 
83     auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
84     if (!memrefType)
85       return;
86     cstr.bound(value) == memrefType.getRank();
87   }
88 };
89 
90 struct SubViewOpInterface
91     : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
92                                                    SubViewOp> {
93   void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
94                                        ValueBoundsConstraintSet &cstr) const {
95     auto subViewOp = cast<SubViewOp>(op);
96     assert(value == subViewOp.getResult() && "invalid value");
97 
98     llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
99     int64_t ctr = -1;
100     for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
101       // Skip over rank-reduced dimensions.
102       if (!dropped.test(i))
103         ++ctr;
104       if (ctr == dim) {
105         cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
106         return;
107       }
108     }
109     llvm_unreachable("could not find non-rank-reduced dim");
110   }
111 };
112 
113 } // namespace
114 } // namespace memref
115 } // namespace mlir
116 
117 void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
118     DialectRegistry &registry) {
119   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
120     memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
121         *ctx);
122     memref::AllocaOp::attachInterface<
123         memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
124     memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
125     memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
126     memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
127     memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
128     memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
129   });
130 }
131