2323import logging
2424import os
2525import signal
26+ import sys
2627import threading
2728import time
2829import traceback
4142
4243_HOST = "[::]"
4344
45+ # Create seperate logger to log all INFO message for this module. These show
46+ # stages of server startup and inform user if server is ready to take requests.
47+ # The default logger created in orchestrator.py only logs WARNINGs and above
48+ logger = logging .getLogger (__name__ )
49+ logger .propagate = False
50+ logger .setLevel (logging .INFO )
51+ formatter = logging .Formatter (
52+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
53+ )
54+
55+ info_handler = logging .StreamHandler (sys .stdout )
56+ info_handler .setLevel (logging .INFO )
57+ info_handler .setFormatter (formatter )
58+ logger .addHandler (info_handler )
59+
4460
4561class JetStreamServer :
4662 """JetStream grpc server."""
@@ -120,7 +136,7 @@ def create_driver(
120136 prefill_params = [pe .load_params () for pe in engines .prefill_engines ]
121137 generate_params = [ge .load_params () for ge in engines .generate_engines ]
122138 shared_params = [ie .load_params () for ie in engines .interleaved_engines ]
123- logging .info ("Loaded all weights." )
139+ logger .info ("Loaded all weights." )
124140 if metrics_collector :
125141 metrics_collector .get_model_load_time_metric ().set (
126142 time .time () - model_load_start_time
@@ -135,13 +151,13 @@ def create_driver(
135151 generate_params = generate_params + shared_params
136152
137153 if prefill_engines is None :
138- prefill_engines = []
154+ prefill_engines = [] # pragma: no branch
139155 if generate_engines is None :
140- generate_engines = []
156+ generate_engines = [] # pragma: no branch
141157 if prefill_params is None :
142- prefill_params = []
158+ prefill_params = [] # pragma: no branch
143159 if generate_params is None :
144- generate_params = []
160+ generate_params = [] # pragma: no branch
145161
146162 if enable_model_warmup :
147163 prefill_engines = [engine_api .JetStreamEngine (pe ) for pe in prefill_engines ]
@@ -215,19 +231,19 @@ def run(
215231 del lora_input_adapters_path
216232
217233 server_start_time = time .time ()
218- logging .info ("Kicking off gRPC server." )
234+ logger .info ("Kicking off gRPC server." )
219235 # Setup Prometheus server
220236 metrics_collector : JetstreamMetricsCollector = None
221237 if metrics_server_config and metrics_server_config .port :
222- logging .info (
238+ logger .info (
223239 "Starting Prometheus server on port %d" , metrics_server_config .port
224240 )
225241 start_http_server (metrics_server_config .port )
226242 metrics_collector = JetstreamMetricsCollector (
227243 model_name = metrics_server_config .model_name
228244 )
229245 else :
230- logging .info (
246+ logger .info (
231247 "Not starting Prometheus server: --prometheus_port flag not set"
232248 )
233249
@@ -256,7 +272,7 @@ def run(
256272 gc .set_threshold (allocs , gen1 , gen2 )
257273 print ("GC tweaked (allocs, gen1, gen2): " , allocs , gen1 , gen2 )
258274
259- logging .info ("Starting server on port %d with %d threads" , port , threads )
275+ logger .info ("Starting server on port %d with %d threads" , port , threads )
260276 jetstream_server .start ()
261277
262278 if metrics_collector :
@@ -266,10 +282,10 @@ def run(
266282
267283 # Setup Jax Profiler
268284 if enable_jax_profiler :
269- logging .info ("Starting JAX profiler server on port %s" , jax_profiler_port )
285+ logger .info ("Starting JAX profiler server on port %s" , jax_profiler_port )
270286 jax .profiler .start_server (jax_profiler_port )
271287 else :
272- logging .info ("Not starting JAX profiler server: %s" , enable_jax_profiler )
288+ logger .info ("Not starting JAX profiler server: %s" , enable_jax_profiler )
273289
274290 # Start profiling server by default for proxy backend.
275291 if jax .config .jax_platforms and "proxy" in jax .config .jax_platforms :
@@ -279,6 +295,7 @@ def run(
279295 target = proxy_util .start_profiling_server , args = (jax_profiler_port ,)
280296 )
281297 thread .run ()
298+ logger .info ("Server up and ready to process requests on port %s" , port )
282299
283300 return jetstream_server
284301
@@ -287,5 +304,5 @@ def get_devices() -> Any:
287304 """Gets devices."""
288305 # TODO: Add more logs for the devices.
289306 devices = jax .devices ()
290- logging .info ("Using devices: %d" , len (devices ))
307+ logger .info ("Using devices: %d" , len (devices ))
291308 return devices
0 commit comments