2323import logging
2424import os
2525import signal
26+ import sys
2627import threading
2728import time
2829import traceback
4142
4243_HOST = "[::]"
4344
45+ # Create seperate logger to log all INFO message for this module. These show
46+ # stages of server startup and inform user if server is ready to take requests.
47+ # The default logger created in orchestrator.py only logs WARNINGs and above
48+ logger = logging .getLogger (__name__ )
49+ logger .propagate = False
50+ formatter = logging .Formatter (
51+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
52+ )
53+
54+ info_handler = logging .StreamHandler (sys .stdout )
55+ info_handler .setLevel (logging .INFO )
56+ info_handler .setFormatter (formatter )
57+ logger .addHandler (info_handler )
58+
4459
4560class JetStreamServer :
4661 """JetStream grpc server."""
@@ -120,7 +135,7 @@ def create_driver(
120135 prefill_params = [pe .load_params () for pe in engines .prefill_engines ]
121136 generate_params = [ge .load_params () for ge in engines .generate_engines ]
122137 shared_params = [ie .load_params () for ie in engines .interleaved_engines ]
123- logging .info ("Loaded all weights." )
138+ logge .info ("Loaded all weights." )
124139 if metrics_collector :
125140 metrics_collector .get_model_load_time_metric ().set (
126141 time .time () - model_load_start_time
@@ -215,19 +230,19 @@ def run(
215230 del lora_input_adapters_path
216231
217232 server_start_time = time .time ()
218- logging .info ("Kicking off gRPC server." )
233+ logger .info ("Kicking off gRPC server." )
219234 # Setup Prometheus server
220235 metrics_collector : JetstreamMetricsCollector = None
221236 if metrics_server_config and metrics_server_config .port :
222- logging .info (
237+ logger .info (
223238 "Starting Prometheus server on port %d" , metrics_server_config .port
224239 )
225240 start_http_server (metrics_server_config .port )
226241 metrics_collector = JetstreamMetricsCollector (
227242 model_name = metrics_server_config .model_name
228243 )
229244 else :
230- logging .info (
245+ logger .info (
231246 "Not starting Prometheus server: --prometheus_port flag not set"
232247 )
233248
@@ -256,7 +271,7 @@ def run(
256271 gc .set_threshold (allocs , gen1 , gen2 )
257272 print ("GC tweaked (allocs, gen1, gen2): " , allocs , gen1 , gen2 )
258273
259- logging .info ("Starting server on port %d with %d threads" , port , threads )
274+ logger .info ("Starting server on port %d with %d threads" , port , threads )
260275 jetstream_server .start ()
261276
262277 if metrics_collector :
@@ -266,10 +281,10 @@ def run(
266281
267282 # Setup Jax Profiler
268283 if enable_jax_profiler :
269- logging .info ("Starting JAX profiler server on port %s" , jax_profiler_port )
284+ logger .info ("Starting JAX profiler server on port %s" , jax_profiler_port )
270285 jax .profiler .start_server (jax_profiler_port )
271286 else :
272- logging .info ("Not starting JAX profiler server: %s" , enable_jax_profiler )
287+ logger .info ("Not starting JAX profiler server: %s" , enable_jax_profiler )
273288
274289 # Start profiling server by default for proxy backend.
275290 if jax .config .jax_platforms and "proxy" in jax .config .jax_platforms :
@@ -287,5 +302,5 @@ def get_devices() -> Any:
287302 """Gets devices."""
288303 # TODO: Add more logs for the devices.
289304 devices = jax .devices ()
290- logging .info ("Using devices: %d" , len (devices ))
305+ logger .info ("Using devices: %d" , len (devices ))
291306 return devices
0 commit comments