Skip to content

Commit 8275477

Browse files
committed
implement __repr__ and more testing
1 parent 076c1b5 commit 8275477

2 files changed

Lines changed: 95 additions & 0 deletions

File tree

open-codegen/opengen/templates/python/python_bindings.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
///
44
use optimization_engine::alm::*;
55

6+
use pyo3::class::basic::PyObjectProtocol;
67
use pyo3::prelude::*;
78
use pyo3::wrap_pyfunction;
89

@@ -114,6 +115,24 @@ impl From<SolverStatusData> for SolverStatus {
114115
}
115116
}
116117

118+
#[pyproto]
119+
impl PyObjectProtocol for SolverStatus {
120+
fn __repr__(&self) -> PyResult<String> {
121+
Ok(format!(
122+
"SolverStatus(exit_status={:?}, num_outer_iterations={}, num_inner_iterations={}, last_problem_norm_fpr={}, f1_infeasibility={}, f2_norm={}, solve_time_ms={}, penalty={}, cost={})",
123+
self.exit_status,
124+
self.num_outer_iterations,
125+
self.num_inner_iterations,
126+
self.last_problem_norm_fpr,
127+
self.f1_infeasibility,
128+
self.f2_norm,
129+
self.solve_time_ms,
130+
self.penalty,
131+
self.cost
132+
))
133+
}
134+
}
135+
117136
#[pyclass]
118137
struct SolverError {
119138
#[pyo3(get)]
@@ -131,11 +150,41 @@ impl From<SolverErrorData> for SolverError {
131150
}
132151
}
133152

153+
#[pyproto]
154+
impl PyObjectProtocol for SolverError {
155+
fn __repr__(&self) -> PyResult<String> {
156+
Ok(format!(
157+
"SolverError(code={}, message={:?})",
158+
self.code,
159+
self.message
160+
))
161+
}
162+
}
163+
134164
#[pyclass]
135165
struct SolverResponse {
136166
payload: SolverResponsePayload,
137167
}
138168

169+
#[pyproto]
170+
impl PyObjectProtocol for SolverResponse {
171+
fn __repr__(&self) -> PyResult<String> {
172+
match &self.payload {
173+
SolverResponsePayload::Ok(status) => Ok(format!(
174+
"SolverResponse(ok=True, exit_status={:?}, num_outer_iterations={}, num_inner_iterations={})",
175+
status.exit_status,
176+
status.num_outer_iterations,
177+
status.num_inner_iterations
178+
)),
179+
SolverResponsePayload::Err(error) => Ok(format!(
180+
"SolverResponse(ok=False, code={}, message={:?})",
181+
error.code,
182+
error.message
183+
)),
184+
}
185+
}
186+
}
187+
139188
#[pymethods]
140189
impl SolverResponse {
141190
fn is_ok(&self) -> bool {

open-codegen/test/test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,34 @@ def test_python_bindings_error_details(self):
409409
error.message
410410
)
411411

412+
def test_python_bindings_initial_guess_error_details(self):
413+
python_bindings = RustBuildTestCase.import_generated_module(
414+
"python_bindings", "python_bindings")
415+
416+
solver = python_bindings.solver()
417+
result = solver.run([1., 2.], initial_guess=[0.0])
418+
self.assertFalse(result.is_ok())
419+
error = result.get()
420+
self.assertEqual(1600, error.code)
421+
self.assertEqual(
422+
"initial guess has incompatible dimensions: provided 1, expected 5",
423+
error.message
424+
)
425+
426+
def test_python_bindings_initial_lagrange_multipliers_error_details(self):
427+
python_bindings = RustBuildTestCase.import_generated_module(
428+
"python_bindings", "python_bindings")
429+
430+
solver = python_bindings.solver()
431+
result = solver.run([1., 2.], initial_lagrange_multipliers=[0.1])
432+
self.assertFalse(result.is_ok())
433+
error = result.get()
434+
self.assertEqual(1700, error.code)
435+
self.assertEqual(
436+
"wrong dimension of Langrange multipliers: provided 1, expected 0",
437+
error.message
438+
)
439+
412440
def test_python_bindings_solver_error_details(self):
413441
python_bindings = RustBuildTestCase.import_generated_module(
414442
"python_bindings", "python_bindings")
@@ -423,6 +451,24 @@ def test_python_bindings_solver_error_details(self):
423451
error.message
424452
)
425453

454+
def test_python_bindings_repr(self):
455+
python_bindings = RustBuildTestCase.import_generated_module(
456+
"python_bindings", "python_bindings")
457+
458+
solver = python_bindings.solver()
459+
460+
ok_response = solver.run([1., 2.])
461+
self.assertIn("SolverResponse(ok=True", repr(ok_response))
462+
ok_status = ok_response.get()
463+
self.assertIn("SolverStatus(", repr(ok_status))
464+
self.assertIn('exit_status="Converged"', repr(ok_status))
465+
466+
error_response = solver.run([1., 2., 3.])
467+
self.assertIn("SolverResponse(ok=False", repr(error_response))
468+
error = error_response.get()
469+
self.assertIn("SolverError(", repr(error))
470+
self.assertIn("code=3003", repr(error))
471+
426472
def test_rectangle_empty(self):
427473
xmin = [-1, 2]
428474
xmax = [-2, 4]

0 commit comments

Comments
 (0)