Skip to content

Commit b8bd9dc

Browse files
authored
Modify ingest script to check a local directory (#443)
* Modify ingest script to check a local directory * Check data_dir existence instead of assert
1 parent 735da19 commit b8bd9dc

1 file changed

Lines changed: 43 additions & 8 deletions

File tree

bats_ai/core/management/commands/load_public_dataset.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _ingest_files_from_manifest(
116116
offset: int | None,
117117
file_key: str = "file_key",
118118
tag_keys: list[str] | None = None,
119+
data_dir: Path | None = None,
119120
):
120121
if tag_keys is None:
121122
tag_keys = []
@@ -137,6 +138,7 @@ def _ingest_files_from_manifest(
137138
filename = None
138139

139140
try:
141+
local = False
140142
s3_key = line[file_key]
141143
existing_recording = Recording.objects.filter(name=s3_key).first()
142144
if existing_recording:
@@ -146,12 +148,31 @@ def _ingest_files_from_manifest(
146148
logger.info("Ingesting %s...", s3_key)
147149
object_exists = _try_head_s3_object(s3_client, bucket, s3_key)
148150
if not object_exists:
149-
logger.warning("Could not HEAD object with key %s. Skipping...", s3_key)
150-
continue
151-
filename = _create_filename(s3_key)
152-
logger.info("Downloading to temporary file %s...", filename)
153-
s3_client.download_file(bucket, s3_key, filename)
154-
logger.info("Creating recording for %s", s3_key)
151+
if not data_dir:
152+
logger.warning("Could not HEAD object with key %s. Skipping...", s3_key)
153+
continue
154+
else:
155+
logger.info(
156+
"Could not HEAD object with key %s. Checking local directory %s",
157+
s3_key,
158+
data_dir,
159+
)
160+
local = True
161+
if not local:
162+
filename = _create_filename(s3_key)
163+
logger.info("Downloading to temporary file %s...", filename)
164+
s3_client.download_file(bucket, s3_key, filename)
165+
logger.info("Creating recording for %s", s3_key)
166+
else:
167+
if not data_dir:
168+
logger.warning("No local data directory specified. Skipping...")
169+
continue
170+
filename = str(data_dir / s3_key)
171+
if Path(filename).exists():
172+
logger.info("Found local file match for %s.", s3_key)
173+
else:
174+
logger.warning("Could not find a local match for %s, skipping...", s3_key)
175+
continue
155176
metadata = _get_metadata(filename, line)
156177
with open(filename, "rb") as f:
157178
recording = Recording.objects.create(
@@ -188,7 +209,7 @@ def _ingest_files_from_manifest(
188209
)
189210
recording_compute_spectrogram.delay(recording.pk)
190211
finally:
191-
if filename:
212+
if not local and filename:
192213
# Delete the file (this may run on a machine with limited resources)
193214
try:
194215
logger.info("Cleaning up by removing temporary file %s...", filename)
@@ -198,7 +219,7 @@ def _ingest_files_from_manifest(
198219

199220

200221
class Command(BaseCommand):
201-
help = "Create recordings and spectrograms from WAV files in a public s3 bucket"
222+
help = "Ingest recordings from local filesystem and public s3 according to a manifest file."
202223

203224
def add_arguments(self, parser):
204225
parser.add_argument(
@@ -212,6 +233,9 @@ def add_arguments(self, parser):
212233
# Assume columns "Key" and "Tags"
213234
help="Manifest CSV file with file keys and tags",
214235
)
236+
parser.add_argument(
237+
"--data-dir", type=str, help="The directory where local recordings are located"
238+
)
215239
parser.add_argument(
216240
"--owner",
217241
type=str,
@@ -253,6 +277,16 @@ def handle(self, *args, **options):
253277
except ClientError:
254278
self.stdout.write(self.style.ERROR(f"Could not access bucket {bucket}"))
255279
return
280+
281+
data_dir = options.get("data_dir")
282+
if data_dir:
283+
data_dir = Path(data_dir)
284+
if not data_dir.exists():
285+
self.stdout.write(
286+
self.style.ERROR(f"Specified data directory {data_dir} does not exist")
287+
)
288+
return
289+
256290
manifest = Path(options["manifest"])
257291
if not manifest.exists():
258292
self.stdout.write(self.style.ERROR(f"Could not find manifest file {manifest}"))
@@ -290,4 +324,5 @@ def handle(self, *args, **options):
290324
offset=offset,
291325
file_key=file_key,
292326
tag_keys=tag_keys,
327+
data_dir=data_dir,
293328
)

0 commit comments

Comments
 (0)