1212import pathlib
1313import pickle
1414import tempfile
15+ import unittest
16+ from unittest .mock import patch
1517
1618import h5py
1719import hdf5storage
2729
2830__test_functions = []
2931__test_functions_error = []
32+ __test_functions_module_not_found = []
3033
3134
3235def _skip_hdf5storage (* args , ** kwargs ):
@@ -42,7 +45,7 @@ def test_imports():
4245 assert hasattr (cebra , "load_data" )
4346
4447
45- def register (* file_endings ):
48+ def register (* file_endings , requires = () ):
4649 # for each file format
4750 def _register (f ):
4851 # f is the filename
@@ -53,6 +56,12 @@ def _register(f):
5356 lambda filename : f (filename + "." + file_ending )
5457 for file_ending in file_endings
5558 ])
59+ if len (requires ) > 0 :
60+ __test_functions_module_not_found .extend ([
61+ (requires , lambda filename : filename + "." + file_ending ,
62+ lambda filename : f (filename + "." + file_ending ))
63+ for file_ending in file_endings
64+ ])
5665 return f
5766
5867 return _register
@@ -152,7 +161,7 @@ def generate_numpy_no_array(filename):
152161# TODO: test raise ModuleFoundError for h5py
153162
154163
155- @register ("h5" , "hdf" , "hdf5" , "h" )
164+ @register ("h5" , "hdf" , "hdf5" , "h" , requires = ( "h5py" ,) )
156165def generate_h5 (filename ):
157166 A = np .arange (1000 ).reshape (10 , 100 )
158167 with h5py .File (filename , "w" ) as hf :
@@ -380,7 +389,7 @@ def generate_wrong_key(filename):
380389
381390
382391#### .CSV ####
383- @register ("csv" )
392+ @register ("csv" , requires = ( "pandas" ,) )
384393def generate_csv (filename ):
385394 A = np .arange (1000 ).reshape (10 , 100 )
386395 pd .DataFrame (A ).to_csv (filename , header = False , index = False , sep = "," )
@@ -404,7 +413,7 @@ def generate_csv_empty_file(filename):
404413
405414
406415#### EXCEL ####
407- @register ("xls" , "xlsx" , "xlsm" )
416+ @register ("xls" , "xlsx" , "xlsm" , requires = ( "pandas" , "pd" ) )
408417# TODO(celia): add the following extension: "xlsb", "odf", "ods", "odt",
409418# issue to create the files
410419def generate_excel (filename ):
@@ -777,3 +786,23 @@ def test_load_error(save_data):
777786
778787 with pytest .raises ((AttributeError , TypeError )):
779788 save_data (filename )
789+
790+
791+ @pytest .mark .parametrize ("module_names,get_path,save_data" ,
792+ __test_functions_module_not_found )
793+ def test_module_not_installed (module_names , get_path , save_data ):
794+
795+ assert len (module_names ) > 0
796+ assert isinstance (module_names , tuple )
797+
798+ with tempfile .NamedTemporaryFile () as tf :
799+ filename = tf .name
800+
801+ saved_array , loaded_array = save_data (filename )
802+ assert np .allclose (saved_array , loaded_array )
803+
804+ # TODO(stes): Sketch for a test --- needs additional work.
805+ # with patch.dict('sys.modules', {module: None for module in module_names}):
806+ # path = get_path(filename)
807+ # with pytest.raises(ModuleNotFoundError, match="cebra[datasets]"):
808+ # cebra.data.load.load(path)
0 commit comments