Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/150.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add four missing antennas for MeerKat layout
1 change: 1 addition & 0 deletions docs/changes/150.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
allow sinlge values for simulation config (bypass random value drawing)
1 change: 1 addition & 0 deletions docs/changes/150.optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add channel bandwidths to UVH5Writer
4 changes: 4 additions & 0 deletions resources/layouts/meerkat.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ m060 5107254.013471882 2009699.3572179652 -3240542.587340528 13.5 15.0 85.0 110.
m061 5108278.559161539 2006410.136906058 -3240956.885313587 13.5 15.0 85.0 110.0 1000.0
m062 5108713.98241022 2005051.0165491276 -3241111.829132157 13.5 15.0 85.0 110.0 1000.0
m063 5109748.526320712 2003331.232038675 -3240538.853735712 13.5 15.0 85.0 110.0 1000.0
m008 5109148.3563610995 2006668.9215486543 -3239413.4521834562 13.5 15.0 85.0 110.0 1000.0
m021 5109319.2750222068 2006518.5606004531 -3239233.6195771112 13.5 15.0 85.0 110.0 1000.0
m022 5109501.2780305194 2006507.2823303600 -3238950.5447669746 13.5 15.0 85.0 110.0 1000.0
m023 5109415.8381663673 2006528.1891329922 -3239073.8564673522 13.5 15.0 85.0 110.0 1000.0
68 changes: 36 additions & 32 deletions src/pyvisgen/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def _run(self) -> None:
index=i * bundle_length + j,
sky=SIs[j],
overwrite=True,
normalize=self.conf.sampling.normalize,
)

if fits_writer is not None:
Expand Down Expand Up @@ -472,46 +473,49 @@ def draw_sampling_opts(self, size: int) -> dict:
samp_opts : dict
Sampling options/parameters stored inside a dictionary.
"""
ra = self.rng.uniform(
self.conf.sampling.fov_center_ra[0],
self.conf.sampling.fov_center_ra[1],
size,
ra_cfg = self.conf.sampling.fov_center_ra
ra = (
np.full(size, ra_cfg[0])
if len(ra_cfg) == 1
else self.rng.uniform(ra_cfg[0], ra_cfg[1], size)
)
dec = self.rng.uniform(
self.conf.sampling.fov_center_dec[0],
self.conf.sampling.fov_center_dec[1],
size,

dec_cfg = self.conf.sampling.fov_center_dec
dec = (
np.full(size, dec_cfg[0])
if len(dec_cfg) == 1
else self.rng.uniform(dec_cfg[0], dec_cfg[1], size)
)

start_time_l = datetime.strptime(
self.conf.sampling.scan_start[0], self.date_fmt
)
start_time_h = datetime.strptime(
self.conf.sampling.scan_start[1], self.date_fmt
)
start_times = np.arange(
start_time_l,
start_time_h,
timedelta(hours=1),
).astype(datetime)

scan_start = self.rng.choice(start_times, size)
scan_duration = self.rng.integers(
self.conf.sampling.scan_duration[0],
self.conf.sampling.scan_duration[1],
size,
)
num_scans = self.rng.integers(
self.conf.sampling.num_scans[0],
self.conf.sampling.num_scans[1],
size,
if len(self.conf.sampling.scan_start) == 1:
scan_start = np.full(size, start_time_l)
else:
start_time_h = datetime.strptime(
self.conf.sampling.scan_start[1], self.date_fmt
)
start_times = np.arange(
start_time_l,
start_time_h,
timedelta(hours=1),
).astype(datetime)
scan_start = self.rng.choice(start_times, size)

dur_cfg = self.conf.sampling.scan_duration
scan_duration = (
np.full(size, dur_cfg[0], dtype=int)
if len(dur_cfg) == 1
else self.rng.integers(dur_cfg[0], dur_cfg[1], size)
)

if scan_duration.size == 1:
scan_duration = scan_duration.astype(int)

if num_scans.size == 1:
num_scans = num_scans.astype(int)
ns_cfg = self.conf.sampling.num_scans
num_scans = (
np.full(size, ns_cfg[0], dtype=int)
if len(ns_cfg) == 1
else self.rng.integers(ns_cfg[0], ns_cfg[1], size)
)

# if polarization is None, we don't need to enter the
# conditional below, so we set delta, amp_ratio, field_order,
Expand Down
4 changes: 2 additions & 2 deletions src/pyvisgen/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def validate_layout(cls, layout: str) -> None:
@field_validator("scan_start")
@classmethod
def validate_dates(cls, v: list[str]) -> None:
if len(v) != 2:
raise ValueError("expected 'scan_start' to be a list of len 2")
if len(v) not in (1, 2):
raise ValueError("expected 'scan_start' to be a list of len 1 or 2")

return v

Expand Down
4 changes: 4 additions & 0 deletions src/pyvisgen/io/datawriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ class UVH5Writer(DataWriter):
│ ├── m
│ └── n
├── frequency_bands
├── channel_widths
└── sky/
└── SI

Expand Down Expand Up @@ -466,6 +467,7 @@ def write(
obs,
index: int,
sky=None,
normalize: bool = True,
**kwargs,
) -> None:
"""Write simulation data to an HDF5 file.
Expand Down Expand Up @@ -523,6 +525,8 @@ def write(

freq_bands = self.__to_numpy(obs.ref_frequency + obs.frequency_offsets)
f.create_dataset("frequency_bands", data=freq_bands)
f.create_dataset("channel_widths", data=self.__to_numpy(obs.bandwidths))
f.create_dataset("normalize", data=np.bool_(normalize))

if sky is not None:
sky_grp = f.create_group("sky")
Expand Down
1 change: 1 addition & 0 deletions src/pyvisgen/simulation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import astropy.units as un
import numpy as np
import numpy.typing
import torch
from astropy.constants import c
from astropy.coordinates import AltAz, EarthLocation, Longitude, SkyCoord
Expand Down
42 changes: 42 additions & 0 deletions tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from pyvisgen.dataset import SimulateDataSet
from pyvisgen.dataset.dataset import DATEFMT
from pyvisgen.io import Config
from pyvisgen.layouts import Stations

Expand Down Expand Up @@ -238,6 +239,47 @@ def test_polarization(self, pol_mode: str, sd_sampling: SimulateDataSet) -> None
assert samp_opts["order"].shape == (size, 2)
assert samp_opts["scale"].shape == (size, 2)

@pytest.mark.parametrize(
"field,value",
[
("fov_center_ra", [-173.867]),
("fov_center_dec", [6.474]),
("scan_duration", [272]),
("num_scans", [9]),
],
)
def test_fixed_scalar_fields(
self, field: str, value: list, sd_sampling: SimulateDataSet
) -> None:
"""Single-value list bypasses random draw and repeats the fixed value."""
setattr(sd_sampling.conf.sampling, field, value)

size = 5
samp_opts = sd_sampling.draw_sampling_opts(size)

key_map = {
"fov_center_ra": "src_ra",
"fov_center_dec": "src_dec",
"scan_duration": "scan_duration",
"num_scans": "num_scans",
}
result = samp_opts[key_map[field]]
assert result.shape == (size,)
assert (result == value[0]).all()

def test_fixed_scan_start(self, sd_sampling: SimulateDataSet) -> None:
"""Single-value scan_start bypasses random draw and
repeats the fixed datetime."""
date_str = "22-04-2023 17:21:11"
sd_sampling.conf.sampling.scan_start = [date_str]

size = 5
samp_opts = sd_sampling.draw_sampling_opts(size)

expected = datetime.strptime(date_str, DATEFMT)
assert samp_opts["start_time"].shape == (size,)
assert all(t == expected for t in samp_opts["start_time"])

def test_polarization_kwargs_none(self, sd_sampling: SimulateDataSet) -> None:
sd_sampling.conf.polarization.mode = "linear"
sd_sampling.conf.polarization.delta = None
Expand Down
1 change: 1 addition & 0 deletions tests/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def uvh5_obs() -> SimpleNamespace:
lm=lm,
ref_frequency=torch.tensor(15.7e9),
frequency_offsets=torch.tensor([0.0, 1.0e6]),
bandwidths=torch.tensor([1.0e6, 1.0e6]),
)


Expand Down
16 changes: 14 additions & 2 deletions tests/io/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,23 @@ def test_validate_dates(self) -> None:

assert cfg.scan_start == dates

def test_validate_dates_single(self) -> None:
dates = ["2024-06-15 10:00:00"]
cfg = SamplingConfig(scan_start=dates)

assert cfg.scan_start == dates

def test_validate_dates_invalid(self) -> None:
with pytest.raises(ValueError) as excinfo:
SamplingConfig(scan_start=["2025-01-01 12:00:00"])
SamplingConfig(
scan_start=[
"2024-01-01 12:00:00",
"2025-01-01 12:00:00",
"2026-01-01 12:00:00",
]
)

assert "expected 'scan_start' to be a list of len 2" in str(excinfo.value)
assert "expected 'scan_start' to be a list of len 1 or 2" in str(excinfo.value)

@pytest.mark.parametrize(
"seed,expected", [(42, 42), (None, None), ("none", None), (False, None)]
Expand Down