2323logger = get_logger (__name__ )
2424
2525
26+ class _RetryState :
27+ """Mutable state for the refinement retry loop."""
28+
29+ __slots__ = ("last_error_type" , "messages_history" , "min_prompt_level" , "same_error_count" )
30+
31+ def __init__ (self ) -> None :
32+ self .last_error_type : str | None = None
33+ self .same_error_count = 0
34+ self .messages_history : list [dict [str , str ]] = []
35+ self .min_prompt_level : int = 0
36+
37+
2638class AISuggestionFailedError (Exception ):
2739 pass
2840
@@ -107,44 +119,37 @@ def _resolve_use_compact(self, use_compact: bool | None) -> bool:
107119 def _try_prompt_levels (
108120 self ,
109121 schema_ctx : Any ,
110- messages_history : list [dict [str , str ]],
111- min_prompt_level : int ,
122+ state : _RetryState ,
112123 use_compact : bool ,
113124 call_fn : Callable [[list [dict [str , str ]]], dict [str , Any ] | None ],
114- ) -> tuple [dict [str , Any ] | None , ErrorSummary | None , int ]:
125+ ) -> tuple [dict [str , Any ] | None , ErrorSummary | None ]:
115126 """Try LLM call across prompt levels with context overflow fallback.
116127
117128 Args:
118129 schema_ctx: Schema context from the orchestrator.
119- messages_history: Accumulated conversation messages for multi-turn refinement.
120- min_prompt_level: Minimum prompt level index to start from (skips levels
121- that already failed with context overflow).
130+ state: Mutable retry state (messages_history, min_prompt_level updated in-place).
122131 use_compact: Whether to force ultra-compact mode.
123132 call_fn: Function to call LLM (non-streaming or streaming variant).
124133
125134 Returns:
126- (config_dict, error, new_min_prompt_level )
135+ (config_dict or None , error or None )
127136 """
128137 prompt_levels = self ._get_prompt_levels (use_compact )
129138 for level_idx , (compact , ultra ) in enumerate (prompt_levels ):
130- if level_idx < min_prompt_level :
139+ if level_idx < state . min_prompt_level :
131140 continue
132141 initial_messages = self ._analyzer .build_initial_messages (schema_ctx , compact = compact , ultra_compact = ultra )
133- messages = initial_messages + messages_history
142+ messages = initial_messages + state . messages_history
134143 try :
135144 config_dict = call_fn (messages )
136145 if not config_dict :
137- return (
138- None ,
139- ErrorSummary (
140- error_type = "empty_config" ,
141- message = "LLM returned empty result" ,
142- column = None ,
143- retryable = True ,
144- ),
145- min_prompt_level ,
146+ return None , ErrorSummary (
147+ error_type = "empty_config" ,
148+ message = "LLM returned empty result" ,
149+ column = None ,
150+ retryable = True ,
146151 )
147- return config_dict , None , min_prompt_level
152+ return config_dict , None
148153 except (ValueError , RuntimeError , OSError ) as e :
149154 err_lower = str (e ).lower ()
150155 if "context" in err_lower and "exceed" in err_lower and not ultra :
@@ -153,10 +158,10 @@ def _try_prompt_levels(
153158 compact = compact ,
154159 ultra_compact = ultra ,
155160 )
156- min_prompt_level = level_idx + 1
161+ state . min_prompt_level = level_idx + 1
157162 continue
158- return None , summarize_error (e ), min_prompt_level
159- return None , None , min_prompt_level
163+ return None , summarize_error (e )
164+ return None , None
160165
161166 def _handle_validation_result (
162167 self ,
@@ -166,15 +171,13 @@ def _handle_validation_result(
166171 config_dict : dict [str , Any ],
167172 attempt : int ,
168173 max_retries : int ,
169- messages_history : list [dict [str , str ]],
170- last_error_type : str | None ,
171- same_error_count : int ,
174+ state : _RetryState ,
172175 on_progress : Callable [[str , dict [str , Any ]], None ] | None = None ,
173- ) -> tuple [ dict [str , Any ] | None , str | None , int ] :
176+ ) -> dict [str , Any ] | None :
174177 """Handle validation result: return config on success, or update retry state.
175178
176179 Returns:
177- ( config_dict if valid else None, updated last_error_type, updated same_error_count)
180+ config_dict if valid, None if validation failed (state updated for next retry).
178181 """
179182 val_error = self ._validate_config (orch , table_name , config_dict )
180183
@@ -183,16 +186,18 @@ def _handle_validation_result(
183186 self ._cache_successful_config (table_name , config_dict , schema_hash )
184187 if on_progress :
185188 on_progress ("done" , {"tokens" : 0 , "model" : "validated" })
186- return config_dict , last_error_type , same_error_count
189+ return config_dict
187190
188- last_error_type , same_error_count = self ._check_repeated_error (val_error , last_error_type , same_error_count )
191+ state .last_error_type , state .same_error_count = self ._check_repeated_error (
192+ val_error , state .last_error_type , state .same_error_count
193+ )
189194 self ._handle_validation_failure (val_error , attempt , max_retries , table_name )
190195
191- messages_history .append ({"role" : "assistant" , "content" : json .dumps (config_dict , ensure_ascii = False )})
192- messages_history .append (
196+ state . messages_history .append ({"role" : "assistant" , "content" : json .dumps (config_dict , ensure_ascii = False )})
197+ state . messages_history .append (
193198 {"role" : "user" , "content" : self ._build_refinement_prompt (val_error , attempt , max_retries )}
194199 )
195- return None , last_error_type , same_error_count
200+ return None
196201
197202 def _refinement_loop (
198203 self ,
@@ -228,24 +233,18 @@ def _refinement_loop(
228233 return cached
229234
230235 resolved_compact = self ._resolve_use_compact (use_compact )
231-
232- last_error_type : str | None = None
233- same_error_count = 0
234- messages_history : list [dict [str , str ]] = []
235- min_prompt_level : int = 0
236+ state = _RetryState ()
236237
237238 for attempt in range (max_retries + 1 ):
238239 if on_progress :
239240 on_progress ("refining" , {"attempt" : attempt , "max_retries" : max_retries })
240241
241- config_dict , error , min_prompt_level = self ._try_prompt_levels (
242- schema_ctx , messages_history , min_prompt_level , resolved_compact , call_fn
243- )
242+ config_dict , error = self ._try_prompt_levels (schema_ctx , state , resolved_compact , call_fn )
244243
245244 if config_dict is None :
246245 if error is not None :
247- last_error_type , same_error_count = self ._check_repeated_error (
248- error , last_error_type , same_error_count
246+ state . last_error_type , state . same_error_count = self ._check_repeated_error (
247+ error , state . last_error_type , state . same_error_count
249248 )
250249 self ._handle_generation_failure (error , attempt , max_retries )
251250 continue
@@ -254,16 +253,14 @@ def _refinement_loop(
254253 if on_progress :
255254 on_progress ("validating" , {"attempt" : attempt })
256255
257- result , last_error_type , same_error_count = self ._handle_validation_result (
256+ result = self ._handle_validation_result (
258257 orch ,
259258 table_name ,
260259 schema_hash ,
261260 config_dict ,
262261 attempt ,
263262 max_retries ,
264- messages_history ,
265- last_error_type ,
266- same_error_count ,
263+ state ,
267264 on_progress ,
268265 )
269266 if result is not None :
0 commit comments