Skip to content

Commit e91025c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3694cfc commit e91025c

3 files changed

Lines changed: 34 additions & 13 deletions

File tree

unstract/task-abstraction/src/unstract/task_abstraction/backends/hatchet.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def create_steps(self):
278278

279279
# Capture loop variables in closure by using default arguments
280280
def create_workflow_step(step_obj, parent_list):
281-
@self.hatchet.step(name=step_obj.task_name, parents=parent_list)
281+
@self.hatchet.step(
282+
name=step_obj.task_name, parents=parent_list
283+
)
282284
def workflow_step(context):
283285
# Get the original task function
284286
task_fn = self._tasks[step_obj.task_name]
@@ -289,10 +291,13 @@ def workflow_step(context):
289291
workflow_input = context.step_output(parent_list[0])
290292
else:
291293
# Use initial workflow input
292-
workflow_input = context.workflow_input()["initial_input"]
294+
workflow_input = context.workflow_input()[
295+
"initial_input"
296+
]
293297

294298
# Execute task with input and step kwargs
295299
return task_fn(workflow_input, **step_obj.kwargs)
300+
296301
return workflow_step
297302

298303
workflow_step = create_workflow_step(step, parents)

unstract/task-abstraction/tests/integration/test_backend_selection.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ async def test_backend_selection_by_feature_flags(
134134
def create_mock_flag_response(tc):
135135
def mock_flag_response(flag_key, namespace, entity_id, context=None):
136136
return tc.feature_flags.get(flag_key, False)
137+
137138
return mock_flag_response
138139

139140
mock_flag.side_effect = create_mock_flag_response(test_case)
@@ -169,14 +170,19 @@ async def test_rollout_percentage_distribution(self, backend_selector):
169170
) as mock_flag:
170171
# Mock percentage-based rollout - capture scenario in closure
171172
def create_mock_percentage_rollout(scen):
172-
def mock_percentage_rollout(flag_key, namespace, entity_id, context=None):
173+
def mock_percentage_rollout(
174+
flag_key, namespace, entity_id, context=None
175+
):
173176
if flag_key == "task_abstraction_enabled":
174177
import hashlib
175178

176-
hash_value = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
179+
hash_value = int(
180+
hashlib.md5(entity_id.encode()).hexdigest(), 16
181+
)
177182
user_bucket = hash_value % 100
178183
return user_bucket < scen["percentage"]
179184
return False
185+
180186
return mock_percentage_rollout
181187

182188
mock_flag.side_effect = create_mock_percentage_rollout(scenario)
@@ -226,21 +232,24 @@ async def test_organization_based_selection(self, backend_selector):
226232
with patch(
227233
"unstract.flags.feature_flag.check_feature_flag_status"
228234
) as mock_flag:
229-
230235
# Capture case in closure
231236
def create_mock_org_based_flags(c):
232-
def mock_org_based_flags(flag_key, namespace, entity_id, context=None):
237+
def mock_org_based_flags(
238+
flag_key, namespace, entity_id, context=None
239+
):
233240
org_id = context.get("organization_id") if context else None
234241

235242
# Organization-specific logic
236243
if org_id == "beta_org" and flag_key == "hatchet_backend_enabled":
237244
return True
238245
elif (
239-
org_id == "stable_org" and flag_key == "task_abstraction_enabled"
246+
org_id == "stable_org"
247+
and flag_key == "task_abstraction_enabled"
240248
):
241249
return True
242250

243251
return c["feature_flags"].get(flag_key, False)
252+
244253
return mock_org_based_flags
245254

246255
mock_flag.side_effect = create_mock_org_based_flags(case)
@@ -269,11 +278,11 @@ async def test_fallback_chain_construction(
269278
with patch(
270279
"unstract.flags.feature_flag.check_feature_flag_status"
271280
) as mock_flag:
272-
273281
# Capture test_case in closure
274282
def create_mock_flag_response(tc):
275283
def mock_flag_response(flag_key, namespace, entity_id, context=None):
276284
return tc.feature_flags.get(flag_key, False)
285+
277286
return mock_flag_response
278287

279288
mock_flag.side_effect = create_mock_flag_response(test_case)
@@ -344,7 +353,6 @@ async def test_user_segment_based_selection(self, backend_selector):
344353
with patch(
345354
"unstract.flags.feature_flag.check_feature_flag_status"
346355
) as mock_flag:
347-
348356
# Capture segment in closure
349357
def create_mock_segment_based_flags(seg):
350358
def mock_segment_based_flags(
@@ -362,6 +370,7 @@ def mock_segment_based_flags(
362370
elif seg["segment"] == "free_users":
363371
return flag_key == "unified_celery_enabled"
364372
return False
373+
365374
return mock_segment_based_flags
366375

367376
mock_flag.side_effect = create_mock_segment_based_flags(segment)
@@ -405,11 +414,11 @@ async def test_workflow_specific_backend_preferences(self, backend_selector):
405414
with patch(
406415
"unstract.flags.feature_flag.check_feature_flag_status"
407416
) as mock_flag:
408-
409417
# Capture preference in closure
410418
def create_mock_flag_response(pref):
411419
def mock_flag_response(flag_key, namespace, entity_id, context=None):
412420
return pref["feature_flags"].get(flag_key, False)
421+
413422
return mock_flag_response
414423

415424
mock_flag.side_effect = create_mock_flag_response(preference)

unstract/task-abstraction/tests/integration/test_feature_flag_rollout.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,17 @@ async def test_percentage_based_rollout(
139139
) as mock_flag:
140140
# Mock percentage-based rollout - capture percentage in closure
141141
def create_mock_percentage_rollout(pct):
142-
def mock_percentage_rollout(flag_key, namespace, entity_id, context=None):
142+
def mock_percentage_rollout(
143+
flag_key, namespace, entity_id, context=None
144+
):
143145
if flag_key == "task_abstraction_enabled":
144-
hash_value = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
146+
hash_value = int(
147+
hashlib.md5(entity_id.encode()).hexdigest(), 16
148+
)
145149
user_bucket = hash_value % 100
146150
return user_bucket < pct
147151
return False
152+
148153
return mock_percentage_rollout
149154

150155
mock_flag.side_effect = create_mock_percentage_rollout(percentage)
@@ -332,7 +337,9 @@ def mock_progressive_rollout(
332337
first_enabled_index = i
333338
elif first_enabled_index is not None and not enabled:
334339
# User was disabled after being enabled - this shouldn't happen
335-
pytest.fail(f"User {user_id} was disabled after being enabled at stage {rollout_stages[first_enabled_index]}%")
340+
pytest.fail(
341+
f"User {user_id} was disabled after being enabled at stage {rollout_stages[first_enabled_index]}%"
342+
)
336343

337344
@pytest.mark.asyncio
338345
async def test_rollback_scenario(self, feature_flag_manager):

0 commit comments

Comments
 (0)