Skip to content

Commit 43f0124

Browse files
committed
init
1 parent e0674fa commit 43f0124

3 files changed

Lines changed: 297 additions & 0 deletions

File tree

prepare-data.ipynb

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"!pip install ibm-cos-sdk"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import os\n",
19+
"import shutil\n",
20+
"import json\n",
21+
"import uuid\n",
22+
"\n",
23+
"import ibm_boto3"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"bucket = os.getenv(\"BUCKET\", \"\")\n",
33+
"access_key_id = os.getenv(\"ACCESS_KEY_ID\", \"\")\n",
34+
"secret_access_key = os.getenv(\"SECRET_ACCESS_KEY\", \"\")\n",
35+
"endpoint_url = os.getenv(\"ENDPOINT_URL\", \"\")"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"cos = ibm_boto3.resource(\"s3\",\n",
45+
" aws_access_key_id=access_key_id,\n",
46+
" aws_secret_access_key=secret_access_key,\n",
47+
" endpoint_url=endpoint_url\n",
48+
")\n",
49+
"\n",
50+
"# load the annotations\n",
51+
"try:\n",
52+
" annotations = json.loads(cos.Object(bucket, \"_annotations.json\").get()[\"Body\"].read())[\"annotations\"]\n",
53+
"except Exception as e:\n",
54+
" print(\"Unable to retrieve annotations: {}\".format(e))"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"data_dir = \"data\"\n",
64+
"os.makedirs(data_dir)\n",
65+
"\n",
66+
"# create a set of labels and then turn it into a list to remove dupelicates\n",
67+
"labels = list({annotation[\"label\"] for image in annotations.values() for annotation in image})\n",
68+
"\n",
69+
"for label in labels:\n",
70+
" # find a list of images with the given label\n",
71+
" image_list = [image_name for image_name in annotations.keys() for annotation in annotations[image_name] if annotation[\"label\"] == label]\n",
72+
"\n",
73+
" # make directory for the label to store images in\n",
74+
" train_label_dir = os.path.join(data_dir, label)\n",
75+
" os.makedirs(train_label_dir)\n",
76+
"\n",
77+
" # move images to the their label folder\n",
78+
" for im in image_list:\n",
79+
" try:\n",
80+
" extension = os.path.splitext(im)[1]\n",
81+
" cos.meta.client.download_file(bucket, im, os.path.join(train_label_dir, str(uuid.uuid4()) + extension))\n",
82+
" except Exception as e:\n",
83+
" print(\"Error: {}, skipping {}...\".format(e, im))"
84+
]
85+
}
86+
],
87+
"metadata": {
88+
"kernelspec": {
89+
"display_name": "Python 3",
90+
"language": "python",
91+
"name": "python3"
92+
},
93+
"language_info": {
94+
"codemirror_mode": {
95+
"name": "ipython",
96+
"version": 3
97+
},
98+
"file_extension": ".py",
99+
"mimetype": "text/x-python",
100+
"name": "python",
101+
"nbconvert_exporter": "python",
102+
"pygments_lexer": "ipython3",
103+
"version": "3.7.9"
104+
}
105+
},
106+
"nbformat": 4,
107+
"nbformat_minor": 4
108+
}

train.ipynb

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%%bash\n",
10+
"\n",
11+
"pip install -U \"tensorflow~=2.0\" \"tensorflow-hub[make_image_classifier]~=0.6\"\n",
12+
"\n",
13+
"make_image_classifier \\\n",
14+
" --image_dir data \\\n",
15+
" --tfhub_module https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4 \\\n",
16+
" --image_size 224 \\\n",
17+
" --saved_model_dir model \\\n",
18+
" --labels_output_file labels.txt \\\n",
19+
" --train_epochs=5 \\\n",
20+
" --do_fine_tuning \\\n",
21+
" --batch_size=32 \\\n",
22+
" --learning_rate=0.005 \\\n",
23+
" --momentum=0.9 \\\n",
24+
" --dropout_rate=0.2 \\\n",
25+
" --l1_regularizer=0.0 \\\n",
26+
" --l2_regularizer=0.0001 \\\n",
27+
" --label_smoothing=0.1 \\\n",
28+
" --validation_split=0.2\n",
29+
"\n",
30+
"mv labels.txt model"
31+
]
32+
}
33+
],
34+
"metadata": {
35+
"kernelspec": {
36+
"display_name": "Python 3",
37+
"language": "python",
38+
"name": "python3"
39+
},
40+
"language_info": {
41+
"codemirror_mode": {
42+
"name": "ipython",
43+
"version": 3
44+
},
45+
"file_extension": ".py",
46+
"mimetype": "text/x-python",
47+
"name": "python",
48+
"nbconvert_exporter": "python",
49+
"pygments_lexer": "ipython3",
50+
"version": "3.7.9"
51+
}
52+
},
53+
"nbformat": 4,
54+
"nbformat_minor": 4
55+
}

train.pipeline

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
{
2+
"doc_type": "pipeline",
3+
"version": "3.0",
4+
"json_schema": "http://api.dataplatform.ibm.com/schemas/common-pipeline/pipeline-flow/pipeline-flow-v3-schema.json",
5+
"id": "a4ea898f-d773-4f64-850e-861504960b10",
6+
"primary_pipeline": "4851b63f-721c-4d12-b17b-3ccdca7a292f",
7+
"pipelines": [
8+
{
9+
"id": "4851b63f-721c-4d12-b17b-3ccdca7a292f",
10+
"nodes": [
11+
{
12+
"id": "85c9f13e-334a-4d9a-be63-5d7bb0a211cf",
13+
"type": "execution_node",
14+
"op": "execute-notebook-node",
15+
"app_data": {
16+
"filename": "prepare-data.ipynb",
17+
"runtime_image": "tensorflow/tensorflow:2.3.0",
18+
"env_vars": [
19+
"BUCKET=",
20+
"ACCESS_KEY_ID=",
21+
"SECRET_ACCESS_KEY=",
22+
"ENDPOINT_URL="
23+
],
24+
"include_subdirectories": false,
25+
"outputs": [
26+
"data"
27+
],
28+
"invalidNodeError": null,
29+
"ui_data": {
30+
"label": "prepare-data.ipynb",
31+
"image": "data:image/svg+xml;utf8,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20width%3D%2216%22%20viewBox%3D%220%200%2022%2022%22%3E%0A%20%20%3Cg%20class%3D%22jp-icon-warn0%20jp-icon-selectable%22%20fill%3D%22%23EF6C00%22%3E%0A%20%20%20%20%3Cpath%20d%3D%22M18.7%203.3v15.4H3.3V3.3h15.4m1.5-1.5H1.8v18.3h18.3l.1-18.3z%22%2F%3E%0A%20%20%20%20%3Cpath%20d%3D%22M16.5%2016.5l-5.4-4.3-5.6%204.3v-11h11z%22%2F%3E%0A%20%20%3C%2Fg%3E%0A%3C%2Fsvg%3E%0A",
32+
"x_pos": 278,
33+
"y_pos": 249,
34+
"description": "Notebook file"
35+
}
36+
},
37+
"inputs": [
38+
{
39+
"id": "inPort",
40+
"app_data": {
41+
"ui_data": {
42+
"cardinality": {
43+
"min": 0,
44+
"max": -1
45+
},
46+
"label": "Input Port"
47+
}
48+
}
49+
}
50+
],
51+
"outputs": [
52+
{
53+
"id": "outPort",
54+
"app_data": {
55+
"ui_data": {
56+
"cardinality": {
57+
"min": 0,
58+
"max": -1
59+
},
60+
"label": "Output Port"
61+
}
62+
}
63+
}
64+
]
65+
},
66+
{
67+
"id": "14d59e0b-7262-4f11-b07c-de833edb9d00",
68+
"type": "execution_node",
69+
"op": "execute-notebook-node",
70+
"app_data": {
71+
"filename": "train.ipynb",
72+
"runtime_image": "tensorflow/tensorflow:2.3.0",
73+
"env_vars": [],
74+
"include_subdirectories": false,
75+
"invalidNodeError": null,
76+
"outputs": [
77+
"model"
78+
],
79+
"ui_data": {
80+
"label": "train.ipynb",
81+
"image": "data:image/svg+xml;utf8,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20width%3D%2216%22%20viewBox%3D%220%200%2022%2022%22%3E%0A%20%20%3Cg%20class%3D%22jp-icon-warn0%20jp-icon-selectable%22%20fill%3D%22%23EF6C00%22%3E%0A%20%20%20%20%3Cpath%20d%3D%22M18.7%203.3v15.4H3.3V3.3h15.4m1.5-1.5H1.8v18.3h18.3l.1-18.3z%22%2F%3E%0A%20%20%20%20%3Cpath%20d%3D%22M16.5%2016.5l-5.4-4.3-5.6%204.3v-11h11z%22%2F%3E%0A%20%20%3C%2Fg%3E%0A%3C%2Fsvg%3E%0A",
82+
"x_pos": 520,
83+
"y_pos": 259,
84+
"description": "Notebook file"
85+
}
86+
},
87+
"inputs": [
88+
{
89+
"id": "inPort",
90+
"app_data": {
91+
"ui_data": {
92+
"cardinality": {
93+
"min": 0,
94+
"max": -1
95+
},
96+
"label": "Input Port"
97+
}
98+
},
99+
"links": [
100+
{
101+
"id": "5eb1d54e-de34-4d1a-8e82-54f15c3800c2",
102+
"node_id_ref": "85c9f13e-334a-4d9a-be63-5d7bb0a211cf",
103+
"port_id_ref": "outPort"
104+
}
105+
]
106+
}
107+
],
108+
"outputs": [
109+
{
110+
"id": "outPort",
111+
"app_data": {
112+
"ui_data": {
113+
"cardinality": {
114+
"min": 0,
115+
"max": -1
116+
},
117+
"label": "Output Port"
118+
}
119+
}
120+
}
121+
]
122+
}
123+
],
124+
"app_data": {
125+
"ui_data": {
126+
"comments": []
127+
},
128+
"version": 3
129+
},
130+
"runtime_ref": ""
131+
}
132+
],
133+
"schemas": []
134+
}

0 commit comments

Comments
 (0)