//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::complex; //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "cst"); } bool ConstantOp::isBuildableWith(Attribute value, Type type) { if (auto arrAttr = llvm::dyn_cast(value)) { auto complexTy = llvm::dyn_cast(type); if (!complexTy || arrAttr.size() != 2) return false; auto complexEltTy = complexTy.getElementType(); if (auto fre = llvm::dyn_cast(arrAttr[0])) { auto im = llvm::dyn_cast(arrAttr[1]); return im && fre.getType() == complexEltTy && im.getType() == complexEltTy; } if (auto ire = llvm::dyn_cast(arrAttr[0])) { auto im = llvm::dyn_cast(arrAttr[1]); return im && ire.getType() == complexEltTy && im.getType() == complexEltTy; } } return false; } LogicalResult ConstantOp::verify() { ArrayAttr arrayAttr = getValue(); if (arrayAttr.size() != 2) { return emitOpError( "requires 'value' to be a complex constant, represented as array of " "two values"); } auto complexEltTy = getType().getElementType(); if (!isa(arrayAttr[0]) || !isa(arrayAttr[1])) return emitOpError( "requires attribute's elements to be float or integer attributes"); auto re = llvm::dyn_cast(arrayAttr[0]); auto im = llvm::dyn_cast(arrayAttr[1]); if (complexEltTy != re.getType() || complexEltTy != im.getType()) { return emitOpError() << "requires attribute's element types (" << re.getType() << ", " << im.getType() << ") to match the element type of the op's return type (" << complexEltTy << ")"; } return success(); } //===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) { if (getOperand().getType() == getType()) return getOperand(); return {}; } LogicalResult BitcastOp::verify() { auto operandType = getOperand().getType(); auto resultType = getType(); // We allow this to be legal as it can be folded away. if (operandType == resultType) return success(); if (!operandType.isIntOrFloat() && !isa(operandType)) { return emitOpError("operand must be int/float/complex"); } if (!resultType.isIntOrFloat() && !isa(resultType)) { return emitOpError("result must be int/float/complex"); } if (isa(operandType) == isa(resultType)) { return emitOpError( "requires that either input or output has a complex type"); } if (isa(resultType)) std::swap(operandType, resultType); int32_t operandBitwidth = dyn_cast(operandType) .getElementType() .getIntOrFloatBitWidth() * 2; int32_t resultBitwidth = resultType.getIntOrFloatBitWidth(); if (operandBitwidth != resultBitwidth) { return emitOpError("casting bitwidths do not match"); } return success(); } struct MergeComplexBitcast final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override { if (auto defining = op.getOperand().getDefiningOp()) { if (isa(op.getType()) || isa(defining.getOperand().getType())) { // complex.bitcast requires that input or output is complex. rewriter.replaceOpWithNewOp(op, op.getType(), defining.getOperand()); } else { rewriter.replaceOpWithNewOp(op, op.getType(), defining.getOperand()); } return success(); } if (auto defining = op.getOperand().getDefiningOp()) { rewriter.replaceOpWithNewOp(op, op.getType(), defining.getOperand()); return success(); } return failure(); } }; struct MergeArithBitcast final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override { if (auto defining = op.getOperand().getDefiningOp()) { rewriter.replaceOpWithNewOp(op, op.getType(), defining.getOperand()); return success(); } return failure(); } }; void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// OpFoldResult CreateOp::fold(FoldAdaptor adaptor) { // Fold complex.create(complex.re(op), complex.im(op)). if (auto reOp = getOperand(0).getDefiningOp()) { if (auto imOp = getOperand(1).getDefiningOp()) { if (reOp.getOperand() == imOp.getOperand()) { return reOp.getOperand(); } } } return {}; } //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// OpFoldResult ImOp::fold(FoldAdaptor adaptor) { ArrayAttr arrayAttr = llvm::dyn_cast_if_present(adaptor.getComplex()); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[1]; if (auto createOp = getOperand().getDefiningOp()) return createOp.getOperand(1); return {}; } namespace { template struct FoldComponentNeg final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpKind op, PatternRewriter &rewriter) const override { auto negOp = op.getOperand().template getDefiningOp(); if (!negOp) return failure(); auto createOp = negOp.getComplex().template getDefiningOp(); if (!createOp) return failure(); Type elementType = createOp.getType().getElementType(); assert(isa(elementType)); rewriter.replaceOpWithNewOp( op, elementType, createOp.getOperand(ComponentIndex)); return success(); } }; } // namespace void ImOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// OpFoldResult ReOp::fold(FoldAdaptor adaptor) { ArrayAttr arrayAttr = llvm::dyn_cast_if_present(adaptor.getComplex()); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[0]; if (auto createOp = getOperand().getDefiningOp()) return createOp.getOperand(0); return {}; } void ReOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// OpFoldResult AddOp::fold(FoldAdaptor adaptor) { // complex.add(complex.sub(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) return sub.getLhs(); // complex.add(b, complex.sub(a, b)) -> a if (auto sub = getRhs().getDefiningOp()) if (getLhs() == sub.getRhs()) return sub.getLhs(); // complex.add(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); if (llvm::cast(arrayAttr[0]).getValue().isZero() && llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } return {}; } //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// OpFoldResult SubOp::fold(FoldAdaptor adaptor) { // complex.sub(complex.add(a, b), b) -> a if (auto add = getLhs().getDefiningOp()) if (getRhs() == add.getRhs()) return add.getLhs(); // complex.sub(a, complex.constant<0.0, 0.0>) -> a if (auto constantOp = getRhs().getDefiningOp()) { auto arrayAttr = constantOp.getValue(); if (llvm::cast(arrayAttr[0]).getValue().isZero() && llvm::cast(arrayAttr[1]).getValue().isZero()) { return getLhs(); } } return {}; } //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// OpFoldResult NegOp::fold(FoldAdaptor adaptor) { // complex.neg(complex.neg(a)) -> a if (auto negOp = getOperand().getDefiningOp()) return negOp.getOperand(); return {}; } //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// OpFoldResult LogOp::fold(FoldAdaptor adaptor) { // complex.log(complex.exp(a)) -> a if (auto expOp = getOperand().getDefiningOp()) return expOp.getOperand(); return {}; } //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { // complex.exp(complex.log(a)) -> a if (auto logOp = getOperand().getDefiningOp()) return logOp.getOperand(); return {}; } //===----------------------------------------------------------------------===// // ConjOp //===----------------------------------------------------------------------===// OpFoldResult ConjOp::fold(FoldAdaptor adaptor) { // complex.conj(complex.conj(a)) -> a if (auto conjOp = getOperand().getDefiningOp()) return conjOp.getOperand(); return {}; } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto constant = getRhs().getDefiningOp(); if (!constant) return {}; ArrayAttr arrayAttr = constant.getValue(); APFloat real = cast(arrayAttr[0]).getValue(); APFloat imag = cast(arrayAttr[1]).getValue(); if (!imag.isZero()) return {}; // complex.mul(a, complex.constant<1.0, 0.0>) -> a if (real == APFloat(real.getSemantics(), 1)) return getLhs(); return {}; } //===----------------------------------------------------------------------===// // DivOp //===----------------------------------------------------------------------===// OpFoldResult DivOp::fold(FoldAdaptor adaptor) { auto rhs = adaptor.getRhs(); if (!rhs) return {}; ArrayAttr arrayAttr = dyn_cast(rhs); if (!arrayAttr || arrayAttr.size() != 2) return {}; APFloat real = cast(arrayAttr[0]).getValue(); APFloat imag = cast(arrayAttr[1]).getValue(); if (!imag.isZero()) return {}; // complex.div(a, complex.constant<1.0, 0.0>) -> a if (real == APFloat(real.getSemantics(), 1)) return getLhs(); return {}; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"