|
11 | 11 | import uuid |
12 | 12 | from io import BytesIO, SEEK_SET, UnsupportedOperation |
13 | 13 | from time import time |
14 | | -from typing import Any, Dict, Optional, TYPE_CHECKING, Union |
| 14 | +from typing import Any, Dict, List, Optional, TypeVar, TYPE_CHECKING, Union |
15 | 15 | from urllib.parse import ( |
16 | 16 | parse_qsl, |
17 | 17 | urlencode, |
|
29 | 29 | RequestHistory, |
30 | 30 | SansIOHTTPPolicy, |
31 | 31 | ) |
| 32 | +from azure.core.pipeline import PipelineRequest |
| 33 | +from azure.core.pipeline.transport import ( |
| 34 | + HttpRequest as LegacyHttpRequest, |
| 35 | + HttpResponse as LegacyHttpResponse, |
| 36 | +) |
| 37 | +from azure.core.rest import HttpRequest, HttpResponse |
32 | 38 |
|
33 | 39 | from .authentication import AzureSigningError, StorageHttpChallenge |
34 | 40 | from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE |
|
54 | 60 | ) |
55 | 61 |
|
56 | 62 |
|
| 63 | +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) |
| 64 | +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) |
| 65 | + |
| 66 | + |
57 | 67 | _LOGGER = logging.getLogger(__name__) |
58 | 68 | CONTENT_LENGTH_HEADER = "Content-Length" |
59 | 69 | MD5_HEADER = "Content-MD5" |
@@ -843,3 +853,70 @@ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") |
843 | 853 | self.authorize_request(request, scope, tenant_id=challenge.tenant_id) |
844 | 854 |
|
845 | 855 | return True |
| 856 | + |
| 857 | + |
| 858 | +class StorageSensitiveHeaderCleanupPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): |
| 859 | + """A simple policy that cleans up sensitive headers |
| 860 | +
|
| 861 | + :keyword list[str] blocked_redirect_headers: The headers to clean up when redirecting to another domain. |
| 862 | + :keyword bool disable_redirect_cleanup: Opt out cleaning up sensitive headers when redirecting to another domain. |
| 863 | + """ |
| 864 | + |
| 865 | + DEFAULT_SENSITIVE_HEADERS = set( |
| 866 | + [ |
| 867 | + "Authorization", |
| 868 | + "x-ms-authorization-auxiliary", |
| 869 | + "x-ms-copy-source", |
| 870 | + "x-ms-copy-source-authorization", |
| 871 | + "x-ms-rename-source", |
| 872 | + ] |
| 873 | + ) |
| 874 | + |
| 875 | + DEFAULT_SENSITIVE_QUERY_PARAMS = set( |
| 876 | + [ |
| 877 | + "sig", |
| 878 | + ] |
| 879 | + ) |
| 880 | + |
| 881 | + def __init__( |
| 882 | + self, |
| 883 | + *, |
| 884 | + blocked_redirect_headers: Optional[List[str]] = None, |
| 885 | + blocked_query_params: Optional[List[str]] = None, |
| 886 | + disable_redirect_cleanup: bool = False, |
| 887 | + **kwargs: Any |
| 888 | + ) -> None: |
| 889 | + self._disable_redirect_cleanup = disable_redirect_cleanup |
| 890 | + self._blocked_redirect_headers = ( |
| 891 | + StorageSensitiveHeaderCleanupPolicy.DEFAULT_SENSITIVE_HEADERS |
| 892 | + if blocked_redirect_headers is None |
| 893 | + else blocked_redirect_headers |
| 894 | + ) |
| 895 | + self._blocked_query_params = ( |
| 896 | + StorageSensitiveHeaderCleanupPolicy.DEFAULT_SENSITIVE_QUERY_PARAMS |
| 897 | + if blocked_query_params is None |
| 898 | + else blocked_query_params |
| 899 | + ) |
| 900 | + |
| 901 | + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: |
| 902 | + """This is executed before sending the request to the next policy. |
| 903 | +
|
| 904 | + :param request: The PipelineRequest object. |
| 905 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 906 | + """ |
| 907 | + # "insecure_domain_change" is used to indicate that a redirect |
| 908 | + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy |
| 909 | + # to clean up sensitive headers. |
| 910 | + insecure_domain_change = request.context.get("insecure_domain_change", False) |
| 911 | + if not self._disable_redirect_cleanup and insecure_domain_change: |
| 912 | + # Clean up request query parameters |
| 913 | + parsed = urlparse(request.http_request.url) |
| 914 | + kept = [ |
| 915 | + pair for pair in parsed.query.split("&") |
| 916 | + if pair and pair.split("=", 1)[0] not in self._blocked_query_params |
| 917 | + ] |
| 918 | + request.http_request.url = urlunparse(parsed._replace(query="&".join(kept))) |
| 919 | + |
| 920 | + # Clean up request headers |
| 921 | + for header in self._blocked_redirect_headers: |
| 922 | + request.http_request.headers.pop(header, None) |
0 commit comments