Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 78 additions & 31 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,18 @@ def execute(self, context: Context) -> str | None:
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

job = self.hook.get_job_description(self.job_id)
# Persist operator links before deferring so they're available in the UI
# Reuse job description to reduce API calls
job = self._persist_links(context)
job_status = job.get("status")
if job_status == self.hook.SUCCESS_STATE:
Comment thread
kakatur marked this conversation as resolved.
# Job already completed - persist CloudWatch logs
self._persist_cloudwatch_link(context)
self.log.info("Job completed.")
return self.job_id
if job_status == self.hook.FAILURE_STATE:
# Job already failed - persist CloudWatch logs
self._persist_cloudwatch_link(context)
Comment thread
kakatur marked this conversation as resolved.
raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state")
if job_status in self.hook.INTERMEDIATE_STATES:
self.defer(
Expand All @@ -252,14 +258,17 @@ def execute(self, context: Context) -> str | None:
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
validated_event = validate_execute_complete_event(event)

# Set job_id first so CloudWatch link can be persisted even on failure
self.job_id = validated_event["job_id"]

# Persist CloudWatch logs for both success and failure
self._persist_cloudwatch_link(context)

if validated_event["status"] != "success":
raise AirflowException(f"Error while running job: {validated_event}")

self.job_id = validated_event["job_id"]

# Fetch logs if awslogs_enabled
if self.awslogs_enabled:
self.monitor_job(context) # fetch logs, no need to return
# Check job success (already know status is "success" from above)
self.hook.check_job_success(self.job_id)

self.log.info("Job completed successfully for job_id: %s", self.job_id)
return self.job_id
Expand Down Expand Up @@ -330,18 +339,25 @@ def submit_job(self, context: Context):
job_id=self.job_id,
)

def monitor_job(self, context: Context):
def _persist_links(
self,
context: Context,
job_description: dict | None = None,
) -> dict:
"""
Monitor an AWS Batch job.
Persist job definition and queue links for UI display.

This can raise an exception or an AirflowTaskTimeout if the task was
created with ``execution_timeout``.
:param context: Task context
:param job_description: Optional pre-fetched job description to avoid redundant API calls
:return: Job description dict
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

# Fetch job description (needed for return value and link persistence)
job_desc = job_description or self.hook.get_job_description(job_id=self.job_id)

try:
job_desc = self.hook.get_job_description(self.job_id)
job_definition_arn = job_desc["jobDefinition"]
job_queue_arn = job_desc["jobQueue"]
self.log.info(
Expand All @@ -368,33 +384,37 @@ def monitor_job(self, context: Context):
job_queue_arn=job_queue_arn,
)

if self.awslogs_enabled:
if self.waiters:
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)
return job_desc

def _persist_cloudwatch_link(self, context: Context) -> None:
"""
Persist CloudWatch logs link if available.

:param context: Task context
"""
Comment thread
kakatur marked this conversation as resolved.
if not self.do_xcom_push:
return

Comment thread
kakatur marked this conversation as resolved.
if not context or "ti" not in context:
return

if not self.job_id:
return

awslogs = []
try:
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
except AirflowException as ae:
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae)
self.log.warning(
"Unable to retrieve CloudWatch log information for AWS Batch job (%s): %s",
self.job_id,
ae,
)
return

if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(link_builder.format_link(**log))
self.log.info("AWS Batch job (%s) CloudWatch Events details found.", self.job_id)
if len(awslogs) > 1:
# there can be several log streams on multi-node jobs
self.log.warning(
"out of all those logs, we can only link to one in the UI. Using the first one."
)
self.log.warning("Multiple log streams found. Linking to the first one in the UI.")

CloudWatchEventsLink.persist(
context=context,
Expand All @@ -404,6 +424,33 @@ def monitor_job(self, context: Context):
**awslogs[0],
)

def monitor_job(self, context: Context):
"""
Monitor an AWS Batch job.

This can raise an exception or an AirflowTaskTimeout if the task was
created with ``execution_timeout``.
"""
if not self.job_id:
raise ValueError("AWS Batch job - job_id was not found")

# Persist job definition and queue links
self._persist_links(context)

if self.awslogs_enabled:
if self.waiters:
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)

# After job completes, persist CloudWatch logs
self._persist_cloudwatch_link(context)

self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)

Expand Down
55 changes: 42 additions & 13 deletions providers/amazon/tests/unit/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,24 @@ def test_init_defaults(self):
def test_template_fields_overrides(self):
validate_template_fields(self.batch)

@mock.patch.object(BatchClientHook, "get_job_all_awslogs_info")
@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")
def test_execute_without_failures(self, check_mock, wait_mock, job_description_mock):
def test_execute_without_failures(
self, check_mock, wait_mock, job_description_mock, get_job_all_awslogs_info_mock
):
# JOB_ID is in RESPONSE_WITHOUT_FAILURES
self.client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
self.batch.job_id = None
self.batch.waiters = None # use default wait
get_job_all_awslogs_info_mock.return_value = []
# Enable xcom push so _persist_cloudwatch_link actually runs
self.batch.do_xcom_push = True
# Use a real dict so "ti" in context works correctly
context = {"ti": mock.MagicMock()}

self.batch.execute(self.mock_context)
self.batch.execute(context)

self.client_mock.submit_job.assert_called_once_with(
jobQueue="queue",
Expand All @@ -175,9 +183,10 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m
wait_mock.assert_called_once_with(JOB_ID)
check_mock.assert_called_once_with(JOB_ID)

# First Call: Retrieve Batch Queue and Job Definition
# Second Call: Retrieve CloudWatch information
assert job_description_mock.call_count == 2
# get_job_description called once in _persist_links
assert job_description_mock.call_count == 1
# get_job_all_awslogs_info called once in _persist_cloudwatch_link
get_job_all_awslogs_info_mock.assert_called_once_with(JOB_ID)

def test_execute_with_failures(self):
self.client_mock.submit_job.side_effect = Exception()
Expand Down Expand Up @@ -545,46 +554,66 @@ def test_monitor_job_with_logs(

@patch.object(BatchOperator, "log", new_callable=MagicMock)
@patch("airflow.providers.amazon.aws.operators.batch.validate_execute_complete_event")
@patch.object(BatchOperator, "monitor_job")
def test_execute_complete_success_with_logs(self, mock_monitor_job, mock_validate, mock_log):
@patch.object(BatchClientHook, "check_job_success")
@patch.object(BatchClientHook, "get_job_all_awslogs_info")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_execute_complete_success_with_logs(
self, mock_client, mock_get_job_all_awslogs_info, mock_check_job_success, mock_validate, mock_log
):
# Setup
mock_validate.return_value = {"status": "success", "job_id": "12345"}
mock_get_job_all_awslogs_info.return_value = [
{"awslogs_group": "/aws/batch/job", "awslogs_stream_name": "stream1"}
]
batch = BatchOperator(
task_id="test_task",
job_name=JOB_NAME,
job_queue="dummy_queue",
job_definition="dummy_definition",
deferrable=True,
awslogs_enabled=True,
do_xcom_push=True, # Enable xcom push so _persist_cloudwatch_link runs
)
# Add task instance to context
context = {"ti": mock.MagicMock()}

result = batch.execute_complete(context={}, event={"dummy": "event"})
result = batch.execute_complete(context=context, event={"dummy": "event"})

# Assertion
assert result == "12345"
mock_monitor_job.assert_called_once_with({})
mock_get_job_all_awslogs_info.assert_called_once_with("12345")
mock_check_job_success.assert_called_once_with("12345")
mock_log.info.assert_called_with("Job completed successfully for job_id: %s", "12345")

@patch.object(BatchOperator, "log", new_callable=MagicMock)
@patch("airflow.providers.amazon.aws.operators.batch.validate_execute_complete_event")
@patch.object(BatchOperator, "monitor_job")
def test_execute_complete_success_without_logs(self, mock_monitor_job, mock_validate, mock_log):
@patch.object(BatchClientHook, "check_job_success")
@patch.object(BatchClientHook, "get_job_all_awslogs_info")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def test_execute_complete_success_without_logs(
self, mock_client, mock_get_job_all_awslogs_info, mock_check_job_success, mock_validate, mock_log
):
# Setup
mock_validate.return_value = {"status": "success", "job_id": "12345"}
mock_get_job_all_awslogs_info.return_value = []
batch = BatchOperator(
task_id="test_task",
job_name=JOB_NAME,
job_queue="dummy_queue",
job_definition="dummy_definition",
deferrable=True,
awslogs_enabled=False,
do_xcom_push=True, # Enable xcom push so _persist_cloudwatch_link runs
)
# Add task instance to context
context = {"ti": mock.MagicMock()}

result = batch.execute_complete(context={}, event={"dummy": "event"})
result = batch.execute_complete(context=context, event={"dummy": "event"})

# Assertions
assert result == "12345"
mock_monitor_job.assert_not_called()
mock_get_job_all_awslogs_info.assert_called_once_with("12345")
mock_check_job_success.assert_called_once_with("12345")
mock_log.info.assert_called_with("Job completed successfully for job_id: %s", "12345")

@patch("airflow.providers.amazon.aws.operators.batch.validate_execute_complete_event")
Expand Down
Loading