linear_congruential_engine: add using more precision to prevent overflow (#81583)

This PR is a followup to #81080.

This PR makes two major changes to how the LCG operation is computed:

The first is that I added an additional case where `ax + c` might
overflow the intermediate variable, but `ax` by itself won't. In this
case, it's much better to use `(ax mod m) + c mod m` than the previous
behavior of falling back to Schrage's algorithm. The addition modulo is
done in the same way as when using Schrage's algorithm (i.e. `x += c -
(x >= m - c)*m`), but the multiplication modulo is calculated directly,
which is faster.

The second is that I added handling for the case where the `ax`
intermediate might overflow, but Schrage's algorithm doesn't apply (i.e.
r > q). In this case, the only real option is to increase the precision
of the intermediate values. The good news is that - for `x`, `a`, and
`c` being n-bit values - `ax + c` will never overflow a 2n-bit
intermediary, meaning this promotion can only happen once, and will
always be able to use the simplest implementation. This is already the
case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit
intermediate values. For 32-bit LCGs, I simply added code similar to the
16-bit case to use the existing 64-bit implementations. Lastly, for
64-bit LCGs, I wrote a case that calculates it using `unsigned __int128`
if it is available to use.

While this implementation covers a *lot* of the missing cases from
#81080, this still won't compile **every** possible
`linear_congruential_engine`. Specifically, if `a`, `c`, and `m` are
chosen such that it needs 128-bit integers, but the platform doesn't
support `__int128` (eg. 32-bit x86), then it will fail to compile.
However, this is a fairly rare case to see actually used, and libcxx
would be in good company with this, as [libstdc++ also fails to compile
under these
circumstances](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87744).
Fixing **this** gap would require even **more** work of further
complexity, so that would probably be best handled by a different PR
(I'll put more details on what that PR would entail in a comment).
This commit is contained in:
LRFLEW 2024-04-19 11:58:18 -05:00 committed by GitHub
parent 82c320ca59
commit 41e696291c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 192 additions and 58 deletions

View File

@ -26,32 +26,60 @@ _LIBCPP_PUSH_MACROS
_LIBCPP_BEGIN_NAMESPACE_STD
template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
struct __lce_alg_picker {
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
"The current values of a, c, and m cannot generate a number "
"within bounds of linear_congruential_engine.");
static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
enum __lce_alg_type {
_LCE_Full,
_LCE_Part,
_LCE_Schrage,
_LCE_Promote,
};
template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
struct __lce_alg_picker {
static _LIBCPP_CONSTEXPR const __lce_alg_type __mode =
_Full ? _LCE_Full
: _Part ? _LCE_Part
: _Schrage ? _LCE_Schrage
: _LCE_Promote;
#ifdef _LIBCPP_HAS_NO_INT128
static_assert(_Mp != (unsigned long long)(-1) || _Full || _Part || _Schrage,
"The current values for a, c, and m are not currently supported on platforms without __int128");
#endif
};
template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
__lce_alg_type _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
struct __lce_ta;
// 64
#ifndef _LIBCPP_HAS_NO_INT128
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(-1), _LCE_Promote> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) {
__extension__ using __calc_type = unsigned __int128;
const __calc_type __a = static_cast<__calc_type>(_Ap);
const __calc_type __c = static_cast<__calc_type>(_Cp);
const __calc_type __m = static_cast<__calc_type>(_Mp);
const __calc_type __x = static_cast<__calc_type>(__xp);
return static_cast<result_type>((__a * __x + __c) % __m);
}
};
#endif
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Schrage> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
@ -66,7 +94,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
};
template <unsigned long long __a, unsigned long long __m>
struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
struct __lce_ta<__a, 0ull, __m, (unsigned long long)(-1), _LCE_Schrage> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
@ -80,21 +108,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
};
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Part> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Use (((a*x) % m) + c) % m
__x = (__a * __x) % __m;
__x += __c - (__x >= __m - __c) * __m;
return __x;
}
};
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Full> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
};
template <unsigned long long __a, unsigned long long __c>
struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
struct __lce_ta<__a, __c, 0ull, (unsigned long long)(-1), _LCE_Full> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
};
// 32
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, unsigned(-1), _LCE_Promote> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(-1)>::next(__x));
}
};
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Schrage> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@ -112,7 +159,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
};
template <unsigned long long _Ap, unsigned long long _Mp>
struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
struct __lce_ta<_Ap, 0ull, _Mp, unsigned(-1), _LCE_Schrage> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@ -128,7 +175,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
};
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Part> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
const result_type __c = static_cast<result_type>(_Cp);
const result_type __m = static_cast<result_type>(_Mp);
// Use (((a*x) % m) + c) % m
__x = (__a * __x) % __m;
__x += __c - (__x >= __m - __c) * __m;
return __x;
}
};
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Full> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@ -139,7 +200,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
};
template <unsigned long long _Ap, unsigned long long _Cp>
struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
struct __lce_ta<_Ap, _Cp, 0ull, unsigned(-1), _LCE_Full> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@ -150,11 +211,11 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
// 16
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, __lce_alg_type __mode>
struct __lce_ta<__a, __c, __m, (unsigned short)(-1), __mode> {
typedef unsigned short result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(-1)>::next(__x));
}
};
@ -178,7 +239,7 @@ public:
private:
result_type __x_;
static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(~0);
static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(-1);
static_assert(__m == 0 || __a < __m, "linear_congruential_engine invalid parameters");
static_assert(__m == 0 || __c < __m, "linear_congruential_engine invalid parameters");

View File

@ -38,12 +38,12 @@ int main(int, char**)
// m might overflow. The overflow is not OK and result will be in bounds
// so we should use Schrage's algorithm
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1ull> E2;
E2 e2;
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
assert(e2() == (1ull << 32));
assert(e2() == (1ull << 63) - 1ull);
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
assert(e2() == (1ull << 63) - 0x1ffffffffull);
// make sure result is in bounds
assert(e2() < (1ull << 63) + 1);
assert(e2() < (1ull << 63) + 1);
@ -56,9 +56,9 @@ int main(int, char**)
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
E3 e3;
// make sure Schrage's algorithm is used
assert(e3() == 402727752ull);
assert(e3() == 162159612030764687ull);
assert(e3() == 108176466184989142ull);
assert(e3() == 0x18012348ull);
assert(e3() == 0x2401b4ed802468full);
assert(e3() == 0x18051ec400369d6ull);
// make sure result is in bounds
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));
@ -66,19 +66,52 @@ int main(int, char**)
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));
// 32-bit case:
// m might overflow. The overflow is not OK, result will be in bounds,
// and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
E4 e4;
// make sure enough precision is used
assert(e4() == 0x10009u);
assert(e4() == 0x120053u);
assert(e4() == 0xf5030fu);
// make sure result is in bounds
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
#ifndef _LIBCPP_HAS_NO_INT128
// m might overflow. The overflow is not OK, result will be in bounds,
// and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
E5 e5;
// make sure enough precision is used
assert(e5() == 0x100000001ull);
assert(e5() == 0x200000009ull);
assert(e5() == 0xb00000019ull);
// make sure result is in bounds
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
#endif
// m will not overflow so we should not use Schrage's algorithm
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
E4 e4;
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
E6 e6;
// make sure the correct algorithm was used
assert(e4() == 2ull);
assert(e4() == 3ull);
assert(e4() == 4ull);
assert(e6() == 2ull);
assert(e6() == 3ull);
assert(e6() == 4ull);
// make sure result is in bounds
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
return 0;
}
}

View File

@ -61,24 +61,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}
/*
// Cases where m is odd and m % a > m / a (not implemented)
template <class T>
void test_ext() {
const T M(static_cast<T>(-1));
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}
int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif
return 0;
return 0;
}

View File

@ -60,24 +60,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}
/*
// Cases where m is odd and m % a > m / a (not implemented)
template <class T>
void test_ext() {
const T M(static_cast<T>(-1));
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}
int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif
return 0;
return 0;
}

View File

@ -58,24 +58,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}
/*
// Cases where m is odd and m % a > m / a (not implemented)
template <class T>
void test_ext() {
const T M(static_cast<T>(-1));
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}
int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif
return 0;
return 0;
}

View File

@ -91,24 +91,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}
/*
// Cases where m is odd and m % a > m / a (not implemented)
template <class T>
void test_ext() {
const T M(static_cast<T>(-1));
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}
int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif
return 0;
return 0;
}