66use anyhow:: { Result , bail} ;
77use log:: { debug, error, info, warn} ;
88use serde_json:: Value ;
9+ use std:: collections:: HashMap ;
910use std:: path:: Path ;
1011
1112use crate :: ndjson:: { self , SAFE_OUTPUT_FILENAME } ;
1213use crate :: tools:: {
13- CommentOnWorkItemConfig , CommentOnWorkItemResult , CreatePrResult , CreateWikiPageResult , CreateWorkItemResult , ExecutionContext , ExecutionResult ,
14- Executor , UpdateWikiPageResult , UpdateWorkItemConfig , UpdateWorkItemResult ,
14+ CreatePrResult , CreateWikiPageResult , CreateWorkItemResult , CommentOnWorkItemResult ,
15+ ExecutionContext , ExecutionResult , Executor , ToolResult ,
16+ UpdateWikiPageResult , UpdateWorkItemResult ,
1517} ;
1618
1719// Re-export memory types for use by main.rs
@@ -87,15 +89,28 @@ pub async fn execute_safe_outputs(
8789 }
8890 }
8991
90- // Fetch the update-work-item max once; used to skip excess entries without aborting the batch
91- let update_wi_config: UpdateWorkItemConfig = ctx. get_tool_config ( "update-work-item" ) ;
92- let max_update_wi = update_wi_config. max as usize ;
93- let mut update_wi_executed: usize = 0 ;
94-
95- // Fetch the comment-on-work-item max once; same skip-and-continue pattern
96- let comment_wi_config: CommentOnWorkItemConfig = ctx. get_tool_config ( "comment-on-work-item" ) ;
97- let max_comment_wi = comment_wi_config. max as usize ;
98- let mut comment_wi_executed: usize = 0 ;
92+ // Build budget map: tool_name → (executed_count, max_allowed).
93+ // Each tool declares its DEFAULT_MAX via the ToolResult trait; the operator can
94+ // override it with `max` in the front-matter config JSON.
95+ let mut budgets: HashMap < & str , ( usize , usize ) > = HashMap :: new ( ) ;
96+ macro_rules! register_budgets {
97+ ( $( $tool: ty) ,+ $( , ) ?) => {
98+ $( {
99+ let name = <$tool>:: NAME ;
100+ let default = <$tool>:: DEFAULT_MAX ;
101+ let max = resolve_max( ctx, name, default ) ;
102+ budgets. insert( name, ( 0 , max) ) ;
103+ } ) +
104+ } ;
105+ }
106+ register_budgets ! (
107+ CreateWorkItemResult ,
108+ CreatePrResult ,
109+ UpdateWorkItemResult ,
110+ CommentOnWorkItemResult ,
111+ CreateWikiPageResult ,
112+ UpdateWikiPageResult ,
113+ ) ;
99114
100115 let mut results = Vec :: new ( ) ;
101116 for ( i, entry) in entries. iter ( ) . enumerate ( ) {
@@ -107,35 +122,18 @@ pub async fn execute_safe_outputs(
107122 entry_json
108123 ) ;
109124
110- // Enforce update-work-item max : skip excess entries rather than aborting the whole batch.
125+ // Generic budget enforcement : skip excess entries rather than aborting the whole batch.
111126 // Budget is consumed before execution so that failed attempts (target policy rejection,
112127 // network errors) still count — this prevents unbounded retries against a failing endpoint.
113- if entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) == Some ( "update-work-item" ) {
114- let wi_id = entry
115- . get ( "id" )
116- . and_then ( |v| v. as_u64 ( ) )
117- . map ( |id| format ! ( " (work item #{})" , id) )
118- . unwrap_or_default ( ) ;
119- if let Some ( result) = check_budget ( entries. len ( ) , i, "update-work-item" , & wi_id, update_wi_executed, max_update_wi) {
120- results. push ( result) ;
121- continue ;
122- }
123- update_wi_executed += 1 ;
124- }
125-
126- // Enforce comment-on-work-item max: same skip-and-continue pattern as update-work-item.
127- // Budget is consumed before execution so that failed attempts still count.
128- if entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) == Some ( "comment-on-work-item" ) {
129- let wi_id = entry
130- . get ( "work_item_id" )
131- . and_then ( |v| v. as_i64 ( ) )
132- . map ( |id| format ! ( " (work item #{})" , id) )
133- . unwrap_or_default ( ) ;
134- if let Some ( result) = check_budget ( entries. len ( ) , i, "comment-on-work-item" , & wi_id, comment_wi_executed, max_comment_wi) {
135- results. push ( result) ;
136- continue ;
128+ if let Some ( tool_name) = entry. get ( "name" ) . and_then ( |n| n. as_str ( ) ) {
129+ if let Some ( ( executed, max) ) = budgets. get_mut ( tool_name) {
130+ let context_id = extract_entry_context ( entry) ;
131+ if let Some ( result) = check_budget ( entries. len ( ) , i, tool_name, & context_id, * executed, * max) {
132+ results. push ( result) ;
133+ continue ;
134+ }
135+ * executed += 1 ;
137136 }
138- comment_wi_executed += 1 ;
139137 }
140138
141139 match execute_safe_output ( entry, ctx) . await {
@@ -284,6 +282,35 @@ pub async fn execute_safe_output(
284282 Ok ( ( tool_name. to_string ( ) , result) )
285283}
286284
285+ /// Read the operator's `max` override from the tool's config JSON, falling back to the
286+ /// tool's `DEFAULT_MAX` (declared on the `ToolResult` trait) when not configured.
287+ fn resolve_max ( ctx : & ExecutionContext , tool_name : & str , default_max : u32 ) -> usize {
288+ ctx. tool_configs
289+ . get ( tool_name)
290+ . and_then ( |v| v. get ( "max" ) )
291+ . and_then ( |v| v. as_u64 ( ) )
292+ . map ( |v| v as usize )
293+ . unwrap_or ( default_max as usize )
294+ }
295+
296+ /// Extract a human-readable context identifier from a safe-output entry for log messages.
297+ fn extract_entry_context ( entry : & Value ) -> String {
298+ if let Some ( id) = entry. get ( "id" ) . and_then ( |v| v. as_u64 ( ) ) {
299+ return format ! ( " (work item #{})" , id) ;
300+ }
301+ if let Some ( id) = entry. get ( "work_item_id" ) . and_then ( |v| v. as_i64 ( ) ) {
302+ return format ! ( " (work item #{})" , id) ;
303+ }
304+ if let Some ( title) = entry. get ( "title" ) . and_then ( |v| v. as_str ( ) ) {
305+ let truncated = if title. len ( ) > 40 { & title[ ..40 ] } else { title } ;
306+ return format ! ( " (\" {}\" )" , truncated) ;
307+ }
308+ if let Some ( path) = entry. get ( "path" ) . and_then ( |v| v. as_str ( ) ) {
309+ return format ! ( " (path: {})" , path) ;
310+ }
311+ String :: new ( )
312+ }
313+
287314/// Returns `Some(result)` when the budget for `tool_name` is exhausted so the caller can push the
288315/// result and `continue` to the next entry. Returns `None` when a budget slot is still available
289316/// and the caller should proceed with execution.
@@ -735,4 +762,171 @@ mod tests {
735762 let r = result. unwrap ( ) ;
736763 assert ! ( r. message. contains( "(work item #42)" ) ) ;
737764 }
765+
766+ // --- extract_entry_context unit tests ---
767+
768+ #[ test]
769+ fn test_extract_entry_context_with_id ( ) {
770+ let entry = serde_json:: json!( { "name" : "update-work-item" , "id" : 42 } ) ;
771+ assert_eq ! ( extract_entry_context( & entry) , " (work item #42)" ) ;
772+ }
773+
774+ #[ test]
775+ fn test_extract_entry_context_with_work_item_id ( ) {
776+ let entry = serde_json:: json!( { "name" : "comment-on-work-item" , "work_item_id" : 99 } ) ;
777+ assert_eq ! ( extract_entry_context( & entry) , " (work item #99)" ) ;
778+ }
779+
780+ #[ test]
781+ fn test_extract_entry_context_with_title ( ) {
782+ let entry = serde_json:: json!( { "name" : "create-work-item" , "title" : "Fix the bug" } ) ;
783+ assert_eq ! ( extract_entry_context( & entry) , " (\" Fix the bug\" )" ) ;
784+ }
785+
786+ #[ test]
787+ fn test_extract_entry_context_with_path ( ) {
788+ let entry = serde_json:: json!( { "name" : "create-wiki-page" , "path" : "/Overview/NewPage" } ) ;
789+ assert_eq ! ( extract_entry_context( & entry) , " (path: /Overview/NewPage)" ) ;
790+ }
791+
792+ #[ test]
793+ fn test_extract_entry_context_empty ( ) {
794+ let entry = serde_json:: json!( { "name" : "noop" } ) ;
795+ assert_eq ! ( extract_entry_context( & entry) , "" ) ;
796+ }
797+
798+ // --- resolve_max and DEFAULT_MAX unit tests ---
799+
800+ #[ test]
801+ fn test_default_max_trait_constant ( ) {
802+ assert_eq ! ( CreateWorkItemResult :: DEFAULT_MAX , 1 ) ;
803+ assert_eq ! ( CreatePrResult :: DEFAULT_MAX , 1 ) ;
804+ assert_eq ! ( UpdateWorkItemResult :: DEFAULT_MAX , 1 ) ;
805+ assert_eq ! ( CommentOnWorkItemResult :: DEFAULT_MAX , 1 ) ;
806+ assert_eq ! ( CreateWikiPageResult :: DEFAULT_MAX , 1 ) ;
807+ assert_eq ! ( UpdateWikiPageResult :: DEFAULT_MAX , 1 ) ;
808+ }
809+
810+ #[ test]
811+ fn test_resolve_max_uses_config_override ( ) {
812+ let mut tool_configs = HashMap :: new ( ) ;
813+ tool_configs. insert ( "test-tool" . to_string ( ) , serde_json:: json!( { "max" : 5 } ) ) ;
814+ let ctx = ExecutionContext {
815+ tool_configs,
816+ ..ExecutionContext :: default ( )
817+ } ;
818+ assert_eq ! ( resolve_max( & ctx, "test-tool" , 1 ) , 5 ) ;
819+ }
820+
821+ #[ test]
822+ fn test_resolve_max_falls_back_to_default ( ) {
823+ let ctx = ExecutionContext :: default ( ) ;
824+ assert_eq ! ( resolve_max( & ctx, "nonexistent-tool" , 3 ) , 3 ) ;
825+ }
826+
827+ #[ test]
828+ fn test_resolve_max_uses_default_when_no_max_in_config ( ) {
829+ let mut tool_configs = HashMap :: new ( ) ;
830+ tool_configs. insert ( "test-tool" . to_string ( ) , serde_json:: json!( { "other" : true } ) ) ;
831+ let ctx = ExecutionContext {
832+ tool_configs,
833+ ..ExecutionContext :: default ( )
834+ } ;
835+ assert_eq ! ( resolve_max( & ctx, "test-tool" , 7 ) , 7 ) ;
836+ }
837+
838+ // --- Generic budget enforcement for all tool types ---
839+
840+ #[ tokio:: test]
841+ async fn test_budget_enforcement_create_work_item_max ( ) {
842+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
843+ let safe_output_path = temp_dir. path ( ) . join ( SAFE_OUTPUT_FILENAME ) ;
844+
845+ // Write 3 create-work-item entries + 1 noop; max set to 2
846+ let ndjson = r#"{"name":"create-work-item","title":"First item","description":"A description that is definitely longer than thirty characters."}
847+ {"name":"create-work-item","title":"Second item","description":"A description that is definitely longer than thirty characters."}
848+ {"name":"create-work-item","title":"Third item","description":"A description that is definitely longer than thirty characters."}
849+ {"name":"noop","context":"still runs"}
850+ "# ;
851+ tokio:: fs:: write ( & safe_output_path, ndjson) . await . unwrap ( ) ;
852+
853+ let mut tool_configs = HashMap :: new ( ) ;
854+ tool_configs. insert ( "create-work-item" . to_string ( ) , serde_json:: json!( { "max" : 2 } ) ) ;
855+
856+ let ctx = ExecutionContext {
857+ ado_org_url : Some ( "https://dev.azure.com/org" . to_string ( ) ) ,
858+ ado_organization : Some ( "org" . to_string ( ) ) ,
859+ ado_project : Some ( "Proj" . to_string ( ) ) ,
860+ access_token : Some ( "token" . to_string ( ) ) ,
861+ working_directory : PathBuf :: from ( "." ) ,
862+ source_directory : PathBuf :: from ( "." ) ,
863+ tool_configs,
864+ repository_id : None ,
865+ repository_name : None ,
866+ allowed_repositories : HashMap :: new ( ) ,
867+ } ;
868+
869+ let results = execute_safe_outputs ( temp_dir. path ( ) , & ctx) . await ;
870+ assert ! ( results. is_ok( ) , "Batch should not abort when max is exceeded" ) ;
871+ let results = results. unwrap ( ) ;
872+ assert_eq ! ( results. len( ) , 4 , "Expected 4 results" ) ;
873+
874+ // Only 1 should be skipped (max=2 allows first 2, third is skipped)
875+ let skipped: Vec < _ > = results
876+ . iter ( )
877+ . filter ( |r| r. message . contains ( "maximum create-work-item count" ) )
878+ . collect ( ) ;
879+ assert_eq ! ( skipped. len( ) , 1 , "Expected 1 skipped entry, got: {:?}" , skipped) ;
880+
881+ // noop still runs
882+ assert ! ( results[ 3 ] . success, "noop should still succeed" ) ;
883+ }
884+
885+ #[ tokio:: test]
886+ async fn test_budget_enforcement_mixed_tools_independent_budgets ( ) {
887+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
888+ let safe_output_path = temp_dir. path ( ) . join ( SAFE_OUTPUT_FILENAME ) ;
889+
890+ // Mix of tools: each has max=1 (default), so only the first of each type should pass budget
891+ let ndjson = r#"{"name":"create-work-item","title":"WI 1","description":"A description that is definitely longer than thirty characters."}
892+ {"name":"create-work-item","title":"WI 2","description":"A description that is definitely longer than thirty characters."}
893+ {"name":"create-wiki-page","path":"/Page1","content":"Some valid wiki content here."}
894+ {"name":"create-wiki-page","path":"/Page2","content":"Some valid wiki content here."}
895+ {"name":"noop","context":"always runs"}
896+ "# ;
897+ tokio:: fs:: write ( & safe_output_path, ndjson) . await . unwrap ( ) ;
898+
899+ let ctx = ExecutionContext {
900+ ado_org_url : Some ( "https://dev.azure.com/org" . to_string ( ) ) ,
901+ ado_organization : Some ( "org" . to_string ( ) ) ,
902+ ado_project : Some ( "Proj" . to_string ( ) ) ,
903+ access_token : Some ( "token" . to_string ( ) ) ,
904+ working_directory : PathBuf :: from ( "." ) ,
905+ source_directory : PathBuf :: from ( "." ) ,
906+ tool_configs : HashMap :: new ( ) , // defaults: max=1 for all
907+ repository_id : None ,
908+ repository_name : None ,
909+ allowed_repositories : HashMap :: new ( ) ,
910+ } ;
911+
912+ let results = execute_safe_outputs ( temp_dir. path ( ) , & ctx) . await . unwrap ( ) ;
913+ assert_eq ! ( results. len( ) , 5 ) ;
914+
915+ // Second create-work-item should be skipped
916+ let cwi_skipped: Vec < _ > = results
917+ . iter ( )
918+ . filter ( |r| r. message . contains ( "maximum create-work-item count" ) )
919+ . collect ( ) ;
920+ assert_eq ! ( cwi_skipped. len( ) , 1 , "Expected 1 skipped create-work-item" ) ;
921+
922+ // Second create-wiki-page should be skipped
923+ let cwp_skipped: Vec < _ > = results
924+ . iter ( )
925+ . filter ( |r| r. message . contains ( "maximum create-wiki-page count" ) )
926+ . collect ( ) ;
927+ assert_eq ! ( cwp_skipped. len( ) , 1 , "Expected 1 skipped create-wiki-page" ) ;
928+
929+ // noop always runs
930+ assert ! ( results[ 4 ] . success, "noop should still succeed" ) ;
931+ }
738932}
0 commit comments