1 //===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Polynomial/IR/Polynomial.h" 12 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" 13 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "llvm/ADT/APInt.h" 19 20 using namespace mlir; 21 using namespace mlir::polynomial; 22 23 void FromTensorOp::build(OpBuilder &builder, OperationState &result, 24 Value input, RingAttr ring) { 25 TensorType tensorType = dyn_cast<TensorType>(input.getType()); 26 auto bitWidth = tensorType.getElementTypeBitWidth(); 27 APInt cmod(1 + bitWidth, 1); 28 cmod = cmod << bitWidth; 29 Type resultType = PolynomialType::get(builder.getContext(), ring); 30 build(builder, result, resultType, input); 31 } 32 33 LogicalResult FromTensorOp::verify() { 34 ArrayRef<int64_t> tensorShape = getInput().getType().getShape(); 35 RingAttr ring = getOutput().getType().getRing(); 36 IntPolynomialAttr polyMod = ring.getPolynomialModulus(); 37 if (polyMod) { 38 unsigned polyDegree = polyMod.getPolynomial().getDegree(); 39 bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree; 40 if (!compatible) { 41 InFlightDiagnostic diag = emitOpError() 42 << "input type " << getInput().getType() 43 << " does not match output type " 44 << getOutput().getType(); 45 diag.attachNote() 46 << "the input type must be a tensor of shape [d] where d " 47 "is at most the degree of the polynomialModulus of " 48 "the output type's ring attribute"; 49 return diag; 50 } 51 } 52 53 unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth(); 54 if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) { 55 InFlightDiagnostic diag = emitOpError() 56 << "input tensor element type " 57 << getInput().getType().getElementType() 58 << " is too large to fit in the coefficients of " 59 << getOutput().getType(); 60 diag.attachNote() << "the input tensor's elements must be rescaled" 61 " to fit before using from_tensor"; 62 return diag; 63 } 64 65 return success(); 66 } 67 68 LogicalResult ToTensorOp::verify() { 69 ArrayRef<int64_t> tensorShape = getOutput().getType().getShape(); 70 IntPolynomialAttr polyMod = 71 getInput().getType().getRing().getPolynomialModulus(); 72 if (polyMod) { 73 unsigned polyDegree = polyMod.getPolynomial().getDegree(); 74 bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; 75 76 if (compatible) 77 return success(); 78 79 InFlightDiagnostic diag = emitOpError() 80 << "input type " << getInput().getType() 81 << " does not match output type " 82 << getOutput().getType(); 83 diag.attachNote() 84 << "the output type must be a tensor of shape [d] where d " 85 "is at most the degree of the polynomialModulus of " 86 "the input type's ring attribute"; 87 return diag; 88 } 89 90 return success(); 91 } 92 93 LogicalResult MulScalarOp::verify() { 94 Type argType = getPolynomial().getType(); 95 PolynomialType polyType; 96 97 if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) { 98 polyType = cast<PolynomialType>(shapedPolyType.getElementType()); 99 } else { 100 polyType = cast<PolynomialType>(argType); 101 } 102 103 Type coefficientType = polyType.getRing().getCoefficientType(); 104 105 if (coefficientType != getScalar().getType()) 106 return emitOpError() << "polynomial coefficient type " << coefficientType 107 << " does not match scalar type " 108 << getScalar().getType(); 109 110 return success(); 111 } 112 113 /// Test if a value is a primitive nth root of unity modulo cmod. 114 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, 115 const APInt &cmod) { 116 // The first or subsequent multiplications, may overflow the input bit width, 117 // so scale them up to ensure they do not overflow. 118 unsigned requiredBitWidth = 119 std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2); 120 APInt r = APInt(root).zextOrTrunc(requiredBitWidth); 121 APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth); 122 assert(r.ule(cmodExt) && "root must be less than cmod"); 123 uint64_t upperBound = n.getZExtValue(); 124 125 APInt a = r; 126 for (size_t k = 1; k < upperBound; k++) { 127 if (a.isOne()) 128 return false; 129 a = (a * r).urem(cmodExt); 130 } 131 return a.isOne(); 132 } 133 134 /// Verify that the types involved in an NTT or INTT operation are 135 /// compatible. 136 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, 137 RankedTensorType tensorType, 138 std::optional<PrimitiveRootAttr> root) { 139 Attribute encoding = tensorType.getEncoding(); 140 if (!encoding) { 141 return op->emitOpError() 142 << "expects a ring encoding to be provided to the tensor"; 143 } 144 auto encodedRing = dyn_cast<RingAttr>(encoding); 145 if (!encodedRing) { 146 return op->emitOpError() 147 << "the provided tensor encoding is not a ring attribute"; 148 } 149 150 if (encodedRing != ring) { 151 return op->emitOpError() 152 << "encoded ring type " << encodedRing 153 << " is not equivalent to the polynomial ring " << ring; 154 } 155 156 unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree(); 157 ArrayRef<int64_t> tensorShape = tensorType.getShape(); 158 bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; 159 if (!compatible) { 160 InFlightDiagnostic diag = op->emitOpError() 161 << "tensor type " << tensorType 162 << " does not match output type " << ring; 163 diag.attachNote() << "the tensor must have shape [d] where d " 164 "is exactly the degree of the polynomialModulus of " 165 "the polynomial type's ring attribute"; 166 return diag; 167 } 168 169 if (root.has_value()) { 170 APInt rootValue = root.value().getValue().getValue(); 171 APInt rootDegree = root.value().getDegree().getValue(); 172 APInt cmod = ring.getCoefficientModulus().getValue(); 173 if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { 174 return op->emitOpError() 175 << "provided root " << rootValue.getZExtValue() 176 << " is not a primitive root " 177 << "of unity mod " << cmod.getZExtValue() 178 << ", with the specified degree " << rootDegree.getZExtValue(); 179 } 180 } 181 182 return success(); 183 } 184 185 LogicalResult NTTOp::verify() { 186 return verifyNTTOp(this->getOperation(), getInput().getType().getRing(), 187 getOutput().getType(), getRoot()); 188 } 189 190 LogicalResult INTTOp::verify() { 191 return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(), 192 getInput().getType(), getRoot()); 193 } 194 195 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { 196 // Using the built-in parser.parseAttribute requires the full 197 // #polynomial.typed_int_polynomial syntax, which is excessive. 198 // Instead we parse a keyword int to signal it's an integer polynomial 199 Type type; 200 if (succeeded(parser.parseOptionalKeyword("float"))) { 201 Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr); 202 if (floatPolyAttr) { 203 if (parser.parseColon() || parser.parseType(type)) 204 return failure(); 205 result.addAttribute("value", 206 TypedFloatPolynomialAttr::get(type, floatPolyAttr)); 207 result.addTypes(type); 208 return success(); 209 } 210 } 211 212 if (succeeded(parser.parseOptionalKeyword("int"))) { 213 Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr); 214 if (intPolyAttr) { 215 if (parser.parseColon() || parser.parseType(type)) 216 return failure(); 217 218 result.addAttribute("value", 219 TypedIntPolynomialAttr::get(type, intPolyAttr)); 220 result.addTypes(type); 221 return success(); 222 } 223 } 224 225 // In the worst case, still accept the verbose versions. 226 TypedIntPolynomialAttr typedIntPolyAttr; 227 OptionalParseResult res = 228 parser.parseOptionalAttribute<TypedIntPolynomialAttr>( 229 typedIntPolyAttr, "value", result.attributes); 230 if (res.has_value() && succeeded(res.value())) { 231 result.addTypes(typedIntPolyAttr.getType()); 232 return success(); 233 } 234 235 TypedFloatPolynomialAttr typedFloatPolyAttr; 236 res = parser.parseAttribute<TypedFloatPolynomialAttr>( 237 typedFloatPolyAttr, "value", result.attributes); 238 if (res.has_value() && succeeded(res.value())) { 239 result.addTypes(typedFloatPolyAttr.getType()); 240 return success(); 241 } 242 243 return failure(); 244 } 245 246 void ConstantOp::print(OpAsmPrinter &p) { 247 p << " "; 248 if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) { 249 p << "int"; 250 intPoly.getValue().print(p); 251 } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) { 252 p << "float"; 253 floatPoly.getValue().print(p); 254 } else { 255 assert(false && "unexpected attribute type"); 256 } 257 p << " : "; 258 p.printType(getOutput().getType()); 259 } 260 261 LogicalResult ConstantOp::inferReturnTypes( 262 MLIRContext *context, std::optional<mlir::Location> location, 263 ConstantOp::Adaptor adaptor, 264 llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) { 265 Attribute operand = adaptor.getValue(); 266 if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) { 267 inferredReturnTypes.push_back(intPoly.getType()); 268 } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) { 269 inferredReturnTypes.push_back(floatPoly.getType()); 270 } else { 271 assert(false && "unexpected attribute type"); 272 return failure(); 273 } 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // TableGen'd canonicalization patterns 279 //===----------------------------------------------------------------------===// 280 281 namespace { 282 #include "PolynomialCanonicalization.inc" 283 } // namespace 284 285 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results, 286 MLIRContext *context) { 287 results.add<SubAsAdd>(context); 288 } 289 290 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results, 291 MLIRContext *context) { 292 results.add<NTTAfterINTT>(context); 293 } 294 295 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results, 296 MLIRContext *context) { 297 results.add<INTTAfterNTT>(context); 298 } 299