1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // Codegen rewrite: rewriting of subgraphs of ops 26 //===----------------------------------------------------------------------===// 27 28 #define DEBUG_TYPE "flang-codegen-rewrite" 29 30 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 31 fir::ShapeOp shape) { 32 vec.append(shape.getExtents().begin(), shape.getExtents().end()); 33 } 34 35 // Operands of fir.shape_shift split into two vectors. 36 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 37 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 38 fir::ShapeShiftOp shift) { 39 for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); 40 i != endIter;) { 41 shiftVec.push_back(*i++); 42 shapeVec.push_back(*i++); 43 } 44 } 45 46 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 47 fir::ShiftOp shift) { 48 vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); 49 } 50 51 namespace { 52 53 /// Convert fir.embox to the extended form where necessary. 54 /// 55 /// The embox operation can take arguments that specify multidimensional array 56 /// properties at runtime. These properties may be shared between distinct 57 /// objects that have the same properties. Before we lower these small DAGs to 58 /// LLVM-IR, we gather all the information into a single extended operation. For 59 /// example, 60 /// ``` 61 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 62 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 63 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, 64 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 65 /// ``` 66 /// can be rewritten as 67 /// ``` 68 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : 69 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> 70 /// !fir.box<!fir.array<?xi32>> 71 /// ``` 72 class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> { 73 public: 74 using OpRewritePattern::OpRewritePattern; 75 76 mlir::LogicalResult 77 matchAndRewrite(fir::EmboxOp embox, 78 mlir::PatternRewriter &rewriter) const override { 79 // If the embox does not include a shape, then do not convert it 80 if (auto shapeVal = embox.getShape()) 81 return rewriteDynamicShape(embox, rewriter, shapeVal); 82 if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>()) 83 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>()) 84 if (!seqTy.hasDynamicExtents()) 85 return rewriteStaticShape(embox, rewriter, seqTy); 86 return mlir::failure(); 87 } 88 89 mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox, 90 mlir::PatternRewriter &rewriter, 91 fir::SequenceType seqTy) const { 92 auto loc = embox.getLoc(); 93 llvm::SmallVector<mlir::Value> shapeOpers; 94 auto idxTy = rewriter.getIndexType(); 95 for (auto ext : seqTy.getShape()) { 96 auto iAttr = rewriter.getIndexAttr(ext); 97 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); 98 shapeOpers.push_back(extVal); 99 } 100 auto xbox = rewriter.create<fir::cg::XEmboxOp>( 101 loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None, 102 llvm::None, llvm::None, llvm::None, embox.getTypeparams()); 103 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 104 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 105 return mlir::success(); 106 } 107 108 mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, 109 mlir::PatternRewriter &rewriter, 110 mlir::Value shapeVal) const { 111 auto loc = embox.getLoc(); 112 llvm::SmallVector<mlir::Value> shapeOpers; 113 llvm::SmallVector<mlir::Value> shiftOpers; 114 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) { 115 populateShape(shapeOpers, shapeOp); 116 } else { 117 auto shiftOp = 118 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()); 119 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 120 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 121 } 122 llvm::SmallVector<mlir::Value> sliceOpers; 123 llvm::SmallVector<mlir::Value> subcompOpers; 124 llvm::SmallVector<mlir::Value> substrOpers; 125 if (auto s = embox.getSlice()) 126 if (auto sliceOp = 127 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 128 sliceOpers.assign(sliceOp.getTriples().begin(), 129 sliceOp.getTriples().end()); 130 subcompOpers.assign(sliceOp.getFields().begin(), 131 sliceOp.getFields().end()); 132 substrOpers.assign(sliceOp.getSubstr().begin(), 133 sliceOp.getSubstr().end()); 134 } 135 auto xbox = rewriter.create<fir::cg::XEmboxOp>( 136 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, 137 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams()); 138 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 139 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 140 return mlir::success(); 141 } 142 }; 143 144 /// Convert fir.rebox to the extended form where necessary. 145 /// 146 /// For example, 147 /// ``` 148 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> 149 /// !fir.box<!fir.array<?xi32>> 150 /// ``` 151 /// converted to 152 /// ``` 153 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, 154 /// index, index) -> !fir.box<!fir.array<?xi32>> 155 /// ``` 156 class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> { 157 public: 158 using OpRewritePattern::OpRewritePattern; 159 160 mlir::LogicalResult 161 matchAndRewrite(fir::ReboxOp rebox, 162 mlir::PatternRewriter &rewriter) const override { 163 auto loc = rebox.getLoc(); 164 llvm::SmallVector<mlir::Value> shapeOpers; 165 llvm::SmallVector<mlir::Value> shiftOpers; 166 if (auto shapeVal = rebox.getShape()) { 167 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) 168 populateShape(shapeOpers, shapeOp); 169 else if (auto shiftOp = 170 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) 171 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 172 else if (auto shiftOp = 173 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) 174 populateShift(shiftOpers, shiftOp); 175 else 176 return mlir::failure(); 177 } 178 llvm::SmallVector<mlir::Value> sliceOpers; 179 llvm::SmallVector<mlir::Value> subcompOpers; 180 llvm::SmallVector<mlir::Value> substrOpers; 181 if (auto s = rebox.getSlice()) 182 if (auto sliceOp = 183 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 184 sliceOpers.append(sliceOp.getTriples().begin(), 185 sliceOp.getTriples().end()); 186 subcompOpers.append(sliceOp.getFields().begin(), 187 sliceOp.getFields().end()); 188 substrOpers.append(sliceOp.getSubstr().begin(), 189 sliceOp.getSubstr().end()); 190 } 191 192 auto xRebox = rewriter.create<fir::cg::XReboxOp>( 193 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, 194 sliceOpers, subcompOpers, substrOpers); 195 LLVM_DEBUG(llvm::dbgs() 196 << "rewriting " << rebox << " to " << xRebox << '\n'); 197 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 198 return mlir::success(); 199 } 200 }; 201 202 /// Convert all fir.array_coor to the extended form. 203 /// 204 /// For example, 205 /// ``` 206 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, 207 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 208 /// ``` 209 /// converted to 210 /// ``` 211 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : 212 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> 213 /// !fir.ref<i32> 214 /// ``` 215 class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> { 216 public: 217 using OpRewritePattern::OpRewritePattern; 218 219 mlir::LogicalResult 220 matchAndRewrite(fir::ArrayCoorOp arrCoor, 221 mlir::PatternRewriter &rewriter) const override { 222 auto loc = arrCoor.getLoc(); 223 llvm::SmallVector<mlir::Value> shapeOpers; 224 llvm::SmallVector<mlir::Value> shiftOpers; 225 if (auto shapeVal = arrCoor.getShape()) { 226 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) 227 populateShape(shapeOpers, shapeOp); 228 else if (auto shiftOp = 229 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) 230 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 231 else if (auto shiftOp = 232 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) 233 populateShift(shiftOpers, shiftOp); 234 else 235 return mlir::failure(); 236 } 237 llvm::SmallVector<mlir::Value> sliceOpers; 238 llvm::SmallVector<mlir::Value> subcompOpers; 239 if (auto s = arrCoor.getSlice()) 240 if (auto sliceOp = 241 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 242 sliceOpers.append(sliceOp.getTriples().begin(), 243 sliceOp.getTriples().end()); 244 subcompOpers.append(sliceOp.getFields().begin(), 245 sliceOp.getFields().end()); 246 assert(sliceOp.getSubstr().empty() && 247 "Don't allow substring operations on array_coor. This " 248 "restriction may be lifted in the future."); 249 } 250 auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>( 251 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, 252 sliceOpers, subcompOpers, arrCoor.getIndices(), 253 arrCoor.getTypeparams()); 254 LLVM_DEBUG(llvm::dbgs() 255 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 256 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 257 return mlir::success(); 258 } 259 }; 260 261 class CodeGenRewrite : public fir::CodeGenRewriteBase<CodeGenRewrite> { 262 public: 263 void runOn(mlir::Operation *op, mlir::Region ®ion) { 264 auto &context = getContext(); 265 mlir::OpBuilder rewriter(&context); 266 mlir::ConversionTarget target(context); 267 target.addLegalDialect<mlir::arith::ArithmeticDialect, fir::FIROpsDialect, 268 fir::FIRCodeGenDialect, mlir::func::FuncDialect>(); 269 target.addIllegalOp<fir::ArrayCoorOp>(); 270 target.addIllegalOp<fir::ReboxOp>(); 271 target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) { 272 return !(embox.getShape() || embox.getType() 273 .cast<fir::BoxType>() 274 .getEleTy() 275 .isa<fir::SequenceType>()); 276 }); 277 mlir::RewritePatternSet patterns(&context); 278 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 279 &context); 280 if (mlir::failed( 281 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 282 mlir::emitError(mlir::UnknownLoc::get(&context), 283 "error in running the pre-codegen conversions"); 284 signalPassFailure(); 285 } 286 // Erase any residual. 287 simplifyRegion(region); 288 } 289 290 void runOnOperation() override final { 291 // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp. 292 auto mod = getOperation(); 293 for (auto func : mod.getOps<mlir::func::FuncOp>()) 294 runOn(func, func.getBody()); 295 for (auto global : mod.getOps<fir::GlobalOp>()) 296 runOn(global, global.getRegion()); 297 } 298 299 // Clean up the region. 300 void simplifyRegion(mlir::Region ®ion) { 301 for (auto &block : region.getBlocks()) 302 for (auto &op : block.getOperations()) { 303 for (auto ® : op.getRegions()) 304 simplifyRegion(reg); 305 maybeEraseOp(&op); 306 } 307 doDCE(); 308 } 309 310 /// Run a simple DCE cleanup to remove any dead code after the rewrites. 311 void doDCE() { 312 std::vector<mlir::Operation *> workList; 313 workList.swap(opsToErase); 314 while (!workList.empty()) { 315 for (auto *op : workList) { 316 std::vector<mlir::Value> opOperands(op->operand_begin(), 317 op->operand_end()); 318 LLVM_DEBUG(llvm::dbgs() << "DCE on " << *op << '\n'); 319 ++numDCE; 320 op->erase(); 321 for (auto opnd : opOperands) 322 maybeEraseOp(opnd.getDefiningOp()); 323 } 324 workList.clear(); 325 workList.swap(opsToErase); 326 } 327 } 328 329 void maybeEraseOp(mlir::Operation *op) { 330 if (!op) 331 return; 332 if (op->hasTrait<mlir::OpTrait::IsTerminator>()) 333 return; 334 if (mlir::isOpTriviallyDead(op)) 335 opsToErase.push_back(op); 336 } 337 338 private: 339 std::vector<mlir::Operation *> opsToErase; 340 }; 341 342 } // namespace 343 344 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 345 return std::make_unique<CodeGenRewrite>(); 346 } 347