|
17 | 17 | # under the License. |
18 | 18 | # |
19 | 19 |
|
| 20 | +import argparse |
20 | 21 | import os |
21 | 22 | import urllib.request |
22 | 23 |
|
23 | 24 |
|
24 | | -def check_exist_or_download(url): |
| 25 | +def check_exist_or_download(url, download_dir): |
| 26 | + os.makedirs(download_dir, exist_ok=True) |
25 | 27 |
|
26 | | - download_dir = '/tmp/' # downloaded to the /tmp/ folder |
27 | 28 | name = url.rsplit('/', 1)[-1] |
28 | 29 | filename = os.path.join(download_dir, name) |
29 | 30 |
|
30 | 31 | if not os.path.isfile(filename): |
31 | | - print("Downloading %s" % url) |
| 32 | + print("Downloading %s to %s" % (url, filename)) |
32 | 33 | urllib.request.urlretrieve(url, filename) |
33 | 34 | else: |
34 | | - print("Already Downloaded: %s" % url) |
| 35 | + print("Already Downloaded: %s" % filename) |
35 | 36 |
|
36 | 37 |
|
37 | 38 | if __name__ == '__main__': |
| 39 | + parser = argparse.ArgumentParser( |
| 40 | + description='Download the MNIST dataset.' |
| 41 | + ) |
| 42 | + parser.add_argument( |
| 43 | + '-dir', |
| 44 | + '--dir-path', |
| 45 | + dest='dir_path', |
| 46 | + default='/tmp/mnist', |
| 47 | + help='Directory to save the MNIST dataset.' |
| 48 | + ) |
| 49 | + args = parser.parse_args() |
38 | 50 |
|
39 | | - # List urls of the mnist dataset |
40 | 51 | train_x_url = 'https://github.com/fgnt/mnist/raw/master/train-images-idx3-ubyte.gz' |
41 | 52 | train_y_url = 'https://github.com/fgnt/mnist/raw/master/train-labels-idx1-ubyte.gz' |
42 | 53 | valid_x_url = 'https://github.com/fgnt/mnist/raw/master/t10k-images-idx3-ubyte.gz' |
43 | 54 | valid_y_url = 'https://github.com/fgnt/mnist/raw/master/t10k-labels-idx1-ubyte.gz' |
44 | | - |
45 | | - # Download the mnist dataset |
46 | | - check_exist_or_download(train_x_url) |
47 | | - check_exist_or_download(train_y_url) |
48 | | - check_exist_or_download(valid_x_url) |
49 | | - check_exist_or_download(valid_y_url) |
| 55 | + |
| 56 | + check_exist_or_download(train_x_url, args.dir_path) |
| 57 | + check_exist_or_download(train_y_url, args.dir_path) |
| 58 | + check_exist_or_download(valid_x_url, args.dir_path) |
| 59 | + check_exist_or_download(valid_y_url, args.dir_path) |
0 commit comments