155b6f170SJeremy Kun //===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===// 255b6f170SJeremy Kun // 355b6f170SJeremy Kun // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 455b6f170SJeremy Kun // See https://llvm.org/LICENSE.txt for license information. 555b6f170SJeremy Kun // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 655b6f170SJeremy Kun // 755b6f170SJeremy Kun //===----------------------------------------------------------------------===// 855b6f170SJeremy Kun 9145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h" 10932bef23SJeremy Kun #include "mlir/Dialect/Arith/IR/Arith.h" 1155b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/Polynomial.h" 12145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" 13145176dcSJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h" 14145176dcSJeremy Kun #include "mlir/IR/Builders.h" 15145176dcSJeremy Kun #include "mlir/IR/BuiltinTypes.h" 16145176dcSJeremy Kun #include "mlir/IR/Dialect.h" 17932bef23SJeremy Kun #include "mlir/IR/PatternMatch.h" 18145176dcSJeremy Kun #include "llvm/ADT/APInt.h" 1955b6f170SJeremy Kun 2055b6f170SJeremy Kun using namespace mlir; 2155b6f170SJeremy Kun using namespace mlir::polynomial; 2255b6f170SJeremy Kun 23145176dcSJeremy Kun void FromTensorOp::build(OpBuilder &builder, OperationState &result, 24145176dcSJeremy Kun Value input, RingAttr ring) { 25145176dcSJeremy Kun TensorType tensorType = dyn_cast<TensorType>(input.getType()); 26145176dcSJeremy Kun auto bitWidth = tensorType.getElementTypeBitWidth(); 27145176dcSJeremy Kun APInt cmod(1 + bitWidth, 1); 28145176dcSJeremy Kun cmod = cmod << bitWidth; 29145176dcSJeremy Kun Type resultType = PolynomialType::get(builder.getContext(), ring); 30145176dcSJeremy Kun build(builder, result, resultType, input); 31145176dcSJeremy Kun } 32145176dcSJeremy Kun 33145176dcSJeremy Kun LogicalResult FromTensorOp::verify() { 34145176dcSJeremy Kun ArrayRef<int64_t> tensorShape = getInput().getType().getShape(); 35145176dcSJeremy Kun RingAttr ring = getOutput().getType().getRing(); 36692ae544SJeremy Kun IntPolynomialAttr polyMod = ring.getPolynomialModulus(); 37692ae544SJeremy Kun if (polyMod) { 38692ae544SJeremy Kun unsigned polyDegree = polyMod.getPolynomial().getDegree(); 39145176dcSJeremy Kun bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree; 40145176dcSJeremy Kun if (!compatible) { 41145176dcSJeremy Kun InFlightDiagnostic diag = emitOpError() 42145176dcSJeremy Kun << "input type " << getInput().getType() 43145176dcSJeremy Kun << " does not match output type " 44145176dcSJeremy Kun << getOutput().getType(); 45692ae544SJeremy Kun diag.attachNote() 46692ae544SJeremy Kun << "the input type must be a tensor of shape [d] where d " 47145176dcSJeremy Kun "is at most the degree of the polynomialModulus of " 48145176dcSJeremy Kun "the output type's ring attribute"; 49145176dcSJeremy Kun return diag; 50145176dcSJeremy Kun } 51692ae544SJeremy Kun } 52145176dcSJeremy Kun 53145176dcSJeremy Kun unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth(); 54692ae544SJeremy Kun if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) { 55145176dcSJeremy Kun InFlightDiagnostic diag = emitOpError() 56145176dcSJeremy Kun << "input tensor element type " 57145176dcSJeremy Kun << getInput().getType().getElementType() 58145176dcSJeremy Kun << " is too large to fit in the coefficients of " 59145176dcSJeremy Kun << getOutput().getType(); 60145176dcSJeremy Kun diag.attachNote() << "the input tensor's elements must be rescaled" 61145176dcSJeremy Kun " to fit before using from_tensor"; 62145176dcSJeremy Kun return diag; 63145176dcSJeremy Kun } 64145176dcSJeremy Kun 65145176dcSJeremy Kun return success(); 66145176dcSJeremy Kun } 67145176dcSJeremy Kun 68145176dcSJeremy Kun LogicalResult ToTensorOp::verify() { 69145176dcSJeremy Kun ArrayRef<int64_t> tensorShape = getOutput().getType().getShape(); 70692ae544SJeremy Kun IntPolynomialAttr polyMod = 71692ae544SJeremy Kun getInput().getType().getRing().getPolynomialModulus(); 72692ae544SJeremy Kun if (polyMod) { 73692ae544SJeremy Kun unsigned polyDegree = polyMod.getPolynomial().getDegree(); 74145176dcSJeremy Kun bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; 75145176dcSJeremy Kun 76145176dcSJeremy Kun if (compatible) 77145176dcSJeremy Kun return success(); 78145176dcSJeremy Kun 79692ae544SJeremy Kun InFlightDiagnostic diag = emitOpError() 80692ae544SJeremy Kun << "input type " << getInput().getType() 81692ae544SJeremy Kun << " does not match output type " 82692ae544SJeremy Kun << getOutput().getType(); 83692ae544SJeremy Kun diag.attachNote() 84692ae544SJeremy Kun << "the output type must be a tensor of shape [d] where d " 85145176dcSJeremy Kun "is at most the degree of the polynomialModulus of " 86145176dcSJeremy Kun "the input type's ring attribute"; 87145176dcSJeremy Kun return diag; 88145176dcSJeremy Kun } 89145176dcSJeremy Kun 90692ae544SJeremy Kun return success(); 91692ae544SJeremy Kun } 92692ae544SJeremy Kun 93145176dcSJeremy Kun LogicalResult MulScalarOp::verify() { 94145176dcSJeremy Kun Type argType = getPolynomial().getType(); 95145176dcSJeremy Kun PolynomialType polyType; 96145176dcSJeremy Kun 97145176dcSJeremy Kun if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) { 98145176dcSJeremy Kun polyType = cast<PolynomialType>(shapedPolyType.getElementType()); 99145176dcSJeremy Kun } else { 100145176dcSJeremy Kun polyType = cast<PolynomialType>(argType); 101145176dcSJeremy Kun } 102145176dcSJeremy Kun 103145176dcSJeremy Kun Type coefficientType = polyType.getRing().getCoefficientType(); 104145176dcSJeremy Kun 105145176dcSJeremy Kun if (coefficientType != getScalar().getType()) 106145176dcSJeremy Kun return emitOpError() << "polynomial coefficient type " << coefficientType 107145176dcSJeremy Kun << " does not match scalar type " 108145176dcSJeremy Kun << getScalar().getType(); 109145176dcSJeremy Kun 110145176dcSJeremy Kun return success(); 111145176dcSJeremy Kun } 112624c9fc8SJeremy Kun 113624c9fc8SJeremy Kun /// Test if a value is a primitive nth root of unity modulo cmod. 1141f46729aSJeremy Kun bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, 115624c9fc8SJeremy Kun const APInt &cmod) { 116f2f6569eSJeremy Kun // The first or subsequent multiplications, may overflow the input bit width, 117f2f6569eSJeremy Kun // so scale them up to ensure they do not overflow. 118f2f6569eSJeremy Kun unsigned requiredBitWidth = 119f2f6569eSJeremy Kun std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2); 120f2f6569eSJeremy Kun APInt r = APInt(root).zextOrTrunc(requiredBitWidth); 121f2f6569eSJeremy Kun APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth); 122f2f6569eSJeremy Kun assert(r.ule(cmodExt) && "root must be less than cmod"); 123f2f6569eSJeremy Kun uint64_t upperBound = n.getZExtValue(); 124624c9fc8SJeremy Kun 125624c9fc8SJeremy Kun APInt a = r; 1261f46729aSJeremy Kun for (size_t k = 1; k < upperBound; k++) { 127624c9fc8SJeremy Kun if (a.isOne()) 128624c9fc8SJeremy Kun return false; 129f2f6569eSJeremy Kun a = (a * r).urem(cmodExt); 130624c9fc8SJeremy Kun } 131624c9fc8SJeremy Kun return a.isOne(); 132624c9fc8SJeremy Kun } 133624c9fc8SJeremy Kun 134624c9fc8SJeremy Kun /// Verify that the types involved in an NTT or INTT operation are 135624c9fc8SJeremy Kun /// compatible. 136624c9fc8SJeremy Kun static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, 1371f46729aSJeremy Kun RankedTensorType tensorType, 1381f46729aSJeremy Kun std::optional<PrimitiveRootAttr> root) { 139624c9fc8SJeremy Kun Attribute encoding = tensorType.getEncoding(); 140624c9fc8SJeremy Kun if (!encoding) { 141624c9fc8SJeremy Kun return op->emitOpError() 142624c9fc8SJeremy Kun << "expects a ring encoding to be provided to the tensor"; 143624c9fc8SJeremy Kun } 144624c9fc8SJeremy Kun auto encodedRing = dyn_cast<RingAttr>(encoding); 145624c9fc8SJeremy Kun if (!encodedRing) { 146624c9fc8SJeremy Kun return op->emitOpError() 147624c9fc8SJeremy Kun << "the provided tensor encoding is not a ring attribute"; 148624c9fc8SJeremy Kun } 149624c9fc8SJeremy Kun 150624c9fc8SJeremy Kun if (encodedRing != ring) { 151624c9fc8SJeremy Kun return op->emitOpError() 152624c9fc8SJeremy Kun << "encoded ring type " << encodedRing 153624c9fc8SJeremy Kun << " is not equivalent to the polynomial ring " << ring; 154624c9fc8SJeremy Kun } 155624c9fc8SJeremy Kun 156624c9fc8SJeremy Kun unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree(); 157624c9fc8SJeremy Kun ArrayRef<int64_t> tensorShape = tensorType.getShape(); 158624c9fc8SJeremy Kun bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; 159624c9fc8SJeremy Kun if (!compatible) { 160624c9fc8SJeremy Kun InFlightDiagnostic diag = op->emitOpError() 161624c9fc8SJeremy Kun << "tensor type " << tensorType 162624c9fc8SJeremy Kun << " does not match output type " << ring; 163624c9fc8SJeremy Kun diag.attachNote() << "the tensor must have shape [d] where d " 164624c9fc8SJeremy Kun "is exactly the degree of the polynomialModulus of " 165624c9fc8SJeremy Kun "the polynomial type's ring attribute"; 166624c9fc8SJeremy Kun return diag; 167624c9fc8SJeremy Kun } 168624c9fc8SJeremy Kun 1691f46729aSJeremy Kun if (root.has_value()) { 1701f46729aSJeremy Kun APInt rootValue = root.value().getValue().getValue(); 1711f46729aSJeremy Kun APInt rootDegree = root.value().getDegree().getValue(); 1721f46729aSJeremy Kun APInt cmod = ring.getCoefficientModulus().getValue(); 1731f46729aSJeremy Kun if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { 174624c9fc8SJeremy Kun return op->emitOpError() 1751f46729aSJeremy Kun << "provided root " << rootValue.getZExtValue() 1761f46729aSJeremy Kun << " is not a primitive root " 1771f46729aSJeremy Kun << "of unity mod " << cmod.getZExtValue() 1781f46729aSJeremy Kun << ", with the specified degree " << rootDegree.getZExtValue(); 179624c9fc8SJeremy Kun } 180624c9fc8SJeremy Kun } 181624c9fc8SJeremy Kun 182624c9fc8SJeremy Kun return success(); 183624c9fc8SJeremy Kun } 184624c9fc8SJeremy Kun 185624c9fc8SJeremy Kun LogicalResult NTTOp::verify() { 1861f46729aSJeremy Kun return verifyNTTOp(this->getOperation(), getInput().getType().getRing(), 1871f46729aSJeremy Kun getOutput().getType(), getRoot()); 188624c9fc8SJeremy Kun } 189624c9fc8SJeremy Kun 190624c9fc8SJeremy Kun LogicalResult INTTOp::verify() { 1911f46729aSJeremy Kun return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(), 1921f46729aSJeremy Kun getInput().getType(), getRoot()); 193624c9fc8SJeremy Kun } 194932bef23SJeremy Kun 195ab29203eSJeremy Kun ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { 196ab29203eSJeremy Kun // Using the built-in parser.parseAttribute requires the full 197ab29203eSJeremy Kun // #polynomial.typed_int_polynomial syntax, which is excessive. 198ab29203eSJeremy Kun // Instead we parse a keyword int to signal it's an integer polynomial 199ab29203eSJeremy Kun Type type; 200ab29203eSJeremy Kun if (succeeded(parser.parseOptionalKeyword("float"))) { 201ab29203eSJeremy Kun Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr); 202ab29203eSJeremy Kun if (floatPolyAttr) { 203ab29203eSJeremy Kun if (parser.parseColon() || parser.parseType(type)) 204ab29203eSJeremy Kun return failure(); 205ab29203eSJeremy Kun result.addAttribute("value", 206ab29203eSJeremy Kun TypedFloatPolynomialAttr::get(type, floatPolyAttr)); 207ab29203eSJeremy Kun result.addTypes(type); 208ab29203eSJeremy Kun return success(); 209ab29203eSJeremy Kun } 210ab29203eSJeremy Kun } 211ab29203eSJeremy Kun 212ab29203eSJeremy Kun if (succeeded(parser.parseOptionalKeyword("int"))) { 213ab29203eSJeremy Kun Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr); 214ab29203eSJeremy Kun if (intPolyAttr) { 215ab29203eSJeremy Kun if (parser.parseColon() || parser.parseType(type)) 216ab29203eSJeremy Kun return failure(); 217ab29203eSJeremy Kun 218ab29203eSJeremy Kun result.addAttribute("value", 219ab29203eSJeremy Kun TypedIntPolynomialAttr::get(type, intPolyAttr)); 220ab29203eSJeremy Kun result.addTypes(type); 221ab29203eSJeremy Kun return success(); 222ab29203eSJeremy Kun } 223ab29203eSJeremy Kun } 224ab29203eSJeremy Kun 225ab29203eSJeremy Kun // In the worst case, still accept the verbose versions. 226ab29203eSJeremy Kun TypedIntPolynomialAttr typedIntPolyAttr; 227ab29203eSJeremy Kun OptionalParseResult res = 228ab29203eSJeremy Kun parser.parseOptionalAttribute<TypedIntPolynomialAttr>( 229ab29203eSJeremy Kun typedIntPolyAttr, "value", result.attributes); 230ab29203eSJeremy Kun if (res.has_value() && succeeded(res.value())) { 231ab29203eSJeremy Kun result.addTypes(typedIntPolyAttr.getType()); 232ab29203eSJeremy Kun return success(); 233ab29203eSJeremy Kun } 234ab29203eSJeremy Kun 235ab29203eSJeremy Kun TypedFloatPolynomialAttr typedFloatPolyAttr; 236ab29203eSJeremy Kun res = parser.parseAttribute<TypedFloatPolynomialAttr>( 237ab29203eSJeremy Kun typedFloatPolyAttr, "value", result.attributes); 238ab29203eSJeremy Kun if (res.has_value() && succeeded(res.value())) { 239ab29203eSJeremy Kun result.addTypes(typedFloatPolyAttr.getType()); 240ab29203eSJeremy Kun return success(); 241ab29203eSJeremy Kun } 242ab29203eSJeremy Kun 243ab29203eSJeremy Kun return failure(); 244ab29203eSJeremy Kun } 245ab29203eSJeremy Kun 246ab29203eSJeremy Kun void ConstantOp::print(OpAsmPrinter &p) { 247ab29203eSJeremy Kun p << " "; 248ab29203eSJeremy Kun if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) { 249ab29203eSJeremy Kun p << "int"; 250ab29203eSJeremy Kun intPoly.getValue().print(p); 251ab29203eSJeremy Kun } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) { 252ab29203eSJeremy Kun p << "float"; 253ab29203eSJeremy Kun floatPoly.getValue().print(p); 254ab29203eSJeremy Kun } else { 255ab29203eSJeremy Kun assert(false && "unexpected attribute type"); 256ab29203eSJeremy Kun } 257ab29203eSJeremy Kun p << " : "; 258ab29203eSJeremy Kun p.printType(getOutput().getType()); 259ab29203eSJeremy Kun } 260ab29203eSJeremy Kun 261ab29203eSJeremy Kun LogicalResult ConstantOp::inferReturnTypes( 262ab29203eSJeremy Kun MLIRContext *context, std::optional<mlir::Location> location, 263ab29203eSJeremy Kun ConstantOp::Adaptor adaptor, 264ab29203eSJeremy Kun llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) { 265ab29203eSJeremy Kun Attribute operand = adaptor.getValue(); 266ab29203eSJeremy Kun if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) { 267ab29203eSJeremy Kun inferredReturnTypes.push_back(intPoly.getType()); 268ab29203eSJeremy Kun } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) { 269ab29203eSJeremy Kun inferredReturnTypes.push_back(floatPoly.getType()); 270ab29203eSJeremy Kun } else { 271ab29203eSJeremy Kun assert(false && "unexpected attribute type"); 272ab29203eSJeremy Kun return failure(); 273ab29203eSJeremy Kun } 274ab29203eSJeremy Kun return success(); 275ab29203eSJeremy Kun } 276ab29203eSJeremy Kun 277932bef23SJeremy Kun //===----------------------------------------------------------------------===// 278932bef23SJeremy Kun // TableGen'd canonicalization patterns 279932bef23SJeremy Kun //===----------------------------------------------------------------------===// 280932bef23SJeremy Kun 281932bef23SJeremy Kun namespace { 282932bef23SJeremy Kun #include "PolynomialCanonicalization.inc" 283932bef23SJeremy Kun } // namespace 284932bef23SJeremy Kun 285932bef23SJeremy Kun void SubOp::getCanonicalizationPatterns(RewritePatternSet &results, 286932bef23SJeremy Kun MLIRContext *context) { 287932bef23SJeremy Kun results.add<SubAsAdd>(context); 288932bef23SJeremy Kun } 289932bef23SJeremy Kun 290932bef23SJeremy Kun void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results, 291932bef23SJeremy Kun MLIRContext *context) { 292*d48777ecSHongren Zheng results.add<NTTAfterINTT>(context); 293932bef23SJeremy Kun } 294932bef23SJeremy Kun 295932bef23SJeremy Kun void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results, 296932bef23SJeremy Kun MLIRContext *context) { 297*d48777ecSHongren Zheng results.add<INTTAfterNTT>(context); 298932bef23SJeremy Kun } 299