-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlasso_nmr_quantification.py
More file actions
325 lines (258 loc) · 11.8 KB
/
lasso_nmr_quantification.py
File metadata and controls
325 lines (258 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
#!/usr/bin/env python3
"""
LASSO-based NMR Mixture Quantification with Shift-Tolerant Dictionary
"""
import sys
sys.path.insert(0, '.')
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso, LassoCV
from sklearn.metrics import r2_score
from quantify_all_metabolites_v3_ref20 import (
read_and_process, find_tsp_peak, integrate_peak, get_metabolite_info_v3
)
class LassoNMRQuantifier:
def __init__(self, reference_base_dir="raw_data/Reference_Raw_Date_JCAMP-DX"):
self.reference_dir = reference_base_dir
self.metabolite_info = get_metabolite_info_v3()
self.pure_spectra = {}
self.common_ppm = np.linspace(0.5, 5.5, 1000) # Common ppm grid
# Metabolites known to have concentration-dependent shifts
self.shift_sensitive = ['Asparagine', 'Aspartate', 'Glutamate']
self._load_references()
def _load_references(self):
"""Load and preprocess all reference spectra (File 20)"""
print("Loading reference spectra...")
for met_name, info in self.metabolite_info.items():
ref_file = os.path.join(self.reference_dir, info['folder'], "20.dx")
if not os.path.exists(ref_file):
print(f" Warning: {ref_file} not found")
continue
try:
# Load and process
ppm, spec = read_and_process(ref_file)
tsp = find_tsp_peak(ppm, spec)
ppm_corr = ppm - tsp
tsp_area = integrate_peak(ppm_corr, spec, (-0.2, 0.2))
spec_norm = spec / tsp_area
# Interpolate to common ppm grid
spec_interp = np.interp(self.common_ppm, ppm_corr, spec_norm)
self.pure_spectra[met_name] = {
'spectrum': spec_interp,
'concentration': info['files'][20], # Reference concentration
'shift_sensitive': met_name in self.shift_sensitive
}
except Exception as e:
print(f" Error loading {met_name}: {e}")
print(f" Loaded {len(self.pure_spectra)} reference spectra")
def build_dictionary(self, include_shifts=True, shift_range=(-0.15, 0.15), n_shifts=5, normalize=True):
"""
Build design matrix dictionary
For shift-sensitive metabolites, include multiple shifted versions
"""
dictionary = []
labels = []
for met_name, data in self.pure_spectra.items():
spec = data['spectrum'].copy()
if include_shifts and data['shift_sensitive']:
# Create multiple shifted versions
shifts = np.linspace(shift_range[0], shift_range[1], n_shifts)
for shift in shifts:
# Shift by interpolation
ppm_shifted = self.common_ppm + shift
spec_shifted = np.interp(self.common_ppm, ppm_shifted, spec)
# Normalize column to unit norm (important for LASSO)
if normalize:
norm = np.linalg.norm(spec_shifted)
if norm > 0:
spec_shifted = spec_shifted / norm
dictionary.append(spec_shifted)
labels.append(f"{met_name}_shift{shift:+.3f}")
else:
# Single version (no shift)
# Normalize column to unit norm
if normalize:
norm = np.linalg.norm(spec)
if norm > 0:
spec = spec / norm
dictionary.append(spec)
labels.append(f"{met_name}_shift0.000")
X = np.column_stack(dictionary)
return X, labels
def quantify(self, mixture_file, alpha=None, include_shifts=True):
"""
Quantify metabolites in mixture using LASSO
Parameters:
mixture_file: Path to .dx file
alpha: LASSO regularization (None = use CV to find optimal)
include_shifts: Whether to include shifted dictionary entries
Returns:
dict with concentrations and metadata
"""
print(f"\n{'='*60}")
print(f"LASSO Quantification: {os.path.basename(mixture_file)}")
print(f"{'='*60}")
# Load mixture
ppm, spec = read_and_process(mixture_file)
tsp = find_tsp_peak(ppm, spec)
ppm_corr = ppm - tsp
tsp_area = integrate_peak(ppm_corr, spec, (-0.2, 0.2))
spec_norm = spec / tsp_area
print(f"TSP correction: +{-tsp:.4f} ppm")
print(f"TSP area: {tsp_area:.2e}")
# Interpolate to common grid
spec_interp = np.interp(self.common_ppm, ppm_corr, spec_norm)
# Build dictionary
X, labels = self.build_dictionary(include_shifts=include_shifts)
y = spec_interp
print(f"Dictionary size: {X.shape[1]} entries")
print(f" (includes shifted versions for: {', '.join(self.shift_sensitive)})")
# LASSO regression
if alpha is None:
# Use cross-validation to find optimal alpha
print("\nRunning LASSO with cross-validation...")
model = LassoCV(cv=5, fit_intercept=False, positive=True,
alphas=np.logspace(-4, -1, 50), max_iter=5000)
else:
print(f"\nRunning LASSO with alpha={alpha}...")
model = Lasso(alpha=alpha, fit_intercept=False, positive=True, max_iter=5000)
model.fit(X, y)
# Get optimal alpha if using CV
if alpha is None:
alpha = model.alpha_
print(f"Optimal alpha: {alpha:.6f}")
# Calculate R²
y_pred = model.predict(X)
r2 = r2_score(y, y_pred)
print(f"R² (reconstruction): {r2:.4f}")
# Extract results - group by metabolite, keep best shift
results = {}
for label, coef in zip(labels, model.coef_):
if coef < 1e-4: # Skip negligible concentrations
continue
# Parse label
parts = label.split('_shift')
met_name = parts[0]
shift = float(parts[1])
# Get reference concentration
ref_conc = self.pure_spectra[met_name]['concentration']
abs_conc = coef * ref_conc
# Keep only the best shift (highest concentration) for each metabolite
if met_name not in results or abs_conc > results[met_name]['concentration']:
results[met_name] = {
'concentration': abs_conc,
'shift': shift,
'coefficient': coef,
'ref_conc': ref_conc
}
# Store model for plotting
self.last_model = model
self.last_X = X
self.last_y = y
self.last_labels = labels
return results, r2
def plot_reconstruction(self, mixture_name, save_path=None):
"""Plot mixture spectrum vs LASSO reconstruction"""
if not hasattr(self, 'last_model'):
print("No model to plot. Run quantify() first.")
return
fig, axes = plt.subplots(3, 1, figsize=(16, 12))
# Full spectrum
ax = axes[0]
y_pred = self.last_model.predict(self.last_X)
ax.plot(self.common_ppm, self.last_y, 'b-', linewidth=0.8, label='Mixture', alpha=0.7)
ax.plot(self.common_ppm, y_pred, 'r--', linewidth=1.0, label='LASSO Reconstruction')
ax.set_xlim(5.5, 0.5)
ax.set_xlabel('ppm')
ax.set_ylabel('Normalized Intensity')
ax.set_title(f'{mixture_name} - LASSO Reconstruction')
ax.legend()
ax.grid(True, alpha=0.3)
# Aliphatic region
ax = axes[1]
mask = (self.common_ppm >= 0.5) & (self.common_ppm <= 3.0)
ax.plot(self.common_ppm[mask], self.last_y[mask], 'b-', linewidth=1.0, label='Mixture')
ax.plot(self.common_ppm[mask], y_pred[mask], 'r--', linewidth=1.5, label='Reconstruction')
ax.set_xlim(3.0, 0.5)
ax.set_xlabel('ppm')
ax.set_ylabel('Normalized Intensity')
ax.set_title('Aliphatic Region (0.5-3.0 ppm)')
ax.legend()
ax.grid(True, alpha=0.3)
# Residual
ax = axes[2]
residual = self.last_y - y_pred
ax.plot(self.common_ppm, residual, 'g-', linewidth=0.8)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
ax.set_xlim(5.5, 0.5)
ax.set_xlabel('ppm')
ax.set_ylabel('Residual')
ax.set_title('Residual (Mixture - Reconstruction)')
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.close()
def compare_with_lorentzian(self, mixture_file, lorentzian_results):
"""Compare LASSO vs Lorentzian fitting results"""
print(f"\n{'='*70}")
print("COMPARISON: LASSO vs Lorentzian Fitting")
print(f"{'='*70}")
# Get LASSO results
lasso_results, r2 = self.quantify(mixture_file, alpha=None)
# Print comparison table
print(f"\n{'Metabolite':<15} {'Lorentzian (mM)':>18} {'LASSO (mM)':>18} {'Ratio L/L':>12}")
print("-"*70)
all_mets = set(lorentzian_results.keys()) | set(lasso_results.keys())
for met in sorted(all_mets):
lor = lorentzian_results.get(met, {}).get('concentration', 0)
las = lasso_results.get(met, {}).get('concentration', 0)
ratio = las / lor if lor > 0 else float('inf')
print(f"{met:<15} {lor:>18.2f} {las:>18.2f} {ratio:>12.2f}")
return lasso_results
def main():
"""Test LASSO quantification on File 10"""
# Initialize quantifier
quantifier = LassoNMRQuantifier()
# Quantify File 10
mixture_file = "raw_data/Model_Mixtures-M1-M6_JCAMP-DX/10.dx"
results, r2 = quantifier.quantify(mixture_file, alpha=None, include_shifts=True)
# Print results
print(f"\n{'='*60}")
print("LASSO Quantification Results - File 10")
print(f"{'='*60}")
print(f"{'Metabolite':<15} {'Conc (mM)':>12} {'Shift (ppm)':>12}")
print("-"*60)
total = 0
for met_name in sorted(results.keys()):
conc = results[met_name]['concentration']
shift = results[met_name]['shift']
total += conc
print(f"{met_name:<15} {conc:>12.2f} {shift:>12.3f}")
print(f"{'='*60}")
print(f"{'Total':<15} {total:>12.2f}")
print(f"{'='*60}")
# Plot reconstruction
quantifier.plot_reconstruction("File 10", "lasso_reconstruction_file10.png")
# Compare with Lorentzian results (from previous run)
lorentzian_results = {
'Alanine': {'concentration': 7.68},
'Arginine': {'concentration': 7.20},
'Asparagine': {'concentration': 1.13},
'Aspartate': {'concentration': 0.53},
'Glucose': {'concentration': 28.98},
'Glutamate': {'concentration': 4.31},
'Glutamine': {'concentration': 5.03},
'Isoleucine': {'concentration': 3.25},
'Lactate': {'concentration': 27.81},
'Leucine': {'concentration': 8.98},
'Methionine': {'concentration': 1.22},
'Phenylalanine': {'concentration': 5.99},
'Tyrosine': {'concentration': 0.76},
'Valine': {'concentration': 0.61},
}
quantifier.compare_with_lorentzian(mixture_file, lorentzian_results)
if __name__ == "__main__":
main()