Skip to content

Commit 5257c5a

Browse files
better way to check for testing model
Signed-off-by: thiswillbeyourgithub <26625900+thiswillbeyourgithub@users.noreply.github.com>
1 parent 73924e1 commit 5257c5a

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

wdoc/utils/llm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
from wdoc.utils.env import env
1717
from wdoc.utils.misc import ModelName, get_model_max_tokens, langfuse_callback_holder
1818

19-
TESTING_LLM = "testing/testing"
20-
21-
# lorem ipsum is output by TESTING_LLM
19+
# lorem ipsum is output by "testing" llm
2220
LOREM_IPSUM = (
2321
"Lorem ipsum dolor sit amet, consectetur adipiscing "
2422
"elit, sed do eiusmod tempor incididunt ut labore et "

wdoc/utils/misc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,9 @@ def model_name_matcher(model: str) -> str:
498498
model has a known cost and print the matched name)
499499
Bypassed if env variable WDOC_NO_MODELNAME_MATCHING is 'true'
500500
"""
501-
assert "testing" not in model
501+
assert (
502+
"testing" not in model.lower()
503+
), f"Found 'testing' in model, this should not happen"
502504
assert "/" in model, f"expected / in model '{model}'"
503505
if env.WDOC_NO_MODELNAME_MATCHING:
504506
# logger.debug(f"Bypassing model name matching for model '{model}'")
@@ -617,7 +619,7 @@ def is_testing(self) -> bool:
617619
return False
618620

619621
def __hash__(self):
620-
# necessary for memoizing
622+
"necessary for memoizing"
621623
return (str(self.original.__hash__()) + str("ModelName".__hash__())).__hash__()
622624

623625

@@ -730,7 +732,7 @@ def get_splitter(
730732
length_function=get_tkn_length,
731733
)
732734

733-
if modelname.original == "testing/testing":
735+
if modelname.is_testing():
734736
return get_splitter(task=task, modelname=DEFAULT_SPLITTER_MODELNAME)
735737

736738
try:
@@ -1218,7 +1220,7 @@ def create_langfuse_callback(version: str) -> None:
12181220

12191221
@memoize
12201222
def get_supported_model_params(modelname: ModelName) -> list:
1221-
if modelname.backend == "testing":
1223+
if modelname.is_testing():
12221224
return []
12231225
if modelname.backend == "openrouter":
12241226
metadata = get_openrouter_metadata()

wdoc/wdoc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
ShouldIncreaseTopKAfterLLMEvalFiltering,
4040
)
4141

42-
from wdoc.utils.llm import TESTING_LLM, load_llm
42+
from wdoc.utils.llm import load_llm
4343

4444
from wdoc.utils.misc import ( # debug_chain,
4545
cache_dir,
@@ -186,6 +186,7 @@ def print_exception(exc_type, exc_value, exc_traceback):
186186
f"Cli_kwargs '{k}' is of type '{type(val)}' instead of '{expected_type}'"
187187
)
188188

189+
TESTING_LLM = "testing/testing"
189190
if (
190191
model == TESTING_LLM
191192
or model == "testing"

0 commit comments

Comments
 (0)