Skip to content

Commit db88be7

Browse files
authored
Support MNIST download directory argument
1 parent 3c00259 commit db88be7

1 file changed

Lines changed: 21 additions & 11 deletions

File tree

examples/singa_peft/examples/data/download_mnist.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,43 @@
1717
# under the License.
1818
#
1919

20+
import argparse
2021
import os
2122
import urllib.request
2223

2324

24-
def check_exist_or_download(url):
25+
def check_exist_or_download(url, download_dir):
26+
os.makedirs(download_dir, exist_ok=True)
2527

26-
download_dir = '/tmp/' # downloaded to the /tmp/ folder
2728
name = url.rsplit('/', 1)[-1]
2829
filename = os.path.join(download_dir, name)
2930

3031
if not os.path.isfile(filename):
31-
print("Downloading %s" % url)
32+
print("Downloading %s to %s" % (url, filename))
3233
urllib.request.urlretrieve(url, filename)
3334
else:
34-
print("Already Downloaded: %s" % url)
35+
print("Already Downloaded: %s" % filename)
3536

3637

3738
if __name__ == '__main__':
39+
parser = argparse.ArgumentParser(
40+
description='Download the MNIST dataset.'
41+
)
42+
parser.add_argument(
43+
'-dir',
44+
'--dir-path',
45+
dest='dir_path',
46+
default='/tmp/mnist',
47+
help='Directory to save the MNIST dataset.'
48+
)
49+
args = parser.parse_args()
3850

39-
# List urls of the mnist dataset
4051
train_x_url = 'https://github.com/fgnt/mnist/raw/master/train-images-idx3-ubyte.gz'
4152
train_y_url = 'https://github.com/fgnt/mnist/raw/master/train-labels-idx1-ubyte.gz'
4253
valid_x_url = 'https://github.com/fgnt/mnist/raw/master/t10k-images-idx3-ubyte.gz'
4354
valid_y_url = 'https://github.com/fgnt/mnist/raw/master/t10k-labels-idx1-ubyte.gz'
44-
45-
# Download the mnist dataset
46-
check_exist_or_download(train_x_url)
47-
check_exist_or_download(train_y_url)
48-
check_exist_or_download(valid_x_url)
49-
check_exist_or_download(valid_y_url)
55+
56+
check_exist_or_download(train_x_url, args.dir_path)
57+
check_exist_or_download(train_y_url, args.dir_path)
58+
check_exist_or_download(valid_x_url, args.dir_path)
59+
check_exist_or_download(valid_y_url, args.dir_path)

0 commit comments

Comments
 (0)