1// RUN: mlir-opt -canonicalize %s | FileCheck %s 2#ntt_poly = #polynomial.int_polynomial<-1 + x**8> 3#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly> 4#root = #polynomial.primitive_root<value=31:i32, degree=8:index> 5!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring> 6!tensor_ty = tensor<8xi32, #ntt_ring> 7 8// CHECK-LABEL: @test_canonicalize_intt_after_ntt 9// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]] 10func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty { 11 // CHECK-NOT: polynomial.ntt 12 // CHECK-NOT: polynomial.intt 13 // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]] 14 %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty 15 %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty 16 %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty 17 // CHECK: return %[[RESULT]] : [[T]] 18 return %p2 : !ntt_poly_ty 19} 20 21// CHECK-LABEL: @test_canonicalize_ntt_after_intt 22// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]] 23func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty { 24 // CHECK-NOT: polynomial.intt 25 // CHECK-NOT: polynomial.ntt 26 // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]] 27 %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty 28 %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty 29 %t2 = arith.addi %t1, %t1 : !tensor_ty 30 // CHECK: return %[[RESULT]] : [[T]] 31 return %t2 : !tensor_ty 32} 33 34#cycl_2048 = #polynomial.int_polynomial<1 + x**1024> 35#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048> 36!sub_ty = !polynomial.polynomial<ring=#ring> 37 38// CHECK-LABEL: test_canonicalize_sub 39// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] { 40func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty { 41 %0 = polynomial.sub %poly0, %poly1 : !sub_ty 42 // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32 43 // CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]] 44 // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]] 45 return %0 : !sub_ty 46} 47 48