diff --git a/src/abacusagent/constant.py b/src/abacusagent/constant.py index 96dff74..e531c26 100644 --- a/src/abacusagent/constant.py +++ b/src/abacusagent/constant.py @@ -1,2 +1,15 @@ RY_TO_EV = 13.60569253 THZ_TO_K = 47.9924 + +# Physical constants for effective mass calculation +HBAR_EV_S = 6.582119569e-16 # ℏ in eV·s +HBAR_J_S = 1.054571817e-34 # ℏ in J·s +ELECTRON_MASS_KG = 9.1093837015e-31 # m_e in kg +ANGSTROM_TO_M = 1e-10 # Å to m conversion +EV_TO_J = 1.602176634e-19 # eV to J conversion + +# Derived constant: m*/m_e = EFFECTIVE_MASS_FACTOR / curvature +# where curvature is d²E/dk² in eV/Å⁻² +# Formula: m* = ℏ² / (m_e * d²E/dk²) +# Converting units: ℏ in J·s, m_e in kg, curvature in eV/Å⁻² +EFFECTIVE_MASS_FACTOR = (HBAR_J_S**2) / (ELECTRON_MASS_KG * EV_TO_J * ANGSTROM_TO_M**2) diff --git a/src/abacusagent/modules/band.py b/src/abacusagent/modules/band.py index 71ab934..5cd8bf8 100644 --- a/src/abacusagent/modules/band.py +++ b/src/abacusagent/modules/band.py @@ -1,8 +1,9 @@ from pathlib import Path -from typing import Literal, Dict, List, Union +from typing import Literal, Dict, List, Union, Optional from abacusagent.init_mcp import mcp from abacusagent.modules.submodules.band import abacus_cal_band as _abacus_cal_band +from abacusagent.modules.submodules.band import abacus_cal_effective_mass as _abacus_cal_effective_mass @mcp.tool() def abacus_cal_band(abacus_inputs_dir: Path, @@ -43,3 +44,109 @@ def abacus_cal_band(abacus_inputs_dir: Path, """ return _abacus_cal_band(abacus_inputs_dir, mode, kpath, high_symm_points, energy_min, energy_max, insert_point_nums) + +@mcp.tool() +def abacus_cal_effective_mass( + band_calc_dir: Path, + calculation_points: Union[Literal["auto", "extrema"], List[Dict[str, Union[str, List[float], int]]]] = "auto", + fitting_window: int = 5, + directions: List[str] = ["kx", "ky", "kz"], + band_indices: Optional[List[int]] = None, + energy_range: Optional[List[float]] = None, + output_dir: Optional[Path] = None +) -> Dict[str, Union[List, Path, Dict, str]]: + """ + Calculate effective mass from band structure using parabolic fitting. + + This function analyzes band structure data to compute effective masses at band extrema + (VBM/CBM) or user-specified k-points. It uses parabolic fitting E(k) = E₀ + a(k-k₀)² + to extract the band curvature d²E/dk², from which the effective mass is calculated + using m* = ℏ²/(d²E/dk²). + + Args: + band_calc_dir: Path to directory containing band calculation results. Must have + BANDS_*.dat files from ABACUS NSCF or PYATB calculation. + calculation_points: Where to calculate effective mass. Options: + - "auto": Automatically detect VBM and CBM (default) + - "extrema": Find all local extrema within energy_range + - List of dicts: User-specified points, each dict contains: + * "type": "kpoint" or "high_symmetry" + * "coords": [kx, ky, kz] for "kpoint" type + * "label": "G", "M", "K", etc. for "high_symmetry" type + * "band_index": (optional) specific band index + fitting_window: Number of k-points on each side of extremum for parabolic + fitting. Larger values give smoother fits but may miss non-parabolic + behavior. Default: 5 (total 11 points used). + directions: Directions for effective mass calculation. Currently only "kpath" + direction (along the band path) is implemented. Default: ["kx", "ky", "kz"]. + band_indices: Specific band indices to analyze. If None, analyzes all bands + near Fermi level. Default: None. + energy_range: [E_min, E_max] in eV relative to Fermi level for extrema + detection. Only used when calculation_points="extrema". Default: [-2, 2]. + output_dir: Directory for output files (plots and JSON). If None, uses + band_calc_dir. Default: None. + + Returns: + Dict containing: + - effective_mass_results: List of dicts with effective mass data for each point. + Each dict contains: + * point_info: Band index, k-point index, coordinates, energy, extrema type + * effective_masses: m* values in units of electron mass (m_e) for each direction + * fitting_data: k-distances, energies, and fitted values for plotting + - effective_mass_json: Path to JSON file with detailed results + - effective_mass_plots: List of paths to visualization plots (one per point + summary) + - summary: Dict with statistics: + * electron_effective_mass: {average, min, max, std} for CBM + * hole_effective_mass: {average, min, max, std} for VBM + - message: Success or error message + + Raises: + RuntimeError: If band data files are not found or cannot be read + ValueError: If calculation_points format is invalid + + Examples: + # Calculate effective mass at VBM and CBM automatically + result = abacus_cal_effective_mass("/path/to/band/calc") + print(f"Electron m* = {result['summary']['electron_effective_mass']['average']:.3f} m_e") + + # Calculate at all extrema within ±3 eV of Fermi level + result = abacus_cal_effective_mass( + "/path/to/band/calc", + calculation_points="extrema", + energy_range=[-3.0, 3.0] + ) + + # Calculate at specific k-point + result = abacus_cal_effective_mass( + "/path/to/band/calc", + calculation_points=[ + {"type": "kpoint", "coords": [0.0, 0.0, 0.0], "band_index": 10} + ] + ) + + # Calculate at high symmetry point + result = abacus_cal_effective_mass( + "/path/to/band/calc", + calculation_points=[ + {"type": "high_symmetry", "label": "G"} + ] + ) + + Notes: + - Effective mass is reported in units of electron mass (m_e = 9.109×10⁻³¹ kg) + - Positive curvature → positive effective mass (electron-like) + - Negative curvature → negative effective mass (hole-like) + - Reported values are absolute values |m*| + - R² < 0.90 indicates poor parabolic fit; consider denser k-mesh + - For accurate results, use insert_point_nums ≥ 30 in band calculation + """ + return _abacus_cal_effective_mass( + band_calc_dir, + calculation_points, + fitting_window, + directions, + band_indices, + energy_range, + output_dir + ) + diff --git a/src/abacusagent/modules/submodules/band.py b/src/abacusagent/modules/submodules/band.py index 250ffbd..2b9c860 100644 --- a/src/abacusagent/modules/submodules/band.py +++ b/src/abacusagent/modules/submodules/band.py @@ -426,3 +426,789 @@ def abacus_cal_band(abacus_inputs_dir: Path, raise ValueError(f"Calculation mode {mode} not in ('pyatb', 'nscf', 'auto')") except Exception as e: return {'message': f"Calculating band failed: {e}"} + + +# ============================================================================ +# Effective Mass Calculation Functions +# ============================================================================ + +def load_band_data_for_effective_mass(band_calc_dir: Path) -> Dict[str, Any]: + """ + Load band data from BANDS_*.dat files and extract k-point information. + + Args: + band_calc_dir: Path to directory containing band calculation results + + Returns: + Dict containing: + - bands: List[List[float]] - bands[i][j] = energy of band i at k-point j + - kline: List[float] - cumulative k-distances + - kpoints: List[List[float]] - actual k-point coordinates in reciprocal space + - efermi: float - Fermi energy + - nspin: int - spin polarization + - bands_dw: Optional[List[List[float]]] - spin-down bands if nspin=2 + """ + import numpy as np + + input_params = ReadInput(os.path.join(band_calc_dir, "INPUT")) + suffix = input_params.get('suffix', 'ABACUS') + nspin = input_params.get('nspin', 1) + + # Get Fermi energy + metrics = collect_metrics(band_calc_dir, ['efermi']) + efermi = metrics['efermi'] + + # Read band data + band_file = os.path.join(band_calc_dir, f"OUT.{suffix}/BANDS_1.dat") + bands, kline, nbands = read_band_data(band_file, efermi) + + bands_dw = None + if nspin == 2: + band_file_dw = os.path.join(band_calc_dir, f"OUT.{suffix}/BANDS_2.dat") + bands_dw, _, _ = read_band_data(band_file_dw, efermi) + + # Reconstruct k-point coordinates + kpoints = reconstruct_kpoint_coords(band_calc_dir, kline) + + return { + 'bands': bands, + 'kline': kline, + 'kpoints': kpoints, + 'efermi': efermi, + 'nspin': nspin, + 'bands_dw': bands_dw, + 'nbands': nbands + } + + +def reconstruct_kpoint_coords(band_calc_dir: Path, kline: List[float]) -> List[List[float]]: + """ + Reconstruct actual k-point coordinates from KPT file and k-line distances. + + Args: + band_calc_dir: Path to band calculation directory + kline: List of cumulative k-distances + + Returns: + List of k-point coordinates [[kx, ky, kz], ...] + """ + import numpy as np + + # Read KPT_band file to get high symmetry points + kpt_file = os.path.join(band_calc_dir, "KPT_band") + if not os.path.exists(kpt_file): + kpt_file = os.path.join(band_calc_dir, "KPT") + + high_symm_kpoints = [] + insert_nums = [] + + with open(kpt_file) as fin: + lines = fin.readlines() + for line in lines: + words = line.split() + if len(words) >= 4: + try: + kx, ky, kz = float(words[0]), float(words[1]), float(words[2]) + num = int(words[3]) + high_symm_kpoints.append([kx, ky, kz]) + insert_nums.append(num) + except ValueError: + continue + + # Interpolate k-points along the path + kpoints = [] + for i in range(len(high_symm_kpoints) - 1): + k_start = np.array(high_symm_kpoints[i]) + k_end = np.array(high_symm_kpoints[i + 1]) + num_points = insert_nums[i] + + for j in range(num_points): + t = j / num_points if num_points > 1 else 0 + k_interp = k_start + t * (k_end - k_start) + kpoints.append(k_interp.tolist()) + + # Add the last point + kpoints.append(high_symm_kpoints[-1]) + + return kpoints + + +def find_band_extrema( + bands: List[List[float]], + kpoints: List[List[float]], + kline: List[float], + efermi: float, + energy_range: List[float], + band_indices: Optional[List[int]] = None +) -> List[Dict[str, Any]]: + """ + Automatically detect band extrema (VBM, CBM, and other local extrema). + + Args: + bands: Band energies + kpoints: K-point coordinates + kline: K-line distances + efermi: Fermi energy (already subtracted from bands) + energy_range: [E_min, E_max] relative to Fermi level for extrema detection + band_indices: Specific band indices to analyze + + Returns: + List of dicts, each containing extrema information + """ + import numpy as np + + extrema = [] + + # Find VBM and CBM + vbm_info = None + cbm_info = None + vbm_energy = -float('inf') + cbm_energy = float('inf') + + for band_idx, band in enumerate(bands): + if band_indices is not None and band_idx not in band_indices: + continue + + for k_idx, energy in enumerate(band): + # Find VBM (highest occupied state below Fermi level) + if energy < 0 and energy > vbm_energy: + vbm_energy = energy + vbm_info = { + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': energy, + 'extrema_type': 'VBM', + 'is_degenerate': False + } + + # Find CBM (lowest unoccupied state above Fermi level) + if energy > 0 and energy < cbm_energy: + cbm_energy = energy + cbm_info = { + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': energy, + 'extrema_type': 'CBM', + 'is_degenerate': False + } + + if vbm_info: + extrema.append(vbm_info) + if cbm_info: + extrema.append(cbm_info) + + # Find other local extrema within energy range + for band_idx, band in enumerate(bands): + if band_indices is not None and band_idx not in band_indices: + continue + + for k_idx in range(1, len(band) - 1): + energy = band[k_idx] + + # Check if within energy range + if energy < energy_range[0] or energy > energy_range[1]: + continue + + # Check for local maximum + if band[k_idx] > band[k_idx - 1] and band[k_idx] > band[k_idx + 1]: + extrema_type = 'local_max' + if abs(energy - vbm_energy) < 0.01: # Same as VBM + continue + + extrema.append({ + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': energy, + 'extrema_type': extrema_type, + 'is_degenerate': False + }) + + # Check for local minimum + elif band[k_idx] < band[k_idx - 1] and band[k_idx] < band[k_idx + 1]: + extrema_type = 'local_min' + if abs(energy - cbm_energy) < 0.01: # Same as CBM + continue + + extrema.append({ + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': energy, + 'extrema_type': extrema_type, + 'is_degenerate': False + }) + + return extrema + + +def fit_parabola_1d( + k_distances: List[float], + energies: List[float], + k0: float = 0.0 +) -> Dict[str, Any]: + """ + Fit parabola E(k) = E0 + a*(k-k0)^2 to band data. + + Args: + k_distances: K-point distances along direction + energies: Band energies at those k-points + k0: Center k-point (default: 0.0) + + Returns: + Dict containing fit parameters and quality metrics + """ + import numpy as np + + if len(k_distances) < 3: + return { + 'E0': None, + 'a': None, + 'curvature': None, + 'r_squared': 0.0, + 'fit_energies': [], + 'residuals': [], + 'error': 'Insufficient data points for fitting' + } + + # Shift k-points to center at k0 + k_shifted = np.array(k_distances) - k0 + e_array = np.array(energies) + + # Fit parabola: E = E0 + a*k^2 + # Using polyfit with degree 2 + try: + coeffs = np.polyfit(k_shifted, e_array, 2) + a, b, E0 = coeffs[0], coeffs[1], coeffs[2] + + # Calculate fitted energies + fit_energies = np.polyval(coeffs, k_shifted) + + # Calculate R² + ss_res = np.sum((e_array - fit_energies) ** 2) + ss_tot = np.sum((e_array - np.mean(e_array)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0 + + # Curvature is d²E/dk² = 2*a + curvature = 2 * a + + return { + 'E0': E0, + 'a': a, + 'b': b, + 'curvature': curvature, + 'r_squared': r_squared, + 'fit_energies': fit_energies.tolist(), + 'residuals': (e_array - fit_energies).tolist(), + 'num_points': len(k_distances) + } + except Exception as e: + return { + 'E0': None, + 'a': None, + 'curvature': None, + 'r_squared': 0.0, + 'fit_energies': [], + 'residuals': [], + 'error': f'Fitting failed: {str(e)}' + } + + +def calculate_effective_mass_from_curvature( + curvature: float, + direction: str +) -> Dict[str, float]: + """ + Calculate effective mass from band curvature using m* = ℏ²/(d²E/dk²). + + Args: + curvature: d²E/dk² in eV/Å^-2 + direction: "kx", "ky", or "kz" + + Returns: + Dict containing effective mass values + """ + from abacusagent.constant import EFFECTIVE_MASS_FACTOR + + if curvature is None or abs(curvature) < 1e-10: + return { + 'm_star': float('inf'), + 'm_star_kg': float('inf'), + 'curvature': curvature, + 'is_flat_band': True + } + + # m*/m_e = EFFECTIVE_MASS_FACTOR / curvature + m_star = EFFECTIVE_MASS_FACTOR / curvature + + # Convert to kg + from abacusagent.constant import ELECTRON_MASS_KG + m_star_kg = m_star * ELECTRON_MASS_KG + + return { + 'm_star': m_star, + 'm_star_kg': m_star_kg, + 'curvature': curvature, + 'is_flat_band': False + } + + +def calculate_effective_mass_tensor( + band_index: int, + kpoint_index: int, + kpoints: List[List[float]], + bands: List[List[float]], + kline: List[float], + fitting_window: int, + directions: List[str] +) -> Dict[str, Any]: + """ + Calculate effective mass in multiple directions around a k-point. + + Args: + band_index: Band index + kpoint_index: K-point index + kpoints: All k-point coordinates + bands: All band energies + kline: K-line distances + fitting_window: Number of points on each side for fitting + directions: List of directions to calculate + + Returns: + Dict containing effective mass tensor components + """ + import numpy as np + + result = { + 'effective_masses': {}, + 'fitting_data': {}, + 'anisotropy_ratio': None, + 'm_star_avg': None + } + + k_center = np.array(kpoints[kpoint_index]) + band = bands[band_index] + + # Determine available k-points within window + start_idx = max(0, kpoint_index - fitting_window) + end_idx = min(len(kpoints), kpoint_index + fitting_window + 1) + + # Extract k-points and energies in window + k_window = np.array([kpoints[i] for i in range(start_idx, end_idx)]) + e_window = [band[i] for i in range(start_idx, end_idx)] + kline_window = [kline[i] for i in range(start_idx, end_idx)] + + # Calculate effective mass along k-path direction + # Use kline distances directly + k_distances = np.array(kline_window) - kline[kpoint_index] + + fit_result = fit_parabola_1d(k_distances.tolist(), e_window, 0.0) + + if fit_result['curvature'] is not None: + mass_result = calculate_effective_mass_from_curvature( + fit_result['curvature'], + 'kpath' + ) + + result['effective_masses']['kpath'] = { + 'm_star': mass_result['m_star'], + 'curvature': fit_result['curvature'], + 'r_squared': fit_result['r_squared'], + 'num_points': fit_result['num_points'] + } + + result['fitting_data']['kpath'] = { + 'k_distances': k_distances.tolist(), + 'energies': e_window, + 'fit_energies': fit_result['fit_energies'] + } + + result['m_star_avg'] = mass_result['m_star'] + + return result + + +def find_high_symmetry_kpoint( + label: str, + band_calc_dir: Path, + kpoints: List[List[float]], + kline: List[float] +) -> Optional[int]: + """ + Find k-point index corresponding to a high symmetry point label. + + Args: + label: High symmetry point label (e.g., "G", "M", "K") + band_calc_dir: Path to band calculation directory + kpoints: All k-point coordinates + kline: K-line distances + + Returns: + K-point index or None if not found + """ + try: + high_symm_labels, band_point_nums = read_high_symmetry_labels(band_calc_dir) + + # Normalize label (G -> Γ) + if label == 'G' or label == 'Gamma': + label = r'$\Gamma$' + + for i, symm_label in enumerate(high_symm_labels): + if symm_label == label or symm_label.replace('$', '').replace('\\', '') == label: + return band_point_nums[i] + + return None + except Exception as e: + print(f"Warning: Could not find high symmetry point {label}: {e}") + return None + + +def plot_effective_mass_fit( + result: Dict[str, Any], + output_path: Path +) -> Path: + """ + Create a plot showing parabolic fit for effective mass calculation. + + Args: + result: Effective mass result dictionary + output_path: Path to save the plot + + Returns: + Path to saved plot + """ + import matplotlib.pyplot as plt + import numpy as np + + point_info = result['point_info'] + eff_masses = result['effective_masses'] + fitting_data = result['fitting_data'] + + # Create figure + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + + # Plot for kpath direction + if 'kpath' in fitting_data: + data = fitting_data['kpath'] + mass_data = eff_masses['kpath'] + + # Scatter plot of actual data + ax.scatter(data['k_distances'], data['energies'], + color='blue', s=50, label='Band data', zorder=3) + + # Plot fitted parabola + ax.plot(data['k_distances'], data['fit_energies'], + 'r-', linewidth=2, label='Parabolic fit', zorder=2) + + # Add text with effective mass info + textstr = f"m* = {mass_data['m_star']:.3f} $m_e$\n" + textstr += f"R² = {mass_data['r_squared']:.4f}\n" + textstr += f"Curvature = {mass_data['curvature']:.4f} eV/Ų" + + ax.text(0.05, 0.95, textstr, transform=ax.transAxes, + verticalalignment='top', bbox=dict(boxstyle='round', + facecolor='wheat', alpha=0.5)) + + ax.set_xlabel('k distance (Å⁻¹)') + ax.set_ylabel('E - E$_F$ (eV)') + ax.legend() + ax.grid(True, alpha=0.3) + + title = f"Effective Mass at {point_info['extrema_type']}" + if point_info.get('high_symmetry_label'): + title += f" ({point_info['high_symmetry_label']})" + ax.set_title(title) + + plt.tight_layout() + plt.savefig(output_path, dpi=300) + plt.close() + + return Path(output_path).absolute() + + +def plot_effective_mass_summary( + results: List[Dict[str, Any]], + output_path: Path +) -> Path: + """ + Create summary plot showing effective masses at different points. + + Args: + results: List of effective mass results + output_path: Path to save the plot + + Returns: + Path to saved plot + """ + import matplotlib.pyplot as plt + import numpy as np + + if not results: + return None + + # Extract data + labels = [] + masses = [] + colors = [] + + for res in results: + point_info = res['point_info'] + eff_mass = res['effective_masses'].get('kpath', {}).get('m_star') + + if eff_mass is not None and not np.isinf(eff_mass): + label = point_info['extrema_type'] + if point_info.get('high_symmetry_label'): + label += f"\n{point_info['high_symmetry_label']}" + + labels.append(label) + masses.append(abs(eff_mass)) + + # Color code by type + if 'VBM' in point_info['extrema_type']: + colors.append('blue') + elif 'CBM' in point_info['extrema_type']: + colors.append('red') + else: + colors.append('gray') + + if not masses: + return None + + # Create bar plot + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + + x_pos = np.arange(len(labels)) + bars = ax.bar(x_pos, masses, color=colors, alpha=0.7, edgecolor='black') + + ax.set_xlabel('Band Extrema') + ax.set_ylabel('|m*| / $m_e$') + ax.set_title('Effective Mass Summary') + ax.set_xticks(x_pos) + ax.set_xticklabels(labels, rotation=45, ha='right') + ax.grid(True, alpha=0.3, axis='y') + + # Add value labels on bars + for i, (bar, mass) in enumerate(zip(bars, masses)): + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height, + f'{mass:.3f}', + ha='center', va='bottom', fontsize=9) + + plt.tight_layout() + plt.savefig(output_path, dpi=300) + plt.close() + + return Path(output_path).absolute() + + +def abacus_cal_effective_mass( + band_calc_dir: Path, + calculation_points: Union[str, List[Dict]], + fitting_window: int, + directions: List[str], + band_indices: Optional[List[int]], + energy_range: Optional[List[float]], + output_dir: Optional[Path] +) -> Dict[str, Any]: + """ + Calculate effective mass from band structure using parabolic fitting. + + Args: + band_calc_dir: Path to directory containing band calculation results + calculation_points: Where to calculate effective mass + fitting_window: Number of k-points on each side for fitting + directions: Directions for effective mass calculation + band_indices: Specific band indices to analyze + energy_range: Energy range for extrema detection + output_dir: Directory for output files + + Returns: + Dict containing effective mass results, plots, and summary + """ + import json + from datetime import datetime + import numpy as np + + try: + # Set default values + if energy_range is None: + energy_range = [-2.0, 2.0] + + if output_dir is None: + output_dir = band_calc_dir + + # Load band data + print("Loading band data...") + band_data = load_band_data_for_effective_mass(band_calc_dir) + bands = band_data['bands'] + kpoints = band_data['kpoints'] + kline = band_data['kline'] + efermi = band_data['efermi'] + nspin = band_data['nspin'] + bands_dw = band_data['bands_dw'] + + # Determine calculation points + calc_points = [] + + if calculation_points == "auto": + # Find only VBM and CBM + extrema = find_band_extrema(bands, kpoints, kline, efermi, + energy_range, band_indices) + calc_points = [e for e in extrema if e['extrema_type'] in ['VBM', 'CBM']] + + elif calculation_points == "extrema": + # Find all extrema + calc_points = find_band_extrema(bands, kpoints, kline, efermi, + energy_range, band_indices) + + elif isinstance(calculation_points, list): + # User-specified points + for point_spec in calculation_points: + if point_spec['type'] == 'kpoint': + # Find nearest k-point to specified coordinates + target_k = np.array(point_spec['coords']) + k_array = np.array(kpoints) + distances = np.linalg.norm(k_array - target_k, axis=1) + k_idx = np.argmin(distances) + + band_idx = point_spec.get('band_index') + if band_idx is None: + # Find band closest to Fermi level at this k-point + energies = [bands[i][k_idx] for i in range(len(bands))] + band_idx = np.argmin(np.abs(energies)) + + calc_points.append({ + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': bands[band_idx][k_idx], + 'extrema_type': 'user_specified', + 'is_degenerate': False + }) + + elif point_spec['type'] == 'high_symmetry': + label = point_spec['label'] + k_idx = find_high_symmetry_kpoint(label, band_calc_dir, + kpoints, kline) + if k_idx is not None: + band_idx = point_spec.get('band_index') + if band_idx is None: + energies = [bands[i][k_idx] for i in range(len(bands))] + band_idx = np.argmin(np.abs(energies)) + + calc_points.append({ + 'band_index': band_idx, + 'kpoint_index': k_idx, + 'kpoint_coords': kpoints[k_idx], + 'energy': bands[band_idx][k_idx], + 'extrema_type': 'high_symmetry', + 'high_symmetry_label': label, + 'is_degenerate': False + }) + + if not calc_points: + return {'message': 'No calculation points found'} + + print(f"Calculating effective mass at {len(calc_points)} points...") + + # Calculate effective mass at each point + results = [] + plot_paths = [] + + for i, point in enumerate(calc_points): + print(f"Processing point {i+1}/{len(calc_points)}: {point['extrema_type']}") + + # Calculate effective mass tensor + eff_mass_result = calculate_effective_mass_tensor( + point['band_index'], + point['kpoint_index'], + kpoints, + bands, + kline, + fitting_window, + directions + ) + + # Compile result + result = { + 'point_info': point, + 'effective_masses': eff_mass_result['effective_masses'], + 'fitting_data': eff_mass_result['fitting_data'] + } + + results.append(result) + + # Generate plot for this point + plot_filename = f"effective_mass_{point['extrema_type']}_{i}.png" + plot_path = os.path.join(output_dir, plot_filename) + plot_effective_mass_fit(result, plot_path) + plot_paths.append(Path(plot_path).absolute()) + + # Generate summary statistics + electron_masses = [] + hole_masses = [] + + for res in results: + point_type = res['point_info']['extrema_type'] + m_star = res['effective_masses'].get('kpath', {}).get('m_star') + + if m_star is not None and not np.isinf(m_star): + if 'CBM' in point_type or 'local_min' in point_type: + electron_masses.append(abs(m_star)) + elif 'VBM' in point_type or 'local_max' in point_type: + hole_masses.append(abs(m_star)) + + summary = {} + if electron_masses: + summary['electron_effective_mass'] = { + 'average': float(np.mean(electron_masses)), + 'min': float(np.min(electron_masses)), + 'max': float(np.max(electron_masses)), + 'std': float(np.std(electron_masses)) + } + + if hole_masses: + summary['hole_effective_mass'] = { + 'average': float(np.mean(hole_masses)), + 'min': float(np.min(hole_masses)), + 'max': float(np.max(hole_masses)), + 'std': float(np.std(hole_masses)) + } + + # Generate summary plot + summary_plot_path = os.path.join(output_dir, "effective_mass_summary.png") + plot_effective_mass_summary(results, summary_plot_path) + if os.path.exists(summary_plot_path): + plot_paths.append(Path(summary_plot_path).absolute()) + + # Write JSON output + json_output = { + 'metadata': { + 'band_calc_dir': str(band_calc_dir), + 'calculation_date': datetime.now().isoformat(), + 'nspin': nspin, + 'efermi': efermi, + 'fitting_window': fitting_window, + 'directions': directions + }, + 'results': results, + 'summary': summary + } + + json_path = os.path.join(output_dir, "effective_mass_results.json") + with open(json_path, 'w') as f: + json.dump(json_output, f, indent=2, default=str) + + return { + 'effective_mass_results': results, + 'effective_mass_json': Path(json_path).absolute(), + 'effective_mass_plots': plot_paths, + 'summary': summary, + 'message': 'Effective mass calculation completed successfully' + } + + except Exception as e: + import traceback + return {'message': f"Effective mass calculation failed: {e}\n{traceback.format_exc()}"} diff --git a/tests/test_band_effective_mass.py b/tests/test_band_effective_mass.py new file mode 100644 index 0000000..e57cc1b --- /dev/null +++ b/tests/test_band_effective_mass.py @@ -0,0 +1,534 @@ +""" +Unit tests for effective mass calculation functions. +""" +import os +import sys +import pytest +import numpy as np +from pathlib import Path +from unittest.mock import MagicMock + +# Set test mode to avoid MCP server initialization +os.environ["ABACUSAGENT_MODEL"] = "test" + +# Mock MPI-related imports to avoid MPI dependency in tests +sys.modules['mpi4py'] = MagicMock() +sys.modules['mpi4py.MPI'] = MagicMock() +sys.modules['pyatb'] = MagicMock() +sys.modules['pyatb.easy_use'] = MagicMock() +sys.modules['pyatb.easy_use.input_generator'] = MagicMock() +sys.modules['pyatb.easy_use.stru_analyzer'] = MagicMock() +sys.modules['pyatb.parallel'] = MagicMock() + +from abacusagent.modules.submodules.band import ( + fit_parabola_1d, + calculate_effective_mass_from_curvature, + find_band_extrema, +) +from abacusagent.constant import EFFECTIVE_MASS_FACTOR + + +class TestParabolicFitting: + """Test parabolic fitting function.""" + + def test_fit_parabola_perfect_fit(self): + """Test fitting with perfect parabolic data.""" + # Generate perfect parabola: E = 1.0 + 2.0*k^2 + k_distances = np.linspace(-0.5, 0.5, 11) + energies = 1.0 + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + assert result['E0'] == pytest.approx(1.0, abs=1e-10) + assert result['a'] == pytest.approx(2.0, abs=1e-10) + assert result['curvature'] == pytest.approx(4.0, abs=1e-10) + assert result['r_squared'] == pytest.approx(1.0, abs=1e-10) + + def test_fit_parabola_with_noise(self): + """Test fitting with noisy data.""" + np.random.seed(42) + k_distances = np.linspace(-0.5, 0.5, 11) + energies = 1.0 + 2.0 * k_distances**2 + np.random.normal(0, 0.01, 11) + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # Should still be close to true values + assert result['E0'] == pytest.approx(1.0, abs=0.1) + assert result['a'] == pytest.approx(2.0, abs=0.1) + assert result['r_squared'] > 0.95 + + def test_fit_parabola_insufficient_points(self): + """Test fitting with insufficient data points.""" + k_distances = [0.0, 0.1] + energies = [1.0, 1.02] + + result = fit_parabola_1d(k_distances, energies, 0.0) + + assert result['curvature'] is None + assert 'error' in result + + def test_fit_parabola_negative_curvature(self): + """Test fitting with negative curvature (band maximum).""" + k_distances = np.linspace(-0.5, 0.5, 11) + energies = 2.0 - 3.0 * k_distances**2 # Negative curvature + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + assert result['a'] == pytest.approx(-3.0, abs=1e-10) + assert result['curvature'] == pytest.approx(-6.0, abs=1e-10) + + +class TestEffectiveMassCalculation: + """Test effective mass calculation from curvature.""" + + def test_calculate_effective_mass_positive_curvature(self): + """Test effective mass calculation with positive curvature.""" + # Typical semiconductor electron effective mass + # For Si, m* ≈ 0.26 m_e, curvature ≈ EFFECTIVE_MASS_FACTOR / 0.26 + curvature = EFFECTIVE_MASS_FACTOR / 0.26 + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert result['m_star'] == pytest.approx(0.26, abs=1e-6) + assert result['is_flat_band'] is False + + def test_calculate_effective_mass_negative_curvature(self): + """Test effective mass calculation with negative curvature (hole).""" + # Typical hole effective mass + curvature = -EFFECTIVE_MASS_FACTOR / 0.5 + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert result['m_star'] == pytest.approx(-0.5, abs=1e-6) + assert result['is_flat_band'] is False + + def test_calculate_effective_mass_flat_band(self): + """Test effective mass calculation for flat band.""" + curvature = 1e-12 # Nearly zero curvature + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert np.isinf(result['m_star']) + assert result['is_flat_band'] is True + + def test_calculate_effective_mass_zero_curvature(self): + """Test effective mass calculation with exactly zero curvature.""" + curvature = 0.0 + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert np.isinf(result['m_star']) + assert result['is_flat_band'] is True + + +class TestBandExtremaDetection: + """Test band extrema detection function.""" + + def test_find_vbm_cbm_simple(self): + """Test VBM and CBM detection in simple band structure.""" + # Create simple band structure with clear VBM and CBM + # 3 bands, 10 k-points + bands = [ + [-2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1], # Valence band + [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4], # Band crossing Fermi + [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9] # Conduction band + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-3.0, 3.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find VBM and CBM + vbm = [e for e in extrema if e['extrema_type'] == 'VBM'] + cbm = [e for e in extrema if e['extrema_type'] == 'CBM'] + + assert len(vbm) == 1 + assert len(cbm) == 1 + assert vbm[0]['energy'] == pytest.approx(-0.1, abs=1e-6) + assert cbm[0]['energy'] == pytest.approx(0.1, abs=1e-6) + + def test_find_local_extrema(self): + """Test detection of local extrema.""" + # Create band with local maximum + bands = [ + [-1.0, -0.5, 0.0, -0.5, -1.0, -1.5, -2.0, -2.5, -3.0, -3.5] # Local max at k=2 + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-4.0, 1.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find VBM (which is also local max at k=2) + # But VBM is the highest energy below Fermi, which is at k=1 (-0.5 eV) + assert len(extrema) >= 1 + vbm = [e for e in extrema if e['extrema_type'] == 'VBM'][0] + # VBM is at k=1 with energy -0.5 eV (highest below Fermi) + assert vbm['kpoint_index'] == 1 + + def test_find_extrema_with_band_indices(self): + """Test extrema detection with specific band indices.""" + bands = [ + [-2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1], + [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4], + [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9] + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-3.0, 3.0] + + # Only analyze band 1 + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range, + band_indices=[1]) + + # Should only find extrema in band 1 + for e in extrema: + assert e['band_index'] == 1 + + +class TestIntegration: + """Integration tests combining multiple functions.""" + + def test_full_effective_mass_workflow(self): + """Test complete workflow from band data to effective mass.""" + # Generate parabolic band around k=0 + k_distances = np.linspace(-0.5, 0.5, 11) + curvature_true = 4.0 # eV/Å^2 + energies = 1.0 + 0.5 * curvature_true * k_distances**2 + + # Fit parabola + fit_result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # Calculate effective mass + mass_result = calculate_effective_mass_from_curvature( + fit_result['curvature'], + 'kpath' + ) + + # Verify + assert fit_result['curvature'] == pytest.approx(curvature_true, abs=1e-10) + expected_mass = EFFECTIVE_MASS_FACTOR / curvature_true + assert mass_result['m_star'] == pytest.approx(expected_mass, rel=1e-10) + + def test_effective_mass_for_silicon_like_band(self): + """Test effective mass calculation for Si-like band structure.""" + # Si electron effective mass at Γ point: m* ≈ 0.26 m_e + # This corresponds to curvature ≈ EFFECTIVE_MASS_FACTOR / 0.26 + expected_mass = 0.26 + curvature = EFFECTIVE_MASS_FACTOR / expected_mass + + # Generate band data + k_distances = np.linspace(-0.2, 0.2, 21) + energies = 0.5 + 0.5 * curvature * k_distances**2 + + # Fit and calculate + fit_result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + mass_result = calculate_effective_mass_from_curvature( + fit_result['curvature'], + 'kpath' + ) + + assert mass_result['m_star'] == pytest.approx(expected_mass, abs=1e-3) + + +class TestParabolicFittingEdgeCases: + """Additional edge case tests for parabolic fitting.""" + + def test_fit_parabola_asymmetric_window(self): + """Test fitting with asymmetric data around extremum.""" + # More points on one side + k_distances = np.concatenate([ + np.linspace(-0.5, 0.0, 3), + np.linspace(0.0, 0.5, 8) + ]) + energies = 1.0 + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # Should still fit reasonably well + assert result['a'] == pytest.approx(2.0, abs=0.1) + assert result['r_squared'] > 0.99 + + def test_fit_parabola_off_center(self): + """Test fitting with extremum not at k=0.""" + k0 = 0.3 + k_distances = np.linspace(0.0, 0.6, 11) + energies = 1.5 + 3.0 * (k_distances - k0)**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), k0) + + assert result['E0'] == pytest.approx(1.5, abs=1e-2) + assert result['a'] == pytest.approx(3.0, abs=1e-2) + + def test_fit_parabola_very_flat_band(self): + """Test fitting with very small curvature.""" + k_distances = np.linspace(-0.5, 0.5, 11) + energies = 1.0 + 1e-6 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + assert result['curvature'] == pytest.approx(2e-6, abs=1e-8) + assert result['r_squared'] > 0.5 + + def test_fit_parabola_linear_component(self): + """Test fitting with linear component (non-extremum point).""" + k_distances = np.linspace(-0.5, 0.5, 11) + # E = 1.0 + 0.5*k + 2.0*k^2 (has linear term) + energies = 1.0 + 0.5 * k_distances + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # Should still extract curvature correctly + assert result['curvature'] == pytest.approx(4.0, abs=1e-10) + assert result['b'] == pytest.approx(0.5, abs=1e-10) + + +class TestEffectiveMassEdgeCases: + """Additional edge case tests for effective mass calculation.""" + + def test_very_light_effective_mass(self): + """Test calculation with very light effective mass (large curvature).""" + # m* = 0.01 m_e (very light) + curvature = EFFECTIVE_MASS_FACTOR / 0.01 + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert result['m_star'] == pytest.approx(0.01, abs=1e-6) + assert result['is_flat_band'] is False + + def test_very_heavy_effective_mass(self): + """Test calculation with very heavy effective mass (small curvature).""" + # m* = 10.0 m_e (very heavy) + curvature = EFFECTIVE_MASS_FACTOR / 10.0 + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert result['m_star'] == pytest.approx(10.0, abs=1e-6) + assert result['is_flat_band'] is False + + def test_effective_mass_sign_preservation(self): + """Test that sign of effective mass is preserved.""" + # Positive curvature (electron-like) + curvature_pos = EFFECTIVE_MASS_FACTOR / 0.5 + result_pos = calculate_effective_mass_from_curvature(curvature_pos, 'kx') + assert result_pos['m_star'] > 0 + + # Negative curvature (hole-like) + curvature_neg = -EFFECTIVE_MASS_FACTOR / 0.5 + result_neg = calculate_effective_mass_from_curvature(curvature_neg, 'kx') + assert result_neg['m_star'] < 0 + + def test_effective_mass_near_zero_curvature(self): + """Test behavior near flat band threshold.""" + # Just above threshold (threshold is 1e-10) + curvature = 1e-11 + result = calculate_effective_mass_from_curvature(curvature, 'kx') + assert result['is_flat_band'] is True + assert np.isinf(result['m_star']) + + +class TestBandExtremaEdgeCases: + """Additional edge case tests for band extrema detection.""" + + def test_find_extrema_no_gap(self): + """Test extrema detection in metallic system (no gap).""" + # Band crossing Fermi level + bands = [ + [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5] + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-2.0, 4.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find VBM and CBM at adjacent points + vbm = [e for e in extrema if e['extrema_type'] == 'VBM'] + cbm = [e for e in extrema if e['extrema_type'] == 'CBM'] + + assert len(vbm) == 1 + assert len(cbm) == 1 + + def test_find_extrema_multiple_bands(self): + """Test extrema detection with multiple bands.""" + bands = [ + [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, -0.4, -0.3, -0.2, -0.1], # VB + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1], # CB1 + [1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4] # CB2 + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-4.0, 3.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find VBM in band 0 and CBM in band 1 + vbm = [e for e in extrema if e['extrema_type'] == 'VBM'] + cbm = [e for e in extrema if e['extrema_type'] == 'CBM'] + + assert len(vbm) == 1 + assert len(cbm) == 1 + assert vbm[0]['band_index'] == 0 + assert cbm[0]['band_index'] == 1 + + def test_find_extrema_at_boundary(self): + """Test extrema detection when extremum is at k-path boundary.""" + # Band with maximum at first point + bands = [ + [0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8, -0.9] + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-1.0, 0.5] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find VBM - but it's at the boundary (k=0) + # The algorithm looks for local extrema (comparing with neighbors) + # So boundary points may not be detected as local extrema + # VBM will be the highest energy below Fermi, which is at k=0 (0.0 eV) + # But since 0.0 is not < 0, it won't be VBM. Next is k=1 (-0.1 eV) + vbm = [e for e in extrema if e['extrema_type'] == 'VBM'] + assert len(vbm) == 1 + assert vbm[0]['kpoint_index'] == 1 # -0.1 eV is the highest below Fermi + + def test_find_extrema_with_plateau(self): + """Test extrema detection with flat region (plateau).""" + bands = [ + [-1.0, -0.5, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.5, 2.0] + ] + kpoints = [[i*0.1, 0, 0] for i in range(10)] + kline = [i*0.1 for i in range(10)] + efermi = 0.0 + energy_range = [-2.0, 3.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Should find CBM (one of the plateau points) + cbm = [e for e in extrema if e['extrema_type'] == 'CBM'] + assert len(cbm) >= 1 + + +class TestFittingQuality: + """Tests for fitting quality assessment.""" + + def test_good_fit_high_r_squared(self): + """Test that good parabolic data gives high R².""" + k_distances = np.linspace(-0.5, 0.5, 21) + energies = 1.0 + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + assert result['r_squared'] > 0.999 + + def test_poor_fit_low_r_squared(self): + """Test that non-parabolic data gives lower R².""" + k_distances = np.linspace(-0.5, 0.5, 21) + # Quartic function (not parabolic) + energies = 1.0 + 2.0 * k_distances**4 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # R² should be lower for non-parabolic data + assert result['r_squared'] < 0.99 + + def test_residuals_calculation(self): + """Test that residuals are calculated correctly.""" + k_distances = np.linspace(-0.5, 0.5, 11) + energies = 1.0 + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + # For perfect fit, residuals should be near zero + assert all(abs(r) < 1e-10 for r in result['residuals']) + + def test_num_points_in_result(self): + """Test that number of points is correctly reported.""" + k_distances = np.linspace(-0.5, 0.5, 15) + energies = 1.0 + 2.0 * k_distances**2 + + result = fit_parabola_1d(k_distances.tolist(), energies.tolist(), 0.0) + + assert result['num_points'] == 15 + + +class TestPhysicalConstants: + """Tests to verify physical constants are correct.""" + + def test_effective_mass_factor_value(self): + """Test that EFFECTIVE_MASS_FACTOR has reasonable value.""" + # Should be around 7.62 for the given units + assert EFFECTIVE_MASS_FACTOR > 7.0 + assert EFFECTIVE_MASS_FACTOR < 8.0 + + def test_effective_mass_units_consistency(self): + """Test unit consistency in effective mass calculation.""" + # For GaAs: m* ≈ 0.067 m_e + # Typical curvature for GaAs CBM + m_star_expected = 0.067 + curvature = EFFECTIVE_MASS_FACTOR / m_star_expected + + result = calculate_effective_mass_from_curvature(curvature, 'kx') + + assert result['m_star'] == pytest.approx(m_star_expected, abs=1e-6) + + +class TestRobustness: + """Tests for robustness and error handling.""" + + def test_fit_with_nan_values(self): + """Test fitting behavior with NaN values.""" + k_distances = [0.0, 0.1, 0.2, 0.3, 0.4] + energies = [1.0, 1.02, np.nan, 1.08, 1.16] + + result = fit_parabola_1d(k_distances, energies, 0.0) + + # Should handle NaN gracefully + assert 'error' in result or result['curvature'] is None + + def test_fit_with_inf_values(self): + """Test fitting behavior with infinite values.""" + k_distances = [0.0, 0.1, 0.2, 0.3, 0.4] + energies = [1.0, 1.02, np.inf, 1.08, 1.16] + + result = fit_parabola_1d(k_distances, energies, 0.0) + + # Should handle inf gracefully - polyfit returns NaN + assert 'error' in result or result['curvature'] is None or np.isnan(result['curvature']) + + def test_empty_band_list(self): + """Test extrema detection with empty band list.""" + bands = [] + kpoints = [] + kline = [] + efermi = 0.0 + energy_range = [-1.0, 1.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + assert len(extrema) == 0 + + def test_single_kpoint(self): + """Test extrema detection with single k-point.""" + bands = [[0.5]] + kpoints = [[0.0, 0.0, 0.0]] + kline = [0.0] + efermi = 0.0 + energy_range = [-1.0, 1.0] + + extrema = find_band_extrema(bands, kpoints, kline, efermi, energy_range) + + # Cannot find local extrema with single point + # But should find VBM or CBM if energy is appropriate + assert len(extrema) <= 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])