From 41e696291c64fe19629e14887ed1ed9b9c2271f0 Mon Sep 17 00:00:00 2001 From: LRFLEW Date: Fri, 19 Apr 2024 11:58:18 -0500 Subject: [PATCH] linear_congruential_engine: add using more precision to prevent overflow (#81583) This PR is a followup to #81080. This PR makes two major changes to how the LCG operation is computed: The first is that I added an additional case where `ax + c` might overflow the intermediate variable, but `ax` by itself won't. In this case, it's much better to use `(ax mod m) + c mod m` than the previous behavior of falling back to Schrage's algorithm. The addition modulo is done in the same way as when using Schrage's algorithm (i.e. `x += c - (x >= m - c)*m`), but the multiplication modulo is calculated directly, which is faster. The second is that I added handling for the case where the `ax` intermediate might overflow, but Schrage's algorithm doesn't apply (i.e. r > q). In this case, the only real option is to increase the precision of the intermediate values. The good news is that - for `x`, `a`, and `c` being n-bit values - `ax + c` will never overflow a 2n-bit intermediary, meaning this promotion can only happen once, and will always be able to use the simplest implementation. This is already the case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit intermediate values. For 32-bit LCGs, I simply added code similar to the 16-bit case to use the existing 64-bit implementations. Lastly, for 64-bit LCGs, I wrote a case that calculates it using `unsigned __int128` if it is available to use. While this implementation covers a *lot* of the missing cases from #81080, this still won't compile **every** possible `linear_congruential_engine`. Specifically, if `a`, `c`, and `m` are chosen such that it needs 128-bit integers, but the platform doesn't support `__int128` (eg. 32-bit x86), then it will fail to compile. However, this is a fairly rare case to see actually used, and libcxx would be in good company with this, as [libstdc++ also fails to compile under these circumstances](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87744). Fixing **this** gap would require even **more** work of further complexity, so that would probably be best handled by a different PR (I'll put more details on what that PR would entail in a comment). --- .../__random/linear_congruential_engine.h | 113 ++++++++++++++---- .../rand/rand.eng/rand.eng.lcong/alg.pass.cpp | 65 +++++++--- .../rand.eng/rand.eng.lcong/assign.pass.cpp | 18 ++- .../rand.eng/rand.eng.lcong/copy.pass.cpp | 18 ++- .../rand.eng/rand.eng.lcong/default.pass.cpp | 18 ++- .../rand.eng/rand.eng.lcong/values.pass.cpp | 18 ++- 6 files changed, 192 insertions(+), 58 deletions(-) diff --git a/libcxx/include/__random/linear_congruential_engine.h b/libcxx/include/__random/linear_congruential_engine.h index fe9cb909b74d..9d77649e9cfc 100644 --- a/libcxx/include/__random/linear_congruential_engine.h +++ b/libcxx/include/__random/linear_congruential_engine.h @@ -26,32 +26,60 @@ _LIBCPP_PUSH_MACROS _LIBCPP_BEGIN_NAMESPACE_STD -template (_Mp - __c) / __a), - bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n - bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q -struct __lce_alg_picker { - static_assert(!_MightOverflow || _OverflowOK || _SchrageOK, - "The current values of a, c, and m cannot generate a number " - "within bounds of linear_congruential_engine."); - - static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK; +enum __lce_alg_type { + _LCE_Full, + _LCE_Part, + _LCE_Schrage, + _LCE_Promote, }; template ::__use_schrage> + bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n + bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works + bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works + bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q +struct __lce_alg_picker { + static _LIBCPP_CONSTEXPR const __lce_alg_type __mode = + _Full ? _LCE_Full + : _Part ? _LCE_Part + : _Schrage ? _LCE_Schrage + : _LCE_Promote; + +#ifdef _LIBCPP_HAS_NO_INT128 + static_assert(_Mp != (unsigned long long)(-1) || _Full || _Part || _Schrage, + "The current values for a, c, and m are not currently supported on platforms without __int128"); +#endif +}; + +template ::__mode> struct __lce_ta; // 64 +#ifndef _LIBCPP_HAS_NO_INT128 +template +struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(-1), _LCE_Promote> { + typedef unsigned long long result_type; + _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) { + __extension__ using __calc_type = unsigned __int128; + const __calc_type __a = static_cast<__calc_type>(_Ap); + const __calc_type __c = static_cast<__calc_type>(_Cp); + const __calc_type __m = static_cast<__calc_type>(_Mp); + const __calc_type __x = static_cast<__calc_type>(__xp); + return static_cast((__a * __x + __c) % __m); + } +}; +#endif + template -struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> { +struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Schrage> { typedef unsigned long long result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { // Schrage's algorithm @@ -66,7 +94,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> { }; template -struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> { +struct __lce_ta<__a, 0ull, __m, (unsigned long long)(-1), _LCE_Schrage> { typedef unsigned long long result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { // Schrage's algorithm @@ -80,21 +108,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> { }; template -struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> { +struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Part> { + typedef unsigned long long result_type; + _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { + // Use (((a*x) % m) + c) % m + __x = (__a * __x) % __m; + __x += __c - (__x >= __m - __c) * __m; + return __x; + } +}; + +template +struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Full> { typedef unsigned long long result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; } }; template -struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> { +struct __lce_ta<__a, __c, 0ull, (unsigned long long)(-1), _LCE_Full> { typedef unsigned long long result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; } }; // 32 +template +struct __lce_ta<__a, __c, __m, unsigned(-1), _LCE_Promote> { + typedef unsigned result_type; + _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { + return static_cast(__lce_ta<__a, __c, __m, (unsigned long long)(-1)>::next(__x)); + } +}; + template -struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> { +struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Schrage> { typedef unsigned result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { const result_type __a = static_cast(_Ap); @@ -112,7 +159,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> { }; template -struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> { +struct __lce_ta<_Ap, 0ull, _Mp, unsigned(-1), _LCE_Schrage> { typedef unsigned result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { const result_type __a = static_cast(_Ap); @@ -128,7 +175,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> { }; template -struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> { +struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Part> { + typedef unsigned result_type; + _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { + const result_type __a = static_cast(_Ap); + const result_type __c = static_cast(_Cp); + const result_type __m = static_cast(_Mp); + // Use (((a*x) % m) + c) % m + __x = (__a * __x) % __m; + __x += __c - (__x >= __m - __c) * __m; + return __x; + } +}; + +template +struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Full> { typedef unsigned result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { const result_type __a = static_cast(_Ap); @@ -139,7 +200,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> { }; template -struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> { +struct __lce_ta<_Ap, _Cp, 0ull, unsigned(-1), _LCE_Full> { typedef unsigned result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { const result_type __a = static_cast(_Ap); @@ -150,11 +211,11 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> { // 16 -template -struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> { +template +struct __lce_ta<__a, __c, __m, (unsigned short)(-1), __mode> { typedef unsigned short result_type; _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { - return static_cast(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x)); + return static_cast(__lce_ta<__a, __c, __m, unsigned(-1)>::next(__x)); } }; @@ -178,7 +239,7 @@ public: private: result_type __x_; - static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(~0); + static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(-1); static_assert(__m == 0 || __a < __m, "linear_congruential_engine invalid parameters"); static_assert(__m == 0 || __c < __m, "linear_congruential_engine invalid parameters"); diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp index 8a9cae0e610c..159cb19f6546 100644 --- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp +++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp @@ -38,12 +38,12 @@ int main(int, char**) // m might overflow. The overflow is not OK and result will be in bounds // so we should use Schrage's algorithm - typedef std::linear_congruential_engine E2; + typedef std::linear_congruential_engine E2; E2 e2; // make sure Schrage's algorithm is used (it would be 0s after the first otherwise) assert(e2() == (1ull << 32)); assert(e2() == (1ull << 63) - 1ull); - assert(e2() == (1ull << 63) - (1ull << 33) + 1ull); + assert(e2() == (1ull << 63) - 0x1ffffffffull); // make sure result is in bounds assert(e2() < (1ull << 63) + 1); assert(e2() < (1ull << 63) + 1); @@ -56,9 +56,9 @@ int main(int, char**) typedef std::linear_congruential_engine E3; E3 e3; // make sure Schrage's algorithm is used - assert(e3() == 402727752ull); - assert(e3() == 162159612030764687ull); - assert(e3() == 108176466184989142ull); + assert(e3() == 0x18012348ull); + assert(e3() == 0x2401b4ed802468full); + assert(e3() == 0x18051ec400369d6ull); // make sure result is in bounds assert(e3() < (3ull << 56)); assert(e3() < (3ull << 56)); @@ -66,19 +66,52 @@ int main(int, char**) assert(e3() < (3ull << 56)); assert(e3() < (3ull << 56)); + // 32-bit case: + // m might overflow. The overflow is not OK, result will be in bounds, + // and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic. + typedef std::linear_congruential_engine E4; + E4 e4; + // make sure enough precision is used + assert(e4() == 0x10009u); + assert(e4() == 0x120053u); + assert(e4() == 0xf5030fu); + // make sure result is in bounds + assert(e4() < 0x7fffffffu); + assert(e4() < 0x7fffffffu); + assert(e4() < 0x7fffffffu); + assert(e4() < 0x7fffffffu); + assert(e4() < 0x7fffffffu); + +#ifndef _LIBCPP_HAS_NO_INT128 + // m might overflow. The overflow is not OK, result will be in bounds, + // and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic. + typedef std::linear_congruential_engine E5; + E5 e5; + // make sure enough precision is used + assert(e5() == 0x100000001ull); + assert(e5() == 0x200000009ull); + assert(e5() == 0xb00000019ull); + // make sure result is in bounds + assert(e5() < (1ull << 61) - 1ull); + assert(e5() < (1ull << 61) - 1ull); + assert(e5() < (1ull << 61) - 1ull); + assert(e5() < (1ull << 61) - 1ull); + assert(e5() < (1ull << 61) - 1ull); +#endif + // m will not overflow so we should not use Schrage's algorithm - typedef std::linear_congruential_engine E4; - E4 e4; + typedef std::linear_congruential_engine E6; + E6 e6; // make sure the correct algorithm was used - assert(e4() == 2ull); - assert(e4() == 3ull); - assert(e4() == 4ull); + assert(e6() == 2ull); + assert(e6() == 3ull); + assert(e6() == 4ull); // make sure result is in bounds - assert(e4() < (1ull << 48)); - assert(e4() < (1ull << 48)); - assert(e4() < (1ull << 48)); - assert(e4() < (1ull << 48)); - assert(e4() < (1ull << 48)); + assert(e6() < (1ull << 48)); + assert(e6() < (1ull << 48)); + assert(e6() < (1ull << 48)); + assert(e6() < (1ull << 48)); + assert(e6() < (1ull << 48)); return 0; -} \ No newline at end of file +} diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp index 5317f171a98a..73829071bd95 100644 --- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp +++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp @@ -61,24 +61,34 @@ test() test1(); test1(); test1(); +} - /* - // Cases where m is odd and m % a > m / a (not implemented) +template +void test_ext() { + const T M(static_cast(-1)); + + // Cases where m is odd and m % a > m / a test1(); test1(); test1(); test1(); test1(); test1(); - */ } int main(int, char**) { test(); + test_ext(); test(); + test_ext(); test(); + test_ext(); test(); + // This isn't implemented on platforms without __int128 +#ifndef _LIBCPP_HAS_NO_INT128 + test_ext(); +#endif - return 0; + return 0; } diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp index 8e950043d594..8387a1763714 100644 --- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp +++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp @@ -60,24 +60,34 @@ test() test1(); test1(); test1(); +} - /* - // Cases where m is odd and m % a > m / a (not implemented) +template +void test_ext() { + const T M(static_cast(-1)); + + // Cases where m is odd and m % a > m / a test1(); test1(); test1(); test1(); test1(); test1(); - */ } int main(int, char**) { test(); + test_ext(); test(); + test_ext(); test(); + test_ext(); test(); + // This isn't implemented on platforms without __int128 +#ifndef _LIBCPP_HAS_NO_INT128 + test_ext(); +#endif - return 0; + return 0; } diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp index 52126f7a200d..c59afd7a3eb2 100644 --- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp +++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp @@ -58,24 +58,34 @@ test() test1(); test1(); test1(); +} - /* - // Cases where m is odd and m % a > m / a (not implemented) +template +void test_ext() { + const T M(static_cast(-1)); + + // Cases where m is odd and m % a > m / a test1(); test1(); test1(); test1(); test1(); test1(); - */ } int main(int, char**) { test(); + test_ext(); test(); + test_ext(); test(); + test_ext(); test(); + // This isn't implemented on platforms without __int128 +#ifndef _LIBCPP_HAS_NO_INT128 + test_ext(); +#endif - return 0; + return 0; } diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp index 28d8dfea01fa..98b07e70f247 100644 --- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp +++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp @@ -91,24 +91,34 @@ test() test1(); test1(); test1(); +} - /* - // Cases where m is odd and m % a > m / a (not implemented) +template +void test_ext() { + const T M(static_cast(-1)); + + // Cases where m is odd and m % a > m / a test1(); test1(); test1(); test1(); test1(); test1(); - */ } int main(int, char**) { test(); + test_ext(); test(); + test_ext(); test(); + test_ext(); test(); + // This isn't implemented on platforms without __int128 +#ifndef _LIBCPP_HAS_NO_INT128 + test_ext(); +#endif - return 0; + return 0; }