1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 13 #include "mlir/Dialect/ArmSME/Utils/Utils.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 namespace mlir { 18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS 19 #include "mlir/Conversion/Passes.h.inc" 20 } // namespace mlir 21 22 #define DEBUG_TYPE "arith-to-arm-sme" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Conversion helpers 28 //===----------------------------------------------------------------------===// 29 30 /// Returns true if 'val' is a splat of zero, false otherwise. 31 static bool isSplatZero(Type elemType, DenseElementsAttr val) { 32 if (llvm::isa<FloatType>(elemType)) 33 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); 34 if (llvm::isa<IntegerType>(elemType)) 35 return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); 36 return false; 37 } 38 39 namespace { 40 41 //===----------------------------------------------------------------------===// 42 // ConstantOp 43 //===----------------------------------------------------------------------===// 44 45 /// Conversion pattern for dense arith.constant. 46 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { 47 using OpRewritePattern<arith::ConstantOp>::OpRewritePattern; 48 49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp, 50 PatternRewriter &rewriter) const final { 51 auto tileType = dyn_cast<VectorType>(constantOp.getType()); 52 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 53 return failure(); 54 55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); 56 if (!denseAttr || !denseAttr.isSplat()) 57 return failure(); 58 59 auto tileElementType = tileType.getElementType(); 60 61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. 62 if (isSplatZero(tileElementType, denseAttr)) { 63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType); 64 return success(); 65 } 66 67 // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice' 68 // ops that broadcast the constant to each tile slice. 69 auto loc = constantOp.getLoc(); 70 71 // To fill a tile with a constant, we create a 1-D splat of the constant, 72 // then move that into each tile slice (the largest unit we can set at once, 73 // outside of operations like the outerproduct). 74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 75 auto denseAttr1D = DenseElementsAttr::get( 76 tileSliceType, denseAttr.getSplatValue<Attribute>()); 77 auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); 78 79 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 80 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 81 Value currentTile) { 82 // Create 'arm_sme.insert_tile_slice' to write vector to tile 83 // slice. 84 auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 85 loc, tileType, constantOp1D, currentTile, tileSliceIndex); 86 return nextTile.getResult(); 87 }; 88 auto forOp = mlir::arm_sme::createLoopOverTileSlices( 89 rewriter, loc, initTile, makeLoopBody); 90 rewriter.replaceOp(constantOp, forOp.getResult(0)); 91 92 return success(); 93 } 94 }; 95 96 } // namespace 97 98 //===----------------------------------------------------------------------===// 99 // Pattern population 100 //===----------------------------------------------------------------------===// 101 102 void mlir::arith::populateArithToArmSMEConversionPatterns( 103 RewritePatternSet &patterns) { 104 patterns.add<ConstantOpToArmSMELowering>(patterns.getContext()); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // Pass definition 109 //===----------------------------------------------------------------------===// 110 111 namespace { 112 struct ArithToArmSMEConversionPass final 113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> { 114 using impl::ArithToArmSMEConversionPassBase< 115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; 116 117 void runOnOperation() override { 118 RewritePatternSet patterns(&getContext()); 119 arith::populateArithToArmSMEConversionPatterns(patterns); 120 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 121 return signalPassFailure(); 122 } 123 }; 124 } // namespace 125