@@ -157,6 +157,18 @@ def make_query_plan(self, working_keyspace=None, query=None):
157157 """
158158 raise NotImplementedError ()
159159
160+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
161+ """
162+ Same as :meth:`make_query_plan`, but with an additional `excluded` parameter.
163+ `excluded` should be a container (set, list, etc.) of hosts to skip.
164+
165+ The default implementation simply delegates to `make_query_plan` and filters the result.
166+ Subclasses may override this for performance.
167+ """
168+ for host in self .make_query_plan (working_keyspace , query ):
169+ if host not in excluded :
170+ yield host
171+
160172 def check_supported (self ):
161173 """
162174 This will be called after the cluster Metadata has been initialized.
@@ -198,6 +210,20 @@ def make_query_plan(self, working_keyspace=None, query=None):
198210 else :
199211 return []
200212
213+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
214+ pos = self ._position
215+ self ._position += 1
216+
217+ hosts = self ._live_hosts
218+ length = len (hosts )
219+ if length :
220+ pos %= length
221+ for host in islice (cycle (hosts ), pos , pos + length ):
222+ if host not in excluded :
223+ yield host
224+ else :
225+ return
226+
201227 def on_up (self , host ):
202228 with self ._hosts_lock :
203229 self ._live_hosts = self ._live_hosts .union ((host , ))
@@ -297,6 +323,40 @@ def make_query_plan(self, working_keyspace=None, query=None):
297323 for host in remote_hosts :
298324 yield host
299325
326+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
327+ # not thread-safe, but we don't care much about lost increments
328+ # for the purposes of load balancing
329+ pos = self ._position
330+ self ._position += 1
331+
332+ local_live = self ._dc_live_hosts .get (self .local_dc , ())
333+ length = len (local_live )
334+ remote_hosts = self ._remote_hosts
335+ if not excluded :
336+ if length :
337+ pos %= length
338+ for i in range (length ):
339+ yield local_live [(pos + i ) % length ]
340+ for host in remote_hosts :
341+ yield host
342+ return
343+
344+ if not isinstance (excluded , set ):
345+ excluded = set (excluded )
346+
347+ if length :
348+ pos %= length
349+ for i in range (length ):
350+ host = local_live [(pos + i ) % length ]
351+ if host in excluded :
352+ continue
353+ yield host
354+
355+ for host in remote_hosts :
356+ if host in excluded :
357+ continue
358+ yield host
359+
300360 def on_up (self , host ):
301361 # not worrying about threads because this will happen during
302362 # control connection startup/refresh
0 commit comments