2525
2626import com .datastax .oss .driver .api .core .ConsistencyLevel ;
2727import com .datastax .oss .driver .api .core .CqlIdentifier ;
28+ import com .datastax .oss .driver .api .core .RequestRoutingType ;
2829import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
2930import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
3031import com .datastax .oss .driver .api .core .context .DriverContext ;
6364import edu .umd .cs .findbugs .annotations .NonNull ;
6465import edu .umd .cs .findbugs .annotations .Nullable ;
6566import java .nio .ByteBuffer ;
67+ import java .util .ArrayList ;
68+ import java .util .Collections ;
69+ import java .util .LinkedHashMap ;
6670import java .util .LinkedHashSet ;
6771import java .util .List ;
6872import java .util .Map ;
7175import java .util .Queue ;
7276import java .util .Set ;
7377import java .util .UUID ;
78+ import java .util .concurrent .ThreadLocalRandom ;
7479import java .util .concurrent .atomic .AtomicInteger ;
7580import java .util .function .IntUnaryOperator ;
81+ import java .util .stream .Collectors ;
7682import net .jcip .annotations .ThreadSafe ;
7783import org .slf4j .Logger ;
7884import org .slf4j .LoggerFactory ;
113119@ ThreadSafe
114120public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
115121
122+ public enum RequestRoutingMethod {
123+ REGULAR ,
124+ PRESERVE_REPLICA_ORDER
125+ }
126+
116127 private static final Logger LOG = LoggerFactory .getLogger (BasicLoadBalancingPolicy .class );
117128
118129 protected static final IntUnaryOperator INCREMENT = i -> (i == Integer .MAX_VALUE ) ? 0 : i + 1 ;
@@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
127138 private final int maxNodesPerRemoteDc ;
128139 private final boolean allowDcFailoverForLocalCl ;
129140 private final ConsistencyLevel defaultConsistencyLevel ;
141+ private final RequestRoutingMethod lwtRequestRoutingMethod ;
130142
131143 // private because they should be set in init() and never be modified after
132144 private volatile DistanceReporter distanceReporter ;
@@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
154166 new LinkedHashSet <>(
155167 profile .getStringList (
156168 DefaultDriverOption .LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS ));
169+ this .lwtRequestRoutingMethod = parseLwtRequestRoutingMethod ();
170+ }
171+
172+ @ NonNull
173+ private RequestRoutingMethod parseLwtRequestRoutingMethod () {
174+ String methodString =
175+ profile .getString (DefaultDriverOption .LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD );
176+ try {
177+ return RequestRoutingMethod .valueOf (methodString .toUpperCase ());
178+ } catch (IllegalArgumentException e ) {
179+ LOG .warn (
180+ "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER" ,
181+ logPrefix ,
182+ methodString );
183+ return RequestRoutingMethod .PRESERVE_REPLICA_ORDER ;
184+ }
185+ }
186+
187+ @ NonNull
188+ public RequestRoutingMethod getRequestRoutingMethod (@ Nullable Request request ) {
189+ if (request == null ) {
190+ return RequestRoutingMethod .REGULAR ;
191+ }
192+ if (request .getRequestRoutingType () == RequestRoutingType .LWT ) {
193+ return lwtRequestRoutingMethod ;
194+ } else {
195+ return RequestRoutingMethod .REGULAR ;
196+ }
157197 }
158198
159199 /**
@@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
260300 @ NonNull
261301 @ Override
262302 public Queue <Node > newQueryPlan (@ Nullable Request request , @ Nullable Session session ) {
303+ switch (getRequestRoutingMethod (request )) {
304+ case PRESERVE_REPLICA_ORDER :
305+ return newQueryPlanPreserveReplicas (request , session );
306+ case REGULAR :
307+ default :
308+ return newQueryPlanRegular (request , session );
309+ }
310+ }
311+
312+ @ NonNull
313+ protected Queue <Node > newQueryPlanRegular (@ Nullable Request request , @ Nullable Session session ) {
263314 // Take a snapshot since the set is concurrent:
264315 Object [] currentNodes = liveNodes .dc (localDc ).toArray ();
265316
@@ -294,6 +345,116 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
294345 return maybeAddDcFailover (request , plan );
295346 }
296347
348+ /**
349+ * Builds a query plan that preserves replica order: local replicas, remote replicas, local
350+ * non-replicas (rotated), remote non-replicas (rotated).
351+ */
352+ @ NonNull
353+ protected Queue <Node > newQueryPlanPreserveReplicas (
354+ @ Nullable Request request , @ Nullable Session session ) {
355+ List <Node > replicas = getReplicas (request , session );
356+ String localDc = getLocalDatacenter ();
357+ List <Node > queryPlan = new ArrayList <>();
358+
359+ if (localDc == null ) {
360+ // No local DC: all replicas first, then rotated non-replicas
361+ List <Node > allNodes = new ArrayList <>();
362+ for (Object obj : getLiveNodes ().dc (null ).toArray ()) {
363+ allNodes .add ((Node ) obj );
364+ }
365+ queryPlan .addAll (replicas );
366+ addRotatedNonReplicas (queryPlan , allNodes , replicas , request );
367+ } else {
368+ // With local DC: prioritize local, then remote
369+ Map <String , List <Node >> nodesByDc = getAllNodesByDc ();
370+ addReplicasByDc (queryPlan , replicas , localDc );
371+ addNonReplicasByDc (queryPlan , nodesByDc , replicas , localDc , request );
372+ }
373+
374+ return new SimpleQueryPlan (queryPlan .toArray ());
375+ }
376+
377+ /** Collect all live nodes grouped by DC, with preferred remote DCs ordered first. */
378+ private Map <String , List <Node >> getAllNodesByDc () {
379+ Map <String , List <Node >> nodesByDc = new LinkedHashMap <>();
380+ Set <String > allDcs = getLiveNodes ().dcs ();
381+ // Add preferred remote DCs first (in configured order)
382+ for (String dc : preferredRemoteDcs ) {
383+ if (allDcs .contains (dc )) {
384+ nodesByDc .put (dc , dcNodeList (dc ));
385+ }
386+ }
387+ // Add remaining DCs (sorted for deterministic ordering)
388+ allDcs .stream ()
389+ .sorted ()
390+ .filter (dc -> !nodesByDc .containsKey (dc ))
391+ .forEach (dc -> nodesByDc .put (dc , dcNodeList (dc )));
392+ return nodesByDc ;
393+ }
394+
395+ private List <Node > dcNodeList (String dc ) {
396+ List <Node > dcNodes = new ArrayList <>();
397+ for (Object obj : getLiveNodes ().dc (dc ).toArray ()) {
398+ dcNodes .add ((Node ) obj );
399+ }
400+ return dcNodes ;
401+ }
402+
403+ /** Add replicas with local DC first, then remote DCs. */
404+ private void addReplicasByDc (List <Node > queryPlan , List <Node > replicas , String localDc ) {
405+ replicas .stream ()
406+ .filter (r -> Objects .equals (r .getDatacenter (), localDc ))
407+ .forEach (queryPlan ::add );
408+ replicas .stream ()
409+ .filter (r -> !Objects .equals (r .getDatacenter (), localDc ))
410+ .forEach (queryPlan ::add );
411+ }
412+
413+ /** Add non-replicas with local DC first, then remote DCs (all rotated). */
414+ private void addNonReplicasByDc (
415+ List <Node > queryPlan ,
416+ Map <String , List <Node >> nodesByDc ,
417+ List <Node > replicas ,
418+ String localDc ,
419+ Request request ) {
420+ // Local DC non-replicas first
421+ List <Node > localNodes = nodesByDc .get (localDc );
422+ if (localNodes != null ) {
423+ addRotatedNonReplicas (queryPlan , localNodes , replicas , request );
424+ }
425+ // Remote DC non-replicas
426+ for (Map .Entry <String , List <Node >> entry : nodesByDc .entrySet ()) {
427+ if (!Objects .equals (entry .getKey (), localDc )) {
428+ addRotatedNonReplicas (queryPlan , entry .getValue (), replicas , request );
429+ }
430+ }
431+ }
432+
433+ /** Add non-replica nodes from given list with rotation. */
434+ private void addRotatedNonReplicas (
435+ List <Node > queryPlan , List <Node > nodes , List <Node > replicas , Request request ) {
436+ List <Node > nonReplicas =
437+ nodes .stream ().filter (n -> !replicas .contains (n )).collect (Collectors .toList ());
438+ if (!nonReplicas .isEmpty ()) {
439+ rotateNonReplicas (nonReplicas , request );
440+ queryPlan .addAll (nonReplicas );
441+ }
442+ }
443+
444+ /** Rotates nodes based on routing key (consistent) or randomly. */
445+ private void rotateNonReplicas (List <Node > nodes , @ Nullable Request request ) {
446+ if (nodes .size () <= 1 ) return ;
447+
448+ int rotationAmount =
449+ (request != null && request .getRoutingKey () != null )
450+ ? (request .getRoutingKey ().hashCode () & 0x7fffffff ) % nodes .size ()
451+ : randomNextInt (nodes .size ());
452+
453+ if (rotationAmount > 0 ) {
454+ Collections .rotate (nodes , -rotationAmount );
455+ }
456+ }
457+
297458 @ NonNull
298459 protected List <Node > getReplicas (@ Nullable Request request , @ Nullable Session session ) {
299460 if (request == null || session == null ) {
@@ -441,6 +602,11 @@ protected Object[] computeNodes() {
441602 return new CompositeQueryPlan (queryPlans );
442603 }
443604
605+ /** Exposed as a protected method so that it can be accessed by tests */
606+ protected int randomNextInt (int bound ) {
607+ return ThreadLocalRandom .current ().nextInt (bound );
608+ }
609+
444610 /** Exposed as a protected method so that it can be accessed by tests */
445611 protected void shuffleHead (Object [] currentNodes , int headLength ) {
446612 ArrayUtils .shuffleHead (currentNodes , headLength );
0 commit comments