Skip to content

Commit 765960f

Browse files
author
Your Name
committed
2 parents 75119e5 + 6c5dce5 commit 765960f

1 file changed

Lines changed: 20 additions & 1 deletion

File tree

complexFunctions.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@
55
@author: spopoff
66
"""
77

8-
from torch.nn.functional import relu, max_pool2d, dropout, dropout2d
8+
from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d
99
import torch
1010

11+
def complex_matmul(A, B):
12+
'''
13+
Performs the matrix product between two complex matrices
14+
'''
15+
16+
outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag)
17+
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real)
18+
19+
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
20+
21+
def complex_avg_pool2d(input, *args, **kwargs):
22+
'''
23+
Perform complex average pooling.
24+
'''
25+
absolute_value_real = avg_pool2d(input.real, *args, **kwargs)
26+
absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs)
27+
28+
return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64)
29+
1130
def complex_relu(input):
1231
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
1332

0 commit comments

Comments
 (0)