diff --git a/pytext/task/tasks.py b/pytext/task/tasks.py index 05636b487..576e3d0f6 100644 --- a/pytext/task/tasks.py +++ b/pytext/task/tasks.py @@ -244,7 +244,7 @@ def torchscript_export(self, model, export_path=None, **kwargs): # noqa model(*inputs) if quantize: model.quantize() - if "half" in accelerate: + if accelerate is not None and "half" in accelerate: model.half() if self.trace_both_encoders: trace = jit.trace(model, inputs) @@ -276,7 +276,7 @@ def torchscript_export(self, model, export_path=None, **kwargs): # noqa "inference_interface not supported by model. Ignoring inference_interface" ) trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None) - if "nnpi" in accelerate: + if accelerate is not None and "nnpi" in accelerate: trace._c = torch._C._freeze_module( trace._c, preservedAttrs=["make_prediction", "make_batch", "set_padding_control"],