Skip to content

Commit 31806a0

Browse files
authored
Merge pull request #1383 from eyumboo/dev-postgresql
Add the implementation for downloading mnist
2 parents c3f0fb5 + e5fdbd9 commit 31806a0

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
import argparse
21+
import os
22+
import urllib.request
23+
24+
25+
def check_exist_or_download(url, download_dir):
26+
download_dir = os.path.abspath(download_dir)
27+
os.makedirs(download_dir, exist_ok=True)
28+
29+
name = url.rsplit('/', 1)[-1]
30+
filename = os.path.join(download_dir, name)
31+
32+
if not os.path.isfile(filename):
33+
print("Downloading %s -> %s" % (url, filename))
34+
urllib.request.urlretrieve(url, filename)
35+
else:
36+
print("Already Downloaded: %s -> %s" % (url, filename))
37+
38+
if __name__ == '__main__':
39+
parser = argparse.ArgumentParser(
40+
description='Download the MNIST dataset.'
41+
)
42+
parser.add_argument(
43+
'-dir',
44+
'--dir-path',
45+
dest='dir_path',
46+
default='/tmp/mnist',
47+
help='Directory to save the MNIST dataset.'
48+
)
49+
args = parser.parse_args()
50+
51+
train_x_url = 'https://github.com/fgnt/mnist/raw/master/train-images-idx3-ubyte.gz'
52+
train_y_url = 'https://github.com/fgnt/mnist/raw/master/train-labels-idx1-ubyte.gz'
53+
valid_x_url = 'https://github.com/fgnt/mnist/raw/master/t10k-images-idx3-ubyte.gz'
54+
valid_y_url = 'https://github.com/fgnt/mnist/raw/master/t10k-labels-idx1-ubyte.gz'
55+
56+
check_exist_or_download(train_x_url, args.dir_path)
57+
check_exist_or_download(train_y_url, args.dir_path)
58+
check_exist_or_download(valid_x_url, args.dir_path)
59+
check_exist_or_download(valid_y_url, args.dir_path)

0 commit comments

Comments
 (0)