xref: /llvm-project/mlir/include/mlir/Dialect/IRDL/IR/IRDLTraits.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- IRDLTraits.h - IRDL traits definition ---------------------*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file declares the traits used by the IR Definition Language dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_IRDL_IR_IRDLTRAITS_H_
15 #define MLIR_DIALECT_IRDL_IR_IRDLTRAITS_H_
16 
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/Support/Casting.h"
19 
20 namespace mlir {
21 namespace OpTrait {
22 
23 /// Characterize operations that have at most a single operation of certain
24 /// types in their region.
25 /// This check is only done on the children that are immediate children of the
26 /// operation, and does not recurse into the children's regions.
27 /// This trait expects the Op to satisfy the `OneRegion` trait.
28 template <typename... ChildOps>
29 class AtMostOneChildOf {
30 public:
31   template <typename ConcreteType>
32   class Impl
33       : public TraitBase<ConcreteType, AtMostOneChildOf<ChildOps...>::Impl> {
34   public:
verifyTrait(Operation * op)35     static LogicalResult verifyTrait(Operation *op) {
36       static_assert(
37           ConcreteType::template hasTrait<::mlir::OpTrait::OneRegion>(),
38           "expected operation to have a single region");
39       static_assert(sizeof...(ChildOps) > 0,
40                     "expected at least one child operation type");
41 
42       // Contains `true` if the corresponding child op has been seen.
43       bool satisfiedOps[sizeof...(ChildOps)] = {};
44 
45       for (Operation &child : cast<ConcreteType>(op).getOps()) {
46         int childOpIndex = 0;
47         if (((isa<ChildOps>(child) ? false : (++childOpIndex, true)) && ...))
48           continue;
49 
50         // Check that the operation has not been seen before.
51         if (satisfiedOps[childOpIndex])
52           return op->emitError()
53                  << "failed to verify AtMostOneChildOf trait: the operation "
54                     "contains at least two operations of type "
55                  << child.getName();
56 
57         // Mark the operation as seen.
58         satisfiedOps[childOpIndex] = true;
59       }
60       return success();
61     }
62 
63     /// Get the unique operation of a specific op that is in the operation
64     /// region.
65     template <typename OpT>
66     std::enable_if_t<std::disjunction<std::is_same<OpT, ChildOps>...>::value,
67                      std::optional<OpT>>
getOp()68     getOp() {
69       auto ops =
70           cast<ConcreteType>(this->getOperation()).template getOps<OpT>();
71       if (ops.empty())
72         return {};
73       return {*ops.begin()};
74     }
75   };
76 };
77 } // namespace OpTrait
78 } // namespace mlir
79 
80 #endif // MLIR_DIALECT_IRDL_IR_IRDLTRAITS_H_
81