|
25 | 25 | """ |
26 | 26 | import numpy as np |
27 | 27 |
|
28 | | -import tskit |
29 | 28 | from . import provenance |
30 | 29 |
|
31 | 30 |
|
@@ -140,53 +139,32 @@ def __make_sample_mapping(self, ploidy, individuals): |
140 | 139 | raise ValueError( |
141 | 140 | "Cannot specify ploidy when individuals are present in tables " |
142 | 141 | ) |
143 | | - |
144 | | - if individuals is None: |
145 | | - # Find all sample nodes that reference individuals |
146 | | - individuals = np.unique(ts.nodes_individual[ts.samples()]) |
147 | | - if len(individuals) == 1 and individuals[0] == tskit.NULL: |
148 | | - # No samples refer to individuals |
149 | | - individuals = None |
150 | | - else: |
151 | | - # np.unique sorts the argument, so if NULL (-1) is present it |
152 | | - # will be the first value. |
153 | | - if individuals[0] == tskit.NULL: |
154 | | - raise ValueError( |
155 | | - "Sample nodes must either all be associated with individuals " |
156 | | - "or not associated with any individuals" |
157 | | - ) |
| 142 | + # If there are no individuals, or all the individuals are not associated with |
| 143 | + # nodes, then we split by ploidy. |
| 144 | + if ts.num_individuals > 0 and np.any(ts.individuals_nodes != -1): |
| 145 | + individuals_nodes = ts.individuals_nodes |
158 | 146 | else: |
| 147 | + if ploidy is None: |
| 148 | + ploidy = 1 |
| 149 | + individuals_nodes = ts.sample_nodes_by_ploidy(ploidy) |
| 150 | + |
| 151 | + if individuals is not None: |
159 | 152 | individuals = np.array(individuals, dtype=np.int32) |
160 | 153 | if len(individuals) == 0: |
161 | 154 | raise ValueError("List of sample individuals empty") |
| 155 | + if any(individuals < 0) or any(individuals >= ts.num_individuals): |
| 156 | + raise ValueError("Invalid individual IDs provided.") |
| 157 | + individuals_nodes = ts.individuals_nodes[individuals] |
162 | 158 |
|
163 | | - if individuals is not None: |
164 | | - self.samples = [] |
165 | | - # FIXME this could probably be done more efficiently. |
166 | | - for i in individuals: |
167 | | - if i < 0 or i >= self.tree_sequence.num_individuals: |
168 | | - raise ValueError("Invalid individual IDs provided.") |
169 | | - ind = self.tree_sequence.individual(i) |
170 | | - if len(ind.nodes) == 0: |
171 | | - raise ValueError(f"Individual {i} not associated with a node") |
172 | | - is_sample = {ts.node(u).is_sample() for u in ind.nodes} |
173 | | - if len(is_sample) != 1: |
174 | | - raise ValueError( |
175 | | - f"Individual {ind.id} has nodes that are sample and " |
176 | | - "non-samples" |
177 | | - ) |
178 | | - self.samples.extend(ind.nodes) |
179 | | - self.individual_ploidies.append(len(ind.nodes)) |
180 | | - else: |
181 | | - if ploidy is None: |
182 | | - ploidy = 1 |
183 | | - if ploidy < 1: |
184 | | - raise ValueError("Ploidy must be >= 1") |
185 | | - if ts.num_samples % ploidy != 0: |
186 | | - raise ValueError("Sample size must be divisible by ploidy") |
187 | | - self.individual_ploidies = np.full( |
188 | | - ts.sample_size // ploidy, ploidy, dtype=np.int32 |
189 | | - ) |
| 159 | + self.samples = [] |
| 160 | + for i, row in enumerate(individuals_nodes): |
| 161 | + wanted_nodes = row[row != -1] |
| 162 | + # This error only fires if an individual was specifically specified. |
| 163 | + if len(wanted_nodes) == 0 and individuals is not None: |
| 164 | + raise ValueError(f"Individual {i} not associated with a node") |
| 165 | + if len(wanted_nodes) > 0: |
| 166 | + self.samples.extend(wanted_nodes) |
| 167 | + self.individual_ploidies.append(len(wanted_nodes)) |
190 | 168 | self.num_individuals = len(self.individual_ploidies) |
191 | 169 |
|
192 | 170 | def __write_header(self, output): |
|
0 commit comments