xref: /llvm-project/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1d0cb0d30SAlexander Belyaev //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2d0cb0d30SAlexander Belyaev //
3d0cb0d30SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d0cb0d30SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5d0cb0d30SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d0cb0d30SAlexander Belyaev //
7d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
8d0cb0d30SAlexander Belyaev 
9a1e78615SLei Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
10d0cb0d30SAlexander Belyaev #include "mlir/Dialect/Complex/IR/Complex.h"
11a28fe17dSAdrian Kuegel #include "mlir/IR/Builders.h"
12a1e78615SLei Zhang #include "mlir/IR/BuiltinTypes.h"
13036a6996Slewuathe #include "mlir/IR/Matchers.h"
14a1e78615SLei Zhang #include "mlir/IR/PatternMatch.h"
15d0cb0d30SAlexander Belyaev 
16d0cb0d30SAlexander Belyaev using namespace mlir;
17d0cb0d30SAlexander Belyaev using namespace mlir::complex;
18d0cb0d30SAlexander Belyaev 
19d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
20480cd4cbSRiver Riddle // ConstantOp
21d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
22d0cb0d30SAlexander Belyaev 
fold(FoldAdaptor adaptor)237df76121SMarkus Böck OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24480cd4cbSRiver Riddle   return getValue();
25480cd4cbSRiver Riddle }
26480cd4cbSRiver Riddle 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)27480cd4cbSRiver Riddle void ConstantOp::getAsmResultNames(
28480cd4cbSRiver Riddle     function_ref<void(Value, StringRef)> setNameFn) {
29480cd4cbSRiver Riddle   setNameFn(getResult(), "cst");
30480cd4cbSRiver Riddle }
31480cd4cbSRiver Riddle 
isBuildableWith(Attribute value,Type type)32480cd4cbSRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33c1fa60b4STres Popp   if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34c1fa60b4STres Popp     auto complexTy = llvm::dyn_cast<ComplexType>(type);
35e1795322SJeff Niu     if (!complexTy || arrAttr.size() != 2)
36480cd4cbSRiver Riddle       return false;
37480cd4cbSRiver Riddle     auto complexEltTy = complexTy.getElementType();
38c1fa60b4STres Popp     if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39c1fa60b4STres Popp       auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40cc4fb583SXiang Li       return im && fre.getType() == complexEltTy &&
41e1795322SJeff Niu              im.getType() == complexEltTy;
42480cd4cbSRiver Riddle     }
43c1fa60b4STres Popp     if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44c1fa60b4STres Popp       auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45cc4fb583SXiang Li       return im && ire.getType() == complexEltTy &&
46cc4fb583SXiang Li              im.getType() == complexEltTy;
47cc4fb583SXiang Li     }
48cc4fb583SXiang Li   }
49480cd4cbSRiver Riddle   return false;
50480cd4cbSRiver Riddle }
51480cd4cbSRiver Riddle 
verify()521be88f5aSRiver Riddle LogicalResult ConstantOp::verify() {
531be88f5aSRiver Riddle   ArrayAttr arrayAttr = getValue();
54480cd4cbSRiver Riddle   if (arrayAttr.size() != 2) {
551be88f5aSRiver Riddle     return emitOpError(
56480cd4cbSRiver Riddle         "requires 'value' to be a complex constant, represented as array of "
57480cd4cbSRiver Riddle         "two values");
58480cd4cbSRiver Riddle   }
59480cd4cbSRiver Riddle 
601be88f5aSRiver Riddle   auto complexEltTy = getType().getElementType();
6116129937SMatthias Springer   if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
6216129937SMatthias Springer       !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
6316129937SMatthias Springer     return emitOpError(
6416129937SMatthias Springer         "requires attribute's elements to be float or integer attributes");
6516129937SMatthias Springer   auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
6616129937SMatthias Springer   auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67e1795322SJeff Niu   if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
681be88f5aSRiver Riddle     return emitOpError()
69e1795322SJeff Niu            << "requires attribute's element types (" << re.getType() << ", "
70e1795322SJeff Niu            << im.getType()
71480cd4cbSRiver Riddle            << ") to match the element type of the op's return type ("
72480cd4cbSRiver Riddle            << complexEltTy << ")";
73480cd4cbSRiver Riddle   }
74480cd4cbSRiver Riddle   return success();
75480cd4cbSRiver Riddle }
76480cd4cbSRiver Riddle 
77480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
78a8df21f4SRob Suderman // BitcastOp
79a8df21f4SRob Suderman //===----------------------------------------------------------------------===//
80a8df21f4SRob Suderman 
fold(FoldAdaptor bitcast)81a8df21f4SRob Suderman OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
82a8df21f4SRob Suderman   if (getOperand().getType() == getType())
83a8df21f4SRob Suderman     return getOperand();
84a8df21f4SRob Suderman 
85a8df21f4SRob Suderman   return {};
86a8df21f4SRob Suderman }
87a8df21f4SRob Suderman 
verify()88a8df21f4SRob Suderman LogicalResult BitcastOp::verify() {
89a8df21f4SRob Suderman   auto operandType = getOperand().getType();
90a8df21f4SRob Suderman   auto resultType = getType();
91a8df21f4SRob Suderman 
92a8df21f4SRob Suderman   // We allow this to be legal as it can be folded away.
93a8df21f4SRob Suderman   if (operandType == resultType)
94a8df21f4SRob Suderman     return success();
95a8df21f4SRob Suderman 
96a8df21f4SRob Suderman   if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97a8df21f4SRob Suderman     return emitOpError("operand must be int/float/complex");
98a8df21f4SRob Suderman   }
99a8df21f4SRob Suderman 
100a8df21f4SRob Suderman   if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101a8df21f4SRob Suderman     return emitOpError("result must be int/float/complex");
102a8df21f4SRob Suderman   }
103a8df21f4SRob Suderman 
104a8df21f4SRob Suderman   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105192439dbSMatthias Springer     return emitOpError(
106192439dbSMatthias Springer         "requires that either input or output has a complex type");
107a8df21f4SRob Suderman   }
108a8df21f4SRob Suderman 
109a8df21f4SRob Suderman   if (isa<ComplexType>(resultType))
110a8df21f4SRob Suderman     std::swap(operandType, resultType);
111a8df21f4SRob Suderman 
112a8df21f4SRob Suderman   int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113a8df21f4SRob Suderman                                 .getElementType()
114a8df21f4SRob Suderman                                 .getIntOrFloatBitWidth() *
115a8df21f4SRob Suderman                             2;
116a8df21f4SRob Suderman   int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117a8df21f4SRob Suderman 
118a8df21f4SRob Suderman   if (operandBitwidth != resultBitwidth) {
119a8df21f4SRob Suderman     return emitOpError("casting bitwidths do not match");
120a8df21f4SRob Suderman   }
121a8df21f4SRob Suderman 
122a8df21f4SRob Suderman   return success();
123a8df21f4SRob Suderman }
124a8df21f4SRob Suderman 
125a8df21f4SRob Suderman struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
126a8df21f4SRob Suderman   using OpRewritePattern<BitcastOp>::OpRewritePattern;
127a8df21f4SRob Suderman 
matchAndRewriteMergeComplexBitcast128a8df21f4SRob Suderman   LogicalResult matchAndRewrite(BitcastOp op,
129a8df21f4SRob Suderman                                 PatternRewriter &rewriter) const override {
130a8df21f4SRob Suderman     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131192439dbSMatthias Springer       if (isa<ComplexType>(op.getType()) ||
132192439dbSMatthias Springer           isa<ComplexType>(defining.getOperand().getType())) {
133192439dbSMatthias Springer         // complex.bitcast requires that input or output is complex.
134a8df21f4SRob Suderman         rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135a8df21f4SRob Suderman                                                defining.getOperand());
136192439dbSMatthias Springer       } else {
137192439dbSMatthias Springer         rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
138192439dbSMatthias Springer                                                       defining.getOperand());
139192439dbSMatthias Springer       }
140a8df21f4SRob Suderman       return success();
141a8df21f4SRob Suderman     }
142a8df21f4SRob Suderman 
143a8df21f4SRob Suderman     if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144a8df21f4SRob Suderman       rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
145a8df21f4SRob Suderman                                              defining.getOperand());
146a8df21f4SRob Suderman       return success();
147a8df21f4SRob Suderman     }
148a8df21f4SRob Suderman 
149a8df21f4SRob Suderman     return failure();
150a8df21f4SRob Suderman   }
151a8df21f4SRob Suderman };
152a8df21f4SRob Suderman 
153a8df21f4SRob Suderman struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
154a8df21f4SRob Suderman   using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
155a8df21f4SRob Suderman 
matchAndRewriteMergeArithBitcast156a8df21f4SRob Suderman   LogicalResult matchAndRewrite(arith::BitcastOp op,
157a8df21f4SRob Suderman                                 PatternRewriter &rewriter) const override {
158a8df21f4SRob Suderman     if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159a8df21f4SRob Suderman       rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
160a8df21f4SRob Suderman                                                       defining.getOperand());
161a8df21f4SRob Suderman       return success();
162a8df21f4SRob Suderman     }
163a8df21f4SRob Suderman 
164a8df21f4SRob Suderman     return failure();
165a8df21f4SRob Suderman   }
166a8df21f4SRob Suderman };
167a8df21f4SRob Suderman 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)168a8df21f4SRob Suderman void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
169a8df21f4SRob Suderman                                             MLIRContext *context) {
170192439dbSMatthias Springer   results.add<MergeComplexBitcast, MergeArithBitcast>(context);
171a8df21f4SRob Suderman }
172a8df21f4SRob Suderman 
173a8df21f4SRob Suderman //===----------------------------------------------------------------------===//
174480cd4cbSRiver Riddle // CreateOp
175480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
176fa765a09SAdrian Kuegel 
fold(FoldAdaptor adaptor)1777df76121SMarkus Böck OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
178dee46d08SAdrian Kuegel   // Fold complex.create(complex.re(op), complex.im(op)).
179dee46d08SAdrian Kuegel   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180dee46d08SAdrian Kuegel     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181dee46d08SAdrian Kuegel       if (reOp.getOperand() == imOp.getOperand()) {
182dee46d08SAdrian Kuegel         return reOp.getOperand();
183dee46d08SAdrian Kuegel       }
184dee46d08SAdrian Kuegel     }
185dee46d08SAdrian Kuegel   }
186fa765a09SAdrian Kuegel   return {};
187fa765a09SAdrian Kuegel }
188fa765a09SAdrian Kuegel 
189480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
190480cd4cbSRiver Riddle // ImOp
191480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
192480cd4cbSRiver Riddle 
fold(FoldAdaptor adaptor)1937df76121SMarkus Böck OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
19468f58812STres Popp   ArrayAttr arrayAttr =
19568f58812STres Popp       llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196fa765a09SAdrian Kuegel   if (arrayAttr && arrayAttr.size() == 2)
197fa765a09SAdrian Kuegel     return arrayAttr[1];
198cb65419bSAdrian Kuegel   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
199b99f892bSAdrian Kuegel     return createOp.getOperand(1);
200fa765a09SAdrian Kuegel   return {};
201fa765a09SAdrian Kuegel }
202dee46d08SAdrian Kuegel 
203a1e78615SLei Zhang namespace {
204a1e78615SLei Zhang template <typename OpKind, int ComponentIndex>
205a1e78615SLei Zhang struct FoldComponentNeg final : OpRewritePattern<OpKind> {
206a1e78615SLei Zhang   using OpRewritePattern<OpKind>::OpRewritePattern;
207a1e78615SLei Zhang 
matchAndRewrite__anon416dd25c0111::FoldComponentNeg208a1e78615SLei Zhang   LogicalResult matchAndRewrite(OpKind op,
209a1e78615SLei Zhang                                 PatternRewriter &rewriter) const override {
210a1e78615SLei Zhang     auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211a1e78615SLei Zhang     if (!negOp)
212a1e78615SLei Zhang       return failure();
213a1e78615SLei Zhang 
214a1e78615SLei Zhang     auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215a1e78615SLei Zhang     if (!createOp)
216a1e78615SLei Zhang       return failure();
217a1e78615SLei Zhang 
218a1e78615SLei Zhang     Type elementType = createOp.getType().getElementType();
219a1e78615SLei Zhang     assert(isa<FloatType>(elementType));
220a1e78615SLei Zhang 
221a1e78615SLei Zhang     rewriter.replaceOpWithNewOp<arith::NegFOp>(
222a1e78615SLei Zhang         op, elementType, createOp.getOperand(ComponentIndex));
223a1e78615SLei Zhang     return success();
224a1e78615SLei Zhang   }
225a1e78615SLei Zhang };
226a1e78615SLei Zhang } // namespace
227a1e78615SLei Zhang 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)228a1e78615SLei Zhang void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
229a1e78615SLei Zhang                                        MLIRContext *context) {
230a1e78615SLei Zhang   results.add<FoldComponentNeg<ImOp, 1>>(context);
231a1e78615SLei Zhang }
232a1e78615SLei Zhang 
233480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
234480cd4cbSRiver Riddle // ReOp
235480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
236480cd4cbSRiver Riddle 
fold(FoldAdaptor adaptor)2377df76121SMarkus Böck OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
23868f58812STres Popp   ArrayAttr arrayAttr =
23968f58812STres Popp       llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240dee46d08SAdrian Kuegel   if (arrayAttr && arrayAttr.size() == 2)
241dee46d08SAdrian Kuegel     return arrayAttr[0];
242dee46d08SAdrian Kuegel   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
243dee46d08SAdrian Kuegel     return createOp.getOperand(0);
244dee46d08SAdrian Kuegel   return {};
245dee46d08SAdrian Kuegel }
246480cd4cbSRiver Riddle 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)247a1e78615SLei Zhang void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
248a1e78615SLei Zhang                                        MLIRContext *context) {
249a1e78615SLei Zhang   results.add<FoldComponentNeg<ReOp, 0>>(context);
250a1e78615SLei Zhang }
251a1e78615SLei Zhang 
252480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
253036a6996Slewuathe // AddOp
254036a6996Slewuathe //===----------------------------------------------------------------------===//
255036a6996Slewuathe 
fold(FoldAdaptor adaptor)2567df76121SMarkus Böck OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
257036a6996Slewuathe   // complex.add(complex.sub(a, b), b) -> a
258036a6996Slewuathe   if (auto sub = getLhs().getDefiningOp<SubOp>())
259036a6996Slewuathe     if (getRhs() == sub.getRhs())
260036a6996Slewuathe       return sub.getLhs();
261036a6996Slewuathe 
262036a6996Slewuathe   // complex.add(b, complex.sub(a, b)) -> a
263036a6996Slewuathe   if (auto sub = getRhs().getDefiningOp<SubOp>())
264036a6996Slewuathe     if (getLhs() == sub.getRhs())
265036a6996Slewuathe       return sub.getLhs();
266036a6996Slewuathe 
267730cb822Slewuathe   // complex.add(a, complex.constant<0.0, 0.0>) -> a
268730cb822Slewuathe   if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269730cb822Slewuathe     auto arrayAttr = constantOp.getValue();
270c1fa60b4STres Popp     if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271c1fa60b4STres Popp         llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
272730cb822Slewuathe       return getLhs();
273730cb822Slewuathe     }
274730cb822Slewuathe   }
275730cb822Slewuathe 
276036a6996Slewuathe   return {};
277036a6996Slewuathe }
278036a6996Slewuathe 
279036a6996Slewuathe //===----------------------------------------------------------------------===//
280ccf97505SKai Sasaki // SubOp
281ccf97505SKai Sasaki //===----------------------------------------------------------------------===//
282ccf97505SKai Sasaki 
fold(FoldAdaptor adaptor)2837df76121SMarkus Böck OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
284ccf97505SKai Sasaki   // complex.sub(complex.add(a, b), b) -> a
285ccf97505SKai Sasaki   if (auto add = getLhs().getDefiningOp<AddOp>())
286ccf97505SKai Sasaki     if (getRhs() == add.getRhs())
287ccf97505SKai Sasaki       return add.getLhs();
288ccf97505SKai Sasaki 
289c9741bafSKai Sasaki   // complex.sub(a, complex.constant<0.0, 0.0>) -> a
290c9741bafSKai Sasaki   if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291c9741bafSKai Sasaki     auto arrayAttr = constantOp.getValue();
292c1fa60b4STres Popp     if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293c1fa60b4STres Popp         llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
294c9741bafSKai Sasaki       return getLhs();
295c9741bafSKai Sasaki     }
296c9741bafSKai Sasaki   }
297c9741bafSKai Sasaki 
298ccf97505SKai Sasaki   return {};
299ccf97505SKai Sasaki }
300ccf97505SKai Sasaki 
301ccf97505SKai Sasaki //===----------------------------------------------------------------------===//
30201807095Slewuathe // NegOp
30301807095Slewuathe //===----------------------------------------------------------------------===//
30401807095Slewuathe 
fold(FoldAdaptor adaptor)3057df76121SMarkus Böck OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
30601807095Slewuathe   // complex.neg(complex.neg(a)) -> a
30701807095Slewuathe   if (auto negOp = getOperand().getDefiningOp<NegOp>())
30801807095Slewuathe     return negOp.getOperand();
30901807095Slewuathe 
31001807095Slewuathe   return {};
31101807095Slewuathe }
31201807095Slewuathe 
31301807095Slewuathe //===----------------------------------------------------------------------===//
3145148c685Slewuathe // LogOp
3155148c685Slewuathe //===----------------------------------------------------------------------===//
3165148c685Slewuathe 
fold(FoldAdaptor adaptor)3177df76121SMarkus Böck OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
3185148c685Slewuathe   // complex.log(complex.exp(a)) -> a
3195148c685Slewuathe   if (auto expOp = getOperand().getDefiningOp<ExpOp>())
3205148c685Slewuathe     return expOp.getOperand();
3215148c685Slewuathe 
3225148c685Slewuathe   return {};
3235148c685Slewuathe }
3245148c685Slewuathe 
3255148c685Slewuathe //===----------------------------------------------------------------------===//
3265148c685Slewuathe // ExpOp
3275148c685Slewuathe //===----------------------------------------------------------------------===//
3285148c685Slewuathe 
fold(FoldAdaptor adaptor)3297df76121SMarkus Böck OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
3305148c685Slewuathe   // complex.exp(complex.log(a)) -> a
3315148c685Slewuathe   if (auto logOp = getOperand().getDefiningOp<LogOp>())
3325148c685Slewuathe     return logOp.getOperand();
3335148c685Slewuathe 
3345148c685Slewuathe   return {};
3355148c685Slewuathe }
3365148c685Slewuathe 
3375148c685Slewuathe //===----------------------------------------------------------------------===//
338bcd538abSlewuathe // ConjOp
339bcd538abSlewuathe //===----------------------------------------------------------------------===//
340bcd538abSlewuathe 
fold(FoldAdaptor adaptor)3417df76121SMarkus Böck OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
342bcd538abSlewuathe   // complex.conj(complex.conj(a)) -> a
343bcd538abSlewuathe   if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
344bcd538abSlewuathe     return conjOp.getOperand();
345bcd538abSlewuathe 
346bcd538abSlewuathe   return {};
347bcd538abSlewuathe }
348bcd538abSlewuathe 
349bcd538abSlewuathe //===----------------------------------------------------------------------===//
3508d175b35SKai Sasaki // MulOp
3518d175b35SKai Sasaki //===----------------------------------------------------------------------===//
3528d175b35SKai Sasaki 
fold(FoldAdaptor adaptor)3538d175b35SKai Sasaki OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
3548d175b35SKai Sasaki   auto constant = getRhs().getDefiningOp<ConstantOp>();
3558d175b35SKai Sasaki   if (!constant)
3568d175b35SKai Sasaki     return {};
3578d175b35SKai Sasaki 
3588d175b35SKai Sasaki   ArrayAttr arrayAttr = constant.getValue();
3598d175b35SKai Sasaki   APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
3608d175b35SKai Sasaki   APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
3618d175b35SKai Sasaki 
3628d175b35SKai Sasaki   if (!imag.isZero())
3638d175b35SKai Sasaki     return {};
3648d175b35SKai Sasaki 
3658d175b35SKai Sasaki   // complex.mul(a, complex.constant<1.0, 0.0>) -> a
3668d175b35SKai Sasaki   if (real == APFloat(real.getSemantics(), 1))
3678d175b35SKai Sasaki     return getLhs();
3688d175b35SKai Sasaki 
3698d175b35SKai Sasaki   return {};
3708d175b35SKai Sasaki }
3718d175b35SKai Sasaki 
3728d175b35SKai Sasaki //===----------------------------------------------------------------------===//
37308a321e1SKai Sasaki // DivOp
37408a321e1SKai Sasaki //===----------------------------------------------------------------------===//
37508a321e1SKai Sasaki 
fold(FoldAdaptor adaptor)37608a321e1SKai Sasaki OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
37708a321e1SKai Sasaki   auto rhs = adaptor.getRhs();
37808a321e1SKai Sasaki   if (!rhs)
37908a321e1SKai Sasaki     return {};
38008a321e1SKai Sasaki 
381*a5757c5bSChristian Sigg   ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
38208a321e1SKai Sasaki   if (!arrayAttr || arrayAttr.size() != 2)
38308a321e1SKai Sasaki     return {};
38408a321e1SKai Sasaki 
38508a321e1SKai Sasaki   APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
38608a321e1SKai Sasaki   APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
38708a321e1SKai Sasaki 
38808a321e1SKai Sasaki   if (!imag.isZero())
38908a321e1SKai Sasaki     return {};
39008a321e1SKai Sasaki 
39108a321e1SKai Sasaki   // complex.div(a, complex.constant<1.0, 0.0>) -> a
39208a321e1SKai Sasaki   if (real == APFloat(real.getSemantics(), 1))
39308a321e1SKai Sasaki     return getLhs();
39408a321e1SKai Sasaki 
39508a321e1SKai Sasaki   return {};
39608a321e1SKai Sasaki }
39708a321e1SKai Sasaki 
39808a321e1SKai Sasaki //===----------------------------------------------------------------------===//
399480cd4cbSRiver Riddle // TableGen'd op method definitions
400480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
401480cd4cbSRiver Riddle 
402480cd4cbSRiver Riddle #define GET_OP_CLASSES
403480cd4cbSRiver Riddle #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
404