Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kaldi.fbank does not work with non-contiguous input when snip_edges=False #3856

Open
gau-nernst opened this issue Nov 25, 2024 · 0 comments
Open

Comments

@gau-nernst
Copy link

🐛 Describe the bug

from torchaudio.compliance.kaldi import fbank
import torch


x = torch.rand(1, 16_000 * 2) * (1 << 15)
x = x[:, ::2]
torch.testing.assert_close(fbank(x.contiguous(), snip_edges=False), fbank(x, snip_edges=False))
File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:177, in _get_window(waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
    174 epsilon = _get_epsilon(device, dtype)
    176 # size (m, window_size)
--> 177 strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
    179 if dither != 0.0:
    180     rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)

File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:83, in _get_strided(waveform, window_size, window_shift, snip_edges)
     80         waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
     82 sizes = (m, window_size)
---> 83 return waveform.as_strided(sizes, strides)

RuntimeError: setStorage: sizes [100, 400], strides [320, 2], storage offset 0, and itemsize 4 requiring a storage size of 129916 are out of bounds for storage of size 128480

I encountered this problem while implementing a batched version of kaldi.fbank (btw I'm also willing to contribute my batch support back to torchaudio if the maintainers are interested). The problem lies in _get_strided() function. It first obtains the stride of original waveform

strides = (window_shift * waveform.stride(0), waveform.stride(0))

However, when snip_edges=False, there is a copy via torch.cat(), which forces waveform to be contiguous, if it was not originally so

if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

Hence, there is a mismatch between the original stride (before padding) and the new stride (padding).

The solution is to move the stride calculation line after padding

    # padding...

    strides = (window_shift * waveform.stride(0), waveform.stride(0))
    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)

Versions

PyTorch 2.4, torchaudio 2.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant