-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathmain_em.py
More file actions
executable file
·178 lines (152 loc) · 6.13 KB
/
main_em.py
File metadata and controls
executable file
·178 lines (152 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
import logging
import os
from base64 import b64encode
from time import sleep
from urllib.parse import quote_plus, urlparse
import niquests
import uvicorn
from context_chat_backend.types import DEFAULT_EM_MODEL_ALIAS # isort: skip
from context_chat_backend.config_parser import get_config # isort: skip
from context_chat_backend.logger import get_logging_config, setup_logging # isort: skip
from context_chat_backend.setup_functions import ensure_config_file, setup_env_vars # isort: skip
from context_chat_backend.utils import redact_config # isort: skip
LOGGER_CONFIG_NAME = 'logger_config_em.yaml'
STARTUP_CHECK_SEC = 10
MAX_TRIES = 180 # 180*10 secs = 30 minutes max
def _get_main_app_client() -> niquests.Session:
"""
Get a niquests Session to connect to the main app, depending on the deployment type.
Returns
-------
niquests.Session: The niquests Session.
"""
if os.getenv('HP_SHARED_KEY'):
base_url = 'http+unix://' + quote_plus(os.getenv('HP_EXAPP_SOCK', '/tmp/exapp.sock')) # noqa: S108
else:
connect_host = 'localhost' if os.environ['APP_HOST'] in ('0.0.0.0', '::') else os.environ['APP_HOST'] # noqa: S104
base_url = f'http://{connect_host}:{os.environ["APP_PORT"]}'
return niquests.Session(base_url=base_url, headers={
'EX-APP-ID': os.getenv('APP_ID', 'context_chat_backend'),
'EX-APP-VERSION': os.getenv('APP_VERSION', ''),
'OCS-APIRequest': 'true',
'AUTHORIZATION-APP-API': b64encode(f':{os.getenv("APP_SECRET", "")}'.encode()).decode('utf-8'),
})
def _wait_main_app_enabled() -> None:
'''
Raises
------
RuntimeError: If the main app is not enabled/ready within the expected time.
niquests.RequestException: If there is an error while making the request to the main app
'''
_max_tries = MAX_TRIES
_last_err = None
client = _get_main_app_client()
# wait for the main process to be ready
while _max_tries > 0:
try:
response = client.get('/enabled')
response.raise_for_status()
enabled = response.json().get('enabled', False)
if enabled:
return
print(
f'{(MAX_TRIES-_max_tries+1)*STARTUP_CHECK_SEC}/{MAX_TRIES*STARTUP_CHECK_SEC} secs:'
f' [Embedding server] Waiting for the main app to be enabled/ready. Current enabled state: {enabled}',
flush=True,
)
except niquests.RequestException as e:
print(
f'{(MAX_TRIES-_max_tries+1)*STARTUP_CHECK_SEC}/{MAX_TRIES*STARTUP_CHECK_SEC} secs:'
f' [Embedding server] Waiting for the main app to be enabled/ready, errors are expected initially: {e}',
flush=True,
)
if _max_tries == 1:
_last_err = e
except Exception as e:
raise RuntimeError('Unexpected error while waiting for the main app to be enabled/ready') from e
finally:
sleep(STARTUP_CHECK_SEC)
_max_tries -= 1
# if we exhausted all tries
raise _last_err or RuntimeError('Timed out waiting for the main app to be enabled/ready.')
if __name__ == '__main__':
# intial buffer
print(
f"Waiting for {STARTUP_CHECK_SEC} seconds before starting embedding server to allow main app to start",
flush=True,
)
sleep(STARTUP_CHECK_SEC)
setup_env_vars()
ensure_config_file()
app_config = get_config(os.environ['CC_CONFIG_PATH'])
em_conf = app_config.embedding
if em_conf.workers <= 0 or em_conf.remote_service:
print('Exiting embedding server as it is not configured to run locally.', flush=True)
exit(0)
# redact sensitive info before logging, although no api key or password should be present
# in local embedding server config
print('Embedder config:\n' + redact_config(em_conf).model_dump_json(indent=2), flush=True)
logging_config = get_logging_config(LOGGER_CONFIG_NAME)
setup_logging(logging_config)
logger = logging.getLogger('emserver')
if app_config.debug:
logger.setLevel(logging.DEBUG)
try:
_wait_main_app_enabled()
except Exception as e:
logger.error(
'Failed waiting for the main app to be enabled. This could indicate an issue with the AppAPI'
' Deploy Daemon setup or some issue in the main app setup. Some common causes of the latter'
' could be no/no stable internet connection to download the required models, disk space full,'
' or this app not being able to contact the Nextcloud server to report progress of the model'
' download.',
exc_info=e,
)
exit(1)
# update model path to be in the persistent storage if it is not already valid
if 'model' not in em_conf.llama:
raise ValueError('Error: "model" key not found in embedding->llama\'s config')
if not os.path.isfile(em_conf.llama['model']):
em_conf.llama['model'] = os.path.join(
os.getenv('APP_PERSISTENT_STORAGE', 'persistent_storage'),
'model_files',
em_conf.llama['model'],
)
logger.debug(f'Trying model path: {em_conf.llama["model"]}')
# if the model file is still not found, raise an error
if not os.path.isfile(em_conf.llama['model']):
raise ValueError('Error: Model file not found at the updated path')
# delayed import for libcuda.so.1 to be available
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import ModelSettings, ServerSettings
base_url = urlparse(em_conf.base_url)
host = base_url.hostname or '127.0.0.1'
port = base_url.port or 5000
server_settings = ServerSettings(host=host, port=port)
model_settings = [ModelSettings(model_alias=DEFAULT_EM_MODEL_ALIAS, embedding=True, **em_conf.llama)]
app = create_app(
server_settings=server_settings,
model_settings=model_settings,
)
uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]
uv_log_config['formatters']['json'] = logging_config['formatters']['json']
uv_log_config['handlers']['file_json'] = logging_config['handlers']['file_json']
uv_log_config['loggers']['uvicorn']['handlers'].append('file_json')
uv_log_config['loggers']['uvicorn.access']['handlers'].append('file_json')
uvicorn.run(
# todo: use string import of the app
app=app,
host=host,
port=port,
http='h11',
interface='asgi3',
log_config=uv_log_config,
log_level=app_config.uvicorn_log_level,
use_colors=bool(app_config.use_colors and os.getenv('CI', 'false') == 'false'),
workers=em_conf.workers,
)