2525
2626import com .datastax .oss .driver .api .core .ConsistencyLevel ;
2727import com .datastax .oss .driver .api .core .CqlIdentifier ;
28+ import com .datastax .oss .driver .api .core .RequestRoutingType ;
2829import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
2930import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
3031import com .datastax .oss .driver .api .core .context .DriverContext ;
6364import edu .umd .cs .findbugs .annotations .NonNull ;
6465import edu .umd .cs .findbugs .annotations .Nullable ;
6566import java .nio .ByteBuffer ;
67+ import java .util .ArrayList ;
68+ import java .util .Collections ;
69+ import java .util .LinkedHashMap ;
6670import java .util .LinkedHashSet ;
6771import java .util .List ;
6872import java .util .Map ;
7175import java .util .Queue ;
7276import java .util .Set ;
7377import java .util .UUID ;
78+ import java .util .concurrent .ThreadLocalRandom ;
7479import java .util .concurrent .atomic .AtomicInteger ;
7580import java .util .function .IntUnaryOperator ;
81+ import java .util .stream .Collectors ;
7682import net .jcip .annotations .ThreadSafe ;
7783import org .slf4j .Logger ;
7884import org .slf4j .LoggerFactory ;
113119@ ThreadSafe
114120public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
115121
122+ public enum RequestRoutingMethod {
123+ REGULAR ,
124+ PRESERVE_REPLICA_ORDER
125+ }
126+
116127 private static final Logger LOG = LoggerFactory .getLogger (BasicLoadBalancingPolicy .class );
117128
118129 protected static final IntUnaryOperator INCREMENT = i -> (i == Integer .MAX_VALUE ) ? 0 : i + 1 ;
@@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
127138 private final int maxNodesPerRemoteDc ;
128139 private final boolean allowDcFailoverForLocalCl ;
129140 private final ConsistencyLevel defaultConsistencyLevel ;
141+ private final RequestRoutingMethod lwtRequestRoutingMethod ;
130142
131143 // private because they should be set in init() and never be modified after
132144 private volatile DistanceReporter distanceReporter ;
@@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
154166 new LinkedHashSet <>(
155167 profile .getStringList (
156168 DefaultDriverOption .LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS ));
169+ this .lwtRequestRoutingMethod = parseLwtRequestRoutingMethod ();
170+ }
171+
172+ @ NonNull
173+ private RequestRoutingMethod parseLwtRequestRoutingMethod () {
174+ String methodString =
175+ profile .getString (DefaultDriverOption .LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD );
176+ try {
177+ return RequestRoutingMethod .valueOf (methodString .toUpperCase ());
178+ } catch (IllegalArgumentException e ) {
179+ LOG .warn (
180+ "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER" ,
181+ logPrefix ,
182+ methodString );
183+ return RequestRoutingMethod .PRESERVE_REPLICA_ORDER ;
184+ }
185+ }
186+
187+ @ NonNull
188+ public RequestRoutingMethod getRequestRoutingMethod (@ Nullable Request request ) {
189+ if (request == null ) {
190+ return RequestRoutingMethod .REGULAR ;
191+ }
192+ if (request .getRequestRoutingType () == RequestRoutingType .LWT ) {
193+ return lwtRequestRoutingMethod ;
194+ } else {
195+ return RequestRoutingMethod .REGULAR ;
196+ }
157197 }
158198
159199 /**
@@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
260300 @ NonNull
261301 @ Override
262302 public Queue <Node > newQueryPlan (@ Nullable Request request , @ Nullable Session session ) {
303+ switch (getRequestRoutingMethod (request )) {
304+ case PRESERVE_REPLICA_ORDER :
305+ return newQueryPlanPreserveReplicas (request , session );
306+ case REGULAR :
307+ default :
308+ return newQueryPlanRegular (request , session );
309+ }
310+ }
311+
312+ @ NonNull
313+ protected Queue <Node > newQueryPlanRegular (@ Nullable Request request , @ Nullable Session session ) {
263314 // Take a snapshot since the set is concurrent:
264315 Object [] currentNodes = liveNodes .dc (localDc ).toArray ();
265316
@@ -294,6 +345,117 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
294345 return maybeAddDcFailover (request , plan );
295346 }
296347
348+ /**
349+ * Builds a query plan that preserves replica order: local replicas, remote replicas, local
350+ * non-replicas (rotated), remote non-replicas (rotated).
351+ */
352+ @ NonNull
353+ protected Queue <Node > newQueryPlanPreserveReplicas (
354+ @ Nullable Request request , @ Nullable Session session ) {
355+ List <Node > replicas = getReplicas (request , session );
356+ String localDc = getLocalDatacenter ();
357+ List <Node > queryPlan = new ArrayList <>();
358+
359+ if (localDc == null ) {
360+ // No local DC: all replicas first, then rotated non-replicas
361+ List <Node > allNodes = new ArrayList <>();
362+ for (Object obj : getLiveNodes ().dc (null ).toArray ()) {
363+ allNodes .add ((Node ) obj );
364+ }
365+ queryPlan .addAll (replicas );
366+ addRotatedNonReplicas (queryPlan , allNodes , replicas , request );
367+ } else {
368+ // With local DC: prioritize local, then remote
369+ Map <String , List <Node >> nodesByDc = getAllNodesByDc ();
370+ addReplicasByDc (queryPlan , replicas , localDc );
371+ addNonReplicasByDc (queryPlan , nodesByDc , replicas , localDc , request );
372+ }
373+
374+ return new SimpleQueryPlan (queryPlan .toArray ());
375+ }
376+
377+ /** Collect all live nodes grouped by DC, with preferred remote DCs ordered first. */
378+ private Map <String , List <Node >> getAllNodesByDc () {
379+ Map <String , List <Node >> nodesByDc = new LinkedHashMap <>();
380+ Set <String > allDcs = getLiveNodes ().dcs ();
381+ // Add preferred remote DCs first (in configured order)
382+ for (String dc : preferredRemoteDcs ) {
383+ if (allDcs .contains (dc )) {
384+ nodesByDc .put (dc , dcNodeList (dc ));
385+ }
386+ }
387+ // Add remaining DCs
388+ for (String dc : allDcs ) {
389+ if (!nodesByDc .containsKey (dc )) {
390+ nodesByDc .put (dc , dcNodeList (dc ));
391+ }
392+ }
393+ return nodesByDc ;
394+ }
395+
396+ private List <Node > dcNodeList (String dc ) {
397+ List <Node > dcNodes = new ArrayList <>();
398+ for (Object obj : getLiveNodes ().dc (dc ).toArray ()) {
399+ dcNodes .add ((Node ) obj );
400+ }
401+ return dcNodes ;
402+ }
403+
404+ /** Add replicas with local DC first, then remote DCs. */
405+ private void addReplicasByDc (List <Node > queryPlan , List <Node > replicas , String localDc ) {
406+ replicas .stream ()
407+ .filter (r -> Objects .equals (r .getDatacenter (), localDc ))
408+ .forEach (queryPlan ::add );
409+ replicas .stream ()
410+ .filter (r -> !Objects .equals (r .getDatacenter (), localDc ))
411+ .forEach (queryPlan ::add );
412+ }
413+
414+ /** Add non-replicas with local DC first, then remote DCs (all rotated). */
415+ private void addNonReplicasByDc (
416+ List <Node > queryPlan ,
417+ Map <String , List <Node >> nodesByDc ,
418+ List <Node > replicas ,
419+ String localDc ,
420+ Request request ) {
421+ // Local DC non-replicas first
422+ List <Node > localNodes = nodesByDc .get (localDc );
423+ if (localNodes != null ) {
424+ addRotatedNonReplicas (queryPlan , localNodes , replicas , request );
425+ }
426+ // Remote DC non-replicas
427+ for (Map .Entry <String , List <Node >> entry : nodesByDc .entrySet ()) {
428+ if (!Objects .equals (entry .getKey (), localDc )) {
429+ addRotatedNonReplicas (queryPlan , entry .getValue (), replicas , request );
430+ }
431+ }
432+ }
433+
434+ /** Add non-replica nodes from given list with rotation. */
435+ private void addRotatedNonReplicas (
436+ List <Node > queryPlan , List <Node > nodes , List <Node > replicas , Request request ) {
437+ List <Node > nonReplicas =
438+ nodes .stream ().filter (n -> !replicas .contains (n )).collect (Collectors .toList ());
439+ if (!nonReplicas .isEmpty ()) {
440+ rotateNonReplicas (nonReplicas , request );
441+ queryPlan .addAll (nonReplicas );
442+ }
443+ }
444+
445+ /** Rotates nodes based on routing key (consistent) or randomly. */
446+ private void rotateNonReplicas (List <Node > nodes , @ Nullable Request request ) {
447+ if (nodes .size () <= 1 ) return ;
448+
449+ int rotationAmount =
450+ (request != null && request .getRoutingKey () != null )
451+ ? (request .getRoutingKey ().hashCode () & 0x7fffffff ) % nodes .size ()
452+ : randomNextInt (nodes .size ());
453+
454+ if (rotationAmount > 0 ) {
455+ Collections .rotate (nodes , -rotationAmount );
456+ }
457+ }
458+
297459 @ NonNull
298460 protected List <Node > getReplicas (@ Nullable Request request , @ Nullable Session session ) {
299461 if (request == null || session == null ) {
@@ -441,6 +603,11 @@ protected Object[] computeNodes() {
441603 return new CompositeQueryPlan (queryPlans );
442604 }
443605
606+ /** Exposed as a protected method so that it can be accessed by tests */
607+ protected int randomNextInt (int bound ) {
608+ return ThreadLocalRandom .current ().nextInt (bound );
609+ }
610+
444611 /** Exposed as a protected method so that it can be accessed by tests */
445612 protected void shuffleHead (Object [] currentNodes , int headLength ) {
446613 ArrayUtils .shuffleHead (currentNodes , headLength );
0 commit comments