44import matplotlib .pyplot as plt
55from scipy .sparse import csr_matrix
66
7+
78def diagnose_qc (h5ad_path , output_path , min_umis = 500 , min_genes = 200 , max_mt = 0.15 ):
89 print (f"Loading { h5ad_path } ..." )
9-
10+
1011 with h5py .File (h5ad_path , "r" ) as f :
1112 # Load barcodes and gene names
12- barcodes = [b .decode ("utf-8" ) if isinstance (b , bytes ) else b for b in f ["obs" ]["_index" ][:]]
13- gene_names = [g .decode ("utf-8" ) if isinstance (g , bytes ) else g for g in f ["var" ]["_index" ][:]]
14-
13+ barcodes = [
14+ b .decode ("utf-8" ) if isinstance (b , bytes ) else b
15+ for b in f ["obs" ]["_index" ][:]
16+ ]
17+ gene_names = [
18+ g .decode ("utf-8" ) if isinstance (g , bytes ) else g
19+ for g in f ["var" ]["_index" ][:]
20+ ]
21+
1522 # Load expression matrix X
1623 X_group = f ["X" ]
1724 if isinstance (X_group , h5py .Group ):
1825 data = X_group ["data" ][:]
1926 indices = X_group ["indices" ][:]
2027 indptr = X_group ["indptr" ][:]
21- X = csr_matrix ((data , indices , indptr ), shape = (len (barcodes ), len (gene_names )))
28+ X = csr_matrix (
29+ (data , indices , indptr ), shape = (len (barcodes ), len (gene_names ))
30+ )
2231 else :
2332 X = f ["X" ][:]
24-
33+
2534 # Spatial coordinates (usually in obsm/spatial)
2635 if "obsm" in f and "spatial" in f ["obsm" ]:
2736 coords = f ["obsm" ]["spatial" ][:]
@@ -33,10 +42,14 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
3342 # Calculate QC metrics
3443 n_counts = np .array (X .sum (axis = 1 )).flatten ()
3544 n_genes = np .array ((X > 0 ).sum (axis = 1 )).flatten ()
36-
45+
3746 # MT fraction
3847 # Robust MT detection: check for mt- or MT- anywhere, but prioritize common patterns
39- mt_genes = [i for i , name in enumerate (gene_names ) if "mt-" in name .lower () or "mt:" in name .lower ()]
48+ mt_genes = [
49+ i
50+ for i , name in enumerate (gene_names )
51+ if "mt-" in name .lower () or "mt:" in name .lower ()
52+ ]
4053 if mt_genes :
4154 print (f"Found { len (mt_genes )} mitochondrial genes." )
4255 mt_counts = np .array (X [:, mt_genes ].sum (axis = 1 )).flatten ()
@@ -49,12 +62,12 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
4962 pass_umi = n_counts >= min_umis
5063 pass_gene = n_genes >= min_genes
5164 pass_mt = pct_counts_mt <= max_mt
52-
65+
5366 keep_mask = pass_umi & pass_gene & pass_mt
54-
67+
5568 total_spots = len (barcodes )
5669 kept_spots = np .sum (keep_mask )
57-
70+
5871 print (f"Total spots: { total_spots } " )
5972 print (f"Kept spots: { kept_spots } ({ kept_spots / total_spots :.1%} )" )
6073 print (f"Filtered: { total_spots - kept_spots } " )
@@ -64,31 +77,41 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
6477
6578 # Plotting
6679 fig , axes = plt .subplots (2 , 2 , figsize = (15 , 12 ))
67-
80+
6881 # 1. UMI vs Genes
6982 axes [0 , 0 ].scatter (n_counts , n_genes , c = keep_mask , cmap = "RdYlGn" , alpha = 0.5 , s = 10 )
70- axes [0 , 0 ].axvline (min_umis , color = "red" , linestyle = "--" , label = f"Min UMI={ min_umis } " )
71- axes [0 , 0 ].axhline (min_genes , color = "blue" , linestyle = "--" , label = f"Min Genes={ min_genes } " )
83+ axes [0 , 0 ].axvline (
84+ min_umis , color = "red" , linestyle = "--" , label = f"Min UMI={ min_umis } "
85+ )
86+ axes [0 , 0 ].axhline (
87+ min_genes , color = "blue" , linestyle = "--" , label = f"Min Genes={ min_genes } "
88+ )
7289 axes [0 , 0 ].set_xlabel ("Total UMI counts" )
7390 axes [0 , 0 ].set_ylabel ("Number of detected genes" )
7491 axes [0 , 0 ].set_title ("QC: UMI vs Genes" )
7592 axes [0 , 0 ].legend ()
76-
93+
7794 # 2. MT Fraction distribution
7895 axes [0 , 1 ].hist (pct_counts_mt , bins = 50 , color = "gray" , alpha = 0.7 )
79- axes [0 , 1 ].axvline (max_mt , color = "red" , linestyle = "--" , label = f"Max MT={ max_mt :.0%} " )
96+ axes [0 , 1 ].axvline (
97+ max_mt , color = "red" , linestyle = "--" , label = f"Max MT={ max_mt :.0%} "
98+ )
8099 axes [0 , 1 ].set_xlabel ("Mitochondrial Fraction" )
81100 axes [0 , 1 ].set_ylabel ("Count" )
82101 axes [0 , 1 ].set_title ("QC: MT Fraction Distribution" )
83102 axes [0 , 1 ].legend ()
84-
103+
85104 # 3. Spatial: Before (Total)
86- axes [1 , 0 ].scatter (coords [:, 0 ], coords [:, 1 ], c = "lightgray" , s = 15 , label = "All Spots" )
87- axes [1 , 0 ].scatter (coords [keep_mask , 0 ], coords [keep_mask , 1 ], c = "green" , s = 15 , label = "Pass QC" )
105+ axes [1 , 0 ].scatter (
106+ coords [:, 0 ], coords [:, 1 ], c = "lightgray" , s = 15 , label = "All Spots"
107+ )
108+ axes [1 , 0 ].scatter (
109+ coords [keep_mask , 0 ], coords [keep_mask , 1 ], c = "green" , s = 15 , label = "Pass QC"
110+ )
88111 axes [1 , 0 ].set_title ("Spatial Distribution: Kept vs Filtered" )
89112 axes [1 , 0 ].set_aspect ("equal" )
90113 axes [1 , 0 ].legend ()
91-
114+
92115 # 4. Summary Table (as text)
93116 stats_text = (
94117 f"Sample: { os .path .basename (h5ad_path )} \n \n "
@@ -102,21 +125,23 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
102125 )
103126 axes [1 , 1 ].text (0.1 , 0.5 , stats_text , fontsize = 14 , family = "monospace" )
104127 axes [1 , 1 ].axis ("off" )
105-
128+
106129 plt .tight_layout ()
107130 plt .savefig (output_path )
108131 print (f"Plot saved to { output_path } " )
109132
133+
110134if __name__ == "__main__" :
111135 import argparse
136+
112137 parser = argparse .ArgumentParser ()
113138 parser .add_argument ("--sample" , type = str , default = "MEND29" )
114139 args = parser .parse_args ()
115-
140+
116141 sample_id = args .sample
117142 h5ad_file = f"A:\\ hest_data\\ st\\ { sample_id } .h5ad"
118143 output_file = f"qc_diagnosis_{ sample_id } .png"
119-
144+
120145 if not os .path .exists (h5ad_file ):
121146 print (f"Error: { h5ad_file } not found." )
122147 else :
0 commit comments