From 6ef59a4068194d154338fcb30db0514dd26555dd Mon Sep 17 00:00:00 2001 From: Leopold Talirz Date: Mon, 25 Nov 2024 12:17:16 +0100 Subject: [PATCH 1/2] fix: keep inferring n_elements When n_elements was set to 'infer', the first call to `load_data` would overwrite the value of `n_elements` to a concrete integer, fixing the number of elements for all subsequent calls. --- crabnet/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crabnet/model.py b/crabnet/model.py index 8842e82..c63c92f 100644 --- a/crabnet/model.py +++ b/crabnet/model.py @@ -73,7 +73,8 @@ def load_data(self, file_name, batch_size=2**9, train=False): f'elements in the formula') # update n_elements after loading dataset - self.n_elements = data_loaders.n_elements + if self.n_elements != 'infer': + self.n_elements = data_loaders.n_elements data_loader = data_loaders.get_data_loaders(inference=inference) y = data_loader.dataset.data[1] From 623ff6fd32005d3e4bf5e4c341f5752ff208716f Mon Sep 17 00:00:00 2001 From: Leopold Talirz Date: Mon, 25 Nov 2024 12:25:42 +0100 Subject: [PATCH 2/2] fix: don't update n_elements on load_data --- crabnet/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/crabnet/model.py b/crabnet/model.py index c63c92f..6403a47 100644 --- a/crabnet/model.py +++ b/crabnet/model.py @@ -72,10 +72,6 @@ def load_data(self, file_name, batch_size=2**9, train=False): print(f'loading data with up to {data_loaders.n_elements:0.0f} ' f'elements in the formula') - # update n_elements after loading dataset - if self.n_elements != 'infer': - self.n_elements = data_loaders.n_elements - data_loader = data_loaders.get_data_loaders(inference=inference) y = data_loader.dataset.data[1] if train: