1- from typing import Optional , TypeVar
1+ from math import floor
22
33import torch
4- from typing_extensions import TypeGuard
4+ import torch .nn .functional as F
5+ from torch import Tensor , nn
56
6- T = TypeVar ("T" )
77
8+ def get_center_frequencies (
9+ num_octaves : int , num_bins_per_octave : int , sample_rate : int # C # B # Xi_s
10+ ) -> Tensor : # Xi_k for k in [1, 2*K+1]
11+ """Compute log scaled center frequencies tensor"""
12+ frequency_nyquist = sample_rate / 2
13+ frequency_min = frequency_nyquist / (2 ** num_octaves )
14+ num_bins = num_octaves * num_bins_per_octave # K
15+ # Exponential increase from min to Nyquist
16+ frequencies = frequency_min * (2 ** (torch .arange (num_bins ) / num_bins_per_octave ))
17+ frequencies_all = torch .cat (
18+ [
19+ frequencies ,
20+ torch .tensor ([frequency_nyquist ]),
21+ # sample_rate - torch.flip(frequencies, dims=[0]) # not necessary
22+ ],
23+ dim = 0 ,
24+ )
25+ return frequencies_all
826
9- """
10- Utils
11- """
1227
28+ def get_bandwidths (
29+ num_octaves : int , # C
30+ num_bins_per_octave : int , # B
31+ sample_rate : int , # Xi_s
32+ frequencies : Tensor , # Xi_k for k in [1, 2*K+1]
33+ ) -> Tensor : # Omega_k for k in [1, 2*K+1]
34+ """Compute bandwidths tensor from center frequencies"""
35+ num_bins = num_octaves * num_bins_per_octave # K
36+ q_factor = 1.0 / (
37+ 2 ** (1.0 / num_bins_per_octave ) - 2 ** (- 1.0 / num_bins_per_octave )
38+ )
39+ bandwidths = frequencies [1 : num_bins + 1 ] / q_factor
40+ bandwidths_symmetric = (
41+ torch .flip (frequencies [1 : num_bins + 1 ], dims = [0 ]) / q_factor
42+ )
43+ bandwidths_all = torch .cat (
44+ [
45+ bandwidths ,
46+ torch .tensor ([sample_rate - 2 * frequencies [num_bins ]]),
47+ bandwidths_symmetric ,
48+ ],
49+ dim = 0 ,
50+ )
51+ return bandwidths_all
1352
14- def exists (val : Optional [T ]) -> TypeGuard [T ]:
15- return val is not None
1653
54+ def get_windows_range_indices (lengths : Tensor , positions : Tensor ) -> Tensor :
55+ """Compute windowing tensor of indices"""
56+ num_bins = lengths .shape [0 ] // 2
57+ max_length = lengths .max ()
58+ ranges = []
59+ for i in range (num_bins ):
60+ start = positions [i ] - max_length
61+ ranges += [torch .arange (start = start , end = start + max_length )] # type: ignore
62+ return torch .stack (ranges , dim = 0 ).long ()
1763
18- """
19- CQT
20- """
2164
22- class CQT (nn .Module ):
65+ def get_windows (lengths : Tensor ) -> Tensor :
66+ """Compute tensor of stacked (centered) windows"""
67+ num_bins = lengths .shape [0 ] // 2
68+ max_length = lengths .max ()
69+ windows = []
70+ for length in lengths [:num_bins ]:
71+ # Pad windows left and right to center them
72+ pad_left = floor (max_length / 2 - length / 2 )
73+ pad_right = int (max_length - length - pad_left )
74+ windows += [F .pad (torch .hann_window (int (length )), pad = (pad_left , pad_right ))]
75+ return torch .stack (windows , dim = 0 )
76+
77+
78+ def get_windows_inverse (windows : Tensor , lengths : Tensor ) -> Tensor :
79+ num_bins = windows .shape [0 ]
80+ return torch .einsum ("k m, k -> k m" , windows ** 2 , lengths [:num_bins ])
2381
82+
83+ class CQT (nn .Module ):
2484 def __init__ (
2585 self ,
86+ num_octaves : int ,
87+ num_bins_per_octave : int ,
88+ sample_rate : int ,
89+ block_length : int ,
2690 ):
27- super ().__init__ ()
91+ super ().__init__ ()
92+ self .block_length = block_length
93+
94+ frequencies = get_center_frequencies (
95+ num_octaves = num_octaves ,
96+ num_bins_per_octave = num_bins_per_octave ,
97+ sample_rate = sample_rate ,
98+ )
99+
100+ bandwidths = get_bandwidths (
101+ num_octaves = num_octaves ,
102+ num_bins_per_octave = num_bins_per_octave ,
103+ sample_rate = sample_rate ,
104+ frequencies = frequencies ,
105+ )
106+
107+ window_lengths = torch .round (bandwidths * block_length / sample_rate )
108+
109+ self .register_buffer (
110+ "windows_range_indices" ,
111+ get_windows_range_indices (
112+ lengths = window_lengths ,
113+ positions = torch .round (frequencies * block_length / sample_rate ),
114+ ),
115+ )
28116
117+ self .register_buffer ("windows" , get_windows (lengths = window_lengths ))
29118
30- def encode (self , x : Tensor ) -> Tensor :
31- pass
119+ self .register_buffer (
120+ "windows_inverse" ,
121+ get_windows_inverse (windows = self .windows , lengths = window_lengths ), # type: ignore # noqa
122+ )
32123
124+ def encode (self , waveform : Tensor ) -> Tensor :
125+ frequencies = torch .fft .fft (waveform )
126+ crops = frequencies [:, :, self .windows_range_indices ]
127+ crops_windowed = torch .einsum ("... t k, t k -> ... t k" , crops , self .windows )
128+ transform = torch .fft .ifft (crops_windowed )
129+ return transform
33130
34- def decode (self , x : Tensor ) -> Tensor :
35- pass
131+ def decode (self , transform : Tensor ) -> Tensor :
132+ b , c , length = * transform .shape [0 :2 ], self .block_length
133+ crops_windowed = torch .fft .fft (transform )
134+ crops_unwindowed = crops_windowed # TODO crops_unwindowed = torch.einsum('... t k, t k -> ... t k', transformed, self.windows_inverse) # noqa
135+ frequencies = torch .zeros (b , c , length ).to (transform )
136+ frequencies .scatter_add_ (
137+ dim = - 1 ,
138+ index = self .windows_range_indices .view (- 1 ).expand (b , c , - 1 ) % l , # type: ignore # noqa
139+ src = crops_unwindowed .view (b , c , - 1 ),
140+ )
141+ waveform = torch .fft .ifft (frequencies )
142+ return waveform
0 commit comments