1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file Std transformations to expand Divs operation to help for the 10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor 11 // for Signed Integers. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 16 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Arith/Transforms/Passes.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 namespace mlir { 26 namespace memref { 27 #define GEN_PASS_DEF_EXPANDOPS 28 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 29 } // namespace memref 30 } // namespace mlir 31 32 using namespace mlir; 33 34 namespace { 35 36 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with 37 /// AtomicRMWOpLowering pattern, such as minimum and maximum operations for 38 /// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded 39 /// code. 40 /// 41 /// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 42 /// 43 /// will be lowered to 44 /// 45 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> { 46 /// ^bb0(%current: f32): 47 /// %1 = arith.maximumf %current, %fval : f32 48 /// memref.atomic_yield %1 : f32 49 /// } 50 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> { 51 public: 52 using OpRewritePattern::OpRewritePattern; 53 54 LogicalResult matchAndRewrite(memref::AtomicRMWOp op, 55 PatternRewriter &rewriter) const final { 56 auto loc = op.getLoc(); 57 auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>( 58 loc, op.getMemref(), op.getIndices()); 59 OpBuilder bodyBuilder = 60 OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); 61 62 Value lhs = genericOp.getCurrentValue(); 63 Value rhs = op.getValue(); 64 65 Value arithOp = 66 mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs); 67 bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp); 68 69 rewriter.replaceOp(op, genericOp.getResult()); 70 return success(); 71 } 72 }; 73 74 /// Converts `memref.reshape` that has a target shape of a statically-known 75 /// size to `memref.reinterpret_cast`. 76 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> { 77 public: 78 using OpRewritePattern::OpRewritePattern; 79 80 LogicalResult matchAndRewrite(memref::ReshapeOp op, 81 PatternRewriter &rewriter) const final { 82 auto shapeType = cast<MemRefType>(op.getShape().getType()); 83 if (!shapeType.hasStaticShape()) 84 return failure(); 85 86 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0); 87 SmallVector<OpFoldResult, 4> sizes, strides; 88 sizes.resize(rank); 89 strides.resize(rank); 90 91 Location loc = op.getLoc(); 92 Value stride = nullptr; 93 int64_t staticStride = 1; 94 for (int i = rank - 1; i >= 0; --i) { 95 Value size; 96 // Load dynamic sizes from the shape input, use constants for static dims. 97 if (op.getType().isDynamicDim(i)) { 98 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i); 99 size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index); 100 if (!isa<IndexType>(size.getType())) 101 size = rewriter.create<arith::IndexCastOp>( 102 loc, rewriter.getIndexType(), size); 103 sizes[i] = size; 104 } else { 105 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); 106 size = rewriter.create<arith::ConstantOp>(loc, sizeAttr); 107 sizes[i] = sizeAttr; 108 } 109 if (stride) 110 strides[i] = stride; 111 else 112 strides[i] = rewriter.getIndexAttr(staticStride); 113 114 if (i > 0) { 115 if (stride) { 116 stride = rewriter.create<arith::MulIOp>(loc, stride, size); 117 } else if (op.getType().isDynamicDim(i)) { 118 stride = rewriter.create<arith::MulIOp>( 119 loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride), 120 size); 121 } else { 122 staticStride *= op.getType().getDimSize(i); 123 } 124 } 125 } 126 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( 127 op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0), 128 sizes, strides); 129 return success(); 130 } 131 }; 132 133 struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> { 134 void runOnOperation() override { 135 MLIRContext &ctx = getContext(); 136 137 RewritePatternSet patterns(&ctx); 138 memref::populateExpandOpsPatterns(patterns); 139 ConversionTarget target(ctx); 140 141 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>(); 142 target.addDynamicallyLegalOp<memref::AtomicRMWOp>( 143 [](memref::AtomicRMWOp op) { 144 constexpr std::array shouldBeExpandedKinds = { 145 arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf, 146 arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf}; 147 return !llvm::is_contained(shouldBeExpandedKinds, op.getKind()); 148 }); 149 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) { 150 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape(); 151 }); 152 if (failed(applyPartialConversion(getOperation(), target, 153 std::move(patterns)))) 154 signalPassFailure(); 155 } 156 }; 157 158 } // namespace 159 160 void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) { 161 patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>( 162 patterns.getContext()); 163 } 164 165 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() { 166 return std::make_unique<ExpandOpsPass>(); 167 } 168