Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
819cf0c
Remove requires fitting and finalize routine
c-w-feldmann Apr 24, 2025
c1556c3
linters
c-w-feldmann Apr 24, 2025
a6716fb
move transform_single back in (for now)
c-w-feldmann Apr 24, 2025
1051fec
remove parameter property
c-w-feldmann Apr 24, 2025
e32f72d
rewrite to use mixins
c-w-feldmann Apr 24, 2025
3ae90b4
Merge branch 'refs/heads/development' into instance-routine-mixins
c-w-feldmann May 6, 2025
df1d09e
Merge branch 'development' into instance-routine-mixins
c-w-feldmann May 6, 2025
06f4753
remove unnecessary type check
c-w-feldmann May 7, 2025
8045215
remove Raises from docu
c-w-feldmann May 7, 2025
be4dd08
remove pylint ignore
c-w-feldmann May 7, 2025
d141147
fix unittests and rewrite
c-w-feldmann May 13, 2025
d356b84
Change inheritance
c-w-feldmann May 13, 2025
3b684cd
Type cast
c-w-feldmann May 13, 2025
0826079
type hints
c-w-feldmann May 13, 2025
f1f8350
fix var name
c-w-feldmann May 13, 2025
3d1ed3e
Add type ignore and minor linting
c-w-feldmann May 13, 2025
e240dca
Merge branch 'development' into instance-routine-mixins
c-w-feldmann May 13, 2025
9cd2354
remove final estimator
c-w-feldmann May 13, 2025
7fbc504
Merge branch 'development' into instance-routine-mixins
c-w-feldmann May 13, 2025
b3a5ce1
Merge branch 'development' into instance-routine-mixins
c-w-feldmann May 14, 2025
905bd9f
remode duplicate _estimator_type property
c-w-feldmann May 14, 2025
225858d
add ignore to duplicate code
c-w-feldmann May 14, 2025
577ea38
use sklearn native transform
c-w-feldmann May 14, 2025
dc076cf
delete _can_transform
c-w-feldmann May 14, 2025
b96254c
use sklearn native decision function
c-w-feldmann May 14, 2025
b500f4b
remove duplicate fit_predict function
c-w-feldmann May 14, 2025
d8a717a
rework predict function
c-w-feldmann May 14, 2025
654f549
Remove classes property
c-w-feldmann May 15, 2025
bf63000
Remove Validate steps
c-w-feldmann May 15, 2025
f489304
Switch type casting back and adapt types
c-w-feldmann May 15, 2025
f312483
use super.fit
c-w-feldmann May 15, 2025
763f0dc
remove can decision function
c-w-feldmann May 15, 2025
389c2fa
ignore duplicate code (cannot be inherited)
c-w-feldmann May 15, 2025
1527dc5
pylint ignore
c-w-feldmann May 15, 2025
0f1c587
pylint ignore and move function
c-w-feldmann May 15, 2025
8cbeb9d
change ignore statement
c-w-feldmann May 15, 2025
49a6359
Merge branch 'development' into instance-routine-mixins
c-w-feldmann Jun 6, 2025
c1a7b51
Merge branch 'development' into instance-routine-mixins
c-w-feldmann Feb 12, 2026
f4f0842
Fixes
c-w-feldmann Feb 12, 2026
001806a
Fixes
c-w-feldmann Feb 12, 2026
36917ad
Docsig
c-w-feldmann Feb 12, 2026
d48163a
Merge branch 'development' into instance-routine-mixins
c-w-feldmann Feb 12, 2026
9dd9abb
Merge branch 'development' into instance-routine-mixins
c-w-feldmann Feb 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 91 additions & 153 deletions molpipeline/abstract_pipeline_elements/core.py

Large diffs are not rendered by default.

73 changes: 24 additions & 49 deletions molpipeline/error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from loguru import logger

from molpipeline.abstract_pipeline_elements.core import (
ABCPipelineElement,
InvalidInstance,
RemovedInstance,
SingleInstanceTransformerMixin,
TransformingPipelineElement,
)
from molpipeline.utils.molpipeline_types import AnyVarSeq, TypeFixedVarSeq
Expand All @@ -21,7 +23,7 @@
_S = TypeVar("_S")


class ErrorFilter(ABCPipelineElement):
class ErrorFilter(SingleInstanceTransformerMixin, TransformingPipelineElement):
"""Collects tracks and removes error values."""

element_ids: set[str]
Expand Down Expand Up @@ -71,7 +73,6 @@ def __init__(
self.element_ids = element_ids
self.filter_everything = filter_everything
self.n_total = 0
self._requires_fitting = True

@classmethod
def from_element_list(
Expand Down Expand Up @@ -179,7 +180,11 @@ def check_removal(self, value: Any) -> bool:
return False
return self.filter_everything or value.element_id in self.element_ids

def fit(self, values: AnyVarSeq, labels: Any = None) -> Self: # noqa: ARG002
def fit(
self,
values: AnyVarSeq, # noqa: ARG002
labels: Any = None, # noqa: ARG002
) -> Self:
"""Fit to input values.

Only for compatibility with sklearn Pipelines.
Expand Down Expand Up @@ -212,8 +217,8 @@ def fit_transform(
----------
values: TypeFixedVarSeq
Iterable to which element is fitted and which is subsequently transformed.
labels: Any
Label used for fitting. For compatibility with sklearn, not used.
labels: Any, optional
Label used for fitting.

Returns
-------
Expand Down Expand Up @@ -249,6 +254,9 @@ def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq:

"""
if self.n_total != len(values):
logger.error(
f"Expected {self.n_total} values, but got {len(values)}",
)
raise ValueError("Length of values does not match length of values in fit")
if isinstance(values, list):
out_list = []
Expand Down Expand Up @@ -288,6 +296,8 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq:
def transform_single(self, value: Any) -> Any:
"""Transform a single value.

Overrides parent method to not skip InvalidInstances.

Parameters
----------
value: Any
Expand Down Expand Up @@ -688,8 +698,8 @@ def fit_transform(
----------
values: TypeFixedVarSeq
Iterable to which element is fitted and which is subsequently transformed.
labels: Any
Label used for fitting. For compatibility with sklearn, not used.
labels: Any, optional
Label used for fitting.
**params: Any
Additional keyword arguments. For compatibility with sklearn, not used.

Expand All @@ -702,47 +712,10 @@ def fit_transform(
self.fit(values)
return self.transform(values)

def transform_single(self, value: Any) -> Any:
"""Transform a single value.

Parameters
----------
value: Any
Value to be transformed.

Returns
-------
Any
Transformed value.

"""
return self.pretransform_single(value)

def pretransform_single(self, value: Any) -> Any:
"""Transform a single value.

Parameters
----------
value: Any
Value to be transformed.

Returns
-------
Any
Transformed value.

"""
if (
isinstance(value, RemovedInstance)
and value.filter_element_id == self.error_filter.uuid
):
return self.fill_value
return value

def transform(
self,
values: TypeFixedVarSeq,
**_params: Any,
**params: Any,
) -> TypeFixedVarSeq:
"""Transform iterable of values by removing invalid instances.

Expand All @@ -752,7 +725,7 @@ def transform(
----------
values: TypeFixedVarSeq
Iterable of which according invalid instances are removed.
**_params: Any
**params: Any
Additional keyword arguments.

Raises
Expand All @@ -770,12 +743,14 @@ def transform(
if len(values) != self.error_filter.n_total - len(
self.error_filter.error_indices,
):
expected_length = self.error_filter.n_total - len(
self.error_filter.n_total - len(
self.error_filter.error_indices,
)
raise ValueError(
"Length of values does not match length of values in fit. "
f"Expected: {expected_length} - Received :{len(values)}",
f"Length of values does not match length of values in fit. "
f"Expected: "
f"{self.error_filter.n_total - len(self.error_filter.error_indices)} "
f"- Received :{len(values)}",
)
return self.fill_with_dummy(values)

Expand Down
8 changes: 3 additions & 5 deletions molpipeline/estimators/chemprop/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_trainer_path(trainer: pl.Trainer) -> str | None:
None if the path is the current path.

"""
curr_path = str(Path().resolve())
curr_path = str(Path.cwd())
trainer_path: str | None = trainer.default_root_dir
if trainer_path == curr_path:
trainer_path = None
Expand Down Expand Up @@ -139,7 +139,7 @@ def get_params_trainer(trainer: pl.Trainer) -> dict[str, Any]:
else:
enable_progress_model_summary = False

trainer_dict = {
return {
"accelerator": get_device(trainer),
"strategy": "auto", # trainer.strategy, # collides with accelerator
"devices": "auto", # trainer._accelerator_connector._devices_flag does not work
Expand Down Expand Up @@ -180,7 +180,6 @@ def get_params_trainer(trainer: pl.Trainer) -> dict[str, Any]:
"reload_dataloaders_every_n_epochs": trainer.reload_dataloaders_every_n_epochs, # type: ignore[attr-defined]
"default_root_dir": get_trainer_path(trainer),
}
return trainer_dict


def get_non_default_params_trainer(trainer: pl.Trainer) -> dict[str, Any]:
Expand All @@ -198,9 +197,8 @@ def get_non_default_params_trainer(trainer: pl.Trainer) -> dict[str, Any]:

"""
trainer_dict = get_params_trainer(trainer)
non_default_values = {
return {
key: value
for key, value in trainer_dict.items()
if key not in TRAINER_DEFAULT_PARAMS or value != TRAINER_DEFAULT_PARAMS[key]
}
return non_default_values
11 changes: 7 additions & 4 deletions molpipeline/estimators/murcko_scaffold_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ def _generate_pipeline(self) -> Pipeline:
)

# Directly add the error filter and replacer to the pipeline
pipeline_step_list.append(("no_scaffold_filter", no_scaffold_filter))
pipeline_step_list.append(("no_scaffold_replacer", no_scaffold_replacer))
pipeline_step_list.extend(
(
("no_scaffold_filter", no_scaffold_filter),
("no_scaffold_replacer", no_scaffold_replacer),
),
)
else:
raise ValueError(
f"Invalid value for linear_molecules_strategy: "
Expand All @@ -139,11 +143,10 @@ def _generate_pipeline(self) -> Pipeline:
],
)

cluster_pipeline = Pipeline(
return Pipeline(
pipeline_step_list,
n_jobs=self.n_jobs,
)
return cluster_pipeline

# pylint: disable=C0103,W0613
@_fit_context(prefer_skip_nested_validation=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ def get_color_normalizer_from_data(

"""
abs_max = np.max(np.abs(values))
normalizer = colors.Normalize(vmin=-abs_max, vmax=abs_max)
return normalizer
return colors.Normalize(vmin=-abs_max, vmax=abs_max)


def color_canvas(canvas: Draw.MolDraw2D, color_grid: ColorGrid) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def color_tuple_to_colormap(
half_to_one = col1 * linspace4d + col2 * (1 - linspace4d)

# Creating new colormap from
color_map = ListedColormap(np.vstack([zero_to_half, half_to_one]))
return color_map
return ListedColormap(np.vstack([zero_to_half, half_to_one]))


def to_png(data: bytes) -> Image.Image:
Expand All @@ -165,8 +164,7 @@ def to_png(data: bytes) -> Image.Image:

"""
bio = io.BytesIO(data)
img = Image.open(bio)
return img
return Image.open(bio)


def plt_to_pil(figure: plt.Figure) -> Image.Image:
Expand All @@ -186,8 +184,7 @@ def plt_to_pil(figure: plt.Figure) -> Image.Image:
bio = io.BytesIO()
figure.savefig(bio, format="png")
bio.seek(0)
img = Image.open(bio)
return img
return Image.open(bio)


def get_atom_coords_of_bond(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def _make_grid_from_mol(

xl = list(pad(xl, padding[0])) # Increasing size of x-axis
yl = list(pad(yl, padding[1])) # Increasing size of y-axis
v_map = ValueGrid(xl, yl, grid_resolution[0], grid_resolution[1])
return v_map
return ValueGrid(xl, yl, grid_resolution[0], grid_resolution[1])


def _add_gaussians_for_atoms(
Expand Down Expand Up @@ -537,8 +536,7 @@ def structure_heatmap(
color_limits,
)
figure_bytes = drawer.GetDrawingText()
image = to_png(figure_bytes)
return image
return to_png(figure_bytes)


def structure_heatmap_shap( # pylint: disable=too-many-locals
Expand Down
3 changes: 1 addition & 2 deletions molpipeline/mol2any/mol2bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,4 @@ def transform_single(self, value: Any) -> Any:
Bool representation of the molecule.

"""
pre_value = self.pretransform_single(value)
return self.finalize_single(pre_value)
return self.pretransform_single(value)
Loading
Loading