forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_benchmark_evaluator.py
More file actions
412 lines (329 loc) · 17.3 KB
/
test_benchmark_evaluator.py
File metadata and controls
412 lines (329 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Integration tests for BenchmarkEvaluator"""
from __future__ import absolute_import
import pytest
import logging
from sagemaker.train.evaluate import (
BenchMarkEvaluator,
get_benchmarks,
get_benchmark_properties,
EvaluationPipelineExecution,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s - %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Test timeout configuration (in seconds)
EVALUATION_TIMEOUT_SECONDS = 14400 # 4 hours
# Test configuration values from benchmark_demo.ipynb
# TEST_CONFIG = {
# "model_package_arn": "arn:aws:sagemaker:us-west-2:052150106756:model-package/test-finetuned-models-gamma/28",
# "dataset_s3_uri": "s3://sagemaker-us-west-2-052150106756/studio-users/d20251107t195443/datasets/2025-11-07T19-55-37-609Z/zc_test.jsonl",
# "s3_output_path": "s3://mufi-test-serverless-smtj/eval/",
# "mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:052150106756:mlflow-tracking-server/mmlu-eval-experiment",
# "model_package_group_arn": "arn:aws:sagemaker:us-west-2:052150106756:model-package-group/example-name-aovqo",
# "region": "us-west-2",
# }
TEST_CONFIG = {
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
"region": "us-west-2",
}
# Base model only evaluation configuration (from commented section in notebook)
BASE_MODEL_ONLY_CONFIG = {
"base_model_id": "meta-textgeneration-llama-3-2-1b-instruct",
"dataset_s3_uri": "s3://sagemaker-us-west-2-052150106756/studio-users/d20251107t195443/datasets/2025-11-07T19-55-37-609Z/zc_test.jsonl",
"s3_output_path": "s3://mufi-test-serverless-smtj/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:052150106756:mlflow-tracking-server/mmlu-eval-experiment",
"region": "us-west-2",
}
# Nova model evaluation configuration (from commented section in notebook)
NOVA_CONFIG = {
"model_package_arn": "arn:aws:sagemaker:us-east-1:052150106756:model-package/test-nova-finetuned-models/3",
"dataset_s3_uri": "s3://sagemaker-us-east-1-052150106756/studio-users/d20251107t195443/datasets/2025-11-07T19-55-37-609Z/zc_test.jsonl",
"s3_output_path": "s3://mufi-test-serverless-iad/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-east-1:052150106756:mlflow-tracking-server/mlflow-prod-server",
"model_package_group_arn": "arn:aws:sagemaker:us-east-1:052150106756:model-package-group/test-nova-finetuned-models",
"region": "us-east-1",
}
@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestBenchmarkEvaluatorIntegration:
"""Integration tests for BenchmarkEvaluator with fine-tuned model package"""
def test_get_benchmarks_and_properties(self):
"""Test getting available benchmarks and their properties"""
# Get available benchmarks
Benchmark = get_benchmarks()
# Verify it's an enum
assert hasattr(Benchmark, "__members__")
# Verify MMLU is available
assert hasattr(Benchmark, "MMLU")
# Get properties for MMLU benchmark
properties = get_benchmark_properties(benchmark=Benchmark.MMLU)
# Verify properties structure
assert isinstance(properties, dict)
assert "modality" in properties
assert "description" in properties
assert "metrics" in properties
assert "strategy" in properties
logger.info(f"MMLU properties: {properties}")
def test_benchmark_evaluation_full_flow(self):
"""
Test complete benchmark evaluation flow with fine-tuned model package.
This test mirrors the flow from benchmark_demo.ipynb and covers:
1. Creating BenchMarkEvaluator with MMLU benchmark
2. Accessing hyperparameters
3. Starting evaluation
4. Monitoring execution
5. Waiting for completion
6. Viewing results
7. Retrieving execution by ARN
8. Listing all evaluations
Test configuration values are taken directly from the notebook example.
"""
# Get benchmarks
Benchmark = get_benchmarks()
# Step 1: Create BenchmarkEvaluator
logger.info("Creating BenchmarkEvaluator with MMLU benchmark")
# Create evaluator (matching notebook configuration)
evaluator = BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
model_package_group=TEST_CONFIG["model_package_group_arn"],
base_eval_name="integ-test-gen-qa-eval",
)
# Verify evaluator was created
assert evaluator is not None
assert evaluator.benchmark == Benchmark.MMLU
assert evaluator.model == TEST_CONFIG["model_package_arn"]
logger.info(f"Created evaluator: {evaluator.base_eval_name}")
# Step 2: Access hyperparameters
logger.info("Accessing hyperparameters")
hyperparams = evaluator.hyperparameters.to_dict()
# Verify hyperparameters structure
assert isinstance(hyperparams, dict)
assert "max_new_tokens" in hyperparams
assert "temperature" in hyperparams
logger.info(f"Hyperparameters: {hyperparams}")
# Step 3: Start evaluation
logger.info("Starting evaluation execution")
execution = evaluator.evaluate()
# Verify execution was created
assert execution is not None
assert execution.arn is not None
assert execution.name is not None
assert execution.eval_type is not None
logger.info(f"Pipeline Execution ARN: {execution.arn}")
logger.info(f"Initial Status: {execution.status.overall_status}")
# Step 4: Monitor execution
logger.info("Refreshing execution status")
execution.refresh()
# Verify status was updated
assert execution.status.overall_status is not None
# Log step details if available
if execution.status.step_details:
logger.info("Step Details:")
for step in execution.status.step_details:
logger.info(f" {step.name}: {step.status}")
# Step 5: Wait for completion
logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)")
try:
execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS)
logger.info(f"Final Status: {execution.status.overall_status}")
# Verify completion
assert execution.status.overall_status == "Succeeded"
# Step 6: View results
logger.info("Displaying results")
execution.show_results()
# Verify S3 output path is set
assert execution.s3_output_path is not None
logger.info(f"Results stored at: {execution.s3_output_path}")
except Exception as e:
logger.error(f"Evaluation failed or timed out: {e}")
logger.error(f"Final status: {execution.status.overall_status}")
if execution.status.failure_reason:
logger.error(f"Failure reason: {execution.status.failure_reason}")
# Log step failures
if execution.status.step_details:
for step in execution.status.step_details:
if "failed" in step.status.lower():
logger.error(f"Failed step: {step.name}")
if step.failure_reason:
logger.error(f" Reason: {step.failure_reason}")
# Re-raise to fail the test
raise
# Step 7: Retrieve execution by ARN
logger.info("Retrieving execution by ARN")
retrieved_execution = EvaluationPipelineExecution.get(
arn=execution.arn,
region=TEST_CONFIG["region"]
)
# Verify retrieved execution matches
assert retrieved_execution.arn == execution.arn
logger.info(f"Retrieved execution status: {retrieved_execution.status.overall_status}")
# Step 8: List all benchmark evaluations
logger.info("Listing all benchmark evaluations")
all_executions_iter = BenchMarkEvaluator.get_all(region=TEST_CONFIG["region"])
all_executions = list(all_executions_iter)
if all_executions:
# Verify our execution is in the list
execution_arns = [exec.arn for exec in all_executions]
assert execution.arn in execution_arns
logger.info("Integration test completed successfully")
def test_benchmark_evaluator_validation(self):
"""Test BenchmarkEvaluator validation of inputs"""
Benchmark = get_benchmarks()
# Test invalid benchmark type
with pytest.raises(ValueError):
BenchMarkEvaluator(
benchmark="invalid_benchmark",
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
)
# Test invalid MLflow ARN format
with pytest.raises(ValueError, match="Invalid MLFlow resource ARN"):
BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
mlflow_resource_arn="invalid-arn",
)
logger.info("Validation tests passed")
def test_benchmark_subtasks_validation(self):
"""Test benchmark subtask validation"""
Benchmark = get_benchmarks()
# Test valid subtask for MMLU (has subtask support)
evaluator = BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
subtasks="abstract_algebra",
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
)
assert evaluator.subtasks == "abstract_algebra"
# Test invalid subtask for benchmark without subtask support
with pytest.raises(ValueError, match="Invalid subtask 'invalid' for benchmark 'mmlu'"):
BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
subtasks=["invalid"],
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
)
logger.info("Subtask validation tests passed")
@pytest.mark.skip(reason="Base model only evaluation - to be enabled when needed")
def test_benchmark_evaluation_base_model_only(self):
"""
Test benchmark evaluation with base model only (no fine-tuned model).
This test uses a JumpStart model ID directly instead of a model package ARN.
Configuration from commented section in benchmark_demo.ipynb.
Note: This test is currently skipped. Remove the @pytest.mark.skip decorator
when you want to enable it.
"""
# Get benchmarks
Benchmark = get_benchmarks()
logger.info("Creating BenchmarkEvaluator with base model only (JumpStart model ID)")
# Create evaluator with JumpStart model ID (no model package)
evaluator = BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=BASE_MODEL_ONLY_CONFIG["base_model_id"],
s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"],
mlflow_resource_arn=BASE_MODEL_ONLY_CONFIG["mlflow_tracking_server_arn"],
base_eval_name="integ-test-base-model-only",
# Note: model_package_group not needed for JumpStart models
)
# Verify evaluator was created
assert evaluator is not None
assert evaluator.benchmark == Benchmark.MMLU
assert evaluator.model == BASE_MODEL_ONLY_CONFIG["base_model_id"]
logger.info(f"Created evaluator: {evaluator.base_eval_name}")
# Start evaluation
logger.info("Starting evaluation execution")
execution = evaluator.evaluate()
# Verify execution was created
assert execution is not None
assert execution.arn is not None
assert execution.name is not None
logger.info(f"Pipeline Execution ARN: {execution.arn}")
logger.info(f"Initial Status: {execution.status.overall_status}")
# Wait for completion
logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)")
execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS)
# Verify completion
assert execution.status.overall_status == "Succeeded"
logger.info("Base model only evaluation completed successfully")
@pytest.mark.skip(reason="Nova model evaluation - to be enabled when needed")
def test_benchmark_evaluation_nova_model(self):
"""
Test benchmark evaluation with Nova model.
This test uses a Nova fine-tuned model package in us-east-1 region.
Configuration from commented section in benchmark_demo.ipynb.
Note: This test is currently skipped. Remove the @pytest.mark.skip decorator
when you want to enable it.
"""
# Get benchmarks
Benchmark = get_benchmarks()
logger.info("Creating BenchmarkEvaluator with Nova model")
# Create evaluator with Nova model package
evaluator = BenchMarkEvaluator(
benchmark=Benchmark.MMLU,
model=NOVA_CONFIG["model_package_arn"],
s3_output_path=NOVA_CONFIG["s3_output_path"],
mlflow_resource_arn=NOVA_CONFIG["mlflow_tracking_server_arn"],
model_package_group=NOVA_CONFIG["model_package_group_arn"],
base_eval_name="integ-test-nova-eval",
region=NOVA_CONFIG["region"],
)
# Verify evaluator was created
assert evaluator is not None
assert evaluator.benchmark == Benchmark.MMLU
assert evaluator.model == NOVA_CONFIG["model_package_arn"]
assert evaluator.region == NOVA_CONFIG["region"]
logger.info(f"Created evaluator: {evaluator.base_eval_name}")
# Access hyperparameters (Nova models may have different hyperparameters)
logger.info("Accessing hyperparameters")
hyperparams = evaluator.hyperparameters.to_dict()
# Verify hyperparameters structure
assert isinstance(hyperparams, dict)
logger.info(f"Hyperparameters: {hyperparams}")
# Start evaluation
logger.info("Starting evaluation execution")
execution = evaluator.evaluate()
# Verify execution was created
assert execution is not None
assert execution.arn is not None
assert execution.name is not None
logger.info(f"Pipeline Execution ARN: {execution.arn}")
logger.info(f"Initial Status: {execution.status.overall_status}")
# Monitor execution
execution.refresh()
logger.info(f"Status after refresh: {execution.status.overall_status}")
# Wait for completion
logger.info("Waiting for evaluation to complete (timeout: 1 hour)")
execution.wait(target_status="Succeeded", poll=30, timeout=3600)
# Verify completion
assert execution.status.overall_status == "Succeeded"
# View results
logger.info("Displaying results")
execution.show_results()
logger.info("Nova model evaluation completed successfully")