99logger = logging .getLogger (__name__ )
1010
1111
12- import numpy as np
13- import tskit
14-
15- def individual_nodes (ts ):
16- """
17- Convert a tree sequence with individuals to a 2D array of node IDs.
18-
19- Parameters
20- ----------
21- ts : tskit.TreeSequence
22- The tree sequence to convert
23-
24- Returns
25- -------
26- numpy.ndarray
27- Array of shape (num_individuals, max_ploidy) containing node IDs.
28- Values of -1 indicate unused slots for individuals with ploidy
29- less than the maximum.
30-
31- Raises
32- ------
33- ValueError
34- If the tree sequence has no individuals, if any sample doesn't have an individual,
35- if individuals have nodes that are both samples and non-samples, or if an
36- individuals has no samples.
37- """
38- if ts .num_individuals == 0 :
39- raise ValueError ("Tree sequence has no individuals" )
40-
41- individuals = np .unique (ts .nodes_individual [ts .samples ()])
42- if len (individuals ) == 1 and individuals [0 ] == tskit .NULL :
43- raise ValueError ("No samples refer to individuals" )
44-
45- # np.unique sorts the argument, so if NULL (-1) is present it will be first
46- if individuals [0 ] == tskit .NULL :
47- raise ValueError (
48- "Sample nodes must all be associated with individuals"
49- )
50-
51- max_ploidy = 0
52- for i in range (ts .num_individuals ):
53- ind = ts .individual (i )
54- max_ploidy = max (max_ploidy , len (ind .nodes ))
55-
56- # Initialize output array with -1 (indicating no node)
57- result = np .full ((ts .num_individuals , max_ploidy ), - 1 , dtype = np .int32 )
58-
59- for i in range (ts .num_individuals ):
60- ind = ts .individual (i )
61- if len (ind .nodes ) == 0 :
62- raise ValueError (f"Individual { i } not associated with any nodes" )
63-
64- is_sample = {ts .node (u ).is_sample () for u in ind .nodes }
65- if len (is_sample ) != 1 :
66- raise ValueError (
67- f"Individual { ind .id } has nodes that are sample and non-samples"
68- )
69-
70- for j , node_id in enumerate (ind .nodes ):
71- result [i , j ] = node_id
72-
73- return result
74-
7512class TskitFormat (vcz .Source ):
76- def __init__ (self , ts_path , contig_id = None , isolated_as_missing = False ):
13+ def __init__ (
14+ self ,
15+ ts_path ,
16+ individual_nodes ,
17+ sample_ids = None ,
18+ contig_id = None ,
19+ isolated_as_missing = False ,
20+ ):
7721 self ._path = ts_path
7822 self .ts = tskit .load (ts_path )
7923 self .contig_id = contig_id if contig_id is not None else "1"
8024 self .isolated_as_missing = isolated_as_missing
8125
82- self ._make_sample_mapping ()
8326 self .positions = self .ts .sites_position
8427
28+ self ._num_samples = individual_nodes .shape [0 ]
29+ if self ._num_samples < 1 :
30+ raise ValueError ("individual_nodes must have at least one sample" )
31+ self .max_ploidy = individual_nodes .shape [1 ]
32+ if sample_ids is None :
33+ sample_ids = [f"tsk_{ j } " for j in range (self ._num_samples )]
34+ elif len (sample_ids ) != self ._num_samples :
35+ raise ValueError (
36+ f"Length of sample_ids ({ len (sample_ids )} ) does not match "
37+ f"number of samples ({ self ._num_samples } )"
38+ )
39+
40+ self ._samples = [vcz .Sample (id = sample_id ) for sample_id in sample_ids ]
41+
42+ self .tskit_samples = np .unique (individual_nodes [individual_nodes >= 0 ])
43+ if len (self .tskit_samples ) < 1 :
44+ raise ValueError ("individual_nodes must have at least one valid sample" )
45+ node_id_to_index = {node_id : i for i , node_id in enumerate (self .tskit_samples )}
46+ valid_mask = individual_nodes >= 0
47+ self .sample_indices , self .ploidy_indices = np .where (valid_mask )
48+ self .genotype_indices = np .array (
49+ [node_id_to_index [node_id ] for node_id in individual_nodes [valid_mask ]]
50+ )
51+
8552 @property
8653 def path (self ):
8754 return self ._path
@@ -106,21 +73,6 @@ def root_attrs(self):
10673 def contigs (self ):
10774 return [vcz .Contig (id = self .contig_id )]
10875
109- def _make_sample_mapping (self ):
110- ts = self .ts
111-
112- # Use individual_nodes to get the mapping between individuals and nodes
113- try :
114- # Get a 2D array of node IDs for each individual
115- self .node_ids_array = individual_nodes (ts )
116- self ._num_samples = ts .num_individuals
117- self .max_ploidy = self .node_ids_array .shape [1 ]
118-
119- except ValueError as e :
120- raise ValueError (f"Error mapping individuals to nodes: { e } " ) from e
121-
122- self ._samples = [vcz .Sample (id = f"tsk_{ j } " ) for j in range (self .num_samples )]
123-
12476 def iter_contig (self , start , stop ):
12577 yield from (0 for _ in range (start , stop ))
12678
@@ -139,6 +91,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
13991 isolated_as_missing = self .isolated_as_missing ,
14092 left = self .positions [start ],
14193 right = self .positions [stop ] if stop < self .num_records else None ,
94+ samples = self .tskit_samples ,
14295 ):
14396 gt = np .full (shape , constants .INT_FILL , dtype = np .int8 )
14497 alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
@@ -149,12 +102,9 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
149102 assert i < num_alleles
150103 alleles [i ] = allele
151104
152- # For each individual, get genotypes for their nodes
153- for i in range (self .num_samples ):
154- for j in range (self .max_ploidy ):
155- node_id = self .node_ids_array [i , j ]
156- if node_id >= 0 : # Skip -1 entries (unused slots)
157- gt [i , j ] = variant .genotypes [node_id ]
105+ gt [self .sample_indices , self .ploidy_indices ] = variant .genotypes [
106+ self .genotype_indices
107+ ]
158108
159109 yield alleles , (gt , phased )
160110
@@ -253,7 +203,9 @@ def generate_schema(
253203def convert (
254204 ts_path ,
255205 zarr_path ,
206+ individual_nodes ,
256207 * ,
208+ sample_ids = None ,
257209 contig_id = None ,
258210 isolated_as_missing = False ,
259211 variants_chunk_size = None ,
@@ -263,6 +215,8 @@ def convert(
263215):
264216 tskit_format = TskitFormat (
265217 ts_path ,
218+ individual_nodes ,
219+ sample_ids = sample_ids ,
266220 contig_id = contig_id ,
267221 isolated_as_missing = isolated_as_missing ,
268222 )
0 commit comments