forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconftest.py
More file actions
113 lines (93 loc) · 3.33 KB
/
conftest.py
File metadata and controls
113 lines (93 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Fixtures for AI Registry integration tests."""
import os
import tempfile
import uuid
import zipfile
import pytest
import boto3
from sagemaker.ai_registry.air_utils import _get_default_bucket
from sagemaker.train.defaults import TrainDefaults
@pytest.fixture
def unique_name():
"""Generate unique name for testing."""
return f"test-{uuid.uuid4().hex[:8]}"
@pytest.fixture
def test_bucket():
"""Get test S3 bucket name."""
return _get_default_bucket()
@pytest.fixture
def test_role():
"""Get test IAM role ARN."""
return TrainDefaults.get_role()
@pytest.fixture
def sample_jsonl_file():
"""Create sample JSONL dataset file."""
content = """{"prompt": "What is AI?", "completion": "AI is artificial intelligence."}
{"prompt": "What is ML?", "completion": "ML is machine learning."}
{"prompt": "What is DL?", "completion": "DL is deep learning."}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write(content)
f.flush() # Ensure content is written to disk
yield f.name
os.unlink(f.name)
@pytest.fixture
def sample_lambda_code():
"""Create sample Lambda function code as zip."""
code = '''import json
def lambda_handler(event, context):
return {"statusCode": 200, "body": json.dumps({"score": 0.8})}
'''
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as zip_f:
with zipfile.ZipFile(zip_f.name, 'w') as zf:
zf.writestr('lambda_function.py', code)
yield zip_f.name
os.unlink(zip_f.name)
@pytest.fixture
def sample_prompt_file():
"""Create sample prompt file."""
content = "Evaluate the response: {response}"
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(content)
yield f.name
os.unlink(f.name)
@pytest.fixture
def sample_hub_content_document():
"""Create sample hub content document."""
from sagemaker.ai_registry.dataset_utils import DataSetHubContentDocument
from sagemaker.ai_registry.air_constants import (
DATASET_DEFAULT_TYPE, DATASET_DEFAULT_CONVERSATION_ID, DATASET_DEFAULT_CHECKPOINT_ID
)
document = DataSetHubContentDocument(
dataset_s3_bucket=_get_default_bucket(),
dataset_s3_prefix="test",
dataset_context_s3_uri="\"\"",
dataset_type=DATASET_DEFAULT_TYPE,
dataset_role_arn=TrainDefaults.get_role(),
conversation_id=DATASET_DEFAULT_CONVERSATION_ID,
conversation_checkpoint_id=DATASET_DEFAULT_CHECKPOINT_ID,
dependencies=[],
)
return document.to_json()
@pytest.fixture
def cleanup_list():
"""Track resources for cleanup."""
resources = []
yield resources
for resource in resources:
try:
resource.delete()
except Exception:
pass