Skip to content

Commit cce8506

Browse files
authored
feat: add seeded model initialization (#21)
1 parent f775f11 commit cce8506

8 files changed

Lines changed: 89 additions & 12 deletions

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ IrtRuby::TwoParameterModel.new(
123123
decay_factor: 0.5
124124
)
125125
```
126+
127+
### Reproducible Initialization
128+
Each model initializes parameters randomly. By default, constructors use Ruby's global random number generator, preserving the historical behavior and honoring any external `srand` calls. For reproducible model initialization without resetting or consuming global RNG state, pass `seed:`:
129+
130+
```ruby
131+
model_a = IrtRuby::ThreeParameterModel.new(data, seed: 1234)
132+
model_b = IrtRuby::ThreeParameterModel.new(data, seed: 1234)
133+
134+
# Same data, options, and seed produce identical fitted results.
135+
model_a.fit == model_b.fit #=> true
136+
```
137+
138+
The `seed:` keyword is available for `RaschModel`, `TwoParameterModel`, and `ThreeParameterModel`.
139+
126140
### Parameter Clamping
127141
For 2PL and 3PL:
128142

lib/irt_ruby/rasch_model.rb

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def initialize(data,
1818
param_tolerance: 1e-6,
1919
learning_rate: 0.01,
2020
decay_factor: 0.5,
21-
missing_strategy: :ignore)
21+
missing_strategy: :ignore,
22+
seed: nil)
2223
# data: A Matrix or array-of-arrays of responses (0/1 or nil for missing).
2324
# missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct
2425

@@ -36,10 +37,11 @@ def initialize(data,
3637
raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)
3738

3839
@missing_strategy = missing_strategy
40+
@random = seed.nil? ? nil : Random.new(seed)
3941

4042
# Initialize parameters near zero
41-
@abilities = Array.new(num_rows) { rand(-0.25..0.25) }
42-
@difficulties = Array.new(num_cols) { rand(-0.25..0.25) }
43+
@abilities = Array.new(num_rows) { random_between(-0.25..0.25) }
44+
@difficulties = Array.new(num_cols) { random_between(-0.25..0.25) }
4345

4446
@max_iter = max_iter
4547
@tolerance = tolerance
@@ -52,6 +54,11 @@ def sigmoid(x)
5254
1.0 / (1.0 + Math.exp(-x))
5355
end
5456

57+
def random_between(range)
58+
@random ? @random.rand(range) : rand(range)
59+
end
60+
private :random_between
61+
5562
def resolve_missing(resp)
5663
return [resp, false] unless resp.nil?
5764

lib/irt_ruby/three_parameter_model.rb

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def initialize(data,
2020
param_tolerance: 1e-6,
2121
learning_rate: 0.01,
2222
decay_factor: 0.5,
23-
missing_strategy: :ignore)
23+
missing_strategy: :ignore,
24+
seed: nil)
2425
ModelOptionsValidator.validate!(max_iter: max_iter,
2526
tolerance: tolerance,
2627
param_tolerance: param_tolerance,
@@ -35,12 +36,13 @@ def initialize(data,
3536
raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)
3637

3738
@missing_strategy = missing_strategy
39+
@random = seed.nil? ? nil : Random.new(seed)
3840

3941
# Initialize parameters
40-
@abilities = Array.new(num_rows) { rand(-0.25..0.25) }
41-
@difficulties = Array.new(num_cols) { rand(-0.25..0.25) }
42-
@discriminations = Array.new(num_cols) { rand(0.5..1.5) }
43-
@guessings = Array.new(num_cols) { rand(0.0..0.3) }
42+
@abilities = Array.new(num_rows) { random_between(-0.25..0.25) }
43+
@difficulties = Array.new(num_cols) { random_between(-0.25..0.25) }
44+
@discriminations = Array.new(num_cols) { random_between(0.5..1.5) }
45+
@guessings = Array.new(num_cols) { random_between(0.0..0.3) }
4446

4547
@max_iter = max_iter
4648
@tolerance = tolerance
@@ -53,6 +55,11 @@ def sigmoid(x)
5355
1.0 / (1.0 + Math.exp(-x))
5456
end
5557

58+
def random_between(range)
59+
@random ? @random.rand(range) : rand(range)
60+
end
61+
private :random_between
62+
5663
# Probability for the 3PL model: c + (1-c)*sigmoid(a*(θ - b))
5764
def probability(theta, a, b, c)
5865
c + ((1.0 - c) * sigmoid(a * (theta - b)))

lib/irt_ruby/two_parameter_model.rb

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TwoParameterModel
1616

1717
def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
1818
learning_rate: 0.01, decay_factor: 0.5,
19-
missing_strategy: :ignore)
19+
missing_strategy: :ignore, seed: nil)
2020
ModelOptionsValidator.validate!(max_iter: max_iter,
2121
tolerance: tolerance,
2222
param_tolerance: param_tolerance,
@@ -31,12 +31,13 @@ def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
3131
raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)
3232

3333
@missing_strategy = missing_strategy
34+
@random = seed.nil? ? nil : Random.new(seed)
3435

3536
# Initialize parameters
3637
# Typically: ability ~ 0, difficulty ~ 0, discrimination ~ 1
37-
@abilities = Array.new(num_rows) { rand(-0.25..0.25) }
38-
@difficulties = Array.new(num_cols) { rand(-0.25..0.25) }
39-
@discriminations = Array.new(num_cols) { rand(0.5..1.5) }
38+
@abilities = Array.new(num_rows) { random_between(-0.25..0.25) }
39+
@difficulties = Array.new(num_cols) { random_between(-0.25..0.25) }
40+
@discriminations = Array.new(num_cols) { random_between(0.5..1.5) }
4041

4142
@max_iter = max_iter
4243
@tolerance = tolerance
@@ -49,6 +50,11 @@ def sigmoid(x)
4950
1.0 / (1.0 + Math.exp(-x))
5051
end
5152

53+
def random_between(range)
54+
@random ? @random.rand(range) : rand(range)
55+
end
56+
private :random_between
57+
5258
def resolve_missing(resp)
5359
return [resp, false] unless resp.nil?
5460

spec/irt_ruby/rasch_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
RSpec.describe IrtRuby::RaschModel do
66
it_behaves_like "response data validation"
77
it_behaves_like "model optimization option validation"
8+
it_behaves_like "seeded model initialization", %i[abilities difficulties]
89

910
let(:data_array) do
1011
[

spec/irt_ruby/three_parameter_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
RSpec.describe IrtRuby::ThreeParameterModel do
66
it_behaves_like "response data validation"
77
it_behaves_like "model optimization option validation"
8+
it_behaves_like "seeded model initialization", %i[abilities difficulties discriminations guessings]
89

910
let(:data_array) do
1011
[

spec/irt_ruby/two_parameter_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
RSpec.describe IrtRuby::TwoParameterModel do
66
it_behaves_like "response data validation"
77
it_behaves_like "model optimization option validation"
8+
it_behaves_like "seeded model initialization", %i[abilities difficulties discriminations]
89

910
let(:data_array) do
1011
[

spec/spec_helper.rb

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,46 @@
5959
end
6060
end
6161

62+
RSpec.shared_examples "seeded model initialization" do |parameter_names|
63+
let(:seeded_fit_options) { { max_iter: 50, learning_rate: 0.05 } }
64+
65+
def seeded_parameter_snapshot(model, parameter_names)
66+
parameter_names.to_h do |parameter_name|
67+
[parameter_name, model.instance_variable_get("@#{parameter_name}").dup]
68+
end
69+
end
70+
71+
it "produces identical initial and fitted parameters with the same seed" do
72+
model1 = described_class.new(data_array, **seeded_fit_options, seed: 12_345)
73+
model2 = described_class.new(data_array, **seeded_fit_options, seed: 12_345)
74+
75+
expect(seeded_parameter_snapshot(model1, parameter_names)).to eq(
76+
seeded_parameter_snapshot(model2, parameter_names)
77+
)
78+
expect(model1.fit).to eq(model2.fit)
79+
end
80+
81+
it "produces different initial parameters with different seeds" do
82+
model1 = described_class.new(data_array, **seeded_fit_options, seed: 12_345)
83+
model2 = described_class.new(data_array, **seeded_fit_options, seed: 54_321)
84+
85+
expect(seeded_parameter_snapshot(model1, parameter_names)).not_to eq(
86+
seeded_parameter_snapshot(model2, parameter_names)
87+
)
88+
end
89+
90+
it "does not reset or consume Ruby's global random number generator" do
91+
srand(98_765)
92+
expected_values = Array.new(5) { rand }
93+
94+
srand(98_765)
95+
described_class.new(data_array, **seeded_fit_options, seed: 12_345)
96+
actual_values = Array.new(5) { rand }
97+
98+
expect(actual_values).to eq(expected_values)
99+
end
100+
end
101+
62102
RSpec.configure do |config|
63103
# Enable flags like --only-failures and --next-failure
64104
config.example_status_persistence_file_path = ".rspec_status"

0 commit comments

Comments
 (0)