1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/CodeGen/CodeGen.h" 14 15 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done 16 #include "flang/Optimizer/CodeGen/CGOps.h" 17 #include "flang/Optimizer/Dialect/FIRDialect.h" 18 #include "flang/Optimizer/Dialect/FIROps.h" 19 #include "flang/Optimizer/Dialect/FIRType.h" 20 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 21 #include "mlir/IR/Iterators.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/Support/Debug.h" 25 26 namespace fir { 27 #define GEN_PASS_DEF_CODEGENREWRITE 28 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 29 } // namespace fir 30 31 //===----------------------------------------------------------------------===// 32 // Codegen rewrite: rewriting of subgraphs of ops 33 //===----------------------------------------------------------------------===// 34 35 #define DEBUG_TYPE "flang-codegen-rewrite" 36 37 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 38 fir::ShapeOp shape) { 39 vec.append(shape.getExtents().begin(), shape.getExtents().end()); 40 } 41 42 // Operands of fir.shape_shift split into two vectors. 43 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 44 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 45 fir::ShapeShiftOp shift) { 46 for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); 47 i != endIter;) { 48 shiftVec.push_back(*i++); 49 shapeVec.push_back(*i++); 50 } 51 } 52 53 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 54 fir::ShiftOp shift) { 55 vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); 56 } 57 58 namespace { 59 60 /// Convert fir.embox to the extended form where necessary. 61 /// 62 /// The embox operation can take arguments that specify multidimensional array 63 /// properties at runtime. These properties may be shared between distinct 64 /// objects that have the same properties. Before we lower these small DAGs to 65 /// LLVM-IR, we gather all the information into a single extended operation. For 66 /// example, 67 /// ``` 68 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 69 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 70 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, 71 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 72 /// ``` 73 /// can be rewritten as 74 /// ``` 75 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : 76 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> 77 /// !fir.box<!fir.array<?xi32>> 78 /// ``` 79 class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> { 80 public: 81 using OpRewritePattern::OpRewritePattern; 82 83 llvm::LogicalResult 84 matchAndRewrite(fir::EmboxOp embox, 85 mlir::PatternRewriter &rewriter) const override { 86 // If the embox does not include a shape, then do not convert it 87 if (auto shapeVal = embox.getShape()) 88 return rewriteDynamicShape(embox, rewriter, shapeVal); 89 if (mlir::isa<fir::ClassType>(embox.getType())) 90 TODO(embox.getLoc(), "embox conversion for fir.class type"); 91 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(embox.getType())) 92 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) 93 if (!seqTy.hasDynamicExtents()) 94 return rewriteStaticShape(embox, rewriter, seqTy); 95 return mlir::failure(); 96 } 97 98 llvm::LogicalResult rewriteStaticShape(fir::EmboxOp embox, 99 mlir::PatternRewriter &rewriter, 100 fir::SequenceType seqTy) const { 101 auto loc = embox.getLoc(); 102 llvm::SmallVector<mlir::Value> shapeOpers; 103 auto idxTy = rewriter.getIndexType(); 104 for (auto ext : seqTy.getShape()) { 105 auto iAttr = rewriter.getIndexAttr(ext); 106 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); 107 shapeOpers.push_back(extVal); 108 } 109 auto xbox = rewriter.create<fir::cg::XEmboxOp>( 110 loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt, 111 std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(), 112 embox.getSourceBox(), embox.getAllocatorIdxAttr()); 113 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 114 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 115 return mlir::success(); 116 } 117 118 llvm::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, 119 mlir::PatternRewriter &rewriter, 120 mlir::Value shapeVal) const { 121 auto loc = embox.getLoc(); 122 llvm::SmallVector<mlir::Value> shapeOpers; 123 llvm::SmallVector<mlir::Value> shiftOpers; 124 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) { 125 populateShape(shapeOpers, shapeOp); 126 } else { 127 auto shiftOp = 128 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()); 129 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 130 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 131 } 132 llvm::SmallVector<mlir::Value> sliceOpers; 133 llvm::SmallVector<mlir::Value> subcompOpers; 134 llvm::SmallVector<mlir::Value> substrOpers; 135 if (auto s = embox.getSlice()) 136 if (auto sliceOp = 137 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 138 sliceOpers.assign(sliceOp.getTriples().begin(), 139 sliceOp.getTriples().end()); 140 subcompOpers.assign(sliceOp.getFields().begin(), 141 sliceOp.getFields().end()); 142 substrOpers.assign(sliceOp.getSubstr().begin(), 143 sliceOp.getSubstr().end()); 144 } 145 auto xbox = rewriter.create<fir::cg::XEmboxOp>( 146 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, 147 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(), 148 embox.getSourceBox(), embox.getAllocatorIdxAttr()); 149 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 150 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 151 return mlir::success(); 152 } 153 }; 154 155 /// Convert fir.rebox to the extended form where necessary. 156 /// 157 /// For example, 158 /// ``` 159 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> 160 /// !fir.box<!fir.array<?xi32>> 161 /// ``` 162 /// converted to 163 /// ``` 164 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, 165 /// index, index) -> !fir.box<!fir.array<?xi32>> 166 /// ``` 167 class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> { 168 public: 169 using OpRewritePattern::OpRewritePattern; 170 171 llvm::LogicalResult 172 matchAndRewrite(fir::ReboxOp rebox, 173 mlir::PatternRewriter &rewriter) const override { 174 auto loc = rebox.getLoc(); 175 llvm::SmallVector<mlir::Value> shapeOpers; 176 llvm::SmallVector<mlir::Value> shiftOpers; 177 if (auto shapeVal = rebox.getShape()) { 178 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) 179 populateShape(shapeOpers, shapeOp); 180 else if (auto shiftOp = 181 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) 182 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 183 else if (auto shiftOp = 184 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) 185 populateShift(shiftOpers, shiftOp); 186 else 187 return mlir::failure(); 188 } 189 llvm::SmallVector<mlir::Value> sliceOpers; 190 llvm::SmallVector<mlir::Value> subcompOpers; 191 llvm::SmallVector<mlir::Value> substrOpers; 192 if (auto s = rebox.getSlice()) 193 if (auto sliceOp = 194 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 195 sliceOpers.append(sliceOp.getTriples().begin(), 196 sliceOp.getTriples().end()); 197 subcompOpers.append(sliceOp.getFields().begin(), 198 sliceOp.getFields().end()); 199 substrOpers.append(sliceOp.getSubstr().begin(), 200 sliceOp.getSubstr().end()); 201 } 202 203 auto xRebox = rewriter.create<fir::cg::XReboxOp>( 204 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, 205 sliceOpers, subcompOpers, substrOpers); 206 LLVM_DEBUG(llvm::dbgs() 207 << "rewriting " << rebox << " to " << xRebox << '\n'); 208 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 209 return mlir::success(); 210 } 211 }; 212 213 /// Convert all fir.array_coor to the extended form. 214 /// 215 /// For example, 216 /// ``` 217 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, 218 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 219 /// ``` 220 /// converted to 221 /// ``` 222 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : 223 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> 224 /// !fir.ref<i32> 225 /// ``` 226 class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> { 227 public: 228 using OpRewritePattern::OpRewritePattern; 229 230 llvm::LogicalResult 231 matchAndRewrite(fir::ArrayCoorOp arrCoor, 232 mlir::PatternRewriter &rewriter) const override { 233 auto loc = arrCoor.getLoc(); 234 llvm::SmallVector<mlir::Value> shapeOpers; 235 llvm::SmallVector<mlir::Value> shiftOpers; 236 if (auto shapeVal = arrCoor.getShape()) { 237 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) 238 populateShape(shapeOpers, shapeOp); 239 else if (auto shiftOp = 240 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) 241 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 242 else if (auto shiftOp = 243 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) 244 populateShift(shiftOpers, shiftOp); 245 else 246 return mlir::failure(); 247 } 248 llvm::SmallVector<mlir::Value> sliceOpers; 249 llvm::SmallVector<mlir::Value> subcompOpers; 250 if (auto s = arrCoor.getSlice()) 251 if (auto sliceOp = 252 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { 253 sliceOpers.append(sliceOp.getTriples().begin(), 254 sliceOp.getTriples().end()); 255 subcompOpers.append(sliceOp.getFields().begin(), 256 sliceOp.getFields().end()); 257 assert(sliceOp.getSubstr().empty() && 258 "Don't allow substring operations on array_coor. This " 259 "restriction may be lifted in the future."); 260 } 261 auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>( 262 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, 263 sliceOpers, subcompOpers, arrCoor.getIndices(), 264 arrCoor.getTypeparams()); 265 LLVM_DEBUG(llvm::dbgs() 266 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 267 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 268 return mlir::success(); 269 } 270 }; 271 272 class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { 273 bool preserveDeclare; 274 275 public: 276 using OpRewritePattern::OpRewritePattern; 277 DeclareOpConversion(mlir::MLIRContext *ctx, bool preserveDecl) 278 : OpRewritePattern(ctx), preserveDeclare(preserveDecl) {} 279 280 llvm::LogicalResult 281 matchAndRewrite(fir::DeclareOp declareOp, 282 mlir::PatternRewriter &rewriter) const override { 283 if (!preserveDeclare) { 284 rewriter.replaceOp(declareOp, declareOp.getMemref()); 285 return mlir::success(); 286 } 287 auto loc = declareOp.getLoc(); 288 llvm::SmallVector<mlir::Value> shapeOpers; 289 llvm::SmallVector<mlir::Value> shiftOpers; 290 if (auto shapeVal = declareOp.getShape()) { 291 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) 292 populateShape(shapeOpers, shapeOp); 293 else if (auto shiftOp = 294 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) 295 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 296 else if (auto shiftOp = 297 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) 298 populateShift(shiftOpers, shiftOp); 299 else 300 return mlir::failure(); 301 } 302 // FIXME: Add FortranAttrs and CudaAttrs 303 auto xDeclOp = rewriter.create<fir::cg::XDeclareOp>( 304 loc, declareOp.getType(), declareOp.getMemref(), shapeOpers, shiftOpers, 305 declareOp.getTypeparams(), declareOp.getDummyScope(), 306 declareOp.getUniqName()); 307 LLVM_DEBUG(llvm::dbgs() 308 << "rewriting " << declareOp << " to " << xDeclOp << '\n'); 309 rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults()); 310 return mlir::success(); 311 } 312 }; 313 314 class DummyScopeOpConversion 315 : public mlir::OpRewritePattern<fir::DummyScopeOp> { 316 public: 317 using OpRewritePattern::OpRewritePattern; 318 319 llvm::LogicalResult 320 matchAndRewrite(fir::DummyScopeOp dummyScopeOp, 321 mlir::PatternRewriter &rewriter) const override { 322 rewriter.replaceOpWithNewOp<fir::UndefOp>(dummyScopeOp, 323 dummyScopeOp.getType()); 324 return mlir::success(); 325 } 326 }; 327 328 /// Simple DCE to erase fir.shape/shift/slice/unused shape operands after this 329 /// pass (fir.shape and like have no codegen). 330 /// mlir::RegionDCE is expensive and requires running 331 /// mlir::eraseUnreachableBlocks. It does things that are not needed here, like 332 /// removing unused block arguments. fir.shape/shift/slice cannot be block 333 /// arguments. 334 /// This helper does a naive backward walk of the IR. It is not even guaranteed 335 /// to walk blocks according to backward dominance, but that is good enough for 336 /// what is done here, fir.shape/shift/slice have no usages anymore. The 337 /// backward walk allows getting rid of most of the unused operands, it is not a 338 /// problem to leave some in the weird cases. 339 static void simpleDCE(mlir::RewriterBase &rewriter, mlir::Operation *op) { 340 op->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>( 341 [&](mlir::Operation *subOp) { 342 if (mlir::isOpTriviallyDead(subOp)) 343 rewriter.eraseOp(subOp); 344 }); 345 } 346 347 class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> { 348 public: 349 using CodeGenRewriteBase<CodeGenRewrite>::CodeGenRewriteBase; 350 351 void runOnOperation() override final { 352 mlir::ModuleOp mod = getOperation(); 353 354 auto &context = getContext(); 355 mlir::ConversionTarget target(context); 356 target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect, 357 fir::FIRCodeGenDialect, mlir::func::FuncDialect>(); 358 target.addIllegalOp<fir::ArrayCoorOp>(); 359 target.addIllegalOp<fir::ReboxOp>(); 360 target.addIllegalOp<fir::DeclareOp>(); 361 target.addIllegalOp<fir::DummyScopeOp>(); 362 target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) { 363 return !(embox.getShape() || 364 mlir::isa<fir::SequenceType>( 365 mlir::cast<fir::BaseBoxType>(embox.getType()).getEleTy())); 366 }); 367 mlir::RewritePatternSet patterns(&context); 368 fir::populatePreCGRewritePatterns(patterns, preserveDeclare); 369 if (mlir::failed( 370 mlir::applyPartialConversion(mod, target, std::move(patterns)))) { 371 mlir::emitError(mlir::UnknownLoc::get(&context), 372 "error in running the pre-codegen conversions"); 373 signalPassFailure(); 374 return; 375 } 376 // Erase any residual (fir.shape, fir.slice...). 377 mlir::IRRewriter rewriter(&context); 378 simpleDCE(rewriter, mod.getOperation()); 379 } 380 }; 381 382 } // namespace 383 384 void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns, 385 bool preserveDeclare) { 386 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion, 387 DummyScopeOpConversion>(patterns.getContext()); 388 patterns.add<DeclareOpConversion>(patterns.getContext(), preserveDeclare); 389 } 390