@@ -76,7 +76,8 @@ def __init__(
7676 self .temperature = temperature
7777 self .top_p = top_p
7878 self .repetition_penalty = repetition_penalty
79- self .generation_batch_size = generation_batch_size
79+ # Default to batch_size when not provided to avoid None in range() calls
80+ self .generation_batch_size = generation_batch_size or batch_size
8081 # self.max_context_items = max_context_items
8182 self .max_input_length = max_input_length
8283 self .extraction_method = extraction_method
@@ -295,8 +296,9 @@ def _repair_generation(self, original_text: str, row: Dict[str, Any], topic: str
295296 return repair_prompt
296297
297298 def _generate_texts (self , prompts : List [str ]) -> List [str ]:
299+ # Route to the correct backend and return its outputs
298300 if self .is_chat_model :
299- self ._generate_texts_chat_llm (prompts = prompts )
301+ return self ._generate_texts_chat_llm (prompts = prompts )
300302 return self ._generate_texts_causal_llm (prompts = prompts )
301303
302304 def _generate_texts_causal_llm (self , prompts : List [str ]) -> List [str ]:
@@ -425,7 +427,7 @@ def generate_documents(self, pseudo_sentences: List[PseudoSentence], topic: str,
425427 raise ValueError ("Generated response is not valid JSON." )
426428 validated = self ._validate_document (response_payload , row = row , topic = topic )
427429 generated_docs .append (
428- Document (id = row ["id" ], title = validated .title , text = validated .fluent_passage_text , row = row ))
430+ Document (id = row ["id" ], title = validated .title , text = validated .fluent_passage_text ))
429431 print (">>> Generation was sucessful for doc: " , row ["id" ])
430432 # except Exception:
431433 # try:
0 commit comments