xref: /llvm-project/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Complex/IR/Complex.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
15 
16 using namespace mlir;
17 using namespace mlir::complex;
18 
19 //===----------------------------------------------------------------------===//
20 // ConstantOp
21 //===----------------------------------------------------------------------===//
22 
fold(FoldAdaptor adaptor)23 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24   return getValue();
25 }
26 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)27 void ConstantOp::getAsmResultNames(
28     function_ref<void(Value, StringRef)> setNameFn) {
29   setNameFn(getResult(), "cst");
30 }
31 
isBuildableWith(Attribute value,Type type)32 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33   if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34     auto complexTy = llvm::dyn_cast<ComplexType>(type);
35     if (!complexTy || arrAttr.size() != 2)
36       return false;
37     auto complexEltTy = complexTy.getElementType();
38     if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39       auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40       return im && fre.getType() == complexEltTy &&
41              im.getType() == complexEltTy;
42     }
43     if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44       auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45       return im && ire.getType() == complexEltTy &&
46              im.getType() == complexEltTy;
47     }
48   }
49   return false;
50 }
51 
verify()52 LogicalResult ConstantOp::verify() {
53   ArrayAttr arrayAttr = getValue();
54   if (arrayAttr.size() != 2) {
55     return emitOpError(
56         "requires 'value' to be a complex constant, represented as array of "
57         "two values");
58   }
59 
60   auto complexEltTy = getType().getElementType();
61   if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
62       !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
63     return emitOpError(
64         "requires attribute's elements to be float or integer attributes");
65   auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
66   auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67   if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
68     return emitOpError()
69            << "requires attribute's element types (" << re.getType() << ", "
70            << im.getType()
71            << ") to match the element type of the op's return type ("
72            << complexEltTy << ")";
73   }
74   return success();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // BitcastOp
79 //===----------------------------------------------------------------------===//
80 
fold(FoldAdaptor bitcast)81 OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
82   if (getOperand().getType() == getType())
83     return getOperand();
84 
85   return {};
86 }
87 
verify()88 LogicalResult BitcastOp::verify() {
89   auto operandType = getOperand().getType();
90   auto resultType = getType();
91 
92   // We allow this to be legal as it can be folded away.
93   if (operandType == resultType)
94     return success();
95 
96   if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97     return emitOpError("operand must be int/float/complex");
98   }
99 
100   if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101     return emitOpError("result must be int/float/complex");
102   }
103 
104   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105     return emitOpError(
106         "requires that either input or output has a complex type");
107   }
108 
109   if (isa<ComplexType>(resultType))
110     std::swap(operandType, resultType);
111 
112   int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113                                 .getElementType()
114                                 .getIntOrFloatBitWidth() *
115                             2;
116   int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117 
118   if (operandBitwidth != resultBitwidth) {
119     return emitOpError("casting bitwidths do not match");
120   }
121 
122   return success();
123 }
124 
125 struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
126   using OpRewritePattern<BitcastOp>::OpRewritePattern;
127 
matchAndRewriteMergeComplexBitcast128   LogicalResult matchAndRewrite(BitcastOp op,
129                                 PatternRewriter &rewriter) const override {
130     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131       if (isa<ComplexType>(op.getType()) ||
132           isa<ComplexType>(defining.getOperand().getType())) {
133         // complex.bitcast requires that input or output is complex.
134         rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135                                                defining.getOperand());
136       } else {
137         rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
138                                                       defining.getOperand());
139       }
140       return success();
141     }
142 
143     if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144       rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
145                                              defining.getOperand());
146       return success();
147     }
148 
149     return failure();
150   }
151 };
152 
153 struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
154   using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
155 
matchAndRewriteMergeArithBitcast156   LogicalResult matchAndRewrite(arith::BitcastOp op,
157                                 PatternRewriter &rewriter) const override {
158     if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159       rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
160                                                       defining.getOperand());
161       return success();
162     }
163 
164     return failure();
165   }
166 };
167 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)168 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
169                                             MLIRContext *context) {
170   results.add<MergeComplexBitcast, MergeArithBitcast>(context);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // CreateOp
175 //===----------------------------------------------------------------------===//
176 
fold(FoldAdaptor adaptor)177 OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
178   // Fold complex.create(complex.re(op), complex.im(op)).
179   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181       if (reOp.getOperand() == imOp.getOperand()) {
182         return reOp.getOperand();
183       }
184     }
185   }
186   return {};
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // ImOp
191 //===----------------------------------------------------------------------===//
192 
fold(FoldAdaptor adaptor)193 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
194   ArrayAttr arrayAttr =
195       llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196   if (arrayAttr && arrayAttr.size() == 2)
197     return arrayAttr[1];
198   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
199     return createOp.getOperand(1);
200   return {};
201 }
202 
203 namespace {
204 template <typename OpKind, int ComponentIndex>
205 struct FoldComponentNeg final : OpRewritePattern<OpKind> {
206   using OpRewritePattern<OpKind>::OpRewritePattern;
207 
matchAndRewrite__anon416dd25c0111::FoldComponentNeg208   LogicalResult matchAndRewrite(OpKind op,
209                                 PatternRewriter &rewriter) const override {
210     auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211     if (!negOp)
212       return failure();
213 
214     auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215     if (!createOp)
216       return failure();
217 
218     Type elementType = createOp.getType().getElementType();
219     assert(isa<FloatType>(elementType));
220 
221     rewriter.replaceOpWithNewOp<arith::NegFOp>(
222         op, elementType, createOp.getOperand(ComponentIndex));
223     return success();
224   }
225 };
226 } // namespace
227 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)228 void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
229                                        MLIRContext *context) {
230   results.add<FoldComponentNeg<ImOp, 1>>(context);
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // ReOp
235 //===----------------------------------------------------------------------===//
236 
fold(FoldAdaptor adaptor)237 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
238   ArrayAttr arrayAttr =
239       llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240   if (arrayAttr && arrayAttr.size() == 2)
241     return arrayAttr[0];
242   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
243     return createOp.getOperand(0);
244   return {};
245 }
246 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)247 void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
248                                        MLIRContext *context) {
249   results.add<FoldComponentNeg<ReOp, 0>>(context);
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // AddOp
254 //===----------------------------------------------------------------------===//
255 
fold(FoldAdaptor adaptor)256 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
257   // complex.add(complex.sub(a, b), b) -> a
258   if (auto sub = getLhs().getDefiningOp<SubOp>())
259     if (getRhs() == sub.getRhs())
260       return sub.getLhs();
261 
262   // complex.add(b, complex.sub(a, b)) -> a
263   if (auto sub = getRhs().getDefiningOp<SubOp>())
264     if (getLhs() == sub.getRhs())
265       return sub.getLhs();
266 
267   // complex.add(a, complex.constant<0.0, 0.0>) -> a
268   if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269     auto arrayAttr = constantOp.getValue();
270     if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271         llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
272       return getLhs();
273     }
274   }
275 
276   return {};
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // SubOp
281 //===----------------------------------------------------------------------===//
282 
fold(FoldAdaptor adaptor)283 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
284   // complex.sub(complex.add(a, b), b) -> a
285   if (auto add = getLhs().getDefiningOp<AddOp>())
286     if (getRhs() == add.getRhs())
287       return add.getLhs();
288 
289   // complex.sub(a, complex.constant<0.0, 0.0>) -> a
290   if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291     auto arrayAttr = constantOp.getValue();
292     if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293         llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
294       return getLhs();
295     }
296   }
297 
298   return {};
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // NegOp
303 //===----------------------------------------------------------------------===//
304 
fold(FoldAdaptor adaptor)305 OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
306   // complex.neg(complex.neg(a)) -> a
307   if (auto negOp = getOperand().getDefiningOp<NegOp>())
308     return negOp.getOperand();
309 
310   return {};
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // LogOp
315 //===----------------------------------------------------------------------===//
316 
fold(FoldAdaptor adaptor)317 OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
318   // complex.log(complex.exp(a)) -> a
319   if (auto expOp = getOperand().getDefiningOp<ExpOp>())
320     return expOp.getOperand();
321 
322   return {};
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // ExpOp
327 //===----------------------------------------------------------------------===//
328 
fold(FoldAdaptor adaptor)329 OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
330   // complex.exp(complex.log(a)) -> a
331   if (auto logOp = getOperand().getDefiningOp<LogOp>())
332     return logOp.getOperand();
333 
334   return {};
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // ConjOp
339 //===----------------------------------------------------------------------===//
340 
fold(FoldAdaptor adaptor)341 OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
342   // complex.conj(complex.conj(a)) -> a
343   if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
344     return conjOp.getOperand();
345 
346   return {};
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // MulOp
351 //===----------------------------------------------------------------------===//
352 
fold(FoldAdaptor adaptor)353 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
354   auto constant = getRhs().getDefiningOp<ConstantOp>();
355   if (!constant)
356     return {};
357 
358   ArrayAttr arrayAttr = constant.getValue();
359   APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
360   APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
361 
362   if (!imag.isZero())
363     return {};
364 
365   // complex.mul(a, complex.constant<1.0, 0.0>) -> a
366   if (real == APFloat(real.getSemantics(), 1))
367     return getLhs();
368 
369   return {};
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // DivOp
374 //===----------------------------------------------------------------------===//
375 
fold(FoldAdaptor adaptor)376 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
377   auto rhs = adaptor.getRhs();
378   if (!rhs)
379     return {};
380 
381   ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
382   if (!arrayAttr || arrayAttr.size() != 2)
383     return {};
384 
385   APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
386   APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
387 
388   if (!imag.isZero())
389     return {};
390 
391   // complex.div(a, complex.constant<1.0, 0.0>) -> a
392   if (real == APFloat(real.getSemantics(), 1))
393     return getLhs();
394 
395   return {};
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // TableGen'd op method definitions
400 //===----------------------------------------------------------------------===//
401 
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
404