@@ -1089,6 +1089,36 @@ defmodule Bumblebee do
10891089 end
10901090 end
10911091
1092+ @ doc """
1093+ Initializes state for a new logits processor.
1094+
1095+ Returns `state`, which is an opaque `Nx.Container`, and it is then
1096+ passed to and returned from `process/4`.
1097+ """
1098+ @ doc type: :logits_processor
1099+ @ spec logits_processor_init (
1100+ Bumblebee.LogitsProcessor . t ( ) ,
1101+ context :: Bumblebee.LogitsProcessor . init_context ( )
1102+ ) :: Bumblebee.LogitsProcessor . state ( )
1103+ def logits_processor_init ( % module { } = logits_processor , context ) do
1104+ module . init ( logits_processor , context )
1105+ end
1106+
1107+ @ doc """
1108+ Processes logits, applying specific rules. Receives context, state and
1109+ logits, and returns updated logits and state.
1110+ """
1111+ @ doc type: :logits_processor
1112+ @ spec logits_processor_process (
1113+ Bumblebee.LogitsProcessor . t ( ) ,
1114+ Bumblebee.LogitsProcessor . state ( ) ,
1115+ logits :: Nx.Tensor . t ( ) ,
1116+ context :: Bumblebee.LogitsProcessor . process_context ( )
1117+ ) :: { Bumblebee.LogitsProcessor . state ( ) , logits :: Nx.Tensor . t ( ) }
1118+ def logits_processor_process ( % module { } = logits_processor , state , logits , context ) do
1119+ module . process ( logits_processor , state , logits , context )
1120+ end
1121+
10921122 @ doc """
10931123 Initializes state for a new scheduler loop.
10941124
0 commit comments