因为无信通信中经常用到复数的乘法,pytorch中又没有现成的处理方式,自己懒得写,就在网上搜了下好心人分享的资料。确实是正确的。但是作者采用循环的方式,在处理大批量数据的时候非常慢,我对几十万复数的数据集执行操作时花了好几分钟才完成,故进行改动。改为矩阵处理后,瞬间就可以出结果。并不是什么有难度的操作,只是记录下以后方便使用。
代码如下:
H1 = torch.randn(3,2) // 3代表数据的数量,选取小的数易于观察,2代表复数,分别是实部和虚部 H2 = torch.randn(3,2) print('H1',H1) print('H2',H2) def complexMulti(a,b): // 循环的方式 r = a.shape[0] c = torch.zeros([r,2]) for i in range(r): c[i,0] = a[i,0]*b[i,0]-a[i,1]*b[i,1] c[i,1] = a[i,0]*b[i,1]+a[i,1]*b[i,0] return c
def complexMulti_1(a,b): // 矩阵处理 r = a.shape[0] c = torch.zeros([r,2]) c[:,0] = a[:,0]*b[:,0]-a[:,1]*b[:,1] c[:,1] = a[:,0]*b[:,1]+a[:,1]*b[:,0] return c y = complexMulti(H1,H2) print('y',y) y1 = complexMulti_1(H1,H2) print('y1',y1)
结果如下:
H1 tensor([[-1.0672, -0.9281], [ 0.6703, 0.3585], [ 0.5990, -1.3287]]) H2 tensor([[-0.1537, 1.3645], [ 1.0352, 0.5920], [-0.9543, -0.8222]]) y tensor([[ 1.4305, -1.3135], [ 0.4817, 0.7680], [-1.6641, 0.7754]]) y1 tensor([[ 1.4305, -1.3135], [ 0.4817, 0.7680], [-1.6641, 0.7754]])与python的复数乘法比较:
x = -1.0672-0.9281j y = -0.1537+1.3645j print(x * y)
(1.43042109-1.31354543j)参考的文章如下:
https://blog.csdn.net/Stephanie2014/article/details/105984274