xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (revision d48777ece50c39df553ed779d0771bc9ef6747cf)
155b6f170SJeremy Kun //===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
255b6f170SJeremy Kun //
355b6f170SJeremy Kun // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
455b6f170SJeremy Kun // See https://llvm.org/LICENSE.txt for license information.
555b6f170SJeremy Kun // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
655b6f170SJeremy Kun //
755b6f170SJeremy Kun //===----------------------------------------------------------------------===//
855b6f170SJeremy Kun 
9145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
10932bef23SJeremy Kun #include "mlir/Dialect/Arith/IR/Arith.h"
1155b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
12145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
13145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
14145176dcSJeremy Kun #include "mlir/IR/Builders.h"
15145176dcSJeremy Kun #include "mlir/IR/BuiltinTypes.h"
16145176dcSJeremy Kun #include "mlir/IR/Dialect.h"
17932bef23SJeremy Kun #include "mlir/IR/PatternMatch.h"
18145176dcSJeremy Kun #include "llvm/ADT/APInt.h"
1955b6f170SJeremy Kun 
2055b6f170SJeremy Kun using namespace mlir;
2155b6f170SJeremy Kun using namespace mlir::polynomial;
2255b6f170SJeremy Kun 
23145176dcSJeremy Kun void FromTensorOp::build(OpBuilder &builder, OperationState &result,
24145176dcSJeremy Kun                          Value input, RingAttr ring) {
25145176dcSJeremy Kun   TensorType tensorType = dyn_cast<TensorType>(input.getType());
26145176dcSJeremy Kun   auto bitWidth = tensorType.getElementTypeBitWidth();
27145176dcSJeremy Kun   APInt cmod(1 + bitWidth, 1);
28145176dcSJeremy Kun   cmod = cmod << bitWidth;
29145176dcSJeremy Kun   Type resultType = PolynomialType::get(builder.getContext(), ring);
30145176dcSJeremy Kun   build(builder, result, resultType, input);
31145176dcSJeremy Kun }
32145176dcSJeremy Kun 
33145176dcSJeremy Kun LogicalResult FromTensorOp::verify() {
34145176dcSJeremy Kun   ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
35145176dcSJeremy Kun   RingAttr ring = getOutput().getType().getRing();
36692ae544SJeremy Kun   IntPolynomialAttr polyMod = ring.getPolynomialModulus();
37692ae544SJeremy Kun   if (polyMod) {
38692ae544SJeremy Kun     unsigned polyDegree = polyMod.getPolynomial().getDegree();
39145176dcSJeremy Kun     bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
40145176dcSJeremy Kun     if (!compatible) {
41145176dcSJeremy Kun       InFlightDiagnostic diag = emitOpError()
42145176dcSJeremy Kun                                 << "input type " << getInput().getType()
43145176dcSJeremy Kun                                 << " does not match output type "
44145176dcSJeremy Kun                                 << getOutput().getType();
45692ae544SJeremy Kun       diag.attachNote()
46692ae544SJeremy Kun           << "the input type must be a tensor of shape [d] where d "
47145176dcSJeremy Kun              "is at most the degree of the polynomialModulus of "
48145176dcSJeremy Kun              "the output type's ring attribute";
49145176dcSJeremy Kun       return diag;
50145176dcSJeremy Kun     }
51692ae544SJeremy Kun   }
52145176dcSJeremy Kun 
53145176dcSJeremy Kun   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
54692ae544SJeremy Kun   if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
55145176dcSJeremy Kun     InFlightDiagnostic diag = emitOpError()
56145176dcSJeremy Kun                               << "input tensor element type "
57145176dcSJeremy Kun                               << getInput().getType().getElementType()
58145176dcSJeremy Kun                               << " is too large to fit in the coefficients of "
59145176dcSJeremy Kun                               << getOutput().getType();
60145176dcSJeremy Kun     diag.attachNote() << "the input tensor's elements must be rescaled"
61145176dcSJeremy Kun                          " to fit before using from_tensor";
62145176dcSJeremy Kun     return diag;
63145176dcSJeremy Kun   }
64145176dcSJeremy Kun 
65145176dcSJeremy Kun   return success();
66145176dcSJeremy Kun }
67145176dcSJeremy Kun 
68145176dcSJeremy Kun LogicalResult ToTensorOp::verify() {
69145176dcSJeremy Kun   ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
70692ae544SJeremy Kun   IntPolynomialAttr polyMod =
71692ae544SJeremy Kun       getInput().getType().getRing().getPolynomialModulus();
72692ae544SJeremy Kun   if (polyMod) {
73692ae544SJeremy Kun     unsigned polyDegree = polyMod.getPolynomial().getDegree();
74145176dcSJeremy Kun     bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
75145176dcSJeremy Kun 
76145176dcSJeremy Kun     if (compatible)
77145176dcSJeremy Kun       return success();
78145176dcSJeremy Kun 
79692ae544SJeremy Kun     InFlightDiagnostic diag = emitOpError()
80692ae544SJeremy Kun                               << "input type " << getInput().getType()
81692ae544SJeremy Kun                               << " does not match output type "
82692ae544SJeremy Kun                               << getOutput().getType();
83692ae544SJeremy Kun     diag.attachNote()
84692ae544SJeremy Kun         << "the output type must be a tensor of shape [d] where d "
85145176dcSJeremy Kun            "is at most the degree of the polynomialModulus of "
86145176dcSJeremy Kun            "the input type's ring attribute";
87145176dcSJeremy Kun     return diag;
88145176dcSJeremy Kun   }
89145176dcSJeremy Kun 
90692ae544SJeremy Kun   return success();
91692ae544SJeremy Kun }
92692ae544SJeremy Kun 
93145176dcSJeremy Kun LogicalResult MulScalarOp::verify() {
94145176dcSJeremy Kun   Type argType = getPolynomial().getType();
95145176dcSJeremy Kun   PolynomialType polyType;
96145176dcSJeremy Kun 
97145176dcSJeremy Kun   if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
98145176dcSJeremy Kun     polyType = cast<PolynomialType>(shapedPolyType.getElementType());
99145176dcSJeremy Kun   } else {
100145176dcSJeremy Kun     polyType = cast<PolynomialType>(argType);
101145176dcSJeremy Kun   }
102145176dcSJeremy Kun 
103145176dcSJeremy Kun   Type coefficientType = polyType.getRing().getCoefficientType();
104145176dcSJeremy Kun 
105145176dcSJeremy Kun   if (coefficientType != getScalar().getType())
106145176dcSJeremy Kun     return emitOpError() << "polynomial coefficient type " << coefficientType
107145176dcSJeremy Kun                          << " does not match scalar type "
108145176dcSJeremy Kun                          << getScalar().getType();
109145176dcSJeremy Kun 
110145176dcSJeremy Kun   return success();
111145176dcSJeremy Kun }
112624c9fc8SJeremy Kun 
113624c9fc8SJeremy Kun /// Test if a value is a primitive nth root of unity modulo cmod.
1141f46729aSJeremy Kun bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
115624c9fc8SJeremy Kun                                const APInt &cmod) {
116f2f6569eSJeremy Kun   // The first or subsequent multiplications, may overflow the input bit width,
117f2f6569eSJeremy Kun   // so scale them up to ensure they do not overflow.
118f2f6569eSJeremy Kun   unsigned requiredBitWidth =
119f2f6569eSJeremy Kun       std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
120f2f6569eSJeremy Kun   APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
121f2f6569eSJeremy Kun   APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
122f2f6569eSJeremy Kun   assert(r.ule(cmodExt) && "root must be less than cmod");
123f2f6569eSJeremy Kun   uint64_t upperBound = n.getZExtValue();
124624c9fc8SJeremy Kun 
125624c9fc8SJeremy Kun   APInt a = r;
1261f46729aSJeremy Kun   for (size_t k = 1; k < upperBound; k++) {
127624c9fc8SJeremy Kun     if (a.isOne())
128624c9fc8SJeremy Kun       return false;
129f2f6569eSJeremy Kun     a = (a * r).urem(cmodExt);
130624c9fc8SJeremy Kun   }
131624c9fc8SJeremy Kun   return a.isOne();
132624c9fc8SJeremy Kun }
133624c9fc8SJeremy Kun 
134624c9fc8SJeremy Kun /// Verify that the types involved in an NTT or INTT operation are
135624c9fc8SJeremy Kun /// compatible.
136624c9fc8SJeremy Kun static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
1371f46729aSJeremy Kun                                  RankedTensorType tensorType,
1381f46729aSJeremy Kun                                  std::optional<PrimitiveRootAttr> root) {
139624c9fc8SJeremy Kun   Attribute encoding = tensorType.getEncoding();
140624c9fc8SJeremy Kun   if (!encoding) {
141624c9fc8SJeremy Kun     return op->emitOpError()
142624c9fc8SJeremy Kun            << "expects a ring encoding to be provided to the tensor";
143624c9fc8SJeremy Kun   }
144624c9fc8SJeremy Kun   auto encodedRing = dyn_cast<RingAttr>(encoding);
145624c9fc8SJeremy Kun   if (!encodedRing) {
146624c9fc8SJeremy Kun     return op->emitOpError()
147624c9fc8SJeremy Kun            << "the provided tensor encoding is not a ring attribute";
148624c9fc8SJeremy Kun   }
149624c9fc8SJeremy Kun 
150624c9fc8SJeremy Kun   if (encodedRing != ring) {
151624c9fc8SJeremy Kun     return op->emitOpError()
152624c9fc8SJeremy Kun            << "encoded ring type " << encodedRing
153624c9fc8SJeremy Kun            << " is not equivalent to the polynomial ring " << ring;
154624c9fc8SJeremy Kun   }
155624c9fc8SJeremy Kun 
156624c9fc8SJeremy Kun   unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
157624c9fc8SJeremy Kun   ArrayRef<int64_t> tensorShape = tensorType.getShape();
158624c9fc8SJeremy Kun   bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
159624c9fc8SJeremy Kun   if (!compatible) {
160624c9fc8SJeremy Kun     InFlightDiagnostic diag = op->emitOpError()
161624c9fc8SJeremy Kun                               << "tensor type " << tensorType
162624c9fc8SJeremy Kun                               << " does not match output type " << ring;
163624c9fc8SJeremy Kun     diag.attachNote() << "the tensor must have shape [d] where d "
164624c9fc8SJeremy Kun                          "is exactly the degree of the polynomialModulus of "
165624c9fc8SJeremy Kun                          "the polynomial type's ring attribute";
166624c9fc8SJeremy Kun     return diag;
167624c9fc8SJeremy Kun   }
168624c9fc8SJeremy Kun 
1691f46729aSJeremy Kun   if (root.has_value()) {
1701f46729aSJeremy Kun     APInt rootValue = root.value().getValue().getValue();
1711f46729aSJeremy Kun     APInt rootDegree = root.value().getDegree().getValue();
1721f46729aSJeremy Kun     APInt cmod = ring.getCoefficientModulus().getValue();
1731f46729aSJeremy Kun     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
174624c9fc8SJeremy Kun       return op->emitOpError()
1751f46729aSJeremy Kun              << "provided root " << rootValue.getZExtValue()
1761f46729aSJeremy Kun              << " is not a primitive root "
1771f46729aSJeremy Kun              << "of unity mod " << cmod.getZExtValue()
1781f46729aSJeremy Kun              << ", with the specified degree " << rootDegree.getZExtValue();
179624c9fc8SJeremy Kun     }
180624c9fc8SJeremy Kun   }
181624c9fc8SJeremy Kun 
182624c9fc8SJeremy Kun   return success();
183624c9fc8SJeremy Kun }
184624c9fc8SJeremy Kun 
185624c9fc8SJeremy Kun LogicalResult NTTOp::verify() {
1861f46729aSJeremy Kun   return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
1871f46729aSJeremy Kun                      getOutput().getType(), getRoot());
188624c9fc8SJeremy Kun }
189624c9fc8SJeremy Kun 
190624c9fc8SJeremy Kun LogicalResult INTTOp::verify() {
1911f46729aSJeremy Kun   return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
1921f46729aSJeremy Kun                      getInput().getType(), getRoot());
193624c9fc8SJeremy Kun }
194932bef23SJeremy Kun 
195ab29203eSJeremy Kun ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
196ab29203eSJeremy Kun   // Using the built-in parser.parseAttribute requires the full
197ab29203eSJeremy Kun   // #polynomial.typed_int_polynomial syntax, which is excessive.
198ab29203eSJeremy Kun   // Instead we parse a keyword int to signal it's an integer polynomial
199ab29203eSJeremy Kun   Type type;
200ab29203eSJeremy Kun   if (succeeded(parser.parseOptionalKeyword("float"))) {
201ab29203eSJeremy Kun     Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
202ab29203eSJeremy Kun     if (floatPolyAttr) {
203ab29203eSJeremy Kun       if (parser.parseColon() || parser.parseType(type))
204ab29203eSJeremy Kun         return failure();
205ab29203eSJeremy Kun       result.addAttribute("value",
206ab29203eSJeremy Kun                           TypedFloatPolynomialAttr::get(type, floatPolyAttr));
207ab29203eSJeremy Kun       result.addTypes(type);
208ab29203eSJeremy Kun       return success();
209ab29203eSJeremy Kun     }
210ab29203eSJeremy Kun   }
211ab29203eSJeremy Kun 
212ab29203eSJeremy Kun   if (succeeded(parser.parseOptionalKeyword("int"))) {
213ab29203eSJeremy Kun     Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
214ab29203eSJeremy Kun     if (intPolyAttr) {
215ab29203eSJeremy Kun       if (parser.parseColon() || parser.parseType(type))
216ab29203eSJeremy Kun         return failure();
217ab29203eSJeremy Kun 
218ab29203eSJeremy Kun       result.addAttribute("value",
219ab29203eSJeremy Kun                           TypedIntPolynomialAttr::get(type, intPolyAttr));
220ab29203eSJeremy Kun       result.addTypes(type);
221ab29203eSJeremy Kun       return success();
222ab29203eSJeremy Kun     }
223ab29203eSJeremy Kun   }
224ab29203eSJeremy Kun 
225ab29203eSJeremy Kun   // In the worst case, still accept the verbose versions.
226ab29203eSJeremy Kun   TypedIntPolynomialAttr typedIntPolyAttr;
227ab29203eSJeremy Kun   OptionalParseResult res =
228ab29203eSJeremy Kun       parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
229ab29203eSJeremy Kun           typedIntPolyAttr, "value", result.attributes);
230ab29203eSJeremy Kun   if (res.has_value() && succeeded(res.value())) {
231ab29203eSJeremy Kun     result.addTypes(typedIntPolyAttr.getType());
232ab29203eSJeremy Kun     return success();
233ab29203eSJeremy Kun   }
234ab29203eSJeremy Kun 
235ab29203eSJeremy Kun   TypedFloatPolynomialAttr typedFloatPolyAttr;
236ab29203eSJeremy Kun   res = parser.parseAttribute<TypedFloatPolynomialAttr>(
237ab29203eSJeremy Kun       typedFloatPolyAttr, "value", result.attributes);
238ab29203eSJeremy Kun   if (res.has_value() && succeeded(res.value())) {
239ab29203eSJeremy Kun     result.addTypes(typedFloatPolyAttr.getType());
240ab29203eSJeremy Kun     return success();
241ab29203eSJeremy Kun   }
242ab29203eSJeremy Kun 
243ab29203eSJeremy Kun   return failure();
244ab29203eSJeremy Kun }
245ab29203eSJeremy Kun 
246ab29203eSJeremy Kun void ConstantOp::print(OpAsmPrinter &p) {
247ab29203eSJeremy Kun   p << " ";
248ab29203eSJeremy Kun   if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
249ab29203eSJeremy Kun     p << "int";
250ab29203eSJeremy Kun     intPoly.getValue().print(p);
251ab29203eSJeremy Kun   } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
252ab29203eSJeremy Kun     p << "float";
253ab29203eSJeremy Kun     floatPoly.getValue().print(p);
254ab29203eSJeremy Kun   } else {
255ab29203eSJeremy Kun     assert(false && "unexpected attribute type");
256ab29203eSJeremy Kun   }
257ab29203eSJeremy Kun   p << " : ";
258ab29203eSJeremy Kun   p.printType(getOutput().getType());
259ab29203eSJeremy Kun }
260ab29203eSJeremy Kun 
261ab29203eSJeremy Kun LogicalResult ConstantOp::inferReturnTypes(
262ab29203eSJeremy Kun     MLIRContext *context, std::optional<mlir::Location> location,
263ab29203eSJeremy Kun     ConstantOp::Adaptor adaptor,
264ab29203eSJeremy Kun     llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
265ab29203eSJeremy Kun   Attribute operand = adaptor.getValue();
266ab29203eSJeremy Kun   if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
267ab29203eSJeremy Kun     inferredReturnTypes.push_back(intPoly.getType());
268ab29203eSJeremy Kun   } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
269ab29203eSJeremy Kun     inferredReturnTypes.push_back(floatPoly.getType());
270ab29203eSJeremy Kun   } else {
271ab29203eSJeremy Kun     assert(false && "unexpected attribute type");
272ab29203eSJeremy Kun     return failure();
273ab29203eSJeremy Kun   }
274ab29203eSJeremy Kun   return success();
275ab29203eSJeremy Kun }
276ab29203eSJeremy Kun 
277932bef23SJeremy Kun //===----------------------------------------------------------------------===//
278932bef23SJeremy Kun // TableGen'd canonicalization patterns
279932bef23SJeremy Kun //===----------------------------------------------------------------------===//
280932bef23SJeremy Kun 
281932bef23SJeremy Kun namespace {
282932bef23SJeremy Kun #include "PolynomialCanonicalization.inc"
283932bef23SJeremy Kun } // namespace
284932bef23SJeremy Kun 
285932bef23SJeremy Kun void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
286932bef23SJeremy Kun                                         MLIRContext *context) {
287932bef23SJeremy Kun   results.add<SubAsAdd>(context);
288932bef23SJeremy Kun }
289932bef23SJeremy Kun 
290932bef23SJeremy Kun void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
291932bef23SJeremy Kun                                         MLIRContext *context) {
292*d48777ecSHongren Zheng   results.add<NTTAfterINTT>(context);
293932bef23SJeremy Kun }
294932bef23SJeremy Kun 
295932bef23SJeremy Kun void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
296932bef23SJeremy Kun                                          MLIRContext *context) {
297*d48777ecSHongren Zheng   results.add<INTTAfterNTT>(context);
298932bef23SJeremy Kun }
299