Skip to content

Commit e4fbf22

Browse files
committed
fix(safety_checker): drop DummySafetyChecker and just use None
Originally we had to instantiate a DummySafetyChecker() to replace the built in safety checker if we wanted to disable it. However, diffusers (for a while now) supports safety_checker=None, which is also less brittle.
1 parent 7a64846 commit e4fbf22

1 file changed

Lines changed: 1 addition & 13 deletions

File tree

api/app.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,10 @@
4949
always_normalize_model_id = None
5050

5151

52-
class DummySafetyChecker:
53-
@staticmethod
54-
def __call__(images, clip_input):
55-
return images, False
56-
57-
5852
# Init is ran on server startup
5953
# Load your model to GPU as a global variable here using the variable name "model"
6054
def init():
6155
global model # needed for bananna optimizations
62-
global dummy_safety_checker
6356
global always_normalize_model_id
6457

6558
asyncio.run(
@@ -75,8 +68,6 @@ def init():
7568
)
7669
)
7770

78-
dummy_safety_checker = DummySafetyChecker()
79-
8071
if MODEL_ID == "ALL" or RUNTIME_DOWNLOADS:
8172
global last_model_id
8273
last_model_id = None
@@ -140,7 +131,6 @@ async def inference(all_inputs: dict, response) -> dict:
140131
global pipelines
141132
global last_model_id
142133
global schedulers
143-
global dummy_safety_checker
144134
global last_xformers_memory_efficient_attention
145135
global always_normalize_model_id
146136
global last_attn_procs
@@ -310,9 +300,7 @@ def sendStatus():
310300
}
311301

312302
safety_checker = call_inputs.get("safety_checker", True)
313-
pipeline.safety_checker = (
314-
model.safety_checker if safety_checker else dummy_safety_checker
315-
)
303+
pipeline.safety_checker = model.safety_checker if safety_checker else None
316304
is_url = call_inputs.get("is_url", False)
317305
image_decoder = getFromUrl if is_url else decodeBase64Image
318306

0 commit comments

Comments
 (0)