wmv_decoder/wma/
bitstream.rs

1//! Bitstream reader.
2//!
3//! This is a simplified but semantically equivalent translation of upstream's
4//! `GetBitContext` (MSB-first bit order).
5
6use crate::error::{DecoderError, Result};
7
8#[derive(Clone)]
9pub struct GetBitContext<'a> {
10    buf: &'a [u8],
11    size_in_bits: usize,
12    bit_pos: usize,
13}
14
15impl<'a> GetBitContext<'a> {
16    pub fn new(buf: &'a [u8]) -> Self {
17        Self {
18            buf,
19            size_in_bits: buf.len() * 8,
20            bit_pos: 0,
21        }
22    }
23
24    #[inline]
25    pub fn bits_left(&self) -> isize {
26        self.size_in_bits as isize - self.bit_pos as isize
27    }
28
29    #[inline]
30    pub fn bits_read(&self) -> usize {
31        self.bit_pos
32    }
33
34    #[inline]
35    pub fn align_to_byte(&mut self) {
36        self.bit_pos = (self.bit_pos + 7) & !7;
37    }
38
39    #[inline]
40    pub fn skip_bits(&mut self, n: usize) -> Result<()> {
41        if self.bit_pos + n > self.size_in_bits {
42            return Err(DecoderError::InvalidData("bitstream overflow".into()));
43        }
44        self.bit_pos += n;
45        Ok(())
46    }
47
48    #[inline]
49    pub fn get_bits1(&mut self) -> Result<u32> {
50        self.get_bits(1)
51    }
52
53    /// Read up to 32 bits.
54    #[inline]
55    pub fn get_bits(&mut self, n: usize) -> Result<u32> {
56        if n == 0 {
57            return Ok(0);
58        }
59        if n > 32 {
60            return Err(DecoderError::InvalidData("get_bits > 32".into()));
61        }
62        if self.bit_pos + n > self.size_in_bits {
63            return Err(DecoderError::InvalidData("bitstream overflow".into()));
64        }
65
66        let mut out: u32 = 0;
67        let mut remaining = n;
68        while remaining > 0 {
69            let byte_idx = self.bit_pos >> 3;
70            let bit_in_byte = self.bit_pos & 7; // 0..7, MSB-first
71            let avail = 8 - bit_in_byte;
72            let take = remaining.min(avail);
73
74            let byte = self.buf[byte_idx] as u32;
75            let shift = (avail - take) as u32;
76            let mask = (1u32 << take) - 1;
77            let bits = (byte >> shift) & mask;
78
79            out = (out << take) | bits;
80            self.bit_pos += take;
81            remaining -= take;
82        }
83
84        Ok(out)
85    }
86
87    #[inline]
88    pub fn show_bits(&self, n: usize) -> Result<u32> {
89        let mut tmp = self.clone();
90        tmp.get_bits(n)
91    }
92
93    #[inline]
94    pub fn get_bits_long(&mut self, n: usize) -> Result<u32> {
95        self.get_bits(n)
96    }
97}