diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..1213a41 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,46 @@ +name: "Black – code-style check" + +on: + pull_request: + paths: ["**/*.py"] + push: + branches: [main] + paths: ["**/*.py"] + +jobs: + black: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + # ---------- pip wheel cache ---------- + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-black-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-black-pip- + + # ---------- Black cache (formatting state) ---------- + - name: Cache Black .cache + uses: actions/cache@v4 + with: + path: .cache/black + key: ${{ runner.os }}-black-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-black- + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install Black (pinned) + run: | + python -m pip install --upgrade pip + pip install black==24.10.0 + + - name: Run Black in check mode + run: black --check --diff src diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..3d0ca9f --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,47 @@ +name: "mypy – static type checks" + +on: + pull_request: + paths: ["**/*.py"] + push: + branches: [main] + paths: ["**/*.py"] + +jobs: + mypy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + # ---------- pip wheel cache ---------- + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-py${{ matrix.python-version }}-pip- + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install deps + mypy + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + # ---------- mypy incremental cache ---------- + - name: Cache mypy .mypy_cache + uses: actions/cache@v4 + with: + path: .mypy_cache + key: ${{ runner.os }}-mypy-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-mypy- + + - name: Type‑check + run: | + mypy src diff --git a/README.md b/README.md index 757048b..ee08591 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,10 @@ We are keeping implementing more agents and will open-source them very soon. Uti ## Installation -1. Create and activate a conda environment with Python 3.9.18: +1. Create and activate a conda environment with Python 3.13: ```sh - conda create -n repoaudit python=3.9.18 + conda create -n repoaudit python=3.13 conda activate repoaudit ``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2f07553 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.black] +line-length = 88 # keep the default PEP‑8‑plus setting +target-version = ["py39"] # ensures Python‑3.9 compatible formatting +skip-string-normalization = false +include = '\.pyi?$' + +exclude = ''' +/( + \.git + | \.mypy_cache + | \.pytest_cache + | \.venv + | build + | dist +)/ +''' diff --git a/requirements.txt b/requirements.txt index bd97068..78dfb30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,8 @@ streamlit botocore boto3 black -anthropic \ No newline at end of file +anthropic +mypy +types-networkx +types-tqdm +boto3-stubs[essential] diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/agent/dfbscan.py b/src/agent/dfbscan.py index 5e266f4..d49073a 100644 --- a/src/agent/dfbscan.py +++ b/src/agent/dfbscan.py @@ -37,15 +37,15 @@ class DFBScanAgent(Agent): def __init__( self, - bug_type, - is_reachable, - project_path, - language, - ts_analyzer, - model_name, - temperature, - call_depth, - max_neural_workers=30, + bug_type: str, + is_reachable: bool, + project_path: str, + language: str, + ts_analyzer: TSAnalyzer, + model_name: str, + temperature: float, + call_depth: int, + max_neural_workers: int = 30, agent_id: int = 0, ) -> None: self.bug_type = bug_type @@ -112,7 +112,9 @@ def __obtain_extractor(self) -> DFBScanExtractor: elif self.language == "Go": if self.bug_type == "NPD": return Go_NPD_Extractor(self.ts_analyzer) - return None + raise NotImplementedError( + f"Unsupported bug type: {self.bug_type} in {self.language}" + ) def __update_worklist( self, @@ -174,21 +176,23 @@ def __update_worklist( if not is_CFL_reachable: continue - for para in callee_function.paras: - if para.index == value.index: - delta_worklist.append( - (para, callee_function, new_call_context) - ) - self.state.update_external_value_match( - (value, call_context), set({(para, new_call_context)}) - ) + if callee_function.paras is not None: + for para in callee_function.paras: + if para.index == value.index: + delta_worklist.append( + (para, callee_function, new_call_context) + ) + self.state.update_external_value_match( + (value, call_context), + set({(para, new_call_context)}), + ) if value.label == ValueLabel.PARA: # Consider side-effect. # Example: the parameter *p is used in the function: p->f = null; # We need to consider the side-effect of p. - caller_function = self.ts_analyzer.get_all_caller_functions(function) - for caller_function in caller_function: + caller_functions = self.ts_analyzer.get_all_caller_functions(function) + for caller_function in caller_functions: new_call_context = copy.deepcopy(call_context) top_unmatched_context_label = ( new_call_context.get_top_unmatched_context_label() @@ -442,9 +446,13 @@ def start_scan_sequential(self) -> None: ret.name, ret.line_number - start_function.start_line_number + 1, ) - for ret in start_function.retvals + for ret in ( + start_function.retvals + if start_function.retvals is not None + else [] + ) ] - input = IntraDataFlowAnalyzerInput( + df_input = IntraDataFlowAnalyzerInput( start_function, start_value, sink_values, @@ -453,20 +461,22 @@ def start_scan_sequential(self) -> None: ) # Invoke the intra-procedural data-flow analysis - output = self.intra_dfa.invoke(input) - if output is None: + df_output = self.intra_dfa.invoke( + df_input, IntraDataFlowAnalyzerOutput + ) + if df_output is None: continue - for path_index in range(len(output.reachable_values)): + for path_index in range(len(df_output.reachable_values)): reachable_values_in_single_path = set([]) - for value in output.reachable_values[path_index]: + for value in df_output.reachable_values[path_index]: reachable_values_in_single_path.add((value, call_context)) self.state.update_reachable_values_per_path( (start_value, call_context), reachable_values_in_single_path ) delta_worklist = self.__update_worklist( - input, output, call_context, path_index + df_input, df_output, call_context, path_index ) worklist.extend(delta_worklist) @@ -479,7 +489,7 @@ def start_scan_sequential(self) -> None: continue for buggy_path in self.state.potential_buggy_paths[src_value].values(): - input = PathValidatorInput( + pv_input = PathValidatorInput( self.bug_type, buggy_path, { @@ -487,12 +497,14 @@ def start_scan_sequential(self) -> None: for value in buggy_path }, ) - output: PathValidatorOutput = self.path_validator.invoke(input) + pv_output = self.path_validator.invoke( + pv_input, PathValidatorOutput + ) - if output is None: + if pv_output is None: continue - if output.is_reachable: + if pv_output.is_reachable: relevant_functions = {} for value in buggy_path: function = self.ts_analyzer.get_function_from_localvalue( @@ -505,7 +517,7 @@ def start_scan_sequential(self) -> None: self.bug_type, src_value, relevant_functions, - output.explanation_str, + pv_output.explanation_str, ) self.state.update_bug_report(bug_report) @@ -606,28 +618,30 @@ def __process_src_value(self, src_value: Value) -> None: ret_values = [ (ret.name, ret.line_number - start_function.start_line_number + 1) - for ret in start_function.retvals + for ret in ( + start_function.retvals if start_function.retvals is not None else [] + ) ] - input = IntraDataFlowAnalyzerInput( + df_input = IntraDataFlowAnalyzerInput( start_function, start_value, sink_values, call_statements, ret_values ) # Invoke the intra-procedural data-flow analysis - output = self.intra_dfa.invoke(input) + df_output = self.intra_dfa.invoke(df_input, IntraDataFlowAnalyzerOutput) - if output is None: + if df_output is None: continue - for path_index in range(len(output.reachable_values)): + for path_index in range(len(df_output.reachable_values)): reachable_values_in_single_path = set([]) - for value in output.reachable_values[path_index]: + for value in df_output.reachable_values[path_index]: reachable_values_in_single_path.add((value, call_context)) self.state.update_reachable_values_per_path( (start_value, call_context), reachable_values_in_single_path ) delta_worklist = self.__update_worklist( - input, output, call_context, path_index + df_input, df_output, call_context, path_index ) worklist.extend(delta_worklist) @@ -645,22 +659,25 @@ def __process_src_value(self, src_value: Value) -> None: for value in buggy_path } - relevant_functions = values_to_functions.values() + functions: Set[Function] = set() + for func in values_to_functions.values(): + if func is not None: + functions.add(func) - if self.state.check_existence(src_value, relevant_functions): + if self.state.check_existence(src_value, functions): continue - input = PathValidatorInput( + pv_input = PathValidatorInput( self.bug_type, buggy_path, values_to_functions, ) - output: PathValidatorOutput = self.path_validator.invoke(input) + pv_output = self.path_validator.invoke(pv_input, PathValidatorOutput) - if output is None: + if pv_output is None: continue - if output.is_reachable: + if pv_output.is_reachable: relevant_functions = {} for value in buggy_path: function = self.ts_analyzer.get_function_from_localvalue(value) @@ -668,7 +685,10 @@ def __process_src_value(self, src_value: Value) -> None: relevant_functions[function.function_id] = function bug_report = BugReport( - self.bug_type, src_value, relevant_functions, output.explanation_str + self.bug_type, + src_value, + relevant_functions, + pv_output.explanation_str, ) self.state.update_bug_report(bug_report) bug_report_dict = { diff --git a/src/agent/metascan.py b/src/agent/metascan.py index c7dd71e..49cd746 100644 --- a/src/agent/metascan.py +++ b/src/agent/metascan.py @@ -18,7 +18,9 @@ class MetaScanAgent(Agent): Used for testing llmtools :) """ - def __init__(self, project_path, language, ts_analyzer) -> None: + def __init__( + self, project_path: str, language: str, ts_analyzer: TSAnalyzer + ) -> None: self.project_path = project_path self.project_name = project_path.split("/")[-1] self.language = language @@ -36,23 +38,31 @@ def start_scan(self) -> None: ) if not os.path.exists(log_dir_path): os.makedirs(log_dir_path) - self.logger = Logger(self.log_dir_path + "/" + "metascan.log") + self.logger = Logger(log_dir_path + "/" + "metascan.log") for function_id in self.ts_analyzer.function_env: - function_meta_data = {} + function_meta_data: Dict = {} function = self.ts_analyzer.function_env[function_id] function_meta_data["function_id"] = function.function_id function_meta_data["function_name"] = function.function_name function_meta_data["function_start_line"] = function.start_line_number function_meta_data["function_end_line"] = function.end_line_number - function_meta_data["parameters"] = [str(para) for para in function.paras] + function_meta_data["parameters"] = ( + [str(para) for para in function.paras] + if function.paras is not None + else [] + ) - function_meta_data["retvals"] = [str(retval) for retval in function.retvals] + function_meta_data["retvals"] = ( + [str(retval) for retval in function.retvals] + if function.retvals is not None + else [] + ) function_meta_data["call_sites"] = [] for call_site in function.function_call_site_nodes: - call_site_info = {} + call_site_info: Dict = {} file_content = self.ts_analyzer.fileContentDic[function.file_path] call_site_info["callee_id"] = ( self.ts_analyzer.get_callee_function_ids_at_callsite( @@ -110,7 +120,7 @@ def start_scan(self) -> None: ) = self.ts_analyzer.function_env[function_id].if_statements[ (if_statement_start_line, if_statement_end_line) ] - if_statement = {} + if_statement: Dict = {} if_statement["condition_str"] = condition_str if_statement["condition_start_line"] = condition_start_line if_statement["condition_end_line"] = condition_end_line @@ -134,7 +144,7 @@ def start_scan(self) -> None: ) = self.ts_analyzer.function_env[function_id].loop_statements[ (loop_statement_start_line, loop_statement_end_line) ] - loop_statement = {} + loop_statement: Dict = {} loop_statement["loop_statement_start_line"] = loop_statement_start_line loop_statement["loop_statement_end_line"] = loop_statement_end_line loop_statement["header_str"] = header_str diff --git a/src/llmtool/LLM_tool.py b/src/llmtool/LLM_tool.py index 4773179..721c2d5 100644 --- a/src/llmtool/LLM_tool.py +++ b/src/llmtool/LLM_tool.py @@ -1,18 +1,18 @@ from llmtool.LLM_utils import * from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, Optional, Type, TypeVar, cast from ui.logger import Logger class LLMToolInput(ABC): - def __init__(self): - pass + def __init__(self) -> None: + raise NotImplementedError @abstractmethod - def __hash__(self): - pass + def __hash__(self) -> int: + raise NotImplementedError - def __eq__(self, value): + def __eq__(self, value) -> bool: return self.__hash__() == value.__hash__() @@ -21,6 +21,9 @@ def __init__(self): pass +T = TypeVar("T", bound=LLMToolOutput) + + class LLMTool(ABC): def __init__( self, @@ -44,7 +47,23 @@ def __init__( self.output_token_cost = 0 self.total_query_num = 0 - def invoke(self, input: LLMToolInput) -> LLMToolOutput: + def invoke(self, input: LLMToolInput, cls: Type[T]) -> Optional[T]: + """ + Invoke the LLM tool with the given input. + :param input: the input of the LLM tool + :param cls: the class of the output + :return: the output of the LLM tool + """ + output = self._invoke(input) + if output is None: + return None + + if not isinstance(output, cls): + raise TypeError(f"Expected output of type {cls}, but got {type(output)}") + + return cast(T, output) + + def _invoke(self, input: LLMToolInput) -> Optional[LLMToolOutput]: class_name = type(self).__name__ self.logger.print_console(f"The LLM Tool {class_name} is invoked.") if input in self.cache: @@ -82,6 +101,6 @@ def _get_prompt(self, input: LLMToolInput) -> str: @abstractmethod def _parse_response( - self, response: str, input: LLMToolInput = None - ) -> LLMToolOutput: + self, response: str, input: Optional[LLMToolInput] = None + ) -> Optional[LLMToolOutput]: pass diff --git a/src/llmtool/dfbscan/intra_dataflow_analyzer.py b/src/llmtool/dfbscan/intra_dataflow_analyzer.py index 360bebf..b26b54b 100644 --- a/src/llmtool/dfbscan/intra_dataflow_analyzer.py +++ b/src/llmtool/dfbscan/intra_dataflow_analyzer.py @@ -67,7 +67,9 @@ def __init__( ) return - def _get_prompt(self, input: IntraDataFlowAnalyzerInput) -> str: + def _get_prompt(self, input: LLMToolInput) -> str: + if not isinstance(input, IntraDataFlowAnalyzerInput): + raise TypeError("Expect IntraDataFlowAnalyzerInput") with open(self.prompt_file, "r") as f: prompt_template_dict = json.load(f) prompt = prompt_template_dict["task"] @@ -109,8 +111,8 @@ def _get_prompt(self, input: IntraDataFlowAnalyzerInput) -> str: return prompt def _parse_response( - self, response: str, input: IntraDataFlowAnalyzerInput - ) -> IntraDataFlowAnalyzerOutput: + self, response: str, input: Optional[LLMToolInput] = None + ) -> Optional[LLMToolOutput]: """ Parse the LLM response to extract all execution paths and their propagation details. @@ -121,7 +123,7 @@ def _parse_response( Returns: IntraDataFlowAnalyzerOutput: The output containing reachable values for each path. """ - paths = [] + paths: List[Dict] = [] # Regex to match a path header line, e.g., "Path 1: Lines 2 -> 3" path_header_re = re.compile(r"Path\s*(\d+):\s*([^;]+);?$") @@ -172,6 +174,10 @@ def _parse_response( if current_path: paths.append(current_path) + assert input is not None, "input cannot be none" + if not isinstance(input, IntraDataFlowAnalyzerInput): + raise TypeError("Expect IntraDataFlowAnalyzerInput") + # Process paths to extract reachable values reachable_values = [] file_path = input.function.file_path diff --git a/src/llmtool/dfbscan/path_validator.py b/src/llmtool/dfbscan/path_validator.py index 18eb653..069edfc 100644 --- a/src/llmtool/dfbscan/path_validator.py +++ b/src/llmtool/dfbscan/path_validator.py @@ -1,7 +1,6 @@ from os import path import json -import time -from typing import List, Set, Optional, Dict +from typing import List, Dict from llmtool.LLM_utils import * from llmtool.LLM_tool import * from memory.syntactic.function import * @@ -16,7 +15,7 @@ def __init__( self, bug_type: str, values: List[Value], - values_to_functions: Dict[Value, Function], + values_to_functions: Dict[Value, Optional[Function]], ) -> None: self.bug_type = bug_type self.values = values @@ -59,7 +58,9 @@ def __init__( self.prompt_file = f"{BASE_PATH}/prompt/{language}/dfbscan/path_validator.json" return - def _get_prompt(self, input: PathValidatorInput) -> str: + def _get_prompt(self, input: LLMToolInput) -> str: + if not isinstance(input, PathValidatorInput): + raise TypeError("expect PathValidatorInput") with open(self.prompt_file, "r") as f: prompt_template_dict = json.load(f) prompt = prompt_template_dict["task"] @@ -88,7 +89,7 @@ def _get_prompt(self, input: PathValidatorInput) -> str: program = "\n".join( [ - "```\n" + func.lined_code + "\n```\n" + "```\n" + func.lined_code + "\n```\n" if func is not None else "\n" for func in input.values_to_functions.values() ] ) @@ -96,8 +97,8 @@ def _get_prompt(self, input: PathValidatorInput) -> str: return prompt def _parse_response( - self, response: str, input: PathValidatorInput - ) -> PathValidatorOutput: + self, response: str, input: Optional[LLMToolInput] = None + ) -> Optional[LLMToolOutput]: answer_match = re.search(r"Answer:\s*(\w+)", response) if answer_match: answer = answer_match.group(1).strip() diff --git a/src/memory/report/bug_report.py b/src/memory/report/bug_report.py index 62ebc0e..db08e41 100644 --- a/src/memory/report/bug_report.py +++ b/src/memory/report/bug_report.py @@ -10,7 +10,7 @@ def __init__( buggy_value: Value, relevant_functions: Dict[int, Function], explanation: str, - is_human_confirmed_true: bool = None, + is_human_confirmed_true: bool = False, ) -> None: """ :param bug_type: the bug type diff --git a/src/memory/semantic/dfbscan_state.py b/src/memory/semantic/dfbscan_state.py index 1fc1f56..6cfbc87 100644 --- a/src/memory/semantic/dfbscan_state.py +++ b/src/memory/semantic/dfbscan_state.py @@ -26,7 +26,7 @@ def __init__(self, src_values: List[Value], sink_values: List[Value]) -> None: self._potential_buggy_paths: Dict[Value, Dict[str, List[Value]]] = {} # Bug reports - self._bug_reports: dict[int, List[BugReport]] = {} + self._bug_reports: Dict[int, BugReport] = {} self._total_bug_count = 0 # Create locks for each field @@ -113,7 +113,7 @@ def potential_buggy_paths(self) -> Dict[Value, Dict[str, List[Value]]]: return self._potential_buggy_paths.copy() @property - def bug_reports(self) -> Dict[int, List[BugReport]]: + def bug_reports(self) -> Dict[int, BugReport]: """ Get the bug reports """ diff --git a/src/memory/semantic/metascan_state.py b/src/memory/semantic/metascan_state.py index f4af83f..8b58ced 100644 --- a/src/memory/semantic/metascan_state.py +++ b/src/memory/semantic/metascan_state.py @@ -2,12 +2,14 @@ from memory.syntactic.value import * from memory.report.bug_report import * from memory.semantic.state import * -from typing import List, Tuple, Dict +from typing import Dict class MetaScanState(State): def __init__(self) -> None: - self.function_meta_data_dict = {} # function id --> function meta data + self.function_meta_data_dict: Dict[int, Dict] = ( + {} + ) # function id --> function meta data return def update_function_meta_data( diff --git a/src/memory/semantic/state.py b/src/memory/semantic/state.py index fa80941..4f93f34 100644 --- a/src/memory/semantic/state.py +++ b/src/memory/semantic/state.py @@ -1,5 +1,4 @@ -from abc import ABC, abstractmethod -from typing import Dict +from abc import ABC class State(ABC): diff --git a/src/memory/syntactic/function.py b/src/memory/syntactic/function.py index 9ea60df..358a039 100644 --- a/src/memory/syntactic/function.py +++ b/src/memory/syntactic/function.py @@ -1,4 +1,10 @@ -import tree_sitter +from tree_sitter import Node +from typing import List, Optional, Set, Tuple, Dict +from memory.syntactic.value import Value + +LineScope = Tuple[int, int] +IfInfo = Tuple[int, int, str, LineScope, LineScope] +LoopInfo = Tuple[int, int, str, int, int] class Function: @@ -9,7 +15,7 @@ def __init__( function_code: str, start_line_number: int, end_line_number: int, - function_node: tree_sitter.Node, + function_node: Node, file_path: str, ) -> None: """ @@ -31,16 +37,18 @@ def __init__( self.parse_tree_root_node = ( function_node # root node of the parse tree of the current function ) - self.function_call_site_nodes = [] # call site info of user-defined functions - self.api_call_site_nodes = [] # call site info of library APIs + self.function_call_site_nodes: List[Node] = ( + [] + ) # call site info of user-defined functions + self.api_call_site_nodes: List[Node] = [] # call site info of library APIs ## Results of AST node type analysis - self.paras = None # A set of parameters - self.retvals = None # A set of returned values + self.paras: Optional[Set[Value]] = None # A set of parameters + self.retvals: Optional[Set[Value]] = None # A set of returned values ## Results of intraprocedural control flow analysis - self.if_statements = {} # if statement info - self.loop_statements = {} # loop statement info + self.if_statements: Dict[LineScope, IfInfo] = {} # if statement info + self.loop_statements: Dict[LineScope, LoopInfo] = {} # loop statement info def __hash__(self) -> int: return hash( diff --git a/src/memory/syntactic/value.py b/src/memory/syntactic/value.py index 8e41b10..61e439f 100644 --- a/src/memory/syntactic/value.py +++ b/src/memory/syntactic/value.py @@ -86,7 +86,9 @@ def __str__(self) -> str: + ")" ) - def __eq__(self, other: "Value") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Value): + return NotImplemented return self.__str__() == other.__str__() def __repr__(self) -> str: diff --git a/src/repoaudit.py b/src/repoaudit.py index a376ba3..24d3639 100644 --- a/src/repoaudit.py +++ b/src/repoaudit.py @@ -39,7 +39,7 @@ def __init__( self.project_path = args.project_path self.language = args.language - self.code_in_files = {} + self.code_in_files: Dict[str, str] = {} self.model_name = args.model_name self.temperature = args.temperature @@ -65,6 +65,7 @@ def __init__( # Load all files with the specified suffix in the project path self.traverse_files(self.project_path, suffixs) + self.ts_analyzer: TSAnalyzer if self.language == "Cpp": self.ts_analyzer = Cpp_TSAnalyzer( self.code_in_files, self.language, self.max_symbolic_workers @@ -92,8 +93,6 @@ def start_repo_auditing(self) -> None: self.project_path, self.language, self.ts_analyzer, - self.model_name, - self.temperature, ) metascan_pipeline.start_scan() diff --git a/src/test/test_dfa.py b/src/test/test_dfa.py deleted file mode 100644 index 9665257..0000000 --- a/src/test/test_dfa.py +++ /dev/null @@ -1,115 +0,0 @@ -import sys -from os import path -from pathlib import Path -import json -import time - -sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) - -from tstool.bugscan_extractor.bugscan_extractor import * -from tstool.bugscan_extractor.Cpp.Cpp_BOF_extractor import * -from tstool.bugscan_extractor.Cpp.Cpp_MLK_extractor import * -from tstool.bugscan_extractor.Cpp.Cpp_NPD_extractor import * -from tstool.bugscan_extractor.Cpp.Cpp_UAF_extractor import * -from tstool.bugscan_extractor.Go.Go_BOF_extractor import * -from tstool.bugscan_extractor.Go.Go_NPD_extractor import * -from tstool.bugscan_extractor.Java.Java_NPD_extractor import * -from tstool.bugscan_extractor.Python.Python_NPD_extractor import * - -from repoaudit import RepoAudit - -BASE_DIR = path.dirname(path.dirname(path.dirname(path.abspath(__file__)))) - - -# TODO: @jinyao: We need to utilize the methods of repoaudit to run the test cases -class TestDFScan: - ############### Test DFScan ############### - def __init__(self, language, bug_type): - self.language = language - self.bug_type = bug_type - self.test_cases = set() - self.test_case_num = 0 - - def run(self): - start_time = time.time() - self.analyze() - self.validate() - print("====================Test Result====================") - print(f"Language: {self.language}") - print(f"Bug Type: {self.bug_type}") - print(f"Execution Time: {time.time() - start_time:.2f} seconds") - print( - f"Detected Test Cases: {self.test_case_num-len(self.test_cases)} / {self.test_case_num}" - ) - print("Missing Test Cases: ", self.test_cases) - - def analyze(self): - seed_path = f"{BASE_DIR}/result/src_extract/{self.bug_type}/{self.language}_toy/seed_result.json" - project_path = f"{BASE_DIR}/benchmark/{self.language}/toy" - - if self.language == "Cpp": - for file in Path(f"{project_path}/{self.bug_type}").rglob("*.cpp"): - self.test_cases.add(str(file)) - self.test_case_num = len(self.test_cases) - if self.bug_type == "NPD": - extractor = Cpp_NPD_Extractor(project_path, self.language, seed_path) - elif self.bug_type == "MLK": - extractor = Cpp_MLK_Extractor(project_path, self.language, seed_path) - elif self.bug_type == "UAF": - extractor = Cpp_UAF_Extractor(project_path, self.language, seed_path) - else: - raise ValueError("Invalid bug type") - else: - raise ValueError("Invalid language") - extractor.run() - - batch_scan = RepoAudit( - seed_spec_file=seed_path, - project_path=project_path, - language=self.language, - inference_model_name="claude-3.7", - temperature=0.0, - scanners=["DFscan"], - bug_type=self.bug_type, - boundary=3, - max_neural_workers=1, - ) - - batch_scan.start_batch_scan() - - def validate(self): - result_dir = ( - f"{BASE_DIR}/result/DFscan-claude-3.7/{self.bug_type}/{self.language}_toy/" - ) - if Path(result_dir).exists(): - timestamps = [d.name for d in Path(result_dir).iterdir() if d.is_dir()] - if not timestamps: - print("No results found.") - return - timestamps.sort(reverse=True) - timestamp = timestamps[0] - - result_path = f"{BASE_DIR}/result/DFscan-claude-3.7/{self.bug_type}/{self.language}_toy/{timestamp}/bug_info.json" - if not Path(result_path).exists(): - print("Result file does not exist.") - return - with open(result_path, "r") as f: - results = json.load(f) - - for _, item in results.items(): - paths = item["Path"] - vali_llm = item["Vali_LLM"] - if vali_llm == "True": - file_name = paths[0]["file_name"] - for path in paths: - if path["file_name"] != file_name: - print( - f"Cross-file Bug Trace: {file_name} -> {path['file_name']}" - ) - break - self.test_cases.discard(file_name) - - -if __name__ == "__main__": - test_Cpp_NPD = TestDFScan("Cpp", "NPD") - test_Cpp_NPD.run() diff --git a/src/test/test_state.py b/src/test/test_state.py deleted file mode 100644 index 3e46b65..0000000 --- a/src/test/test_state.py +++ /dev/null @@ -1,90 +0,0 @@ -import unittest -import sys -from os import path - -sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) - -from memory.semantic.bugscan_state import BugScanState -from memory.syntactic.value import * -from memory.syntactic.function import * - - -class TestState(unittest.TestCase): - def setUp(self): - # Create LocalValue and Function instances - for i in range(1, 5): - function = Function( - function_id=i, - function_name=f"test_function{i}", - function_code="", - start_line_number=1, - end_line_number=10, - function_node=None, - file_path="test_file.py", - ) - local_value = Value( - name=f"test_var{i}", - line_number=i, - v_label=ValueLabel.SRC, - file="test_file.py", - ) - setattr(self, f"state{i}", BugScanState(local_value, function)) - - # Set up callers and callees - self.state2.callers.append(self.state1) - self.state1.callees.append(self.state2) - - self.state2.callers.append(self.state4) - self.state4.callees.append(self.state2) - - self.state3.callers.append(self.state2) - self.state2.callees.append(self.state3) - - # Set up slices - self.state1.slice = "slice1" - self.state2.slice = "slice2" - self.state3.slice = "slice3" - self.state4.slice = "slice4" - - def test_find_root(self): - # Test if state3 can find the root state1 - roots = self.state3.find_root() - self.assertEqual(len(roots), 2) - self.assertEqual(roots, [self.state1, self.state4]) - - # Test if state1 is its own root - roots = self.state1.find_root() - self.assertEqual(len(roots), 1) - self.assertEqual(roots[0], self.state1) - - def test_get_slice_tree(self): - # Test if root functions can get the entire slice tree - roots = self.state3.find_root() - for root in roots: - slice_list = root.get_slice_tree() - slices = set(slice_list) - self.assertEqual(len(slices), 3) - if root == self.state1: - self.assertEqual(slices, {"slice1", "slice2", "slice3"}) - elif root == self.state4: - self.assertEqual(slices, {"slice2", "slice3", "slice4"}) - - # Test if state3 can get the entire slice tree - slices = self.state3.get_slice_tree() - self.assertEqual(len(slices), 1) - self.assertEqual(slices, ["slice3"]) - - def test_get_call_tree(self): - expected_tree_state1 = ( - "test_function1\n" " └── test_function2\n" " └── test_function3\n" - ) - self.assertEqual(self.state1.get_call_tree(), expected_tree_state1) - - expected_tree_state4 = ( - "test_function4\n" " └── test_function2\n" " └── test_function3\n" - ) - self.assertEqual(self.state4.get_call_tree(), expected_tree_state4) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/tstool/analyzer/Cpp_TS_analyzer.py b/src/tstool/analyzer/Cpp_TS_analyzer.py index fb0e5b3..244e82f 100644 --- a/src/tstool/analyzer/Cpp_TS_analyzer.py +++ b/src/tstool/analyzer/Cpp_TS_analyzer.py @@ -180,7 +180,7 @@ def get_arguments_at_callsite( :param call_site_node: the node of the call site :return: the arguments """ - arguments = set([]) + arguments: Set[Value] = set([]) file_name = current_function.file_path source_code = self.code_in_files[file_name] for sub_node in call_site_node.children: diff --git a/src/tstool/analyzer/Go_TS_analyzer.py b/src/tstool/analyzer/Go_TS_analyzer.py index 3a7bc2f..13ceb2a 100644 --- a/src/tstool/analyzer/Go_TS_analyzer.py +++ b/src/tstool/analyzer/Go_TS_analyzer.py @@ -116,7 +116,7 @@ def get_arguments_at_callsite( :param call_site_node: the node of the call site :return: the arguments """ - arguments = set([]) + arguments: Set[Value] = set([]) file_name = current_function.file_path source_code = self.code_in_files[file_name] for sub_node in call_site_node.children: diff --git a/src/tstool/analyzer/Java_TS_analyzer.py b/src/tstool/analyzer/Java_TS_analyzer.py index 91b16a5..4464bea 100644 --- a/src/tstool/analyzer/Java_TS_analyzer.py +++ b/src/tstool/analyzer/Java_TS_analyzer.py @@ -105,7 +105,7 @@ def get_arguments_at_callsite( :param call_site_node: the node of the call site :return: the arguments """ - arguments = set([]) + arguments: Set[Value] = set([]) file_name = current_function.file_path source_code = self.code_in_files[file_name] for sub_node in call_site_node.children: diff --git a/src/tstool/analyzer/Python_TS_analyzer.py b/src/tstool/analyzer/Python_TS_analyzer.py index 09adfe2..c5d0147 100644 --- a/src/tstool/analyzer/Python_TS_analyzer.py +++ b/src/tstool/analyzer/Python_TS_analyzer.py @@ -119,7 +119,7 @@ def get_arguments_at_callsite( :param call_site_node: the node of the call site :return: the arguments """ - arguments = set([]) + arguments: Set[Value] = set([]) file_name = current_function.file_path source_code = self.code_in_files[file_name] for sub_node in call_site_node.children: diff --git a/src/tstool/analyzer/TS_analyzer.py b/src/tstool/analyzer/TS_analyzer.py index 016566d..31118ab 100644 --- a/src/tstool/analyzer/TS_analyzer.py +++ b/src/tstool/analyzer/TS_analyzer.py @@ -3,11 +3,10 @@ from pathlib import Path import copy import concurrent.futures -from typing import List, Tuple, Dict, Set +from typing import List, Optional, Tuple, Dict, Set from abc import ABC, abstractmethod -import tree_sitter -from tree_sitter import Language +from tree_sitter import Language, Node, Tree, Parser from tqdm import tqdm import networkx as nx @@ -65,6 +64,8 @@ def add_and_check_context(self, label: ContextLabel) -> bool: # Get the top element from the context stack top_label = self.get_top_unmatched_context_label() + # TODO (ZZ): this assertion is added to satisfy the mypy check. Update the code to remove this assertion + assert top_label is not None, "Top label should not be None" # Determine which labels to match based on analysis direction first_label = ( @@ -95,7 +96,7 @@ def add_and_check_context(self, label: ContextLabel) -> bool: self.context.append(label) return is_CFL_reachable - def get_top_unmatched_context_label(self) -> ContextLabel: + def get_top_unmatched_context_label(self) -> Optional[ContextLabel]: """ Get the top unmatched context label. :return: The top unmatched context label. @@ -112,7 +113,9 @@ def __str__(self) -> str: [str(label) for label in self.context] ) - def __eq__(self, other: "CallContext") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, CallContext): + return NotImplemented return self.__str__() == other.__str__() def __hash__(self) -> int: @@ -143,7 +146,7 @@ def __init__( self.max_symbolic_workers_num = max_symbolic_workers_num # Initialize tree-sitter parser - self.parser = tree_sitter.Parser() + self.parser = Parser() self.language_name = language_name if language_name == "C": self.language = Language(str(language_path), "c") @@ -160,23 +163,23 @@ def __init__( self.parser.set_language(self.language) # Results of parsing - self.functionRawDataDic = {} - self.functionNameToId = {} - self.functionToFile = {} - self.fileContentDic = {} - self.glb_var_map = {} # global var info + self.functionRawDataDic: Dict[int, Tuple[str, int, int, Node]] = {} + self.functionNameToId: Dict[str, Set[int]] = {} + self.functionToFile: Dict[int, str] = {} + self.fileContentDic: Dict[str, str] = {} + self.glb_var_map: Dict[str, str] = {} # global var info - self.function_env: dict[int, Function] = {} - self.api_env: dict[int, API] = {} + self.function_env: Dict[int, Function] = {} + self.api_env: Dict[int, API] = {} # Results of call graph analysis ## Caller-callee relationship between user-defined functions - self.function_caller_callee_map = {} - self.function_callee_caller_map = {} + self.function_caller_callee_map: Dict[int, Set[int]] = {} + self.function_callee_caller_map: Dict[int, Set[int]] = {} ## Caller-callee relationship between user-defined functions and library APIs - self.function_caller_api_callee_map = {} - self.api_callee_function_caller_map = {} + self.function_caller_api_callee_map: Dict[int, Set[int]] = {} + self.api_callee_function_caller_map: Dict[int, Set[int]] = {} # Analyze stage I: Project AST parsing self.parse_project() @@ -201,8 +204,8 @@ def _parse_single_file(self, file_path: str, source_code: str) -> Tuple[str, str return file_path, source_code def _analyze_single_function( - self, function_id: int, raw_data: Tuple - ) -> Tuple[int, "Function"]: + self, function_id: int, raw_data: Tuple[str, int, int, Node] + ) -> Tuple[int, Function]: """ Helper function to analyze a single function. """ @@ -229,17 +232,17 @@ def parse_project(self) -> None: with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_symbolic_workers_num ) as executor: - futures = {} + parse_futures: Dict[concurrent.futures.Future[Tuple[str, str]], str] = {} pbar = tqdm(total=len(self.code_in_files), desc="Parsing files") for file_path, source_code in self.code_in_files.items(): # Submit a task for each file. - future = executor.submit( + parse_future = executor.submit( self._parse_single_file, file_path, source_code ) - futures[future] = file_path + parse_futures[parse_future] = file_path # Collect results. - for future in concurrent.futures.as_completed(futures): - file_path, source = future.result() + for parse_future in concurrent.futures.as_completed(parse_futures): + file_path, source = parse_future.result() self.fileContentDic[file_path] = source pbar.update(1) pbar.close() @@ -247,16 +250,18 @@ def parse_project(self) -> None: with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_symbolic_workers_num ) as executor: - futures = {} + analyze_futures: Dict[ + concurrent.futures.Future[Tuple[int, Function]], int + ] = {} pbar = tqdm(total=len(self.functionRawDataDic), desc="Analyzing functions") for function_id, raw_data in self.functionRawDataDic.items(): - future = executor.submit( + analyze_future = executor.submit( self._analyze_single_function, function_id, raw_data ) - futures[future] = function_id + analyze_futures[analyze_future] = function_id - for future in concurrent.futures.as_completed(futures): - func_id, current_function = future.result() + for analyze_future in concurrent.futures.as_completed(analyze_futures): + func_id, current_function = analyze_future.result() self.function_env[func_id] = current_function pbar.update(1) pbar.close() @@ -291,7 +296,7 @@ def analyze_call_graph(self) -> None: ########################################### @abstractmethod def extract_function_info( - self, file_path: str, source_code: str, tree: tree_sitter.Tree + self, file_path: str, source_code: str, tree: Tree ) -> None: """ Parse function information from a source file. @@ -326,9 +331,7 @@ def extract_meta_data_in_single_function( return current_function @abstractmethod - def extract_global_info( - self, file_path: str, source_code: str, tree: tree_sitter.Tree - ) -> None: + def extract_global_info(self, file_path: str, source_code: str, tree: Tree) -> None: """ Parse macro or global variable information from a source file. :param file_path: Path of the source file. @@ -509,14 +512,12 @@ def get_all_callee_apis( """ callee_list = [] for callee_api_id in self.function_caller_api_callee_map[function.function_id]: - if self.api_env[callee_list] == API(-1, callee_name, para_num): - callee_list.append(self.api_env[callee_list]) + if self.api_env[callee_api_id] == API(-1, callee_name, para_num): + callee_list.append(self.api_env[callee_api_id]) return callee_list @abstractmethod - def get_callee_name_at_call_site( - self, node: tree_sitter.Node, source_code: str - ) -> str: + def get_callee_name_at_call_site(self, node: Node, source_code: str) -> str: """ Get the callee name at the call site. :param node: The node of the call site. @@ -526,7 +527,7 @@ def get_callee_name_at_call_site( pass def get_callee_function_ids_at_callsite( - self, current_function: Function, call_site_node: tree_sitter.Node + self, current_function: Function, call_site_node: Node ) -> List[int]: """ Determine the callee function(s) from a call site. @@ -548,12 +549,15 @@ def get_callee_function_ids_at_callsite( for callee_id in temp_callee_ids: callee = self.function_env[callee_id] paras = callee.paras + # TODO (ZZ): this assertion is to make mypy happy + assert paras is not None, "analysis is not done yet" + if len(paras) == len(arguments): callee_ids.append(callee_id) return callee_ids def get_callee_api_ids_at_callsite( - self, current_function: Function, call_site_node: tree_sitter.Node + self, current_function: Function, call_site_node: Node ) -> List[int]: """ Determine the callee api(s) from a call site. @@ -577,7 +581,7 @@ def get_callee_api_ids_at_callsite( @abstractmethod def get_callsites_by_callee_name( self, current_function: Function, callee_name: str - ) -> List[tree_sitter.Node]: + ) -> List[Node]: """ Find the call site nodes by callee name. :param current_function: The function to be analyzed. @@ -589,7 +593,7 @@ def get_callsites_by_callee_name( # Helper functions for arguments @abstractmethod def get_arguments_at_callsite( - self, current_function: Function, call_site_node: tree_sitter.Node + self, current_function: Function, call_site_node: Node ) -> Set[Value]: """ Get arguments from a call site in a function. @@ -613,7 +617,7 @@ def get_parameters_in_single_function( # Helper functions for output values def get_output_value_at_callsite( - self, current_function: Function, call_site_node: tree_sitter.Node + self, current_function: Function, call_site_node: Node ) -> Value: """ Get the output value from a call site. @@ -667,7 +671,7 @@ def get_loop_statements( pass def check_control_order( - self, function: Function, src_line_number: str, sink_line_number: str + self, function: Function, src_line_number: int, sink_line_number: int ) -> bool: """ Check if the source line could execute before the sink line. @@ -720,7 +724,7 @@ def check_control_order( return True def check_control_reachability( - self, function: Function, src_line_number: str, sink_line_number: str + self, function: Function, src_line_number: int, sink_line_number: int ) -> bool: """ Check if control can reach from the source line to the sink line, considering return statements. @@ -732,9 +736,7 @@ def check_control_reachability( # Other helper functions - def get_node_by_line_number( - self, line_number: int - ) -> List[Tuple[str, tree_sitter.Node]]: + def get_node_by_line_number(self, line_number: int) -> List[Tuple[str, Node]]: """ Find nodes that contain a specific line number. """ @@ -759,12 +761,12 @@ def get_node_by_line_number( code_node_list.append((function.function_code, node)) return code_node_list - def get_function_from_localvalue(self, value: Value) -> Function: + def get_function_from_localvalue(self, value: Value) -> Optional[Function]: """ Retrieve the function corresponding to a local value. """ file_name = value.file - for function_id, function in self.function_env.items(): + for _, function in self.function_env.items(): if function.file_path == file_name: if ( function.start_line_number @@ -789,7 +791,7 @@ def get_content_by_line_number(self, line_number: int, file_name: str) -> str: # Utility functions for AST node type maching -def find_all_nodes(root_node: tree_sitter.Node) -> List[tree_sitter.Node]: +def find_all_nodes(root_node: Node) -> List[Node]: """ Recursively find all nodes in the tree starting at root_node. """ @@ -801,9 +803,7 @@ def find_all_nodes(root_node: tree_sitter.Node) -> List[tree_sitter.Node]: return nodes -def find_nodes_by_type( - root_node: tree_sitter.Node, node_type: str, k=0 -) -> List[tree_sitter.Node]: +def find_nodes_by_type(root_node: Node, node_type: str, k=0) -> List[Node]: """ Recursively find all nodes of a given type. """ diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py index 9760693..b3c8367 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_MLK_extractor.py @@ -1,8 +1,6 @@ from tstool.analyzer.TS_analyzer import * from tstool.analyzer.Cpp_TS_analyzer import * from ..dfbscan_extractor import * -import tree_sitter -import argparse class Cpp_MLK_Extractor(DFBScanExtractor): @@ -36,7 +34,7 @@ def extract_sources(self, function: Function) -> List[Value]: "vasprintf", "getline", } - spec_apis = {} # specific user-defined APIs that allocate memory + # spec_apis = {} # specific user-defined APIs that allocate memory sources = [] for node in nodes: is_seed_node = False @@ -46,7 +44,7 @@ def extract_sources(self, function: Function) -> List[Value]: for child in node.children: if child.type == "identifier": name = source_code[child.start_byte : child.end_byte] - if name in mem_allocations or name in spec_apis: + if name in mem_allocations: # or name in spec_apis: is_seed_node = True if is_seed_node: @@ -71,7 +69,7 @@ def extract_sinks(self, function: Function) -> List[Value]: """ nodes = find_nodes_by_type(root_node, "call_expression") mem_deallocations = {"free"} - spec_apis = {} # specific user-defined APIs that deallocate memory + # spec_apis = {} # specific user-defined APIs that deallocate memory sinks = [] for node in nodes: is_sink_node = False @@ -80,7 +78,7 @@ def extract_sinks(self, function: Function) -> List[Value]: for child in node.children: if child.type == "identifier": name = source_code[child.start_byte : child.end_byte] - if name in mem_deallocations or name in spec_apis: + if name in mem_deallocations: # or name in spec_apis: is_sink_node = True if is_sink_node: diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py index eb7f261..b7fe94f 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_NPD_extractor.py @@ -22,7 +22,7 @@ def extract_sources(self, function: Function) -> List[Value]: nodes.extend(find_nodes_by_type(root_node, "return_statement")) nodes.extend(find_nodes_by_type(root_node, "call_expression")) - spec_apis = {"malloc"} # specific user-defined APIs that can return NULL + # spec_apis = {"malloc"} # specific user-defined APIs that can return NULL sources = [] for node in nodes: is_seed_node = False diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py index 1fdc190..a18bf18 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_UAF_extractor.py @@ -6,7 +6,7 @@ class Cpp_UAF_Extractor(DFBScanExtractor): - def extract_sources(self, function: Function) -> List[Tuple[Value, bool]]: + def extract_sources(self, function: Function) -> List[Value]: """ Extract the sources that can cause the use-after-free bugs from C/C++ programs. :param: function: Function object. @@ -24,7 +24,7 @@ def extract_sources(self, function: Function) -> List[Tuple[Value, bool]]: nodes.extend(find_nodes_by_type(root_node, "delete_expression")) free_functions = {"free", "ngx_destroy_black_list_link"} - spec_apis = {} # specific user-defined APIs + # spec_apis = {} # specific user-defined APIs sources = [] for node in nodes: is_seed_node = False diff --git a/src/tstool/dfbscan_extractor/dfbscan_extractor.py b/src/tstool/dfbscan_extractor/dfbscan_extractor.py index c4c82d0..d300cdb 100644 --- a/src/tstool/dfbscan_extractor/dfbscan_extractor.py +++ b/src/tstool/dfbscan_extractor/dfbscan_extractor.py @@ -1,12 +1,8 @@ import sys -import os from os import path -from pathlib import Path from tstool.analyzer.TS_analyzer import * from memory.syntactic.function import * from memory.syntactic.value import * -import tree_sitter -import json from tqdm import tqdm from abc import ABC, abstractmethod @@ -20,11 +16,11 @@ class DFBScanExtractor(ABC): def __init__(self, ts_analyzer: TSAnalyzer): self.ts_analyzer = ts_analyzer - self.sources = [] - self.sinks = [] + self.sources: List[Value] = [] + self.sinks: List[Value] = [] return - def extract_all(self): + def extract_all(self) -> Tuple[List[Value], List[Value]]: """ Start the source/sink extraction process. """