77import os
88import tarfile
99import tempfile
10+ from contextlib import contextmanager
1011from pathlib import Path
11- from typing import Literal
12+ from typing import BinaryIO , Literal
1213
1314import httpx
15+ import zstandard
1416
1517from diracx .client .aio import AsyncDiracClient
1618from diracx .client .models import SandboxInfo
2022logger = logging .getLogger (__name__ )
2123
2224SANDBOX_CHECKSUM_ALGORITHM = "sha256"
23- SANDBOX_COMPRESSION : Literal ["bz2" ] = "bz2"
24- SANDBOX_OPEN_MODE : Literal ["w|bz2" ] = "w|bz2"
25+ SANDBOX_COMPRESSION : Literal ["zst" ] = "zst"
26+
27+
28+ @contextmanager
29+ def tarfile_open (fileobj : BinaryIO ):
30+ """Context manager to extend tarfile.open to support reading zstd compressed files.
31+
32+ This is only needed for Python <=3.13.
33+ """
34+ # Save current position and read magic bytes
35+ current_pos = fileobj .tell ()
36+ magic = fileobj .read (4 )
37+ fileobj .seek (current_pos )
38+
39+ # Read magic bytes to determine compression format
40+ if magic .startswith (b"\x28 \xb5 \x2f \xfd " ): # zstd magic number
41+ dctx = zstandard .ZstdDecompressor ()
42+ with dctx .stream_reader (fileobj ) as decompressor :
43+ with tarfile .open (fileobj = decompressor , mode = "r|" ) as tf :
44+ yield tf
45+ else :
46+ with tarfile .open (fileobj = fileobj , mode = "r" ) as tf :
47+ yield tf
2548
2649
2750@with_client
@@ -33,10 +56,18 @@ async def create_sandbox(paths: list[Path], *, client: AsyncDiracClient) -> str:
3356 be used to submit jobs.
3457 """
3558 with tempfile .TemporaryFile (mode = "w+b" ) as tar_fh :
36- with tarfile .open (fileobj = tar_fh , mode = SANDBOX_OPEN_MODE ) as tf :
37- for path in paths :
38- logger .debug ("Adding %s to sandbox as %s" , path .resolve (), path .name )
39- tf .add (path .resolve (), path .name , recursive = True )
59+ # Create zstd compressed tar with level 18 and long matching enabled
60+ compression_params = zstandard .ZstdCompressionParameters .from_level (
61+ 18 , enable_ldm = 1
62+ )
63+ cctx = zstandard .ZstdCompressor (compression_params = compression_params )
64+ with cctx .stream_writer (tar_fh , closefd = False ) as compressor :
65+ with tarfile .open (fileobj = compressor , mode = "w|" ) as tf :
66+ for path in paths :
67+ logger .debug (
68+ "Adding %s to sandbox as %s" , path .resolve (), path .name
69+ )
70+ tf .add (path .resolve (), path .name , recursive = True )
4071 tar_fh .seek (0 )
4172
4273 hasher = getattr (hashlib , SANDBOX_CHECKSUM_ALGORITHM )()
@@ -89,6 +120,6 @@ async def download_sandbox(pfn: str, destination: Path, *, client: AsyncDiracCli
89120 fh .seek (0 )
90121 logger .debug ("Sandbox downloaded for %s" , pfn )
91122
92- with tarfile . open ( fileobj = fh ) as tf :
123+ with tarfile_open ( fh ) as tf :
93124 tf .extractall (path = destination , filter = "data" )
94125 logger .debug ("Extracted %s to %s" , pfn , destination )
0 commit comments