Skip to content

Commit 4e0f5d5

Browse files
authored
fix: use semantic version comparison for PyTorch scheduler compatibility (#2094)
1 parent 50c32ac commit 4e0f5d5

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

qlib/contrib/model/pytorch_nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import numpy as np
1212
import pandas as pd
13+
from packaging import version
1314
from typing import Callable, Optional, Text, Union
1415
from sklearn.metrics import roc_auc_score, mean_squared_error
1516

@@ -148,7 +149,7 @@ def __init__(
148149
if scheduler == "default":
149150
# In torch version 2.7.0, the verbose parameter has been removed. Reference Link:
150151
# https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313
151-
if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0":
152+
if version.parse(str(torch.__version__).split("+", maxsplit=1)[0]) <= version.parse("2.6.0"):
152153
# Reduce learning rate when loss has stopped decrease
153154
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123
154155
self.train_optimizer,

0 commit comments

Comments
 (0)