Skip to content

Commit e161199

Browse files
Fix: hardcode handler_name = "lambda_function.lambda_handler" to match the zip entry name. (#5692)
* Fix lambda function handler name * Add integ test * Update integ test to wait for lambda call
1 parent ee420cc commit e161199

File tree

4 files changed

+61
-1
lines changed

4 files changed

+61
-1
lines changed

sagemaker-train/src/sagemaker/ai_registry/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def _create_lambda_function(cls, name: str, source_file: str, role: Optional[str
382382
# Create Lambda function
383383
lambda_client = boto3.client("lambda")
384384
function_name = f"SageMaker-evaluator-{name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
385-
handler_name = f"{os.path.splitext(os.path.basename(source_file))[0]}.lambda_handler"
385+
handler_name = "lambda_function.lambda_handler"
386386

387387
try:
388388
lambda_response = lambda_client.create_function(

sagemaker-train/tests/integ/ai_registry/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ def sample_jsonl_file():
5656
os.unlink(f.name)
5757

5858

59+
@pytest.fixture
60+
def sample_lambda_py_file():
61+
"""Create a raw Python Lambda file with a non-default filename to test handler derivation."""
62+
code = '''import json
63+
def lambda_handler(event, context):
64+
return {"statusCode": 200, "body": json.dumps({"score": 0.9})}
65+
'''
66+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', prefix='my_custom_evaluator_', delete=False) as f:
67+
f.write(code)
68+
f.flush()
69+
os.fsync(f.fileno())
70+
fname = f.name
71+
yield fname
72+
os.unlink(fname)
73+
74+
5975
@pytest.fixture
6076
def sample_lambda_code():
6177
"""Create sample Lambda function code as zip."""

sagemaker-train/tests/integ/ai_registry/test_evaluator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,47 @@ def test_create_reward_function_from_local_code(self, unique_name, sample_lambda
8181
assert evaluator.method == EvaluatorMethod.BYOC
8282
assert evaluator.reference is not None
8383

84+
def test_create_reward_function_from_local_py_file_and_invoke(
85+
self, unique_name, sample_lambda_py_file, test_role, cleanup_list
86+
):
87+
"""End-to-end test: create evaluator from a raw .py file with non-default name and invoke it.
88+
89+
Regression test for the handler name bug where the Lambda was created with an incorrect
90+
handler derived from the source filename instead of 'lambda_function.lambda_handler'.
91+
"""
92+
import json
93+
import boto3
94+
95+
evaluator = Evaluator.create(
96+
name=unique_name,
97+
type=REWARD_FUNCTION,
98+
source=sample_lambda_py_file,
99+
role=test_role,
100+
wait=True, # wait for Lambda to be active
101+
)
102+
cleanup_list.append(evaluator)
103+
assert evaluator.method == EvaluatorMethod.BYOC
104+
assert evaluator.reference is not None
105+
106+
# Wait for Lambda to become Active before invoking
107+
lambda_client = boto3.client("lambda")
108+
waiter = lambda_client.get_waiter("function_active_v2")
109+
waiter.wait(FunctionName=evaluator.reference)
110+
111+
# Invoke the Lambda directly to verify the handler is correct
112+
lambda_client = boto3.client("lambda")
113+
response = lambda_client.invoke(
114+
FunctionName=evaluator.reference,
115+
InvocationType="RequestResponse",
116+
Payload=json.dumps({"input": "test"}).encode(),
117+
)
118+
assert response["StatusCode"] == 200
119+
assert "FunctionError" not in response, (
120+
f"Lambda invocation failed with error: {response.get('FunctionError')}"
121+
)
122+
result = json.loads(response["Payload"].read())
123+
assert result.get("statusCode") == 200
124+
84125
def test_get_evaluator(self, unique_name, sample_prompt_file, cleanup_list):
85126
"""Test retrieving evaluator by name."""
86127
try:

sagemaker-train/tests/unit/ai_registry/test_evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def test_create_with_byoc(self, mock_air_hub, mock_boto3):
7777

7878
assert evaluator.method == EvaluatorMethod.BYOC
7979
mock_air_hub.upload_to_s3.assert_called_once()
80+
mock_lambda_client.create_function.assert_called_once()
81+
call_kwargs = mock_lambda_client.create_function.call_args[1]
82+
assert call_kwargs["Handler"] == "lambda_function.lambda_handler"
8083

8184
@patch('sagemaker.ai_registry.evaluator.AIRHub')
8285
def test_get_all(self, mock_air_hub):

0 commit comments

Comments
 (0)