Skip to content

Commit 6c24999

Browse files
committed
fix: close zst tar resources safely
1 parent 5a30fdf commit 6c24999

1 file changed

Lines changed: 31 additions & 19 deletions

File tree

biothings/utils/common.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919
import os.path
2020
import pickle
2121
import random
22+
import shutil
2223
import string
2324
import sys
2425
import tarfile
26+
import tempfile
2527
import time
2628
import types
2729
import urllib.parse
2830
import warnings
2931
from collections import UserDict, UserList
30-
from contextlib import contextmanager
32+
from contextlib import closing, contextmanager
3133
from datetime import date, datetime, timezone
3234
from functools import partial
3335
from itertools import islice
@@ -188,22 +190,35 @@ def anyfile(infile, mode="r"):
188190
# tarfile handling. works for zst in Python >= 3.14
189191
if lower_version_zst or tarfile.is_tarfile(infile):
190192
if lower_version_zst:
191-
f = open(infile, "rb")
192-
dctx = zstd.ZstdDecompressor()
193-
reader = dctx.stream_reader(f)
194-
tar_file = tarfile.open(fileobj=reader, mode="r|") # streaming mode
195-
else:
196-
tar_file = tarfile.open(infile, mode)
197-
198-
extracted = None
193+
with open(infile, "rb") as compressed_file:
194+
dctx = zstd.ZstdDecompressor()
195+
with closing(dctx.stream_reader(compressed_file)) as reader:
196+
with tarfile.open(fileobj=reader, mode="r|") as tar_file:
197+
for member in tar_file:
198+
if member.name == rawfile:
199+
extracted = tar_file.extractfile(member)
200+
break
201+
else:
202+
extracted = None
203+
204+
# Keep the returned file readable after closing the tar and zst streams.
205+
if extracted is not None:
206+
with extracted:
207+
spooled_file = tempfile.SpooledTemporaryFile( # pylint: disable=consider-using-with
208+
max_size=1024 * 1024
209+
)
210+
shutil.copyfileobj(extracted, spooled_file)
211+
spooled_file.seek(0)
212+
213+
# extracted member is not a regular file or link
214+
if extracted is None:
215+
raise Exception("invalid target file: must be a regular file or a link")
216+
217+
return spooled_file
218+
219+
tar_file = tarfile.open(infile, mode) # pylint: disable=consider-using-with
199220
try:
200-
if lower_version_zst:
201-
for member in tar_file:
202-
if member.name == rawfile:
203-
extracted = tar_file.extractfile(member)
204-
break
205-
else:
206-
extracted = tar_file.extractfile(rawfile)
221+
extracted = tar_file.extractfile(rawfile)
207222
except KeyError:
208223
# provided rawfile does not appear in the tarball
209224
tar_file.close()
@@ -214,9 +229,6 @@ def anyfile(infile, mode="r"):
214229
tar_file.close()
215230
raise Exception("invalid target file: must be a regular file or a link")
216231

217-
if lower_version_zst:
218-
return extracted
219-
220232
return io.TextIOWrapper(extracted)
221233

222234
if filetype == ".gz":

0 commit comments

Comments
 (0)