Skip to content

Commit 06b4981

Browse files
better callback
1 parent a29a422 commit 06b4981

1 file changed

Lines changed: 40 additions & 35 deletions

File tree

WDoc/utils/llm.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -169,77 +169,82 @@ def on_llm_start(
169169
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
170170
) -> Any:
171171
"""Run when LLM starts running."""
172-
if self.verbose:
173-
print("Callback method: on_llm_start")
174-
print(serialized)
175-
print(prompts)
176-
print(kwargs)
177-
print("Callback method end: on_llm_start")
178172
self.methods_called.append("on_llm_start")
173+
if self.verbose:
174+
yel("Callback method: on_llm_start")
175+
yel(serialized)
176+
yel(prompts)
177+
yel(kwargs)
178+
yel("Callback method end: on_llm_start")
179179
self._check_methods_called()
180180

181181
def on_chat_model_start(
182182
self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any
183183
) -> Any:
184184
"""Run when Chat Model starts running."""
185-
if self.verbose:
186-
print("Callback method: on_chat_model_start")
187-
print(serialized)
188-
print(messages)
189-
print(kwargs)
190-
print("Callback method end: on_chat_model_start")
191185
self.methods_called.append("on_chat_model_start")
186+
if self.verbose:
187+
yel("Callback method: on_chat_model_start")
188+
yel(serialized)
189+
yel(messages)
190+
yel(kwargs)
191+
yel("Callback method end: on_chat_model_start")
192192
self._check_methods_called()
193193

194194
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
195195
"""Run when LLM ends running."""
196+
self.methods_called.append("on_llm_end")
196197
if self.verbose:
197-
print("Callback method: on_llm_end")
198-
print(response)
199-
print(kwargs)
200-
print("Callback method end: on_llm_end")
198+
yel("Callback method: on_llm_end")
199+
yel(response)
200+
yel(kwargs)
201+
yel("Callback method end: on_llm_end")
202+
203+
if response.llm_output is None or response.llm_output["token_usage"] is None:
204+
if self.verbose:
205+
yel("None llm_output, returning.")
206+
return
201207

202208
new_p = response.llm_output["token_usage"]["prompt_tokens"]
203209
new_c = response.llm_output["token_usage"]["completion_tokens"]
204210
self.prompt_tokens += new_p
205211
self.completion_tokens += new_c
206212
self.total_tokens += new_p + new_c
207213
assert self.total_tokens == self.prompt_tokens + self.completion_tokens
208-
self.methods_called.append("on_llm_end")
209214
self._check_methods_called()
210215

211216
def on_llm_error(
212217
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
213218
) -> Any:
214219
"""Run when LLM errors."""
215-
if self.verbose:
216-
print("Callback method: on_llm_error")
217-
print(error)
218-
print(kwargs)
219-
print("Callback method end: on_llm_error")
220220
self.methods_called.append("on_llm_error")
221+
if self.verbose:
222+
yel("Callback method: on_llm_error")
223+
yel(error)
224+
yel(kwargs)
225+
yel("Callback method end: on_llm_error")
221226
self._check_methods_called()
222227

223228
def on_chain_start(
224229
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
225230
) -> Any:
226231
"""Run when chain starts running."""
227232
if self.verbose:
228-
print("Callback method: on_chain_start")
229-
print(serialized)
230-
print(inputs)
231-
print(kwargs)
232-
print("Callback method end: on_chain_start")
233+
yel("Callback method: on_chain_start")
234+
yel(serialized)
235+
yel(inputs)
236+
yel(kwargs)
237+
yel("Callback method end: on_chain_start")
233238
self.methods_called.append("on_chain_start")
234239
self._check_methods_called()
235240

236241
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
237242
"""Run when chain ends running."""
238243
if self.verbose:
239-
print("Callback method: on_chain_end")
240-
print(outputs)
241-
print(kwargs)
242-
print("Callback method end: on_chain_end")
244+
yel("Callback method: on_chain_end")
245+
yel(outputs)
246+
yel(kwargs)
247+
yel("Callback method end: on_chain_end")
243248
self.methods_called.append("on_chain_end")
244249
self._check_methods_called()
245250
if self.pbar:
@@ -250,10 +255,10 @@ def on_chain_error(
250255
) -> Any:
251256
"""Run when chain errors."""
252257
if self.verbose:
253-
print("Callback method: on_chain_error")
254-
print(error)
255-
print(kwargs)
256-
print("Callback method end: on_chain_error")
258+
yel("Callback method: on_chain_error")
259+
yel(error)
260+
yel(kwargs)
261+
yel("Callback method end: on_chain_error")
257262
self.methods_called.append("on_chain_error")
258263
self._check_methods_called()
259264

0 commit comments

Comments
 (0)