|
22 | 22 |
|
23 | 23 |
|
24 | 24 | @dataclass |
25 | | -class _Candidate: |
26 | | - """A single placement candidate produced by the solver.""" |
| 25 | +class PlacementCandidate: |
| 26 | + """A single solver result, ranked and selected in ObjectPlacer.place().""" |
27 | 27 |
|
28 | 28 | loss: float |
29 | 29 | """Loss value returned by the solver.""" |
@@ -118,35 +118,35 @@ def place( |
118 | 118 | assert self._solver.last_loss_per_env is not None |
119 | 119 | all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() |
120 | 120 |
|
121 | | - all_candidates: list[_Candidate] = [] |
| 121 | + all_candidates: list[PlacementCandidate] = [] |
122 | 122 | for idx in range(num_candidates): |
123 | 123 | loss = all_losses[idx] |
124 | 124 | is_valid = self._validate_placement(all_positions[idx]) |
125 | | - all_candidates.append(_Candidate(loss, all_positions[idx], is_valid)) |
| 125 | + all_candidates.append(PlacementCandidate(loss, all_positions[idx], is_valid)) |
126 | 126 |
|
127 | 127 | # Sort: valid solutions first (by loss), then invalid (by loss) |
128 | | - all_candidates.sort(key=lambda c: (not c.is_valid, c.loss)) |
| 128 | + all_candidates.sort(key=lambda candidate: (not candidate.is_valid, candidate.loss)) |
129 | 129 | selected = all_candidates[:num_results] |
130 | 130 |
|
131 | | - n_valid = sum(1 for c in selected if c.is_valid) |
| 131 | + n_valid = sum(1 for candidate in selected if candidate.is_valid) |
132 | 132 | if self.params.verbose: |
133 | | - total_valid = sum(1 for c in all_candidates if c.is_valid) |
134 | | - finite_losses = [c.loss for c in all_candidates if math.isfinite(c.loss)] |
| 133 | + total_valid = sum(1 for candidate in all_candidates if candidate.is_valid) |
| 134 | + finite_losses = [candidate.loss for candidate in all_candidates if math.isfinite(candidate.loss)] |
135 | 135 | mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") |
136 | 136 | print( |
137 | 137 | f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," |
138 | 138 | f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" |
139 | 139 | ) |
140 | 140 |
|
141 | | - final_per_env: list[dict] = [c.positions for c in selected] |
| 141 | + final_per_env: list[dict] = [candidate.positions for candidate in selected] |
142 | 142 | results_per_env = [ |
143 | 143 | PlacementResult( |
144 | | - success=c.is_valid, |
145 | | - positions=c.positions, |
146 | | - final_loss=c.loss, |
| 144 | + success=candidate.is_valid, |
| 145 | + positions=candidate.positions, |
| 146 | + final_loss=candidate.loss, |
147 | 147 | attempts=self.params.max_placement_attempts, |
148 | 148 | ) |
149 | | - for c in selected |
| 149 | + for candidate in selected |
150 | 150 | ] |
151 | 151 |
|
152 | 152 | if self.params.apply_positions_to_objects: |
|
0 commit comments