@@ -93,23 +93,9 @@ def wait_for_complete_response(context: Context) -> None:
9393 """Wait for the response to be complete."""
9494 context .response_data = _parse_streaming_response (context .response .text )
9595 context .response .raise_for_status ()
96- assert context .response_data ["finished" ] is True , f"Response is not finished: { context .response_data } "
97-
98-
99- @step ('I use "{endpoint}" to ask question' )
100- def ask_question (context : Context , endpoint : str ) -> None :
101- """Call the service REST API endpoint with question."""
102- base = f"http://{ context .hostname } :{ context .port } "
103- path = f"{ context .api_prefix } /{ endpoint } " .replace ("//" , "/" )
104- url = base + path
105-
106- # Replace {MODEL} and {PROVIDER} placeholders with actual values
107- json_str = replace_placeholders (context , context .text or "{}" )
108-
109- data = json .loads (json_str )
110- context .response = request_with_transient_retry (
111- method = "POST" , url = url , json = data , timeout = DEFAULT_LLM_TIMEOUT
112- )
96+ assert (
97+ context .response_data ["finished" ] is True
98+ ), f"Response is not finished: { context .response_data } "
11399
114100
115101def _read_streamed_response (response : requests .Response ) -> str :
@@ -124,41 +110,72 @@ def _read_streamed_response(response: requests.Response) -> str:
124110 return "" .join (chunks )
125111
126112
127- @step ('I use "{endpoint}" to ask question with authorization header' )
128- def ask_question_authorized (context : Context , endpoint : str ) -> None :
129- """Call the service REST API endpoint with question."""
113+ def _uses_sse (endpoint : str , data : dict [str , Any ]) -> bool :
114+ """Return whether the endpoint delivers an SSE stream for the given payload."""
115+ return endpoint == "streaming_query" or (
116+ endpoint == "responses" and bool (data .get ("stream" ))
117+ )
118+
119+
120+ def _post_question (
121+ context : Context ,
122+ endpoint : str ,
123+ headers : dict [str , str ] | None = None ,
124+ extra_data : dict [str , Any ] | None = None ,
125+ ) -> requests .Response :
126+ """POST a question to the service REST API endpoint.
127+
128+ Parameters:
129+ context: Behave context with hostname, port, and request body text.
130+ endpoint: API endpoint name (e.g. ``query``, ``streaming_query``).
131+ headers: Optional HTTP headers (e.g. authorization).
132+ extra_data: Optional fields merged into the JSON request body.
133+
134+ Returns:
135+ The HTTP response, with streamed bodies fully consumed when applicable.
136+ """
130137 base = f"http://{ context .hostname } :{ context .port } "
131138 path = f"{ context .api_prefix } /{ endpoint } " .replace ("//" , "/" )
132139 url = base + path
133140
134- # Replace {MODEL} and {PROVIDER} placeholders with actual values
135141 json_str = replace_placeholders (context , context .text or "{}" )
136-
137142 data = json .loads (json_str )
138- use_sse = endpoint == "streaming_query" or (
139- endpoint == "responses" and bool ( data .get ( "stream" ) )
140- )
141- if use_sse :
143+ if extra_data :
144+ data .update ( extra_data )
145+
146+ if _uses_sse ( endpoint , data ) :
142147 resp = request_with_transient_retry (
143148 method = "POST" ,
144149 url = url ,
145150 json = data ,
146- headers = context . auth_headers ,
151+ headers = headers ,
147152 timeout = DEFAULT_LLM_TIMEOUT ,
148153 stream = True ,
149154 )
150155 # Consume stream so server close after error event does not raise
151156 body = _read_streamed_response (resp )
152157 resp ._content = body .encode (resp .encoding or "utf-8" )
153- context .response = resp
154- else :
155- context .response = request_with_transient_retry (
156- method = "POST" ,
157- url = url ,
158- json = data ,
159- headers = context .auth_headers ,
160- timeout = DEFAULT_LLM_TIMEOUT ,
161- )
158+ return resp
159+
160+ return request_with_transient_retry (
161+ method = "POST" ,
162+ url = url ,
163+ json = data ,
164+ headers = headers ,
165+ timeout = DEFAULT_LLM_TIMEOUT ,
166+ )
167+
168+
169+ @step ('I use "{endpoint}" to ask question' )
170+ def ask_question (context : Context , endpoint : str ) -> None :
171+ """Call the service REST API endpoint with question."""
172+ context .response = _post_question (context , endpoint )
173+
174+
175+ @step ('I use "{endpoint}" to ask question with authorization header' )
176+ def ask_question_authorized (context : Context , endpoint : str ) -> None :
177+ """Call the service REST API endpoint with question."""
178+ context .response = _post_question (context , endpoint , headers = context .auth_headers )
162179
163180
164181# Query length chosen to exceed typical model context windows (e.g. 128k tokens)
@@ -188,19 +205,12 @@ def store_conversation_details(context: Context) -> None:
188205@step ('I use "{endpoint}" to ask question with same conversation_id' )
189206def ask_question_in_same_conversation (context : Context , endpoint : str ) -> None :
190207 """Call the service REST API endpoint with question, but use the existing conversation id."""
191- base = f"http://{ context .hostname } :{ context .port } "
192- path = f"{ context .api_prefix } /{ endpoint } " .replace ("//" , "/" )
193- url = base + path
194-
195- # Replace {MODEL} and {PROVIDER} placeholders with actual values
196- json_str = replace_placeholders (context , context .text or "{}" )
197-
198- data = json .loads (json_str )
199- headers = context .auth_headers if hasattr (context , "auth_headers" ) else {}
200- data ["conversation_id" ] = context .response_data ["conversation_id" ]
201-
202- context .response = request_with_transient_retry (
203- method = "POST" , url = url , json = data , headers = headers , timeout = DEFAULT_LLM_TIMEOUT
208+ headers = context .auth_headers if hasattr (context , "auth_headers" ) else None
209+ context .response = _post_question (
210+ context ,
211+ endpoint ,
212+ headers = headers ,
213+ extra_data = {"conversation_id" : context .response_data ["conversation_id" ]},
204214 )
205215
206216
@@ -366,12 +376,12 @@ def _parse_streaming_response(response_text: str) -> dict:
366376 full_response = ""
367377 full_response_split = []
368378 finished = False
369- first_token = True
370379 stream_error = (
371380 None # {"status_code": int, "response": str, "cause": str} if event "error"
372381 )
373382
374383 for line in lines :
384+ print (f"line: { line } " )
375385 if line .startswith ("data: " ):
376386 try :
377387 data = json .loads (line [6 :]) # Remove 'data: ' prefix
@@ -380,10 +390,6 @@ def _parse_streaming_response(response_text: str) -> dict:
380390 if event == "start" :
381391 conversation_id = data ["data" ]["conversation_id" ]
382392 elif event == "token" :
383- # Skip the first token (shield status message)
384- if first_token :
385- first_token = False
386- continue
387393 full_response_split .append (data ["data" ]["token" ])
388394 elif event == "turn_complete" :
389395 full_response = data ["data" ]["token" ]
0 commit comments