diff --git a/yolox/models/yolo_head.py b/yolox/models/yolo_head.py index 3e51768ee7..79e8ecc110 100644 --- a/yolox/models/yolo_head.py +++ b/yolox/models/yolo_head.py @@ -227,7 +227,7 @@ def get_output_and_grid(self, output, k, stride, dtype): output = output.permute(0, 1, 3, 4, 2).reshape( batch_size, hsize * wsize, -1 ) - grid = grid.view(1, -1, 2) + grid = grid.view(1, -1, 2).to(output.device) output[..., :2] = (output[..., :2] + grid) * stride output[..., 2:4] = torch.exp(output[..., 2:4]) * stride return output, grid