Skip to main content

siglus_cfx_decompiler/
wgsl.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use crate::ctab::{ConstantInfo, ConstantTable, RegisterSet, ValueType};
4use crate::disasm::{
5    mask_len, DeclUsage, Instruction, Opcode, RegisterKey, RegisterType, ResultModifier,
6    SamplerTextureType, ShaderKind, SourceModifier,
7};
8use crate::disasm::{parse_shader, ShaderModel};
9
10#[derive(Debug, Clone)]
11struct DeclInfo {
12    reg: RegisterKey,
13    semantic: String,
14}
15
16#[derive(Debug, Clone)]
17struct DefFloat {
18    values: [f32; 4],
19}
20
21#[derive(Debug, Clone)]
22struct DefInt {
23    values: [i32; 4],
24}
25
26#[derive(Debug, Clone)]
27struct DefBool {
28    value: bool,
29}
30
31#[derive(Debug, Clone)]
32struct Context<'a> {
33    shader: &'a ShaderModel,
34    ctab: Option<&'a ConstantTable>,
35    decls: BTreeMap<RegisterKey, DeclInfo>,
36    sampler_decls: BTreeMap<u16, SamplerTextureType>,
37    def_float: BTreeMap<u16, DefFloat>,
38    def_int: BTreeMap<u16, DefInt>,
39    def_bool: BTreeMap<u16, DefBool>,
40    used_inputs: BTreeSet<RegisterKey>,
41    used_outputs: BTreeSet<RegisterKey>,
42    used_temps: BTreeSet<RegisterKey>,
43    used_samplers: BTreeSet<u16>,
44    used_consts: BTreeSet<u16>,
45    used_int_consts: BTreeSet<u16>,
46    used_bool_consts: BTreeSet<u16>,
47    uses_lit: bool,
48    uses_dst: bool,
49}
50
51pub fn decompile_wgsl(data: &[u8], ctab: Option<&ConstantTable>) -> String {
52    match parse_shader(data) {
53        Ok(shader) => {
54            let ctx = analyze(&shader, ctab);
55            emit_wgsl(&ctx)
56        }
57        Err(_) => String::new(),
58    }
59}
60
61fn emit_wgsl(ctx: &Context<'_>) -> String {
62    let mut out = String::new();
63    emit_register_bindings(&mut out, ctx);
64    emit_texture_bindings(&mut out, ctx);
65    emit_def_constants(&mut out, ctx);
66    emit_helpers(&mut out, ctx);
67    emit_structs(&mut out, ctx);
68    emit_main(&mut out, ctx);
69    out
70}
71
72fn emit_register_bindings(out: &mut String, ctx: &Context<'_>) {
73    if !ctx.used_consts.is_empty() {
74        out.push_str("struct FloatRegs {\n");
75        out.push_str("    c: array<vec4<f32>, 256>,\n");
76        out.push_str("};\n");
77        out.push_str("@group(0) @binding(0) var<uniform> float_regs: FloatRegs;\n\n");
78    }
79    if !ctx.used_int_consts.is_empty() {
80        out.push_str("struct IntRegs {\n");
81        out.push_str("    i: array<vec4<i32>, 16>,\n");
82        out.push_str("};\n");
83        out.push_str("@group(0) @binding(1) var<uniform> int_regs: IntRegs;\n\n");
84    }
85    if !ctx.used_bool_consts.is_empty() {
86        out.push_str("struct BoolRegs {\n");
87        out.push_str("    b: array<u32, 16>,\n");
88        out.push_str("};\n");
89        out.push_str("@group(0) @binding(2) var<uniform> bool_regs: BoolRegs;\n\n");
90    }
91}
92
93fn emit_texture_bindings(out: &mut String, ctx: &Context<'_>) {
94    for sampler in &ctx.used_samplers {
95        let tex_binding = (*sampler as u32) * 2;
96        let samp_binding = tex_binding + 1;
97        let tex_ty = wgsl_texture_type(sampler_texture_type(ctx, *sampler));
98        out.push_str(&format!("@group(1) @binding({}) var tex_s{}: {};\n", tex_binding, sampler, tex_ty));
99        out.push_str(&format!("@group(1) @binding({}) var samp_s{}: sampler;\n", samp_binding, sampler));
100    }
101    if !ctx.used_samplers.is_empty() {
102        out.push('\n');
103    }
104}
105
106fn emit_def_constants(out: &mut String, ctx: &Context<'_>) {
107    let mut any = false;
108    for (idx, def) in &ctx.def_float {
109        out.push_str(&format!(
110            "const c{}: vec4<f32> = vec4<f32>({}, {}, {}, {});\n",
111            idx,
112            fmt_f32(def.values[0]),
113            fmt_f32(def.values[1]),
114            fmt_f32(def.values[2]),
115            fmt_f32(def.values[3])
116        ));
117        any = true;
118    }
119    for (idx, def) in &ctx.def_int {
120        out.push_str(&format!(
121            "const i{}: vec4<i32> = vec4<i32>({}, {}, {}, {});\n",
122            idx, def.values[0], def.values[1], def.values[2], def.values[3]
123        ));
124        any = true;
125    }
126    for (idx, def) in &ctx.def_bool {
127        out.push_str(&format!("const b{}: bool = {};\n", idx, if def.value { "true" } else { "false" }));
128        any = true;
129    }
130    if any {
131        out.push('\n');
132    }
133}
134
135fn emit_helpers(out: &mut String, ctx: &Context<'_>) {
136    if ctx.uses_lit {
137        out.push_str("fn sm2_lit(v: vec4<f32>) -> vec4<f32> {\n");
138        out.push_str("    let y = max(v.x, 0.0);\n");
139        out.push_str("    let z = select(0.0, pow(max(v.y, 0.0), v.w), v.x > 0.0 && v.y > 0.0);\n");
140        out.push_str("    return vec4<f32>(1.0, y, z, 1.0);\n");
141        out.push_str("}\n\n");
142    }
143    if ctx.uses_dst {
144        out.push_str("fn sm2_dst(a: vec4<f32>, b: vec4<f32>) -> vec4<f32> {\n");
145        out.push_str("    return vec4<f32>(1.0, a.y * b.y, a.z, b.w);\n");
146        out.push_str("}\n\n");
147    }
148}
149
150fn emit_structs(out: &mut String, ctx: &Context<'_>) {
151    let input_name = input_struct_name(ctx.shader.kind);
152    let output_name = output_struct_name(ctx.shader.kind);
153
154    out.push_str(&format!("struct {} {{\n", input_name));
155    for decl in ctx.decls.values() {
156        let field = input_field_name(decl.reg);
157        let ty = wgsl_input_field_type(ctx, decl.reg);
158        let attr = input_attr(ctx.shader.kind, &decl.semantic);
159        out.push_str(&format!("    {} {}: {},\n", attr, field, ty));
160    }
161    out.push_str("};\n\n");
162
163    out.push_str(&format!("struct {} {{\n", output_name));
164    for reg in &ctx.used_outputs {
165        let field = output_field_name(*reg);
166        let sem = output_semantic(ctx, *reg);
167        let ty = wgsl_output_field_type(*reg);
168        let attr = output_attr(ctx.shader.kind, &sem, *reg);
169        out.push_str(&format!("    {} {}: {},\n", attr, field, ty));
170    }
171    out.push_str("};\n\n");
172}
173
174fn emit_main(out: &mut String, ctx: &Context<'_>) {
175    let stage = match ctx.shader.kind {
176        ShaderKind::Vertex => "vertex",
177        ShaderKind::Pixel => "fragment",
178    };
179    let input_name = input_struct_name(ctx.shader.kind);
180    let output_name = output_struct_name(ctx.shader.kind);
181    out.push_str(&format!("@{}\n", stage));
182    out.push_str(&format!("fn main(input: {}) -> {} {{\n", input_name, output_name));
183    out.push_str(&format!("    var output: {};\n", output_name));
184
185    for reg in &ctx.used_outputs {
186        out.push_str(&format!("    output.{} = {};\n", output_field_name(*reg), zero_value(wgsl_output_field_type(*reg))));
187    }
188    for reg in &ctx.used_temps {
189        out.push_str(&format!("    var {}: {} = {};\n", temp_name(*reg, ctx.shader.kind), temp_type(*reg), zero_value(temp_type(*reg))));
190    }
191    if !ctx.used_outputs.is_empty() || !ctx.used_temps.is_empty() {
192        out.push('\n');
193    }
194
195    let mut indent = 1usize;
196    for inst in &ctx.shader.instructions {
197        emit_instruction(out, ctx, inst, &mut indent);
198    }
199
200    out.push_str("    return output;\n");
201    out.push_str("}\n");
202}
203
204fn emit_instruction(out: &mut String, ctx: &Context<'_>, inst: &Instruction, indent: &mut usize) {
205    match inst.opcode {
206        Opcode::Comment | Opcode::Dcl | Opcode::Def | Opcode::DefI | Opcode::DefB | Opcode::Nop | Opcode::End => {}
207        Opcode::TexKill => {
208            if let Some(reg) = inst.dest_register() {
209                let expr = format!("{}{}", register_base(ctx, reg), wgsl_mask_suffix(inst.dest_write_mask()));
210                let n = mask_len(inst.dest_write_mask());
211                if n == 1 {
212                    line(out, *indent, &format!("if ({} < 0.0) {{ discard; }}", expr));
213                } else {
214                    line(out, *indent, &format!("if (any({} < {})) {{ discard; }}", expr, zero_vector(n)));
215                }
216            }
217        }
218        Opcode::If => {
219            let cond = source_expr(ctx, inst, 0, 1);
220            line(out, *indent, &format!("if ({}) {{", scalar_bool_expr(cond)));
221            *indent += 1;
222        }
223        Opcode::IfC => {
224            let a = source_expr(ctx, inst, 0, 4);
225            let b = source_expr(ctx, inst, 1, 4);
226            line(out, *indent, &format!("if ({}) {{", compare_all_expr(a, cmp_op(inst.comparison()), b, 4)));
227            *indent += 1;
228        }
229        Opcode::Else => {
230            if *indent > 0 { *indent -= 1; }
231            line(out, *indent, "} else {");
232            *indent += 1;
233        }
234        Opcode::EndIf => {
235            if *indent > 0 { *indent -= 1; }
236            line(out, *indent, "}");
237        }
238        Opcode::Break => line(out, *indent, "break;"),
239        Opcode::BreakC => {
240            let a = source_expr(ctx, inst, 0, 4);
241            let b = source_expr(ctx, inst, 1, 4);
242            line(out, *indent, &format!("if ({}) {{ break; }}", compare_all_expr(a, cmp_op(inst.comparison()), b, 4)));
243        }
244        Opcode::Rep => {
245            let n = source_expr(ctx, inst, 0, 1);
246            line(out, *indent, &format!("for (var _rep{}: i32 = 0; _rep{} < i32({}); _rep{} = _rep{} + 1) {{", inst.offset, inst.offset, n, inst.offset, inst.offset));
247            *indent += 1;
248        }
249        Opcode::Loop => {
250            line(out, *indent, &format!("for (var _loop{}: i32 = 0; ; _loop{} = _loop{} + 1) {{", inst.offset, inst.offset, inst.offset));
251            *indent += 1;
252        }
253        Opcode::EndRep | Opcode::EndLoop => {
254            if *indent > 0 { *indent -= 1; }
255            line(out, *indent, "}");
256        }
257        Opcode::Ret => line(out, *indent, "return output;"),
258        _ if inst.opcode.has_destination() => {
259            if let Some((dst, rhs)) = assignment_expr(ctx, inst) {
260                line(out, *indent, &format!("{} = {};", dst, rhs));
261            }
262        }
263        _ => {}
264    }
265}
266
267fn assignment_expr(ctx: &Context<'_>, inst: &Instruction) -> Option<(String, String)> {
268    let reg = inst.dest_register()?;
269    let mask = inst.dest_write_mask();
270    let dst_count = mask_len(mask);
271    let dst = format!("{}{}", register_base(ctx, reg), wgsl_mask_suffix(mask));
272
273    let (raw_rhs, raw_width) = match inst.opcode {
274        Opcode::Mov | Opcode::MovA => (source_expr(ctx, inst, 1, dst_count), dst_count),
275        Opcode::Add => (bin(ctx, inst, dst_count, "+"), dst_count),
276        Opcode::Sub => (bin(ctx, inst, dst_count, "-"), dst_count),
277        Opcode::Mul => (bin(ctx, inst, dst_count, "*"), dst_count),
278        Opcode::Mad => (format!("({} * {} + {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count), source_expr(ctx, inst, 3, dst_count)), dst_count),
279        Opcode::Rcp => (format!("(1.0 / {})", source_expr(ctx, inst, 1, 1)), 1),
280        Opcode::Rsq => (format!("inverseSqrt({})", source_expr(ctx, inst, 1, 1)), 1),
281        Opcode::Dp3 => (format!("dot({}, {})", source_expr(ctx, inst, 1, 3), source_expr(ctx, inst, 2, 3)), 1),
282        Opcode::Dp4 => (format!("dot({}, {})", source_expr(ctx, inst, 1, 4), source_expr(ctx, inst, 2, 4)), 1),
283        Opcode::Min => (format!("min({}, {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count)), dst_count),
284        Opcode::Max => (format!("max({}, {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count)), dst_count),
285        Opcode::Slt => (select_float(ctx, inst, dst_count, "<"), dst_count),
286        Opcode::Sge => (select_float(ctx, inst, dst_count, ">="), dst_count),
287        Opcode::Exp | Opcode::ExpP => (format!("exp2({})", source_expr(ctx, inst, 1, 1)), 1),
288        Opcode::Log | Opcode::LogP => (format!("log2({})", source_expr(ctx, inst, 1, 1)), 1),
289        Opcode::Lit => (format!("sm2_lit({})", source_expr(ctx, inst, 1, 4)), 4),
290        Opcode::Dst => (format!("sm2_dst({}, {})", source_expr(ctx, inst, 1, 4), source_expr(ctx, inst, 2, 4)), 4),
291        Opcode::Lrp => (format!("mix({}, {}, {})", source_expr(ctx, inst, 3, dst_count), source_expr(ctx, inst, 2, dst_count), source_expr(ctx, inst, 1, dst_count)), dst_count),
292        Opcode::Frc => (format!("fract({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
293        Opcode::Pow => (format!("pow({}, {})", source_expr(ctx, inst, 1, 1), source_expr(ctx, inst, 2, 1)), 1),
294        Opcode::Crs => (format!("cross({}, {})", source_expr(ctx, inst, 1, 3), source_expr(ctx, inst, 2, 3)), 3),
295        Opcode::Sgn => (format!("sign({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
296        Opcode::Abs => (format!("abs({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
297        Opcode::Nrm => (format!("normalize({})", source_expr(ctx, inst, 1, 3)), 3),
298        Opcode::SinCos => {
299            let width = if dst_count <= 1 { 1 } else { 2 };
300            (sincos_expr(ctx, inst, dst_count), width)
301        }
302        Opcode::Cmp => (select_expr(ctx, inst, dst_count, ">="), dst_count),
303        Opcode::Cnd => (select_expr(ctx, inst, dst_count, ">"), dst_count),
304        Opcode::Dp2Add => (format!("(dot({}, {}) + {})", source_expr(ctx, inst, 1, 2), source_expr(ctx, inst, 2, 2), source_expr(ctx, inst, 3, 1)), 1),
305        Opcode::M4x4 => (matrix_mul_expr(ctx, inst, 4, 4), 4),
306        Opcode::M4x3 => (matrix_mul_expr(ctx, inst, 4, 3), 3),
307        Opcode::M3x4 => (matrix_mul_expr(ctx, inst, 3, 4), 4),
308        Opcode::M3x3 => (matrix_mul_expr(ctx, inst, 3, 3), 3),
309        Opcode::M3x2 => (matrix_mul_expr(ctx, inst, 3, 2), 2),
310        Opcode::Tex | Opcode::TexLdl | Opcode::TexLdd => (texture_expr(ctx, inst), 4),
311        Opcode::TexCoord => (source_expr(ctx, inst, 1, dst_count), dst_count),
312        Opcode::Dsx => (format!("dpdx({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
313        Opcode::Dsy => (format!("dpdy({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
314        _ => return None,
315    };
316
317    let mut rhs = coerce_expr_width(raw_rhs, raw_width, dst_count);
318    rhs = apply_result_modifier(rhs, inst.dest_modifier(), dst_count);
319    Some((dst, rhs))
320}
321fn bin(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
322    format!("({} {} {})", source_expr(ctx, inst, 1, n), op, source_expr(ctx, inst, 2, n))
323}
324
325fn select_float(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
326    let a = source_expr(ctx, inst, 1, n);
327    let b = source_expr(ctx, inst, 2, n);
328    let zero = zero_vector(n);
329    let one = one_vector(n);
330    format!("select({}, {}, {} {} {})", zero, one, a, op, b)
331}
332
333fn select_expr(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
334    let a = source_expr(ctx, inst, 1, n);
335    let yes = source_expr(ctx, inst, 2, n);
336    let no = source_expr(ctx, inst, 3, n);
337    let pivot = if op == ">" { half_vector(n) } else { zero_vector(n) };
338    format!("select({}, {}, {} {} {})", no, yes, a, op, pivot)
339}
340
341fn sincos_expr(ctx: &Context<'_>, inst: &Instruction, n: usize) -> String {
342    if n <= 1 {
343        format!("cos({})", source_expr(ctx, inst, 1, 1))
344    } else {
345        format!("vec2<f32>(cos({}), sin({}))", source_expr(ctx, inst, 1, 1), source_expr(ctx, inst, 1, 1))
346    }
347}
348
349fn matrix_mul_expr(ctx: &Context<'_>, inst: &Instruction, vec_len: usize, out_len: usize) -> String {
350    let v = source_expr(ctx, inst, 1, vec_len);
351    let mut rows = Vec::new();
352    if let Some(reg) = inst.source_register(2) {
353        if reg.ty == RegisterType::Const {
354            for i in 0..out_len {
355                rows.push(format!("dot({}, {})", v, const_row_expr(ctx, reg.number + i as u16)));
356            }
357        } else {
358            for _ in 0..out_len {
359                rows.push(format!("dot({}, {})", v, source_expr(ctx, inst, 2, vec_len)));
360            }
361        }
362    }
363    vector_constructor(&rows)
364}
365
366fn texture_expr(ctx: &Context<'_>, inst: &Instruction) -> String {
367    let sampler = inst.source_register(2).map(|r| r.number).unwrap_or(0);
368    let sampler_ty = sampler_texture_type(ctx, sampler);
369    let dim = sampler_ty.hlsl_dim();
370    let tex = format!("tex_s{}", sampler);
371    let samp = format!("samp_s{}", sampler);
372    let controls = inst.texld_controls();
373    match inst.opcode {
374        Opcode::TexLdl => {
375            let coord = source_expr(ctx, inst, 1, dim);
376            let lod = source_component_expr(ctx, inst, 1, 3);
377            format!("textureSampleLevel({}, {}, {}, {})", tex, samp, coord, lod)
378        }
379        Opcode::TexLdd => {
380            let coord = source_expr(ctx, inst, 1, dim);
381            let ddx = source_expr(ctx, inst, 3, dim);
382            let ddy = source_expr(ctx, inst, 4, dim);
383            format!("textureSampleGrad({}, {}, {}, {}, {})", tex, samp, coord, ddx, ddy)
384        }
385        Opcode::Tex if controls == 1 => {
386            let coord = source_expr(ctx, inst, 1, dim);
387            let w = source_component_expr(ctx, inst, 1, 3);
388            format!("textureSample({}, {}, ({} / {}))", tex, samp, coord, w)
389        }
390        Opcode::Tex if controls == 2 => {
391            let coord = source_expr(ctx, inst, 1, dim);
392            let bias = source_component_expr(ctx, inst, 1, 3);
393            format!("textureSampleBias({}, {}, {}, {})", tex, samp, coord, bias)
394        }
395        _ => {
396            let coord = source_expr(ctx, inst, 1, dim);
397            format!("textureSample({}, {}, {})", tex, samp, coord)
398        }
399    }
400}
401fn source_expr(ctx: &Context<'_>, inst: &Instruction, param_index: usize, count: usize) -> String {
402    let Some(reg) = inst.source_register(param_index) else { return zero_vector(count); };
403    let count = count.clamp(1, 4);
404    let mut expr = register_base(ctx, reg);
405    if !register_is_scalar_source(ctx, reg) {
406        let swz = inst.source_swizzle(param_index);
407        expr.push_str(&wgsl_source_swizzle_suffix(swz, count));
408    }
409    apply_source_modifier(expr, inst.source_modifier(param_index), count)
410}
411
412fn source_component_expr(ctx: &Context<'_>, inst: &Instruction, param_index: usize, component: usize) -> String {
413    let Some(reg) = inst.source_register(param_index) else { return "0.0".to_string(); };
414    let mut expr = register_base(ctx, reg);
415    if !register_is_scalar_source(ctx, reg) {
416        let swz = inst.source_swizzle(param_index);
417        let comp = swz[component.clamp(0, 3)];
418        expr.push_str(component_suffix(comp));
419    }
420    apply_source_modifier(expr, inst.source_modifier(param_index), 1)
421}
422
423fn register_is_scalar_source(ctx: &Context<'_>, reg: RegisterKey) -> bool {
424    match reg.ty {
425        RegisterType::ConstBool | RegisterType::Loop | RegisterType::Label => true,
426        RegisterType::MiscType if ctx.shader.kind == ShaderKind::Pixel && reg.number == 1 => true,
427        _ => false,
428    }
429}
430fn register_base(ctx: &Context<'_>, reg: RegisterKey) -> String {
431    match reg.ty {
432        RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Predicate => temp_name(reg, ctx.shader.kind),
433        RegisterType::Texture if ctx.shader.kind == ShaderKind::Vertex => temp_name(reg, ctx.shader.kind),
434        RegisterType::Texture | RegisterType::Input | RegisterType::MiscType => format!("input.{}", input_field_name(reg)),
435        RegisterType::Const => const_row_expr(ctx, reg.number),
436        RegisterType::ConstInt => int_const_expr(ctx, reg.number),
437        RegisterType::ConstBool => bool_const_expr(ctx, reg.number),
438        RegisterType::Sampler => format!("samp_s{}", reg.number),
439        RegisterType::ColorOut | RegisterType::DepthOut | RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output => format!("output.{}", output_field_name(reg)),
440        RegisterType::Loop => "_loop".to_string(),
441        RegisterType::Label => format!("label{}", reg.number),
442        _ => format!("u{}", reg.number),
443    }
444}
445
446fn const_row_expr(ctx: &Context<'_>, index: u16) -> String {
447    if ctx.def_float.contains_key(&index) {
448        return format!("c{}", index);
449    }
450    format!("float_regs.c[{}]", index)
451}
452
453fn int_const_expr(ctx: &Context<'_>, index: u16) -> String {
454    if ctx.def_int.contains_key(&index) {
455        return format!("i{}", index);
456    }
457    format!("int_regs.i[{}]", index)
458}
459
460fn bool_const_expr(ctx: &Context<'_>, index: u16) -> String {
461    if ctx.def_bool.contains_key(&index) {
462        return format!("b{}", index);
463    }
464    format!("(bool_regs.b[{}] != 0u)", index)
465}
466
467fn analyze<'a>(shader: &'a ShaderModel, ctab: Option<&'a ConstantTable>) -> Context<'a> {
468    let mut ctx = Context {
469        shader,
470        ctab,
471        decls: BTreeMap::new(),
472        sampler_decls: BTreeMap::new(),
473        def_float: BTreeMap::new(),
474        def_int: BTreeMap::new(),
475        def_bool: BTreeMap::new(),
476        used_inputs: BTreeSet::new(),
477        used_outputs: BTreeSet::new(),
478        used_temps: BTreeSet::new(),
479        used_samplers: BTreeSet::new(),
480        used_consts: BTreeSet::new(),
481        used_int_consts: BTreeSet::new(),
482        used_bool_consts: BTreeSet::new(),
483        uses_lit: false,
484        uses_dst: false,
485    };
486
487    for inst in &shader.instructions {
488        match inst.opcode {
489            Opcode::Dcl => {
490                if let Some(reg) = inst.dest_register() {
491                    if reg.ty == RegisterType::Sampler {
492                        ctx.sampler_decls.insert(reg.number, inst.decl_sampler_type());
493                        ctx.used_samplers.insert(reg.number);
494                    } else {
495                        let semantic = semantic_from_decl(shader.kind, inst);
496                        ctx.decls.insert(reg, DeclInfo { reg, semantic });
497                    }
498                }
499            }
500            Opcode::Def => {
501                if let Some(reg) = inst.dest_register() {
502                    if reg.ty == RegisterType::Const {
503                        ctx.def_float.insert(reg.number, DefFloat {
504                            values: [inst.get_float_param(1), inst.get_float_param(2), inst.get_float_param(3), inst.get_float_param(4)],
505                        });
506                    }
507                }
508            }
509            Opcode::DefI => {
510                if let Some(reg) = inst.dest_register() {
511                    if reg.ty == RegisterType::ConstInt {
512                        ctx.def_int.insert(reg.number, DefInt {
513                            values: [inst.get_int_param(1), inst.get_int_param(2), inst.get_int_param(3), inst.get_int_param(4)],
514                        });
515                    }
516                }
517            }
518            Opcode::DefB => {
519                if let Some(reg) = inst.dest_register() {
520                    if reg.ty == RegisterType::ConstBool {
521                        ctx.def_bool.insert(reg.number, DefBool { value: inst.get_int_param(1) != 0 });
522                    }
523                }
524            }
525            Opcode::Lit => ctx.uses_lit = true,
526            Opcode::Dst => ctx.uses_dst = true,
527            _ => {}
528        }
529
530        if inst.opcode == Opcode::End || inst.opcode == Opcode::Comment {
531            continue;
532        }
533
534        if inst.opcode == Opcode::TexKill {
535            if let Some(reg) = inst.dest_register() {
536                classify_source_register(&mut ctx, reg);
537            }
538            continue;
539        }
540
541        if inst.opcode.has_destination() {
542            if let Some(reg) = inst.dest_register() {
543                classify_dest_register(&mut ctx, reg);
544            }
545        }
546
547        let first_src = if inst.opcode.has_destination() { 1 } else { 0 };
548        for pi in first_src..inst.params.len() {
549            if matches!(inst.opcode, Opcode::Dcl | Opcode::Def | Opcode::DefI | Opcode::DefB) {
550                continue;
551            }
552            if let Some(reg) = inst.source_register(pi) {
553                classify_source_register(&mut ctx, reg);
554            }
555        }
556    }
557
558    infer_missing_decls(&mut ctx);
559    ctx
560}
561
562fn classify_dest_register(ctx: &mut Context<'_>, reg: RegisterKey) {
563    match reg.ty {
564        RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Texture | RegisterType::Predicate => {
565            if !(ctx.shader.kind == ShaderKind::Pixel && reg.ty == RegisterType::Texture) {
566                ctx.used_temps.insert(reg);
567            }
568        }
569        RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output | RegisterType::ColorOut | RegisterType::DepthOut => {
570            ctx.used_outputs.insert(reg);
571        }
572        _ => {}
573    }
574}
575
576fn classify_source_register(ctx: &mut Context<'_>, reg: RegisterKey) {
577    match reg.ty {
578        RegisterType::Input | RegisterType::MiscType => {
579            ctx.used_inputs.insert(reg);
580        }
581        RegisterType::Texture => {
582            if ctx.shader.kind == ShaderKind::Pixel {
583                ctx.used_inputs.insert(reg);
584            } else {
585                ctx.used_temps.insert(reg);
586            }
587        }
588        RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Predicate => {
589            ctx.used_temps.insert(reg);
590        }
591        RegisterType::Const => {
592            ctx.used_consts.insert(reg.number);
593        }
594        RegisterType::ConstInt => {
595            ctx.used_int_consts.insert(reg.number);
596        }
597        RegisterType::ConstBool => {
598            ctx.used_bool_consts.insert(reg.number);
599        }
600        RegisterType::Sampler => {
601            ctx.used_samplers.insert(reg.number);
602        }
603        RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output | RegisterType::ColorOut | RegisterType::DepthOut => {
604            ctx.used_outputs.insert(reg);
605        }
606        _ => {}
607    }
608}
609
610fn infer_missing_decls(ctx: &mut Context<'_>) {
611    for reg in ctx.used_inputs.clone() {
612        if ctx.decls.contains_key(&reg) {
613            continue;
614        }
615        let semantic = inferred_input_semantic(ctx.shader.kind, reg);
616        ctx.decls.insert(reg, DeclInfo { reg, semantic });
617    }
618    for sampler in ctx.used_samplers.clone() {
619        ctx.sampler_decls.entry(sampler).or_insert(SamplerTextureType::TwoD);
620    }
621    if ctx.shader.kind == ShaderKind::Pixel && ctx.used_outputs.is_empty() {
622        ctx.used_outputs.insert(RegisterKey { ty: RegisterType::ColorOut, number: 0 });
623    }
624}
625
626fn semantic_from_decl(kind: ShaderKind, inst: &Instruction) -> String {
627    if let Some(reg) = inst.dest_register() {
628        if reg.ty == RegisterType::MiscType {
629            return match reg.number {
630                0 => "VPOS".to_string(),
631                1 => "VFACE".to_string(),
632                _ => format!("TEXCOORD{}", reg.number),
633            };
634        }
635        if kind == ShaderKind::Pixel {
636            match reg.ty {
637                RegisterType::Texture => return format!("TEXCOORD{}", reg.number),
638                RegisterType::Input => return format!("COLOR{}", reg.number),
639                _ => {}
640            }
641        }
642    }
643    let usage = inst.decl_usage();
644    let index = inst.decl_index();
645    let prefix = usage.semantic_prefix();
646    if index == 0 && !matches!(usage, DeclUsage::TexCoord | DeclUsage::Color | DeclUsage::Position) {
647        prefix.to_string()
648    } else {
649        format!("{}{}", prefix, index)
650    }
651}
652
653fn inferred_input_semantic(kind: ShaderKind, reg: RegisterKey) -> String {
654    match (kind, reg.ty) {
655        (ShaderKind::Pixel, RegisterType::Texture) => format!("TEXCOORD{}", reg.number),
656        (ShaderKind::Pixel, RegisterType::Input) => format!("COLOR{}", reg.number),
657        (_, RegisterType::Input) => match reg.number {
658            0 => "POSITION0".to_string(),
659            1 => "NORMAL0".to_string(),
660            n => format!("TEXCOORD{}", n.saturating_sub(2)),
661        },
662        (_, RegisterType::MiscType) => match reg.number {
663            0 => "VPOS".to_string(),
664            1 => "VFACE".to_string(),
665            _ => format!("TEXCOORD{}", reg.number),
666        },
667        _ => format!("TEXCOORD{}", reg.number),
668    }
669}
670
671fn output_semantic(ctx: &Context<'_>, reg: RegisterKey) -> String {
672    if let Some(decl) = ctx.decls.get(&reg) {
673        return decl.semantic.clone();
674    }
675    match reg.ty {
676        RegisterType::ColorOut => format!("COLOR{}", reg.number),
677        RegisterType::DepthOut => "DEPTH".to_string(),
678        RegisterType::RastOut => match reg.number {
679            0 => "POSITION0".to_string(),
680            1 => "FOG".to_string(),
681            2 => "PSIZE".to_string(),
682            n => format!("TEXCOORD{}", n),
683        },
684        RegisterType::AttrOut => format!("COLOR{}", reg.number),
685        RegisterType::Output => format!("TEXCOORD{}", reg.number),
686        _ => format!("TEXCOORD{}", reg.number),
687    }
688}
689
690fn input_attr(kind: ShaderKind, semantic: &str) -> String {
691    let upper = semantic.to_ascii_uppercase();
692    if kind == ShaderKind::Pixel && (upper == "VPOS" || upper == "POSITION" || upper == "POSITION0") {
693        return "@builtin(position)".to_string();
694    }
695    if kind == ShaderKind::Pixel && upper == "VFACE" {
696        return "@builtin(front_facing)".to_string();
697    }
698    format!("@location({})", semantic_location(semantic))
699}
700fn output_attr(kind: ShaderKind, semantic: &str, reg: RegisterKey) -> String {
701    if kind == ShaderKind::Vertex && is_position_semantic(semantic) {
702        "@builtin(position)".to_string()
703    } else if reg.ty == RegisterType::DepthOut || semantic == "DEPTH" {
704        "@builtin(frag_depth)".to_string()
705    } else {
706        format!("@location({})", semantic_location(semantic))
707    }
708}
709
710fn semantic_location(semantic: &str) -> u32 {
711    let upper = semantic.to_ascii_uppercase();
712    if let Some(n) = parse_semantic_index(&upper, "POSITION") {
713        return n;
714    }
715    if let Some(n) = parse_semantic_index(&upper, "NORMAL") {
716        return 1 + n;
717    }
718    if let Some(n) = parse_semantic_index(&upper, "COLOR") {
719        return 2 + n;
720    }
721    if let Some(n) = parse_semantic_index(&upper, "TEXCOORD") {
722        return 4 + n;
723    }
724    if let Some(n) = parse_semantic_index(&upper, "BLENDWEIGHT") {
725        return 12 + n;
726    }
727    if let Some(n) = parse_semantic_index(&upper, "BLENDINDICES") {
728        return 14 + n;
729    }
730    if let Some(n) = parse_semantic_index(&upper, "TANGENT") {
731        return 16 + n;
732    }
733    if let Some(n) = parse_semantic_index(&upper, "BINORMAL") {
734        return 18 + n;
735    }
736    if upper == "FOG" {
737        return 20;
738    }
739    if upper == "PSIZE" {
740        return 21;
741    }
742    31
743}
744
745fn parse_semantic_index(s: &str, prefix: &str) -> Option<u32> {
746    if !s.starts_with(prefix) {
747        return None;
748    }
749    let rest = &s[prefix.len()..];
750    if rest.is_empty() {
751        Some(0)
752    } else {
753        rest.parse::<u32>().ok()
754    }
755}
756
757fn is_position_semantic(s: &str) -> bool {
758    let upper = s.to_ascii_uppercase();
759    upper == "POSITION" || upper == "POSITION0" || upper == "POSITIONT" || upper == "POSITIONT0"
760}
761
762fn wgsl_input_field_type(ctx: &Context<'_>, reg: RegisterKey) -> &'static str {
763    match reg.ty {
764        RegisterType::MiscType if reg.number == 1 => "bool",
765        _ => {
766            if let Some(decl) = ctx.decls.get(&reg) {
767                if decl.semantic.starts_with("TEXCOORD") || decl.semantic.starts_with("COLOR") {
768                    return "vec4<f32>";
769                }
770            }
771            "vec4<f32>"
772        }
773    }
774}
775
776fn wgsl_output_field_type(reg: RegisterKey) -> &'static str {
777    match reg.ty {
778        RegisterType::DepthOut => "f32",
779        RegisterType::RastOut if reg.number == 1 || reg.number == 2 => "f32",
780        _ => "vec4<f32>",
781    }
782}
783
784fn input_struct_name(kind: ShaderKind) -> &'static str {
785    match kind {
786        ShaderKind::Vertex => "VSInput",
787        ShaderKind::Pixel => "PSInput",
788    }
789}
790
791fn output_struct_name(kind: ShaderKind) -> &'static str {
792    match kind {
793        ShaderKind::Vertex => "VSOutput",
794        ShaderKind::Pixel => "PSOutput",
795    }
796}
797
798fn input_field_name(reg: RegisterKey) -> String {
799    match reg.ty {
800        RegisterType::Input => format!("v{}", reg.number),
801        RegisterType::Texture => format!("t{}", reg.number),
802        RegisterType::MiscType => match reg.number {
803            0 => "vPos".to_string(),
804            1 => "vFace".to_string(),
805            _ => format!("vMisc{}", reg.number),
806        },
807        _ => format!("in{}", reg.number),
808    }
809}
810
811fn output_field_name(reg: RegisterKey) -> String {
812    match reg.ty {
813        RegisterType::ColorOut => format!("oC{}", reg.number),
814        RegisterType::DepthOut => "oDepth".to_string(),
815        RegisterType::RastOut => match reg.number {
816            0 => "oPos".to_string(),
817            1 => "oFog".to_string(),
818            2 => "oPts".to_string(),
819            _ => format!("o{}", reg.number),
820        },
821        RegisterType::AttrOut => format!("oD{}", reg.number),
822        RegisterType::Output => format!("o{}", reg.number),
823        _ => format!("out{}", reg.number),
824    }
825}
826
827fn temp_name(reg: RegisterKey, kind: ShaderKind) -> String {
828    match reg.ty {
829        RegisterType::Texture if kind == ShaderKind::Vertex => format!("a{}", reg.number),
830        RegisterType::Predicate => format!("p{}", reg.number),
831        RegisterType::TempFloat16 => format!("h{}", reg.number),
832        _ => format!("r{}", reg.number),
833    }
834}
835
836fn temp_type(reg: RegisterKey) -> &'static str {
837    match reg.ty {
838        RegisterType::Predicate => "vec4<bool>",
839        RegisterType::Texture => "vec4<f32>",
840        RegisterType::TempFloat16 => "vec4<f32>",
841        _ => "vec4<f32>",
842    }
843}
844
845fn zero_value(ty: &str) -> &'static str {
846    match ty {
847        "f32" => "0.0",
848        "bool" => "false",
849        "vec2<f32>" => "vec2<f32>(0.0)",
850        "vec3<f32>" => "vec3<f32>(0.0)",
851        "vec4<bool>" => "vec4<bool>(false)",
852        _ => "vec4<f32>(0.0)",
853    }
854}
855
856fn zero_vector(n: usize) -> String {
857    match n {
858        1 => "0.0".to_string(),
859        2 => "vec2<f32>(0.0)".to_string(),
860        3 => "vec3<f32>(0.0)".to_string(),
861        _ => "vec4<f32>(0.0)".to_string(),
862    }
863}
864
865fn one_vector(n: usize) -> String {
866    match n {
867        1 => "1.0".to_string(),
868        2 => "vec2<f32>(1.0)".to_string(),
869        3 => "vec3<f32>(1.0)".to_string(),
870        _ => "vec4<f32>(1.0)".to_string(),
871    }
872}
873
874fn half_vector(n: usize) -> String {
875    match n {
876        1 => "0.5".to_string(),
877        2 => "vec2<f32>(0.5)".to_string(),
878        3 => "vec3<f32>(0.5)".to_string(),
879        _ => "vec4<f32>(0.5)".to_string(),
880    }
881}
882
883fn vector_constructor(values: &[String]) -> String {
884    match values.len() {
885        0 => "0.0".to_string(),
886        1 => values[0].clone(),
887        2 => format!("vec2<f32>({}, {})", values[0], values[1]),
888        3 => format!("vec3<f32>({}, {}, {})", values[0], values[1], values[2]),
889        _ => format!("vec4<f32>({}, {}, {}, {})", values[0], values[1], values[2], values[3]),
890    }
891}
892
893fn wgsl_mask_suffix(mask: u8) -> String {
894    let m = if mask == 0 { 0xf } else { mask };
895    if m == 0xf {
896        String::new()
897    } else {
898        let mut s = String::from(".");
899        for (i, c) in ['x', 'y', 'z', 'w'].iter().enumerate() {
900            if (m & (1 << i)) != 0 {
901                s.push(*c);
902            }
903        }
904        s
905    }
906}
907
908fn wgsl_source_swizzle_suffix(swizzle: [usize; 4], count: usize) -> String {
909    let count = count.clamp(1, 4);
910    let identity = [0usize, 1, 2, 3];
911    if count == 4 && swizzle == identity {
912        return String::new();
913    }
914    let names = ['x', 'y', 'z', 'w'];
915    let mut s = String::from(".");
916    for i in 0..count {
917        s.push(names[swizzle[i]]);
918    }
919    s
920}
921
922fn component_suffix(component: usize) -> &'static str {
923    match component {
924        0 => ".x",
925        1 => ".y",
926        2 => ".z",
927        3 => ".w",
928        _ => ".x",
929    }
930}
931
932fn coerce_expr_width(expr: String, src_width: usize, dst_width: usize) -> String {
933    let src_width = src_width.clamp(1, 4);
934    let dst_width = dst_width.clamp(1, 4);
935    if src_width == dst_width {
936        return expr;
937    }
938    if src_width == 1 {
939        return match dst_width {
940            1 => expr,
941            2 => format!("vec2<f32>({})", expr),
942            3 => format!("vec3<f32>({})", expr),
943            _ => format!("vec4<f32>({})", expr),
944        };
945    }
946    if src_width > dst_width {
947        let suffix = match dst_width {
948            1 => ".x",
949            2 => ".xy",
950            3 => ".xyz",
951            _ => "",
952        };
953        if suffix.is_empty() { expr } else { format!("({}){}", expr, suffix) }
954    } else {
955        match dst_width {
956            2 => format!("vec2<f32>({}, 0.0)", expr),
957            3 => format!("vec3<f32>({}, 0.0)", expr),
958            _ => format!("vec4<f32>({}, 0.0)", expr),
959        }
960    }
961}
962fn apply_source_modifier(expr: String, modifier: SourceModifier, count: usize) -> String {
963    match modifier {
964        SourceModifier::None => expr,
965        SourceModifier::Negate => format!("-({})", expr),
966        SourceModifier::Bias => format!("(({}) - {})", expr, half_vector(count)),
967        SourceModifier::BiasAndNegate => format!("-(({}) - {})", expr, half_vector(count)),
968        SourceModifier::Sign => format!("((({}) - {}) * {})", expr, half_vector(count), one_vector(count)),
969        SourceModifier::SignAndNegate => format!("-((({}) - {}) * {})", expr, half_vector(count), one_vector(count)),
970        SourceModifier::Complement => format!("({} - ({}))", one_vector(count), expr),
971        SourceModifier::X2 => format!("(({}) * 2.0)", expr),
972        SourceModifier::X2AndNegate => format!("-(({}) * 2.0)", expr),
973        SourceModifier::DivideByZ => format!("(({}) / ({}).z)", expr, expr),
974        SourceModifier::DivideByW => format!("(({}) / ({}).w)", expr, expr),
975        SourceModifier::Abs => format!("abs({})", expr),
976        SourceModifier::AbsAndNegate => format!("-abs({})", expr),
977        SourceModifier::Not => format!("!({})", expr),
978        SourceModifier::Unknown(_) => expr,
979    }
980}
981
982fn apply_result_modifier(expr: String, modifier: ResultModifier, count: usize) -> String {
983    if modifier.saturate {
984        format!("clamp({}, {}, {})", expr, zero_vector(count), one_vector(count))
985    } else {
986        expr
987    }
988}
989
990fn scalar_bool_expr(expr: String) -> String {
991    format!("({} != 0.0)", expr)
992}
993
994fn line(out: &mut String, indent: usize, s: &str) {
995    for _ in 0..indent {
996        out.push_str("    ");
997    }
998    out.push_str(s);
999    out.push('\n');
1000}
1001
1002fn compare_all_expr(a: String, op: &str, b: String, width: usize) -> String {
1003    if width <= 1 {
1004        format!("{} {} {}", a, op, b)
1005    } else {
1006        format!("all({} {} {})", a, op, b)
1007    }
1008}
1009
1010fn cmp_op(code: u8) -> &'static str {
1011    match code {
1012        1 => ">",
1013        2 => "==",
1014        3 => ">=",
1015        4 => "<",
1016        5 => "!=",
1017        6 => "<=",
1018        _ => "!=",
1019    }
1020}
1021
1022fn sampler_texture_type(ctx: &Context<'_>, sampler: u16) -> SamplerTextureType {
1023    if let Some(c) = sampler_constant(ctx, sampler) {
1024        if let Some(t) = &c.type_info {
1025            return match t.value_type {
1026                ValueType::SamplerCube => SamplerTextureType::Cube,
1027                ValueType::Sampler3D => SamplerTextureType::Volume,
1028                _ => SamplerTextureType::TwoD,
1029            };
1030        }
1031    }
1032    ctx.sampler_decls.get(&sampler).copied().unwrap_or(SamplerTextureType::TwoD)
1033}
1034
1035fn wgsl_texture_type(ty: SamplerTextureType) -> &'static str {
1036    match ty {
1037        SamplerTextureType::Cube => "texture_cube<f32>",
1038        SamplerTextureType::Volume => "texture_3d<f32>",
1039        SamplerTextureType::TwoD | SamplerTextureType::Unknown => "texture_2d<f32>",
1040    }
1041}
1042
1043fn sampler_constant<'a>(ctx: &'a Context<'_>, sampler: u16) -> Option<&'a ConstantInfo> {
1044    ctx.ctab?.constants.iter().find(|c| c.register_set == RegisterSet::Sampler && c.register_index == sampler)
1045}
1046
1047fn fmt_f32(v: f32) -> String {
1048    if v.is_nan() {
1049        "(0.0 / 0.0)".to_string()
1050    } else if v.is_infinite() {
1051        if v.is_sign_positive() { "(1.0 / 0.0)".to_string() } else { "(-1.0 / 0.0)".to_string() }
1052    } else {
1053        let mut s = format!("{:.9}", v);
1054        while s.contains('.') && s.ends_with('0') {
1055            s.pop();
1056        }
1057        if s.ends_with('.') {
1058            s.push('0');
1059        }
1060        if !s.contains('.') && !s.contains('e') && !s.contains('E') {
1061            s.push_str(".0");
1062        }
1063        s
1064    }
1065}