|
4 | 4 |
|
5 | 5 | import json |
6 | 6 | import logging |
| 7 | +import threading |
| 8 | +import time |
7 | 9 | from datetime import timedelta |
8 | 10 | from typing import Optional, Union |
9 | 11 | from unittest.mock import Mock |
@@ -785,3 +787,132 @@ def mock_request(method, url, data, headers): |
785 | 787 | raise Exception( |
786 | 788 | f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" |
787 | 789 | ) |
| 790 | + |
| 791 | + |
| 792 | +class TestConcurrentTokenRefresh: |
| 793 | + """ |
| 794 | + Test class for verifying thread-safe token refresh behavior. |
| 795 | +
|
| 796 | + These tests ensure that when multiple threads (streams) attempt to refresh |
| 797 | + an expired token simultaneously, only one refresh actually occurs and |
| 798 | + others wait and use the refreshed token. |
| 799 | + """ |
| 800 | + |
| 801 | + def test_concurrent_token_refresh_only_refreshes_once(self, mocker): |
| 802 | + """ |
| 803 | + When multiple threads detect an expired token and try to refresh simultaneously, |
| 804 | + only one thread should actually perform the refresh. Others should wait and |
| 805 | + use the newly refreshed token. |
| 806 | + """ |
| 807 | + refresh_call_count = 0 |
| 808 | + refresh_call_lock = threading.Lock() |
| 809 | + |
| 810 | + def mock_refresh_access_token(self): |
| 811 | + nonlocal refresh_call_count |
| 812 | + with refresh_call_lock: |
| 813 | + refresh_call_count += 1 |
| 814 | + time.sleep(0.1) |
| 815 | + return ("new_access_token", ab_datetime_now() + timedelta(hours=1)) |
| 816 | + |
| 817 | + mocker.patch.object( |
| 818 | + Oauth2Authenticator, |
| 819 | + "refresh_access_token", |
| 820 | + mock_refresh_access_token, |
| 821 | + ) |
| 822 | + |
| 823 | + oauth = Oauth2Authenticator( |
| 824 | + token_refresh_endpoint="https://refresh_endpoint.com", |
| 825 | + client_id="client_id", |
| 826 | + client_secret="client_secret", |
| 827 | + refresh_token="refresh_token", |
| 828 | + token_expiry_date=ab_datetime_now() - timedelta(hours=1), |
| 829 | + ) |
| 830 | + |
| 831 | + results = [] |
| 832 | + errors = [] |
| 833 | + |
| 834 | + def get_token(): |
| 835 | + try: |
| 836 | + token = oauth.get_access_token() |
| 837 | + results.append(token) |
| 838 | + except Exception as e: |
| 839 | + errors.append(e) |
| 840 | + |
| 841 | + threads = [threading.Thread(target=get_token) for _ in range(5)] |
| 842 | + for t in threads: |
| 843 | + t.start() |
| 844 | + for t in threads: |
| 845 | + t.join() |
| 846 | + |
| 847 | + assert len(errors) == 0, f"Unexpected errors: {errors}" |
| 848 | + assert len(results) == 5 |
| 849 | + assert all(token == "new_access_token" for token in results) |
| 850 | + assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}" |
| 851 | + |
| 852 | + def test_single_use_refresh_token_concurrent_refresh_only_refreshes_once(self, mocker): |
| 853 | + """ |
| 854 | + For SingleUseRefreshTokenOauth2Authenticator, concurrent refresh attempts |
| 855 | + should also only result in one actual refresh to prevent invalidating |
| 856 | + the single-use refresh token. |
| 857 | + """ |
| 858 | + refresh_call_count = 0 |
| 859 | + refresh_call_lock = threading.Lock() |
| 860 | + |
| 861 | + connector_config = { |
| 862 | + "credentials": { |
| 863 | + "client_id": "client_id", |
| 864 | + "client_secret": "client_secret", |
| 865 | + "refresh_token": "refresh_token", |
| 866 | + "access_token": "old_access_token", |
| 867 | + "token_expiry_date": str(ab_datetime_now() - timedelta(hours=1)), |
| 868 | + } |
| 869 | + } |
| 870 | + |
| 871 | + def mock_refresh_access_token(self): |
| 872 | + nonlocal refresh_call_count |
| 873 | + with refresh_call_lock: |
| 874 | + refresh_call_count += 1 |
| 875 | + time.sleep(0.1) |
| 876 | + return ( |
| 877 | + "new_access_token", |
| 878 | + ab_datetime_now() + timedelta(hours=1), |
| 879 | + "new_refresh_token", |
| 880 | + ) |
| 881 | + |
| 882 | + mocker.patch.object( |
| 883 | + SingleUseRefreshTokenOauth2Authenticator, |
| 884 | + "refresh_access_token", |
| 885 | + mock_refresh_access_token, |
| 886 | + ) |
| 887 | + |
| 888 | + mocker.patch.object( |
| 889 | + SingleUseRefreshTokenOauth2Authenticator, |
| 890 | + "_emit_control_message", |
| 891 | + lambda self: None, |
| 892 | + ) |
| 893 | + |
| 894 | + oauth = SingleUseRefreshTokenOauth2Authenticator( |
| 895 | + connector_config=connector_config, |
| 896 | + token_refresh_endpoint="https://refresh_endpoint.com", |
| 897 | + ) |
| 898 | + |
| 899 | + results = [] |
| 900 | + errors = [] |
| 901 | + |
| 902 | + def get_token(): |
| 903 | + try: |
| 904 | + token = oauth.get_access_token() |
| 905 | + results.append(token) |
| 906 | + except Exception as e: |
| 907 | + errors.append(e) |
| 908 | + |
| 909 | + threads = [threading.Thread(target=get_token) for _ in range(5)] |
| 910 | + for t in threads: |
| 911 | + t.start() |
| 912 | + for t in threads: |
| 913 | + t.join() |
| 914 | + |
| 915 | + assert len(errors) == 0, f"Unexpected errors: {errors}" |
| 916 | + assert len(results) == 5 |
| 917 | + assert all(token == "new_access_token" for token in results) |
| 918 | + assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}" |
0 commit comments