xref: /llvm-project/mlir/test/Dialect/Polynomial/ops.mlir (revision f2f6569ecabd54cc7d26bf77424c0b8b674bf14d)
1145176dcSJeremy Kun// RUN: mlir-opt %s | FileCheck %s
2145176dcSJeremy Kun
3145176dcSJeremy Kun// This simply tests for syntax.
4145176dcSJeremy Kun
52ff43ce8SJeremy Kun#my_poly = #polynomial.int_polynomial<1 + x**1024>
62ff43ce8SJeremy Kun#my_poly_2 = #polynomial.int_polynomial<2>
72ff43ce8SJeremy Kun#my_poly_3 = #polynomial.int_polynomial<3x>
82ff43ce8SJeremy Kun#my_poly_4 = #polynomial.int_polynomial<t**3 + 4t + 2>
9145176dcSJeremy Kun#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
102ff43ce8SJeremy Kun#ring2 = #polynomial.ring<coefficientType=f32>
112ff43ce8SJeremy Kun#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
12145176dcSJeremy Kun
132ff43ce8SJeremy Kun#ideal = #polynomial.int_polynomial<-1 + x**1024>
141f46729aSJeremy Kun#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
152ff43ce8SJeremy Kun!poly_ty = !polynomial.polynomial<ring=#ring>
16145176dcSJeremy Kun
172ff43ce8SJeremy Kun#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
181f46729aSJeremy Kun#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
192ff43ce8SJeremy Kun!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
20624c9fc8SJeremy Kun
21*f2f6569eSJeremy Kun#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
22*f2f6569eSJeremy Kun#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
23*f2f6569eSJeremy Kun#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
24*f2f6569eSJeremy Kun!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
25*f2f6569eSJeremy Kun
26145176dcSJeremy Kunmodule {
272ff43ce8SJeremy Kun  func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
28145176dcSJeremy Kun    %c0 = arith.constant 0 : index
29145176dcSJeremy Kun    %two = arith.constant 2 : i16
30145176dcSJeremy Kun    %five = arith.constant 5 : i16
31145176dcSJeremy Kun    %coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
32145176dcSJeremy Kun    %coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
33145176dcSJeremy Kun
342ff43ce8SJeremy Kun    %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
352ff43ce8SJeremy Kun    %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
36145176dcSJeremy Kun
372ff43ce8SJeremy Kun    %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<ring=#ring1>
38145176dcSJeremy Kun
392ff43ce8SJeremy Kun    return %3 : !polynomial.polynomial<ring=#ring1>
40145176dcSJeremy Kun  }
41145176dcSJeremy Kun
422ff43ce8SJeremy Kun  func.func @test_elementwise(%p0 : !polynomial.polynomial<ring=#ring1>, %p1: !polynomial.polynomial<ring=#ring1>) {
432ff43ce8SJeremy Kun    %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
442ff43ce8SJeremy Kun    %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<ring=#ring1>>
45145176dcSJeremy Kun
46145176dcSJeremy Kun    %c = arith.constant 2 : i32
472ff43ce8SJeremy Kun    %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<ring=#ring1>>, i32
48145176dcSJeremy Kun
492ff43ce8SJeremy Kun    %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
502ff43ce8SJeremy Kun    %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
512ff43ce8SJeremy Kun    %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
52145176dcSJeremy Kun
53145176dcSJeremy Kun    return
54145176dcSJeremy Kun  }
55145176dcSJeremy Kun
562ff43ce8SJeremy Kun  func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<ring=#ring1>) {
57145176dcSJeremy Kun    %c0 = arith.constant 0 : index
58145176dcSJeremy Kun    %two = arith.constant 2 : i16
59145176dcSJeremy Kun    %coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
60145176dcSJeremy Kun    // CHECK: from_tensor
612ff43ce8SJeremy Kun    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<ring=#ring1>
62145176dcSJeremy Kun    // CHECK: to_tensor
632ff43ce8SJeremy Kun    %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<ring=#ring1> -> tensor<1024xi16>
64145176dcSJeremy Kun
65145176dcSJeremy Kun    return
66145176dcSJeremy Kun  }
67145176dcSJeremy Kun
682ff43ce8SJeremy Kun  func.func @test_degree(%p0 : !polynomial.polynomial<ring=#ring1>) {
692ff43ce8SJeremy Kun    %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<ring=#ring1> -> (index, i32)
70145176dcSJeremy Kun    return
71145176dcSJeremy Kun  }
72145176dcSJeremy Kun
73145176dcSJeremy Kun  func.func @test_monomial() {
74145176dcSJeremy Kun    %deg = arith.constant 1023 : index
75145176dcSJeremy Kun    %five = arith.constant 5 : i16
762ff43ce8SJeremy Kun    %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<ring=#ring1>
77145176dcSJeremy Kun    return
78145176dcSJeremy Kun  }
79145176dcSJeremy Kun
80145176dcSJeremy Kun  func.func @test_monic_monomial_mul() {
81145176dcSJeremy Kun    %five = arith.constant 5 : index
82ab29203eSJeremy Kun    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
832ff43ce8SJeremy Kun    %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
84145176dcSJeremy Kun    return
85145176dcSJeremy Kun  }
86145176dcSJeremy Kun
87145176dcSJeremy Kun  func.func @test_constant() {
88ab29203eSJeremy Kun    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
89ab29203eSJeremy Kun    %1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
90ab29203eSJeremy Kun    %2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
91ab29203eSJeremy Kun
92ab29203eSJeremy Kun    // Test verbose fallbacks
93ab29203eSJeremy Kun    %verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
94ab29203eSJeremy Kun    %verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
95145176dcSJeremy Kun    return
96145176dcSJeremy Kun  }
97624c9fc8SJeremy Kun
98624c9fc8SJeremy Kun  func.func @test_ntt(%0 : !ntt_poly_ty) {
991f46729aSJeremy Kun    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
100624c9fc8SJeremy Kun    return
101624c9fc8SJeremy Kun  }
102624c9fc8SJeremy Kun
103*f2f6569eSJeremy Kun  func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
104*f2f6569eSJeremy Kun    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
105*f2f6569eSJeremy Kun    return
106*f2f6569eSJeremy Kun  }
107*f2f6569eSJeremy Kun
108624c9fc8SJeremy Kun  func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
1091f46729aSJeremy Kun    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
110624c9fc8SJeremy Kun    return
111624c9fc8SJeremy Kun  }
112145176dcSJeremy Kun}
113