From b89d9beadfa5850e1484789a3f5859540ae39ce1 Mon Sep 17 00:00:00 2001 From: Jade Wang Date: Mon, 31 Mar 2025 14:30:20 -0700 Subject: [PATCH] Update DatabricksConnection.cs Update DatabricksConnection.cs Update RetryHttpHandlerTest.cs Add retry after handling in ADBC Spark driver fix pre commit check failures fix build error address PR comments fix linter lint Update DatabricksConnection.cs address comments --- .../Apache/Spark/SparkHttpConnection.cs | 9 +- .../Databricks/DatabricksConnection.cs | 51 ++++ .../Drivers/Databricks/DatabricksException.cs | 69 +++++ .../Databricks/DatabricksParameters.cs | 10 + .../Drivers/Databricks/RetryHttpHandler.cs | 147 +++++++++++ .../Databricks/DatabricksConnectionTest.cs | 5 + .../Databricks/RetryHttpHandlerTest.cs | 244 ++++++++++++++++++ 7 files changed, 533 insertions(+), 2 deletions(-) create mode 100644 csharp/src/Drivers/Databricks/DatabricksException.cs create mode 100644 csharp/src/Drivers/Databricks/RetryHttpHandler.cs create mode 100644 csharp/test/Drivers/Databricks/RetryHttpHandlerTest.cs diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 9cabd1ac41..6806d0c801 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -135,11 +135,17 @@ protected override void ValidateOptions() ? connectTimeoutMsValue : throw new ArgumentOutOfRangeException(SparkParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); } + TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties); } internal override IArrowArrayStream NewReader(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + protected virtual HttpMessageHandler CreateHttpHandler() + { + return HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions); + } + protected override TTransport CreateTransport() { // Assumption: parameters have already been validated. @@ -160,8 +166,7 @@ protected override TTransport CreateTransport() Uri baseAddress = GetBaseAddress(uri, hostName, path, port, SparkParameters.HostName, TlsOptions.IsTlsEnabled); AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password, access_token); - HttpClientHandler httpClientHandler = HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions); - HttpClient httpClient = new(httpClientHandler); + HttpClient httpClient = new(CreateHttpHandler()); httpClient.BaseAddress = baseAddress; httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index 45a496b64e..aefe2df89e 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache; @@ -39,6 +40,8 @@ internal class DatabricksConnection : SparkHttpConnection private bool _useCloudFetch = true; private bool _canDecompressLz4 = true; private long _maxBytesPerFile = DefaultMaxBytesPerFile; + private const bool DefaultRetryOnUnavailable= true; + private const int DefaultTemporarilyUnavailableRetryTimeout = 500; public DatabricksConnection(IReadOnlyDictionary properties) : base(properties) { @@ -122,6 +125,26 @@ private void ValidateProperties() /// internal long MaxBytesPerFile => _maxBytesPerFile; + /// + /// Gets a value indicating whether to retry requests that receive a 503 response with a Retry-After header. + /// + protected bool TemporarilyUnavailableRetry { get; private set; } = DefaultRetryOnUnavailable; + + /// + /// Gets the maximum total time in seconds to retry 503 responses before failing. + /// + protected int TemporarilyUnavailableRetryTimeout { get; private set; } = DefaultTemporarilyUnavailableRetryTimeout; + + protected override HttpMessageHandler CreateHttpHandler() + { + var baseHandler = base.CreateHttpHandler(); + if (TemporarilyUnavailableRetry) + { + return new RetryHttpHandler(baseHandler, TemporarilyUnavailableRetryTimeout); + } + return baseHandler; + } + internal override IArrowArrayStream NewReader(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) { // Get result format from metadata response if available @@ -259,6 +282,34 @@ private string EscapeSqlString(string value) return "`" + value.Replace("`", "``") + "`"; } + protected override void ValidateOptions() + { + base.ValidateOptions(); + + if (Properties.TryGetValue(DatabricksParameters.TemporarilyUnavailableRetry, out string? tempUnavailableRetryStr)) + { + if (!bool.TryParse(tempUnavailableRetryStr, out bool tempUnavailableRetryValue)) + { + throw new ArgumentOutOfRangeException(DatabricksParameters.TemporarilyUnavailableRetry, tempUnavailableRetryStr, + $"must be a value of false (disabled) or true (enabled). Default is true."); + } + + TemporarilyUnavailableRetry = tempUnavailableRetryValue; + } + + + if(Properties.TryGetValue(DatabricksParameters.TemporarilyUnavailableRetryTimeout, out string? tempUnavailableRetryTimeoutStr)) + { + if (!int.TryParse(tempUnavailableRetryTimeoutStr, out int tempUnavailableRetryTimeoutValue) || + tempUnavailableRetryTimeoutValue < 0) + { + throw new ArgumentOutOfRangeException(DatabricksParameters.TemporarilyUnavailableRetryTimeout, tempUnavailableRetryTimeoutStr, + $"must be a value of 0 (retry indefinitely) or a positive integer representing seconds. Default is 900 seconds (15 minutes)."); + } + TemporarilyUnavailableRetryTimeout = tempUnavailableRetryTimeoutValue; + } + } + protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => Task.FromResult(response.DirectResults.ResultSetMetadata); protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => diff --git a/csharp/src/Drivers/Databricks/DatabricksException.cs b/csharp/src/Drivers/Databricks/DatabricksException.cs new file mode 100644 index 0000000000..c6c3d203a2 --- /dev/null +++ b/csharp/src/Drivers/Databricks/DatabricksException.cs @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; + +namespace Apache.Arrow.Adbc.Drivers.Databricks +{ + public class DatabricksException : AdbcException + { + private string? _sqlState; + private int _nativeError; + + public DatabricksException() + { + } + + public DatabricksException(string message) : base(message) + { + } + + public DatabricksException(string message, AdbcStatusCode statusCode) : base(message, statusCode) + { + } + + public DatabricksException(string message, Exception innerException) : base(message, innerException) + { + } + + public DatabricksException(string message, AdbcStatusCode statusCode, Exception innerException) : base(message, statusCode, innerException) + { + } + + public override string? SqlState + { + get { return _sqlState; } + } + + public override int NativeError + { + get { return _nativeError; } + } + + internal DatabricksException SetSqlState(string sqlState) + { + _sqlState = sqlState; + return this; + } + + internal DatabricksException SetNativeError(int nativeError) + { + _nativeError = nativeError; + return this; + } + } +} diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs b/csharp/src/Drivers/Databricks/DatabricksParameters.cs index 1c963b4d7b..c99fbde3ee 100644 --- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs +++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs @@ -75,6 +75,16 @@ public class DatabricksParameters : SparkParameters /// and value "true" will result in executing "set use_cached_result=true" on the server. /// public const string ServerSidePropertyPrefix = "adbc.databricks.SSP_"; + /// Controls whether to retry requests that receive a 503 response with a Retry-After header. + /// Default value is true (enabled). Set to false to disable retry behavior. + /// + public const string TemporarilyUnavailableRetry = "adbc.spark.temporarily_unavailable_retry"; + + /// + /// Maximum total time in seconds to retry 503 responses before failing. + /// Default value is 900 seconds (15 minutes). Set to 0 to retry indefinitely. + /// + public const string TemporarilyUnavailableRetryTimeout = "adbc.spark.temporarily_unavailable_retry_timeout"; } /// diff --git a/csharp/src/Drivers/Databricks/RetryHttpHandler.cs b/csharp/src/Drivers/Databricks/RetryHttpHandler.cs new file mode 100644 index 0000000000..2a9f281613 --- /dev/null +++ b/csharp/src/Drivers/Databricks/RetryHttpHandler.cs @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using System.IO; + +namespace Apache.Arrow.Adbc.Drivers.Databricks +{ + /// + /// HTTP handler that implements retry behavior for 503 responses with Retry-After headers. + /// + internal class RetryHttpHandler : DelegatingHandler + { + private readonly int _retryTimeoutSeconds; + + /// + /// Initializes a new instance of the class. + /// + /// The inner handler to delegate to. + /// Whether retry behavior is enabled. + /// Maximum total time in seconds to retry before failing. + public RetryHttpHandler(HttpMessageHandler innerHandler, int retryTimeoutSeconds) + : base(innerHandler) + { + _retryTimeoutSeconds = retryTimeoutSeconds; + } + + /// + /// Sends an HTTP request to the inner handler with retry logic for 503 responses. + /// + protected override async Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken) + { + // Clone the request content if it's not null so we can reuse it for retries + var requestContentClone = request.Content != null + ? await CloneHttpContentAsync(request.Content) + : null; + + HttpResponseMessage response; + string? lastErrorMessage = null; + DateTime startTime = DateTime.UtcNow; + int totalRetrySeconds = 0; + + do + { + // Set the content for each attempt (if needed) + if (requestContentClone != null && request.Content == null) + { + request.Content = await CloneHttpContentAsync(requestContentClone); + } + + response = await base.SendAsync(request, cancellationToken); + + // If it's not a 503 response, return immediately + if (response.StatusCode != HttpStatusCode.ServiceUnavailable) + { + return response; + } + + // Check for Retry-After header + if (!response.Headers.TryGetValues("Retry-After", out var retryAfterValues)) + { + // No Retry-After header, so return the response as is + return response; + } + + // Parse the Retry-After value + string retryAfterValue = string.Join(",", retryAfterValues); + if (!int.TryParse(retryAfterValue, out int retryAfterSeconds) || retryAfterSeconds <= 0) + { + // Invalid Retry-After value, return the response as is + return response; + } + + lastErrorMessage = $"Service temporarily unavailable (HTTP 503). Retry after {retryAfterSeconds} seconds."; + + // Dispose the response before retrying + response.Dispose(); + + // Reset the request content for the next attempt + request.Content = null; + + // Check if we've exceeded the timeout + totalRetrySeconds += retryAfterSeconds; + if (_retryTimeoutSeconds > 0 && totalRetrySeconds > _retryTimeoutSeconds) + { + // We've exceeded the timeout, so break out of the loop + break; + } + + // Wait for the specified retry time + await Task.Delay(TimeSpan.FromSeconds(retryAfterSeconds), cancellationToken); + } while (!cancellationToken.IsCancellationRequested); + + // If we get here, we've either exceeded the timeout or been cancelled + if (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException("Request cancelled during retry wait", cancellationToken); + } + + throw new DatabricksException( + lastErrorMessage ?? "Service temporarily unavailable and retry timeout exceeded", + AdbcStatusCode.IOError); + } + + /// + /// Clones an HttpContent object so it can be reused for retries. + /// per .net guidance, we should not reuse the http content across multiple + /// request, as it maybe disposed. + /// + private static async Task CloneHttpContentAsync(HttpContent content) + { + var ms = new MemoryStream(); + await content.CopyToAsync(ms); + ms.Position = 0; + + var clone = new StreamContent(ms); + if (content.Headers != null) + { + foreach (var header in content.Headers) + { + clone.Headers.Add(header.Key, header.Value); + } + } + return clone; + } + } +} diff --git a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs index 3462f15058..859ee7e849 100644 --- a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs +++ b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs @@ -307,6 +307,11 @@ public InvalidConnectionParametersTestData() Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "http-//hostname.com" }, typeof(ArgumentException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com:1234567890" }, typeof(ArgumentException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.Token] = "abcdef", [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Uri] = "http://valid.hostname.com" }, typeof(ArgumentOutOfRangeException))); + + // Tests for the new retry configuration parameters + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [DatabricksParameters.TemporarilyUnavailableRetry] = "invalid" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [DatabricksParameters.TemporarilyUnavailableRetryTimeout] = "invalid" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [DatabricksParameters.TemporarilyUnavailableRetryTimeout] = "-1" }, typeof(ArgumentOutOfRangeException))); } } } diff --git a/csharp/test/Drivers/Databricks/RetryHttpHandlerTest.cs b/csharp/test/Drivers/Databricks/RetryHttpHandlerTest.cs new file mode 100644 index 0000000000..fad2ad9eb7 --- /dev/null +++ b/csharp/test/Drivers/Databricks/RetryHttpHandlerTest.cs @@ -0,0 +1,244 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Databricks; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks +{ + /// + /// Tests for the RetryHttpHandler class. + /// + public class RetryHttpHandlerTest + { + /// + /// Tests that the RetryHttpHandler properly processes 503 responses with Retry-After headers. + /// + [Fact] + public async Task RetryAfterHandlerProcesses503Response() + { + // Create a mock handler that returns a 503 response with a Retry-After header + var mockHandler = new MockHttpMessageHandler( + new HttpResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Headers = { { "Retry-After", "1" } }, + Content = new StringContent("Service Unavailable") + }); + + // Create the RetryHttpHandler with retry enabled and a 5-second timeout + var retryHandler = new RetryHttpHandler(mockHandler, 5); + + // Create an HttpClient with our handler + var httpClient = new HttpClient(retryHandler); + + // Set the mock handler to return a success response after the first retry + mockHandler.SetResponseAfterRetryCount(1, new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("Success") + }); + + // Send a request + var response = await httpClient.GetAsync("http://test.com"); + + // Verify the response is OK + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("Success", await response.Content.ReadAsStringAsync()); + Assert.Equal(2, mockHandler.RequestCount); // Initial request + 1 retry + } + + /// + /// Tests that the RetryHttpHandler throws an exception when the retry timeout is exceeded. + /// + [Fact] + public async Task RetryAfterHandlerThrowsWhenTimeoutExceeded() + { + // Create a mock handler that always returns a 503 response with a Retry-After header + var mockHandler = new MockHttpMessageHandler( + new HttpResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Headers = { { "Retry-After", "2" } }, + Content = new StringContent("Service Unavailable") + }); + + // Create the RetryHttpHandler with retry enabled and a 1-second timeout + var retryHandler = new RetryHttpHandler(mockHandler, 1); + + // Create an HttpClient with our handler + var httpClient = new HttpClient(retryHandler); + + // Send a request and expect an AdbcException + var exception = await Assert.ThrowsAsync(async () => + await httpClient.GetAsync("http://test.com")); + + // Verify the exception has the correct SQL state in the message + Assert.Contains("[SQLState: 08001]", exception.Message); + Assert.Equal(AdbcStatusCode.IOError, exception.Status); + + // Verify we only tried once (since the Retry-After value of 2 exceeds our timeout of 1) + Assert.Equal(1, mockHandler.RequestCount); + } + + /// + /// Tests that the RetryHttpHandler handles non-503 responses correctly. + /// + [Fact] + public async Task RetryAfterHandlerHandlesNon503Response() + { + // Create a mock handler that returns a 404 response + var mockHandler = new MockHttpMessageHandler( + new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("Not Found") + }); + + // Create the RetryHttpHandler with retry enabled + var retryHandler = new RetryHttpHandler(mockHandler, 5); + + // Create an HttpClient with our handler + var httpClient = new HttpClient(retryHandler); + + // Send a request + var response = await httpClient.GetAsync("http://test.com"); + + // Verify the response is 404 + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + Assert.Equal("Not Found", await response.Content.ReadAsStringAsync()); + Assert.Equal(1, mockHandler.RequestCount); // Only the initial request, no retries + } + + /// + /// Tests that the RetryHttpHandler handles 503 responses without Retry-After headers correctly. + /// + [Fact] + public async Task RetryAfterHandlerHandles503WithoutRetryAfterHeader() + { + // Create a mock handler that returns a 503 response without a Retry-After header + var mockHandler = new MockHttpMessageHandler( + new HttpResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Content = new StringContent("Service Unavailable") + }); + + // Create the RetryHttpHandler with retry enabled + var retryHandler = new RetryHttpHandler(mockHandler, 5); + + // Create an HttpClient with our handler + var httpClient = new HttpClient(retryHandler); + + // Send a request + var response = await httpClient.GetAsync("http://test.com"); + + // Verify the response is 503 + Assert.Equal(HttpStatusCode.ServiceUnavailable, response.StatusCode); + Assert.Equal("Service Unavailable", await response.Content.ReadAsStringAsync()); + Assert.Equal(1, mockHandler.RequestCount); // Only the initial request, no retries + } + + /// + /// Tests that the RetryHttpHandler handles invalid Retry-After headers correctly. + /// + [Fact] + public async Task RetryAfterHandlerHandlesInvalidRetryAfterHeader() + { + // Create a mock handler that returns a 503 response with an invalid Retry-After header + var mockHandler = new MockHttpMessageHandler( + new HttpResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Content = new StringContent("Service Unavailable") + }); + + // Add the invalid Retry-After header directly in the test + var response = new HttpResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Content = new StringContent("Service Unavailable") + }; + response.Headers.TryAddWithoutValidation("Retry-After", "invalid"); + mockHandler.SetResponseAfterRetryCount(0, response); + + // Create the RetryHttpHandler with retry enabled + var retryHandler = new RetryHttpHandler(mockHandler, 5); + + // Create an HttpClient with our handler + var httpClient = new HttpClient(retryHandler); + + // Send a request + response = await httpClient.GetAsync("http://test.com"); + + // Verify the response is 503 + Assert.Equal(HttpStatusCode.ServiceUnavailable, response.StatusCode); + Assert.Equal("Service Unavailable", await response.Content.ReadAsStringAsync()); + Assert.Equal(1, mockHandler.RequestCount); // Only the initial request, no retries + } + + /// + /// Mock HttpMessageHandler for testing the RetryHttpHandler. + /// + private class MockHttpMessageHandler : HttpMessageHandler + { + private readonly HttpResponseMessage _defaultResponse; + private HttpResponseMessage? _responseAfterRetryCount; + private int _retryCountForResponse; + + public int RequestCount { get; private set; } + + public MockHttpMessageHandler(HttpResponseMessage defaultResponse) + { + _defaultResponse = defaultResponse; + } + + public void SetResponseAfterRetryCount(int retryCount, HttpResponseMessage response) + { + _retryCountForResponse = retryCount; + _responseAfterRetryCount = response; + } + + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken) + { + RequestCount++; + + if (_responseAfterRetryCount != null && RequestCount > _retryCountForResponse) + { + return Task.FromResult(_responseAfterRetryCount); + } + + // Create a new response instance to avoid modifying the original + var response = new HttpResponseMessage + { + StatusCode = _defaultResponse.StatusCode, + Content = _defaultResponse.Content + }; + + // Copy headers only if they exist + if (_defaultResponse.Headers.Contains("Retry-After")) + { + foreach (var value in _defaultResponse.Headers.GetValues("Retry-After")) + { + response.Headers.Add("Retry-After", value); + } + } + + return Task.FromResult(response); + } + } + } +}