55
66use anyhow:: { Result , bail} ;
77use log:: { debug, error, info, warn} ;
8+ use serde:: Deserialize ;
89use serde_json:: Value ;
10+ use std:: collections:: HashMap ;
911use std:: path:: Path ;
1012
1113use crate :: ndjson:: { self , SAFE_OUTPUT_FILENAME } ;
1214use crate :: tools:: {
13- CommentOnWorkItemConfig , CommentOnWorkItemResult , CreatePrResult , CreateWikiPageResult , CreateWorkItemResult , ExecutionContext , ExecutionResult ,
14- Executor , UpdateWikiPageResult , UpdateWorkItemConfig , UpdateWorkItemResult ,
15+ CreatePrResult , CreateWikiPageResult , CreateWorkItemResult , CommentOnWorkItemResult ,
16+ ExecutionContext , ExecutionResult , Executor , UpdateWikiPageResult , UpdateWorkItemResult ,
1517} ;
1618
1719// Re-export memory types for use by main.rs
@@ -87,15 +89,22 @@ pub async fn execute_safe_outputs(
8789 }
8890 }
8991
90- // Fetch the update-work-item max once; used to skip excess entries without aborting the batch
91- let update_wi_config: UpdateWorkItemConfig = ctx. get_tool_config ( "update-work-item" ) ;
92- let max_update_wi = update_wi_config. max as usize ;
93- let mut update_wi_executed: usize = 0 ;
94-
95- // Fetch the comment-on-work-item max once; same skip-and-continue pattern
96- let comment_wi_config: CommentOnWorkItemConfig = ctx. get_tool_config ( "comment-on-work-item" ) ;
97- let max_comment_wi = comment_wi_config. max as usize ;
98- let mut comment_wi_executed: usize = 0 ;
92+ // Build budget map: tool_name → (executed_count, max_allowed).
93+ // All safe-output tools that perform side-effects are budgeted. The `max` field
94+ // is extracted generically from each tool's config JSON (defaulting to 1).
95+ let budgeted_tools = [
96+ "create-work-item" ,
97+ "create-pull-request" ,
98+ "update-work-item" ,
99+ "comment-on-work-item" ,
100+ "create-wiki-page" ,
101+ "update-wiki-page" ,
102+ ] ;
103+ let mut budgets: HashMap < & str , ( usize , usize ) > = HashMap :: new ( ) ;
104+ for tool_name in & budgeted_tools {
105+ let max_config: MaxConfig = ctx. get_tool_config ( tool_name) ;
106+ budgets. insert ( tool_name, ( 0 , max_config. max as usize ) ) ;
107+ }
99108
100109 let mut results = Vec :: new ( ) ;
101110 for ( i, entry) in entries. iter ( ) . enumerate ( ) {
@@ -107,35 +116,18 @@ pub async fn execute_safe_outputs(
107116 entry_json
108117 ) ;
109118
110- // Enforce update-work-item max : skip excess entries rather than aborting the whole batch.
119+ // Generic budget enforcement : skip excess entries rather than aborting the whole batch.
111120 // Budget is consumed before execution so that failed attempts (target policy rejection,
112121 // network errors) still count — this prevents unbounded retries against a failing endpoint.
113- if entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) == Some ( "update-work-item" ) {
114- let wi_id = entry
115- . get ( "id" )
116- . and_then ( |v| v. as_u64 ( ) )
117- . map ( |id| format ! ( " (work item #{})" , id) )
118- . unwrap_or_default ( ) ;
119- if let Some ( result) = check_budget ( entries. len ( ) , i, "update-work-item" , & wi_id, update_wi_executed, max_update_wi) {
120- results. push ( result) ;
121- continue ;
122- }
123- update_wi_executed += 1 ;
124- }
125-
126- // Enforce comment-on-work-item max: same skip-and-continue pattern as update-work-item.
127- // Budget is consumed before execution so that failed attempts still count.
128- if entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) == Some ( "comment-on-work-item" ) {
129- let wi_id = entry
130- . get ( "work_item_id" )
131- . and_then ( |v| v. as_i64 ( ) )
132- . map ( |id| format ! ( " (work item #{})" , id) )
133- . unwrap_or_default ( ) ;
134- if let Some ( result) = check_budget ( entries. len ( ) , i, "comment-on-work-item" , & wi_id, comment_wi_executed, max_comment_wi) {
135- results. push ( result) ;
136- continue ;
122+ if let Some ( tool_name) = entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) {
123+ if let Some ( ( executed, max) ) = budgets. get_mut ( tool_name) {
124+ let context_id = extract_entry_context ( entry) ;
125+ if let Some ( result) = check_budget ( entries. len ( ) , i, tool_name, & context_id, * executed, * max) {
126+ results. push ( result) ;
127+ continue ;
128+ }
129+ * executed += 1 ;
137130 }
138- comment_wi_executed += 1 ;
139131 }
140132
141133 match execute_safe_output ( entry, ctx) . await {
@@ -284,6 +276,42 @@ pub async fn execute_safe_output(
284276 Ok ( ( tool_name. to_string ( ) , result) )
285277}
286278
279+ /// Helper struct for extracting the `max` field from any tool's config JSON.
280+ /// All safe-output tool configs use `max` with a default of 1.
281+ #[ derive( Deserialize ) ]
282+ struct MaxConfig {
283+ #[ serde( default = "default_max" ) ]
284+ max : u32 ,
285+ }
286+
287+ fn default_max ( ) -> u32 {
288+ 1
289+ }
290+
291+ impl Default for MaxConfig {
292+ fn default ( ) -> Self {
293+ Self { max : default_max ( ) }
294+ }
295+ }
296+
297+ /// Extract a human-readable context identifier from a safe-output entry for log messages.
298+ fn extract_entry_context ( entry : & Value ) -> String {
299+ if let Some ( id) = entry. get ( "id" ) . and_then ( |v| v. as_u64 ( ) ) {
300+ return format ! ( " (work item #{})" , id) ;
301+ }
302+ if let Some ( id) = entry. get ( "work_item_id" ) . and_then ( |v| v. as_i64 ( ) ) {
303+ return format ! ( " (work item #{})" , id) ;
304+ }
305+ if let Some ( title) = entry. get ( "title" ) . and_then ( |v| v. as_str ( ) ) {
306+ let truncated = if title. len ( ) > 40 { & title[ ..40 ] } else { title } ;
307+ return format ! ( " (\" {}\" )" , truncated) ;
308+ }
309+ if let Some ( path) = entry. get ( "path" ) . and_then ( |v| v. as_str ( ) ) {
310+ return format ! ( " (path: {})" , path) ;
311+ }
312+ String :: new ( )
313+ }
314+
287315/// Returns `Some(result)` when the budget for `tool_name` is exhausted so the caller can push the
288316/// result and `continue` to the next entry. Returns `None` when a budget slot is still available
289317/// and the caller should proceed with execution.
@@ -735,4 +763,153 @@ mod tests {
735763 let r = result. unwrap ( ) ;
736764 assert ! ( r. message. contains( "(work item #42)" ) ) ;
737765 }
766+
767+ // --- extract_entry_context unit tests ---
768+
769+ #[ test]
770+ fn test_extract_entry_context_with_id ( ) {
771+ let entry = serde_json:: json!( { "name" : "update-work-item" , "id" : 42 } ) ;
772+ assert_eq ! ( extract_entry_context( & entry) , " (work item #42)" ) ;
773+ }
774+
775+ #[ test]
776+ fn test_extract_entry_context_with_work_item_id ( ) {
777+ let entry = serde_json:: json!( { "name" : "comment-on-work-item" , "work_item_id" : 99 } ) ;
778+ assert_eq ! ( extract_entry_context( & entry) , " (work item #99)" ) ;
779+ }
780+
781+ #[ test]
782+ fn test_extract_entry_context_with_title ( ) {
783+ let entry = serde_json:: json!( { "name" : "create-work-item" , "title" : "Fix the bug" } ) ;
784+ assert_eq ! ( extract_entry_context( & entry) , " (\" Fix the bug\" )" ) ;
785+ }
786+
787+ #[ test]
788+ fn test_extract_entry_context_with_path ( ) {
789+ let entry = serde_json:: json!( { "name" : "create-wiki-page" , "path" : "/Overview/NewPage" } ) ;
790+ assert_eq ! ( extract_entry_context( & entry) , " (path: /Overview/NewPage)" ) ;
791+ }
792+
793+ #[ test]
794+ fn test_extract_entry_context_empty ( ) {
795+ let entry = serde_json:: json!( { "name" : "noop" } ) ;
796+ assert_eq ! ( extract_entry_context( & entry) , "" ) ;
797+ }
798+
799+ // --- MaxConfig unit tests ---
800+
801+ #[ test]
802+ fn test_max_config_default ( ) {
803+ let config = MaxConfig :: default ( ) ;
804+ assert_eq ! ( config. max, 1 ) ;
805+ }
806+
807+ #[ test]
808+ fn test_max_config_from_json_with_max ( ) {
809+ let json = serde_json:: json!( { "max" : 5 , "other_field" : true } ) ;
810+ let config: MaxConfig = serde_json:: from_value ( json) . unwrap ( ) ;
811+ assert_eq ! ( config. max, 5 ) ;
812+ }
813+
814+ #[ test]
815+ fn test_max_config_from_json_without_max ( ) {
816+ let json = serde_json:: json!( { "other_field" : true } ) ;
817+ let config: MaxConfig = serde_json:: from_value ( json) . unwrap ( ) ;
818+ assert_eq ! ( config. max, 1 ) ;
819+ }
820+
821+ // --- Generic budget enforcement for all tool types ---
822+
823+ #[ tokio:: test]
824+ async fn test_budget_enforcement_create_work_item_max ( ) {
825+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
826+ let safe_output_path = temp_dir. path ( ) . join ( SAFE_OUTPUT_FILENAME ) ;
827+
828+ // Write 3 create-work-item entries + 1 noop; max set to 2
829+ let ndjson = r#"{"name":"create-work-item","title":"First item","description":"A description that is definitely longer than thirty characters."}
830+ {"name":"create-work-item","title":"Second item","description":"A description that is definitely longer than thirty characters."}
831+ {"name":"create-work-item","title":"Third item","description":"A description that is definitely longer than thirty characters."}
832+ {"name":"noop","context":"still runs"}
833+ "# ;
834+ tokio:: fs:: write ( & safe_output_path, ndjson) . await . unwrap ( ) ;
835+
836+ let mut tool_configs = HashMap :: new ( ) ;
837+ tool_configs. insert ( "create-work-item" . to_string ( ) , serde_json:: json!( { "max" : 2 } ) ) ;
838+
839+ let ctx = ExecutionContext {
840+ ado_org_url : Some ( "https://dev.azure.com/org" . to_string ( ) ) ,
841+ ado_organization : Some ( "org" . to_string ( ) ) ,
842+ ado_project : Some ( "Proj" . to_string ( ) ) ,
843+ access_token : Some ( "token" . to_string ( ) ) ,
844+ working_directory : PathBuf :: from ( "." ) ,
845+ source_directory : PathBuf :: from ( "." ) ,
846+ tool_configs,
847+ repository_id : None ,
848+ repository_name : None ,
849+ allowed_repositories : HashMap :: new ( ) ,
850+ } ;
851+
852+ let results = execute_safe_outputs ( temp_dir. path ( ) , & ctx) . await ;
853+ assert ! ( results. is_ok( ) , "Batch should not abort when max is exceeded" ) ;
854+ let results = results. unwrap ( ) ;
855+ assert_eq ! ( results. len( ) , 4 , "Expected 4 results" ) ;
856+
857+ // Only 1 should be skipped (max=2 allows first 2, third is skipped)
858+ let skipped: Vec < _ > = results
859+ . iter ( )
860+ . filter ( |r| r. message . contains ( "maximum create-work-item count" ) )
861+ . collect ( ) ;
862+ assert_eq ! ( skipped. len( ) , 1 , "Expected 1 skipped entry, got: {:?}" , skipped) ;
863+
864+ // noop still runs
865+ assert ! ( results[ 3 ] . success, "noop should still succeed" ) ;
866+ }
867+
868+ #[ tokio:: test]
869+ async fn test_budget_enforcement_mixed_tools_independent_budgets ( ) {
870+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
871+ let safe_output_path = temp_dir. path ( ) . join ( SAFE_OUTPUT_FILENAME ) ;
872+
873+ // Mix of tools: each has max=1 (default), so only the first of each type should pass budget
874+ let ndjson = r#"{"name":"create-work-item","title":"WI 1","description":"A description that is definitely longer than thirty characters."}
875+ {"name":"create-work-item","title":"WI 2","description":"A description that is definitely longer than thirty characters."}
876+ {"name":"create-wiki-page","path":"/Page1","content":"Some valid wiki content here."}
877+ {"name":"create-wiki-page","path":"/Page2","content":"Some valid wiki content here."}
878+ {"name":"noop","context":"always runs"}
879+ "# ;
880+ tokio:: fs:: write ( & safe_output_path, ndjson) . await . unwrap ( ) ;
881+
882+ let ctx = ExecutionContext {
883+ ado_org_url : Some ( "https://dev.azure.com/org" . to_string ( ) ) ,
884+ ado_organization : Some ( "org" . to_string ( ) ) ,
885+ ado_project : Some ( "Proj" . to_string ( ) ) ,
886+ access_token : Some ( "token" . to_string ( ) ) ,
887+ working_directory : PathBuf :: from ( "." ) ,
888+ source_directory : PathBuf :: from ( "." ) ,
889+ tool_configs : HashMap :: new ( ) , // defaults: max=1 for all
890+ repository_id : None ,
891+ repository_name : None ,
892+ allowed_repositories : HashMap :: new ( ) ,
893+ } ;
894+
895+ let results = execute_safe_outputs ( temp_dir. path ( ) , & ctx) . await . unwrap ( ) ;
896+ assert_eq ! ( results. len( ) , 5 ) ;
897+
898+ // Second create-work-item should be skipped
899+ let cwi_skipped: Vec < _ > = results
900+ . iter ( )
901+ . filter ( |r| r. message . contains ( "maximum create-work-item count" ) )
902+ . collect ( ) ;
903+ assert_eq ! ( cwi_skipped. len( ) , 1 , "Expected 1 skipped create-work-item" ) ;
904+
905+ // Second create-wiki-page should be skipped
906+ let cwp_skipped: Vec < _ > = results
907+ . iter ( )
908+ . filter ( |r| r. message . contains ( "maximum create-wiki-page count" ) )
909+ . collect ( ) ;
910+ assert_eq ! ( cwp_skipped. len( ) , 1 , "Expected 1 skipped create-wiki-page" ) ;
911+
912+ // noop always runs
913+ assert ! ( results[ 4 ] . success, "noop should still succeed" ) ;
914+ }
738915}
0 commit comments