1111from twisted .internet .protocol import connectionDone , Protocol
1212
1313from pyrdp .core import ObservedBy
14+ from pyrdp .core .ssl import ServerTLSContext
1415from pyrdp .layer .layer import IntermediateLayer , LayerObserver
1516from pyrdp .logging import LOGGER_NAMES , getSSLLogger
1617from pyrdp .parser .tcp import TCPParser
1718from pyrdp .pdu import PDU
19+ from pyrdp .mitm import MITMConfig
20+
21+
22+ TLS_RECORD = 0x16
1823
1924
2025class TCPObserver (LayerObserver ):
@@ -41,17 +46,21 @@ class TwistedTCPLayer(IntermediateLayer, Protocol):
4146 TCP observers are notified when a connection is made.
4247 """
4348
44- def __init__ (self ):
49+ def __init__ (self , config : MITMConfig ):
4550 self .log = logging .getLogger (LOGGER_NAMES .PYRDP )
4651 super ().__init__ (TCPParser ())
4752 self .connectedEvent = asyncio .Event ()
4853 self .logSSLRequired = False
4954
55+ self .new = True
56+ self .config = config
57+
5058 def logSSLParameters (self ):
5159 """
5260 Log the SSL parameters of the connection in a format suitable for decryption by Wireshark.
5361 """
54- getSSLLogger ().info (self .transport .protocol ._tlsConnection .client_random (), self .transport .protocol ._tlsConnection .master_key ())
62+ getSSLLogger ().info (self .transport .protocol ._tlsConnection .client_random (),
63+ self .transport .protocol ._tlsConnection .master_key ())
5564
5665 def connectionMade (self ):
5766 """
@@ -66,7 +75,7 @@ def connectionLost(self, reason=connectionDone):
6675 """
6776 self .observer .onDisconnection (reason )
6877
69- def disconnect (self , abort = False ):
78+ def disconnect (self , abort = False ):
7079 """
7180 Close the TCP connection.
7281 :param abort: True to force close the connection, False to end gracefully.
@@ -83,6 +92,20 @@ def dataReceived(self, data: bytes):
8392 Called whenever data is received.
8493 :param data: bytes received.
8594 """
95+
96+ # Check if the client is sending us a TLS record immediately.
97+ # and start the TLS handshake early.
98+ if self .new and data [0 ] == TLS_RECORD :
99+ self .startTLS (ServerTLSContext (self .config .privateKeyFileName ,
100+ self .config .certificateFileName ))
101+ # Resend the ClientHello to Twisted for handshake processing.
102+ # WARNING: This is using a private Twisted API which could change in the future.
103+ self .transport ._dataReceived (data )
104+ self .new = False
105+ return
106+
107+ self .new = False # First packet was not a TLS record.
108+
86109 try :
87110 if self .logSSLRequired :
88111 self .logSSLParameters ()
@@ -93,7 +116,7 @@ def dataReceived(self, data: bytes):
93116 raise
94117 except Exception as e :
95118 self .log .exception (e )
96- self .log .error ("Exception occurred when receiving: %(data)s" , {"data" : hexlify (data ).decode ()})
119+ self .log .error ("Exception occurred when receiving: %(data)s" , {"data" : hexlify (data ).decode ()})
97120 raise
98121
99122 def sendBytes (self , data : bytes ):
@@ -143,7 +166,7 @@ def connection_lost(self, exception=connectionDone):
143166 """
144167 self .observer .onDisconnection (exception )
145168
146- def disconnect (self , abort = False ):
169+ def disconnect (self , abort = False ):
147170 """
148171 Close the TCP connection.
149172 :param abort: True to force close the connection, False to end gracefully.
0 commit comments