Skip to content

Commit 12a00fb

Browse files
Try to parallelize evolve_slice
1 parent c6f8a5f commit 12a00fb

3 files changed

Lines changed: 164 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pineappl/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ itertools = "0.10.1"
2727
lz4_flex = "0.9.2"
2828
ndarray = { features = ["serde"], version = "0.15.4" }
2929
pineappl-v0 = { package = "pineappl", version = "0.8.2" }
30+
rayon = "1.10"
3031
rustc-hash = "1.1.0"
3132
serde = { features = ["derive"], version = "1.0.130" }
3233
thiserror = "1.0.30"

pineappl/src/evolution.rs

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ use ndarray::{
1515
s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, ArrayViewD, ArrayViewMutD, Axis,
1616
Ix1, Ix2,
1717
};
18+
use rayon::prelude::*;
1819
use std::iter;
20+
use std::sync::{Arc, Mutex};
1921

2022
/// This structure captures the information needed to create an evolution kernel operator (EKO) for
2123
/// a specific [`Grid`].
@@ -355,7 +357,7 @@ fn ndarray_from_subgrid_orders_slice(
355357
Ok((x1n, (!zero).then_some(array)))
356358
}
357359

358-
pub(crate) fn evolve_slice(
360+
pub(crate) fn evolve_slice_sequential(
359361
grid: &Grid,
360362
operators: &[ArrayView4<f64>],
361363
infos: &[OperatorSliceInfo],
@@ -494,6 +496,163 @@ pub(crate) fn evolve_slice(
494496
))
495497
}
496498

499+
pub(crate) fn evolve_slice(
500+
grid: &Grid,
501+
operators: &[ArrayView4<f64>],
502+
infos: &[OperatorSliceInfo],
503+
scale_values: &[f64],
504+
order_mask: &[bool],
505+
xi: (f64, f64, f64),
506+
alphas_table: &AlphasTable,
507+
) -> Result<(Array3<SubgridEnum>, Vec<Channel>)> {
508+
let gluon_has_pid_zero = gluon_has_pid_zero(grid);
509+
510+
let mut fac1_scales: Vec<_> = infos.iter().map(|info| info.fac1).collect();
511+
fac1_scales.sort_by(f64::total_cmp);
512+
assert!(fac1_scales
513+
.windows(2)
514+
.all(|scales| subgrid::node_value_eq(scales[0], scales[1])));
515+
let fac1 = fac1_scales[0];
516+
517+
assert_eq!(operators.len(), infos.len());
518+
assert_eq!(operators.len(), grid.convolutions().len());
519+
520+
let (pid_indices, pids01): (Vec<_>, Vec<_>) = izip!(0..infos.len(), operators, infos)
521+
.map(|(d, operator, info)| {
522+
pid_slices(operator, info, gluon_has_pid_zero, &|pid1| {
523+
grid.channels()
524+
.iter()
525+
.flat_map(Channel::entry)
526+
.any(|(pids, _)| pids[d] == pid1)
527+
})
528+
})
529+
.collect::<Result<Vec<_>>>()?
530+
.into_iter()
531+
.unzip();
532+
533+
let mut channels0: Vec<_> = pids01
534+
.iter()
535+
.map(|pids| pids.iter().map(|&(pid0, _)| pid0))
536+
.multi_cartesian_product()
537+
.collect();
538+
channels0.sort_unstable();
539+
channels0.dedup();
540+
let channels0 = channels0;
541+
542+
let mut sub_fk_tables = Vec::with_capacity(grid.bwfl().len() * channels0.len());
543+
let mut last_x1 = vec![Vec::new(); infos.len()];
544+
let mut eko_slices = vec![Vec::new(); infos.len()];
545+
let dim: Vec<_> = infos.iter().map(|info| info.x0.len()).collect();
546+
547+
for subgrids_oc in grid.subgrids().axis_iter(Axis(1)) {
548+
// Use Arc<Mutex<>> for thread-safe access to tables
549+
let tables = Arc::new(Mutex::new(vec![
550+
ArrayD::zeros(dim.clone());
551+
channels0.len()
552+
]));
553+
554+
for (subgrids_o, channel1) in subgrids_oc.axis_iter(Axis(1)).zip(grid.channels()) {
555+
let (x1, array) = ndarray_from_subgrid_orders_slice(
556+
grid,
557+
fac1,
558+
grid.kinematics(),
559+
&subgrids_o,
560+
grid.orders(),
561+
order_mask,
562+
xi,
563+
alphas_table,
564+
)?;
565+
566+
let Some(array) = array else {
567+
continue;
568+
};
569+
570+
// Sequential lazy loading update (must remain sequential)
571+
for (last_x1, x1, pid_indices, slices, operator, info) in izip!(
572+
&mut last_x1,
573+
x1,
574+
&pid_indices,
575+
&mut eko_slices,
576+
operators,
577+
infos
578+
) {
579+
if (last_x1.len() != x1.len())
580+
|| last_x1
581+
.iter()
582+
.zip(x1.iter())
583+
.any(|(&lhs, &rhs)| !subgrid::node_value_eq(lhs, rhs))
584+
{
585+
*slices = operator_slices(operator, info, pid_indices, &x1)?;
586+
*last_x1 = x1;
587+
}
588+
}
589+
590+
// Collect channel entry operations for parallelization
591+
let channel_operations: Vec<_> = channel1
592+
.entry()
593+
.iter()
594+
.enumerate()
595+
.map(|(entry_idx, (pids1, factor))| {
596+
let table_ops: Vec<_> = channels0
597+
.iter()
598+
.enumerate()
599+
.filter_map(|(table_idx, pids0)| {
600+
izip!(pids0, pids1, &pids01, &eko_slices)
601+
.map(|(&pid0, &pid1, pids, slices)| {
602+
pids.iter().zip(slices).find_map(|(&(p0, p1), op)| {
603+
((p0 == pid0) && (p1 == pid1)).then_some(op)
604+
})
605+
})
606+
.collect::<Option<Vec<_>>>()
607+
.map(|ops| (table_idx, ops))
608+
})
609+
.collect();
610+
(entry_idx, *factor, table_ops)
611+
})
612+
.collect();
613+
614+
// Parallelize across channel operations
615+
channel_operations
616+
.into_par_iter()
617+
.for_each(|(_, factor, table_ops)| {
618+
table_ops.into_par_iter().for_each(|(table_idx, ops)| {
619+
// Create temporary result
620+
let mut temp_table = ArrayD::zeros(dim.clone());
621+
general_tensor_mul(factor, array.view(), &ops, temp_table.view_mut());
622+
623+
// Add to main table under lock
624+
let mut tables_guard = tables.lock().unwrap();
625+
tables_guard[table_idx] += &temp_table;
626+
});
627+
});
628+
}
629+
630+
let mut node_values = vec![scale_values.to_vec()];
631+
for info in infos {
632+
node_values.push(info.x0.clone());
633+
}
634+
635+
let tables = Arc::try_unwrap(tables).unwrap().into_inner().unwrap();
636+
sub_fk_tables.extend(tables.into_iter().map(|table| {
637+
ImportSubgridV1::new(
638+
PackedArray::from(table.insert_axis(Axis(0)).view()),
639+
node_values.clone(),
640+
)
641+
.into()
642+
}));
643+
}
644+
645+
Ok((
646+
Array1::from_iter(sub_fk_tables)
647+
.into_shape((1, grid.bwfl().len(), channels0.len()))
648+
.unwrap(),
649+
channels0
650+
.into_iter()
651+
.map(|c| Channel::new(vec![(c, 1.0)]))
652+
.collect(),
653+
))
654+
}
655+
497656
fn general_tensor_mul(
498657
factor: f64,
499658
array: ArrayViewD<f64>,

0 commit comments

Comments
 (0)