wmv_decoder/wma/
vlc.rs

1//! VLC (Huffman) table builder and decoder.
2//!
3
4use crate::error::{DecoderError, Result};
5use crate::wma::bitstream::GetBitContext;
6
7pub const VLC_INIT_USE_STATIC: i32 = 1;
8pub const VLC_INIT_STATIC_OVERLONG: i32 = 2 | VLC_INIT_USE_STATIC;
9pub const VLC_INIT_INPUT_LE: i32 = 4;
10pub const VLC_INIT_OUTPUT_LE: i32 = 8;
11
12pub type VlcBaseType = i16;
13
14#[derive(Clone, Copy, Default)]
15pub struct VlcElem {
16    pub sym: VlcBaseType,
17    pub len: VlcBaseType,
18}
19
20#[derive(Default)]
21pub struct Vlc {
22    pub bits: i32,
23    pub table: Vec<VlcElem>,
24    pub table_size: i32,
25    pub table_allocated: i32,
26}
27
28#[derive(Clone, Copy)]
29struct VlcCode {
30    bits: u8,
31    symbol: VlcBaseType,
32    /// Codeword with the first bit-to-be-read in the MSB.
33    code: u32,
34}
35
36fn bitswap_32(x: u32) -> u32 {
37    x.reverse_bits()
38}
39
40fn alloc_table(vlc: &mut Vlc, size: i32, use_static: bool) -> Result<i32> {
41    let index = vlc.table_size;
42    vlc.table_size += size;
43    if vlc.table_size > vlc.table_allocated {
44        if use_static {
45            return Err(DecoderError::InvalidData("static VLC table too small".into()));
46        }
47        vlc.table_allocated += 1 << vlc.bits;
48        let new_len = vlc.table_allocated as usize;
49        if new_len > vlc.table.len() {
50            vlc.table.resize(new_len, VlcElem { sym: 0, len: 0 });
51        }
52    }
53    Ok(index)
54}
55
56fn build_table(vlc: &mut Vlc, table_nb_bits: i32, nb_codes: usize, codes: &mut [VlcCode], flags: i32) -> Result<i32> {
57    if table_nb_bits > 30 {
58        return Err(DecoderError::InvalidData("table_nb_bits > 30".into()));
59    }
60    let table_size = 1 << table_nb_bits;
61    let table_index = alloc_table(vlc, table_size, (flags & VLC_INIT_USE_STATIC) != 0)?;
62
63    let base = table_index as usize;
64
65    for i in 0..nb_codes {
66        let mut n = codes[i].bits as i32;
67        let mut code = codes[i].code;
68        let symbol = codes[i].symbol;
69
70        if n <= table_nb_bits {
71            let mut j = (code >> (32 - table_nb_bits)) as i32;
72            let nb = 1 << (table_nb_bits - n);
73            let mut inc = 1;
74            if (flags & VLC_INIT_OUTPUT_LE) != 0 {
75                j = (bitswap_32(code) >> (32 - table_nb_bits)) as i32;
76                inc = 1 << n;
77            }
78            for _k in 0..nb {
79                let idx = base + j as usize;
80                let bits = vlc.table[idx].len;
81                let oldsym = vlc.table[idx].sym;
82                if (bits != 0 || oldsym != 0) && (bits != n as i16 || oldsym != symbol) {
83                    return Err(DecoderError::InvalidData("incorrect VLC codes".into()));
84                }
85                vlc.table[idx].len = n as i16;
86                vlc.table[idx].sym = symbol;
87                j += inc;
88            }
89        } else {
90            // Subtable.
91            n -= table_nb_bits;
92            let code_prefix = code >> (32 - table_nb_bits);
93            let mut subtable_bits = n;
94            codes[i].bits = n as u8;
95            codes[i].code = code << table_nb_bits;
96
97            let mut k = i + 1;
98            while k < nb_codes {
99                let nn = codes[k].bits as i32 - table_nb_bits;
100                if nn <= 0 {
101                    break;
102                }
103                let cc = codes[k].code;
104                if (cc >> (32 - table_nb_bits)) != code_prefix {
105                    break;
106                }
107                codes[k].bits = nn as u8;
108                codes[k].code = cc << table_nb_bits;
109                if nn > subtable_bits {
110                    subtable_bits = nn;
111                }
112                k += 1;
113            }
114            if subtable_bits > table_nb_bits {
115                subtable_bits = table_nb_bits;
116            }
117
118            let j = if (flags & VLC_INIT_OUTPUT_LE) != 0 {
119                (bitswap_32(code_prefix) >> (32 - table_nb_bits)) as i32
120            } else {
121                code_prefix as i32
122            };
123
124            let idx = base + j as usize;
125            vlc.table[idx].len = -(subtable_bits as i16);
126
127            let sub_index = build_table(vlc, subtable_bits, k - i, &mut codes[i..k], flags)?;
128
129            // Rebase after possible resize.
130            let base2 = table_index as usize;
131            let idx2 = base2 + j as usize;
132            vlc.table[idx2].sym = sub_index as i16;
133
134            // Skip processed range.
135            // Equivalent to `i = k - 1` in C loop.
136            // We cannot easily modify `i` in Rust for-loop, so handle via while in caller.
137        }
138    }
139
140    // Mark empty entries.
141    let base3 = table_index as usize;
142    for i in 0..table_size {
143        let idx = base3 + i as usize;
144        if vlc.table[idx].len == 0 {
145            vlc.table[idx].sym = -1;
146        }
147    }
148
149    Ok(table_index)
150}
151
152fn vlc_common_init(vlc: &mut Vlc, nb_bits: i32, flags: i32) {
153    vlc.bits = nb_bits;
154    vlc.table_size = 0;
155    if (flags & VLC_INIT_USE_STATIC) == 0 {
156        vlc.table.clear();
157        vlc.table_allocated = 0;
158    }
159}
160
161fn vlc_common_end(vlc: &mut Vlc, nb_bits: i32, codes: &mut [VlcCode], flags: i32) -> Result<()> {
162    // upstream's build_table expects codes grouped; for sparse init it sorts.
163    // We use a while loop in order to emulate the C for-loop that updates `i`.
164    let nb_codes = codes.len();
165
166    // Build table.
167    // Our build_table implementation above uses recursion but does not update outer loop index.
168    // To preserve upstream semantics, we rebuild using a local recursive builder that uses slices.
169
170    // Re-implement build_table logic with slice recursion, closer to C.
171    fn build(vlc: &mut Vlc, table_nb_bits: i32, codes: &mut [VlcCode], flags: i32) -> Result<i32> {
172        if table_nb_bits > 30 {
173            return Err(DecoderError::InvalidData("table_nb_bits > 30".into()));
174        }
175        let table_size = 1 << table_nb_bits;
176        let table_index = alloc_table(vlc, table_size, (flags & VLC_INIT_USE_STATIC) != 0)?;
177        let mut i: usize = 0;
178        while i < codes.len() {
179            let mut n = codes[i].bits as i32;
180            let mut code = codes[i].code;
181            let symbol = codes[i].symbol;
182
183            let base = table_index as usize;
184
185            if n <= table_nb_bits {
186                let mut j = (code >> (32 - table_nb_bits)) as i32;
187                let nb = 1 << (table_nb_bits - n);
188                let mut inc = 1;
189                if (flags & VLC_INIT_OUTPUT_LE) != 0 {
190                    j = (bitswap_32(code) >> (32 - table_nb_bits)) as i32;
191                    inc = 1 << n;
192                }
193                for _ in 0..nb {
194                    let idx = base + j as usize;
195                    let bits = vlc.table[idx].len;
196                    let oldsym = vlc.table[idx].sym;
197                    if (bits != 0 || oldsym != 0) && (bits != n as i16 || oldsym != symbol) {
198                        return Err(DecoderError::InvalidData("incorrect VLC codes".into()));
199                    }
200                    vlc.table[idx].len = n as i16;
201                    vlc.table[idx].sym = symbol;
202                    j += inc;
203                }
204                i += 1;
205            } else {
206                // Subtable.
207                n -= table_nb_bits;
208                let code_prefix = code >> (32 - table_nb_bits);
209                let mut subtable_bits = n;
210
211                codes[i].bits = n as u8;
212                codes[i].code = code << table_nb_bits;
213
214                let mut k = i + 1;
215                while k < codes.len() {
216                    let nn = codes[k].bits as i32 - table_nb_bits;
217                    if nn <= 0 {
218                        break;
219                    }
220                    let cc = codes[k].code;
221                    if (cc >> (32 - table_nb_bits)) != code_prefix {
222                        break;
223                    }
224                    codes[k].bits = nn as u8;
225                    codes[k].code = cc << table_nb_bits;
226                    if nn > subtable_bits {
227                        subtable_bits = nn;
228                    }
229                    k += 1;
230                }
231
232                if subtable_bits > table_nb_bits {
233                    subtable_bits = table_nb_bits;
234                }
235
236                let j = if (flags & VLC_INIT_OUTPUT_LE) != 0 {
237                    (bitswap_32(code_prefix) >> (32 - table_nb_bits)) as i32
238                } else {
239                    code_prefix as i32
240                };
241
242                {
243                    let idx = base + j as usize;
244                    vlc.table[idx].len = -(subtable_bits as i16);
245                }
246
247                let sub_index = build(vlc, subtable_bits, &mut codes[i..k], flags)?;
248
249                // Reload base after possible resize.
250                let base2 = table_index as usize;
251                let idx2 = base2 + j as usize;
252                vlc.table[idx2].sym = sub_index as i16;
253
254                i = k;
255            }
256        }
257
258        // Mark empty.
259        let base = table_index as usize;
260        for t in 0..table_size {
261            let idx = base + t as usize;
262            if vlc.table[idx].len == 0 {
263                vlc.table[idx].sym = -1;
264            }
265        }
266
267        Ok(table_index)
268    }
269
270    build(vlc, nb_bits, codes, flags)?;
271
272    if (flags & VLC_INIT_USE_STATIC) != 0 {
273        // Nothing.
274        let _ = nb_codes;
275    }
276
277    Ok(())
278}
279
280fn get_data_u32(table: &[u8], wrap: i32, i: usize, size: i32) -> u32 {
281    let off = i * wrap as usize;
282    match size {
283        1 => table[off] as u32,
284        2 => u16::from_ne_bytes([table[off], table[off + 1]]) as u32,
285        4 => u32::from_ne_bytes([table[off], table[off + 1], table[off + 2], table[off + 3]]),
286        _ => 0,
287    }
288}
289
290fn get_data_u16(table: &[u8], wrap: i32, i: usize, size: i32) -> u16 {
291    let off = i * wrap as usize;
292    match size {
293        1 => table[off] as u16,
294        2 => u16::from_ne_bytes([table[off], table[off + 1]]),
295        _ => 0,
296    }
297}
298
299/// Equivalent to upstream `ff_vlc_init_sparse()`.
300#[allow(clippy::too_many_arguments)]
301pub fn ff_vlc_init_sparse(
302    vlc: &mut Vlc,
303    nb_bits: i32,
304    nb_codes: usize,
305    bits: &[u8],
306    bits_wrap: i32,
307    bits_size: i32,
308    codes: &[u8],
309    codes_wrap: i32,
310    codes_size: i32,
311    symbols: Option<&[u8]>,
312    symbols_wrap: i32,
313    symbols_size: i32,
314    flags: i32,
315) -> Result<()> {
316    vlc_common_init(vlc, nb_bits, flags);
317
318    let mut buf: Vec<VlcCode> = Vec::with_capacity(nb_codes);
319
320    // Copy entries with len > nb_bits first.
321    for pass in 0..2 {
322        for i in 0..nb_codes {
323            let len = get_data_u32(bits, bits_wrap, i, bits_size) as u32;
324            let cond = if pass == 0 { len > nb_bits as u32 } else { len != 0 && len <= nb_bits as u32 };
325            if !cond {
326                continue;
327            }
328            if len > (3 * nb_bits) as u32 || len > 32 {
329                return Err(DecoderError::InvalidData(format!("Too long VLC ({len})")));
330            }
331            let mut code = get_data_u32(codes, codes_wrap, i, codes_size);
332            if code as u64 >= (1u64 << len) {
333                return Err(DecoderError::InvalidData(format!("Invalid code {code:x} for {i}")));
334            }
335            if (flags & VLC_INIT_INPUT_LE) != 0 {
336                code = bitswap_32(code);
337            } else {
338                code <<= 32 - len;
339            }
340            let sym: i16 = if let Some(symtab) = symbols {
341                get_data_u16(symtab, symbols_wrap, i, symbols_size) as i16
342            } else {
343                i as i16
344            };
345            buf.push(VlcCode { bits: len as u8, symbol: sym, code });
346        }
347        if pass == 0 {
348            buf.sort_by_key(|c| c.code >> 1);
349        }
350    }
351
352    vlc_common_end(vlc, nb_bits, &mut buf, flags)
353}
354
355/// Equivalent to upstream `ff_vlc_init_from_lengths()`.
356#[allow(clippy::too_many_arguments)]
357pub fn ff_vlc_init_from_lengths(
358    vlc: &mut Vlc,
359    nb_bits: i32,
360    nb_codes: usize,
361    lens: &[i8],
362    lens_wrap: i32,
363    symbols: Option<&[u8]>,
364    symbols_wrap: i32,
365    symbols_size: i32,
366    offset: i32,
367    flags: i32,
368) -> Result<()> {
369    vlc_common_init(vlc, nb_bits, flags);
370
371    let mut buf: Vec<VlcCode> = Vec::with_capacity(nb_codes);
372    let mut code: u64 = 0;
373    let len_max: i32 = 32.min(3 * nb_bits);
374
375    for i in 0..nb_codes {
376        let len = lens[(i * lens_wrap as usize)] as i32;
377        if len > 0 {
378            let sym_u = if let Some(symtab) = symbols {
379                get_data_u16(symtab, symbols_wrap, i, symbols_size) as u32
380            } else {
381                i as u32
382            };
383            let sym = (sym_u as i32 + offset) as i16;
384            buf.push(VlcCode {
385                bits: len as u8,
386                symbol: sym,
387                code: code as u32,
388            });
389        } else if len < 0 {
390            // Incomplete tree marker.
391        } else {
392            continue;
393        }
394
395        let mut abs_len = len;
396        if abs_len < 0 {
397            abs_len = -abs_len;
398        }
399        if abs_len > len_max || (code & ((1u64 << (32 - abs_len)) - 1)) != 0 {
400            return Err(DecoderError::InvalidData(format!("Invalid VLC (length {abs_len})")));
401        }
402        code += 1u64 << (32 - abs_len);
403        if code > (u32::MAX as u64) + 1 {
404            return Err(DecoderError::InvalidData("Overdetermined VLC tree".into()));
405        }
406    }
407
408    vlc_common_end(vlc, nb_bits, &mut buf, flags)
409}
410
411/// Equivalent to `get_vlc2()`.
412#[inline]
413pub fn get_vlc2(gb: &mut GetBitContext<'_>, table: &[VlcElem], bits: i32, max_depth: i32) -> Result<i32> {
414    let mut code: i32;
415    let mut index = gb.show_bits(bits as usize)? as usize;
416    let mut n = table[index].len as i32;
417    code = table[index].sym as i32;
418
419    if max_depth > 1 && n < 0 {
420        gb.skip_bits(bits as usize)?;
421        let mut nb_bits = -n;
422        index = (gb.show_bits(nb_bits as usize)? as usize) + code as usize;
423        n = table[index].len as i32;
424        code = table[index].sym as i32;
425        if max_depth > 2 && n < 0 {
426            gb.skip_bits(nb_bits as usize)?;
427            nb_bits = -n;
428            index = (gb.show_bits(nb_bits as usize)? as usize) + code as usize;
429            n = table[index].len as i32;
430            code = table[index].sym as i32;
431        }
432    }
433
434    gb.skip_bits(n as usize)?;
435    Ok(code)
436}
437