File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 164164}
165165
166166_TEST_CHAT_GENERATION_PREDICTION1 = {
167- "safetyAttributes" : {
168- "scores" : [],
169- "blocked" : False ,
170- "categories" : [],
171- },
167+ "safetyAttributes" : [
168+ {
169+ "scores" : [],
170+ "blocked" : False ,
171+ "categories" : [],
172+ }
173+ ],
172174 "candidates" : [
173175 {
174176 "author" : "1" ,
177179 ],
178180}
179181_TEST_CHAT_GENERATION_PREDICTION2 = {
180- "safetyAttributes" : {
181- "scores" : [],
182- "blocked" : False ,
183- "categories" : [],
184- },
182+ "safetyAttributes" : [
183+ {
184+ "scores" : [],
185+ "blocked" : False ,
186+ "categories" : [],
187+ }
188+ ],
185189 "candidates" : [
186190 {
187191 "author" : "1" ,
Original file line number Diff line number Diff line change @@ -799,7 +799,8 @@ def send_message(
799799 )
800800
801801 prediction = prediction_response .predictions [0 ]
802- safety_attributes = prediction ["safetyAttributes" ]
802+ # ! Note: For chat models, the safetyAttributes is a list.
803+ safety_attributes = prediction ["safetyAttributes" ][0 ]
803804 response_obj = TextGenerationResponse (
804805 text = prediction ["candidates" ][0 ]["content" ]
805806 if prediction .get ("candidates" )
You can’t perform that action at this time.
0 commit comments