Skip to content

Commit e6b2f19

Browse files
add buffer and checks on eta
1 parent 81f2cd4 commit e6b2f19

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
of the residuals. Must be between ``0`` and ``1``.
9595
Default is ``0.999``.
9696
:raises: ValueError if `gamma` is not in the range (0, 1).
97+
:raises: ValueError if `eta` is not greater than 0.
9798
"""
9899
super().__init__(
99100
model=model,
@@ -114,24 +115,20 @@ def __init__(
114115
f"Invalid range: expected 0 < gamma < 1, but got {gamma}"
115116
)
116117

118+
# Validate range for eta
119+
if eta <= 0:
120+
raise ValueError(f"Invalid range: expected eta > 0, but got {eta}")
121+
117122
# Initialize parameters
118123
self.eta = eta
119124
self.gamma = gamma
120125

121126
# Initialize the weight of each point to 0
122-
self.weights = {
123-
cond: torch.zeros((len(data), 1), device=self.device)
124-
for cond, data in self.problem.input_pts.items()
125-
}
126-
127-
def on_train_start(self):
128-
"""
129-
Hook method called at the beginning of training.
130-
"""
131-
device = self.trainer.strategy.root_device
132-
for cond in self.weights:
133-
self.weights[cond] = self.weights[cond].to(device)
134-
return super().on_train_start()
127+
self.weights = {}
128+
for cond, data in self.problem.input_pts.items():
129+
buffer_tensor = torch.zeros((len(data), 1), device=self.device)
130+
self.register_buffer(f"weight_{cond}", buffer_tensor)
131+
self.weights[cond] = getattr(self, f"weight_{cond}")
135132

136133
def training_step(self, batch, batch_idx, **kwargs):
137134
"""

tests/test_solver/test_rba_pinn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_constructor(problem, eta, gamma):
4747
with pytest.raises(ValueError):
4848
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
4949

50+
with pytest.raises(ValueError):
51+
solver = RBAPINN(model=model, problem=problem, eta=-0.1)
52+
5053
assert solver.accepted_conditions_types == (
5154
InputTargetCondition,
5255
InputEquationCondition,

0 commit comments

Comments
 (0)