Skip to content

Commit 158cc23

Browse files
committed
Always print server starup messages
1 parent 8aa6a9e commit 158cc23

1 file changed

Lines changed: 23 additions & 8 deletions

File tree

jetstream/core/server_lib.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import logging
2424
import os
2525
import signal
26+
import sys
2627
import threading
2728
import time
2829
import traceback
@@ -41,6 +42,20 @@
4142

4243
_HOST = "[::]"
4344

45+
# Create seperate logger to log all INFO message for this module. These show
46+
# stages of server startup and inform user if server is ready to take requests.
47+
# The default logger created in orchestrator.py only logs WARNINGs and above
48+
logger = logging.getLogger(__name__)
49+
logger.propagate = False
50+
formatter = logging.Formatter(
51+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
52+
)
53+
54+
info_handler = logging.StreamHandler(sys.stdout)
55+
info_handler.setLevel(logging.INFO)
56+
info_handler.setFormatter(formatter)
57+
logger.addHandler(info_handler)
58+
4459

4560
class JetStreamServer:
4661
"""JetStream grpc server."""
@@ -120,7 +135,7 @@ def create_driver(
120135
prefill_params = [pe.load_params() for pe in engines.prefill_engines]
121136
generate_params = [ge.load_params() for ge in engines.generate_engines]
122137
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
123-
logging.info("Loaded all weights.")
138+
logger.info("Loaded all weights.")
124139
if metrics_collector:
125140
metrics_collector.get_model_load_time_metric().set(
126141
time.time() - model_load_start_time
@@ -215,19 +230,19 @@ def run(
215230
del lora_input_adapters_path
216231

217232
server_start_time = time.time()
218-
logging.info("Kicking off gRPC server.")
233+
logger.info("Kicking off gRPC server.")
219234
# Setup Prometheus server
220235
metrics_collector: JetstreamMetricsCollector = None
221236
if metrics_server_config and metrics_server_config.port:
222-
logging.info(
237+
logger.info(
223238
"Starting Prometheus server on port %d", metrics_server_config.port
224239
)
225240
start_http_server(metrics_server_config.port)
226241
metrics_collector = JetstreamMetricsCollector(
227242
model_name=metrics_server_config.model_name
228243
)
229244
else:
230-
logging.info(
245+
logger.info(
231246
"Not starting Prometheus server: --prometheus_port flag not set"
232247
)
233248

@@ -256,7 +271,7 @@ def run(
256271
gc.set_threshold(allocs, gen1, gen2)
257272
print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2)
258273

259-
logging.info("Starting server on port %d with %d threads", port, threads)
274+
logger.info("Starting server on port %d with %d threads", port, threads)
260275
jetstream_server.start()
261276

262277
if metrics_collector:
@@ -266,10 +281,10 @@ def run(
266281

267282
# Setup Jax Profiler
268283
if enable_jax_profiler:
269-
logging.info("Starting JAX profiler server on port %s", jax_profiler_port)
284+
logger.info("Starting JAX profiler server on port %s", jax_profiler_port)
270285
jax.profiler.start_server(jax_profiler_port)
271286
else:
272-
logging.info("Not starting JAX profiler server: %s", enable_jax_profiler)
287+
logger.info("Not starting JAX profiler server: %s", enable_jax_profiler)
273288

274289
# Start profiling server by default for proxy backend.
275290
if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms:
@@ -287,5 +302,5 @@ def get_devices() -> Any:
287302
"""Gets devices."""
288303
# TODO: Add more logs for the devices.
289304
devices = jax.devices()
290-
logging.info("Using devices: %d", len(devices))
305+
logger.info("Using devices: %d", len(devices))
291306
return devices

0 commit comments

Comments
 (0)