Skip to content

Commit 860277d

Browse files
authored
Fix TaskInstance crash with non-serialized operators missing get_weight (#64557)
* Fix TaskInstance crash with non-serialized operators missing get_weight
1 parent 20553e6 commit 860277d

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from airflow.models.taskreschedule import TaskReschedule
8989
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence, XComModel
9090
from airflow.settings import task_instance_mutation_hook
91+
from airflow.task.priority_strategy import validate_and_load_priority_weight_strategy
9192
from airflow.ti_deps.dep_context import DepContext
9293
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
9394
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
@@ -693,7 +694,10 @@ def insert_mapping(
693694
694695
:meta private:
695696
"""
696-
priority_weight = task.weight_rule.get_weight(
697+
weight_rule = task.weight_rule
698+
if not hasattr(weight_rule, "get_weight"):
699+
weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
700+
priority_weight = weight_rule.get_weight(
697701
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
698702
)
699703
context_carrier = new_task_run_carrier(dag_run.context_carrier)
@@ -874,7 +878,10 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) ->
874878
self.queue = task.queue
875879
self.pool = pool_override or task.pool
876880
self.pool_slots = task.pool_slots
877-
self.priority_weight = self.task.weight_rule.get_weight(self)
881+
weight_rule = self.task.weight_rule
882+
if not hasattr(weight_rule, "get_weight"):
883+
weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
884+
self.priority_weight = weight_rule.get_weight(self)
878885
self.run_as_user = task.run_as_user
879886
# Do not set max_tries to task.retries here because max_tries is a cumulative
880887
# value that needs to be stored in the db.

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,26 @@ def mock_policy(task_instance: TaskInstance):
26532653
assert ti.max_tries == expected_max_tries
26542654

26552655

2656+
@pytest.mark.parametrize(
2657+
("weight_rule", "expected_weight"),
2658+
[
2659+
pytest.param("downstream", 10 + 5, id="downstream-sums-descendants"),
2660+
pytest.param("upstream", 10, id="upstream-no-ancestors"),
2661+
pytest.param("absolute", 10, id="absolute-self-only"),
2662+
],
2663+
)
2664+
def test_refresh_from_task_with_non_serialized_operator(weight_rule, expected_weight):
2665+
"""Regression: TaskInstance must work with non-serialized operators whose weight_rule is a WeightRule enum."""
2666+
with DAG(dag_id="test_dag"):
2667+
root = EmptyOperator(task_id="root", priority_weight=10, weight_rule=weight_rule)
2668+
child = EmptyOperator(task_id="child", priority_weight=5)
2669+
root >> child
2670+
2671+
ti = TI(root, run_id=None, dag_version_id=mock.MagicMock())
2672+
2673+
assert ti.priority_weight == expected_weight
2674+
2675+
26562676
def test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance):
26572677
session = mock.MagicMock()
26582678
ti = create_task_instance(

0 commit comments

Comments
 (0)