|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import functools |
| 3 | +from collections.abc import ( |
| 4 | + Callable, |
| 5 | +) |
2 | 6 | from copy import ( |
3 | 7 | deepcopy, |
4 | 8 | ) |
@@ -332,6 +336,88 @@ def model_output_def(self) -> ModelOutputDef: |
332 | 336 | backbone_model_atomic_output_def[var_name].magnetic = True |
333 | 337 | return ModelOutputDef(backbone_model_atomic_output_def) |
334 | 338 |
|
| 339 | + def _get_spin_sampled_func( |
| 340 | + self, sampled_func: Callable[[], list[dict]] |
| 341 | + ) -> Callable[[], list[dict]]: |
| 342 | + """Get a spin-aware sampled function that transforms spin data for the backbone model. |
| 343 | +
|
| 344 | + Parameters |
| 345 | + ---------- |
| 346 | + sampled_func |
| 347 | + A callable that returns a list of data dicts containing 'coord', 'atype', 'spin', etc. |
| 348 | +
|
| 349 | + Returns |
| 350 | + ------- |
| 351 | + Callable |
| 352 | + A cached callable that returns spin-preprocessed data dicts. |
| 353 | + """ |
| 354 | + |
| 355 | + @functools.lru_cache |
| 356 | + def spin_sampled_func() -> list[dict]: |
| 357 | + sampled = sampled_func() |
| 358 | + spin_sampled = [] |
| 359 | + for sys in sampled: |
| 360 | + coord_updated, atype_updated = self.process_spin_input( |
| 361 | + sys["coord"], sys["atype"], sys["spin"] |
| 362 | + ) |
| 363 | + tmp_dict = { |
| 364 | + "coord": coord_updated, |
| 365 | + "atype": atype_updated, |
| 366 | + } |
| 367 | + if "natoms" in sys: |
| 368 | + natoms = sys["natoms"] |
| 369 | + tmp_dict["natoms"] = np.concatenate( |
| 370 | + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1 |
| 371 | + ) |
| 372 | + for item_key in sys.keys(): |
| 373 | + if item_key not in ["coord", "atype", "spin", "natoms"]: |
| 374 | + tmp_dict[item_key] = sys[item_key] |
| 375 | + spin_sampled.append(tmp_dict) |
| 376 | + return spin_sampled |
| 377 | + |
| 378 | + return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func) |
| 379 | + |
| 380 | + def change_out_bias( |
| 381 | + self, |
| 382 | + merged: Callable[[], list[dict]] | list[dict], |
| 383 | + bias_adjust_mode: str = "change-by-statistic", |
| 384 | + ) -> None: |
| 385 | + """Change the output bias of atomic model according to the input data and the pretrained model. |
| 386 | +
|
| 387 | + Parameters |
| 388 | + ---------- |
| 389 | + merged : Union[Callable[[], list[dict]], list[dict]] |
| 390 | + - list[dict]: A list of data samples from various data systems. |
| 391 | + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` |
| 392 | + originating from the `i`-th data system. |
| 393 | + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format |
| 394 | + only when needed. Since the sampling process can be slow and memory-intensive, |
| 395 | + the lazy function helps by only sampling once. |
| 396 | + bias_adjust_mode : str |
| 397 | + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] |
| 398 | + 'change-by-statistic' : perform predictions on labels of target dataset, |
| 399 | + and do least square on the errors to obtain the target shift as bias. |
| 400 | + 'set-by-statistic' : directly use the statistic output bias in the target dataset. |
| 401 | + """ |
| 402 | + spin_sampled_func = self._get_spin_sampled_func( |
| 403 | + merged if callable(merged) else lambda: merged |
| 404 | + ) |
| 405 | + self.backbone_model.change_out_bias( |
| 406 | + spin_sampled_func, |
| 407 | + bias_adjust_mode=bias_adjust_mode, |
| 408 | + ) |
| 409 | + |
| 410 | + def change_type_map( |
| 411 | + self, type_map: list[str], model_with_new_type_stat: Any = None |
| 412 | + ) -> None: |
| 413 | + """Change the type related params to new ones, according to `type_map` and the original one in the model. |
| 414 | + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. |
| 415 | + """ |
| 416 | + type_map_with_spin = type_map + [item + "_spin" for item in type_map] |
| 417 | + self.backbone_model.change_type_map( |
| 418 | + type_map_with_spin, model_with_new_type_stat |
| 419 | + ) |
| 420 | + |
335 | 421 | def __getattr__(self, name: str) -> Any: |
336 | 422 | """Get attribute from the wrapped model.""" |
337 | 423 | if "backbone_model" not in self.__dict__: |
|
0 commit comments