Skip to content

Commit 6c5dce5

Browse files
author
Octave Guinebretiere
committed
add complex_matmul
1 parent d391a4f commit 6c5dce5

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

complexFunctions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d
99
import torch
1010

11+
def complex_matmul(A, B):
12+
'''
13+
Performs the matrix product between two complex matrices
14+
'''
15+
16+
outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag)
17+
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real)
18+
19+
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
1120

1221
def complex_avg_pool2d(input, *args, **kwargs):
1322
'''

0 commit comments

Comments
 (0)