1010from collections import deque
1111
1212import numpy as np
13- import scipy
14- import scipy .stats
1513from loguru import logger
1614
1715from vimms .ChemicalSamplers import (
2624 PROTON_MASS ,
2725 POSITIVE ,
2826 NEGATIVE ,
29- C12_PROPORTION ,
3027 C13_MZ_DIFF ,
31- C ,
3228 MONO ,
33- C13 ,
3429 load_obj ,
3530 ADDUCT_NAMES_POS ,
3631 ADDUCT_NAMES_NEG ,
32+ ADDUCT_PRIOR_POS ,
33+ ADDUCT_PRIOR_NEG ,
34+ NATURAL_ISOTOPES ,
3735)
3836from vimms .Noise import GaussianPeakNoise
3937from vimms .Roi import make_roi , RoiBuilderParams
@@ -70,15 +68,21 @@ class Isotopes:
7068 A class to represent an isotope of a chemical
7169 """
7270
73- def __init__ (self , formula ):
71+ def __init__ (self , formula , min_prob = 1e-12 , max_peaks = 20 , max_states = 4000 , mass_precision = 8 ):
7472 """
7573 Create an Isotope object
7674 Args:
7775 formula: the formula for the given isotope
7876 """
7977 self .formula = formula
78+ self .min_prob = min_prob
79+ self .max_peaks = max_peaks
80+ self .max_states = max_states
81+ self .mass_precision = mass_precision
8082
81- def get_isotopes (self , total_proportion ):
83+ def get_isotopes (
84+ self , total_proportion , min_prob = None , max_peaks = None , max_states = None , mass_precision = None
85+ ):
8286 """
8387 Gets the isotope total proportion
8488
@@ -87,93 +91,151 @@ def get_isotopes(self, total_proportion):
8791
8892 Returns: the computed isotope total proportion
8993
90- TODO: Add functionality for elements other than Carbon
9194 """
92- peaks = [() for i in range (len (self ._get_isotope_proportions (total_proportion )))]
93- for i in range (len (peaks )):
94- peaks [i ] += (self ._get_isotope_mz (self ._get_isotope_names (i )),)
95- peaks [i ] += (self ._get_isotope_proportions (total_proportion )[i ],)
96- peaks [i ] += (self ._get_isotope_names (i ),)
95+ peaks = []
96+ distributions = self ._get_isotope_distribution (
97+ total_proportion = total_proportion ,
98+ min_prob = self .min_prob if min_prob is None else min_prob ,
99+ max_peaks = self .max_peaks if max_peaks is None else max_peaks ,
100+ max_states = self .max_states if max_states is None else max_states ,
101+ mass_precision = self .mass_precision if mass_precision is None else mass_precision ,
102+ )
103+ base_mz = self .formula ._get_mz ()
104+ for idx , (mass_shift , proportion ) in enumerate (distributions ):
105+ name = MONO if idx == 0 else f"M+{ idx } "
106+ peaks .append ((base_mz + mass_shift , proportion , name ))
97107 return peaks
98108
99- def _get_isotope_proportions (self , total_proportion ):
100- """
101- Get isotope proportion by sampling from a binomial pmf
102-
103- Args:
104- total_proportion: the total proportion to compute
105-
106- Returns: the computed isotope total proportion
107-
108- """
109- proportions = []
110- while sum (proportions ) < total_proportion :
111- proportions .extend (
112- [
113- scipy .stats .binom .pmf (
114- len (proportions ), self .formula ._get_n_element (C ), 1 - C12_PROPORTION
115- )
116- ]
109+ def _get_isotope_distribution (
110+ self , total_proportion , min_prob = 1e-12 , max_peaks = 20 , max_states = 4000 , mass_precision = 8
111+ ):
112+ distribution = [(0.0 , 1.0 )]
113+ for element , count in self .formula .atoms .items ():
114+ if count <= 0 :
115+ continue
116+ isotopes = NATURAL_ISOTOPES .get (element )
117+ if not isotopes or len (isotopes ) == 1 :
118+ continue
119+ mono_mass = isotopes [0 ][0 ]
120+ base_distribution = [(mass - mono_mass , abundance ) for mass , abundance in isotopes ]
121+ element_distribution = self ._power_distribution (
122+ base_distribution ,
123+ count ,
124+ min_prob = min_prob ,
125+ max_states = max_states ,
126+ mass_precision = mass_precision ,
127+ )
128+ distribution = self ._convolve_distributions (
129+ distribution ,
130+ element_distribution ,
131+ min_prob = min_prob ,
132+ max_states = max_states ,
133+ mass_precision = mass_precision ,
117134 )
118- normalised_proportions = [
119- proportions [i ] / sum (proportions ) for i in range (len (proportions ))
120- ]
121- return normalised_proportions
122-
123- def _get_isotope_names (self , isotope_number ):
124- """
125- Get the isotope name given the number, e.g. 0 is the monoisotope
126- Args:
127- isotope_number: the isotope number
128-
129- Returns: the isotope name
130-
131- """
132- if isotope_number == 0 :
133- return MONO
134- else :
135- return str (isotope_number ) + C13
136-
137- def _get_isotope_mz (self , isotope ):
138- """
139- Get the isotope m/z value
140- Args:
141- isotope: the isotope name
142-
143- Returns: the isotope m/z value
144135
145- """
146- if isotope == MONO :
147- return self .formula ._get_mz ()
148- elif isotope [- 3 :] == C13 :
149- return self .formula ._get_mz () + float (isotope .split (C13 )[0 ]) * C13_MZ_DIFF
150- else :
151- return None
136+ distribution = [(shift , prob ) for shift , prob in distribution if prob >= min_prob ]
137+ distribution .sort (key = lambda x : x [0 ])
138+
139+ selected = []
140+ cumulative = 0.0
141+ for mass_shift , prob in distribution :
142+ selected .append ((mass_shift , prob ))
143+ cumulative += prob
144+ if cumulative >= total_proportion or len (selected ) >= max_peaks :
145+ break
146+
147+ total = sum (prob for _ , prob in selected )
148+ if total == 0 :
149+ return [(0.0 , 1.0 )]
150+ return [(shift , prob / total ) for shift , prob in selected ]
151+
152+ def _power_distribution (self , base_distribution , count , min_prob , max_states , mass_precision ):
153+ if count == 1 :
154+ return base_distribution
155+ result = [(0.0 , 1.0 )]
156+ power = base_distribution
157+ remaining = count
158+ while remaining > 0 :
159+ if remaining % 2 == 1 :
160+ result = self ._convolve_distributions (
161+ result ,
162+ power ,
163+ min_prob = min_prob ,
164+ max_states = max_states ,
165+ mass_precision = mass_precision ,
166+ )
167+ remaining //= 2
168+ if remaining :
169+ power = self ._convolve_distributions (
170+ power ,
171+ power ,
172+ min_prob = min_prob ,
173+ max_states = max_states ,
174+ mass_precision = mass_precision ,
175+ )
176+ return result
177+
178+ def _convolve_distributions (self , left , right , min_prob , max_states , mass_precision ):
179+ new_distribution = {}
180+ for left_shift , left_prob in left :
181+ for right_shift , right_prob in right :
182+ prob = left_prob * right_prob
183+ if prob < min_prob :
184+ continue
185+ shift = left_shift + right_shift
186+ key = round (shift , mass_precision )
187+ new_distribution [key ] = new_distribution .get (key , 0.0 ) + prob
188+ if not new_distribution :
189+ return []
190+ distribution = list (new_distribution .items ())
191+ if len (distribution ) > max_states :
192+ distribution .sort (key = lambda x : x [1 ], reverse = True )
193+ distribution = distribution [:max_states ]
194+ return distribution
152195
153196
154197class Adducts :
155198 """
156199 A class to represent an adduct of a chemical
157200 """
158201
159- def __init__ (self , formula , adduct_proportion_cutoff = 0.05 , adduct_prior_dict = None ):
202+ def __init__ (
203+ self ,
204+ formula ,
205+ adduct_proportion_cutoff = 0.05 ,
206+ adduct_prior_dict = None ,
207+ adduct_profile = None ,
208+ adduct_concentration = 15.0 ,
209+ ):
160210 """
161211 Create an Adduct class
162212
163213 Args:
164214 formula: the formula of this adduct
165215 adduct_proportion_cutoff: proportion cut-off of the adduct
166- adduct_prior_dict: custom adduct dictionary, if any
216+ adduct_prior_dict: custom adduct dictionary or callable, if any
217+ adduct_profile: preset profile name or dict of adduct priors
218+ adduct_concentration: dirichlet concentration for adduct sampling
167219 """
220+ if callable (adduct_prior_dict ):
221+ adduct_prior_dict = adduct_prior_dict (formula )
222+
223+ if adduct_prior_dict is None and adduct_profile is not None :
224+ from vimms .Common import ADDUCT_PROFILE_PRESETS
225+
226+ if isinstance (adduct_profile , str ):
227+ adduct_prior_dict = ADDUCT_PROFILE_PRESETS .get (adduct_profile )
228+ if adduct_prior_dict is None :
229+ raise ValueError (f"Unknown adduct profile '{ adduct_profile } '" )
230+ else :
231+ adduct_prior_dict = adduct_profile
232+
168233 if adduct_prior_dict is None :
169234 self .adduct_names = {POSITIVE : ADDUCT_NAMES_POS , NEGATIVE : ADDUCT_NAMES_NEG }
170235 self .adduct_prior = {
171- POSITIVE : np .ones ( len ( self . adduct_names [ POSITIVE ])) * 0.1 ,
172- NEGATIVE : np .ones ( len ( self . adduct_names [ NEGATIVE ])) * 0.1 ,
236+ POSITIVE : np .array ([ ADDUCT_PRIOR_POS . get ( name , 0.05 ) for name in ADDUCT_NAMES_POS ]) ,
237+ NEGATIVE : np .array ([ ADDUCT_PRIOR_NEG . get ( name , 0.05 ) for name in ADDUCT_NAMES_NEG ]) ,
173238 }
174- # give more weight to the first one, i.e. M+H
175- self .adduct_prior [POSITIVE ][0 ] = 1.0
176- self .adduct_prior [NEGATIVE ][0 ] = 1.0
177239 else :
178240 assert POSITIVE in adduct_prior_dict or NEGATIVE in adduct_prior_dict
179241 self .adduct_names = {k : list (adduct_prior_dict [k ].keys ()) for k in adduct_prior_dict }
@@ -182,6 +244,7 @@ def __init__(self, formula, adduct_proportion_cutoff=0.05, adduct_prior_dict=Non
182244 }
183245 self .formula = formula
184246 self .adduct_proportion_cutoff = adduct_proportion_cutoff
247+ self .adduct_concentration = adduct_concentration
185248
186249 def get_adducts (self ):
187250 """
@@ -204,15 +267,17 @@ def _get_adduct_proportions(self):
204267 Returns: adduct proportion after sampling
205268
206269 """
207- # TODO: replace this with something proper
208270 proportions = {}
209271 for k in self .adduct_prior :
210- proportions [ k ] = np . random . dirichlet ( self .adduct_prior [k ])
211- while max ( proportions [ k ]) < 0.2 :
212- proportions [k ] = np .random .dirichlet (self . adduct_prior [ k ] )
272+ alpha = self .adduct_prior [k ] * self . adduct_concentration
273+ alpha = np . where ( alpha > 0 , alpha , 0.001 )
274+ proportions [k ] = np .random .dirichlet (alpha )
213275 proportions [k ][np .where (proportions [k ] < self .adduct_proportion_cutoff )] = 0
214- proportions [k ] = proportions [k ] / max (proportions [k ])
215- proportions [k ].tolist ()
276+ if proportions [k ].sum () == 0 :
277+ proportions [k ] = np .zeros_like (proportions [k ])
278+ proportions [k ][np .argmax (alpha )] = 1.0
279+ else :
280+ proportions [k ] = proportions [k ] / proportions [k ].sum ()
216281 assert len (proportions [k ]) == len (self .adduct_names [k ])
217282 return proportions
218283
@@ -625,6 +690,8 @@ def __init__(
625690 ms2_sampler = UniformMS2Sampler (),
626691 adduct_proportion_cutoff = 0.05 ,
627692 adduct_prior_dict = None ,
693+ adduct_profile = None ,
694+ adduct_concentration = 15.0 ,
628695 ):
629696 """
630697 Create a mixture of [vimms.Chemicals.KnownChemical][] objects.
@@ -642,13 +709,17 @@ def __init__(
642709 fragmentation spectra.
643710 adduct_proportion_cutoff: proportion of adduct cut-off
644711 adduct_prior_dict: custom adduct dictionary
712+ adduct_profile: preset name or dict of adduct priors
713+ adduct_concentration: dirichlet concentration for adduct sampling
645714 """
646715 self .formula_sampler = formula_sampler
647716 self .rt_and_intensity_sampler = rt_and_intensity_sampler
648717 self .chromatogram_sampler = chromatogram_sampler
649718 self .ms2_sampler = ms2_sampler
650719 self .adduct_proportion_cutoff = adduct_proportion_cutoff
651720 self .adduct_prior_dict = adduct_prior_dict
721+ self .adduct_profile = adduct_profile
722+ self .adduct_concentration = adduct_concentration
652723
653724 # if self.database is not None:
654725 # logger.debug('Sorting database compounds by masses')
@@ -691,6 +762,8 @@ def sample(self, n_chemicals, ms_levels, include_adducts_isotopes=True):
691762 formula ,
692763 self .adduct_proportion_cutoff ,
693764 adduct_prior_dict = self .adduct_prior_dict ,
765+ adduct_profile = self .adduct_profile ,
766+ adduct_concentration = self .adduct_concentration ,
694767 )
695768
696769 chemicals .append (
0 commit comments