Skip to content

Commit de9ccb9

Browse files
committed
add support for configuring CSV fields during import
1 parent 5d18f19 commit de9ccb9

1 file changed

Lines changed: 107 additions & 22 deletions

File tree

plugins/io/__init__.py

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
|
77
"""
88
import base64
9+
import csv
910
import contextlib
11+
import io
1012
import multiprocessing.dummy
1113
import os
1214
from packaging.version import Version
@@ -491,9 +493,6 @@ def _import_media_and_labels_inputs(ctx, inputs):
491493
),
492494
view=file_explorer,
493495
)
494-
data_path = _parse_path(ctx, "data_path")
495-
if data_path is None:
496-
return False
497496

498497
labels_path_type = _get_labels_path_type(dataset_type)
499498

@@ -537,7 +536,11 @@ def _import_media_and_labels_inputs(ctx, inputs):
537536
prop.error_message = f"Please provide a {ext} path"
538537
return False
539538

540-
_add_label_types(ctx, inputs, dataset_type)
539+
data_path = _parse_path(ctx, "data_path")
540+
if data_path is None:
541+
return False
542+
543+
_add_importer_extras(ctx, inputs, dataset_type)
541544

542545
inputs.bool(
543546
"dynamic",
@@ -694,34 +697,92 @@ def _import_labels_only_inputs(ctx, inputs):
694697
if dataset_dir is None:
695698
return False
696699

697-
_add_label_types(ctx, inputs, dataset_type)
700+
_add_importer_extras(ctx, inputs, dataset_type)
698701

699702
# Don't allow delegation when uploading files
700703
return tab != "UPLOAD"
701704

702705

703-
def _add_label_types(ctx, inputs, dataset_type):
704-
supported_types = _get_dataset_type(dataset_type).get("label_types", None)
706+
def _add_importer_extras(ctx, inputs, dataset_type):
707+
d = _get_dataset_type(dataset_type)
708+
dataset_type = d["dataset_type"]
709+
supported_types = d.get("label_types", None)
705710

706-
if supported_types is None or len(supported_types) <= 1:
707-
return
711+
if dataset_type == fot.CSVDataset:
712+
dataset_dir = _parse_path(ctx, "dataset_dir")
713+
labels_path = _parse_path(ctx, "labels_path")
714+
_, labels_bytes = _parse_file(ctx, "labels_file")
715+
716+
if dataset_dir is not None:
717+
if labels_path is not None:
718+
labels_path = fos.join(dataset_dir, labels_path)
719+
else:
720+
labels_path = fos.join(dataset_dir, "labels.csv")
721+
722+
if labels_path is not None:
723+
_get_csv_import_fields(ctx, inputs, csv_path=labels_path)
724+
725+
if labels_bytes is not None:
726+
_get_csv_import_fields(ctx, inputs, csv_bytes=labels_bytes)
727+
728+
if supported_types is not None and len(supported_types) > 1:
729+
label_type_choices = types.DropdownView(multiple=True)
730+
for label_type in supported_types:
731+
label_type_choices.add_choice(label_type, label=label_type)
708732

709-
label_type_choices = types.DropdownView(multiple=True)
710-
for label_type in supported_types:
711-
label_type_choices.add_choice(label_type, label=label_type)
733+
inputs.list(
734+
"label_types",
735+
types.String(),
736+
default=None,
737+
label="Label types",
738+
description=(
739+
"The label type(s) to load. By default, all label types are "
740+
"loaded"
741+
),
742+
view=label_type_choices,
743+
)
744+
745+
746+
def _get_csv_import_fields(ctx, inputs, csv_path=None, csv_bytes=None):
747+
fieldnames = _get_csv_fieldnames(csv_path=csv_path, csv_bytes=csv_bytes)
748+
749+
field_choices = types.DropdownView(multiple=True)
750+
for field in fieldnames:
751+
field_choices.add_choice(field, label=field)
712752

713753
inputs.list(
714-
"label_types",
754+
"csv_fields",
715755
types.String(),
756+
required=False,
716757
default=None,
717-
label="Label types",
758+
label="Fields",
759+
description="An optional subset of column(s) to import",
760+
view=field_choices,
761+
)
762+
763+
inputs.str(
764+
"media_field",
765+
required=True,
766+
default="filepath" if "filepath" in fieldnames else None,
767+
label="Media field",
718768
description=(
719-
"The label type(s) to load. By default, all label types are loaded"
769+
"The name of the column containing the media path for each column"
720770
),
721-
view=label_type_choices,
771+
view=field_choices,
722772
)
723773

724774

775+
def _get_csv_fieldnames(csv_path=None, csv_bytes=None):
776+
if csv_path is not None:
777+
f = fos.open_file(csv_path, "r")
778+
else:
779+
f = io.StringIO(csv_bytes.decode("utf-8"))
780+
781+
with f:
782+
reader = csv.DictReader(f)
783+
return list(reader.fieldnames)
784+
785+
725786
def _upload_media_inputs(ctx, inputs):
726787
style = ctx.params.get("style", None)
727788

@@ -776,11 +837,9 @@ def _upload_media_inputs(ctx, inputs):
776837

777838

778839
def _upload_media_bytes(ctx):
779-
media_obj = ctx.params["media_file"]
840+
filename, content = _parse_file(ctx, "media_file")
780841
upload_dir = _parse_path(ctx, "upload_dir")
781842
overwrite = ctx.params["overwrite"]
782-
filename = media_obj["name"]
783-
content = base64.b64decode(media_obj["content"])
784843

785844
if overwrite:
786845
outpath = fos.join(upload_dir, filename)
@@ -864,6 +923,15 @@ def _import_media_and_labels(ctx):
864923
if label_types is not None:
865924
kwargs["label_types"] = label_types
866925

926+
if dataset_type == fot.CSVDataset:
927+
csv_fields = ctx.params.get("csv_fields", None)
928+
if csv_fields is not None:
929+
kwargs["fields"] = csv_fields
930+
931+
media_field = ctx.params.get("media_field", None)
932+
if media_field is not None:
933+
kwargs["media_field"] = media_field
934+
867935
# @todo can remove version check if we require `fiftyone>=1.6.0`
868936
if ctx.delegated and Version(foc.VERSION) >= Version("1.6.0"):
869937
progress = lambda pb: ctx.set_progress(progress=pb.progress)
@@ -885,9 +953,7 @@ def _import_media_and_labels(ctx):
885953

886954

887955
def _upload_labels_bytes(ctx, tmp_dir):
888-
labels_obj = ctx.params["labels_file"]
889-
filename = labels_obj["name"]
890-
content = base64.b64decode(labels_obj["content"])
956+
filename, content = _parse_file(ctx, "labels_file")
891957

892958
outpath = fos.join(tmp_dir, filename)
893959
fos.write_file(content, outpath)
@@ -909,6 +975,15 @@ def _import_labels_only(ctx):
909975
if label_types is not None:
910976
kwargs["label_types"] = label_types
911977

978+
if dataset_type == fot.CSVDataset:
979+
csv_fields = ctx.params.get("csv_fields", None)
980+
if csv_fields is not None:
981+
kwargs["fields"] = csv_fields
982+
983+
media_field = ctx.params.get("media_field", None)
984+
if media_field is not None:
985+
kwargs["media_field"] = media_field
986+
912987
# @todo can remove version check if we require `fiftyone>=1.6.0`
913988
if ctx.delegated and Version(foc.VERSION) >= Version("1.6.0"):
914989
progress = lambda pb: ctx.set_progress(progress=pb.progress)
@@ -2820,6 +2895,16 @@ def _to_path(value):
28202895
return {"absolute_path": value}
28212896

28222897

2898+
def _parse_file(ctx, key):
2899+
file_obj = ctx.params.get(key, None)
2900+
if file_obj is None:
2901+
return None, None
2902+
2903+
filename = file_obj["name"]
2904+
content = base64.b64decode(file_obj["content"])
2905+
return filename, content
2906+
2907+
28232908
def _to_list(value):
28242909
if value is None:
28252910
return None

0 commit comments

Comments
 (0)