Skip to content

Commit e6e7734

Browse files
committed
Add math_helpers module and various fixes
1 parent 209836f commit e6e7734

File tree

4 files changed

+189
-164
lines changed

4 files changed

+189
-164
lines changed

src/distributions/float.rs

+22-80
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
use core::mem;
1414
use Rng;
1515
use distributions::{Distribution, Standard};
16+
use distributions::math_helpers::CastFromInt;
1617
#[cfg(feature="simd_support")]
1718
use core::simd::*;
1819

@@ -84,71 +85,9 @@ pub(crate) trait IntoFloat {
8485
fn into_float_with_exponent(self, exponent: i32) -> Self::F;
8586
}
8687

87-
macro_rules! float_impls {
88-
($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => {
89-
impl IntoFloat for $uty {
90-
type F = $ty;
91-
#[inline(always)]
92-
fn into_float_with_exponent(self, exponent: i32) -> $ty {
93-
// The exponent is encoded using an offset-binary representation
94-
let exponent_bits =
95-
(($exponent_bias + exponent) as $uty) << $fraction_bits;
96-
unsafe { mem::transmute(self | exponent_bits) }
97-
}
98-
}
99-
100-
impl Distribution<$ty> for Standard {
101-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
102-
// Multiply-based method; 24/53 random bits; [0, 1) interval.
103-
// We use the most significant bits because for simple RNGs
104-
// those are usually more random.
105-
let float_size = mem::size_of::<$ty>() * 8;
106-
let precision = $fraction_bits + 1;
107-
let scale = 1.0 / ((1 as $uty << precision) as $ty);
108-
109-
let value: $uty = rng.gen();
110-
scale * (value >> (float_size - precision)) as $ty
111-
}
112-
}
113-
114-
impl Distribution<$ty> for OpenClosed01 {
115-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
116-
// Multiply-based method; 24/53 random bits; (0, 1] interval.
117-
// We use the most significant bits because for simple RNGs
118-
// those are usually more random.
119-
let float_size = mem::size_of::<$ty>() * 8;
120-
let precision = $fraction_bits + 1;
121-
let scale = 1.0 / ((1 as $uty << precision) as $ty);
122-
123-
let value: $uty = rng.gen();
124-
let value = value >> (float_size - precision);
125-
// Add 1 to shift up; will not overflow because of right-shift:
126-
scale * (value + 1) as $ty
127-
}
128-
}
129-
130-
impl Distribution<$ty> for Open01 {
131-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
132-
// Transmute-based method; 23/52 random bits; (0, 1) interval.
133-
// We use the most significant bits because for simple RNGs
134-
// those are usually more random.
135-
const EPSILON: $ty = 1.0 / (1u64 << $fraction_bits) as $ty;
136-
let float_size = mem::size_of::<$ty>() * 8;
137-
138-
let value: $uty = rng.gen();
139-
let fraction = value >> (float_size - $fraction_bits);
140-
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
141-
}
142-
}
143-
}
144-
}
145-
float_impls! { f32, u32, 23, 127 }
146-
float_impls! { f64, u64, 52, 1023 }
147-
148-
14988
#[cfg(feature="simd_support")]
150-
macro_rules! simd_float_impls {
151-
($ty:ident, $uty:ident, $f_scalar:ty, $u_scalar:ty,
89+
macro_rules! float_impls {
90+
($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty,
15291
$fraction_bits:expr, $exponent_bias:expr) => {
15392
impl IntoFloat for $uty {
15493
type F = $ty;
@@ -157,7 +96,7 @@ macro_rules! simd_float_impls {
15796
// The exponent is encoded using an offset-binary representation
15897
let exponent_bits: $u_scalar =
15998
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
160-
unsafe { mem::transmute(self | $uty::splat(exponent_bits)) }
99+
$ty::from_bits(self | exponent_bits)
161100
}
162101
}
163102

@@ -168,11 +107,11 @@ macro_rules! simd_float_impls {
168107
// those are usually more random.
169108
let float_size = mem::size_of::<$f_scalar>() * 8;
170109
let precision = $fraction_bits + 1;
171-
let scale = $ty::splat(1.0 / ((1 as $u_scalar << precision) as $f_scalar));
110+
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
172111

173112
let value: $uty = rng.gen();
174-
let value = $ty::from(value >> (float_size - precision));
175-
scale * value
113+
let value = value >> (float_size - precision);
114+
scale * $ty::cast_from_int(value)
176115
}
177116
}
178117

@@ -183,12 +122,12 @@ macro_rules! simd_float_impls {
183122
// those are usually more random.
184123
let float_size = mem::size_of::<$f_scalar>() * 8;
185124
let precision = $fraction_bits + 1;
186-
let scale = $ty::splat(1.0 / ((1 as $u_scalar << precision) as $f_scalar));
125+
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);
187126

188127
let value: $uty = rng.gen();
128+
let value = value >> (float_size - precision);
189129
// Add 1 to shift up; will not overflow because of right-shift:
190-
let value = $ty::from((value >> (float_size - precision)) + 1);
191-
scale * value
130+
scale * $ty::cast_from_int(value + 1)
192131
}
193132
}
194133

@@ -197,32 +136,35 @@ macro_rules! simd_float_impls {
197136
// Transmute-based method; 23/52 random bits; (0, 1) interval.
198137
// We use the most significant bits because for simple RNGs
199138
// those are usually more random.
200-
const EPSILON: $f_scalar = 1.0 / (1u64 << $fraction_bits) as $f_scalar;
139+
use core::$f_scalar::EPSILON;
201140
let float_size = mem::size_of::<$f_scalar>() * 8;
202141

203142
let value: $uty = rng.gen();
204143
let fraction = value >> (float_size - $fraction_bits);
205-
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
144+
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
206145
}
207146
}
208147
}
209148
}
210149

150+
float_impls! { f32, u32, f32, u32, 23, 127 }
151+
float_impls! { f64, u64, f64, u64, 52, 1023 }
152+
211153
#[cfg(feature="simd_support")]
212-
simd_float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
154+
float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
213155
#[cfg(feature="simd_support")]
214-
simd_float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
156+
float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
215157
#[cfg(feature="simd_support")]
216-
simd_float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
158+
float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
217159
#[cfg(feature="simd_support")]
218-
simd_float_impls! { f32x16, u32x16, f32, u32, 23, 127 }
160+
float_impls! { f32x16, u32x16, f32, u32, 23, 127 }
219161

220162
#[cfg(feature="simd_support")]
221-
simd_float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
163+
float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
222164
#[cfg(feature="simd_support")]
223-
simd_float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
165+
float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
224166
#[cfg(feature="simd_support")]
225-
simd_float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
167+
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
226168

227169

228170
#[cfg(test)]

src/distributions/math_helpers.rs

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! Math helper functions
12+
13+
#[cfg(feature="simd_support")]
14+
use core::simd::*;
15+
16+
17+
pub trait WideningMultiply<RHS = Self> {
18+
type Output;
19+
20+
fn wmul(self, x: RHS) -> Self::Output;
21+
}
22+
23+
macro_rules! wmul_impl {
24+
($ty:ty, $wide:ty, $shift:expr) => {
25+
impl WideningMultiply for $ty {
26+
type Output = ($ty, $ty);
27+
28+
#[inline(always)]
29+
fn wmul(self, x: $ty) -> Self::Output {
30+
let tmp = (self as $wide) * (x as $wide);
31+
((tmp >> $shift) as $ty, tmp as $ty)
32+
}
33+
}
34+
}
35+
}
36+
wmul_impl! { u8, u16, 8 }
37+
wmul_impl! { u16, u32, 16 }
38+
wmul_impl! { u32, u64, 32 }
39+
#[cfg(feature = "i128_support")]
40+
wmul_impl! { u64, u128, 64 }
41+
42+
// This code is a translation of the __mulddi3 function in LLVM's
43+
// compiler-rt. It is an optimised variant of the common method
44+
// `(a + b) * (c + d) = ac + ad + bc + bd`.
45+
//
46+
// For some reason LLVM can optimise the C version very well, but
47+
// keeps shuffling registers in this Rust translation.
48+
macro_rules! wmul_impl_large {
49+
($ty:ty, $half:expr) => {
50+
impl WideningMultiply for $ty {
51+
type Output = ($ty, $ty);
52+
53+
#[inline(always)]
54+
fn wmul(self, b: $ty) -> Self::Output {
55+
const LOWER_MASK: $ty = !0 >> $half;
56+
let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
57+
let mut t = low >> $half;
58+
low &= LOWER_MASK;
59+
t += (self >> $half).wrapping_mul(b & LOWER_MASK);
60+
low += (t & LOWER_MASK) << $half;
61+
let mut high = t >> $half;
62+
t = low >> $half;
63+
low &= LOWER_MASK;
64+
t += (b >> $half).wrapping_mul(self & LOWER_MASK);
65+
low += (t & LOWER_MASK) << $half;
66+
high += t >> $half;
67+
high += (self >> $half).wrapping_mul(b >> $half);
68+
69+
(high, low)
70+
}
71+
}
72+
}
73+
}
74+
#[cfg(not(feature = "i128_support"))]
75+
wmul_impl_large! { u64, 32 }
76+
#[cfg(feature = "i128_support")]
77+
wmul_impl_large! { u128, 64 }
78+
79+
macro_rules! wmul_impl_usize {
80+
($ty:ty) => {
81+
impl WideningMultiply for usize {
82+
type Output = (usize, usize);
83+
84+
#[inline(always)]
85+
fn wmul(self, x: usize) -> Self::Output {
86+
let (high, low) = (self as $ty).wmul(x as $ty);
87+
(high as usize, low as usize)
88+
}
89+
}
90+
}
91+
}
92+
#[cfg(target_pointer_width = "32")]
93+
wmul_impl_usize! { u32 }
94+
#[cfg(target_pointer_width = "64")]
95+
wmul_impl_usize! { u64 }
96+
97+
98+
pub trait CastFromInt<T> {
99+
fn cast_from_int(i: T) -> Self;
100+
}
101+
102+
impl CastFromInt<u32> for f32 {
103+
fn cast_from_int(i: u32) -> Self { i as f32 }
104+
}
105+
106+
impl CastFromInt<u64> for f64 {
107+
fn cast_from_int(i: u64) -> Self { i as f64 }
108+
}
109+
110+
#[cfg(feature="simd_support")]
111+
macro_rules! simd_float_from_int {
112+
($ty:ident, $uty:ident) => {
113+
impl CastFromInt<$uty> for $ty {
114+
fn cast_from_int(i: $uty) -> Self { $ty::from(i) }
115+
}
116+
}
117+
}
118+
119+
#[cfg(feature="simd_support")] simd_float_from_int! { f32x2, u32x2 }
120+
#[cfg(feature="simd_support")] simd_float_from_int! { f32x4, u32x4 }
121+
#[cfg(feature="simd_support")] simd_float_from_int! { f32x8, u32x8 }
122+
#[cfg(feature="simd_support")] simd_float_from_int! { f32x16, u32x16 }
123+
#[cfg(feature="simd_support")] simd_float_from_int! { f64x2, u64x2 }
124+
#[cfg(feature="simd_support")] simd_float_from_int! { f64x4, u64x4 }
125+
#[cfg(feature="simd_support")] simd_float_from_int! { f64x8, u64x8 }
126+
127+
128+
/// `PartialOrd` for vectors compares lexicographically. We want natural order.
129+
/// Only the comparison functions we need are implemented.
130+
pub trait NaturalCompare {
131+
fn cmp_lt(self, other: Self) -> bool;
132+
fn cmp_le(self, other: Self) -> bool;
133+
}
134+
135+
impl NaturalCompare for f32 {
136+
fn cmp_lt(self, other: Self) -> bool { self < other }
137+
fn cmp_le(self, other: Self) -> bool { self <= other }
138+
}
139+
140+
impl NaturalCompare for f64 {
141+
fn cmp_lt(self, other: Self) -> bool { self < other }
142+
fn cmp_le(self, other: Self) -> bool { self <= other }
143+
}
144+
145+
#[cfg(feature="simd_support")]
146+
macro_rules! simd_less_then {
147+
($ty:ident) => {
148+
impl NaturalCompare for $ty {
149+
fn cmp_lt(self, other: Self) -> bool { self.lt(other).all() }
150+
fn cmp_le(self, other: Self) -> bool { self.le(other).all() }
151+
}
152+
}
153+
}
154+
155+
#[cfg(feature="simd_support")] simd_less_then! { f32x2 }
156+
#[cfg(feature="simd_support")] simd_less_then! { f32x4 }
157+
#[cfg(feature="simd_support")] simd_less_then! { f32x8 }
158+
#[cfg(feature="simd_support")] simd_less_then! { f32x16 }
159+
#[cfg(feature="simd_support")] simd_less_then! { f64x2 }
160+
#[cfg(feature="simd_support")] simd_less_then! { f64x4 }
161+
#[cfg(feature="simd_support")] simd_less_then! { f64x8 }

src/distributions/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ mod float;
214214
mod integer;
215215
#[cfg(feature="std")]
216216
mod log_gamma;
217+
mod math_helpers;
217218
mod other;
218219
#[cfg(feature="std")]
219220
mod ziggurat_tables;

0 commit comments

Comments
 (0)