Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 50 additions & 59 deletions dags/import_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
*DAG ID*: {{ dag.dag_id }}
*Execution Time*: {{ execution_date }}
"""
import_sql_failure_slack_msg = """
:red_circle: Import SQL Failed. Please check the notification file in the Airflow logs.
import_ch_failure_slack_msg = """
:red_circle: ClickHouse Import Failed. Please check the notification file in the Airflow logs.
*DAG ID*: {{ dag.dag_id }}
*Execution Time*: {{ execution_date }}
*Log Url*: {{ import_sql_log_url }}
*Log Url*: {{ import_ch_log_url }}
"""
import_sql_success_slack_msg = """
:large_green_circle: Import SQL Success!
import_ch_success_slack_msg = """
:large_green_circle: ClickHouse Import Success!
*DAG ID*: {{ dag.dag_id }}
*Execution Time*: {{ execution_date }}
*Log Url*: {{ import_sql_log_url }}
*Log Url*: {{ import_ch_log_url }}
"""
dag_failure_slack_webhook_notification = send_slack_webhook_notification(
slack_webhook_conn_id="slack_default", text=fail_slack_msg
Expand Down Expand Up @@ -86,10 +86,13 @@ class ImporterConfig:
schedule_interval: Optional[str] = None


def _script(scripts_dir: str, script_name: str, *args: object) -> str:
def _script(scripts_dir: str, script_name: str, *args: object, source_automation_env: bool = False) -> str:
parts = [f"{scripts_dir}/{script_name}"]
parts.extend(str(arg) for arg in args)
return " ".join(parts)
cmd = " ".join(parts)
if source_automation_env:
return f"source {scripts_dir}/automation-environment.sh && {cmd}"
return cmd


def build_import_dag(config: ImporterConfig) -> DAG:
Expand Down Expand Up @@ -124,33 +127,33 @@ def build_import_dag(config: ImporterConfig) -> DAG:
@task
def get_data_repos(repos: list[str]) -> str:
return " ".join(repos)

# run this task even if import_sql failed
# run this task even if import_direct_to_clickhouse failed
@task(trigger_rule=TriggerRule.ALL_DONE)
def send_update_notification(notification_filepath: str, ssh_conn_id: str) -> None:
"""
Sends a Slack message to the #airflow-logs channel with a link to the import_sql logs URL.
Sends a Slack message to the #airflow-logs channel with a link to the import_direct_to_clickhouse logs URL.
This tells the curators whether there were any studies that suceeded or failed to import during a given run.
To avoid confusion -- we run this task towards the end of the DAG
(eg. after the transfer_deployment step) because we don't want to
send a success message before the entire import run completes.
"""

# Get the log URL for the import_sql task
# Get the log URL for the import_direct_to_clickhouse task
context = get_current_context()
dag_run = context.get("dag_run")
import_sql_ti = None
import_ch_ti = None
if dag_run is not None:
import_sql_ti = dag_run.get_task_instance("import_sql", map_index=0)
import_sql_log_url = import_sql_ti.log_url if import_sql_ti is not None else ""
if not import_sql_log_url:
logger.warning("Could not determine import_sql log url; skipping Slack notification.")
import_ch_ti = dag_run.get_task_instance("import_direct_to_clickhouse", map_index=0)
import_ch_log_url = import_ch_ti.log_url if import_ch_ti is not None else ""
if not import_ch_log_url:
logger.warning("Could not determine import_direct_to_clickhouse log url; skipping Slack notification.")
raise AirflowSkipException()

import_sql_failed = (
import_sql_ti is not None and import_sql_ti.state == State.FAILED
import_ch_failed = (
import_ch_ti is not None and import_ch_ti.state == State.FAILED
)
if not import_sql_failed:
if not import_ch_failed:
# Read the notification file from the remote node to check if any studies failed
try:
ssh_hook = SSHHook(ssh_conn_id=ssh_conn_id)
Expand All @@ -160,21 +163,21 @@ def send_update_notification(notification_filepath: str, ssh_conn_id: str) -> No
)
if exit_status != 0:
logger.warning("Notification file not found at %s; treating as failure", notification_filepath)
import_sql_failed = True
import_ch_failed = True
else:
notification_content = notif_contents.decode("utf-8")
ERROR_STRING = "The following studies had errors during import"
import_sql_failed = (ERROR_STRING in notification_content)
import_ch_failed = (ERROR_STRING in notification_content)
except Exception as exc:
logger.warning("Could not read notification file from remote node; skipping Slack notification")
logger.warning("Stack trace:")
logger.warning(exc)
raise AirflowSkipException() from exc

# Build the msg and send to Slack
msg_template = import_sql_failure_slack_msg if import_sql_failed else import_sql_success_slack_msg
msg_template = import_ch_failure_slack_msg if import_ch_failed else import_ch_success_slack_msg
rendered_message = Template(msg_template).render(
import_sql_log_url=import_sql_log_url,
import_ch_log_url=import_ch_log_url,
**context,
)
SlackWebhookHook(slack_webhook_conn_id="slack_default").send(text=rendered_message)
Expand All @@ -189,27 +192,35 @@ def send_update_notification(notification_filepath: str, ssh_conn_id: str) -> No
db_properties_filepath,
color_swap_config_filepath,
),
"scale_up_rds_node": _script(
"clone_database": _script(
scripts_dir,
"scale-rds.sh",
"up",
"airflow-clone-db.sh",
importer,
color_swap_config_filepath,
scripts_dir,
db_properties_filepath,
),
"clone_database": _script(
"create_derived_tables": _script(
scripts_dir,
"airflow-clone-db.sh",
"airflow-create-derived-tables.sh",
importer,
scripts_dir,
db_properties_filepath,
),
"set_import_complete": _script(
scripts_dir,
"set_update_process_state.sh",
db_properties_filepath,
"complete",
source_automation_env=True,
),
"fetch_data": _script(
scripts_dir,
"data_source_repo_clone_manager.sh",
data_source_properties_filepath,
"pull",
importer,
data_repos,
source_automation_env=True,
),
"setup_import": _script(
scripts_dir,
Expand All @@ -218,58 +229,35 @@ def send_update_notification(notification_filepath: str, ssh_conn_id: str) -> No
scripts_dir,
db_properties_filepath,
),
"import_sql": _script(
# reuse the old import-sql script for now
"import_direct_to_clickhouse": _script(
scripts_dir,
"airflow-import-sql.sh",
importer,
scripts_dir,
db_properties_filepath,
notification_filepath,
),
"import_clickhouse": _script(
scripts_dir,
"airflow-import-clickhouse.sh",
importer,
scripts_dir,
db_properties_filepath,
),
"scale_down_rds_node": _script(
scripts_dir,
"scale-rds.sh",
"down",
importer,
color_swap_config_filepath,
# Normally, we would verify that we are in a "scaled up" state before trying to scale down.
# However, if the DAG run failed before "scale_up_rds_node" completed successfully,
# we may still be in a "scaled down" state when we run the scale down task
# (which runs regardless of upstream failures).
# In those cases -- skip verifying that we're in a scaled down state
"{{ '' if (dag_run.get_task_instance('scale_up_rds_node', map_index=ti.map_index) and dag_run.get_task_instance('scale_up_rds_node', map_index=ti.map_index).state == 'success') else '--skip-pre-validation' }}",
),
"transfer_deployment": _script(
scripts_dir,
"airflow-transfer-deployment.sh",
scripts_dir,
db_properties_filepath,
color_swap_config_filepath,
),
"clear_persistence_caches": _script(
scripts_dir,
"airflow-clear-persistence-caches.sh",
importer,
scripts_dir,
),
"set_import_running": _script(
scripts_dir,
"set_update_process_state.sh",
db_properties_filepath,
"running",
source_automation_env=True,
),
"set_import_abandoned": _script(
scripts_dir,
"set_update_process_state.sh",
db_properties_filepath,
"abandoned",
source_automation_env=True,
),
"cleanup_data": _script(
scripts_dir,
Expand All @@ -278,6 +266,7 @@ def send_update_notification(notification_filepath: str, ssh_conn_id: str) -> No
"cleanup",
importer,
data_repos,
source_automation_env=True,
),
}

Expand Down Expand Up @@ -312,9 +301,11 @@ def _build_task(name: str) -> object:

return SSHOperator.partial(**params).expand(ssh_conn_id=list(ssh_targets))

tasks: dict[str, object] = {"data_repos": data_repos}
tasks: dict[str, object] = {}
for name in config.task_names:
if name == "send_update_notification":
if name == "data_repos":
tasks[name] = data_repos
elif name == "send_update_notification":
tasks[name] = send_update_notification(
notification_filepath=notification_filepath,
ssh_conn_id=config.target_nodes[0],
Expand Down
Loading