diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 6e88dc93f1..4751a43e25 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -37,6 +37,7 @@ jobs: - test_whitelist.py - test_arp.py - test_arp_poisoner.py + - test_arp_filter.py - test_blocking.py - test_unblocker.py - test_flow_handler.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3567eb858d..31cc6c5f75 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,3 +43,11 @@ repos: - id: yamllint args: ["-d", "{rules: {line-length: {max: 100}}}"] files: "slips.yaml" + +- repo: local + hooks: + - id: vulture + name: vulture dead code check + entry: bash -c 'files=$(git diff --cached --name-only --diff-filter=ACM | grep -E "\.py$" | grep -vE "^(tests/|migrations/)"); [ -n "$files" ] && vulture --exclude "tests/*,venv/*" $files || true' + language: system + types: [python] diff --git a/conftest.py b/conftest.py index 3ce9d71fcf..b0178d1527 100644 --- a/conftest.py +++ b/conftest.py @@ -67,14 +67,6 @@ def profiler_queue(): profiler_queue.put = do_nothing return profiler_queue - -@pytest.fixture -def database(): - db = DBManager(Output(), "output/", 6379) - db.print = do_nothing - return db - - @pytest.fixture def flow(): """returns a dummy flow for testing""" diff --git a/docker/Dockerfile b/docker/Dockerfile index 4482e786e9..a6731bce61 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -59,6 +59,7 @@ RUN apt update && apt install -y --no-install-recommends \ nano \ tree \ tmux \ + arp-scan \ && echo 'deb http://download.opensuse.org/repositories/security:/zeek/xUbuntu_22.04/ /' | tee /etc/apt/sources.list.d/security:zeek.list \ && curl -fsSL https://download.opensuse.org/repositories/security:zeek/xUbuntu_22.04/Release.key | gpg --dearmor | tee /etc/apt/trusted.gpg.d/security_zeek.gpg > /dev/null \ && apt update \ diff --git a/docs/create_new_module.md b/docs/create_new_module.md index 3646c96cbb..2df387b024 100644 --- a/docs/create_new_module.md +++ b/docs/create_new_module.md @@ -365,7 +365,8 @@ import json from slips_files.common.flow_classifier import FlowClassifier from slips_files.core.structures.evidence import - ( + +( Evidence, ProfileID, TimeWindow, @@ -378,7 +379,7 @@ from slips_files.core.structures.evidence import ) from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class LocalConnectionDetector( diff --git a/install/requirements.txt b/install/requirements.txt index 47e1db470d..d920233c6d 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -41,4 +41,5 @@ netifaces==0.11.0 scapy==2.6.1 pyyaml pytest-asyncio +vulture git+https://github.com/SECEF/python-idmefv2.git diff --git a/managers/process_manager.py b/managers/process_manager.py index 5bd3067ccf..be02f42c0d 100644 --- a/managers/process_manager.py +++ b/managers/process_manager.py @@ -33,7 +33,7 @@ import modules from modules.update_manager.update_manager import UpdateManager from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import ( +from slips_files.common.abstracts.imodule import ( IModule, ) @@ -106,6 +106,7 @@ def start_profiler_process(self): self.main.redis_port, self.termination_event, self.main.args, + self.main.conf, is_profiler_done=self.is_profiler_done, profiler_queue=self.profiler_queue, is_profiler_done_event=self.is_profiler_done_event, @@ -127,6 +128,7 @@ def start_evidence_process(self): self.main.redis_port, self.evidence_handler_termination_event, self.main.args, + self.main.conf, ) evidence_process.start() self.main.print( @@ -145,6 +147,7 @@ def start_input_process(self): self.main.redis_port, self.termination_event, self.main.args, + self.main.conf, is_input_done=self.is_input_done, profiler_queue=self.profiler_queue, input_type=self.main.input_type, @@ -389,6 +392,7 @@ def load_modules(self): self.main.redis_port, self.termination_event, self.main.args, + self.main.conf, ) module.start() self.main.db.store_pid(module_name, int(module.pid)) @@ -444,6 +448,7 @@ def start_update_manager(self, local_files=False, ti_feeds=False): self.main.redis_port, multiprocessing.Event(), self.main.args, + self.main.conf, ) if local_files: @@ -821,7 +826,7 @@ def shutdown_gracefully(self): self.main.profilers_manager.cpu_profiler_release() self.main.profilers_manager.memory_profiler_release() - self.main.db.close_redis_and_sqlite() + self.main.db.close_all_dbs() if graceful_shutdown: print( "[Process Manager] Slips shutdown gracefully\n", diff --git a/modules/arp/arp.py b/modules/arp/arp.py index 8b4bdde752..f6ae2567aa 100644 --- a/modules/arp/arp.py +++ b/modules/arp/arp.py @@ -8,10 +8,11 @@ from multiprocessing import Queue from typing import List +from modules.arp.filter import ARPEvidenceFilter from slips_files.common.flow_classifier import FlowClassifier from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.core.structures.evidence import ( Evidence, ProfileID, @@ -26,7 +27,7 @@ class ARP(IModule): - # Name: short name of the module. Do not use spaces + name = "ARP" description = "Detect ARP attacks" authors = ["Alya Gomaa"] @@ -64,6 +65,7 @@ def init(self): # wait 10s for mmore arp scan evidence to come self.time_to_wait = 10 self.is_zeek_running: bool = self.is_running_zeek() + self.evidence_filter = ARPEvidenceFilter(self.conf, self.args, self.db) def read_configuration(self): conf = ConfigParser() @@ -247,7 +249,7 @@ def set_evidence_arp_scan(self, ts, profileid, twid, uids: List[str]): timestamp=ts, ) - self.db.set_evidence(evidence) + self.set_evidence(evidence) # after we set evidence, clear the dict so we can detect if it # does another scan try: @@ -315,7 +317,7 @@ def check_dstip_outside_localnet(self, twid, flow): timestamp=flow.starttime, victim=victim, ) - self.db.set_evidence(evidence) + self.set_evidence(evidence) return True return False @@ -357,7 +359,7 @@ def detect_unsolicited_arp(self, twid: str, flow): timestamp=flow.starttime, ) - self.db.set_evidence(evidence) + self.set_evidence(evidence) return True def detect_mitm_arp_attack(self, twid: str, flow): @@ -460,7 +462,7 @@ def detect_mitm_arp_attack(self, twid: str, flow): victim=victim, ) - self.db.set_evidence(evidence) + self.set_evidence(evidence) return True def check_if_gratutitous_arp(self, flow): @@ -514,6 +516,13 @@ def clear_arp_logfile(self): # update ts of the new arp.log self.arp_log_creation_time = time.time() + def set_evidence(self, evidence: Evidence): + """the goal of this function is to discard evidence of other slips + peers doing arp scans because that's slips attacking back attackers""" + if self.evidence_filter.should_discard_evidence(evidence.profile.ip): + return + self.db.set_evidence(evidence) + def pre_main(self): """runs once before the main() is executed in a loop""" utils.drop_root_privs() diff --git a/modules/arp/filter.py b/modules/arp/filter.py new file mode 100644 index 0000000000..5e67e79123 --- /dev/null +++ b/modules/arp/filter.py @@ -0,0 +1,58 @@ +from typing import List + +from slips_files.common.slips_utils import utils +from slips_files.core.database.database_manager import DBManager + + +class ARPEvidenceFilter: + """ + A class to filter ARP evidence coming from a peer slips. + Slips uses arp poisoning, arp spoofing, and arp scans to discover + attackers and isolate them from the network, we don't want this + instance of Slips to block other Slips instances, so we discard + evidence about other slips attacking. + """ + + def __init__(self, conf, slips_args, db: DBManager): + self.db = db + self.conf = conf + self.args = slips_args + # p2p needs to be enabled for slips to be able to recognize slips peers + self.p2p_enabled = False + if self.conf.use_local_p2p(): + self.p2p_enabled = True + self.our_ips: List[str] = utils.get_own_ips(ret="List") + + def should_discard_evidence(self, ip: str) -> bool: + return self.is_slips_peer(ip) or self.is_self_defense(ip) + + def is_self_defense(self, ip: str): + """ + slips uses arp poison to defend itself and th enetwork, + check arp_poison.py for more details. + goal of this function is to discard evidence about slips doing arp + attacks when it's just attacking attackers + """ + loaded_modules = self.db.get_pids().keys() + return ( + ip in self.our_ips + and self.args.blocking + and "ARP Poisoner" in loaded_modules + ) + + def is_slips_peer(self, ip: str) -> bool: + """ + Check if the given IP address is a trusted Slips peer. + Trust here is defined from the p2p network (trust model). + Only works if the local p2p is enabled. + + :param ip: The IP address to check. + """ + if not self.p2p_enabled or not utils.is_private_ip(ip): + return False + + trust = self.db.get_peer_trust(ip) + if not trust: + return False + + return trust >= 0.3 diff --git a/modules/arp_poisoner/arp_poisoner.py b/modules/arp_poisoner/arp_poisoner.py index bc14402ba9..dc72bfb8d5 100644 --- a/modules/arp_poisoner/arp_poisoner.py +++ b/modules/arp_poisoner/arp_poisoner.py @@ -12,7 +12,7 @@ from scapy.all import ARP, Ether from scapy.sendrecv import sendp, srp -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from modules.arp_poisoner.unblocker import ARPUnblocker from slips_files.common.slips_utils import utils @@ -195,6 +195,7 @@ def _arp_poison(self, target_ip: str, first_time=False): repoisoning every x seconds. """ fake_mac = "aa:aa:aa:aa:aa:aa" + # it makes sense here to get the mac using cache, because if we # reached this function, means there's an alert, means slips seen # traffic from that target_ip and has itsmac in the arp cache. diff --git a/modules/arp_poisoner/unblocker.py b/modules/arp_poisoner/unblocker.py index ba02ed32b1..ee3defc0ed 100644 --- a/modules/arp_poisoner/unblocker.py +++ b/modules/arp_poisoner/unblocker.py @@ -1,7 +1,7 @@ from threading import Lock import time from typing import Callable, Optional -from slips_files.common.abstracts.unblocker import IUnblocker +from slips_files.common.abstracts.iunblocker import IUnblocker from slips_files.common.printer import Printer from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import TimeWindow diff --git a/modules/blocking/blocking.py b/modules/blocking/blocking.py index 376e6620d9..35aaa3b70b 100644 --- a/modules/blocking/blocking.py +++ b/modules/blocking/blocking.py @@ -10,7 +10,7 @@ import time from threading import Lock -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.common.slips_utils import utils from .exec_iptables_cmd import exec_iptables_command from modules.blocking.unblocker import Unblocker diff --git a/modules/blocking/unblocker.py b/modules/blocking/unblocker.py index e39d195d91..68f6bcfcf1 100644 --- a/modules/blocking/unblocker.py +++ b/modules/blocking/unblocker.py @@ -2,7 +2,7 @@ import time import threading from typing import Dict, Callable -from slips_files.common.abstracts.unblocker import IUnblocker +from slips_files.common.abstracts.iunblocker import IUnblocker from slips_files.common.printer import Printer from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import TimeWindow diff --git a/modules/cesnet/cesnet.py b/modules/cesnet/cesnet.py index 399763fd52..55db0796b8 100644 --- a/modules/cesnet/cesnet.py +++ b/modules/cesnet/cesnet.py @@ -9,7 +9,7 @@ import validators from slips_files.common.parsers.config_parser import ConfigParser -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.core.structures.evidence import ( ThreatLevel, Evidence, diff --git a/modules/cyst/cyst.py b/modules/cyst/cyst.py index da10bb3017..d43631f155 100644 --- a/modules/cyst/cyst.py +++ b/modules/cyst/cyst.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule import socket import json import os diff --git a/modules/exporting_alerts/exporting_alerts.py b/modules/exporting_alerts/exporting_alerts.py index 94ebff6260..3390b4408b 100644 --- a/modules/exporting_alerts/exporting_alerts.py +++ b/modules/exporting_alerts/exporting_alerts.py @@ -5,7 +5,7 @@ from modules.exporting_alerts.slack_exporter import SlackExporter from modules.exporting_alerts.stix_exporter import StixExporter from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class ExportingAlerts(IModule): diff --git a/modules/exporting_alerts/slack_exporter.py b/modules/exporting_alerts/slack_exporter.py index 595362c74f..4ac48d6dc7 100644 --- a/modules/exporting_alerts/slack_exporter.py +++ b/modules/exporting_alerts/slack_exporter.py @@ -3,7 +3,7 @@ from slack import WebClient from slack.errors import SlackApiError from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.exporter import IExporter +from slips_files.common.abstracts.iexporter import IExporter from slips_files.common.parsers.config_parser import ConfigParser diff --git a/modules/exporting_alerts/stix_exporter.py b/modules/exporting_alerts/stix_exporter.py index 457ac1abd7..1833258983 100644 --- a/modules/exporting_alerts/stix_exporter.py +++ b/modules/exporting_alerts/stix_exporter.py @@ -6,7 +6,7 @@ import threading import os -from slips_files.common.abstracts.exporter import IExporter +from slips_files.common.abstracts.iexporter import IExporter from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils diff --git a/modules/fidesModule/fidesModule.py b/modules/fidesModule/fidesModule.py index a4817f6f97..54da822267 100644 --- a/modules/fidesModule/fidesModule.py +++ b/modules/fidesModule/fidesModule.py @@ -1,10 +1,9 @@ import os import json -from dataclasses import asdict from pathlib import Path from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.common.parsers.config_parser import ( ConfigParser, ) @@ -12,11 +11,10 @@ dict_to_alert, Alert, ) -from .messaging.model import NetworkMessage from ..fidesModule.messaging.message_handler import MessageHandler from ..fidesModule.messaging.network_bridge import NetworkBridge from ..fidesModule.model.configuration import load_configuration -from ..fidesModule.model.threat_intelligence import SlipsThreatIntelligence, ThreatIntelligence +from ..fidesModule.model.threat_intelligence import SlipsThreatIntelligence from ..fidesModule.protocols.alert import AlertProtocol from ..fidesModule.protocols.initial_trusl import InitialTrustProtocol from ..fidesModule.protocols.opinion import OpinionAggregator @@ -26,15 +24,13 @@ ThreatIntelligenceProtocol, ) from ..fidesModule.utils.logger import LoggerPrintCallbacks -from ..fidesModule.messaging.redis_simplex_queue import RedisSimplexQueue, RedisDuplexQueue +from ..fidesModule.messaging.redis_simplex_queue import RedisSimplexQueue from ..fidesModule.persistence.threat_intelligence_db import ( SlipsThreatIntelligenceDatabase, ) from ..fidesModule.persistence.trust_db import SlipsTrustDatabase from ..fidesModule.persistence.sqlite_db import SQLiteDB -from ..fidesModule.model.alert import Alert as FidesAlert - class FidesModule(IModule): """ diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index bf5a428a12..1519601c06 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -9,7 +9,7 @@ import validators from modules.flowalerts.dns import DNS -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.parsers.config_parser import ConfigParser @@ -40,7 +40,7 @@ def init(self): self.dns_analyzer = DNS(self.db, flowalerts=self) self.is_running_non_stop: bool = self.db.is_running_non_stop() self.classifier = FlowClassifier() - self.our_ips: List[str] = utils.get_own_ips(ret=List) + self.our_ips: List[str] = utils.get_own_ips(ret="List") self.input_type: str = self.db.get_input_type() self.multiple_reconnection_attempts_threshold = 5 # we use this to try to detect if there's dns server that has a diff --git a/modules/flowalerts/dns.py b/modules/flowalerts/dns.py index b4cca77289..bc8a42ce37 100644 --- a/modules/flowalerts/dns.py +++ b/modules/flowalerts/dns.py @@ -16,7 +16,7 @@ from multiprocessing import Queue from threading import Thread, Event -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier @@ -40,7 +40,7 @@ def init(self): self.arpa_scan_threshold = 10 self.is_running_non_stop: bool = self.db.is_running_non_stop() self.classifier = FlowClassifier() - self.our_ips: List[str] = utils.get_own_ips(ret=List) + self.our_ips: List[str] = utils.get_own_ips(ret="List") # In mins self.dns_without_conn_interface_wait_time = 30 # to store dns queries that we should check later. the purpose of diff --git a/modules/flowalerts/downloaded_file.py b/modules/flowalerts/downloaded_file.py index 8a14eb4151..066057cb6e 100644 --- a/modules/flowalerts/downloaded_file.py +++ b/modules/flowalerts/downloaded_file.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier diff --git a/modules/flowalerts/flowalerts.py b/modules/flowalerts/flowalerts.py index dc132058dd..92368027b5 100644 --- a/modules/flowalerts/flowalerts.py +++ b/modules/flowalerts/flowalerts.py @@ -6,7 +6,7 @@ from typing import List from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.async_module import AsyncModule +from slips_files.common.abstracts.iasync_module import AsyncModule from .conn import Conn from .dns import DNS from .downloaded_file import DownloadedFile diff --git a/modules/flowalerts/notice.py b/modules/flowalerts/notice.py index 959a6ed4d9..ff5259a638 100644 --- a/modules/flowalerts/notice.py +++ b/modules/flowalerts/notice.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier diff --git a/modules/flowalerts/smtp.py b/modules/flowalerts/smtp.py index a73ecb9308..0511682044 100644 --- a/modules/flowalerts/smtp.py +++ b/modules/flowalerts/smtp.py @@ -3,7 +3,7 @@ import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.slips_utils import utils diff --git a/modules/flowalerts/software.py b/modules/flowalerts/software.py index 2274a6bd8b..8258e599f3 100644 --- a/modules/flowalerts/software.py +++ b/modules/flowalerts/software.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier diff --git a/modules/flowalerts/ssh.py b/modules/flowalerts/ssh.py index f134806e6b..f526a16ace 100644 --- a/modules/flowalerts/ssh.py +++ b/modules/flowalerts/ssh.py @@ -3,7 +3,7 @@ import asyncio import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier diff --git a/modules/flowalerts/ssl.py b/modules/flowalerts/ssl.py index e4bb7cec2a..e71850e677 100644 --- a/modules/flowalerts/ssl.py +++ b/modules/flowalerts/ssl.py @@ -9,7 +9,7 @@ import time from multiprocessing import Lock import tldextract -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.flow_classifier import FlowClassifier diff --git a/modules/flowalerts/tunnel.py b/modules/flowalerts/tunnel.py index 340a32dbb4..346e6f23f4 100644 --- a/modules/flowalerts/tunnel.py +++ b/modules/flowalerts/tunnel.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import json -from slips_files.common.abstracts.flowalerts_analyzer import ( +from slips_files.common.abstracts.iflowalerts_analyzer import ( IFlowalertsAnalyzer, ) from slips_files.common.slips_utils import utils diff --git a/modules/flowmldetection/flowmldetection.py b/modules/flowmldetection/flowmldetection.py index e44ac83f4d..5ca5a2425b 100644 --- a/modules/flowmldetection/flowmldetection.py +++ b/modules/flowmldetection/flowmldetection.py @@ -13,7 +13,7 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.core.structures.evidence import ( Evidence, ProfileID, diff --git a/modules/http_analyzer/http_analyzer.py b/modules/http_analyzer/http_analyzer.py index 9772e733ee..e9423b378a 100644 --- a/modules/http_analyzer/http_analyzer.py +++ b/modules/http_analyzer/http_analyzer.py @@ -13,7 +13,7 @@ from slips_files.common.flow_classifier import FlowClassifier from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.async_module import AsyncModule +from slips_files.common.abstracts.iasync_module import AsyncModule ESTAB = "Established" diff --git a/modules/ip_info/ip_info.py b/modules/ip_info/ip_info.py index 3a4a276ccb..9ec5d6ab89 100644 --- a/modules/ip_info/ip_info.py +++ b/modules/ip_info/ip_info.py @@ -26,7 +26,7 @@ from slips_files.common.flow_classifier import FlowClassifier from slips_files.core.helpers.whitelist.whitelist import Whitelist from .asn_info import ASN -from slips_files.common.abstracts.async_module import AsyncModule +from slips_files.common.abstracts.iasync_module import AsyncModule from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import ( Evidence, diff --git a/modules/irisModule/irisModule.py b/modules/irisModule/irisModule.py index eed75559d5..2e766e4b21 100644 --- a/modules/irisModule/irisModule.py +++ b/modules/irisModule/irisModule.py @@ -4,7 +4,7 @@ from slips_files.common.parsers.config_parser import ConfigParser -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule import json import os import subprocess @@ -61,7 +61,7 @@ def _iris_configurator(self, config_path: str, redis_port: int): "Tl2NlChannel": "iris_internal", } if "Server" in config: - #config["Server"]["Port"] = 9010 + # config["Server"]["Port"] = 9010 config["Server"]["Host"] = self.db.get_host_ip() config["Server"]["DhtServerMode"] = "true" else: diff --git a/modules/leak_detector/leak_detector.py b/modules/leak_detector/leak_detector.py index fa202f70bf..0b72a0443d 100644 --- a/modules/leak_detector/leak_detector.py +++ b/modules/leak_detector/leak_detector.py @@ -15,7 +15,7 @@ from uuid import uuid4 from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.core.structures.evidence import ( Evidence, ProfileID, diff --git a/modules/network_discovery/network_discovery.py b/modules/network_discovery/network_discovery.py index 8afa80d692..da2f1c0c9a 100644 --- a/modules/network_discovery/network_discovery.py +++ b/modules/network_discovery/network_discovery.py @@ -5,7 +5,7 @@ from slips_files.common.flow_classifier import FlowClassifier from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from modules.network_discovery.horizontal_portscan import HorizontalPortscan from modules.network_discovery.vertical_portscan import VerticalPortscan from slips_files.core.structures.evidence import ( diff --git a/modules/p2ptrust/p2ptrust.py b/modules/p2ptrust/p2ptrust.py index ef357db919..9935cae8b4 100644 --- a/modules/p2ptrust/p2ptrust.py +++ b/modules/p2ptrust/p2ptrust.py @@ -6,14 +6,13 @@ import signal import subprocess import time -from pathlib import Path from typing import Dict, Optional, Tuple import json import socket from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule import modules.p2ptrust.trust.base_model as reputation_model import modules.p2ptrust.trust.trustdb as trustdb import modules.p2ptrust.utils.utils as p2p_utils @@ -83,25 +82,21 @@ class Trust(IModule): pigeon_binary = os.path.join(os.getcwd(), "p2p4slips/p2p4slips") pigeon_key_file = "pigeon.keys" rename_redis_ip_info = False - rename_sql_db_file = False override_p2p = False def init(self, *args, **kwargs): - output_dir = self.db.get_output_dir() - # flag to ensure slips prints multiaddress only once - self.mutliaddress_printed = False - self.pigeon_logfile_raw = os.path.join(output_dir, "p2p.log") - - self.p2p_reports_logfile = os.path.join(output_dir, "p2p_reports.log") - # pigeon generate keys and stores them in the following dir, if this is placed in the dir, - # when restarting slips, it will look for the old keys in the new output dir! so it wont find them and will + self.pigeon_logfile_raw = os.path.join(self.output_dir, "p2p.log") + self.p2p_reports_logfile = os.path.join( + self.output_dir, "p2p_reports.log" + ) + # pigeon generate keys and stores them in the following dir, if this + # is placed in the dir, + # when restarting slips, it will look for the old keys in the new + # output dir! so it wont find them and will # generate new keys, and therefore new peerid! # store the keys in slips main dir so they don't change every run - data_dir = os.path.join(os.getcwd(), "p2ptrust_runtime/") - # data_dir = f'./output/{used_interface}/p2ptrust_runtime/' - # create data folder - Path(data_dir).mkdir(parents=True, exist_ok=True) - self.data_dir = data_dir + self.p2ptrust_runtime_dir = self.db.get_p2ptrust_dir() + self.sql_db_name = self.db.get_p2ptrust_db_path() self.port = self.get_available_port() self.host = self.get_local_IP() @@ -137,8 +132,10 @@ def init(self, *args, **kwargs): self.gopy_channel: self.c3, } - # they have to be defined here because the variable name utils is already taken - # TODO rename one of them + # todo don't duplicate this dict, move it to slips_utils + # all evidence slips detects has threat levels of strings + # each string should have a corresponding int value to be able to calculate + # the accumulated threat level and alert self.threat_levels = { "info": 0, "low": 0.2, @@ -147,13 +144,8 @@ def init(self, *args, **kwargs): "critical": 1, } - self.sql_db_name = f"{self.data_dir}trustdb.db" - if self.rename_sql_db_file: - self.sql_db_name += str(self.pigeon_port) - # todo don't duplicate this dict, move it to slips_utils - # all evidence slips detects has threat levels of strings - # each string should have a corresponding int value to be able to calculate - # the accumulated threat level and alert + # flag to ensure slips prints multiaddress only once + self.mutliaddress_printed = False def read_configuration(self): conf = ConfigParser() @@ -192,12 +184,13 @@ def get_available_port(self): continue def _configure(self): - # TODO: do not drop tables on startup - self.trust_db = trustdb.TrustDB( - self.logger, self.sql_db_name, drop_tables_on_startup=True + self.trust_db = trustdb.TrustDB( # + self.logger, + self.sql_db_name, + drop_tables_on_startup=False, ) self.reputation_model = reputation_model.BaseModel( - self.logger, self.trust_db + self.logger, self.trust_db, self.db ) # print(f"[DEBUGGING] Starting godirector with # pygo_channel: {self.pygo_channel}") @@ -252,7 +245,7 @@ def _configure(self): outfile = open(os.devnull, "+w") self.pigeon = subprocess.Popen( - executable, cwd=self.data_dir, stdout=outfile + executable, cwd=self.p2ptrust_runtime_dir, stdout=outfile ) # print(f"[debugging] runnning pigeon: {executable}") @@ -300,9 +293,12 @@ def extract_threat_level( def should_share(self, evidence: Evidence) -> bool: """ - decides whether or not to report the given evidence to other + decides whether to report the given evidence to other peers """ + if evidence.profile.ip in utils.get_own_ips(): + return False + if evidence.evidence_type == EvidenceType.P2P_REPORT: # we shouldn't re-share evidence reported by other peers return False @@ -407,10 +403,14 @@ def data_request_callback(self, msg: Dict): if msg and not isinstance(msg["data"], int): self.handle_data_request(msg["data"]) except Exception as e: - self.print(f"Exception {e} in data_request_callback", 0, 1) + self.print( + f"Exception: {e} .. msg: {msg} in " f"data_request_callback()", + 0, + 1, + ) def set_evidence_malicious_ip( - self, ip_info: dict, threat_level: str, confidence: float + self, ip_info: dict, threat_level: float, confidence: float ): """ Set an evidence for a malicious IP met in the timewindow @@ -430,13 +430,12 @@ def set_evidence_malicious_ip( :param threat_level: the threat level we learned form the network :param confidence: how confident the network opinion is about this opinion """ - attacker_ip: str = ip_info.get("ip") profileid = ip_info.get("profileid") saddr = profileid.split("_")[-1] - threat_level = utils.threat_level_to_string(threat_level) - threat_level = ThreatLevel[threat_level.upper()] + threat_level: str = utils.threat_level_to_string(threat_level) + threat_level: ThreatLevel = ThreatLevel[threat_level.upper()] twid_int = int(ip_info.get("twid").replace("timewindow", "")) if "src" in ip_info.get("ip_state"): @@ -495,7 +494,6 @@ def handle_data_request(self, message_data: str) -> None: :return: None, the result is saved into the redis database under key `p2p4slips` """ - # make sure that IP address is valid # and cache age is a valid timestamp from the past ip_info = validate_slips_data(message_data) @@ -504,7 +502,6 @@ def handle_data_request(self, message_data: str) -> None: # print(f"DEBUGGING: IP address is not valid: # {ip_info}, not asking the network") return - # ip_info is { # 'ip': str(saddr), # 'profileid' : str(profileid), @@ -520,12 +517,15 @@ def handle_data_request(self, message_data: str) -> None: # if data is in cache and is recent enough, # nothing happens and Slips should just check the database ( - score, - confidence, - network_score, - timestamp, + cached_score, + cached_confidence, + cached_network_score, + cached_timestamp, ) = self.trust_db.get_cached_network_opinion("ip", ip_address) - if score is not None and time.time() - timestamp < cache_age: + if ( + cached_score is not None + and time.time() - cached_timestamp < cache_age + ): # cached value is ok, do nothing # print("DEBUGGING: cached value is ok, not asking the network.") return @@ -547,6 +547,7 @@ def handle_data_request(self, message_data: str) -> None: time.sleep(2) # get data from db, processed by the trust model + # this score and confidence are the network's opinion ( combined_score, combined_confidence, @@ -556,8 +557,6 @@ def handle_data_request(self, message_data: str) -> None: ip_address, combined_score, combined_confidence, - network_score, - confidence, ip_info, ) @@ -566,42 +565,30 @@ def process_network_response( ip, combined_score, combined_confidence, - network_score, - confidence, ip_info, ): """ stores the reported score and confidence about the ip and adds an - evidence if necessary + evidence if necessary like when the peers report a malicious ip """ # no data in db - this happens when testing, # if there is not enough data on peers - if combined_score is None: - self.print( - f"No data received from the" f" network about {ip}\n", 0, 2 - ) + if combined_score is None or combined_confidence is None: + self.print(f"No data received from the network about {ip}\n", 0, 2) return self.print( f"The Network shared some data about {ip}, " f"Shared data: score={combined_score}, " - f"confidence={combined_confidence} saving it to now!\n", + f"confidence={combined_confidence} saving it now!\n", 0, 2, ) - # save it to IPsInfo hash in p2p4slips key in the db - # AND p2p_reports key - p2p_utils.save_ip_report_to_db( - ip, - combined_score, - combined_confidence, - network_score, - self.db, - self.storage_name, - ) - if int(combined_score) * int(confidence) > 0: - self.set_evidence_malicious_ip(ip_info, combined_score, confidence) + if combined_score * combined_confidence > 0: + self.set_evidence_malicious_ip( + ip_info, combined_score, combined_confidence + ) def respond_to_message_request(self, key, reporter): # todo do you mean another peer is asking me about diff --git a/modules/p2ptrust/trust/base_model.py b/modules/p2ptrust/trust/base_model.py index 769229093b..e7c62595d5 100644 --- a/modules/p2ptrust/trust/base_model.py +++ b/modules/p2ptrust/trust/base_model.py @@ -1,6 +1,9 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only +from typing_extensions import List + from slips_files.common.printer import Printer +from slips_files.core.database.database_manager import DBManager from slips_files.core.output import Output @@ -17,15 +20,16 @@ class BaseModel: name = "P2P Base Model" - def __init__(self, logger: Output, trustdb): + def __init__(self, logger: Output, trustdb, main_slips_db: DBManager): self.trustdb = trustdb + self.main_slips_db = main_slips_db self.printer = Printer(logger, self.name) self.reliability_weight = 0.7 def print(self, *args, **kwargs): return self.printer.print(*args, **kwargs) - def get_opinion_on_ip(self, ipaddr: str) -> (float, float, float): + def get_opinion_on_ip(self, ipaddr: str) -> (float, float): """ Compute the network's opinion for a given IP @@ -39,8 +43,10 @@ def get_opinion_on_ip(self, ipaddr: str) -> (float, float, float): # get report on that ip that is at most max_age old # if no such report is found: - - reports_on_ip = self.trustdb.get_opinion_on_ip(ipaddr) + # reports_on_ip looks like this: + # [(report_score, report_confidence, reporter_reliability, + # reporter_score, reporter_confidence, reporter_ipaddress), ...] + reports_on_ip: List[tuple] = self.trustdb.get_opinion_on_ip(ipaddr) if len(reports_on_ip) == 0: return None, None combined_score, combined_confidence = self.assemble_peer_opinion( @@ -106,9 +112,17 @@ def assemble_peer_opinion(self, data: list) -> (float, float, float): :param data: a list of peers and their reports, in the format given by TrustDB.get_opinion_on_ip() + ( + report_score, + report_confidence, + reporter_reliability, + reporter_score, # what does slips think about the reporter's ip + # how confident slips is about the reporter's ip's score + reporter_confidence, + reporter_ipaddress, + ) :return: average peer reputation, final score and final confidence """ - reports = [] reporters = [] @@ -117,18 +131,26 @@ def assemble_peer_opinion(self, data: list) -> (float, float, float): report_score, report_confidence, reporter_reliability, + # what does slips think about the reporter's ip reporter_score, + # how confident slips is about the reporter's ip's score reporter_confidence, + reporter_ipaddress, ) = peer_report + reports.append((report_score, report_confidence)) - reporters.append( - self.compute_peer_trust( - reporter_reliability, reporter_score, reporter_confidence - ) + # here reporter_score, reporter_confidence are the local ips + # detection of this peer + peer_trust = self.compute_peer_trust( + reporter_reliability, reporter_score, reporter_confidence ) + reporters.append(peer_trust) + self.main_slips_db.set_peer_trust(reporter_ipaddress, peer_trust) weighted_reporters = self.normalize_peer_reputations(reporters) - + # peers we trust more will contribute more to the final score. + # r[0] → the score from each peer's report. + # w → the normalized trust weight for that peer combined_score = sum( r[0] * w for r, w, in zip(reports, weighted_reporters) ) @@ -136,4 +158,10 @@ def assemble_peer_opinion(self, data: list) -> (float, float, float): [max(0, r[1] * w) for r, w, in zip(reports, reporters)] ) / len(reporters) + # to ensure the score and confidence are within the range [0, 1] + # this avoids python issues with negative values or values + # slightly above 1.0 + combined_score = min(1.0, max(0.0, combined_score)) + combined_confidence = min(1.0, max(0.0, combined_confidence)) + return combined_score, combined_confidence diff --git a/modules/p2ptrust/trust/trustdb.py b/modules/p2ptrust/trust/trustdb.py index 0bdb7b8c6f..acae817a53 100644 --- a/modules/p2ptrust/trust/trustdb.py +++ b/modules/p2ptrust/trust/trustdb.py @@ -3,11 +3,13 @@ import sqlite3 import datetime import time + +from slips_files.common.abstracts.isqlite import ISQLite from slips_files.common.printer import Printer from slips_files.core.output import Output -class TrustDB: +class TrustDB(ISQLite): name = "P2P Trust DB" def __init__( @@ -18,87 +20,86 @@ def __init__( ): """create a database connection to a SQLite database""" self.printer = Printer(logger, self.name) - self.conn = sqlite3.connect(db_file) + self.conn = sqlite3.connect( + db_file, check_same_thread=False, timeout=20 + ) + self.cursor = self.conn.cursor() + super().__init__(self.name.replace(" ", "_").lower()) if drop_tables_on_startup: self.print("Dropping tables") self.delete_tables() self.create_tables() - # self.insert_slips_score("8.8.8.8", 0.0, 0.9) - # self.get_opinion_on_ip("zzz") def __del__(self): self.conn.close() - def print(self, *args, **kwargs): - return self.printer.print(*args, **kwargs) - def create_tables(self): - self.conn.execute( - "CREATE TABLE IF NOT EXISTS slips_reputation (" - "id INTEGER PRIMARY KEY NOT NULL, " - "ipaddress TEXT NOT NULL, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "update_time REAL NOT NULL);" - ) - - self.conn.execute( - "CREATE TABLE IF NOT EXISTS go_reliability (" - "id INTEGER PRIMARY KEY NOT NULL, " - "peerid TEXT NOT NULL, " - "reliability REAL NOT NULL, " - "update_time REAL NOT NULL);" - ) - - self.conn.execute( - "CREATE TABLE IF NOT EXISTS peer_ips (" - "id INTEGER PRIMARY KEY NOT NULL, " - "ipaddress TEXT NOT NULL, " - "peerid TEXT NOT NULL, " - "update_time REAL NOT NULL);" - ) - - self.conn.execute( - "CREATE TABLE IF NOT EXISTS reports (" - "id INTEGER PRIMARY KEY NOT NULL, " - "reporter_peerid TEXT NOT NULL, " - "key_type TEXT NOT NULL, " - "reported_key TEXT NOT NULL, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "update_time REAL NOT NULL);" - ) - - self.conn.execute( - "CREATE TABLE IF NOT EXISTS opinion_cache (" - "key_type TEXT NOT NULL, " - "reported_key TEXT NOT NULL PRIMARY KEY, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "network_score REAL NOT NULL, " - "update_time DATE NOT NULL);" - ) + table_schema = { + "slips_reputation": ( + "id INTEGER PRIMARY KEY NOT NULL, " + "ipaddress TEXT NOT NULL, " + "score REAL NOT NULL, " + "confidence REAL NOT NULL, " + "update_time REAL NOT NULL" + ), + "go_reliability": ( + "id INTEGER PRIMARY KEY NOT NULL, " + "peerid TEXT NOT NULL, " + "reliability REAL NOT NULL, " + "update_time REAL NOT NULL" + ), + "peer_ips": ( + "id INTEGER PRIMARY KEY NOT NULL, " + "ipaddress TEXT NOT NULL, " + "peerid TEXT NOT NULL, " + "update_time REAL NOT NULL" + ), + "reports": ( + "id INTEGER PRIMARY KEY NOT NULL, " + "reporter_peerid TEXT NOT NULL, " + "key_type TEXT NOT NULL, " + "reported_key TEXT NOT NULL, " + "score REAL NOT NULL, " + "confidence REAL NOT NULL, " + "update_time REAL NOT NULL" + ), + "opinion_cache": ( + "key_type TEXT NOT NULL, " + "reported_key TEXT NOT NULL PRIMARY KEY, " + "score REAL NOT NULL, " + "confidence REAL NOT NULL, " + "network_score REAL NOT NULL, " + "update_time DATE NOT NULL" + ), + } + + for table, schema in table_schema.items(): + self.create_table(table, schema) def delete_tables(self): - self.conn.execute("DROP TABLE IF EXISTS opinion_cache;") - self.conn.execute("DROP TABLE IF EXISTS slips_reputation;") - self.conn.execute("DROP TABLE IF EXISTS go_reliability;") - self.conn.execute("DROP TABLE IF EXISTS peer_ips;") - self.conn.execute("DROP TABLE IF EXISTS reports;") + tables = [ + "opinion_cache", + "slips_reputation", + "go_reliability", + "peer_ips", + "reports", + ] + for table in tables: + self.execute(f"DROP TABLE IF EXISTS {table};") def insert_slips_score( self, ip: str, score: float, confidence: float, timestamp: int = None ): if timestamp is None: timestamp = time.time() - parameters = (ip, score, confidence, timestamp) - self.conn.execute( - "INSERT INTO slips_reputation (ipaddress, score, confidence, update_time) " - "VALUES (?, ?, ?, ?);", - parameters, - ) - self.conn.commit() + + query = """ + INSERT OR REPLACE INTO slips_reputation + (ipaddress, score, confidence, update_time) + VALUES (?, ?, ?, ?) + """ + self.execute(query, (ip, score, confidence, timestamp)) def insert_go_reliability( self, peerid: str, reliability: float, timestamp: int = None @@ -106,13 +107,10 @@ def insert_go_reliability( if timestamp is None: timestamp = datetime.datetime.now() - parameters = (peerid, reliability, timestamp) - self.conn.execute( - "INSERT INTO go_reliability (peerid, reliability, update_time) " - "VALUES (?, ?, ?);", - parameters, + values = (peerid, reliability, timestamp) + self.insert( + "go_reliability", values, "peerid, reliability, update_time" ) - self.conn.commit() def insert_go_ip_pairing( self, peerid: str, ip: str, timestamp: int = None @@ -120,22 +118,8 @@ def insert_go_ip_pairing( if timestamp is None: timestamp = datetime.datetime.now() - parameters = (ip, peerid, timestamp) - self.conn.execute( - "INSERT INTO peer_ips (ipaddress, peerid, update_time) " - "VALUES (?, ?, ?);", - parameters, - ) - self.conn.commit() - - def insert_new_go_data(self, reports: list): - self.conn.executemany( - "INSERT INTO reports " - "(reporter_peerid, key_type, reported_key, score, confidence, update_time) " - "VALUES (?, ?, ?, ?, ?, ?)", - reports, - ) - self.conn.commit() + values = (ip, peerid, timestamp) + self.insert("peer_ips", values, "ipaddress, peerid, update_time") def insert_new_go_report( self, @@ -151,8 +135,7 @@ def insert_new_go_report( # f"score: {score} confidence: {confidence} timestamp: {timestamp} ") if timestamp is None: - timestamp = datetime.datetime.now() - timestamp = time.time() + timestamp = time.time() parameters = ( reporter_peerid, @@ -162,13 +145,12 @@ def insert_new_go_report( confidence, timestamp, ) - self.conn.execute( - "INSERT INTO reports " - "(reporter_peerid, key_type, reported_key, score, confidence, update_time) " - "VALUES (?, ?, ?, ?, ?, ?)", + self.insert( + "reports", parameters, + "reporter_peerid, key_type, reported_key, score, " + "confidence, update_time", ) - self.conn.commit() def update_cached_network_opinion( self, @@ -178,103 +160,101 @@ def update_cached_network_opinion( confidence: float, network_score: float, ): - self.conn.execute( + self.execute( "REPLACE INTO" - " opinion_cache (key_type, reported_key, score, confidence, network_score, update_time)" + " opinion_cache (key_type, reported_key, " + "score, confidence, network_score, update_time)" "VALUES (?, ?, ?, ?, ?, strftime('%s','now'));", (key_type, reported_key, score, confidence, network_score), ) - self.conn.commit() def get_cached_network_opinion(self, key_type: str, reported_key: str): - cache_cur = self.conn.execute( - "SELECT score, confidence, network_score, update_time " - "FROM opinion_cache " - "WHERE key_type = ? " - " AND reported_key = ? " - "ORDER BY update_time LIMIT 1;", - (key_type, reported_key), + res = self.select( + table_name="opinion_cache", + columns="score, confidence, network_score, update_time", + condition="key_type = ? AND reported_key = ?", + params=(key_type, reported_key), + order_by="update_time", + limit=1, ) - result = cache_cur.fetchone() - if result is None: - result = None, None, None, None - return result + if res is None: + return None, None, None, None + return res def get_ip_of_peer(self, peerid): """ Returns the latest IP seen associated with the given peerid :param peerid: the id of the peer we want the ip of + returns a tuple with (last_update_time, ip) """ - cache_cur = self.conn.execute( - "SELECT MAX(update_time) AS ip_update_time, ipaddress FROM peer_ips WHERE peerid = ?;", - ((peerid),), + res = self.select( + table_name="peer_ips", + columns="MAX(update_time) AS ip_update_time, ipaddress", + condition="peerid = ?", + params=(peerid,), + limit=1, ) - if res := cache_cur.fetchone(): - last_update_time, ip = res - return last_update_time, ip - return False, False + return res if res else (False, False) def get_reports_for_ip(self, ipaddress): """ Returns a list of all reports for the given IP address. """ - reports_cur = self.conn.execute( - "SELECT reports.reporter_peerid, reports.update_time, reports.score, " - " reports.confidence, reports.reported_key " - "FROM reports " - "WHERE reports.reported_key = ? AND reports.key_type = 'ip'" - "ORDER BY reports.update_time DESC;", - (ipaddress,), + return self.select( + table_name="reports", + columns="reporter_peerid, update_time, score, confidence, reported_key", + condition="reported_key = ? AND key_type = ?", + params=(ipaddress, "ip"), ) - return reports_cur.fetchall() - def get_reporter_ip(self, reporter_peerid, report_timestamp): + def get_reporter_ip(self, reporter_peerid, report_timestamp) -> str: """ Returns the IP address of the reporter at the time of the report. """ - ip_cur = self.conn.execute( - "SELECT MAX(update_time), ipaddress " - "FROM peer_ips " - "WHERE update_time <= ? AND peerid = ? " - "ORDER BY update_time DESC " - "LIMIT 1;", - (report_timestamp, reporter_peerid), + res = self.select( + table_name="peer_ips", + columns="MAX(update_time), ipaddress", + condition="update_time <= ? AND peerid = ?", + params=(report_timestamp, reporter_peerid), + limit=1, ) - if res := ip_cur.fetchone(): - return res[1] + + if res: + return res[1] # Return the IP address return None def get_reporter_reliability(self, reporter_peerid): """ Returns the latest reliability score for the given peer. """ - go_reliability_cur = self.conn.execute( - "SELECT reliability " - "FROM go_reliability " - "WHERE peerid = ? " - "ORDER BY update_time DESC " - "LIMIT 1;" + res = self.select( + table_name="go_reliability", + columns="reliability", + condition="peerid = ?", + params=(reporter_peerid,), + limit=1, ) - if res := go_reliability_cur.fetchone(): + + try: return res[0] - return None + except IndexError: + return None def get_reporter_reputation(self, reporter_ipaddress): """ - Returns the latest reputation score and confidence for the given IP address. + returns the latest reputation score and confidence for the given IP address. """ - slips_reputation_cur = self.conn.execute( - "SELECT score, confidence " - "FROM slips_reputation " - "WHERE ipaddress = ? " - "ORDER BY update_time DESC " - "LIMIT 1;", - (reporter_ipaddress,), + res = self.select( + table_name="slips_reputation", + columns="score, confidence", + condition="ipaddress = ?", + params=(reporter_ipaddress,), + order_by="update_time DESC", + limit=1, ) - if res := slips_reputation_cur.fetchone(): - return res - return None, None + + return res or (None, None) def get_opinion_on_ip(self, ipaddress): """ @@ -282,6 +262,7 @@ def get_opinion_on_ip(self, ipaddress): reporter reliability, reporter score, and reporter confidence for a given IP address. """ reports = self.get_reports_for_ip(ipaddress) + reporters_scores = [] for ( @@ -309,18 +290,21 @@ def get_opinion_on_ip(self, ipaddress): if reporter_score is None or reporter_confidence is None: continue + # TODO update the docs in assemble_peer_opinion() when the + # format of this list changes:D reporters_scores.append( ( report_score, report_confidence, reporter_reliability, - reporter_score, + reporter_score, # what does slips think about the reporter's ip + # how confident slips is about the reporter's ip's score reporter_confidence, + reporter_ipaddress, ) ) - return reporters_scores if __name__ == "__main__": - trustDB = TrustDB(r"trustdb.db") + trustDB = TrustDB("trustdb.db") diff --git a/modules/p2ptrust/utils/go_director.py b/modules/p2ptrust/utils/go_director.py index a7ffc05199..fd7e5a02b5 100644 --- a/modules/p2ptrust/utils/go_director.py +++ b/modules/p2ptrust/utils/go_director.py @@ -116,7 +116,8 @@ def handle_gopy_data(self, data_dict: dict): except json.decoder.JSONDecodeError: self.print( - f"Couldn't load message from pigeon - invalid Json from the pigeon: {data_dict}", + f"Couldn't load message from pigeon - invalid Json from" + f" the pigeon: {data_dict}", 0, 1, ) @@ -134,7 +135,8 @@ def process_go_data(self, report: dict) -> None: The data is expected to be a list of messages received from go peers. They are parsed and inserted into the database. - If a message does not comply with the format, the reporter's reputation is lowered. + If a message does not comply with the format, the reporter's + reputation is lowered. """ # "message_type":"go_data", # "message_contents":{"reporter":"aconcagua","report_time":1649445643,"message": @@ -398,7 +400,7 @@ def process_evaluation_score_confidence( :param report_time: Time of receiving the data, provided by the go part :param key_type: The type of key the peer is reporting (only "ip" is supported now) - :param key: The key itself + :param key: The key itself, aka the ip :param evaluation: Dictionary containing score and confidence values :return: None, data is saved to the database """ @@ -437,7 +439,6 @@ def process_evaluation_score_confidence( self.print("Confidence value is out of bounds", 0, 2) # TODO: lower reputation return - self.trustdb.insert_new_go_report( reporter, key_type, key, score, confidence, report_time ) @@ -447,19 +448,7 @@ def process_evaluation_score_confidence( f"score {score}, confidence {confidence}" ) self.print(result, 2, 0) - # print(f"*** [debugging p2p] *** stored a report about about - # {key} from {reporter} in p2p_reports key in the db ") - # save all report info in the db - # convert ts to human readable format - report_info = { - "reporter": reporter, - "report_time": utils.convert_ts_format( - report_time, utils.alerts_format - ), - } - report_info.update(evaluation) - self.db.store_p2p_report(key, report_info) - + report_time = time.time() # create a new profile for the reported ip # with the width from slips.yaml and the starttime as the report time if key_type == "ip": diff --git a/modules/p2ptrust/utils/utils.py b/modules/p2ptrust/utils/utils.py index 172e5b4dec..9ff0eedb11 100644 --- a/modules/p2ptrust/utils/utils.py +++ b/modules/p2ptrust/utils/utils.py @@ -144,33 +144,10 @@ def read_data_from_ip_info(ip_info: dict) -> (float, float): confidence = float(confidence.split()[-1]) return float(score), confidence - except KeyError: + except (KeyError, TypeError): return None, None -def save_ip_report_to_db( - ip, score, confidence, network_trust, db, timestamp=None -): - if timestamp is None: - timestamp = time.time() - - report_data = { - "score": score, - "confidence": confidence, - "network_score": network_trust, - "timestamp": timestamp, - } - - # store it in p2p_reports key - # print(f"*** [debugging p2p] *** stored a report about - # {ip} in p2p_Reports and IPsInfo keys") - db.store_p2p_report(ip, report_data) - - # store it in IPsInfo key - wrapped_data = {"p2p4slips": report_data} - db.set_ip_info(ip, wrapped_data) - - # # SEND COMMUNICATION TO GO # @@ -182,13 +159,16 @@ def build_go_message( evaluation=None, ) -> dict: """ - Assemble parameters to one dictionary, with keys that are expected by the remote peer. + Assemble parameters to one dictionary, with keys that are expected by the + remote peer. :param message_type: Type of message (request, report, blame...) :param key_type: Type of key, usually "ip" :param key: The key the message is about - :param evaluation_type: Type of evaluation that is reported (for report and blame) or expected (for request message) - :param evaluation: The score that is being reported (for report and blame). This can be left out for request message + :param evaluation_type: Type of evaluation that is reported (for report + and blame) or expected (for request message) + :param evaluation: The score that is being reported (for report and + blame). This can be left out for request message :return: A dictionary with proper values set. """ diff --git a/modules/riskiq/riskiq.py b/modules/riskiq/riskiq.py index 5abf2ddb19..d2d83bdb61 100644 --- a/modules/riskiq/riskiq.py +++ b/modules/riskiq/riskiq.py @@ -7,7 +7,7 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class RiskIQ(IModule): diff --git a/modules/rnn_cc_detection/rnn_cc_detection.py b/modules/rnn_cc_detection/rnn_cc_detection.py index 127f861752..aea2fef51e 100644 --- a/modules/rnn_cc_detection/rnn_cc_detection.py +++ b/modules/rnn_cc_detection/rnn_cc_detection.py @@ -9,7 +9,7 @@ from tensorflow.keras.models import load_model from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.core.structures.evidence import ( Evidence, ProfileID, diff --git a/modules/template/template.py b/modules/template/template.py index a675ba6cd9..45072f7ab6 100644 --- a/modules/template/template.py +++ b/modules/template/template.py @@ -14,7 +14,7 @@ from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule import json diff --git a/modules/threat_intelligence/threat_intelligence.py b/modules/threat_intelligence/threat_intelligence.py index 3d662b0b7a..04c76dfa5d 100644 --- a/modules/threat_intelligence/threat_intelligence.py +++ b/modules/threat_intelligence/threat_intelligence.py @@ -21,7 +21,7 @@ from modules.threat_intelligence.spamhaus import Spamhaus from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from modules.threat_intelligence.urlhaus import URLhaus from slips_files.core.structures.evidence import ( Evidence, diff --git a/modules/timeline/timeline.py b/modules/timeline/timeline.py index 45f8603c8c..545307b4cd 100644 --- a/modules/timeline/timeline.py +++ b/modules/timeline/timeline.py @@ -12,7 +12,7 @@ from slips_files.common.flow_classifier import FlowClassifier from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class Timeline(IModule): diff --git a/modules/update_manager/update_manager.py b/modules/update_manager/update_manager.py index ba8106aa5c..70173b7e28 100644 --- a/modules/update_manager/update_manager.py +++ b/modules/update_manager/update_manager.py @@ -24,7 +24,7 @@ from modules.update_manager.timer_manager import InfiniteTimer from slips_files.common.parsers.config_parser import ConfigParser -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.common.slips_utils import utils from slips_files.core.helpers.whitelist.whitelist import Whitelist diff --git a/modules/virustotal/virustotal.py b/modules/virustotal/virustotal.py index 61e0bf1feb..538d62d316 100644 --- a/modules/virustotal/virustotal.py +++ b/modules/virustotal/virustotal.py @@ -12,7 +12,7 @@ from slips_files.common.flow_classifier import FlowClassifier from slips_files.common.parsers.config_parser import ConfigParser -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule from slips_files.common.slips_utils import utils diff --git a/slips/main.py b/slips/main.py index e68c2903c9..0d8422dc77 100644 --- a/slips/main.py +++ b/slips/main.py @@ -477,6 +477,7 @@ def start(self): self.logger, self.args.output, self.redis_port, + self.conf, start_redis_server=start_redis_server, ) except RuntimeError as e: diff --git a/slips_files/common/abstracts/async_module.py b/slips_files/common/abstracts/iasync_module.py similarity index 98% rename from slips_files/common/abstracts/async_module.py rename to slips_files/common/abstracts/iasync_module.py index 15685d5ae3..c2315d4732 100644 --- a/slips_files/common/abstracts/async_module.py +++ b/slips_files/common/abstracts/iasync_module.py @@ -6,7 +6,7 @@ Callable, List, ) -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class AsyncModule(IModule): diff --git a/slips_files/common/abstracts/core.py b/slips_files/common/abstracts/icore.py similarity index 96% rename from slips_files/common/abstracts/core.py rename to slips_files/common/abstracts/icore.py index 30b27f751e..954782bf53 100644 --- a/slips_files/common/abstracts/core.py +++ b/slips_files/common/abstracts/icore.py @@ -3,7 +3,7 @@ import traceback from multiprocessing import Process -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule class ICore(IModule, Process): diff --git a/slips_files/common/abstracts/exporter.py b/slips_files/common/abstracts/iexporter.py similarity index 100% rename from slips_files/common/abstracts/exporter.py rename to slips_files/common/abstracts/iexporter.py diff --git a/slips_files/common/abstracts/flowalerts_analyzer.py b/slips_files/common/abstracts/iflowalerts_analyzer.py similarity index 100% rename from slips_files/common/abstracts/flowalerts_analyzer.py rename to slips_files/common/abstracts/iflowalerts_analyzer.py diff --git a/slips_files/common/abstracts/input_type.py b/slips_files/common/abstracts/iinput_type.py similarity index 100% rename from slips_files/common/abstracts/input_type.py rename to slips_files/common/abstracts/iinput_type.py diff --git a/slips_files/common/abstracts/module.py b/slips_files/common/abstracts/imodule.py similarity index 96% rename from slips_files/common/abstracts/module.py rename to slips_files/common/abstracts/imodule.py index 88117cb6b1..844372813a 100644 --- a/slips_files/common/abstracts/module.py +++ b/slips_files/common/abstracts/imodule.py @@ -35,7 +35,8 @@ def __init__( output_dir, redis_port, termination_event, - args, + slips_args, + conf, **kwargs, ): Process.__init__(self) @@ -43,12 +44,16 @@ def __init__( self.output_dir = output_dir self.msg_received = False # as parsed by arg_parser, these are the cli args - self.args: Namespace = args + self.args: Namespace = slips_args + # to be able to access the configuration file + self.conf = conf # used to tell all slips.py children to stop self.termination_event: Event = termination_event self.logger = logger self.printer = Printer(self.logger, self.name) - self.db = DBManager(self.logger, self.output_dir, self.redis_port) + self.db = DBManager( + self.logger, self.output_dir, self.redis_port, self.conf + ) self.keyboard_int_ctr = 0 self.init(**kwargs) # should after the module's init() so the module has a chance to diff --git a/slips_files/common/abstracts/observer.py b/slips_files/common/abstracts/iobserver.py similarity index 100% rename from slips_files/common/abstracts/observer.py rename to slips_files/common/abstracts/iobserver.py diff --git a/slips_files/common/abstracts/performance_profiler.py b/slips_files/common/abstracts/iperformance_profiler.py similarity index 100% rename from slips_files/common/abstracts/performance_profiler.py rename to slips_files/common/abstracts/iperformance_profiler.py diff --git a/slips_files/common/abstracts/isqlite.py b/slips_files/common/abstracts/isqlite.py new file mode 100644 index 0000000000..b911a76feb --- /dev/null +++ b/slips_files/common/abstracts/isqlite.py @@ -0,0 +1,202 @@ +import fcntl +import sqlite3 +from abc import ABC +from threading import Lock +from time import sleep + + +class ISQLite(ABC): + """ + Interface for SQLite database operations. + Any sqlite db that slips connects to should use thisinterface for + avoiding common sqlite errors + """ + + # to avoid multi threading errors where multiple threads try to write to + # the same sqlite db at the same time + cursor_lock = Lock() + + def __init__(self, name): + """ + :param name: the name of the sqlite db, used to create a lock file + """ + # enable write-ahead logging for concurrent reads and writes to + # avoid the "DB is locked" error + # to avoid multi processing errors where multiple processes + # try to write to the same sqlite db at the same time + # this name needs to change per sqlite db, meaning trustb should have + # its own lock file that is different from slips' main sqlite db lockfile + self.lockfile_name = f"/tmp/slips_{name}.lock" + # important: do not use self.execute here because this query + # shouldnt be wrapped in a transaction, which is what self.execute( + # ) does + self.conn.execute("PRAGMA journal_mode=WAL;") + + def _acquire_flock(self): + """to avoid multiprocess issues with sqlite, + we use a lock file, if the lock file is acquired by a different + proc, the current proc will wait until the lock is released""" + self.lockfile_fd = open(self.lockfile_name, "w") + fcntl.flock(self.lockfile_fd, fcntl.LOCK_EX) + + def _release_flock(self): + try: + fcntl.flock(self.lockfile_fd, fcntl.LOCK_UN) + self.lockfile_fd.close() + except ValueError: + # to handle trying to release an already released + # lock "ValueError: I/O operation on closed file" + pass + + def print(self, *args, **kwargs): + return self.printer.print(*args, **kwargs) + + def get_number_of_tables(self): + """ + returns the number of tables in the current db + """ + query = "SELECT count(*) FROM sqlite_master WHERE type='table';" + self.execute(query) + x = self.fetchone() + return x[0] + + def create_table(self, table_name, schema): + query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})" + self.execute(query) + + def insert(self, table_name, values: tuple, columns: str = None): + if columns: + placeholders = ", ".join(["?"] * len(values)) + query = ( + f"INSERT INTO {table_name} ({columns}) " + f"VALUES ({placeholders})" + ) + self.execute(query, values) + else: + query = f"INSERT INTO {table_name} VALUES {values}" # fallback + self.execute(query) + + def update(self, table_name, set_clause, condition): + query = f"UPDATE {table_name} SET {set_clause} WHERE {condition}" + self.execute(query) + + def delete(self, table_name, condition): + query = f"DELETE FROM {table_name} WHERE {condition}" + self.execute(query) + + def select( + self, + table_name, + columns="*", + condition=None, + params=(), + order_by=None, + limit: int = None, + ): + query = f"SELECT {columns} FROM {table_name} " + if condition: + query += f" WHERE {condition}" + if order_by: + query += f" ORDER BY {order_by}" + + self.execute(query, params) + if limit == 1: + result = self.fetchone() + else: + result = self.fetchall() + return result + + def get_count(self, table, condition=None): + """ + returns th enumber of matching rows in the given table + based on a specific contioins + """ + query = f"SELECT COUNT(*) FROM {table}" + + if condition: + query += f" WHERE {condition}" + + self.execute(query) + return self.fetchone()[0] + + def close(self): + self.cursor.close() + self.conn.close() + + def fetchall(self): + """ + wrapper for sqlite fetchall to be able to use a lock + """ + with self.cursor_lock: + res = self.cursor.fetchall() + return res + + def fetchone(self): + """ + wrapper for sqlite fetchone to be able to use a lock + """ + with self.cursor_lock: + res = self.cursor.fetchone() + return res + + def execute(self, query: str, params=None) -> None: + """ + wrapper for sqlite execute() To avoid + 'Recursive use of cursors not allowed' error + and to be able to use a Lock() + + since sqlite is terrible with multi-process applications + this function should be used instead of all calls to commit() and + execute() + + using transactions here is a must. + Since slips uses python3.10, we can't use autocommit here. we have + to do it manually + any conn other than the current one will not see the changes this + conn did unless they're committed. + + Each call to this function results in 1 sqlite transaction + """ + trial = 0 + max_trials = 5 + while trial < max_trials: + try: + # note that self.conn.in_transaction is not reliable + # sqlite may change the state internally, on errors for + # example. + # if no errors occur, this will be the only transaction in + # the conn + with self.cursor_lock: + if self.conn.in_transaction is False: + self.cursor.execute("BEGIN") + self._acquire_flock() + if params is None: + self.cursor.execute(query) + else: + self.cursor.execute(query, params) + self._release_flock() + + # aka END TRANSACTION + if self.conn.in_transaction: + self.conn.commit() + + return + + except sqlite3.Error as err: + self._release_flock() + # no need to manually rollback here + # sqlite automatically rolls back the tx if an error occurs + trial += 1 + if trial >= max_trials: + self.print( + f"Error executing query: " + f"'{query}'. Params: {params}. Error: {err}. " + f"Retried executing {trial} times but failed. " + f"Query discarded.", + 0, + 1, + ) + return + + elif "database is locked" in str(err): + sleep(5) diff --git a/slips_files/common/abstracts/unblocker.py b/slips_files/common/abstracts/iunblocker.py similarity index 100% rename from slips_files/common/abstracts/unblocker.py rename to slips_files/common/abstracts/iunblocker.py diff --git a/slips_files/common/abstracts/whitelist_analyzer.py b/slips_files/common/abstracts/iwhitelist_analyzer.py similarity index 100% rename from slips_files/common/abstracts/whitelist_analyzer.py rename to slips_files/common/abstracts/iwhitelist_analyzer.py diff --git a/slips_files/common/imports.py b/slips_files/common/imports.py index 2d428f5f40..3f9cef27b6 100644 --- a/slips_files/common/imports.py +++ b/slips_files/common/imports.py @@ -4,4 +4,4 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils import multiprocessing -from slips_files.common.abstracts.module import IModule +from slips_files.common.abstracts.imodule import IModule diff --git a/slips_files/common/performance_profilers/_memory_profiler_example_no_import.py b/slips_files/common/performance_profilers/_memory_profiler_example_no_import.py index 1ad48a03c8..26b079a7b5 100644 --- a/slips_files/common/performance_profilers/_memory_profiler_example_no_import.py +++ b/slips_files/common/performance_profilers/_memory_profiler_example_no_import.py @@ -433,7 +433,7 @@ def _release_lock(self): def _cleanup(self): self._pop_map() - self._release_lock() + self._release_flock() cls._cleanup = _cleanup diff --git a/slips_files/common/performance_profilers/cpu_profiler.py b/slips_files/common/performance_profilers/cpu_profiler.py index 858dd3e538..12bf8cf10b 100644 --- a/slips_files/common/performance_profilers/cpu_profiler.py +++ b/slips_files/common/performance_profilers/cpu_profiler.py @@ -8,7 +8,7 @@ import pstats import os -from slips_files.common.abstracts.performance_profiler import ( +from slips_files.common.abstracts.iperformance_profiler import ( IPerformanceProfiler, ) diff --git a/slips_files/common/performance_profilers/memory_profiler.py b/slips_files/common/performance_profilers/memory_profiler.py index ea1fe96aa2..30d1b39789 100644 --- a/slips_files/common/performance_profilers/memory_profiler.py +++ b/slips_files/common/performance_profilers/memory_profiler.py @@ -5,7 +5,7 @@ import os import subprocess from termcolor import colored -from slips_files.common.abstracts.performance_profiler import ( +from slips_files.common.abstracts.iperformance_profiler import ( IPerformanceProfiler, ) import time @@ -433,7 +433,7 @@ def _release_lock(self): def _cleanup(self): self._pop_map() - self._release_lock() + self._release_flock() cls._cleanup = _cleanup diff --git a/slips_files/common/printer.py b/slips_files/common/printer.py index 69bc84db8f..26f081bb50 100644 --- a/slips_files/common/printer.py +++ b/slips_files/common/printer.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only -from slips_files.common.abstracts.observer import IObservable +from slips_files.common.abstracts.iobserver import IObservable from slips_files.core.output import Output diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index 706931cb73..708b278bbc 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -19,7 +19,7 @@ import sys import ipaddress import aid_hash -from typing import Any, Optional, Union, List, Dict +from typing import Any, Optional, Union, List from ipaddress import IPv4Network, IPv6Network, IPv4Address, IPv6Address from dataclasses import is_dataclass, asdict from enum import Enum @@ -426,11 +426,31 @@ def get_mac_for_ip_using_cache(self, ip: str) -> str | None: pass return None - def get_own_ips(self, ret=Dict) -> Dict[str, List[str]] | List[str]: + def get_public_ip(self) -> str | None: """ - Returns a dict of our private IPs from all interfaces and our public + fetch public IP from ipinfo.io + returns either an IPv4 or IPv6 address as a string, or None if unavailable + """ + try: + response = requests.get("http://ipinfo.io/json", timeout=5) + if response.status_code == 200: + data = json.loads(response.text) + if "ip" in data: + return data["ip"] + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ReadTimeout, + json.decoder.JSONDecodeError, + ): + return None + + def get_own_ips(self, ret="Dict") -> dict[str, list[str]] | list[str]: + """ + returns a dict of our private IPs from all interfaces and our public IPs. return a dict by default e.g. { "ipv4": [..], "ipv6": [..] } + :kwarg ret: "Dict" or "List" and returns a list of all the ips combined if ret=List is given """ if "-i" not in sys.argv and "-g" not in sys.argv: @@ -439,18 +459,14 @@ def get_own_ips(self, ret=Dict) -> Dict[str, List[str]] | List[str]: ips = {"ipv4": [], "ipv6": []} - interfaces = netifaces.interfaces() - - for interface in interfaces: + for interface in netifaces.interfaces(): try: addrs = netifaces.ifaddresses(interface) - # get IPv4 addresses if netifaces.AF_INET in addrs: for addr in addrs[netifaces.AF_INET]: ips["ipv4"].append(addr["addr"]) - # get IPv6 addresses if netifaces.AF_INET6 in addrs: for addr in addrs[netifaces.AF_INET6]: # remove interface suffix @@ -460,37 +476,16 @@ def get_own_ips(self, ret=Dict) -> Dict[str, List[str]] | List[str]: except Exception as e: print(f"Error processing interface {interface}: {e}") - # get public ip - try: - response = requests.get( - "http://ipinfo.io/json", - timeout=5, - ) - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ChunkedEncodingError, - requests.exceptions.ReadTimeout, - ): - return ips - - if response.status_code != 200: - return ips - if "Connection timed out" in response.text: - return ips - try: - response = json.loads(response.text) - except json.decoder.JSONDecodeError: - return ips - - public_ip = response["ip"] - if validators.ipv4(public_ip): - ips["ipv4"].append(public_ip) - elif validators.ipv6(public_ip): - ips["ipv6"].append(public_ip) + public_ip = self.get_public_ip() + if public_ip: + if validators.ipv4(public_ip): + ips["ipv4"].append(public_ip) + elif validators.ipv6(public_ip): + ips["ipv6"].append(public_ip) - if ret == Dict: + if ret == "Dict": return ips - elif ret == List: + elif ret == "List": return [ip for sublist in ips.values() for ip in sublist] def convert_to_mb(self, bytes): diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index f2ccd1d12c..12aeee9034 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only +import os +from pathlib import Path from typing import ( List, Dict, ) +from modules.p2ptrust.trust.trustdb import TrustDB from slips_files.common.printer import Printer +from slips_files.common.slips_utils import utils from slips_files.core.database.redis_db.database import RedisDB from slips_files.core.database.sqlite_db.database import SQLiteDB from slips_files.common.parsers.config_parser import ConfigParser @@ -28,10 +32,12 @@ def __init__( logger: Output, output_dir, redis_port, + conf, start_sqlite=True, start_redis_server=True, **kwargs, ): + self.conf = conf self.output_dir = output_dir self.redis_port = redis_port self.logger = logger @@ -40,6 +46,15 @@ def __init__( self.logger, redis_port, start_redis_server, **kwargs ) + self.trust_db = None + if self.conf.use_local_p2p(): + self.trust_db_path: str = self.init_p2ptrust_db() + self.trust_db = TrustDB( + self.logger, + self.trust_db_path, + drop_tables_on_startup=False, + ) + # in some rare cases we don't wanna create the sqlite db from scratch, # like when using -S to stop the daemon, we just wanna connect to # the existing one @@ -47,6 +62,19 @@ def __init__( if start_sqlite: self.sqlite = SQLiteDB(self.logger, output_dir) + def init_p2ptrust_db(self) -> str: + """returns the path of the trustdb inside the p2ptrust_runtime_dir""" + p2ptrust_runtime_dir = os.path.join(os.getcwd(), "p2ptrust_runtime/") + Path(p2ptrust_runtime_dir).mkdir(parents=True, exist_ok=True) + self.p2ptrust_runtime_dir = p2ptrust_runtime_dir + return os.path.join(p2ptrust_runtime_dir, "trustdb.db") + + def get_p2ptrust_dir(self) -> str: + return self.p2ptrust_runtime_dir + + def get_p2ptrust_db_path(self) -> str: + return self.trust_db_path + def print(self, *args, **kwargs): return self.printer.print(*args, **kwargs) @@ -174,11 +202,11 @@ def update_accumulated_threat_level(self, *args, **kwargs): def set_ip_info(self, *args, **kwargs): return self.rdb.set_ip_info(*args, **kwargs) - def get_p2p_reports_about_ip(self, *args, **kwargs): - return self.rdb.get_p2p_reports_about_ip(*args, **kwargs) + def set_peer_trust(self, *args, **kwargs): + return self.rdb.set_peer_trust(*args, **kwargs) - def store_p2p_report(self, *args, **kwargs): - return self.rdb.store_p2p_report(*args, **kwargs) + def get_peer_trust(self, *args, **kwargs): + return self.rdb.get_peer_trust(*args, **kwargs) def get_dns_resolution(self, *args, **kwargs): return self.rdb.get_dns_resolution(*args, **kwargs) @@ -421,8 +449,17 @@ def get_flows_causing_evidence(self, *args, **kwargs): """returns the list of uids of the flows causing evidence""" return self.rdb.get_flows_causing_evidence(*args, **kwargs) - def set_evidence(self, *args, **kwargs): - return self.rdb.set_evidence(*args, **kwargs) + def set_evidence(self, evidence: Evidence): + evidence_set = self.rdb.set_evidence(evidence) + if evidence_set: + # an evidence is generated for this profile + # update the threat level of this profile + self.update_threat_level( + str(evidence.attacker.profile), + str(evidence.threat_level), + evidence.confidence, + ) + return evidence_set def set_alert( self, alert: Alert, evidence_causing_the_alert: Dict[str, Evidence] @@ -434,6 +471,11 @@ def set_alert( self.rdb.set_alert(alert) self.sqlite.add_alert(alert) + # when an alert is generated , we should set the threat level of the + # attacker's profile to 1(critical) and confidence 1 + # so that it gets reported to other peers with these numbers + self.update_threat_level(str(alert.profile), "critical", 1) + for evidence_id in evidence_causing_the_alert.keys(): uids: List[str] = self.rdb.get_flows_causing_evidence(evidence_id) self.set_flow_label(uids, "malicious") @@ -472,8 +514,20 @@ def get_profileid_twid_alerts(self, *args, **kwargs): def get_twid_evidence(self, *args, **kwargs): return self.rdb.get_twid_evidence(*args, **kwargs) - def update_threat_level(self, *args, **kwargs): - return self.rdb.update_threat_level(*args, **kwargs) + def update_threat_level( + self, profileid: str, threat_level: str, confidence: float + ): + """updates the threat level and confidence of an ip in redis and + trust db for other peers to use it""" + if self.trust_db: + ip = profileid.split("_")[-1] + float_threat_level = utils.threat_levels[threat_level] + self.trust_db.insert_slips_score( + ip, float_threat_level, confidence + ) + return self.rdb.update_threat_level( + profileid, threat_level, confidence + ) def set_loaded_ti_files(self, *args, **kwargs): return self.rdb.set_loaded_ti_files(*args, **kwargs) @@ -723,8 +777,12 @@ def get_first_twid_for_profile(self, *args, **kwargs): def get_tw_of_ts(self, *args, **kwargs): return self.rdb.get_tw_of_ts(*args, **kwargs) - def add_new_tw(self, *args, **kwargs): - return self.rdb.add_new_tw(*args, **kwargs) + def add_new_tw(self, profileid, timewindow: str, startoftw: float): + self.rdb.add_new_tw(profileid, timewindow, startoftw) + # When a new TW is created for this profile, + # change the threat level of the profile to 0(info) + # and confidence to 0.05 + self.update_threat_level(profileid, "info", 0.5) def get_tw_start_time(self, *args, **kwargs): return self.rdb.get_tw_start_time(*args, **kwargs) @@ -762,8 +820,10 @@ def get_user_agent_from_profile(self, *args, **kwargs): def mark_profile_as_dhcp(self, *args, **kwargs): return self.rdb.mark_profile_as_dhcp(*args, **kwargs) - def add_profile(self, *args, **kwargs): - return self.rdb.add_profile(*args, **kwargs) + def add_profile(self, profileid, starttime): + confidence = 0.05 + self.update_threat_level(profileid, "info", confidence) + return self.rdb.add_profile(profileid, starttime, confidence) def set_module_label_for_profile(self, *args, **kwargs): return self.rdb.set_module_label_for_profile(*args, **kwargs) @@ -969,10 +1029,12 @@ def close_sqlite(self, *args, **kwargs): if self.sqlite: self.sqlite.close(*args, **kwargs) - def close_redis_and_sqlite(self, *args, **kwargs): + def close_all_dbs(self, *args, **kwargs): self.rdb.r.close() self.rdb.rcache.close() self.close_sqlite() + if self.trust_db: + self.trust_db.close() def get_fides_ti(self, target: str): return self.rdb.get_fides_ti(target) diff --git a/slips_files/core/database/redis_db/alert_handler.py b/slips_files/core/database/redis_db/alert_handler.py index 234ccb55b1..435de45d8c 100644 --- a/slips_files/core/database/redis_db/alert_handler.py +++ b/slips_files/core/database/redis_db/alert_handler.py @@ -271,26 +271,12 @@ def set_evidence(self, evidence: Evidence): self.r.hset(evidence_hash, evidence.id, evidence_to_send) self.r.incr(self.constants.NUMBER_OF_EVIDENCE, 1) self.publish("evidence_added", evidence_to_send) - - # an evidence is generated for this profile - # update the threat level of this profile - self.update_threat_level( - str(evidence.attacker.profile), - str(evidence.threat_level), - evidence.confidence, - ) - return True return False def set_alert(self, alert: Alert): self.set_evidence_causing_alert(alert) - # when an alert is generated , we should set the threat level of the - # attacker's profile to 1(critical) and confidence 1 - # so that it gets reported to other peers with these numbers - self.update_threat_level(str(alert.profile), "critical", 1) - # reset the accumulated threat level now that an alert is generated self._set_accumulated_threat_level(alert, 0) self.mark_profile_as_malicious(alert.profile) @@ -488,9 +474,11 @@ def update_past_threat_levels(self, profileid, threat_level, confidence): self.r.hset(profileid, "past_threat_levels", past_threat_levels) def update_ips_info(self, profileid, max_threat_lvl, confidence): - # set the score and confidence of the given ip in the db - # when it causes an evidence - # these 2 values will be needed when sharing with peers + """ + sets the score and confidence of the given ip in the db + when it causes an evidence + these 2 values will be needed when sharing with peers + """ score_confidence = {"score": max_threat_lvl, "confidence": confidence} ip = profileid.split("_")[-1] @@ -512,6 +500,8 @@ def update_threat_level( in IPsInfo :param threat_level: available options are 'low', 'medium' 'critical' etc + Do not call this function directy from the db, always call it user + dbmanager.update_threat_level() to update the trustdb too:D """ self.r.hset(profileid, "threat_level", threat_level) diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 580372f015..35daa93360 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -115,7 +115,7 @@ class RedisDB(IoCHandler, AlertHandler, ProfileHandler, P2PHandler): # flag to know if we found the gateway MAC using the most seen MAC method _gateway_MAC_found = False _conf_file = "config/redis.conf" - our_ips: List[str] = utils.get_own_ips(ret=List) + our_ips: List[str] = utils.get_own_ips(ret="List") # to make sure we only detect and store the user's localnet once is_localnet_set = False # in case of redis ConnectionErrors, this is how long we'll wait in @@ -835,63 +835,6 @@ def get_redis_pid(self): """returns the pid of the current redis server""" return int(self.r.info()["process_id"]) - def get_p2p_reports_about_ip(self, ip) -> dict: - """ - returns a dict of all p2p past reports about the given ip - """ - # p2p_reports key is basically - # { ip: { reporter1: [report1, report2, report3]} } - if reports := self.rcache.hget(self.constants.P2P_REPORTS, ip): - return json.loads(reports) - return {} - - def store_p2p_report(self, ip: str, report_data: dict): - """ - stores answers about IPs slips asked other peers for. - updates the p2p_reports key only - """ - # reports in the db are sorted by reporter by default - reporter = report_data["reporter"] - del report_data["reporter"] - - # if we have old reports about this ip, append this one to them - # cached_p2p_reports is a dict - cached_p2p_reports: Dict[str, List[dict]] = ( - self.get_p2p_reports_about_ip(ip) - ) - if cached_p2p_reports: - # was this ip reported by the same peer before? - if reporter in cached_p2p_reports: - # ip was reported before, by the same peer - # did the same peer report the same score and - # confidence about the same ip twice in a row? - last_report_about_this_ip = cached_p2p_reports[reporter][-1] - score = report_data["score"] - confidence = report_data["confidence"] - if ( - last_report_about_this_ip["score"] == score - and last_report_about_this_ip["confidence"] == confidence - ): - report_time = report_data["report_time"] - # score and confidence are the same as the last report, - # only update the time - last_report_about_this_ip["report_time"] = report_time - else: - # score and confidence are the different from the last - # report, add report to the list - cached_p2p_reports[reporter].append(report_data) - else: - # ip was reported before, but not by the same peer - cached_p2p_reports[reporter] = [report_data] - report_data = cached_p2p_reports - else: - # no old reports about this ip - report_data = {reporter: [report_data]} - - self.rcache.hset( - self.constants.P2P_REPORTS, ip, json.dumps(report_data) - ) - def get_dns_resolution(self, ip: str): """ IF this IP was resolved by slips diff --git a/slips_files/core/database/redis_db/p2p_handler.py b/slips_files/core/database/redis_db/p2p_handler.py index 6804e1a5a9..e966a33389 100644 --- a/slips_files/core/database/redis_db/p2p_handler.py +++ b/slips_files/core/database/redis_db/p2p_handler.py @@ -104,3 +104,30 @@ def get_cached_network_opinion( k: v for k, v in cache_data.items() if k != "created_seconds" } return opinion + + def get_p2p_reports_about_ip(self, ip) -> dict: + """ + returns a dict of all p2p past reports about the given ip + """ + # p2p_reports key is basically + # { ip: { reporter1: [report1, report2, report3]} } + if reports := self.rcache.hget(self.constants.P2P_REPORTS, ip): + return json.loads(reports) + return {} + + def set_peer_trust(self, peer_ip, peer_trust): + """ + Set the trust value for a peer in the database. + :param peer_ip: IP address of the peer + :param peer_trust: Trust value to be set as determined by the + trust model + For now, this is only for local peers + """ + + self.r.hset("peer_trust", peer_ip, peer_trust) + + def get_peer_trust(self, peer_ip): + trust = self.r.hget("peer_trust", peer_ip) + if trust: + return float(trust) + return None diff --git a/slips_files/core/database/redis_db/profile_handler.py b/slips_files/core/database/redis_db/profile_handler.py index cd12f974d2..b2411233e5 100644 --- a/slips_files/core/database/redis_db/profile_handler.py +++ b/slips_files/core/database/redis_db/profile_handler.py @@ -1147,14 +1147,8 @@ def add_new_tw(self, profileid, timewindow: str, startoftw: float): 0, 4, ) - # The creation of a TW now does not imply that it was modified. # You need to put data to mark is at modified. - - # When a new TW is created for this profile, - # change the threat level of the profile to 0(info) - # and confidence to 0.05 - self.update_threat_level(profileid, "info", 0.5) except redis.exceptions.ResponseError: self.print("Error in addNewTW", 0, 1) self.print(traceback.format_exc(), 0, 1) @@ -1455,7 +1449,7 @@ def mark_profile_as_dhcp(self, profileid): if not is_dhcp_set: self.r.hset(profileid, "dhcp", "true") - def add_profile(self, profileid, starttime): + def add_profile(self, profileid, starttime, confidence=0.05): """ Add a new profile to the DB. Both the list of profiles and the hashmap of profile data @@ -1477,8 +1471,7 @@ def add_profile(self, profileid, starttime): self.r.hset(profileid, "duration", self.width) # When a new profiled is created assign threat level = 0 # and confidence = 0.05 - confidence = 0.05 - self.update_threat_level(profileid, "info", confidence) + self.r.hset(profileid, "confidence", confidence) # The IP of the profile should also be added as a new IP # we know about. diff --git a/slips_files/core/database/sqlite_db/database.py b/slips_files/core/database/sqlite_db/database.py index 6832846758..3c42cb3132 100644 --- a/slips_files/core/database/sqlite_db/database.py +++ b/slips_files/core/database/sqlite_db/database.py @@ -7,35 +7,26 @@ import json import csv from dataclasses import asdict -from threading import Lock -from time import sleep +from slips_files.common.abstracts.isqlite import ISQLite from slips_files.common.printer import Printer from slips_files.common.slips_utils import utils from slips_files.core.structures.alerts import Alert from slips_files.core.output import Output -class SQLiteDB: +class SQLiteDB(ISQLite): """ Stores all the flows slips reads and handles labeling them Creates a new db and connects to it if there's none in the given output_dir """ name = "SQLiteDB" - # used to lock operations using the self.cursor - cursor_lock = Lock() def __init__(self, logger: Output, output_dir: str): self.printer = Printer(logger, self.name) self._flows_db = os.path.join(output_dir, "flows.sqlite") - self.connect() - def connect(self): - """ - Creates the db if it doesn't exist and connects to it. - OR connects to the existing db if it's there. - """ db_newly_created = False if not os.path.exists(self._flows_db): # db not created, mark it as first time accessing it so we can @@ -43,34 +34,34 @@ def connect(self): db_newly_created = True self._init_db() - # you can get multithreaded access on a single pysqlite connection by - # passing "check_same_thread=False" - self.conn = sqlite3.connect( - self._flows_db, check_same_thread=False, timeout=20 - ) - # enable write-ahead logging for concurrent reads and writes to - # avoid the "DB is locked" error - self.conn.execute("PRAGMA journal_mode=WAL;") - self.cursor = self.conn.cursor() + self.connect() + super().__init__(self.name.lower()) if db_newly_created: # only init tables if the db is newly created self.init_tables() - def get_number_of_tables(self): + def connect(self): """ - returns the number of tables in the current db + Creates the db if it doesn't exist and connects to it. + OR connects to the existing db if it's there. """ - query = "SELECT count(*) FROM sqlite_master WHERE type='table';" - self.execute(query) - x = self.fetchone() - return x[0] + # you can get multithreaded access on a single pysqlite connection by + # passing "check_same_thread=False" + self.conn = sqlite3.connect( + self._flows_db, check_same_thread=False, timeout=20 + ) + self.cursor = self.conn.cursor() def init_tables(self): """creates the tables we're gonna use""" table_schema = { - "flows": "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT, aid TEXT", - "altflows": "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT, flow_type TEXT", - "alerts": "alert_id TEXT PRIMARY KEY, alert_time TEXT, ip_alerted TEXT, timewindow TEXT, tw_start TEXT, tw_end TEXT, label TEXT", + "flows": "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid " + "TEXT, twid TEXT, aid TEXT", + "altflows": "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, " + "profileid TEXT, twid TEXT, flow_type TEXT", + "alerts": "alert_id TEXT PRIMARY KEY, alert_time TEXT, ip_alerted " + "TEXT, timewindow TEXT, tw_start TEXT, tw_end TEXT, " + "label TEXT", } for table_name, schema in table_schema.items(): self.create_table(table_name, schema) @@ -81,13 +72,6 @@ def _init_db(self): """ open(self._flows_db, "w").close() - def create_table(self, table_name, schema): - query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})" - self.execute(query) - - def print(self, *args, **kwargs): - return self.printer.print(*args, **kwargs) - def get_db_path(self) -> str: """ returns the path of the sqlite flows db placed in the output dir @@ -214,7 +198,8 @@ def iterate_flows(self): def row_generator(): # select all flows and altflows self.execute( - "SELECT * FROM flows UNION SELECT uid, flow, label, profileid, twid FROM altflows" + "SELECT * FROM flows UNION SELECT uid, flow, label, profileid," + " twid FROM altflows" ) while True: @@ -296,7 +281,8 @@ def add_altflow(self, flow, profileid: str, twid: str, label="benign"): flow.type_, ) self.execute( - "INSERT OR REPLACE INTO altflows (profileid, twid, uid, flow, label, flow_type) " + "INSERT OR REPLACE INTO altflows (profileid, twid, uid, " + "flow, label, flow_type) " "VALUES (?, ?, ?, ?, ?, ?);", parameters, ) @@ -322,116 +308,3 @@ def add_alert(self, alert: Alert): now, ), ) - - def insert(self, table_name, values): - query = f"INSERT INTO {table_name} VALUES ({values})" - self.execute(query) - - def update(self, table_name, set_clause, condition): - query = f"UPDATE {table_name} SET {set_clause} WHERE {condition}" - self.execute(query) - - def delete(self, table_name, condition): - query = f"DELETE FROM {table_name} WHERE {condition}" - self.execute(query) - - def select(self, table_name, columns="*", condition=None): - query = f"SELECT {columns} FROM {table_name}" - if condition: - query += f" WHERE {condition}" - self.execute(query) - result = self.fetchall() - return result - - def get_count(self, table, condition=None): - """ - returns th enumber of matching rows in the given table based on a specific contioins - """ - query = f"SELECT COUNT(*) FROM {table}" - - if condition: - query += f" WHERE {condition}" - - self.execute(query) - return self.fetchone()[0] - - def close(self): - self.cursor.close() - self.conn.close() - - def fetchall(self): - """ - wrapper for sqlite fetchall to be able to use a lock - """ - with self.cursor_lock: - res = self.cursor.fetchall() - return res - - def fetchone(self): - """ - wrapper for sqlite fetchone to be able to use a lock - """ - with self.cursor_lock: - res = self.cursor.fetchone() - return res - - def execute(self, query: str, params=None) -> None: - """ - wrapper for sqlite execute() To avoid - 'Recursive use of cursors not allowed' error - and to be able to use a Lock() - - since sqlite is terrible with multi-process applications - this function should be used instead of all calls to commit() and - execute() - - using transactions here is a must. - Since slips uses python3.10, we can't use autocommit here. we have - to do it manually - any conn other than the current one will not see the changes this - conn did unless they're committed. - - Each call to this function results in 1 sqlite transaction - """ - trial = 0 - max_trials = 5 - while trial < max_trials: - try: - # note that self.conn.in_transaction is not reliable - # sqlite may change the state internally, on errors for - # example. - # if no errors occur, this will be the only transaction in - # the conn - with self.cursor_lock: - if self.conn.in_transaction is False: - self.cursor.execute("BEGIN") - - if params is None: - self.cursor.execute(query) - else: - self.cursor.execute(query, params) - - # aka END TRANSACTION - if self.conn.in_transaction: - self.conn.commit() - - return - - except sqlite3.Error as err: - # no need to manually rollback here - # sqlite automatically rolls back the tx if an error occurs - trial += 1 - - if trial >= max_trials: - self.print( - f"Error executing query: " - f"({query} {params}) - {err}. " - f"Retried executing {trial} times but failed. " - f"Query discarded.", - 0, - 1, - ) - return - - elif "database is locked" in str(err): - sleep(5) diff --git a/slips_files/core/evidence_handler.py b/slips_files/core/evidence_handler.py index 199177c222..8f7271f2b8 100644 --- a/slips_files/core/evidence_handler.py +++ b/slips_files/core/evidence_handler.py @@ -41,7 +41,7 @@ from slips_files.common.slips_utils import utils from slips_files.core.helpers.whitelist.whitelist import Whitelist from slips_files.core.helpers.notify import Notify -from slips_files.common.abstracts.core import ICore +from slips_files.common.abstracts.icore import ICore from slips_files.core.structures.evidence import ( dict_to_evidence, Evidence, @@ -98,7 +98,7 @@ def init(self): self.jsonfile = self.clean_file(self.output_dir, "alerts.json") utils.change_logfiles_ownership(self.jsonfile.name, self.UID, self.GID) # this list will have our local and public ips when using -i - self.our_ips: List[str] = utils.get_own_ips(ret=List) + self.our_ips: List[str] = utils.get_own_ips(ret="List") self.formatter = EvidenceFormatter(self.db) # thats just a tmp value, this variable will be set and used when # the diff --git a/slips_files/core/helpers/whitelist/domain_whitelist.py b/slips_files/core/helpers/whitelist/domain_whitelist.py index 731df4b327..1cd086455a 100644 --- a/slips_files/core/helpers/whitelist/domain_whitelist.py +++ b/slips_files/core/helpers/whitelist/domain_whitelist.py @@ -3,7 +3,7 @@ from typing import List, Dict import tldextract -from slips_files.common.abstracts.whitelist_analyzer import IWhitelistAnalyzer +from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import ( diff --git a/slips_files/core/helpers/whitelist/ip_whitelist.py b/slips_files/core/helpers/whitelist/ip_whitelist.py index 17fc13edc2..22d2dcd46d 100644 --- a/slips_files/core/helpers/whitelist/ip_whitelist.py +++ b/slips_files/core/helpers/whitelist/ip_whitelist.py @@ -3,7 +3,7 @@ import ipaddress from typing import List, Dict -from slips_files.common.abstracts.whitelist_analyzer import IWhitelistAnalyzer +from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer from slips_files.common.parsers.config_parser import ConfigParser from slips_files.core.structures.evidence import ( Direction, diff --git a/slips_files/core/helpers/whitelist/mac_whitelist.py b/slips_files/core/helpers/whitelist/mac_whitelist.py index 0a3b672eee..126c6a2cc5 100644 --- a/slips_files/core/helpers/whitelist/mac_whitelist.py +++ b/slips_files/core/helpers/whitelist/mac_whitelist.py @@ -4,7 +4,7 @@ import validators -from slips_files.common.abstracts.whitelist_analyzer import IWhitelistAnalyzer +from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer from slips_files.common.parsers.config_parser import ConfigParser from slips_files.core.structures.evidence import ( Direction, diff --git a/slips_files/core/helpers/whitelist/organization_whitelist.py b/slips_files/core/helpers/whitelist/organization_whitelist.py index 3aa86c08f2..c2612238d7 100644 --- a/slips_files/core/helpers/whitelist/organization_whitelist.py +++ b/slips_files/core/helpers/whitelist/organization_whitelist.py @@ -8,7 +8,7 @@ Union, ) -from slips_files.common.abstracts.whitelist_analyzer import IWhitelistAnalyzer +from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import ( diff --git a/slips_files/core/input.py b/slips_files/core/input.py index 36d646d5e7..0e4563dd59 100644 --- a/slips_files/core/input.py +++ b/slips_files/core/input.py @@ -30,7 +30,7 @@ from watchdog.observers import Observer -from slips_files.common.abstracts.core import ICore +from slips_files.common.abstracts.icore import ICore # common imports for all modules from slips_files.common.parsers.config_parser import ConfigParser @@ -142,7 +142,7 @@ def read_nfdump_output(self) -> int: """ A binary file generated by nfcapd can be read by nfdump. The task for this function is to send nfdump output line by line to - performance_profiler.py for processing + iperformance_profiler.py for processing """ if not self.nfdump_output: # The nfdump command returned nothing diff --git a/slips_files/core/input_profilers/argus.py b/slips_files/core/input_profilers/argus.py index f8cc2d458a..27ea9d7273 100644 --- a/slips_files/core/input_profilers/argus.py +++ b/slips_files/core/input_profilers/argus.py @@ -3,7 +3,7 @@ import sys import traceback -from slips_files.common.abstracts.input_type import IInputType +from slips_files.common.abstracts.iinput_type import IInputType from slips_files.common.slips_utils import utils from slips_files.core.flows.argus import ArgusConn diff --git a/slips_files/core/input_profilers/nfdump.py b/slips_files/core/input_profilers/nfdump.py index 016077430c..80d9533d94 100644 --- a/slips_files/core/input_profilers/nfdump.py +++ b/slips_files/core/input_profilers/nfdump.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only -from slips_files.common.abstracts.input_type import IInputType +from slips_files.common.abstracts.iinput_type import IInputType from slips_files.common.slips_utils import utils from slips_files.core.flows.nfdump import NfdumpConn diff --git a/slips_files/core/input_profilers/suricata.py b/slips_files/core/input_profilers/suricata.py index e5ed50c77d..ca2eeb592c 100644 --- a/slips_files/core/input_profilers/suricata.py +++ b/slips_files/core/input_profilers/suricata.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import json -from slips_files.common.abstracts.input_type import IInputType +from slips_files.common.abstracts.iinput_type import IInputType from slips_files.common.slips_utils import utils from slips_files.core.flows.suricata import ( SuricataFlow, diff --git a/slips_files/core/input_profilers/zeek.py b/slips_files/core/input_profilers/zeek.py index 54dab22a1c..21b661957c 100644 --- a/slips_files/core/input_profilers/zeek.py +++ b/slips_files/core/input_profilers/zeek.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only from re import split from typing import Dict -from slips_files.common.abstracts.input_type import IInputType +from slips_files.common.abstracts.iinput_type import IInputType from slips_files.common.slips_utils import utils from slips_files.core.flows.zeek import ( Conn, diff --git a/slips_files/core/output.py b/slips_files/core/output.py index 32c657ceef..edfd8d3029 100644 --- a/slips_files/core/output.py +++ b/slips_files/core/output.py @@ -21,7 +21,7 @@ from datetime import datetime import os -from slips_files.common.abstracts.observer import IObserver +from slips_files.common.abstracts.iobserver import IObserver from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils from slips_files.common.style import red, yellow diff --git a/slips_files/core/profiler.py b/slips_files/core/profiler.py index 5ff049c8fb..4619f940b1 100644 --- a/slips_files/core/profiler.py +++ b/slips_files/core/profiler.py @@ -33,10 +33,10 @@ import netifaces import validators from ipaddress import IPv4Network, IPv6Network, IPv4Address, IPv6Address -from slips_files.common.abstracts.observer import IObservable +from slips_files.common.abstracts.iobserver import IObservable from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils -from slips_files.common.abstracts.core import ICore +from slips_files.common.abstracts.icore import ICore from slips_files.common.style import green from slips_files.core.helpers.flow_handler import FlowHandler from slips_files.core.helpers.symbols_handler import SymbolHandler diff --git a/tests/module_factory.py b/tests/module_factory.py index 2ce30d091f..5493cf71d0 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -13,6 +13,7 @@ from managers.host_ip_manager import HostIPManager from managers.metadata_manager import MetadataManager from managers.profilers_manager import ProfilersManager +from modules.arp.filter import ARPEvidenceFilter from modules.arp_poisoner.arp_poisoner import ARPPoisoner from modules.blocking.unblocker import Unblocker from modules.flowalerts.conn import Conn @@ -100,8 +101,7 @@ def check_zeek_or_bro(): return False -MODULE_DB_MANAGER = "slips_files.common.abstracts.module.DBManager" -# CORE_DB_MANAGER = "slips_files.common.abstracts.core.DBManager" +MODULE_DB_MANAGER = "slips_files.common.abstracts.imodule.DBManager" DB_MANAGER = "slips_files.core.database.database_manager.DBManager" @@ -132,10 +132,12 @@ def create_db_manager_obj( "RedisDB._set_redis_options", return_value=Mock(), ): + conf = Mock() db = DBManager( self.logger, output_dir, port, + conf, flush_db=flush_db, start_sqlite=False, start_redis_server=start_redis_server, @@ -163,6 +165,7 @@ def create_http_analyzer_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) # override the self.print function to avoid broken pipes @@ -177,6 +180,7 @@ def create_fidesModule_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) # override the self.print function @@ -191,6 +195,7 @@ def create_virustotal_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) virustotal.print = Mock() virustotal.__read_configuration = Mock() @@ -207,10 +212,17 @@ def create_arp_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) arp.print = Mock() + arp.evidence_filter.is_slips_peer = Mock(return_value=False) return arp + @patch(MODULE_DB_MANAGER, name="mock_db") + def create_arp_filter_obj(self, mock_db): + filter = ARPEvidenceFilter(Mock(), Mock(), mock_db) # conf # args + return filter + @patch(MODULE_DB_MANAGER, name="mock_db") def create_blocking_obj(self, mock_db): blocking = Blocking( @@ -219,6 +231,7 @@ def create_blocking_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) # override the print function to avoid broken pipes blocking.print = Mock() @@ -246,6 +259,7 @@ def create_flowalerts_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) # override the self.print function to avoid broken pipes @@ -306,8 +320,9 @@ def create_input_obj( Output(), "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # args is_input_done=Mock(), profiler_queue=self.profiler_queue, input_type=input_type, @@ -334,8 +349,9 @@ def create_ip_info_obj(self, mock_db): self.logger, "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf ) # override the self.print function to avoid broken pipes ip_info.print = Mock() @@ -356,8 +372,9 @@ def create_leak_detector_obj(self, mock_db): self.logger, "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf ) leak_detector.print = Mock() # this is the path containing 1 yara rule for testing, @@ -373,8 +390,9 @@ def create_profiler_obj(self, mock_db): self.logger, "output/", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf is_profiler_done=Mock(), profiler_queue=self.input_queue, is_profiler_done_event=Mock(), @@ -415,8 +433,9 @@ def create_threatintel_obj(self, mock_db): self.logger, "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf ) # override the self.print function to avoid broken pipes @@ -433,8 +452,9 @@ def create_update_manager_obj(self, mock_db): self.logger, "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf ) # override the self.print function to avoid broken pipes update_manager.print = Mock() @@ -549,8 +569,9 @@ def create_network_discovery_obj(self, mock_db): self.logger, "dummy_output_dir", 6379, - Mock(), # args Mock(), # termination event + Mock(), # args + Mock(), # conf ) return network_discovery @@ -562,6 +583,7 @@ def create_arp_poisoner_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) return poisoner @@ -625,10 +647,11 @@ def create_trust_db_obj(self, sqlite_mock): trust_db.print = Mock() return trust_db - def create_base_model_obj(self): + @patch(MODULE_DB_MANAGER, name="mock_db") + def create_base_model_obj(self, mock_db): logger = Mock(spec=Output) trustdb = Mock() - return BaseModel(logger, trustdb) + return BaseModel(logger, trustdb, mock_db) def create_notify_obj(self): notify = Notify() @@ -652,6 +675,7 @@ def create_cesnet_obj(self, mock_db): redis_port, Mock(), # termination event Mock(), # args + Mock(), # conf ) cesnet.db = mock_db cesnet.wclient = MagicMock() @@ -673,6 +697,7 @@ def create_evidence_handler_obj(self, mock_db): redis_port, Mock(), # termination event Mock(), # args + Mock(), # conf ) handler.db = mock_db return handler @@ -695,6 +720,7 @@ def create_riskiq_obj(self, mock_db): 6379, Mock(), # termination event Mock(), # args + Mock(), # conf ) riskiq.db = mock_db return riskiq @@ -710,6 +736,7 @@ def create_timeline_object(self, mock_db): redis_port, Mock(), # termination event Mock(), # args + Mock(), # conf ) tl.db = mock_db return tl diff --git a/tests/test_arp_filter.py b/tests/test_arp_filter.py new file mode 100644 index 0000000000..b0ec4d44df --- /dev/null +++ b/tests/test_arp_filter.py @@ -0,0 +1,71 @@ +import pytest +from unittest.mock import Mock, patch + +from modules.arp.filter import ARPEvidenceFilter +from tests.module_factory import ModuleFactory + + +@pytest.mark.parametrize( + "p2p_enabled, is_private, peer_trust, expected", + [ + (True, True, 0.5, True), + (True, True, 0.2, False), + (True, False, 0.5, False), + (False, True, 0.5, False), + (True, True, None, False), + ], +) +def test_is_slips_peer(p2p_enabled, is_private, peer_trust, expected): + arp = ModuleFactory().create_arp_filter_obj() + with patch( + "slips_files.core.profiler.utils.is_private_ip", + return_value=is_private, + ), patch.object( + arp.db, + "get_peer_trust", + return_value=peer_trust, + ): + arp.p2p_enabled = p2p_enabled + assert arp.is_slips_peer("192.168.1.100") == expected + + +@pytest.mark.parametrize( + "ip, our_ips, blocking, has_poisoner, expected", + [ + ("192.168.1.10", ["192.168.1.10"], True, True, True), + ("192.168.1.10", ["192.168.1.10"], True, False, False), + ("192.168.1.10", ["192.168.1.20"], True, True, False), + ("192.168.1.10", ["192.168.1.10"], False, True, False), + ], +) +def test_is_self_defense(ip, our_ips, blocking, has_poisoner, expected): + db = Mock() + db.get_pids.return_value = {"ARP Poisoner": 123} if has_poisoner else {} + args = Mock() + args.blocking = blocking + + arp = ARPEvidenceFilter(conf=Mock(), slips_args=args, db=db) + arp.our_ips = our_ips + + assert arp.is_self_defense(ip) == expected + + +@pytest.mark.parametrize( + "is_slips_peer_return, is_self_defense_return, expected_result", + [ + (False, False, False), + (True, False, True), + (False, True, True), + ], +) +def test_should_discard_evidence_combines_both_checks( + is_slips_peer_return, is_self_defense_return, expected_result +): + arp = ModuleFactory().create_arp_filter_obj() + with patch.object( + arp, "is_slips_peer", return_value=is_slips_peer_return + ), patch.object( + arp, "is_self_defense", return_value=is_self_defense_return + ): + result = arp.should_discard_evidence("1.2.3.4") + assert result == expected_result diff --git a/tests/test_base_model.py b/tests/test_base_model.py index 8bbe9a1f9c..091e4db26e 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -78,19 +78,22 @@ def test_compute_peer_trust(reliability, score, confidence, expected_trust): "data, expected_score, expected_confidence", [ # testcase1: assemble opinion with one report - ([(0.8, 0.9, 0.7, 0.8, 0.9)], 0.8, 0.5445), + ([(0.8, 0.9, 0.7, 0.8, 0.9, "192.168.1.2")], 0.8, 0.5445), # testcase2: assemble opinion with multiple reports ( - [(0.6, 0.7, 0.8, 0.7, 0.8), (0.7, 0.8, 0.9, 0.8, 0.9)], + [ + (0.6, 0.7, 0.8, 0.7, 0.8, "192.168.1.2"), + (0.7, 0.8, 0.9, 0.8, 0.9, "192.168.1.3"), + ], 0.6517774343122101, 0.46599999999999997, ), # testcase3: assemble opinion with diverse reports ( [ - (0.9, 0.8, 0.6, 0.7, 0.8), - (0.5, 0.6, 0.9, 0.8, 0.7), - (0.3, 0.4, 0.7, 0.6, 0.5), + (0.9, 0.8, 0.6, 0.7, 0.8, "192.168.1.2"), + (0.5, 0.6, 0.9, 0.8, 0.7, "192.168.1.3"), + (0.3, 0.4, 0.7, 0.6, 0.5, "192.168.1.4"), ], 0.5707589285714285, 0.30233333333333334, diff --git a/tests/test_go_director.py b/tests/test_go_director.py index 8230d79e1e..eb8ed9e92c 100644 --- a/tests/test_go_director.py +++ b/tests/test_go_director.py @@ -484,8 +484,6 @@ def test_process_evaluation_score_confidence_valid(): with patch.object(go_director, "print") as mock_print, patch.object( go_director.trustdb, "insert_new_go_report" ) as mock_insert, patch.object( - go_director.db, "store_p2p_report" - ) as mock_store, patch.object( go_director.db, "add_profile" ) as mock_add_profile, patch.object( go_director, "set_evidence_p2p_report" @@ -496,7 +494,6 @@ def test_process_evaluation_score_confidence_valid(): mock_print.assert_called_with(expected_result, 2, 0) mock_insert.assert_called_once() - mock_store.assert_called_once() mock_add_profile.assert_called_once() mock_set_evidence.assert_called_once() diff --git a/tests/test_process_manager.py b/tests/test_process_manager.py index bdc0b08a74..a8fe335c9e 100644 --- a/tests/test_process_manager.py +++ b/tests/test_process_manager.py @@ -49,6 +49,7 @@ def test_start_input_process( process_manager.main.redis_port, process_manager.termination_event, process_manager.main.args, + process_manager.main.conf, is_input_done=process_manager.is_input_done, profiler_queue=process_manager.profiler_queue, input_type=input_type, @@ -404,6 +405,7 @@ def test_start_profiler_process(): process_manager.main.redis_port, process_manager.termination_event, process_manager.main.args, + process_manager.main.conf, is_profiler_done=process_manager.is_profiler_done, profiler_queue=process_manager.profiler_queue, is_profiler_done_event=process_manager.is_profiler_done_event, @@ -443,6 +445,7 @@ def test_start_evidence_process(output_dir, redis_port): redis_port, process_manager.evidence_handler_termination_event, process_manager.main.args, + process_manager.main.conf, ) mock_evidence_process.start.assert_called_once() process_manager.main.print.assert_called_once() diff --git a/tests/test_profile_handler.py b/tests/test_profile_handler.py index e2a7625eb3..222c4fdd1c 100644 --- a/tests/test_profile_handler.py +++ b/tests/test_profile_handler.py @@ -801,14 +801,9 @@ def test_add_new_tw( expected_update_threat_level_call, ): handler = ModuleFactory().create_profile_handler_obj() - handler.update_threat_level = MagicMock() - handler.add_new_tw(profileid, timewindow, startoftw) handler.r.zadd.assert_called_once_with(*expected_zadd_call.args) - handler.update_threat_level.assert_called_once_with( - *expected_update_threat_level_call.args - ) @pytest.mark.parametrize( @@ -2408,7 +2403,6 @@ def test_add_profile_new_profile(): handler.set_new_ip = MagicMock() handler.publish = MagicMock() - handler.update_threat_level = MagicMock() profileid = "profile_1" starttime = 1678886400.0 @@ -2430,9 +2424,6 @@ def test_add_profile_new_profile(): ip = profileid.split(handler.separator)[1] handler.set_new_ip.assert_called_once_with(ip) handler.publish.assert_called_once_with("new_profile", ip) - handler.update_threat_level.assert_called_once_with( - profileid, "info", 0.05 - ) def test_add_profile_existing_profile(): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 0369d5b717..ec086e4b41 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only -"""Unit test for slips_files/core/performance_profiler.py""" +"""Unit test for slips_files/core/iperformance_profiler.py""" from unittest.mock import Mock diff --git a/tests/test_slips_utils.py b/tests/test_slips_utils.py index 9c467da833..e29154c05a 100644 --- a/tests/test_slips_utils.py +++ b/tests/test_slips_utils.py @@ -10,7 +10,6 @@ import pytz import json from collections import namedtuple -from typing import List def test_get_sha256_hash(): @@ -293,14 +292,14 @@ def _check_ip_presence(utils, expected_ip): in the list of own IPs. """ return ( - expected_ip in utils.get_own_ips(ret=List) or not utils.get_own_ips() + expected_ip in utils.get_own_ips(ret="List") or not utils.get_own_ips() ) def test_get_own_ips_success(): """Test that the function returns a list when successful.""" utils = ModuleFactory().create_utils_obj() - ips = utils.get_own_ips(ret=List) + ips = utils.get_own_ips(ret="List") assert isinstance(ips, list), "Should return a list of IPs" diff --git a/tests/test_trustdb.py b/tests/test_trustdb.py index 2ee9040490..d4ad2bb24b 100644 --- a/tests/test_trustdb.py +++ b/tests/test_trustdb.py @@ -9,7 +9,10 @@ ) from tests.module_factory import ModuleFactory import datetime -import time + + +def normalize_sql(sql): + return " ".join(sql.strip().split()) @pytest.mark.parametrize( @@ -33,8 +36,8 @@ ) def test_delete_tables(existing_tables): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.execute.side_effect = lambda query: ( + trust_db.execute = Mock() + trust_db.execute.side_effect = lambda query: ( None if query.startswith("DROP TABLE") else ["table"] ) trust_db.conn.fetchall = Mock() @@ -49,7 +52,7 @@ def test_delete_tables(existing_tables): ] trust_db.delete_tables() - assert trust_db.conn.execute.call_args_list == expected_calls + assert trust_db.execute.call_args_list == expected_calls @pytest.mark.parametrize( @@ -78,8 +81,8 @@ def test_get_cached_network_opinion( expected_result, ): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.execute.return_value.fetchone.return_value = fetchone_result + trust_db.fetchone = Mock() + trust_db.fetchone.return_value = fetchone_result result = trust_db.get_cached_network_opinion(key_type, reported_key) assert result == expected_result @@ -123,106 +126,11 @@ def test_update_cached_network_opinion( expected_params, ): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.commit = Mock() + trust_db.execute = Mock() trust_db.update_cached_network_opinion( key_type, reported_key, score, confidence, network_score ) - trust_db.conn.execute.assert_called_once_with( - expected_query, expected_params - ) - trust_db.conn.commit.assert_called_once() - - -@pytest.mark.parametrize( - "reports, expected_calls", - [ - ( - # Testcase 1: Single report - [ - ( - "reporter_1", - "ip", - "192.168.1.1", - 0.5, - 0.8, - 1678886400, # Fixed timestamp - ) - ], - [ - call( - "INSERT INTO reports " - "(reporter_peerid, key_type, reported_key, " - "score, confidence, update_time) " - "VALUES (?, ?, ?, ?, ?, ?)", - [ - ( - "reporter_1", - "ip", - "192.168.1.1", - 0.5, - 0.8, - 1678886400, - ) - ], - ) - ], - ), - ( - # Testcase 2: Multiple reports - [ - ( - "reporter_1", - "ip", - "192.168.1.1", - 0.5, - 0.8, - 1678886400, - ), - ( - "reporter_2", - "peerid", - "another_peer", - 0.3, - 0.6, - 1678886500, - ), - ], - [ - call( - "INSERT INTO reports " - "(reporter_peerid, key_type, reported_key, " - "score, confidence, update_time) " - "VALUES (?, ?, ?, ?, ?, ?)", - [ - ( - "reporter_1", - "ip", - "192.168.1.1", - 0.5, - 0.8, - 1678886400, - ), - ( - "reporter_2", - "peerid", - "another_peer", - 0.3, - 0.6, - 1678886500, - ), - ], - ) - ], - ), - ], -) -def test_insert_new_go_data(reports, expected_calls): - trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.executemany = Mock() - trust_db.insert_new_go_data(reports) - trust_db.conn.executemany.assert_has_calls(expected_calls) - assert trust_db.conn.executemany.call_count == len(expected_calls) + trust_db.execute.assert_called_once_with(expected_query, expected_params) @pytest.mark.parametrize( @@ -249,15 +157,11 @@ def test_insert_new_go_data(reports, expected_calls): ) def test_insert_go_ip_pairing(peerid, ip, timestamp, expected_params): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.commit = Mock() + trust_db.insert = Mock() trust_db.insert_go_ip_pairing(peerid, ip, timestamp) - trust_db.conn.execute.assert_called_once_with( - "INSERT INTO peer_ips (ipaddress, peerid, " - "update_time) VALUES (?, ?, ?);", - expected_params, + trust_db.insert.assert_called_once_with( + "peer_ips", (ip, peerid, timestamp), "ipaddress, peerid, update_time" ) - trust_db.conn.commit.assert_called_once() @pytest.mark.parametrize( @@ -273,21 +177,15 @@ def test_insert_slips_score( ip, score, confidence, timestamp, expected_timestamp ): trust_db = ModuleFactory().create_trust_db_obj() - with patch.object(time, "time", return_value=time.time()) as mock_time: - trust_db.insert_slips_score(ip, score, confidence, timestamp) - expected_params = ( - ip, - score, - confidence, - timestamp or mock_time.return_value, - ) - - trust_db.conn.execute.assert_called_once_with( - "INSERT INTO slips_reputation (ipaddress, score, confidence, " - "update_time) VALUES (?, ?, ?, ?);", - expected_params, - ) - trust_db.conn.commit.assert_called_once() + trust_db.execute = Mock() + trust_db.insert_slips_score(ip, score, confidence, timestamp) + actual_call = trust_db.execute.call_args + actual_sql = normalize_sql(actual_call[0][0]) + expected_sql = normalize_sql( + "INSERT OR REPLACE INTO slips_reputation (ipaddress, score, " + "confidence, update_time) VALUES (?, ?, ?, ?)" + ) + assert actual_sql == expected_sql @pytest.mark.parametrize( @@ -307,20 +205,13 @@ def test_insert_go_reliability( datetime, "datetime", wraps=datetime.datetime ) as mock_datetime: mock_datetime.now.return_value = expected_timestamp + trust_db.insert = Mock() trust_db.insert_go_reliability(peerid, reliability, timestamp) - - expected_params = ( - peerid, - reliability, - timestamp or expected_timestamp, - ) - - trust_db.conn.execute.assert_called_once_with( - "INSERT INTO go_reliability (peerid, reliability, " - "update_time) VALUES (?, ?, ?);", - expected_params, + trust_db.insert.assert_called_once_with( + "go_reliability", + (peerid, reliability, expected_timestamp), + "peerid, reliability, update_time", ) - trust_db.conn.commit.assert_called_once() @pytest.mark.parametrize( @@ -338,68 +229,51 @@ def test_insert_go_reliability( ) def test_get_ip_of_peer(peerid, fetchone_result, expected_result): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.execute.return_value.fetchone.return_value = fetchone_result + trust_db.select = Mock() + trust_db.select.return_value = fetchone_result result = trust_db.get_ip_of_peer(peerid) assert result == expected_result def test_create_tables(): trust_db = ModuleFactory().create_trust_db_obj() + trust_db.create_table = Mock() + + trust_db.create_tables() expected_calls = [ - call( - "CREATE TABLE IF NOT EXISTS slips_reputation (" - "id INTEGER PRIMARY KEY NOT NULL, " - "ipaddress TEXT NOT NULL, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "update_time REAL NOT NULL);" + ( + "slips_reputation", + "id INTEGER PRIMARY KEY NOT NULL, ipaddress TEXT NOT NULL, score REAL NOT NULL, confidence REAL NOT NULL, update_time REAL NOT NULL", ), - call( - "CREATE TABLE IF NOT EXISTS go_reliability (" - "id INTEGER PRIMARY KEY NOT NULL, " - "peerid TEXT NOT NULL, " - "reliability REAL NOT NULL, " - "update_time REAL NOT NULL);" + ( + "go_reliability", + "id INTEGER PRIMARY KEY NOT NULL, peerid TEXT NOT NULL, reliability REAL NOT NULL, update_time REAL NOT NULL", ), - call( - "CREATE TABLE IF NOT EXISTS peer_ips (" - "id INTEGER PRIMARY KEY NOT NULL, " - "ipaddress TEXT NOT NULL, " - "peerid TEXT NOT NULL, " - "update_time REAL NOT NULL);" + ( + "peer_ips", + "id INTEGER PRIMARY KEY NOT NULL, ipaddress TEXT NOT NULL, peerid TEXT NOT NULL, update_time REAL NOT NULL", ), - call( - "CREATE TABLE IF NOT EXISTS reports (" - "id INTEGER PRIMARY KEY NOT NULL, " - "reporter_peerid TEXT NOT NULL, " - "key_type TEXT NOT NULL, " - "reported_key TEXT NOT NULL, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "update_time REAL NOT NULL);" + ( + "reports", + "id INTEGER PRIMARY KEY NOT NULL, reporter_peerid TEXT NOT NULL, key_type TEXT NOT NULL, reported_key TEXT NOT NULL, score REAL NOT NULL, confidence REAL NOT NULL, update_time REAL NOT NULL", ), - call( - "CREATE TABLE IF NOT EXISTS opinion_cache (" - "key_type TEXT NOT NULL, " - "reported_key TEXT NOT NULL PRIMARY KEY, " - "score REAL NOT NULL, " - "confidence REAL NOT NULL, " - "network_score REAL NOT NULL, " - "update_time DATE NOT NULL);" + ( + "opinion_cache", + "key_type TEXT NOT NULL, reported_key TEXT NOT NULL PRIMARY KEY, score REAL NOT NULL, confidence REAL NOT NULL, network_score REAL NOT NULL, update_time DATE NOT NULL", ), ] - trust_db.conn.execute = Mock() - trust_db.create_tables() - trust_db.conn.execute.assert_has_calls(expected_calls, any_order=True) + + for table, schema in expected_calls: + trust_db.create_table.assert_any_call(table, schema) + + assert trust_db.create_table.call_count == len(expected_calls) @pytest.mark.parametrize( "reporter_peerid, key_type, reported_key, score, confidence, " "timestamp, expected_query, expected_params", [ - # Testcase 1: Using provided timestamp ( "peer_123", "ip", @@ -411,7 +285,6 @@ def test_create_tables(): "score, confidence, update_time) VALUES (?, ?, ?, ?, ?, ?)", ("peer_123", "ip", "192.168.1.1", 0.8, 0.9, 1678887000), ), - # Testcase 2: Using current time as timestamp ( "another_peer", "peerid", @@ -436,9 +309,19 @@ def test_insert_new_go_report( expected_params, ): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.commit = Mock() - with patch("time.time", return_value=1678887000.0): + trust_db.insert = Mock() + + if timestamp is None: + with patch("time.time", return_value=1678887000.0): + trust_db.insert_new_go_report( + reporter_peerid, + key_type, + reported_key, + score, + confidence, + timestamp, + ) + else: trust_db.insert_new_go_report( reporter_peerid, key_type, @@ -447,21 +330,18 @@ def test_insert_new_go_report( confidence, timestamp, ) - trust_db.conn.execute.assert_called_once() - actual_query, actual_params = trust_db.conn.execute.call_args[0] - assert actual_query == expected_query - assert actual_params[:-1] == expected_params[:-1] - assert isinstance(actual_params[-1], (float, int)) - assert abs(actual_params[-1] - expected_params[-1]) < 0.001 - trust_db.conn.commit.assert_called_once() + + trust_db.insert.assert_called_once_with( + "reports", + expected_params, + "reporter_peerid, key_type, reported_key, score, confidence, update_time", + ) @pytest.mark.parametrize( "ipaddress, expected_reports", [ - # Testcase 1: No reports for the IP ("192.168.1.1", []), - # Testcase 2: One report ( "192.168.1.1", [ @@ -474,7 +354,6 @@ def test_insert_new_go_report( ) ], ), - # Testcase 3: Multiple reports ( "192.168.1.1", [ @@ -498,26 +377,36 @@ def test_insert_new_go_report( ) def test_get_reports_for_ip(ipaddress, expected_reports): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.execute.return_value.fetchall.return_value = expected_reports - reports = trust_db.get_reports_for_ip(ipaddress) - assert reports == expected_reports + trust_db.select = Mock(return_value=expected_reports) + + result = trust_db.get_reports_for_ip(ipaddress) + + trust_db.select.assert_called_once_with( + table_name="reports", + columns="reporter_peerid, update_time, score, confidence, reported_key", + condition="reported_key = ? AND key_type = ?", + params=(ipaddress, "ip"), + ) + + assert result == expected_reports @pytest.mark.parametrize( "reporter_peerid, expected_reliability", [ - # Testcase 1: Reliability found for reporter ("reporter_1", 0.7), - # Testcase 2: No reliability found for reporter ("unknown_reporter", None), ], ) def test_get_reporter_reliability(reporter_peerid, expected_reliability): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute.return_value.fetchone.return_value = ( - expected_reliability, - ) + trust_db.select = Mock() + + if expected_reliability is not None: + trust_db.select.return_value = (expected_reliability,) + else: + trust_db.select.return_value = [] + reliability = trust_db.get_reporter_reliability(reporter_peerid) assert reliability == expected_reliability @@ -525,9 +414,7 @@ def test_get_reporter_reliability(reporter_peerid, expected_reliability): @pytest.mark.parametrize( "reporter_ipaddress, expected_score, expected_confidence", [ - # Testcase 1: Reputation found for reporter ("192.168.1.2", 0.6, 0.9), - # Testcase 2: No reputation found for reporter ("unknown_ip", None, None), ], ) @@ -535,13 +422,18 @@ def test_get_reporter_reputation( reporter_ipaddress, expected_score, expected_confidence ): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute.return_value.fetchone.return_value = ( - expected_score, - expected_confidence, - ) - score, confidence = trust_db.get_reporter_reputation(reporter_ipaddress) - assert score == expected_score - assert confidence == expected_confidence + + with patch.object(trust_db, "select") as mock_select: + if expected_score is not None: + mock_select.return_value = (expected_score, expected_confidence) + else: + mock_select.return_value = None + + score, confidence = trust_db.get_reporter_reputation( + reporter_ipaddress + ) + assert score == expected_score + assert confidence == expected_confidence @pytest.mark.parametrize( @@ -557,8 +449,8 @@ def test_get_reporter_ip( reporter_peerid, report_timestamp, fetchone_result, expected_ip ): trust_db = ModuleFactory().create_trust_db_obj() - trust_db.conn.execute = Mock() - trust_db.conn.execute.return_value.fetchone.return_value = fetchone_result + trust_db.fetchone = Mock() + trust_db.fetchone.return_value = fetchone_result ip = trust_db.get_reporter_ip(reporter_peerid, report_timestamp) assert ip == expected_ip @@ -568,20 +460,26 @@ def test_get_reporter_ip( [ # Testcase 1: No reports for the IP ("192.168.1.1", [], []), - # Testcase 2: One report with valid reporter data + # Testcase 2: One report with valid reporter data, but + # reporter_ipaddress == ipaddress: ( "192.168.1.1", + # peerid, ts, score, conf, reported_ip [("reporter_1", 1678886400, 0.5, 0.8, "192.168.1.1")], - [(0.5, 0.8, 0.7, 0.6, 0.9)], + [], ), # Testcase 3: Multiple reports with valid reporter data ( - "192.168.1.1", + "192.168.1.7", + [ + # these 2 ips shouldnt be the same as the reports ips + ("reporter_1", 1678886400, 0.5, 0.8, "192.168.1.3"), + ("reporter_2", 1678886500, 0.3, 0.6, "192.168.1.4"), + ], [ - ("reporter_1", 1678886400, 0.5, 0.8, "192.168.1.1"), - ("reporter_2", 1678886500, 0.3, 0.6, "192.168.1.1"), + (0.5, 0.8, 0.7, 0.6, 0.9, "192.168.1.1"), + (0.3, 0.6, 0.8, 0.4, 0.7, "192.168.1.2"), ], - [(0.5, 0.8, 0.7, 0.6, 0.9), (0.3, 0.6, 0.8, 0.4, 0.7)], ), ], ) @@ -590,7 +488,7 @@ def test_get_opinion_on_ip(ipaddress, reports, expected_result): trust_db.get_reports_for_ip = MagicMock(return_value=reports) trust_db.get_reporter_ip = MagicMock( - side_effect=["192.168.1.2", "192.168.1.3", "192.168.1.2"] + side_effect=["192.168.1.1", "192.168.1.2"] ) trust_db.get_reporter_reliability = MagicMock(side_effect=[0.7, 0.8, 0.7]) trust_db.get_reporter_reputation = MagicMock( diff --git a/tests/test_update_file_manager.py b/tests/test_update_file_manager.py index 0f04435935..faa7bc84aa 100644 --- a/tests/test_update_file_manager.py +++ b/tests/test_update_file_manager.py @@ -44,7 +44,6 @@ def test_check_if_update_based_on_e_tag(mocker): def test_check_if_update_based_on_last_modified( - database, mocker, ): update_manager = ModuleFactory().create_update_manager_obj()