PyTorch新旧版本fft函数的差别
前言
跑代码的时候发现torch新版本和旧版本的fft函数不一样
正文
旧版torch.fft(input,signal_ndim, normalized=False)
参数说明如下
第一维是输入,第二维是需要做fft的数据的维度,第三维是是否进行正则化
新版torch.fft.fft(input,n=None,dim=-1,norm=None,*,out=None)
参数说明如下
input (Tensor) – the input tensor
n (int, optional) – Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the FFT.
dim (int, optional) – The dimension along which to take the one dimensional FFT.
norm (str, optional) –
Normalization mode. For the forward transform (
fft()
), these correspond to:
"forward"
- normalize by1/n
"backward"
- no normalization"ortho"
- normalize by1/sqrt(n)
(making the FFT orthonormal) #保证中间过程中能量不变Calling the backward transform (
ifft()
) with the same normalization mode will apply an overall normalization of1/n
between the two transforms. This is required to makeifft()
the exact inverse.Default is
"backward"
(no normalization).
第一维没变,第二维变成了信号长度,如果设置的的话所有输入会被裁剪或补0到相同的长度,第三维是要做傅里叶变换的维度,因为这个函数实际上是做了一维傅里叶变换,第四维是正则化的方式,相对于旧版函数多了正交(ortho)的选项,从旧函数的布尔类型变量变成了字符类型变量,分别在fft和ifft做$1/\sqrt{n}$的缩放,保证在fft和ifft的中间进行操作的过程中能量不变。后面的参数我也不确定是做什么用的,大概是用于设置输出Tensor的尺寸。
下面来看一个简单的例子
torch.fft(x, 2, normalized=1)
如何改成新版的fft函数呢,我们知道,对于一个张量/矩阵做二维傅里叶变换,等于分别在前两个维度做一维傅里叶变换,旧版函数是在x上做了二维傅里叶变换,那么我们应该使用新版函数分别在x的第一和第二维做傅里叶变换
torch.fft.fft(x, dim=-1, norm='ortho')#这里norm根据实际情况设置
torch.fft.fft(x, dim=-2, norm='ortho')
或者也可以使用另一个函数torch.fft.fft2(input, s=None, dim=(- 2, - 1), norm=None, *, out=None)
,这个函数是二维傅里叶变换,实际上和在两个维度上分别作一维fft相同
参数和一维fft也基本一样
input (Tensor) – the input tensor
s (Tuple**[int]**, optional) – Signal size in the transformed dimensions. If given, each dimension
dim[i]
will either be zero-padded or trimmed to the lengths[i]
before computing the FFT. If a length-1
is specified, no padding is done in that dimension. Default:s = [input.size(d) for d in dim]
dim (Tuple**[int]**, optional) – Dimensions to be transformed. Default: last two dimensions.
norm (str, optional) –
Normalization mode. For the forward transform (
fft2()
), these correspond to:
"forward"
- normalize by1/n
"backward"
- no normalization"ortho"
- normalize by1/sqrt(n)
(making the FFT orthonormal)Where
n = prod(s)
is the logical FFT size. Calling the backward transform (ifft2()
) with the same normalization mode will apply an overall normalization of1/n
between the two transforms. This is required to makeifft2()
the exact inverse.Default is
"backward"
(no normalization).
需要注意的是长度和维度参数从一个数变成了一个元组
使用torch.fft.fft2
替换上述的代码
torch.fft.ff2(x, dim=(-1, -2), norm='orcho')
更高维度的fft可以用fftn函数实现
torch.ifft
在新版中变成了torch.fft.ifft
函数,其他的变化和参数均与fft函数相似,不再赘述。
参考文献
https://pytorch.org/docs/1.12/generated/torch.fft.ifft2.html?highlight=ifft2#torch.fft.ifft2
https://pytorch.org/docs/0.4.0/torch.html?highlight=fft#torch.fft
https://runebook.dev/zh/docs/pytorch/fft
王孜师兄