11import asyncio
22import paramiko
3+ import shlex
34import time
45import re
56from typing import Tuple
@@ -14,7 +15,8 @@ class SSHInteractiveSession:
1415 # ps1_label = "SSHInteractiveSession CLI>"
1516
1617 def __init__ (
17- self , logger : Log , hostname : str , port : int , username : str , password : str , cwd : str | None = None
18+ self , logger : Log , hostname : str , port : int , username : str , password : str ,
19+ cwd : str | None = None , extra_env : dict | None = None
1820 ):
1921 self .logger = logger
2022 self .hostname = hostname
@@ -28,6 +30,7 @@ def __init__(
2830 self .last_command = b""
2931 self .trimmed_command_length = 0 # Initialize trimmed_command_length
3032 self .cwd = cwd
33+ self .extra_env = extra_env
3134
3235 async def connect (self , keepalive_interval : int = 5 ):
3336 """
@@ -37,7 +40,7 @@ async def connect(self, keepalive_interval: int = 5):
3740 ----------
3841 keepalive_interval : int
3942 Interval in **seconds** between keep-alive packets sent by Paramiko.
40- A value ≤ 0 disables Paramiko’ s keep-alive feature.
43+ A value ≤ 0 disables Paramiko' s keep-alive feature.
4144 """
4245 errors = 0
4346 while True :
@@ -66,6 +69,17 @@ async def connect(self, keepalive_interval: int = 5):
6669 initial_command = "unset PROMPT_COMMAND PS0; stty -echo"
6770 if self .cwd :
6871 initial_command = f"cd { self .cwd } ; { initial_command } "
72+
73+ # When extra_env is provided, prepend export statements so the
74+ # variables are available for the entire session. Values are
75+ # shell-quoted via shlex.quote to prevent injection.
76+ if self .extra_env :
77+ exports = "; " .join (
78+ f"export { k } ={ shlex .quote (str (v ))} "
79+ for k , v in self .extra_env .items ()
80+ )
81+ initial_command = f"{ exports } ; { initial_command } "
82+
6983 self .shell .send (f"{ initial_command } \n " .encode ())
7084
7185 # wait for initial prompt/output to settle
@@ -104,7 +118,7 @@ async def send_command(self, command: str):
104118 self .last_command = command .encode ()
105119 self .trimmed_command_length = 0
106120 self .shell .send (self .last_command )
107-
121+
108122 async def read_output (
109123 self , timeout : float = 0 , reset_full_output : bool = False
110124 ) -> Tuple [str , str ]:
@@ -138,7 +152,7 @@ async def read_output(
138152 # deviation_threshold=8,
139153 # deviation_reset=2,
140154 # ignore_patterns=[
141- # rb"\x1b\ [\?\d{4}[a-zA-Z](?:> )?", # ANSI escape sequences
155+ # rb"\[\?\d{4}[a-zA-Z](?:> )?", # ANSI escape sequences
142156 # rb"\r", # Carriage return
143157 # rb">\s", # Greater-than symbol
144158 # ],
@@ -212,13 +226,14 @@ def recv_all(num_bytes):
212226
213227 return data
214228
229+
215230def clean_string (input_string ):
216231 # Remove ANSI escape codes
217- ansi_escape = re .compile (r"\x1B (?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])" )
232+ ansi_escape = re .compile (r"\x1b (?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])" )
218233 cleaned = ansi_escape .sub ("" , input_string )
219234
220235 # remove null bytes
221- cleaned = cleaned .replace ("\x00 " , "" )
236+ cleaned = cleaned .replace ("" , "" )
222237
223238 # remove ipython \r\r\n> sequences from the start
224239 cleaned = re .sub (r'^[ \r]*(?:\r*\n>[ \r]*)*' , '' , cleaned )
0 commit comments