forked from awsdocs/aws-doc-sdk-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_labels_csv_to_manifest.py
More file actions
207 lines (164 loc) · 6.81 KB
/
custom_labels_csv_to_manifest.py
File metadata and controls
207 lines (164 loc) · 6.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# snippet-start:[rekognition.python.custom_labels_csv_manifest.complete]
from datetime import datetime, timezone
import argparse
import logging
import csv
import os
import json
"""
Purpose
Amazon Rekognition Custom Labels model example used in the service documentation.
Shows how to create an image-level (classification) manifest file from a CSV file.
You can specify multiple image level labels per image.
CSV file format is
image,label,label,..
If necessary, use the bucket argument to specify the S3 bucket folder for the images.
https://docs.aws.amazon.com/rekognition/latest/customlabels-dg/md-gt-cl-transform.html
"""
logger = logging.getLogger(__name__)
def check_duplicates(csv_file, deduplicated_file, duplicates_file):
"""
Checks for duplicate images in a CSV file. If duplicate images
are found, deduplicated_file is the deduplicated CSV file - only the first
occurence of a duplicate is recorded. Other duplicates are recorded in duplicates_file.
:param csv_file: The source CSV file.
:param deduplicated_file: The deduplicated CSV file to create. If no duplicates are found
this file is removed.
:param duplicates_file: The duplicate images CSV file to create. If no duplicates are found
this file is removed.
:return: True if duplicates are found, otherwise false.
"""
logger.info("Deduplicating %s", csv_file)
duplicates_found = False
# Find duplicates.
with open(csv_file, "r", newline="", encoding="UTF-8") as f, open(
deduplicated_file, "w", encoding="UTF-8"
) as dedup, open(duplicates_file, "w", encoding="UTF-8") as duplicates:
reader = csv.reader(f, delimiter=",")
dedup_writer = csv.writer(dedup)
duplicates_writer = csv.writer(duplicates)
entries = set()
for row in reader:
# Skip empty lines.
if not "".join(row).strip():
continue
key = row[0]
if key not in entries:
dedup_writer.writerow(row)
entries.add(key)
else:
duplicates_writer.writerow(row)
duplicates_found = True
if duplicates_found:
logger.info("Duplicates found check %s", duplicates_file)
else:
os.remove(duplicates_file)
os.remove(deduplicated_file)
return duplicates_found
def create_manifest_file(csv_file, manifest_file, s3_path):
"""
Reads a CSV file and creates a Custom Labels classification manifest file.
:param csv_file: The source CSV file.
:param manifest_file: The name of the manifest file to create.
:param s3_path: The S3 path to the folder that contains the images.
"""
logger.info("Processing CSV file %s", csv_file)
image_count = 0
label_count = 0
with open(csv_file, newline="", encoding="UTF-8") as csvfile, open(
manifest_file, "w", encoding="UTF-8"
) as output_file:
image_classifications = csv.reader(csvfile, delimiter=",", quotechar="|")
# Process each row (image) in CSV file.
for row in image_classifications:
source_ref = str(s3_path) + row[0]
image_count += 1
# Create JSON for image source ref.
json_line = {}
json_line["source-ref"] = source_ref
# Process each image level label.
for index in range(1, len(row)):
image_level_label = row[index]
# Skip empty columns.
if image_level_label == "":
continue
label_count += 1
# Create the JSON line metadata.
json_line[image_level_label] = 1
metadata = {}
metadata["confidence"] = 1
metadata["job-name"] = "labeling-job/" + image_level_label
metadata["class-name"] = image_level_label
metadata["human-annotated"] = "yes"
metadata["creation-date"] = datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%f"
)
metadata["type"] = "groundtruth/image-classification"
json_line[f"{image_level_label}-metadata"] = metadata
# Write the image JSON Line.
output_file.write(json.dumps(json_line))
output_file.write("\n")
output_file.close()
logger.info(
"Finished creating manifest file %s\nImages: %s\nLabels: %s",
manifest_file,
image_count,
label_count,
)
return image_count, label_count
def add_arguments(parser):
"""
Adds command line arguments to the parser.
:param parser: The command line parser.
"""
parser.add_argument("csv_file", help="The CSV file that you want to process.")
parser.add_argument(
"--s3_path",
help="The S3 bucket and folder path for the images."
" If not supplied, column 1 is assumed to include the S3 path.",
required=False,
)
def main():
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
try:
# Get command line arguments
parser = argparse.ArgumentParser(usage=argparse.SUPPRESS)
add_arguments(parser)
args = parser.parse_args()
s3_path = args.s3_path
if s3_path is None:
s3_path = ""
# Create file names.
csv_file = args.csv_file
file_name = os.path.splitext(csv_file)[0]
manifest_file = f"{file_name}.manifest"
duplicates_file = f"{file_name}-duplicates.csv"
deduplicated_file = f"{file_name}-deduplicated.csv"
# Create manifest file, if there are no duplicate images.
if check_duplicates(csv_file, deduplicated_file, duplicates_file):
print(
f"Duplicates found. Use {duplicates_file} to view duplicates "
f"and then update {deduplicated_file}. "
)
print(
f"{deduplicated_file} contains the first occurence of a duplicate. "
"Update as necessary with the correct label information."
)
print(f"Re-run the script with {deduplicated_file}")
else:
print("No duplicates found. Creating manifest file.")
image_count, label_count = create_manifest_file(
csv_file, manifest_file, s3_path
)
print(
f"Finished creating manifest file: {manifest_file} \n"
f"Images: {image_count}\nLabels: {label_count}"
)
except FileNotFoundError as err:
logger.exception("File not found: %s", err)
print(f"File not found: {err}. Check your input CSV file.")
if __name__ == "__main__":
main()
# snippet-end:[rekognition.python.custom_labels_csv_manifest.complete]