1d0cb0d30SAlexander Belyaev //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2d0cb0d30SAlexander Belyaev //
3d0cb0d30SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d0cb0d30SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5d0cb0d30SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d0cb0d30SAlexander Belyaev //
7d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
8d0cb0d30SAlexander Belyaev
9a1e78615SLei Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
10d0cb0d30SAlexander Belyaev #include "mlir/Dialect/Complex/IR/Complex.h"
11a28fe17dSAdrian Kuegel #include "mlir/IR/Builders.h"
12a1e78615SLei Zhang #include "mlir/IR/BuiltinTypes.h"
13036a6996Slewuathe #include "mlir/IR/Matchers.h"
14a1e78615SLei Zhang #include "mlir/IR/PatternMatch.h"
15d0cb0d30SAlexander Belyaev
16d0cb0d30SAlexander Belyaev using namespace mlir;
17d0cb0d30SAlexander Belyaev using namespace mlir::complex;
18d0cb0d30SAlexander Belyaev
19d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
20480cd4cbSRiver Riddle // ConstantOp
21d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
22d0cb0d30SAlexander Belyaev
fold(FoldAdaptor adaptor)237df76121SMarkus Böck OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24480cd4cbSRiver Riddle return getValue();
25480cd4cbSRiver Riddle }
26480cd4cbSRiver Riddle
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)27480cd4cbSRiver Riddle void ConstantOp::getAsmResultNames(
28480cd4cbSRiver Riddle function_ref<void(Value, StringRef)> setNameFn) {
29480cd4cbSRiver Riddle setNameFn(getResult(), "cst");
30480cd4cbSRiver Riddle }
31480cd4cbSRiver Riddle
isBuildableWith(Attribute value,Type type)32480cd4cbSRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33c1fa60b4STres Popp if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34c1fa60b4STres Popp auto complexTy = llvm::dyn_cast<ComplexType>(type);
35e1795322SJeff Niu if (!complexTy || arrAttr.size() != 2)
36480cd4cbSRiver Riddle return false;
37480cd4cbSRiver Riddle auto complexEltTy = complexTy.getElementType();
38c1fa60b4STres Popp if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39c1fa60b4STres Popp auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40cc4fb583SXiang Li return im && fre.getType() == complexEltTy &&
41e1795322SJeff Niu im.getType() == complexEltTy;
42480cd4cbSRiver Riddle }
43c1fa60b4STres Popp if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44c1fa60b4STres Popp auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45cc4fb583SXiang Li return im && ire.getType() == complexEltTy &&
46cc4fb583SXiang Li im.getType() == complexEltTy;
47cc4fb583SXiang Li }
48cc4fb583SXiang Li }
49480cd4cbSRiver Riddle return false;
50480cd4cbSRiver Riddle }
51480cd4cbSRiver Riddle
verify()521be88f5aSRiver Riddle LogicalResult ConstantOp::verify() {
531be88f5aSRiver Riddle ArrayAttr arrayAttr = getValue();
54480cd4cbSRiver Riddle if (arrayAttr.size() != 2) {
551be88f5aSRiver Riddle return emitOpError(
56480cd4cbSRiver Riddle "requires 'value' to be a complex constant, represented as array of "
57480cd4cbSRiver Riddle "two values");
58480cd4cbSRiver Riddle }
59480cd4cbSRiver Riddle
601be88f5aSRiver Riddle auto complexEltTy = getType().getElementType();
6116129937SMatthias Springer if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
6216129937SMatthias Springer !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
6316129937SMatthias Springer return emitOpError(
6416129937SMatthias Springer "requires attribute's elements to be float or integer attributes");
6516129937SMatthias Springer auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
6616129937SMatthias Springer auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67e1795322SJeff Niu if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
681be88f5aSRiver Riddle return emitOpError()
69e1795322SJeff Niu << "requires attribute's element types (" << re.getType() << ", "
70e1795322SJeff Niu << im.getType()
71480cd4cbSRiver Riddle << ") to match the element type of the op's return type ("
72480cd4cbSRiver Riddle << complexEltTy << ")";
73480cd4cbSRiver Riddle }
74480cd4cbSRiver Riddle return success();
75480cd4cbSRiver Riddle }
76480cd4cbSRiver Riddle
77480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
78a8df21f4SRob Suderman // BitcastOp
79a8df21f4SRob Suderman //===----------------------------------------------------------------------===//
80a8df21f4SRob Suderman
fold(FoldAdaptor bitcast)81a8df21f4SRob Suderman OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
82a8df21f4SRob Suderman if (getOperand().getType() == getType())
83a8df21f4SRob Suderman return getOperand();
84a8df21f4SRob Suderman
85a8df21f4SRob Suderman return {};
86a8df21f4SRob Suderman }
87a8df21f4SRob Suderman
verify()88a8df21f4SRob Suderman LogicalResult BitcastOp::verify() {
89a8df21f4SRob Suderman auto operandType = getOperand().getType();
90a8df21f4SRob Suderman auto resultType = getType();
91a8df21f4SRob Suderman
92a8df21f4SRob Suderman // We allow this to be legal as it can be folded away.
93a8df21f4SRob Suderman if (operandType == resultType)
94a8df21f4SRob Suderman return success();
95a8df21f4SRob Suderman
96a8df21f4SRob Suderman if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97a8df21f4SRob Suderman return emitOpError("operand must be int/float/complex");
98a8df21f4SRob Suderman }
99a8df21f4SRob Suderman
100a8df21f4SRob Suderman if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101a8df21f4SRob Suderman return emitOpError("result must be int/float/complex");
102a8df21f4SRob Suderman }
103a8df21f4SRob Suderman
104a8df21f4SRob Suderman if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105192439dbSMatthias Springer return emitOpError(
106192439dbSMatthias Springer "requires that either input or output has a complex type");
107a8df21f4SRob Suderman }
108a8df21f4SRob Suderman
109a8df21f4SRob Suderman if (isa<ComplexType>(resultType))
110a8df21f4SRob Suderman std::swap(operandType, resultType);
111a8df21f4SRob Suderman
112a8df21f4SRob Suderman int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113a8df21f4SRob Suderman .getElementType()
114a8df21f4SRob Suderman .getIntOrFloatBitWidth() *
115a8df21f4SRob Suderman 2;
116a8df21f4SRob Suderman int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117a8df21f4SRob Suderman
118a8df21f4SRob Suderman if (operandBitwidth != resultBitwidth) {
119a8df21f4SRob Suderman return emitOpError("casting bitwidths do not match");
120a8df21f4SRob Suderman }
121a8df21f4SRob Suderman
122a8df21f4SRob Suderman return success();
123a8df21f4SRob Suderman }
124a8df21f4SRob Suderman
125a8df21f4SRob Suderman struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
126a8df21f4SRob Suderman using OpRewritePattern<BitcastOp>::OpRewritePattern;
127a8df21f4SRob Suderman
matchAndRewriteMergeComplexBitcast128a8df21f4SRob Suderman LogicalResult matchAndRewrite(BitcastOp op,
129a8df21f4SRob Suderman PatternRewriter &rewriter) const override {
130a8df21f4SRob Suderman if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131192439dbSMatthias Springer if (isa<ComplexType>(op.getType()) ||
132192439dbSMatthias Springer isa<ComplexType>(defining.getOperand().getType())) {
133192439dbSMatthias Springer // complex.bitcast requires that input or output is complex.
134a8df21f4SRob Suderman rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135a8df21f4SRob Suderman defining.getOperand());
136192439dbSMatthias Springer } else {
137192439dbSMatthias Springer rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
138192439dbSMatthias Springer defining.getOperand());
139192439dbSMatthias Springer }
140a8df21f4SRob Suderman return success();
141a8df21f4SRob Suderman }
142a8df21f4SRob Suderman
143a8df21f4SRob Suderman if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144a8df21f4SRob Suderman rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
145a8df21f4SRob Suderman defining.getOperand());
146a8df21f4SRob Suderman return success();
147a8df21f4SRob Suderman }
148a8df21f4SRob Suderman
149a8df21f4SRob Suderman return failure();
150a8df21f4SRob Suderman }
151a8df21f4SRob Suderman };
152a8df21f4SRob Suderman
153a8df21f4SRob Suderman struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
154a8df21f4SRob Suderman using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
155a8df21f4SRob Suderman
matchAndRewriteMergeArithBitcast156a8df21f4SRob Suderman LogicalResult matchAndRewrite(arith::BitcastOp op,
157a8df21f4SRob Suderman PatternRewriter &rewriter) const override {
158a8df21f4SRob Suderman if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159a8df21f4SRob Suderman rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
160a8df21f4SRob Suderman defining.getOperand());
161a8df21f4SRob Suderman return success();
162a8df21f4SRob Suderman }
163a8df21f4SRob Suderman
164a8df21f4SRob Suderman return failure();
165a8df21f4SRob Suderman }
166a8df21f4SRob Suderman };
167a8df21f4SRob Suderman
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)168a8df21f4SRob Suderman void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
169a8df21f4SRob Suderman MLIRContext *context) {
170192439dbSMatthias Springer results.add<MergeComplexBitcast, MergeArithBitcast>(context);
171a8df21f4SRob Suderman }
172a8df21f4SRob Suderman
173a8df21f4SRob Suderman //===----------------------------------------------------------------------===//
174480cd4cbSRiver Riddle // CreateOp
175480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
176fa765a09SAdrian Kuegel
fold(FoldAdaptor adaptor)1777df76121SMarkus Böck OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
178dee46d08SAdrian Kuegel // Fold complex.create(complex.re(op), complex.im(op)).
179dee46d08SAdrian Kuegel if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180dee46d08SAdrian Kuegel if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181dee46d08SAdrian Kuegel if (reOp.getOperand() == imOp.getOperand()) {
182dee46d08SAdrian Kuegel return reOp.getOperand();
183dee46d08SAdrian Kuegel }
184dee46d08SAdrian Kuegel }
185dee46d08SAdrian Kuegel }
186fa765a09SAdrian Kuegel return {};
187fa765a09SAdrian Kuegel }
188fa765a09SAdrian Kuegel
189480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
190480cd4cbSRiver Riddle // ImOp
191480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
192480cd4cbSRiver Riddle
fold(FoldAdaptor adaptor)1937df76121SMarkus Böck OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
19468f58812STres Popp ArrayAttr arrayAttr =
19568f58812STres Popp llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196fa765a09SAdrian Kuegel if (arrayAttr && arrayAttr.size() == 2)
197fa765a09SAdrian Kuegel return arrayAttr[1];
198cb65419bSAdrian Kuegel if (auto createOp = getOperand().getDefiningOp<CreateOp>())
199b99f892bSAdrian Kuegel return createOp.getOperand(1);
200fa765a09SAdrian Kuegel return {};
201fa765a09SAdrian Kuegel }
202dee46d08SAdrian Kuegel
203a1e78615SLei Zhang namespace {
204a1e78615SLei Zhang template <typename OpKind, int ComponentIndex>
205a1e78615SLei Zhang struct FoldComponentNeg final : OpRewritePattern<OpKind> {
206a1e78615SLei Zhang using OpRewritePattern<OpKind>::OpRewritePattern;
207a1e78615SLei Zhang
matchAndRewrite__anon416dd25c0111::FoldComponentNeg208a1e78615SLei Zhang LogicalResult matchAndRewrite(OpKind op,
209a1e78615SLei Zhang PatternRewriter &rewriter) const override {
210a1e78615SLei Zhang auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211a1e78615SLei Zhang if (!negOp)
212a1e78615SLei Zhang return failure();
213a1e78615SLei Zhang
214a1e78615SLei Zhang auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215a1e78615SLei Zhang if (!createOp)
216a1e78615SLei Zhang return failure();
217a1e78615SLei Zhang
218a1e78615SLei Zhang Type elementType = createOp.getType().getElementType();
219a1e78615SLei Zhang assert(isa<FloatType>(elementType));
220a1e78615SLei Zhang
221a1e78615SLei Zhang rewriter.replaceOpWithNewOp<arith::NegFOp>(
222a1e78615SLei Zhang op, elementType, createOp.getOperand(ComponentIndex));
223a1e78615SLei Zhang return success();
224a1e78615SLei Zhang }
225a1e78615SLei Zhang };
226a1e78615SLei Zhang } // namespace
227a1e78615SLei Zhang
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)228a1e78615SLei Zhang void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
229a1e78615SLei Zhang MLIRContext *context) {
230a1e78615SLei Zhang results.add<FoldComponentNeg<ImOp, 1>>(context);
231a1e78615SLei Zhang }
232a1e78615SLei Zhang
233480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
234480cd4cbSRiver Riddle // ReOp
235480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
236480cd4cbSRiver Riddle
fold(FoldAdaptor adaptor)2377df76121SMarkus Böck OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
23868f58812STres Popp ArrayAttr arrayAttr =
23968f58812STres Popp llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240dee46d08SAdrian Kuegel if (arrayAttr && arrayAttr.size() == 2)
241dee46d08SAdrian Kuegel return arrayAttr[0];
242dee46d08SAdrian Kuegel if (auto createOp = getOperand().getDefiningOp<CreateOp>())
243dee46d08SAdrian Kuegel return createOp.getOperand(0);
244dee46d08SAdrian Kuegel return {};
245dee46d08SAdrian Kuegel }
246480cd4cbSRiver Riddle
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)247a1e78615SLei Zhang void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
248a1e78615SLei Zhang MLIRContext *context) {
249a1e78615SLei Zhang results.add<FoldComponentNeg<ReOp, 0>>(context);
250a1e78615SLei Zhang }
251a1e78615SLei Zhang
252480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
253036a6996Slewuathe // AddOp
254036a6996Slewuathe //===----------------------------------------------------------------------===//
255036a6996Slewuathe
fold(FoldAdaptor adaptor)2567df76121SMarkus Böck OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
257036a6996Slewuathe // complex.add(complex.sub(a, b), b) -> a
258036a6996Slewuathe if (auto sub = getLhs().getDefiningOp<SubOp>())
259036a6996Slewuathe if (getRhs() == sub.getRhs())
260036a6996Slewuathe return sub.getLhs();
261036a6996Slewuathe
262036a6996Slewuathe // complex.add(b, complex.sub(a, b)) -> a
263036a6996Slewuathe if (auto sub = getRhs().getDefiningOp<SubOp>())
264036a6996Slewuathe if (getLhs() == sub.getRhs())
265036a6996Slewuathe return sub.getLhs();
266036a6996Slewuathe
267730cb822Slewuathe // complex.add(a, complex.constant<0.0, 0.0>) -> a
268730cb822Slewuathe if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269730cb822Slewuathe auto arrayAttr = constantOp.getValue();
270c1fa60b4STres Popp if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271c1fa60b4STres Popp llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
272730cb822Slewuathe return getLhs();
273730cb822Slewuathe }
274730cb822Slewuathe }
275730cb822Slewuathe
276036a6996Slewuathe return {};
277036a6996Slewuathe }
278036a6996Slewuathe
279036a6996Slewuathe //===----------------------------------------------------------------------===//
280ccf97505SKai Sasaki // SubOp
281ccf97505SKai Sasaki //===----------------------------------------------------------------------===//
282ccf97505SKai Sasaki
fold(FoldAdaptor adaptor)2837df76121SMarkus Böck OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
284ccf97505SKai Sasaki // complex.sub(complex.add(a, b), b) -> a
285ccf97505SKai Sasaki if (auto add = getLhs().getDefiningOp<AddOp>())
286ccf97505SKai Sasaki if (getRhs() == add.getRhs())
287ccf97505SKai Sasaki return add.getLhs();
288ccf97505SKai Sasaki
289c9741bafSKai Sasaki // complex.sub(a, complex.constant<0.0, 0.0>) -> a
290c9741bafSKai Sasaki if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291c9741bafSKai Sasaki auto arrayAttr = constantOp.getValue();
292c1fa60b4STres Popp if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293c1fa60b4STres Popp llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
294c9741bafSKai Sasaki return getLhs();
295c9741bafSKai Sasaki }
296c9741bafSKai Sasaki }
297c9741bafSKai Sasaki
298ccf97505SKai Sasaki return {};
299ccf97505SKai Sasaki }
300ccf97505SKai Sasaki
301ccf97505SKai Sasaki //===----------------------------------------------------------------------===//
30201807095Slewuathe // NegOp
30301807095Slewuathe //===----------------------------------------------------------------------===//
30401807095Slewuathe
fold(FoldAdaptor adaptor)3057df76121SMarkus Böck OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
30601807095Slewuathe // complex.neg(complex.neg(a)) -> a
30701807095Slewuathe if (auto negOp = getOperand().getDefiningOp<NegOp>())
30801807095Slewuathe return negOp.getOperand();
30901807095Slewuathe
31001807095Slewuathe return {};
31101807095Slewuathe }
31201807095Slewuathe
31301807095Slewuathe //===----------------------------------------------------------------------===//
3145148c685Slewuathe // LogOp
3155148c685Slewuathe //===----------------------------------------------------------------------===//
3165148c685Slewuathe
fold(FoldAdaptor adaptor)3177df76121SMarkus Böck OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
3185148c685Slewuathe // complex.log(complex.exp(a)) -> a
3195148c685Slewuathe if (auto expOp = getOperand().getDefiningOp<ExpOp>())
3205148c685Slewuathe return expOp.getOperand();
3215148c685Slewuathe
3225148c685Slewuathe return {};
3235148c685Slewuathe }
3245148c685Slewuathe
3255148c685Slewuathe //===----------------------------------------------------------------------===//
3265148c685Slewuathe // ExpOp
3275148c685Slewuathe //===----------------------------------------------------------------------===//
3285148c685Slewuathe
fold(FoldAdaptor adaptor)3297df76121SMarkus Böck OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
3305148c685Slewuathe // complex.exp(complex.log(a)) -> a
3315148c685Slewuathe if (auto logOp = getOperand().getDefiningOp<LogOp>())
3325148c685Slewuathe return logOp.getOperand();
3335148c685Slewuathe
3345148c685Slewuathe return {};
3355148c685Slewuathe }
3365148c685Slewuathe
3375148c685Slewuathe //===----------------------------------------------------------------------===//
338bcd538abSlewuathe // ConjOp
339bcd538abSlewuathe //===----------------------------------------------------------------------===//
340bcd538abSlewuathe
fold(FoldAdaptor adaptor)3417df76121SMarkus Böck OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
342bcd538abSlewuathe // complex.conj(complex.conj(a)) -> a
343bcd538abSlewuathe if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
344bcd538abSlewuathe return conjOp.getOperand();
345bcd538abSlewuathe
346bcd538abSlewuathe return {};
347bcd538abSlewuathe }
348bcd538abSlewuathe
349bcd538abSlewuathe //===----------------------------------------------------------------------===//
3508d175b35SKai Sasaki // MulOp
3518d175b35SKai Sasaki //===----------------------------------------------------------------------===//
3528d175b35SKai Sasaki
fold(FoldAdaptor adaptor)3538d175b35SKai Sasaki OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
3548d175b35SKai Sasaki auto constant = getRhs().getDefiningOp<ConstantOp>();
3558d175b35SKai Sasaki if (!constant)
3568d175b35SKai Sasaki return {};
3578d175b35SKai Sasaki
3588d175b35SKai Sasaki ArrayAttr arrayAttr = constant.getValue();
3598d175b35SKai Sasaki APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
3608d175b35SKai Sasaki APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
3618d175b35SKai Sasaki
3628d175b35SKai Sasaki if (!imag.isZero())
3638d175b35SKai Sasaki return {};
3648d175b35SKai Sasaki
3658d175b35SKai Sasaki // complex.mul(a, complex.constant<1.0, 0.0>) -> a
3668d175b35SKai Sasaki if (real == APFloat(real.getSemantics(), 1))
3678d175b35SKai Sasaki return getLhs();
3688d175b35SKai Sasaki
3698d175b35SKai Sasaki return {};
3708d175b35SKai Sasaki }
3718d175b35SKai Sasaki
3728d175b35SKai Sasaki //===----------------------------------------------------------------------===//
37308a321e1SKai Sasaki // DivOp
37408a321e1SKai Sasaki //===----------------------------------------------------------------------===//
37508a321e1SKai Sasaki
fold(FoldAdaptor adaptor)37608a321e1SKai Sasaki OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
37708a321e1SKai Sasaki auto rhs = adaptor.getRhs();
37808a321e1SKai Sasaki if (!rhs)
37908a321e1SKai Sasaki return {};
38008a321e1SKai Sasaki
381*a5757c5bSChristian Sigg ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
38208a321e1SKai Sasaki if (!arrayAttr || arrayAttr.size() != 2)
38308a321e1SKai Sasaki return {};
38408a321e1SKai Sasaki
38508a321e1SKai Sasaki APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
38608a321e1SKai Sasaki APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
38708a321e1SKai Sasaki
38808a321e1SKai Sasaki if (!imag.isZero())
38908a321e1SKai Sasaki return {};
39008a321e1SKai Sasaki
39108a321e1SKai Sasaki // complex.div(a, complex.constant<1.0, 0.0>) -> a
39208a321e1SKai Sasaki if (real == APFloat(real.getSemantics(), 1))
39308a321e1SKai Sasaki return getLhs();
39408a321e1SKai Sasaki
39508a321e1SKai Sasaki return {};
39608a321e1SKai Sasaki }
39708a321e1SKai Sasaki
39808a321e1SKai Sasaki //===----------------------------------------------------------------------===//
399480cd4cbSRiver Riddle // TableGen'd op method definitions
400480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
401480cd4cbSRiver Riddle
402480cd4cbSRiver Riddle #define GET_OP_CLASSES
403480cd4cbSRiver Riddle #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
404