Skip to content

Commit 35ef191

Browse files
committed
Fix ruff warnings in SOFIE tutorials
In particular, replace `raiseError` which is only available in pytest fixtures.
1 parent 1db9f94 commit 35ef191

2 files changed

Lines changed: 4 additions & 7 deletions

File tree

tutorials/machine_learning/TMVA_SOFIE_Keras.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import contextlib
1313
import warnings
1414

15+
import numpy as np
1516
import ROOT
17+
from tensorflow.keras.layers import Activation, Dense, Input, Softmax
18+
from tensorflow.keras.models import Model
1619

1720
# Enable ROOT in batch mode (same effect as -nodraw)
1821
ROOT.gROOT.SetBatch(True)
@@ -44,10 +47,6 @@ def expect_warning(category, message):
4447
# Step 1: Create and train a simple Keras model (via embedded Python)
4548
# -----------------------------------------------------------------------------
4649

47-
import numpy as np
48-
from tensorflow.keras.layers import Activation, Dense, Input, Softmax
49-
from tensorflow.keras.models import Model
50-
5150
input = Input(shape=(4,), batch_size=2)
5251
x = Dense(32)(input)
5352
x = Activation("relu")(x)
@@ -81,8 +80,6 @@ def expect_warning(category, message):
8180
# Step 2: Use TMVA::SOFIE to parse the ONNX model
8281
# -----------------------------------------------------------------------------
8382

84-
import ROOT
85-
8683
# Parse the ONNX model
8784

8885
model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse("KerasModel.keras")

tutorials/machine_learning/TMVA_SOFIE_ONNX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def ParseModel(modelFile, verbose=False):
179179

180180
testFailed = abs(y_sofie - y_ort) > 0.01
181181
if np.any(testFailed):
182-
raiseError("Result is different between SOFIE and ONNXRT")
182+
raise RuntimeError("Result is different between SOFIE and ONNXRT")
183183
else:
184184
print("OK")
185185

0 commit comments

Comments
 (0)