1 // Written in the D programming language. 2 3 /** 4 This module is a port of a growing fragment of the $(D_PARAM numeric) 5 header in Alexander Stepanov's $(LINK2 http://sgi.com/tech/stl, 6 Standard Template Library), with a few additions. 7 8 Macros: 9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009. 10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 11 Authors: $(HTTP erdani.org, Andrei Alexandrescu), 12 Don Clugston, Robert Jacques, Ilya Yaroshenko 13 Source: $(PHOBOSSRC std/_numeric.d) 14 */ 15 /* 16 Copyright Andrei Alexandrescu 2008 - 2009. 17 Distributed under the Boost Software License, Version 1.0. 18 (See accompanying file LICENSE_1_0.txt or copy at 19 http://www.boost.org/LICENSE_1_0.txt) 20 */ 21 module std.numeric; 22 23 import std.complex; 24 import std.math; 25 import std.range.primitives; 26 import std.traits; 27 import std.typecons; 28 29 version (unittest) 30 { 31 import std.stdio; 32 } 33 /// Format flags for CustomFloat. 34 public enum CustomFloatFlags 35 { 36 /// Adds a sign bit to allow for signed numbers. 37 signed = 1, 38 39 /** 40 * Store values in normalized form by default. The actual precision of the 41 * significand is extended by 1 bit by assuming an implicit leading bit of 1 42 * instead of 0. i.e. $(D 1.nnnn) instead of $(D 0.nnnn). 43 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types 44 */ 45 storeNormalized = 2, 46 47 /** 48 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 49 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. 50 */ 51 allowDenorm = 4, 52 53 /** 54 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, 55 * IEEE754 _infinity) values. 56 */ 57 infinity = 8, 58 59 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. 60 nan = 16, 61 62 /** 63 * If set, select an exponent bias such that max_exp = 1. 64 * i.e. so that the maximum value is >= 1.0 and < 2.0. 65 * Ignored if the exponent bias is manually specified. 66 */ 67 probability = 32, 68 69 /// If set, unsigned custom floats are assumed to be negative. 70 negativeUnsigned = 64, 71 72 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 73 * IEEE754 denormalized) number. 74 * Requires allowDenorm and storeNormalized. 75 */ 76 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, 77 78 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. 79 ieee = signed | storeNormalized | allowDenorm | infinity | nan , 80 81 /// Include none of the above options. 82 none = 0 83 } 84 85 private template CustomFloatParams(uint bits) 86 { 87 enum CustomFloatFlags flags = CustomFloatFlags.ieee 88 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); 89 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); 90 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); 91 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); 92 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); 93 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); 94 } 95 96 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) 97 { 98 import std.meta : AliasSeq; 99 alias CustomFloatParams = 100 AliasSeq!( 101 precision, 102 exponentWidth, 103 flags, 104 (1 << (exponentWidth - ((flags & flags.probability) == 0))) 105 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) 106 ); // ((flags & CustomFloatFlags.probability) == 0) 107 } 108 109 /** 110 * Allows user code to define custom floating-point formats. These formats are 111 * for storage only; all operations on them are performed by first implicitly 112 * extracting them to $(D real) first. After the operation is completed the 113 * result can be stored in a custom floating-point value via assignment. 114 */ 115 template CustomFloat(uint bits) 116 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) 117 { 118 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); 119 } 120 121 /// ditto 122 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) 123 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) 124 { 125 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); 126 } 127 128 /// 129 @safe unittest 130 { 131 import std.math : sin, cos; 132 133 // Define a 16-bit floating point values 134 CustomFloat!16 x; // Using the number of bits 135 CustomFloat!(10, 5) y; // Using the precision and exponent width 136 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags 137 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias 138 139 // Use the 16-bit floats mostly like normal numbers 140 w = x*y - 1; 141 142 // Functions calls require conversion 143 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real 144 z = sin(x.get!float) + cos(y.get!float); // Or use get!T 145 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert 146 147 // Define a 8-bit custom float for storing probabilities 148 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); 149 auto p = Probability(0.5); 150 } 151 152 /// ditto 153 struct CustomFloat(uint precision, // fraction bits (23 for float) 154 uint exponentWidth, // exponent bits (8 for float) Exponent width 155 CustomFloatFlags flags, 156 uint bias) 157 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && 158 precision + exponentWidth > 0) 159 { 160 import std.bitmanip : bitfields; 161 import std.meta : staticIndexOf; 162 private: 163 // get the correct unsigned bitfield type to support > 32 bits 164 template uType(uint bits) 165 { 166 static if (bits <= size_t.sizeof*8) alias uType = size_t; 167 else alias uType = ulong ; 168 } 169 170 // get the correct signed bitfield type to support > 32 bits 171 template sType(uint bits) 172 { 173 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; 174 else alias sType = long; 175 } 176 177 alias T_sig = uType!precision; 178 alias T_exp = uType!exponentWidth; 179 alias T_signed_exp = sType!exponentWidth; 180 181 alias Flags = CustomFloatFlags; 182 183 // Facilitate converting numeric types to custom float 184 union ToBinary(F) 185 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) 186 { 187 F set; 188 189 // If on Linux or Mac, where 80-bit reals are padded, ignore the 190 // padding. 191 import std.algorithm.comparison : min; 192 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; 193 194 // Convert F to the correct binary type. 195 static typeof(get) opCall(F value) 196 { 197 ToBinary r; 198 r.set = value; 199 return r.get; 200 } 201 alias get this; 202 } 203 204 // Perform IEEE rounding with round to nearest detection 205 void roundedShift(T,U)(ref T sig, U shift) 206 { 207 if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) 208 { 209 // round to even 210 sig >>= shift; 211 sig += sig & 1; 212 } 213 else 214 { 215 sig >>= shift - 1; 216 sig += sig & 1; 217 // Perform standard rounding 218 sig >>= 1; 219 } 220 } 221 222 // Convert the current value to signed exponent, normalized form 223 void toNormalized(T,U)(ref T sig, ref U exp) 224 { 225 sig = significand; 226 auto shift = (T.sizeof*8) - precision; 227 exp = exponent; 228 static if (flags&(Flags.infinity|Flags.nan)) 229 { 230 // Handle inf or nan 231 if (exp == exponent_max) 232 { 233 exp = exp.max; 234 sig <<= shift; 235 static if (flags&Flags.storeNormalized) 236 { 237 // Save inf/nan in denormalized format 238 sig >>= 1; 239 sig += cast(T) 1uL << (T.sizeof*8 - 1); 240 } 241 return; 242 } 243 } 244 if ((~flags&Flags.storeNormalized) || 245 // Convert denormalized form to normalized form 246 ((flags&Flags.allowDenorm) && exp == 0)) 247 { 248 if (sig > 0) 249 { 250 import core.bitop : bsr; 251 auto shift2 = precision - bsr(sig); 252 exp -= shift2-1; 253 shift += shift2; 254 } 255 else // value = 0.0 256 { 257 exp = exp.min; 258 return; 259 } 260 } 261 sig <<= shift; 262 exp -= bias; 263 } 264 265 // Set the current value from signed exponent, normalized form 266 void fromNormalized(T,U)(ref T sig, ref U exp) 267 { 268 auto shift = (T.sizeof*8) - precision; 269 if (exp == exp.max) 270 { 271 // infinity or nan 272 exp = exponent_max; 273 static if (flags & Flags.storeNormalized) 274 sig <<= 1; 275 276 // convert back to normalized form 277 static if (~flags & Flags.infinity) 278 // No infinity support? 279 assert(sig != 0, "Infinity floating point value assigned to a " 280 ~ typeof(this).stringof ~ " (no infinity support)."); 281 282 static if (~flags & Flags.nan) // No NaN support? 283 assert(sig == 0, "NaN floating point value assigned to a " ~ 284 typeof(this).stringof ~ " (no nan support)."); 285 sig >>= shift; 286 return; 287 } 288 if (exp == exp.min) // 0.0 289 { 290 exp = 0; 291 sig = 0; 292 return; 293 } 294 295 exp += bias; 296 if (exp <= 0) 297 { 298 static if ((flags&Flags.allowDenorm) || 299 // Convert from normalized form to denormalized 300 (~flags&Flags.storeNormalized)) 301 { 302 shift += -exp; 303 roundedShift(sig,1); 304 sig += cast(T) 1uL << (T.sizeof*8 - 1); 305 // Add the leading 1 306 exp = 0; 307 } 308 else 309 assert((flags&Flags.storeNormalized) && exp == 0, 310 "Underflow occured assigning to a " ~ 311 typeof(this).stringof ~ " (no denormal support)."); 312 } 313 else 314 { 315 static if (~flags&Flags.storeNormalized) 316 { 317 // Convert from normalized form to denormalized 318 roundedShift(sig,1); 319 sig += cast(T) 1uL << (T.sizeof*8 - 1); 320 // Add the leading 1 321 } 322 } 323 324 if (shift > 0) 325 roundedShift(sig,shift); 326 if (sig > significand_max) 327 { 328 // handle significand overflow (should only be 1 bit) 329 static if (~flags&Flags.storeNormalized) 330 { 331 sig >>= 1; 332 } 333 else 334 sig &= significand_max; 335 exp++; 336 } 337 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) 338 { 339 // disallow non-zero denormals 340 if (exp == 0) 341 { 342 sig <<= 1; 343 if (sig > significand_max && (sig&significand_max) > 0) 344 // Check and round to even 345 exp++; 346 sig = 0; 347 } 348 } 349 350 if (exp >= exponent_max) 351 { 352 static if (flags&(Flags.infinity|Flags.nan)) 353 { 354 sig = 0; 355 exp = exponent_max; 356 static if (~flags&(Flags.infinity)) 357 assert(0, "Overflow occured assigning to a " ~ 358 typeof(this).stringof ~ " (no infinity support)."); 359 } 360 else 361 assert(exp == exponent_max, "Overflow occured assigning to a " 362 ~ typeof(this).stringof ~ " (no infinity support)."); 363 } 364 } 365 366 public: 367 static if (precision == 64) // CustomFloat!80 support hack 368 { 369 ulong significand; 370 enum ulong significand_max = ulong.max; 371 mixin(bitfields!( 372 T_exp , "exponent", exponentWidth, 373 bool , "sign" , flags & flags.signed )); 374 } 375 else 376 { 377 mixin(bitfields!( 378 T_sig, "significand", precision, 379 T_exp, "exponent" , exponentWidth, 380 bool , "sign" , flags & flags.signed )); 381 } 382 383 /// Returns: infinity value 384 static if (flags & Flags.infinity) 385 static @property CustomFloat infinity() 386 { 387 CustomFloat value; 388 static if (flags & Flags.signed) 389 value.sign = 0; 390 value.significand = 0; 391 value.exponent = exponent_max; 392 return value; 393 } 394 395 /// Returns: NaN value 396 static if (flags & Flags.nan) 397 static @property CustomFloat nan() 398 { 399 CustomFloat value; 400 static if (flags & Flags.signed) 401 value.sign = 0; 402 value.significand = cast(typeof(significand_max)) 1L << (precision-1); 403 value.exponent = exponent_max; 404 return value; 405 } 406 407 /// Returns: number of decimal digits of precision 408 static @property size_t dig() 409 { 410 auto shiftcnt = precision - ((flags&Flags.storeNormalized) != 0); 411 immutable x = (shiftcnt == 64) ? 0 : 1uL << shiftcnt; 412 return cast(size_t) log10(x); 413 } 414 415 /// Returns: smallest increment to the value 1 416 static @property CustomFloat epsilon() 417 { 418 CustomFloat value; 419 static if (flags & Flags.signed) 420 value.sign = 0; 421 T_signed_exp exp = -precision; 422 T_sig sig = 0; 423 424 value.fromNormalized(sig,exp); 425 if (exp == 0 && sig == 0) // underflowed to zero 426 { 427 static if ((flags&Flags.allowDenorm) || 428 (~flags&Flags.storeNormalized)) 429 sig = 1; 430 else 431 sig = cast(T) 1uL << (precision - 1); 432 } 433 value.exponent = cast(value.T_exp) exp; 434 value.significand = cast(value.T_sig) sig; 435 return value; 436 } 437 438 /// the number of bits in mantissa 439 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); 440 441 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable 442 static @property int max_10_exp(){ return cast(int) log10( +max ); } 443 444 /// maximum int value such that 2<sup>max_exp-1</sup> is representable 445 enum max_exp = exponent_max-bias+((~flags&(Flags.infinity|flags.nan))!=0); 446 447 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable 448 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } 449 450 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value 451 enum min_exp = cast(T_signed_exp)-bias +1+ ((flags&Flags.allowDenorm)!=0); 452 453 /// Returns: largest representable value that's not infinity 454 static @property CustomFloat max() 455 { 456 CustomFloat value; 457 static if (flags & Flags.signed) 458 value.sign = 0; 459 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); 460 value.significand = significand_max; 461 return value; 462 } 463 464 /// Returns: smallest representable normalized value that's not 0 465 static @property CustomFloat min_normal() { 466 CustomFloat value; 467 static if (flags & Flags.signed) 468 value.sign = 0; 469 value.exponent = 1; 470 static if (flags&Flags.storeNormalized) 471 value.significand = 0; 472 else 473 value.significand = cast(T_sig) 1uL << (precision - 1); 474 return value; 475 } 476 477 /// Returns: real part 478 @property CustomFloat re() { return this; } 479 480 /// Returns: imaginary part 481 static @property CustomFloat im() { return CustomFloat(0.0f); } 482 483 /// Initialize from any $(D real) compatible type. 484 this(F)(F input) if (__traits(compiles, cast(real) input )) 485 { 486 this = input; 487 } 488 489 /// Self assignment 490 void opAssign(F:CustomFloat)(F input) 491 { 492 static if (flags & Flags.signed) 493 sign = input.sign; 494 exponent = input.exponent; 495 significand = input.significand; 496 } 497 498 /// Assigns from any $(D real) compatible type. 499 void opAssign(F)(F input) 500 if (__traits(compiles, cast(real) input)) 501 { 502 import std.conv : text; 503 504 static if (staticIndexOf!(Unqual!F, float, double, real) >= 0) 505 auto value = ToBinary!(Unqual!F)(input); 506 else 507 auto value = ToBinary!(real )(input); 508 509 // Assign the sign bit 510 static if (~flags & Flags.signed) 511 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), 512 "Incorrectly signed floating point value assigned to a " ~ 513 typeof(this).stringof ~ " (no sign support)."); 514 else 515 sign = value.sign; 516 517 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; 518 CommonType!(T_sig, value.T_sig ) sig = value.significand; 519 520 value.toNormalized(sig,exp); 521 fromNormalized(sig,exp); 522 523 assert(exp <= exponent_max, text(typeof(this).stringof ~ 524 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); 525 assert(sig <= significand_max, text(typeof(this).stringof ~ 526 " significand too large: ",sig," > ",significand_max, 527 "\t",input,"\t",exp," ",exponent_max)); 528 exponent = cast(T_exp) exp; 529 significand = cast(T_sig) sig; 530 } 531 532 /// Fetches the stored value either as a $(D float), $(D double) or $(D real). 533 @property F get(F)() 534 if (staticIndexOf!(Unqual!F, float, double, real) >= 0) 535 { 536 import std.conv : text; 537 538 ToBinary!F result; 539 540 static if (flags&Flags.signed) 541 result.sign = sign; 542 else 543 result.sign = (flags&flags.negativeUnsigned) > 0; 544 545 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction 546 CommonType!(T_sig, result.get.T_sig ) sig = significand; 547 548 toNormalized(sig,exp); 549 result.fromNormalized(sig,exp); 550 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); 551 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); 552 result.exponent = cast(result.get.T_exp) exp; 553 result.significand = cast(result.get.T_sig) sig; 554 return result.set; 555 } 556 557 ///ditto 558 T opCast(T)() if (__traits(compiles, get!T )) { return get!T; } 559 560 /// Convert the CustomFloat to a real and perform the relavent operator on the result 561 real opUnary(string op)() 562 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") 563 { 564 static if (op=="++" || op=="--") 565 { 566 auto result = get!real; 567 this = mixin(op~`result`); 568 return result; 569 } 570 else 571 return mixin(op~`get!real`); 572 } 573 574 /// ditto 575 real opBinary(string op,T)(T b) 576 if (__traits(compiles, mixin(`get!real`~op~`b`))) 577 { 578 return mixin(`get!real`~op~`b`); 579 } 580 581 /// ditto 582 real opBinaryRight(string op,T)(T a) 583 if ( __traits(compiles, mixin(`a`~op~`get!real`)) && 584 !__traits(compiles, mixin(`get!real`~op~`b`))) 585 { 586 return mixin(`a`~op~`get!real`); 587 } 588 589 /// ditto 590 int opCmp(T)(auto ref T b) 591 if (__traits(compiles, cast(real) b)) 592 { 593 auto x = get!real; 594 auto y = cast(real) b; 595 return (x >= y)-(x <= y); 596 } 597 598 /// ditto 599 void opOpAssign(string op, T)(auto ref T b) 600 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) 601 { 602 return mixin(`this = this `~op~` cast(real) b`); 603 } 604 605 /// ditto 606 template toString() 607 { 608 import std.format : FormatSpec, formatValue; 609 // Needs to be a template because of DMD @@BUG@@ 13737. 610 void toString()(scope void delegate(const(char)[]) sink, FormatSpec!char fmt) 611 { 612 sink.formatValue(get!real, fmt); 613 } 614 } 615 } 616 617 @safe unittest 618 { 619 import std.meta; 620 alias FPTypes = 621 AliasSeq!( 622 CustomFloat!(5, 10), 623 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 624 CustomFloat!(1, 15, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 625 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) 626 ); 627 628 foreach (F; FPTypes) 629 { 630 auto x = F(0.125); 631 assert(x.get!float == 0.125F); 632 assert(x.get!double == 0.125); 633 634 x -= 0.0625; 635 assert(x.get!float == 0.0625F); 636 assert(x.get!double == 0.0625); 637 638 x *= 2; 639 assert(x.get!float == 0.125F); 640 assert(x.get!double == 0.125); 641 642 x /= 4; 643 assert(x.get!float == 0.03125); 644 assert(x.get!double == 0.03125); 645 646 x = 0.5; 647 x ^^= 4; 648 assert(x.get!float == 1 / 16.0F); 649 assert(x.get!double == 1 / 16.0); 650 } 651 } 652 653 @system unittest 654 { 655 // @system due to to!string(CustomFloat) 656 import std.conv; 657 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); 658 assert(y.to!string == "0.125"); 659 } 660 661 /** 662 Defines the fastest type to use when storing temporaries of a 663 calculation intended to ultimately yield a result of type $(D F) 664 (where $(D F) must be one of $(D float), $(D double), or $(D 665 real)). When doing a multi-step computation, you may want to store 666 intermediate results as $(D FPTemporary!F). 667 668 The necessity of $(D FPTemporary) stems from the optimized 669 floating-point operations and registers present in virtually all 670 processors. When adding numbers in the example above, the addition may 671 in fact be done in $(D real) precision internally. In that case, 672 storing the intermediate $(D result) in $(D double format) is not only 673 less precise, it is also (surprisingly) slower, because a conversion 674 from $(D real) to $(D double) is performed every pass through the 675 loop. This being a lose-lose situation, $(D FPTemporary!F) has been 676 defined as the $(I fastest) type to use for calculations at precision 677 $(D F). There is no need to define a type for the $(I most accurate) 678 calculations, as that is always $(D real). 679 680 Finally, there is no guarantee that using $(D FPTemporary!F) will 681 always be fastest, as the speed of floating-point calculations depends 682 on very many factors. 683 */ 684 template FPTemporary(F) 685 if (isFloatingPoint!F) 686 { 687 version (X86) 688 alias FPTemporary = real; 689 else 690 alias FPTemporary = Unqual!F; 691 } 692 693 /// 694 @safe unittest 695 { 696 import std.math : approxEqual; 697 698 // Average numbers in an array 699 double avg(in double[] a) 700 { 701 if (a.length == 0) return 0; 702 FPTemporary!double result = 0; 703 foreach (e; a) result += e; 704 return result / a.length; 705 } 706 707 auto a = [1.0, 2.0, 3.0]; 708 assert(approxEqual(avg(a), 2)); 709 } 710 711 /** 712 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a 713 root of the function $(D fun) starting from points $(D [xn_1, x_n]) 714 (ideally close to the root). $(D Num) may be $(D float), $(D double), 715 or $(D real). 716 */ 717 template secantMethod(alias fun) 718 { 719 import std.functional : unaryFun; 720 Num secantMethod(Num)(Num xn_1, Num xn) 721 { 722 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; 723 typeof(fxn) fxn_1; 724 725 xn = xn_1; 726 while (!approxEqual(d, 0) && isFinite(d)) 727 { 728 xn_1 = xn; 729 xn -= d; 730 fxn_1 = fxn; 731 fxn = unaryFun!(fun)(xn); 732 d *= -fxn / (fxn - fxn_1); 733 } 734 return xn; 735 } 736 } 737 738 /// 739 @safe unittest 740 { 741 import std.math : approxEqual, cos; 742 743 float f(float x) 744 { 745 return cos(x) - x*x*x; 746 } 747 auto x = secantMethod!(f)(0f, 1f); 748 assert(approxEqual(x, 0.865474)); 749 } 750 751 @system unittest 752 { 753 // @system because of __gshared stderr 754 scope(failure) stderr.writeln("Failure testing secantMethod"); 755 float f(float x) 756 { 757 return cos(x) - x*x*x; 758 } 759 immutable x = secantMethod!(f)(0f, 1f); 760 assert(approxEqual(x, 0.865474)); 761 auto d = &f; 762 immutable y = secantMethod!(d)(0f, 1f); 763 assert(approxEqual(y, 0.865474)); 764 } 765 766 767 /** 768 * Return true if a and b have opposite sign. 769 */ 770 private bool oppositeSigns(T1, T2)(T1 a, T2 b) 771 { 772 return signbit(a) != signbit(b); 773 } 774 775 public: 776 777 /** Find a real root of a real function f(x) via bracketing. 778 * 779 * Given a function `f` and a range `[a .. b]` such that `f(a)` 780 * and `f(b)` have opposite signs or at least one of them equals ±0, 781 * returns the value of `x` in 782 * the range which is closest to a root of `f(x)`. If `f(x)` 783 * has more than one root in the range, one will be chosen 784 * arbitrarily. If `f(x)` returns NaN, NaN will be returned; 785 * otherwise, this algorithm is guaranteed to succeed. 786 * 787 * Uses an algorithm based on TOMS748, which uses inverse cubic 788 * interpolation whenever possible, otherwise reverting to parabolic 789 * or secant interpolation. Compared to TOMS748, this implementation 790 * improves worst-case performance by a factor of more than 100, and 791 * typical performance by a factor of 2. For 80-bit reals, most 792 * problems require 8 to 15 calls to `f(x)` to achieve full machine 793 * precision. The worst-case performance (pathological cases) is 794 * approximately twice the number of bits. 795 * 796 * References: "On Enclosing Simple Roots of Nonlinear Equations", 797 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, 798 * pp733-744 (1993). Fortran code available from $(HTTP 799 * www.netlib.org,www.netlib.org) as algorithm TOMS478. 800 * 801 */ 802 T findRoot(T, DF, DT)(scope DF f, in T a, in T b, 803 scope DT tolerance) //= (T a, T b) => false) 804 if ( 805 isFloatingPoint!T && 806 is(typeof(tolerance(T.init, T.init)) : bool) && 807 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R 808 ) 809 { 810 immutable fa = f(a); 811 if (fa == 0) 812 return a; 813 immutable fb = f(b); 814 if (fb == 0) 815 return b; 816 immutable r = findRoot(f, a, b, fa, fb, tolerance); 817 // Return the first value if it is smaller or NaN 818 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; 819 } 820 821 ///ditto 822 T findRoot(T, DF)(scope DF f, in T a, in T b) 823 { 824 return findRoot(f, a, b, (T a, T b) => false); 825 } 826 827 /** Find root of a real function f(x) by bracketing, allowing the 828 * termination condition to be specified. 829 * 830 * Params: 831 * 832 * f = Function to be analyzed 833 * 834 * ax = Left bound of initial range of `f` known to contain the 835 * root. 836 * 837 * bx = Right bound of initial range of `f` known to contain the 838 * root. 839 * 840 * fax = Value of $(D f(ax)). 841 * 842 * fbx = Value of $(D f(bx)). $(D fax) and $(D fbx) should have opposite signs. 843 * ($(D f(ax)) and $(D f(bx)) are commonly known in advance.) 844 * 845 * 846 * tolerance = Defines an early termination condition. Receives the 847 * current upper and lower bounds on the root. The 848 * delegate must return $(D true) when these bounds are 849 * acceptable. If this function always returns $(D false), 850 * full machine precision will be achieved. 851 * 852 * Returns: 853 * 854 * A tuple consisting of two ranges. The first two elements are the 855 * range (in `x`) of the root, while the second pair of elements 856 * are the corresponding function values at those points. If an exact 857 * root was found, both of the first two elements will contain the 858 * root, and the second pair of elements will be 0. 859 */ 860 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, in T ax, in T bx, in R fax, in R fbx, 861 scope DT tolerance) // = (T a, T b) => false) 862 if ( 863 isFloatingPoint!T && 864 is(typeof(tolerance(T.init, T.init)) : bool) && 865 is(typeof(f(T.init)) == R) && isFloatingPoint!R 866 ) 867 in 868 { 869 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); 870 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); 871 } 872 body 873 { 874 // Author: Don Clugston. This code is (heavily) modified from TOMS748 875 // (www.netlib.org). The changes to improve the worst-cast performance are 876 // entirely original. 877 878 T a, b, d; // [a .. b] is our current bracket. d is the third best guess. 879 R fa, fb, fd; // Values of f at a, b, d. 880 bool done = false; // Has a root been found? 881 882 // Allow ax and bx to be provided in reverse order 883 if (ax <= bx) 884 { 885 a = ax; fa = fax; 886 b = bx; fb = fbx; 887 } 888 else 889 { 890 a = bx; fa = fbx; 891 b = ax; fb = fax; 892 } 893 894 // Test the function at point c; update brackets accordingly 895 void bracket(T c) 896 { 897 R fc = f(c); 898 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN 899 { 900 a = c; 901 fa = fc; 902 d = c; 903 fd = fc; 904 done = true; 905 return; 906 } 907 908 // Determine new enclosing interval 909 if (signbit(fa) != signbit(fc)) 910 { 911 d = b; 912 fd = fb; 913 b = c; 914 fb = fc; 915 } 916 else 917 { 918 d = a; 919 fd = fa; 920 a = c; 921 fa = fc; 922 } 923 } 924 925 /* Perform a secant interpolation. If the result would lie on a or b, or if 926 a and b differ so wildly in magnitude that the result would be meaningless, 927 perform a bisection instead. 928 */ 929 static T secant_interpolate(T a, T b, R fa, R fb) 930 { 931 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) 932 { 933 // Catastrophic cancellation 934 if (a == 0) 935 a = copysign(T(0), b); 936 else if (b == 0) 937 b = copysign(T(0), a); 938 else if (signbit(a) != signbit(b)) 939 return 0; 940 T c = ieeeMean(a, b); 941 return c; 942 } 943 // avoid overflow 944 if (b - a > T.max) 945 return b / 2 + a / 2; 946 if (fb - fa > R.max) 947 return a - (b - a) / 2; 948 T c = a - (fa / (fb - fa)) * (b - a); 949 if (c == a || c == b) 950 return (a + b) / 2; 951 return c; 952 } 953 954 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the 955 quadratic polynomial interpolating f(x) at a, b, and d. 956 Returns: 957 The approximate zero in [a .. b] of the quadratic polynomial. 958 */ 959 T newtonQuadratic(int numsteps) 960 { 961 // Find the coefficients of the quadratic polynomial. 962 immutable T a0 = fa; 963 immutable T a1 = (fb - fa)/(b - a); 964 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); 965 966 // Determine the starting point of newton steps. 967 T c = oppositeSigns(a2, fa) ? a : b; 968 969 // start the safeguarded newton steps. 970 foreach (int i; 0 .. numsteps) 971 { 972 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); 973 immutable T pdc = a1 + a2*((2 * c) - (a + b)); 974 if (pdc == 0) 975 return a - a0 / a1; 976 else 977 c = c - pc / pdc; 978 } 979 return c; 980 } 981 982 // On the first iteration we take a secant step: 983 if (fa == 0 || fa.isNaN()) 984 { 985 done = true; 986 b = a; 987 fb = fa; 988 } 989 else if (fb == 0 || fb.isNaN()) 990 { 991 done = true; 992 a = b; 993 fa = fb; 994 } 995 else 996 { 997 bracket(secant_interpolate(a, b, fa, fb)); 998 } 999 1000 // Starting with the second iteration, higher-order interpolation can 1001 // be used. 1002 int itnum = 1; // Iteration number 1003 int baditer = 1; // Num bisections to take if an iteration is bad. 1004 T c, e; // e is our fourth best guess 1005 R fe; 1006 1007 whileloop: 1008 while (!done && (b != nextUp(a)) && !tolerance(a, b)) 1009 { 1010 T a0 = a, b0 = b; // record the brackets 1011 1012 // Do two higher-order (cubic or parabolic) interpolation steps. 1013 foreach (int QQ; 0 .. 2) 1014 { 1015 // Cubic inverse interpolation requires that 1016 // all four function values fa, fb, fd, and fe are distinct; 1017 // otherwise use quadratic interpolation. 1018 bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 1019 && (fb != fd) && (fb != fe) && (fd != fe); 1020 // The first time, cubic interpolation is impossible. 1021 if (itnum<2) distinct = false; 1022 bool ok = distinct; 1023 if (distinct) 1024 { 1025 // Cubic inverse interpolation of f(x) at a, b, d, and e 1026 immutable q11 = (d - e) * fd / (fe - fd); 1027 immutable q21 = (b - d) * fb / (fd - fb); 1028 immutable q31 = (a - b) * fa / (fb - fa); 1029 immutable d21 = (b - d) * fd / (fd - fb); 1030 immutable d31 = (a - b) * fb / (fb - fa); 1031 1032 immutable q22 = (d21 - q11) * fb / (fe - fb); 1033 immutable q32 = (d31 - q21) * fa / (fd - fa); 1034 immutable d32 = (d31 - q21) * fd / (fd - fa); 1035 immutable q33 = (d32 - q22) * fa / (fe - fa); 1036 c = a + (q31 + q32 + q33); 1037 if (c.isNaN() || (c <= a) || (c >= b)) 1038 { 1039 // DAC: If the interpolation predicts a or b, it's 1040 // probable that it's the actual root. Only allow this if 1041 // we're already close to the root. 1042 if (c == a && a - b != a) 1043 { 1044 c = nextUp(a); 1045 } 1046 else if (c == b && a - b != -b) 1047 { 1048 c = nextDown(b); 1049 } 1050 else 1051 { 1052 ok = false; 1053 } 1054 } 1055 } 1056 if (!ok) 1057 { 1058 // DAC: Alefeld doesn't explain why the number of newton steps 1059 // should vary. 1060 c = newtonQuadratic(distinct ? 3 : 2); 1061 if (c.isNaN() || (c <= a) || (c >= b)) 1062 { 1063 // Failure, try a secant step: 1064 c = secant_interpolate(a, b, fa, fb); 1065 } 1066 } 1067 ++itnum; 1068 e = d; 1069 fe = fd; 1070 bracket(c); 1071 if (done || ( b == nextUp(a)) || tolerance(a, b)) 1072 break whileloop; 1073 if (itnum == 2) 1074 continue whileloop; 1075 } 1076 1077 // Now we take a double-length secant step: 1078 T u; 1079 R fu; 1080 if (fabs(fa) < fabs(fb)) 1081 { 1082 u = a; 1083 fu = fa; 1084 } 1085 else 1086 { 1087 u = b; 1088 fu = fb; 1089 } 1090 c = u - 2 * (fu / (fb - fa)) * (b - a); 1091 1092 // DAC: If the secant predicts a value equal to an endpoint, it's 1093 // probably false. 1094 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) 1095 { 1096 if ((a-b) == a || (b-a) == b) 1097 { 1098 if ((a>0 && b<0) || (a<0 && b>0)) 1099 c = 0; 1100 else 1101 { 1102 if (a == 0) 1103 c = ieeeMean(copysign(T(0), b), b); 1104 else if (b == 0) 1105 c = ieeeMean(copysign(T(0), a), a); 1106 else 1107 c = ieeeMean(a, b); 1108 } 1109 } 1110 else 1111 { 1112 c = a + (b - a) / 2; 1113 } 1114 } 1115 e = d; 1116 fe = fd; 1117 bracket(c); 1118 if (done || (b == nextUp(a)) || tolerance(a, b)) 1119 break; 1120 1121 // IMPROVE THE WORST-CASE PERFORMANCE 1122 // We must ensure that the bounds reduce by a factor of 2 1123 // in binary space! every iteration. If we haven't achieved this 1124 // yet, or if we don't yet know what the exponent is, 1125 // perform a binary chop. 1126 1127 if ((a == 0 || b == 0 || 1128 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) 1129 && (b - a) < T(0.25) * (b0 - a0)) 1130 { 1131 baditer = 1; 1132 continue; 1133 } 1134 1135 // DAC: If this happens on consecutive iterations, we probably have a 1136 // pathological function. Perform a number of bisections equal to the 1137 // total number of consecutive bad iterations. 1138 1139 if ((b - a) < T(0.25) * (b0 - a0)) 1140 baditer = 1; 1141 foreach (int QQ; 0 .. baditer) 1142 { 1143 e = d; 1144 fe = fd; 1145 1146 T w; 1147 if ((a>0 && b<0) || (a<0 && b>0)) 1148 w = 0; 1149 else 1150 { 1151 T usea = a; 1152 T useb = b; 1153 if (a == 0) 1154 usea = copysign(T(0), b); 1155 else if (b == 0) 1156 useb = copysign(T(0), a); 1157 w = ieeeMean(usea, useb); 1158 } 1159 bracket(w); 1160 } 1161 ++baditer; 1162 } 1163 return Tuple!(T, T, R, R)(a, b, fa, fb); 1164 } 1165 1166 ///ditto 1167 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, in T ax, in T bx, in R fax, in R fbx) 1168 { 1169 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); 1170 } 1171 1172 ///ditto 1173 T findRoot(T, R)(scope R delegate(T) f, in T a, in T b, 1174 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) 1175 { 1176 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); 1177 } 1178 1179 @safe nothrow unittest 1180 { 1181 int numProblems = 0; 1182 int numCalls; 1183 1184 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure 1185 { 1186 //numCalls=0; 1187 //++numProblems; 1188 assert(!x1.isNaN() && !x2.isNaN()); 1189 assert(signbit(x1) != signbit(x2)); 1190 auto result = findRoot(f, x1, x2, f(x1), f(x2), 1191 (real lo, real hi) { return false; }); 1192 1193 auto flo = f(result[0]); 1194 auto fhi = f(result[1]); 1195 if (flo != 0) 1196 { 1197 assert(oppositeSigns(flo, fhi)); 1198 } 1199 } 1200 1201 // Test functions 1202 real cubicfn(real x) @nogc @safe nothrow pure 1203 { 1204 //++numCalls; 1205 if (x>float.max) 1206 x = float.max; 1207 if (x<-double.max) 1208 x = -double.max; 1209 // This has a single real root at -59.286543284815 1210 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; 1211 } 1212 // Test a function with more than one root. 1213 real multisine(real x) { ++numCalls; return sin(x); } 1214 //testFindRoot( &multisine, 6, 90); 1215 //testFindRoot(&cubicfn, -100, 100); 1216 //testFindRoot( &cubicfn, -double.max, real.max); 1217 1218 1219 /* Tests from the paper: 1220 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 1221 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). 1222 */ 1223 // Parameters common to many alefeld tests. 1224 int n; 1225 real ale_a, ale_b; 1226 1227 int powercalls = 0; 1228 1229 real power(real x) 1230 { 1231 ++powercalls; 1232 ++numCalls; 1233 return pow(x, n) + double.min_normal; 1234 } 1235 int [] power_nvals = [3, 5, 7, 9, 19, 25]; 1236 // Alefeld paper states that pow(x,n) is a very poor case, where bisection 1237 // outperforms his method, and gives total numcalls = 1238 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 1239 // 2624 for brent (6.8/bit) 1240 // ... but that is for double, not real80. 1241 // This poor performance seems mainly due to catastrophic cancellation, 1242 // which is avoided here by the use of ieeeMean(). 1243 // I get: 231 (0.48/bit). 1244 // IE this is 10X faster in Alefeld's worst case 1245 numProblems=0; 1246 foreach (k; power_nvals) 1247 { 1248 n = k; 1249 //testFindRoot(&power, -1, 10); 1250 } 1251 1252 int powerProblems = numProblems; 1253 1254 // Tests from Alefeld paper 1255 1256 int [9] alefeldSums; 1257 real alefeld0(real x) 1258 { 1259 ++alefeldSums[0]; 1260 ++numCalls; 1261 real q = sin(x) - x/2; 1262 for (int i=1; i<20; ++i) 1263 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 1264 return q; 1265 } 1266 real alefeld1(real x) 1267 { 1268 ++numCalls; 1269 ++alefeldSums[1]; 1270 return ale_a*x + exp(ale_b * x); 1271 } 1272 real alefeld2(real x) 1273 { 1274 ++numCalls; 1275 ++alefeldSums[2]; 1276 return pow(x, n) - ale_a; 1277 } 1278 real alefeld3(real x) 1279 { 1280 ++numCalls; 1281 ++alefeldSums[3]; 1282 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 1283 } 1284 real alefeld4(real x) 1285 { 1286 ++numCalls; 1287 ++alefeldSums[4]; 1288 return x*x - pow(1-x, n); 1289 } 1290 real alefeld5(real x) 1291 { 1292 ++numCalls; 1293 ++alefeldSums[5]; 1294 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 1295 } 1296 real alefeld6(real x) 1297 { 1298 ++numCalls; 1299 ++alefeldSums[6]; 1300 return exp(-n*x)*(x-1.01L) + pow(x, n); 1301 } 1302 real alefeld7(real x) 1303 { 1304 ++numCalls; 1305 ++alefeldSums[7]; 1306 return (n*x-1)/((n-1)*x); 1307 } 1308 1309 numProblems=0; 1310 //testFindRoot(&alefeld0, PI_2, PI); 1311 for (n=1; n <= 10; ++n) 1312 { 1313 //testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 1314 } 1315 ale_a = -40; ale_b = -1; 1316 //testFindRoot(&alefeld1, -9, 31); 1317 ale_a = -100; ale_b = -2; 1318 //testFindRoot(&alefeld1, -9, 31); 1319 ale_a = -200; ale_b = -3; 1320 //testFindRoot(&alefeld1, -9, 31); 1321 int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 1322 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 1323 int [] nvals_6 = [1, 5, 10, 15, 20]; 1324 int [] nvals_7 = [2, 5, 15, 20]; 1325 1326 for (int i=4; i<12; i+=2) 1327 { 1328 n = i; 1329 ale_a = 0.2; 1330 //testFindRoot(&alefeld2, 0, 5); 1331 ale_a=1; 1332 //testFindRoot(&alefeld2, 0.95, 4.05); 1333 //testFindRoot(&alefeld2, 0, 1.5); 1334 } 1335 foreach (i; nvals_3) 1336 { 1337 n=i; 1338 //testFindRoot(&alefeld3, 0, 1); 1339 } 1340 foreach (i; nvals_3) 1341 { 1342 n=i; 1343 //testFindRoot(&alefeld4, 0, 1); 1344 } 1345 foreach (i; nvals_5) 1346 { 1347 n=i; 1348 //testFindRoot(&alefeld5, 0, 1); 1349 } 1350 foreach (i; nvals_6) 1351 { 1352 n=i; 1353 //testFindRoot(&alefeld6, 0, 1); 1354 } 1355 foreach (i; nvals_7) 1356 { 1357 n=i; 1358 //testFindRoot(&alefeld7, 0.01L, 1); 1359 } 1360 real worstcase(real x) 1361 { 1362 ++numCalls; 1363 return x<0.3*real.max? -0.999e-3 : 1.0; 1364 } 1365 //testFindRoot(&worstcase, -real.max, real.max); 1366 1367 // just check that the double + float cases compile 1368 //findRoot((double x){ return 0.0; }, -double.max, double.max); 1369 //findRoot((float x){ return 0.0f; }, -float.max, float.max); 1370 1371 /* 1372 int grandtotal=0; 1373 foreach (calls; alefeldSums) 1374 { 1375 grandtotal+=calls; 1376 } 1377 grandtotal-=2*numProblems; 1378 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 1379 grandtotal, (1.0*grandtotal)/numProblems); 1380 powercalls -= 2*powerProblems; 1381 printf("POWER TOTAL = %d avg = %f ", powercalls, 1382 (1.0*powercalls)/powerProblems); 1383 */ 1384 //Issue 14231 1385 auto xp = findRoot((float x) => x, 0f, 1f); 1386 auto xn = findRoot((float x) => x, -1f, -0f); 1387 } 1388 1389 //regression control 1390 @system unittest 1391 { 1392 // @system due to the case in the 2nd line 1393 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); 1394 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); 1395 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); 1396 } 1397 1398 /++ 1399 Find a real minimum of a real function `f(x)` via bracketing. 1400 Given a function `f` and a range `(ax .. bx)`, 1401 returns the value of `x` in the range which is closest to a minimum of `f(x)`. 1402 `f` is never evaluted at the endpoints of `ax` and `bx`. 1403 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. 1404 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; 1405 otherwise, this algorithm is guaranteed to succeed. 1406 1407 Params: 1408 f = Function to be analyzed 1409 ax = Left bound of initial range of f known to contain the minimum. 1410 bx = Right bound of initial range of f known to contain the minimum. 1411 relTolerance = Relative tolerance. 1412 absTolerance = Absolute tolerance. 1413 1414 Preconditions: 1415 `ax` and `bx` shall be finite reals. $(BR) 1416 $(D relTolerance) shall be normal positive real. $(BR) 1417 $(D absTolerance) shall be normal positive real no less then $(D T.epsilon*2). 1418 1419 Returns: 1420 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. 1421 1422 The method used is a combination of golden section search and 1423 successive parabolic interpolation. Convergence is never much slower 1424 than that for a Fibonacci search. 1425 1426 References: 1427 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) 1428 1429 See_Also: $(LREF findRoot), $(REF isNormal, std,math) 1430 +/ 1431 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") 1432 findLocalMin(T, DF)( 1433 scope DF f, 1434 in T ax, 1435 in T bx, 1436 in T relTolerance = sqrt(T.epsilon), 1437 in T absTolerance = sqrt(T.epsilon), 1438 ) 1439 if (isFloatingPoint!T 1440 && __traits(compiles, {T _ = DF.init(T.init);})) 1441 in 1442 { 1443 assert(isFinite(ax), "ax is not finite"); 1444 assert(isFinite(bx), "bx is not finite"); 1445 assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); 1446 assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); 1447 assert(relTolerance >= 0, "absTolerance is not positive"); 1448 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); 1449 } 1450 out (result) 1451 { 1452 assert(isFinite(result.x)); 1453 } 1454 body 1455 { 1456 alias R = Unqual!(CommonType!(ReturnType!DF, T)); 1457 // c is the squared inverse of the golden ratio 1458 // (3 - sqrt(5))/2 1459 // Value obtained from Wolfram Alpha. 1460 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; 1461 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; 1462 R tolerance; 1463 T a = ax > bx ? bx : ax; 1464 T b = ax > bx ? ax : bx; 1465 // sequence of declarations suitable for SIMD instructions 1466 T v = a * cm1 + b * c; 1467 assert(isFinite(v)); 1468 R fv = f(v); 1469 if (isNaN(fv) || fv == -T.infinity) 1470 { 1471 return typeof(return)(v, fv, T.init); 1472 } 1473 T w = v; 1474 R fw = fv; 1475 T x = v; 1476 R fx = fv; 1477 size_t i; 1478 for (R d = 0, e = 0;;) 1479 { 1480 i++; 1481 T m = (a + b) / 2; 1482 // This fix is not part of the original algorithm 1483 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. 1484 { 1485 m = a / 2 + b / 2; 1486 if (!isFinite(m)) // fast-math compiler switch is enabled 1487 { 1488 //SIMD instructions can be used by compiler, do not reduce declarations 1489 int a_exp = void; 1490 int b_exp = void; 1491 immutable an = frexp(a, a_exp); 1492 immutable bn = frexp(b, b_exp); 1493 immutable am = ldexp(an, a_exp-1); 1494 immutable bm = ldexp(bn, b_exp-1); 1495 m = am + bm; 1496 if (!isFinite(m)) // wrong input: constraints are disabled in release mode 1497 { 1498 return typeof(return).init; 1499 } 1500 } 1501 } 1502 tolerance = absTolerance * fabs(x) + relTolerance; 1503 immutable t2 = tolerance * 2; 1504 // check stopping criterion 1505 if (!(fabs(x - m) > t2 - (b - a) / 2)) 1506 { 1507 break; 1508 } 1509 R p = 0; 1510 R q = 0; 1511 R r = 0; 1512 // fit parabola 1513 if (fabs(e) > tolerance) 1514 { 1515 immutable xw = x - w; 1516 immutable fxw = fx - fw; 1517 immutable xv = x - v; 1518 immutable fxv = fx - fv; 1519 immutable xwfxv = xw * fxv; 1520 immutable xvfxw = xv * fxw; 1521 p = xv * xvfxw - xw * xwfxv; 1522 q = (xvfxw - xwfxv) * 2; 1523 if (q > 0) 1524 p = -p; 1525 else 1526 q = -q; 1527 r = e; 1528 e = d; 1529 } 1530 T u; 1531 // a parabolic-interpolation step 1532 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) 1533 { 1534 d = p / q; 1535 u = x + d; 1536 // f must not be evaluated too close to a or b 1537 if (u - a < t2 || b - u < t2) 1538 d = x < m ? tolerance : -tolerance; 1539 } 1540 // a golden-section step 1541 else 1542 { 1543 e = (x < m ? b : a) - x; 1544 d = c * e; 1545 } 1546 // f must not be evaluated too close to x 1547 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); 1548 immutable fu = f(u); 1549 if (isNaN(fu) || fu == -T.infinity) 1550 { 1551 return typeof(return)(u, fu, T.init); 1552 } 1553 // update a, b, v, w, and x 1554 if (fu <= fx) 1555 { 1556 u < x ? b : a = x; 1557 v = w; fv = fw; 1558 w = x; fw = fx; 1559 x = u; fx = fu; 1560 } 1561 else 1562 { 1563 u < x ? a : b = u; 1564 if (fu <= fw || w == x) 1565 { 1566 v = w; fv = fw; 1567 w = u; fw = fu; 1568 } 1569 else if (fu <= fv || v == x || v == w) 1570 { // do not remove this braces 1571 v = u; fv = fu; 1572 } 1573 } 1574 } 1575 return typeof(return)(x, fx, tolerance * 3); 1576 } 1577 1578 /// 1579 @safe unittest 1580 { 1581 import std.math : approxEqual; 1582 1583 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); 1584 assert(ret.x.approxEqual(4.0)); 1585 assert(ret.y.approxEqual(0.0)); 1586 } 1587 1588 @safe unittest 1589 { 1590 import std.meta : AliasSeq; 1591 foreach (T; AliasSeq!(double, float, real)) 1592 { 1593 { 1594 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); 1595 assert(ret.x.approxEqual(T(4))); 1596 assert(ret.y.approxEqual(T(0))); 1597 } 1598 { 1599 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); 1600 assert(approxEqual(ret.x, T(1))); 1601 assert(approxEqual(ret.y, T(0))); 1602 assert(ret.error <= 10 * T.epsilon); 1603 } 1604 { 1605 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); 1606 assert(!ret.x.isNaN); 1607 assert(ret.y.isNaN); 1608 assert(ret.error.isNaN); 1609 } 1610 { 1611 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); 1612 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); 1613 assert(ret.x >= 0 && ret.x <= ret.error); 1614 } 1615 { 1616 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); 1617 assert(ret.y < -18); 1618 assert(ret.error < 5e-08); 1619 assert(ret.x >= 0 && ret.x <= ret.error); 1620 } 1621 { 1622 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); 1623 assert(ret.x.fabs.approxEqual(T(1))); 1624 assert(ret.y.fabs.approxEqual(T(1))); 1625 assert(ret.error.approxEqual(T(0))); 1626 } 1627 } 1628 } 1629 1630 /** 1631 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, 1632 Euclidean distance) between input ranges $(D a) and 1633 $(D b). The two ranges must have the same length. The three-parameter 1634 version stops computation as soon as the distance is greater than or 1635 equal to $(D limit) (this is useful to save computation if a small 1636 distance is sought). 1637 */ 1638 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1639 euclideanDistance(Range1, Range2)(Range1 a, Range2 b) 1640 if (isInputRange!(Range1) && isInputRange!(Range2)) 1641 { 1642 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1643 static if (haveLen) assert(a.length == b.length); 1644 Unqual!(typeof(return)) result = 0; 1645 for (; !a.empty; a.popFront(), b.popFront()) 1646 { 1647 immutable t = a.front - b.front; 1648 result += t * t; 1649 } 1650 static if (!haveLen) assert(b.empty); 1651 return sqrt(result); 1652 } 1653 1654 /// Ditto 1655 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1656 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) 1657 if (isInputRange!(Range1) && isInputRange!(Range2)) 1658 { 1659 limit *= limit; 1660 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1661 static if (haveLen) assert(a.length == b.length); 1662 Unqual!(typeof(return)) result = 0; 1663 for (; ; a.popFront(), b.popFront()) 1664 { 1665 if (a.empty) 1666 { 1667 static if (!haveLen) assert(b.empty); 1668 break; 1669 } 1670 immutable t = a.front - b.front; 1671 result += t * t; 1672 if (result >= limit) break; 1673 } 1674 return sqrt(result); 1675 } 1676 1677 @safe unittest 1678 { 1679 import std.meta : AliasSeq; 1680 foreach (T; AliasSeq!(double, const double, immutable double)) 1681 { 1682 T[] a = [ 1.0, 2.0, ]; 1683 T[] b = [ 4.0, 6.0, ]; 1684 assert(euclideanDistance(a, b) == 5); 1685 assert(euclideanDistance(a, b, 5) == 5); 1686 assert(euclideanDistance(a, b, 4) == 5); 1687 assert(euclideanDistance(a, b, 2) == 3); 1688 } 1689 } 1690 1691 /** 1692 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, 1693 dot product) of input ranges $(D a) and $(D 1694 b). The two ranges must have the same length. If both ranges define 1695 length, the check is done once; otherwise, it is done at each 1696 iteration. 1697 */ 1698 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1699 dotProduct(Range1, Range2)(Range1 a, Range2 b) 1700 if (isInputRange!(Range1) && isInputRange!(Range2) && 1701 !(isArray!(Range1) && isArray!(Range2))) 1702 { 1703 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1704 static if (haveLen) assert(a.length == b.length); 1705 Unqual!(typeof(return)) result = 0; 1706 for (; !a.empty; a.popFront(), b.popFront()) 1707 { 1708 result += a.front * b.front; 1709 } 1710 static if (!haveLen) assert(b.empty); 1711 return result; 1712 } 1713 1714 /// Ditto 1715 CommonType!(F1, F2) 1716 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) 1717 { 1718 immutable n = avector.length; 1719 assert(n == bvector.length); 1720 auto avec = avector.ptr, bvec = bvector.ptr; 1721 Unqual!(typeof(return)) sum0 = 0, sum1 = 0; 1722 1723 const all_endp = avec + n; 1724 const smallblock_endp = avec + (n & ~3); 1725 const bigblock_endp = avec + (n & ~15); 1726 1727 for (; avec != bigblock_endp; avec += 16, bvec += 16) 1728 { 1729 sum0 += avec[0] * bvec[0]; 1730 sum1 += avec[1] * bvec[1]; 1731 sum0 += avec[2] * bvec[2]; 1732 sum1 += avec[3] * bvec[3]; 1733 sum0 += avec[4] * bvec[4]; 1734 sum1 += avec[5] * bvec[5]; 1735 sum0 += avec[6] * bvec[6]; 1736 sum1 += avec[7] * bvec[7]; 1737 sum0 += avec[8] * bvec[8]; 1738 sum1 += avec[9] * bvec[9]; 1739 sum0 += avec[10] * bvec[10]; 1740 sum1 += avec[11] * bvec[11]; 1741 sum0 += avec[12] * bvec[12]; 1742 sum1 += avec[13] * bvec[13]; 1743 sum0 += avec[14] * bvec[14]; 1744 sum1 += avec[15] * bvec[15]; 1745 } 1746 1747 for (; avec != smallblock_endp; avec += 4, bvec += 4) 1748 { 1749 sum0 += avec[0] * bvec[0]; 1750 sum1 += avec[1] * bvec[1]; 1751 sum0 += avec[2] * bvec[2]; 1752 sum1 += avec[3] * bvec[3]; 1753 } 1754 1755 sum0 += sum1; 1756 1757 /* Do trailing portion in naive loop. */ 1758 while (avec != all_endp) 1759 { 1760 sum0 += *avec * *bvec; 1761 ++avec; 1762 ++bvec; 1763 } 1764 1765 return sum0; 1766 } 1767 1768 @system unittest 1769 { 1770 // @system due to dotProduct and assertCTFEable 1771 import std.exception : assertCTFEable; 1772 import std.meta : AliasSeq; 1773 foreach (T; AliasSeq!(double, const double, immutable double)) 1774 { 1775 T[] a = [ 1.0, 2.0, ]; 1776 T[] b = [ 4.0, 6.0, ]; 1777 assert(dotProduct(a, b) == 16); 1778 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); 1779 } 1780 1781 // Make sure the unrolled loop codepath gets tested. 1782 static const x = 1783 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]; 1784 static const y = 1785 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]; 1786 assertCTFEable!({ assert(dotProduct(x, y) == 2280); }); 1787 } 1788 1789 /** 1790 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, 1791 cosine similarity) of input ranges $(D a) and $(D 1792 b). The two ranges must have the same length. If both ranges define 1793 length, the check is done once; otherwise, it is done at each 1794 iteration. If either range has all-zero elements, return 0. 1795 */ 1796 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1797 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) 1798 if (isInputRange!(Range1) && isInputRange!(Range2)) 1799 { 1800 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1801 static if (haveLen) assert(a.length == b.length); 1802 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; 1803 for (; !a.empty; a.popFront(), b.popFront()) 1804 { 1805 immutable t1 = a.front, t2 = b.front; 1806 norma += t1 * t1; 1807 normb += t2 * t2; 1808 dotprod += t1 * t2; 1809 } 1810 static if (!haveLen) assert(b.empty); 1811 if (norma == 0 || normb == 0) return 0; 1812 return dotprod / sqrt(norma * normb); 1813 } 1814 1815 @safe unittest 1816 { 1817 import std.meta : AliasSeq; 1818 foreach (T; AliasSeq!(double, const double, immutable double)) 1819 { 1820 T[] a = [ 1.0, 2.0, ]; 1821 T[] b = [ 4.0, 3.0, ]; 1822 assert(approxEqual( 1823 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), 1824 0.01)); 1825 } 1826 } 1827 1828 /** 1829 Normalizes values in $(D range) by multiplying each element with a 1830 number chosen such that values sum up to $(D sum). If elements in $(D 1831 range) sum to zero, assigns $(D sum / range.length) to 1832 all. Normalization makes sense only if all elements in $(D range) are 1833 positive. $(D normalize) assumes that is the case without checking it. 1834 1835 Returns: $(D true) if normalization completed normally, $(D false) if 1836 all elements in $(D range) were zero or if $(D range) is empty. 1837 */ 1838 bool normalize(R)(R range, ElementType!(R) sum = 1) 1839 if (isForwardRange!(R)) 1840 { 1841 ElementType!(R) s = 0; 1842 // Step 1: Compute sum and length of the range 1843 static if (hasLength!(R)) 1844 { 1845 const length = range.length; 1846 foreach (e; range) 1847 { 1848 s += e; 1849 } 1850 } 1851 else 1852 { 1853 uint length = 0; 1854 foreach (e; range) 1855 { 1856 s += e; 1857 ++length; 1858 } 1859 } 1860 // Step 2: perform normalization 1861 if (s == 0) 1862 { 1863 if (length) 1864 { 1865 immutable f = sum / range.length; 1866 foreach (ref e; range) e = f; 1867 } 1868 return false; 1869 } 1870 // The path most traveled 1871 assert(s >= 0); 1872 immutable f = sum / s; 1873 foreach (ref e; range) 1874 e *= f; 1875 return true; 1876 } 1877 1878 /// 1879 @safe unittest 1880 { 1881 double[] a = []; 1882 assert(!normalize(a)); 1883 a = [ 1.0, 3.0 ]; 1884 assert(normalize(a)); 1885 assert(a == [ 0.25, 0.75 ]); 1886 a = [ 0.0, 0.0 ]; 1887 assert(!normalize(a)); 1888 assert(a == [ 0.5, 0.5 ]); 1889 } 1890 1891 /** 1892 Compute the sum of binary logarithms of the input range $(D r). 1893 The error of this method is much smaller than with a naive sum of log2. 1894 */ 1895 ElementType!Range sumOfLog2s(Range)(Range r) 1896 if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) 1897 { 1898 long exp = 0; 1899 Unqual!(typeof(return)) x = 1; 1900 foreach (e; r) 1901 { 1902 if (e < 0) 1903 return typeof(return).nan; 1904 int lexp = void; 1905 x *= frexp(e, lexp); 1906 exp += lexp; 1907 if (x < 0.5) 1908 { 1909 x *= 2; 1910 exp--; 1911 } 1912 } 1913 return exp + log2(x); 1914 } 1915 1916 /// 1917 @safe unittest 1918 { 1919 import std.math : isNaN; 1920 1921 assert(sumOfLog2s(new double[0]) == 0); 1922 assert(sumOfLog2s([0.0L]) == -real.infinity); 1923 assert(sumOfLog2s([-0.0L]) == -real.infinity); 1924 assert(sumOfLog2s([2.0L]) == 1); 1925 assert(sumOfLog2s([-2.0L]).isNaN()); 1926 assert(sumOfLog2s([real.nan]).isNaN()); 1927 assert(sumOfLog2s([-real.nan]).isNaN()); 1928 assert(sumOfLog2s([real.infinity]) == real.infinity); 1929 assert(sumOfLog2s([-real.infinity]).isNaN()); 1930 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); 1931 } 1932 1933 /** 1934 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), 1935 _entropy) of input range $(D r) in bits. This 1936 function assumes (without checking) that the values in $(D r) are all 1937 in $(D [0, 1]). For the entropy to be meaningful, often $(D r) should 1938 be normalized too (i.e., its values should sum to 1). The 1939 two-parameter version stops evaluating as soon as the intermediate 1940 result is greater than or equal to $(D max). 1941 */ 1942 ElementType!Range entropy(Range)(Range r) 1943 if (isInputRange!Range) 1944 { 1945 Unqual!(typeof(return)) result = 0.0; 1946 for (;!r.empty; r.popFront) 1947 { 1948 if (!r.front) continue; 1949 result -= r.front * log2(r.front); 1950 } 1951 return result; 1952 } 1953 1954 /// Ditto 1955 ElementType!Range entropy(Range, F)(Range r, F max) 1956 if (isInputRange!Range && 1957 !is(CommonType!(ElementType!Range, F) == void)) 1958 { 1959 Unqual!(typeof(return)) result = 0.0; 1960 for (;!r.empty; r.popFront) 1961 { 1962 if (!r.front) continue; 1963 result -= r.front * log2(r.front); 1964 if (result >= max) break; 1965 } 1966 return result; 1967 } 1968 1969 @safe unittest 1970 { 1971 import std.meta : AliasSeq; 1972 foreach (T; AliasSeq!(double, const double, immutable double)) 1973 { 1974 T[] p = [ 0.0, 0, 0, 1 ]; 1975 assert(entropy(p) == 0); 1976 p = [ 0.25, 0.25, 0.25, 0.25 ]; 1977 assert(entropy(p) == 2); 1978 assert(entropy(p, 1) == 1); 1979 } 1980 } 1981 1982 /** 1983 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, 1984 Kullback-Leibler divergence) between input ranges 1985 $(D a) and $(D b), which is the sum $(D ai * log(ai / bi)). The base 1986 of logarithm is 2. The ranges are assumed to contain elements in $(D 1987 [0, 1]). Usually the ranges are normalized probability distributions, 1988 but this is not required or checked by $(D 1989 kullbackLeiblerDivergence). If any element $(D bi) is zero and the 1990 corresponding element $(D ai) nonzero, returns infinity. (Otherwise, 1991 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is 1992 considered zero.) If the inputs are normalized, the result is 1993 positive. 1994 */ 1995 CommonType!(ElementType!Range1, ElementType!Range2) 1996 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) 1997 if (isInputRange!(Range1) && isInputRange!(Range2)) 1998 { 1999 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2000 static if (haveLen) assert(a.length == b.length); 2001 Unqual!(typeof(return)) result = 0; 2002 for (; !a.empty; a.popFront(), b.popFront()) 2003 { 2004 immutable t1 = a.front; 2005 if (t1 == 0) continue; 2006 immutable t2 = b.front; 2007 if (t2 == 0) return result.infinity; 2008 assert(t1 > 0 && t2 > 0); 2009 result += t1 * log2(t1 / t2); 2010 } 2011 static if (!haveLen) assert(b.empty); 2012 return result; 2013 } 2014 2015 /// 2016 @safe unittest 2017 { 2018 import std.math : approxEqual; 2019 2020 double[] p = [ 0.0, 0, 0, 1 ]; 2021 assert(kullbackLeiblerDivergence(p, p) == 0); 2022 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2023 assert(kullbackLeiblerDivergence(p1, p1) == 0); 2024 assert(kullbackLeiblerDivergence(p, p1) == 2); 2025 assert(kullbackLeiblerDivergence(p1, p) == double.infinity); 2026 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2027 assert(approxEqual(kullbackLeiblerDivergence(p1, p2), 0.0719281)); 2028 assert(approxEqual(kullbackLeiblerDivergence(p2, p1), 0.0780719)); 2029 } 2030 2031 /** 2032 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, 2033 Jensen-Shannon divergence) between $(D a) and $(D 2034 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * 2035 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are 2036 assumed to contain elements in $(D [0, 1]). Usually the ranges are 2037 normalized probability distributions, but this is not required or 2038 checked by $(D jensenShannonDivergence). If the inputs are normalized, 2039 the result is bounded within $(D [0, 1]). The three-parameter version 2040 stops evaluations as soon as the intermediate result is greater than 2041 or equal to $(D limit). 2042 */ 2043 CommonType!(ElementType!Range1, ElementType!Range2) 2044 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) 2045 if (isInputRange!Range1 && isInputRange!Range2 && 2046 is(CommonType!(ElementType!Range1, ElementType!Range2))) 2047 { 2048 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2049 static if (haveLen) assert(a.length == b.length); 2050 Unqual!(typeof(return)) result = 0; 2051 for (; !a.empty; a.popFront(), b.popFront()) 2052 { 2053 immutable t1 = a.front; 2054 immutable t2 = b.front; 2055 immutable avg = (t1 + t2) / 2; 2056 if (t1 != 0) 2057 { 2058 result += t1 * log2(t1 / avg); 2059 } 2060 if (t2 != 0) 2061 { 2062 result += t2 * log2(t2 / avg); 2063 } 2064 } 2065 static if (!haveLen) assert(b.empty); 2066 return result / 2; 2067 } 2068 2069 /// Ditto 2070 CommonType!(ElementType!Range1, ElementType!Range2) 2071 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2072 if (isInputRange!Range1 && isInputRange!Range2 && 2073 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init 2074 >= F.init) : bool)) 2075 { 2076 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2077 static if (haveLen) assert(a.length == b.length); 2078 Unqual!(typeof(return)) result = 0; 2079 limit *= 2; 2080 for (; !a.empty; a.popFront(), b.popFront()) 2081 { 2082 immutable t1 = a.front; 2083 immutable t2 = b.front; 2084 immutable avg = (t1 + t2) / 2; 2085 if (t1 != 0) 2086 { 2087 result += t1 * log2(t1 / avg); 2088 } 2089 if (t2 != 0) 2090 { 2091 result += t2 * log2(t2 / avg); 2092 } 2093 if (result >= limit) break; 2094 } 2095 static if (!haveLen) assert(b.empty); 2096 return result / 2; 2097 } 2098 2099 /// 2100 @safe unittest 2101 { 2102 import std.math : approxEqual; 2103 2104 double[] p = [ 0.0, 0, 0, 1 ]; 2105 assert(jensenShannonDivergence(p, p) == 0); 2106 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2107 assert(jensenShannonDivergence(p1, p1) == 0); 2108 assert(approxEqual(jensenShannonDivergence(p1, p), 0.548795)); 2109 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2110 assert(approxEqual(jensenShannonDivergence(p1, p2), 0.0186218)); 2111 assert(approxEqual(jensenShannonDivergence(p2, p1), 0.0186218)); 2112 assert(approxEqual(jensenShannonDivergence(p2, p1, 0.005), 0.00602366)); 2113 } 2114 2115 /** 2116 The so-called "all-lengths gap-weighted string kernel" computes a 2117 similarity measure between $(D s) and $(D t) based on all of their 2118 common subsequences of all lengths. Gapped subsequences are also 2119 included. 2120 2121 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, 2122 consider first the case $(D lambda = 1) and the strings $(D s = 2123 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", 2124 "world"]). In that case, $(D gapWeightedSimilarity) counts the 2125 following matches: 2126 2127 $(OL $(LI three matches of length 1, namely $(D "Hello"), $(D "new"), 2128 and $(D "world");) $(LI three matches of length 2, namely ($(D 2129 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) 2130 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) 2131 2132 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of 2133 these matches and adds them up, returning 7. 2134 2135 ---- 2136 string[] s = ["Hello", "brave", "new", "world"]; 2137 string[] t = ["Hello", "new", "world"]; 2138 assert(gapWeightedSimilarity(s, t, 1) == 7); 2139 ---- 2140 2141 Note how the gaps in matching are simply ignored, for example ($(D 2142 "Hello", "new")) is deemed as good a match as ($(D "new", 2143 "world")). This may be too permissive for some applications. To 2144 eliminate gapped matches entirely, use $(D lambda = 0): 2145 2146 ---- 2147 string[] s = ["Hello", "brave", "new", "world"]; 2148 string[] t = ["Hello", "new", "world"]; 2149 assert(gapWeightedSimilarity(s, t, 0) == 4); 2150 ---- 2151 2152 The call above eliminated the gapped matches ($(D "Hello", "new")), 2153 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the 2154 tally. That leaves only 4 matches. 2155 2156 The most interesting case is when gapped matches still participate in 2157 the result, but not as strongly as ungapped matches. The result will 2158 be a smooth, fine-grained similarity measure between the input 2159 strings. This is where values of $(D lambda) between 0 and 1 enter 2160 into play: gapped matches are $(I exponentially penalized with the 2161 number of gaps) with base $(D lambda). This means that an ungapped 2162 match adds 1 to the return value; a match with one gap in either 2163 string adds $(D lambda) to the return value; ...; a match with a total 2164 of $(D n) gaps in both strings adds $(D pow(lambda, n)) to the return 2165 value. In the example above, we have 4 matches without gaps, 2 matches 2166 with one gap, and 1 match with three gaps. The latter match is ($(D 2167 "Hello", "world")), which has two gaps in the first string and one gap 2168 in the second string, totaling to three gaps. Summing these up we get 2169 $(D 4 + 2 * lambda + pow(lambda, 3)). 2170 2171 ---- 2172 string[] s = ["Hello", "brave", "new", "world"]; 2173 string[] t = ["Hello", "new", "world"]; 2174 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); 2175 ---- 2176 2177 $(D gapWeightedSimilarity) is useful wherever a smooth similarity 2178 measure between sequences allowing for approximate matches is 2179 needed. The examples above are given with words, but any sequences 2180 with elements comparable for equality are allowed, e.g. characters or 2181 numbers. $(D gapWeightedSimilarity) uses a highly optimized dynamic 2182 programming implementation that needs $(D 16 * min(s.length, 2183 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time 2184 to complete. 2185 */ 2186 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) 2187 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2188 isRandomAccessRange!(R2) && hasLength!(R2)) 2189 { 2190 import core.exception : onOutOfMemoryError; 2191 import core.stdc.stdlib : malloc, free; 2192 import std.algorithm.mutation : swap; 2193 import std.functional : binaryFun; 2194 2195 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); 2196 if (!t.length) return 0; 2197 2198 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); 2199 if (!dpvi) 2200 onOutOfMemoryError(); 2201 2202 auto dpvi1 = dpvi + t.length; 2203 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); 2204 dpvi[0 .. t.length] = 0; 2205 dpvi1[0] = 0; 2206 immutable lambda2 = lambda * lambda; 2207 2208 F result = 0; 2209 foreach (i; 0 .. s.length) 2210 { 2211 const si = s[i]; 2212 for (size_t j = 0;;) 2213 { 2214 F dpsij = void; 2215 if (binaryFun!(comp)(si, t[j])) 2216 { 2217 dpsij = 1 + dpvi[j]; 2218 result += dpsij; 2219 } 2220 else 2221 { 2222 dpsij = 0; 2223 } 2224 immutable j1 = j + 1; 2225 if (j1 == t.length) break; 2226 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - 2227 lambda2 * dpvi[j]; 2228 j = j1; 2229 } 2230 swap(dpvi, dpvi1); 2231 } 2232 return result; 2233 } 2234 2235 @system unittest 2236 { 2237 string[] s = ["Hello", "brave", "new", "world"]; 2238 string[] t = ["Hello", "new", "world"]; 2239 assert(gapWeightedSimilarity(s, t, 1) == 7); 2240 assert(gapWeightedSimilarity(s, t, 0) == 4); 2241 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); 2242 } 2243 2244 /** 2245 The similarity per $(D gapWeightedSimilarity) has an issue in that it 2246 grows with the lengths of the two strings, even though the strings are 2247 not actually very similar. For example, the range $(D ["Hello", 2248 "world"]) is increasingly similar with the range $(D ["Hello", 2249 "world", "world", "world",...]) as more instances of $(D "world") are 2250 appended. To prevent that, $(D gapWeightedSimilarityNormalized) 2251 computes a normalized version of the similarity that is computed as 2252 $(D gapWeightedSimilarity(s, t, lambda) / 2253 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, 2254 lambda))). The function $(D gapWeightedSimilarityNormalized) (a 2255 so-called normalized kernel) is bounded in $(D [0, 1]), reaches $(D 0) 2256 only for ranges that don't match in any position, and $(D 1) only for 2257 identical ranges. 2258 2259 The optional parameters $(D sSelfSim) and $(D tSelfSim) are meant for 2260 avoiding duplicate computation. Many applications may have already 2261 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D 2262 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed 2263 as $(D sSelfSim) and $(D tSelfSim), respectively. 2264 */ 2265 Select!(isFloatingPoint!(F), F, double) 2266 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) 2267 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) 2268 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2269 isRandomAccessRange!(R2) && hasLength!(R2)) 2270 { 2271 static bool uncomputed(F n) 2272 { 2273 static if (isFloatingPoint!(F)) 2274 return isNaN(n); 2275 else 2276 return n == n.init; 2277 } 2278 if (uncomputed(sSelfSim)) 2279 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); 2280 if (sSelfSim == 0) return 0; 2281 if (uncomputed(tSelfSim)) 2282 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); 2283 if (tSelfSim == 0) return 0; 2284 2285 return gapWeightedSimilarity!(comp)(s, t, lambda) / 2286 sqrt(cast(typeof(return)) sSelfSim * tSelfSim); 2287 } 2288 2289 /// 2290 @system unittest 2291 { 2292 import std.math : approxEqual, sqrt; 2293 2294 string[] s = ["Hello", "brave", "new", "world"]; 2295 string[] t = ["Hello", "new", "world"]; 2296 assert(gapWeightedSimilarity(s, s, 1) == 15); 2297 assert(gapWeightedSimilarity(t, t, 1) == 7); 2298 assert(gapWeightedSimilarity(s, t, 1) == 7); 2299 assert(approxEqual(gapWeightedSimilarityNormalized(s, t, 1), 2300 7.0 / sqrt(15.0 * 7), 0.01)); 2301 } 2302 2303 /** 2304 Similar to $(D gapWeightedSimilarity), just works in an incremental 2305 manner by first revealing the matches of length 1, then gapped matches 2306 of length 2, and so on. The memory requirement is $(BIGOH s.length * 2307 t.length). The time complexity is $(BIGOH s.length * t.length) time 2308 for computing each step. Continuing on the previous example: 2309 2310 The implementation is based on the pseudocode in Fig. 4 of the paper 2311 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, 2312 "Efficient Computation of Gapped Substring Kernels on Large Alphabets") 2313 by Rousu et al., with additional algorithmic and systems-level 2314 optimizations. 2315 */ 2316 struct GapWeightedSimilarityIncremental(Range, F = double) 2317 if (isRandomAccessRange!(Range) && hasLength!(Range)) 2318 { 2319 import core.stdc.stdlib : malloc, realloc, alloca, free; 2320 2321 private: 2322 Range s, t; 2323 F currentValue = 0; 2324 F* kl; 2325 size_t gram = void; 2326 F lambda = void, lambda2 = void; 2327 2328 public: 2329 /** 2330 Constructs an object given two ranges $(D s) and $(D t) and a penalty 2331 $(D lambda). Constructor completes in $(BIGOH s.length * t.length) 2332 time and computes all matches of length 1. 2333 */ 2334 this(Range s, Range t, F lambda) 2335 { 2336 import core.exception : onOutOfMemoryError; 2337 2338 assert(lambda > 0); 2339 this.gram = 0; 2340 this.lambda = lambda; 2341 this.lambda2 = lambda * lambda; // for efficiency only 2342 2343 size_t iMin = size_t.max, jMin = size_t.max, 2344 iMax = 0, jMax = 0; 2345 /* initialize */ 2346 Tuple!(size_t, size_t) * k0; 2347 size_t k0len; 2348 scope(exit) free(k0); 2349 currentValue = 0; 2350 foreach (i, si; s) 2351 { 2352 foreach (j; 0 .. t.length) 2353 { 2354 if (si != t[j]) continue; 2355 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); 2356 with (k0[k0len - 1]) 2357 { 2358 field[0] = i; 2359 field[1] = j; 2360 } 2361 // Maintain the minimum and maximum i and j 2362 if (iMin > i) iMin = i; 2363 if (iMax < i) iMax = i; 2364 if (jMin > j) jMin = j; 2365 if (jMax < j) jMax = j; 2366 } 2367 } 2368 2369 if (iMin > iMax) return; 2370 assert(k0len); 2371 2372 currentValue = k0len; 2373 // Chop strings down to the useful sizes 2374 s = s[iMin .. iMax + 1]; 2375 t = t[jMin .. jMax + 1]; 2376 this.s = s; 2377 this.t = t; 2378 2379 kl = cast(F*) malloc(s.length * t.length * F.sizeof); 2380 if (!kl) 2381 onOutOfMemoryError(); 2382 2383 kl[0 .. s.length * t.length] = 0; 2384 foreach (pos; 0 .. k0len) 2385 { 2386 with (k0[pos]) 2387 { 2388 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; 2389 } 2390 } 2391 } 2392 2393 /** 2394 Returns: $(D this). 2395 */ 2396 ref GapWeightedSimilarityIncremental opSlice() 2397 { 2398 return this; 2399 } 2400 2401 /** 2402 Computes the match of the popFront length. Completes in $(BIGOH s.length * 2403 t.length) time. 2404 */ 2405 void popFront() 2406 { 2407 import std.algorithm.mutation : swap; 2408 2409 // This is a large source of optimization: if similarity at 2410 // the gram-1 level was 0, then we can safely assume 2411 // similarity at the gram level is 0 as well. 2412 if (empty) return; 2413 2414 // Now attempt to match gapped substrings of length `gram' 2415 ++gram; 2416 currentValue = 0; 2417 2418 auto Si = cast(F*) alloca(t.length * F.sizeof); 2419 Si[0 .. t.length] = 0; 2420 foreach (i; 0 .. s.length) 2421 { 2422 const si = s[i]; 2423 F Sij_1 = 0; 2424 F Si_1j_1 = 0; 2425 auto kli = kl + i * t.length; 2426 for (size_t j = 0;;) 2427 { 2428 const klij = kli[j]; 2429 const Si_1j = Si[j]; 2430 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; 2431 // now update kl and currentValue 2432 if (si == t[j]) 2433 currentValue += kli[j] = lambda2 * Si_1j_1; 2434 else 2435 kli[j] = 0; 2436 // commit to Si 2437 Si[j] = tmp; 2438 if (++j == t.length) break; 2439 // get ready for the popFront step; virtually increment j, 2440 // so essentially stuffj_1 <-- stuffj 2441 Si_1j_1 = Si_1j; 2442 Sij_1 = tmp; 2443 } 2444 } 2445 currentValue /= pow(lambda, 2 * (gram + 1)); 2446 2447 version (none) 2448 { 2449 Si_1[0 .. t.length] = 0; 2450 kl[0 .. min(t.length, maxPerimeter + 1)] = 0; 2451 foreach (i; 1 .. min(s.length, maxPerimeter + 1)) 2452 { 2453 auto kli = kl + i * t.length; 2454 assert(s.length > i); 2455 const si = s[i]; 2456 auto kl_1i_1 = kl_1 + (i - 1) * t.length; 2457 kli[0] = 0; 2458 F lastS = 0; 2459 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) 2460 { 2461 immutable j_1 = j - 1; 2462 immutable tmp = kl_1i_1[j_1] 2463 + lambda * (Si_1[j] + lastS) 2464 - lambda2 * Si_1[j_1]; 2465 kl_1i_1[j_1] = float.nan; 2466 Si_1[j_1] = lastS; 2467 lastS = tmp; 2468 if (si == t[j]) 2469 { 2470 currentValue += kli[j] = lambda2 * lastS; 2471 } 2472 else 2473 { 2474 kli[j] = 0; 2475 } 2476 } 2477 Si_1[t.length - 1] = lastS; 2478 } 2479 currentValue /= pow(lambda, 2 * (gram + 1)); 2480 // get ready for the popFront computation 2481 swap(kl, kl_1); 2482 } 2483 } 2484 2485 /** 2486 Returns: The gapped similarity at the current match length (initially 2487 1, grows with each call to $(D popFront)). 2488 */ 2489 @property F front() { return currentValue; } 2490 2491 /** 2492 Returns: Whether there are more matches. 2493 */ 2494 @property bool empty() 2495 { 2496 if (currentValue) return false; 2497 if (kl) 2498 { 2499 free(kl); 2500 kl = null; 2501 } 2502 return true; 2503 } 2504 } 2505 2506 /** 2507 Ditto 2508 */ 2509 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) 2510 (R r1, R r2, F penalty) 2511 { 2512 return typeof(return)(r1, r2, penalty); 2513 } 2514 2515 /// 2516 @system unittest 2517 { 2518 string[] s = ["Hello", "brave", "new", "world"]; 2519 string[] t = ["Hello", "new", "world"]; 2520 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2521 assert(simIter.front == 3); // three 1-length matches 2522 simIter.popFront(); 2523 assert(simIter.front == 3); // three 2-length matches 2524 simIter.popFront(); 2525 assert(simIter.front == 1); // one 3-length match 2526 simIter.popFront(); 2527 assert(simIter.empty); // no more match 2528 } 2529 2530 @system unittest 2531 { 2532 import std.conv : text; 2533 string[] s = ["Hello", "brave", "new", "world"]; 2534 string[] t = ["Hello", "new", "world"]; 2535 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2536 //foreach (e; simIter) writeln(e); 2537 assert(simIter.front == 3); // three 1-length matches 2538 simIter.popFront(); 2539 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches 2540 simIter.popFront(); 2541 assert(simIter.front == 1); // one 3-length matches 2542 simIter.popFront(); 2543 assert(simIter.empty); // no more match 2544 2545 s = ["Hello"]; 2546 t = ["bye"]; 2547 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2548 assert(simIter.empty); 2549 2550 s = ["Hello"]; 2551 t = ["Hello"]; 2552 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2553 assert(simIter.front == 1); // one match 2554 simIter.popFront(); 2555 assert(simIter.empty); 2556 2557 s = ["Hello", "world"]; 2558 t = ["Hello"]; 2559 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2560 assert(simIter.front == 1); // one match 2561 simIter.popFront(); 2562 assert(simIter.empty); 2563 2564 s = ["Hello", "world"]; 2565 t = ["Hello", "yah", "world"]; 2566 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2567 assert(simIter.front == 2); // two 1-gram matches 2568 simIter.popFront(); 2569 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap 2570 } 2571 2572 @system unittest 2573 { 2574 GapWeightedSimilarityIncremental!(string[]) sim = 2575 GapWeightedSimilarityIncremental!(string[])( 2576 ["nyuk", "I", "have", "no", "chocolate", "giba"], 2577 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], 2578 0.5); 2579 double[] witness = [ 7.0, 4.03125, 0, 0 ]; 2580 foreach (e; sim) 2581 { 2582 //writeln(e); 2583 assert(e == witness.front); 2584 witness.popFront(); 2585 } 2586 witness = [ 3.0, 1.3125, 0.25 ]; 2587 sim = GapWeightedSimilarityIncremental!(string[])( 2588 ["I", "have", "no", "chocolate"], 2589 ["I", "have", "some", "chocolate"], 2590 0.5); 2591 foreach (e; sim) 2592 { 2593 //writeln(e); 2594 assert(e == witness.front); 2595 witness.popFront(); 2596 } 2597 assert(witness.empty); 2598 } 2599 2600 /** 2601 Computes the greatest common divisor of $(D a) and $(D b) by using 2602 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) 2603 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. 2604 2605 Params: 2606 T = Any numerical type that supports the modulo operator `%`. If 2607 bit-shifting `<<` and `>>` are also supported, Stein's algorithm will 2608 be used; otherwise, Euclid's algorithm is used as _a fallback. 2609 Returns: 2610 The greatest common divisor of the given arguments. 2611 */ 2612 T gcd(T)(T a, T b) 2613 if (isIntegral!T) 2614 { 2615 static if (is(T == const) || is(T == immutable)) 2616 { 2617 return gcd!(Unqual!T)(a, b); 2618 } 2619 else version (DigitalMars) 2620 { 2621 static if (T.min < 0) 2622 { 2623 assert(a >= 0 && b >= 0); 2624 } 2625 while (b) 2626 { 2627 immutable t = b; 2628 b = a % b; 2629 a = t; 2630 } 2631 return a; 2632 } 2633 else 2634 { 2635 if (a == 0) 2636 return b; 2637 if (b == 0) 2638 return a; 2639 2640 import core.bitop : bsf; 2641 import std.algorithm.mutation : swap; 2642 2643 immutable uint shift = bsf(a | b); 2644 a >>= a.bsf; 2645 2646 do 2647 { 2648 b >>= b.bsf; 2649 if (a > b) 2650 swap(a, b); 2651 b -= a; 2652 } while (b); 2653 2654 return a << shift; 2655 } 2656 } 2657 2658 /// 2659 @safe unittest 2660 { 2661 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); 2662 const int a = 5 * 13 * 23 * 23, b = 13 * 59; 2663 assert(gcd(a, b) == 13); 2664 } 2665 2666 // This overload is for non-builtin numerical types like BigInt or 2667 // user-defined types. 2668 /// ditto 2669 T gcd(T)(T a, T b) 2670 if (!isIntegral!T && 2671 is(typeof(T.init % T.init)) && 2672 is(typeof(T.init == 0 || T.init > 0))) 2673 { 2674 import std.algorithm.mutation : swap; 2675 2676 enum canUseBinaryGcd = is(typeof(() { 2677 T t, u; 2678 t <<= 1; 2679 t >>= 1; 2680 t -= u; 2681 bool b = (t & 1) == 0; 2682 swap(t, u); 2683 })); 2684 2685 assert(a >= 0 && b >= 0); 2686 2687 static if (canUseBinaryGcd) 2688 { 2689 uint shift = 0; 2690 while ((a & 1) == 0 && (b & 1) == 0) 2691 { 2692 a >>= 1; 2693 b >>= 1; 2694 shift++; 2695 } 2696 2697 do 2698 { 2699 assert((a & 1) != 0); 2700 while ((b & 1) == 0) 2701 b >>= 1; 2702 if (a > b) 2703 swap(a, b); 2704 b -= a; 2705 } while (b); 2706 2707 return a << shift; 2708 } 2709 else 2710 { 2711 // The only thing we have is %; fallback to Euclidean algorithm. 2712 while (b != 0) 2713 { 2714 auto t = b; 2715 b = a % b; 2716 a = t; 2717 } 2718 return a; 2719 } 2720 } 2721 2722 // Issue 7102 2723 @system pure unittest 2724 { 2725 import std.bigint : BigInt; 2726 assert(gcd(BigInt("71_000_000_000_000_000_000"), 2727 BigInt("31_000_000_000_000_000_000")) == 2728 BigInt("1_000_000_000_000_000_000")); 2729 } 2730 2731 @safe pure nothrow unittest 2732 { 2733 // A numerical type that only supports % and - (to force gcd implementation 2734 // to use Euclidean algorithm). 2735 struct CrippledInt 2736 { 2737 int impl; 2738 CrippledInt opBinary(string op : "%")(CrippledInt i) 2739 { 2740 return CrippledInt(impl % i.impl); 2741 } 2742 int opEquals(CrippledInt i) { return impl == i.impl; } 2743 int opEquals(int i) { return impl == i; } 2744 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } 2745 } 2746 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); 2747 } 2748 2749 // This is to make tweaking the speed/size vs. accuracy tradeoff easy, 2750 // though floats seem accurate enough for all practical purposes, since 2751 // they pass the "approxEqual(inverseFft(fft(arr)), arr)" test even for 2752 // size 2 ^^ 22. 2753 private alias lookup_t = float; 2754 2755 /**A class for performing fast Fourier transforms of power of two sizes. 2756 * This class encapsulates a large amount of state that is reusable when 2757 * performing multiple FFTs of sizes smaller than or equal to that specified 2758 * in the constructor. This results in substantial speedups when performing 2759 * multiple FFTs with a known maximum size. However, 2760 * a free function API is provided for convenience if you need to perform a 2761 * one-off FFT. 2762 * 2763 * References: 2764 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) 2765 */ 2766 final class Fft 2767 { 2768 import core.bitop : bsf; 2769 import std.algorithm.iteration : map; 2770 import std.array : uninitializedArray; 2771 2772 private: 2773 immutable lookup_t[][] negSinLookup; 2774 2775 void enforceSize(R)(R range) const 2776 { 2777 import std.conv : text; 2778 assert(range.length <= size, text( 2779 "FFT size mismatch. Expected ", size, ", got ", range.length)); 2780 } 2781 2782 void fftImpl(Ret, R)(Stride!R range, Ret buf) const 2783 in 2784 { 2785 assert(range.length >= 4); 2786 assert(isPowerOf2(range.length)); 2787 } 2788 body 2789 { 2790 auto recurseRange = range; 2791 recurseRange.doubleSteps(); 2792 2793 if (buf.length > 4) 2794 { 2795 fftImpl(recurseRange, buf[0..$ / 2]); 2796 recurseRange.popHalf(); 2797 fftImpl(recurseRange, buf[$ / 2..$]); 2798 } 2799 else 2800 { 2801 // Do this here instead of in another recursion to save on 2802 // recursion overhead. 2803 slowFourier2(recurseRange, buf[0..$ / 2]); 2804 recurseRange.popHalf(); 2805 slowFourier2(recurseRange, buf[$ / 2..$]); 2806 } 2807 2808 butterfly(buf); 2809 } 2810 2811 // This algorithm works by performing the even and odd parts of our FFT 2812 // using the "two for the price of one" method mentioned at 2813 // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 2814 // by making the odd terms into the imaginary components of our new FFT, 2815 // and then using symmetry to recombine them. 2816 void fftImplPureReal(Ret, R)(R range, Ret buf) const 2817 in 2818 { 2819 assert(range.length >= 4); 2820 assert(isPowerOf2(range.length)); 2821 } 2822 body 2823 { 2824 alias E = ElementType!R; 2825 2826 // Converts odd indices of range to the imaginary components of 2827 // a range half the size. The even indices become the real components. 2828 static if (isArray!R && isFloatingPoint!E) 2829 { 2830 // Then the memory layout of complex numbers provides a dirt 2831 // cheap way to convert. This is a common case, so take advantage. 2832 auto oddsImag = cast(Complex!E[]) range; 2833 } 2834 else 2835 { 2836 // General case: Use a higher order range. We can assume 2837 // source.length is even because it has to be a power of 2. 2838 static struct OddToImaginary 2839 { 2840 R source; 2841 alias C = Complex!(CommonType!(E, typeof(buf[0].re))); 2842 2843 @property 2844 { 2845 C front() 2846 { 2847 return C(source[0], source[1]); 2848 } 2849 2850 C back() 2851 { 2852 immutable n = source.length; 2853 return C(source[n - 2], source[n - 1]); 2854 } 2855 2856 typeof(this) save() 2857 { 2858 return typeof(this)(source.save); 2859 } 2860 2861 bool empty() 2862 { 2863 return source.empty; 2864 } 2865 2866 size_t length() 2867 { 2868 return source.length / 2; 2869 } 2870 } 2871 2872 void popFront() 2873 { 2874 source.popFront(); 2875 source.popFront(); 2876 } 2877 2878 void popBack() 2879 { 2880 source.popBack(); 2881 source.popBack(); 2882 } 2883 2884 C opIndex(size_t index) 2885 { 2886 return C(source[index * 2], source[index * 2 + 1]); 2887 } 2888 2889 typeof(this) opSlice(size_t lower, size_t upper) 2890 { 2891 return typeof(this)(source[lower * 2 .. upper * 2]); 2892 } 2893 } 2894 2895 auto oddsImag = OddToImaginary(range); 2896 } 2897 2898 fft(oddsImag, buf[0..$ / 2]); 2899 auto evenFft = buf[0..$ / 2]; 2900 auto oddFft = buf[$ / 2..$]; 2901 immutable halfN = evenFft.length; 2902 oddFft[0].re = buf[0].im; 2903 oddFft[0].im = 0; 2904 evenFft[0].im = 0; 2905 // evenFft[0].re is already right b/c it's aliased with buf[0].re. 2906 2907 foreach (k; 1 .. halfN / 2 + 1) 2908 { 2909 immutable bufk = buf[k]; 2910 immutable bufnk = buf[buf.length / 2 - k]; 2911 evenFft[k].re = 0.5 * (bufk.re + bufnk.re); 2912 evenFft[halfN - k].re = evenFft[k].re; 2913 evenFft[k].im = 0.5 * (bufk.im - bufnk.im); 2914 evenFft[halfN - k].im = -evenFft[k].im; 2915 2916 oddFft[k].re = 0.5 * (bufk.im + bufnk.im); 2917 oddFft[halfN - k].re = oddFft[k].re; 2918 oddFft[k].im = 0.5 * (bufnk.re - bufk.re); 2919 oddFft[halfN - k].im = -oddFft[k].im; 2920 } 2921 2922 butterfly(buf); 2923 } 2924 2925 void butterfly(R)(R buf) const 2926 in 2927 { 2928 assert(isPowerOf2(buf.length)); 2929 } 2930 body 2931 { 2932 immutable n = buf.length; 2933 immutable localLookup = negSinLookup[bsf(n)]; 2934 assert(localLookup.length == n); 2935 2936 immutable cosMask = n - 1; 2937 immutable cosAdd = n / 4 * 3; 2938 2939 lookup_t negSinFromLookup(size_t index) pure nothrow 2940 { 2941 return localLookup[index]; 2942 } 2943 2944 lookup_t cosFromLookup(size_t index) pure nothrow 2945 { 2946 // cos is just -sin shifted by PI * 3 / 2. 2947 return localLookup[(index + cosAdd) & cosMask]; 2948 } 2949 2950 immutable halfLen = n / 2; 2951 2952 // This loop is unrolled and the two iterations are interleaved 2953 // relative to the textbook FFT to increase ILP. This gives roughly 5% 2954 // speedups on DMD. 2955 for (size_t k = 0; k < halfLen; k += 2) 2956 { 2957 immutable cosTwiddle1 = cosFromLookup(k); 2958 immutable sinTwiddle1 = negSinFromLookup(k); 2959 immutable cosTwiddle2 = cosFromLookup(k + 1); 2960 immutable sinTwiddle2 = negSinFromLookup(k + 1); 2961 2962 immutable realLower1 = buf[k].re; 2963 immutable imagLower1 = buf[k].im; 2964 immutable realLower2 = buf[k + 1].re; 2965 immutable imagLower2 = buf[k + 1].im; 2966 2967 immutable upperIndex1 = k + halfLen; 2968 immutable upperIndex2 = upperIndex1 + 1; 2969 immutable realUpper1 = buf[upperIndex1].re; 2970 immutable imagUpper1 = buf[upperIndex1].im; 2971 immutable realUpper2 = buf[upperIndex2].re; 2972 immutable imagUpper2 = buf[upperIndex2].im; 2973 2974 immutable realAdd1 = cosTwiddle1 * realUpper1 2975 - sinTwiddle1 * imagUpper1; 2976 immutable imagAdd1 = sinTwiddle1 * realUpper1 2977 + cosTwiddle1 * imagUpper1; 2978 immutable realAdd2 = cosTwiddle2 * realUpper2 2979 - sinTwiddle2 * imagUpper2; 2980 immutable imagAdd2 = sinTwiddle2 * realUpper2 2981 + cosTwiddle2 * imagUpper2; 2982 2983 buf[k].re += realAdd1; 2984 buf[k].im += imagAdd1; 2985 buf[k + 1].re += realAdd2; 2986 buf[k + 1].im += imagAdd2; 2987 2988 buf[upperIndex1].re = realLower1 - realAdd1; 2989 buf[upperIndex1].im = imagLower1 - imagAdd1; 2990 buf[upperIndex2].re = realLower2 - realAdd2; 2991 buf[upperIndex2].im = imagLower2 - imagAdd2; 2992 } 2993 } 2994 2995 // This constructor is used within this module for allocating the 2996 // buffer space elsewhere besides the GC heap. It's definitely **NOT** 2997 // part of the public API and definitely **IS** subject to change. 2998 // 2999 // Also, this is unsafe because the memSpace buffer will be cast 3000 // to immutable. 3001 public this(lookup_t[] memSpace) // Public b/c of bug 4636. 3002 { 3003 immutable size = memSpace.length / 2; 3004 3005 /* Create a lookup table of all negative sine values at a resolution of 3006 * size and all smaller power of two resolutions. This may seem 3007 * inefficient, but having all the lookups be next to each other in 3008 * memory at every level of iteration is a huge win performance-wise. 3009 */ 3010 if (size == 0) 3011 { 3012 return; 3013 } 3014 3015 assert(isPowerOf2(size), 3016 "Can only do FFTs on ranges with a size that is a power of two."); 3017 3018 auto table = new lookup_t[][bsf(size) + 1]; 3019 3020 table[$ - 1] = memSpace[$ - size..$]; 3021 memSpace = memSpace[0 .. size]; 3022 3023 auto lastRow = table[$ - 1]; 3024 lastRow[0] = 0; // -sin(0) == 0. 3025 foreach (ptrdiff_t i; 1 .. size) 3026 { 3027 // The hard coded cases are for improved accuracy and to prevent 3028 // annoying non-zeroness when stuff should be zero. 3029 3030 if (i == size / 4) 3031 lastRow[i] = -1; // -sin(pi / 2) == -1. 3032 else if (i == size / 2) 3033 lastRow[i] = 0; // -sin(pi) == 0. 3034 else if (i == size * 3 / 4) 3035 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 3036 else 3037 lastRow[i] = -sin(i * 2.0L * PI / size); 3038 } 3039 3040 // Fill in all the other rows with strided versions. 3041 foreach (i; 1 .. table.length - 1) 3042 { 3043 immutable strideLength = size / (2 ^^ i); 3044 auto strided = Stride!(lookup_t[])(lastRow, strideLength); 3045 table[i] = memSpace[$ - strided.length..$]; 3046 memSpace = memSpace[0..$ - strided.length]; 3047 3048 size_t copyIndex; 3049 foreach (elem; strided) 3050 { 3051 table[i][copyIndex++] = elem; 3052 } 3053 } 3054 3055 negSinLookup = cast(immutable) table; 3056 } 3057 3058 public: 3059 /**Create an $(D Fft) object for computing fast Fourier transforms of 3060 * power of two sizes of $(D size) or smaller. $(D size) must be a 3061 * power of two. 3062 */ 3063 this(size_t size) 3064 { 3065 // Allocate all twiddle factor buffers in one contiguous block so that, 3066 // when one is done being used, the next one is next in cache. 3067 auto memSpace = uninitializedArray!(lookup_t[])(2 * size); 3068 this(memSpace); 3069 } 3070 3071 @property size_t size() const 3072 { 3073 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; 3074 } 3075 3076 /**Compute the Fourier transform of range using the $(BIGOH N log N) 3077 * Cooley-Tukey Algorithm. $(D range) must be a random-access range with 3078 * slicing and a length equal to $(D size) as provided at the construction of 3079 * this object. The contents of range can be either numeric types, 3080 * which will be interpreted as pure real values, or complex types with 3081 * properties or members $(D .re) and $(D .im) that can be read. 3082 * 3083 * Note: Pure real FFTs are automatically detected and the relevant 3084 * optimizations are performed. 3085 * 3086 * Returns: An array of complex numbers representing the transformed data in 3087 * the frequency domain. 3088 * 3089 * Conventions: The exponent is negative and the factor is one, 3090 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. 3091 */ 3092 Complex!F[] fft(F = double, R)(R range) const 3093 if (isFloatingPoint!F && isRandomAccessRange!R) 3094 { 3095 enforceSize(range); 3096 Complex!F[] ret; 3097 if (range.length == 0) 3098 { 3099 return ret; 3100 } 3101 3102 // Don't waste time initializing the memory for ret. 3103 ret = uninitializedArray!(Complex!F[])(range.length); 3104 3105 fft(range, ret); 3106 return ret; 3107 } 3108 3109 /**Same as the overload, but allows for the results to be stored in a user- 3110 * provided buffer. The buffer must be of the same length as range, must be 3111 * a random-access range, must have slicing, and must contain elements that are 3112 * complex-like. This means that they must have a .re and a .im member or 3113 * property that can be both read and written and are floating point numbers. 3114 */ 3115 void fft(Ret, R)(R range, Ret buf) const 3116 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3117 { 3118 assert(buf.length == range.length); 3119 enforceSize(range); 3120 3121 if (range.length == 0) 3122 { 3123 return; 3124 } 3125 else if (range.length == 1) 3126 { 3127 buf[0] = range[0]; 3128 return; 3129 } 3130 else if (range.length == 2) 3131 { 3132 slowFourier2(range, buf); 3133 return; 3134 } 3135 else 3136 { 3137 alias E = ElementType!R; 3138 static if (is(E : real)) 3139 { 3140 return fftImplPureReal(range, buf); 3141 } 3142 else 3143 { 3144 static if (is(R : Stride!R)) 3145 return fftImpl(range, buf); 3146 else 3147 return fftImpl(Stride!R(range, 1), buf); 3148 } 3149 } 3150 } 3151 3152 /** 3153 * Computes the inverse Fourier transform of a range. The range must be a 3154 * random access range with slicing, have a length equal to the size 3155 * provided at construction of this object, and contain elements that are 3156 * either of type std.complex.Complex or have essentially 3157 * the same compile-time interface. 3158 * 3159 * Returns: The time-domain signal. 3160 * 3161 * Conventions: The exponent is positive and the factor is 1/N, i.e., 3162 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. 3163 */ 3164 Complex!F[] inverseFft(F = double, R)(R range) const 3165 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) 3166 { 3167 enforceSize(range); 3168 Complex!F[] ret; 3169 if (range.length == 0) 3170 { 3171 return ret; 3172 } 3173 3174 // Don't waste time initializing the memory for ret. 3175 ret = uninitializedArray!(Complex!F[])(range.length); 3176 3177 inverseFft(range, ret); 3178 return ret; 3179 } 3180 3181 /** 3182 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer 3183 * must be a random access range with slicing, and its elements 3184 * must be some complex-like type. 3185 */ 3186 void inverseFft(Ret, R)(R range, Ret buf) const 3187 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3188 { 3189 enforceSize(range); 3190 3191 auto swapped = map!swapRealImag(range); 3192 fft(swapped, buf); 3193 3194 immutable lenNeg1 = 1.0 / buf.length; 3195 foreach (ref elem; buf) 3196 { 3197 immutable temp = elem.re * lenNeg1; 3198 elem.re = elem.im * lenNeg1; 3199 elem.im = temp; 3200 } 3201 } 3202 } 3203 3204 // This mixin creates an Fft object in the scope it's mixed into such that all 3205 // memory owned by the object is deterministically destroyed at the end of that 3206 // scope. 3207 private enum string MakeLocalFft = q{ 3208 import core.stdc.stdlib; 3209 import core.exception : onOutOfMemoryError; 3210 3211 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) 3212 [0 .. 2 * range.length]; 3213 if (!lookupBuf.ptr) 3214 onOutOfMemoryError(); 3215 3216 scope(exit) free(cast(void*) lookupBuf.ptr); 3217 auto fftObj = scoped!Fft(lookupBuf); 3218 }; 3219 3220 /**Convenience functions that create an $(D Fft) object, run the FFT or inverse 3221 * FFT and return the result. Useful for one-off FFTs. 3222 * 3223 * Note: In addition to convenience, these functions are slightly more 3224 * efficient than manually creating an Fft object for a single use, 3225 * as the Fft object is deterministically destroyed before these 3226 * functions return. 3227 */ 3228 Complex!F[] fft(F = double, R)(R range) 3229 { 3230 mixin(MakeLocalFft); 3231 return fftObj.fft!(F, R)(range); 3232 } 3233 3234 /// ditto 3235 void fft(Ret, R)(R range, Ret buf) 3236 { 3237 mixin(MakeLocalFft); 3238 return fftObj.fft!(Ret, R)(range, buf); 3239 } 3240 3241 /// ditto 3242 Complex!F[] inverseFft(F = double, R)(R range) 3243 { 3244 mixin(MakeLocalFft); 3245 return fftObj.inverseFft!(F, R)(range); 3246 } 3247 3248 /// ditto 3249 void inverseFft(Ret, R)(R range, Ret buf) 3250 { 3251 mixin(MakeLocalFft); 3252 return fftObj.inverseFft!(Ret, R)(range, buf); 3253 } 3254 3255 @system unittest 3256 { 3257 import std.algorithm; 3258 import std.conv; 3259 import std.range; 3260 // Test values from R and Octave. 3261 auto arr = [1,2,3,4,5,6,7,8]; 3262 auto fft1 = fft(arr); 3263 assert(approxEqual(map!"a.re"(fft1), 3264 [36.0, -4, -4, -4, -4, -4, -4, -4])); 3265 assert(approxEqual(map!"a.im"(fft1), 3266 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568])); 3267 3268 auto fft1Retro = fft(retro(arr)); 3269 assert(approxEqual(map!"a.re"(fft1Retro), 3270 [36.0, 4, 4, 4, 4, 4, 4, 4])); 3271 assert(approxEqual(map!"a.im"(fft1Retro), 3272 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568])); 3273 3274 auto fft1Float = fft(to!(float[])(arr)); 3275 assert(approxEqual(map!"a.re"(fft1), map!"a.re"(fft1Float))); 3276 assert(approxEqual(map!"a.im"(fft1), map!"a.im"(fft1Float))); 3277 3278 alias C = Complex!float; 3279 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), 3280 C(11,12), C(13,14), C(15,16)]; 3281 auto fft2 = fft(arr2); 3282 assert(approxEqual(map!"a.re"(fft2), 3283 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137])); 3284 assert(approxEqual(map!"a.im"(fft2), 3285 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137])); 3286 3287 auto inv1 = inverseFft(fft1); 3288 assert(approxEqual(map!"a.re"(inv1), arr)); 3289 assert(reduce!max(map!"a.im"(inv1)) < 1e-10); 3290 3291 auto inv2 = inverseFft(fft2); 3292 assert(approxEqual(map!"a.re"(inv2), map!"a.re"(arr2))); 3293 assert(approxEqual(map!"a.im"(inv2), map!"a.im"(arr2))); 3294 3295 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. 3296 ushort[] empty; 3297 assert(fft(empty) == null); 3298 assert(inverseFft(fft(empty)) == null); 3299 3300 real[] oneElem = [4.5L]; 3301 auto oneFft = fft(oneElem); 3302 assert(oneFft.length == 1); 3303 assert(oneFft[0].re == 4.5L); 3304 assert(oneFft[0].im == 0); 3305 3306 auto oneInv = inverseFft(oneFft); 3307 assert(oneInv.length == 1); 3308 assert(approxEqual(oneInv[0].re, 4.5)); 3309 assert(approxEqual(oneInv[0].im, 0)); 3310 3311 long[2] twoElems = [8, 4]; 3312 auto twoFft = fft(twoElems[]); 3313 assert(twoFft.length == 2); 3314 assert(approxEqual(twoFft[0].re, 12)); 3315 assert(approxEqual(twoFft[0].im, 0)); 3316 assert(approxEqual(twoFft[1].re, 4)); 3317 assert(approxEqual(twoFft[1].im, 0)); 3318 auto twoInv = inverseFft(twoFft); 3319 assert(approxEqual(twoInv[0].re, 8)); 3320 assert(approxEqual(twoInv[0].im, 0)); 3321 assert(approxEqual(twoInv[1].re, 4)); 3322 assert(approxEqual(twoInv[1].im, 0)); 3323 } 3324 3325 // Swaps the real and imaginary parts of a complex number. This is useful 3326 // for inverse FFTs. 3327 C swapRealImag(C)(C input) 3328 { 3329 return C(input.im, input.re); 3330 } 3331 3332 private: 3333 // The reasons I couldn't use std.algorithm were b/c its stride length isn't 3334 // modifiable on the fly and because range has grown some performance hacks 3335 // for powers of 2. 3336 struct Stride(R) 3337 { 3338 import core.bitop : bsf; 3339 Unqual!R range; 3340 size_t _nSteps; 3341 size_t _length; 3342 alias E = ElementType!(R); 3343 3344 this(R range, size_t nStepsIn) 3345 { 3346 this.range = range; 3347 _nSteps = nStepsIn; 3348 _length = (range.length + _nSteps - 1) / nSteps; 3349 } 3350 3351 size_t length() const @property 3352 { 3353 return _length; 3354 } 3355 3356 typeof(this) save() @property 3357 { 3358 auto ret = this; 3359 ret.range = ret.range.save; 3360 return ret; 3361 } 3362 3363 E opIndex(size_t index) 3364 { 3365 return range[index * _nSteps]; 3366 } 3367 3368 E front() @property 3369 { 3370 return range[0]; 3371 } 3372 3373 void popFront() 3374 { 3375 if (range.length >= _nSteps) 3376 { 3377 range = range[_nSteps .. range.length]; 3378 _length--; 3379 } 3380 else 3381 { 3382 range = range[0 .. 0]; 3383 _length = 0; 3384 } 3385 } 3386 3387 // Pops half the range's stride. 3388 void popHalf() 3389 { 3390 range = range[_nSteps / 2 .. range.length]; 3391 } 3392 3393 bool empty() const @property 3394 { 3395 return length == 0; 3396 } 3397 3398 size_t nSteps() const @property 3399 { 3400 return _nSteps; 3401 } 3402 3403 void doubleSteps() 3404 { 3405 _nSteps *= 2; 3406 _length /= 2; 3407 } 3408 3409 size_t nSteps(size_t newVal) @property 3410 { 3411 _nSteps = newVal; 3412 3413 // Using >> bsf(nSteps) is a few cycles faster than / nSteps. 3414 _length = (range.length + _nSteps - 1) >> bsf(nSteps); 3415 return newVal; 3416 } 3417 } 3418 3419 // Hard-coded base case for FFT of size 2. This is actually a TON faster than 3420 // using a generic slow DFT. This seems to be the best base case. (Size 1 3421 // can be coded inline as buf[0] = range[0]). 3422 void slowFourier2(Ret, R)(R range, Ret buf) 3423 { 3424 assert(range.length == 2); 3425 assert(buf.length == 2); 3426 buf[0] = range[0] + range[1]; 3427 buf[1] = range[0] - range[1]; 3428 } 3429 3430 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size 3431 // 2 case. 3432 void slowFourier4(Ret, R)(R range, Ret buf) 3433 { 3434 alias C = ElementType!Ret; 3435 3436 assert(range.length == 4); 3437 assert(buf.length == 4); 3438 buf[0] = range[0] + range[1] + range[2] + range[3]; 3439 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); 3440 buf[2] = range[0] - range[1] + range[2] - range[3]; 3441 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); 3442 } 3443 3444 N roundDownToPowerOf2(N)(N num) 3445 if (isScalarType!N && !isFloatingPoint!N) 3446 { 3447 import core.bitop : bsr; 3448 return num & (cast(N) 1 << bsr(num)); 3449 } 3450 3451 @safe unittest 3452 { 3453 assert(roundDownToPowerOf2(7) == 4); 3454 assert(roundDownToPowerOf2(4) == 4); 3455 } 3456 3457 template isComplexLike(T) 3458 { 3459 enum bool isComplexLike = is(typeof(T.init.re)) && 3460 is(typeof(T.init.im)); 3461 } 3462 3463 @safe unittest 3464 { 3465 static assert(isComplexLike!(Complex!double)); 3466 static assert(!isComplexLike!(uint)); 3467 } 3468