Skip to content

Commit f6d599f

Browse files
committed
Make tools.download_datasets consistent with the runner
closes #454
1 parent 60b9ba6 commit f6d599f

File tree

3 files changed

+141
-6
lines changed

3 files changed

+141
-6
lines changed

khiops/core/internals/runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,8 +1032,9 @@ def _initialize_default_samples_dir(self):
10321032
# Set the fallback value for the samples directory
10331033
home_samples_dir = Path.home() / "khiops_data" / "samples"
10341034

1035-
# Take the value of an environment variable in priority
1036-
if "KHIOPS_SAMPLES_DIR" in os.environ:
1035+
# Take the value of an environment variable in priority, if set to
1036+
# non-empty string
1037+
if "KHIOPS_SAMPLES_DIR" in os.environ and os.environ["KHIOPS_SAMPLES_DIR"]:
10371038
self._samples_dir = os.environ["KHIOPS_SAMPLES_DIR"]
10381039

10391040
# The samples location of Windows systems is:

khiops/tools.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import argparse
1414
import os
1515
import pathlib
16+
import platform
1617
import shutil
1718
import sys
1819
import tempfile
@@ -115,7 +116,10 @@ def download_datasets(
115116
"""Downloads the Khiops sample datasets for a given version
116117
117118
The datasets are downloaded to:
118-
- Windows: ``%USERPROFILE%\\khiops_data\\samples``
119+
- Windows:
120+
- ``%PUBLIC%\\khiops_data\\samples`` if ``%PUBLIC%`` is defined and
121+
points to a directory
122+
- ``%USERPROFILE%\\khiops_data\\samples`` otherwise
119123
- Linux/macOS: ``$HOME/khiops_data/samples``
120124
121125
Parameters
@@ -128,8 +132,23 @@ def download_datasets(
128132
# Note: The hidden parameter _called_from_shell is just to change the user messages.
129133

130134
# Check if the home sample dataset location is available and build it if necessary
131-
samples_dir = pathlib.Path.home() / "khiops_data" / "samples"
132-
if samples_dir.exists() and not force_overwrite:
135+
home_samples_dir = pathlib.Path.home() / "khiops_data" / "samples"
136+
137+
# Take the value of an environment variable in priority, if set to
138+
# non-empty string
139+
# If the environment variable is not set, samples location is:
140+
# - on Windows systems:
141+
# - %PUBLIC%\khiops_data\samples if %PUBLIC% exists
142+
# - %USERPROFILE%\khiops_data\samples otherwise
143+
# - on Linux / macOS systems:
144+
# - $HOME/khiops_data/samples
145+
if "KHIOPS_SAMPLES_DIR" in os.environ and os.environ["KHIOPS_SAMPLES_DIR"]:
146+
samples_dir = os.environ["KHIOPS_SAMPLES_DIR"]
147+
elif platform.system() == "Windows" and "PUBLIC" in os.environ:
148+
samples_dir = os.path.join(os.environ["PUBLIC"], "khiops_data", "samples")
149+
else:
150+
samples_dir = str(home_samples_dir)
151+
if os.path.exists(samples_dir) and not force_overwrite:
133152
if _called_from_shell:
134153
instructions = "Execute with '--force-overwrite' to overwrite it"
135154
else:
@@ -140,7 +159,7 @@ def download_datasets(
140159
)
141160
else:
142161
# Create the samples dataset directory
143-
if samples_dir.exists():
162+
if os.path.exists(samples_dir):
144163
shutil.rmtree(samples_dir)
145164
os.makedirs(samples_dir, exist_ok=True)
146165

tests/test_khiops_integrations.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import khiops.core as kh
1919
import khiops.core.internals.filesystems as fs
20+
from khiops import tools
2021
from khiops.core.exceptions import KhiopsEnvironmentError
2122
from khiops.core.internals.runner import KhiopsLocalRunner
2223
from khiops.extras.docker import KhiopsDockerRunner
@@ -30,6 +31,120 @@
3031
class KhiopsRunnerEnvironmentTests(unittest.TestCase):
3132
"""Test that runners in different environments work"""
3233

34+
def _is_samples_dir(self, samples_dir):
35+
expected_dataset_names = [
36+
"Accidents",
37+
"AccidentsSummary",
38+
"Adult",
39+
"CustomerExtended",
40+
"Iris",
41+
"Letter",
42+
"Mushroom",
43+
"SpliceJunction",
44+
]
45+
return os.path.isdir(samples_dir) and all(
46+
ds_name in os.listdir(samples_dir) for ds_name in expected_dataset_names
47+
)
48+
49+
@unittest.skipIf(
50+
os.environ.get("SKIP_EXPENSIVE_TESTS", "false").lower() == "true",
51+
"Skipping expensive test",
52+
)
53+
def test_samples_are_downloaded_according_to_the_runner_setting(self):
54+
"""Test that samples are downloaded to the runner samples directory"""
55+
56+
# Get initial runner
57+
initial_runner = kh.get_runner()
58+
59+
# Get initial Khiops samples dir environment variable
60+
initial_khiops_samples_dir = os.environ.get("KHIOPS_SAMPLES_DIR")
61+
62+
# Test that default samples download location is consistent with the
63+
# KhiopsLocalRunner samples directory
64+
with tempfile.TemporaryDirectory() as tmp_samples_dir:
65+
66+
# Set environment variable to the temporary samples dir
67+
os.environ["KHIOPS_SAMPLES_DIR"] = tmp_samples_dir
68+
69+
# Create test runner to update samples dir to tmp_samples_dir,
70+
# according to the newly-set environment variable
71+
test_runner = KhiopsLocalRunner()
72+
73+
# Set current runner to the test runner
74+
kh.set_runner(test_runner)
75+
76+
# Check that samples are not in tmp_samples_dir
77+
self.assertFalse(self._is_samples_dir(tmp_samples_dir))
78+
79+
# Download samples into existing, but empty, tmp_samples_dir
80+
tools.download_datasets(force_overwrite=True)
81+
82+
# Check that samples have been downloaded to tmp_samples_dir
83+
self.assertTrue(self._is_samples_dir(tmp_samples_dir))
84+
85+
# Remove KHIOPS_SAMPLES_DIR
86+
del os.environ["KHIOPS_SAMPLES_DIR"]
87+
88+
# Create test runner to update samples dir to the default runner samples
89+
# dir, following the deletion of the KHIOPS_SAMPLES_DIR environment
90+
# variable
91+
test_runner = KhiopsLocalRunner()
92+
93+
# Set current runner to the test runner
94+
kh.set_runner(test_runner)
95+
96+
# Get the default runner samples dir
97+
default_runner_samples_dir = kh.get_samples_dir()
98+
99+
# Copy existing default runner samples dir contents to temporary directory
100+
if os.path.isdir(default_runner_samples_dir):
101+
tmp_initial_samples_dir = tempfile.mkdtemp()
102+
shutil.copytree(
103+
default_runner_samples_dir,
104+
tmp_initial_samples_dir,
105+
dirs_exist_ok=True,
106+
)
107+
108+
# Remove default runner samples dir
109+
shutil.rmtree(default_runner_samples_dir)
110+
else:
111+
tmp_initial_samples_dir = None
112+
113+
# Check that the samples are not present in the default runner
114+
# samples dir
115+
self.assertFalse(self._is_samples_dir(default_runner_samples_dir))
116+
117+
# Download datasets to the default runner samples dir (which
118+
# should be created on this occasion)
119+
# Default samples dir does not exist anymore
120+
tools.download_datasets()
121+
122+
# Check that the default samples dir is populated
123+
self.assertTrue(self._is_samples_dir(default_runner_samples_dir))
124+
125+
# Clean-up default samples dir
126+
shutil.rmtree(default_runner_samples_dir)
127+
128+
# Restore initial samples dir contents if previously present
129+
if tmp_initial_samples_dir is not None and os.path.isdir(
130+
tmp_initial_samples_dir
131+
):
132+
shutil.copytree(
133+
tmp_initial_samples_dir,
134+
default_runner_samples_dir,
135+
dirs_exist_ok=True,
136+
)
137+
138+
# Remove temporary directory
139+
shutil.rmtree(tmp_initial_samples_dir)
140+
141+
# Restore initial KHIOPS_SAMPLES_DIR if set
142+
if initial_khiops_samples_dir is not None:
143+
os.environ["KHIOPS_SAMPLES_DIR"] = initial_khiops_samples_dir
144+
145+
# Restore initial runner
146+
kh.set_runner(initial_runner)
147+
33148
@unittest.skipIf(
34149
platform.system() != "Linux", "Skipping test for non-Linux platform"
35150
)

0 commit comments

Comments
 (0)