Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
import javax.inject.Named;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
import static org.opensearch.dataprepper.plugins.source.microsoft_office365.utils.Constants.CONTENT_TYPES;
Expand Down Expand Up @@ -89,59 +92,136 @@ public Office365RestClient(final Office365AuthenticationInterface authConfig,
this.errorTypeMetricCounterMap = getErrorTypeMetricCounterMap(pluginMetrics);
}

/**
* Lists current subscriptions for Office 365 audit logs.
*
* @return List of subscription maps containing contentType, status, and webhook information
*/
private List<Map<String, Object>> listSubscriptions() {
log.info("Listing Office 365 subscriptions");
final String SUBSCRIPTION_LIST_URL = MANAGEMENT_API_BASE_URL + "%s/activity/feed/subscriptions/list";
String listUrl = String.format(SUBSCRIPTION_LIST_URL, authConfig.getTenantId());

HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);

try {
return RetryHandler.executeWithRetry(() -> {
headers.setBearerAuth(authConfig.getAccessToken());
ResponseEntity<List<Map<String, Object>>> response = restTemplate.exchange(
listUrl,
HttpMethod.GET,
new HttpEntity<>(headers),
new ParameterizedTypeReference<>() {}
);
log.debug("Current subscriptions: {}", response.getBody());
return response.getBody();
}, authConfig::renewCredentials);
} catch (HttpClientErrorException | HttpServerErrorException e) {
HttpStatus statusCode = e.getStatusCode();
publishErrorTypeMetricCounter(statusCode.getReasonPhrase(), this.errorTypeMetricCounterMap);
log.error(NOISY, "Failed to list subscriptions with status code {}: {}",
statusCode, e.getMessage());
throw new RuntimeException("Failed to list subscriptions: " + e.getMessage(), e);
} catch (Exception e) {
if (e instanceof SecurityException) {
publishErrorTypeMetricCounter(HttpStatus.FORBIDDEN.getReasonPhrase(), this.errorTypeMetricCounterMap);
}
log.error(NOISY, "Failed to list subscriptions", e);
throw new RuntimeException("Failed to list subscriptions: " + e.getMessage(), e);
}
}

/**
* Starts subscriptions for the specified content types.
*
* @param contentTypesToStart List of content types to start subscriptions for
*/
private void startSubscriptionsForContentTypes(List<String> contentTypesToStart) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setContentLength(0);

for (String contentType : contentTypesToStart) {
final String SUBSCRIPTION_START_URL = MANAGEMENT_API_BASE_URL + "%s/activity/feed/subscriptions/start?contentType=%s";
String url = String.format(SUBSCRIPTION_START_URL,
authConfig.getTenantId(),
contentType);

RetryHandler.executeWithRetry(() -> {
try {
headers.setBearerAuth(authConfig.getAccessToken());
ResponseEntity<String> response = restTemplate.exchange(
url,
HttpMethod.POST,
new HttpEntity<>(headers),
String.class
);
log.info("Successfully started subscription for {}: {}", contentType, response.getBody());
return response;
} catch (HttpClientErrorException | HttpServerErrorException e) {
if (e.getResponseBodyAsString().contains("AF20024")) {
log.debug("Subscription for {} is already enabled", contentType);
return null;
}
throw e;
}
}, authConfig::renewCredentials);
}

log.info("Successfully started {} subscription(s)", contentTypesToStart.size());
}

/**
* Starts and verifies subscriptions for Office 365 audit logs.
* Only starts subscriptions for content types that are not already enabled.
* If listing subscriptions fails, falls back to starting all content types.
*/
public void startSubscriptions() {
log.info("Starting Office 365 subscriptions for audit logs");

List<String> contentTypesToStart = new ArrayList<>();

// Try to get current subscriptions to determine which need to be started
try {
HttpHeaders headers = new HttpHeaders();

headers.setContentType(MediaType.APPLICATION_JSON);

// TODO: Only start the subscriptions only if the call commented
// out below doesn't return all the audit log types
// Check current subscriptions
// final String SUBSCRIPTION_LIST_URL = MANAGEMENT_API_BASE_URL + "%s/activity/feed/subscriptions/list";
// String listUrl = String.format(SUBSCRIPTION_LIST_URL, authConfig.getTenantId());
//
// ResponseEntity<String> listResponse = restTemplate.exchange(
// listUrl,
// HttpMethod.GET,
// new HttpEntity<>(headers),
// String.class
// );
// log.debug("Current subscriptions: {}", listResponse.getBody());

// Start subscriptions for each content type
headers.setContentLength(0);
List<Map<String, Object>> currentSubscriptions = listSubscriptions();

// Determine which content types are already enabled
Set<String> enabledContentTypes = new HashSet<>();
for (Map<String, Object> subscription : currentSubscriptions) {
String contentType = (String) subscription.get("contentType");
String status = (String) subscription.get("status");

if ("enabled".equalsIgnoreCase(status)) {
enabledContentTypes.add(contentType);
log.info("Content type {} is already enabled", contentType);
}
}

// Identify content types that need to be started
for (String contentType : CONTENT_TYPES) {
final String SUBSCRIPTION_START_URL = MANAGEMENT_API_BASE_URL + "%s/activity/feed/subscriptions/start?contentType=%s";
String url = String.format(SUBSCRIPTION_START_URL,
authConfig.getTenantId(),
contentType);

RetryHandler.executeWithRetry(() -> {
try {
headers.setBearerAuth(authConfig.getAccessToken());
ResponseEntity<String> response = restTemplate.exchange(
url,
HttpMethod.POST,
new HttpEntity<>(headers),
String.class
);
log.debug("Started subscription for {}: {}", contentType, response.getBody());
return response;
} catch (HttpClientErrorException | HttpServerErrorException e) {
if (e.getResponseBodyAsString().contains("AF20024")) {
log.debug("Subscription for {} is already enabled", contentType);
return null;
}
throw e;
}
}, authConfig::renewCredentials);
if (!enabledContentTypes.contains(contentType)) {
contentTypesToStart.add(contentType);
log.info("Content type {} needs to be started", contentType);
}
}

// If all content types are already enabled, we're done
if (contentTypesToStart.isEmpty()) {
log.info("All content types are already enabled. No subscriptions need to be started.");
return;
}
} catch (Exception e) {
// If listing subscriptions fails, fall back to starting all content types
log.warn("Failed to list subscriptions, will attempt to start all content types as fallback: {}", e.getMessage());
for (String contentType : CONTENT_TYPES) {
contentTypesToStart.add(contentType);
}
}

// Start subscriptions for the identified content types
try {
startSubscriptionsForContentTypes(contentTypesToStart);
} catch (HttpClientErrorException | HttpServerErrorException e) {
HttpStatus statusCode = e.getStatusCode();
publishErrorTypeMetricCounter(statusCode.getReasonPhrase(), this.errorTypeMetricCounterMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,201 @@ void setUp() throws NoSuchFieldException, IllegalAccessException{

@Test
void testStartSubscriptionsSuccess() {
// Mock auth config
when(authConfig.getTenantId()).thenReturn("test-tenant-id");
when(authConfig.getAccessToken()).thenReturn("test-access-token");

// Mock listSubscriptions to return all subscriptions as disabled
List<Map<String, Object>> mockSubscriptions = new ArrayList<>();
for (String contentType : CONTENT_TYPES) {
Map<String, Object> subscription = new HashMap<>();
subscription.put("contentType", contentType);
subscription.put("status", "disabled");
mockSubscriptions.add(subscription);
}
ResponseEntity<List<Map<String, Object>>> listResponse = new ResponseEntity<>(mockSubscriptions, HttpStatus.OK);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
)).thenReturn(listResponse);

// Mock startSubscription calls
ResponseEntity<String> mockResponse = new ResponseEntity<>("{\"status\":\"enabled\"}", HttpStatus.OK);
when(restTemplate.exchange(anyString(), eq(HttpMethod.POST), any(), eq(String.class)))
.thenReturn(mockResponse);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.POST),
any(),
eq(String.class)
)).thenReturn(mockResponse);

assertDoesNotThrow(() -> office365RestClient.startSubscriptions());

// Verify list was called once
ArgumentCaptor<String> listUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(1)).exchange(
listUrlCaptor.capture(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
);
assertTrue(listUrlCaptor.getValue().contains("/subscriptions/list"));

// Verify start was called for all content types
ArgumentCaptor<String> startUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(CONTENT_TYPES.length)).exchange(
startUrlCaptor.capture(),
eq(HttpMethod.POST),
any(),
eq(String.class)
);
assertTrue(startUrlCaptor.getAllValues().stream().allMatch(url -> url.contains("/subscriptions/start")));
}

@Test
void testStartSubscriptionsPartiallyEnabled() {
// Mock auth config
when(authConfig.getTenantId()).thenReturn("test-tenant-id");
when(authConfig.getAccessToken()).thenReturn("test-access-token");

// Mock listSubscriptions to return some subscriptions as enabled
List<Map<String, Object>> mockSubscriptions = new ArrayList<>();
for (int i = 0; i < CONTENT_TYPES.length; i++) {
Map<String, Object> subscription = new HashMap<>();
subscription.put("contentType", CONTENT_TYPES[i]);
// First two are enabled, rest are disabled
subscription.put("status", i < 2 ? "enabled" : "disabled");
mockSubscriptions.add(subscription);
}
ResponseEntity<List<Map<String, Object>>> listResponse = new ResponseEntity<>(mockSubscriptions, HttpStatus.OK);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
)).thenReturn(listResponse);

// Mock startSubscription calls
ResponseEntity<String> mockResponse = new ResponseEntity<>("{\"status\":\"enabled\"}", HttpStatus.OK);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.POST),
any(),
eq(String.class)
)).thenReturn(mockResponse);

assertDoesNotThrow(() -> office365RestClient.startSubscriptions());

// Verify list was called once
ArgumentCaptor<String> listUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(1)).exchange(
listUrlCaptor.capture(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
);
assertTrue(listUrlCaptor.getValue().contains("/subscriptions/list"));

// Verify start was called only for disabled content types (CONTENT_TYPES.length - 2)
ArgumentCaptor<String> startUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(CONTENT_TYPES.length - 2)).exchange(
startUrlCaptor.capture(),
eq(HttpMethod.POST),
any(),
eq(String.class)
);
assertTrue(startUrlCaptor.getAllValues().stream().allMatch(url -> url.contains("/subscriptions/start")));
}

@Test
void testStartSubscriptionsAllEnabled() {
// Mock auth config
when(authConfig.getTenantId()).thenReturn("test-tenant-id");
when(authConfig.getAccessToken()).thenReturn("test-access-token");

// Mock listSubscriptions to return all subscriptions as enabled
List<Map<String, Object>> mockSubscriptions = new ArrayList<>();
for (String contentType : CONTENT_TYPES) {
Map<String, Object> subscription = new HashMap<>();
subscription.put("contentType", contentType);
subscription.put("status", "enabled");
mockSubscriptions.add(subscription);
}
ResponseEntity<List<Map<String, Object>>> listResponse = new ResponseEntity<>(mockSubscriptions, HttpStatus.OK);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
)).thenReturn(listResponse);

assertDoesNotThrow(() -> office365RestClient.startSubscriptions());

// Verify list was called once
ArgumentCaptor<String> listUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(1)).exchange(
listUrlCaptor.capture(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
);
assertTrue(listUrlCaptor.getValue().contains("/subscriptions/list"));

// Verify start was never called since all are enabled
verify(restTemplate, never()).exchange(
anyString(),
eq(HttpMethod.POST),
any(),
eq(String.class)
);
}

@Test
void testStartSubscriptionsListFailsFallbackToAll() {
// Mock auth config
when(authConfig.getTenantId()).thenReturn("test-tenant-id");
when(authConfig.getAccessToken()).thenReturn("test-access-token");

// Mock listSubscriptions to throw an exception
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
)).thenThrow(new HttpClientErrorException(HttpStatus.INTERNAL_SERVER_ERROR));

// Mock startSubscription calls to succeed
ResponseEntity<String> mockResponse = new ResponseEntity<>("{\"status\":\"enabled\"}", HttpStatus.OK);
when(restTemplate.exchange(
anyString(),
eq(HttpMethod.POST),
any(),
eq(String.class)
)).thenReturn(mockResponse);

// Should not throw exception, should fall back to starting all
assertDoesNotThrow(() -> office365RestClient.startSubscriptions());

// Verify list was attempted more than once
ArgumentCaptor<String> listUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(6)).exchange(
listUrlCaptor.capture(),
eq(HttpMethod.GET),
any(),
any(ParameterizedTypeReference.class)
);
assertTrue(listUrlCaptor.getValue().contains("/subscriptions/list"));

// Verify start was called for all content types as fallback
ArgumentCaptor<String> startUrlCaptor = ArgumentCaptor.forClass(String.class);
verify(restTemplate, times(CONTENT_TYPES.length)).exchange(
startUrlCaptor.capture(),
eq(HttpMethod.POST),
any(),
eq(String.class)
);
assertTrue(startUrlCaptor.getAllValues().stream().allMatch(url -> url.contains("/subscriptions/start")));
}

@Test
Expand Down
Loading