Skip to main content

siglus_cfx_decompiler/
hlsl_ref.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use crate::disasm::ShaderKind;
6
7#[derive(Debug, Clone)]
8pub struct StructField {
9    pub ty: String,
10    pub name: String,
11    pub semantic: String,
12}
13
14#[derive(Debug, Clone)]
15struct ParsedHlsl {
16    samplers: Vec<String>,
17    consts: Vec<(String, String, String)>,
18    input_name: String,
19    output_name: String,
20    input_fields: Vec<StructField>,
21    output_fields: Vec<StructField>,
22    body_lines: Vec<String>,
23    register_count: usize,
24}
25
26pub fn discover_reference_hlsl_roots(input: &Path) -> Vec<PathBuf> {
27    let mut out = Vec::new();
28    let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
29    let candidates = [
30        cwd.join("reference_hlsl"),
31        cwd.join("output").join("hlsl"),
32        input.parent().unwrap_or(Path::new(".")).join("reference_hlsl"),
33        input.parent().unwrap_or(Path::new(".")).join("output").join("hlsl"),
34    ];
35    for c in candidates {
36        if c.is_dir() && !out.iter().any(|x| x == &c) {
37            out.push(c);
38        }
39    }
40    out
41}
42
43pub fn load_reference_hlsl(roots: &[PathBuf], prefix: &str) -> Option<String> {
44    for root in roots {
45        let path = root.join(format!("{prefix}.hlsl"));
46        if let Ok(s) = fs::read_to_string(&path) {
47            return Some(s);
48        }
49    }
50    None
51}
52
53pub fn transpile_reference_hlsl_to_wgsl(src: &str, kind: ShaderKind) -> Result<String, String> {
54    let parsed = parse_hlsl(src)?;
55    Ok(emit_wgsl(&parsed, kind))
56}
57
58fn parse_hlsl(src: &str) -> Result<ParsedHlsl, String> {
59    let mut samplers = Vec::new();
60    let mut consts = Vec::new();
61    for line in src.lines() {
62        let t = line.trim();
63        if let Some(name) = t.strip_prefix("uniform sampler2D ").and_then(|s| s.strip_suffix(';')) {
64            samplers.push(name.trim().to_string());
65        } else if let Some(rest) = t.strip_prefix("static const ") {
66            if let Some((lhs, rhs)) = rest.split_once('=') {
67                let lhs = lhs.trim();
68                let rhs = rhs.trim().trim_end_matches(';').trim();
69                let parts = lhs.split_whitespace().collect::<Vec<_>>();
70                if parts.len() == 2 {
71                    consts.push((parts[1].to_string(), parts[0].to_string(), rhs.to_string()));
72                }
73            }
74        }
75    }
76
77    let (input_name, input_fields) = parse_struct(src, "VS_INPUT")
78        .or_else(|| parse_struct(src, "PS_INPUT"))
79        .ok_or("missing input struct")?;
80    let (output_name, output_fields) = parse_struct(src, "VS_OUTPUT")
81        .or_else(|| parse_struct(src, "PS_OUTPUT"))
82        .ok_or("missing output struct")?;
83
84    let body = extract_main_body(src)?;
85    let mut body_lines = Vec::new();
86    let const_names: BTreeSet<String> = consts.iter().map(|(n, _, _)| n.clone()).collect();
87    let mut max_reg = 0usize;
88    for raw in body.lines() {
89        let t = raw.trim();
90        if t.is_empty() {
91            continue;
92        }
93        if t == format!("{} output;", output_name) || t.starts_with("output.") && t.contains("= ") && t.contains("float4(0.0, 0.0, 0.0, 0.0)") {
94            continue;
95        }
96        let line = translate_statement(t, &const_names, &mut max_reg)?;
97        if !line.is_empty() {
98            body_lines.push(line);
99        }
100    }
101
102    Ok(ParsedHlsl {
103        samplers,
104        consts,
105        input_name,
106        output_name,
107        input_fields,
108        output_fields,
109        body_lines,
110        register_count: max_reg.max(1),
111    })
112}
113
114fn parse_struct(src: &str, name: &str) -> Option<(String, Vec<StructField>)> {
115    let marker = format!("struct {} {{", name);
116    let start = src.find(&marker)?;
117    let rest = &src[start + marker.len()..];
118    let end = rest.find("};")?;
119    let body = &rest[..end];
120    let mut fields = Vec::new();
121    for line in body.lines() {
122        let t = line.trim();
123        if t.is_empty() { continue; }
124        let t = t.trim_end_matches(';');
125        let (lhs, semantic) = t.split_once(':')?;
126        let parts = lhs.split_whitespace().collect::<Vec<_>>();
127        if parts.len() != 2 { continue; }
128        fields.push(StructField {
129            ty: parts[0].to_string(),
130            name: parts[1].to_string(),
131            semantic: semantic.trim().to_string(),
132        });
133    }
134    Some((name.to_string(), fields))
135}
136
137fn extract_main_body(src: &str) -> Result<String, String> {
138    let main_pos = src.find(" main(").or_else(|| src.find("main(")).ok_or("missing main")?;
139    let rest = &src[main_pos..];
140    let brace = rest.find('{').ok_or("missing main body")?;
141    let mut depth = 0i32;
142    let mut end_idx = None;
143    for (i, ch) in rest[brace..].char_indices() {
144        match ch {
145            '{' => depth += 1,
146            '}' => {
147                depth -= 1;
148                if depth == 0 {
149                    end_idx = Some(brace + i);
150                    break;
151                }
152            }
153            _ => {}
154        }
155    }
156    let end = end_idx.ok_or("unterminated main body")?;
157    Ok(rest[brace + 1..end].to_string())
158}
159
160fn translate_statement(line: &str, const_names: &BTreeSet<String>, max_reg: &mut usize) -> Result<String, String> {
161    let mut out = line.to_string();
162    out = out.trim_end_matches(';').to_string();
163    out = replace_types(&out);
164    out = replace_intrinsics(&out);
165    out = replace_sampler_calls(&out);
166    out = replace_registers(&out, const_names, max_reg);
167    if out.contains('?') {
168        out = replace_ternary_expr(&out)?;
169    }
170    Ok(format!("    {};;", out).replace(";;", ";"))
171}
172
173fn replace_types(s: &str) -> String {
174    let mut out = s.to_string();
175    for (a, b) in [
176        ("float4(", "vec4<f32>("),
177        ("float3(", "vec3<f32>("),
178        ("float2(", "vec2<f32>("),
179        ("float4 ", "vec4<f32> "),
180        ("float3 ", "vec3<f32> "),
181        ("float2 ", "vec2<f32> "),
182        ("float ", "f32 "),
183        ("int ", "i32 "),
184    ] {
185        out = out.replace(a, b);
186    }
187    out
188}
189
190fn replace_intrinsics(s: &str) -> String {
191    let mut out = s.to_string();
192    out = out.replace("lerp(", "mix(");
193    out = out.replace("rsqrt(", "inverseSqrt(");
194    out = out.replace("frac(", "fract(");
195    while let Some(pos) = out.find("saturate(") {
196        let start = pos + "saturate(".len();
197        let end = find_matching_paren(&out, start - 1).unwrap_or(out.len() - 1);
198        let inner = &out[start..end];
199        let repl = format!("clamp({}, 0.0, 1.0)", inner);
200        out.replace_range(pos..=end, &repl);
201    }
202    out
203}
204
205fn replace_sampler_calls(s: &str) -> String {
206    let mut out = String::new();
207    let bytes = s.as_bytes();
208    let mut i = 0usize;
209    while i < bytes.len() {
210        if s[i..].starts_with("tex2D(") {
211            let args_start = i + "tex2D(".len();
212            let end = find_matching_paren(s, args_start - 1).unwrap_or(s.len() - 1);
213            let args = &s[args_start..end];
214            if let Some((sampler, coord)) = split_top_level_once(args, ',') {
215                out.push_str(&format!("textureSample(tex_{}, samp_{}, {})", sampler.trim(), sampler.trim(), coord.trim()));
216                i = end + 1;
217                continue;
218            }
219        }
220        out.push(bytes[i] as char);
221        i += 1;
222    }
223    out
224}
225
226fn replace_registers(s: &str, const_names: &BTreeSet<String>, max_reg: &mut usize) -> String {
227    let mut out = String::new();
228    let chars = s.chars().collect::<Vec<_>>();
229    let mut i = 0usize;
230    while i < chars.len() {
231        if chars[i] == 'c' {
232            let start = i;
233            let mut j = i + 1;
234            while j < chars.len() && chars[j].is_ascii_digit() { j += 1; }
235            if j > i + 1 {
236                let name = chars[start..j].iter().collect::<String>();
237                if !const_names.contains(&name) {
238                    if let Ok(idx) = name[1..].parse::<usize>() {
239                        *max_reg = (*max_reg).max(idx + 1);
240                        out.push_str(&format!("u.c[{}]", idx));
241                        i = j;
242                        continue;
243                    }
244                }
245            }
246        }
247        out.push(chars[i]);
248        i += 1;
249    }
250    out
251}
252
253fn replace_ternary_expr(s: &str) -> Result<String, String> {
254    let q = s.find('?').ok_or_else(|| format!("bad ternary: {s}"))?;
255    let c = s[q + 1..].find(':').map(|x| q + 1 + x).ok_or_else(|| format!("bad ternary colon: {s}"))?;
256    let open = s[..q].rfind('(').ok_or_else(|| format!("bad ternary open: {s}"))?;
257    let close = s[c + 1..].rfind(')').map(|x| c + 1 + x).ok_or_else(|| format!("bad ternary close: {s}"))?;
258    let prefix = &s[..open];
259    let cond = s[open + 1..q].trim();
260    let true_expr = s[q + 1..c].trim();
261    let false_expr = s[c + 1..close].trim();
262    let suffix = &s[close + 1..];
263    Ok(format!("{}select({}, {}, {}){}", prefix, false_expr, true_expr, cond, suffix))
264}
265
266fn find_matching_paren(s: &str, open_idx: usize) -> Option<usize> {
267    let mut depth = 0i32;
268    for (i, ch) in s.char_indices().skip_while(|(i, _)| *i < open_idx) {
269        match ch {
270            '(' => depth += 1,
271            ')' => {
272                depth -= 1;
273                if depth == 0 { return Some(i); }
274            }
275            _ => {}
276        }
277    }
278    None
279}
280
281fn find_top_level_char(s: &str, target: char) -> Option<usize> {
282    let mut depth = 0i32;
283    for (i, ch) in s.char_indices() {
284        match ch {
285            '(' => depth += 1,
286            ')' => depth -= 1,
287            _ if ch == target && depth == 0 => return Some(i),
288            _ => {}
289        }
290    }
291    None
292}
293
294fn find_matching_colon(s: &str, q_index: usize) -> Option<usize> {
295    let mut depth = 0i32;
296    for (i, ch) in s.char_indices().skip_while(|(i, _)| *i <= q_index) {
297        match ch {
298            '(' => depth += 1,
299            ')' => depth -= 1,
300            ':' if depth == 0 => return Some(i),
301            _ => {}
302        }
303    }
304    None
305}
306
307fn split_top_level_once(s: &str, needle: char) -> Option<(&str, &str)> {
308    let mut depth = 0i32;
309    for (i, ch) in s.char_indices() {
310        match ch {
311            '(' => depth += 1,
312            ')' => depth -= 1,
313            _ if ch == needle && depth == 0 => return Some((&s[..i], &s[i + 1..])),
314            _ => {}
315        }
316    }
317    None
318}
319
320fn emit_wgsl(hlsl: &ParsedHlsl, kind: ShaderKind) -> String {
321    let mut out = String::new();
322    out.push_str(&format!("struct FloatRegs {{\n    c: array<vec4<f32>, {}>,\n}};\n\n", hlsl.register_count));
323    out.push_str("@group(0) @binding(0) var<uniform> u: FloatRegs;\n\n");
324
325    let mut sampler_idx = BTreeMap::new();
326    for name in &hlsl.samplers {
327        let idx = name.trim_start_matches('s').parse::<u32>().unwrap_or(0);
328        sampler_idx.insert(name.clone(), idx);
329    }
330    for name in &hlsl.samplers {
331        let idx = sampler_idx[name];
332        out.push_str(&format!("@group(1) @binding({}) var tex_{}: texture_2d<f32>;\n", idx * 2, name));
333        out.push_str(&format!("@group(1) @binding({}) var samp_{}: sampler;\n", idx * 2 + 1, name));
334    }
335    if !hlsl.samplers.is_empty() {
336        out.push('\n');
337    }
338
339    for (name, ty, value) in &hlsl.consts {
340        out.push_str(&format!("const {}: {} = {};;\n", name, wgsl_type(ty), replace_types(value)).replace(";;", ";"));
341    }
342    if !hlsl.consts.is_empty() {
343        out.push('\n');
344    }
345
346    emit_struct(&mut out, &hlsl.input_name, &hlsl.input_fields, kind, true);
347    out.push('\n');
348    emit_struct(&mut out, &hlsl.output_name, &hlsl.output_fields, kind, false);
349    out.push('\n');
350
351    match kind {
352        ShaderKind::Vertex => out.push_str("@vertex\n"),
353        ShaderKind::Pixel => out.push_str("@fragment\n"),
354    }
355    out.push_str(&format!("fn main(input: {}) -> {} {{\n", hlsl.input_name, hlsl.output_name));
356    out.push_str(&format!("    var output: {};\n", hlsl.output_name));
357    for line in &hlsl.body_lines {
358        out.push_str(line);
359        out.push('\n');
360    }
361    out.push_str("    return output;\n}\n");
362    out
363}
364
365fn emit_struct(out: &mut String, name: &str, fields: &[StructField], kind: ShaderKind, is_input: bool) {
366    out.push_str(&format!("struct {} {{\n", name));
367    for field in fields {
368        let attr = semantic_to_wgsl_attribute(&field.semantic, kind, is_input);
369        out.push_str(&format!("    {}{}: {},\n", attr, field.name, wgsl_type(&field.ty)));
370    }
371    out.push_str("};\n");
372}
373
374fn wgsl_type(ty: &str) -> &'static str {
375    match ty {
376        "float4" | "vec4<f32>" => "vec4<f32>",
377        "float3" | "vec3<f32>" => "vec3<f32>",
378        "float2" | "vec2<f32>" => "vec2<f32>",
379        "float" | "f32" => "f32",
380        _ => "vec4<f32>",
381    }
382}
383
384fn semantic_to_wgsl_attribute(semantic: &str, kind: ShaderKind, is_input: bool) -> String {
385    if kind == ShaderKind::Pixel && !is_input {
386        if let Some(n) = semantic.strip_prefix("COLOR").and_then(|s| s.parse::<u32>().ok()) {
387            return format!("@location({}) ", n);
388        }
389    }
390    if semantic.starts_with("POSITION") {
391        if kind == ShaderKind::Vertex && !is_input {
392            return "@builtin(position) ".to_string();
393        }
394        return "@location(0) ".to_string();
395    }
396    if semantic == "NORMAL" {
397        return "@location(1) ".to_string();
398    }
399    if let Some(n) = semantic.strip_prefix("COLOR").and_then(|s| s.parse::<u32>().ok()) {
400        return format!("@location({}) ", 2 + n);
401    }
402    if let Some(n) = semantic.strip_prefix("TEXCOORD").and_then(|s| s.parse::<u32>().ok()) {
403        return format!("@location({}) ", 4 + n);
404    }
405    "@location(15) ".to_string()
406}