66|
77"""
88import base64
9+ import csv
910import contextlib
11+ import io
1012import multiprocessing .dummy
1113import os
1214from packaging .version import Version
@@ -491,9 +493,6 @@ def _import_media_and_labels_inputs(ctx, inputs):
491493 ),
492494 view = file_explorer ,
493495 )
494- data_path = _parse_path (ctx , "data_path" )
495- if data_path is None :
496- return False
497496
498497 labels_path_type = _get_labels_path_type (dataset_type )
499498
@@ -537,7 +536,11 @@ def _import_media_and_labels_inputs(ctx, inputs):
537536 prop .error_message = f"Please provide a { ext } path"
538537 return False
539538
540- _add_label_types (ctx , inputs , dataset_type )
539+ data_path = _parse_path (ctx , "data_path" )
540+ if data_path is None :
541+ return False
542+
543+ _add_importer_extras (ctx , inputs , dataset_type )
541544
542545 inputs .bool (
543546 "dynamic" ,
@@ -694,34 +697,92 @@ def _import_labels_only_inputs(ctx, inputs):
694697 if dataset_dir is None :
695698 return False
696699
697- _add_label_types (ctx , inputs , dataset_type )
700+ _add_importer_extras (ctx , inputs , dataset_type )
698701
699702 # Don't allow delegation when uploading files
700703 return tab != "UPLOAD"
701704
702705
703- def _add_label_types (ctx , inputs , dataset_type ):
704- supported_types = _get_dataset_type (dataset_type ).get ("label_types" , None )
706+ def _add_importer_extras (ctx , inputs , dataset_type ):
707+ d = _get_dataset_type (dataset_type )
708+ dataset_type = d ["dataset_type" ]
709+ supported_types = d .get ("label_types" , None )
705710
706- if supported_types is None or len (supported_types ) <= 1 :
707- return
711+ if dataset_type == fot .CSVDataset :
712+ dataset_dir = _parse_path (ctx , "dataset_dir" )
713+ labels_path = _parse_path (ctx , "labels_path" )
714+ _ , labels_bytes = _parse_file (ctx , "labels_file" )
715+
716+ if dataset_dir is not None :
717+ if labels_path is not None :
718+ labels_path = fos .join (dataset_dir , labels_path )
719+ else :
720+ labels_path = fos .join (dataset_dir , "labels.csv" )
721+
722+ if labels_path is not None :
723+ _get_csv_import_fields (ctx , inputs , csv_path = labels_path )
724+
725+ if labels_bytes is not None :
726+ _get_csv_import_fields (ctx , inputs , csv_bytes = labels_bytes )
727+
728+ if supported_types is not None and len (supported_types ) > 1 :
729+ label_type_choices = types .DropdownView (multiple = True )
730+ for label_type in supported_types :
731+ label_type_choices .add_choice (label_type , label = label_type )
708732
709- label_type_choices = types .DropdownView (multiple = True )
710- for label_type in supported_types :
711- label_type_choices .add_choice (label_type , label = label_type )
733+ inputs .list (
734+ "label_types" ,
735+ types .String (),
736+ default = None ,
737+ label = "Label types" ,
738+ description = (
739+ "The label type(s) to load. By default, all label types are "
740+ "loaded"
741+ ),
742+ view = label_type_choices ,
743+ )
744+
745+
746+ def _get_csv_import_fields (ctx , inputs , csv_path = None , csv_bytes = None ):
747+ fieldnames = _get_csv_fieldnames (csv_path = csv_path , csv_bytes = csv_bytes )
748+
749+ field_choices = types .DropdownView (multiple = True )
750+ for field in fieldnames :
751+ field_choices .add_choice (field , label = field )
712752
713753 inputs .list (
714- "label_types " ,
754+ "csv_fields " ,
715755 types .String (),
756+ required = False ,
716757 default = None ,
717- label = "Label types" ,
758+ label = "Fields" ,
759+ description = "An optional subset of column(s) to import" ,
760+ view = field_choices ,
761+ )
762+
763+ inputs .str (
764+ "media_field" ,
765+ required = True ,
766+ default = "filepath" if "filepath" in fieldnames else None ,
767+ label = "Media field" ,
718768 description = (
719- "The label type(s) to load. By default, all label types are loaded "
769+ "The name of the column containing the media path for each column "
720770 ),
721- view = label_type_choices ,
771+ view = field_choices ,
722772 )
723773
724774
775+ def _get_csv_fieldnames (csv_path = None , csv_bytes = None ):
776+ if csv_path is not None :
777+ f = fos .open_file (csv_path , "r" )
778+ else :
779+ f = io .StringIO (csv_bytes .decode ("utf-8" ))
780+
781+ with f :
782+ reader = csv .DictReader (f )
783+ return list (reader .fieldnames )
784+
785+
725786def _upload_media_inputs (ctx , inputs ):
726787 style = ctx .params .get ("style" , None )
727788
@@ -776,11 +837,9 @@ def _upload_media_inputs(ctx, inputs):
776837
777838
778839def _upload_media_bytes (ctx ):
779- media_obj = ctx . params [ "media_file" ]
840+ filename , content = _parse_file ( ctx , "media_file" )
780841 upload_dir = _parse_path (ctx , "upload_dir" )
781842 overwrite = ctx .params ["overwrite" ]
782- filename = media_obj ["name" ]
783- content = base64 .b64decode (media_obj ["content" ])
784843
785844 if overwrite :
786845 outpath = fos .join (upload_dir , filename )
@@ -864,6 +923,15 @@ def _import_media_and_labels(ctx):
864923 if label_types is not None :
865924 kwargs ["label_types" ] = label_types
866925
926+ if dataset_type == fot .CSVDataset :
927+ csv_fields = ctx .params .get ("csv_fields" , None )
928+ if csv_fields is not None :
929+ kwargs ["fields" ] = csv_fields
930+
931+ media_field = ctx .params .get ("media_field" , None )
932+ if media_field is not None :
933+ kwargs ["media_field" ] = media_field
934+
867935 # @todo can remove version check if we require `fiftyone>=1.6.0`
868936 if ctx .delegated and Version (foc .VERSION ) >= Version ("1.6.0" ):
869937 progress = lambda pb : ctx .set_progress (progress = pb .progress )
@@ -885,9 +953,7 @@ def _import_media_and_labels(ctx):
885953
886954
887955def _upload_labels_bytes (ctx , tmp_dir ):
888- labels_obj = ctx .params ["labels_file" ]
889- filename = labels_obj ["name" ]
890- content = base64 .b64decode (labels_obj ["content" ])
956+ filename , content = _parse_file (ctx , "labels_file" )
891957
892958 outpath = fos .join (tmp_dir , filename )
893959 fos .write_file (content , outpath )
@@ -909,6 +975,15 @@ def _import_labels_only(ctx):
909975 if label_types is not None :
910976 kwargs ["label_types" ] = label_types
911977
978+ if dataset_type == fot .CSVDataset :
979+ csv_fields = ctx .params .get ("csv_fields" , None )
980+ if csv_fields is not None :
981+ kwargs ["fields" ] = csv_fields
982+
983+ media_field = ctx .params .get ("media_field" , None )
984+ if media_field is not None :
985+ kwargs ["media_field" ] = media_field
986+
912987 # @todo can remove version check if we require `fiftyone>=1.6.0`
913988 if ctx .delegated and Version (foc .VERSION ) >= Version ("1.6.0" ):
914989 progress = lambda pb : ctx .set_progress (progress = pb .progress )
@@ -2820,6 +2895,16 @@ def _to_path(value):
28202895 return {"absolute_path" : value }
28212896
28222897
2898+ def _parse_file (ctx , key ):
2899+ file_obj = ctx .params .get (key , None )
2900+ if file_obj is None :
2901+ return None , None
2902+
2903+ filename = file_obj ["name" ]
2904+ content = base64 .b64decode (file_obj ["content" ])
2905+ return filename , content
2906+
2907+
28232908def _to_list (value ):
28242909 if value is None :
28252910 return None
0 commit comments