xref: /llvm-project/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
19f6c0056SKolya Panchenko //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===//
29f6c0056SKolya Panchenko //
39f6c0056SKolya Panchenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49f6c0056SKolya Panchenko // See https://llvm.org/LICENSE.txt for license information.
59f6c0056SKolya Panchenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69f6c0056SKolya Panchenko //
79f6c0056SKolya Panchenko //===----------------------------------------------------------------------===//
89f6c0056SKolya Panchenko 
99f6c0056SKolya Panchenko #include "mlir/Dialect/Arith/IR/Arith.h"
109f6c0056SKolya Panchenko #include "mlir/Dialect/Func/IR/FuncOps.h"
119f6c0056SKolya Panchenko #include "mlir/Dialect/LLVMIR/VCIXDialect.h"
129f6c0056SKolya Panchenko #include "mlir/Dialect/Math/IR/Math.h"
139f6c0056SKolya Panchenko #include "mlir/Dialect/Vector/IR/VectorOps.h"
149f6c0056SKolya Panchenko #include "mlir/IR/PatternMatch.h"
159f6c0056SKolya Panchenko #include "mlir/Pass/Pass.h"
169f6c0056SKolya Panchenko #include "mlir/Pass/PassManager.h"
179f6c0056SKolya Panchenko #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
189f6c0056SKolya Panchenko 
199f6c0056SKolya Panchenko namespace mlir {
209f6c0056SKolya Panchenko namespace {
219f6c0056SKolya Panchenko 
229f6c0056SKolya Panchenko /// Return number of extracts required to make input VectorType \vt legal and
239f6c0056SKolya Panchenko /// also return thatlegal vector type.
249f6c0056SKolya Panchenko /// For fixed vectors nothing special is needed. Scalable vectors are legalizes
259f6c0056SKolya Panchenko /// according to LLVM's encoding:
269f6c0056SKolya Panchenko /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
279f6c0056SKolya Panchenko static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
28a5757c5bSChristian Sigg   VectorType vt = cast<VectorType>(type);
299f6c0056SKolya Panchenko   // To simplify test pass, avoid multi-dimensional vectors.
309f6c0056SKolya Panchenko   if (!vt || vt.getRank() != 1)
319f6c0056SKolya Panchenko     return {0, nullptr};
329f6c0056SKolya Panchenko 
339f6c0056SKolya Panchenko   if (!vt.isScalable())
349f6c0056SKolya Panchenko     return {1, vt};
359f6c0056SKolya Panchenko 
369f6c0056SKolya Panchenko   Type eltTy = vt.getElementType();
379f6c0056SKolya Panchenko   unsigned sew = 0;
389f6c0056SKolya Panchenko   if (eltTy.isF32())
399f6c0056SKolya Panchenko     sew = 32;
409f6c0056SKolya Panchenko   else if (eltTy.isF64())
419f6c0056SKolya Panchenko     sew = 64;
42a5757c5bSChristian Sigg   else if (auto intTy = dyn_cast<IntegerType>(eltTy))
439f6c0056SKolya Panchenko     sew = intTy.getWidth();
449f6c0056SKolya Panchenko   else
459f6c0056SKolya Panchenko     return {0, nullptr};
469f6c0056SKolya Panchenko 
479f6c0056SKolya Panchenko   unsigned eltCount = vt.getShape()[0];
489f6c0056SKolya Panchenko   const unsigned lmul = eltCount * sew / 64;
499f6c0056SKolya Panchenko 
509f6c0056SKolya Panchenko   unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1;
519f6c0056SKolya Panchenko   return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
529f6c0056SKolya Panchenko }
539f6c0056SKolya Panchenko 
549f6c0056SKolya Panchenko /// Replace math.cos(v) operation with vcix.v.iv(v).
559f6c0056SKolya Panchenko struct MathCosToVCIX final : OpRewritePattern<math::CosOp> {
569f6c0056SKolya Panchenko   using OpRewritePattern::OpRewritePattern;
579f6c0056SKolya Panchenko 
589f6c0056SKolya Panchenko   LogicalResult matchAndRewrite(math::CosOp op,
599f6c0056SKolya Panchenko                                 PatternRewriter &rewriter) const override {
609f6c0056SKolya Panchenko     const Type opType = op.getOperand().getType();
619f6c0056SKolya Panchenko     auto [n, legalType] = legalizeVectorType(opType);
629f6c0056SKolya Panchenko     if (!legalType)
639f6c0056SKolya Panchenko       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
649f6c0056SKolya Panchenko     Location loc = op.getLoc();
659f6c0056SKolya Panchenko     Value vec = op.getOperand();
669f6c0056SKolya Panchenko     Attribute immAttr = rewriter.getI32IntegerAttr(0);
679f6c0056SKolya Panchenko     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
689f6c0056SKolya Panchenko     Value rvl = nullptr;
699f6c0056SKolya Panchenko     if (legalType.isScalable())
709f6c0056SKolya Panchenko       // Use arbitrary runtime vector length when vector type is scalable.
719f6c0056SKolya Panchenko       // Proper conversion pass should take it from the IR.
729f6c0056SKolya Panchenko       rvl = rewriter.create<arith::ConstantOp>(loc,
739f6c0056SKolya Panchenko                                                rewriter.getI64IntegerAttr(9));
749f6c0056SKolya Panchenko     Value res;
759f6c0056SKolya Panchenko     if (n == 1) {
769f6c0056SKolya Panchenko       res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec,
779f6c0056SKolya Panchenko                                                immAttr, rvl);
789f6c0056SKolya Panchenko     } else {
799f6c0056SKolya Panchenko       const unsigned eltCount = legalType.getShape()[0];
809f6c0056SKolya Panchenko       Type eltTy = legalType.getElementType();
819f6c0056SKolya Panchenko       Value zero = rewriter.create<arith::ConstantOp>(
829f6c0056SKolya Panchenko           loc, eltTy, rewriter.getZeroAttr(eltTy));
839f6c0056SKolya Panchenko       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
849f6c0056SKolya Panchenko       for (unsigned i = 0; i < n; ++i) {
859f6c0056SKolya Panchenko         Value extracted = rewriter.create<vector::ScalableExtractOp>(
869f6c0056SKolya Panchenko             loc, legalType, vec, i * eltCount);
879f6c0056SKolya Panchenko         Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr,
889f6c0056SKolya Panchenko                                                      extracted, immAttr, rvl);
899f6c0056SKolya Panchenko         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
909f6c0056SKolya Panchenko                                                         i * eltCount);
919f6c0056SKolya Panchenko       }
929f6c0056SKolya Panchenko     }
939f6c0056SKolya Panchenko     rewriter.replaceOp(op, res);
949f6c0056SKolya Panchenko     return success();
959f6c0056SKolya Panchenko   }
969f6c0056SKolya Panchenko };
979f6c0056SKolya Panchenko 
989f6c0056SKolya Panchenko // Replace math.sin(v) operation with vcix.v.sv(v, v).
999f6c0056SKolya Panchenko struct MathSinToVCIX final : OpRewritePattern<math::SinOp> {
1009f6c0056SKolya Panchenko   using OpRewritePattern::OpRewritePattern;
1019f6c0056SKolya Panchenko 
1029f6c0056SKolya Panchenko   LogicalResult matchAndRewrite(math::SinOp op,
1039f6c0056SKolya Panchenko                                 PatternRewriter &rewriter) const override {
1049f6c0056SKolya Panchenko     const Type opType = op.getOperand().getType();
1059f6c0056SKolya Panchenko     auto [n, legalType] = legalizeVectorType(opType);
1069f6c0056SKolya Panchenko     if (!legalType)
1079f6c0056SKolya Panchenko       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
1089f6c0056SKolya Panchenko     Location loc = op.getLoc();
1099f6c0056SKolya Panchenko     Value vec = op.getOperand();
1109f6c0056SKolya Panchenko     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
1119f6c0056SKolya Panchenko     Value rvl = nullptr;
1129f6c0056SKolya Panchenko     if (legalType.isScalable())
1139f6c0056SKolya Panchenko       // Use arbitrary runtime vector length when vector type is scalable.
1149f6c0056SKolya Panchenko       // Proper conversion pass should take it from the IR.
1159f6c0056SKolya Panchenko       rvl = rewriter.create<arith::ConstantOp>(loc,
1169f6c0056SKolya Panchenko                                                rewriter.getI64IntegerAttr(9));
1179f6c0056SKolya Panchenko     Value res;
1189f6c0056SKolya Panchenko     if (n == 1) {
1199f6c0056SKolya Panchenko       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
1209f6c0056SKolya Panchenko                                             vec, rvl);
1219f6c0056SKolya Panchenko     } else {
1229f6c0056SKolya Panchenko       const unsigned eltCount = legalType.getShape()[0];
1239f6c0056SKolya Panchenko       Type eltTy = legalType.getElementType();
1249f6c0056SKolya Panchenko       Value zero = rewriter.create<arith::ConstantOp>(
1259f6c0056SKolya Panchenko           loc, eltTy, rewriter.getZeroAttr(eltTy));
1269f6c0056SKolya Panchenko       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
1279f6c0056SKolya Panchenko       for (unsigned i = 0; i < n; ++i) {
1289f6c0056SKolya Panchenko         Value extracted = rewriter.create<vector::ScalableExtractOp>(
1299f6c0056SKolya Panchenko             loc, legalType, vec, i * eltCount);
1309f6c0056SKolya Panchenko         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
1319f6c0056SKolya Panchenko                                                   extracted, extracted, rvl);
1329f6c0056SKolya Panchenko         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
1339f6c0056SKolya Panchenko                                                         i * eltCount);
1349f6c0056SKolya Panchenko       }
1359f6c0056SKolya Panchenko     }
1369f6c0056SKolya Panchenko     rewriter.replaceOp(op, res);
1379f6c0056SKolya Panchenko     return success();
1389f6c0056SKolya Panchenko   }
1399f6c0056SKolya Panchenko };
1409f6c0056SKolya Panchenko 
1419f6c0056SKolya Panchenko // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f).
1429f6c0056SKolya Panchenko struct MathTanToVCIX final : OpRewritePattern<math::TanOp> {
1439f6c0056SKolya Panchenko   using OpRewritePattern::OpRewritePattern;
1449f6c0056SKolya Panchenko 
1459f6c0056SKolya Panchenko   LogicalResult matchAndRewrite(math::TanOp op,
1469f6c0056SKolya Panchenko                                 PatternRewriter &rewriter) const override {
1479f6c0056SKolya Panchenko     const Type opType = op.getOperand().getType();
1489f6c0056SKolya Panchenko     auto [n, legalType] = legalizeVectorType(opType);
1499f6c0056SKolya Panchenko     Type eltTy = legalType.getElementType();
1509f6c0056SKolya Panchenko     if (!legalType)
1519f6c0056SKolya Panchenko       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
1529f6c0056SKolya Panchenko     Location loc = op.getLoc();
1539f6c0056SKolya Panchenko     Value vec = op.getOperand();
1549f6c0056SKolya Panchenko     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
1559f6c0056SKolya Panchenko     Value zero = rewriter.create<arith::ConstantOp>(
1569f6c0056SKolya Panchenko         loc, eltTy, rewriter.getZeroAttr(eltTy));
1579f6c0056SKolya Panchenko     Value rvl = nullptr;
1589f6c0056SKolya Panchenko     if (legalType.isScalable())
1599f6c0056SKolya Panchenko       // Use arbitrary runtime vector length when vector type is scalable.
1609f6c0056SKolya Panchenko       // Proper conversion pass should take it from the IR.
1619f6c0056SKolya Panchenko       rvl = rewriter.create<arith::ConstantOp>(loc,
1629f6c0056SKolya Panchenko                                                rewriter.getI64IntegerAttr(9));
1639f6c0056SKolya Panchenko     Value res;
1649f6c0056SKolya Panchenko     if (n == 1) {
1659f6c0056SKolya Panchenko       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
1669f6c0056SKolya Panchenko                                             zero, rvl);
1679f6c0056SKolya Panchenko     } else {
1689f6c0056SKolya Panchenko       const unsigned eltCount = legalType.getShape()[0];
1699f6c0056SKolya Panchenko       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
1709f6c0056SKolya Panchenko       for (unsigned i = 0; i < n; ++i) {
1719f6c0056SKolya Panchenko         Value extracted = rewriter.create<vector::ScalableExtractOp>(
1729f6c0056SKolya Panchenko             loc, legalType, vec, i * eltCount);
1739f6c0056SKolya Panchenko         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
1749f6c0056SKolya Panchenko                                                   extracted, zero, rvl);
1759f6c0056SKolya Panchenko         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
1769f6c0056SKolya Panchenko                                                         i * eltCount);
1779f6c0056SKolya Panchenko       }
1789f6c0056SKolya Panchenko     }
1799f6c0056SKolya Panchenko     rewriter.replaceOp(op, res);
1809f6c0056SKolya Panchenko     return success();
1819f6c0056SKolya Panchenko   }
1829f6c0056SKolya Panchenko };
1839f6c0056SKolya Panchenko 
1849f6c0056SKolya Panchenko // Replace math.log(v) operation with vcix.v.sv(v, 0).
1859f6c0056SKolya Panchenko struct MathLogToVCIX final : OpRewritePattern<math::LogOp> {
1869f6c0056SKolya Panchenko   using OpRewritePattern::OpRewritePattern;
1879f6c0056SKolya Panchenko 
1889f6c0056SKolya Panchenko   LogicalResult matchAndRewrite(math::LogOp op,
1899f6c0056SKolya Panchenko                                 PatternRewriter &rewriter) const override {
1909f6c0056SKolya Panchenko     const Type opType = op.getOperand().getType();
1919f6c0056SKolya Panchenko     auto [n, legalType] = legalizeVectorType(opType);
1929f6c0056SKolya Panchenko     if (!legalType)
1939f6c0056SKolya Panchenko       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
1949f6c0056SKolya Panchenko     Location loc = op.getLoc();
1959f6c0056SKolya Panchenko     Value vec = op.getOperand();
1969f6c0056SKolya Panchenko     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
1979f6c0056SKolya Panchenko     Value rvl = nullptr;
1989f6c0056SKolya Panchenko     Value zeroInt = rewriter.create<arith::ConstantOp>(
1999f6c0056SKolya Panchenko         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
2009f6c0056SKolya Panchenko     if (legalType.isScalable())
2019f6c0056SKolya Panchenko       // Use arbitrary runtime vector length when vector type is scalable.
2029f6c0056SKolya Panchenko       // Proper conversion pass should take it from the IR.
2039f6c0056SKolya Panchenko       rvl = rewriter.create<arith::ConstantOp>(loc,
2049f6c0056SKolya Panchenko                                                rewriter.getI64IntegerAttr(9));
2059f6c0056SKolya Panchenko     Value res;
2069f6c0056SKolya Panchenko     if (n == 1) {
2079f6c0056SKolya Panchenko       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
2089f6c0056SKolya Panchenko                                             zeroInt, rvl);
2099f6c0056SKolya Panchenko     } else {
2109f6c0056SKolya Panchenko       const unsigned eltCount = legalType.getShape()[0];
2119f6c0056SKolya Panchenko       Type eltTy = legalType.getElementType();
2129f6c0056SKolya Panchenko       Value zero = rewriter.create<arith::ConstantOp>(
2139f6c0056SKolya Panchenko           loc, eltTy, rewriter.getZeroAttr(eltTy));
2149f6c0056SKolya Panchenko       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
2159f6c0056SKolya Panchenko       for (unsigned i = 0; i < n; ++i) {
2169f6c0056SKolya Panchenko         Value extracted = rewriter.create<vector::ScalableExtractOp>(
2179f6c0056SKolya Panchenko             loc, legalType, vec, i * eltCount);
2189f6c0056SKolya Panchenko         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
2199f6c0056SKolya Panchenko                                                   extracted, zeroInt, rvl);
2209f6c0056SKolya Panchenko         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
2219f6c0056SKolya Panchenko                                                         i * eltCount);
2229f6c0056SKolya Panchenko       }
2239f6c0056SKolya Panchenko     }
2249f6c0056SKolya Panchenko     rewriter.replaceOp(op, res);
2259f6c0056SKolya Panchenko     return success();
2269f6c0056SKolya Panchenko   }
2279f6c0056SKolya Panchenko };
2289f6c0056SKolya Panchenko 
2299f6c0056SKolya Panchenko struct TestMathToVCIX
2309f6c0056SKolya Panchenko     : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> {
2319f6c0056SKolya Panchenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX)
2329f6c0056SKolya Panchenko 
2339f6c0056SKolya Panchenko   StringRef getArgument() const final { return "test-math-to-vcix"; }
2349f6c0056SKolya Panchenko 
2359f6c0056SKolya Panchenko   StringRef getDescription() const final {
2369f6c0056SKolya Panchenko     return "Test lowering patterns that converts some vector operations to "
2379f6c0056SKolya Panchenko            "VCIX. Since DLA can implement VCIX instructions in completely "
2389f6c0056SKolya Panchenko            "different way, conversions of that test pass only lives here.";
2399f6c0056SKolya Panchenko   }
2409f6c0056SKolya Panchenko 
2419f6c0056SKolya Panchenko   void getDependentDialects(DialectRegistry &registry) const override {
2429f6c0056SKolya Panchenko     registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect,
2439f6c0056SKolya Panchenko                     vcix::VCIXDialect, vector::VectorDialect>();
2449f6c0056SKolya Panchenko   }
2459f6c0056SKolya Panchenko 
2469f6c0056SKolya Panchenko   void runOnOperation() override {
2479f6c0056SKolya Panchenko     MLIRContext *ctx = &getContext();
2489f6c0056SKolya Panchenko     RewritePatternSet patterns(ctx);
2499f6c0056SKolya Panchenko     patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
2509f6c0056SKolya Panchenko         ctx);
251*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2529f6c0056SKolya Panchenko   }
2539f6c0056SKolya Panchenko };
2549f6c0056SKolya Panchenko 
2559f6c0056SKolya Panchenko } // namespace
2569f6c0056SKolya Panchenko 
2579f6c0056SKolya Panchenko namespace test {
2589f6c0056SKolya Panchenko void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); }
2599f6c0056SKolya Panchenko } // namespace test
2609f6c0056SKolya Panchenko } // namespace mlir
261