Skip to content

Commit 58be453

Browse files
committed
add regression test; fix race
1 parent c6ba594 commit 58be453

3 files changed

Lines changed: 270 additions & 8 deletions

File tree

src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public SftpStatusResponse(uint protocolVersion)
1212
{
1313
}
1414

15-
public StatusCodes StatusCode { get; private set; }
15+
public StatusCodes StatusCode { get; set; }
1616

1717
public string ErrorMessage { get; private set; }
1818

@@ -39,5 +39,12 @@ protected override void LoadData()
3939
Language = ReadString(Ascii);
4040
}
4141
}
42+
43+
protected override void SaveData()
44+
{
45+
base.SaveData();
46+
47+
Write((uint)StatusCode);
48+
}
4249
}
4350
}

src/Renci.SshNet/SftpClient.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,6 +2460,10 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo
24602460

24612461
int bytesRead;
24622462
var expectedResponses = 0;
2463+
2464+
// We will send out all the write requests without waiting for each response.
2465+
// Afterwards, we may wait on this handle until all responses are received
2466+
// or an error has occured.
24632467
using var mres = new ManualResetEventSlim(initialState: false);
24642468

24652469
ExceptionDispatchInfo? exception = null;
@@ -2484,11 +2488,6 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo
24842488

24852489
try
24862490
{
2487-
if (Interlocked.Decrement(ref expectedResponses) == 0)
2488-
{
2489-
setHandle = true;
2490-
}
2491-
24922491
if (Sftp.SftpSession.GetSftpException(s) is Exception ex)
24932492
{
24942493
exception = ExceptionDispatchInfo.Capture(ex);
@@ -2513,7 +2512,7 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo
25132512
}
25142513
finally
25152514
{
2516-
if (setHandle)
2515+
if (Interlocked.Decrement(ref expectedResponses) == 0 || setHandle)
25172516
{
25182517
mres.Set();
25192518
}
@@ -2523,7 +2522,11 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo
25232522
offset += (ulong)bytesRead;
25242523
}
25252524

2526-
if (expectedResponses != 0)
2525+
// Make sure the read of exception cannot be executed ahead of
2526+
// the read of expectedResponses so that we do not miss an
2527+
// exception.
2528+
2529+
if (Volatile.Read(ref expectedResponses) != 0)
25272530
{
25282531
_sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
25292532
}
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
using System;
2+
using System.IO;
3+
using System.Net.Sockets;
4+
using System.Text;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
8+
using Microsoft.VisualStudio.TestTools.UnitTesting;
9+
10+
using Moq;
11+
12+
using Renci.SshNet.Channels;
13+
using Renci.SshNet.Common;
14+
using Renci.SshNet.Connection;
15+
using Renci.SshNet.Messages;
16+
using Renci.SshNet.Messages.Authentication;
17+
using Renci.SshNet.Messages.Connection;
18+
using Renci.SshNet.Sftp;
19+
using Renci.SshNet.Sftp.Responses;
20+
21+
namespace Renci.SshNet.Tests.Classes
22+
{
23+
public partial class SftpClientTest
24+
{
25+
[TestMethod]
26+
public void UploadFile_ObservesErrorResponses()
27+
{
28+
// A regression test for UploadFile hanging instead of observing
29+
// error responses from the server.
30+
// https://github.com/sshnet/SSH.NET/issues/957
31+
32+
var serviceFactoryMock = new Mock<IServiceFactory>();
33+
34+
var connInfo = new PasswordConnectionInfo("host", "user", "pwd");
35+
36+
var session = new MySession(connInfo);
37+
38+
var concreteServiceFactory = new ServiceFactory();
39+
40+
serviceFactoryMock
41+
.Setup(p => p.CreateSession(It.IsAny<ConnectionInfo>(), It.IsAny<ISocketFactory>()))
42+
.Returns(session);
43+
44+
serviceFactoryMock
45+
.Setup(p => p.CreateSftpResponseFactory())
46+
.Returns(concreteServiceFactory.CreateSftpResponseFactory);
47+
48+
serviceFactoryMock
49+
.Setup(p => p.CreateSftpSession(session, It.IsAny<int>(), It.IsAny<Encoding>(), It.IsAny<ISftpResponseFactory>()))
50+
.Returns(concreteServiceFactory.CreateSftpSession);
51+
52+
using var client = new SftpClient(connInfo, false, serviceFactoryMock.Object);
53+
client.Connect();
54+
55+
Assert.Throws<SftpPermissionDeniedException>(() => client.UploadFile(
56+
new OneByteStream(new MemoryStream("Hello World"u8.ToArray())),
57+
"path.txt"));
58+
}
59+
60+
#pragma warning disable IDE0022 // Use block body for method
61+
#pragma warning disable IDE0025 // Use block body for property
62+
#pragma warning disable IDE0027 // Use block body for accessor
63+
#pragma warning disable CS0067 // event is unused
64+
65+
private class MySession(ConnectionInfo connectionInfo) : ISession
66+
{
67+
public IConnectionInfo ConnectionInfo => connectionInfo;
68+
69+
public event EventHandler<MessageEventArgs<ChannelCloseMessage>> ChannelCloseReceived;
70+
public event EventHandler<MessageEventArgs<ChannelDataMessage>> ChannelDataReceived;
71+
public event EventHandler<MessageEventArgs<ChannelEofMessage>> ChannelEofReceived;
72+
public event EventHandler<MessageEventArgs<ChannelExtendedDataMessage>> ChannelExtendedDataReceived;
73+
public event EventHandler<MessageEventArgs<ChannelFailureMessage>> ChannelFailureReceived;
74+
public event EventHandler<MessageEventArgs<ChannelOpenConfirmationMessage>> ChannelOpenConfirmationReceived;
75+
public event EventHandler<MessageEventArgs<ChannelOpenFailureMessage>> ChannelOpenFailureReceived;
76+
public event EventHandler<MessageEventArgs<ChannelOpenMessage>> ChannelOpenReceived;
77+
public event EventHandler<MessageEventArgs<ChannelRequestMessage>> ChannelRequestReceived;
78+
public event EventHandler<MessageEventArgs<ChannelSuccessMessage>> ChannelSuccessReceived;
79+
public event EventHandler<MessageEventArgs<ChannelWindowAdjustMessage>> ChannelWindowAdjustReceived;
80+
public event EventHandler<EventArgs> Disconnected;
81+
public event EventHandler<ExceptionEventArgs> ErrorOccured;
82+
public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
83+
public event EventHandler<HostKeyEventArgs> HostKeyReceived;
84+
public event EventHandler<MessageEventArgs<RequestSuccessMessage>> RequestSuccessReceived;
85+
public event EventHandler<MessageEventArgs<RequestFailureMessage>> RequestFailureReceived;
86+
public event EventHandler<MessageEventArgs<BannerMessage>> UserAuthenticationBannerReceived;
87+
88+
private uint _numRequests;
89+
private int _numWriteRequests;
90+
91+
public void SendMessage(Message message)
92+
{
93+
// Initialisation sequence for SFTP session
94+
95+
if (message is ChannelOpenMessage)
96+
{
97+
ChannelOpenConfirmationReceived?.Invoke(
98+
this,
99+
new MessageEventArgs<ChannelOpenConfirmationMessage>(
100+
new ChannelOpenConfirmationMessage(0, int.MaxValue, int.MaxValue, 0)));
101+
}
102+
else if (message is ChannelRequestMessage)
103+
{
104+
ChannelSuccessReceived?.Invoke(
105+
this,
106+
new MessageEventArgs<ChannelSuccessMessage>(new ChannelSuccessMessage(0)));
107+
}
108+
else if (message is ChannelDataMessage dataMsg)
109+
{
110+
if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Init)
111+
{
112+
ChannelDataReceived?.Invoke(
113+
this,
114+
new MessageEventArgs<ChannelDataMessage>(
115+
new ChannelDataMessage(0, new SftpVersionResponse() { Version = 3 }.GetBytes())));
116+
}
117+
else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.RealPath)
118+
{
119+
ChannelDataReceived?.Invoke(
120+
this,
121+
new MessageEventArgs<ChannelDataMessage>(
122+
new ChannelDataMessage(0,
123+
new SftpNameResponse(3, Encoding.UTF8)
124+
{
125+
ResponseId = ++_numRequests,
126+
Files = [new("thepath", new SftpFileAttributes(default, default, default, default, default, default, default))]
127+
}.GetBytes())));
128+
}
129+
else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Open)
130+
{
131+
ChannelDataReceived?.Invoke(
132+
this,
133+
new MessageEventArgs<ChannelDataMessage>(
134+
new ChannelDataMessage(0,
135+
new SftpHandleResponse(3)
136+
{
137+
ResponseId = ++_numRequests,
138+
Handle = "file"u8.ToArray()
139+
}.GetBytes())));
140+
}
141+
142+
// --------- The actual interesting part of all of this ---------
143+
//
144+
else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Write)
145+
{
146+
// Fail the 5th write request
147+
var statusCode = ++_numWriteRequests == 5 ? StatusCodes.PermissionDenied : StatusCodes.Ok;
148+
var responseId = ++_numRequests;
149+
150+
// Dispatch the responses on a different thread to simulate reality.
151+
_ = Task.Run(() =>
152+
{
153+
ChannelDataReceived?.Invoke(
154+
this,
155+
new MessageEventArgs<ChannelDataMessage>(
156+
new ChannelDataMessage(0,
157+
new SftpStatusResponse(3)
158+
{
159+
ResponseId = responseId,
160+
StatusCode = statusCode
161+
}.GetBytes())));
162+
});
163+
}
164+
//
165+
// --------------------------------------------------------------
166+
}
167+
}
168+
169+
public bool IsConnected => false;
170+
171+
public SemaphoreSlim SessionSemaphore { get; } = new(1);
172+
173+
public IChannelSession CreateChannelSession() => new ChannelSession(this, 0, int.MaxValue, int.MaxValue);
174+
175+
public WaitHandle MessageListenerCompleted => throw new NotImplementedException();
176+
177+
public void Connect()
178+
{
179+
}
180+
181+
public Task ConnectAsync(CancellationToken cancellationToken) => throw new NotImplementedException();
182+
183+
public IChannelDirectTcpip CreateChannelDirectTcpip() => throw new NotImplementedException();
184+
185+
public IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize)
186+
=> throw new NotImplementedException();
187+
188+
public void Dispose()
189+
{
190+
}
191+
192+
public void OnDisconnecting()
193+
{
194+
}
195+
196+
public void Disconnect() => throw new NotImplementedException();
197+
198+
public void RegisterMessage(string messageName) => throw new NotImplementedException();
199+
200+
public bool TrySendMessage(Message message) => throw new NotImplementedException();
201+
202+
public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout, out Exception exception) => throw new NotImplementedException();
203+
204+
public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException();
205+
206+
public void UnRegisterMessage(string messageName) => throw new NotImplementedException();
207+
208+
public void WaitOnHandle(WaitHandle waitHandle)
209+
{
210+
}
211+
212+
public void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException();
213+
}
214+
215+
private class OneByteStream : Stream
216+
{
217+
private readonly Stream _stream;
218+
219+
public OneByteStream(Stream stream)
220+
{
221+
_stream = stream;
222+
}
223+
224+
public override bool CanRead => _stream.CanRead;
225+
226+
public override bool CanSeek => throw new NotImplementedException();
227+
228+
public override bool CanWrite => throw new NotImplementedException();
229+
230+
public override long Length => _stream.Length;
231+
232+
public override long Position
233+
{
234+
get => throw new NotImplementedException();
235+
set => throw new NotImplementedException();
236+
}
237+
238+
public override void Flush() => throw new NotImplementedException();
239+
240+
public override int Read(byte[] buffer, int offset, int count)
241+
{
242+
return _stream.Read(buffer, offset, Math.Min(1, count));
243+
}
244+
245+
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
246+
247+
public override void SetLength(long value) => throw new NotImplementedException();
248+
249+
public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
250+
}
251+
}
252+
}

0 commit comments

Comments
 (0)