Skip to content

Commit aca31a9

Browse files
authored
Merge pull request #114 from pavi2707/main
Add model testing results to pytest xml report
2 parents 45a15f7 + 82e7675 commit aca31a9

2 files changed

Lines changed: 12 additions & 2 deletions

File tree

tests/models/test_decoders.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,12 @@ def persistent_model():
473473
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
474474
)
475475
def test_common_shapes(
476-
model_path, batch_size, seq_length, max_new_tokens, persistent_model
476+
model_path,
477+
batch_size,
478+
seq_length,
479+
max_new_tokens,
480+
persistent_model,
481+
record_property,
477482
):
478483
torch.manual_seed(42)
479484
torch.set_grad_enabled(False)
@@ -735,6 +740,9 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
735740
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
736741
dprint(f"mean diff failure rate: {diff_failure_rate}")
737742
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
743+
# Add failure rates to xml report
744+
record_property("mean_diff_failure_rate", diff_failure_rate)
745+
record_property("cross_entropy_loss_failure_rate", ce_failure_rate)
738746
if "mean_diff" not in skip_assertions:
739747
assert diff_failure_rate < failure_rate_threshold, (
740748
f"failure rate for mean diff was too high: {diff_failure_rate}"

tests/models/test_encoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def reset_compiler():
104104

105105

106106
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_shapes)
107-
def test_common_shapes(model_path, batch_size, seq_length):
107+
def test_common_shapes(model_path, batch_size, seq_length, record_property):
108108
os.environ["COMPILATION_MODE"] = "offline"
109109

110110
dprint(
@@ -187,5 +187,7 @@ def test_common_shapes(model_path, batch_size, seq_length):
187187

188188
abs_mean_diff = sum(diffs) / len(diffs)
189189
print(f"absolute mean diff: {abs_mean_diff}")
190+
# Add value to xml report
191+
record_property("absolute_mean_diff", float(abs_mean_diff))
190192

191193
assert abs_mean_diff < validation_diff_threshold

0 commit comments

Comments
 (0)