Skip to content

Commit 965ccb0

Browse files
committed
Remove latest_commit mechanism and add cache_full_library
1 parent 81968b8 commit 965ccb0

File tree

3 files changed

+35
-62
lines changed

3 files changed

+35
-62
lines changed

src/probeinterface/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,11 @@
3939
generate_multi_columns_probe,
4040
generate_multi_shank,
4141
)
42-
from .library import get_probe, list_manufacturers_in_library, list_probes_in_library, get_tags_in_library
42+
from .library import (
43+
get_probe,
44+
list_manufacturers_in_library,
45+
list_probes_in_library,
46+
get_tags_in_library,
47+
cache_full_library,
48+
)
4349
from .wiring import get_available_pathways

src/probeinterface/library.py

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,6 @@ def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None
9090
return None
9191
cache_folder = cache_folder_tag
9292
else:
93-
# load latest commit if exists
94-
commit_file = cache_folder / "main" / "latest_commit.txt"
95-
commit = None
96-
if commit_file.is_file():
97-
with open(commit_file, "r") as f:
98-
commit = f.read().strip()
99-
100-
# check against latest commit on github
101-
try:
102-
latest_commit = get_latest_commit("SpikeInterface", "probeinterface_library")["sha"]
103-
if commit is None or commit != latest_commit:
104-
# in this case we need to redownload the file and update the latest_commit.txt
105-
with open(cache_folder / "main" / "latest_commit.txt", "w") as f:
106-
f.write(latest_commit)
107-
return None
108-
except Exception:
109-
warnings.warn("Could not check for latest commit on github. Using local 'main' cache.")
110-
pass
111-
11293
cache_folder_tag = cache_folder / "main"
11394

11495
local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
@@ -153,7 +134,13 @@ def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = N
153134
os.remove(local_file)
154135

155136

156-
def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, tag: Optional[str] = None) -> "Probe":
137+
def get_probe(
138+
manufacturer: str,
139+
probe_name: str,
140+
name: Optional[str] = None,
141+
tag: Optional[str] = None,
142+
force_download: bool = False,
143+
) -> "Probe":
157144
"""
158145
Get probe from ProbeInterface library
159146
@@ -167,14 +154,18 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, ta
167154
Optional name for the probe
168155
tag : str | None, default: None
169156
Optional tag for the probe
157+
force_download : bool, default: False
158+
If True, force re-download of the probe file.
170159
171160
Returns
172161
----------
173162
probe : Probe object
174163
175164
"""
176-
177-
probe = get_from_cache(manufacturer, probe_name, tag=tag)
165+
if not force_download:
166+
probe = get_from_cache(manufacturer, probe_name, tag=tag)
167+
else:
168+
probe = None
178169

179170
if probe is None:
180171
download_probeinterface_file(manufacturer, probe_name, tag=tag)
@@ -187,6 +178,21 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, ta
187178
return probe
188179

189180

181+
def cache_full_library(tag=None) -> None:
182+
"""
183+
Download all probes from the library to the cache directory.
184+
"""
185+
manufacturers = list_manufacturers_in_library(tag=tag)
186+
187+
for manufacturer in manufacturers:
188+
probes = list_probes_in_library(manufacturer, tag=tag)
189+
for probe_name in probes:
190+
try:
191+
download_probeinterface_file(manufacturer, probe_name, tag=tag)
192+
except Exception as e:
193+
warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}")
194+
195+
190196
def list_manufacturers_in_library(tag=None) -> list[str]:
191197
"""
192198
Get the list of available manufacturers in the library
@@ -266,24 +272,3 @@ def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None,
266272
raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}")
267273
items = resp.json()
268274
return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."]
269-
270-
271-
def get_latest_commit(owner: str, repo: str, branch: str = "main", token: str = None):
272-
"""
273-
Get the latest commit SHA and message from a given branch (default: main).
274-
"""
275-
url = f"https://api.github.com/repos/{owner}/{repo}/commits/{branch}"
276-
headers = {}
277-
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
278-
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
279-
headers["Authorization"] = f"token {token}"
280-
resp = requests.get(url, headers=headers)
281-
if resp.status_code != 200:
282-
raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}")
283-
data = resp.json()
284-
return {
285-
"sha": data["sha"],
286-
"message": data["commit"]["message"],
287-
"author": data["commit"]["author"]["name"],
288-
"date": data["commit"]["author"]["date"],
289-
}

tests/test_library.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,6 @@ def test_download_probeinterface_file():
1919
download_probeinterface_file(manufacturer, probe_name, tag=None)
2020

2121

22-
def test_latest_commit_mechanism():
23-
_ = get_probe(manufacturer, probe_name)
24-
cache_folder = get_cache_folder()
25-
latest_commit_file = cache_folder / "main" / "latest_commit.txt"
26-
assert latest_commit_file.is_file()
27-
28-
# now we manually change latest_commit.txt to something else
29-
with open(latest_commit_file, "w") as f:
30-
f.write("1234567890123456789012345678901234567890")
31-
32-
# now we get the probe again and make sure the latest_commit.txt file is updated
33-
_ = get_probe(manufacturer, probe_name)
34-
assert latest_commit_file.is_file()
35-
with open(latest_commit_file, "r") as f:
36-
latest_commit = f.read().strip()
37-
assert latest_commit != "123456789012345678901234567890123456789"
38-
39-
4022
def test_get_from_cache():
4123
download_probeinterface_file(manufacturer, probe_name)
4224
probe = get_from_cache(manufacturer, probe_name)

0 commit comments

Comments
 (0)