mel-computations.cc 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. /**
  2. * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3. *
  4. * See LICENSE for clarification regarding multiple authors
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. // This file is copied/modified from kaldi/src/feat/mel-computations.cc
  19. #include "mel-computations.h"
  20. #include <algorithm>
  21. #include <sstream>
  22. #include <vector>
  23. #include "feature-window.h"
  24. namespace knf {
  25. std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
  26. os << opts.ToString();
  27. return os;
  28. }
  29. float MelBanks::VtlnWarpFreq(
  30. float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
  31. float vtln_high_cutoff,
  32. float low_freq, // upper+lower frequency cutoffs in mel computation
  33. float high_freq, float vtln_warp_factor, float freq) {
  34. /// This computes a VTLN warping function that is not the same as HTK's one,
  35. /// but has similar inputs (this function has the advantage of never producing
  36. /// empty bins).
  37. /// This function computes a warp function F(freq), defined between low_freq
  38. /// and high_freq inclusive, with the following properties:
  39. /// F(low_freq) == low_freq
  40. /// F(high_freq) == high_freq
  41. /// The function is continuous and piecewise linear with two inflection
  42. /// points.
  43. /// The lower inflection point (measured in terms of the unwarped
  44. /// frequency) is at frequency l, determined as described below.
  45. /// The higher inflection point is at a frequency h, determined as
  46. /// described below.
  47. /// If l <= f <= h, then F(f) = f/vtln_warp_factor.
  48. /// If the higher inflection point (measured in terms of the unwarped
  49. /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
  50. /// Since (by the last point) F(h) == h/vtln_warp_factor, then
  51. /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
  52. /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
  53. /// = vtln_high_cutoff * min(1, vtln_warp_factor).
  54. /// If the lower inflection point (measured in terms of the unwarped
  55. /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
  56. /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
  57. /// = vtln_low_cutoff * max(1, vtln_warp_factor)
  58. if (freq < low_freq || freq > high_freq)
  59. return freq; // in case this gets called
  60. // for out-of-range frequencies, just return the freq.
  61. KNF_CHECK_GT(vtln_low_cutoff, low_freq);
  62. KNF_CHECK_LT(vtln_high_cutoff, high_freq);
  63. float one = 1.0f;
  64. float l = vtln_low_cutoff * std::max(one, vtln_warp_factor);
  65. float h = vtln_high_cutoff * std::min(one, vtln_warp_factor);
  66. float scale = 1.0f / vtln_warp_factor;
  67. float Fl = scale * l; // F(l);
  68. float Fh = scale * h; // F(h);
  69. KNF_CHECK(l > low_freq && h < high_freq);
  70. // slope of left part of the 3-piece linear function
  71. float scale_left = (Fl - low_freq) / (l - low_freq);
  72. // [slope of center part is just "scale"]
  73. // slope of right part of the 3-piece linear function
  74. float scale_right = (high_freq - Fh) / (high_freq - h);
  75. if (freq < l) {
  76. return low_freq + scale_left * (freq - low_freq);
  77. } else if (freq < h) {
  78. return scale * freq;
  79. } else { // freq >= h
  80. return high_freq + scale_right * (freq - high_freq);
  81. }
  82. }
  83. float MelBanks::VtlnWarpMelFreq(
  84. float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
  85. float vtln_high_cutoff,
  86. float low_freq, // upper+lower frequency cutoffs in mel computation
  87. float high_freq, float vtln_warp_factor, float mel_freq) {
  88. return MelScale(VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, low_freq,
  89. high_freq, vtln_warp_factor,
  90. InverseMelScale(mel_freq)));
  91. }
  92. MelBanks::MelBanks(const MelBanksOptions &opts,
  93. const FrameExtractionOptions &frame_opts,
  94. float vtln_warp_factor)
  95. : htk_mode_(opts.htk_mode) {
  96. int32_t num_bins = opts.num_bins;
  97. if (num_bins < 3) KNF_LOG(FATAL) << "Must have at least 3 mel bins";
  98. float sample_freq = frame_opts.samp_freq;
  99. int32_t window_length_padded = frame_opts.PaddedWindowSize();
  100. KNF_CHECK_EQ(window_length_padded % 2, 0);
  101. int32_t num_fft_bins = window_length_padded / 2;
  102. float nyquist = 0.5f * sample_freq;
  103. float low_freq = opts.low_freq, high_freq;
  104. if (opts.high_freq > 0.0f)
  105. high_freq = opts.high_freq;
  106. else
  107. high_freq = nyquist + opts.high_freq;
  108. if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f ||
  109. high_freq > nyquist || high_freq <= low_freq) {
  110. KNF_LOG(FATAL) << "Bad values in options: low-freq " << low_freq
  111. << " and high-freq " << high_freq << " vs. nyquist "
  112. << nyquist;
  113. }
  114. float fft_bin_width = sample_freq / window_length_padded;
  115. // fft-bin width [think of it as Nyquist-freq / half-window-length]
  116. float mel_low_freq = MelScale(low_freq);
  117. float mel_high_freq = MelScale(high_freq);
  118. debug_ = opts.debug_mel;
  119. // divide by num_bins+1 in next line because of end-effects where the bins
  120. // spread out to the sides.
  121. float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
  122. float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high;
  123. if (vtln_high < 0.0f) {
  124. vtln_high += nyquist;
  125. }
  126. if (vtln_warp_factor != 1.0f &&
  127. (vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq ||
  128. vtln_high <= 0.0f || vtln_high >= high_freq || vtln_high <= vtln_low)) {
  129. KNF_LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low
  130. << " and vtln-high " << vtln_high << ", versus "
  131. << "low-freq " << low_freq << " and high-freq " << high_freq;
  132. }
  133. bins_.resize(num_bins);
  134. center_freqs_.resize(num_bins);
  135. for (int32_t bin = 0; bin < num_bins; ++bin) {
  136. float left_mel = mel_low_freq + bin * mel_freq_delta,
  137. center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
  138. right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
  139. if (vtln_warp_factor != 1.0f) {
  140. left_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
  141. vtln_warp_factor, left_mel);
  142. center_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
  143. vtln_warp_factor, center_mel);
  144. right_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
  145. vtln_warp_factor, right_mel);
  146. }
  147. center_freqs_[bin] = InverseMelScale(center_mel);
  148. // this_bin will be a vector of coefficients that is only
  149. // nonzero where this mel bin is active.
  150. std::vector<float> this_bin(num_fft_bins);
  151. int32_t first_index = -1, last_index = -1;
  152. for (int32_t i = 0; i < num_fft_bins; ++i) {
  153. float freq = (fft_bin_width * i); // Center frequency of this fft
  154. // bin.
  155. float mel = MelScale(freq);
  156. if (mel > left_mel && mel < right_mel) {
  157. float weight;
  158. if (mel <= center_mel)
  159. weight = (mel - left_mel) / (center_mel - left_mel);
  160. else
  161. weight = (right_mel - mel) / (right_mel - center_mel);
  162. this_bin[i] = weight;
  163. if (first_index == -1) first_index = i;
  164. last_index = i;
  165. }
  166. }
  167. KNF_CHECK(first_index != -1 && last_index >= first_index &&
  168. "You may have set num_mel_bins too large.");
  169. bins_[bin].first = first_index;
  170. int32_t size = last_index + 1 - first_index;
  171. bins_[bin].second.insert(bins_[bin].second.end(),
  172. this_bin.begin() + first_index,
  173. this_bin.begin() + first_index + size);
  174. // Replicate a bug in HTK, for testing purposes.
  175. if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) {
  176. bins_[bin].second[0] = 0.0;
  177. }
  178. } // for (int32_t bin = 0; bin < num_bins; ++bin) {
  179. if (debug_) {
  180. std::ostringstream os;
  181. for (size_t i = 0; i < bins_.size(); i++) {
  182. os << "bin " << i << ", offset = " << bins_[i].first << ", vec = ";
  183. for (auto k : bins_[i].second) os << k << ", ";
  184. os << "\n";
  185. }
  186. KNF_LOG(INFO) << os.str();
  187. }
  188. }
  189. // "power_spectrum" contains fft energies.
  190. void MelBanks::Compute(const float *power_spectrum,
  191. float *mel_energies_out) const {
  192. int32_t num_bins = bins_.size();
  193. for (int32_t i = 0; i < num_bins; i++) {
  194. int32_t offset = bins_[i].first;
  195. const auto &v = bins_[i].second;
  196. float energy = 0;
  197. for (int32_t k = 0; k != v.size(); ++k) {
  198. energy += v[k] * power_spectrum[k + offset];
  199. }
  200. // HTK-like flooring- for testing purposes (we prefer dither)
  201. if (htk_mode_ && energy < 1.0) {
  202. energy = 1.0;
  203. }
  204. mel_energies_out[i] = energy;
  205. // The following assert was added due to a problem with OpenBlas that
  206. // we had at one point (it was a bug in that library). Just to detect
  207. // it early.
  208. KNF_CHECK_EQ(energy, energy); // check that energy is not nan
  209. }
  210. if (debug_) {
  211. fprintf(stderr, "MEL BANKS:\n");
  212. for (int32_t i = 0; i < num_bins; i++)
  213. fprintf(stderr, " %f", mel_energies_out[i]);
  214. fprintf(stderr, "\n");
  215. }
  216. }
  217. } // namespace knf