xref: /llvm-project/mlir/lib/Interfaces/CastInterfaces.cpp (revision e66f2beba8b38b148d3a892326a7133c388ffbfb)
1 //===- CastInterfaces.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/CastInterfaces.h"
10 
11 #include "mlir/IR/BuiltinDialect.h"
12 #include "mlir/IR/BuiltinOps.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // Helper functions for CastOpInterface
18 //===----------------------------------------------------------------------===//
19 
20 /// Attempt to fold the given cast operation.
21 LogicalResult
foldCastInterfaceOp(Operation * op,ArrayRef<Attribute> attrOperands,SmallVectorImpl<OpFoldResult> & foldResults)22 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
23                           SmallVectorImpl<OpFoldResult> &foldResults) {
24   OperandRange operands = op->getOperands();
25   if (operands.empty())
26     return failure();
27   ResultRange results = op->getResults();
28 
29   // Check for the case where the input and output types match 1-1.
30   if (operands.getTypes() == results.getTypes()) {
31     foldResults.append(operands.begin(), operands.end());
32     return success();
33   }
34 
35   return failure();
36 }
37 
38 /// Attempt to verify the given cast operation.
verifyCastInterfaceOp(Operation * op)39 LogicalResult impl::verifyCastInterfaceOp(Operation *op) {
40   auto resultTypes = op->getResultTypes();
41   if (resultTypes.empty())
42     return op->emitOpError()
43            << "expected at least one result for cast operation";
44 
45   auto operandTypes = op->getOperandTypes();
46   if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) {
47     InFlightDiagnostic diag = op->emitOpError("operand type");
48     if (operandTypes.empty())
49       diag << "s []";
50     else if (llvm::size(operandTypes) == 1)
51       diag << " " << *operandTypes.begin();
52     else
53       diag << "s " << operandTypes;
54     return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
55                 << resultTypes << " are cast incompatible";
56   }
57 
58   return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // External model for BuiltinDialect ops
63 //===----------------------------------------------------------------------===//
64 
65 namespace mlir {
66 namespace {
67 // This interface cannot be implemented directly on the op because the IR build
68 // unit cannot depend on the Interfaces build unit.
69 struct UnrealizedConversionCastOpInterface
70     : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface,
71                                      UnrealizedConversionCastOp> {
areCastCompatiblemlir::__anond689e21e0111::UnrealizedConversionCastOpInterface72   static bool areCastCompatible(TypeRange inputs, TypeRange outputs) {
73     // `UnrealizedConversionCastOp` is agnostic of the input/output types.
74     return true;
75   }
76 };
77 } // namespace
78 } // namespace mlir
79 
registerCastOpInterfaceExternalModels(DialectRegistry & registry)80 void mlir::builtin::registerCastOpInterfaceExternalModels(
81     DialectRegistry &registry) {
82   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
83     UnrealizedConversionCastOp::attachInterface<
84         UnrealizedConversionCastOpInterface>(*ctx);
85   });
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Table-generated class definitions
90 //===----------------------------------------------------------------------===//
91 
92 #include "mlir/Interfaces/CastInterfaces.cpp.inc"
93