16ccf2d62SRiver Riddle //===- CastInterfaces.cpp -------------------------------------------------===//
26ccf2d62SRiver Riddle //
36ccf2d62SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46ccf2d62SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
56ccf2d62SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66ccf2d62SRiver Riddle //
76ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
86ccf2d62SRiver Riddle
96ccf2d62SRiver Riddle #include "mlir/Interfaces/CastInterfaces.h"
106ccf2d62SRiver Riddle
11*e66f2bebSMatthias Springer #include "mlir/IR/BuiltinDialect.h"
12*e66f2bebSMatthias Springer #include "mlir/IR/BuiltinOps.h"
13*e66f2bebSMatthias Springer
146ccf2d62SRiver Riddle using namespace mlir;
156ccf2d62SRiver Riddle
166ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
17*e66f2bebSMatthias Springer // Helper functions for CastOpInterface
18*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
19*e66f2bebSMatthias Springer
20*e66f2bebSMatthias Springer /// Attempt to fold the given cast operation.
21*e66f2bebSMatthias Springer LogicalResult
foldCastInterfaceOp(Operation * op,ArrayRef<Attribute> attrOperands,SmallVectorImpl<OpFoldResult> & foldResults)22*e66f2bebSMatthias Springer impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
23*e66f2bebSMatthias Springer SmallVectorImpl<OpFoldResult> &foldResults) {
24*e66f2bebSMatthias Springer OperandRange operands = op->getOperands();
25*e66f2bebSMatthias Springer if (operands.empty())
26*e66f2bebSMatthias Springer return failure();
27*e66f2bebSMatthias Springer ResultRange results = op->getResults();
28*e66f2bebSMatthias Springer
29*e66f2bebSMatthias Springer // Check for the case where the input and output types match 1-1.
30*e66f2bebSMatthias Springer if (operands.getTypes() == results.getTypes()) {
31*e66f2bebSMatthias Springer foldResults.append(operands.begin(), operands.end());
32*e66f2bebSMatthias Springer return success();
33*e66f2bebSMatthias Springer }
34*e66f2bebSMatthias Springer
35*e66f2bebSMatthias Springer return failure();
36*e66f2bebSMatthias Springer }
37*e66f2bebSMatthias Springer
38*e66f2bebSMatthias Springer /// Attempt to verify the given cast operation.
verifyCastInterfaceOp(Operation * op)39*e66f2bebSMatthias Springer LogicalResult impl::verifyCastInterfaceOp(Operation *op) {
40*e66f2bebSMatthias Springer auto resultTypes = op->getResultTypes();
41*e66f2bebSMatthias Springer if (resultTypes.empty())
42*e66f2bebSMatthias Springer return op->emitOpError()
43*e66f2bebSMatthias Springer << "expected at least one result for cast operation";
44*e66f2bebSMatthias Springer
45*e66f2bebSMatthias Springer auto operandTypes = op->getOperandTypes();
46*e66f2bebSMatthias Springer if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) {
47*e66f2bebSMatthias Springer InFlightDiagnostic diag = op->emitOpError("operand type");
48*e66f2bebSMatthias Springer if (operandTypes.empty())
49*e66f2bebSMatthias Springer diag << "s []";
50*e66f2bebSMatthias Springer else if (llvm::size(operandTypes) == 1)
51*e66f2bebSMatthias Springer diag << " " << *operandTypes.begin();
52*e66f2bebSMatthias Springer else
53*e66f2bebSMatthias Springer diag << "s " << operandTypes;
54*e66f2bebSMatthias Springer return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
55*e66f2bebSMatthias Springer << resultTypes << " are cast incompatible";
56*e66f2bebSMatthias Springer }
57*e66f2bebSMatthias Springer
58*e66f2bebSMatthias Springer return success();
59*e66f2bebSMatthias Springer }
60*e66f2bebSMatthias Springer
61*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
62*e66f2bebSMatthias Springer // External model for BuiltinDialect ops
63*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
64*e66f2bebSMatthias Springer
65*e66f2bebSMatthias Springer namespace mlir {
66*e66f2bebSMatthias Springer namespace {
67*e66f2bebSMatthias Springer // This interface cannot be implemented directly on the op because the IR build
68*e66f2bebSMatthias Springer // unit cannot depend on the Interfaces build unit.
69*e66f2bebSMatthias Springer struct UnrealizedConversionCastOpInterface
70*e66f2bebSMatthias Springer : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface,
71*e66f2bebSMatthias Springer UnrealizedConversionCastOp> {
areCastCompatiblemlir::__anond689e21e0111::UnrealizedConversionCastOpInterface72*e66f2bebSMatthias Springer static bool areCastCompatible(TypeRange inputs, TypeRange outputs) {
73*e66f2bebSMatthias Springer // `UnrealizedConversionCastOp` is agnostic of the input/output types.
74*e66f2bebSMatthias Springer return true;
75*e66f2bebSMatthias Springer }
76*e66f2bebSMatthias Springer };
77*e66f2bebSMatthias Springer } // namespace
78*e66f2bebSMatthias Springer } // namespace mlir
79*e66f2bebSMatthias Springer
registerCastOpInterfaceExternalModels(DialectRegistry & registry)80*e66f2bebSMatthias Springer void mlir::builtin::registerCastOpInterfaceExternalModels(
81*e66f2bebSMatthias Springer DialectRegistry ®istry) {
82*e66f2bebSMatthias Springer registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
83*e66f2bebSMatthias Springer UnrealizedConversionCastOp::attachInterface<
84*e66f2bebSMatthias Springer UnrealizedConversionCastOpInterface>(*ctx);
85*e66f2bebSMatthias Springer });
86*e66f2bebSMatthias Springer }
87*e66f2bebSMatthias Springer
88*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
896ccf2d62SRiver Riddle // Table-generated class definitions
906ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
916ccf2d62SRiver Riddle
926ccf2d62SRiver Riddle #include "mlir/Interfaces/CastInterfaces.cpp.inc"
93