Skip to content

Commit f83cff6

Browse files
authored
fix: validate model optimization options (#18)
1 parent b3bc7f4 commit f83cff6

9 files changed

Lines changed: 81 additions & 0 deletions

lib/irt_ruby.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
require "irt_ruby/version"
44
require "matrix"
5+
require "irt_ruby/model_options_validator"
56
require "irt_ruby/response_data_validator"
67
require "irt_ruby/rasch_model"
78
require "irt_ruby/two_parameter_model"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# frozen_string_literal: true
2+
3+
module IrtRuby
4+
# Validates optimization hyperparameters shared by IRT model implementations.
5+
module ModelOptionsValidator
6+
module_function
7+
8+
def validate!(max_iter:, tolerance:, param_tolerance:, learning_rate:, decay_factor:)
9+
validate_positive_integer!(:max_iter, max_iter)
10+
validate_positive_finite_numeric!(:tolerance, tolerance)
11+
validate_positive_finite_numeric!(:param_tolerance, param_tolerance)
12+
validate_positive_finite_numeric!(:learning_rate, learning_rate)
13+
validate_decay_factor!(decay_factor)
14+
end
15+
16+
def validate_positive_integer!(name, value)
17+
return if value.is_a?(Integer) && value.positive?
18+
19+
raise ArgumentError, "#{name} must be a positive Integer"
20+
end
21+
22+
def validate_positive_finite_numeric!(name, value)
23+
return if finite_numeric?(value) && value.positive?
24+
25+
raise ArgumentError, "#{name} must be a positive finite Numeric"
26+
end
27+
28+
def validate_decay_factor!(value)
29+
return if finite_numeric?(value) && value.positive? && value < 1
30+
31+
raise ArgumentError, "decay_factor must be a finite Numeric strictly between 0 and 1"
32+
end
33+
34+
def finite_numeric?(value)
35+
value.is_a?(Numeric) && !value.is_a?(Complex) && value.finite?
36+
end
37+
end
38+
end

lib/irt_ruby/rasch_model.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# frozen_string_literal: true
22

3+
require "irt_ruby/model_options_validator"
34
require "irt_ruby/response_data_validator"
45

56
module IrtRuby
@@ -21,6 +22,12 @@ def initialize(data,
2122
# data: A Matrix or array-of-arrays of responses (0/1 or nil for missing).
2223
# missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct
2324

25+
ModelOptionsValidator.validate!(max_iter: max_iter,
26+
tolerance: tolerance,
27+
param_tolerance: param_tolerance,
28+
learning_rate: learning_rate,
29+
decay_factor: decay_factor)
30+
2431
@data = data
2532
@data_array = ResponseDataValidator.validate!(data)
2633
num_rows = @data_array.size

lib/irt_ruby/three_parameter_model.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# frozen_string_literal: true
22

3+
require "irt_ruby/model_options_validator"
34
require "irt_ruby/response_data_validator"
45

56
module IrtRuby
@@ -20,6 +21,12 @@ def initialize(data,
2021
learning_rate: 0.01,
2122
decay_factor: 0.5,
2223
missing_strategy: :ignore)
24+
ModelOptionsValidator.validate!(max_iter: max_iter,
25+
tolerance: tolerance,
26+
param_tolerance: param_tolerance,
27+
learning_rate: learning_rate,
28+
decay_factor: decay_factor)
29+
2330
@data = data
2431
@data_array = ResponseDataValidator.validate!(data)
2532
num_rows = @data_array.size

lib/irt_ruby/two_parameter_model.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# frozen_string_literal: true
22

3+
require "irt_ruby/model_options_validator"
34
require "irt_ruby/response_data_validator"
45

56
module IrtRuby
@@ -16,6 +17,12 @@ class TwoParameterModel
1617
def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
1718
learning_rate: 0.01, decay_factor: 0.5,
1819
missing_strategy: :ignore)
20+
ModelOptionsValidator.validate!(max_iter: max_iter,
21+
tolerance: tolerance,
22+
param_tolerance: param_tolerance,
23+
learning_rate: learning_rate,
24+
decay_factor: decay_factor)
25+
1926
@data = data
2027
@data_array = ResponseDataValidator.validate!(data)
2128
num_rows = @data_array.size

spec/irt_ruby/rasch_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
RSpec.describe IrtRuby::RaschModel do
66
it_behaves_like "response data validation"
7+
it_behaves_like "model optimization option validation"
78

89
let(:data_array) do
910
[

spec/irt_ruby/three_parameter_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
RSpec.describe IrtRuby::ThreeParameterModel do
66
it_behaves_like "response data validation"
7+
it_behaves_like "model optimization option validation"
78

89
let(:data_array) do
910
[

spec/irt_ruby/two_parameter_model_spec.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
RSpec.describe IrtRuby::TwoParameterModel do
66
it_behaves_like "response data validation"
7+
it_behaves_like "model optimization option validation"
78

89
let(:data_array) do
910
[

spec/spec_helper.rb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@
4141
end
4242
end
4343

44+
RSpec.shared_examples "model optimization option validation" do
45+
let(:valid_data) { [[1, 0], [0, 1]] }
46+
47+
{
48+
max_iter: [0, -1, 1.5, "100", nil],
49+
tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil],
50+
param_tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil],
51+
learning_rate: [0, -0.01, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.01, 0), "0.01", nil],
52+
decay_factor: [0, 1, -0.1, 1.1, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.5, 0), "0.5", nil]
53+
}.each do |option, invalid_values|
54+
invalid_values.each do |value|
55+
it "rejects #{option}=#{value.inspect}" do
56+
expect { described_class.new(valid_data, option => value) }.to raise_error(ArgumentError, /\A#{option} /)
57+
end
58+
end
59+
end
60+
end
61+
4462
RSpec.configure do |config|
4563
# Enable flags like --only-failures and --next-failure
4664
config.example_status_persistence_file_path = ".rspec_status"

0 commit comments

Comments
 (0)