Skip to content

Commit 99803ca

Browse files
[7.0.2 Cherry-pick] Add TDS token data length bounds checks (#4340) (#4358)
Co-authored-by: Paul Medynski <31868385+paulmedynski@users.noreply.github.com>
1 parent bacb2d9 commit 99803ca

13 files changed

Lines changed: 1334 additions & 38 deletions

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,11 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
13541354
len = bLen;
13551355
}
13561356

1357+
if (len < 0 || len > data.Length - i)
1358+
{
1359+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, len);
1360+
}
1361+
13571362
byte[] stateData = new byte[len];
13581363
Buffer.BlockCopy(data, i, stateData, 0, len);
13591364
i += len;

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,10 @@ internal static Exception ParsingErrorLength(ParsingErrorState state, int length
746746
{
747747
return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorLength, ((int)state).ToString(CultureInfo.InvariantCulture), length));
748748
}
749+
internal static Exception ParsingErrorLength(ParsingErrorState state, uint length)
750+
{
751+
return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorLength, ((int)state).ToString(CultureInfo.InvariantCulture), length));
752+
}
749753
internal static Exception ParsingErrorStatus(ParsingErrorState state, int status)
750754
{
751755
return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorStatus, ((int)state).ToString(CultureInfo.InvariantCulture), status));

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ internal static class TdsEnums
8484
public const int MAX_PACKET_SIZE = 32768;
8585
public const int MAX_SERVER_USER_NAME = 256; // obtained from luxor
8686

87+
// Maximum allowed data length for token payloads (feature ext ack,
88+
// session state, fedauth info). Prevents a malicious server from causing
89+
// unbounded memory allocation via spoofed token length fields.
90+
internal const int MaxTokenDataLength = 1 << 20; // 1 MB
91+
92+
// Maximum allowed data length for a DTC promote transaction propagation token.
93+
internal const int MaxPromoteTransactionLength = 1 << 16; // 64 KB
94+
95+
// Maximum valid wire size for datetime types (DateTimeOffset = 5 time + 3 date + 2 offset).
96+
internal const int MaxDateTimeLength = 10;
97+
8798
// Severity 0 - 10 indicates informational (non-error) messages
8899
// Severity 11 - 16 indicates errors that can be corrected by user (syntax errors, etc...)
89100
// Severity 17 - 19 indicates failure due to insufficient resources in the server

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ static TdsParser()
208208
{
209209
// For CoreCLR, we need to register the ANSI Code Page encoding provider before attempting to get an Encoding from a CodePage
210210
// For a default installation of SqlServer the encoding exchanged during Login is 1252. This encoding is not loaded by default
211-
// See Remarks at https://msdn.microsoft.com/en-us/library/system.text.encodingprovider(v=vs.110).aspx
211+
// See Remarks at https://msdn.microsoft.com/en-us/library/system.text.encodingprovider(v=vs.110).aspx
212212
// SqlClient needs to register the encoding providers to make sure that even basic scenarios work with Sql Server.
213213
Encoding.RegisterProvider(CodePagesEncodingProvider.Instance);
214214
}
@@ -683,7 +683,7 @@ internal void RemoveEncryption()
683683

684684
// create a new packet encryption changes the internal packet size Bug# 228403
685685
_physicalStateObj.ClearAllWritePackets();
686-
}
686+
}
687687

688688
internal void EnableMars()
689689
{
@@ -1376,11 +1376,11 @@ internal void TdsLogin(
13761376
int feOffset = length;
13771377
// calculate and reserve the required bytes for the featureEx
13781378
length = ApplyFeatureExData(
1379-
requestedFeatures,
1380-
recoverySessionData,
1379+
requestedFeatures,
1380+
recoverySessionData,
13811381
fedAuthFeatureExtensionData,
13821382
UserAgent.Ucs2Bytes,
1383-
useFeatureExt,
1383+
useFeatureExt,
13841384
length
13851385
);
13861386

@@ -2792,7 +2792,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle
27922792
{
27932793
_connHandler._federatedAuthenticationInfoReceived = true;
27942794
SqlFedAuthInfo info;
2795-
2795+
27962796
result = TryProcessFedAuthInfo(stateObj, tokenLength, out info);
27972797
if (result != TdsOperationStatus.Done)
27982798
{
@@ -3348,6 +3348,10 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb
33483348
// new value has 4 byte length
33493349
return result;
33503350
}
3351+
if (env._newLength < 0 || env._newLength > TdsEnums.MaxPromoteTransactionLength)
3352+
{
3353+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, env._newLength);
3354+
}
33513355
// read new value with 4 byte length
33523356
env._newBinValue = new byte[env._newLength];
33533357
result = stateObj.TryReadByteArray(env._newBinValue, env._newLength);
@@ -3846,10 +3850,15 @@ private TdsOperationStatus TryProcessFeatureExtAck(TdsParserStateObject stateObj
38463850
{
38473851
return result;
38483852
}
3849-
byte[] data = new byte[dataLen];
3850-
if (dataLen > 0)
3853+
if (dataLen > (uint)TdsEnums.MaxTokenDataLength)
38513854
{
3852-
result = stateObj.TryReadByteArray(data, checked((int)dataLen));
3855+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, dataLen);
3856+
}
3857+
int dataLength = (int)dataLen;
3858+
byte[] data = new byte[dataLength];
3859+
if (dataLength > 0)
3860+
{
3861+
result = stateObj.TryReadByteArray(data, dataLength);
38533862
if (result != TdsOperationStatus.Done)
38543863
{
38553864
return result;
@@ -4169,6 +4178,10 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj,
41694178
{
41704179
throw SQL.ParsingErrorLength(ParsingErrorState.SessionStateLengthTooShort, length);
41714180
}
4181+
if (length > TdsEnums.MaxTokenDataLength)
4182+
{
4183+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length);
4184+
}
41724185
uint seqNum;
41734186
TdsOperationStatus result = stateObj.TryReadUInt32(out seqNum);
41744187
if (result != TdsOperationStatus.Done)
@@ -4218,6 +4231,10 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj,
42184231
return result;
42194232
}
42204233
}
4234+
if (stateLen < 0 || stateLen > TdsEnums.MaxTokenDataLength)
4235+
{
4236+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, stateLen);
4237+
}
42214238
byte[] buffer = null;
42224239
lock (sdata._delta)
42234240
{
@@ -4435,6 +4452,10 @@ private TdsOperationStatus TryProcessFedAuthInfo(TdsParserStateObject stateObj,
44354452
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.TryProcessFedAuthInfo|ERR> FEDAUTHINFO token stream length too short for CountOfInfoIDs.");
44364453
throw SQL.ParsingErrorLength(ParsingErrorState.FedAuthInfoLengthTooShortForCountOfInfoIds, tokenLen);
44374454
}
4455+
if (tokenLen > TdsEnums.MaxTokenDataLength)
4456+
{
4457+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, tokenLen);
4458+
}
44384459

44394460
// read how many FedAuthInfo options there are
44404461
uint optionsCount;
@@ -4912,14 +4933,20 @@ internal TdsOperationStatus TryProcessReturnValue(int length,
49124933
}
49134934

49144935
// always read as sql types
4915-
Debug.Assert(valLen < (ulong)(int.MaxValue), "ProcessReturnValue received data size > 2Gb");
4916-
4917-
int intlen = valLen > (ulong)(int.MaxValue) ? int.MaxValue : (int)valLen;
4936+
int intlen;
49184937

49194938
if (rec.metaType.IsPlp)
49204939
{
49214940
intlen = int.MaxValue; // If plp data, read it all
49224941
}
4942+
else if (valLen > (ulong)int.MaxValue)
4943+
{
4944+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, unchecked((int)valLen));
4945+
}
4946+
else
4947+
{
4948+
intlen = (int)valLen;
4949+
}
49234950

49244951
if (rec.type == SqlDbTypeExtensions.Vector)
49254952
{
@@ -5790,7 +5817,7 @@ private TdsOperationStatus TryCommonProcessMetaData(TdsParserStateObject stateOb
57905817
{
57915818
return result;
57925819
}
5793-
5820+
57945821
// read flags and set appropriate flags in structure
57955822
byte flags;
57965823
result = stateObj.TryReadByte(out flags);
@@ -7119,7 +7146,7 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value,
71197146
return result;
71207147
}
71217148

7122-
// Internally, we use Sqlbinary to deal with varbinary data and store it in
7149+
// Internally, we use Sqlbinary to deal with varbinary data and store it in
71237150
// SqlBuffer as SqlBinary value.
71247151
#if NET
71257152
value.SqlBinary = SqlBinary.WrapBytes(b);
@@ -7188,9 +7215,20 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value,
71887215
return TdsOperationStatus.Done;
71897216
}
71907217

7218+
// length originates as a single byte on the wire (nullable datetime length prefix),
7219+
// but is kept as int to match the TDS parsing API surface where all lengths are int.
7220+
// Using byte here would require casts at all call sites and silently truncate values
7221+
// from the sql_variant path where lenData is computed arithmetically.
71917222
private TdsOperationStatus TryReadSqlDateTime(SqlBuffer value, byte tdsType, int length, byte scale, TdsParserStateObject stateObj)
71927223
{
7193-
Span<byte> datetimeBuffer = ((uint)length <= 16) ? stackalloc byte[16] : new byte[length];
7224+
// DateTimeOffset is the largest datetime type at 10 bytes (5 time + 3 date + 2 offset).
7225+
// Reject anything larger to prevent heap allocation from spoofed metadata.
7226+
if (length < 0 || length > TdsEnums.MaxDateTimeLength)
7227+
{
7228+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length);
7229+
}
7230+
7231+
Span<byte> datetimeBuffer = stackalloc byte[TdsEnums.MaxDateTimeLength];
71947232

71957233
TdsOperationStatus result = stateObj.TryReadByteArray(datetimeBuffer, length);
71967234
if (result != TdsOperationStatus.Done)
@@ -7446,9 +7484,11 @@ internal TdsOperationStatus TryReadSqlValueInternal(SqlBuffer value, byte tdsTyp
74467484
case TdsEnums.SQLVECTOR:
74477485
{
74487486
// Note: Better not come here with plp data!!
7449-
Debug.Assert(length <= TdsEnums.MAXSIZE);
7450-
byte[] b = new byte[length];
7451-
result = stateObj.TryReadByteArrayWithContinue(length, isPlp: false, out b);
7487+
if (length < 0 || length > TdsEnums.MAXSIZE)
7488+
{
7489+
throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length);
7490+
}
7491+
result = stateObj.TryReadByteArrayWithContinue(length, isPlp: false, out byte[] b);
74527492
if (result != TdsOperationStatus.Done)
74537493
{
74547494
return result;
@@ -9278,7 +9318,7 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD
92789318
/// </remarks>
92799319
internal int WriteVectorSupportFeatureRequest(bool write)
92809320
{
9281-
const int len = 6;
9321+
const int len = 6;
92829322

92839323
if (write)
92849324
{
@@ -10476,7 +10516,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet
1047610516
{
1047710517
isSqlVal = param.ParameterIsSqlType; // We have to forward the TYPE info, we need to know what type we are returning. Once we null the parameter we will no longer be able to distinguish what type were seeing.
1047810518

10479-
// Output parameter of SqlDbType vector are defined through SqlParameter.Value.
10519+
// Output parameter of SqlDbType vector are defined through SqlParameter.Value.
1048010520
// This check ensures that we do not discard the parameter value when SqlDbType is vector.
1048110521
if (mt.SqlDbType != SqlDbTypeExtensions.Vector)
1048210522
{
@@ -10761,7 +10801,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet
1076110801

1076210802
Debug.Assert(udtVal != null, "GetBytes returned null instance. Make sure that it always returns non-null value");
1076310803
size = udtVal.Length;
10764-
10804+
1076510805
if (size >= maxSupportedSize && maxsize != -1)
1076610806
{
1076710807
throw SQL.UDTInvalidSize(maxsize, maxSupportedSize);
@@ -13263,7 +13303,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
1326313303
{
1326413304
if (type.NullableType == TdsEnums.SQLJSON)
1326513305
{
13266-
// TODO : Performance and BOM check. Saurabh
13306+
// TODO : Performance and BOM check. Saurabh
1326713307
byte[] jsonAsBytes = Encoding.UTF8.GetBytes((string)value);
1326813308
WriteInt(jsonAsBytes.Length, stateObj);
1326913309
return stateObj.WriteByteArray(jsonAsBytes, jsonAsBytes.Length, 0, canAccumulate: false);
@@ -13921,13 +13961,13 @@ internal TdsOperationStatus TryReadPlpUnicodeCharsWithContinue(TdsParserStateObj
1392113961
}
1392213962

1392313963
TdsOperationStatus result = TryReadPlpUnicodeChars(
13924-
ref temp,
13925-
0,
13926-
length >> 1,
13927-
stateObj,
13928-
out length,
13964+
ref temp,
13965+
0,
13966+
length >> 1,
13967+
stateObj,
13968+
out length,
1392913969
supportRentedBuff: !canContinue, // do not use the arraypool if we are going to keep the buffer in the snapshot
13930-
rentedBuff: ref buffIsRented,
13970+
rentedBuff: ref buffIsRented,
1393113971
startOffset,
1393213972
canContinue
1393313973
);
@@ -14137,7 +14177,7 @@ bool writeDataSizeToSnapshot
1413714177
stateObj._longlenleft--;
1413814178
if (writeDataSizeToSnapshot)
1413914179
{
14140-
// we need to write the single b1 byte to the array because we may run out of data
14180+
// we need to write the single b1 byte to the array because we may run out of data
1414114181
// and need to wait for another packet
1414214182
buff[offst] = (char)((b1 & 0xff));
1414314183
currentPacketId = IncrementSnapshotDataSize(stateObj, restartingDataSizeCount, currentPacketId, 1);

src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests;
1414
/// <summary>
1515
/// Tests connection routing using the enhanced routing feature extension and envchange token
1616
/// </summary>
17+
// TODO: Do we need this collection? It serializes all tests within it, which we probably don't
18+
// need since each test uses its own TDS Server with ephemeral listen port.
1719
[Collection("SimulatedServerTests")]
1820
public class ConnectionEnhancedRoutingTests
1921
{

src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests
1313
{
1414
[Trait("Category", "flaky")]
15+
// TODO: Do we need this collection? It serializes all tests within it, which we probably don't
16+
// need since each test uses its own TDS Server with ephemeral listen port.
1517
[Collection("SimulatedServerTests")]
1618
public class ConnectionFailoverTests
1719
{
@@ -173,7 +175,7 @@ public void NetworkTimeout_ShouldFail()
173175
InitialCatalog = "master",// Required for failover partner to work
174176
ConnectTimeout = 1,
175177
ConnectRetryInterval = 1,
176-
ConnectRetryCount = 0, // Disable retry
178+
ConnectRetryCount = 0, // Disable retry
177179
Encrypt = false,
178180
MultiSubnetFailover = false,
179181
#if NETFRAMEWORK
@@ -460,7 +462,7 @@ public void TransientFault_WithUserProvidedPartner_ShouldConnectToPrimary(uint e
460462
FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner
461463
};
462464
using SqlConnection connection = new(builder.ConnectionString);
463-
465+
464466
// Act
465467
connection.Open();
466468

src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests
1313
{
14+
// TODO: Do we need this collection? It serializes all tests within it, which we probably don't
15+
// need since each test uses its own TDS Server with ephemeral listen port.
1416
[Collection("SimulatedServerTests")]
1517
public class ConnectionReadOnlyRoutingTests
1618
{
@@ -71,7 +73,7 @@ public void RecursivelyRoutedConnection(int layers)
7173
router.Start();
7274
routingLayers.Push(router);
7375
lastEndpoint = router.EndPoint;
74-
lastConnectionString = (new SqlConnectionStringBuilder() {
76+
lastConnectionString = (new SqlConnectionStringBuilder() {
7577
DataSource = $"localhost,{lastEndpoint.Port}",
7678
ApplicationIntent = ApplicationIntent.ReadOnly,
7779
Encrypt = false
@@ -114,8 +116,8 @@ public async Task RecursivelyRoutedAsyncConnection(int layers)
114116
router.Start();
115117
routingLayers.Push(router);
116118
lastEndpoint = router.EndPoint;
117-
lastConnectionString = (new SqlConnectionStringBuilder() {
118-
DataSource = $"localhost,{lastEndpoint.Port}",
119+
lastConnectionString = (new SqlConnectionStringBuilder() {
120+
DataSource = $"localhost,{lastEndpoint.Port}",
119121
ApplicationIntent = ApplicationIntent.ReadOnly,
120122
Encrypt = false
121123
}).ConnectionString;

src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests
1212
{
1313
[Trait("Category", "flaky")]
14+
// TODO: Do we need this collection? It serializes all tests within it, which we probably don't
15+
// need since each test uses its own TDS Server with ephemeral listen port.
1416
[Collection("SimulatedServerTests")]
1517
public class ConnectionRoutingTests
1618
{
@@ -195,7 +197,7 @@ public void NetworkTimeoutAtRoutedLocation_RetryDisabled_ShouldFail()
195197
// Act
196198
var e = Assert.Throws<SqlException>(connection.Open);
197199

198-
// Assert
200+
// Assert
199201
Assert.Equal(ConnectionState.Closed, connection.State);
200202
Assert.Contains("Connection Timeout Expired", e.Message);
201203
}

0 commit comments

Comments
 (0)