2323from utils .filter import filter_process
2424from utils .shortcuts import shortcuts_process
2525from utils .cancel_inspect import cancel_inspect_process
26+ from utils .custom_header import custom_header_process
2627from utils .logger import configure_file_logger , configure_console_logger
2728
2829class ProxyServer :
@@ -35,7 +36,7 @@ class ProxyServer:
3536 # pylint: disable=too-many-locals
3637 def __init__ (self , host , port , debug , access_log , block_log ,
3738 html_403 , no_filter , filter_mode , no_logging_access , no_logging_block , ssl_inspect ,
38- blocked_sites , blocked_url , shortcuts , inspect_ca_cert ,
39+ blocked_sites , blocked_url , shortcuts , custom_header , inspect_ca_cert ,
3940 inspect_ca_key , inspect_certs_folder , cancel_inspect ):
4041 """
4142 Initializes the ProxyServer instance with the provided configurations.
@@ -57,10 +58,14 @@ def __init__(self, host, port, debug, access_log, block_log,
5758 self .cancel_inspect_proc = None
5859 self .cancel_inspect_queue = multiprocessing .Queue ()
5960 self .cancel_inspect_result_queue = multiprocessing .Queue ()
61+ self .custom_header_proc = None
62+ self .custom_header_queue = multiprocessing .Queue ()
63+ self .custom_header_result_queue = multiprocessing .Queue ()
6064 self .console_logger = configure_console_logger ()
6165 self .config_blocked_sites = blocked_sites
6266 self .config_blocked_url = blocked_url
6367 self .config_shortcuts = shortcuts
68+ self .config_custom_header = custom_header
6469 self .config_cancel_inspect = cancel_inspect
6570 self .config_inspect_cert = inspect_ca_cert
6671 self .config_inspect_key = inspect_ca_key
@@ -91,6 +96,7 @@ def start(self):
9196 self .console_logger .debug ("[*] blocked_sites = %s" , self .config_blocked_sites )
9297 self .console_logger .debug ("[*] blocked_url = %s" , self .config_blocked_url )
9398 self .console_logger .debug ("[*] shortcuts = %s" , self .config_shortcuts )
99+ self .console_logger .debug ("[*] custom_header = %s" , self .config_custom_header )
94100 self .console_logger .debug ("[*] inspect_ca_cert = %s" , self .config_inspect_cert )
95101 self .console_logger .debug ("[*] inspect_ca_key = %s" , self .config_inspect_key )
96102 self .console_logger .debug (
@@ -162,6 +168,18 @@ def start(self):
162168 self .cancel_inspect_proc .start ()
163169 self .console_logger .debug ("[*] Starting the cancel inspection process..." )
164170
171+ if self .config_custom_header and os .path .isfile (self .config_custom_header ):
172+ self .custom_header_proc = multiprocessing .Process (
173+ target = custom_header_process ,
174+ args = (
175+ self .custom_header_queue ,
176+ self .custom_header_result_queue ,
177+ self .config_custom_header
178+ )
179+ )
180+ self .custom_header_proc .start ()
181+ self .console_logger .debug ("[*] Starting the custom header process..." )
182+
165183 server = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
166184 server .bind (self .host_port )
167185 server .listen (10 )
@@ -213,6 +231,13 @@ def handle_http_request(self, client_socket, request):
213231 first_line = request .decode (errors = 'ignore' ).split ("\n " )[0 ]
214232 url = first_line .split (" " )[1 ]
215233
234+ if self .config_custom_header and os .path .isfile (self .config_custom_header ):
235+ headers = self .extract_headers (request .decode (errors = 'ignore' ))
236+ self .custom_header_queue .put (url )
237+ new_headers = self .custom_header_result_queue .get ()
238+ headers .update (new_headers )
239+ print ("headers" , headers )
240+
216241 if self .config_shortcuts :
217242 domain , _ = self .parse_url (url )
218243 self .shortcuts_queue .put (domain )
@@ -259,7 +284,26 @@ def handle_http_request(self, client_socket, request):
259284 f"http://{ server_host } " ,
260285 first_line
261286 )
262- self .forward_request_to_server (client_socket , request , url )
287+
288+ if self .config_custom_header and os .path .isfile (self .config_custom_header ):
289+ request_lines = request .decode (errors = 'ignore' ).split ("\r \n " )
290+ request_line = request_lines [0 ] # GET / HTTP/1.1
291+
292+ header_lines = [f"{ key } : { value } " for key , value in headers .items ()]
293+ reconstructed_headers = "\r \n " .join (header_lines )
294+
295+ if "\r \n \r \n " in request .decode (errors = 'ignore' ):
296+ body = request .decode (errors = 'ignore' ).split ("\r \n \r \n " , 1 )[1 ]
297+ else :
298+ body = ""
299+
300+ modified_request = f"{ request_line } \r \n { reconstructed_headers } \r \n \r \n { body } " .encode ()
301+
302+ # 5. Envoyer au serveur
303+ self .forward_request_to_server (client_socket , modified_request , url )
304+
305+ else :
306+ self .forward_request_to_server (client_socket , request , url )
263307
264308 def forward_request_to_server (self , client_socket , request , url ):
265309 """
@@ -320,6 +364,24 @@ def parse_url(self, url):
320364
321365 return server_host , server_port
322366
367+ def extract_headers (self , request_str ):
368+ """
369+ Extracts the HTTP headers from a raw HTTP request string.
370+
371+ Args:
372+ request_str (str): The full HTTP request as a decoded string.
373+
374+ Returns:
375+ dict: A dictionary containing the HTTP header fields as key-value pairs.
376+ """
377+ headers = {}
378+ lines = request_str .split ("\n " )[1 :]
379+ for line in lines :
380+ if line .strip ():
381+ key , value = line .split (":" , 1 )
382+ headers [key .strip ()] = value .strip ()
383+ return headers
384+
323385 # pylint: disable=too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks
324386 def handle_https_connection (self , client_socket , first_line ):
325387 """
0 commit comments