|
1 | | -import logging |
2 | | -import random |
3 | | -from enum import Enum |
| 1 | +from typing import Any, Literal |
4 | 2 |
|
5 | | -from django.conf import settings |
6 | | -from django.core.cache import cache |
7 | | -from django.db import connections |
| 3 | +from django.db.models import Model |
8 | 4 |
|
9 | | -from .exceptions import ImproperlyConfiguredError |
10 | | - |
11 | | -logger = logging.getLogger(__name__) |
12 | | - |
13 | | -CONNECTION_CHECK_CACHE_TTL = 2 |
14 | | - |
15 | | - |
16 | | -class ReplicaReadStrategy(Enum): |
17 | | - DISTRIBUTED = "DISTRIBUTED" |
18 | | - SEQUENTIAL = "SEQUENTIAL" |
19 | | - |
20 | | - |
21 | | -def connection_check(database: str) -> bool: |
22 | | - try: |
23 | | - conn = connections.create_connection(database) |
24 | | - conn.connect() |
25 | | - usable = conn.is_usable() |
26 | | - if not usable: |
27 | | - logger.warning( |
28 | | - f"Unable to access database {database} during connection check" |
29 | | - ) |
30 | | - except Exception: |
31 | | - usable = False |
32 | | - logger.error( |
33 | | - "Encountered exception during connection", |
34 | | - exc_info=True, |
35 | | - ) |
36 | | - |
37 | | - if usable: |
38 | | - cache.set( |
39 | | - f"db_connection_active.{database}", "online", CONNECTION_CHECK_CACHE_TTL |
40 | | - ) |
41 | | - else: |
42 | | - cache.set( |
43 | | - f"db_connection_active.{database}", "offline", CONNECTION_CHECK_CACHE_TTL |
44 | | - ) |
45 | | - |
46 | | - return usable |
47 | | - |
48 | | - |
49 | | -class PrimaryReplicaRouter: |
50 | | - def db_for_read(self, model, **hints): # type: ignore[no-untyped-def] |
51 | | - if settings.NUM_DB_REPLICAS == 0: |
52 | | - return "default" |
53 | | - |
54 | | - replicas = [f"replica_{i}" for i in range(1, settings.NUM_DB_REPLICAS + 1)] |
55 | | - replica = self._get_replica(replicas) |
56 | | - if replica: |
57 | | - # This return is the most likely as replicas should be |
58 | | - # online and properly functioning. |
59 | | - return replica |
60 | | - |
61 | | - # Since no replicas are available, fall back to the cross |
62 | | - # region replicas which have worse availability. |
63 | | - cross_region_replicas = [ |
64 | | - f"cross_region_replica_{i}" |
65 | | - for i in range(1, settings.NUM_CROSS_REGION_DB_REPLICAS + 1) |
66 | | - ] |
67 | | - |
68 | | - cross_region_replica = self._get_replica(cross_region_replicas) |
69 | | - if cross_region_replica: |
70 | | - return cross_region_replica |
71 | | - |
72 | | - # No available replicas, so fallback to the default. |
73 | | - logger.warning( |
74 | | - "Unable to serve any available replicas, falling back to default database" |
75 | | - ) |
76 | | - return "default" |
77 | | - |
78 | | - def db_for_write(self, model, **hints): # type: ignore[no-untyped-def] |
79 | | - return "default" |
80 | | - |
81 | | - def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def] |
82 | | - """ |
83 | | - Relations between objects are allowed if both objects are |
84 | | - in the primary/replica pool. |
85 | | - """ |
86 | | - db_set = { |
87 | | - "default", |
88 | | - *[f"replica_{i}" for i in range(1, settings.NUM_DB_REPLICAS + 1)], |
89 | | - *[ |
90 | | - f"cross_region_replica_{i}" |
91 | | - for i in range(1, settings.NUM_CROSS_REGION_DB_REPLICAS + 1) |
92 | | - ], |
93 | | - } |
94 | | - if obj1._state.db in db_set and obj2._state.db in db_set: |
95 | | - return True |
96 | | - return None |
97 | | - |
98 | | - def allow_migrate(self, db, app_label, model_name=None, **hints): # type: ignore[no-untyped-def] |
99 | | - return db == "default" |
100 | | - |
101 | | - def _get_replica(self, replicas: list[str]) -> None | str: # type: ignore[return] |
102 | | - while replicas: |
103 | | - if settings.REPLICA_READ_STRATEGY == ReplicaReadStrategy.DISTRIBUTED: |
104 | | - database = random.choice(replicas) |
105 | | - elif settings.REPLICA_READ_STRATEGY == ReplicaReadStrategy.SEQUENTIAL: |
106 | | - database = replicas[0] |
107 | | - else: |
108 | | - raise ImproperlyConfiguredError( |
109 | | - f"Unknown REPLICA_READ_STRATEGY {settings.REPLICA_READ_STRATEGY}" |
110 | | - ) |
111 | | - |
112 | | - replicas.remove(database) |
113 | | - db_cache = cache.get(f"db_connection_active.{database}") |
114 | | - if db_cache == "online": |
115 | | - return database |
116 | | - if db_cache == "offline": |
117 | | - continue |
118 | | - |
119 | | - if connection_check(database): |
120 | | - return database |
| 5 | +AnalyticsDatabaseName = Literal["analytics"] |
121 | 6 |
|
122 | 7 |
|
123 | 8 | class AnalyticsRouter: |
124 | 9 | route_app_labels = ["app_analytics"] |
125 | 10 |
|
126 | | - def db_for_read(self, model, **hints): # type: ignore[no-untyped-def] |
127 | | - """ |
128 | | - Attempts to read analytics models go to 'analytics' database. |
129 | | - """ |
| 11 | + def db_for_read( |
| 12 | + self, model: type[Model], **hints: Any |
| 13 | + ) -> AnalyticsDatabaseName | None: |
| 14 | + """Route read queries to the 'analytics' database""" |
130 | 15 | if model._meta.app_label in self.route_app_labels: |
131 | 16 | return "analytics" |
132 | 17 | return None |
133 | 18 |
|
134 | | - def db_for_write(self, model, **hints): # type: ignore[no-untyped-def] |
135 | | - """ |
136 | | - Attempts to write analytics models go to 'analytics' database. |
137 | | - """ |
| 19 | + def db_for_write( |
| 20 | + self, model: type[Model], **hints: Any |
| 21 | + ) -> AnalyticsDatabaseName | None: |
| 22 | + """Route write queries to the 'analytics' database""" |
138 | 23 | if model._meta.app_label in self.route_app_labels: |
139 | 24 | return "analytics" |
140 | 25 | return None |
141 | 26 |
|
142 | | - def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def] |
143 | | - """ |
144 | | - Relations between objects are allowed if both objects are |
145 | | - in the analytics database. |
146 | | - """ |
| 27 | + def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool | None: |
| 28 | + """Allow relations between analytics models""" |
147 | 29 | if ( |
148 | 30 | obj1._meta.app_label in self.route_app_labels |
149 | 31 | and obj2._meta.app_label in self.route_app_labels |
150 | 32 | ): |
151 | 33 | return True |
152 | 34 | return None |
153 | 35 |
|
154 | | - def allow_migrate(self, db, app_label, model_name=None, **hints): # type: ignore[no-untyped-def] |
155 | | - """ |
156 | | - Make sure the analytics app only appears in the 'analytics' database |
157 | | - """ |
158 | | - if app_label in self.route_app_labels: |
159 | | - if db != "default": |
160 | | - return db == "analytics" |
| 36 | + def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool | None: |
| 37 | + """Ensure the analytics database only gets analytics models""" |
| 38 | + if db == "analytics": |
| 39 | + return app_label in self.route_app_labels |
161 | 40 | return None |
0 commit comments