wmv_decoder/wma/
decoder.rs

1use std::f32;
2
3use crate::asf::AudioStreamInfo;
4use crate::error::{DecoderError, Result};
5
6use super::bitstream::GetBitContext;
7use super::common::ff_wma_get_frame_len_bits;
8use super::mdct::MdctNaive;
9use super::tables;
10use super::vlc::{ff_vlc_init_from_lengths, ff_vlc_init_sparse, get_vlc2, Vlc, VlcElem};
11
12const BLOCK_MIN_BITS: i32 = 7;
13const BLOCK_MAX_BITS: i32 = 11;
14const BLOCK_MAX_SIZE: usize = 1 << BLOCK_MAX_BITS;
15const BLOCK_NB_SIZES: usize = (BLOCK_MAX_BITS - BLOCK_MIN_BITS + 1) as usize;
16
17const HIGH_BAND_MAX_SIZE: usize = 16;
18const NB_LSP_COEFS: usize = 10;
19
20const MAX_CODED_SUPERFRAME_SIZE: usize = 32768;
21const MAX_CHANNELS: usize = 2;
22
23const NOISE_TAB_SIZE: usize = 8192;
24const LSP_POW_BITS: usize = 7;
25
26const VLCBITS: i32 = 9;
27const VLCMAX: i32 = (22 + VLCBITS - 1) / VLCBITS;
28
29const EXPVLCBITS: i32 = 8;
30const EXPMAX: i32 = (19 + EXPVLCBITS - 1) / EXPVLCBITS;
31
32const HGAINVLCBITS: i32 = 9;
33const HGAINMAX: i32 = (13 + HGAINVLCBITS - 1) / HGAINVLCBITS;
34
35/// A decoded PCM chunk.
36#[derive(Debug, Clone)]
37pub struct PcmFrameF32 {
38    pub pts_ms: u32,
39    pub sample_rate: u32,
40    pub channels: u16,
41    /// Interleaved samples.
42    pub samples: Vec<f32>,
43}
44
45#[derive(Clone, Copy, Debug)]
46enum WmaVersion {
47    V1,
48    V2,
49}
50
51impl WmaVersion {
52    fn id(&self) -> i32 {
53        match self {
54            WmaVersion::V1 => 1,
55            WmaVersion::V2 => 2,
56        }
57    }
58}
59
60/// Direct translation of upstream `WMACodecContext` for WMAv1/2.
61pub struct WmaDecoder {
62    version: WmaVersion,
63
64    channels: usize,
65    sample_rate: u32,
66    bit_rate: u32,
67    block_align: u16,
68
69    // Flags derived from `flags2`.
70    use_exp_vlc: bool,
71    use_bit_reservoir: bool,
72    use_variable_block_len: bool,
73    use_noise_coding: bool,
74
75    byte_offset_bits: i32,
76
77    // VLC tables.
78    exp_vlc: Vlc,
79    hgain_vlc: Vlc,
80    coef_vlc: [Vlc; 2],
81    run_table: [Vec<u16>; 2],
82    level_table: [Vec<f32>; 2],
83
84    // Frame / block config.
85    frame_len_bits: i32,
86    frame_len: usize,
87    nb_block_sizes: usize,
88
89    reset_block_lengths: bool,
90    block_len_bits: i32,
91    next_block_len_bits: i32,
92    prev_block_len_bits: i32,
93    block_len: usize,
94    block_num: i32,
95    block_pos: usize,
96
97    ms_stereo: bool,
98    channel_coded: [bool; MAX_CHANNELS],
99
100    // Exponent bands.
101    exponent_sizes: [usize; BLOCK_NB_SIZES],
102    exponent_bands: [[u16; 25]; BLOCK_NB_SIZES],
103    high_band_start: [usize; BLOCK_NB_SIZES],
104    coefs_start: usize,
105    coefs_end: [usize; BLOCK_NB_SIZES],
106    exponent_high_sizes: [usize; BLOCK_NB_SIZES],
107    exponent_high_bands: [[u16; HIGH_BAND_MAX_SIZE]; BLOCK_NB_SIZES],
108
109    high_band_coded: [[bool; HIGH_BAND_MAX_SIZE]; MAX_CHANNELS],
110    high_band_values: [[i32; HIGH_BAND_MAX_SIZE]; MAX_CHANNELS],
111
112    // Exponents and coefficients.
113    exponents_bsize: [usize; MAX_CHANNELS],
114    exponents: [Vec<f32>; MAX_CHANNELS],
115    max_exponent: [f32; MAX_CHANNELS],
116    coefs1: [Vec<f32>; MAX_CHANNELS],
117    coefs: [Vec<f32>; MAX_CHANNELS],
118
119    // MDCT.
120    mdct: Vec<MdctNaive>,
121    windows: Vec<Vec<f32>>, // per block-size: half window of length block_len
122    output: Vec<f32>,       // 2*BLOCK_MAX_SIZE
123    frame_out: [Vec<f32>; MAX_CHANNELS],
124
125    // Bit reservoir.
126    last_superframe: Vec<u8>,
127    last_bitoffset: usize,
128    last_superframe_len: usize,
129    eof_done: bool,
130
131    // Noise.
132    noise_table: Vec<f32>,
133    noise_index: usize,
134    noise_mult: f32,
135
136    // LSP to curve.
137    lsp_cos_table: Vec<f32>,
138    lsp_pow_e_table: [f32; 256],
139    lsp_pow_m_table1: [f32; 1 << LSP_POW_BITS],
140    lsp_pow_m_table2: [f32; 1 << LSP_POW_BITS],
141
142    exponents_initialized: [bool; MAX_CHANNELS],
143}
144
145fn ilog2_u32(x: u32) -> i32 {
146    31 - (x.leading_zeros() as i32)
147}
148
149fn ff_exp10f(x: f32) -> f32 {
150    // ff_exp10f(x) = exp2f(M_LOG2_10 * x)
151    (std::f32::consts::LOG2_10 * x).exp2()
152}
153
154fn sine_window_init(n: usize) -> Vec<f32> {
155    // Translated from ff_sine_window_init.
156    let mut w = vec![0f32; n];
157    let den = 2.0f32 * n as f32;
158    for i in 0..n {
159        w[i] = ((i as f32 + 0.5) * (std::f32::consts::PI / den)).sin();
160    }
161    w
162}
163
164
165fn vector_fmul_reverse(dst: &mut [f32], src0: &[f32], win: &[f32]) {
166    let len = dst.len();
167    for i in 0..len {
168        dst[i] = src0[i] * win[len - 1 - i];
169    }
170}
171
172fn butterflies_float(v1: &mut [f32], v2: &mut [f32]) {
173    for i in 0..v1.len() {
174        let t = v1[i] - v2[i];
175        v1[i] += v2[i];
176        v2[i] = t;
177    }
178}
179
180
181fn pow_m1_4_tables(
182    x: f32,
183    lsp_pow_e_table: &[f32; 256],
184    lsp_pow_m_table1: &[f32; 1 << LSP_POW_BITS],
185    lsp_pow_m_table2: &[f32; 1 << LSP_POW_BITS],
186) -> f32 {
187    // Direct translation of `pow_m1_4` from upstream wmadec.c, but parameterized to avoid borrowing `self`.
188    let u = x.to_bits();
189    let e = (u >> 23) as usize;
190    let m = ((u >> (23 - LSP_POW_BITS)) & ((1 << LSP_POW_BITS) - 1) as u32) as usize;
191    let t_bits = ((u << LSP_POW_BITS) & ((1 << 23) - 1)) | (127 << 23);
192    let t = f32::from_bits(t_bits);
193    let a = lsp_pow_m_table1[m];
194    let b = lsp_pow_m_table2[m];
195    lsp_pow_e_table[e] * (a + b * t)
196}
197
198fn wma_lsp_to_curve_tables(
199    out: &mut [f32],
200    n: usize,
201    lsp: &[f32; NB_LSP_COEFS],
202    lsp_cos_table: &[f32],
203    lsp_pow_e_table: &[f32; 256],
204    lsp_pow_m_table1: &[f32; 1 << LSP_POW_BITS],
205    lsp_pow_m_table2: &[f32; 1 << LSP_POW_BITS],
206) -> f32 {
207    // Direct translation of `wma_lsp_to_curve` from upstream wmadec.c, parameterized to avoid borrowing `self`.
208    let mut val_max = 0.0f32;
209    for i in 0..n {
210        let mut p = 0.5f32;
211        let mut q = 0.5f32;
212        let w = lsp_cos_table[i];
213        let mut j = 1usize;
214        while j < NB_LSP_COEFS {
215            q *= w - lsp[j - 1];
216            p *= w - lsp[j];
217            j += 2;
218        }
219        p *= p * (2.0f32 - w);
220        q *= q * (2.0f32 + w);
221        let mut v = p + q;
222        v = pow_m1_4_tables(v, lsp_pow_e_table, lsp_pow_m_table1, lsp_pow_m_table2);
223        if v > val_max {
224            val_max = v;
225        }
226        out[i] = v;
227    }
228    val_max
229}
230
231
232
233
234fn wma_window_apply(
235    out: &mut [f32],
236    output: &[f32],
237    windows: &[Vec<f32>],
238    frame_len_bits: i32,
239    block_len_bits: i32,
240    prev_block_len_bits: i32,
241    next_block_len_bits: i32,
242    block_len: usize,
243) {
244    // Direct translation of upstream `wma_window`, but parameterized to avoid borrowing `self`.
245    let mut in_buf: &[f32] = output;
246
247    // Left part.
248    if block_len_bits <= prev_block_len_bits {
249        let bsize = (frame_len_bits - block_len_bits) as usize;
250        let win = &windows[bsize];
251        for i in 0..block_len {
252            out[i] = in_buf[i] * win[i] + out[i];
253        }
254    } else {
255        let prev_len = 1usize << prev_block_len_bits;
256        let n = (block_len - prev_len) / 2;
257        let bsize = (frame_len_bits - prev_block_len_bits) as usize;
258        let win = &windows[bsize];
259        for i in 0..prev_len {
260            let idx = n + i;
261            out[idx] = in_buf[idx] * win[i] + out[idx];
262        }
263        out[n + prev_len..n + prev_len + n].copy_from_slice(&in_buf[n + prev_len..n + prev_len + n]);
264    }
265
266    // Right part.
267    let out2 = &mut out[block_len..];
268    in_buf = &in_buf[block_len..];
269
270    if block_len_bits <= next_block_len_bits {
271        let bsize = (frame_len_bits - block_len_bits) as usize;
272        vector_fmul_reverse(&mut out2[..block_len], &in_buf[..block_len], &windows[bsize]);
273    } else {
274        let next_len = 1usize << next_block_len_bits;
275        let n = (block_len - next_len) / 2;
276        let bsize = (frame_len_bits - next_block_len_bits) as usize;
277        out2[n + next_len..n + next_len + n].copy_from_slice(&in_buf[n + next_len..n + next_len + n]);
278        vector_fmul_reverse(&mut out2[n..n + next_len], &in_buf[n..n + next_len], &windows[bsize]);
279    }
280}
281
282
283
284impl WmaDecoder {
285    pub fn new(info: &AudioStreamInfo) -> Result<Self> {
286        let version = match info.format_tag {
287            0x0160 => WmaVersion::V1,
288            0x0161 => WmaVersion::V2,
289            _ => return Err(DecoderError::Unsupported(format!("unsupported WMA format tag: 0x{:04x}", info.format_tag))),
290        };
291
292        if info.block_align == 0 {
293            return Err(DecoderError::InvalidData("block_align is not set".into()));
294        }
295
296        let channels = info.channels as usize;
297        if channels == 0 || channels > MAX_CHANNELS {
298            return Err(DecoderError::Unsupported("only mono/stereo supported".into()));
299        }
300
301        // Extract flags2 like upstream.
302        let mut flags2: u16 = 0;
303        let extradata = &info.extra_data;
304        match version {
305            WmaVersion::V1 => {
306                if extradata.len() >= 4 {
307                    flags2 = u16::from_le_bytes([extradata[2], extradata[3]]);
308                }
309            }
310            WmaVersion::V2 => {
311                if extradata.len() >= 6 {
312                    flags2 = u16::from_le_bytes([extradata[4], extradata[5]]);
313                }
314            }
315        }
316
317        let mut use_variable_block_len = (flags2 & 0x0004) != 0;
318        let use_exp_vlc = (flags2 & 0x0001) != 0;
319        let use_bit_reservoir = (flags2 & 0x0002) != 0;
320
321        // upstream quirk (issue1503).
322        if let WmaVersion::V2 = version {
323            if extradata.len() >= 8 {
324                let v = u16::from_le_bytes([extradata[4], extradata[5]]);
325                if v == 0x000d && use_variable_block_len {
326                    use_variable_block_len = false;
327                }
328            }
329        }
330
331        // Pre-init fixed fields.
332        let mut dec = Self {
333            version,
334            channels,
335            sample_rate: info.sample_rate,
336            bit_rate: info.bit_rate,
337            block_align: info.block_align,
338
339            use_exp_vlc,
340            use_bit_reservoir,
341            use_variable_block_len,
342            use_noise_coding: true,
343
344            byte_offset_bits: 0,
345
346            exp_vlc: Vlc::default(),
347            hgain_vlc: Vlc::default(),
348            coef_vlc: [Vlc::default(), Vlc::default()],
349            run_table: [Vec::new(), Vec::new()],
350            level_table: [Vec::new(), Vec::new()],
351
352            frame_len_bits: 0,
353            frame_len: 0,
354            nb_block_sizes: 0,
355
356            reset_block_lengths: true,
357            block_len_bits: 0,
358            next_block_len_bits: 0,
359            prev_block_len_bits: 0,
360            block_len: 0,
361            block_num: 0,
362            block_pos: 0,
363
364            ms_stereo: false,
365            channel_coded: [false; MAX_CHANNELS],
366
367            exponent_sizes: [0usize; BLOCK_NB_SIZES],
368            exponent_bands: [[0u16; 25]; BLOCK_NB_SIZES],
369            high_band_start: [0usize; BLOCK_NB_SIZES],
370            coefs_start: 0,
371            coefs_end: [0usize; BLOCK_NB_SIZES],
372            exponent_high_sizes: [0usize; BLOCK_NB_SIZES],
373            exponent_high_bands: [[0u16; HIGH_BAND_MAX_SIZE]; BLOCK_NB_SIZES],
374
375            high_band_coded: [[false; HIGH_BAND_MAX_SIZE]; MAX_CHANNELS],
376            high_band_values: [[0i32; HIGH_BAND_MAX_SIZE]; MAX_CHANNELS],
377
378            exponents_bsize: [0usize; MAX_CHANNELS],
379            exponents: [vec![0f32; BLOCK_MAX_SIZE], vec![0f32; BLOCK_MAX_SIZE]],
380            max_exponent: [1.0f32; MAX_CHANNELS],
381            coefs1: [vec![0f32; BLOCK_MAX_SIZE], vec![0f32; BLOCK_MAX_SIZE]],
382            coefs: [vec![0f32; BLOCK_MAX_SIZE], vec![0f32; BLOCK_MAX_SIZE]],
383
384            mdct: Vec::new(),
385            windows: Vec::new(),
386            output: vec![0f32; BLOCK_MAX_SIZE * 2],
387            frame_out: [vec![0f32; BLOCK_MAX_SIZE * 2], vec![0f32; BLOCK_MAX_SIZE * 2]],
388
389            last_superframe: vec![0u8; MAX_CODED_SUPERFRAME_SIZE + 64],
390            last_bitoffset: 0,
391            last_superframe_len: 0,
392            eof_done: false,
393
394            noise_table: vec![0f32; NOISE_TAB_SIZE],
395            noise_index: 0,
396            noise_mult: 0.0,
397
398            lsp_cos_table: vec![0f32; BLOCK_MAX_SIZE],
399            lsp_pow_e_table: [0f32; 256],
400            lsp_pow_m_table1: [0f32; 1 << LSP_POW_BITS],
401            lsp_pow_m_table2: [0f32; 1 << LSP_POW_BITS],
402
403            exponents_initialized: [false; MAX_CHANNELS],
404        };
405
406        // Full init = ff_wma_init + wma_decode_init bits.
407        dec.ff_wma_init(flags2 as i32)?;
408        dec.wma_decode_init(flags2 as i32)?;
409
410        Ok(dec)
411    }
412
413    pub fn sample_rate(&self) -> u32 {
414        self.sample_rate
415    }
416
417    pub fn channels(&self) -> u16 {
418        self.channels as u16
419    }
420
421    pub fn frame_len(&self) -> usize {
422        self.frame_len
423    }
424
425    /// Decode one ASF packet payload (usually `block_align` bytes).
426    pub fn decode_packet(&mut self, pkt: &[u8], pts_ms: u32) -> Result<Option<PcmFrameF32>> {
427        if pkt.is_empty() {
428            if self.eof_done {
429                return Ok(None);
430            }
431            // Flush delayed samples.
432            self.eof_done = true;
433            let mut out = Vec::with_capacity(self.frame_len * self.channels);
434            for i in 0..self.frame_len {
435                for ch in 0..self.channels {
436                    out.push(self.frame_out[ch][i]);
437                }
438            }
439            self.last_superframe_len = 0;
440            return Ok(Some(PcmFrameF32 {
441                pts_ms,
442                sample_rate: self.sample_rate,
443                channels: self.channels as u16,
444                samples: out,
445            }));
446        }
447
448        if pkt.len() < self.block_align as usize {
449            return Err(DecoderError::InvalidData(format!(
450                "Input packet size too small ({} < {})",
451                pkt.len(),
452                self.block_align
453            )));
454        }
455
456        let buf = &pkt[..self.block_align as usize];
457
458        let mut gb = GetBitContext::new(buf);
459
460        let mut nb_frames: i32;
461
462        if self.use_bit_reservoir {
463            // super frame header
464            gb.skip_bits(4)?; // super frame index
465            let mut nf = gb.get_bits(4)? as i32;
466            nf -= if self.last_superframe_len <= 0 { 1 } else { 0 };
467            nb_frames = nf;
468            if nb_frames <= 0 {
469                let is_error = nb_frames < 0 || gb.bits_left() <= 8;
470                if is_error {
471                    return Err(DecoderError::InvalidData(format!(
472                        "nb_frames is {nb_frames} bits left {}",
473                        gb.bits_left()
474                    )));
475                }
476
477                if self.last_superframe_len + buf.len() - 1 > MAX_CODED_SUPERFRAME_SIZE {
478                    return Err(DecoderError::InvalidData("bit reservoir overflow".into()));
479                }
480
481                let mut q = self.last_superframe_len;
482                let mut len = buf.len() - 1;
483                while len > 0 {
484                    let b = gb.get_bits(8)? as u8;
485                    self.last_superframe[q] = b;
486                    q += 1;
487                    len -= 1;
488                }
489
490                self.last_superframe_len += 8 * buf.len() - 8;
491                return Ok(None);
492            }
493        } else {
494            nb_frames = 1;
495        }
496
497        // Planar output like upstream, then interleave.
498        let mut samples: [Vec<f32>; MAX_CHANNELS] = [Vec::new(), Vec::new()];
499        for ch in 0..self.channels {
500            samples[ch].resize(nb_frames as usize * self.frame_len, 0f32);
501        }
502        let mut samples_offset: usize = 0;
503
504        if self.use_bit_reservoir {
505            let bit_offset = gb.get_bits((self.byte_offset_bits + 3) as usize)? as usize;
506            if bit_offset as isize > gb.bits_left() {
507                return Err(DecoderError::InvalidData("Invalid last frame bit offset".into()));
508            }
509
510            if self.last_superframe_len > 0 {
511                // Add `bit_offset` bits to last frame.
512                let add_bytes = (bit_offset + 7) >> 3;
513                if self.last_superframe_len + add_bytes > MAX_CODED_SUPERFRAME_SIZE {
514                    return Err(DecoderError::InvalidData("bit reservoir overflow".into()));
515                }
516
517                let mut q = self.last_superframe_len;
518                let mut len = bit_offset;
519                while len > 7 {
520                    self.last_superframe[q] = gb.get_bits(8)? as u8;
521                    q += 1;
522                    len -= 8;
523                }
524                if len > 0 {
525                    self.last_superframe[q] = (gb.get_bits(len)? as u8) << (8 - len);
526                }
527
528                // Decode the previous frame.
529                let total_bits = self.last_superframe_len * 8 + bit_offset;
530                let need_bytes = (total_bits + 7) / 8;
531                // Avoid borrowing `self` across the decode call.
532                let sf_bytes: Vec<u8> = self.last_superframe[..need_bytes].to_vec();
533                let mut gb2 = GetBitContext::new(&sf_bytes);
534                if self.last_bitoffset > 0 {
535                    gb2.skip_bits(self.last_bitoffset)?;
536                }
537                self.reset_block_lengths = true;
538                self.wma_decode_frame(&mut gb2, &mut samples, samples_offset)?;
539                samples_offset += self.frame_len;
540                nb_frames -= 1;
541            }
542
543            // Read each frame starting from bit_offset.
544            let pos = bit_offset + 4 + 4 + (self.byte_offset_bits as usize) + 3;
545            if pos >= MAX_CODED_SUPERFRAME_SIZE * 8 || pos > buf.len() * 8 {
546                return Err(DecoderError::InvalidData("invalid superframe pos".into()));
547            }
548
549            let start_byte = pos >> 3;
550            let mut gb3 = GetBitContext::new(&buf[start_byte..]);
551            let rem = pos & 7;
552            if rem > 0 {
553                gb3.skip_bits(rem)?;
554            }
555
556            self.reset_block_lengths = true;
557            for _ in 0..nb_frames {
558                self.wma_decode_frame(&mut gb3, &mut samples, samples_offset)?;
559                samples_offset += self.frame_len;
560            }
561
562            // Copy end of frame into last frame buffer.
563            let consumed_bits = gb3.bits_read();
564            let mut pos2 = consumed_bits + ((bit_offset + 4 + 4 + (self.byte_offset_bits as usize) + 3) & !7);
565            self.last_bitoffset = pos2 & 7;
566            pos2 >>= 3;
567            let len = buf.len().saturating_sub(pos2);
568            if len > MAX_CODED_SUPERFRAME_SIZE {
569                return Err(DecoderError::InvalidData("invalid reservoir len".into()));
570            }
571            self.last_superframe_len = len;
572            self.last_superframe[..len].copy_from_slice(&buf[pos2..pos2 + len]);
573        } else {
574            self.reset_block_lengths = true;
575            self.wma_decode_frame(&mut gb, &mut samples, samples_offset)?;
576            samples_offset += self.frame_len;
577        }
578
579        // Interleave.
580        let total_samples = samples_offset * self.channels;
581        let mut out = Vec::with_capacity(total_samples);
582        for i in 0..samples_offset {
583            for ch in 0..self.channels {
584                out.push(samples[ch][i]);
585            }
586        }
587
588        Ok(Some(PcmFrameF32 {
589            pts_ms,
590            sample_rate: self.sample_rate,
591            channels: self.channels as u16,
592            samples: out,
593        }))
594    }
595
596    fn wma_decode_init(&mut self, flags2: i32) -> Result<()> {
597        // Initialize MDCT contexts (naive) like wma_decode_init.
598        let scale = 1.0f64 / 32768.0f64;
599        self.mdct.clear();
600        for i in 0..self.nb_block_sizes {
601            let len = 1usize << (self.frame_len_bits - i as i32);
602            self.mdct.push(MdctNaive::new(len, scale));
603        }
604
605        // Noise/hgain VLC.
606        if self.use_noise_coding {
607            let flat: &[u8] = unsafe {
608                std::slice::from_raw_parts(
609                    tables::FF_WMA_HGAIN_HUFFTAB.as_ptr() as *const u8,
610                    tables::FF_WMA_HGAIN_HUFFTAB.len() * 2,
611                )
612            };
613            let lens: &[i8] = unsafe {
614                std::slice::from_raw_parts(flat.as_ptr().add(1) as *const i8, flat.len() - 1)
615            };
616            ff_vlc_init_from_lengths(
617                &mut self.hgain_vlc,
618                HGAINVLCBITS,
619                tables::FF_WMA_HGAIN_HUFFTAB.len(),
620                lens,
621                2,
622                Some(flat),
623                2,
624                1,
625                -18,
626                0,
627            )?;
628        }
629
630        // Exponent VLC.
631        if self.use_exp_vlc {
632            let bits = &tables::FF_AAC_SCALEFACTOR_BITS;
633            let codes_u32 = &tables::FF_AAC_SCALEFACTOR_CODE;
634            let codes_bytes: &[u8] = unsafe {
635                std::slice::from_raw_parts(codes_u32.as_ptr() as *const u8, codes_u32.len() * 4)
636            };
637
638            ff_vlc_init_sparse(
639                &mut self.exp_vlc,
640                EXPVLCBITS,
641                bits.len(),
642                bits,
643                1,
644                1,
645                codes_bytes,
646                4,
647                4,
648                None,
649                0,
650                0,
651                0,
652            )?;
653        } else {
654            self.wma_lsp_to_curve_init(self.frame_len);
655        }
656
657        // Flags and defaults.
658        let _ = flags2;
659        Ok(())
660    }
661
662    fn ff_wma_init(&mut self, flags2: i32) -> Result<()> {
663        // Validate stream params.
664        if self.sample_rate > 50000 || self.channels > 2 || self.bit_rate == 0 {
665            return Err(DecoderError::InvalidData("invalid audio params".into()));
666        }
667
668        let version_id = self.version.id();
669
670        // Compute MDCT block size.
671        self.frame_len_bits = ff_wma_get_frame_len_bits(self.sample_rate as i32, version_id, 0);
672        self.next_block_len_bits = self.frame_len_bits;
673        self.prev_block_len_bits = self.frame_len_bits;
674        self.block_len_bits = self.frame_len_bits;
675
676        self.frame_len = 1usize << self.frame_len_bits;
677        if self.use_variable_block_len {
678            let mut nb = ((flags2 >> 3) & 3) + 1;
679            if (self.bit_rate / self.channels as u32) >= 32000 {
680                nb += 2;
681            }
682            let nb_max = self.frame_len_bits - BLOCK_MIN_BITS;
683            if nb > nb_max {
684                nb = nb_max;
685            }
686            self.nb_block_sizes = (nb + 1) as usize;
687        } else {
688            self.nb_block_sizes = 1;
689        }
690
691        // Rate dependent params.
692        self.use_noise_coding = true;
693        let mut high_freq = self.sample_rate as f32 * 0.5f32;
694
695        // Version 2 normalized rates.
696        let mut sample_rate1 = self.sample_rate as i32;
697        if version_id == 2 {
698            if sample_rate1 >= 44100 {
699                sample_rate1 = 44100;
700            } else if sample_rate1 >= 22050 {
701                sample_rate1 = 22050;
702            } else if sample_rate1 >= 16000 {
703                sample_rate1 = 16000;
704            } else if sample_rate1 >= 11025 {
705                sample_rate1 = 11025;
706            } else if sample_rate1 >= 8000 {
707                sample_rate1 = 8000;
708            }
709        }
710
711        let bps = (self.bit_rate as f32) / ((self.channels as f32) * (self.sample_rate as f32));
712        let mut bps1 = bps;
713        if self.channels == 2 {
714            bps1 = bps * 1.6f32;
715        }
716
717        let x = (bps * (self.frame_len as f32) / 8.0 + 0.5) as u32;
718        self.byte_offset_bits = ilog2_u32(x.max(1)) + 2;
719
720        // Compute high frequency and noise coding.
721        if sample_rate1 == 44100 {
722            if bps1 >= 0.61 {
723                self.use_noise_coding = false;
724            } else {
725                high_freq *= 0.4;
726            }
727        } else if sample_rate1 == 22050 {
728            if bps1 >= 1.16 {
729                self.use_noise_coding = false;
730            } else if bps1 >= 0.72 {
731                high_freq *= 0.7;
732            } else {
733                high_freq *= 0.6;
734            }
735        } else if sample_rate1 == 16000 {
736            if bps > 0.5 {
737                high_freq *= 0.5;
738            } else {
739                high_freq *= 0.3;
740            }
741        } else if sample_rate1 == 11025 {
742            high_freq *= 0.7;
743        } else if sample_rate1 == 8000 {
744            if bps <= 0.625 {
745                high_freq *= 0.5;
746            } else if bps > 0.75 {
747                self.use_noise_coding = false;
748            } else {
749                high_freq *= 0.65;
750            }
751        } else {
752            if bps >= 0.8 {
753                high_freq *= 0.75;
754            } else if bps >= 0.6 {
755                high_freq *= 0.6;
756            } else {
757                high_freq *= 0.5;
758            }
759        }
760
761        // Compute scale factor band sizes.
762        self.coefs_start = if version_id == 1 { 3 } else { 0 };
763
764        for k in 0..self.nb_block_sizes {
765            let block_len = self.frame_len >> k;
766
767            if version_id == 1 {
768                let mut lpos = 0usize;
769                let mut i = 0usize;
770                for idx in 0..25 {
771                    let a = tables::FF_WMA_CRITICAL_FREQS[idx] as usize;
772                    let b = self.sample_rate as usize;
773                    let mut pos = ((block_len * 2 * a) + (b >> 1)) / b;
774                    if pos > block_len {
775                        pos = block_len;
776                    }
777                    self.exponent_bands[0][idx] = (pos - lpos) as u16;
778                    if pos >= block_len {
779                        i = idx + 1;
780                        break;
781                    }
782                    lpos = pos;
783                    i = idx + 1;
784                }
785                self.exponent_sizes[0] = i;
786            } else {
787                // Hardcoded tables.
788                let a = self.frame_len_bits - BLOCK_MIN_BITS - (k as i32);
789                let mut table_row: Option<&[u8; 25]> = None;
790                if a < 3 {
791                    if self.sample_rate >= 44100 {
792                        table_row = Some(&tables::EXPONENT_BAND_44100[a as usize]);
793                    } else if self.sample_rate >= 32000 {
794                        table_row = Some(&tables::EXPONENT_BAND_32000[a as usize]);
795                    } else if self.sample_rate >= 22050 {
796                        table_row = Some(&tables::EXPONENT_BAND_22050[a as usize]);
797                    }
798                }
799
800                if let Some(row) = table_row {
801                    let n = row[0] as usize;
802                    for i in 0..n {
803                        self.exponent_bands[k][i] = row[1 + i] as u16;
804                    }
805                    self.exponent_sizes[k] = n;
806                } else {
807                    let mut j = 0usize;
808                    let mut lpos = 0usize;
809                    for idx in 0..25 {
810                        let a = tables::FF_WMA_CRITICAL_FREQS[idx] as usize;
811                        let b = self.sample_rate as usize;
812                        let mut pos = ((block_len * 2 * a) + (b << 1)) / (4 * b);
813                        pos <<= 2;
814                        if pos > block_len {
815                            pos = block_len;
816                        }
817                        if pos > lpos {
818                            self.exponent_bands[k][j] = (pos - lpos) as u16;
819                            j += 1;
820                        }
821                        if pos >= block_len {
822                            break;
823                        }
824                        lpos = pos;
825                    }
826                    self.exponent_sizes[k] = j;
827                }
828            }
829
830            self.coefs_end[k] = (self.frame_len - ((self.frame_len * 9) / 100)) >> k;
831            self.high_band_start[k] = (((block_len as f32) * 2.0 * high_freq) / (self.sample_rate as f32) + 0.5) as usize;
832
833            let n = self.exponent_sizes[k];
834            let mut j = 0usize;
835            let mut pos = 0usize;
836            for i in 0..n {
837                let start0 = pos;
838                pos += self.exponent_bands[k][i] as usize;
839                let end0 = pos;
840                let mut start = start0;
841                let mut end = end0;
842                if start < self.high_band_start[k] {
843                    start = self.high_band_start[k];
844                }
845                if end > self.coefs_end[k] {
846                    end = self.coefs_end[k];
847                }
848                if end > start {
849                    self.exponent_high_bands[k][j] = (end - start) as u16;
850                    j += 1;
851                }
852            }
853            self.exponent_high_sizes[k] = j;
854        }
855
856        // Init MDCT windows.
857        self.windows.clear();
858        for i in 0..self.nb_block_sizes {
859            let half = 1usize << (self.frame_len_bits - i as i32);
860            self.windows.push(sine_window_init(half));
861        }
862
863        self.reset_block_lengths = true;
864
865        // Noise table.
866        if self.use_noise_coding {
867            self.noise_mult = if self.use_exp_vlc { 0.02 } else { 0.04 };
868            let mut seed: u32 = 1;
869            let norm = (1.0 / ((1u64 << 31) as f32)) * 3.0f32.sqrt() * self.noise_mult;
870            for i in 0..NOISE_TAB_SIZE {
871                seed = seed.wrapping_mul(314159).wrapping_add(1);
872                self.noise_table[i] = (seed as i32 as f32) * norm;
873            }
874        }
875
876        // Choose coef VLC tables.
877        let mut coef_vlc_table = 2;
878        if self.sample_rate >= 32000 {
879            if bps1 < 0.72 {
880                coef_vlc_table = 0;
881            } else if bps1 < 1.16 {
882                coef_vlc_table = 1;
883            }
884        }
885        let t0 = &tables::COEF_VLCS[coef_vlc_table * 2];
886        let t1 = &tables::COEF_VLCS[coef_vlc_table * 2 + 1];
887
888        self.init_coef_vlc(0, t0)?;
889        self.init_coef_vlc(1, t1)?;
890
891        Ok(())
892    }
893
894    fn init_coef_vlc(&mut self, idx: usize, tbl: &tables::CoefVlcTable) -> Result<()> {
895        // vlc_init(vlc, VLCBITS, n, table_bits, 1, 1, table_codes, 4, 4, 0)
896        let bits = tbl.huffbits;
897        let codes_u32 = tbl.huffcodes;
898        let codes_bytes: &[u8] = unsafe {
899            std::slice::from_raw_parts(codes_u32.as_ptr() as *const u8, codes_u32.len() * 4)
900        };
901        ff_vlc_init_sparse(
902            &mut self.coef_vlc[idx],
903            VLCBITS,
904            tbl.n,
905            bits,
906            1,
907            1,
908            codes_bytes,
909            4,
910            4,
911            None,
912            0,
913            0,
914            0,
915        )?;
916
917        // Build run/level tables like init_coef_vlc.
918        let n = tbl.n;
919        let levels_table = tbl.levels;
920
921        let mut run_table = vec![0u16; n];
922        let mut flevel_table = vec![0f32; n];
923        let mut int_table = vec![0u16; n];
924
925        let mut i = 2usize;
926        let mut level = 1usize;
927        let mut k = 0usize;
928        while i < n {
929            int_table[k] = i as u16;
930            let l = levels_table[k] as usize;
931            k += 1;
932            for j in 0..l {
933                run_table[i] = j as u16;
934                flevel_table[i] = level as f32;
935                i += 1;
936            }
937            level += 1;
938        }
939
940        self.run_table[idx] = run_table;
941        self.level_table[idx] = flevel_table;
942
943        Ok(())
944    }
945
946    fn ff_wma_total_gain_to_bits(total_gain: i32) -> i32 {
947        if total_gain < 15 {
948            13
949        } else if total_gain < 32 {
950            12
951        } else if total_gain < 40 {
952            11
953        } else if total_gain < 45 {
954            10
955        } else {
956            9
957        }
958    }
959
960    fn ff_wma_get_large_val(gb: &mut GetBitContext<'_>) -> Result<u32> {
961        let mut n_bits: usize = 8;
962        if gb.get_bits1()? != 0 {
963            n_bits += 8;
964            if gb.get_bits1()? != 0 {
965                n_bits += 8;
966                if gb.get_bits1()? != 0 {
967                    n_bits += 7;
968                }
969            }
970        }
971        gb.get_bits_long(n_bits)
972    }
973
974    #[allow(clippy::too_many_arguments)]
975    fn ff_wma_run_level_decode(
976        gb: &mut GetBitContext<'_>,
977        vlc: &[VlcElem],
978        level_table: &[f32],
979        run_table: &[u16],
980        version: i32,
981        ptr: &mut [f32],
982        mut offset: i32,
983        num_coefs: i32,
984        block_len: usize,
985        frame_len_bits: i32,
986        coef_nb_bits: i32,
987    ) -> Result<()> {
988        let coef_mask = (block_len as i32) - 1;
989        while offset < num_coefs {
990            let code = get_vlc2(gb, vlc, VLCBITS, VLCMAX)?;
991            if code > 1 {
992                offset += run_table[code as usize] as i32;
993                let sign = gb.get_bits1()? as i32 - 1;
994                let lvl_bits = level_table[code as usize].to_bits();
995                let signed_bits = lvl_bits ^ ((sign as u32) & 0x8000_0000);
996                ptr[(offset & coef_mask) as usize] = f32::from_bits(signed_bits);
997            } else if code == 1 {
998                break;
999            } else {
1000                let level: i32;
1001                if version == 0 {
1002                    level = gb.get_bits(coef_nb_bits as usize)? as i32;
1003                    offset += gb.get_bits(frame_len_bits as usize)? as i32;
1004                } else {
1005                    level = Self::ff_wma_get_large_val(gb)? as i32;
1006                    if gb.get_bits1()? != 0 {
1007                        if gb.get_bits1()? != 0 {
1008                            if gb.get_bits1()? != 0 {
1009                                return Err(DecoderError::InvalidData("broken escape sequence".into()));
1010                            } else {
1011                                offset += gb.get_bits(frame_len_bits as usize)? as i32 + 4;
1012                            }
1013                        } else {
1014                            offset += gb.get_bits(2)? as i32 + 1;
1015                        }
1016                    }
1017                }
1018                let sign = gb.get_bits1()? as i32 - 1;
1019                let v = (level ^ sign) - sign;
1020                ptr[(offset & coef_mask) as usize] = v as f32;
1021            }
1022            offset += 1;
1023        }
1024
1025        if offset > num_coefs {
1026            return Err(DecoderError::InvalidData("overflow in spectral RLE".into()));
1027        }
1028
1029        Ok(())
1030    }
1031
1032    fn wma_lsp_to_curve_init(&mut self, frame_len: usize) {
1033        let wdel = std::f32::consts::PI / (frame_len as f32);
1034        for i in 0..frame_len {
1035            self.lsp_cos_table[i] = 2.0f32 * (wdel * (i as f32)).cos();
1036        }
1037
1038        for i in 0..256 {
1039            let e = (i as i32) - 126;
1040            self.lsp_pow_e_table[i] = (e as f32 * -0.25).exp2();
1041        }
1042
1043        let mut b = 1.0f32;
1044        for i in (0..(1 << LSP_POW_BITS)).rev() {
1045            let m = (1 << LSP_POW_BITS) + i;
1046            let mut a = (m as f32) * (0.5f32 / (1 << LSP_POW_BITS) as f32);
1047            a = 1.0f32 / a.sqrt().sqrt();
1048            self.lsp_pow_m_table1[i] = 2.0f32 * a - b;
1049            self.lsp_pow_m_table2[i] = b - a;
1050            b = a;
1051        }
1052    }
1053
1054    fn decode_exp_lsp(&mut self, gb: &mut GetBitContext<'_>, ch: usize) -> Result<()> {
1055        // upstream wmadec.c: decode_exp_lsp()
1056        let mut lsp: [f32; NB_LSP_COEFS] = [0.0; NB_LSP_COEFS];
1057        for i in 0..NB_LSP_COEFS {
1058            let val = if i == 0 || i >= 8 {
1059                gb.get_bits(3)? as usize
1060            } else {
1061                gb.get_bits(4)? as usize
1062            };
1063            lsp[i] = tables::FF_WMA_LSP_CODEBOOK[i][val];
1064        }
1065
1066        let cos = &self.lsp_cos_table;
1067        let e = &self.lsp_pow_e_table;
1068        let m1 = &self.lsp_pow_m_table1;
1069        let m2 = &self.lsp_pow_m_table2;
1070        let out = &mut self.exponents[ch];
1071        let vmax = wma_lsp_to_curve_tables(out, self.block_len, &lsp, cos, e, m1, m2);
1072        self.max_exponent[ch] = vmax;
1073        Ok(())
1074    }
1075
1076    fn decode_exp_vlc(&mut self, gb: &mut GetBitContext<'_>, ch: usize) -> Result<()> {
1077        let mut last_exp: i32;
1078        let mut max_scale: f32 = 0.0;
1079        let ptab = &tables::POW_TAB[60..];
1080
1081        let bsize = (self.frame_len_bits - self.block_len_bits) as usize;
1082        let bands = &self.exponent_bands[bsize];
1083
1084        let mut q = 0usize;
1085        let q_end = self.block_len;
1086
1087        if self.version.id() == 1 {
1088            last_exp = gb.get_bits(5)? as i32 + 10;
1089            let v = ptab[last_exp as usize];
1090            max_scale = v;
1091            let n = bands[0] as usize;
1092            for _ in 0..n {
1093                self.exponents[ch][q] = v;
1094                q += 1;
1095            }
1096        } else {
1097            last_exp = 36;
1098        }
1099
1100        let mut ptr_idx = 0usize;
1101        if self.version.id() == 1 {
1102            ptr_idx = 1;
1103        }
1104
1105        while q < q_end {
1106            let code = get_vlc2(gb, &self.exp_vlc.table, EXPVLCBITS, EXPMAX)?;
1107            last_exp += code - 60;
1108            if (last_exp as i32 + 60) as usize >= tables::POW_TAB.len() {
1109                return Err(DecoderError::InvalidData(format!("Exponent out of range: {last_exp}")));
1110            }
1111            let v = ptab[last_exp as usize];
1112            if v > max_scale {
1113                max_scale = v;
1114            }
1115            let n = bands[ptr_idx] as usize;
1116            ptr_idx += 1;
1117            for _ in 0..n {
1118                self.exponents[ch][q] = v;
1119                q += 1;
1120            }
1121        }
1122
1123        self.max_exponent[ch] = max_scale;
1124        Ok(())
1125    }
1126
1127
1128    fn wma_decode_block(&mut self, gb: &mut GetBitContext<'_>) -> Result<bool> {
1129        // Returns Ok(true) if last block of frame.
1130        // Translated from wma_decode_block.
1131
1132        // Compute current block length.
1133        if self.use_variable_block_len {
1134            let n = ilog2_u32((self.nb_block_sizes - 1) as u32) + 1;
1135            if self.reset_block_lengths {
1136                self.reset_block_lengths = false;
1137                let v = gb.get_bits(n as usize)? as usize;
1138                if v >= self.nb_block_sizes {
1139                    return Err(DecoderError::InvalidData("prev_block_len_bits out of range".into()));
1140                }
1141                self.prev_block_len_bits = self.frame_len_bits - v as i32;
1142                let v = gb.get_bits(n as usize)? as usize;
1143                if v >= self.nb_block_sizes {
1144                    return Err(DecoderError::InvalidData("block_len_bits out of range".into()));
1145                }
1146                self.block_len_bits = self.frame_len_bits - v as i32;
1147            } else {
1148                self.prev_block_len_bits = self.block_len_bits;
1149                self.block_len_bits = self.next_block_len_bits;
1150            }
1151            let v = gb.get_bits(n as usize)? as usize;
1152            if v >= self.nb_block_sizes {
1153                return Err(DecoderError::InvalidData("next_block_len_bits out of range".into()));
1154            }
1155            self.next_block_len_bits = self.frame_len_bits - v as i32;
1156        } else {
1157            self.next_block_len_bits = self.frame_len_bits;
1158            self.prev_block_len_bits = self.frame_len_bits;
1159            self.block_len_bits = self.frame_len_bits;
1160        }
1161
1162        let bsize = (self.frame_len_bits - self.block_len_bits) as usize;
1163        if (self.frame_len_bits - self.block_len_bits) as usize >= self.nb_block_sizes {
1164            return Err(DecoderError::InvalidData("block_len_bits not initialized".into()));
1165        }
1166
1167        self.block_len = 1usize << self.block_len_bits;
1168        if self.block_pos + self.block_len > self.frame_len {
1169            return Err(DecoderError::InvalidData("frame_len overflow".into()));
1170        }
1171
1172        if self.channels == 2 {
1173            self.ms_stereo = gb.get_bits1()? != 0;
1174        }
1175
1176        let mut v_any = false;
1177        for ch in 0..self.channels {
1178            let a = gb.get_bits1()? != 0;
1179            self.channel_coded[ch] = a;
1180            v_any |= a;
1181        }
1182
1183        if !v_any {
1184            return self.wma_decode_block_next(gb, bsize);
1185        }
1186
1187        // Total gain.
1188        let mut total_gain: i32 = 1;
1189        loop {
1190            if gb.bits_left() < 7 {
1191                return Err(DecoderError::InvalidData("total_gain overread".into()));
1192            }
1193            let a = gb.get_bits(7)? as i32;
1194            total_gain += a;
1195            if a != 127 {
1196                break;
1197            }
1198        }
1199
1200        let coef_nb_bits = Self::ff_wma_total_gain_to_bits(total_gain);
1201
1202        // Number of coefficients.
1203        let ncoefs = (self.coefs_end[bsize] as i32) - (self.coefs_start as i32);
1204        let mut nb_coefs = [0i32; MAX_CHANNELS];
1205        for ch in 0..self.channels {
1206            nb_coefs[ch] = ncoefs;
1207        }
1208
1209        // Noise coding.
1210        if self.use_noise_coding {
1211            for ch in 0..self.channels {
1212                if self.channel_coded[ch] {
1213                    let n1 = self.exponent_high_sizes[bsize];
1214                    for i in 0..n1 {
1215                        let a = gb.get_bits1()? != 0;
1216                        self.high_band_coded[ch][i] = a;
1217                        if a {
1218                            nb_coefs[ch] -= self.exponent_high_bands[bsize][i] as i32;
1219                        }
1220                    }
1221                }
1222            }
1223            for ch in 0..self.channels {
1224                if self.channel_coded[ch] {
1225                    let n1 = self.exponent_high_sizes[bsize];
1226                    let mut val: i32 = 0x8000_0000u32 as i32;
1227                    for i in 0..n1 {
1228                        if self.high_band_coded[ch][i] {
1229                            if val == (0x8000_0000u32 as i32) {
1230                                val = gb.get_bits(7)? as i32 - 19;
1231                            } else {
1232                                val += get_vlc2(gb, &self.hgain_vlc.table, HGAINVLCBITS, HGAINMAX)?;
1233                            }
1234                            self.high_band_values[ch][i] = val;
1235                        }
1236                    }
1237                }
1238            }
1239        }
1240
1241        // Exponents can be reused in short blocks.
1242        let reuse = (self.block_len_bits == self.frame_len_bits) || (gb.get_bits1()? != 0);
1243        if reuse {
1244            for ch in 0..self.channels {
1245                if self.channel_coded[ch] {
1246                    if self.use_exp_vlc {
1247                        self.decode_exp_vlc(gb, ch)?;
1248                    } else {
1249                        self.decode_exp_lsp(gb, ch)?;
1250                    }
1251                    self.exponents_bsize[ch] = bsize;
1252                    self.exponents_initialized[ch] = true;
1253                }
1254            }
1255        }
1256
1257        for ch in 0..self.channels {
1258            if self.channel_coded[ch] && !self.exponents_initialized[ch] {
1259                return Err(DecoderError::InvalidData("exponents not initialized".into()));
1260            }
1261        }
1262
1263        // Parse spectral coefficients.
1264        for ch in 0..self.channels {
1265            if self.channel_coded[ch] {
1266                let tindex = (ch == 1 && self.ms_stereo) as usize;
1267                for v in &mut self.coefs1[ch][..self.block_len] {
1268                    *v = 0.0;
1269                }
1270                // Decode into coefs1 (upstream WMACoef).
1271                Self::ff_wma_run_level_decode(
1272                    gb,
1273                    &self.coef_vlc[tindex].table,
1274                    &self.level_table[tindex],
1275                    &self.run_table[tindex],
1276                    0,
1277                    &mut self.coefs1[ch],
1278                    0,
1279                    nb_coefs[ch],
1280                    self.block_len,
1281                    self.frame_len_bits,
1282                    coef_nb_bits,
1283                )?;
1284            }
1285            if self.version.id() == 1 && self.channels >= 2 {
1286                gb.align_to_byte();
1287            }
1288        }
1289
1290        // Normalize.
1291        let n4 = self.block_len / 2;
1292        let mut mdct_norm = 1.0f32 / (n4 as f32);
1293        if self.version.id() == 1 {
1294            mdct_norm *= (n4 as f32).sqrt();
1295        }
1296
1297        // Compute MDCT coefficients.
1298        for ch in 0..self.channels {
1299            if !self.channel_coded[ch] {
1300                continue;
1301            }
1302
1303            let esize = self.exponents_bsize[ch];
1304            let mult = ff_exp10f(total_gain as f32 * 0.05f32) / self.max_exponent[ch] * mdct_norm;
1305
1306            let mut coefs_pos = 0usize;
1307
1308            if self.use_noise_coding {
1309                // very low freqs: noise
1310                for i in 0..self.coefs_start {
1311                    let exp_idx = ((i << bsize) >> esize) as usize;
1312                    let noise = self.noise_table[self.noise_index];
1313                    self.noise_index = (self.noise_index + 1) & (NOISE_TAB_SIZE - 1);
1314                    self.coefs[ch][coefs_pos] = noise * self.exponents[ch][exp_idx] * mult;
1315                    coefs_pos += 1;
1316                }
1317
1318                let n1 = self.exponent_high_sizes[bsize];
1319
1320                // compute power of high bands
1321                let mut exp_power = [0f32; HIGH_BAND_MAX_SIZE];
1322                let mut exponents_ptr = (self.high_band_start[bsize] << bsize) >> esize;
1323                let mut last_high_band: usize = 0;
1324                for j in 0..n1 {
1325                    let n = self.exponent_high_bands[bsize][j] as usize;
1326                    if self.high_band_coded[ch][j] {
1327                        let mut e2: f32 = 0.0;
1328                        for i in 0..n {
1329                            let v = self.exponents[ch][exponents_ptr + ((i << bsize) >> esize)];
1330                            e2 += v * v;
1331                        }
1332                        exp_power[j] = e2 / (n as f32);
1333                        last_high_band = j;
1334                    }
1335                    exponents_ptr += (n << bsize) >> esize;
1336                }
1337
1338                // main freqs and high freqs
1339                let mut exponents_ptr = (self.coefs_start << bsize) >> esize;
1340                let mut coef1_idx = 0usize;
1341
1342                for j in (-1i32)..(n1 as i32) {
1343                    let n = if j < 0 {
1344                        self.high_band_start[bsize].saturating_sub(self.coefs_start)
1345                    } else {
1346                        self.exponent_high_bands[bsize][j as usize] as usize
1347                    };
1348
1349                    if j >= 0 && self.high_band_coded[ch][j as usize] {
1350                        let mut mult1 = (exp_power[j as usize] / exp_power[last_high_band]).sqrt();
1351                        mult1 *= ff_exp10f(self.high_band_values[ch][j as usize] as f32 * 0.05f32);
1352                        mult1 /= self.max_exponent[ch] * self.noise_mult;
1353                        mult1 *= mdct_norm;
1354
1355                        for i in 0..n {
1356                            let noise = self.noise_table[self.noise_index];
1357                            self.noise_index = (self.noise_index + 1) & (NOISE_TAB_SIZE - 1);
1358                            let exp = self.exponents[ch][exponents_ptr + ((i << bsize) >> esize)];
1359                            self.coefs[ch][coefs_pos] = noise * exp * mult1;
1360                            coefs_pos += 1;
1361                        }
1362                        exponents_ptr += (n << bsize) >> esize;
1363                    } else {
1364                        for i in 0..n {
1365                            let noise = self.noise_table[self.noise_index];
1366                            self.noise_index = (self.noise_index + 1) & (NOISE_TAB_SIZE - 1);
1367                            let exp = self.exponents[ch][exponents_ptr + ((i << bsize) >> esize)];
1368                            let coef1 = self.coefs1[ch][coef1_idx];
1369                            coef1_idx += 1;
1370                            self.coefs[ch][coefs_pos] = (coef1 + noise) * exp * mult;
1371                            coefs_pos += 1;
1372                        }
1373                        exponents_ptr += (n << bsize) >> esize;
1374                    }
1375                }
1376
1377                // very high freqs: noise
1378                let n = self.block_len - self.coefs_end[bsize];
1379                let exp_last = self.exponents[ch][((exponents_ptr as i32 - (1 << bsize)) >> esize) as usize];
1380                let mult1 = mult * exp_last;
1381                for _ in 0..n {
1382                    let noise = self.noise_table[self.noise_index];
1383                    self.noise_index = (self.noise_index + 1) & (NOISE_TAB_SIZE - 1);
1384                    self.coefs[ch][coefs_pos] = noise * mult1;
1385                    coefs_pos += 1;
1386                }
1387            } else {
1388                for _ in 0..self.coefs_start {
1389                    self.coefs[ch][coefs_pos] = 0.0;
1390                    coefs_pos += 1;
1391                }
1392
1393                let n = nb_coefs[ch] as usize;
1394                for i in 0..n {
1395                    let exp = self.exponents[ch][((i << bsize) >> esize)];
1396                    let coef1 = self.coefs1[ch][i];
1397                    self.coefs[ch][coefs_pos] = coef1 * exp * mult;
1398                    coefs_pos += 1;
1399                }
1400                let tail = self.block_len - self.coefs_end[bsize];
1401                for _ in 0..tail {
1402                    self.coefs[ch][coefs_pos] = 0.0;
1403                    coefs_pos += 1;
1404                }
1405            }
1406        }
1407
1408        if self.ms_stereo && self.channel_coded[1] {
1409            if !self.channel_coded[0] {
1410                for v in &mut self.coefs[0][..self.block_len] {
1411                    *v = 0.0;
1412                }
1413                self.channel_coded[0] = true;
1414            }
1415            let (c0, c1) = self.coefs.split_at_mut(1);
1416            let v0 = &mut c0[0][..self.block_len];
1417            let v1 = &mut c1[0][..self.block_len];
1418            butterflies_float(v0, v1);
1419        }
1420
1421        self.wma_decode_block_next(gb, bsize)
1422    }
1423
1424    fn wma_decode_block_next(&mut self, _gb: &mut GetBitContext<'_>, bsize: usize) -> Result<bool> {
1425        // MDCT + window add.
1426        for ch in 0..self.channels {
1427            let n4 = self.block_len / 2;
1428            if self.channel_coded[ch] {
1429                self.mdct[bsize].imdct_full(&mut self.output[..self.block_len * 2], &self.coefs[ch][..self.block_len]);
1430            } else if !(self.ms_stereo && ch == 1) {
1431                for v in &mut self.output[..self.block_len * 2] {
1432                    *v = 0.0;
1433                }
1434            }
1435
1436            let index = (self.frame_len / 2) + self.block_pos - n4;
1437            // frame_out has length 2*BLOCK_MAX_SIZE.
1438            let frame_len_bits = self.frame_len_bits;
1439            let block_len_bits = self.block_len_bits;
1440            let prev_block_len_bits = self.prev_block_len_bits;
1441            let next_block_len_bits = self.next_block_len_bits;
1442            let block_len = self.block_len;
1443            let windows = &self.windows;
1444            let output = &self.output;
1445            let out_slice = &mut self.frame_out[ch][index..index + block_len * 2];
1446            wma_window_apply(out_slice, output, windows, frame_len_bits, block_len_bits, prev_block_len_bits, next_block_len_bits, block_len);
1447        }
1448
1449        self.block_num += 1;
1450        self.block_pos += self.block_len;
1451        Ok(self.block_pos >= self.frame_len)
1452    }
1453
1454    fn wma_decode_frame(&mut self, gb: &mut GetBitContext<'_>, samples: &mut [Vec<f32>; MAX_CHANNELS], samples_offset: usize) -> Result<()> {
1455        self.block_num = 0;
1456        self.block_pos = 0;
1457        loop {
1458            let last = self.wma_decode_block(gb)?;
1459            if last {
1460                break;
1461            }
1462        }
1463
1464        for ch in 0..self.channels {
1465            samples[ch][samples_offset..samples_offset + self.frame_len]
1466                .copy_from_slice(&self.frame_out[ch][..self.frame_len]);
1467            // Shift for overlap.
1468            let tail = self.frame_out[ch][self.frame_len..self.frame_len * 2].to_vec();
1469            self.frame_out[ch][..self.frame_len].copy_from_slice(&tail);
1470        }
1471
1472        Ok(())
1473    }
1474}