Skip to content

Commit d83af1d

Browse files
authored
feat(model): support user chose random generators (#22)
1 parent bec4ee8 commit d83af1d

3 files changed

Lines changed: 350 additions & 30 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ dyn-clone = { version = "1.0.10", optional = true }
1818
human-bandwidth = { version = "0.1.3", optional = true }
1919
humantime-serde = { version = "1.1.1", optional = true }
2020
itertools = { version = "0.14.0", optional = true }
21-
once_cell = { version = "1.17.0", optional = true }
22-
rand = { version = "0.9.0", optional = true }
23-
rand_distr = { version = "0.5.0", optional = true }
21+
rand = { version = "0.9.1", optional = true }
22+
rand_distr = { version = "0.5.1", optional = true }
2423
serde = { version = "1.0", features = ["derive"], optional = true }
2524
statrs = { version = "0.18.0", optional = true }
2625
typetag = { version = "0.2.5", optional = true }
2726

2827
[dev-dependencies]
2928
figment = { version = "0.10.19", features = ["json"] }
29+
rand_chacha = "0.9"
3030
serde_json = "1.0"
3131

3232
[features]
@@ -38,7 +38,7 @@ model = [
3838
"loss-model",
3939
"duplicate-model"
4040
]
41-
bw-model = ["dep:rand", "dep:rand_distr", "dep:once_cell", "dep:dyn-clone"]
41+
bw-model = ["dep:rand", "dep:rand_distr", "dep:dyn-clone"]
4242
delay-model = ["dep:dyn-clone"]
4343
delay-per-packet-model = ["dep:dyn-clone"]
4444
loss-model = ["dep:dyn-clone"]

src/model/bw.rs

Lines changed: 175 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
//! ```
5858
use crate::{Bandwidth, BwTrace, Duration};
5959
use dyn_clone::DynClone;
60-
use rand::rngs::StdRng;
61-
use rand::SeedableRng;
60+
use rand::{rngs::StdRng, RngCore, SeedableRng};
6261
use rand_distr::{Distribution, Normal};
6362

6463
const DEFAULT_RNG_SEED: u64 = 42;
@@ -162,15 +161,18 @@ pub struct StaticBwConfig {
162161
/// assert_eq!(normal_bw.next_bw(), Some((Bandwidth::from_bps(12100000), Duration::from_millis(100))));
163162
/// ```
164163
#[derive(Debug, Clone)]
165-
pub struct NormalizedBw {
164+
pub struct NormalizedBw<Rng = StdRng>
165+
where
166+
Rng: RngCore,
167+
{
166168
pub mean: Bandwidth,
167169
pub std_dev: Bandwidth,
168170
pub upper_bound: Option<Bandwidth>,
169171
pub lower_bound: Option<Bandwidth>,
170172
pub duration: Duration,
171173
pub step: Duration,
172174
pub seed: u64,
173-
rng: StdRng,
175+
rng: Rng,
174176
normal: Normal<f64>,
175177
}
176178

@@ -317,7 +319,10 @@ pub struct NormalizedBwConfig {
317319
/// );
318320
/// ```
319321
#[derive(Debug, Clone)]
320-
pub struct SawtoothBw {
322+
pub struct SawtoothBw<Rng = StdRng>
323+
where
324+
Rng: RngCore,
325+
{
321326
pub bottom: Bandwidth,
322327
pub top: Bandwidth,
323328
pub interval: Duration,
@@ -329,7 +334,7 @@ pub struct SawtoothBw {
329334
pub upper_noise_bound: Option<Bandwidth>,
330335
pub lower_noise_bound: Option<Bandwidth>,
331336
current: Duration,
332-
rng: StdRng,
337+
rng: Rng,
333338
noise: Normal<f64>,
334339
}
335340

@@ -768,7 +773,7 @@ impl BwTrace for StaticBw {
768773
}
769774
}
770775

771-
impl BwTrace for NormalizedBw {
776+
impl<Rng: RngCore + Send> BwTrace for NormalizedBw<Rng> {
772777
fn next_bw(&mut self) -> Option<(Bandwidth, Duration)> {
773778
if self.duration.is_zero() {
774779
None
@@ -788,7 +793,7 @@ impl BwTrace for NormalizedBw {
788793
}
789794
}
790795

791-
impl BwTrace for SawtoothBw {
796+
impl<Rng: RngCore + Send> BwTrace for SawtoothBw<Rng> {
792797
fn next_bw(&mut self) -> Option<(Bandwidth, Duration)> {
793798
if self.duration.is_zero() {
794799
None
@@ -870,7 +875,7 @@ impl BwTrace for TraceBw {
870875
}
871876
}
872877

873-
impl NormalizedBw {
878+
impl<Rng: RngCore> NormalizedBw<Rng> {
874879
pub fn sample(&mut self) -> f64 {
875880
self.normal.sample(&mut self.rng)
876881
}
@@ -912,6 +917,7 @@ macro_rules! saturating_bandwidth_as_bps_u64 {
912917
}
913918

914919
impl NormalizedBwConfig {
920+
/// Creates an uninitialized config
915921
pub fn new() -> Self {
916922
Self {
917923
mean: None,
@@ -924,50 +930,127 @@ impl NormalizedBwConfig {
924930
}
925931
}
926932

933+
/// Sets the mean
934+
///
935+
/// If the mean is not set, 12Mbps will be used.
927936
pub fn mean(mut self, mean: Bandwidth) -> Self {
928937
self.mean = Some(mean);
929938
self
930939
}
931940

941+
/// Sets the standard deviation
942+
///
943+
/// If the standard deviation is not set, 0Mbps will be used.
932944
pub fn std_dev(mut self, std_dev: Bandwidth) -> Self {
933945
self.std_dev = Some(std_dev);
934946
self
935947
}
936948

949+
/// Sets the upper bound
950+
///
951+
/// If the upper bound is not set, the upper bound will be the one of [`Bandwidth`].
937952
pub fn upper_bound(mut self, upper_bound: Bandwidth) -> Self {
938953
self.upper_bound = Some(upper_bound);
939954
self
940955
}
941956

957+
/// Sets the lower bound
958+
///
959+
/// If the lower bound is not set, the lower bound will be the one of [`Bandwidth`].
942960
pub fn lower_bound(mut self, lower_bound: Bandwidth) -> Self {
943961
self.lower_bound = Some(lower_bound);
944962
self
945963
}
946964

965+
/// Sets the total duration of the trace
966+
///
967+
/// If the total duration is not set, 1 second will be used.
947968
pub fn duration(mut self, duration: Duration) -> Self {
948969
self.duration = Some(duration);
949970
self
950971
}
951972

973+
/// Sets the duration of each value
974+
///
975+
/// If the step is not set, 1ms will be used.
952976
pub fn step(mut self, step: Duration) -> Self {
953977
self.step = Some(step);
954978
self
955979
}
956980

981+
/// Set the seed for a random generator
982+
///
983+
/// If the seed is not set, `42` will be used.
957984
pub fn seed(mut self, seed: u64) -> Self {
958985
self.seed = Some(seed);
959986
self
960987
}
961988

989+
/// Allows to use a randomly generated seed
990+
///
991+
/// This is equivalent to: `self.seed(rand::random())`
992+
pub fn random_seed(mut self) -> Self {
993+
self.seed = Some(rand::random());
994+
self
995+
}
996+
997+
/// Creates a new [`NormalizedBw`] corresponding to this config.
998+
///
999+
/// The created model will use [`StdRng`] as source of randomness (the call is equivalent to `self.build_with_rng::<StdRng>()`).
1000+
/// It should be sufficient for most cases, but [`StdRng`] is not a portable random number generator,
1001+
/// so one may want to use a portable random number generator like [`ChaCha`](https://crates.io/crates/rand_chacha),
1002+
/// to this end one can use [`build_with_rng`](Self::build_with_rng).
9621003
pub fn build(self) -> NormalizedBw {
1004+
self.build_with_rng()
1005+
}
1006+
1007+
/// Creates a new [`NormalizedBw`] corresponding to this config.
1008+
///
1009+
/// Unlike [`build`](Self::build), this method let you choose the random generator.
1010+
///
1011+
/// # Example
1012+
/// ```rust
1013+
/// # use netem_trace::model::NormalizedBwConfig;
1014+
/// # use netem_trace::{Bandwidth, BwTrace};
1015+
/// # use std::time::Duration;
1016+
/// # use rand::rngs::StdRng;
1017+
/// # use rand_chacha::ChaCha20Rng;
1018+
///
1019+
/// let normal_bw = NormalizedBwConfig::new()
1020+
/// .mean(Bandwidth::from_mbps(12))
1021+
/// .std_dev(Bandwidth::from_mbps(1))
1022+
/// .duration(Duration::from_millis(3))
1023+
/// .seed(42);
1024+
///
1025+
/// let mut default_build = normal_bw.clone().build();
1026+
/// let mut std_build = normal_bw.clone().build_with_rng::<StdRng>();
1027+
/// // ChaCha is deterministic and portable, unlike StdRng
1028+
/// let mut chacha_build = normal_bw.clone().build_with_rng::<ChaCha20Rng>();
1029+
///
1030+
/// for cha in [12044676, 11754367, 11253775] {
1031+
/// let default = default_build.next_bw();
1032+
/// let std = std_build.next_bw();
1033+
/// let chacha = chacha_build.next_bw();
1034+
///
1035+
/// assert!(default.is_some());
1036+
/// assert_eq!(default, std);
1037+
/// assert_ne!(default, chacha);
1038+
/// assert_eq!(chacha, Some((Bandwidth::from_bps(cha), Duration::from_millis(1))));
1039+
/// }
1040+
///
1041+
/// assert_eq!(default_build.next_bw(), None);
1042+
/// assert_eq!(std_build.next_bw(), None);
1043+
/// assert_eq!(chacha_build.next_bw(), None);
1044+
/// ```
1045+
pub fn build_with_rng<Rng: SeedableRng + RngCore>(self) -> NormalizedBw<Rng> {
9631046
let mean = self.mean.unwrap_or_else(|| Bandwidth::from_mbps(12));
9641047
let std_dev = self.std_dev.unwrap_or_else(|| Bandwidth::from_mbps(0));
9651048
let upper_bound = self.upper_bound;
9661049
let lower_bound = self.lower_bound;
9671050
let duration = self.duration.unwrap_or_else(|| Duration::from_secs(1));
9681051
let step = self.step.unwrap_or_else(|| Duration::from_millis(1));
9691052
let seed = self.seed.unwrap_or(DEFAULT_RNG_SEED);
970-
let rng = StdRng::seed_from_u64(seed);
1053+
let rng = Rng::seed_from_u64(seed);
9711054
let bw_mean = saturating_bandwidth_as_bps_u64!(mean) as f64;
9721055
let bw_std_dev = saturating_bandwidth_as_bps_u64!(std_dev) as f64;
9731056
let normal: Normal<f64> = Normal::new(bw_mean, bw_std_dev).unwrap();
@@ -1034,7 +1117,14 @@ impl NormalizedBwConfig {
10341117
/// assert_eq!(avg_mbps(truncate_build), 11.978819427569897);
10351118
///
10361119
/// ```
1037-
pub fn build_truncated(mut self) -> NormalizedBw {
1120+
pub fn build_truncated(self) -> NormalizedBw {
1121+
self.build_truncated_with_rng()
1122+
}
1123+
1124+
/// Similar to [`build_truncated`](Self::build_truncated) but let you choose the random generator.
1125+
///
1126+
/// See [`build`](Self::build) for details about the reason for using another random number generator than [`StdRng`].
1127+
pub fn build_truncated_with_rng<Rng: SeedableRng + RngCore>(mut self) -> NormalizedBw<Rng> {
10381128
let mean = self
10391129
.mean
10401130
.unwrap_or_else(|| Bandwidth::from_mbps(12))
@@ -1052,11 +1142,12 @@ impl NormalizedBwConfig {
10521142
let upper = self.upper_bound.map(|upper| upper.as_gbps_f64() / mean);
10531143
let new_mean = mean * solve(1f64, sigma, Some(lower), upper).unwrap_or(1f64);
10541144
self.mean = Some(Bandwidth::from_gbps_f64(new_mean));
1055-
self.build()
1145+
self.build_with_rng()
10561146
}
10571147
}
10581148

10591149
impl SawtoothBwConfig {
1150+
/// Creates an uninitialized config
10601151
pub fn new() -> Self {
10611152
Self {
10621153
bottom: None,
@@ -1092,21 +1183,41 @@ impl SawtoothBwConfig {
10921183
self
10931184
}
10941185

1186+
/// Sets the total duration of the trace
1187+
///
1188+
/// If the total duration is not set, 1 second will be used.
10951189
pub fn duration(mut self, duration: Duration) -> Self {
10961190
self.duration = Some(duration);
10971191
self
10981192
}
10991193

1194+
/// Sets the duration of each value
1195+
///
1196+
/// If the step is not set, 1ms will be used.
11001197
pub fn step(mut self, step: Duration) -> Self {
11011198
self.step = Some(step);
11021199
self
11031200
}
11041201

1202+
/// Set the seed for a random generator
1203+
///
1204+
/// If the seed is not set, `42` will be used.
11051205
pub fn seed(mut self, seed: u64) -> Self {
11061206
self.seed = Some(seed);
11071207
self
11081208
}
11091209

1210+
/// Allows to use a randomly generated seed
1211+
///
1212+
/// This is equivalent to: `self.seed(rand::random())`
1213+
pub fn random_seed(mut self) -> Self {
1214+
self.seed = Some(rand::random());
1215+
self
1216+
}
1217+
1218+
/// Sets the standard deviation
1219+
///
1220+
/// If the standard deviation is not set, 0Mbps will be used.
11101221
pub fn std_dev(mut self, std_dev: Bandwidth) -> Self {
11111222
self.std_dev = Some(std_dev);
11121223
self
@@ -1122,7 +1233,58 @@ impl SawtoothBwConfig {
11221233
self
11231234
}
11241235

1236+
/// Creates a new [`SawtoothBw`] corresponding to this config.
1237+
///
1238+
/// The created model will use [`StdRng`] as source of randomness (the call is equivalent to `self.build_with_rng::<StdRng>()`).
1239+
/// It should be sufficient for most cases, but [`StdRng`] is not a portable random number generator,
1240+
/// so one may want to use a portable random number generator like [`ChaCha`](https://crates.io/crates/rand_chacha),
1241+
/// to this end one can use [`build_with_rng`](Self::build_with_rng).
11251242
pub fn build(self) -> SawtoothBw {
1243+
self.build_with_rng()
1244+
}
1245+
1246+
/// Creates a new [`SawtoothBw`] corresponding to this config.
1247+
///
1248+
/// Unlike [`build`](Self::build), this method let you choose the random generator.
1249+
///
1250+
/// # Example
1251+
/// ```rust
1252+
/// # use netem_trace::model::SawtoothBwConfig;
1253+
/// # use netem_trace::{Bandwidth, BwTrace};
1254+
/// # use std::time::Duration;
1255+
/// # use rand::rngs::StdRng;
1256+
/// # use rand_chacha::ChaCha20Rng;
1257+
///
1258+
/// let sawtooth_bw = SawtoothBwConfig::new()
1259+
/// .bottom(Bandwidth::from_mbps(12))
1260+
/// .top(Bandwidth::from_mbps(16))
1261+
/// .std_dev(Bandwidth::from_mbps(1))
1262+
/// .duration(Duration::from_millis(3))
1263+
/// .interval(Duration::from_millis(5))
1264+
/// .duty_ratio(0.8)
1265+
/// .seed(42);
1266+
///
1267+
/// let mut default_build = sawtooth_bw.clone().build();
1268+
/// let mut std_build = sawtooth_bw.clone().build_with_rng::<StdRng>();
1269+
/// // ChaCha is deterministic and portable, unlike StdRng
1270+
/// let mut chacha_build = sawtooth_bw.clone().build_with_rng::<ChaCha20Rng>();
1271+
///
1272+
/// for cha in [12044676, 12754367, 13253775] {
1273+
/// let default = default_build.next_bw();
1274+
/// let std = std_build.next_bw();
1275+
/// let chacha = chacha_build.next_bw();
1276+
///
1277+
/// assert!(default.is_some());
1278+
/// assert_eq!(default, std);
1279+
/// assert_ne!(default, chacha);
1280+
/// assert_eq!(chacha, Some((Bandwidth::from_bps(cha), Duration::from_millis(1))));
1281+
/// }
1282+
///
1283+
/// assert_eq!(default_build.next_bw(), None);
1284+
/// assert_eq!(std_build.next_bw(), None);
1285+
/// assert_eq!(chacha_build.next_bw(), None);
1286+
/// ```
1287+
pub fn build_with_rng<Rng: RngCore + SeedableRng>(self) -> SawtoothBw<Rng> {
11261288
let bottom = self.bottom.unwrap_or_else(|| Bandwidth::from_mbps(0));
11271289
let top = self.top.unwrap_or_else(|| Bandwidth::from_mbps(12));
11281290
if bottom > top {
@@ -1133,7 +1295,7 @@ impl SawtoothBwConfig {
11331295
let duration = self.duration.unwrap_or_else(|| Duration::from_secs(1));
11341296
let step = self.step.unwrap_or_else(|| Duration::from_millis(1));
11351297
let seed = self.seed.unwrap_or(DEFAULT_RNG_SEED);
1136-
let rng = StdRng::seed_from_u64(seed);
1298+
let rng = Rng::seed_from_u64(seed);
11371299
let std_dev = self.std_dev.unwrap_or_else(|| Bandwidth::from_mbps(0));
11381300
let upper_noise_bound = self.upper_noise_bound;
11391301
let lower_noise_bound = self.lower_noise_bound;

0 commit comments

Comments
 (0)