diff --git a/components/nimbus/src/stateful/nimbus_client.rs b/components/nimbus/src/stateful/nimbus_client.rs index 6bf64c69c8..05ce8f1201 100644 --- a/components/nimbus/src/stateful/nimbus_client.rs +++ b/components/nimbus/src/stateful/nimbus_client.rs @@ -48,11 +48,9 @@ use crate::stateful::gecko_prefs::{ }; use crate::stateful::matcher::AppContext; use crate::stateful::persistence::{Database, StoreId, Writer}; -use crate::stateful::targeting::{RecordedContext, validate_event_queries}; +use crate::stateful::targeting::{RecordedContext, execute_event_queries, validate_event_queries}; use crate::stateful::updating::{read_and_remove_pending_experiments, write_pending_experiments}; use crate::strings::fmt_with_map; -#[cfg(test)] -use crate::tests::helpers::TestRecordedContext; use crate::{ AvailableExperiment, AvailableRandomizationUnits, EnrolledExperiment, EnrollmentStatus, }; @@ -200,7 +198,7 @@ impl NimbusClient { Ok(v) => v, Err(e) => return Err(NimbusError::JSONError("targeting_helper = nimbus::stateful::nimbus_client::NimbusClient::begin_initialize::serde_json::to_value".into(), e.to_string())) }); - recorded_context.execute_queries(targeting_helper.as_ref())?; + execute_event_queries(&**recorded_context, targeting_helper.as_ref())?; state .targeting_attributes .set_recorded_context(recorded_context.to_json()); @@ -913,23 +911,6 @@ impl NimbusClient { })) } - #[cfg(test)] - pub fn get_recorded_context(&self) -> &&TestRecordedContext { - self.recorded_context - .clone() - .map(|ref recorded_context| - // SAFETY: The cast to TestRecordedContext is safe because the Rust instance is - // guaranteed to be a TestRecordedContext instance. TestRecordedContext is the only - // Rust-implemented version of RecordedContext, and, like this method, is only - // used in tests. - unsafe { - std::mem::transmute::<&&dyn RecordedContext, &&TestRecordedContext>( - &&**recorded_context, - ) - }) - .expect("failed to unwrap RecordedContext object") - } - pub fn set_install_time(&mut self, then: DateTime) { let mut state = self.mutable_state.lock().unwrap(); state.install_date = Some(then); diff --git a/components/nimbus/src/stateful/targeting.rs b/components/nimbus/src/stateful/targeting.rs index dfa805e076..91b3be2ca1 100644 --- a/components/nimbus/src/stateful/targeting.rs +++ b/components/nimbus/src/stateful/targeting.rs @@ -62,55 +62,49 @@ pub trait RecordedContext: Send + Sync { fn record(&self); } -impl dyn RecordedContext { - pub fn execute_queries( - &self, - nimbus_targeting_helper: &NimbusTargetingHelper, - ) -> Result> { - let results: HashMap = - HashMap::from_iter(self.get_event_queries().iter().filter_map(|(key, query)| { - match nimbus_targeting_helper.evaluate_jexl_raw_value(query) { - Ok(result) => match result.as_f64() { - Some(v) => Some((key.clone(), v)), - None => { - warn!( - "Value '{}' for query '{}' was not a string", - result.to_string(), - query - ); - None - } - }, - Err(err) => { - let error_string = format!( - "error during jexl evaluation for query '{}' — {}", - query, err +pub fn execute_event_queries( + recorded_context: &dyn RecordedContext, + nimbus_targeting_helper: &NimbusTargetingHelper, +) -> Result> { + let results: HashMap = + HashMap::from_iter(recorded_context.get_event_queries().iter().filter_map( + |(key, query)| match nimbus_targeting_helper.evaluate_jexl_raw_value(query) { + Ok(result) => match result.as_f64() { + Some(v) => Some((key.clone(), v)), + None => { + warn!( + "Value '{}' for query '{}' was not a string", + result.to_string(), + query ); - warn!("{}", error_string); None } + }, + Err(err) => { + let error_string = format!( + "error during jexl evaluation for query '{}' — {}", + query, err + ); + warn!("{}", error_string); + None } - })); - self.set_event_query_values(results.clone()); - Ok(results) - } + }, + )); + recorded_context.set_event_query_values(results.clone()); + Ok(results) +} - pub fn validate_queries(&self) -> Result<()> { - for query in self.get_event_queries().values() { - match EventQueryType::validate_query(query) { - Ok(true) => continue, - Ok(false) => { - return Err(NimbusError::BehaviorError( - BehaviorError::EventQueryParseError(query.clone()), - )); - } - Err(err) => return Err(err), +pub fn validate_event_queries(recorded_context: Arc) -> Result<()> { + for query in recorded_context.get_event_queries().values() { + match EventQueryType::validate_query(query) { + Ok(true) => continue, + Ok(false) => { + return Err(NimbusError::BehaviorError( + BehaviorError::EventQueryParseError(query.clone()), + )); } + Err(err) => return Err(err), } - Ok(()) } -} - -pub fn validate_event_queries(recorded_context: Arc) -> Result<()> { - recorded_context.validate_queries() + Ok(()) } diff --git a/components/nimbus/src/tests/helpers.rs b/components/nimbus/src/tests/helpers.rs index a9a2ba2068..c2065c0f07 100644 --- a/components/nimbus/src/tests/helpers.rs +++ b/components/nimbus/src/tests/helpers.rs @@ -4,8 +4,6 @@ #![allow(unexpected_cfgs)] -pub use self::detail::*; -use crate::metrics::EnrollmentStatusExtraDef; #[cfg(feature = "stateful")] use std::collections::HashMap; use std::collections::HashSet; @@ -16,12 +14,14 @@ use serde::Serialize; use serde_json::Map; use serde_json::{Value, json}; +pub use self::detail::*; use crate::enrollment::{ EnrolledFeatureConfig, EnrolledReason, EnrollmentChangeEvent, ExperimentEnrollment, NotEnrolledReason, }; #[cfg(feature = "stateful")] use crate::json::JsonObject; +use crate::metrics::EnrollmentStatusExtraDef; #[cfg(feature = "stateful")] use crate::stateful::behavior::EventStore; #[cfg(feature = "stateful")] @@ -80,17 +80,17 @@ struct RecordedContextState { } #[cfg(feature = "stateful")] -#[derive(Clone, Default)] +#[derive(Default)] pub struct TestRecordedContext { - state: Arc>, + state: Mutex, } #[cfg(feature = "stateful")] impl TestRecordedContext { - pub fn new() -> Self { - TestRecordedContext { + pub fn new() -> Arc { + Arc::new(TestRecordedContext { state: Default::default(), - } + }) } pub fn get_record_calls(&self) -> u64 { diff --git a/components/nimbus/src/tests/stateful/test_nimbus.rs b/components/nimbus/src/tests/stateful/test_nimbus.rs index 1a826cc708..e865b33a6f 100644 --- a/components/nimbus/src/tests/stateful/test_nimbus.rs +++ b/components/nimbus/src/tests/stateful/test_nimbus.rs @@ -1680,7 +1680,7 @@ fn test_recorded_context_recorded() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); + let recorded_context = TestRecordedContext::new(); recorded_context.set_context(json!({ "app_version": "125.0.0", "other": "stuff", @@ -1688,7 +1688,7 @@ fn test_recorded_context_recorded() -> Result<()> { let metrics = TestMetrics::new(); let client = NimbusClient::new( app_context.clone(), - Some(recorded_context), + Some(recorded_context.clone()), Default::default(), temp_dir.path(), metrics.clone(), @@ -1707,7 +1707,7 @@ fn test_recorded_context_recorded() -> Result<()> { let active_experiments = client.get_active_experiments()?; assert_eq!(active_experiments.len(), 1); - assert_eq!(client.get_recorded_context().get_record_calls(), 1u64); + assert_eq!(recorded_context.get_record_calls(), 1u64); assert_eq!(metrics.get_submit_targeting_context_calls(), 1u64); Ok(()) @@ -1724,7 +1724,7 @@ fn test_recorded_context_event_queries() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); + let recorded_context = TestRecordedContext::new(); recorded_context.set_context(json!({ "app_version": "125.0.0", "other": "stuff", @@ -1735,7 +1735,7 @@ fn test_recorded_context_event_queries() -> Result<()> { )])); let client = NimbusClient::new( app_context, - Some(recorded_context), + Some(recorded_context.clone()), Default::default(), temp_dir.path(), TestMetrics::new(), @@ -1754,16 +1754,13 @@ fn test_recorded_context_event_queries() -> Result<()> { info!( "{}", - serde_json::to_string(&client.get_recorded_context().get_event_queries())? + serde_json::to_string(&recorded_context.get_event_queries())? ); let active_experiments = client.get_active_experiments()?; - assert_eq!( - client.get_recorded_context().get_event_query_values()["TEST_QUERY"], - 0.0 - ); + assert_eq!(recorded_context.get_event_query_values()["TEST_QUERY"], 0.0); assert_eq!(active_experiments.len(), 1); - assert_eq!(client.get_recorded_context().get_record_calls(), 1u64); + assert_eq!(recorded_context.get_record_calls(), 1u64); Ok(()) } @@ -1779,7 +1776,6 @@ fn test_gecko_pref_enrollment() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); let pref_state = GeckoPrefState::new("test.pref", None) .with_gecko_value(PrefValue::Null) @@ -1792,7 +1788,7 @@ fn test_gecko_pref_enrollment() -> Result<()> { let client = NimbusClient::new( app_context, - Some(recorded_context), + Some(TestRecordedContext::new()), Default::default(), temp_dir.path(), TestMetrics::new(), @@ -1853,7 +1849,6 @@ fn test_gecko_pref_unenrollment() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); let pref_state = GeckoPrefState::new("test.pref", None).with_gecko_value(PrefValue::Null); let handler = TestGeckoPrefHandler::new(create_feature_prop_pref_map(vec![( @@ -1864,7 +1859,7 @@ fn test_gecko_pref_unenrollment() -> Result<()> { let client = NimbusClient::new( app_context, - Some(recorded_context), + Some(TestRecordedContext::new()), Default::default(), temp_dir.path(), TestMetrics::new(), @@ -1979,7 +1974,6 @@ fn test_gecko_pref_unenrollment_reverts() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); let pref_state_1 = GeckoPrefState::new("test.pref.1", None).with_gecko_value(PrefValue::Null); let pref_state_2 = GeckoPrefState::new("test.pref.2", None).with_gecko_value(PrefValue::Null); @@ -1990,7 +1984,7 @@ fn test_gecko_pref_unenrollment_reverts() -> Result<()> { let client = NimbusClient::new( app_context, - Some(recorded_context), + Some(TestRecordedContext::new()), Default::default(), temp_dir.path(), TestMetrics::new(), @@ -2118,7 +2112,6 @@ fn register_previous_gecko_pref_states() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); let pref_state = GeckoPrefState::new("test.pref", None).with_gecko_value(PrefValue::Null); let handler = TestGeckoPrefHandler::new(create_feature_prop_pref_map(vec![( "test_feature", @@ -2127,7 +2120,7 @@ fn register_previous_gecko_pref_states() -> Result<()> { )])); let client = NimbusClient::new( app_context.clone(), - Some(recorded_context), + Some(TestRecordedContext::new()), Default::default(), temp_dir.path(), metrics.clone(), @@ -2315,7 +2308,6 @@ fn test_add_prev_gecko_pref_states_for_experiment() -> Result<()> { app_version: Some("124.0.0".to_string()), ..Default::default() }; - let recorded_context = Arc::new(TestRecordedContext::new()); let pref_state = GeckoPrefState::new("test.pref", None).with_gecko_value(PrefValue::Null); let handler = TestGeckoPrefHandler::new(create_feature_prop_pref_map(vec![( "test_feature", @@ -2324,7 +2316,7 @@ fn test_add_prev_gecko_pref_states_for_experiment() -> Result<()> { )])); let client = NimbusClient::new( app_context.clone(), - Some(recorded_context), + Some(TestRecordedContext::new()), Default::default(), temp_dir.path(), metrics.clone(), diff --git a/components/nimbus/src/tests/stateful/test_targeting.rs b/components/nimbus/src/tests/stateful/test_targeting.rs index 7f18437b17..14db35aef6 100644 --- a/components/nimbus/src/tests/stateful/test_targeting.rs +++ b/components/nimbus/src/tests/stateful/test_targeting.rs @@ -7,7 +7,8 @@ use std::sync::{Arc, Mutex}; use serde_json::Map; -use crate::stateful::{behavior::EventStore, targeting::RecordedContext}; +use crate::stateful::behavior::EventStore; +use crate::stateful::targeting::{execute_event_queries, validate_event_queries}; use crate::tests::helpers::TestRecordedContext; use crate::{NimbusTargetingHelper, Result}; @@ -31,23 +32,14 @@ fn test_recorded_context_execute_queries() -> Result<()> { let recorded_context = TestRecordedContext::new(); recorded_context.set_event_queries(map.clone()); - let recorded_context: Box = Box::new(recorded_context); + execute_event_queries(&*recorded_context, &targeting_helper)?; - recorded_context.execute_queries(&targeting_helper)?; - - // SAFETY: The cast to TestRecordedContext is safe because the Rust instance is - // guaranteed to be a TestRecordedContext instance. TestRecordedContext is the only - // Rust-implemented version of RecordedContext, and, like this method, is only - // used in tests. - let test_recorded_context = unsafe { - std::mem::transmute::<&&dyn RecordedContext, &&TestRecordedContext>(&&*recorded_context) - }; assert_eq!( - test_recorded_context.get_event_query_values()["TEST_QUERY_SUCCESS"], + recorded_context.get_event_query_values()["TEST_QUERY_SUCCESS"], 1.0 ); assert!( - !test_recorded_context + !recorded_context .get_event_query_values() .contains_key("TEST_QUERY_FAIL_NOT_VALID_QUERY") ); @@ -70,9 +62,7 @@ fn test_recorded_context_validate_queries() -> Result<()> { let recorded_context = TestRecordedContext::new(); recorded_context.set_event_queries(map.clone()); - let recorded_context: Box = Box::new(recorded_context); - - let result = recorded_context.validate_queries(); + let result = validate_event_queries(recorded_context); assert!(result.is_err_and(|e| { assert_eq!(e.to_string(), "Behavior error: EventQueryParseError: \"'event'|eventYolo('Days', 1, 0)\" is not a valid EventQuery".to_string()); true