diff --git a/singlestoredb/config.py b/singlestoredb/config.py index 044af2a5..594b06cf 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -9,6 +9,7 @@ from .utils.config import check_float # noqa: F401 from .utils.config import check_int # noqa: F401 from .utils.config import check_optional_bool # noqa: F401 +from .utils.config import check_socket_options # noqa: F401 from .utils.config import check_str # noqa: F401 from .utils.config import check_url # noqa: F401 from .utils.config import describe_option # noqa: F401 @@ -263,6 +264,11 @@ environ='SINGLESTOREDB_FUSION_ENABLED', ) +register_option( + 'socket_options', 'dict', check_socket_options, None, + 'Format for socket options', +) + # # Query results options # diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 942b2feb..9e42c7da 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -1340,6 +1340,7 @@ def connect( vector_data_format: Optional[str] = None, parse_json: Optional[bool] = None, interpolate_query_with_empty_args: Optional[bool] = None, + socket_options: Optional[Dict[int, Dict[int, Any]]] = None, ) -> Connection: """ Return a SingleStoreDB connection. @@ -1428,6 +1429,11 @@ def connect( interpolate_query_with_empty_args : bool, optional Should the connector apply parameter interpolation even when the parameters are empty? This corresponds to pymysql/mysqlclient's handling + socket_options : dict, optional + Socket options to set on the underlying socket. The keys should be + socket level constants (e.g., socket.SOL_SOCKET) and the values should be + dictionaries mapping socket option constants (e.g., socket.SO_KEEPALIVE) to + the desired value for that option. Examples -------- diff --git a/singlestoredb/mysql/connection.py b/singlestoredb/mysql/connection.py index 094fad68..c4fc6417 100644 --- a/singlestoredb/mysql/connection.py +++ b/singlestoredb/mysql/connection.py @@ -230,6 +230,10 @@ class Connection(BaseConnection): Set to true to check the server's identity. tls_sni_servername: str, optional Set server host name for TLS connection + socket_options: Dict[int, Dict[int, any]], optional + A dictionary of socket options to set on the connection. + The keys are the socket level constants (e.g., socket.SOL_SOCKET), + and the values are dictionaries mapping option names to values. read_default_group : str, optional Group to read from in the configuration file. autocommit : bool, optional @@ -341,6 +345,7 @@ def __init__( # noqa: C901 ssl_verify_cert=None, ssl_verify_identity=None, tls_sni_servername=None, + socket_options=None, parse_json=True, invalid_values=None, pure_python=None, @@ -477,7 +482,7 @@ def _config(key, arg): self.collation = collation self.use_unicode = use_unicode self.encoding_errors = encoding_errors - + self._socket_options = socket_options or {} self.encoding = charset_by_name(self.charset).encoding client_flag |= CLIENT.CAPABILITIES @@ -1107,6 +1112,11 @@ def connect(self, sock=None): print('connected using socket') sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + for level, options in self._socket_options.items(): + for opt, value in options.items(): + sock.setsockopt(level, opt, value) + sock.settimeout(None) self._sock = sock diff --git a/singlestoredb/utils/config.py b/singlestoredb/utils/config.py index 6b7cdd9f..bd2cc232 100644 --- a/singlestoredb/utils/config.py +++ b/singlestoredb/utils/config.py @@ -646,6 +646,55 @@ def check_str( return out +def check_socket_options( + value: Any, +) -> Optional[Dict[int, Dict[int, Any]]]: + """ + Validate socket options. + + Parameters + ---------- + value : dict + The value to validate. It must be a dictionary where the keys are + socket level constants (e.g., socket.SOL_SOCKET) and the values are + dictionaries mapping socket option constants (e.g., socket.SO_KEEPALIVE) + to the desired value for that option. + + Returns + ------- + dict + The validated socket options + + """ + if value is None: + return None + + if not isinstance(value, Mapping): + raise ValueError( + 'value {} must be of type dict'.format(value), + ) + + out: dict[int, dict[int, Any]] = {} + for level, options in value.items(): + if not isinstance(level, int): + raise ValueError( + f'keys in {value} must be integers corresponding to socket levels', + ) + if not isinstance(options, Mapping): + raise ValueError( + f'values in {value} must be dicts.', + ) + out[level] = {} + for opt, val in options.items(): + if not isinstance(opt, int): + raise ValueError( + f'keys in sub-dicts of {value} must be integers.', + ) + out[level][opt] = val + + return out + + def check_dict_str_str( value: Any, ) -> Optional[Dict[str, str]]: