@@ -823,7 +823,7 @@ def preserve_chunks(
823823 ... typesize=8,
824824 ... previous_chunks=(128, 128, 1),
825825 ... )
826- (512, 256 , 1)
826+ (128, 1024 , 1)
827827
828828 >>> ChunkManagerEntrypoint.preserve_chunks(
829829 ... chunks=("preserve", "preserve", 1),
@@ -861,6 +861,7 @@ def preserve_chunks(
861861 if isinstance (previous_chunk , tuple ):
862862 # For uniform chunks just take the first item
863863 if previous_chunk [1 :- 1 ] == previous_chunk [:- 2 ]:
864+ new_chunks [i ] = previous_chunk [0 ]
864865 previous_chunk = previous_chunk [0 ]
865866 # For non-uniform chunks, leave them alone
866867 else :
@@ -880,14 +881,13 @@ def preserve_chunks(
880881 return chunks
881882
882883 while True :
883- # Repeatedly look for the dim with the most chunks and multiply it by 2.
884+ # Repeatedly look for the last dim with more than one chunk and multiply it by 2.
884885 # Stop when:
885886 # 1a. we are larger than the target chunk size OR
886887 # 1b. we are within 50% of the target chunk size OR
887888 # 2. the chunk covers the entire array
888889
889890 num_chunks = np .array (shape ) / max_chunks * auto_dims
890- idx = np .argmax (num_chunks )
891891 chunk_bytes = np .prod (max_chunks ) * typesize
892892
893893 if chunk_bytes > target or abs (chunk_bytes - target ) / target < 0.5 :
@@ -896,6 +896,8 @@ def preserve_chunks(
896896 if (num_chunks <= 1 ).all ():
897897 break
898898
899+ idx = int (np .nonzero (num_chunks > 1 )[0 ][- 1 ])
900+
899901 new_chunks [idx ] = min (new_chunks [idx ] * 2 , shape [idx ])
900902 max_chunks [idx ] = new_chunks [idx ]
901903
0 commit comments