@@ -1299,16 +1299,17 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
12991299// Holds "is at end of stream" state for each query.
13001300class TokenStreamer {
13011301 public:
1302- explicit TokenStreamer (const RuntimeConfig& runtime_config)
1303- : runtime_config_(runtime_config) {}
1302+ explicit TokenStreamer (const RuntimeConfig& runtime_config,
1303+ const ModelConfig& model_config)
1304+ : runtime_config_(runtime_config), model_config_(model_config) {}
13041305
13051306 // Returns whether the query was already at, or has just reached, the end of
13061307 // the stream: either via token == eos_id, or StreamToken returning false.
13071308 bool operator ()(size_t query_idx, size_t pos, int token, float prob) {
13081309 if (HWY_UNLIKELY (is_eos_.Get (query_idx))) return true ;
13091310
13101311 if (!runtime_config_.StreamToken (query_idx, pos, token, prob) ||
1311- token == runtime_config_. eos_id ) {
1312+ model_config_. IsEOS (token) ) {
13121313 is_eos_.Set (query_idx);
13131314 return true ;
13141315 }
@@ -1318,6 +1319,7 @@ class TokenStreamer {
13181319
13191320 private:
13201321 const RuntimeConfig& runtime_config_;
1322+ const ModelConfig& model_config_;
13211323 hwy::BitSet4096<> is_eos_;
13221324};
13231325
@@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14251427 // Sanity check: prompts should not be empty, nor start with EOS.
14261428 for (size_t query_idx = 0 ; query_idx < queries_prompt.size (); ++query_idx) {
14271429 const PromptTokens& prompt = queries_prompt[query_idx];
1428- HWY_ASSERT (prompt.size () != 0 && prompt[0 ] != runtime_config. eos_id );
1430+ HWY_ASSERT (prompt.size () != 0 && !model. Config (). IsEOS ( prompt[0 ]) );
14291431 }
14301432
14311433 const size_t num_queries = queries_prompt.size ();
@@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14691471 std::vector<int > gen_tokens (num_queries);
14701472
14711473 // Stream the last prompt token from each query and fill gen_tokens.
1472- TokenStreamer token_streamer (runtime_config);
1474+ TokenStreamer token_streamer (runtime_config, model. Config () );
14731475 for (size_t query_idx = 0 ; query_idx < num_queries; ++query_idx) {
14741476 size_t last_token_pos_in_prompt =
14751477 queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
0 commit comments