PyTorch新旧版本fft函数的差别

PyTorch新旧版本fft函数的差别

前言

跑代码的时候发现torch新版本和旧版本的fft函数不一样

正文

旧版torch.fft(input,signal_ndim, normalized=False)参数说明如下

  • input (Tensor) – the input tensor of at least signal_ndim + 1 dimensions
  • signal_ndim (int) – the number of dimensions in each signal. signal_ndim can only be 1, 2 or 3
  • normalized (bool, optional) – controls whether to return normalized results. Default: False

第一维是输入,第二维是需要做fft的数据的维度,第三维是是否进行正则化

新版torch.fft.fft(input,n=None,dim=-1,norm=None,*,out=None)

参数说明如下

  • input (Tensor) – the input tensor

  • n (int, optional) – Signal length. If given, the input will either be zero-padded or trimmed to this length before computing the FFT.

  • dim (int, optional) – The dimension along which to take the one dimensional FFT.

  • norm (str, optional) –

    Normalization mode. For the forward transform (fft()), these correspond to:

    • "forward" - normalize by 1/n
    • "backward" - no normalization
    • "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal) #保证中间过程中能量不变

    Calling the backward transform (ifft()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft() the exact inverse.

    Default is "backward" (no normalization).

第一维没变,第二维变成了信号长度,如果设置的的话所有输入会被裁剪或补0到相同的长度,第三维是要做傅里叶变换的维度,因为这个函数实际上是做了一维傅里叶变换,第四维是正则化的方式,相对于旧版函数多了正交(ortho)的选项,从旧函数的布尔类型变量变成了字符类型变量,分别在fft和ifft做$1/\sqrt{n}$的缩放,保证在fft和ifft的中间进行操作的过程中能量不变。后面的参数我也不确定是做什么用的,大概是用于设置输出Tensor的尺寸。

下面来看一个简单的例子

torch.fft(x, 2, normalized=1)

如何改成新版的fft函数呢,我们知道,对于一个张量/矩阵做二维傅里叶变换,等于分别在前两个维度做一维傅里叶变换,旧版函数是在x上做了二维傅里叶变换,那么我们应该使用新版函数分别在x的第一和第二维做傅里叶变换

torch.fft.fft(x, dim=-1, norm='ortho')#这里norm根据实际情况设置
torch.fft.fft(x, dim=-2, norm='ortho')

或者也可以使用另一个函数torch.fft.fft2(input, s=None, dim=(- 2, - 1), norm=None, *, out=None),这个函数是二维傅里叶变换,实际上和在两个维度上分别作一维fft相同

参数和一维fft也基本一样

  • input (Tensor) – the input tensor

  • s (Tuple**[int]**, optional) – Signal size in the transformed dimensions. If given, each dimension dim[i] will either be zero-padded or trimmed to the length s[i] before computing the FFT. If a length -1 is specified, no padding is done in that dimension. Default: s = [input.size(d) for d in dim]

  • dim (Tuple**[int]**, optional) – Dimensions to be transformed. Default: last two dimensions.

  • norm (str, optional) –

    Normalization mode. For the forward transform (fft2()), these correspond to:

    • "forward" - normalize by 1/n
    • "backward" - no normalization
    • "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

    Where n = prod(s) is the logical FFT size. Calling the backward transform (ifft2()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft2() the exact inverse.

    Default is "backward" (no normalization).

需要注意的是长度和维度参数从一个数变成了一个元组

使用torch.fft.fft2替换上述的代码

torch.fft.ff2(x, dim=(-1, -2), norm='orcho')

更高维度的fft可以用fftn函数实现

torch.ifft在新版中变成了torch.fft.ifft函数,其他的变化和参数均与fft函数相似,不再赘述。

参考文献

https://pytorch.org/docs/1.12/generated/torch.fft.ifft2.html?highlight=ifft2#torch.fft.ifft2

https://pytorch.org/docs/0.4.0/torch.html?highlight=fft#torch.fft

https://runebook.dev/zh/docs/pytorch/fft

王孜师兄


文章作者: keevinzha
版权声明: 咳咳想白嫖文章?本文章著作权归作者所有,任何形式的转载都请注明出处。 https://www.keevinzha.com !
  目录