Skip to content

Commit 7082001

Browse files
authored
refactor: Refactor segment cloning (#5898)
1 parent 6153206 commit 7082001

11 files changed

Lines changed: 354 additions & 919 deletions

api/conftest.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
)
8282
from projects.tags.models import Tag
8383
from segments.models import Condition, Segment, SegmentRule
84-
from segments.services import SegmentCloneService
8584
from tests.test_helpers import fix_issue_3869
8685
from tests.types import (
8786
AdminClientAuthType,
@@ -376,13 +375,9 @@ def project(organisation): # type: ignore[no-untyped-def]
376375

377376

378377
@pytest.fixture()
379-
def segment(project: Project): # type: ignore[no-untyped-def]
380-
_segment = Segment.objects.create(name="segment", project=project)
381-
# Deep clone the segment to ensure that any bugs around
382-
# versioning get bubbled up through the test suite.
383-
SegmentCloneService(_segment).deep_clone()
384-
385-
return _segment
378+
def segment(project: Project) -> Segment:
379+
segment: Segment = Segment.objects.create(name="segment", project=project)
380+
return segment
386381

387382

388383
@pytest.fixture()

api/core/workflows_services.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from typing import TYPE_CHECKING
22

3+
import structlog
4+
from django.db import transaction
35
from django.utils import timezone
46

57
from environments.tasks import rebuild_environment_document
68
from features.versioning.models import EnvironmentFeatureVersion
79
from features.versioning.signals import environment_feature_version_published
810
from features.versioning.tasks import trigger_update_version_webhooks
911
from features.workflows.core.exceptions import ChangeRequestNotApprovedError
10-
from segments.models import Segment
11-
from segments.services import SegmentCloneService
1212

1313
if TYPE_CHECKING:
1414
from features.workflows.core.models import ChangeRequest
1515
from users.models import FFAdminUser
1616

17+
logger = structlog.get_logger()
18+
1719

1820
class ChangeRequestCommitService:
1921
def __init__(self, change_request: "ChangeRequest") -> None:
@@ -95,26 +97,24 @@ def _publish_change_sets(self, published_by: "FFAdminUser") -> None:
9597
for change_set in self.change_request.change_sets.all():
9698
change_set.publish(user=published_by)
9799

100+
@transaction.atomic
98101
def _publish_segments(self) -> None:
99-
for segment in self.change_request.segments.all():
100-
target_segment: Segment = segment.version_of # type: ignore[assignment]
101-
assert target_segment != segment
102-
103-
# Deep clone the segment to establish historical version this is required
104-
# because the target segment will be altered when the segment is published.
105-
# Think of it like a regular update to a segment where we create the clone
106-
# to create the version, then modifying the new 'draft' version with the
107-
# data from the change request.
108-
SegmentCloneService(target_segment).deep_clone()
109-
110-
# Set the properties of the change request's segment to the properties
111-
# of the target (i.e., canonical) segment.
112-
target_segment.name = segment.name
113-
target_segment.description = segment.description
114-
target_segment.feature = segment.feature
115-
target_segment.save()
116-
117-
# Delete the rules in order to replace them with copies of the segment.
118-
target_segment.rules.all().delete()
119-
for rule in segment.rules.all():
120-
rule.deep_clone(target_segment)
102+
for draft_segment in self.change_request.segments.all():
103+
live_segment = draft_segment.version_of
104+
if not live_segment: # pragma: no cover
105+
logger.warning("missing-live-segment", draft_segment=draft_segment.uuid)
106+
continue
107+
108+
# Make a revision of the live segment
109+
revision = live_segment.clone(is_revision=True)
110+
logger.info(
111+
"segment-revision-created",
112+
segment_id=live_segment.id,
113+
revision_id=revision.id,
114+
)
115+
116+
live_segment.name = draft_segment.name
117+
live_segment.description = draft_segment.description
118+
live_segment.feature = draft_segment.feature
119+
live_segment.save()
120+
live_segment.copy_rules_and_conditions_from(draft_segment)

api/segments/models.py

Lines changed: 86 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from django.conf import settings
77
from django.contrib.contenttypes.fields import GenericRelation
88
from django.core.exceptions import ValidationError
9-
from django.db import models
9+
from django.db import models, transaction
1010
from django_lifecycle import ( # type: ignore[import-untyped]
1111
AFTER_CREATE,
1212
BEFORE_CREATE,
@@ -114,35 +114,72 @@ def set_version_of_to_self_if_none(self): # type: ignore[no-untyped-def]
114114
self.version_of = self
115115
self.save_without_historical_record()
116116

117-
def _clone_segment_rules(self, cloned_segment: "Segment") -> list["SegmentRule"]:
118-
cloned_rules = []
119-
for rule in self.rules.all():
120-
cloned_rule = rule.deep_clone(cloned_segment)
121-
cloned_rules.append(cloned_rule)
122-
cloned_segment.refresh_from_db()
123-
assert (
124-
len(self.rules.all())
125-
== len(cloned_rules)
126-
== len(cloned_segment.rules.all())
127-
), "Mismatch during rules creation"
128-
129-
return cloned_rules
130-
131-
# TODO: To be depreacted in flagsmith-common and flagsmith-workflows
132-
def deep_clone(self) -> "Segment":
117+
@transaction.atomic
118+
def clone(self, is_revision: bool = False, **extra_attrs: typing.Any) -> "Segment":
119+
"""
120+
Create a revision of the segment
121+
"""
133122
cloned_segment = deepcopy(self)
134-
cloned_segment.id = None
123+
cloned_segment.pk = None
135124
cloned_segment.uuid = uuid.uuid4()
136-
cloned_segment.version_of = self
125+
cloned_segment.version_of = None # Unset for now
126+
cloned_segment.version = 0 # Unset for now
127+
for attr_name, value in extra_attrs.items():
128+
setattr(cloned_segment, attr_name, value)
137129
cloned_segment.save()
138130

139-
self.version += 1 # type: ignore[operator]
140-
self.save_without_historical_record()
131+
cloned_segment.copy_rules_and_conditions_from(self)
141132

142-
self._clone_segment_rules(cloned_segment)
133+
# Handle versioning
134+
version_of = self if is_revision else cloned_segment
135+
cloned_segment.version_of = extra_attrs.get("version_of", version_of)
136+
cloned_segment.version = self.version if is_revision else 1
137+
Segment.objects.filter(pk=cloned_segment.pk).update(
138+
version_of=cloned_segment.version_of,
139+
version=cloned_segment.version,
140+
)
141+
142+
# Increase self version
143+
if is_revision:
144+
self.version = (self.version or 1) + 1
145+
Segment.objects.filter(pk=self.pk).update(version=self.version)
143146

144147
return cloned_segment
145148

149+
def copy_rules_and_conditions_from(self, source_segment: "Segment") -> None:
150+
"""
151+
Recursively copy rules and conditions from another segment
152+
"""
153+
assert transaction.get_connection().in_atomic_block, "Must run in a transaction"
154+
155+
# Delete existing rules
156+
SegmentRule.objects.filter(segment=self).delete()
157+
158+
source_rules = SegmentRule.objects.filter(
159+
models.Q(segment=source_segment) | models.Q(rule__segment=source_segment)
160+
)
161+
162+
# Ensure top-level rules are cloned first (for dependencies)
163+
source_rules = source_rules.order_by(models.F("rule").asc(nulls_first=True))
164+
165+
rule_to_cloned_rule_map: dict[SegmentRule, SegmentRule] = {}
166+
for rule in source_rules:
167+
cloned_rule = deepcopy(rule)
168+
cloned_rule.pk = None
169+
cloned_rule.uuid = uuid.uuid4()
170+
cloned_rule.segment = self if rule.segment else None
171+
cloned_rule.rule = rule_to_cloned_rule_map.get(rule.rule)
172+
cloned_rule.save()
173+
rule_to_cloned_rule_map[rule] = cloned_rule
174+
175+
source_conditions = Condition.objects.filter(rule__in=rule_to_cloned_rule_map)
176+
for condition in source_conditions:
177+
cloned_condition = deepcopy(condition)
178+
cloned_condition.pk = None
179+
cloned_condition.uuid = uuid.uuid4()
180+
cloned_condition.rule = rule_to_cloned_rule_map[condition.rule]
181+
cloned_condition.save()
182+
146183
def get_create_log_message(self, history_instance) -> typing.Optional[str]: # type: ignore[no-untyped-def]
147184
return SEGMENT_CREATED_MESSAGE % self.name
148185

@@ -180,94 +217,28 @@ class SegmentRule(
180217

181218
history_record_class_path = "segments.models.HistoricalSegmentRule"
182219

183-
def clean(self): # type: ignore[no-untyped-def]
184-
super().clean()
185-
parents = [self.segment, self.rule]
186-
num_parents = sum(parent is not None for parent in parents)
187-
if num_parents != 1:
188-
raise ValidationError(
189-
"Segment rule must have exactly one parent, %d found",
190-
num_parents, # type: ignore[arg-type]
191-
)
192-
193220
def __str__(self): # type: ignore[no-untyped-def]
194221
return "%s rule for %s" % (
195222
self.type,
196223
str(self.segment) if self.segment else str(self.rule),
197224
)
198225

199-
def get_skip_create_audit_log(self) -> bool:
200-
try:
201-
segment = self.get_segment() # type: ignore[no-untyped-call]
202-
if segment.deleted_at:
203-
return True
204-
return segment.version_of_id != segment.id # type: ignore[no-any-return]
205-
except (Segment.DoesNotExist, SegmentRule.DoesNotExist):
206-
# handle hard delete
207-
return True
208-
209-
def _get_project(self) -> typing.Optional[Project]:
210-
return self.get_segment().project # type: ignore[no-untyped-call,no-any-return]
211-
212-
def get_segment(self): # type: ignore[no-untyped-def]
213-
"""
214-
rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the
215-
segment
216-
217-
TODO: denormalise the segment information so that we don't have to make multiple queries here in complex cases
218-
"""
219-
rule = self
220-
while not rule.segment_id:
221-
rule = rule.rule # type: ignore[assignment]
222-
return rule.segment
223-
224-
def deep_clone(self, cloned_segment: Segment) -> "SegmentRule":
225-
if self.rule:
226-
# Since we're expecting a rule that is only belonging to a
227-
# segment, since a rule either belongs to a segment xor belongs
228-
# to a rule, we don't expect there also to be a rule associated.
229-
assert False, "Unexpected rule, expecting segment set not rule"
230-
cloned_rule = deepcopy(self)
231-
cloned_rule.segment = cloned_segment
232-
cloned_rule.uuid = uuid.uuid4()
233-
cloned_rule.id = None
234-
cloned_rule.save()
235-
logger.info(
236-
f"Deep copying rule {self.id} for cloned rule {cloned_rule.id} for cloned segment {cloned_segment.id}"
237-
)
226+
def clean(self) -> None:
227+
super().clean()
228+
self._validate_one_parent()
238229

239-
# Conditions are only part of the sub-rules.
240-
assert self.conditions.exists() is False
241-
242-
for sub_rule in self.rules.all():
243-
if sub_rule.rules.exists():
244-
assert False, "Expected two layers of rules, not more"
245-
246-
cloned_sub_rule = deepcopy(sub_rule)
247-
cloned_sub_rule.rule = cloned_rule
248-
cloned_sub_rule.uuid = uuid.uuid4()
249-
cloned_sub_rule.id = None
250-
cloned_sub_rule.save()
251-
logger.info(
252-
f"Deep copying sub rule {sub_rule.id} for cloned sub rule {cloned_sub_rule.id} "
253-
f"for cloned segment {cloned_segment.id}"
230+
def _validate_one_parent(self) -> None:
231+
parents = [self.segment, self.rule]
232+
if (num_parents := sum(parent is not None for parent in parents)) != 1:
233+
raise ValidationError(
234+
f"SegmentRule must have exactly one parent, {num_parents} found"
254235
)
255236

256-
cloned_conditions = []
257-
for condition in sub_rule.conditions.all():
258-
cloned_condition = deepcopy(condition)
259-
cloned_condition.rule = cloned_sub_rule
260-
cloned_condition.uuid = uuid.uuid4()
261-
cloned_condition.id = None
262-
cloned_conditions.append(cloned_condition)
263-
logger.info(
264-
f"Cloning condition {condition.id} for cloned condition {cloned_condition.uuid} "
265-
f"for cloned segment {cloned_segment.id}"
266-
)
267-
268-
Condition.objects.bulk_create(cloned_conditions)
269-
270-
return cloned_rule
237+
def get_skip_create_audit_log(self) -> bool:
238+
# NOTE: We'll transition to storing rules and conditions in JSON so
239+
# individual audit logs for rules and conditions is irrelevant.
240+
# This model will be deleted as of https://github.com/Flagsmith/flagsmith/issues/5846
241+
return True
271242

272243

273244
class ConditionManager(SoftDeleteExportableManager):
@@ -330,52 +301,31 @@ class Condition(
330301

331302
objects: typing.ClassVar[ConditionManager] = ConditionManager()
332303

333-
def __str__(self): # type: ignore[no-untyped-def]
304+
def __str__(self) -> str:
334305
return "Condition for %s: %s %s %s" % (
335306
str(self.rule),
336307
self.property,
337308
self.operator,
338309
self.value,
339310
)
340311

341-
def get_skip_create_audit_log(self) -> bool:
342-
try:
343-
if self.rule.deleted_at:
344-
return True
345-
346-
segment = self.rule.get_segment() # type: ignore[no-untyped-call]
347-
if segment.deleted_at:
348-
return True
312+
def get_skip_create_audit_log(self) -> bool: # pragma: no cover
313+
# NOTE: We'll transition to storing rules and conditions in JSON so
314+
# individual audit logs for rules and conditions is irrelevant.
315+
# This model will be deleted as of https://github.com/Flagsmith/flagsmith/issues/5846
316+
return True
349317

350-
return segment.version_of_id != segment.id # type: ignore[no-any-return]
351-
except (Segment.DoesNotExist, SegmentRule.DoesNotExist):
352-
# handle hard delete
353-
return True
318+
def get_update_log_message(self, _: typing.Any) -> None: # pragma: no cover
319+
return None
354320

355-
def get_update_log_message(self, history_instance) -> typing.Optional[str]: # type: ignore[no-untyped-def]
356-
return f"Condition updated on segment '{self._get_segment().name}'."
321+
def get_create_log_message(self, _: typing.Any) -> None: # pragma: no cover
322+
return None
357323

358-
def get_create_log_message(self, history_instance) -> typing.Optional[str]: # type: ignore[no-untyped-def,return]
359-
if not self.created_with_segment:
360-
return f"Condition added to segment '{self._get_segment().name}'."
361-
362-
def get_delete_log_message(self, history_instance) -> typing.Optional[str]: # type: ignore[no-untyped-def,return]
363-
if not self._get_segment().deleted_at:
364-
return f"Condition removed from segment '{self._get_segment().name}'."
365-
366-
def get_audit_log_related_object_id(self, history_instance) -> int: # type: ignore[no-untyped-def]
367-
return self._get_segment().id
368-
369-
def _get_segment(self) -> Segment:
370-
"""
371-
Temporarily cache the segment on the condition object to reduce number of queries.
372-
"""
373-
if not hasattr(self, "segment"):
374-
setattr(self, "segment", self.rule.get_segment()) # type: ignore[no-untyped-call]
375-
return self.segment # type: ignore[no-any-return]
324+
def get_delete_log_message(self, _: typing.Any) -> None: # pragma: no cover
325+
return None
376326

377-
def _get_project(self) -> typing.Optional[Project]:
378-
return self.rule.get_segment().project # type: ignore[no-untyped-call,no-any-return]
327+
def get_audit_log_related_object_id(self, _: typing.Any) -> int: # pragma: no cover
328+
raise NotImplementedError("No longer used, will be removed soon.")
379329

380330

381331
class WhitelistedSegment(models.Model):

0 commit comments

Comments
 (0)