Skip to content

Commit 5a3bfe8

Browse files
committed
df/io: Add support for torchaudio 2.9.
torchaudio drops support for meta data functionality in favor of using torchcodec. See pytorch/audio#3902 This patch implements a new path for torchaudio versions 2.9+ by checking the version number. In uses the new torchcodec API for retrieving metadata and decoding. Signed-off-by: Lubosz Sarnecki <lubosz@gmail.com>
1 parent d375b2d commit 5a3bfe8

1 file changed

Lines changed: 34 additions & 11 deletions

File tree

DeepFilterNet/df/io.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,38 @@
66
from loguru import logger
77
from numpy import ndarray
88
from torch import Tensor
9+
from packaging import version
910

10-
try:
11-
from torchaudio import AudioMetaData
11+
12+
if version.parse(ta.__version__) >= version.parse("2.9.0"):
13+
from torchcodec.decoders import AudioDecoder
14+
from torchcodec._core import AudioStreamMetadata
15+
16+
AudioStreamMetadataType = AudioStreamMetadata
1217

1318
TA_RESAMPLE_SINC = "sinc_interp_hann"
1419
TA_RESAMPLE_KAISER = "sinc_interp_kaiser"
15-
except ImportError:
16-
from torchaudio.backend.common import AudioMetaData
20+
else:
21+
try:
22+
from torchaudio import AudioMetaData
23+
24+
TA_RESAMPLE_SINC = "sinc_interp_hann"
25+
TA_RESAMPLE_KAISER = "sinc_interp_kaiser"
26+
except ImportError:
27+
from torchaudio.backend.common import AudioMetaData
1728

18-
TA_RESAMPLE_SINC = "sinc_interpolation"
19-
TA_RESAMPLE_KAISER = "kaiser_window"
29+
TA_RESAMPLE_SINC = "sinc_interpolation"
30+
TA_RESAMPLE_KAISER = "kaiser_window"
31+
32+
AudioStreamMetadataType = AudioMetaData
2033

2134
from df.logger import warn_once
2235
from df.utils import download_file, get_cache_dir, get_git_root
2336

2437

2538
def load_audio(
2639
file: str, sr: Optional[int] = None, verbose=True, **kwargs
27-
) -> Tuple[Tensor, AudioMetaData]:
40+
) -> Tuple[Tensor, AudioStreamMetadataType]:
2841
"""Loads an audio file using torchaudio.
2942
3043
Args:
@@ -43,10 +56,20 @@ def load_audio(
4356
rkwargs = {}
4457
if "method" in kwargs:
4558
rkwargs["method"] = kwargs.pop("method")
46-
info: AudioMetaData = ta.info(file, **ikwargs)
47-
if "num_frames" in kwargs and sr is not None:
48-
kwargs["num_frames"] *= info.sample_rate // sr
49-
audio, orig_sr = ta.load(file, **kwargs)
59+
60+
if version.parse(ta.__version__) >= version.parse("2.9.0"):
61+
decoder = AudioDecoder(file)
62+
info: AudioStreamMetadata = decoder.metadata
63+
samples = decoder.get_all_samples()
64+
audio = samples.data
65+
orig_sr = samples.sample_rate
66+
else:
67+
info: AudioMetaData = ta.info(file, **ikwargs)
68+
69+
if "num_frames" in kwargs and sr is not None:
70+
kwargs["num_frames"] *= info.sample_rate // sr
71+
audio, orig_sr = ta.load(file, **kwargs)
72+
5073
if sr is not None and orig_sr != sr:
5174
if verbose:
5275
warn_once(

0 commit comments

Comments
 (0)