1212from functools import lru_cache
1313from pathlib import Path
1414from tempfile import TemporaryDirectory
15+ from typing import TYPE_CHECKING
16+
17+ if TYPE_CHECKING :
18+ from collections .abc import Generator
1519
1620import tomlkit
1721
@@ -112,7 +116,7 @@ def normalize_by_max(values: list[float]) -> list[float]:
112116 return [v / mx for v in values ]
113117
114118
115- def create_score_dictionary_from_metrics (weights : list [float ], * metrics : list [float ]) -> dict [int , int ]:
119+ def create_score_dictionary_from_metrics (weights : list [float ], * metrics : list [float ]) -> dict [int , float ]:
116120 """Combine multiple metrics into a single weighted score dictionary.
117121
118122 Each metric is a list of values (smaller = better).
@@ -208,67 +212,53 @@ def filter_args(addopts_args: list[str]) -> list[str]:
208212def modify_addopts (config_file : Path ) -> tuple [str , bool ]:
209213 file_type = config_file .suffix .lower ()
210214 filename = config_file .name
211- config = None
212215 if file_type not in {".toml" , ".ini" , ".cfg" } or not config_file .exists ():
213216 return "" , False
214217 # Read original file
215218 with Path .open (config_file , encoding = "utf-8" ) as f :
216219 content = f .read ()
217220 try :
218221 if filename == "pyproject.toml" :
219- # use tomlkit
220222 data = tomlkit .parse (content )
221223 original_addopts = data .get ("tool" , {}).get ("pytest" , {}).get ("ini_options" , {}).get ("addopts" , "" )
222- # nothing to do if no addopts present
223224 if original_addopts == "" :
224225 return content , False
225226 if isinstance (original_addopts , list ):
226227 original_addopts = " " .join (original_addopts )
227228 original_addopts = original_addopts .replace ("=" , " " )
228- addopts_args = (
229- original_addopts .split ()
230- ) # any number of space characters as delimiter, doesn't look at = which is fine
231- else :
232- # use configparser
233- config = configparser .ConfigParser ()
234- config .read_string (content )
235- data = {section : dict (config [section ]) for section in config .sections ()}
236- if config_file .name in {"pytest.ini" , ".pytest.ini" , "tox.ini" }:
237- original_addopts = data .get ("pytest" , {}).get ("addopts" , "" ) # should only be a string
238- else :
239- original_addopts = data .get ("tool:pytest" , {}).get ("addopts" , "" ) # should only be a string
240- original_addopts = original_addopts .replace ("=" , " " )
241229 addopts_args = original_addopts .split ()
242- new_addopts_args = filter_args (addopts_args )
243- if new_addopts_args == addopts_args :
244- return content , False
245- # change addopts now
246- if file_type == ".toml" :
247- data ["tool" ]["pytest" ]["ini_options" ]["addopts" ] = " " .join (new_addopts_args )
248- # Write modified file
230+ new_addopts_args = filter_args (addopts_args )
231+ if new_addopts_args == addopts_args :
232+ return content , False
233+ data ["tool" ]["pytest" ]["ini_options" ]["addopts" ] = " " .join (new_addopts_args ) # type: ignore[index]
249234 with Path .open (config_file , "w" , encoding = "utf-8" ) as f :
250235 f .write (tomlkit .dumps (data ))
251- return content , True
252- elif config_file .name in {"pytest.ini" , ".pytest.ini" , "tox.ini" }:
253- config .set ("pytest" , "addopts" , " " .join (new_addopts_args ))
254- # Write modified file
255- with Path .open (config_file , "w" , encoding = "utf-8" ) as f :
256- config .write (f )
257- return content , True
236+ return content , True
237+ config = configparser .ConfigParser ()
238+ config .read_string (content )
239+ ini_data = {section : dict (config [section ]) for section in config .sections ()}
240+ if config_file .name in {"pytest.ini" , ".pytest.ini" , "tox.ini" }:
241+ original_addopts = ini_data .get ("pytest" , {}).get ("addopts" , "" )
258242 else :
259- config .set ("tool:pytest" , "addopts" , " " .join (new_addopts_args ))
260- # Write modified file
261- with Path .open (config_file , "w" , encoding = "utf-8" ) as f :
262- config .write (f )
263- return content , True
243+ original_addopts = ini_data .get ("tool:pytest" , {}).get ("addopts" , "" )
244+ original_addopts = original_addopts .replace ("=" , " " )
245+ addopts_args = original_addopts .split ()
246+ new_addopts_args = filter_args (addopts_args )
247+ if new_addopts_args == addopts_args :
248+ return content , False
249+ section = "pytest" if config_file .name in {"pytest.ini" , ".pytest.ini" , "tox.ini" } else "tool:pytest"
250+ config .set (section , "addopts" , " " .join (new_addopts_args ))
251+ with Path .open (config_file , "w" , encoding = "utf-8" ) as f :
252+ config .write (f )
253+ return content , True
264254
265255 except Exception :
266256 logger .debug ("Trouble parsing" )
267- return content , False # not modified
257+ return content , False
268258
269259
270260@contextmanager
271- def custom_addopts () -> None :
261+ def custom_addopts () -> Generator [ None , None , None ] :
272262 closest_config_files = get_all_closest_config_files ()
273263
274264 original_content = {}
@@ -287,18 +277,17 @@ def custom_addopts() -> None:
287277
288278
289279@contextmanager
290- def add_addopts_to_pyproject () -> None :
280+ def add_addopts_to_pyproject () -> Generator [ None , None , None ] :
291281 pyproject_file = find_pyproject_toml ()
292- original_content = None
282+ original_content : str | None = None
293283 try :
294- # Read original file
295284 if pyproject_file .exists ():
296285 with Path .open (pyproject_file , encoding = "utf-8" ) as f :
297286 original_content = f .read ()
298287 data = tomlkit .parse (original_content )
299- data ["tool" ]["pytest" ] = {}
300- data ["tool" ]["pytest" ]["ini_options" ] = {}
301- data ["tool" ]["pytest" ]["ini_options" ]["addopts" ] = [
288+ data ["tool" ]["pytest" ] = {} # type: ignore[index]
289+ data ["tool" ]["pytest" ]["ini_options" ] = {} # type: ignore[index]
290+ data ["tool" ]["pytest" ]["ini_options" ]["addopts" ] = [ # type: ignore[index]
302291 "-n=auto" ,
303292 "-n" ,
304293 "1" ,
@@ -312,9 +301,9 @@ def add_addopts_to_pyproject() -> None:
312301 yield
313302
314303 finally :
315- # Restore original file
316- with Path .open (pyproject_file , "w" , encoding = "utf-8" ) as f :
317- f .write (original_content )
304+ if original_content is not None :
305+ with Path .open (pyproject_file , "w" , encoding = "utf-8" ) as f :
306+ f .write (original_content )
318307
319308
320309def encoded_tokens_len (s : str ) -> int :
@@ -418,13 +407,18 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
418407 return True , function_names
419408
420409
410+ _run_tmpdir : TemporaryDirectory [str ] | None = None
411+ _run_tmpdir_path : Path | None = None
412+
413+
421414def get_run_tmp_file (file_path : Path | str ) -> Path :
415+ global _run_tmpdir , _run_tmpdir_path
422416 if isinstance (file_path , str ):
423417 file_path = Path (file_path )
424- if not hasattr ( get_run_tmp_file , "tmpdir_path" ) :
425- get_run_tmp_file . tmpdir = TemporaryDirectory (prefix = "codeflash_" )
426- get_run_tmp_file . tmpdir_path = Path (get_run_tmp_file . tmpdir .name ).resolve ()
427- return get_run_tmp_file . tmpdir_path / file_path
418+ if _run_tmpdir_path is None :
419+ _run_tmpdir = TemporaryDirectory (prefix = "codeflash_" )
420+ _run_tmpdir_path = Path (_run_tmpdir .name ).resolve ()
421+ return _run_tmpdir_path / file_path
428422
429423
430424@lru_cache (maxsize = 1 )
0 commit comments