Skip to content

Commit d9fe442

Browse files
Switched to the higher version of MLflow (#36)
1 parent 7f22ad5 commit d9fe442

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

src/pquant/core/hyperparameter_optimization.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def log_model_by_backend(model, name, backend, signature=None, registered_model_
4040
import mlflow
4141

4242
kwargs = {
43-
"artifact_path": name,
43+
"name": name,
4444
"signature": signature,
4545
"registered_model_name": registered_model_name,
4646
}
@@ -249,7 +249,7 @@ def set_hyperparameters(self):
249249

250250
if numerical_params:
251251
self.set_numerical_params(numerical_params)
252-
elif categorical_params:
252+
if categorical_params:
253253
self.set_categorical_params(categorical_params)
254254

255255
def set_numerical_params(self, numerical_params):
@@ -293,8 +293,9 @@ def register_hyperparameter(self, name, optuna_func, *args, **kwargs):
293293

294294
def objective(self, trial, model, train_func, valid_func, **kwargs):
295295
from pquant import add_compression_layers, train_model
296-
296+
297297
config_copy = copy.deepcopy(self.config)
298+
applied_parameters = {}
298299
for param_name, (optuna_func, func_args, func_kwargs) in self.hyperparameters.items():
299300
new_value = optuna_func(trial, *func_args, **func_kwargs)
300301
logging.info(f"Suggested {param_name} = {new_value}")
@@ -308,19 +309,21 @@ def objective(self, trial, model, train_func, valid_func, **kwargs):
308309
]:
309310
if hasattr(sub_config, param_name):
310311
setattr(sub_config, param_name, new_value)
312+
applied_parameters[param_name] = new_value
311313
applied = True
312314
break
313315
if not applied:
314316
logging.error(f"'{param_name}' not found in config: value not applied.")
315317

316318
trainloader = kwargs['trainloader']
317319
raw_input_batch = next(iter(trainloader))
320+
318321
sample_input = raw_input_batch[0]
319322
model_copy = self.adapter.clone_model(model)
320323
model_copy = self.adapter.move_to_device(model_copy)
321324
sample_output = self.adapter.forward(model_copy, sample_input)
322-
323325
input_shape = sample_input.shape
326+
324327
compressed_model = add_compression_layers(model_copy, config_copy, input_shape)
325328
optimizer_func = self.get_optimizer_function()
326329
optimizer = optimizer_func(config_copy, compressed_model)
@@ -350,7 +353,7 @@ def objective(self, trial, model, train_func, valid_func, **kwargs):
350353
from mlflow.models import infer_signature
351354

352355
with mlflow.start_run(nested=True):
353-
mlflow.log_params({param_name: getattr(config_copy, param_name) for param_name in config_copy.model_fields})
356+
mlflow.log_params(applied_parameters)
354357
mlflow.log_metrics({key: val for key, val in zip(self.objectives.keys(), objectives)})
355358
signature = infer_signature(
356359
self.adapter.tensor_to_numpy(sample_input), self.adapter.tensor_to_numpy(sample_output)

0 commit comments

Comments
 (0)