xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (revision d48777ece50c39df553ed779d0771bc9ef6747cf)
1//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef POLYNOMIAL_CANONICALIZATION
10#define POLYNOMIAL_CANONICALIZATION
11
12include "mlir/Dialect/Arith/IR/ArithOps.td"
13include "mlir/Dialect/Polynomial/IR/Polynomial.td"
14include "mlir/IR/OpBase.td"
15include "mlir/IR/PatternBase.td"
16
17def Equal : Constraint<CPred<"$0 == $1">>;
18
19// Get a -1 integer attribute of the same type as the polynomial SSA value's
20// ring coefficient type.
21def getMinusOne
22  : NativeCodeCall<
23      "$_builder.getIntegerAttr("
24        "cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
25
26def SubAsAdd : Pat<
27  (Polynomial_SubOp $f, $g),
28  (Polynomial_AddOp $f,
29    (Polynomial_MulScalarOp $g,
30      (Arith_ConstantOp (getMinusOne $g))))>;
31
32def INTTAfterNTT : Pat<
33  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
34  (replaceWithValue $poly),
35  [(Equal $r1, $r2)]
36>;
37
38def NTTAfterINTT : Pat<
39  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
40  (replaceWithValue $tensor),
41  [(Equal $r1, $r2)]
42>;
43
44#endif  // POLYNOMIAL_CANONICALIZATION
45