xref: /llvm-project/mlir/lib/IR/Operation.cpp (revision a3ef1b587d7cf88e311d6f17132fa7fc5a6490db)
1 //===- Operation.cpp - Operation support code -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/OperationSupport.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Interfaces/FoldInterfaces.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include <numeric>
23 #include <optional>
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Operation
29 //===----------------------------------------------------------------------===//
30 
31 /// Create a new Operation from operation state.
32 Operation *Operation::create(const OperationState &state) {
33   Operation *op =
34       create(state.location, state.name, state.types, state.operands,
35              state.attributes.getDictionary(state.getContext()),
36              state.properties, state.successors, state.regions);
37   if (LLVM_UNLIKELY(state.propertiesAttr)) {
38     assert(!state.properties);
39     LogicalResult result =
40         op->setPropertiesFromAttribute(state.propertiesAttr,
41                                        /*diagnostic=*/nullptr);
42     assert(result.succeeded() && "invalid properties in op creation");
43     (void)result;
44   }
45   return op;
46 }
47 
48 /// Create a new Operation with the specific fields.
49 Operation *Operation::create(Location location, OperationName name,
50                              TypeRange resultTypes, ValueRange operands,
51                              NamedAttrList &&attributes,
52                              OpaqueProperties properties, BlockRange successors,
53                              RegionRange regions) {
54   unsigned numRegions = regions.size();
55   Operation *op =
56       create(location, name, resultTypes, operands, std::move(attributes),
57              properties, successors, numRegions);
58   for (unsigned i = 0; i < numRegions; ++i)
59     if (regions[i])
60       op->getRegion(i).takeBody(*regions[i]);
61   return op;
62 }
63 
64 /// Create a new Operation with the specific fields.
65 Operation *Operation::create(Location location, OperationName name,
66                              TypeRange resultTypes, ValueRange operands,
67                              NamedAttrList &&attributes,
68                              OpaqueProperties properties, BlockRange successors,
69                              unsigned numRegions) {
70   // Populate default attributes.
71   name.populateDefaultAttrs(attributes);
72 
73   return create(location, name, resultTypes, operands,
74                 attributes.getDictionary(location.getContext()), properties,
75                 successors, numRegions);
76 }
77 
78 /// Overload of create that takes an existing DictionaryAttr to avoid
79 /// unnecessarily uniquing a list of attributes.
80 Operation *Operation::create(Location location, OperationName name,
81                              TypeRange resultTypes, ValueRange operands,
82                              DictionaryAttr attributes,
83                              OpaqueProperties properties, BlockRange successors,
84                              unsigned numRegions) {
85   assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
86          "unexpected null result type");
87 
88   // We only need to allocate additional memory for a subset of results.
89   unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size());
90   unsigned numInlineResults = OpResult::getNumInline(resultTypes.size());
91   unsigned numSuccessors = successors.size();
92   unsigned numOperands = operands.size();
93   unsigned numResults = resultTypes.size();
94   int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize());
95 
96   // If the operation is known to have no operands, don't allocate an operand
97   // storage.
98   bool needsOperandStorage =
99       operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true;
100 
101   // Compute the byte size for the operation and the operand storage. This takes
102   // into account the size of the operation, its trailing objects, and its
103   // prefixed objects.
104   size_t byteSize =
105       totalSizeToAlloc<detail::OperandStorage, detail::OpProperties,
106                        BlockOperand, Region, OpOperand>(
107           needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors,
108           numRegions, numOperands);
109   size_t prefixByteSize = llvm::alignTo(
110       Operation::prefixAllocSize(numTrailingResults, numInlineResults),
111       alignof(Operation));
112   char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize));
113   void *rawMem = mallocMem + prefixByteSize;
114 
115   // Create the new Operation.
116   Operation *op = ::new (rawMem) Operation(
117       location, name, numResults, numSuccessors, numRegions,
118       opPropertiesAllocSize, attributes, properties, needsOperandStorage);
119 
120   assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&
121          "unexpected successors in a non-terminator operation");
122 
123   // Initialize the results.
124   auto resultTypeIt = resultTypes.begin();
125   for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt)
126     new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i);
127   for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) {
128     new (op->getOutOfLineOpResult(i))
129         detail::OutOfLineOpResult(*resultTypeIt, i);
130   }
131 
132   // Initialize the regions.
133   for (unsigned i = 0; i != numRegions; ++i)
134     new (&op->getRegion(i)) Region(op);
135 
136   // Initialize the operands.
137   if (needsOperandStorage) {
138     new (&op->getOperandStorage()) detail::OperandStorage(
139         op, op->getTrailingObjects<OpOperand>(), operands);
140   }
141 
142   // Initialize the successors.
143   auto blockOperands = op->getBlockOperands();
144   for (unsigned i = 0; i != numSuccessors; ++i)
145     new (&blockOperands[i]) BlockOperand(op, successors[i]);
146 
147   // This must be done after properties are initalized.
148   op->setAttrs(attributes);
149 
150   return op;
151 }
152 
153 Operation::Operation(Location location, OperationName name, unsigned numResults,
154                      unsigned numSuccessors, unsigned numRegions,
155                      int fullPropertiesStorageSize, DictionaryAttr attributes,
156                      OpaqueProperties properties, bool hasOperandStorage)
157     : location(location), numResults(numResults), numSuccs(numSuccessors),
158       numRegions(numRegions), hasOperandStorage(hasOperandStorage),
159       propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) {
160   assert(attributes && "unexpected null attribute dictionary");
161   assert(fullPropertiesStorageSize <= propertiesCapacity &&
162          "Properties size overflow");
163 #ifndef NDEBUG
164   if (!getDialect() && !getContext()->allowsUnregisteredDialects())
165     llvm::report_fatal_error(
166         name.getStringRef() +
167         " created with unregistered dialect. If this is intended, please call "
168         "allowUnregisteredDialects() on the MLIRContext, or use "
169         "-allow-unregistered-dialect with the MLIR tool used.");
170 #endif
171   if (fullPropertiesStorageSize)
172     name.initOpProperties(getPropertiesStorage(), properties);
173 }
174 
175 // Operations are deleted through the destroy() member because they are
176 // allocated via malloc.
177 Operation::~Operation() {
178   assert(block == nullptr && "operation destroyed but still in a block");
179 #ifndef NDEBUG
180   if (!use_empty()) {
181     {
182       InFlightDiagnostic diag =
183           emitOpError("operation destroyed but still has uses");
184       for (Operation *user : getUsers())
185         diag.attachNote(user->getLoc()) << "- use: " << *user << "\n";
186     }
187     llvm::report_fatal_error("operation destroyed but still has uses");
188   }
189 #endif
190   // Explicitly run the destructors for the operands.
191   if (hasOperandStorage)
192     getOperandStorage().~OperandStorage();
193 
194   // Explicitly run the destructors for the successors.
195   for (auto &successor : getBlockOperands())
196     successor.~BlockOperand();
197 
198   // Explicitly destroy the regions.
199   for (auto &region : getRegions())
200     region.~Region();
201   if (propertiesStorageSize)
202     name.destroyOpProperties(getPropertiesStorage());
203 }
204 
205 /// Destroy this operation or one of its subclasses.
206 void Operation::destroy() {
207   // Operations may have additional prefixed allocation, which needs to be
208   // accounted for here when computing the address to free.
209   char *rawMem = reinterpret_cast<char *>(this) -
210                  llvm::alignTo(prefixAllocSize(), alignof(Operation));
211   this->~Operation();
212   free(rawMem);
213 }
214 
215 /// Return true if this operation is a proper ancestor of the `other`
216 /// operation.
217 bool Operation::isProperAncestor(Operation *other) {
218   while ((other = other->getParentOp()))
219     if (this == other)
220       return true;
221   return false;
222 }
223 
224 /// Replace any uses of 'from' with 'to' within this operation.
225 void Operation::replaceUsesOfWith(Value from, Value to) {
226   if (from == to)
227     return;
228   for (auto &operand : getOpOperands())
229     if (operand.get() == from)
230       operand.set(to);
231 }
232 
233 /// Replace the current operands of this operation with the ones provided in
234 /// 'operands'.
235 void Operation::setOperands(ValueRange operands) {
236   if (LLVM_LIKELY(hasOperandStorage))
237     return getOperandStorage().setOperands(this, operands);
238   assert(operands.empty() && "setting operands without an operand storage");
239 }
240 
241 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
242 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
243 /// than the range pointed to by 'start'+'length'.
244 void Operation::setOperands(unsigned start, unsigned length,
245                             ValueRange operands) {
246   assert((start + length) <= getNumOperands() &&
247          "invalid operand range specified");
248   if (LLVM_LIKELY(hasOperandStorage))
249     return getOperandStorage().setOperands(this, start, length, operands);
250   assert(operands.empty() && "setting operands without an operand storage");
251 }
252 
253 /// Insert the given operands into the operand list at the given 'index'.
254 void Operation::insertOperands(unsigned index, ValueRange operands) {
255   if (LLVM_LIKELY(hasOperandStorage))
256     return setOperands(index, /*length=*/0, operands);
257   assert(operands.empty() && "inserting operands without an operand storage");
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // Diagnostics
262 //===----------------------------------------------------------------------===//
263 
264 /// Emit an error about fatal conditions with this operation, reporting up to
265 /// any diagnostic handlers that may be listening.
266 InFlightDiagnostic Operation::emitError(const Twine &message) {
267   InFlightDiagnostic diag = mlir::emitError(getLoc(), message);
268   if (getContext()->shouldPrintOpOnDiagnostic()) {
269     diag.attachNote(getLoc())
270         .append("see current operation: ")
271         .appendOp(*this, OpPrintingFlags().printGenericOpForm());
272   }
273   return diag;
274 }
275 
276 /// Emit a warning about this operation, reporting up to any diagnostic
277 /// handlers that may be listening.
278 InFlightDiagnostic Operation::emitWarning(const Twine &message) {
279   InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message);
280   if (getContext()->shouldPrintOpOnDiagnostic())
281     diag.attachNote(getLoc()) << "see current operation: " << *this;
282   return diag;
283 }
284 
285 /// Emit a remark about this operation, reporting up to any diagnostic
286 /// handlers that may be listening.
287 InFlightDiagnostic Operation::emitRemark(const Twine &message) {
288   InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message);
289   if (getContext()->shouldPrintOpOnDiagnostic())
290     diag.attachNote(getLoc()) << "see current operation: " << *this;
291   return diag;
292 }
293 
294 DictionaryAttr Operation::getAttrDictionary() {
295   if (getPropertiesStorageSize()) {
296     NamedAttrList attrsList = attrs;
297     getName().populateInherentAttrs(this, attrsList);
298     return attrsList.getDictionary(getContext());
299   }
300   return attrs;
301 }
302 
303 void Operation::setAttrs(DictionaryAttr newAttrs) {
304   assert(newAttrs && "expected valid attribute dictionary");
305   if (getPropertiesStorageSize()) {
306     attrs = DictionaryAttr::get(getContext(), {});
307     for (const NamedAttribute &attr : newAttrs)
308       setAttr(attr.getName(), attr.getValue());
309     return;
310   }
311   attrs = newAttrs;
312 }
313 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
314   if (getPropertiesStorageSize()) {
315     setAttrs(DictionaryAttr::get(getContext(), {}));
316     for (const NamedAttribute &attr : newAttrs)
317       setAttr(attr.getName(), attr.getValue());
318     return;
319   }
320   attrs = DictionaryAttr::get(getContext(), newAttrs);
321 }
322 
323 std::optional<Attribute> Operation::getInherentAttr(StringRef name) {
324   return getName().getInherentAttr(this, name);
325 }
326 
327 void Operation::setInherentAttr(StringAttr name, Attribute value) {
328   getName().setInherentAttr(this, name, value);
329 }
330 
331 Attribute Operation::getPropertiesAsAttribute() {
332   std::optional<RegisteredOperationName> info = getRegisteredInfo();
333   if (LLVM_UNLIKELY(!info))
334     return *getPropertiesStorage().as<Attribute *>();
335   return info->getOpPropertiesAsAttribute(this);
336 }
337 LogicalResult
338 Operation::setPropertiesFromAttribute(Attribute attr,
339                                       InFlightDiagnostic *diagnostic) {
340   std::optional<RegisteredOperationName> info = getRegisteredInfo();
341   if (LLVM_UNLIKELY(!info)) {
342     *getPropertiesStorage().as<Attribute *>() = attr;
343     return success();
344   }
345   return info->setOpPropertiesFromAttribute(this, attr, diagnostic);
346 }
347 
348 void Operation::copyProperties(OpaqueProperties rhs) {
349   name.copyOpProperties(getPropertiesStorage(), rhs);
350 }
351 
352 llvm::hash_code Operation::hashProperties() {
353   return name.hashOpProperties(getPropertiesStorage());
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // Operation Ordering
358 //===----------------------------------------------------------------------===//
359 
360 constexpr unsigned Operation::kInvalidOrderIdx;
361 constexpr unsigned Operation::kOrderStride;
362 
363 /// Given an operation 'other' that is within the same parent block, return
364 /// whether the current operation is before 'other' in the operation list
365 /// of the parent block.
366 /// Note: This function has an average complexity of O(1), but worst case may
367 /// take O(N) where N is the number of operations within the parent block.
368 bool Operation::isBeforeInBlock(Operation *other) {
369   assert(block && "Operations without parent blocks have no order.");
370   assert(other && other->block == block &&
371          "Expected other operation to have the same parent block.");
372   // If the order of the block is already invalid, directly recompute the
373   // parent.
374   if (!block->isOpOrderValid()) {
375     block->recomputeOpOrder();
376   } else {
377     // Update the order either operation if necessary.
378     updateOrderIfNecessary();
379     other->updateOrderIfNecessary();
380   }
381 
382   return orderIndex < other->orderIndex;
383 }
384 
385 /// Update the order index of this operation of this operation if necessary,
386 /// potentially recomputing the order of the parent block.
387 void Operation::updateOrderIfNecessary() {
388   assert(block && "expected valid parent");
389 
390   // If the order is valid for this operation there is nothing to do.
391   if (hasValidOrder())
392     return;
393   Operation *blockFront = &block->front();
394   Operation *blockBack = &block->back();
395 
396   // This method is expected to only be invoked on blocks with more than one
397   // operation.
398   assert(blockFront != blockBack && "expected more than one operation");
399 
400   // If the operation is at the end of the block.
401   if (this == blockBack) {
402     Operation *prevNode = getPrevNode();
403     if (!prevNode->hasValidOrder())
404       return block->recomputeOpOrder();
405 
406     // Add the stride to the previous operation.
407     orderIndex = prevNode->orderIndex + kOrderStride;
408     return;
409   }
410 
411   // If this is the first operation try to use the next operation to compute the
412   // ordering.
413   if (this == blockFront) {
414     Operation *nextNode = getNextNode();
415     if (!nextNode->hasValidOrder())
416       return block->recomputeOpOrder();
417     // There is no order to give this operation.
418     if (nextNode->orderIndex == 0)
419       return block->recomputeOpOrder();
420 
421     // If we can't use the stride, just take the middle value left. This is safe
422     // because we know there is at least one valid index to assign to.
423     if (nextNode->orderIndex <= kOrderStride)
424       orderIndex = (nextNode->orderIndex / 2);
425     else
426       orderIndex = kOrderStride;
427     return;
428   }
429 
430   // Otherwise, this operation is between two others. Place this operation in
431   // the middle of the previous and next if possible.
432   Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
433   if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
434     return block->recomputeOpOrder();
435   unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
436 
437   // Check to see if there is a valid order between the two.
438   if (prevOrder + 1 == nextOrder)
439     return block->recomputeOpOrder();
440   orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // ilist_traits for Operation
445 //===----------------------------------------------------------------------===//
446 
447 auto llvm::ilist_detail::SpecificNodeAccess<
448     typename llvm::ilist_detail::compute_node_options<
449         ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * {
450   return NodeAccess::getNodePtr<OptionsT>(n);
451 }
452 
453 auto llvm::ilist_detail::SpecificNodeAccess<
454     typename llvm::ilist_detail::compute_node_options<
455         ::mlir::Operation>::type>::getNodePtr(const_pointer n)
456     -> const node_type * {
457   return NodeAccess::getNodePtr<OptionsT>(n);
458 }
459 
460 auto llvm::ilist_detail::SpecificNodeAccess<
461     typename llvm::ilist_detail::compute_node_options<
462         ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer {
463   return NodeAccess::getValuePtr<OptionsT>(n);
464 }
465 
466 auto llvm::ilist_detail::SpecificNodeAccess<
467     typename llvm::ilist_detail::compute_node_options<
468         ::mlir::Operation>::type>::getValuePtr(const node_type *n)
469     -> const_pointer {
470   return NodeAccess::getValuePtr<OptionsT>(n);
471 }
472 
473 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
474   op->destroy();
475 }
476 
477 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
478   size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
479   iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this));
480   return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset);
481 }
482 
483 /// This is a trait method invoked when an operation is added to a block.  We
484 /// keep the block pointer up to date.
485 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
486   assert(!op->getBlock() && "already in an operation block!");
487   op->block = getContainingBlock();
488 
489   // Invalidate the order on the operation.
490   op->orderIndex = Operation::kInvalidOrderIdx;
491 }
492 
493 /// This is a trait method invoked when an operation is removed from a block.
494 /// We keep the block pointer up to date.
495 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
496   assert(op->block && "not already in an operation block!");
497   op->block = nullptr;
498 }
499 
500 /// This is a trait method invoked when an operation is moved from one block
501 /// to another.  We keep the block pointer up to date.
502 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
503     ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) {
504   Block *curParent = getContainingBlock();
505 
506   // Invalidate the ordering of the parent block.
507   curParent->invalidateOpOrder();
508 
509   // If we are transferring operations within the same block, the block
510   // pointer doesn't need to be updated.
511   if (curParent == otherList.getContainingBlock())
512     return;
513 
514   // Update the 'block' member of each operation.
515   for (; first != last; ++first)
516     first->block = curParent;
517 }
518 
519 /// Remove this operation (and its descendants) from its Block and delete
520 /// all of them.
521 void Operation::erase() {
522   if (auto *parent = getBlock())
523     parent->getOperations().erase(this);
524   else
525     destroy();
526 }
527 
528 /// Remove the operation from its parent block, but don't delete it.
529 void Operation::remove() {
530   if (Block *parent = getBlock())
531     parent->getOperations().remove(this);
532 }
533 
534 /// Unlink this operation from its current block and insert it right before
535 /// `existingOp` which may be in the same or another block in the same
536 /// function.
537 void Operation::moveBefore(Operation *existingOp) {
538   moveBefore(existingOp->getBlock(), existingOp->getIterator());
539 }
540 
541 /// Unlink this operation from its current basic block and insert it right
542 /// before `iterator` in the specified basic block.
543 void Operation::moveBefore(Block *block,
544                            llvm::iplist<Operation>::iterator iterator) {
545   block->getOperations().splice(iterator, getBlock()->getOperations(),
546                                 getIterator());
547 }
548 
549 /// Unlink this operation from its current block and insert it right after
550 /// `existingOp` which may be in the same or another block in the same function.
551 void Operation::moveAfter(Operation *existingOp) {
552   moveAfter(existingOp->getBlock(), existingOp->getIterator());
553 }
554 
555 /// Unlink this operation from its current block and insert it right after
556 /// `iterator` in the specified block.
557 void Operation::moveAfter(Block *block,
558                           llvm::iplist<Operation>::iterator iterator) {
559   assert(iterator != block->end() && "cannot move after end of block");
560   moveBefore(block, std::next(iterator));
561 }
562 
563 /// This drops all operand uses from this operation, which is an essential
564 /// step in breaking cyclic dependences between references when they are to
565 /// be deleted.
566 void Operation::dropAllReferences() {
567   for (auto &op : getOpOperands())
568     op.drop();
569 
570   for (auto &region : getRegions())
571     region.dropAllReferences();
572 
573   for (auto &dest : getBlockOperands())
574     dest.drop();
575 }
576 
577 /// This drops all uses of any values defined by this operation or its nested
578 /// regions, wherever they are located.
579 void Operation::dropAllDefinedValueUses() {
580   dropAllUses();
581 
582   for (auto &region : getRegions())
583     for (auto &block : region)
584       block.dropAllDefinedValueUses();
585 }
586 
587 void Operation::setSuccessor(Block *block, unsigned index) {
588   assert(index < getNumSuccessors());
589   getBlockOperands()[index].set(block);
590 }
591 
592 /// Attempt to fold this operation using the Op's registered foldHook.
593 LogicalResult Operation::fold(ArrayRef<Attribute> operands,
594                               SmallVectorImpl<OpFoldResult> &results) {
595   // If we have a registered operation definition matching this one, use it to
596   // try to constant fold the operation.
597   if (succeeded(name.foldHook(this, operands, results)))
598     return success();
599 
600   // Otherwise, fall back on the dialect hook to handle it.
601   Dialect *dialect = getDialect();
602   if (!dialect)
603     return failure();
604 
605   auto *interface = dyn_cast<DialectFoldInterface>(dialect);
606   if (!interface)
607     return failure();
608 
609   return interface->fold(this, operands, results);
610 }
611 
612 /// Emit an error with the op name prefixed, like "'dim' op " which is
613 /// convenient for verifiers.
614 InFlightDiagnostic Operation::emitOpError(const Twine &message) {
615   return emitError() << "'" << getName() << "' op " << message;
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // Operation Cloning
620 //===----------------------------------------------------------------------===//
621 
622 Operation::CloneOptions::CloneOptions()
623     : cloneRegionsFlag(false), cloneOperandsFlag(false) {}
624 
625 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands)
626     : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {}
627 
628 Operation::CloneOptions Operation::CloneOptions::all() {
629   return CloneOptions().cloneRegions().cloneOperands();
630 }
631 
632 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) {
633   cloneRegionsFlag = enable;
634   return *this;
635 }
636 
637 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) {
638   cloneOperandsFlag = enable;
639   return *this;
640 }
641 
642 /// Create a deep copy of this operation but keep the operation regions empty.
643 /// Operands are remapped using `mapper` (if present), and `mapper` is updated
644 /// to contain the results. The `mapResults` flag specifies whether the results
645 /// of the cloned operation should be added to the map.
646 Operation *Operation::cloneWithoutRegions(IRMapping &mapper) {
647   return clone(mapper, CloneOptions::all().cloneRegions(false));
648 }
649 
650 Operation *Operation::cloneWithoutRegions() {
651   IRMapping mapper;
652   return cloneWithoutRegions(mapper);
653 }
654 
655 /// Create a deep copy of this operation, remapping any operands that use
656 /// values outside of the operation using the map that is provided (leaving
657 /// them alone if no entry is present).  Replaces references to cloned
658 /// sub-operations to the corresponding operation that is copied, and adds
659 /// those mappings to the map.
660 Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
661   SmallVector<Value, 8> operands;
662   SmallVector<Block *, 2> successors;
663 
664   // Remap the operands.
665   if (options.shouldCloneOperands()) {
666     operands.reserve(getNumOperands());
667     for (auto opValue : getOperands())
668       operands.push_back(mapper.lookupOrDefault(opValue));
669   }
670 
671   // Remap the successors.
672   successors.reserve(getNumSuccessors());
673   for (Block *successor : getSuccessors())
674     successors.push_back(mapper.lookupOrDefault(successor));
675 
676   // Create the new operation.
677   auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
678                        getPropertiesStorage(), successors, getNumRegions());
679   mapper.map(this, newOp);
680 
681   // Clone the regions.
682   if (options.shouldCloneRegions()) {
683     for (unsigned i = 0; i != numRegions; ++i)
684       getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
685   }
686 
687   // Remember the mapping of any results.
688   for (unsigned i = 0, e = getNumResults(); i != e; ++i)
689     mapper.map(getResult(i), newOp->getResult(i));
690 
691   return newOp;
692 }
693 
694 Operation *Operation::clone(CloneOptions options) {
695   IRMapping mapper;
696   return clone(mapper, options);
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // OpState trait class.
701 //===----------------------------------------------------------------------===//
702 
703 // The fallback for the parser is to try for a dialect operation parser.
704 // Otherwise, reject the custom assembly form.
705 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
706   if (auto parseFn = result.name.getDialect()->getParseOperationHook(
707           result.name.getStringRef()))
708     return (*parseFn)(parser, result);
709   return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
710 }
711 
712 // The fallback for the printer is to try for a dialect operation printer.
713 // Otherwise, it prints the generic form.
714 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
715   if (auto printFn = op->getDialect()->getOperationPrinter(op)) {
716     printOpName(op, p, defaultDialect);
717     printFn(op, p);
718   } else {
719     p.printGenericOp(op);
720   }
721 }
722 
723 /// Print an operation name, eliding the dialect prefix if necessary and doesn't
724 /// lead to ambiguities.
725 void OpState::printOpName(Operation *op, OpAsmPrinter &p,
726                           StringRef defaultDialect) {
727   StringRef name = op->getName().getStringRef();
728   if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1)
729     name = name.drop_front(defaultDialect.size() + 1);
730   p.getStream() << name;
731 }
732 
733 /// Parse properties as a Attribute.
734 ParseResult OpState::genericParseProperties(OpAsmParser &parser,
735                                             Attribute &result) {
736   if (parser.parseLess() || parser.parseAttribute(result) ||
737       parser.parseGreater())
738     return failure();
739   return success();
740 }
741 
742 /// Print the properties as a Attribute.
743 void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) {
744   p << "<" << properties << ">";
745 }
746 
747 /// Emit an error about fatal conditions with this operation, reporting up to
748 /// any diagnostic handlers that may be listening.
749 InFlightDiagnostic OpState::emitError(const Twine &message) {
750   return getOperation()->emitError(message);
751 }
752 
753 /// Emit an error with the op name prefixed, like "'dim' op " which is
754 /// convenient for verifiers.
755 InFlightDiagnostic OpState::emitOpError(const Twine &message) {
756   return getOperation()->emitOpError(message);
757 }
758 
759 /// Emit a warning about this operation, reporting up to any diagnostic
760 /// handlers that may be listening.
761 InFlightDiagnostic OpState::emitWarning(const Twine &message) {
762   return getOperation()->emitWarning(message);
763 }
764 
765 /// Emit a remark about this operation, reporting up to any diagnostic
766 /// handlers that may be listening.
767 InFlightDiagnostic OpState::emitRemark(const Twine &message) {
768   return getOperation()->emitRemark(message);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // Op Trait implementations
773 //===----------------------------------------------------------------------===//
774 
775 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
776   if (op->getNumOperands() == 1) {
777     auto *argumentOp = op->getOperand(0).getDefiningOp();
778     if (argumentOp && op->getName() == argumentOp->getName()) {
779       // Replace the outer operation output with the inner operation.
780       return op->getOperand(0);
781     }
782   } else if (op->getOperand(0) == op->getOperand(1)) {
783     return op->getOperand(0);
784   }
785 
786   return {};
787 }
788 
789 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
790   auto *argumentOp = op->getOperand(0).getDefiningOp();
791   if (argumentOp && op->getName() == argumentOp->getName()) {
792     // Replace the outer involutions output with inner's input.
793     return argumentOp->getOperand(0);
794   }
795 
796   return {};
797 }
798 
799 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
800   if (op->getNumOperands() != 0)
801     return op->emitOpError() << "requires zero operands";
802   return success();
803 }
804 
805 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) {
806   if (op->getNumOperands() != 1)
807     return op->emitOpError() << "requires a single operand";
808   return success();
809 }
810 
811 LogicalResult OpTrait::impl::verifyNOperands(Operation *op,
812                                              unsigned numOperands) {
813   if (op->getNumOperands() != numOperands) {
814     return op->emitOpError() << "expected " << numOperands
815                              << " operands, but found " << op->getNumOperands();
816   }
817   return success();
818 }
819 
820 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
821                                                     unsigned numOperands) {
822   if (op->getNumOperands() < numOperands)
823     return op->emitOpError()
824            << "expected " << numOperands << " or more operands, but found "
825            << op->getNumOperands();
826   return success();
827 }
828 
829 /// If this is a vector type, or a tensor type, return the scalar element type
830 /// that it is built around, otherwise return the type unmodified.
831 static Type getTensorOrVectorElementType(Type type) {
832   if (auto vec = llvm::dyn_cast<VectorType>(type))
833     return vec.getElementType();
834 
835   // Look through tensor<vector<...>> to find the underlying element type.
836   if (auto tensor = llvm::dyn_cast<TensorType>(type))
837     return getTensorOrVectorElementType(tensor.getElementType());
838   return type;
839 }
840 
841 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) {
842   // FIXME: Add back check for no side effects on operation.
843   // Currently adding it would cause the shared library build
844   // to fail since there would be a dependency of IR on SideEffectInterfaces
845   // which is cyclical.
846   return success();
847 }
848 
849 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
850   // FIXME: Add back check for no side effects on operation.
851   // Currently adding it would cause the shared library build
852   // to fail since there would be a dependency of IR on SideEffectInterfaces
853   // which is cyclical.
854   return success();
855 }
856 
857 LogicalResult
858 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
859   for (auto opType : op->getOperandTypes()) {
860     auto type = getTensorOrVectorElementType(opType);
861     if (!type.isSignlessIntOrIndex())
862       return op->emitOpError() << "requires an integer or index type";
863   }
864   return success();
865 }
866 
867 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
868   for (auto opType : op->getOperandTypes()) {
869     auto type = getTensorOrVectorElementType(opType);
870     if (!llvm::isa<FloatType>(type))
871       return op->emitOpError("requires a float type");
872   }
873   return success();
874 }
875 
876 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
877   // Zero or one operand always have the "same" type.
878   unsigned nOperands = op->getNumOperands();
879   if (nOperands < 2)
880     return success();
881 
882   auto type = op->getOperand(0).getType();
883   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
884     if (opType != type)
885       return op->emitOpError() << "requires all operands to have the same type";
886   return success();
887 }
888 
889 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) {
890   if (op->getNumRegions() != 0)
891     return op->emitOpError() << "requires zero regions";
892   return success();
893 }
894 
895 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) {
896   if (op->getNumRegions() != 1)
897     return op->emitOpError() << "requires one region";
898   return success();
899 }
900 
901 LogicalResult OpTrait::impl::verifyNRegions(Operation *op,
902                                             unsigned numRegions) {
903   if (op->getNumRegions() != numRegions)
904     return op->emitOpError() << "expected " << numRegions << " regions";
905   return success();
906 }
907 
908 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op,
909                                                    unsigned numRegions) {
910   if (op->getNumRegions() < numRegions)
911     return op->emitOpError() << "expected " << numRegions << " or more regions";
912   return success();
913 }
914 
915 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) {
916   if (op->getNumResults() != 0)
917     return op->emitOpError() << "requires zero results";
918   return success();
919 }
920 
921 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) {
922   if (op->getNumResults() != 1)
923     return op->emitOpError() << "requires one result";
924   return success();
925 }
926 
927 LogicalResult OpTrait::impl::verifyNResults(Operation *op,
928                                             unsigned numOperands) {
929   if (op->getNumResults() != numOperands)
930     return op->emitOpError() << "expected " << numOperands << " results";
931   return success();
932 }
933 
934 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
935                                                    unsigned numOperands) {
936   if (op->getNumResults() < numOperands)
937     return op->emitOpError()
938            << "expected " << numOperands << " or more results";
939   return success();
940 }
941 
942 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
943   if (failed(verifyAtLeastNOperands(op, 1)))
944     return failure();
945 
946   if (failed(verifyCompatibleShapes(op->getOperandTypes())))
947     return op->emitOpError() << "requires the same shape for all operands";
948 
949   return success();
950 }
951 
952 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
953   if (failed(verifyAtLeastNOperands(op, 1)) ||
954       failed(verifyAtLeastNResults(op, 1)))
955     return failure();
956 
957   SmallVector<Type, 8> types(op->getOperandTypes());
958   types.append(llvm::to_vector<4>(op->getResultTypes()));
959 
960   if (failed(verifyCompatibleShapes(types)))
961     return op->emitOpError()
962            << "requires the same shape for all operands and results";
963 
964   return success();
965 }
966 
967 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
968   if (failed(verifyAtLeastNOperands(op, 1)))
969     return failure();
970   auto elementType = getElementTypeOrSelf(op->getOperand(0));
971 
972   for (auto operand : llvm::drop_begin(op->getOperands(), 1)) {
973     if (getElementTypeOrSelf(operand) != elementType)
974       return op->emitOpError("requires the same element type for all operands");
975   }
976 
977   return success();
978 }
979 
980 LogicalResult
981 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
982   if (failed(verifyAtLeastNOperands(op, 1)) ||
983       failed(verifyAtLeastNResults(op, 1)))
984     return failure();
985 
986   auto elementType = getElementTypeOrSelf(op->getResult(0));
987 
988   // Verify result element type matches first result's element type.
989   for (auto result : llvm::drop_begin(op->getResults(), 1)) {
990     if (getElementTypeOrSelf(result) != elementType)
991       return op->emitOpError(
992           "requires the same element type for all operands and results");
993   }
994 
995   // Verify operand's element type matches first result's element type.
996   for (auto operand : op->getOperands()) {
997     if (getElementTypeOrSelf(operand) != elementType)
998       return op->emitOpError(
999           "requires the same element type for all operands and results");
1000   }
1001 
1002   return success();
1003 }
1004 
1005 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
1006   if (failed(verifyAtLeastNOperands(op, 1)) ||
1007       failed(verifyAtLeastNResults(op, 1)))
1008     return failure();
1009 
1010   auto type = op->getResult(0).getType();
1011   auto elementType = getElementTypeOrSelf(type);
1012   Attribute encoding = nullptr;
1013   if (auto rankedType = dyn_cast<RankedTensorType>(type))
1014     encoding = rankedType.getEncoding();
1015   for (auto resultType : llvm::drop_begin(op->getResultTypes())) {
1016     if (getElementTypeOrSelf(resultType) != elementType ||
1017         failed(verifyCompatibleShape(resultType, type)))
1018       return op->emitOpError()
1019              << "requires the same type for all operands and results";
1020     if (encoding)
1021       if (auto rankedType = dyn_cast<RankedTensorType>(resultType);
1022           encoding != rankedType.getEncoding())
1023         return op->emitOpError()
1024                << "requires the same encoding for all operands and results";
1025   }
1026   for (auto opType : op->getOperandTypes()) {
1027     if (getElementTypeOrSelf(opType) != elementType ||
1028         failed(verifyCompatibleShape(opType, type)))
1029       return op->emitOpError()
1030              << "requires the same type for all operands and results";
1031     if (encoding)
1032       if (auto rankedType = dyn_cast<RankedTensorType>(opType);
1033           encoding != rankedType.getEncoding())
1034         return op->emitOpError()
1035                << "requires the same encoding for all operands and results";
1036   }
1037   return success();
1038 }
1039 
1040 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
1041   Block *block = op->getBlock();
1042   // Verify that the operation is at the end of the respective parent block.
1043   if (!block || &block->back() != op)
1044     return op->emitOpError("must be the last operation in the parent block");
1045   return success();
1046 }
1047 
1048 static LogicalResult verifyTerminatorSuccessors(Operation *op) {
1049   auto *parent = op->getParentRegion();
1050 
1051   // Verify that the operands lines up with the BB arguments in the successor.
1052   for (Block *succ : op->getSuccessors())
1053     if (succ->getParent() != parent)
1054       return op->emitError("reference to block defined in another region");
1055   return success();
1056 }
1057 
1058 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) {
1059   if (op->getNumSuccessors() != 0) {
1060     return op->emitOpError("requires 0 successors but found ")
1061            << op->getNumSuccessors();
1062   }
1063   return success();
1064 }
1065 
1066 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
1067   if (op->getNumSuccessors() != 1) {
1068     return op->emitOpError("requires 1 successor but found ")
1069            << op->getNumSuccessors();
1070   }
1071   return verifyTerminatorSuccessors(op);
1072 }
1073 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
1074                                                unsigned numSuccessors) {
1075   if (op->getNumSuccessors() != numSuccessors) {
1076     return op->emitOpError("requires ")
1077            << numSuccessors << " successors but found "
1078            << op->getNumSuccessors();
1079   }
1080   return verifyTerminatorSuccessors(op);
1081 }
1082 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
1083                                                       unsigned numSuccessors) {
1084   if (op->getNumSuccessors() < numSuccessors) {
1085     return op->emitOpError("requires at least ")
1086            << numSuccessors << " successors but found "
1087            << op->getNumSuccessors();
1088   }
1089   return verifyTerminatorSuccessors(op);
1090 }
1091 
1092 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
1093   for (auto resultType : op->getResultTypes()) {
1094     auto elementType = getTensorOrVectorElementType(resultType);
1095     bool isBoolType = elementType.isInteger(1);
1096     if (!isBoolType)
1097       return op->emitOpError() << "requires a bool result type";
1098   }
1099 
1100   return success();
1101 }
1102 
1103 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
1104   for (auto resultType : op->getResultTypes())
1105     if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType)))
1106       return op->emitOpError() << "requires a floating point type";
1107 
1108   return success();
1109 }
1110 
1111 LogicalResult
1112 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
1113   for (auto resultType : op->getResultTypes())
1114     if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex())
1115       return op->emitOpError() << "requires an integer or index type";
1116   return success();
1117 }
1118 
1119 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
1120                                                  StringRef attrName,
1121                                                  StringRef valueGroupName,
1122                                                  size_t expectedCount) {
1123   auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName);
1124   if (!sizeAttr)
1125     return op->emitOpError("requires dense i32 array attribute '")
1126            << attrName << "'";
1127 
1128   ArrayRef<int32_t> sizes = sizeAttr.asArrayRef();
1129   if (llvm::any_of(sizes, [](int32_t element) { return element < 0; }))
1130     return op->emitOpError("'")
1131            << attrName << "' attribute cannot have negative elements";
1132 
1133   size_t totalCount =
1134       std::accumulate(sizes.begin(), sizes.end(), 0,
1135                       [](unsigned all, int32_t one) { return all + one; });
1136 
1137   if (totalCount != expectedCount)
1138     return op->emitOpError()
1139            << valueGroupName << " count (" << expectedCount
1140            << ") does not match with the total size (" << totalCount
1141            << ") specified in attribute '" << attrName << "'";
1142   return success();
1143 }
1144 
1145 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
1146                                                    StringRef attrName) {
1147   return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands());
1148 }
1149 
1150 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
1151                                                   StringRef attrName) {
1152   return verifyValueSizeAttr(op, attrName, "result", op->getNumResults());
1153 }
1154 
1155 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
1156   for (Region &region : op->getRegions()) {
1157     if (region.empty())
1158       continue;
1159 
1160     if (region.getNumArguments() != 0) {
1161       if (op->getNumRegions() > 1)
1162         return op->emitOpError("region #")
1163                << region.getRegionNumber() << " should have no arguments";
1164       return op->emitOpError("region should have no arguments");
1165     }
1166   }
1167   return success();
1168 }
1169 
1170 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
1171   auto isMappableType = [](Type type) {
1172     return llvm::isa<VectorType, TensorType>(type);
1173   };
1174   auto resultMappableTypes = llvm::to_vector<1>(
1175       llvm::make_filter_range(op->getResultTypes(), isMappableType));
1176   auto operandMappableTypes = llvm::to_vector<2>(
1177       llvm::make_filter_range(op->getOperandTypes(), isMappableType));
1178 
1179   // If the op only has scalar operand/result types, then we have nothing to
1180   // check.
1181   if (resultMappableTypes.empty() && operandMappableTypes.empty())
1182     return success();
1183 
1184   if (!resultMappableTypes.empty() && operandMappableTypes.empty())
1185     return op->emitOpError("if a result is non-scalar, then at least one "
1186                            "operand must be non-scalar");
1187 
1188   assert(!operandMappableTypes.empty());
1189 
1190   if (resultMappableTypes.empty())
1191     return op->emitOpError("if an operand is non-scalar, then there must be at "
1192                            "least one non-scalar result");
1193 
1194   if (resultMappableTypes.size() != op->getNumResults())
1195     return op->emitOpError(
1196         "if an operand is non-scalar, then all results must be non-scalar");
1197 
1198   SmallVector<Type, 4> types = llvm::to_vector<2>(
1199       llvm::concat<Type>(operandMappableTypes, resultMappableTypes));
1200   TypeID expectedBaseTy = types.front().getTypeID();
1201   if (!llvm::all_of(types,
1202                     [&](Type t) { return t.getTypeID() == expectedBaseTy; }) ||
1203       failed(verifyCompatibleShapes(types))) {
1204     return op->emitOpError() << "all non-scalar operands/results must have the "
1205                                 "same shape and base type";
1206   }
1207 
1208   return success();
1209 }
1210 
1211 /// Check for any values used by operations regions attached to the
1212 /// specified "IsIsolatedFromAbove" operation defined outside of it.
1213 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) {
1214   assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1215          "Intended to check IsolatedFromAbove ops");
1216 
1217   // List of regions to analyze.  Each region is processed independently, with
1218   // respect to the common `limit` region, so we can look at them in any order.
1219   // Therefore, use a simple vector and push/pop back the current region.
1220   SmallVector<Region *, 8> pendingRegions;
1221   for (auto &region : isolatedOp->getRegions()) {
1222     pendingRegions.push_back(&region);
1223 
1224     // Traverse all operations in the region.
1225     while (!pendingRegions.empty()) {
1226       for (Operation &op : pendingRegions.pop_back_val()->getOps()) {
1227         for (Value operand : op.getOperands()) {
1228           // Check that any value that is used by an operation is defined in the
1229           // same region as either an operation result.
1230           auto *operandRegion = operand.getParentRegion();
1231           if (!operandRegion)
1232             return op.emitError("operation's operand is unlinked");
1233           if (!region.isAncestor(operandRegion)) {
1234             return op.emitOpError("using value defined outside the region")
1235                        .attachNote(isolatedOp->getLoc())
1236                    << "required by region isolation constraints";
1237           }
1238         }
1239 
1240         // Schedule any regions in the operation for further checking.  Don't
1241         // recurse into other IsolatedFromAbove ops, because they will check
1242         // themselves.
1243         if (op.getNumRegions() &&
1244             !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1245           for (Region &subRegion : op.getRegions())
1246             pendingRegions.push_back(&subRegion);
1247         }
1248       }
1249     }
1250   }
1251 
1252   return success();
1253 }
1254 
1255 bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
1256   return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
1257          op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();
1258 }
1259 
1260 //===----------------------------------------------------------------------===//
1261 // CastOpInterface
1262 //===----------------------------------------------------------------------===//
1263 
1264 /// Attempt to fold the given cast operation.
1265 LogicalResult
1266 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
1267                           SmallVectorImpl<OpFoldResult> &foldResults) {
1268   OperandRange operands = op->getOperands();
1269   if (operands.empty())
1270     return failure();
1271   ResultRange results = op->getResults();
1272 
1273   // Check for the case where the input and output types match 1-1.
1274   if (operands.getTypes() == results.getTypes()) {
1275     foldResults.append(operands.begin(), operands.end());
1276     return success();
1277   }
1278 
1279   return failure();
1280 }
1281 
1282 /// Attempt to verify the given cast operation.
1283 LogicalResult impl::verifyCastInterfaceOp(
1284     Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) {
1285   auto resultTypes = op->getResultTypes();
1286   if (resultTypes.empty())
1287     return op->emitOpError()
1288            << "expected at least one result for cast operation";
1289 
1290   auto operandTypes = op->getOperandTypes();
1291   if (!areCastCompatible(operandTypes, resultTypes)) {
1292     InFlightDiagnostic diag = op->emitOpError("operand type");
1293     if (operandTypes.empty())
1294       diag << "s []";
1295     else if (llvm::size(operandTypes) == 1)
1296       diag << " " << *operandTypes.begin();
1297     else
1298       diag << "s " << operandTypes;
1299     return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
1300                 << resultTypes << " are cast incompatible";
1301   }
1302 
1303   return success();
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // Misc. utils
1308 //===----------------------------------------------------------------------===//
1309 
1310 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
1311 /// region's only block if it does not have a terminator already. If the region
1312 /// is empty, insert a new block first. `buildTerminatorOp` should return the
1313 /// terminator operation to insert.
1314 void impl::ensureRegionTerminator(
1315     Region &region, OpBuilder &builder, Location loc,
1316     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) {
1317   OpBuilder::InsertionGuard guard(builder);
1318   if (region.empty())
1319     builder.createBlock(&region);
1320 
1321   Block &block = region.back();
1322   if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
1323     return;
1324 
1325   builder.setInsertionPointToEnd(&block);
1326   builder.insert(buildTerminatorOp(builder, loc));
1327 }
1328 
1329 /// Create a simple OpBuilder and forward to the OpBuilder version of this
1330 /// function.
1331 void impl::ensureRegionTerminator(
1332     Region &region, Builder &builder, Location loc,
1333     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) {
1334   OpBuilder opBuilder(builder.getContext());
1335   ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp);
1336 }
1337