Skip to content

Commit 8809c60

Browse files
author
Paul Prescod
committed
Make failure to close a stream an error, as it would be by default.
1 parent ca96441 commit 8809c60

4 files changed

Lines changed: 49 additions & 25 deletions

File tree

snowfakery/api.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,9 @@ def configure_output_stream(
224224
try:
225225
yield output_stream
226226
finally:
227-
try:
228-
messages = output_stream.close()
229-
except Exception as e:
230-
messages = None
231-
parent_application.echo(
232-
f"Could not close {output_stream}: {str(e)}", err=True
233-
)
234-
if messages:
235-
for message in messages:
236-
parent_application.echo(message)
227+
messages = output_stream.close() or []
228+
for message in messages:
229+
parent_application.echo(message)
237230

238231

239232
@contextmanager

snowfakery/output_streams.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def close(self) -> Optional[Sequence[str]]:
125125
126126
Return a list of messages to print out.
127127
"""
128-
return super().close()
128+
raise NotImplementedError()
129129

130130
def __enter__(self, *args):
131131
return self
@@ -578,8 +578,20 @@ def write_row(self, tablename: str, row_with_references: Dict) -> None:
578578
stream.write_row(tablename, row_with_references)
579579

580580
def close(self) -> Optional[Sequence[str]]:
581+
all_messages = []
582+
closing_errors = []
581583
for stream in self.outputstreams:
582-
stream.close()
584+
try:
585+
messages = stream.close() or []
586+
all_messages.extend(messages)
587+
except Exception as e:
588+
closing_errors.append(e)
589+
590+
if len(closing_errors) == 1:
591+
raise closing_errors[0]
592+
elif closing_errors:
593+
raise IOError(f"Could not close streams: {closing_errors}")
594+
return all_messages
583595

584596
def write_single_row(self, tablename: str, row: Dict) -> None:
585597
return super().write_single_row(tablename, row)

tests/test_embedding.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_continuation_as_open_file(self):
6363
with mapping_file.open() as f:
6464
assert yaml.safe_load(f)
6565

66-
def test_parent_application__echo(self):
66+
def test_parent_application__exception_raised(self):
6767
called = False
6868

6969
class MyEmbedder(SnowfakeryApplication):
@@ -74,10 +74,10 @@ def echo(self, *args, **kwargs):
7474
meth = "snowfakery.output_streams.DebugOutputStream.close"
7575
with mock.patch(meth) as close:
7676
close.side_effect = AssertionError
77-
generate_data(
78-
yaml_file="examples/company.yml", parent_application=MyEmbedder()
79-
)
80-
assert called
77+
with pytest.raises(AssertionError):
78+
generate_data(
79+
yaml_file="examples/company.yml", parent_application=MyEmbedder()
80+
)
8181

8282
def test_parent_application__early_finish(self, generated_rows):
8383
class MyEmbedder(SnowfakeryApplication):
@@ -89,14 +89,9 @@ def check_if_finished(self, idmanager):
8989
assert self.__class__.count < 100, "Runaway recipe!"
9090
return idmanager["Employee"] >= 10
9191

92-
meth = "snowfakery.output_streams.DebugOutputStream.close"
93-
with mock.patch(meth) as close:
94-
close.side_effect = AssertionError
95-
generate_data(
96-
yaml_file="examples/company.yml", parent_application=MyEmbedder()
97-
)
98-
# called 5 times, after generating 2 employees each
99-
assert MyEmbedder.count == 5
92+
generate_data(yaml_file="examples/company.yml", parent_application=MyEmbedder())
93+
# called 5 times, after generating 2 employees each
94+
assert MyEmbedder.count == 5
10095

10196
def test_embedding__cannot_infer_output_format(self):
10297
with pytest.raises(exc.DataGenError, match="No format"):

tests/test_output_streams.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
88
from contextlib import redirect_stdout
9+
from unittest import mock
910

1011

1112
import pytest
@@ -375,3 +376,26 @@ def test_external_output_stream__failure(self):
375376
generate_cli.callback(
376377
yaml_file=sample_yaml, output_format="no.such.output.Stream"
377378
)
379+
380+
381+
class TestMultiplexOutputStream:
382+
@mock.patch("snowfakery.output_streams.DebugOutputStream.close")
383+
def test_cannot_close_multiple_streams(self, close):
384+
close.side_effect = AssertionError
385+
with TemporaryDirectory() as t:
386+
files = [Path(t) / "1.txt", Path(t) / "2.txt"]
387+
with pytest.raises(IOError) as e:
388+
generate_cli.callback(
389+
yaml_file="examples/company.yml", output_files=files
390+
)
391+
assert "Could not close streams:" in str(e.value)
392+
393+
@mock.patch("snowfakery.output_streams.DebugOutputStream.close")
394+
def test_cannot_close_one_stream(self, close):
395+
close.side_effect = AssertionError
396+
with TemporaryDirectory() as t:
397+
files = [Path(t) / "1.txt", Path(t) / "2.jpg"]
398+
with pytest.raises(AssertionError):
399+
generate_cli.callback(
400+
yaml_file="examples/company.yml", output_files=files
401+
)

0 commit comments

Comments
 (0)