xref: /llvm-project/mlir/test/lib/Dialect/Test/TestOpDefs.cpp (revision 3c64f86314fbf9a3cd578419f16e621a4de57eaa)
1e95e94adSJeff Niu //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2e95e94adSJeff Niu //
3e95e94adSJeff Niu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e95e94adSJeff Niu // See https://llvm.org/LICENSE.txt for license information.
5e95e94adSJeff Niu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e95e94adSJeff Niu //
7e95e94adSJeff Niu //===----------------------------------------------------------------------===//
8e95e94adSJeff Niu 
9e95e94adSJeff Niu #include "TestDialect.h"
10e95e94adSJeff Niu #include "TestOps.h"
11e95e94adSJeff Niu #include "mlir/Dialect/Tensor/IR/Tensor.h"
12e95e94adSJeff Niu #include "mlir/IR/Verifier.h"
13e95e94adSJeff Niu #include "mlir/Interfaces/FunctionImplementation.h"
14eeafc9daSChristian Ulmann #include "mlir/Interfaces/MemorySlotInterfaces.h"
15e95e94adSJeff Niu 
16e95e94adSJeff Niu using namespace mlir;
17e95e94adSJeff Niu using namespace test;
18e95e94adSJeff Niu 
19e95e94adSJeff Niu //===----------------------------------------------------------------------===//
20e95e94adSJeff Niu // TestBranchOp
21e95e94adSJeff Niu //===----------------------------------------------------------------------===//
22e95e94adSJeff Niu 
23e95e94adSJeff Niu SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
24e95e94adSJeff Niu   assert(index == 0 && "invalid successor index");
25e95e94adSJeff Niu   return SuccessorOperands(getTargetOperandsMutable());
26e95e94adSJeff Niu }
27e95e94adSJeff Niu 
28e95e94adSJeff Niu //===----------------------------------------------------------------------===//
29e95e94adSJeff Niu // TestProducingBranchOp
30e95e94adSJeff Niu //===----------------------------------------------------------------------===//
31e95e94adSJeff Niu 
32e95e94adSJeff Niu SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
33e95e94adSJeff Niu   assert(index <= 1 && "invalid successor index");
34e95e94adSJeff Niu   if (index == 1)
35e95e94adSJeff Niu     return SuccessorOperands(getFirstOperandsMutable());
36e95e94adSJeff Niu   return SuccessorOperands(getSecondOperandsMutable());
37e95e94adSJeff Niu }
38e95e94adSJeff Niu 
39e95e94adSJeff Niu //===----------------------------------------------------------------------===//
40e95e94adSJeff Niu // TestInternalBranchOp
41e95e94adSJeff Niu //===----------------------------------------------------------------------===//
42e95e94adSJeff Niu 
43e95e94adSJeff Niu SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
44e95e94adSJeff Niu   assert(index <= 1 && "invalid successor index");
45e95e94adSJeff Niu   if (index == 0)
46e95e94adSJeff Niu     return SuccessorOperands(0, getSuccessOperandsMutable());
47e95e94adSJeff Niu   return SuccessorOperands(1, getErrorOperandsMutable());
48e95e94adSJeff Niu }
49e95e94adSJeff Niu 
50e95e94adSJeff Niu //===----------------------------------------------------------------------===//
51e95e94adSJeff Niu // TestCallOp
52e95e94adSJeff Niu //===----------------------------------------------------------------------===//
53e95e94adSJeff Niu 
54e95e94adSJeff Niu LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
55e95e94adSJeff Niu   // Check that the callee attribute was specified.
56e95e94adSJeff Niu   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
57e95e94adSJeff Niu   if (!fnAttr)
58e95e94adSJeff Niu     return emitOpError("requires a 'callee' symbol reference attribute");
59e95e94adSJeff Niu   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
60e95e94adSJeff Niu     return emitOpError() << "'" << fnAttr.getValue()
61e95e94adSJeff Niu                          << "' does not reference a valid function";
62e95e94adSJeff Niu   return success();
63e95e94adSJeff Niu }
64e95e94adSJeff Niu 
65e95e94adSJeff Niu //===----------------------------------------------------------------------===//
66e95e94adSJeff Niu // FoldToCallOp
67e95e94adSJeff Niu //===----------------------------------------------------------------------===//
68e95e94adSJeff Niu 
69e95e94adSJeff Niu namespace {
70e95e94adSJeff Niu struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
71e95e94adSJeff Niu   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
72e95e94adSJeff Niu 
73e95e94adSJeff Niu   LogicalResult matchAndRewrite(FoldToCallOp op,
74e95e94adSJeff Niu                                 PatternRewriter &rewriter) const override {
75e95e94adSJeff Niu     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
76e95e94adSJeff Niu                                               op.getCalleeAttr(), ValueRange());
77e95e94adSJeff Niu     return success();
78e95e94adSJeff Niu   }
79e95e94adSJeff Niu };
80e95e94adSJeff Niu } // namespace
81e95e94adSJeff Niu 
82e95e94adSJeff Niu void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
83e95e94adSJeff Niu                                                MLIRContext *context) {
84e95e94adSJeff Niu   results.add<FoldToCallOpPattern>(context);
85e95e94adSJeff Niu }
86e95e94adSJeff Niu 
87e95e94adSJeff Niu //===----------------------------------------------------------------------===//
88e95e94adSJeff Niu // IsolatedRegionOp - test parsing passthrough operands
89e95e94adSJeff Niu //===----------------------------------------------------------------------===//
90e95e94adSJeff Niu 
91e95e94adSJeff Niu ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
92e95e94adSJeff Niu                                     OperationState &result) {
93e95e94adSJeff Niu   // Parse the input operand.
94e95e94adSJeff Niu   OpAsmParser::Argument argInfo;
95e95e94adSJeff Niu   argInfo.type = parser.getBuilder().getIndexType();
96e95e94adSJeff Niu   if (parser.parseOperand(argInfo.ssaName) ||
97e95e94adSJeff Niu       parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
98e95e94adSJeff Niu     return failure();
99e95e94adSJeff Niu 
100e95e94adSJeff Niu   // Parse the body region, and reuse the operand info as the argument info.
101e95e94adSJeff Niu   Region *body = result.addRegion();
102e95e94adSJeff Niu   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
103e95e94adSJeff Niu }
104e95e94adSJeff Niu 
105e95e94adSJeff Niu void IsolatedRegionOp::print(OpAsmPrinter &p) {
106e95e94adSJeff Niu   p << ' ';
107e95e94adSJeff Niu   p.printOperand(getOperand());
108e95e94adSJeff Niu   p.shadowRegionArgs(getRegion(), getOperand());
109e95e94adSJeff Niu   p << ' ';
110e95e94adSJeff Niu   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
111e95e94adSJeff Niu }
112e95e94adSJeff Niu 
113e95e94adSJeff Niu //===----------------------------------------------------------------------===//
114e95e94adSJeff Niu // SSACFGRegionOp
115e95e94adSJeff Niu //===----------------------------------------------------------------------===//
116e95e94adSJeff Niu 
117e95e94adSJeff Niu RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
118e95e94adSJeff Niu   return RegionKind::SSACFG;
119e95e94adSJeff Niu }
120e95e94adSJeff Niu 
121e95e94adSJeff Niu //===----------------------------------------------------------------------===//
122e95e94adSJeff Niu // GraphRegionOp
123e95e94adSJeff Niu //===----------------------------------------------------------------------===//
124e95e94adSJeff Niu 
125e95e94adSJeff Niu RegionKind GraphRegionOp::getRegionKind(unsigned index) {
126e95e94adSJeff Niu   return RegionKind::Graph;
127e95e94adSJeff Niu }
128e95e94adSJeff Niu 
129e95e94adSJeff Niu //===----------------------------------------------------------------------===//
130b084111cSThéo Degioanni // IsolatedGraphRegionOp
131b084111cSThéo Degioanni //===----------------------------------------------------------------------===//
132b084111cSThéo Degioanni 
133b084111cSThéo Degioanni RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
134b084111cSThéo Degioanni   return RegionKind::Graph;
135b084111cSThéo Degioanni }
136b084111cSThéo Degioanni 
137b084111cSThéo Degioanni //===----------------------------------------------------------------------===//
138e95e94adSJeff Niu // AffineScopeOp
139e95e94adSJeff Niu //===----------------------------------------------------------------------===//
140e95e94adSJeff Niu 
141e95e94adSJeff Niu ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
142e95e94adSJeff Niu   // Parse the body region, and reuse the operand info as the argument info.
143e95e94adSJeff Niu   Region *body = result.addRegion();
144e95e94adSJeff Niu   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
145e95e94adSJeff Niu }
146e95e94adSJeff Niu 
147e95e94adSJeff Niu void AffineScopeOp::print(OpAsmPrinter &p) {
148e95e94adSJeff Niu   p << " ";
149e95e94adSJeff Niu   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
150e95e94adSJeff Niu }
151e95e94adSJeff Niu 
152e95e94adSJeff Niu //===----------------------------------------------------------------------===//
153e95e94adSJeff Niu // TestRemoveOpWithInnerOps
154e95e94adSJeff Niu //===----------------------------------------------------------------------===//
155e95e94adSJeff Niu 
156e95e94adSJeff Niu namespace {
157e95e94adSJeff Niu struct TestRemoveOpWithInnerOps
158e95e94adSJeff Niu     : public OpRewritePattern<TestOpWithRegionPattern> {
159e95e94adSJeff Niu   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
160e95e94adSJeff Niu 
161e95e94adSJeff Niu   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
162e95e94adSJeff Niu 
163e95e94adSJeff Niu   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
164e95e94adSJeff Niu                                 PatternRewriter &rewriter) const override {
165e95e94adSJeff Niu     rewriter.eraseOp(op);
166e95e94adSJeff Niu     return success();
167e95e94adSJeff Niu   }
168e95e94adSJeff Niu };
169e95e94adSJeff Niu } // namespace
170e95e94adSJeff Niu 
171e95e94adSJeff Niu //===----------------------------------------------------------------------===//
172e95e94adSJeff Niu // TestOpWithRegionPattern
173e95e94adSJeff Niu //===----------------------------------------------------------------------===//
174e95e94adSJeff Niu 
175e95e94adSJeff Niu void TestOpWithRegionPattern::getCanonicalizationPatterns(
176e95e94adSJeff Niu     RewritePatternSet &results, MLIRContext *context) {
177e95e94adSJeff Niu   results.add<TestRemoveOpWithInnerOps>(context);
178e95e94adSJeff Niu }
179e95e94adSJeff Niu 
180e95e94adSJeff Niu //===----------------------------------------------------------------------===//
181e95e94adSJeff Niu // TestOpWithRegionFold
182e95e94adSJeff Niu //===----------------------------------------------------------------------===//
183e95e94adSJeff Niu 
184e95e94adSJeff Niu OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
185e95e94adSJeff Niu   return getOperand();
186e95e94adSJeff Niu }
187e95e94adSJeff Niu 
188e95e94adSJeff Niu //===----------------------------------------------------------------------===//
189e95e94adSJeff Niu // TestOpConstant
190e95e94adSJeff Niu //===----------------------------------------------------------------------===//
191e95e94adSJeff Niu 
192e95e94adSJeff Niu OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
193e95e94adSJeff Niu 
194e95e94adSJeff Niu //===----------------------------------------------------------------------===//
195e95e94adSJeff Niu // TestOpWithVariadicResultsAndFolder
196e95e94adSJeff Niu //===----------------------------------------------------------------------===//
197e95e94adSJeff Niu 
198e95e94adSJeff Niu LogicalResult TestOpWithVariadicResultsAndFolder::fold(
199e95e94adSJeff Niu     FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
200e95e94adSJeff Niu   for (Value input : this->getOperands()) {
201e95e94adSJeff Niu     results.push_back(input);
202e95e94adSJeff Niu   }
203e95e94adSJeff Niu   return success();
204e95e94adSJeff Niu }
205e95e94adSJeff Niu 
206e95e94adSJeff Niu //===----------------------------------------------------------------------===//
207e95e94adSJeff Niu // TestOpInPlaceFold
208e95e94adSJeff Niu //===----------------------------------------------------------------------===//
209e95e94adSJeff Niu 
210e95e94adSJeff Niu OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
211e95e94adSJeff Niu   // Exercise the fact that an operation created with createOrFold should be
212e95e94adSJeff Niu   // allowed to access its parent block.
213e95e94adSJeff Niu   assert(getOperation()->getBlock() &&
214e95e94adSJeff Niu          "expected that operation is not unlinked");
215e95e94adSJeff Niu 
216e95e94adSJeff Niu   if (adaptor.getOp() && !getProperties().attr) {
217e95e94adSJeff Niu     // The folder adds "attr" if not present.
218e95e94adSJeff Niu     getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
219e95e94adSJeff Niu     return getResult();
220e95e94adSJeff Niu   }
221e95e94adSJeff Niu   return {};
222e95e94adSJeff Niu }
223e95e94adSJeff Niu 
224e95e94adSJeff Niu //===----------------------------------------------------------------------===//
225e95e94adSJeff Niu // OpWithInferTypeInterfaceOp
226e95e94adSJeff Niu //===----------------------------------------------------------------------===//
227e95e94adSJeff Niu 
228e95e94adSJeff Niu LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
229e95e94adSJeff Niu     MLIRContext *, std::optional<Location> location, ValueRange operands,
230e95e94adSJeff Niu     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
231e95e94adSJeff Niu     SmallVectorImpl<Type> &inferredReturnTypes) {
232e95e94adSJeff Niu   if (operands[0].getType() != operands[1].getType()) {
233e95e94adSJeff Niu     return emitOptionalError(location, "operand type mismatch ",
234e95e94adSJeff Niu                              operands[0].getType(), " vs ",
235e95e94adSJeff Niu                              operands[1].getType());
236e95e94adSJeff Niu   }
237e95e94adSJeff Niu   inferredReturnTypes.assign({operands[0].getType()});
238e95e94adSJeff Niu   return success();
239e95e94adSJeff Niu }
240e95e94adSJeff Niu 
241e95e94adSJeff Niu //===----------------------------------------------------------------------===//
242e95e94adSJeff Niu // OpWithShapedTypeInferTypeInterfaceOp
243e95e94adSJeff Niu //===----------------------------------------------------------------------===//
244e95e94adSJeff Niu 
245e95e94adSJeff Niu LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
246e95e94adSJeff Niu     MLIRContext *context, std::optional<Location> location,
247e95e94adSJeff Niu     ValueShapeRange operands, DictionaryAttr attributes,
248e95e94adSJeff Niu     OpaqueProperties properties, RegionRange regions,
249e95e94adSJeff Niu     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
250e95e94adSJeff Niu   // Create return type consisting of the last element of the first operand.
251e95e94adSJeff Niu   auto operandType = operands.front().getType();
252e95e94adSJeff Niu   auto sval = dyn_cast<ShapedType>(operandType);
253e95e94adSJeff Niu   if (!sval)
254e95e94adSJeff Niu     return emitOptionalError(location, "only shaped type operands allowed");
255e95e94adSJeff Niu   int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
256e95e94adSJeff Niu   auto type = IntegerType::get(context, 17);
257e95e94adSJeff Niu 
258e95e94adSJeff Niu   Attribute encoding;
259e95e94adSJeff Niu   if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
260e95e94adSJeff Niu     encoding = rankedTy.getEncoding();
261e95e94adSJeff Niu   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
262e95e94adSJeff Niu   return success();
263e95e94adSJeff Niu }
264e95e94adSJeff Niu 
265e95e94adSJeff Niu LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
266e95e94adSJeff Niu     OpBuilder &builder, ValueRange operands,
267e95e94adSJeff Niu     llvm::SmallVectorImpl<Value> &shapes) {
268e95e94adSJeff Niu   shapes = SmallVector<Value, 1>{
269e95e94adSJeff Niu       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
270e95e94adSJeff Niu   return success();
271e95e94adSJeff Niu }
272e95e94adSJeff Niu 
273e95e94adSJeff Niu //===----------------------------------------------------------------------===//
274e95e94adSJeff Niu // OpWithResultShapeInterfaceOp
275e95e94adSJeff Niu //===----------------------------------------------------------------------===//
276e95e94adSJeff Niu 
277e95e94adSJeff Niu LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
278e95e94adSJeff Niu     OpBuilder &builder, ValueRange operands,
279e95e94adSJeff Niu     llvm::SmallVectorImpl<Value> &shapes) {
280e95e94adSJeff Niu   Location loc = getLoc();
281e95e94adSJeff Niu   shapes.reserve(operands.size());
282e95e94adSJeff Niu   for (Value operand : llvm::reverse(operands)) {
283e95e94adSJeff Niu     auto rank = cast<RankedTensorType>(operand.getType()).getRank();
284e95e94adSJeff Niu     auto currShape = llvm::to_vector<4>(
285e95e94adSJeff Niu         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
286e95e94adSJeff Niu           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
287e95e94adSJeff Niu         }));
288e95e94adSJeff Niu     shapes.push_back(builder.create<tensor::FromElementsOp>(
289e95e94adSJeff Niu         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
290e95e94adSJeff Niu         currShape));
291e95e94adSJeff Niu   }
292e95e94adSJeff Niu   return success();
293e95e94adSJeff Niu }
294e95e94adSJeff Niu 
295e95e94adSJeff Niu //===----------------------------------------------------------------------===//
296e95e94adSJeff Niu // OpWithResultShapePerDimInterfaceOp
297e95e94adSJeff Niu //===----------------------------------------------------------------------===//
298e95e94adSJeff Niu 
299e95e94adSJeff Niu LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
300e95e94adSJeff Niu     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
301e95e94adSJeff Niu   Location loc = getLoc();
302e95e94adSJeff Niu   shapes.reserve(getNumOperands());
303e95e94adSJeff Niu   for (Value operand : llvm::reverse(getOperands())) {
304e95e94adSJeff Niu     auto tensorType = cast<RankedTensorType>(operand.getType());
305e95e94adSJeff Niu     auto currShape = llvm::to_vector<4>(llvm::map_range(
306e95e94adSJeff Niu         llvm::seq<int64_t>(0, tensorType.getRank()),
307e95e94adSJeff Niu         [&](int64_t dim) -> OpFoldResult {
308e95e94adSJeff Niu           return tensorType.isDynamicDim(dim)
309e95e94adSJeff Niu                      ? static_cast<OpFoldResult>(
310e95e94adSJeff Niu                            builder.createOrFold<tensor::DimOp>(loc, operand,
311e95e94adSJeff Niu                                                                dim))
312e95e94adSJeff Niu                      : static_cast<OpFoldResult>(
313e95e94adSJeff Niu                            builder.getIndexAttr(tensorType.getDimSize(dim)));
314e95e94adSJeff Niu         }));
315e95e94adSJeff Niu     shapes.emplace_back(std::move(currShape));
316e95e94adSJeff Niu   }
317e95e94adSJeff Niu   return success();
318e95e94adSJeff Niu }
319e95e94adSJeff Niu 
320e95e94adSJeff Niu //===----------------------------------------------------------------------===//
321e95e94adSJeff Niu // SideEffectOp
322e95e94adSJeff Niu //===----------------------------------------------------------------------===//
323e95e94adSJeff Niu 
324e95e94adSJeff Niu namespace {
325e95e94adSJeff Niu /// A test resource for side effects.
326e95e94adSJeff Niu struct TestResource : public SideEffects::Resource::Base<TestResource> {
327e95e94adSJeff Niu   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
328e95e94adSJeff Niu 
329e95e94adSJeff Niu   StringRef getName() final { return "<Test>"; }
330e95e94adSJeff Niu };
331e95e94adSJeff Niu } // namespace
332e95e94adSJeff Niu 
333e95e94adSJeff Niu void SideEffectOp::getEffects(
334e95e94adSJeff Niu     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
335e95e94adSJeff Niu   // Check for an effects attribute on the op instance.
336e95e94adSJeff Niu   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
337e95e94adSJeff Niu   if (!effectsAttr)
338e95e94adSJeff Niu     return;
339e95e94adSJeff Niu 
340e95e94adSJeff Niu   for (Attribute element : effectsAttr) {
341e95e94adSJeff Niu     DictionaryAttr effectElement = cast<DictionaryAttr>(element);
342e95e94adSJeff Niu 
343e95e94adSJeff Niu     // Get the specific memory effect.
344e95e94adSJeff Niu     MemoryEffects::Effect *effect =
345e95e94adSJeff Niu         StringSwitch<MemoryEffects::Effect *>(
346e95e94adSJeff Niu             cast<StringAttr>(effectElement.get("effect")).getValue())
347e95e94adSJeff Niu             .Case("allocate", MemoryEffects::Allocate::get())
348e95e94adSJeff Niu             .Case("free", MemoryEffects::Free::get())
349e95e94adSJeff Niu             .Case("read", MemoryEffects::Read::get())
350e95e94adSJeff Niu             .Case("write", MemoryEffects::Write::get());
351e95e94adSJeff Niu 
352e95e94adSJeff Niu     // Check for a non-default resource to use.
353e95e94adSJeff Niu     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
354e95e94adSJeff Niu     if (effectElement.get("test_resource"))
355e95e94adSJeff Niu       resource = TestResource::get();
356e95e94adSJeff Niu 
357e95e94adSJeff Niu     // Check for a result to affect.
358e95e94adSJeff Niu     if (effectElement.get("on_result"))
3592c1ae801Sdonald chen       effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
360e95e94adSJeff Niu     else if (Attribute ref = effectElement.get("on_reference"))
361e95e94adSJeff Niu       effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
362e95e94adSJeff Niu     else
363e95e94adSJeff Niu       effects.emplace_back(effect, resource);
364e95e94adSJeff Niu   }
365e95e94adSJeff Niu }
366e95e94adSJeff Niu 
367e95e94adSJeff Niu void SideEffectOp::getEffects(
368e95e94adSJeff Niu     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
369e95e94adSJeff Niu   testSideEffectOpGetEffect(getOperation(), effects);
370e95e94adSJeff Niu }
371e95e94adSJeff Niu 
3722c1ae801Sdonald chen void SideEffectWithRegionOp::getEffects(
3732c1ae801Sdonald chen     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3742c1ae801Sdonald chen   // Check for an effects attribute on the op instance.
3752c1ae801Sdonald chen   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
3762c1ae801Sdonald chen   if (!effectsAttr)
3772c1ae801Sdonald chen     return;
3782c1ae801Sdonald chen 
3792c1ae801Sdonald chen   for (Attribute element : effectsAttr) {
3802c1ae801Sdonald chen     DictionaryAttr effectElement = cast<DictionaryAttr>(element);
3812c1ae801Sdonald chen 
3822c1ae801Sdonald chen     // Get the specific memory effect.
3832c1ae801Sdonald chen     MemoryEffects::Effect *effect =
3842c1ae801Sdonald chen         StringSwitch<MemoryEffects::Effect *>(
3852c1ae801Sdonald chen             cast<StringAttr>(effectElement.get("effect")).getValue())
3862c1ae801Sdonald chen             .Case("allocate", MemoryEffects::Allocate::get())
3872c1ae801Sdonald chen             .Case("free", MemoryEffects::Free::get())
3882c1ae801Sdonald chen             .Case("read", MemoryEffects::Read::get())
3892c1ae801Sdonald chen             .Case("write", MemoryEffects::Write::get());
3902c1ae801Sdonald chen 
3912c1ae801Sdonald chen     // Check for a non-default resource to use.
3922c1ae801Sdonald chen     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
3932c1ae801Sdonald chen     if (effectElement.get("test_resource"))
3942c1ae801Sdonald chen       resource = TestResource::get();
3952c1ae801Sdonald chen 
3962c1ae801Sdonald chen     // Check for a result to affect.
3972c1ae801Sdonald chen     if (effectElement.get("on_result"))
3982c1ae801Sdonald chen       effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
3992c1ae801Sdonald chen     else if (effectElement.get("on_operand"))
4002c1ae801Sdonald chen       effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
4012c1ae801Sdonald chen                            resource);
4022c1ae801Sdonald chen     else if (effectElement.get("on_argument"))
4032c1ae801Sdonald chen       effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
4042c1ae801Sdonald chen                            resource);
4052c1ae801Sdonald chen     else if (Attribute ref = effectElement.get("on_reference"))
4062c1ae801Sdonald chen       effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
4072c1ae801Sdonald chen     else
4082c1ae801Sdonald chen       effects.emplace_back(effect, resource);
4092c1ae801Sdonald chen   }
4102c1ae801Sdonald chen }
4112c1ae801Sdonald chen 
4122c1ae801Sdonald chen void SideEffectWithRegionOp::getEffects(
4132c1ae801Sdonald chen     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
4142c1ae801Sdonald chen   testSideEffectOpGetEffect(getOperation(), effects);
4152c1ae801Sdonald chen }
4162c1ae801Sdonald chen 
417e95e94adSJeff Niu //===----------------------------------------------------------------------===//
418e95e94adSJeff Niu // StringAttrPrettyNameOp
419e95e94adSJeff Niu //===----------------------------------------------------------------------===//
420e95e94adSJeff Niu 
421e95e94adSJeff Niu // This op has fancy handling of its SSA result name.
422e95e94adSJeff Niu ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
423e95e94adSJeff Niu                                           OperationState &result) {
424e95e94adSJeff Niu   // Add the result types.
425e95e94adSJeff Niu   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
426e95e94adSJeff Niu     result.addTypes(parser.getBuilder().getIntegerType(32));
427e95e94adSJeff Niu 
428e95e94adSJeff Niu   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
429e95e94adSJeff Niu     return failure();
430e95e94adSJeff Niu 
431e95e94adSJeff Niu   // If the attribute dictionary contains no 'names' attribute, infer it from
432e95e94adSJeff Niu   // the SSA name (if specified).
433e95e94adSJeff Niu   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
434e95e94adSJeff Niu     return attr.getName() == "names";
435e95e94adSJeff Niu   });
436e95e94adSJeff Niu 
437e95e94adSJeff Niu   // If there was no name specified, check to see if there was a useful name
438e95e94adSJeff Niu   // specified in the asm file.
439e95e94adSJeff Niu   if (hadNames || parser.getNumResults() == 0)
440e95e94adSJeff Niu     return success();
441e95e94adSJeff Niu 
442e95e94adSJeff Niu   SmallVector<StringRef, 4> names;
443e95e94adSJeff Niu   auto *context = result.getContext();
444e95e94adSJeff Niu 
445e95e94adSJeff Niu   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
446e95e94adSJeff Niu     auto resultName = parser.getResultName(i);
447e95e94adSJeff Niu     StringRef nameStr;
448e95e94adSJeff Niu     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
449e95e94adSJeff Niu       nameStr = resultName.first;
450e95e94adSJeff Niu 
451e95e94adSJeff Niu     names.push_back(nameStr);
452e95e94adSJeff Niu   }
453e95e94adSJeff Niu 
454e95e94adSJeff Niu   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
455e95e94adSJeff Niu   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
456e95e94adSJeff Niu   return success();
457e95e94adSJeff Niu }
458e95e94adSJeff Niu 
459e95e94adSJeff Niu void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
460e95e94adSJeff Niu   // Note that we only need to print the "name" attribute if the asmprinter
461e95e94adSJeff Niu   // result name disagrees with it.  This can happen in strange cases, e.g.
462e95e94adSJeff Niu   // when there are conflicts.
463e95e94adSJeff Niu   bool namesDisagree = getNames().size() != getNumResults();
464e95e94adSJeff Niu 
465e95e94adSJeff Niu   SmallString<32> resultNameStr;
466e95e94adSJeff Niu   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
467e95e94adSJeff Niu     resultNameStr.clear();
468e95e94adSJeff Niu     llvm::raw_svector_ostream tmpStream(resultNameStr);
469e95e94adSJeff Niu     p.printOperand(getResult(i), tmpStream);
470e95e94adSJeff Niu 
471e95e94adSJeff Niu     auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
472e95e94adSJeff Niu     if (!expectedName ||
473e95e94adSJeff Niu         tmpStream.str().drop_front() != expectedName.getValue()) {
474e95e94adSJeff Niu       namesDisagree = true;
475e95e94adSJeff Niu     }
476e95e94adSJeff Niu   }
477e95e94adSJeff Niu 
478e95e94adSJeff Niu   if (namesDisagree)
479e95e94adSJeff Niu     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
480e95e94adSJeff Niu   else
481e95e94adSJeff Niu     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
482e95e94adSJeff Niu }
483e95e94adSJeff Niu 
484e95e94adSJeff Niu // We set the SSA name in the asm syntax to the contents of the name
485e95e94adSJeff Niu // attribute.
486e95e94adSJeff Niu void StringAttrPrettyNameOp::getAsmResultNames(
487e95e94adSJeff Niu     function_ref<void(Value, StringRef)> setNameFn) {
488e95e94adSJeff Niu 
489e95e94adSJeff Niu   auto value = getNames();
490e95e94adSJeff Niu   for (size_t i = 0, e = value.size(); i != e; ++i)
491e95e94adSJeff Niu     if (auto str = dyn_cast<StringAttr>(value[i]))
492e95e94adSJeff Niu       if (!str.getValue().empty())
493e95e94adSJeff Niu         setNameFn(getResult(i), str.getValue());
494e95e94adSJeff Niu }
495e95e94adSJeff Niu 
496e95e94adSJeff Niu //===----------------------------------------------------------------------===//
497e95e94adSJeff Niu // CustomResultsNameOp
498e95e94adSJeff Niu //===----------------------------------------------------------------------===//
499e95e94adSJeff Niu 
500e95e94adSJeff Niu void CustomResultsNameOp::getAsmResultNames(
501e95e94adSJeff Niu     function_ref<void(Value, StringRef)> setNameFn) {
502e95e94adSJeff Niu   ArrayAttr value = getNames();
503e95e94adSJeff Niu   for (size_t i = 0, e = value.size(); i != e; ++i)
504e95e94adSJeff Niu     if (auto str = dyn_cast<StringAttr>(value[i]))
505e95e94adSJeff Niu       if (!str.empty())
506e95e94adSJeff Niu         setNameFn(getResult(i), str.getValue());
507e95e94adSJeff Niu }
508e95e94adSJeff Niu 
509e95e94adSJeff Niu //===----------------------------------------------------------------------===//
510*3c64f863SHongren Zheng // ResultNameFromTypeOp
511*3c64f863SHongren Zheng //===----------------------------------------------------------------------===//
512*3c64f863SHongren Zheng 
513*3c64f863SHongren Zheng void ResultNameFromTypeOp::getAsmResultNames(
514*3c64f863SHongren Zheng     function_ref<void(Value, StringRef)> setNameFn) {
515*3c64f863SHongren Zheng   auto result = getResult();
516*3c64f863SHongren Zheng   auto setResultNameFn = [&](::llvm::StringRef name) {
517*3c64f863SHongren Zheng     setNameFn(result, name);
518*3c64f863SHongren Zheng   };
519*3c64f863SHongren Zheng   auto opAsmTypeInterface =
520*3c64f863SHongren Zheng       ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
521*3c64f863SHongren Zheng   opAsmTypeInterface.getAsmName(setResultNameFn);
522*3c64f863SHongren Zheng }
523*3c64f863SHongren Zheng 
524*3c64f863SHongren Zheng //===----------------------------------------------------------------------===//
525*3c64f863SHongren Zheng // BlockArgumentNameFromTypeOp
526*3c64f863SHongren Zheng //===----------------------------------------------------------------------===//
527*3c64f863SHongren Zheng 
528*3c64f863SHongren Zheng void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
529*3c64f863SHongren Zheng     ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
530*3c64f863SHongren Zheng   for (auto &block : region) {
531*3c64f863SHongren Zheng     for (auto arg : block.getArguments()) {
532*3c64f863SHongren Zheng       if (auto opAsmTypeInterface =
533*3c64f863SHongren Zheng               ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
534*3c64f863SHongren Zheng         auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
535*3c64f863SHongren Zheng         opAsmTypeInterface.getAsmName(setArgNameFn);
536*3c64f863SHongren Zheng       }
537*3c64f863SHongren Zheng     }
538*3c64f863SHongren Zheng   }
539*3c64f863SHongren Zheng }
540*3c64f863SHongren Zheng 
541*3c64f863SHongren Zheng //===----------------------------------------------------------------------===//
542e95e94adSJeff Niu // ResultTypeWithTraitOp
543e95e94adSJeff Niu //===----------------------------------------------------------------------===//
544e95e94adSJeff Niu 
545e95e94adSJeff Niu LogicalResult ResultTypeWithTraitOp::verify() {
546e95e94adSJeff Niu   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
547e95e94adSJeff Niu     return success();
548e95e94adSJeff Niu   return emitError("result type should have trait 'TestTypeTrait'");
549e95e94adSJeff Niu }
550e95e94adSJeff Niu 
551e95e94adSJeff Niu //===----------------------------------------------------------------------===//
552e95e94adSJeff Niu // AttrWithTraitOp
553e95e94adSJeff Niu //===----------------------------------------------------------------------===//
554e95e94adSJeff Niu 
555e95e94adSJeff Niu LogicalResult AttrWithTraitOp::verify() {
556e95e94adSJeff Niu   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
557e95e94adSJeff Niu     return success();
558e95e94adSJeff Niu   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
559e95e94adSJeff Niu }
560e95e94adSJeff Niu 
561e95e94adSJeff Niu //===----------------------------------------------------------------------===//
562e95e94adSJeff Niu // RegionIfOp
563e95e94adSJeff Niu //===----------------------------------------------------------------------===//
564e95e94adSJeff Niu 
565e95e94adSJeff Niu void RegionIfOp::print(OpAsmPrinter &p) {
566e95e94adSJeff Niu   p << " ";
567e95e94adSJeff Niu   p.printOperands(getOperands());
568e95e94adSJeff Niu   p << ": " << getOperandTypes();
569e95e94adSJeff Niu   p.printArrowTypeList(getResultTypes());
570e95e94adSJeff Niu   p << " then ";
571e95e94adSJeff Niu   p.printRegion(getThenRegion(),
572e95e94adSJeff Niu                 /*printEntryBlockArgs=*/true,
573e95e94adSJeff Niu                 /*printBlockTerminators=*/true);
574e95e94adSJeff Niu   p << " else ";
575e95e94adSJeff Niu   p.printRegion(getElseRegion(),
576e95e94adSJeff Niu                 /*printEntryBlockArgs=*/true,
577e95e94adSJeff Niu                 /*printBlockTerminators=*/true);
578e95e94adSJeff Niu   p << " join ";
579e95e94adSJeff Niu   p.printRegion(getJoinRegion(),
580e95e94adSJeff Niu                 /*printEntryBlockArgs=*/true,
581e95e94adSJeff Niu                 /*printBlockTerminators=*/true);
582e95e94adSJeff Niu }
583e95e94adSJeff Niu 
584e95e94adSJeff Niu ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
585e95e94adSJeff Niu   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
586e95e94adSJeff Niu   SmallVector<Type, 2> operandTypes;
587e95e94adSJeff Niu 
588e95e94adSJeff Niu   result.regions.reserve(3);
589e95e94adSJeff Niu   Region *thenRegion = result.addRegion();
590e95e94adSJeff Niu   Region *elseRegion = result.addRegion();
591e95e94adSJeff Niu   Region *joinRegion = result.addRegion();
592e95e94adSJeff Niu 
593e95e94adSJeff Niu   // Parse operand, type and arrow type lists.
594e95e94adSJeff Niu   if (parser.parseOperandList(operandInfos) ||
595e95e94adSJeff Niu       parser.parseColonTypeList(operandTypes) ||
596e95e94adSJeff Niu       parser.parseArrowTypeList(result.types))
597e95e94adSJeff Niu     return failure();
598e95e94adSJeff Niu 
599e95e94adSJeff Niu   // Parse all attached regions.
600e95e94adSJeff Niu   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
601e95e94adSJeff Niu       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
602e95e94adSJeff Niu       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
603e95e94adSJeff Niu     return failure();
604e95e94adSJeff Niu 
605e95e94adSJeff Niu   return parser.resolveOperands(operandInfos, operandTypes,
606e95e94adSJeff Niu                                 parser.getCurrentLocation(), result.operands);
607e95e94adSJeff Niu }
608e95e94adSJeff Niu 
609e95e94adSJeff Niu OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
610e95e94adSJeff Niu   assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
611e95e94adSJeff Niu          "invalid region index");
612e95e94adSJeff Niu   return getOperands();
613e95e94adSJeff Niu }
614e95e94adSJeff Niu 
615e95e94adSJeff Niu void RegionIfOp::getSuccessorRegions(
616e95e94adSJeff Niu     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
617e95e94adSJeff Niu   // We always branch to the join region.
618e95e94adSJeff Niu   if (!point.isParent()) {
619e95e94adSJeff Niu     if (point != getJoinRegion())
620e95e94adSJeff Niu       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
621e95e94adSJeff Niu     else
622e95e94adSJeff Niu       regions.push_back(RegionSuccessor(getResults()));
623e95e94adSJeff Niu     return;
624e95e94adSJeff Niu   }
625e95e94adSJeff Niu 
626e95e94adSJeff Niu   // The then and else regions are the entry regions of this op.
627e95e94adSJeff Niu   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
628e95e94adSJeff Niu   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
629e95e94adSJeff Niu }
630e95e94adSJeff Niu 
631e95e94adSJeff Niu void RegionIfOp::getRegionInvocationBounds(
632e95e94adSJeff Niu     ArrayRef<Attribute> operands,
633e95e94adSJeff Niu     SmallVectorImpl<InvocationBounds> &invocationBounds) {
634e95e94adSJeff Niu   // Each region is invoked at most once.
635e95e94adSJeff Niu   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
636e95e94adSJeff Niu }
637e95e94adSJeff Niu 
638e95e94adSJeff Niu //===----------------------------------------------------------------------===//
639e95e94adSJeff Niu // AnyCondOp
640e95e94adSJeff Niu //===----------------------------------------------------------------------===//
641e95e94adSJeff Niu 
642e95e94adSJeff Niu void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
643e95e94adSJeff Niu                                     SmallVectorImpl<RegionSuccessor> &regions) {
644e95e94adSJeff Niu   // The parent op branches into the only region, and the region branches back
645e95e94adSJeff Niu   // to the parent op.
646e95e94adSJeff Niu   if (point.isParent())
647e95e94adSJeff Niu     regions.emplace_back(&getRegion());
648e95e94adSJeff Niu   else
649e95e94adSJeff Niu     regions.emplace_back(getResults());
650e95e94adSJeff Niu }
651e95e94adSJeff Niu 
652e95e94adSJeff Niu void AnyCondOp::getRegionInvocationBounds(
653e95e94adSJeff Niu     ArrayRef<Attribute> operands,
654e95e94adSJeff Niu     SmallVectorImpl<InvocationBounds> &invocationBounds) {
655e95e94adSJeff Niu   invocationBounds.emplace_back(1, 1);
656e95e94adSJeff Niu }
657e95e94adSJeff Niu 
658e95e94adSJeff Niu //===----------------------------------------------------------------------===//
659e95e94adSJeff Niu // SingleBlockImplicitTerminatorOp
660e95e94adSJeff Niu //===----------------------------------------------------------------------===//
661e95e94adSJeff Niu 
662e95e94adSJeff Niu /// Testing the correctness of some traits.
663e95e94adSJeff Niu static_assert(
664e95e94adSJeff Niu     llvm::is_detected<OpTrait::has_implicit_terminator_t,
665e95e94adSJeff Niu                       SingleBlockImplicitTerminatorOp>::value,
666e95e94adSJeff Niu     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
667e95e94adSJeff Niu static_assert(OpTrait::hasSingleBlockImplicitTerminator<
668e95e94adSJeff Niu                   SingleBlockImplicitTerminatorOp>::value,
669e95e94adSJeff Niu               "hasSingleBlockImplicitTerminator does not match "
670e95e94adSJeff Niu               "SingleBlockImplicitTerminatorOp");
671e95e94adSJeff Niu 
672e95e94adSJeff Niu //===----------------------------------------------------------------------===//
673e95e94adSJeff Niu // SingleNoTerminatorCustomAsmOp
674e95e94adSJeff Niu //===----------------------------------------------------------------------===//
675e95e94adSJeff Niu 
676e95e94adSJeff Niu ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
677e95e94adSJeff Niu                                                  OperationState &state) {
678e95e94adSJeff Niu   Region *body = state.addRegion();
679e95e94adSJeff Niu   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
680e95e94adSJeff Niu     return failure();
681e95e94adSJeff Niu   return success();
682e95e94adSJeff Niu }
683e95e94adSJeff Niu 
684e95e94adSJeff Niu void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
685e95e94adSJeff Niu   printer.printRegion(
686e95e94adSJeff Niu       getRegion(), /*printEntryBlockArgs=*/false,
687e95e94adSJeff Niu       // This op has a single block without terminators. But explicitly mark
688e95e94adSJeff Niu       // as not printing block terminators for testing.
689e95e94adSJeff Niu       /*printBlockTerminators=*/false);
690e95e94adSJeff Niu }
691e95e94adSJeff Niu 
692e95e94adSJeff Niu //===----------------------------------------------------------------------===//
693e95e94adSJeff Niu // TestVerifiersOp
694e95e94adSJeff Niu //===----------------------------------------------------------------------===//
695e95e94adSJeff Niu 
696e95e94adSJeff Niu LogicalResult TestVerifiersOp::verify() {
697e95e94adSJeff Niu   if (!getRegion().hasOneBlock())
698e95e94adSJeff Niu     return emitOpError("`hasOneBlock` trait hasn't been verified");
699e95e94adSJeff Niu 
700e95e94adSJeff Niu   Operation *definingOp = getInput().getDefiningOp();
701e95e94adSJeff Niu   if (definingOp && failed(mlir::verify(definingOp)))
702e95e94adSJeff Niu     return emitOpError("operand hasn't been verified");
703e95e94adSJeff Niu 
704e95e94adSJeff Niu   // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
705e95e94adSJeff Niu   // loop.
706e95e94adSJeff Niu   mlir::emitRemark(getLoc(), "success run of verifier");
707e95e94adSJeff Niu 
708e95e94adSJeff Niu   return success();
709e95e94adSJeff Niu }
710e95e94adSJeff Niu 
711e95e94adSJeff Niu LogicalResult TestVerifiersOp::verifyRegions() {
712e95e94adSJeff Niu   if (!getRegion().hasOneBlock())
713e95e94adSJeff Niu     return emitOpError("`hasOneBlock` trait hasn't been verified");
714e95e94adSJeff Niu 
715e95e94adSJeff Niu   for (Block &block : getRegion())
716e95e94adSJeff Niu     for (Operation &op : block)
717e95e94adSJeff Niu       if (failed(mlir::verify(&op)))
718e95e94adSJeff Niu         return emitOpError("nested op hasn't been verified");
719e95e94adSJeff Niu 
720e95e94adSJeff Niu   // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
721e95e94adSJeff Niu   // loop.
722e95e94adSJeff Niu   mlir::emitRemark(getLoc(), "success run of region verifier");
723e95e94adSJeff Niu 
724e95e94adSJeff Niu   return success();
725e95e94adSJeff Niu }
726e95e94adSJeff Niu 
727e95e94adSJeff Niu //===----------------------------------------------------------------------===//
728e95e94adSJeff Niu // Test InferIntRangeInterface
729e95e94adSJeff Niu //===----------------------------------------------------------------------===//
730e95e94adSJeff Niu 
731e95e94adSJeff Niu //===----------------------------------------------------------------------===//
732e95e94adSJeff Niu // TestWithBoundsOp
733e95e94adSJeff Niu 
734e95e94adSJeff Niu void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
735e95e94adSJeff Niu                                          SetIntRangeFn setResultRanges) {
736e95e94adSJeff Niu   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
737e95e94adSJeff Niu }
738e95e94adSJeff Niu 
739e95e94adSJeff Niu //===----------------------------------------------------------------------===//
740e95e94adSJeff Niu // TestWithBoundsRegionOp
741e95e94adSJeff Niu 
742e95e94adSJeff Niu ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
743e95e94adSJeff Niu                                           OperationState &result) {
744e95e94adSJeff Niu   if (parser.parseOptionalAttrDict(result.attributes))
745e95e94adSJeff Niu     return failure();
746e95e94adSJeff Niu 
747e95e94adSJeff Niu   // Parse the input argument
748e95e94adSJeff Niu   OpAsmParser::Argument argInfo;
749acd10074SFelix Schneider   if (failed(parser.parseArgument(argInfo, true)))
750e95e94adSJeff Niu     return failure();
751e95e94adSJeff Niu 
752e95e94adSJeff Niu   // Parse the body region, and reuse the operand info as the argument info.
753e95e94adSJeff Niu   Region *body = result.addRegion();
754e95e94adSJeff Niu   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
755e95e94adSJeff Niu }
756e95e94adSJeff Niu 
757e95e94adSJeff Niu void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
758e95e94adSJeff Niu   p.printOptionalAttrDict((*this)->getAttrs());
759e95e94adSJeff Niu   p << ' ';
760e95e94adSJeff Niu   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
761acd10074SFelix Schneider                         /*omitType=*/false);
762e95e94adSJeff Niu   p << ' ';
763e95e94adSJeff Niu   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
764e95e94adSJeff Niu }
765e95e94adSJeff Niu 
766e95e94adSJeff Niu void TestWithBoundsRegionOp::inferResultRanges(
767e95e94adSJeff Niu     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
768e95e94adSJeff Niu   Value arg = getRegion().getArgument(0);
769e95e94adSJeff Niu   setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
770e95e94adSJeff Niu }
771e95e94adSJeff Niu 
772e95e94adSJeff Niu //===----------------------------------------------------------------------===//
773e95e94adSJeff Niu // TestIncrementOp
774e95e94adSJeff Niu 
775e95e94adSJeff Niu void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
776e95e94adSJeff Niu                                         SetIntRangeFn setResultRanges) {
777e95e94adSJeff Niu   const ConstantIntRanges &range = argRanges[0];
778e95e94adSJeff Niu   APInt one(range.umin().getBitWidth(), 1);
779e95e94adSJeff Niu   setResultRanges(getResult(),
780e95e94adSJeff Niu                   {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
781e95e94adSJeff Niu                    range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
782e95e94adSJeff Niu }
783e95e94adSJeff Niu 
784e95e94adSJeff Niu //===----------------------------------------------------------------------===//
785e95e94adSJeff Niu // TestReflectBoundsOp
786e95e94adSJeff Niu 
787e95e94adSJeff Niu void TestReflectBoundsOp::inferResultRanges(
788e95e94adSJeff Niu     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
789e95e94adSJeff Niu   const ConstantIntRanges &range = argRanges[0];
790e95e94adSJeff Niu   MLIRContext *ctx = getContext();
791e95e94adSJeff Niu   Builder b(ctx);
7924636b66dSFelix Schneider   Type sIntTy, uIntTy;
7934636b66dSFelix Schneider   // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
7944636b66dSFelix Schneider   // Types for the Attributes.
795f54cdc5dSIvan Butygin   Type type = getElementTypeOrSelf(getType());
796f54cdc5dSIvan Butygin   if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
7974636b66dSFelix Schneider     unsigned bitwidth = intTy.getWidth();
7984636b66dSFelix Schneider     sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
7994636b66dSFelix Schneider     uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
8004636b66dSFelix Schneider   } else
801f54cdc5dSIvan Butygin     sIntTy = uIntTy = type;
8024636b66dSFelix Schneider 
8034636b66dSFelix Schneider   setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
8044636b66dSFelix Schneider   setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
8054636b66dSFelix Schneider   setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
8064636b66dSFelix Schneider   setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
807e95e94adSJeff Niu   setResultRanges(getResult(), range);
808e95e94adSJeff Niu }
809e95e94adSJeff Niu 
810e95e94adSJeff Niu //===----------------------------------------------------------------------===//
811e95e94adSJeff Niu // ConversionFuncOp
812e95e94adSJeff Niu //===----------------------------------------------------------------------===//
813e95e94adSJeff Niu 
814e95e94adSJeff Niu ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
815e95e94adSJeff Niu                                     OperationState &result) {
816e95e94adSJeff Niu   auto buildFuncType =
817e95e94adSJeff Niu       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
818e95e94adSJeff Niu          function_interface_impl::VariadicFlag,
819e95e94adSJeff Niu          std::string &) { return builder.getFunctionType(argTypes, results); };
820e95e94adSJeff Niu 
821e95e94adSJeff Niu   return function_interface_impl::parseFunctionOp(
822e95e94adSJeff Niu       parser, result, /*allowVariadic=*/false,
823e95e94adSJeff Niu       getFunctionTypeAttrName(result.name), buildFuncType,
824e95e94adSJeff Niu       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
825e95e94adSJeff Niu }
826e95e94adSJeff Niu 
827e95e94adSJeff Niu void ConversionFuncOp::print(OpAsmPrinter &p) {
828e95e94adSJeff Niu   function_interface_impl::printFunctionOp(
829e95e94adSJeff Niu       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
830e95e94adSJeff Niu       getArgAttrsAttrName(), getResAttrsAttrName());
831e95e94adSJeff Niu }
832e95e94adSJeff Niu 
833e95e94adSJeff Niu //===----------------------------------------------------------------------===//
834e95e94adSJeff Niu // ReifyBoundOp
835e95e94adSJeff Niu //===----------------------------------------------------------------------===//
836e95e94adSJeff Niu 
837e95e94adSJeff Niu mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
838e95e94adSJeff Niu   if (getType() == "EQ")
839e95e94adSJeff Niu     return mlir::presburger::BoundType::EQ;
840e95e94adSJeff Niu   if (getType() == "LB")
841e95e94adSJeff Niu     return mlir::presburger::BoundType::LB;
842e95e94adSJeff Niu   if (getType() == "UB")
843e95e94adSJeff Niu     return mlir::presburger::BoundType::UB;
844e95e94adSJeff Niu   llvm_unreachable("invalid bound type");
845e95e94adSJeff Niu }
846e95e94adSJeff Niu 
847e95e94adSJeff Niu LogicalResult ReifyBoundOp::verify() {
848e95e94adSJeff Niu   if (isa<ShapedType>(getVar().getType())) {
849e95e94adSJeff Niu     if (!getDim().has_value())
850e95e94adSJeff Niu       return emitOpError("expected 'dim' attribute for shaped type variable");
851e95e94adSJeff Niu   } else if (getVar().getType().isIndex()) {
852e95e94adSJeff Niu     if (getDim().has_value())
853e95e94adSJeff Niu       return emitOpError("unexpected 'dim' attribute for index variable");
854e95e94adSJeff Niu   } else {
855e95e94adSJeff Niu     return emitOpError("expected index-typed variable or shape type variable");
856e95e94adSJeff Niu   }
857e95e94adSJeff Niu   if (getConstant() && getScalable())
858e95e94adSJeff Niu     return emitOpError("'scalable' and 'constant' are mutually exlusive");
859e95e94adSJeff Niu   if (getScalable() != getVscaleMin().has_value())
860e95e94adSJeff Niu     return emitOpError("expected 'vscale_min' if and only if 'scalable'");
861e95e94adSJeff Niu   if (getScalable() != getVscaleMax().has_value())
862e95e94adSJeff Niu     return emitOpError("expected 'vscale_min' if and only if 'scalable'");
863e95e94adSJeff Niu   return success();
864e95e94adSJeff Niu }
865e95e94adSJeff Niu 
866e95e94adSJeff Niu ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
867e95e94adSJeff Niu   if (getDim().has_value())
868e95e94adSJeff Niu     return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
869e95e94adSJeff Niu   return ValueBoundsConstraintSet::Variable(getVar());
870e95e94adSJeff Niu }
871e95e94adSJeff Niu 
872e95e94adSJeff Niu //===----------------------------------------------------------------------===//
873e95e94adSJeff Niu // CompareOp
874e95e94adSJeff Niu //===----------------------------------------------------------------------===//
875e95e94adSJeff Niu 
876e95e94adSJeff Niu ValueBoundsConstraintSet::ComparisonOperator
877e95e94adSJeff Niu CompareOp::getComparisonOperator() {
878e95e94adSJeff Niu   if (getCmp() == "EQ")
879e95e94adSJeff Niu     return ValueBoundsConstraintSet::ComparisonOperator::EQ;
880e95e94adSJeff Niu   if (getCmp() == "LT")
881e95e94adSJeff Niu     return ValueBoundsConstraintSet::ComparisonOperator::LT;
882e95e94adSJeff Niu   if (getCmp() == "LE")
883e95e94adSJeff Niu     return ValueBoundsConstraintSet::ComparisonOperator::LE;
884e95e94adSJeff Niu   if (getCmp() == "GT")
885e95e94adSJeff Niu     return ValueBoundsConstraintSet::ComparisonOperator::GT;
886e95e94adSJeff Niu   if (getCmp() == "GE")
887e95e94adSJeff Niu     return ValueBoundsConstraintSet::ComparisonOperator::GE;
888e95e94adSJeff Niu   llvm_unreachable("invalid comparison operator");
889e95e94adSJeff Niu }
890e95e94adSJeff Niu 
891e95e94adSJeff Niu mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
892e95e94adSJeff Niu   if (!getLhsMap())
893e95e94adSJeff Niu     return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
894e95e94adSJeff Niu   SmallVector<Value> mapOperands(
895e95e94adSJeff Niu       getVarOperands().slice(0, getLhsMap()->getNumInputs()));
896e95e94adSJeff Niu   return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
897e95e94adSJeff Niu }
898e95e94adSJeff Niu 
899e95e94adSJeff Niu mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
900e95e94adSJeff Niu   int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
901e95e94adSJeff Niu   if (!getRhsMap())
902e95e94adSJeff Niu     return ValueBoundsConstraintSet::Variable(
903e95e94adSJeff Niu         getVarOperands()[rhsOperandsBegin]);
904e95e94adSJeff Niu   SmallVector<Value> mapOperands(
905e95e94adSJeff Niu       getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
906e95e94adSJeff Niu   return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
907e95e94adSJeff Niu }
908e95e94adSJeff Niu 
909e95e94adSJeff Niu LogicalResult CompareOp::verify() {
910e95e94adSJeff Niu   if (getCompose() && (getLhsMap() || getRhsMap()))
911e95e94adSJeff Niu     return emitOpError(
912e95e94adSJeff Niu         "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
913e95e94adSJeff Niu   int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
914e95e94adSJeff Niu   expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
915e95e94adSJeff Niu   if (getVarOperands().size() != size_t(expectedNumOperands))
916e95e94adSJeff Niu     return emitOpError("expected ")
917e95e94adSJeff Niu            << expectedNumOperands << " operands, but got "
918e95e94adSJeff Niu            << getVarOperands().size();
919e95e94adSJeff Niu   return success();
920e95e94adSJeff Niu }
921e95e94adSJeff Niu 
922e95e94adSJeff Niu //===----------------------------------------------------------------------===//
9234513050fSChristian Ulmann // TestOpInPlaceSelfFold
9244513050fSChristian Ulmann //===----------------------------------------------------------------------===//
9254513050fSChristian Ulmann 
9264513050fSChristian Ulmann OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
9274513050fSChristian Ulmann   if (!getFolded()) {
9284513050fSChristian Ulmann     // The folder adds the "folded" if not present.
9294513050fSChristian Ulmann     setFolded(true);
9304513050fSChristian Ulmann     return getResult();
9314513050fSChristian Ulmann   }
9324513050fSChristian Ulmann   return {};
9334513050fSChristian Ulmann }
9344513050fSChristian Ulmann 
9354513050fSChristian Ulmann //===----------------------------------------------------------------------===//
936e95e94adSJeff Niu // TestOpFoldWithFoldAdaptor
937e95e94adSJeff Niu //===----------------------------------------------------------------------===//
938e95e94adSJeff Niu 
939e95e94adSJeff Niu OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
940e95e94adSJeff Niu   int64_t sum = 0;
941e95e94adSJeff Niu   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
942e95e94adSJeff Niu     sum += value.getValue().getSExtValue();
943e95e94adSJeff Niu 
944e95e94adSJeff Niu   for (Attribute attr : adaptor.getVariadic())
945e95e94adSJeff Niu     if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
946e95e94adSJeff Niu       sum += 2 * value.getValue().getSExtValue();
947e95e94adSJeff Niu 
948e95e94adSJeff Niu   for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
949e95e94adSJeff Niu     for (Attribute attr : attrs)
950e95e94adSJeff Niu       if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
951e95e94adSJeff Niu         sum += 3 * value.getValue().getSExtValue();
952e95e94adSJeff Niu 
953e95e94adSJeff Niu   sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
954e95e94adSJeff Niu 
955e95e94adSJeff Niu   return IntegerAttr::get(getType(), sum);
956e95e94adSJeff Niu }
957e95e94adSJeff Niu 
958e95e94adSJeff Niu //===----------------------------------------------------------------------===//
959e95e94adSJeff Niu // OpWithInferTypeAdaptorInterfaceOp
960e95e94adSJeff Niu //===----------------------------------------------------------------------===//
961e95e94adSJeff Niu 
962e95e94adSJeff Niu LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
963e95e94adSJeff Niu     MLIRContext *, std::optional<Location> location,
964e95e94adSJeff Niu     OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
965e95e94adSJeff Niu     SmallVectorImpl<Type> &inferredReturnTypes) {
966e95e94adSJeff Niu   if (adaptor.getX().getType() != adaptor.getY().getType()) {
967e95e94adSJeff Niu     return emitOptionalError(location, "operand type mismatch ",
968e95e94adSJeff Niu                              adaptor.getX().getType(), " vs ",
969e95e94adSJeff Niu                              adaptor.getY().getType());
970e95e94adSJeff Niu   }
971e95e94adSJeff Niu   inferredReturnTypes.assign({adaptor.getX().getType()});
972e95e94adSJeff Niu   return success();
973e95e94adSJeff Niu }
974e95e94adSJeff Niu 
975e95e94adSJeff Niu //===----------------------------------------------------------------------===//
976e95e94adSJeff Niu // OpWithRefineTypeInterfaceOp
977e95e94adSJeff Niu //===----------------------------------------------------------------------===//
978e95e94adSJeff Niu 
979e95e94adSJeff Niu // TODO: We should be able to only define either inferReturnType or
980e95e94adSJeff Niu // refineReturnType, currently only refineReturnType can be omitted.
981e95e94adSJeff Niu LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
982e95e94adSJeff Niu     MLIRContext *context, std::optional<Location> location, ValueRange operands,
983e95e94adSJeff Niu     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
984e95e94adSJeff Niu     SmallVectorImpl<Type> &returnTypes) {
985e95e94adSJeff Niu   returnTypes.clear();
986e95e94adSJeff Niu   return OpWithRefineTypeInterfaceOp::refineReturnTypes(
987e95e94adSJeff Niu       context, location, operands, attributes, properties, regions,
988e95e94adSJeff Niu       returnTypes);
989e95e94adSJeff Niu }
990e95e94adSJeff Niu 
991e95e94adSJeff Niu LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
992e95e94adSJeff Niu     MLIRContext *, std::optional<Location> location, ValueRange operands,
993e95e94adSJeff Niu     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
994e95e94adSJeff Niu     SmallVectorImpl<Type> &returnTypes) {
995e95e94adSJeff Niu   if (operands[0].getType() != operands[1].getType()) {
996e95e94adSJeff Niu     return emitOptionalError(location, "operand type mismatch ",
997e95e94adSJeff Niu                              operands[0].getType(), " vs ",
998e95e94adSJeff Niu                              operands[1].getType());
999e95e94adSJeff Niu   }
1000e95e94adSJeff Niu   // TODO: Add helper to make this more concise to write.
1001e95e94adSJeff Niu   if (returnTypes.empty())
1002e95e94adSJeff Niu     returnTypes.resize(1, nullptr);
1003e95e94adSJeff Niu   if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1004e95e94adSJeff Niu     return emitOptionalError(location,
1005e95e94adSJeff Niu                              "required first operand and result to match");
1006e95e94adSJeff Niu   returnTypes[0] = operands[0].getType();
1007e95e94adSJeff Niu   return success();
1008e95e94adSJeff Niu }
1009e95e94adSJeff Niu 
1010e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1011e95e94adSJeff Niu // OpWithShapedTypeInferTypeAdaptorInterfaceOp
1012e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1013e95e94adSJeff Niu 
1014e95e94adSJeff Niu LogicalResult
1015e95e94adSJeff Niu OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
1016e95e94adSJeff Niu     MLIRContext *context, std::optional<Location> location,
1017e95e94adSJeff Niu     OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
1018e95e94adSJeff Niu     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1019e95e94adSJeff Niu   // Create return type consisting of the last element of the first operand.
1020e95e94adSJeff Niu   auto operandType = adaptor.getOperand1().getType();
1021e95e94adSJeff Niu   auto sval = dyn_cast<ShapedType>(operandType);
1022e95e94adSJeff Niu   if (!sval)
1023e95e94adSJeff Niu     return emitOptionalError(location, "only shaped type operands allowed");
1024e95e94adSJeff Niu   int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
1025e95e94adSJeff Niu   auto type = IntegerType::get(context, 17);
1026e95e94adSJeff Niu 
1027e95e94adSJeff Niu   Attribute encoding;
1028e95e94adSJeff Niu   if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
1029e95e94adSJeff Niu     encoding = rankedTy.getEncoding();
1030e95e94adSJeff Niu   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
1031e95e94adSJeff Niu   return success();
1032e95e94adSJeff Niu }
1033e95e94adSJeff Niu 
1034e95e94adSJeff Niu LogicalResult
1035e95e94adSJeff Niu OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1036e95e94adSJeff Niu     OpBuilder &builder, ValueRange operands,
1037e95e94adSJeff Niu     llvm::SmallVectorImpl<Value> &shapes) {
1038e95e94adSJeff Niu   shapes = SmallVector<Value, 1>{
1039e95e94adSJeff Niu       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1040e95e94adSJeff Niu   return success();
1041e95e94adSJeff Niu }
1042e95e94adSJeff Niu 
1043e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1044e95e94adSJeff Niu // TestOpWithPropertiesAndInferredType
1045e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1046e95e94adSJeff Niu 
1047e95e94adSJeff Niu LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1048e95e94adSJeff Niu     MLIRContext *context, std::optional<Location>, ValueRange operands,
1049e95e94adSJeff Niu     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1050e95e94adSJeff Niu     SmallVectorImpl<Type> &inferredReturnTypes) {
1051e95e94adSJeff Niu 
1052e95e94adSJeff Niu   Adaptor adaptor(operands, attributes, properties, regions);
1053e95e94adSJeff Niu   inferredReturnTypes.push_back(IntegerType::get(
1054e95e94adSJeff Niu       context, adaptor.getLhs() + adaptor.getProperties().rhs));
1055e95e94adSJeff Niu   return success();
1056e95e94adSJeff Niu }
1057e95e94adSJeff Niu 
1058e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1059e95e94adSJeff Niu // LoopBlockOp
1060e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1061e95e94adSJeff Niu 
1062e95e94adSJeff Niu void LoopBlockOp::getSuccessorRegions(
1063e95e94adSJeff Niu     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1064e95e94adSJeff Niu   regions.emplace_back(&getBody(), getBody().getArguments());
1065e95e94adSJeff Niu   if (point.isParent())
1066e95e94adSJeff Niu     return;
1067e95e94adSJeff Niu 
1068e95e94adSJeff Niu   regions.emplace_back((*this)->getResults());
1069e95e94adSJeff Niu }
1070e95e94adSJeff Niu 
1071e95e94adSJeff Niu OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1072e95e94adSJeff Niu   assert(point == getBody());
1073e95e94adSJeff Niu   return MutableOperandRange(getInitMutable());
1074e95e94adSJeff Niu }
1075e95e94adSJeff Niu 
1076e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1077e95e94adSJeff Niu // LoopBlockTerminatorOp
1078e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1079e95e94adSJeff Niu 
1080e95e94adSJeff Niu MutableOperandRange
1081e95e94adSJeff Niu LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1082e95e94adSJeff Niu   if (point.isParent())
1083e95e94adSJeff Niu     return getExitArgMutable();
1084e95e94adSJeff Niu   return getNextIterArgMutable();
1085e95e94adSJeff Niu }
1086e95e94adSJeff Niu 
1087e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1088e95e94adSJeff Niu // SwitchWithNoBreakOp
1089e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1090e95e94adSJeff Niu 
1091e95e94adSJeff Niu void TestNoTerminatorOp::getSuccessorRegions(
1092e95e94adSJeff Niu     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1093e95e94adSJeff Niu 
1094e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1095e95e94adSJeff Niu // Test InferIntRangeInterface
1096e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1097e95e94adSJeff Niu 
1098e95e94adSJeff Niu OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1099e95e94adSJeff Niu   // Just a simple fold for testing purposes that reads an operands constant
1100e95e94adSJeff Niu   // value and returns it.
1101e95e94adSJeff Niu   if (!attributes.empty())
1102e95e94adSJeff Niu     return attributes.front();
1103e95e94adSJeff Niu   return nullptr;
1104e95e94adSJeff Niu }
1105e95e94adSJeff Niu 
1106e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1107e95e94adSJeff Niu // Tensor/Buffer Ops
1108e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1109e95e94adSJeff Niu 
1110e95e94adSJeff Niu void ReadBufferOp::getEffects(
1111e95e94adSJeff Niu     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1112e95e94adSJeff Niu         &effects) {
1113e95e94adSJeff Niu   // The buffer operand is read.
11142c1ae801Sdonald chen   effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1115e95e94adSJeff Niu                        SideEffects::DefaultResource::get());
1116e95e94adSJeff Niu   // The buffer contents are dumped.
1117e95e94adSJeff Niu   effects.emplace_back(MemoryEffects::Write::get(),
1118e95e94adSJeff Niu                        SideEffects::DefaultResource::get());
1119e95e94adSJeff Niu }
1120e95e94adSJeff Niu 
1121e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1122e95e94adSJeff Niu // Test Dataflow
1123e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1124e95e94adSJeff Niu 
1125e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1126e95e94adSJeff Niu // TestCallAndStoreOp
1127e95e94adSJeff Niu 
1128e95e94adSJeff Niu CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1129e95e94adSJeff Niu   return getCallee();
1130e95e94adSJeff Niu }
1131e95e94adSJeff Niu 
1132e95e94adSJeff Niu void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
113335e89897SKazu Hirata   setCalleeAttr(cast<SymbolRefAttr>(callee));
1134e95e94adSJeff Niu }
1135e95e94adSJeff Niu 
1136e95e94adSJeff Niu Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1137e95e94adSJeff Niu   return getCalleeOperands();
1138e95e94adSJeff Niu }
1139e95e94adSJeff Niu 
1140e95e94adSJeff Niu MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1141e95e94adSJeff Niu   return getCalleeOperandsMutable();
1142e95e94adSJeff Niu }
1143e95e94adSJeff Niu 
1144e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1145e95e94adSJeff Niu // TestCallOnDeviceOp
1146e95e94adSJeff Niu 
1147e95e94adSJeff Niu CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1148e95e94adSJeff Niu   return getCallee();
1149e95e94adSJeff Niu }
1150e95e94adSJeff Niu 
1151e95e94adSJeff Niu void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
115235e89897SKazu Hirata   setCalleeAttr(cast<SymbolRefAttr>(callee));
1153e95e94adSJeff Niu }
1154e95e94adSJeff Niu 
1155e95e94adSJeff Niu Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1156e95e94adSJeff Niu   return getForwardedOperands();
1157e95e94adSJeff Niu }
1158e95e94adSJeff Niu 
1159e95e94adSJeff Niu MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1160e95e94adSJeff Niu   return getForwardedOperandsMutable();
1161e95e94adSJeff Niu }
1162e95e94adSJeff Niu 
1163e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1164e95e94adSJeff Niu // TestStoreWithARegion
1165e95e94adSJeff Niu 
1166e95e94adSJeff Niu void TestStoreWithARegion::getSuccessorRegions(
1167e95e94adSJeff Niu     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1168e95e94adSJeff Niu   if (point.isParent())
1169e95e94adSJeff Niu     regions.emplace_back(&getBody(), getBody().front().getArguments());
1170e95e94adSJeff Niu   else
1171e95e94adSJeff Niu     regions.emplace_back();
1172e95e94adSJeff Niu }
1173e95e94adSJeff Niu 
1174e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1175e95e94adSJeff Niu // TestStoreWithALoopRegion
1176e95e94adSJeff Niu 
1177e95e94adSJeff Niu void TestStoreWithALoopRegion::getSuccessorRegions(
1178e95e94adSJeff Niu     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1179e95e94adSJeff Niu   // Both the operation itself and the region may be branching into the body or
1180e95e94adSJeff Niu   // back into the operation itself. It is possible for the operation not to
1181e95e94adSJeff Niu   // enter the body.
1182e95e94adSJeff Niu   regions.emplace_back(
1183e95e94adSJeff Niu       RegionSuccessor(&getBody(), getBody().front().getArguments()));
1184e95e94adSJeff Niu   regions.emplace_back();
1185e95e94adSJeff Niu }
1186e95e94adSJeff Niu 
1187e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1188e95e94adSJeff Niu // TestVersionedOpA
1189e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1190e95e94adSJeff Niu 
1191e95e94adSJeff Niu LogicalResult
1192e95e94adSJeff Niu TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1193e95e94adSJeff Niu                                  mlir::OperationState &state) {
1194e95e94adSJeff Niu   auto &prop = state.getOrAddProperties<Properties>();
1195e95e94adSJeff Niu   if (mlir::failed(reader.readAttribute(prop.dims)))
1196e95e94adSJeff Niu     return mlir::failure();
1197e95e94adSJeff Niu 
1198e95e94adSJeff Niu   // Check if we have a version. If not, assume we are parsing the current
1199e95e94adSJeff Niu   // version.
1200e95e94adSJeff Niu   auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1201e95e94adSJeff Niu   if (succeeded(maybeVersion)) {
1202e95e94adSJeff Niu     // If version is less than 2.0, there is no additional attribute to parse.
1203e95e94adSJeff Niu     // We can materialize missing properties post parsing before verification.
1204e95e94adSJeff Niu     const auto *version =
1205e95e94adSJeff Niu         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1206e95e94adSJeff Niu     if ((version->major_ < 2)) {
1207e95e94adSJeff Niu       return success();
1208e95e94adSJeff Niu     }
1209e95e94adSJeff Niu   }
1210e95e94adSJeff Niu 
1211e95e94adSJeff Niu   if (mlir::failed(reader.readAttribute(prop.modifier)))
1212e95e94adSJeff Niu     return mlir::failure();
1213e95e94adSJeff Niu   return mlir::success();
1214e95e94adSJeff Niu }
1215e95e94adSJeff Niu 
1216e95e94adSJeff Niu void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1217e95e94adSJeff Niu   auto &prop = getProperties();
1218e95e94adSJeff Niu   writer.writeAttribute(prop.dims);
1219e95e94adSJeff Niu 
1220e95e94adSJeff Niu   auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1221e95e94adSJeff Niu   if (succeeded(maybeVersion)) {
1222e95e94adSJeff Niu     // If version is less than 2.0, there is no additional attribute to write.
1223e95e94adSJeff Niu     const auto *version =
1224e95e94adSJeff Niu         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1225e95e94adSJeff Niu     if ((version->major_ < 2)) {
1226e95e94adSJeff Niu       llvm::outs() << "downgrading op properties...\n";
1227e95e94adSJeff Niu       return;
1228e95e94adSJeff Niu     }
1229e95e94adSJeff Niu   }
1230e95e94adSJeff Niu   writer.writeAttribute(prop.modifier);
1231e95e94adSJeff Niu }
1232e95e94adSJeff Niu 
1233e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1234e95e94adSJeff Niu // TestOpWithVersionedProperties
1235e95e94adSJeff Niu //===----------------------------------------------------------------------===//
1236e95e94adSJeff Niu 
1237db791b27SRamkumar Ramachandra llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1238e95e94adSJeff Niu     mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1239e95e94adSJeff Niu   uint64_t value1, value2 = 0;
1240e95e94adSJeff Niu   if (failed(reader.readVarInt(value1)))
1241e95e94adSJeff Niu     return failure();
1242e95e94adSJeff Niu 
1243e95e94adSJeff Niu   // Check if we have a version. If not, assume we are parsing the current
1244e95e94adSJeff Niu   // version.
1245e95e94adSJeff Niu   auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1246e95e94adSJeff Niu   bool needToParseAnotherInt = true;
1247e95e94adSJeff Niu   if (succeeded(maybeVersion)) {
1248e95e94adSJeff Niu     // If version is less than 2.0, there is no additional attribute to parse.
1249e95e94adSJeff Niu     // We can materialize missing properties post parsing before verification.
1250e95e94adSJeff Niu     const auto *version =
1251e95e94adSJeff Niu         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1252e95e94adSJeff Niu     if ((version->major_ < 2))
1253e95e94adSJeff Niu       needToParseAnotherInt = false;
1254e95e94adSJeff Niu   }
1255e95e94adSJeff Niu   if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1256e95e94adSJeff Niu     return failure();
1257e95e94adSJeff Niu 
1258e95e94adSJeff Niu   prop.value1 = value1;
1259e95e94adSJeff Niu   prop.value2 = value2;
1260e95e94adSJeff Niu   return success();
1261e95e94adSJeff Niu }
1262e95e94adSJeff Niu 
1263e95e94adSJeff Niu void TestOpWithVersionedProperties::writeToMlirBytecode(
1264e95e94adSJeff Niu     mlir::DialectBytecodeWriter &writer,
1265e95e94adSJeff Niu     const test::VersionedProperties &prop) {
1266e95e94adSJeff Niu   writer.writeVarInt(prop.value1);
1267e95e94adSJeff Niu   writer.writeVarInt(prop.value2);
1268e95e94adSJeff Niu }
1269eeafc9daSChristian Ulmann 
1270eeafc9daSChristian Ulmann //===----------------------------------------------------------------------===//
1271eeafc9daSChristian Ulmann // TestMultiSlotAlloca
1272eeafc9daSChristian Ulmann //===----------------------------------------------------------------------===//
1273eeafc9daSChristian Ulmann 
1274eeafc9daSChristian Ulmann llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1275eeafc9daSChristian Ulmann   SmallVector<MemorySlot> slots;
1276eeafc9daSChristian Ulmann   for (Value result : getResults()) {
1277eeafc9daSChristian Ulmann     slots.push_back(MemorySlot{
1278eeafc9daSChristian Ulmann         result, cast<MemRefType>(result.getType()).getElementType()});
1279eeafc9daSChristian Ulmann   }
1280eeafc9daSChristian Ulmann   return slots;
1281eeafc9daSChristian Ulmann }
1282eeafc9daSChristian Ulmann 
1283eeafc9daSChristian Ulmann Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1284eeafc9daSChristian Ulmann                                            OpBuilder &builder) {
1285eeafc9daSChristian Ulmann   return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1286eeafc9daSChristian Ulmann                                         builder.getI32IntegerAttr(42));
1287eeafc9daSChristian Ulmann }
1288eeafc9daSChristian Ulmann 
1289eeafc9daSChristian Ulmann void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1290eeafc9daSChristian Ulmann                                               BlockArgument argument,
1291eeafc9daSChristian Ulmann                                               OpBuilder &builder) {
1292eeafc9daSChristian Ulmann   // Not relevant for testing.
1293eeafc9daSChristian Ulmann }
1294eeafc9daSChristian Ulmann 
12950b5b2027SChristian Ulmann /// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
12960b5b2027SChristian Ulmann static std::optional<TestMultiSlotAlloca>
12970b5b2027SChristian Ulmann createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
12980b5b2027SChristian Ulmann                                 TestMultiSlotAlloca oldOp) {
1299eeafc9daSChristian Ulmann 
13000b5b2027SChristian Ulmann   if (oldOp.getNumResults() == 1) {
13010b5b2027SChristian Ulmann     oldOp.erase();
1302eeafc9daSChristian Ulmann     return std::nullopt;
1303eeafc9daSChristian Ulmann   }
1304eeafc9daSChristian Ulmann 
1305eeafc9daSChristian Ulmann   SmallVector<Type> newTypes;
1306eeafc9daSChristian Ulmann   SmallVector<Value> remainingValues;
1307eeafc9daSChristian Ulmann 
13080b5b2027SChristian Ulmann   for (Value oldResult : oldOp.getResults()) {
1309eeafc9daSChristian Ulmann     if (oldResult == slot.ptr)
1310eeafc9daSChristian Ulmann       continue;
1311eeafc9daSChristian Ulmann     remainingValues.push_back(oldResult);
1312eeafc9daSChristian Ulmann     newTypes.push_back(oldResult.getType());
1313eeafc9daSChristian Ulmann   }
1314eeafc9daSChristian Ulmann 
1315eeafc9daSChristian Ulmann   OpBuilder::InsertionGuard guard(builder);
13160b5b2027SChristian Ulmann   builder.setInsertionPoint(oldOp);
13170b5b2027SChristian Ulmann   auto replacement =
13180b5b2027SChristian Ulmann       builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
1319eeafc9daSChristian Ulmann   for (auto [oldResult, newResult] :
1320eeafc9daSChristian Ulmann        llvm::zip_equal(remainingValues, replacement.getResults()))
1321eeafc9daSChristian Ulmann     oldResult.replaceAllUsesWith(newResult);
1322eeafc9daSChristian Ulmann 
13230b5b2027SChristian Ulmann   oldOp.erase();
1324eeafc9daSChristian Ulmann   return replacement;
1325eeafc9daSChristian Ulmann }
13260b5b2027SChristian Ulmann 
13270b5b2027SChristian Ulmann std::optional<PromotableAllocationOpInterface>
13280b5b2027SChristian Ulmann TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
13290b5b2027SChristian Ulmann                                              Value defaultValue,
13300b5b2027SChristian Ulmann                                              OpBuilder &builder) {
13310b5b2027SChristian Ulmann   if (defaultValue && defaultValue.use_empty())
13320b5b2027SChristian Ulmann     defaultValue.getDefiningOp()->erase();
13330b5b2027SChristian Ulmann   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
13340b5b2027SChristian Ulmann }
13350b5b2027SChristian Ulmann 
13360b5b2027SChristian Ulmann SmallVector<DestructurableMemorySlot>
13370b5b2027SChristian Ulmann TestMultiSlotAlloca::getDestructurableSlots() {
13380b5b2027SChristian Ulmann   SmallVector<DestructurableMemorySlot> slots;
13390b5b2027SChristian Ulmann   for (Value result : getResults()) {
13400b5b2027SChristian Ulmann     auto memrefType = cast<MemRefType>(result.getType());
13410b5b2027SChristian Ulmann     auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
13420b5b2027SChristian Ulmann     if (!destructurable)
13430b5b2027SChristian Ulmann       continue;
13440b5b2027SChristian Ulmann 
13450b5b2027SChristian Ulmann     std::optional<DenseMap<Attribute, Type>> destructuredType =
13460b5b2027SChristian Ulmann         destructurable.getSubelementIndexMap();
13470b5b2027SChristian Ulmann     if (!destructuredType)
13480b5b2027SChristian Ulmann       continue;
13490b5b2027SChristian Ulmann     slots.emplace_back(
13500b5b2027SChristian Ulmann         DestructurableMemorySlot{{result, memrefType}, *destructuredType});
13510b5b2027SChristian Ulmann   }
13520b5b2027SChristian Ulmann   return slots;
13530b5b2027SChristian Ulmann }
13540b5b2027SChristian Ulmann 
13550b5b2027SChristian Ulmann DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
13560b5b2027SChristian Ulmann     const DestructurableMemorySlot &slot,
13570b5b2027SChristian Ulmann     const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
13580b5b2027SChristian Ulmann     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
13590b5b2027SChristian Ulmann   OpBuilder::InsertionGuard guard(builder);
13600b5b2027SChristian Ulmann   builder.setInsertionPointAfter(*this);
13610b5b2027SChristian Ulmann 
13620b5b2027SChristian Ulmann   DenseMap<Attribute, MemorySlot> slotMap;
13630b5b2027SChristian Ulmann 
13640b5b2027SChristian Ulmann   for (Attribute usedIndex : usedIndices) {
136569d3793fSThéo Degioanni     Type elemType = slot.subelementTypes.lookup(usedIndex);
13660b5b2027SChristian Ulmann     MemRefType elemPtr = MemRefType::get({}, elemType);
13670b5b2027SChristian Ulmann     auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
13680b5b2027SChristian Ulmann     newAllocators.push_back(subAlloca);
13690b5b2027SChristian Ulmann     slotMap.try_emplace<MemorySlot>(usedIndex,
13700b5b2027SChristian Ulmann                                     {subAlloca.getResult(0), elemType});
13710b5b2027SChristian Ulmann   }
13720b5b2027SChristian Ulmann 
13730b5b2027SChristian Ulmann   return slotMap;
13740b5b2027SChristian Ulmann }
13750b5b2027SChristian Ulmann 
13760b5b2027SChristian Ulmann std::optional<DestructurableAllocationOpInterface>
13770b5b2027SChristian Ulmann TestMultiSlotAlloca::handleDestructuringComplete(
13780b5b2027SChristian Ulmann     const DestructurableMemorySlot &slot, OpBuilder &builder) {
13790b5b2027SChristian Ulmann   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
13800b5b2027SChristian Ulmann }
1381