Skip to content

Commit 2671f2d

Browse files
authored
Merge pull request #273 from asogaard/fix-torch-in-data
Fix torch import in graphnet.data
2 parents dd390bd + 3ddce57 commit 2671f2d

8 files changed

Lines changed: 54 additions & 36 deletions

File tree

src/graphnet/data/dataconverter.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,11 @@ def final(f): # Identity decorator
2727
I3TruthExtractor,
2828
)
2929
from graphnet.utilities.filesys import find_i3_files
30-
from graphnet.utilities.logging import LoggerMixin, get_logger
30+
from graphnet.utilities.imports import has_icecube_package
31+
from graphnet.utilities.logging import LoggerMixin
3132

32-
logger = get_logger()
33-
34-
try:
33+
if has_icecube_package():
3534
from icecube import icetray, dataio # pyright: reportMissingImports=false
36-
except ImportError:
37-
logger.warning("icecube package not available.")
3835

3936

4037
SAVE_STRATEGIES = [

src/graphnet/data/extractors/i3extractor.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
from abc import ABC, abstractmethod
22
from typing import List
33

4-
from graphnet.utilities.logging import LoggerMixin, get_logger
4+
from graphnet.utilities.imports import has_icecube_package
5+
from graphnet.utilities.logging import LoggerMixin
56

6-
logger = get_logger()
7-
8-
try:
9-
from icecube import (
10-
icetray,
11-
dataio,
12-
) # pyright: reportMissingImports=false
13-
except ImportError:
14-
logger.warning("icecube package not available.")
7+
if has_icecube_package():
8+
from icecube import icetray, dataio # pyright: reportMissingImports=false
159

1610

1711
class I3Extractor(ABC, LoggerMixin):

src/graphnet/data/extractors/i3featureextractor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
from graphnet.data.extractors.i3extractor import I3Extractor
2-
from graphnet.utilities.logging import get_logger
3-
4-
logger = get_logger()
5-
try:
6-
from icecube import (
7-
dataclasses,
8-
) # pyright: reportMissingImports=false
9-
except ImportError:
10-
logger.warning("icecube package not available.")
2+
from graphnet.utilities.imports import has_icecube_package
3+
4+
if has_icecube_package():
5+
from icecube import dataclasses # pyright: reportMissingImports=false
116

127

138
class I3FeatureExtractor(I3Extractor):

src/graphnet/data/extractors/i3truthextractor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,14 @@
77
frame_is_montecarlo,
88
frame_is_noise,
99
)
10-
from graphnet.utilities.logging import get_logger
10+
from graphnet.utilities.imports import has_icecube_package
1111

12-
logger = get_logger()
13-
14-
try:
12+
if has_icecube_package():
1513
from icecube import (
1614
dataclasses,
1715
icetray,
1816
phys_services,
1917
) # pyright: reportMissingImports=false
20-
except ImportError:
21-
logger.warning("icecube package not available.")
2218

2319

2420
class I3TruthExtractor(I3Extractor):
@@ -385,5 +381,5 @@ def _find_data_type(self, mc, input_file):
385381
if "L2" in input_file: # not robust
386382
sim_type = "dbang"
387383
if sim_type == "lol":
388-
logger.info("SIM TYPE NOT FOUND!")
384+
self.logger.info("SIM TYPE NOT FOUND!")
389385
return sim_type
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1+
from graphnet.utilities.imports import has_torch_package
2+
13
from .parquet_dataconverter import ParquetDataConverter
2-
from .parquet_dataset import ParquetDataset
4+
5+
if has_torch_package():
6+
from .parquet_dataset import ParquetDataset
7+
8+
del has_torch_package
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
from graphnet.utilities.imports import has_torch_package
2+
13
from .sqlite_dataconverter import SQLiteDataConverter
2-
from .sqlite_dataset import SQLiteDataset
3-
from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed
44
from .sqlite_utilities import run_sql_code, save_to_sql
5+
6+
if has_torch_package():
7+
from .sqlite_dataset import SQLiteDataset
8+
from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed
9+
10+
del has_torch_package

src/graphnet/utilities/imports.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from functools import wraps
44

5-
from graphnet.utilities.logging import get_logger
5+
from graphnet.utilities.logging import get_logger, warn_once
66

77

88
logger = get_logger()
@@ -15,6 +15,23 @@ def has_icecube_package() -> bool:
1515

1616
return True
1717
except ImportError:
18+
warn_once(
19+
logger,
20+
"`icecube` not available. Some functionality may be missing.",
21+
)
22+
return False
23+
24+
25+
def has_torch_package() -> bool:
26+
"""Check whether the `torch` package is available."""
27+
try:
28+
import torch
29+
30+
return True
31+
except ImportError:
32+
warn_once(
33+
logger, "`torch` not available. Some functionality may be missing."
34+
)
1835
return False
1936

2037

src/graphnet/utilities/logging.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Consistent and configurable logging across the project."""
22

33
from collections import Counter
4+
from functools import lru_cache
45
import re
56
from typing import Optional
67
import colorlog
@@ -53,6 +54,12 @@ def get_formatters() -> Tuple[logging.Formatter, colorlog.ColoredFormatter]:
5354
return basic_formatter, colored_formatter
5455

5556

57+
@lru_cache(1)
58+
def warn_once(logger: logging.Logger, message: str):
59+
"""Print `message` as warning exactly once."""
60+
logger.warn(message)
61+
62+
5663
class RepeatFilter(object):
5764
"""Filter out repeat messages."""
5865

0 commit comments

Comments
 (0)