1+ import pytest
2+ import torch
3+ from spectpsftoolbox .kernel1d import FunctionKernel1D
4+ from spectpsftoolbox .kernel2d import NGonKernel2D , FunctionalKernel2D
5+ from spectpsftoolbox .utils import get_kernel_meshgrid
6+ from spectpsftoolbox .operator2d import Kernel2DOperator , GaussianOperator
7+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
8+
9+ def test_1Dkernel ():
10+ amplitude_fn = lambda a , bs : bs [0 ]* torch .exp (- bs [1 ]* a )
11+ sigma_fn = lambda a , bs : bs [0 ]* (a + 0.1 )
12+ amplitude_params = torch .tensor ([2 ,0.1 ], device = device , dtype = torch .float32 )
13+ sigma_params = torch .tensor ([0.3 ], device = device , dtype = torch .float32 )
14+ kernel_fn = lambda x : torch .exp (- torch .abs (x ))
15+ kernel1D = FunctionKernel1D (kernel_fn , amplitude_fn , sigma_fn , amplitude_params , sigma_params )
16+ x = torch .linspace (- 5 ,5 ,100 ).to (device )
17+ a = torch .linspace (1 ,10 ,5 ).to (device )
18+ kernel_value = kernel1D (x , a )
19+
20+ def test_FunctionalKernel2D ():
21+ Nx0 = 255
22+ dx0 = 0.05
23+ x = y = torch .arange (- (Nx0 - 1 )/ 2 , (Nx0 + 1 )/ 2 , 1 ).to (device ) * dx0
24+ xv , yv = torch .meshgrid (x , y , indexing = 'xy' )
25+ kernel_fn = lambda xv , yv : torch .exp (- torch .abs (xv ))* torch .exp (- torch .abs (yv )) * torch .sin (xv * 3 )** 2 * torch .cos (yv * 3 )** 2
26+ amplitude_fn = lambda a , bs : bs [0 ]* torch .exp (- bs [1 ]* a )
27+ sigma_fn = lambda a , bs : bs [0 ]* (a + 0.1 )
28+ amplitude_params = torch .tensor ([2 ,0.1 ], device = device , dtype = torch .float32 )
29+ sigma_params = torch .tensor ([0.3 ], device = device , dtype = torch .float32 )
30+ # Define the kernel
31+ kernel2D = FunctionalKernel2D (kernel_fn , amplitude_fn , sigma_fn , amplitude_params , sigma_params )
32+ a = torch .linspace (1 ,10 ,5 ).to (device )
33+ kernel = kernel2D (xv , yv , a , normalize = True )
34+
35+ def test_NGonKernel2D ():
36+ collimator_length = 2.405
37+ collimator_width = 0.254 #flat side to flat side
38+ sigma_fn = lambda a , bs : (bs [0 ]+ a ) / bs [0 ]
39+ sigma_params = torch .tensor ([collimator_length ], requires_grad = True , dtype = torch .float32 , device = device )
40+ # Set amplitude to 1
41+ amplitude_fn = lambda a , bs : torch .ones_like (a )
42+ amplitude_params = torch .tensor ([1. ], requires_grad = True , dtype = torch .float32 , device = device )
43+ ngon_kernel = NGonKernel2D (
44+ N_sides = 6 , # sides of polygon
45+ Nx = 255 , # resolution of polygon
46+ collimator_width = collimator_width , # width of polygon
47+ amplitude_fn = amplitude_fn ,
48+ sigma_fn = sigma_fn ,
49+ amplitude_params = amplitude_params ,
50+ sigma_params = sigma_params ,
51+ rot = 90
52+ )
53+ Nx0 = 255
54+ dx0 = 0.048
55+ x = y = torch .arange (- (Nx0 - 1 )/ 2 , (Nx0 + 1 )/ 2 , 1 ).to (device ) * dx0
56+ xv , yv = torch .meshgrid (x , y , indexing = 'xy' )
57+ distances = torch .tensor ([1 ,5 ,10 ,15 ,20 ,25 ], dtype = torch .float32 , device = device )
58+ kernel = ngon_kernel (xv , yv , distances , normalize = True ).cpu ().detach ()
59+
60+ def test_Operator1 ():
61+ # Tests Kernel2DOperator, GaussianOperator, and Operator __mult__
62+ # -------------------
63+ # Collimator Component
64+ # -------------------
65+ collimator_length = 2.405
66+ collimator_width = 0.254 #flat side to flat side
67+ mu = 28.340267562430935
68+ sigma_fn = lambda a , bs : (bs [0 ]+ a ) / bs [0 ]
69+ sigma_params = torch .tensor ([collimator_length - 2 / mu ], requires_grad = True , dtype = torch .float32 , device = device )
70+ # Set amplitude to 1
71+ amplitude_fn = lambda a , bs : torch .ones_like (a )
72+ amplitude_params = torch .tensor ([1. ], requires_grad = True , dtype = torch .float32 , device = device )
73+ ngon_kernel = NGonKernel2D (
74+ N_sides = 6 , # sides of polygon
75+ Nx = 255 , # resolution of polygon
76+ collimator_width = collimator_width , # width of polygon
77+ amplitude_fn = amplitude_fn ,
78+ sigma_fn = sigma_fn ,
79+ amplitude_params = amplitude_params ,
80+ sigma_params = sigma_params ,
81+ rot = 90
82+ )
83+ ngon_operator = Kernel2DOperator (ngon_kernel )
84+ # -------------------
85+ # Detector component
86+ # -------------------
87+ intrinsic_sigma = 0.1614 # typical for NaI 140keV detection
88+ gauss_amplitude_fn = lambda a , bs : torch .ones_like (a )
89+ gauss_sigma_fn = lambda a , bs : bs [0 ]* torch .ones_like (a )
90+ gauss_amplitude_params = torch .tensor ([1. ], requires_grad = True , dtype = torch .float32 , device = device )
91+ gauss_sigma_params = torch .tensor ([intrinsic_sigma ], requires_grad = True , device = device , dtype = torch .float32 )
92+ scint_operator = GaussianOperator (
93+ gauss_amplitude_fn ,
94+ gauss_sigma_fn ,
95+ gauss_amplitude_params ,
96+ gauss_sigma_params ,
97+ )
98+ # Total combined:
99+ psf_operator = scint_operator * ngon_operator
100+ Nx0 = 512
101+ dx0 = 0.24
102+ x = y = torch .arange (- (Nx0 - 1 )/ 2 , (Nx0 + 1 )/ 2 , 1 ).to (device ) * dx0
103+ xv , yv = torch .meshgrid (x , y , indexing = 'xy' )
104+ distances = torch .arange (0.36 , 57.9600 , 0.48 ).to (device )
105+ # Get kernel meshgrid
106+ k_width = 24 #cm
107+ xv_k , yv_k = get_kernel_meshgrid (xv , yv , k_width )
108+ # Create input with point source at origin
109+ input = torch .zeros_like (xv ).unsqueeze (0 ).repeat (distances .shape [0 ], 1 , 1 )
110+ input [:,256 ,256 ] = 1
111+ output = psf_operator (input , xv_k , yv_k , distances , normalize = True )
0 commit comments