1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Complex/IR/Complex.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
15
16 using namespace mlir;
17 using namespace mlir::complex;
18
19 //===----------------------------------------------------------------------===//
20 // ConstantOp
21 //===----------------------------------------------------------------------===//
22
fold(FoldAdaptor adaptor)23 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24 return getValue();
25 }
26
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)27 void ConstantOp::getAsmResultNames(
28 function_ref<void(Value, StringRef)> setNameFn) {
29 setNameFn(getResult(), "cst");
30 }
31
isBuildableWith(Attribute value,Type type)32 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33 if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34 auto complexTy = llvm::dyn_cast<ComplexType>(type);
35 if (!complexTy || arrAttr.size() != 2)
36 return false;
37 auto complexEltTy = complexTy.getElementType();
38 if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40 return im && fre.getType() == complexEltTy &&
41 im.getType() == complexEltTy;
42 }
43 if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45 return im && ire.getType() == complexEltTy &&
46 im.getType() == complexEltTy;
47 }
48 }
49 return false;
50 }
51
verify()52 LogicalResult ConstantOp::verify() {
53 ArrayAttr arrayAttr = getValue();
54 if (arrayAttr.size() != 2) {
55 return emitOpError(
56 "requires 'value' to be a complex constant, represented as array of "
57 "two values");
58 }
59
60 auto complexEltTy = getType().getElementType();
61 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
62 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
63 return emitOpError(
64 "requires attribute's elements to be float or integer attributes");
65 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
66 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
68 return emitOpError()
69 << "requires attribute's element types (" << re.getType() << ", "
70 << im.getType()
71 << ") to match the element type of the op's return type ("
72 << complexEltTy << ")";
73 }
74 return success();
75 }
76
77 //===----------------------------------------------------------------------===//
78 // BitcastOp
79 //===----------------------------------------------------------------------===//
80
fold(FoldAdaptor bitcast)81 OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
82 if (getOperand().getType() == getType())
83 return getOperand();
84
85 return {};
86 }
87
verify()88 LogicalResult BitcastOp::verify() {
89 auto operandType = getOperand().getType();
90 auto resultType = getType();
91
92 // We allow this to be legal as it can be folded away.
93 if (operandType == resultType)
94 return success();
95
96 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97 return emitOpError("operand must be int/float/complex");
98 }
99
100 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101 return emitOpError("result must be int/float/complex");
102 }
103
104 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105 return emitOpError(
106 "requires that either input or output has a complex type");
107 }
108
109 if (isa<ComplexType>(resultType))
110 std::swap(operandType, resultType);
111
112 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113 .getElementType()
114 .getIntOrFloatBitWidth() *
115 2;
116 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117
118 if (operandBitwidth != resultBitwidth) {
119 return emitOpError("casting bitwidths do not match");
120 }
121
122 return success();
123 }
124
125 struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
126 using OpRewritePattern<BitcastOp>::OpRewritePattern;
127
matchAndRewriteMergeComplexBitcast128 LogicalResult matchAndRewrite(BitcastOp op,
129 PatternRewriter &rewriter) const override {
130 if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131 if (isa<ComplexType>(op.getType()) ||
132 isa<ComplexType>(defining.getOperand().getType())) {
133 // complex.bitcast requires that input or output is complex.
134 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135 defining.getOperand());
136 } else {
137 rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
138 defining.getOperand());
139 }
140 return success();
141 }
142
143 if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
145 defining.getOperand());
146 return success();
147 }
148
149 return failure();
150 }
151 };
152
153 struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
154 using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
155
matchAndRewriteMergeArithBitcast156 LogicalResult matchAndRewrite(arith::BitcastOp op,
157 PatternRewriter &rewriter) const override {
158 if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159 rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
160 defining.getOperand());
161 return success();
162 }
163
164 return failure();
165 }
166 };
167
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)168 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
169 MLIRContext *context) {
170 results.add<MergeComplexBitcast, MergeArithBitcast>(context);
171 }
172
173 //===----------------------------------------------------------------------===//
174 // CreateOp
175 //===----------------------------------------------------------------------===//
176
fold(FoldAdaptor adaptor)177 OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
178 // Fold complex.create(complex.re(op), complex.im(op)).
179 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181 if (reOp.getOperand() == imOp.getOperand()) {
182 return reOp.getOperand();
183 }
184 }
185 }
186 return {};
187 }
188
189 //===----------------------------------------------------------------------===//
190 // ImOp
191 //===----------------------------------------------------------------------===//
192
fold(FoldAdaptor adaptor)193 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
194 ArrayAttr arrayAttr =
195 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196 if (arrayAttr && arrayAttr.size() == 2)
197 return arrayAttr[1];
198 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
199 return createOp.getOperand(1);
200 return {};
201 }
202
203 namespace {
204 template <typename OpKind, int ComponentIndex>
205 struct FoldComponentNeg final : OpRewritePattern<OpKind> {
206 using OpRewritePattern<OpKind>::OpRewritePattern;
207
matchAndRewrite__anon416dd25c0111::FoldComponentNeg208 LogicalResult matchAndRewrite(OpKind op,
209 PatternRewriter &rewriter) const override {
210 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211 if (!negOp)
212 return failure();
213
214 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215 if (!createOp)
216 return failure();
217
218 Type elementType = createOp.getType().getElementType();
219 assert(isa<FloatType>(elementType));
220
221 rewriter.replaceOpWithNewOp<arith::NegFOp>(
222 op, elementType, createOp.getOperand(ComponentIndex));
223 return success();
224 }
225 };
226 } // namespace
227
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)228 void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
229 MLIRContext *context) {
230 results.add<FoldComponentNeg<ImOp, 1>>(context);
231 }
232
233 //===----------------------------------------------------------------------===//
234 // ReOp
235 //===----------------------------------------------------------------------===//
236
fold(FoldAdaptor adaptor)237 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
238 ArrayAttr arrayAttr =
239 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240 if (arrayAttr && arrayAttr.size() == 2)
241 return arrayAttr[0];
242 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
243 return createOp.getOperand(0);
244 return {};
245 }
246
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)247 void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
248 MLIRContext *context) {
249 results.add<FoldComponentNeg<ReOp, 0>>(context);
250 }
251
252 //===----------------------------------------------------------------------===//
253 // AddOp
254 //===----------------------------------------------------------------------===//
255
fold(FoldAdaptor adaptor)256 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
257 // complex.add(complex.sub(a, b), b) -> a
258 if (auto sub = getLhs().getDefiningOp<SubOp>())
259 if (getRhs() == sub.getRhs())
260 return sub.getLhs();
261
262 // complex.add(b, complex.sub(a, b)) -> a
263 if (auto sub = getRhs().getDefiningOp<SubOp>())
264 if (getLhs() == sub.getRhs())
265 return sub.getLhs();
266
267 // complex.add(a, complex.constant<0.0, 0.0>) -> a
268 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269 auto arrayAttr = constantOp.getValue();
270 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
272 return getLhs();
273 }
274 }
275
276 return {};
277 }
278
279 //===----------------------------------------------------------------------===//
280 // SubOp
281 //===----------------------------------------------------------------------===//
282
fold(FoldAdaptor adaptor)283 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
284 // complex.sub(complex.add(a, b), b) -> a
285 if (auto add = getLhs().getDefiningOp<AddOp>())
286 if (getRhs() == add.getRhs())
287 return add.getLhs();
288
289 // complex.sub(a, complex.constant<0.0, 0.0>) -> a
290 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291 auto arrayAttr = constantOp.getValue();
292 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
294 return getLhs();
295 }
296 }
297
298 return {};
299 }
300
301 //===----------------------------------------------------------------------===//
302 // NegOp
303 //===----------------------------------------------------------------------===//
304
fold(FoldAdaptor adaptor)305 OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
306 // complex.neg(complex.neg(a)) -> a
307 if (auto negOp = getOperand().getDefiningOp<NegOp>())
308 return negOp.getOperand();
309
310 return {};
311 }
312
313 //===----------------------------------------------------------------------===//
314 // LogOp
315 //===----------------------------------------------------------------------===//
316
fold(FoldAdaptor adaptor)317 OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
318 // complex.log(complex.exp(a)) -> a
319 if (auto expOp = getOperand().getDefiningOp<ExpOp>())
320 return expOp.getOperand();
321
322 return {};
323 }
324
325 //===----------------------------------------------------------------------===//
326 // ExpOp
327 //===----------------------------------------------------------------------===//
328
fold(FoldAdaptor adaptor)329 OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
330 // complex.exp(complex.log(a)) -> a
331 if (auto logOp = getOperand().getDefiningOp<LogOp>())
332 return logOp.getOperand();
333
334 return {};
335 }
336
337 //===----------------------------------------------------------------------===//
338 // ConjOp
339 //===----------------------------------------------------------------------===//
340
fold(FoldAdaptor adaptor)341 OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
342 // complex.conj(complex.conj(a)) -> a
343 if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
344 return conjOp.getOperand();
345
346 return {};
347 }
348
349 //===----------------------------------------------------------------------===//
350 // MulOp
351 //===----------------------------------------------------------------------===//
352
fold(FoldAdaptor adaptor)353 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
354 auto constant = getRhs().getDefiningOp<ConstantOp>();
355 if (!constant)
356 return {};
357
358 ArrayAttr arrayAttr = constant.getValue();
359 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
360 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
361
362 if (!imag.isZero())
363 return {};
364
365 // complex.mul(a, complex.constant<1.0, 0.0>) -> a
366 if (real == APFloat(real.getSemantics(), 1))
367 return getLhs();
368
369 return {};
370 }
371
372 //===----------------------------------------------------------------------===//
373 // DivOp
374 //===----------------------------------------------------------------------===//
375
fold(FoldAdaptor adaptor)376 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
377 auto rhs = adaptor.getRhs();
378 if (!rhs)
379 return {};
380
381 ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
382 if (!arrayAttr || arrayAttr.size() != 2)
383 return {};
384
385 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
386 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
387
388 if (!imag.isZero())
389 return {};
390
391 // complex.div(a, complex.constant<1.0, 0.0>) -> a
392 if (real == APFloat(real.getSemantics(), 1))
393 return getLhs();
394
395 return {};
396 }
397
398 //===----------------------------------------------------------------------===//
399 // TableGen'd op method definitions
400 //===----------------------------------------------------------------------===//
401
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
404