|
41 | 41 | transform_quad, |
42 | 42 | transform_triangle, |
43 | 43 | ) |
44 | | -from amset.interpolation.wavefunction import WavefunctionOverlapCalculator |
| 44 | +from amset.interpolation.wavefunction import ( |
| 45 | + UnityWavefunctionOverlap, |
| 46 | + WavefunctionOverlapCalculator, |
| 47 | +) |
45 | 48 | from amset.log import log_list, log_time_taken |
46 | 49 | from amset.scattering.basic import AbstractBasicScattering |
47 | 50 | from amset.scattering.elastic import ( |
@@ -136,6 +139,14 @@ def __init__( |
136 | 139 | "Caching wavefunction not supported with orbital projection " |
137 | 140 | "overlaps. Setting cache_wavefunction to False." |
138 | 141 | ) |
| 142 | + elif ( |
| 143 | + isinstance(self.amset_data.overlap_calculator, UnityWavefunctionOverlap) |
| 144 | + and cache_wavefunction |
| 145 | + ): |
| 146 | + logger.info( |
| 147 | + "Caching wavefunction not supported with unity overlaps. Setting " |
| 148 | + "cache_wavefunction to False." |
| 149 | + ) |
139 | 150 | elif cache_wavefunction and not self._basic_only: |
140 | 151 | self._coeffs = {} |
141 | 152 | self._coeffs_mapping = {} |
@@ -200,6 +211,8 @@ def initialize_workers(self): |
200 | 211 |
|
201 | 212 | if isinstance(self.amset_data.overlap_calculator, ProjectionOverlapCalculator): |
202 | 213 | overlap_type = "projection" |
| 214 | + elif isinstance(self.amset_data.overlap_calculator, UnityWavefunctionOverlap): |
| 215 | + overlap_type = "unity" |
203 | 216 | else: |
204 | 217 | overlap_type = "wavefunction" |
205 | 218 |
|
@@ -539,6 +552,8 @@ def scattering_worker( |
539 | 552 | overlap_calculator = WavefunctionOverlapCalculator.from_reference( |
540 | 553 | *overlap_calculator_reference |
541 | 554 | ) |
| 555 | + elif overlap_type == "unity": |
| 556 | + overlap_calculator = UnityWavefunctionOverlap() |
542 | 557 | elif overlap_type == "projection": |
543 | 558 | overlap_calculator = ProjectionOverlapCalculator.from_reference( |
544 | 559 | *overlap_calculator_reference |
|
0 commit comments