Skip to content

Commit 9f975f6

Browse files
committed
create a ReplicaSettings dataclass to hold any replica-dependent configuration
1 parent 3d653fd commit 9f975f6

8 files changed

Lines changed: 169 additions & 63 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ instance/
352352

353353
# Sphinx documentation
354354
docs/_build/
355+
doc/sphinx/source/theories_central.csv
355356

356357
# PyBuilder
357358
target/

n3fit/src/n3fit/model_gen.py

Lines changed: 152 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Contains:
55
# observable_generator:
66
Generates the output layers as functions
7-
# pdfNN_layer_generator:
7+
# _pdfNN_layer_generator:
88
Generates the PDF NN layer to be fitted
99
1010
@@ -329,6 +329,56 @@ def observable_generator(
329329
return layer_info
330330

331331

332+
@dataclass
333+
class _ReplicaSettings:
334+
"""Dataclass which hold all necessary replica-dependent information of a PDF.
335+
336+
Parameters
337+
----------
338+
seed: int
339+
seed for the initialization of the neural network
340+
nodes: list[int]
341+
nodes of each of the layers, starting at the first hidden layer
342+
activations: list[str]
343+
list of activation functions, should be of equal length as nodes
344+
architecture: str
345+
select the architecture of the neural network used for the replica,
346+
e.g. ``dense`` or ``dense_per_flavour``
347+
initializer: str
348+
initializer to be used for this replica
349+
dropout: float
350+
rate of dropout for each layer
351+
regularizer: str
352+
name of the regularizer to use for this replica (if any)
353+
regularizer_args: dict
354+
options to pass down to the regularizer (if any)
355+
"""
356+
357+
seed: int
358+
nodes: list[int]
359+
activations: list[str]
360+
architecture: str
361+
initializer: str
362+
dropout_rate: float = 0.0
363+
regularizer: str = None
364+
regularizer_args: dict = None
365+
366+
def __post_init__(self):
367+
"""Apply checks to the input"""
368+
# TODO: this check cannot be enabled yet
369+
# if len(self.nodes) != len(self.activations):
370+
# raise ValueError(
371+
# f"nodes and activations do not match ({self.nodes} vs {self.activations}"
372+
# )
373+
if self.regularizer_args is not None and self.regularizer is None:
374+
raise ValueError(
375+
"Regularizer arguments have been provided but not regularizer is selected"
376+
)
377+
378+
379+
# TODO:
380+
# 1. Decide whether the sampling occurs in this function or before entering here
381+
# at the moment, it occurs here.
332382
def generate_pdf_model(
333383
nodes: list[int] = None,
334384
activations: list[str] = None,
@@ -337,7 +387,7 @@ def generate_pdf_model(
337387
flav_info: dict = None,
338388
fitbasis: str = "NN31IC",
339389
out: int = 14,
340-
seed: int = None,
390+
seed_list: list[int] = None,
341391
dropout: float = 0.0,
342392
regularizer: str = None,
343393
regularizer_args: dict = None,
@@ -347,66 +397,119 @@ def generate_pdf_model(
347397
photons: Photon = None,
348398
):
349399
"""
350-
Wrapper around pdfNN_layer_generator to allow the generation of single replica models.
400+
Generation of the full PDF model which will be used to determine the full PDF.
401+
The full PDF model can have any number of replicas, which can be trained in parallel,
402+
the limitations of the determination means that there are certain traits that all replicas
403+
must share, while others are fre per-PDF.
404+
405+
In its most general form, the output of this function is a :py:class:`n3fit.backend.MetaModel`
406+
with the following architecture:
407+
408+
<input layer>
409+
in the standard PDF fit this includes only the (x) grid of the NN
410+
411+
[ list of a separate architecture per replica ]
412+
which can be, but is not necessary, equal for all replicas
413+
414+
<preprocessing factors>
415+
postprocessing of the network output by a variation x^{alpha}*(1-x)^{beta}
416+
417+
<normalization>
418+
physical sum rules, requires an integral over the PDF
419+
420+
<rotation to FK-basis>
421+
regardless of the physical basis in which the PDF and preprocessing factors are applied
422+
the output is rotated to the 14-flavour general basis used in FkTables following
423+
PineaAPPL's convention
424+
425+
[<output layer>]
426+
14 flavours per value of x per replica
427+
note that, depending on the fit basis (and fitting scale)
428+
the output of the PDF will contain repeated values
429+
430+
431+
This function defines how the PDFs will be generated.
432+
In the case of identical PDF models (``identical_models = True``, default) the same
433+
settings will be used for all replicas.
434+
Otherwise, the sampling routines will be used.
435+
351436
352437
Parameters:
353438
-----------
354-
see model_gen.pdfNN_layer_generator
439+
<TODO> to be filled
355440
356441
Returns
357442
-------
358443
pdf_model: MetaModel
359-
pdf model, with `single_replica_generator` attached in a list as an attribute
444+
pdf model, with `single_replica_generator` attached as an attribute
360445
"""
361-
joint_args = {
362-
"nodes": nodes,
363-
"activations": activations,
364-
"initializer_name": initializer_name,
365-
"layer_type": layer_type,
446+
if len(seed_list) != num_replicas:
447+
# TODO: remove this error, remove the num_replicas argument
448+
raise ValueError("This should not happen")
449+
450+
num_replicas = len(seed_list)
451+
452+
# Separate the settings which may be different for each replica
453+
# from those that are guaranteed to be equal for all replicas
454+
455+
all_replicas = []
456+
for seed in seed_list:
457+
tmp = _ReplicaSettings(
458+
seed=seed,
459+
nodes=nodes,
460+
activations=activations,
461+
initializer=initializer_name,
462+
architecture=layer_type,
463+
dropout_rate=dropout,
464+
regularizer=regularizer,
465+
regularizer_args=regularizer_args,
466+
)
467+
all_replicas.append(tmp)
468+
469+
shared_config = {
366470
"flav_info": flav_info,
367471
"fitbasis": fitbasis,
368472
"out": out,
369-
"dropout": dropout,
370-
"regularizer": regularizer,
371-
"regularizer_args": regularizer_args,
372473
"impose_sumrule": impose_sumrule,
373474
"scaler": scaler,
475+
"photons": photons,
374476
}
375477

376-
pdf_model = pdfNN_layer_generator(
377-
**joint_args, seed=seed, num_replicas=num_replicas, photons=photons
378-
)
478+
pdf_model = _pdfNN_layer_generator(all_replicas, **shared_config)
379479

380480
# Note that the photons are passed unchanged to the single replica generator
381481
# computing the photon requires running fiatlux which takes 30' per replica
382482
# and so at the moment parallel photons are disabled with a check in checks.py
383483
# In order to enable it `single_replica_generator` must take the index of the replica
384484
# to select the appropiate photon as all of them will be computed and fixed before the fit
385485

386-
# this is necessary to be able to convert back to single replica models after training
387-
single_replica_generator = lambda: pdfNN_layer_generator(
388-
**joint_args, seed=0, num_replicas=1, photons=photons, replica_axis=False
389-
)
486+
def single_replica_generator(replica_idx=0):
487+
"""Generate one single replica from the entire batch.
488+
The select index is relative to the batch, not the entire PDF determination.
489+
490+
This function is necessary to separate all the different models after training.
491+
"""
492+
settings = all_replicas[replica_idx]
493+
# TODO:
494+
# In principle we want to recover the initial replica exactly,
495+
# however, for the regression tests to pass
496+
# _in the polarized case and only in the polarized case_ this line is necessary
497+
# it most likely has to do with numerical precision, but panicking might be in order
498+
settings.seed = 0
499+
return _pdfNN_layer_generator([settings], **shared_config, replica_axis=False)
500+
390501
pdf_model.single_replica_generator = single_replica_generator
391502

392503
return pdf_model
393504

394505

395-
def pdfNN_layer_generator(
396-
nodes: list[int] = None,
397-
activations: list[str] = None,
398-
initializer_name: str = "glorot_normal",
399-
layer_type: str = "dense",
506+
def _pdfNN_layer_generator(
507+
replicas_settings: list[_ReplicaSettings],
400508
flav_info: dict = None,
401509
fitbasis: str = "NN31IC",
402510
out: int = 14,
403-
seed: int = None,
404-
dropout: float = 0.0,
405-
regularizer: str = None,
406-
regularizer_args: dict = None,
407511
impose_sumrule: str = None,
408512
scaler: Callable = None,
409-
num_replicas: int = 1,
410513
photons: Photon = None,
411514
replica_axis: bool = True,
412515
): # pylint: disable=too-many-locals
@@ -465,14 +568,16 @@ def pdfNN_layer_generator(
465568
466569
>>> import numpy as np
467570
>>> from n3fit.vpinterface import N3PDF
468-
>>> from n3fit.model_gen import pdfNN_layer_generator
571+
>>> from n3fit.model_gen import _pdfNN_layer_generator
469572
>>> from validphys.pdfgrids import xplotting_grid
470573
>>> fake_fl = [{'fl' : i, 'largex' : [0,1], 'smallx': [1,2]} for i in ['u', 'ubar', 'd', 'dbar', 'c', 'cbar', 's', 'sbar']]
471574
>>> fake_x = np.linspace(1e-3,0.8,3)
472-
>>> pdf_model = pdfNN_layer_generator(nodes=[8], activations=['linear'], seed=[2,3], flav_info=fake_fl, num_replicas=2)
575+
>>> pdf_model = _pdfNN_layer_generator(nodes=[8], activations=['linear'], seed=[2,3], flav_info=fake_fl, num_replicas=2)
473576
474577
Parameters
475578
----------
579+
seed: list(int)
580+
seed for the initialization of the Neural Network
476581
nodes: list(int)
477582
list of the number of nodes per layer of the PDF NN. Default: [15,8]
478583
activation: list
@@ -488,8 +593,6 @@ def pdfNN_layer_generator(
488593
to be used by Preprocessing
489594
out: int
490595
number of output flavours of the model (default 14)
491-
seed: list(int)
492-
seed to initialize the NN
493596
dropout: float
494597
rate of dropout layer by layer
495598
impose_sumrule: str
@@ -513,22 +616,23 @@ def pdfNN_layer_generator(
513616
pdf_model: n3fit.backends.MetaModel
514617
a model f(x) = y where x is a tensor (1, xgrid, 1) and y a tensor (1, replicas, xgrid, out)
515618
"""
516-
# Parse the input configuration
517-
if seed is None:
518-
seed = num_replicas * [None]
519-
elif isinstance(seed, int):
520-
seed = num_replicas * [seed]
521-
522-
if nodes is None:
523-
nodes = [15, 8]
619+
# TODO: at the moment nothing changes, just the signature of the function
620+
seed = [i.seed for i in replicas_settings]
621+
nodes = replicas_settings[0].nodes
622+
activations = replicas_settings[0].activations
623+
layer_type = replicas_settings[0].architecture
624+
initializer_name = replicas_settings[0].initializer
625+
num_replicas = len(replicas_settings)
626+
dropout = replicas_settings[0].dropout_rate
627+
regularizer = replicas_settings[0].regularizer
628+
regularizer_args = replicas_settings[0].regularizer_args
629+
524630
ln = len(nodes)
525631

526632
if impose_sumrule is None:
527633
impose_sumrule = "All"
528634

529-
if activations is None:
530-
activations = ["tanh", "linear"]
531-
elif callable(activations):
635+
if callable(activations):
532636
# hyperopt passes down a function to generate dynamically the list of
533637
# activations functions
534638
activations = activations(ln)
@@ -608,7 +712,7 @@ def pdfNN_layer_generator(
608712
large_x=not subtract_one,
609713
)
610714

611-
nn_replicas = generate_nn(
715+
nn_replicas = _generate_nn(
612716
layer_type=layer_type,
613717
nodes_in=nn_input_dimensions,
614718
nodes=nodes,
@@ -691,14 +795,14 @@ def compute_unnormalized_pdf(x):
691795
return pdf_model
692796

693797

694-
def generate_nn(
798+
def _generate_nn(
695799
layer_type: str,
696800
nodes_in: int,
697801
nodes: list[int],
698802
activations: list[str],
699803
initializer_name: str,
700-
replica_seeds: list[int],
701804
dropout: float,
805+
replica_seeds: list[int],
702806
regularizer: str,
703807
regularizer_args: dict,
704808
last_layer_nodes: int,

n3fit/src/n3fit/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def _generate_pdf(
705705
layer_type=layer_type,
706706
flav_info=self.flavinfo,
707707
fitbasis=self.fitbasis,
708-
seed=seed,
708+
seed_list=seed,
709709
initializer_name=initializer,
710710
dropout=dropout,
711711
regularizer=regularizer,

n3fit/src/n3fit/tests/test_hyperopt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def generate_pdf(seed, num_replicas):
2828
pdf_model = generate_pdf_model(
2929
nodes=[8],
3030
activations=["linear"],
31-
seed=seed,
31+
seed_list=seed,
3232
num_replicas=num_replicas,
3333
flav_info=fake_fl,
3434
fitbasis="FLAVOUR",

n3fit/src/n3fit/tests/test_modelgen.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
2-
Test for the model generation
2+
Test for the model generation
33
4-
These tests check that the generated NN are as expected
5-
It checks that both the number of layers and the shape
6-
of the weights of the layers are what is expected
4+
These tests check that the generated NN are as expected
5+
It checks that both the number of layers and the shape
6+
of the weights of the layers are what is expected
77
"""
88

99
from n3fit.backends import NN_PREFIX
10-
from n3fit.model_gen import generate_nn
10+
from n3fit.model_gen import _generate_nn
1111

1212
INSIZE = 16
1313
OUT_SIZES = (4, 3)
@@ -27,7 +27,7 @@
2727

2828

2929
def test_generate_dense_network():
30-
nn = generate_nn("dense", **COMMON_ARGS)
30+
nn = _generate_nn("dense", **COMMON_ARGS)
3131

3232
# The number of layers should be input layer + len(OUT_SIZES)
3333
assert len(nn.layers) == len(OUT_SIZES) + 1
@@ -38,7 +38,7 @@ def test_generate_dense_network():
3838

3939

4040
def test_generate_dense_per_flavour_network():
41-
nn = generate_nn("dense_per_flavour", **COMMON_ARGS).get_layer(f"{NN_PREFIX}_0")
41+
nn = _generate_nn("dense_per_flavour", **COMMON_ARGS).get_layer(f"{NN_PREFIX}_0")
4242

4343
# The number of layers should be input + BASIS_SIZE*len(OUT_SIZES) + concatenate
4444
assert len(nn.layers) == BASIS_SIZE * len(OUT_SIZES) + 2

n3fit/src/n3fit/tests/test_multireplica.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_replica_split():
1414
pdf_model = generate_pdf_model(
1515
nodes=[8],
1616
activations=["linear"],
17-
seed=34,
17+
seed_list=[34] * num_replicas,
1818
flav_info=fake_fl,
1919
fitbasis="FLAVOUR",
2020
num_replicas=num_replicas,

0 commit comments

Comments
 (0)