xref: /llvm-project/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td (revision a4c3683b665c6ac875b4821f5c6a881fdf5fef70)
1//===- DestinationStyleOpInterface.td ----------------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DESTINATIONSTYLEOPINTERFACE
10#define MLIR_DESTINATIONSTYLEOPINTERFACE
11
12include "mlir/IR/OpBase.td"
13
14def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
15  let description = [{
16    Ops that are in destination style have designated "init" operands, which act
17    as initial tensor values for the results of the operation or the init
18    buffers to which the results of the op will be written.
19
20    Init operands must be tensors or memrefs. Input operands can have any type.
21    All non-init operands are DPS inputs.
22
23    The init operands of this op are specified by the MutableOperandRange that
24    the `getDpsInitsMutable` interface methods returns. This implies that the
25    init operands must be a consecutive range of operands.
26
27    Each tensor init operand is tied to a corresponding tensor OpResult in a
28    1-to-1 fashion. The i-th init tensor is tied to the i-th OpResult. The op
29    may not have any additional OpResults. Init operands and their tied
30    OpResults have the same type. Dynamic dimension sizes also match at runtime.
31
32    Note: This implies that a destination style op without any tensor inits must
33    not have any OpResults.
34
35    An op has "pure tensor semantics" if it has at least one tensor operand and
36    no buffer (memref) operands. It has "pure buffer semantics" if it has at
37    least one buffer (memref) operand and no tensor operands.
38
39    Destination-passing style abstraction makes certain transformations easier.
40    For example, tiling implementation can extract/insert slices from/into the
41    destination of an op and use the resulting shaped value as an iter_arg in
42    the surrounding loop structure. As another example, bufferization does not
43    have to allocate new buffers for destinations (in case of in-place
44    bufferization) and can directly reuse the existing destination buffer.
45
46    Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
47    where `%t` is the single input and `%d` is the single init. `%d` is tied
48    to `%r`.
49
50    Example of an op that is not in destination style: `%r = tensor.pad %t`.
51    This op is not in destination style because `%r` and `%t` have different
52    shape.
53  }];
54
55  let cppNamespace = "::mlir";
56
57  let methods = [
58    InterfaceMethod<
59      /*desc=*/"Return start and end indices of the init operands range.",
60      /*retTy=*/"::mlir::MutableOperandRange",
61      /*methodName=*/"getDpsInitsMutable",
62      /*args=*/(ins)
63    >,
64  ];
65
66  let extraSharedClassDeclaration = [{
67    ::mlir::OperandRange getDpsInits() {
68      return $_op.getDpsInitsMutable();
69    }
70
71    /// Return the number of DPS inits.
72    int64_t getNumDpsInits() { return $_op.getDpsInits().size(); }
73
74    /// Return the `i`-th DPS init.
75    ::mlir::OpOperand *getDpsInitOperand(int64_t i) {
76      return &$_op.getDpsInitsMutable()[i];
77    }
78
79    /// Set the `i`-th DPS init.
80    void setDpsInitOperand(int64_t i, Value value) {
81      assert(i >= 0 && i < $_op.getNumDpsInits() && "invalid index");
82      $_op->setOperand($_op.getDpsInits().getBeginOperandIndex() + i, value);
83    }
84
85    /// Return the number of DPS inputs.
86    int64_t getNumDpsInputs() {
87      return $_op->getNumOperands() - $_op.getNumDpsInits();
88    }
89
90    /// Return the DPS input operands.
91    ::llvm::SmallVector<::mlir::OpOperand *> getDpsInputOperands() {
92      ::llvm::SmallVector<::mlir::OpOperand *> result;
93      int64_t numOperands = $_op->getNumOperands();
94      ::mlir::OperandRange range = $_op.getDpsInits();
95      if (range.empty()) {
96        result.reserve(numOperands);
97        for (int64_t i = 0; i < numOperands; ++i)
98          result.push_back(&$_op->getOpOperand(i));
99        return result;
100      }
101      int64_t firstInitPos = range.getBeginOperandIndex();
102      int64_t numInits = range.size();
103      result.reserve(numOperands - numInits);
104      for (int64_t i = 0; i < firstInitPos; ++i)
105        result.push_back(&$_op->getOpOperand(i));
106      for (int64_t i = firstInitPos + numInits; i < numOperands; ++i)
107        result.push_back(&$_op->getOpOperand(i));
108      return result;
109    }
110
111    /// Return the DPS input operands.
112    ::llvm::SmallVector<::mlir::Value> getDpsInputs() {
113      return ::llvm::to_vector(::llvm::map_range(
114          $_op.getDpsInputOperands(), [](OpOperand *o) { return o->get(); }));
115    }
116
117    /// Return the `i`-th DPS input operand.
118    ::mlir::OpOperand *getDpsInputOperand(int64_t i) {
119      ::mlir::OperandRange range = $_op.getDpsInits();
120      if (range.empty())
121        return &$_op->getOpOperand(i);
122      int64_t firstInitPos = range.getBeginOperandIndex();
123      int64_t numInits = range.size();
124      assert(i >= 0 && i < $_op->getNumOperands() - numInits
125             && "invalid index");
126      return &$_op->getOpOperand(
127          i < firstInitPos ? i : i + firstInitPos + numInits);
128    }
129
130    /// Return "true" if `opOperand` is an "input".
131    bool isDpsInput(::mlir::OpOperand *opOperand) {
132      assert(opOperand->getOwner() == $_op && "invalid operand");
133      return !$_op.isDpsInit(opOperand);
134    }
135
136    /// Return "true" if `opOperand` is an "init".
137    bool isDpsInit(::mlir::OpOperand *opOperand) {
138      assert(opOperand->getOwner() == $_op && "invalid operand");
139      ::mlir::OperandRange range = $_op.getDpsInits();
140      if (range.empty())
141        return false;
142      auto operandNumber = opOperand->getOperandNumber();
143      return operandNumber >= range.getBeginOperandIndex()
144          && operandNumber < range.getBeginOperandIndex() + range.size();
145    }
146
147    /// Return "true" if `opOperand` is a scalar value. A sclar is defined as
148    /// neither a MemRef nor a tensor value.
149    bool isScalar(::mlir::OpOperand *opOperand) {
150      assert(opOperand->getOwner() == $_op && "invalid operand");
151      return !::llvm::isa<BaseMemRefType, TensorType>(
152          opOperand->get().getType());
153    }
154
155    /// Return the OpResult that is tied to the given OpOperand.
156    ::mlir::OpResult getTiedOpResult(::mlir::OpOperand *opOperand) {
157        assert(opOperand->getOwner() == $_op && "invalid operand");
158        ::mlir::OperandRange range = $_op.getDpsInits();
159        assert(!range.empty() && "op has no inits");
160        int64_t resultIndex =
161            opOperand->getOperandNumber() - range.getBeginOperandIndex();
162        assert(resultIndex >= 0 &&
163               resultIndex < $_op->getNumResults());
164        return $_op->getResult(resultIndex);
165    }
166
167    /// Return the OpOperand that is tied to the given OpResult.
168    ::mlir::OpOperand *getTiedOpOperand(::mlir::OpResult opResult) {
169      assert(opResult.getDefiningOp() == $_op && "invalid opresult");
170      return $_op.getDpsInitOperand(opResult.getResultNumber());
171    }
172
173    /// Return whether the op has pure buffer semantics. That is the case if the
174    /// op has no tensor operands and at least one memref operand.
175    bool hasPureBufferSemantics() {
176      // No tensors.
177      auto isTensor = [](Value v){
178        return ::llvm::isa<::mlir::TensorType>(v.getType());
179      };
180      if (::llvm::any_of($_op->getOperands(), isTensor))
181        return false;
182      // At least one memref.
183      auto isMemref = [](Value v){
184        return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
185      };
186      return llvm::any_of($_op->getOperands(), isMemref);
187    }
188
189    /// Return whether the op has pure tensor semantics. That is the case if the
190    /// op has no memref operands and at least one tensor operand.
191    bool hasPureTensorSemantics() {
192      // No memrefs.
193      auto isMemref = [](Value v){
194        return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
195      };
196      if (::llvm::any_of($_op->getOperands(), isMemref))
197        return false;
198      // At least one tensor.
199      auto isTensor = [](Value v){
200        return ::llvm::isa<::mlir::TensorType>(v.getType());
201      };
202      return llvm::any_of($_op->getOperands(), isTensor);    }
203  }];
204
205  let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
206  let verifyWithRegions = 1;
207}
208
209
210#endif // MLIR_DESTINATIONSTYLEOPINTERFACE
211