2424"""
2525import contextlib
2626import io
27+ from collections .abc import Callable
28+ from collections .abc import Generator
2729from dataclasses import dataclass
2830from itertools import combinations_with_replacement
2931from itertools import permutations
3032from itertools import product
3133from typing import Any
32- from typing import Callable
33- from typing import Dict
34- from typing import Generator
35- from typing import List
36- from typing import Tuple
37- from typing import Union
3834
3935import msprime
4036import numpy as np
@@ -224,7 +220,7 @@ def norm_hap_weighted(
224220 n_a : int ,
225221 n_b : int ,
226222 result : np .ndarray ,
227- params : Dict [str , Any ],
223+ params : dict [str , Any ],
228224) -> None :
229225 """Create a vector of normalizing coefficients, length of the number of
230226 sample sets. In this normalization strategy, we weight each allele's
@@ -250,7 +246,7 @@ def norm_hap_weighted_ij(
250246 n_a : int ,
251247 n_b : int ,
252248 result : np .ndarray ,
253- params : Dict [str , Any ],
249+ params : dict [str , Any ],
254250) -> None :
255251 """
256252 Create a vector of normalizing coefficients, length of the number of
@@ -286,7 +282,7 @@ def norm_total_weighted(
286282 n_a : int ,
287283 n_b : int ,
288284 result : np .ndarray ,
289- params : Dict [str , Any ],
285+ params : dict [str , Any ],
290286) -> None :
291287 """Create a vector of normalizing coefficients, length of the number of
292288 sample sets. In this normalization strategy, we weight each allele's
@@ -332,7 +328,7 @@ def check_order_bounds_dups(values, max_value):
332328
333329def get_site_row_col_indices (
334330 row_sites : np .ndarray , col_sites : np .ndarray
335- ) -> Tuple [ List [int ], List [int ], List [int ]]:
331+ ) -> tuple [ list [int ], list [int ], list [int ]]:
336332 """Co-iterate over the row and column sites, keeping a sorted union of
337333 site values and an index into the unique list of sites for both the row
338334 and column sites. This function produces a list of sites of interest and
@@ -448,8 +444,8 @@ def get_allele_samples(
448444
449445
450446def get_mutation_samples (
451- ts : tskit .TreeSequence , sites : List [int ], sample_index_map : np .ndarray
452- ) -> Tuple [np .ndarray , np .ndarray , BitSet ]:
447+ ts : tskit .TreeSequence , sites : list [int ], sample_index_map : np .ndarray
448+ ) -> tuple [np .ndarray , np .ndarray , BitSet ]:
453449 """For a given set of sites, generate a BitSet of all samples posessing
454450 each allelic state for each site. This includes the ancestral state, along
455451 with any mutations contained in the site.
@@ -507,8 +503,8 @@ def get_mutation_samples(
507503 return num_alleles , site_offsets , allele_samples
508504
509505
510- SummaryFunc = Callable [[int , np .ndarray , int , np .ndarray , Dict [str , Any ]], None ]
511- NormFunc = Callable [[int , np .ndarray , int , int , np .ndarray , Dict [str , Any ]], None ]
506+ SummaryFunc = Callable [[int , np .ndarray , int , np .ndarray , dict [str , Any ]], None ]
507+ NormFunc = Callable [[int , np .ndarray , int , int , np .ndarray , dict [str , Any ]], None ]
512508
513509
514510def compute_general_two_site_stat_result (
@@ -523,7 +519,7 @@ def compute_general_two_site_stat_result(
523519 result_dim : int ,
524520 func : SummaryFunc ,
525521 norm_func : NormFunc ,
526- params : Dict [str , Any ],
522+ params : dict [str , Any ],
527523 polarised : bool ,
528524 result : np .ndarray ,
529525) -> None :
@@ -777,8 +773,8 @@ def two_branch_count_stat(
777773
778774
779775def sample_sets_to_bit_array (
780- ts : tskit .TreeSequence , sample_sets : Union [ List [ List [ int ]], List [np .ndarray ] ]
781- ) -> Tuple [np .ndarray , np .ndarray , BitSet ]:
776+ ts : tskit .TreeSequence , sample_sets : list [ list [ int ]] | list [np .ndarray ]
777+ ) -> tuple [np .ndarray , np .ndarray , BitSet ]:
782778 """Convert the list of sample ids to a bit array. This function takes
783779 sample identifiers and maps them to their enumerated integer values, then
784780 stores these values in a bit array. We produce a BitArray and a numpy
@@ -994,7 +990,7 @@ def r2_summary_func(
994990 state : np .ndarray ,
995991 result_dim : int ,
996992 result : np .ndarray ,
997- params : Dict [str , Any ],
993+ params : dict [str , Any ],
998994) -> None :
999995 """Summary function for the r2 statistic. We first compute the proportion of
1000996 AB, A, and B haplotypes, then we compute the r2 statistic, storing the outputs
@@ -1028,7 +1024,7 @@ def r2_ij_summary_func(
10281024 state : np .ndarray ,
10291025 result_dim : int ,
10301026 result : np .ndarray ,
1031- params : Dict [str , Any ],
1027+ params : dict [str , Any ],
10321028) -> None :
10331029 sample_set_sizes = params ["sample_set_sizes" ]
10341030 set_indexes = params ["set_indexes" ]
@@ -1062,7 +1058,7 @@ def D_summary_func(
10621058 state : np .ndarray ,
10631059 result_dim : int ,
10641060 result : np .ndarray ,
1065- params : Dict [str , Any ],
1061+ params : dict [str , Any ],
10661062) -> None :
10671063 sample_set_sizes = params ["sample_set_sizes" ]
10681064 for k in range (state_dim ):
@@ -1082,7 +1078,7 @@ def D2_summary_func(
10821078 state : np .ndarray ,
10831079 result_dim : int ,
10841080 result : np .ndarray ,
1085- params : Dict [str , Any ],
1081+ params : dict [str , Any ],
10861082) -> None :
10871083 sample_set_sizes = params ["sample_set_sizes" ]
10881084 for k in range (state_dim ):
@@ -1103,7 +1099,7 @@ def D_prime_summary_func(
11031099 state : np .ndarray ,
11041100 result_dim : int ,
11051101 result : np .ndarray ,
1106- params : Dict [str , Any ],
1102+ params : dict [str , Any ],
11071103) -> None :
11081104 sample_set_sizes = params ["sample_set_sizes" ]
11091105 for k in range (state_dim ):
@@ -1128,7 +1124,7 @@ def r_summary_func(
11281124 state : np .ndarray ,
11291125 result_dim : int ,
11301126 result : np .ndarray ,
1131- params : Dict [str , Any ],
1127+ params : dict [str , Any ],
11321128) -> None :
11331129 sample_set_sizes = params ["sample_set_sizes" ]
11341130 for k in range (state_dim ):
@@ -1152,7 +1148,7 @@ def Dz_summary_func(
11521148 state : np .ndarray ,
11531149 result_dim : int ,
11541150 result : np .ndarray ,
1155- params : Dict [str , Any ],
1151+ params : dict [str , Any ],
11561152) -> None :
11571153 sample_set_sizes = params ["sample_set_sizes" ]
11581154 for k in range (state_dim ):
@@ -1174,7 +1170,7 @@ def pi2_summary_func(
11741170 state : np .ndarray ,
11751171 result_dim : int ,
11761172 result : np .ndarray ,
1177- params : Dict [str , Any ],
1173+ params : dict [str , Any ],
11781174) -> None :
11791175 sample_set_sizes = params ["sample_set_sizes" ]
11801176 for k in range (state_dim ):
@@ -1205,7 +1201,7 @@ def pi2_unbiased_summary_func(
12051201 state : np .ndarray ,
12061202 result_dim : int ,
12071203 result : np .ndarray ,
1208- params : Dict [str , Any ],
1204+ params : dict [str , Any ],
12091205):
12101206 sample_set_sizes = params ["sample_set_sizes" ]
12111207 for k in range (state_dim ):
@@ -1227,7 +1223,7 @@ def Dz_unbiased_summary_func(
12271223 state : np .ndarray ,
12281224 result_dim : int ,
12291225 result : np .ndarray ,
1230- params : Dict [str , Any ],
1226+ params : dict [str , Any ],
12311227):
12321228 sample_set_sizes = params ["sample_set_sizes" ]
12331229 for k in range (state_dim ):
@@ -1253,7 +1249,7 @@ def D2_unbiased_summary_func(
12531249 state : np .ndarray ,
12541250 result_dim : int ,
12551251 result : np .ndarray ,
1256- params : Dict [str , Any ],
1252+ params : dict [str , Any ],
12571253):
12581254 sample_set_sizes = params ["sample_set_sizes" ]
12591255 for k in range (state_dim ):
@@ -1275,7 +1271,7 @@ def D2_ij_summary_func(
12751271 state : np .ndarray ,
12761272 result_dim : int ,
12771273 result : np .ndarray ,
1278- params : Dict [str , Any ],
1274+ params : dict [str , Any ],
12791275):
12801276 sample_set_sizes = params ["sample_set_sizes" ]
12811277 set_indexes = params ["set_indexes" ]
@@ -1307,7 +1303,7 @@ def D2_ij_unbiased_summary_func(
13071303 state : np .ndarray ,
13081304 result_dim : int ,
13091305 result : np .ndarray ,
1310- params : Dict [str , Any ],
1306+ params : dict [str , Any ],
13111307):
13121308 sample_set_sizes = params ["sample_set_sizes" ]
13131309 set_indexes = params ["set_indexes" ]
@@ -1831,8 +1827,8 @@ class TreeState:
18311827 # 0 1
18321828 # 1 0
18331829 # 1 1
1834- edges_out : List [int ] # list of edges removed during iteration
1835- edges_in : List [int ] # list of edges added during iteration
1830+ edges_out : list [int ] # list of edges removed during iteration
1831+ edges_in : list [int ] # list of edges added during iteration
18361832
18371833 def __init__ (self , ts , sample_sets , num_sample_sets , sample_index_map ):
18381834 self .pos = tsutil .TreeIndexes (ts )
0 commit comments