1 // Written in the D programming language. 2 3 /** 4 This module is a port of a growing fragment of the $(D_PARAM numeric) 5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library, 6 Standard Template Library), with a few additions. 7 8 Macros: 9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009. 10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 11 Authors: $(HTTP erdani.org, Andrei Alexandrescu), 12 Don Clugston, Robert Jacques, Ilya Yaroshenko 13 Source: $(PHOBOSSRC std/numeric.d) 14 */ 15 /* 16 Copyright Andrei Alexandrescu 2008 - 2009. 17 Distributed under the Boost Software License, Version 1.0. 18 (See accompanying file LICENSE_1_0.txt or copy at 19 http://www.boost.org/LICENSE_1_0.txt) 20 */ 21 module std.numeric; 22 23 import std.complex; 24 import std.math; 25 import core.math : fabs, ldexp, sin, sqrt; 26 import std.range.primitives; 27 import std.traits; 28 import std.typecons; 29 30 /// Format flags for CustomFloat. 31 public enum CustomFloatFlags 32 { 33 /// Adds a sign bit to allow for signed numbers. 34 signed = 1, 35 36 /** 37 * Store values in normalized form by default. The actual precision of the 38 * significand is extended by 1 bit by assuming an implicit leading bit of 1 39 * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`. 40 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types 41 */ 42 storeNormalized = 2, 43 44 /** 45 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 46 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. 47 */ 48 allowDenorm = 4, 49 50 /** 51 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, 52 * IEEE754 _infinity) values. 53 */ 54 infinity = 8, 55 56 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. 57 nan = 16, 58 59 /** 60 * If set, select an exponent bias such that max_exp = 1. 61 * i.e. so that the maximum value is >= 1.0 and < 2.0. 62 * Ignored if the exponent bias is manually specified. 63 */ 64 probability = 32, 65 66 /// If set, unsigned custom floats are assumed to be negative. 67 negativeUnsigned = 64, 68 69 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 70 * IEEE754 denormalized) number. 71 * Requires allowDenorm and storeNormalized. 72 */ 73 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, 74 75 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. 76 ieee = signed | storeNormalized | allowDenorm | infinity | nan , 77 78 /// Include none of the above options. 79 none = 0 80 } 81 82 private template CustomFloatParams(uint bits) 83 { 84 enum CustomFloatFlags flags = CustomFloatFlags.ieee 85 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); 86 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); 87 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); 88 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); 89 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); 90 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); 91 } 92 93 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) 94 { 95 import std.meta : AliasSeq; 96 alias CustomFloatParams = 97 AliasSeq!( 98 precision, 99 exponentWidth, 100 flags, 101 (1 << (exponentWidth - ((flags & flags.probability) == 0))) 102 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) 103 ); // ((flags & CustomFloatFlags.probability) == 0) 104 } 105 106 /** 107 * Allows user code to define custom floating-point formats. These formats are 108 * for storage only; all operations on them are performed by first implicitly 109 * extracting them to `real` first. After the operation is completed the 110 * result can be stored in a custom floating-point value via assignment. 111 */ 112 template CustomFloat(uint bits) 113 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) 114 { 115 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); 116 } 117 118 /// ditto 119 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) 120 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) 121 { 122 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); 123 } 124 125 /// 126 @safe unittest 127 { 128 import std.math.trigonometry : sin, cos; 129 130 // Define a 16-bit floating point values 131 CustomFloat!16 x; // Using the number of bits 132 CustomFloat!(10, 5) y; // Using the precision and exponent width 133 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags 134 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias 135 136 // Use the 16-bit floats mostly like normal numbers 137 w = x*y - 1; 138 139 // Functions calls require conversion 140 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real 141 z = sin(x.get!float) + cos(y.get!float); // Or use get!T 142 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert 143 144 // Define a 8-bit custom float for storing probabilities 145 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); 146 auto p = Probability(0.5); 147 } 148 149 // Facilitate converting numeric types to custom float 150 private union ToBinary(F) 151 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) 152 { 153 F set; 154 155 // If on Linux or Mac, where 80-bit reals are padded, ignore the 156 // padding. 157 import std.algorithm.comparison : min; 158 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; 159 160 // Convert F to the correct binary type. 161 static typeof(get) opCall(F value) 162 { 163 ToBinary r; 164 r.set = value; 165 return r.get; 166 } 167 alias get this; 168 } 169 170 /// ditto 171 struct CustomFloat(uint precision, // fraction bits (23 for float) 172 uint exponentWidth, // exponent bits (8 for float) Exponent width 173 CustomFloatFlags flags, 174 uint bias) 175 if (isCorrectCustomFloat(precision, exponentWidth, flags)) 176 { 177 import std.bitmanip : bitfields; 178 import std.meta : staticIndexOf; 179 private: 180 // get the correct unsigned bitfield type to support > 32 bits 181 template uType(uint bits) 182 { 183 static if (bits <= size_t.sizeof*8) alias uType = size_t; 184 else alias uType = ulong ; 185 } 186 187 // get the correct signed bitfield type to support > 32 bits 188 template sType(uint bits) 189 { 190 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; 191 else alias sType = long; 192 } 193 194 alias T_sig = uType!precision; 195 alias T_exp = uType!exponentWidth; 196 alias T_signed_exp = sType!exponentWidth; 197 198 alias Flags = CustomFloatFlags; 199 200 // Perform IEEE rounding with round to nearest detection 201 void roundedShift(T,U)(ref T sig, U shift) 202 { 203 if (shift >= T.sizeof*8) 204 { 205 // avoid illegal shift 206 sig = 0; 207 } 208 else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) 209 { 210 // round to even 211 sig >>= shift; 212 sig += sig & 1; 213 } 214 else 215 { 216 sig >>= shift - 1; 217 sig += sig & 1; 218 // Perform standard rounding 219 sig >>= 1; 220 } 221 } 222 223 // Convert the current value to signed exponent, normalized form 224 void toNormalized(T,U)(ref T sig, ref U exp) 225 { 226 sig = significand; 227 auto shift = (T.sizeof*8) - precision; 228 exp = exponent; 229 static if (flags&(Flags.infinity|Flags.nan)) 230 { 231 // Handle inf or nan 232 if (exp == exponent_max) 233 { 234 exp = exp.max; 235 sig <<= shift; 236 static if (flags&Flags.storeNormalized) 237 { 238 // Save inf/nan in denormalized format 239 sig >>= 1; 240 sig += cast(T) 1uL << (T.sizeof*8 - 1); 241 } 242 return; 243 } 244 } 245 if ((~flags&Flags.storeNormalized) || 246 // Convert denormalized form to normalized form 247 ((flags&Flags.allowDenorm) && exp == 0)) 248 { 249 if (sig > 0) 250 { 251 import core.bitop : bsr; 252 auto shift2 = precision - bsr(sig); 253 exp -= shift2-1; 254 shift += shift2; 255 } 256 else // value = 0.0 257 { 258 exp = exp.min; 259 return; 260 } 261 } 262 sig <<= shift; 263 exp -= bias; 264 } 265 266 // Set the current value from signed exponent, normalized form 267 void fromNormalized(T,U)(ref T sig, ref U exp) 268 { 269 auto shift = (T.sizeof*8) - precision; 270 if (exp == exp.max) 271 { 272 // infinity or nan 273 exp = exponent_max; 274 static if (flags & Flags.storeNormalized) 275 sig <<= 1; 276 277 // convert back to normalized form 278 static if (~flags & Flags.infinity) 279 // No infinity support? 280 assert(sig != 0, "Infinity floating point value assigned to a " 281 ~ typeof(this).stringof ~ " (no infinity support)."); 282 283 static if (~flags & Flags.nan) // No NaN support? 284 assert(sig == 0, "NaN floating point value assigned to a " ~ 285 typeof(this).stringof ~ " (no nan support)."); 286 sig >>= shift; 287 return; 288 } 289 if (exp == exp.min) // 0.0 290 { 291 exp = 0; 292 sig = 0; 293 return; 294 } 295 296 exp += bias; 297 if (exp <= 0) 298 { 299 static if ((flags&Flags.allowDenorm) || 300 // Convert from normalized form to denormalized 301 (~flags&Flags.storeNormalized)) 302 { 303 shift += -exp; 304 roundedShift(sig,1); 305 sig += cast(T) 1uL << (T.sizeof*8 - 1); 306 // Add the leading 1 307 exp = 0; 308 } 309 else 310 assert((flags&Flags.storeNormalized) && exp == 0, 311 "Underflow occured assigning to a " ~ 312 typeof(this).stringof ~ " (no denormal support)."); 313 } 314 else 315 { 316 static if (~flags&Flags.storeNormalized) 317 { 318 // Convert from normalized form to denormalized 319 roundedShift(sig,1); 320 sig += cast(T) 1uL << (T.sizeof*8 - 1); 321 // Add the leading 1 322 } 323 } 324 325 if (shift > 0) 326 roundedShift(sig,shift); 327 if (sig > significand_max) 328 { 329 // handle significand overflow (should only be 1 bit) 330 static if (~flags&Flags.storeNormalized) 331 { 332 sig >>= 1; 333 } 334 else 335 sig &= significand_max; 336 exp++; 337 } 338 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) 339 { 340 // disallow non-zero denormals 341 if (exp == 0) 342 { 343 sig <<= 1; 344 if (sig > significand_max && (sig&significand_max) > 0) 345 // Check and round to even 346 exp++; 347 sig = 0; 348 } 349 } 350 351 if (exp >= exponent_max) 352 { 353 static if (flags&(Flags.infinity|Flags.nan)) 354 { 355 sig = 0; 356 exp = exponent_max; 357 static if (~flags&(Flags.infinity)) 358 assert(0, "Overflow occured assigning to a " ~ 359 typeof(this).stringof ~ " (no infinity support)."); 360 } 361 else 362 assert(exp == exponent_max, "Overflow occured assigning to a " 363 ~ typeof(this).stringof ~ " (no infinity support)."); 364 } 365 } 366 367 public: 368 static if (precision == 64) // CustomFloat!80 support hack 369 { 370 ulong significand; 371 enum ulong significand_max = ulong.max; 372 mixin(bitfields!( 373 T_exp , "exponent", exponentWidth, 374 bool , "sign" , flags & flags.signed )); 375 } 376 else 377 { 378 mixin(bitfields!( 379 T_sig, "significand", precision, 380 T_exp, "exponent" , exponentWidth, 381 bool , "sign" , flags & flags.signed )); 382 } 383 384 /// Returns: infinity value 385 static if (flags & Flags.infinity) 386 static @property CustomFloat infinity() 387 { 388 CustomFloat value; 389 static if (flags & Flags.signed) 390 value.sign = 0; 391 value.significand = 0; 392 value.exponent = exponent_max; 393 return value; 394 } 395 396 /// Returns: NaN value 397 static if (flags & Flags.nan) 398 static @property CustomFloat nan() 399 { 400 CustomFloat value; 401 static if (flags & Flags.signed) 402 value.sign = 0; 403 value.significand = cast(typeof(significand_max)) 1L << (precision-1); 404 value.exponent = exponent_max; 405 return value; 406 } 407 408 /// Returns: number of decimal digits of precision 409 static @property size_t dig() 410 { 411 auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0); 412 return shiftcnt == 64 ? 19 : cast(size_t) log10(1uL << shiftcnt); 413 } 414 415 /// Returns: smallest increment to the value 1 416 static @property CustomFloat epsilon() 417 { 418 CustomFloat one = CustomFloat(1); 419 CustomFloat onePlusEpsilon = one; 420 onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here 421 422 return CustomFloat(onePlusEpsilon - one); 423 } 424 425 /// the number of bits in mantissa 426 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); 427 428 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable 429 static @property int max_10_exp(){ return cast(int) log10( +max ); } 430 431 /// maximum int value such that 2<sup>max_exp-1</sup> is representable 432 enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1; 433 434 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable 435 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } 436 437 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value 438 enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0); 439 440 /// Returns: largest representable value that's not infinity 441 static @property CustomFloat max() 442 { 443 CustomFloat value; 444 static if (flags & Flags.signed) 445 value.sign = 0; 446 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); 447 value.significand = significand_max; 448 return value; 449 } 450 451 /// Returns: smallest representable normalized value that's not 0 452 static @property CustomFloat min_normal() 453 { 454 CustomFloat value; 455 static if (flags & Flags.signed) 456 value.sign = 0; 457 value.exponent = (flags & Flags.allowDenorm) != 0; 458 static if (flags & Flags.storeNormalized) 459 value.significand = 0; 460 else 461 value.significand = cast(T_sig) 1uL << (precision - 1); 462 return value; 463 } 464 465 /// Returns: real part 466 @property CustomFloat re() { return this; } 467 468 /// Returns: imaginary part 469 static @property CustomFloat im() { return CustomFloat(0.0f); } 470 471 /// Initialize from any `real` compatible type. 472 this(F)(F input) if (__traits(compiles, cast(real) input )) 473 { 474 this = input; 475 } 476 477 /// Self assignment 478 void opAssign(F:CustomFloat)(F input) 479 { 480 static if (flags & Flags.signed) 481 sign = input.sign; 482 exponent = input.exponent; 483 significand = input.significand; 484 } 485 486 /// Assigns from any `real` compatible type. 487 void opAssign(F)(F input) 488 if (__traits(compiles, cast(real) input)) 489 { 490 import std.conv : text; 491 492 static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 493 auto value = ToBinary!(Unqual!F)(input); 494 else 495 auto value = ToBinary!(real )(input); 496 497 // Assign the sign bit 498 static if (~flags & Flags.signed) 499 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), 500 "Incorrectly signed floating point value assigned to a " ~ 501 typeof(this).stringof ~ " (no sign support)."); 502 else 503 sign = value.sign; 504 505 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; 506 CommonType!(T_sig, value.T_sig ) sig = value.significand; 507 508 value.toNormalized(sig,exp); 509 fromNormalized(sig,exp); 510 511 assert(exp <= exponent_max, text(typeof(this).stringof ~ 512 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); 513 assert(sig <= significand_max, text(typeof(this).stringof ~ 514 " significand too large: ",sig," > ",significand_max, 515 "\t",input,"\t",exp," ",exponent_max)); 516 exponent = cast(T_exp) exp; 517 significand = cast(T_sig) sig; 518 } 519 520 /// Fetches the stored value either as a `float`, `double` or `real`. 521 @property F get(F)() 522 if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 523 { 524 import std.conv : text; 525 526 ToBinary!F result; 527 528 static if (flags&Flags.signed) 529 result.sign = sign; 530 else 531 result.sign = (flags&flags.negativeUnsigned) > 0; 532 533 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction 534 CommonType!(T_sig, result.get.T_sig ) sig = significand; 535 536 toNormalized(sig,exp); 537 result.fromNormalized(sig,exp); 538 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); 539 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); 540 result.exponent = cast(result.get.T_exp) exp; 541 result.significand = cast(result.get.T_sig) sig; 542 return result.set; 543 } 544 545 ///ditto 546 alias opCast = get; 547 548 /// Convert the CustomFloat to a real and perform the relevant operator on the result 549 real opUnary(string op)() 550 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") 551 { 552 static if (op=="++" || op=="--") 553 { 554 auto result = get!real; 555 this = mixin(op~`result`); 556 return result; 557 } 558 else 559 return mixin(op~`get!real`); 560 } 561 562 /// ditto 563 // Define an opBinary `CustomFloat op CustomFloat` so that those below 564 // do not match equally, which is disallowed by the spec: 565 // https://dlang.org/spec/operatoroverloading.html#binary 566 real opBinary(string op,T)(T b) 567 if (__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 568 { 569 return mixin(`get!real`~op~`b.get!real`); 570 } 571 572 /// ditto 573 real opBinary(string op,T)(T b) 574 if ( __traits(compiles, mixin(`get!real`~op~`b`)) && 575 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 576 { 577 return mixin(`get!real`~op~`b`); 578 } 579 580 /// ditto 581 real opBinaryRight(string op,T)(T a) 582 if ( __traits(compiles, mixin(`a`~op~`get!real`)) && 583 !__traits(compiles, mixin(`get!real`~op~`b`)) && 584 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 585 { 586 return mixin(`a`~op~`get!real`); 587 } 588 589 /// ditto 590 int opCmp(T)(auto ref T b) 591 if (__traits(compiles, cast(real) b)) 592 { 593 auto x = get!real; 594 auto y = cast(real) b; 595 return (x >= y)-(x <= y); 596 } 597 598 /// ditto 599 void opOpAssign(string op, T)(auto ref T b) 600 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) 601 { 602 return mixin(`this = this `~op~` cast(real) b`); 603 } 604 605 /// ditto 606 template toString() 607 { 608 import std.format.spec : FormatSpec; 609 import std.format.write : formatValue; 610 // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737. 611 void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt) 612 { 613 sink.formatValue(get!real, fmt); 614 } 615 } 616 } 617 618 @safe unittest 619 { 620 import std.meta; 621 alias FPTypes = 622 AliasSeq!( 623 CustomFloat!(5, 10), 624 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 625 CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 626 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) 627 ); 628 629 foreach (F; FPTypes) 630 { 631 auto x = F(0.125); 632 assert(x.get!float == 0.125F); 633 assert(x.get!double == 0.125); 634 635 x -= 0.0625; 636 assert(x.get!float == 0.0625F); 637 assert(x.get!double == 0.0625); 638 639 x *= 2; 640 assert(x.get!float == 0.125F); 641 assert(x.get!double == 0.125); 642 643 x /= 4; 644 assert(x.get!float == 0.03125); 645 assert(x.get!double == 0.03125); 646 647 x = 0.5; 648 x ^^= 4; 649 assert(x.get!float == 1 / 16.0F); 650 assert(x.get!double == 1 / 16.0); 651 } 652 } 653 654 @system unittest 655 { 656 // @system due to to!string(CustomFloat) 657 import std.conv; 658 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); 659 assert(y.to!string == "0.125"); 660 } 661 662 @safe unittest 663 { 664 alias cf = CustomFloat!(5, 2); 665 666 auto a = cf.infinity; 667 assert(a.sign == 0); 668 assert(a.exponent == 3); 669 assert(a.significand == 0); 670 671 auto b = cf.nan; 672 assert(b.exponent == 3); 673 assert(b.significand != 0); 674 675 assert(cf.dig == 1); 676 677 auto c = cf.epsilon; 678 assert(c.sign == 0); 679 assert(c.exponent == 0); 680 assert(c.significand == 1); 681 682 assert(cf.mant_dig == 6); 683 684 assert(cf.max_10_exp == 0); 685 assert(cf.max_exp == 2); 686 assert(cf.min_10_exp == 0); 687 assert(cf.min_exp == 1); 688 689 auto d = cf.max; 690 assert(d.sign == 0); 691 assert(d.exponent == 2); 692 assert(d.significand == 31); 693 694 auto e = cf.min_normal; 695 assert(e.sign == 0); 696 assert(e.exponent == 1); 697 assert(e.significand == 0); 698 699 assert(e.re == e); 700 assert(e.im == cf(0.0)); 701 } 702 703 // check whether CustomFloats identical to float/double behave like float/double 704 @safe unittest 705 { 706 import std.conv : to; 707 708 alias myFloat = CustomFloat!(23, 8); 709 710 static assert(myFloat.dig == float.dig); 711 static assert(myFloat.mant_dig == float.mant_dig); 712 assert(myFloat.max_10_exp == float.max_10_exp); 713 static assert(myFloat.max_exp == float.max_exp); 714 assert(myFloat.min_10_exp == float.min_10_exp); 715 static assert(myFloat.min_exp == float.min_exp); 716 assert(to!float(myFloat.epsilon) == float.epsilon); 717 assert(to!float(myFloat.max) == float.max); 718 assert(to!float(myFloat.min_normal) == float.min_normal); 719 720 alias myDouble = CustomFloat!(52, 11); 721 722 static assert(myDouble.dig == double.dig); 723 static assert(myDouble.mant_dig == double.mant_dig); 724 assert(myDouble.max_10_exp == double.max_10_exp); 725 static assert(myDouble.max_exp == double.max_exp); 726 assert(myDouble.min_10_exp == double.min_10_exp); 727 static assert(myDouble.min_exp == double.min_exp); 728 assert(to!double(myDouble.epsilon) == double.epsilon); 729 assert(to!double(myDouble.max) == double.max); 730 assert(to!double(myDouble.min_normal) == double.min_normal); 731 } 732 733 // testing .dig 734 @safe unittest 735 { 736 static assert(CustomFloat!(1, 6).dig == 0); 737 static assert(CustomFloat!(9, 6).dig == 2); 738 static assert(CustomFloat!(10, 5).dig == 3); 739 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2); 740 static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3); 741 static assert(CustomFloat!(64, 7).dig == 19); 742 } 743 744 // testing .mant_dig 745 @safe unittest 746 { 747 static assert(CustomFloat!(10, 5).mant_dig == 11); 748 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10); 749 } 750 751 // testing .max_exp 752 @safe unittest 753 { 754 static assert(CustomFloat!(1, 6).max_exp == 2^^5); 755 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5); 756 static assert(CustomFloat!(5, 10).max_exp == 2^^9); 757 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9); 758 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5); 759 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9); 760 } 761 762 // testing .min_exp 763 @safe unittest 764 { 765 static assert(CustomFloat!(1, 6).min_exp == -2^^5+3); 766 static assert(CustomFloat!(5, 10).min_exp == -2^^9+3); 767 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1); 768 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1); 769 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2); 770 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2); 771 static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2); 772 static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2); 773 } 774 775 // testing .max_10_exp 776 @safe unittest 777 { 778 assert(CustomFloat!(1, 6).max_10_exp == 9); 779 assert(CustomFloat!(5, 10).max_10_exp == 154); 780 assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9); 781 assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154); 782 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9); 783 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154); 784 } 785 786 // testing .min_10_exp 787 @safe unittest 788 { 789 assert(CustomFloat!(1, 6).min_10_exp == -9); 790 assert(CustomFloat!(5, 10).min_10_exp == -153); 791 assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9); 792 assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154); 793 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9); 794 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153); 795 assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9); 796 assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153); 797 } 798 799 // testing .epsilon 800 @safe unittest 801 { 802 assert(CustomFloat!(1,6).epsilon.sign == 0); 803 assert(CustomFloat!(1,6).epsilon.exponent == 30); 804 assert(CustomFloat!(1,6).epsilon.significand == 0); 805 assert(CustomFloat!(2,5).epsilon.sign == 0); 806 assert(CustomFloat!(2,5).epsilon.exponent == 13); 807 assert(CustomFloat!(2,5).epsilon.significand == 0); 808 assert(CustomFloat!(3,4).epsilon.sign == 0); 809 assert(CustomFloat!(3,4).epsilon.exponent == 4); 810 assert(CustomFloat!(3,4).epsilon.significand == 0); 811 // the following epsilons are only available, when denormalized numbers are allowed: 812 assert(CustomFloat!(4,3).epsilon.sign == 0); 813 assert(CustomFloat!(4,3).epsilon.exponent == 0); 814 assert(CustomFloat!(4,3).epsilon.significand == 4); 815 assert(CustomFloat!(5,2).epsilon.sign == 0); 816 assert(CustomFloat!(5,2).epsilon.exponent == 0); 817 assert(CustomFloat!(5,2).epsilon.significand == 1); 818 } 819 820 // testing .max 821 @safe unittest 822 { 823 static assert(CustomFloat!(5,2).max.sign == 0); 824 static assert(CustomFloat!(5,2).max.exponent == 2); 825 static assert(CustomFloat!(5,2).max.significand == 31); 826 static assert(CustomFloat!(4,3).max.sign == 0); 827 static assert(CustomFloat!(4,3).max.exponent == 6); 828 static assert(CustomFloat!(4,3).max.significand == 15); 829 static assert(CustomFloat!(3,4).max.sign == 0); 830 static assert(CustomFloat!(3,4).max.exponent == 14); 831 static assert(CustomFloat!(3,4).max.significand == 7); 832 static assert(CustomFloat!(2,5).max.sign == 0); 833 static assert(CustomFloat!(2,5).max.exponent == 30); 834 static assert(CustomFloat!(2,5).max.significand == 3); 835 static assert(CustomFloat!(1,6).max.sign == 0); 836 static assert(CustomFloat!(1,6).max.exponent == 62); 837 static assert(CustomFloat!(1,6).max.significand == 1); 838 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31); 839 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7); 840 } 841 842 // testing .min_normal 843 @safe unittest 844 { 845 static assert(CustomFloat!(5,2).min_normal.sign == 0); 846 static assert(CustomFloat!(5,2).min_normal.exponent == 1); 847 static assert(CustomFloat!(5,2).min_normal.significand == 0); 848 static assert(CustomFloat!(4,3).min_normal.sign == 0); 849 static assert(CustomFloat!(4,3).min_normal.exponent == 1); 850 static assert(CustomFloat!(4,3).min_normal.significand == 0); 851 static assert(CustomFloat!(3,4).min_normal.sign == 0); 852 static assert(CustomFloat!(3,4).min_normal.exponent == 1); 853 static assert(CustomFloat!(3,4).min_normal.significand == 0); 854 static assert(CustomFloat!(2,5).min_normal.sign == 0); 855 static assert(CustomFloat!(2,5).min_normal.exponent == 1); 856 static assert(CustomFloat!(2,5).min_normal.significand == 0); 857 static assert(CustomFloat!(1,6).min_normal.sign == 0); 858 static assert(CustomFloat!(1,6).min_normal.exponent == 1); 859 static assert(CustomFloat!(1,6).min_normal.significand == 0); 860 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0); 861 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4); 862 } 863 864 @safe unittest 865 { 866 import std.math.traits : isNaN; 867 868 alias cf = CustomFloat!(5, 2); 869 870 auto f = cf.nan.get!float(); 871 assert(isNaN(f)); 872 873 cf a; 874 a = real.max; 875 assert(a == cf.infinity); 876 877 a = 0.015625; 878 assert(a.exponent == 0); 879 assert(a.significand == 0); 880 881 a = 0.984375; 882 assert(a.exponent == 1); 883 assert(a.significand == 0); 884 } 885 886 @system unittest 887 { 888 import std.exception : assertThrown; 889 import core.exception : AssertError; 890 891 alias cf = CustomFloat!(3, 5, CustomFloatFlags.none); 892 893 cf a; 894 assertThrown!AssertError(a = real.max); 895 } 896 897 @system unittest 898 { 899 import std.exception : assertThrown; 900 import core.exception : AssertError; 901 902 alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan); 903 904 cf a; 905 assertThrown!AssertError(a = real.max); 906 } 907 908 @system unittest 909 { 910 import std.exception : assertThrown; 911 import core.exception : AssertError; 912 913 alias cf = CustomFloat!(24, 8, CustomFloatFlags.none); 914 915 cf a; 916 assertThrown!AssertError(a = float.infinity); 917 } 918 919 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc 920 { 921 // Restrictions from bitfield 922 // due to CustomFloat!80 support hack precision with 64 bits is handled specially 923 auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision); 924 if (length != 8 && length != 16 && length != 32 && length != 64) return false; 925 926 // mantissa needs to fit into real mantissa 927 if (precision > real.mant_dig - 1 && precision != 64) return false; 928 929 // exponent needs to fit into real exponent 930 if (1L << exponentWidth - 1 > real.max_exp) return false; 931 932 // mantissa should have at least one bit 933 if (precision == 0) return false; 934 935 // exponent should have at least one bit, in some cases two 936 if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false; 937 938 return true; 939 } 940 941 @safe pure nothrow @nogc unittest 942 { 943 assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee)); 944 assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none)); 945 assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee)); 946 assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee)); 947 assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee)); 948 assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee)); 949 assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee)); 950 assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee)); 951 assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee)); 952 assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none)); 953 assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none)); 954 } 955 956 /** 957 Defines the fastest type to use when storing temporaries of a 958 calculation intended to ultimately yield a result of type `F` 959 (where `F` must be one of `float`, `double`, or $(D 960 real)). When doing a multi-step computation, you may want to store 961 intermediate results as `FPTemporary!F`. 962 963 The necessity of `FPTemporary` stems from the optimized 964 floating-point operations and registers present in virtually all 965 processors. When adding numbers in the example above, the addition may 966 in fact be done in `real` precision internally. In that case, 967 storing the intermediate `result` in $(D double format) is not only 968 less precise, it is also (surprisingly) slower, because a conversion 969 from `real` to `double` is performed every pass through the 970 loop. This being a lose-lose situation, `FPTemporary!F` has been 971 defined as the $(I fastest) type to use for calculations at precision 972 `F`. There is no need to define a type for the $(I most accurate) 973 calculations, as that is always `real`. 974 975 Finally, there is no guarantee that using `FPTemporary!F` will 976 always be fastest, as the speed of floating-point calculations depends 977 on very many factors. 978 */ 979 template FPTemporary(F) 980 if (isFloatingPoint!F) 981 { 982 version (X86) 983 alias FPTemporary = real; 984 else 985 alias FPTemporary = Unqual!F; 986 } 987 988 /// 989 @safe unittest 990 { 991 import std.math.operations : isClose; 992 993 // Average numbers in an array 994 double avg(in double[] a) 995 { 996 if (a.length == 0) return 0; 997 FPTemporary!double result = 0; 998 foreach (e; a) result += e; 999 return result / a.length; 1000 } 1001 1002 auto a = [1.0, 2.0, 3.0]; 1003 assert(isClose(avg(a), 2)); 1004 } 1005 1006 /** 1007 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a 1008 root of the function `fun` starting from points $(D [xn_1, x_n]) 1009 (ideally close to the root). `Num` may be `float`, `double`, 1010 or `real`. 1011 */ 1012 template secantMethod(alias fun) 1013 { 1014 import std.functional : unaryFun; 1015 Num secantMethod(Num)(Num xn_1, Num xn) 1016 { 1017 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; 1018 typeof(fxn) fxn_1; 1019 1020 xn = xn_1; 1021 while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d)) 1022 { 1023 xn_1 = xn; 1024 xn -= d; 1025 fxn_1 = fxn; 1026 fxn = unaryFun!(fun)(xn); 1027 d *= -fxn / (fxn - fxn_1); 1028 } 1029 return xn; 1030 } 1031 } 1032 1033 /// 1034 @safe unittest 1035 { 1036 import std.math.operations : isClose; 1037 import std.math.trigonometry : cos; 1038 1039 float f(float x) 1040 { 1041 return cos(x) - x*x*x; 1042 } 1043 auto x = secantMethod!(f)(0f, 1f); 1044 assert(isClose(x, 0.865474)); 1045 } 1046 1047 @system unittest 1048 { 1049 // @system because of __gshared stderr 1050 import std.stdio; 1051 scope(failure) stderr.writeln("Failure testing secantMethod"); 1052 float f(float x) 1053 { 1054 return cos(x) - x*x*x; 1055 } 1056 immutable x = secantMethod!(f)(0f, 1f); 1057 assert(isClose(x, 0.865474)); 1058 auto d = &f; 1059 immutable y = secantMethod!(d)(0f, 1f); 1060 assert(isClose(y, 0.865474)); 1061 } 1062 1063 1064 /** 1065 * Return true if a and b have opposite sign. 1066 */ 1067 private bool oppositeSigns(T1, T2)(T1 a, T2 b) 1068 { 1069 return signbit(a) != signbit(b); 1070 } 1071 1072 public: 1073 1074 /** Find a real root of a real function f(x) via bracketing. 1075 * 1076 * Given a function `f` and a range `[a .. b]` such that `f(a)` 1077 * and `f(b)` have opposite signs or at least one of them equals ±0, 1078 * returns the value of `x` in 1079 * the range which is closest to a root of `f(x)`. If `f(x)` 1080 * has more than one root in the range, one will be chosen 1081 * arbitrarily. If `f(x)` returns NaN, NaN will be returned; 1082 * otherwise, this algorithm is guaranteed to succeed. 1083 * 1084 * Uses an algorithm based on TOMS748, which uses inverse cubic 1085 * interpolation whenever possible, otherwise reverting to parabolic 1086 * or secant interpolation. Compared to TOMS748, this implementation 1087 * improves worst-case performance by a factor of more than 100, and 1088 * typical performance by a factor of 2. For 80-bit reals, most 1089 * problems require 8 to 15 calls to `f(x)` to achieve full machine 1090 * precision. The worst-case performance (pathological cases) is 1091 * approximately twice the number of bits. 1092 * 1093 * References: "On Enclosing Simple Roots of Nonlinear Equations", 1094 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, 1095 * pp733-744 (1993). Fortran code available from $(HTTP 1096 * www.netlib.org,www.netlib.org) as algorithm TOMS478. 1097 * 1098 */ 1099 T findRoot(T, DF, DT)(scope DF f, const T a, const T b, 1100 scope DT tolerance) //= (T a, T b) => false) 1101 if ( 1102 isFloatingPoint!T && 1103 is(typeof(tolerance(T.init, T.init)) : bool) && 1104 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R 1105 ) 1106 { 1107 immutable fa = f(a); 1108 if (fa == 0) 1109 return a; 1110 immutable fb = f(b); 1111 if (fb == 0) 1112 return b; 1113 immutable r = findRoot(f, a, b, fa, fb, tolerance); 1114 // Return the first value if it is smaller or NaN 1115 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; 1116 } 1117 1118 ///ditto 1119 T findRoot(T, DF)(scope DF f, const T a, const T b) 1120 { 1121 return findRoot(f, a, b, (T a, T b) => false); 1122 } 1123 1124 /** Find root of a real function f(x) by bracketing, allowing the 1125 * termination condition to be specified. 1126 * 1127 * Params: 1128 * 1129 * f = Function to be analyzed 1130 * 1131 * ax = Left bound of initial range of `f` known to contain the 1132 * root. 1133 * 1134 * bx = Right bound of initial range of `f` known to contain the 1135 * root. 1136 * 1137 * fax = Value of `f(ax)`. 1138 * 1139 * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs. 1140 * (`f(ax)` and `f(bx)` are commonly known in advance.) 1141 * 1142 * 1143 * tolerance = Defines an early termination condition. Receives the 1144 * current upper and lower bounds on the root. The 1145 * delegate must return `true` when these bounds are 1146 * acceptable. If this function always returns `false`, 1147 * full machine precision will be achieved. 1148 * 1149 * Returns: 1150 * 1151 * A tuple consisting of two ranges. The first two elements are the 1152 * range (in `x`) of the root, while the second pair of elements 1153 * are the corresponding function values at those points. If an exact 1154 * root was found, both of the first two elements will contain the 1155 * root, and the second pair of elements will be 0. 1156 */ 1157 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, 1158 const T ax, const T bx, const R fax, const R fbx, 1159 scope DT tolerance) // = (T a, T b) => false) 1160 if ( 1161 isFloatingPoint!T && 1162 is(typeof(tolerance(T.init, T.init)) : bool) && 1163 is(typeof(f(T.init)) == R) && isFloatingPoint!R 1164 ) 1165 in 1166 { 1167 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); 1168 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); 1169 } 1170 do 1171 { 1172 // Author: Don Clugston. This code is (heavily) modified from TOMS748 1173 // (www.netlib.org). The changes to improve the worst-cast performance are 1174 // entirely original. 1175 1176 T a, b, d; // [a .. b] is our current bracket. d is the third best guess. 1177 R fa, fb, fd; // Values of f at a, b, d. 1178 bool done = false; // Has a root been found? 1179 1180 // Allow ax and bx to be provided in reverse order 1181 if (ax <= bx) 1182 { 1183 a = ax; fa = fax; 1184 b = bx; fb = fbx; 1185 } 1186 else 1187 { 1188 a = bx; fa = fbx; 1189 b = ax; fb = fax; 1190 } 1191 1192 // Test the function at point c; update brackets accordingly 1193 void bracket(T c) 1194 { 1195 R fc = f(c); 1196 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN 1197 { 1198 a = c; 1199 fa = fc; 1200 d = c; 1201 fd = fc; 1202 done = true; 1203 return; 1204 } 1205 1206 // Determine new enclosing interval 1207 if (signbit(fa) != signbit(fc)) 1208 { 1209 d = b; 1210 fd = fb; 1211 b = c; 1212 fb = fc; 1213 } 1214 else 1215 { 1216 d = a; 1217 fd = fa; 1218 a = c; 1219 fa = fc; 1220 } 1221 } 1222 1223 /* Perform a secant interpolation. If the result would lie on a or b, or if 1224 a and b differ so wildly in magnitude that the result would be meaningless, 1225 perform a bisection instead. 1226 */ 1227 static T secant_interpolate(T a, T b, R fa, R fb) 1228 { 1229 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) 1230 { 1231 // Catastrophic cancellation 1232 if (a == 0) 1233 a = copysign(T(0), b); 1234 else if (b == 0) 1235 b = copysign(T(0), a); 1236 else if (signbit(a) != signbit(b)) 1237 return 0; 1238 T c = ieeeMean(a, b); 1239 return c; 1240 } 1241 // avoid overflow 1242 if (b - a > T.max) 1243 return b / 2 + a / 2; 1244 if (fb - fa > R.max) 1245 return a - (b - a) / 2; 1246 T c = a - (fa / (fb - fa)) * (b - a); 1247 if (c == a || c == b) 1248 return (a + b) / 2; 1249 return c; 1250 } 1251 1252 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the 1253 quadratic polynomial interpolating f(x) at a, b, and d. 1254 Returns: 1255 The approximate zero in [a .. b] of the quadratic polynomial. 1256 */ 1257 T newtonQuadratic(int numsteps) 1258 { 1259 // Find the coefficients of the quadratic polynomial. 1260 immutable T a0 = fa; 1261 immutable T a1 = (fb - fa)/(b - a); 1262 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); 1263 1264 // Determine the starting point of newton steps. 1265 T c = oppositeSigns(a2, fa) ? a : b; 1266 1267 // start the safeguarded newton steps. 1268 foreach (int i; 0 .. numsteps) 1269 { 1270 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); 1271 immutable T pdc = a1 + a2*((2 * c) - (a + b)); 1272 if (pdc == 0) 1273 return a - a0 / a1; 1274 else 1275 c = c - pc / pdc; 1276 } 1277 return c; 1278 } 1279 1280 // On the first iteration we take a secant step: 1281 if (fa == 0 || fa.isNaN()) 1282 { 1283 done = true; 1284 b = a; 1285 fb = fa; 1286 } 1287 else if (fb == 0 || fb.isNaN()) 1288 { 1289 done = true; 1290 a = b; 1291 fa = fb; 1292 } 1293 else 1294 { 1295 bracket(secant_interpolate(a, b, fa, fb)); 1296 } 1297 1298 // Starting with the second iteration, higher-order interpolation can 1299 // be used. 1300 int itnum = 1; // Iteration number 1301 int baditer = 1; // Num bisections to take if an iteration is bad. 1302 T c, e; // e is our fourth best guess 1303 R fe; 1304 1305 whileloop: 1306 while (!done && (b != nextUp(a)) && !tolerance(a, b)) 1307 { 1308 T a0 = a, b0 = b; // record the brackets 1309 1310 // Do two higher-order (cubic or parabolic) interpolation steps. 1311 foreach (int QQ; 0 .. 2) 1312 { 1313 // Cubic inverse interpolation requires that 1314 // all four function values fa, fb, fd, and fe are distinct; 1315 // otherwise use quadratic interpolation. 1316 bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 1317 && (fb != fd) && (fb != fe) && (fd != fe); 1318 // The first time, cubic interpolation is impossible. 1319 if (itnum<2) distinct = false; 1320 bool ok = distinct; 1321 if (distinct) 1322 { 1323 // Cubic inverse interpolation of f(x) at a, b, d, and e 1324 immutable q11 = (d - e) * fd / (fe - fd); 1325 immutable q21 = (b - d) * fb / (fd - fb); 1326 immutable q31 = (a - b) * fa / (fb - fa); 1327 immutable d21 = (b - d) * fd / (fd - fb); 1328 immutable d31 = (a - b) * fb / (fb - fa); 1329 1330 immutable q22 = (d21 - q11) * fb / (fe - fb); 1331 immutable q32 = (d31 - q21) * fa / (fd - fa); 1332 immutable d32 = (d31 - q21) * fd / (fd - fa); 1333 immutable q33 = (d32 - q22) * fa / (fe - fa); 1334 c = a + (q31 + q32 + q33); 1335 if (c.isNaN() || (c <= a) || (c >= b)) 1336 { 1337 // DAC: If the interpolation predicts a or b, it's 1338 // probable that it's the actual root. Only allow this if 1339 // we're already close to the root. 1340 if (c == a && a - b != a) 1341 { 1342 c = nextUp(a); 1343 } 1344 else if (c == b && a - b != -b) 1345 { 1346 c = nextDown(b); 1347 } 1348 else 1349 { 1350 ok = false; 1351 } 1352 } 1353 } 1354 if (!ok) 1355 { 1356 // DAC: Alefeld doesn't explain why the number of newton steps 1357 // should vary. 1358 c = newtonQuadratic(distinct ? 3 : 2); 1359 if (c.isNaN() || (c <= a) || (c >= b)) 1360 { 1361 // Failure, try a secant step: 1362 c = secant_interpolate(a, b, fa, fb); 1363 } 1364 } 1365 ++itnum; 1366 e = d; 1367 fe = fd; 1368 bracket(c); 1369 if (done || ( b == nextUp(a)) || tolerance(a, b)) 1370 break whileloop; 1371 if (itnum == 2) 1372 continue whileloop; 1373 } 1374 1375 // Now we take a double-length secant step: 1376 T u; 1377 R fu; 1378 if (fabs(fa) < fabs(fb)) 1379 { 1380 u = a; 1381 fu = fa; 1382 } 1383 else 1384 { 1385 u = b; 1386 fu = fb; 1387 } 1388 c = u - 2 * (fu / (fb - fa)) * (b - a); 1389 1390 // DAC: If the secant predicts a value equal to an endpoint, it's 1391 // probably false. 1392 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) 1393 { 1394 if ((a-b) == a || (b-a) == b) 1395 { 1396 if ((a>0 && b<0) || (a<0 && b>0)) 1397 c = 0; 1398 else 1399 { 1400 if (a == 0) 1401 c = ieeeMean(copysign(T(0), b), b); 1402 else if (b == 0) 1403 c = ieeeMean(copysign(T(0), a), a); 1404 else 1405 c = ieeeMean(a, b); 1406 } 1407 } 1408 else 1409 { 1410 c = a + (b - a) / 2; 1411 } 1412 } 1413 e = d; 1414 fe = fd; 1415 bracket(c); 1416 if (done || (b == nextUp(a)) || tolerance(a, b)) 1417 break; 1418 1419 // IMPROVE THE WORST-CASE PERFORMANCE 1420 // We must ensure that the bounds reduce by a factor of 2 1421 // in binary space! every iteration. If we haven't achieved this 1422 // yet, or if we don't yet know what the exponent is, 1423 // perform a binary chop. 1424 1425 if ((a == 0 || b == 0 || 1426 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) 1427 && (b - a) < T(0.25) * (b0 - a0)) 1428 { 1429 baditer = 1; 1430 continue; 1431 } 1432 1433 // DAC: If this happens on consecutive iterations, we probably have a 1434 // pathological function. Perform a number of bisections equal to the 1435 // total number of consecutive bad iterations. 1436 1437 if ((b - a) < T(0.25) * (b0 - a0)) 1438 baditer = 1; 1439 foreach (int QQ; 0 .. baditer) 1440 { 1441 e = d; 1442 fe = fd; 1443 1444 T w; 1445 if ((a>0 && b<0) || (a<0 && b>0)) 1446 w = 0; 1447 else 1448 { 1449 T usea = a; 1450 T useb = b; 1451 if (a == 0) 1452 usea = copysign(T(0), b); 1453 else if (b == 0) 1454 useb = copysign(T(0), a); 1455 w = ieeeMean(usea, useb); 1456 } 1457 bracket(w); 1458 } 1459 ++baditer; 1460 } 1461 return Tuple!(T, T, R, R)(a, b, fa, fb); 1462 } 1463 1464 ///ditto 1465 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, 1466 const T ax, const T bx, const R fax, const R fbx) 1467 { 1468 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); 1469 } 1470 1471 ///ditto 1472 T findRoot(T, R)(scope R delegate(T) f, const T a, const T b, 1473 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) 1474 { 1475 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); 1476 } 1477 1478 @safe nothrow unittest 1479 { 1480 int numProblems = 0; 1481 int numCalls; 1482 1483 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure 1484 { 1485 //numCalls=0; 1486 //++numProblems; 1487 assert(!x1.isNaN() && !x2.isNaN()); 1488 assert(signbit(f(x1)) != signbit(f(x2))); 1489 auto result = findRoot(f, x1, x2, f(x1), f(x2), 1490 (real lo, real hi) { return false; }); 1491 1492 auto flo = f(result[0]); 1493 auto fhi = f(result[1]); 1494 if (flo != 0) 1495 { 1496 assert(oppositeSigns(flo, fhi)); 1497 } 1498 } 1499 1500 // Test functions 1501 real cubicfn(real x) @nogc @safe nothrow pure 1502 { 1503 //++numCalls; 1504 if (x>float.max) 1505 x = float.max; 1506 if (x<-float.max) 1507 x = -float.max; 1508 // This has a single real root at -59.286543284815 1509 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; 1510 } 1511 // Test a function with more than one root. 1512 real multisine(real x) { ++numCalls; return sin(x); } 1513 testFindRoot( &multisine, 6, 90); 1514 testFindRoot(&cubicfn, -100, 100); 1515 testFindRoot( &cubicfn, -double.max, real.max); 1516 1517 1518 /* Tests from the paper: 1519 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 1520 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). 1521 */ 1522 // Parameters common to many alefeld tests. 1523 int n; 1524 real ale_a, ale_b; 1525 1526 int powercalls = 0; 1527 1528 real power(real x) 1529 { 1530 ++powercalls; 1531 ++numCalls; 1532 return pow(x, n) + double.min_normal; 1533 } 1534 int [] power_nvals = [3, 5, 7, 9, 19, 25]; 1535 // Alefeld paper states that pow(x,n) is a very poor case, where bisection 1536 // outperforms his method, and gives total numcalls = 1537 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 1538 // 2624 for brent (6.8/bit) 1539 // ... but that is for double, not real80. 1540 // This poor performance seems mainly due to catastrophic cancellation, 1541 // which is avoided here by the use of ieeeMean(). 1542 // I get: 231 (0.48/bit). 1543 // IE this is 10X faster in Alefeld's worst case 1544 numProblems=0; 1545 foreach (k; power_nvals) 1546 { 1547 n = k; 1548 testFindRoot(&power, -1, 10); 1549 } 1550 1551 int powerProblems = numProblems; 1552 1553 // Tests from Alefeld paper 1554 1555 int [9] alefeldSums; 1556 real alefeld0(real x) 1557 { 1558 ++alefeldSums[0]; 1559 ++numCalls; 1560 real q = sin(x) - x/2; 1561 for (int i=1; i<20; ++i) 1562 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 1563 return q; 1564 } 1565 real alefeld1(real x) 1566 { 1567 ++numCalls; 1568 ++alefeldSums[1]; 1569 return ale_a*x + exp(ale_b * x); 1570 } 1571 real alefeld2(real x) 1572 { 1573 ++numCalls; 1574 ++alefeldSums[2]; 1575 return pow(x, n) - ale_a; 1576 } 1577 real alefeld3(real x) 1578 { 1579 ++numCalls; 1580 ++alefeldSums[3]; 1581 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 1582 } 1583 real alefeld4(real x) 1584 { 1585 ++numCalls; 1586 ++alefeldSums[4]; 1587 return x*x - pow(1-x, n); 1588 } 1589 real alefeld5(real x) 1590 { 1591 ++numCalls; 1592 ++alefeldSums[5]; 1593 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 1594 } 1595 real alefeld6(real x) 1596 { 1597 ++numCalls; 1598 ++alefeldSums[6]; 1599 return exp(-n*x)*(x-1.01L) + pow(x, n); 1600 } 1601 real alefeld7(real x) 1602 { 1603 ++numCalls; 1604 ++alefeldSums[7]; 1605 return (n*x-1)/((n-1)*x); 1606 } 1607 1608 numProblems=0; 1609 testFindRoot(&alefeld0, PI_2, PI); 1610 for (n=1; n <= 10; ++n) 1611 { 1612 testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 1613 } 1614 ale_a = -40; ale_b = -1; 1615 testFindRoot(&alefeld1, -9, 31); 1616 ale_a = -100; ale_b = -2; 1617 testFindRoot(&alefeld1, -9, 31); 1618 ale_a = -200; ale_b = -3; 1619 testFindRoot(&alefeld1, -9, 31); 1620 int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 1621 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 1622 int [] nvals_6 = [1, 5, 10, 15, 20]; 1623 int [] nvals_7 = [2, 5, 15, 20]; 1624 1625 for (int i=4; i<12; i+=2) 1626 { 1627 n = i; 1628 ale_a = 0.2; 1629 testFindRoot(&alefeld2, 0, 5); 1630 ale_a=1; 1631 testFindRoot(&alefeld2, 0.95, 4.05); 1632 testFindRoot(&alefeld2, 0, 1.5); 1633 } 1634 foreach (i; nvals_3) 1635 { 1636 n=i; 1637 testFindRoot(&alefeld3, 0, 1); 1638 } 1639 foreach (i; nvals_3) 1640 { 1641 n=i; 1642 testFindRoot(&alefeld4, 0, 1); 1643 } 1644 foreach (i; nvals_5) 1645 { 1646 n=i; 1647 testFindRoot(&alefeld5, 0, 1); 1648 } 1649 foreach (i; nvals_6) 1650 { 1651 n=i; 1652 testFindRoot(&alefeld6, 0, 1); 1653 } 1654 foreach (i; nvals_7) 1655 { 1656 n=i; 1657 testFindRoot(&alefeld7, 0.01L, 1); 1658 } 1659 real worstcase(real x) 1660 { 1661 ++numCalls; 1662 return x<0.3*real.max? -0.999e-3 : 1.0; 1663 } 1664 testFindRoot(&worstcase, -real.max, real.max); 1665 1666 // just check that the double + float cases compile 1667 findRoot((double x){ return 0.0; }, -double.max, double.max); 1668 findRoot((float x){ return 0.0f; }, -float.max, float.max); 1669 1670 /* 1671 int grandtotal=0; 1672 foreach (calls; alefeldSums) 1673 { 1674 grandtotal+=calls; 1675 } 1676 grandtotal-=2*numProblems; 1677 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 1678 grandtotal, (1.0*grandtotal)/numProblems); 1679 powercalls -= 2*powerProblems; 1680 printf("POWER TOTAL = %d avg = %f ", powercalls, 1681 (1.0*powercalls)/powerProblems); 1682 */ 1683 // https://issues.dlang.org/show_bug.cgi?id=14231 1684 auto xp = findRoot((float x) => x, 0f, 1f); 1685 auto xn = findRoot((float x) => x, -1f, -0f); 1686 } 1687 1688 //regression control 1689 @system unittest 1690 { 1691 // @system due to the case in the 2nd line 1692 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); 1693 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); 1694 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); 1695 } 1696 1697 /++ 1698 Find a real minimum of a real function `f(x)` via bracketing. 1699 Given a function `f` and a range `(ax .. bx)`, 1700 returns the value of `x` in the range which is closest to a minimum of `f(x)`. 1701 `f` is never evaluted at the endpoints of `ax` and `bx`. 1702 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. 1703 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; 1704 otherwise, this algorithm is guaranteed to succeed. 1705 1706 Params: 1707 f = Function to be analyzed 1708 ax = Left bound of initial range of f known to contain the minimum. 1709 bx = Right bound of initial range of f known to contain the minimum. 1710 relTolerance = Relative tolerance. 1711 absTolerance = Absolute tolerance. 1712 1713 Preconditions: 1714 `ax` and `bx` shall be finite reals. $(BR) 1715 `relTolerance` shall be normal positive real. $(BR) 1716 `absTolerance` shall be normal positive real no less then `T.epsilon*2`. 1717 1718 Returns: 1719 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. 1720 1721 The method used is a combination of golden section search and 1722 successive parabolic interpolation. Convergence is never much slower 1723 than that for a Fibonacci search. 1724 1725 References: 1726 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) 1727 1728 See_Also: $(LREF findRoot), $(REF isNormal, std,math) 1729 +/ 1730 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") 1731 findLocalMin(T, DF)( 1732 scope DF f, 1733 const T ax, 1734 const T bx, 1735 const T relTolerance = sqrt(T.epsilon), 1736 const T absTolerance = sqrt(T.epsilon), 1737 ) 1738 if (isFloatingPoint!T 1739 && __traits(compiles, {T _ = DF.init(T.init);})) 1740 in 1741 { 1742 assert(isFinite(ax), "ax is not finite"); 1743 assert(isFinite(bx), "bx is not finite"); 1744 assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); 1745 assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); 1746 assert(relTolerance >= 0, "absTolerance is not positive"); 1747 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); 1748 } 1749 out (result) 1750 { 1751 assert(isFinite(result.x)); 1752 } 1753 do 1754 { 1755 alias R = Unqual!(CommonType!(ReturnType!DF, T)); 1756 // c is the squared inverse of the golden ratio 1757 // (3 - sqrt(5))/2 1758 // Value obtained from Wolfram Alpha. 1759 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; 1760 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; 1761 R tolerance; 1762 T a = ax > bx ? bx : ax; 1763 T b = ax > bx ? ax : bx; 1764 // sequence of declarations suitable for SIMD instructions 1765 T v = a * cm1 + b * c; 1766 assert(isFinite(v)); 1767 R fv = f(v); 1768 if (isNaN(fv) || fv == -T.infinity) 1769 { 1770 return typeof(return)(v, fv, T.init); 1771 } 1772 T w = v; 1773 R fw = fv; 1774 T x = v; 1775 R fx = fv; 1776 size_t i; 1777 for (R d = 0, e = 0;;) 1778 { 1779 i++; 1780 T m = (a + b) / 2; 1781 // This fix is not part of the original algorithm 1782 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. 1783 { 1784 m = a / 2 + b / 2; 1785 if (!isFinite(m)) // fast-math compiler switch is enabled 1786 { 1787 //SIMD instructions can be used by compiler, do not reduce declarations 1788 int a_exp = void; 1789 int b_exp = void; 1790 immutable an = frexp(a, a_exp); 1791 immutable bn = frexp(b, b_exp); 1792 immutable am = ldexp(an, a_exp-1); 1793 immutable bm = ldexp(bn, b_exp-1); 1794 m = am + bm; 1795 if (!isFinite(m)) // wrong input: constraints are disabled in release mode 1796 { 1797 return typeof(return).init; 1798 } 1799 } 1800 } 1801 tolerance = absTolerance * fabs(x) + relTolerance; 1802 immutable t2 = tolerance * 2; 1803 // check stopping criterion 1804 if (!(fabs(x - m) > t2 - (b - a) / 2)) 1805 { 1806 break; 1807 } 1808 R p = 0; 1809 R q = 0; 1810 R r = 0; 1811 // fit parabola 1812 if (fabs(e) > tolerance) 1813 { 1814 immutable xw = x - w; 1815 immutable fxw = fx - fw; 1816 immutable xv = x - v; 1817 immutable fxv = fx - fv; 1818 immutable xwfxv = xw * fxv; 1819 immutable xvfxw = xv * fxw; 1820 p = xv * xvfxw - xw * xwfxv; 1821 q = (xvfxw - xwfxv) * 2; 1822 if (q > 0) 1823 p = -p; 1824 else 1825 q = -q; 1826 r = e; 1827 e = d; 1828 } 1829 T u; 1830 // a parabolic-interpolation step 1831 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) 1832 { 1833 d = p / q; 1834 u = x + d; 1835 // f must not be evaluated too close to a or b 1836 if (u - a < t2 || b - u < t2) 1837 d = x < m ? tolerance : -tolerance; 1838 } 1839 // a golden-section step 1840 else 1841 { 1842 e = (x < m ? b : a) - x; 1843 d = c * e; 1844 } 1845 // f must not be evaluated too close to x 1846 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); 1847 immutable fu = f(u); 1848 if (isNaN(fu) || fu == -T.infinity) 1849 { 1850 return typeof(return)(u, fu, T.init); 1851 } 1852 // update a, b, v, w, and x 1853 if (fu <= fx) 1854 { 1855 (u < x ? b : a) = x; 1856 v = w; fv = fw; 1857 w = x; fw = fx; 1858 x = u; fx = fu; 1859 } 1860 else 1861 { 1862 (u < x ? a : b) = u; 1863 if (fu <= fw || w == x) 1864 { 1865 v = w; fv = fw; 1866 w = u; fw = fu; 1867 } 1868 else if (fu <= fv || v == x || v == w) 1869 { // do not remove this braces 1870 v = u; fv = fu; 1871 } 1872 } 1873 } 1874 return typeof(return)(x, fx, tolerance * 3); 1875 } 1876 1877 /// 1878 @safe unittest 1879 { 1880 import std.math.operations : isClose; 1881 1882 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); 1883 assert(ret.x.isClose(4.0)); 1884 assert(ret.y.isClose(0.0, 0.0, 1e-10)); 1885 } 1886 1887 @safe unittest 1888 { 1889 import std.meta : AliasSeq; 1890 static foreach (T; AliasSeq!(double, float, real)) 1891 { 1892 { 1893 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); 1894 assert(ret.x.isClose(T(4))); 1895 assert(ret.y.isClose(T(0), 0.0, T.epsilon)); 1896 } 1897 { 1898 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); 1899 assert(isClose(ret.x, T(1))); 1900 assert(isClose(ret.y, T(0), 0.0, T.epsilon)); 1901 assert(ret.error <= 10 * T.epsilon); 1902 } 1903 { 1904 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); 1905 assert(!ret.x.isNaN); 1906 assert(ret.y.isNaN); 1907 assert(ret.error.isNaN); 1908 } 1909 { 1910 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); 1911 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); 1912 assert(ret.x >= 0 && ret.x <= ret.error); 1913 } 1914 { 1915 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); 1916 assert(ret.y < -18); 1917 assert(ret.error < 5e-08); 1918 assert(ret.x >= 0 && ret.x <= ret.error); 1919 } 1920 { 1921 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); 1922 assert(ret.x.fabs.isClose(T(1))); 1923 assert(ret.y.fabs.isClose(T(1))); 1924 assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon)); 1925 } 1926 } 1927 } 1928 1929 /** 1930 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, 1931 Euclidean distance) between input ranges `a` and 1932 `b`. The two ranges must have the same length. The three-parameter 1933 version stops computation as soon as the distance is greater than or 1934 equal to `limit` (this is useful to save computation if a small 1935 distance is sought). 1936 */ 1937 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1938 euclideanDistance(Range1, Range2)(Range1 a, Range2 b) 1939 if (isInputRange!(Range1) && isInputRange!(Range2)) 1940 { 1941 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1942 static if (haveLen) assert(a.length == b.length); 1943 Unqual!(typeof(return)) result = 0; 1944 for (; !a.empty; a.popFront(), b.popFront()) 1945 { 1946 immutable t = a.front - b.front; 1947 result += t * t; 1948 } 1949 static if (!haveLen) assert(b.empty); 1950 return sqrt(result); 1951 } 1952 1953 /// Ditto 1954 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1955 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) 1956 if (isInputRange!(Range1) && isInputRange!(Range2)) 1957 { 1958 limit *= limit; 1959 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1960 static if (haveLen) assert(a.length == b.length); 1961 Unqual!(typeof(return)) result = 0; 1962 for (; ; a.popFront(), b.popFront()) 1963 { 1964 if (a.empty) 1965 { 1966 static if (!haveLen) assert(b.empty); 1967 break; 1968 } 1969 immutable t = a.front - b.front; 1970 result += t * t; 1971 if (result >= limit) break; 1972 } 1973 return sqrt(result); 1974 } 1975 1976 @safe unittest 1977 { 1978 import std.meta : AliasSeq; 1979 static foreach (T; AliasSeq!(double, const double, immutable double)) 1980 {{ 1981 T[] a = [ 1.0, 2.0, ]; 1982 T[] b = [ 4.0, 6.0, ]; 1983 assert(euclideanDistance(a, b) == 5); 1984 assert(euclideanDistance(a, b, 6) == 5); 1985 assert(euclideanDistance(a, b, 5) == 5); 1986 assert(euclideanDistance(a, b, 4) == 5); 1987 assert(euclideanDistance(a, b, 2) == 3); 1988 }} 1989 } 1990 1991 /** 1992 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, 1993 dot product) of input ranges `a` and $(D 1994 b). The two ranges must have the same length. If both ranges define 1995 length, the check is done once; otherwise, it is done at each 1996 iteration. 1997 */ 1998 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1999 dotProduct(Range1, Range2)(Range1 a, Range2 b) 2000 if (isInputRange!(Range1) && isInputRange!(Range2) && 2001 !(isArray!(Range1) && isArray!(Range2))) 2002 { 2003 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2004 static if (haveLen) assert(a.length == b.length); 2005 Unqual!(typeof(return)) result = 0; 2006 for (; !a.empty; a.popFront(), b.popFront()) 2007 { 2008 result += a.front * b.front; 2009 } 2010 static if (!haveLen) assert(b.empty); 2011 return result; 2012 } 2013 2014 /// Ditto 2015 CommonType!(F1, F2) 2016 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) 2017 { 2018 immutable n = avector.length; 2019 assert(n == bvector.length); 2020 auto avec = avector.ptr, bvec = bvector.ptr; 2021 Unqual!(typeof(return)) sum0 = 0, sum1 = 0; 2022 2023 const all_endp = avec + n; 2024 const smallblock_endp = avec + (n & ~3); 2025 const bigblock_endp = avec + (n & ~15); 2026 2027 for (; avec != bigblock_endp; avec += 16, bvec += 16) 2028 { 2029 sum0 += avec[0] * bvec[0]; 2030 sum1 += avec[1] * bvec[1]; 2031 sum0 += avec[2] * bvec[2]; 2032 sum1 += avec[3] * bvec[3]; 2033 sum0 += avec[4] * bvec[4]; 2034 sum1 += avec[5] * bvec[5]; 2035 sum0 += avec[6] * bvec[6]; 2036 sum1 += avec[7] * bvec[7]; 2037 sum0 += avec[8] * bvec[8]; 2038 sum1 += avec[9] * bvec[9]; 2039 sum0 += avec[10] * bvec[10]; 2040 sum1 += avec[11] * bvec[11]; 2041 sum0 += avec[12] * bvec[12]; 2042 sum1 += avec[13] * bvec[13]; 2043 sum0 += avec[14] * bvec[14]; 2044 sum1 += avec[15] * bvec[15]; 2045 } 2046 2047 for (; avec != smallblock_endp; avec += 4, bvec += 4) 2048 { 2049 sum0 += avec[0] * bvec[0]; 2050 sum1 += avec[1] * bvec[1]; 2051 sum0 += avec[2] * bvec[2]; 2052 sum1 += avec[3] * bvec[3]; 2053 } 2054 2055 sum0 += sum1; 2056 2057 /* Do trailing portion in naive loop. */ 2058 while (avec != all_endp) 2059 { 2060 sum0 += *avec * *bvec; 2061 ++avec; 2062 ++bvec; 2063 } 2064 2065 return sum0; 2066 } 2067 2068 /// ditto 2069 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b) 2070 if (N <= 16) 2071 { 2072 F sum0 = 0; 2073 F sum1 = 0; 2074 static foreach (i; 0 .. N / 2) 2075 { 2076 sum0 += a[i*2] * b[i*2]; 2077 sum1 += a[i*2+1] * b[i*2+1]; 2078 } 2079 static if (N % 2 == 1) 2080 { 2081 sum0 += a[N-1] * b[N-1]; 2082 } 2083 return sum0 + sum1; 2084 } 2085 2086 @system unittest 2087 { 2088 // @system due to dotProduct and assertCTFEable 2089 import std.exception : assertCTFEable; 2090 import std.meta : AliasSeq; 2091 static foreach (T; AliasSeq!(double, const double, immutable double)) 2092 {{ 2093 T[] a = [ 1.0, 2.0, ]; 2094 T[] b = [ 4.0, 6.0, ]; 2095 assert(dotProduct(a, b) == 16); 2096 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); 2097 // Test with fixed-length arrays. 2098 T[2] c = [ 1.0, 2.0, ]; 2099 T[2] d = [ 4.0, 6.0, ]; 2100 assert(dotProduct(c, d) == 16); 2101 T[3] e = [1, 3, -5]; 2102 T[3] f = [4, -2, -1]; 2103 assert(dotProduct(e, f) == 3); 2104 }} 2105 2106 // Make sure the unrolled loop codepath gets tested. 2107 static const x = 2108 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]; 2109 static const y = 2110 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]; 2111 assertCTFEable!({ assert(dotProduct(x, y) == 4048); }); 2112 } 2113 2114 /** 2115 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, 2116 cosine similarity) of input ranges `a` and $(D 2117 b). The two ranges must have the same length. If both ranges define 2118 length, the check is done once; otherwise, it is done at each 2119 iteration. If either range has all-zero elements, return 0. 2120 */ 2121 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2122 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) 2123 if (isInputRange!(Range1) && isInputRange!(Range2)) 2124 { 2125 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2126 static if (haveLen) assert(a.length == b.length); 2127 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; 2128 for (; !a.empty; a.popFront(), b.popFront()) 2129 { 2130 immutable t1 = a.front, t2 = b.front; 2131 norma += t1 * t1; 2132 normb += t2 * t2; 2133 dotprod += t1 * t2; 2134 } 2135 static if (!haveLen) assert(b.empty); 2136 if (norma == 0 || normb == 0) return 0; 2137 return dotprod / sqrt(norma * normb); 2138 } 2139 2140 @safe unittest 2141 { 2142 import std.meta : AliasSeq; 2143 static foreach (T; AliasSeq!(double, const double, immutable double)) 2144 {{ 2145 T[] a = [ 1.0, 2.0, ]; 2146 T[] b = [ 4.0, 3.0, ]; 2147 assert(isClose( 2148 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), 2149 0.01)); 2150 }} 2151 } 2152 2153 /** 2154 Normalizes values in `range` by multiplying each element with a 2155 number chosen such that values sum up to `sum`. If elements in $(D 2156 range) sum to zero, assigns $(D sum / range.length) to 2157 all. Normalization makes sense only if all elements in `range` are 2158 positive. `normalize` assumes that is the case without checking it. 2159 2160 Returns: `true` if normalization completed normally, `false` if 2161 all elements in `range` were zero or if `range` is empty. 2162 */ 2163 bool normalize(R)(R range, ElementType!(R) sum = 1) 2164 if (isForwardRange!(R)) 2165 { 2166 ElementType!(R) s = 0; 2167 // Step 1: Compute sum and length of the range 2168 static if (hasLength!(R)) 2169 { 2170 const length = range.length; 2171 foreach (e; range) 2172 { 2173 s += e; 2174 } 2175 } 2176 else 2177 { 2178 uint length = 0; 2179 foreach (e; range) 2180 { 2181 s += e; 2182 ++length; 2183 } 2184 } 2185 // Step 2: perform normalization 2186 if (s == 0) 2187 { 2188 if (length) 2189 { 2190 immutable f = sum / range.length; 2191 foreach (ref e; range) e = f; 2192 } 2193 return false; 2194 } 2195 // The path most traveled 2196 assert(s >= 0); 2197 immutable f = sum / s; 2198 foreach (ref e; range) 2199 e *= f; 2200 return true; 2201 } 2202 2203 /// 2204 @safe unittest 2205 { 2206 double[] a = []; 2207 assert(!normalize(a)); 2208 a = [ 1.0, 3.0 ]; 2209 assert(normalize(a)); 2210 assert(a == [ 0.25, 0.75 ]); 2211 assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5] 2212 a = [ 0.0, 0.0 ]; 2213 assert(!normalize(a)); 2214 assert(a == [ 0.5, 0.5 ]); 2215 } 2216 2217 /** 2218 Compute the sum of binary logarithms of the input range `r`. 2219 The error of this method is much smaller than with a naive sum of log2. 2220 */ 2221 ElementType!Range sumOfLog2s(Range)(Range r) 2222 if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) 2223 { 2224 long exp = 0; 2225 Unqual!(typeof(return)) x = 1; 2226 foreach (e; r) 2227 { 2228 if (e < 0) 2229 return typeof(return).nan; 2230 int lexp = void; 2231 x *= frexp(e, lexp); 2232 exp += lexp; 2233 if (x < 0.5) 2234 { 2235 x *= 2; 2236 exp--; 2237 } 2238 } 2239 return exp + log2(x); 2240 } 2241 2242 /// 2243 @safe unittest 2244 { 2245 import std.math.traits : isNaN; 2246 2247 assert(sumOfLog2s(new double[0]) == 0); 2248 assert(sumOfLog2s([0.0L]) == -real.infinity); 2249 assert(sumOfLog2s([-0.0L]) == -real.infinity); 2250 assert(sumOfLog2s([2.0L]) == 1); 2251 assert(sumOfLog2s([-2.0L]).isNaN()); 2252 assert(sumOfLog2s([real.nan]).isNaN()); 2253 assert(sumOfLog2s([-real.nan]).isNaN()); 2254 assert(sumOfLog2s([real.infinity]) == real.infinity); 2255 assert(sumOfLog2s([-real.infinity]).isNaN()); 2256 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); 2257 } 2258 2259 /** 2260 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), 2261 _entropy) of input range `r` in bits. This 2262 function assumes (without checking) that the values in `r` are all 2263 in $(D [0, 1]). For the entropy to be meaningful, often `r` should 2264 be normalized too (i.e., its values should sum to 1). The 2265 two-parameter version stops evaluating as soon as the intermediate 2266 result is greater than or equal to `max`. 2267 */ 2268 ElementType!Range entropy(Range)(Range r) 2269 if (isInputRange!Range) 2270 { 2271 Unqual!(typeof(return)) result = 0.0; 2272 for (;!r.empty; r.popFront) 2273 { 2274 if (!r.front) continue; 2275 result -= r.front * log2(r.front); 2276 } 2277 return result; 2278 } 2279 2280 /// Ditto 2281 ElementType!Range entropy(Range, F)(Range r, F max) 2282 if (isInputRange!Range && 2283 !is(CommonType!(ElementType!Range, F) == void)) 2284 { 2285 Unqual!(typeof(return)) result = 0.0; 2286 for (;!r.empty; r.popFront) 2287 { 2288 if (!r.front) continue; 2289 result -= r.front * log2(r.front); 2290 if (result >= max) break; 2291 } 2292 return result; 2293 } 2294 2295 @safe unittest 2296 { 2297 import std.meta : AliasSeq; 2298 static foreach (T; AliasSeq!(double, const double, immutable double)) 2299 {{ 2300 T[] p = [ 0.0, 0, 0, 1 ]; 2301 assert(entropy(p) == 0); 2302 p = [ 0.25, 0.25, 0.25, 0.25 ]; 2303 assert(entropy(p) == 2); 2304 assert(entropy(p, 1) == 1); 2305 }} 2306 } 2307 2308 /** 2309 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, 2310 Kullback-Leibler divergence) between input ranges 2311 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base 2312 of logarithm is 2. The ranges are assumed to contain elements in $(D 2313 [0, 1]). Usually the ranges are normalized probability distributions, 2314 but this is not required or checked by $(D 2315 kullbackLeiblerDivergence). If any element `bi` is zero and the 2316 corresponding element `ai` nonzero, returns infinity. (Otherwise, 2317 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is 2318 considered zero.) If the inputs are normalized, the result is 2319 positive. 2320 */ 2321 CommonType!(ElementType!Range1, ElementType!Range2) 2322 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) 2323 if (isInputRange!(Range1) && isInputRange!(Range2)) 2324 { 2325 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2326 static if (haveLen) assert(a.length == b.length); 2327 Unqual!(typeof(return)) result = 0; 2328 for (; !a.empty; a.popFront(), b.popFront()) 2329 { 2330 immutable t1 = a.front; 2331 if (t1 == 0) continue; 2332 immutable t2 = b.front; 2333 if (t2 == 0) return result.infinity; 2334 assert(t1 > 0 && t2 > 0); 2335 result += t1 * log2(t1 / t2); 2336 } 2337 static if (!haveLen) assert(b.empty); 2338 return result; 2339 } 2340 2341 /// 2342 @safe unittest 2343 { 2344 import std.math.operations : isClose; 2345 2346 double[] p = [ 0.0, 0, 0, 1 ]; 2347 assert(kullbackLeiblerDivergence(p, p) == 0); 2348 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2349 assert(kullbackLeiblerDivergence(p1, p1) == 0); 2350 assert(kullbackLeiblerDivergence(p, p1) == 2); 2351 assert(kullbackLeiblerDivergence(p1, p) == double.infinity); 2352 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2353 assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5)); 2354 assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5)); 2355 } 2356 2357 /** 2358 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, 2359 Jensen-Shannon divergence) between `a` and $(D 2360 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * 2361 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are 2362 assumed to contain elements in $(D [0, 1]). Usually the ranges are 2363 normalized probability distributions, but this is not required or 2364 checked by `jensenShannonDivergence`. If the inputs are normalized, 2365 the result is bounded within $(D [0, 1]). The three-parameter version 2366 stops evaluations as soon as the intermediate result is greater than 2367 or equal to `limit`. 2368 */ 2369 CommonType!(ElementType!Range1, ElementType!Range2) 2370 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) 2371 if (isInputRange!Range1 && isInputRange!Range2 && 2372 is(CommonType!(ElementType!Range1, ElementType!Range2))) 2373 { 2374 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2375 static if (haveLen) assert(a.length == b.length); 2376 Unqual!(typeof(return)) result = 0; 2377 for (; !a.empty; a.popFront(), b.popFront()) 2378 { 2379 immutable t1 = a.front; 2380 immutable t2 = b.front; 2381 immutable avg = (t1 + t2) / 2; 2382 if (t1 != 0) 2383 { 2384 result += t1 * log2(t1 / avg); 2385 } 2386 if (t2 != 0) 2387 { 2388 result += t2 * log2(t2 / avg); 2389 } 2390 } 2391 static if (!haveLen) assert(b.empty); 2392 return result / 2; 2393 } 2394 2395 /// Ditto 2396 CommonType!(ElementType!Range1, ElementType!Range2) 2397 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2398 if (isInputRange!Range1 && isInputRange!Range2 && 2399 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init 2400 >= F.init) : bool)) 2401 { 2402 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2403 static if (haveLen) assert(a.length == b.length); 2404 Unqual!(typeof(return)) result = 0; 2405 limit *= 2; 2406 for (; !a.empty; a.popFront(), b.popFront()) 2407 { 2408 immutable t1 = a.front; 2409 immutable t2 = b.front; 2410 immutable avg = (t1 + t2) / 2; 2411 if (t1 != 0) 2412 { 2413 result += t1 * log2(t1 / avg); 2414 } 2415 if (t2 != 0) 2416 { 2417 result += t2 * log2(t2 / avg); 2418 } 2419 if (result >= limit) break; 2420 } 2421 static if (!haveLen) assert(b.empty); 2422 return result / 2; 2423 } 2424 2425 /// 2426 @safe unittest 2427 { 2428 import std.math.operations : isClose; 2429 2430 double[] p = [ 0.0, 0, 0, 1 ]; 2431 assert(jensenShannonDivergence(p, p) == 0); 2432 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2433 assert(jensenShannonDivergence(p1, p1) == 0); 2434 assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5)); 2435 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2436 assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5)); 2437 assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5)); 2438 assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5)); 2439 } 2440 2441 /** 2442 The so-called "all-lengths gap-weighted string kernel" computes a 2443 similarity measure between `s` and `t` based on all of their 2444 common subsequences of all lengths. Gapped subsequences are also 2445 included. 2446 2447 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, 2448 consider first the case $(D lambda = 1) and the strings $(D s = 2449 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", 2450 "world"]). In that case, `gapWeightedSimilarity` counts the 2451 following matches: 2452 2453 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`, 2454 and `"world"`;) $(LI three matches of length 2, namely ($(D 2455 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) 2456 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) 2457 2458 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of 2459 these matches and adds them up, returning 7. 2460 2461 ---- 2462 string[] s = ["Hello", "brave", "new", "world"]; 2463 string[] t = ["Hello", "new", "world"]; 2464 assert(gapWeightedSimilarity(s, t, 1) == 7); 2465 ---- 2466 2467 Note how the gaps in matching are simply ignored, for example ($(D 2468 "Hello", "new")) is deemed as good a match as ($(D "new", 2469 "world")). This may be too permissive for some applications. To 2470 eliminate gapped matches entirely, use $(D lambda = 0): 2471 2472 ---- 2473 string[] s = ["Hello", "brave", "new", "world"]; 2474 string[] t = ["Hello", "new", "world"]; 2475 assert(gapWeightedSimilarity(s, t, 0) == 4); 2476 ---- 2477 2478 The call above eliminated the gapped matches ($(D "Hello", "new")), 2479 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the 2480 tally. That leaves only 4 matches. 2481 2482 The most interesting case is when gapped matches still participate in 2483 the result, but not as strongly as ungapped matches. The result will 2484 be a smooth, fine-grained similarity measure between the input 2485 strings. This is where values of `lambda` between 0 and 1 enter 2486 into play: gapped matches are $(I exponentially penalized with the 2487 number of gaps) with base `lambda`. This means that an ungapped 2488 match adds 1 to the return value; a match with one gap in either 2489 string adds `lambda` to the return value; ...; a match with a total 2490 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return 2491 value. In the example above, we have 4 matches without gaps, 2 matches 2492 with one gap, and 1 match with three gaps. The latter match is ($(D 2493 "Hello", "world")), which has two gaps in the first string and one gap 2494 in the second string, totaling to three gaps. Summing these up we get 2495 $(D 4 + 2 * lambda + pow(lambda, 3)). 2496 2497 ---- 2498 string[] s = ["Hello", "brave", "new", "world"]; 2499 string[] t = ["Hello", "new", "world"]; 2500 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); 2501 ---- 2502 2503 `gapWeightedSimilarity` is useful wherever a smooth similarity 2504 measure between sequences allowing for approximate matches is 2505 needed. The examples above are given with words, but any sequences 2506 with elements comparable for equality are allowed, e.g. characters or 2507 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic 2508 programming implementation that needs $(D 16 * min(s.length, 2509 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time 2510 to complete. 2511 */ 2512 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) 2513 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2514 isRandomAccessRange!(R2) && hasLength!(R2)) 2515 { 2516 import core.exception : onOutOfMemoryError; 2517 import core.stdc.stdlib : malloc, free; 2518 import std.algorithm.mutation : swap; 2519 import std.functional : binaryFun; 2520 2521 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); 2522 if (!t.length) return 0; 2523 2524 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); 2525 if (!dpvi) 2526 onOutOfMemoryError(); 2527 2528 auto dpvi1 = dpvi + t.length; 2529 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); 2530 dpvi[0 .. t.length] = 0; 2531 dpvi1[0] = 0; 2532 immutable lambda2 = lambda * lambda; 2533 2534 F result = 0; 2535 foreach (i; 0 .. s.length) 2536 { 2537 const si = s[i]; 2538 for (size_t j = 0;;) 2539 { 2540 F dpsij = void; 2541 if (binaryFun!(comp)(si, t[j])) 2542 { 2543 dpsij = 1 + dpvi[j]; 2544 result += dpsij; 2545 } 2546 else 2547 { 2548 dpsij = 0; 2549 } 2550 immutable j1 = j + 1; 2551 if (j1 == t.length) break; 2552 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - 2553 lambda2 * dpvi[j]; 2554 j = j1; 2555 } 2556 swap(dpvi, dpvi1); 2557 } 2558 return result; 2559 } 2560 2561 @system unittest 2562 { 2563 string[] s = ["Hello", "brave", "new", "world"]; 2564 string[] t = ["Hello", "new", "world"]; 2565 assert(gapWeightedSimilarity(s, t, 1) == 7); 2566 assert(gapWeightedSimilarity(s, t, 0) == 4); 2567 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); 2568 } 2569 2570 /** 2571 The similarity per `gapWeightedSimilarity` has an issue in that it 2572 grows with the lengths of the two strings, even though the strings are 2573 not actually very similar. For example, the range $(D ["Hello", 2574 "world"]) is increasingly similar with the range $(D ["Hello", 2575 "world", "world", "world",...]) as more instances of `"world"` are 2576 appended. To prevent that, `gapWeightedSimilarityNormalized` 2577 computes a normalized version of the similarity that is computed as 2578 $(D gapWeightedSimilarity(s, t, lambda) / 2579 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, 2580 lambda))). The function `gapWeightedSimilarityNormalized` (a 2581 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0` 2582 only for ranges that don't match in any position, and `1` only for 2583 identical ranges. 2584 2585 The optional parameters `sSelfSim` and `tSelfSim` are meant for 2586 avoiding duplicate computation. Many applications may have already 2587 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D 2588 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed 2589 as `sSelfSim` and `tSelfSim`, respectively. 2590 */ 2591 Select!(isFloatingPoint!(F), F, double) 2592 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) 2593 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) 2594 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2595 isRandomAccessRange!(R2) && hasLength!(R2)) 2596 { 2597 static bool uncomputed(F n) 2598 { 2599 static if (isFloatingPoint!(F)) 2600 return isNaN(n); 2601 else 2602 return n == n.init; 2603 } 2604 if (uncomputed(sSelfSim)) 2605 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); 2606 if (sSelfSim == 0) return 0; 2607 if (uncomputed(tSelfSim)) 2608 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); 2609 if (tSelfSim == 0) return 0; 2610 2611 return gapWeightedSimilarity!(comp)(s, t, lambda) / 2612 sqrt(cast(typeof(return)) sSelfSim * tSelfSim); 2613 } 2614 2615 /// 2616 @system unittest 2617 { 2618 import std.math.operations : isClose; 2619 import std.math.algebraic : sqrt; 2620 2621 string[] s = ["Hello", "brave", "new", "world"]; 2622 string[] t = ["Hello", "new", "world"]; 2623 assert(gapWeightedSimilarity(s, s, 1) == 15); 2624 assert(gapWeightedSimilarity(t, t, 1) == 7); 2625 assert(gapWeightedSimilarity(s, t, 1) == 7); 2626 assert(isClose(gapWeightedSimilarityNormalized(s, t, 1), 2627 7.0 / sqrt(15.0 * 7), 0.01)); 2628 } 2629 2630 /** 2631 Similar to `gapWeightedSimilarity`, just works in an incremental 2632 manner by first revealing the matches of length 1, then gapped matches 2633 of length 2, and so on. The memory requirement is $(BIGOH s.length * 2634 t.length). The time complexity is $(BIGOH s.length * t.length) time 2635 for computing each step. Continuing on the previous example: 2636 2637 The implementation is based on the pseudocode in Fig. 4 of the paper 2638 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, 2639 "Efficient Computation of Gapped Substring Kernels on Large Alphabets") 2640 by Rousu et al., with additional algorithmic and systems-level 2641 optimizations. 2642 */ 2643 struct GapWeightedSimilarityIncremental(Range, F = double) 2644 if (isRandomAccessRange!(Range) && hasLength!(Range)) 2645 { 2646 import core.stdc.stdlib : malloc, realloc, alloca, free; 2647 2648 private: 2649 Range s, t; 2650 F currentValue = 0; 2651 F* kl; 2652 size_t gram = void; 2653 F lambda = void, lambda2 = void; 2654 2655 public: 2656 /** 2657 Constructs an object given two ranges `s` and `t` and a penalty 2658 `lambda`. Constructor completes in $(BIGOH s.length * t.length) 2659 time and computes all matches of length 1. 2660 */ 2661 this(Range s, Range t, F lambda) 2662 { 2663 import core.exception : onOutOfMemoryError; 2664 2665 assert(lambda > 0); 2666 this.gram = 0; 2667 this.lambda = lambda; 2668 this.lambda2 = lambda * lambda; // for efficiency only 2669 2670 size_t iMin = size_t.max, jMin = size_t.max, 2671 iMax = 0, jMax = 0; 2672 /* initialize */ 2673 Tuple!(size_t, size_t) * k0; 2674 size_t k0len; 2675 scope(exit) free(k0); 2676 currentValue = 0; 2677 foreach (i, si; s) 2678 { 2679 foreach (j; 0 .. t.length) 2680 { 2681 if (si != t[j]) continue; 2682 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); 2683 with (k0[k0len - 1]) 2684 { 2685 field[0] = i; 2686 field[1] = j; 2687 } 2688 // Maintain the minimum and maximum i and j 2689 if (iMin > i) iMin = i; 2690 if (iMax < i) iMax = i; 2691 if (jMin > j) jMin = j; 2692 if (jMax < j) jMax = j; 2693 } 2694 } 2695 2696 if (iMin > iMax) return; 2697 assert(k0len); 2698 2699 currentValue = k0len; 2700 // Chop strings down to the useful sizes 2701 s = s[iMin .. iMax + 1]; 2702 t = t[jMin .. jMax + 1]; 2703 this.s = s; 2704 this.t = t; 2705 2706 kl = cast(F*) malloc(s.length * t.length * F.sizeof); 2707 if (!kl) 2708 onOutOfMemoryError(); 2709 2710 kl[0 .. s.length * t.length] = 0; 2711 foreach (pos; 0 .. k0len) 2712 { 2713 with (k0[pos]) 2714 { 2715 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; 2716 } 2717 } 2718 } 2719 2720 /** 2721 Returns: `this`. 2722 */ 2723 ref GapWeightedSimilarityIncremental opSlice() 2724 { 2725 return this; 2726 } 2727 2728 /** 2729 Computes the match of the popFront length. Completes in $(BIGOH s.length * 2730 t.length) time. 2731 */ 2732 void popFront() 2733 { 2734 import std.algorithm.mutation : swap; 2735 2736 // This is a large source of optimization: if similarity at 2737 // the gram-1 level was 0, then we can safely assume 2738 // similarity at the gram level is 0 as well. 2739 if (empty) return; 2740 2741 // Now attempt to match gapped substrings of length `gram' 2742 ++gram; 2743 currentValue = 0; 2744 2745 auto Si = cast(F*) alloca(t.length * F.sizeof); 2746 Si[0 .. t.length] = 0; 2747 foreach (i; 0 .. s.length) 2748 { 2749 const si = s[i]; 2750 F Sij_1 = 0; 2751 F Si_1j_1 = 0; 2752 auto kli = kl + i * t.length; 2753 for (size_t j = 0;;) 2754 { 2755 const klij = kli[j]; 2756 const Si_1j = Si[j]; 2757 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; 2758 // now update kl and currentValue 2759 if (si == t[j]) 2760 currentValue += kli[j] = lambda2 * Si_1j_1; 2761 else 2762 kli[j] = 0; 2763 // commit to Si 2764 Si[j] = tmp; 2765 if (++j == t.length) break; 2766 // get ready for the popFront step; virtually increment j, 2767 // so essentially stuffj_1 <-- stuffj 2768 Si_1j_1 = Si_1j; 2769 Sij_1 = tmp; 2770 } 2771 } 2772 currentValue /= pow(lambda, 2 * (gram + 1)); 2773 2774 version (none) 2775 { 2776 Si_1[0 .. t.length] = 0; 2777 kl[0 .. min(t.length, maxPerimeter + 1)] = 0; 2778 foreach (i; 1 .. min(s.length, maxPerimeter + 1)) 2779 { 2780 auto kli = kl + i * t.length; 2781 assert(s.length > i); 2782 const si = s[i]; 2783 auto kl_1i_1 = kl_1 + (i - 1) * t.length; 2784 kli[0] = 0; 2785 F lastS = 0; 2786 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) 2787 { 2788 immutable j_1 = j - 1; 2789 immutable tmp = kl_1i_1[j_1] 2790 + lambda * (Si_1[j] + lastS) 2791 - lambda2 * Si_1[j_1]; 2792 kl_1i_1[j_1] = float.nan; 2793 Si_1[j_1] = lastS; 2794 lastS = tmp; 2795 if (si == t[j]) 2796 { 2797 currentValue += kli[j] = lambda2 * lastS; 2798 } 2799 else 2800 { 2801 kli[j] = 0; 2802 } 2803 } 2804 Si_1[t.length - 1] = lastS; 2805 } 2806 currentValue /= pow(lambda, 2 * (gram + 1)); 2807 // get ready for the popFront computation 2808 swap(kl, kl_1); 2809 } 2810 } 2811 2812 /** 2813 Returns: The gapped similarity at the current match length (initially 2814 1, grows with each call to `popFront`). 2815 */ 2816 @property F front() { return currentValue; } 2817 2818 /** 2819 Returns: Whether there are more matches. 2820 */ 2821 @property bool empty() 2822 { 2823 if (currentValue) return false; 2824 if (kl) 2825 { 2826 free(kl); 2827 kl = null; 2828 } 2829 return true; 2830 } 2831 } 2832 2833 /** 2834 Ditto 2835 */ 2836 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) 2837 (R r1, R r2, F penalty) 2838 { 2839 return typeof(return)(r1, r2, penalty); 2840 } 2841 2842 /// 2843 @system unittest 2844 { 2845 string[] s = ["Hello", "brave", "new", "world"]; 2846 string[] t = ["Hello", "new", "world"]; 2847 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2848 assert(simIter.front == 3); // three 1-length matches 2849 simIter.popFront(); 2850 assert(simIter.front == 3); // three 2-length matches 2851 simIter.popFront(); 2852 assert(simIter.front == 1); // one 3-length match 2853 simIter.popFront(); 2854 assert(simIter.empty); // no more match 2855 } 2856 2857 @system unittest 2858 { 2859 import std.conv : text; 2860 string[] s = ["Hello", "brave", "new", "world"]; 2861 string[] t = ["Hello", "new", "world"]; 2862 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2863 //foreach (e; simIter) writeln(e); 2864 assert(simIter.front == 3); // three 1-length matches 2865 simIter.popFront(); 2866 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches 2867 simIter.popFront(); 2868 assert(simIter.front == 1); // one 3-length matches 2869 simIter.popFront(); 2870 assert(simIter.empty); // no more match 2871 2872 s = ["Hello"]; 2873 t = ["bye"]; 2874 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2875 assert(simIter.empty); 2876 2877 s = ["Hello"]; 2878 t = ["Hello"]; 2879 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2880 assert(simIter.front == 1); // one match 2881 simIter.popFront(); 2882 assert(simIter.empty); 2883 2884 s = ["Hello", "world"]; 2885 t = ["Hello"]; 2886 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2887 assert(simIter.front == 1); // one match 2888 simIter.popFront(); 2889 assert(simIter.empty); 2890 2891 s = ["Hello", "world"]; 2892 t = ["Hello", "yah", "world"]; 2893 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2894 assert(simIter.front == 2); // two 1-gram matches 2895 simIter.popFront(); 2896 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap 2897 } 2898 2899 @system unittest 2900 { 2901 GapWeightedSimilarityIncremental!(string[]) sim = 2902 GapWeightedSimilarityIncremental!(string[])( 2903 ["nyuk", "I", "have", "no", "chocolate", "giba"], 2904 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], 2905 0.5); 2906 double[] witness = [ 7.0, 4.03125, 0, 0 ]; 2907 foreach (e; sim) 2908 { 2909 //writeln(e); 2910 assert(e == witness.front); 2911 witness.popFront(); 2912 } 2913 witness = [ 3.0, 1.3125, 0.25 ]; 2914 sim = GapWeightedSimilarityIncremental!(string[])( 2915 ["I", "have", "no", "chocolate"], 2916 ["I", "have", "some", "chocolate"], 2917 0.5); 2918 foreach (e; sim) 2919 { 2920 //writeln(e); 2921 assert(e == witness.front); 2922 witness.popFront(); 2923 } 2924 assert(witness.empty); 2925 } 2926 2927 /** 2928 Computes the greatest common divisor of `a` and `b` by using 2929 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) 2930 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. 2931 2932 Params: 2933 a = Integer value of any numerical type that supports the modulo operator `%`. 2934 If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will 2935 be used; otherwise, Euclid's algorithm is used as _a fallback. 2936 b = Integer value of any equivalent numerical type. 2937 2938 Returns: 2939 The greatest common divisor of the given arguments. 2940 */ 2941 typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b) 2942 if (isIntegral!T && isIntegral!U) 2943 { 2944 // Operate on a common type between the two arguments. 2945 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 2946 2947 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 2948 static if (is(T : immutable short) || is(T : immutable byte)) 2949 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 2950 else 2951 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 2952 2953 static if (is(U : immutable short) || is(U : immutable byte)) 2954 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 2955 else 2956 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 2957 2958 // Special cases. 2959 if (ax == 0) 2960 return bx; 2961 if (bx == 0) 2962 return ax; 2963 2964 return gcdImpl(ax, bx); 2965 } 2966 2967 private typeof(T.init % T.init) gcdImpl(T)(T a, T b) 2968 if (isIntegral!T) 2969 { 2970 pragma(inline, true); 2971 import core.bitop : bsf; 2972 import std.algorithm.mutation : swap; 2973 2974 immutable uint shift = bsf(a | b); 2975 a >>= a.bsf; 2976 do 2977 { 2978 b >>= b.bsf; 2979 if (a > b) 2980 swap(a, b); 2981 b -= a; 2982 } while (b); 2983 2984 return a << shift; 2985 } 2986 2987 /// 2988 @safe unittest 2989 { 2990 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); 2991 const int a = 5 * 13 * 23 * 23, b = 13 * 59; 2992 assert(gcd(a, b) == 13); 2993 } 2994 2995 @safe unittest 2996 { 2997 import std.meta : AliasSeq; 2998 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 2999 const byte, const short, const int, const long, 3000 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3001 { 3002 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3003 const ubyte, const ushort, const uint, const ulong, 3004 immutable byte, immutable short, immutable int, immutable long)) 3005 { 3006 // Signed and unsigned tests. 3007 static if (T.max > byte.max && U.max > byte.max) 3008 assert(gcd(T(200), U(200)) == 200); 3009 static if (T.max > ubyte.max) 3010 { 3011 assert(gcd(T(2000), U(20)) == 20); 3012 assert(gcd(T(2011), U(17)) == 1); 3013 } 3014 static if (T.max > ubyte.max && U.max > ubyte.max) 3015 assert(gcd(T(1071), U(462)) == 21); 3016 3017 assert(gcd(T(0), U(13)) == 13); 3018 assert(gcd(T(29), U(0)) == 29); 3019 assert(gcd(T(0), U(0)) == 0); 3020 assert(gcd(T(1), U(2)) == 1); 3021 assert(gcd(T(9), U(6)) == 3); 3022 assert(gcd(T(3), U(4)) == 1); 3023 assert(gcd(T(32), U(24)) == 8); 3024 assert(gcd(T(5), U(6)) == 1); 3025 assert(gcd(T(54), U(36)) == 18); 3026 3027 // Int and Long tests. 3028 static if (T.max > short.max && U.max > short.max) 3029 assert(gcd(T(46391), U(62527)) == 2017); 3030 static if (T.max > ushort.max && U.max > ushort.max) 3031 assert(gcd(T(63245986), U(39088169)) == 1); 3032 static if (T.max > uint.max && U.max > uint.max) 3033 { 3034 assert(gcd(T(77160074263), U(47687519812)) == 1); 3035 assert(gcd(T(77160074264), U(47687519812)) == 4); 3036 } 3037 3038 // Negative tests. 3039 static if (T.min < 0) 3040 { 3041 assert(gcd(T(-21), U(28)) == 7); 3042 assert(gcd(T(-3), U(4)) == 1); 3043 } 3044 static if (U.min < 0) 3045 { 3046 assert(gcd(T(1), U(-2)) == 1); 3047 assert(gcd(T(33), U(-44)) == 11); 3048 } 3049 static if (T.min < 0 && U.min < 0) 3050 { 3051 assert(gcd(T(-5), U(-6)) == 1); 3052 assert(gcd(T(-50), U(-60)) == 10); 3053 } 3054 } 3055 } 3056 } 3057 3058 // https://issues.dlang.org/show_bug.cgi?id=21834 3059 @safe unittest 3060 { 3061 assert(gcd(-120, 10U) == 10); 3062 assert(gcd(120U, -10) == 10); 3063 assert(gcd(int.min, 0L) == 1L + int.max); 3064 assert(gcd(0L, int.min) == 1L + int.max); 3065 assert(gcd(int.min, 0L + int.min) == 1L + int.max); 3066 assert(gcd(int.min, 1L + int.max) == 1L + int.max); 3067 assert(gcd(short.min, 1U + short.max) == 1U + short.max); 3068 } 3069 3070 // This overload is for non-builtin numerical types like BigInt or 3071 // user-defined types. 3072 /// ditto 3073 auto gcd(T)(T a, T b) 3074 if (!isIntegral!T && 3075 is(typeof(T.init % T.init)) && 3076 is(typeof(T.init == 0 || T.init > 0))) 3077 { 3078 static if (!is(T == Unqual!T)) 3079 { 3080 return gcd!(Unqual!T)(a, b); 3081 } 3082 else 3083 { 3084 // Ensure arguments are unsigned. 3085 a = a >= 0 ? a : -a; 3086 b = b >= 0 ? b : -b; 3087 3088 // Special cases. 3089 if (a == 0) 3090 return b; 3091 if (b == 0) 3092 return a; 3093 3094 return gcdImpl(a, b); 3095 } 3096 } 3097 3098 private auto gcdImpl(T)(T a, T b) 3099 if (!isIntegral!T) 3100 { 3101 pragma(inline, true); 3102 import std.algorithm.mutation : swap; 3103 enum canUseBinaryGcd = is(typeof(() { 3104 T t, u; 3105 t <<= 1; 3106 t >>= 1; 3107 t -= u; 3108 bool b = (t & 1) == 0; 3109 swap(t, u); 3110 })); 3111 3112 static if (canUseBinaryGcd) 3113 { 3114 uint shift = 0; 3115 while ((a & 1) == 0 && (b & 1) == 0) 3116 { 3117 a >>= 1; 3118 b >>= 1; 3119 shift++; 3120 } 3121 3122 if ((a & 1) == 0) swap(a, b); 3123 3124 do 3125 { 3126 assert((a & 1) != 0); 3127 while ((b & 1) == 0) 3128 b >>= 1; 3129 if (a > b) 3130 swap(a, b); 3131 b -= a; 3132 } while (b); 3133 3134 return a << shift; 3135 } 3136 else 3137 { 3138 // The only thing we have is %; fallback to Euclidean algorithm. 3139 while (b != 0) 3140 { 3141 auto t = b; 3142 b = a % b; 3143 a = t; 3144 } 3145 return a; 3146 } 3147 } 3148 3149 // https://issues.dlang.org/show_bug.cgi?id=7102 3150 @system pure unittest 3151 { 3152 import std.bigint : BigInt; 3153 assert(gcd(BigInt("71_000_000_000_000_000_000"), 3154 BigInt("31_000_000_000_000_000_000")) == 3155 BigInt("1_000_000_000_000_000_000")); 3156 3157 assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567)); 3158 assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567)); 3159 } 3160 3161 @safe pure nothrow unittest 3162 { 3163 // A numerical type that only supports % and - (to force gcd implementation 3164 // to use Euclidean algorithm). 3165 struct CrippledInt 3166 { 3167 int impl; 3168 CrippledInt opBinary(string op : "%")(CrippledInt i) 3169 { 3170 return CrippledInt(impl % i.impl); 3171 } 3172 CrippledInt opUnary(string op : "-")() 3173 { 3174 return CrippledInt(-impl); 3175 } 3176 int opEquals(CrippledInt i) { return impl == i.impl; } 3177 int opEquals(int i) { return impl == i; } 3178 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } 3179 } 3180 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); 3181 assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10)); 3182 assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10)); 3183 } 3184 3185 // https://issues.dlang.org/show_bug.cgi?id=19514 3186 @system pure unittest 3187 { 3188 import std.bigint : BigInt; 3189 assert(gcd(BigInt(2), BigInt(1)) == BigInt(1)); 3190 } 3191 3192 // Issue 20924 3193 @safe unittest 3194 { 3195 import std.bigint : BigInt; 3196 const a = BigInt("123143238472389492934020"); 3197 const b = BigInt("902380489324729338420924"); 3198 assert(__traits(compiles, gcd(a, b))); 3199 } 3200 3201 // https://issues.dlang.org/show_bug.cgi?id=21834 3202 @safe unittest 3203 { 3204 import std.bigint : BigInt; 3205 assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10)); 3206 assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10)); 3207 assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max)); 3208 assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max)); 3209 assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max)); 3210 assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max)); 3211 assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max)); 3212 } 3213 3214 3215 /** 3216 Computes the least common multiple of `a` and `b`. 3217 Arguments are the same as $(MYREF gcd). 3218 3219 Returns: 3220 The least common multiple of the given arguments. 3221 */ 3222 typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b) 3223 if (isIntegral!T && isIntegral!U) 3224 { 3225 // Operate on a common type between the two arguments. 3226 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 3227 3228 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 3229 static if (is(T : immutable short) || is(T : immutable byte)) 3230 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 3231 else 3232 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 3233 3234 static if (is(U : immutable short) || is(U : immutable byte)) 3235 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 3236 else 3237 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 3238 3239 // Special cases. 3240 if (ax == 0) 3241 return ax; 3242 if (bx == 0) 3243 return bx; 3244 3245 return (ax / gcdImpl(ax, bx)) * bx; 3246 } 3247 3248 /// 3249 @safe unittest 3250 { 3251 assert(lcm(1, 2) == 2); 3252 assert(lcm(3, 4) == 12); 3253 assert(lcm(5, 6) == 30); 3254 } 3255 3256 @safe unittest 3257 { 3258 import std.meta : AliasSeq; 3259 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3260 const byte, const short, const int, const long, 3261 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3262 { 3263 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3264 const ubyte, const ushort, const uint, const ulong, 3265 immutable byte, immutable short, immutable int, immutable long)) 3266 { 3267 assert(lcm(T(21), U(6)) == 42); 3268 assert(lcm(T(41), U(0)) == 0); 3269 assert(lcm(T(0), U(7)) == 0); 3270 assert(lcm(T(0), U(0)) == 0); 3271 assert(lcm(T(1U), U(2)) == 2); 3272 assert(lcm(T(3), U(4U)) == 12); 3273 assert(lcm(T(5U), U(6U)) == 30); 3274 static if (T.min < 0) 3275 assert(lcm(T(-42), U(21U)) == 42); 3276 } 3277 } 3278 } 3279 3280 /// ditto 3281 auto lcm(T)(T a, T b) 3282 if (!isIntegral!T && 3283 is(typeof(T.init % T.init)) && 3284 is(typeof(T.init == 0 || T.init > 0))) 3285 { 3286 // Ensure arguments are unsigned. 3287 a = a >= 0 ? a : -a; 3288 b = b >= 0 ? b : -b; 3289 3290 // Special cases. 3291 if (a == 0) 3292 return a; 3293 if (b == 0) 3294 return b; 3295 3296 return (a / gcdImpl(a, b)) * b; 3297 } 3298 3299 @safe unittest 3300 { 3301 import std.bigint : BigInt; 3302 assert(lcm(BigInt(21), BigInt(6)) == BigInt(42)); 3303 assert(lcm(BigInt(41), BigInt(0)) == BigInt(0)); 3304 assert(lcm(BigInt(0), BigInt(7)) == BigInt(0)); 3305 assert(lcm(BigInt(0), BigInt(0)) == BigInt(0)); 3306 assert(lcm(BigInt(1U), BigInt(2)) == BigInt(2)); 3307 assert(lcm(BigInt(3), BigInt(4U)) == BigInt(12)); 3308 assert(lcm(BigInt(5U), BigInt(6U)) == BigInt(30)); 3309 assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42)); 3310 } 3311 3312 // This is to make tweaking the speed/size vs. accuracy tradeoff easy, 3313 // though floats seem accurate enough for all practical purposes, since 3314 // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for 3315 // size 2 ^^ 22. 3316 private alias lookup_t = float; 3317 3318 /**A class for performing fast Fourier transforms of power of two sizes. 3319 * This class encapsulates a large amount of state that is reusable when 3320 * performing multiple FFTs of sizes smaller than or equal to that specified 3321 * in the constructor. This results in substantial speedups when performing 3322 * multiple FFTs with a known maximum size. However, 3323 * a free function API is provided for convenience if you need to perform a 3324 * one-off FFT. 3325 * 3326 * References: 3327 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) 3328 */ 3329 final class Fft 3330 { 3331 import core.bitop : bsf; 3332 import std.algorithm.iteration : map; 3333 import std.array : uninitializedArray; 3334 3335 private: 3336 immutable lookup_t[][] negSinLookup; 3337 3338 void enforceSize(R)(R range) const 3339 { 3340 import std.conv : text; 3341 assert(range.length <= size, text( 3342 "FFT size mismatch. Expected ", size, ", got ", range.length)); 3343 } 3344 3345 void fftImpl(Ret, R)(Stride!R range, Ret buf) const 3346 in 3347 { 3348 assert(range.length >= 4); 3349 assert(isPowerOf2(range.length)); 3350 } 3351 do 3352 { 3353 auto recurseRange = range; 3354 recurseRange.doubleSteps(); 3355 3356 if (buf.length > 4) 3357 { 3358 fftImpl(recurseRange, buf[0..$ / 2]); 3359 recurseRange.popHalf(); 3360 fftImpl(recurseRange, buf[$ / 2..$]); 3361 } 3362 else 3363 { 3364 // Do this here instead of in another recursion to save on 3365 // recursion overhead. 3366 slowFourier2(recurseRange, buf[0..$ / 2]); 3367 recurseRange.popHalf(); 3368 slowFourier2(recurseRange, buf[$ / 2..$]); 3369 } 3370 3371 butterfly(buf); 3372 } 3373 3374 // This algorithm works by performing the even and odd parts of our FFT 3375 // using the "two for the price of one" method mentioned at 3376 // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 3377 // by making the odd terms into the imaginary components of our new FFT, 3378 // and then using symmetry to recombine them. 3379 void fftImplPureReal(Ret, R)(R range, Ret buf) const 3380 in 3381 { 3382 assert(range.length >= 4); 3383 assert(isPowerOf2(range.length)); 3384 } 3385 do 3386 { 3387 alias E = ElementType!R; 3388 3389 // Converts odd indices of range to the imaginary components of 3390 // a range half the size. The even indices become the real components. 3391 static if (isArray!R && isFloatingPoint!E) 3392 { 3393 // Then the memory layout of complex numbers provides a dirt 3394 // cheap way to convert. This is a common case, so take advantage. 3395 auto oddsImag = cast(Complex!E[]) range; 3396 } 3397 else 3398 { 3399 // General case: Use a higher order range. We can assume 3400 // source.length is even because it has to be a power of 2. 3401 static struct OddToImaginary 3402 { 3403 R source; 3404 alias C = Complex!(CommonType!(E, typeof(buf[0].re))); 3405 3406 @property 3407 { 3408 C front() 3409 { 3410 return C(source[0], source[1]); 3411 } 3412 3413 C back() 3414 { 3415 immutable n = source.length; 3416 return C(source[n - 2], source[n - 1]); 3417 } 3418 3419 typeof(this) save() 3420 { 3421 return typeof(this)(source.save); 3422 } 3423 3424 bool empty() 3425 { 3426 return source.empty; 3427 } 3428 3429 size_t length() 3430 { 3431 return source.length / 2; 3432 } 3433 } 3434 3435 void popFront() 3436 { 3437 source.popFront(); 3438 source.popFront(); 3439 } 3440 3441 void popBack() 3442 { 3443 source.popBack(); 3444 source.popBack(); 3445 } 3446 3447 C opIndex(size_t index) 3448 { 3449 return C(source[index * 2], source[index * 2 + 1]); 3450 } 3451 3452 typeof(this) opSlice(size_t lower, size_t upper) 3453 { 3454 return typeof(this)(source[lower * 2 .. upper * 2]); 3455 } 3456 } 3457 3458 auto oddsImag = OddToImaginary(range); 3459 } 3460 3461 fft(oddsImag, buf[0..$ / 2]); 3462 auto evenFft = buf[0..$ / 2]; 3463 auto oddFft = buf[$ / 2..$]; 3464 immutable halfN = evenFft.length; 3465 oddFft[0].re = buf[0].im; 3466 oddFft[0].im = 0; 3467 evenFft[0].im = 0; 3468 // evenFft[0].re is already right b/c it's aliased with buf[0].re. 3469 3470 foreach (k; 1 .. halfN / 2 + 1) 3471 { 3472 immutable bufk = buf[k]; 3473 immutable bufnk = buf[buf.length / 2 - k]; 3474 evenFft[k].re = 0.5 * (bufk.re + bufnk.re); 3475 evenFft[halfN - k].re = evenFft[k].re; 3476 evenFft[k].im = 0.5 * (bufk.im - bufnk.im); 3477 evenFft[halfN - k].im = -evenFft[k].im; 3478 3479 oddFft[k].re = 0.5 * (bufk.im + bufnk.im); 3480 oddFft[halfN - k].re = oddFft[k].re; 3481 oddFft[k].im = 0.5 * (bufnk.re - bufk.re); 3482 oddFft[halfN - k].im = -oddFft[k].im; 3483 } 3484 3485 butterfly(buf); 3486 } 3487 3488 void butterfly(R)(R buf) const 3489 in 3490 { 3491 assert(isPowerOf2(buf.length)); 3492 } 3493 do 3494 { 3495 immutable n = buf.length; 3496 immutable localLookup = negSinLookup[bsf(n)]; 3497 assert(localLookup.length == n); 3498 3499 immutable cosMask = n - 1; 3500 immutable cosAdd = n / 4 * 3; 3501 3502 lookup_t negSinFromLookup(size_t index) pure nothrow 3503 { 3504 return localLookup[index]; 3505 } 3506 3507 lookup_t cosFromLookup(size_t index) pure nothrow 3508 { 3509 // cos is just -sin shifted by PI * 3 / 2. 3510 return localLookup[(index + cosAdd) & cosMask]; 3511 } 3512 3513 immutable halfLen = n / 2; 3514 3515 // This loop is unrolled and the two iterations are interleaved 3516 // relative to the textbook FFT to increase ILP. This gives roughly 5% 3517 // speedups on DMD. 3518 for (size_t k = 0; k < halfLen; k += 2) 3519 { 3520 immutable cosTwiddle1 = cosFromLookup(k); 3521 immutable sinTwiddle1 = negSinFromLookup(k); 3522 immutable cosTwiddle2 = cosFromLookup(k + 1); 3523 immutable sinTwiddle2 = negSinFromLookup(k + 1); 3524 3525 immutable realLower1 = buf[k].re; 3526 immutable imagLower1 = buf[k].im; 3527 immutable realLower2 = buf[k + 1].re; 3528 immutable imagLower2 = buf[k + 1].im; 3529 3530 immutable upperIndex1 = k + halfLen; 3531 immutable upperIndex2 = upperIndex1 + 1; 3532 immutable realUpper1 = buf[upperIndex1].re; 3533 immutable imagUpper1 = buf[upperIndex1].im; 3534 immutable realUpper2 = buf[upperIndex2].re; 3535 immutable imagUpper2 = buf[upperIndex2].im; 3536 3537 immutable realAdd1 = cosTwiddle1 * realUpper1 3538 - sinTwiddle1 * imagUpper1; 3539 immutable imagAdd1 = sinTwiddle1 * realUpper1 3540 + cosTwiddle1 * imagUpper1; 3541 immutable realAdd2 = cosTwiddle2 * realUpper2 3542 - sinTwiddle2 * imagUpper2; 3543 immutable imagAdd2 = sinTwiddle2 * realUpper2 3544 + cosTwiddle2 * imagUpper2; 3545 3546 buf[k].re += realAdd1; 3547 buf[k].im += imagAdd1; 3548 buf[k + 1].re += realAdd2; 3549 buf[k + 1].im += imagAdd2; 3550 3551 buf[upperIndex1].re = realLower1 - realAdd1; 3552 buf[upperIndex1].im = imagLower1 - imagAdd1; 3553 buf[upperIndex2].re = realLower2 - realAdd2; 3554 buf[upperIndex2].im = imagLower2 - imagAdd2; 3555 } 3556 } 3557 3558 // This constructor is used within this module for allocating the 3559 // buffer space elsewhere besides the GC heap. It's definitely **NOT** 3560 // part of the public API and definitely **IS** subject to change. 3561 // 3562 // Also, this is unsafe because the memSpace buffer will be cast 3563 // to immutable. 3564 // 3565 // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636. 3566 public this(lookup_t[] memSpace) 3567 { 3568 immutable size = memSpace.length / 2; 3569 3570 /* Create a lookup table of all negative sine values at a resolution of 3571 * size and all smaller power of two resolutions. This may seem 3572 * inefficient, but having all the lookups be next to each other in 3573 * memory at every level of iteration is a huge win performance-wise. 3574 */ 3575 if (size == 0) 3576 { 3577 return; 3578 } 3579 3580 assert(isPowerOf2(size), 3581 "Can only do FFTs on ranges with a size that is a power of two."); 3582 3583 auto table = new lookup_t[][bsf(size) + 1]; 3584 3585 table[$ - 1] = memSpace[$ - size..$]; 3586 memSpace = memSpace[0 .. size]; 3587 3588 auto lastRow = table[$ - 1]; 3589 lastRow[0] = 0; // -sin(0) == 0. 3590 foreach (ptrdiff_t i; 1 .. size) 3591 { 3592 // The hard coded cases are for improved accuracy and to prevent 3593 // annoying non-zeroness when stuff should be zero. 3594 3595 if (i == size / 4) 3596 lastRow[i] = -1; // -sin(pi / 2) == -1. 3597 else if (i == size / 2) 3598 lastRow[i] = 0; // -sin(pi) == 0. 3599 else if (i == size * 3 / 4) 3600 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 3601 else 3602 lastRow[i] = -sin(i * 2.0L * PI / size); 3603 } 3604 3605 // Fill in all the other rows with strided versions. 3606 foreach (i; 1 .. table.length - 1) 3607 { 3608 immutable strideLength = size / (2 ^^ i); 3609 auto strided = Stride!(lookup_t[])(lastRow, strideLength); 3610 table[i] = memSpace[$ - strided.length..$]; 3611 memSpace = memSpace[0..$ - strided.length]; 3612 3613 size_t copyIndex; 3614 foreach (elem; strided) 3615 { 3616 table[i][copyIndex++] = elem; 3617 } 3618 } 3619 3620 negSinLookup = cast(immutable) table; 3621 } 3622 3623 public: 3624 /**Create an `Fft` object for computing fast Fourier transforms of 3625 * power of two sizes of `size` or smaller. `size` must be a 3626 * power of two. 3627 */ 3628 this(size_t size) 3629 { 3630 // Allocate all twiddle factor buffers in one contiguous block so that, 3631 // when one is done being used, the next one is next in cache. 3632 auto memSpace = uninitializedArray!(lookup_t[])(2 * size); 3633 this(memSpace); 3634 } 3635 3636 @property size_t size() const 3637 { 3638 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; 3639 } 3640 3641 /**Compute the Fourier transform of range using the $(BIGOH N log N) 3642 * Cooley-Tukey Algorithm. `range` must be a random-access range with 3643 * slicing and a length equal to `size` as provided at the construction of 3644 * this object. The contents of range can be either numeric types, 3645 * which will be interpreted as pure real values, or complex types with 3646 * properties or members `.re` and `.im` that can be read. 3647 * 3648 * Note: Pure real FFTs are automatically detected and the relevant 3649 * optimizations are performed. 3650 * 3651 * Returns: An array of complex numbers representing the transformed data in 3652 * the frequency domain. 3653 * 3654 * Conventions: The exponent is negative and the factor is one, 3655 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. 3656 */ 3657 Complex!F[] fft(F = double, R)(R range) const 3658 if (isFloatingPoint!F && isRandomAccessRange!R) 3659 { 3660 enforceSize(range); 3661 Complex!F[] ret; 3662 if (range.length == 0) 3663 { 3664 return ret; 3665 } 3666 3667 // Don't waste time initializing the memory for ret. 3668 ret = uninitializedArray!(Complex!F[])(range.length); 3669 3670 fft(range, ret); 3671 return ret; 3672 } 3673 3674 /**Same as the overload, but allows for the results to be stored in a user- 3675 * provided buffer. The buffer must be of the same length as range, must be 3676 * a random-access range, must have slicing, and must contain elements that are 3677 * complex-like. This means that they must have a .re and a .im member or 3678 * property that can be both read and written and are floating point numbers. 3679 */ 3680 void fft(Ret, R)(R range, Ret buf) const 3681 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3682 { 3683 assert(buf.length == range.length); 3684 enforceSize(range); 3685 3686 if (range.length == 0) 3687 { 3688 return; 3689 } 3690 else if (range.length == 1) 3691 { 3692 buf[0] = range[0]; 3693 return; 3694 } 3695 else if (range.length == 2) 3696 { 3697 slowFourier2(range, buf); 3698 return; 3699 } 3700 else 3701 { 3702 alias E = ElementType!R; 3703 static if (is(E : real)) 3704 { 3705 return fftImplPureReal(range, buf); 3706 } 3707 else 3708 { 3709 static if (is(R : Stride!R)) 3710 return fftImpl(range, buf); 3711 else 3712 return fftImpl(Stride!R(range, 1), buf); 3713 } 3714 } 3715 } 3716 3717 /** 3718 * Computes the inverse Fourier transform of a range. The range must be a 3719 * random access range with slicing, have a length equal to the size 3720 * provided at construction of this object, and contain elements that are 3721 * either of type std.complex.Complex or have essentially 3722 * the same compile-time interface. 3723 * 3724 * Returns: The time-domain signal. 3725 * 3726 * Conventions: The exponent is positive and the factor is 1/N, i.e., 3727 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. 3728 */ 3729 Complex!F[] inverseFft(F = double, R)(R range) const 3730 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) 3731 { 3732 enforceSize(range); 3733 Complex!F[] ret; 3734 if (range.length == 0) 3735 { 3736 return ret; 3737 } 3738 3739 // Don't waste time initializing the memory for ret. 3740 ret = uninitializedArray!(Complex!F[])(range.length); 3741 3742 inverseFft(range, ret); 3743 return ret; 3744 } 3745 3746 /** 3747 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer 3748 * must be a random access range with slicing, and its elements 3749 * must be some complex-like type. 3750 */ 3751 void inverseFft(Ret, R)(R range, Ret buf) const 3752 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3753 { 3754 enforceSize(range); 3755 3756 auto swapped = map!swapRealImag(range); 3757 fft(swapped, buf); 3758 3759 immutable lenNeg1 = 1.0 / buf.length; 3760 foreach (ref elem; buf) 3761 { 3762 immutable temp = elem.re * lenNeg1; 3763 elem.re = elem.im * lenNeg1; 3764 elem.im = temp; 3765 } 3766 } 3767 } 3768 3769 // This mixin creates an Fft object in the scope it's mixed into such that all 3770 // memory owned by the object is deterministically destroyed at the end of that 3771 // scope. 3772 private enum string MakeLocalFft = q{ 3773 import core.stdc.stdlib; 3774 import core.exception : onOutOfMemoryError; 3775 3776 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) 3777 [0 .. 2 * range.length]; 3778 if (!lookupBuf.ptr) 3779 onOutOfMemoryError(); 3780 3781 scope(exit) free(cast(void*) lookupBuf.ptr); 3782 auto fftObj = scoped!Fft(lookupBuf); 3783 }; 3784 3785 /**Convenience functions that create an `Fft` object, run the FFT or inverse 3786 * FFT and return the result. Useful for one-off FFTs. 3787 * 3788 * Note: In addition to convenience, these functions are slightly more 3789 * efficient than manually creating an Fft object for a single use, 3790 * as the Fft object is deterministically destroyed before these 3791 * functions return. 3792 */ 3793 Complex!F[] fft(F = double, R)(R range) 3794 { 3795 mixin(MakeLocalFft); 3796 return fftObj.fft!(F, R)(range); 3797 } 3798 3799 /// ditto 3800 void fft(Ret, R)(R range, Ret buf) 3801 { 3802 mixin(MakeLocalFft); 3803 return fftObj.fft!(Ret, R)(range, buf); 3804 } 3805 3806 /// ditto 3807 Complex!F[] inverseFft(F = double, R)(R range) 3808 { 3809 mixin(MakeLocalFft); 3810 return fftObj.inverseFft!(F, R)(range); 3811 } 3812 3813 /// ditto 3814 void inverseFft(Ret, R)(R range, Ret buf) 3815 { 3816 mixin(MakeLocalFft); 3817 return fftObj.inverseFft!(Ret, R)(range, buf); 3818 } 3819 3820 @system unittest 3821 { 3822 import std.algorithm; 3823 import std.conv; 3824 import std.range; 3825 // Test values from R and Octave. 3826 auto arr = [1,2,3,4,5,6,7,8]; 3827 auto fft1 = fft(arr); 3828 assert(isClose(map!"a.re"(fft1), 3829 [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4)); 3830 assert(isClose(map!"a.im"(fft1), 3831 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4)); 3832 3833 auto fft1Retro = fft(retro(arr)); 3834 assert(isClose(map!"a.re"(fft1Retro), 3835 [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4)); 3836 assert(isClose(map!"a.im"(fft1Retro), 3837 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4)); 3838 3839 auto fft1Float = fft(to!(float[])(arr)); 3840 assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float))); 3841 assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float))); 3842 3843 alias C = Complex!float; 3844 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), 3845 C(11,12), C(13,14), C(15,16)]; 3846 auto fft2 = fft(arr2); 3847 assert(isClose(map!"a.re"(fft2), 3848 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4)); 3849 assert(isClose(map!"a.im"(fft2), 3850 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4)); 3851 3852 auto inv1 = inverseFft(fft1); 3853 assert(isClose(map!"a.re"(inv1), arr, 1e-6)); 3854 assert(reduce!max(map!"a.im"(inv1)) < 1e-10); 3855 3856 auto inv2 = inverseFft(fft2); 3857 assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2))); 3858 assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2))); 3859 3860 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. 3861 ushort[] empty; 3862 assert(fft(empty) == null); 3863 assert(inverseFft(fft(empty)) == null); 3864 3865 real[] oneElem = [4.5L]; 3866 auto oneFft = fft(oneElem); 3867 assert(oneFft.length == 1); 3868 assert(oneFft[0].re == 4.5L); 3869 assert(oneFft[0].im == 0); 3870 3871 auto oneInv = inverseFft(oneFft); 3872 assert(oneInv.length == 1); 3873 assert(isClose(oneInv[0].re, 4.5)); 3874 assert(isClose(oneInv[0].im, 0, 0.0, 1e-10)); 3875 3876 long[2] twoElems = [8, 4]; 3877 auto twoFft = fft(twoElems[]); 3878 assert(twoFft.length == 2); 3879 assert(isClose(twoFft[0].re, 12)); 3880 assert(isClose(twoFft[0].im, 0, 0.0, 1e-10)); 3881 assert(isClose(twoFft[1].re, 4)); 3882 assert(isClose(twoFft[1].im, 0, 0.0, 1e-10)); 3883 auto twoInv = inverseFft(twoFft); 3884 assert(isClose(twoInv[0].re, 8)); 3885 assert(isClose(twoInv[0].im, 0, 0.0, 1e-10)); 3886 assert(isClose(twoInv[1].re, 4)); 3887 assert(isClose(twoInv[1].im, 0, 0.0, 1e-10)); 3888 } 3889 3890 // Swaps the real and imaginary parts of a complex number. This is useful 3891 // for inverse FFTs. 3892 C swapRealImag(C)(C input) 3893 { 3894 return C(input.im, input.re); 3895 } 3896 3897 /** This function transforms `decimal` value into a value in the factorial number 3898 system stored in `fac`. 3899 3900 A factorial number is constructed as: 3901 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!) 3902 3903 Params: 3904 decimal = The decimal value to convert into the factorial number system. 3905 fac = The array to store the factorial number. The array is of size 21 as 3906 `ulong.max` requires 21 digits in the factorial number system. 3907 Returns: 3908 A variable storing the number of digits of the factorial number stored in 3909 `fac`. 3910 */ 3911 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac) 3912 @safe pure nothrow @nogc 3913 { 3914 import std.algorithm.mutation : reverse; 3915 size_t idx; 3916 3917 for (ulong i = 1; decimal != 0; ++i) 3918 { 3919 auto temp = decimal % i; 3920 decimal /= i; 3921 fac[idx++] = cast(ubyte)(temp); 3922 } 3923 3924 if (idx == 0) 3925 { 3926 fac[idx++] = cast(ubyte) 0; 3927 } 3928 3929 reverse(fac[0 .. idx]); 3930 3931 // first digit of the number in factorial will always be zero 3932 assert(fac[idx - 1] == 0); 3933 3934 return idx; 3935 } 3936 3937 /// 3938 @safe pure @nogc unittest 3939 { 3940 ubyte[21] fac; 3941 size_t idx = decimalToFactorial(2982, fac); 3942 3943 assert(fac[0] == 4); 3944 assert(fac[1] == 0); 3945 assert(fac[2] == 4); 3946 assert(fac[3] == 1); 3947 assert(fac[4] == 0); 3948 assert(fac[5] == 0); 3949 assert(fac[6] == 0); 3950 } 3951 3952 @safe pure unittest 3953 { 3954 ubyte[21] fac; 3955 size_t idx = decimalToFactorial(0UL, fac); 3956 assert(idx == 1); 3957 assert(fac[0] == 0); 3958 3959 fac[] = 0; 3960 idx = 0; 3961 idx = decimalToFactorial(ulong.max, fac); 3962 assert(idx == 21); 3963 auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0]; 3964 foreach (i, it; fac[0 .. 21]) 3965 { 3966 assert(it == t[i]); 3967 } 3968 3969 fac[] = 0; 3970 idx = decimalToFactorial(2982, fac); 3971 3972 assert(idx == 7); 3973 t = [4, 0, 4, 1, 0, 0, 0]; 3974 foreach (i, it; fac[0 .. idx]) 3975 { 3976 assert(it == t[i]); 3977 } 3978 } 3979 3980 private: 3981 // The reasons I couldn't use std.algorithm were b/c its stride length isn't 3982 // modifiable on the fly and because range has grown some performance hacks 3983 // for powers of 2. 3984 struct Stride(R) 3985 { 3986 import core.bitop : bsf; 3987 Unqual!R range; 3988 size_t _nSteps; 3989 size_t _length; 3990 alias E = ElementType!(R); 3991 3992 this(R range, size_t nStepsIn) 3993 { 3994 this.range = range; 3995 _nSteps = nStepsIn; 3996 _length = (range.length + _nSteps - 1) / nSteps; 3997 } 3998 3999 size_t length() const @property 4000 { 4001 return _length; 4002 } 4003 4004 typeof(this) save() @property 4005 { 4006 auto ret = this; 4007 ret.range = ret.range.save; 4008 return ret; 4009 } 4010 4011 E opIndex(size_t index) 4012 { 4013 return range[index * _nSteps]; 4014 } 4015 4016 E front() @property 4017 { 4018 return range[0]; 4019 } 4020 4021 void popFront() 4022 { 4023 if (range.length >= _nSteps) 4024 { 4025 range = range[_nSteps .. range.length]; 4026 _length--; 4027 } 4028 else 4029 { 4030 range = range[0 .. 0]; 4031 _length = 0; 4032 } 4033 } 4034 4035 // Pops half the range's stride. 4036 void popHalf() 4037 { 4038 range = range[_nSteps / 2 .. range.length]; 4039 } 4040 4041 bool empty() const @property 4042 { 4043 return length == 0; 4044 } 4045 4046 size_t nSteps() const @property 4047 { 4048 return _nSteps; 4049 } 4050 4051 void doubleSteps() 4052 { 4053 _nSteps *= 2; 4054 _length /= 2; 4055 } 4056 4057 size_t nSteps(size_t newVal) @property 4058 { 4059 _nSteps = newVal; 4060 4061 // Using >> bsf(nSteps) is a few cycles faster than / nSteps. 4062 _length = (range.length + _nSteps - 1) >> bsf(nSteps); 4063 return newVal; 4064 } 4065 } 4066 4067 // Hard-coded base case for FFT of size 2. This is actually a TON faster than 4068 // using a generic slow DFT. This seems to be the best base case. (Size 1 4069 // can be coded inline as buf[0] = range[0]). 4070 void slowFourier2(Ret, R)(R range, Ret buf) 4071 { 4072 assert(range.length == 2); 4073 assert(buf.length == 2); 4074 buf[0] = range[0] + range[1]; 4075 buf[1] = range[0] - range[1]; 4076 } 4077 4078 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size 4079 // 2 case. 4080 void slowFourier4(Ret, R)(R range, Ret buf) 4081 { 4082 alias C = ElementType!Ret; 4083 4084 assert(range.length == 4); 4085 assert(buf.length == 4); 4086 buf[0] = range[0] + range[1] + range[2] + range[3]; 4087 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); 4088 buf[2] = range[0] - range[1] + range[2] - range[3]; 4089 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); 4090 } 4091 4092 N roundDownToPowerOf2(N)(N num) 4093 if (isScalarType!N && !isFloatingPoint!N) 4094 { 4095 import core.bitop : bsr; 4096 return num & (cast(N) 1 << bsr(num)); 4097 } 4098 4099 @safe unittest 4100 { 4101 assert(roundDownToPowerOf2(7) == 4); 4102 assert(roundDownToPowerOf2(4) == 4); 4103 } 4104 4105 template isComplexLike(T) 4106 { 4107 enum bool isComplexLike = is(typeof(T.init.re)) && 4108 is(typeof(T.init.im)); 4109 } 4110 4111 @safe unittest 4112 { 4113 static assert(isComplexLike!(Complex!double)); 4114 static assert(!isComplexLike!(uint)); 4115 } 4116