xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (revision d48777ece50c39df553ed779d0771bc9ef6747cf)
1 //===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
12 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
13 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/APInt.h"
19 
20 using namespace mlir;
21 using namespace mlir::polynomial;
22 
23 void FromTensorOp::build(OpBuilder &builder, OperationState &result,
24                          Value input, RingAttr ring) {
25   TensorType tensorType = dyn_cast<TensorType>(input.getType());
26   auto bitWidth = tensorType.getElementTypeBitWidth();
27   APInt cmod(1 + bitWidth, 1);
28   cmod = cmod << bitWidth;
29   Type resultType = PolynomialType::get(builder.getContext(), ring);
30   build(builder, result, resultType, input);
31 }
32 
33 LogicalResult FromTensorOp::verify() {
34   ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
35   RingAttr ring = getOutput().getType().getRing();
36   IntPolynomialAttr polyMod = ring.getPolynomialModulus();
37   if (polyMod) {
38     unsigned polyDegree = polyMod.getPolynomial().getDegree();
39     bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
40     if (!compatible) {
41       InFlightDiagnostic diag = emitOpError()
42                                 << "input type " << getInput().getType()
43                                 << " does not match output type "
44                                 << getOutput().getType();
45       diag.attachNote()
46           << "the input type must be a tensor of shape [d] where d "
47              "is at most the degree of the polynomialModulus of "
48              "the output type's ring attribute";
49       return diag;
50     }
51   }
52 
53   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
54   if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
55     InFlightDiagnostic diag = emitOpError()
56                               << "input tensor element type "
57                               << getInput().getType().getElementType()
58                               << " is too large to fit in the coefficients of "
59                               << getOutput().getType();
60     diag.attachNote() << "the input tensor's elements must be rescaled"
61                          " to fit before using from_tensor";
62     return diag;
63   }
64 
65   return success();
66 }
67 
68 LogicalResult ToTensorOp::verify() {
69   ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
70   IntPolynomialAttr polyMod =
71       getInput().getType().getRing().getPolynomialModulus();
72   if (polyMod) {
73     unsigned polyDegree = polyMod.getPolynomial().getDegree();
74     bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
75 
76     if (compatible)
77       return success();
78 
79     InFlightDiagnostic diag = emitOpError()
80                               << "input type " << getInput().getType()
81                               << " does not match output type "
82                               << getOutput().getType();
83     diag.attachNote()
84         << "the output type must be a tensor of shape [d] where d "
85            "is at most the degree of the polynomialModulus of "
86            "the input type's ring attribute";
87     return diag;
88   }
89 
90   return success();
91 }
92 
93 LogicalResult MulScalarOp::verify() {
94   Type argType = getPolynomial().getType();
95   PolynomialType polyType;
96 
97   if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
98     polyType = cast<PolynomialType>(shapedPolyType.getElementType());
99   } else {
100     polyType = cast<PolynomialType>(argType);
101   }
102 
103   Type coefficientType = polyType.getRing().getCoefficientType();
104 
105   if (coefficientType != getScalar().getType())
106     return emitOpError() << "polynomial coefficient type " << coefficientType
107                          << " does not match scalar type "
108                          << getScalar().getType();
109 
110   return success();
111 }
112 
113 /// Test if a value is a primitive nth root of unity modulo cmod.
114 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
115                                const APInt &cmod) {
116   // The first or subsequent multiplications, may overflow the input bit width,
117   // so scale them up to ensure they do not overflow.
118   unsigned requiredBitWidth =
119       std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
120   APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
121   APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
122   assert(r.ule(cmodExt) && "root must be less than cmod");
123   uint64_t upperBound = n.getZExtValue();
124 
125   APInt a = r;
126   for (size_t k = 1; k < upperBound; k++) {
127     if (a.isOne())
128       return false;
129     a = (a * r).urem(cmodExt);
130   }
131   return a.isOne();
132 }
133 
134 /// Verify that the types involved in an NTT or INTT operation are
135 /// compatible.
136 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
137                                  RankedTensorType tensorType,
138                                  std::optional<PrimitiveRootAttr> root) {
139   Attribute encoding = tensorType.getEncoding();
140   if (!encoding) {
141     return op->emitOpError()
142            << "expects a ring encoding to be provided to the tensor";
143   }
144   auto encodedRing = dyn_cast<RingAttr>(encoding);
145   if (!encodedRing) {
146     return op->emitOpError()
147            << "the provided tensor encoding is not a ring attribute";
148   }
149 
150   if (encodedRing != ring) {
151     return op->emitOpError()
152            << "encoded ring type " << encodedRing
153            << " is not equivalent to the polynomial ring " << ring;
154   }
155 
156   unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
157   ArrayRef<int64_t> tensorShape = tensorType.getShape();
158   bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
159   if (!compatible) {
160     InFlightDiagnostic diag = op->emitOpError()
161                               << "tensor type " << tensorType
162                               << " does not match output type " << ring;
163     diag.attachNote() << "the tensor must have shape [d] where d "
164                          "is exactly the degree of the polynomialModulus of "
165                          "the polynomial type's ring attribute";
166     return diag;
167   }
168 
169   if (root.has_value()) {
170     APInt rootValue = root.value().getValue().getValue();
171     APInt rootDegree = root.value().getDegree().getValue();
172     APInt cmod = ring.getCoefficientModulus().getValue();
173     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
174       return op->emitOpError()
175              << "provided root " << rootValue.getZExtValue()
176              << " is not a primitive root "
177              << "of unity mod " << cmod.getZExtValue()
178              << ", with the specified degree " << rootDegree.getZExtValue();
179     }
180   }
181 
182   return success();
183 }
184 
185 LogicalResult NTTOp::verify() {
186   return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
187                      getOutput().getType(), getRoot());
188 }
189 
190 LogicalResult INTTOp::verify() {
191   return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
192                      getInput().getType(), getRoot());
193 }
194 
195 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
196   // Using the built-in parser.parseAttribute requires the full
197   // #polynomial.typed_int_polynomial syntax, which is excessive.
198   // Instead we parse a keyword int to signal it's an integer polynomial
199   Type type;
200   if (succeeded(parser.parseOptionalKeyword("float"))) {
201     Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
202     if (floatPolyAttr) {
203       if (parser.parseColon() || parser.parseType(type))
204         return failure();
205       result.addAttribute("value",
206                           TypedFloatPolynomialAttr::get(type, floatPolyAttr));
207       result.addTypes(type);
208       return success();
209     }
210   }
211 
212   if (succeeded(parser.parseOptionalKeyword("int"))) {
213     Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
214     if (intPolyAttr) {
215       if (parser.parseColon() || parser.parseType(type))
216         return failure();
217 
218       result.addAttribute("value",
219                           TypedIntPolynomialAttr::get(type, intPolyAttr));
220       result.addTypes(type);
221       return success();
222     }
223   }
224 
225   // In the worst case, still accept the verbose versions.
226   TypedIntPolynomialAttr typedIntPolyAttr;
227   OptionalParseResult res =
228       parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
229           typedIntPolyAttr, "value", result.attributes);
230   if (res.has_value() && succeeded(res.value())) {
231     result.addTypes(typedIntPolyAttr.getType());
232     return success();
233   }
234 
235   TypedFloatPolynomialAttr typedFloatPolyAttr;
236   res = parser.parseAttribute<TypedFloatPolynomialAttr>(
237       typedFloatPolyAttr, "value", result.attributes);
238   if (res.has_value() && succeeded(res.value())) {
239     result.addTypes(typedFloatPolyAttr.getType());
240     return success();
241   }
242 
243   return failure();
244 }
245 
246 void ConstantOp::print(OpAsmPrinter &p) {
247   p << " ";
248   if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
249     p << "int";
250     intPoly.getValue().print(p);
251   } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
252     p << "float";
253     floatPoly.getValue().print(p);
254   } else {
255     assert(false && "unexpected attribute type");
256   }
257   p << " : ";
258   p.printType(getOutput().getType());
259 }
260 
261 LogicalResult ConstantOp::inferReturnTypes(
262     MLIRContext *context, std::optional<mlir::Location> location,
263     ConstantOp::Adaptor adaptor,
264     llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
265   Attribute operand = adaptor.getValue();
266   if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
267     inferredReturnTypes.push_back(intPoly.getType());
268   } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
269     inferredReturnTypes.push_back(floatPoly.getType());
270   } else {
271     assert(false && "unexpected attribute type");
272     return failure();
273   }
274   return success();
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // TableGen'd canonicalization patterns
279 //===----------------------------------------------------------------------===//
280 
281 namespace {
282 #include "PolynomialCanonicalization.inc"
283 } // namespace
284 
285 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
286                                         MLIRContext *context) {
287   results.add<SubAsAdd>(context);
288 }
289 
290 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
291                                         MLIRContext *context) {
292   results.add<NTTAfterINTT>(context);
293 }
294 
295 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
296                                          MLIRContext *context) {
297   results.add<INTTAfterNTT>(context);
298 }
299