Skip to content

Commit 312b9b8

Browse files
author
Paul Prescod
committed
Make failure to close a stream an error, as it would be by default.
1 parent ca96441 commit 312b9b8

3 files changed

Lines changed: 25 additions & 25 deletions

File tree

snowfakery/api.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,9 @@ def configure_output_stream(
224224
try:
225225
yield output_stream
226226
finally:
227-
try:
228-
messages = output_stream.close()
229-
except Exception as e:
230-
messages = None
231-
parent_application.echo(
232-
f"Could not close {output_stream}: {str(e)}", err=True
233-
)
234-
if messages:
235-
for message in messages:
236-
parent_application.echo(message)
227+
messages = output_stream.close() or []
228+
for message in messages:
229+
parent_application.echo(message)
237230

238231

239232
@contextmanager

snowfakery/output_streams.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def close(self) -> Optional[Sequence[str]]:
125125
126126
Return a list of messages to print out.
127127
"""
128-
return super().close()
128+
raise NotImplementedError()
129129

130130
def __enter__(self, *args):
131131
return self
@@ -578,8 +578,20 @@ def write_row(self, tablename: str, row_with_references: Dict) -> None:
578578
stream.write_row(tablename, row_with_references)
579579

580580
def close(self) -> Optional[Sequence[str]]:
581+
all_messages = []
582+
closing_errors = []
581583
for stream in self.outputstreams:
582-
stream.close()
584+
try:
585+
messages = stream.close() or []
586+
all_messages.extend(messages)
587+
except Exception as e:
588+
closing_errors.append(e)
589+
590+
if len(closing_errors) == 1:
591+
raise closing_errors[1]
592+
elif closing_errors:
593+
raise IOError(f"Could not close streams: {closing_errors}")
594+
return all_messages
583595

584596
def write_single_row(self, tablename: str, row: Dict) -> None:
585597
return super().write_single_row(tablename, row)

tests/test_embedding.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_continuation_as_open_file(self):
6363
with mapping_file.open() as f:
6464
assert yaml.safe_load(f)
6565

66-
def test_parent_application__echo(self):
66+
def test_parent_application__exception_raised(self):
6767
called = False
6868

6969
class MyEmbedder(SnowfakeryApplication):
@@ -74,10 +74,10 @@ def echo(self, *args, **kwargs):
7474
meth = "snowfakery.output_streams.DebugOutputStream.close"
7575
with mock.patch(meth) as close:
7676
close.side_effect = AssertionError
77-
generate_data(
78-
yaml_file="examples/company.yml", parent_application=MyEmbedder()
79-
)
80-
assert called
77+
with pytest.raises(AssertionError):
78+
generate_data(
79+
yaml_file="examples/company.yml", parent_application=MyEmbedder()
80+
)
8181

8282
def test_parent_application__early_finish(self, generated_rows):
8383
class MyEmbedder(SnowfakeryApplication):
@@ -89,14 +89,9 @@ def check_if_finished(self, idmanager):
8989
assert self.__class__.count < 100, "Runaway recipe!"
9090
return idmanager["Employee"] >= 10
9191

92-
meth = "snowfakery.output_streams.DebugOutputStream.close"
93-
with mock.patch(meth) as close:
94-
close.side_effect = AssertionError
95-
generate_data(
96-
yaml_file="examples/company.yml", parent_application=MyEmbedder()
97-
)
98-
# called 5 times, after generating 2 employees each
99-
assert MyEmbedder.count == 5
92+
generate_data(yaml_file="examples/company.yml", parent_application=MyEmbedder())
93+
# called 5 times, after generating 2 employees each
94+
assert MyEmbedder.count == 5
10095

10196
def test_embedding__cannot_infer_output_format(self):
10297
with pytest.raises(exc.DataGenError, match="No format"):

0 commit comments

Comments
 (0)