Skip to content

Commit f5ce346

Browse files
author
gabriel.g.robin
committed
nullable list
1 parent 54cbc75 commit f5ce346

3 files changed

Lines changed: 9 additions & 4 deletions

File tree

dataframely/columns/list.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
147147
chain([0], element_lengths.cum_sum()), element_lengths
148148
)
149149
]
150-
return pl.Series(list_elements)
150+
# Finally, apply a null mask
151+
return generator._apply_null_mask(
152+
pl.Series(list_elements), null_probability=self._null_probability
153+
)

dataframely/random.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,10 @@ def sample_duration(
392392
def _apply_null_mask(self, series: pl.Series, null_probability: float) -> pl.Series:
393393
if null_probability == 0:
394394
return series
395-
null_mask = self.numpy_generator.random(series.len()) < null_probability
396-
return series.scatter(np.where(null_mask)[0], None)
395+
null_mask = (
396+
pl.Series(self.numpy_generator.random(series.len())) > null_probability
397+
)
398+
return pl.select(pl.when(null_mask).then(series)).to_series()
397399

398400

399401
# --------------------------------------- UTILS -------------------------------------- #

tests/columns/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_sample_enum(generator: Generator) -> None:
178178
def test_sample_list(generator: Generator) -> None:
179179
column = dy.List(dy.String(regex="[abc]"), min_length=5, max_length=10)
180180
samples = sample_and_validate(column, generator, n=10_000)
181-
assert set(samples.list.len()) == set(range(5, 11))
181+
assert set(samples.list.len()) == set(range(5, 11)) | {None}
182182

183183

184184
def test_sample_struct(generator: Generator) -> None:

0 commit comments

Comments
 (0)