1 //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/LLVMIR/VCIXDialect.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 namespace mlir { 20 namespace { 21 22 /// Return number of extracts required to make input VectorType \vt legal and 23 /// also return thatlegal vector type. 24 /// For fixed vectors nothing special is needed. Scalable vectors are legalizes 25 /// according to LLVM's encoding: 26 /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html 27 static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { 28 VectorType vt = cast<VectorType>(type); 29 // To simplify test pass, avoid multi-dimensional vectors. 30 if (!vt || vt.getRank() != 1) 31 return {0, nullptr}; 32 33 if (!vt.isScalable()) 34 return {1, vt}; 35 36 Type eltTy = vt.getElementType(); 37 unsigned sew = 0; 38 if (eltTy.isF32()) 39 sew = 32; 40 else if (eltTy.isF64()) 41 sew = 64; 42 else if (auto intTy = dyn_cast<IntegerType>(eltTy)) 43 sew = intTy.getWidth(); 44 else 45 return {0, nullptr}; 46 47 unsigned eltCount = vt.getShape()[0]; 48 const unsigned lmul = eltCount * sew / 64; 49 50 unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1; 51 return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; 52 } 53 54 /// Replace math.cos(v) operation with vcix.v.iv(v). 55 struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { 56 using OpRewritePattern::OpRewritePattern; 57 58 LogicalResult matchAndRewrite(math::CosOp op, 59 PatternRewriter &rewriter) const override { 60 const Type opType = op.getOperand().getType(); 61 auto [n, legalType] = legalizeVectorType(opType); 62 if (!legalType) 63 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 64 Location loc = op.getLoc(); 65 Value vec = op.getOperand(); 66 Attribute immAttr = rewriter.getI32IntegerAttr(0); 67 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 68 Value rvl = nullptr; 69 if (legalType.isScalable()) 70 // Use arbitrary runtime vector length when vector type is scalable. 71 // Proper conversion pass should take it from the IR. 72 rvl = rewriter.create<arith::ConstantOp>(loc, 73 rewriter.getI64IntegerAttr(9)); 74 Value res; 75 if (n == 1) { 76 res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, 77 immAttr, rvl); 78 } else { 79 const unsigned eltCount = legalType.getShape()[0]; 80 Type eltTy = legalType.getElementType(); 81 Value zero = rewriter.create<arith::ConstantOp>( 82 loc, eltTy, rewriter.getZeroAttr(eltTy)); 83 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 84 for (unsigned i = 0; i < n; ++i) { 85 Value extracted = rewriter.create<vector::ScalableExtractOp>( 86 loc, legalType, vec, i * eltCount); 87 Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, 88 extracted, immAttr, rvl); 89 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 90 i * eltCount); 91 } 92 } 93 rewriter.replaceOp(op, res); 94 return success(); 95 } 96 }; 97 98 // Replace math.sin(v) operation with vcix.v.sv(v, v). 99 struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { 100 using OpRewritePattern::OpRewritePattern; 101 102 LogicalResult matchAndRewrite(math::SinOp op, 103 PatternRewriter &rewriter) const override { 104 const Type opType = op.getOperand().getType(); 105 auto [n, legalType] = legalizeVectorType(opType); 106 if (!legalType) 107 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 108 Location loc = op.getLoc(); 109 Value vec = op.getOperand(); 110 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 111 Value rvl = nullptr; 112 if (legalType.isScalable()) 113 // Use arbitrary runtime vector length when vector type is scalable. 114 // Proper conversion pass should take it from the IR. 115 rvl = rewriter.create<arith::ConstantOp>(loc, 116 rewriter.getI64IntegerAttr(9)); 117 Value res; 118 if (n == 1) { 119 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 120 vec, rvl); 121 } else { 122 const unsigned eltCount = legalType.getShape()[0]; 123 Type eltTy = legalType.getElementType(); 124 Value zero = rewriter.create<arith::ConstantOp>( 125 loc, eltTy, rewriter.getZeroAttr(eltTy)); 126 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 127 for (unsigned i = 0; i < n; ++i) { 128 Value extracted = rewriter.create<vector::ScalableExtractOp>( 129 loc, legalType, vec, i * eltCount); 130 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 131 extracted, extracted, rvl); 132 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 133 i * eltCount); 134 } 135 } 136 rewriter.replaceOp(op, res); 137 return success(); 138 } 139 }; 140 141 // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f). 142 struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { 143 using OpRewritePattern::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(math::TanOp op, 146 PatternRewriter &rewriter) const override { 147 const Type opType = op.getOperand().getType(); 148 auto [n, legalType] = legalizeVectorType(opType); 149 Type eltTy = legalType.getElementType(); 150 if (!legalType) 151 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 152 Location loc = op.getLoc(); 153 Value vec = op.getOperand(); 154 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 155 Value zero = rewriter.create<arith::ConstantOp>( 156 loc, eltTy, rewriter.getZeroAttr(eltTy)); 157 Value rvl = nullptr; 158 if (legalType.isScalable()) 159 // Use arbitrary runtime vector length when vector type is scalable. 160 // Proper conversion pass should take it from the IR. 161 rvl = rewriter.create<arith::ConstantOp>(loc, 162 rewriter.getI64IntegerAttr(9)); 163 Value res; 164 if (n == 1) { 165 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 166 zero, rvl); 167 } else { 168 const unsigned eltCount = legalType.getShape()[0]; 169 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 170 for (unsigned i = 0; i < n; ++i) { 171 Value extracted = rewriter.create<vector::ScalableExtractOp>( 172 loc, legalType, vec, i * eltCount); 173 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 174 extracted, zero, rvl); 175 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 176 i * eltCount); 177 } 178 } 179 rewriter.replaceOp(op, res); 180 return success(); 181 } 182 }; 183 184 // Replace math.log(v) operation with vcix.v.sv(v, 0). 185 struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { 186 using OpRewritePattern::OpRewritePattern; 187 188 LogicalResult matchAndRewrite(math::LogOp op, 189 PatternRewriter &rewriter) const override { 190 const Type opType = op.getOperand().getType(); 191 auto [n, legalType] = legalizeVectorType(opType); 192 if (!legalType) 193 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 194 Location loc = op.getLoc(); 195 Value vec = op.getOperand(); 196 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 197 Value rvl = nullptr; 198 Value zeroInt = rewriter.create<arith::ConstantOp>( 199 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 200 if (legalType.isScalable()) 201 // Use arbitrary runtime vector length when vector type is scalable. 202 // Proper conversion pass should take it from the IR. 203 rvl = rewriter.create<arith::ConstantOp>(loc, 204 rewriter.getI64IntegerAttr(9)); 205 Value res; 206 if (n == 1) { 207 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 208 zeroInt, rvl); 209 } else { 210 const unsigned eltCount = legalType.getShape()[0]; 211 Type eltTy = legalType.getElementType(); 212 Value zero = rewriter.create<arith::ConstantOp>( 213 loc, eltTy, rewriter.getZeroAttr(eltTy)); 214 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 215 for (unsigned i = 0; i < n; ++i) { 216 Value extracted = rewriter.create<vector::ScalableExtractOp>( 217 loc, legalType, vec, i * eltCount); 218 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 219 extracted, zeroInt, rvl); 220 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 221 i * eltCount); 222 } 223 } 224 rewriter.replaceOp(op, res); 225 return success(); 226 } 227 }; 228 229 struct TestMathToVCIX 230 : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> { 231 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX) 232 233 StringRef getArgument() const final { return "test-math-to-vcix"; } 234 235 StringRef getDescription() const final { 236 return "Test lowering patterns that converts some vector operations to " 237 "VCIX. Since DLA can implement VCIX instructions in completely " 238 "different way, conversions of that test pass only lives here."; 239 } 240 241 void getDependentDialects(DialectRegistry ®istry) const override { 242 registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect, 243 vcix::VCIXDialect, vector::VectorDialect>(); 244 } 245 246 void runOnOperation() override { 247 MLIRContext *ctx = &getContext(); 248 RewritePatternSet patterns(ctx); 249 patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>( 250 ctx); 251 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 252 } 253 }; 254 255 } // namespace 256 257 namespace test { 258 void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); } 259 } // namespace test 260 } // namespace mlir 261