Skip to content

Commit 5555c2f

Browse files
authored
Merge pull request #663 from hackingmaterials/unity-overlap
Add unity overlaps
2 parents 32939e0 + 80eed6c commit 5555c2f

6 files changed

Lines changed: 48 additions & 11 deletions

File tree

amset/core/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from amset.electronic_structure.fd import dfdde
2424
from amset.electronic_structure.tetrahedron import TetrahedralBandStructure
2525
from amset.interpolation.momentum import MRTACalculator
26+
from amset.interpolation.wavefunction import UnityWavefunctionOverlap
2627
from amset.io import write_mesh
2728
from amset.log import log_list, log_time_taken
2829
from amset.util import cast_dict_list, groupby, tensor_average
@@ -116,7 +117,9 @@ def ir_kpoints_idx(self):
116117
return self.tetrahedral_band_structure.ir_kpoints_idx
117118

118119
def set_overlap_calculator(self, overlap_calculator):
119-
if overlap_calculator is not None:
120+
if overlap_calculator is not None and not isinstance(
121+
overlap_calculator, UnityWavefunctionOverlap
122+
):
120123
equal = check_nbands_equal(overlap_calculator, self)
121124
if not equal:
122125
raise RuntimeError(

amset/core/run.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
from amset.electronic_structure.common import get_band_structure
2626
from amset.interpolation.bandstructure import Interpolator
2727
from amset.interpolation.projections import ProjectionOverlapCalculator
28-
from amset.interpolation.wavefunction import WavefunctionOverlapCalculator
28+
from amset.interpolation.wavefunction import (
29+
UnityWavefunctionOverlap,
30+
WavefunctionOverlapCalculator,
31+
)
2932
from amset.io import load_settings, write_settings
3033
from amset.log import initialize_amset_logger, log_banner, log_list
3134
from amset.scattering.calculate import ScatteringCalculator, basic_scatterers
@@ -161,9 +164,8 @@ def _do_many_fd_tol(self, amset_data, fd_tols, directory, prefix, timing):
161164
return amset_data, timing
162165

163166
def _check_wavefunction(self):
164-
if (
165-
not Path(self.settings["wavefunction_coefficients"]).exists()
166-
and not self.settings["use_projections"]
167+
if not Path(self.settings["wavefunction_coefficients"]).exists() and not (
168+
self.settings["use_projections"] or self.settings["unity_overlap"]
167169
):
168170
raise ValueError(
169171
"Could not find wavefunction coefficients. To run AMSET, the \n"
@@ -209,6 +211,8 @@ def _do_interpolation(self):
209211

210212
if set(self.settings["scattering_type"]).issubset(set(basic_scatterers)):
211213
overlap_calculator = None
214+
elif self.settings["unity_overlap"]:
215+
overlap_calculator = UnityWavefunctionOverlap()
212216
elif self.settings["use_projections"]:
213217
overlap_calculator = ProjectionOverlapCalculator.from_band_structure(
214218
self._band_structure,

amset/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ interpolation_factor: 10
2525
# Overlap settings
2626
wavefunction_coefficients: wavefunction.h5 # Path to a wavefunction coefficients file
2727
use_projections: false # use orbital projections for overlap (not recommended)
28+
unity_overlap: false # set all overlaps to unity
2829

2930
# whether free electrons screen the polar optical and piezoelectric scattering rates
3031
free_carrier_screening: false

amset/interpolation/wavefunction.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,22 @@ def _get_overlap_ncl(grid, data, points, n_coeffs):
167167
res[i - 1] = abs(sum_) ** 2
168168

169169
return res
170+
171+
172+
class UnityWavefunctionOverlap:
173+
def __init__(self, *args, **kwargs):
174+
pass
175+
176+
def to_reference(self):
177+
return [1, 2, 3]
178+
179+
@classmethod
180+
def from_reference(cls, *args, **kwargs):
181+
return cls(*args, **kwargs)
182+
183+
@classmethod
184+
def from_data(cls, *args, **kwargs):
185+
return cls(*args, **kwargs)
186+
187+
def get_overlap(self, *args, **kwargs):
188+
return 1

amset/scattering/calculate.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@
4141
transform_quad,
4242
transform_triangle,
4343
)
44-
from amset.interpolation.wavefunction import WavefunctionOverlapCalculator
44+
from amset.interpolation.wavefunction import (
45+
UnityWavefunctionOverlap,
46+
WavefunctionOverlapCalculator,
47+
)
4548
from amset.log import log_list, log_time_taken
4649
from amset.scattering.basic import AbstractBasicScattering
4750
from amset.scattering.elastic import (
@@ -136,6 +139,14 @@ def __init__(
136139
"Caching wavefunction not supported with orbital projection "
137140
"overlaps. Setting cache_wavefunction to False."
138141
)
142+
elif (
143+
isinstance(self.amset_data.overlap_calculator, UnityWavefunctionOverlap)
144+
and cache_wavefunction
145+
):
146+
logger.info(
147+
"Caching wavefunction not supported with unity overlaps. Setting "
148+
"cache_wavefunction to False."
149+
)
139150
elif cache_wavefunction and not self._basic_only:
140151
self._coeffs = {}
141152
self._coeffs_mapping = {}
@@ -200,6 +211,8 @@ def initialize_workers(self):
200211

201212
if isinstance(self.amset_data.overlap_calculator, ProjectionOverlapCalculator):
202213
overlap_type = "projection"
214+
elif isinstance(self.amset_data.overlap_calculator, UnityWavefunctionOverlap):
215+
overlap_type = "unity"
203216
else:
204217
overlap_type = "wavefunction"
205218

@@ -539,6 +552,8 @@ def scattering_worker(
539552
overlap_calculator = WavefunctionOverlapCalculator.from_reference(
540553
*overlap_calculator_reference
541554
)
555+
elif overlap_type == "unity":
556+
overlap_calculator = UnityWavefunctionOverlap()
542557
elif overlap_type == "projection":
543558
overlap_calculator = ProjectionOverlapCalculator.from_reference(
544559
*overlap_calculator_reference

amset/util.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ def validate_settings(user_settings: Dict[str, Any]) -> Dict[str, Any]:
7676

7777
for charge_setting in ("donor_charge", "acceptor_charge"):
7878
if charge_setting in settings:
79-
logger.warning(
80-
f"The {charge_setting} option has been renamed to defect_charge and "
81-
"will be removed in June 2021. Please see the documentation for more "
82-
"details."
83-
)
8479
settings["defect_charge"] = settings.pop(charge_setting)
8580

8681
for setting in settings:

0 commit comments

Comments
 (0)