wmv_decoder/
vc1.rs

1/// VC-1 / WMV9 Sequence & Picture Header Parser + Bitplane Decoder
2/// 
3
4use crate::bitreader::BitReader;
5use crate::error::{DecoderError, Result};
6
7// ─── Enums ───────────────────────────────────────────────────────────────────
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum Profile { Simple, Main, Advanced }
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum FrameType { I, P, B, BI, Skipped }
14
15impl std::fmt::Display for FrameType {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            FrameType::I       => write!(f, "I"),
19            FrameType::P       => write!(f, "P"),
20            FrameType::B       => write!(f, "B"),
21            FrameType::BI      => write!(f, "BI"),
22            FrameType::Skipped => write!(f, "skip"),
23        }
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum QuantizerMode { Implicit, Explicit, NonUniform, Uniform }
29
30// ─── Sequence Header ─────────────────────────────────────────────────────────
31
32#[derive(Debug, Clone)]
33pub struct SequenceHeader {
34    pub profile:        Profile,
35    pub max_b_frames:   u8,
36    pub frame_rate_num: u32,
37    pub frame_rate_den: u32,
38    pub loop_filter:    bool,
39    pub multires:       bool,
40    pub fastuvmc:       bool,
41    pub extended_mv:    bool,
42    pub dquant:         u8,
43    pub vstransform:    bool,
44    pub overlap:        bool,
45    pub syncmarker:     bool,
46    pub rangered:       bool,
47    pub quantizer_mode: QuantizerMode,
48    pub finterpflag:    bool,
49    pub transacfrm:     u8,   // inter AC table index 0-3 (TRANSACFRM)
50    pub transacfrm2:    u8,   // intra AC table index 0-3 (TRANSACFRM2)
51    pub mvtab:          u8,   // MV table index 0-3 (MVTAB)
52    pub cbptab:         u8,   // CBP table index 0-3 (CBPTAB)
53    pub dctab:          bool, // DC table select (DCTAB)
54    pub width:          u32,
55    pub height:         u32,
56    pub display_width:  u32,
57    pub display_height: u32,
58}
59
60impl SequenceHeader {
61    pub fn parse(data: &[u8]) -> Result<Self> {
62        if data.len() < 4 {
63            return Err(DecoderError::InvalidData("Sequence header too short".into()));
64        }
65        let mut br = BitReader::new(data);
66
67        // WMV9: first 2 bits = profile
68        let profile_bits = br.read_bits(2).unwrap_or(0) as u8;
69        let profile = match profile_bits {
70            0 => Profile::Simple,
71            1 => Profile::Main,
72            3 => Profile::Advanced,
73            _ => Profile::Main,
74        };
75
76        br.read_bits(2); // reserved
77
78        let frmrtq_postproc = br.read_bits(3).unwrap_or(0);
79        let _bitrtq_postproc= br.read_bits(5).unwrap_or(0);
80        let loop_filter     = br.read_bit().unwrap_or(false);
81        let _res_sm         = br.read_bit().unwrap_or(false);
82        let multires        = br.read_bit().unwrap_or(false);
83        let _res_fasttx     = br.read_bit().unwrap_or(true);
84        let fastuvmc        = br.read_bit().unwrap_or(false);
85        let extended_mv     = br.read_bit().unwrap_or(false);
86        let dquant          = br.read_bits(2).unwrap_or(0) as u8;
87        let vstransform     = br.read_bit().unwrap_or(false);
88        let _res_transtab   = br.read_bit().unwrap_or(false);
89        let overlap         = br.read_bit().unwrap_or(false);
90        let _resync_marker  = br.read_bit().unwrap_or(false);
91        let rangered        = br.read_bit().unwrap_or(false);
92        let max_b_frames    = br.read_bits(3).unwrap_or(0) as u8;
93        let quant_bits      = br.read_bits(2).unwrap_or(0) as u8;
94        let finterpflag     = br.read_bit().unwrap_or(false);
95        let syncmarker      = br.read_bit().unwrap_or(false);
96        // Additional fields per SMPTE 421M §8.1.1 (Simple/Main)
97        let transacfrm      = br.read_bits(2).unwrap_or(0) as u8;
98        let transacfrm2     = br.read_bits(2).unwrap_or(0) as u8;
99        let mvtab           = br.read_bits(2).unwrap_or(0) as u8;
100        let cbptab          = br.read_bits(2).unwrap_or(0) as u8;
101        let dctab           = br.read_bit().unwrap_or(false);
102
103        let quantizer_mode = match quant_bits {
104            0 => QuantizerMode::Implicit,
105            1 => QuantizerMode::Explicit,
106            2 => QuantizerMode::NonUniform,
107            _ => QuantizerMode::Uniform,
108        };
109
110        let (frame_rate_num, frame_rate_den) = match frmrtq_postproc {
111            0 => (6, 1), 1 => (8, 1), 2 => (10, 1), 3 => (12, 1),
112            4 => (15, 1), 5 => (24000, 1001), 6 => (24, 1), 7 => (25, 1),
113            _ => (30, 1),
114        };
115
116        Ok(SequenceHeader {
117            profile, max_b_frames, frame_rate_num, frame_rate_den,
118            loop_filter, multires, fastuvmc, extended_mv, dquant,
119            vstransform, overlap, syncmarker, rangered,
120            quantizer_mode, finterpflag,
121            transacfrm, transacfrm2, mvtab, cbptab, dctab,
122            width: 0, height: 0, display_width: 0, display_height: 0,
123        })
124    }
125}
126
127// ─── Bitplane ────────────────────────────────────────────────────────────────
128// SMPTE 421M §8.7.  Used to signal skipped MBs and direct-mode flags.
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131enum BitplaneMode {
132    Norm2, Diff2, Norm6, Diff6, RowSkip, ColSkip,
133}
134
135pub struct Bitplane {
136    pub data: Vec<u8>,  // one byte per macroblock (0 or 1)
137    pub is_raw: bool,
138}
139
140impl Bitplane {
141    pub fn decode(br: &mut BitReader<'_>, mb_w: usize, mb_h: usize) -> Option<Self> {
142        let n_mb = mb_w * mb_h;
143        let mut data = vec![0u8; n_mb];
144
145        // 3-bit mode code
146        let mode_bits = br.read_bits(3)?;
147        let mode = match mode_bits {
148            0 => BitplaneMode::Norm2,
149            1 => BitplaneMode::Norm6,
150            2 => BitplaneMode::Diff2,
151            3 => BitplaneMode::Diff6,
152            4 => BitplaneMode::RowSkip,
153            5 => BitplaneMode::ColSkip,
154            _ => {
155                // Raw: one bit per MB
156                for i in 0..n_mb {
157                    data[i] = br.read_bit()? as u8;
158                }
159                return Some(Bitplane { data, is_raw: true });
160            }
161        };
162
163        match mode {
164            BitplaneMode::Norm6 | BitplaneMode::Diff6 => {
165                // Tile-coded 6 MBs per codeword
166                let tile_size = 6usize;
167                let mut inv = br.read_bit()? as u8; // invert flag for Diff modes
168                if !matches!(mode, BitplaneMode::Diff2 | BitplaneMode::Diff6) { inv = 0; }
169                let mut i = 0;
170                while i < n_mb {
171                    let tile = br.read_bits(tile_size as u8)? as usize;
172                    for b in 0..tile_size.min(n_mb - i) {
173                        data[i + b] = (((tile >> (tile_size - 1 - b)) & 1) as u8) ^ inv;
174                    }
175                    i += tile_size;
176                }
177            }
178            BitplaneMode::Norm2 | BitplaneMode::Diff2 => {
179                let inv = if matches!(mode, BitplaneMode::Diff2) {
180                    br.read_bit()? as u8
181                } else { 0 };
182                let mut i = 0;
183                while i < n_mb {
184                    let pair = br.read_bits(2)? as u8;
185                    data[i    ] = ((pair >> 1) & 1) ^ inv;
186                    if i + 1 < n_mb { data[i+1] = (pair & 1) ^ inv; }
187                    i += 2;
188                }
189            }
190            BitplaneMode::RowSkip => {
191                for row in 0..mb_h {
192                    if br.read_bit()? { continue; }
193                    for col in 0..mb_w {
194                        data[row * mb_w + col] = br.read_bit()? as u8;
195                    }
196                }
197            }
198            BitplaneMode::ColSkip => {
199                for col in 0..mb_w {
200                    if br.read_bit()? { continue; }
201                    for row in 0..mb_h {
202                        data[row * mb_w + col] = br.read_bit()? as u8;
203                    }
204                }
205            }
206        }
207
208        Some(Bitplane { data, is_raw: false })
209    }
210}
211
212// ─── Picture Header ───────────────────────────────────────────────────────────
213
214#[derive(Debug, Clone)]
215pub struct PictureHeader {
216    pub frame_type:   FrameType,
217    pub pqindex:      u8,
218    pub pquant:       u8,
219    pub halfqp:       bool,
220    pub pqual_mode:   u8,
221    pub mvrange:      u8,
222    pub rptfrm:       u8,
223    pub pts_ms:       u32,
224    pub rangeredfrm:  bool,
225    /// Bit offset where the macroblock layer starts (from beginning of the frame payload).
226    ///
227    /// This includes the full picture header and any bitplanes decoded from it.
228    pub header_bits:  usize,
229    /// Skipped-MB bitplane (None if not present or raw-mode)
230    pub skipmb_plane: Option<Vec<u8>>,
231    /// Direct-mode bitplane for B-frames
232    pub directmb_plane: Option<Vec<u8>>,
233    /// B-frame temporal fraction from SMPTE 421M §7.1.3.6 Table 40.
234    pub bfrac_num: i32,
235    pub bfrac_den: i32,
236}
237
238impl PictureHeader {
239    pub fn parse(data: &[u8], seq: &SequenceHeader, pts_ms: u32,
240                 mb_w: usize, mb_h: usize) -> Result<Self> {
241        let mut br = BitReader::new(data);
242
243        // ── frame type ──────────────────────────────────────────────────────
244        let frame_type = if seq.max_b_frames > 0 {
245            match br.read_bits(2).unwrap_or(0xFF) {
246                0b11 => FrameType::I,
247                0b10 => FrameType::P,
248                0b00 => FrameType::B,
249                0b01 => FrameType::BI,
250                _    => return Err(DecoderError::InvalidData("Unknown frame type".into())),
251            }
252        } else {
253            match br.read_bit().unwrap_or(false) {
254                false => FrameType::P,
255                true  => FrameType::I,
256            }
257        };
258
259        // ── range reduction ─────────────────────────────────────────────────
260        let rangeredfrm = seq.rangered && br.read_bit().unwrap_or(false);
261
262        // ── quantizer ───────────────────────────────────────────────────────
263        let pqindex = br.read_bits(5).unwrap_or(1) as u8;
264        let (pquant, halfqp, pqual_mode) = Self::decode_quantizer(pqindex, seq);
265
266        // ── MV range ────────────────────────────────────────────────────────
267        let mvrange = if seq.extended_mv {
268            let mut r = 0u8;
269            while br.read_bit().unwrap_or(false) {
270                r += 1;
271                if r >= 3 { break; }
272            }
273            r
274        } else { 0 };
275
276        // ── repeat frame count (I-frame) ─────────────────────────────────
277        let rptfrm = if frame_type == FrameType::I {
278            br.read_bits(2).unwrap_or(0) as u8
279        } else { 0 };
280
281        // ── bitplanes ───────────────────────────────────────────────────────
282        // P-frame: skipped-MB bitplane
283        let skipmb_plane = if frame_type == FrameType::P {
284            Bitplane::decode(&mut br, mb_w, mb_h).map(|bp| bp.data)
285        } else { None };
286
287        // B-frame: direct-mode bitplane + skipped-MB bitplane
288        let directmb_plane = if frame_type == FrameType::B {
289            Bitplane::decode(&mut br, mb_w, mb_h).map(|bp| bp.data)
290        } else { None };
291
292        let skipmb_plane = if frame_type == FrameType::B {
293            Bitplane::decode(&mut br, mb_w, mb_h).map(|bp| bp.data)
294        } else { skipmb_plane };
295
296        // ── BFRACTION (B-frames only, SMPTE 421M §7.1.3.6 Table 40) ──────────
297        const BFRAC: [(i32,i32); 8] = [
298            (1,2),(1,3),(2,3),(1,4),(3,4),(1,5),(2,5),(1,2),
299        ];
300        let (bfrac_num, bfrac_den) = if frame_type == FrameType::B {
301            let idx = br.read_bits(3).unwrap_or(0) as usize;
302            BFRAC[idx.min(7)]
303        } else { (1, 2) };
304
305        let header_bits = br.bits_read();
306
307        Ok(PictureHeader {
308            frame_type, pqindex, pquant, halfqp, pqual_mode,
309            mvrange, rptfrm, pts_ms, rangeredfrm,
310            header_bits,
311            skipmb_plane, directmb_plane,
312            bfrac_num, bfrac_den,
313        })
314    }
315
316    // ── Legacy parse (no bitplane, backward compat) ─────────────────────────
317    pub fn parse_simple(data: &[u8], seq: &SequenceHeader, pts_ms: u32) -> Result<Self> {
318        let mb_w = ((seq.width + 15) / 16).max(1) as usize;
319        let mb_h = ((seq.height + 15) / 16).max(1) as usize;
320        Self::parse(data, seq, pts_ms, mb_w, mb_h)
321    }
322
323    fn decode_quantizer(pqindex: u8, seq: &SequenceHeader) -> (u8, bool, u8) {
324        match seq.quantizer_mode {
325            QuantizerMode::Implicit => {
326                // SMPTE 421M Table 5
327                let pquant = if pqindex <= 8 { pqindex }
328                else {
329                    const MAP: [u8; 23] = [
330                        9,10,11,12,13,14,15,16,17,18,
331                        19,20,21,22,23,24,25,27,29,31,33,63,0,
332                    ];
333                    MAP.get(pqindex as usize - 9).copied().unwrap_or(pqindex)
334                };
335                let halfqp = pqindex >= 9 && pquant == 0;
336                (pquant, halfqp, 0)
337            }
338            _ => (pqindex, false, 0),
339        }
340    }
341}