1use std::collections::{BTreeMap, BTreeSet};
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use crate::disasm::ShaderKind;
6
7#[derive(Debug, Clone)]
8pub struct StructField {
9 pub ty: String,
10 pub name: String,
11 pub semantic: String,
12}
13
14#[derive(Debug, Clone)]
15struct ParsedHlsl {
16 samplers: Vec<String>,
17 consts: Vec<(String, String, String)>,
18 input_name: String,
19 output_name: String,
20 input_fields: Vec<StructField>,
21 output_fields: Vec<StructField>,
22 body_lines: Vec<String>,
23 register_count: usize,
24}
25
26pub fn discover_reference_hlsl_roots(input: &Path) -> Vec<PathBuf> {
27 let mut out = Vec::new();
28 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
29 let candidates = [
30 cwd.join("reference_hlsl"),
31 cwd.join("output").join("hlsl"),
32 input.parent().unwrap_or(Path::new(".")).join("reference_hlsl"),
33 input.parent().unwrap_or(Path::new(".")).join("output").join("hlsl"),
34 ];
35 for c in candidates {
36 if c.is_dir() && !out.iter().any(|x| x == &c) {
37 out.push(c);
38 }
39 }
40 out
41}
42
43pub fn load_reference_hlsl(roots: &[PathBuf], prefix: &str) -> Option<String> {
44 for root in roots {
45 let path = root.join(format!("{prefix}.hlsl"));
46 if let Ok(s) = fs::read_to_string(&path) {
47 return Some(s);
48 }
49 }
50 None
51}
52
53pub fn transpile_reference_hlsl_to_wgsl(src: &str, kind: ShaderKind) -> Result<String, String> {
54 let parsed = parse_hlsl(src)?;
55 Ok(emit_wgsl(&parsed, kind))
56}
57
58fn parse_hlsl(src: &str) -> Result<ParsedHlsl, String> {
59 let mut samplers = Vec::new();
60 let mut consts = Vec::new();
61 for line in src.lines() {
62 let t = line.trim();
63 if let Some(name) = t.strip_prefix("uniform sampler2D ").and_then(|s| s.strip_suffix(';')) {
64 samplers.push(name.trim().to_string());
65 } else if let Some(rest) = t.strip_prefix("static const ") {
66 if let Some((lhs, rhs)) = rest.split_once('=') {
67 let lhs = lhs.trim();
68 let rhs = rhs.trim().trim_end_matches(';').trim();
69 let parts = lhs.split_whitespace().collect::<Vec<_>>();
70 if parts.len() == 2 {
71 consts.push((parts[1].to_string(), parts[0].to_string(), rhs.to_string()));
72 }
73 }
74 }
75 }
76
77 let (input_name, input_fields) = parse_struct(src, "VS_INPUT")
78 .or_else(|| parse_struct(src, "PS_INPUT"))
79 .ok_or("missing input struct")?;
80 let (output_name, output_fields) = parse_struct(src, "VS_OUTPUT")
81 .or_else(|| parse_struct(src, "PS_OUTPUT"))
82 .ok_or("missing output struct")?;
83
84 let body = extract_main_body(src)?;
85 let mut body_lines = Vec::new();
86 let const_names: BTreeSet<String> = consts.iter().map(|(n, _, _)| n.clone()).collect();
87 let mut max_reg = 0usize;
88 for raw in body.lines() {
89 let t = raw.trim();
90 if t.is_empty() {
91 continue;
92 }
93 if t == format!("{} output;", output_name) || t.starts_with("output.") && t.contains("= ") && t.contains("float4(0.0, 0.0, 0.0, 0.0)") {
94 continue;
95 }
96 let line = translate_statement(t, &const_names, &mut max_reg)?;
97 if !line.is_empty() {
98 body_lines.push(line);
99 }
100 }
101
102 Ok(ParsedHlsl {
103 samplers,
104 consts,
105 input_name,
106 output_name,
107 input_fields,
108 output_fields,
109 body_lines,
110 register_count: max_reg.max(1),
111 })
112}
113
114fn parse_struct(src: &str, name: &str) -> Option<(String, Vec<StructField>)> {
115 let marker = format!("struct {} {{", name);
116 let start = src.find(&marker)?;
117 let rest = &src[start + marker.len()..];
118 let end = rest.find("};")?;
119 let body = &rest[..end];
120 let mut fields = Vec::new();
121 for line in body.lines() {
122 let t = line.trim();
123 if t.is_empty() { continue; }
124 let t = t.trim_end_matches(';');
125 let (lhs, semantic) = t.split_once(':')?;
126 let parts = lhs.split_whitespace().collect::<Vec<_>>();
127 if parts.len() != 2 { continue; }
128 fields.push(StructField {
129 ty: parts[0].to_string(),
130 name: parts[1].to_string(),
131 semantic: semantic.trim().to_string(),
132 });
133 }
134 Some((name.to_string(), fields))
135}
136
137fn extract_main_body(src: &str) -> Result<String, String> {
138 let main_pos = src.find(" main(").or_else(|| src.find("main(")).ok_or("missing main")?;
139 let rest = &src[main_pos..];
140 let brace = rest.find('{').ok_or("missing main body")?;
141 let mut depth = 0i32;
142 let mut end_idx = None;
143 for (i, ch) in rest[brace..].char_indices() {
144 match ch {
145 '{' => depth += 1,
146 '}' => {
147 depth -= 1;
148 if depth == 0 {
149 end_idx = Some(brace + i);
150 break;
151 }
152 }
153 _ => {}
154 }
155 }
156 let end = end_idx.ok_or("unterminated main body")?;
157 Ok(rest[brace + 1..end].to_string())
158}
159
160fn translate_statement(line: &str, const_names: &BTreeSet<String>, max_reg: &mut usize) -> Result<String, String> {
161 let mut out = line.to_string();
162 out = out.trim_end_matches(';').to_string();
163 out = replace_types(&out);
164 out = replace_intrinsics(&out);
165 out = replace_sampler_calls(&out);
166 out = replace_registers(&out, const_names, max_reg);
167 if out.contains('?') {
168 out = replace_ternary_expr(&out)?;
169 }
170 Ok(format!(" {};;", out).replace(";;", ";"))
171}
172
173fn replace_types(s: &str) -> String {
174 let mut out = s.to_string();
175 for (a, b) in [
176 ("float4(", "vec4<f32>("),
177 ("float3(", "vec3<f32>("),
178 ("float2(", "vec2<f32>("),
179 ("float4 ", "vec4<f32> "),
180 ("float3 ", "vec3<f32> "),
181 ("float2 ", "vec2<f32> "),
182 ("float ", "f32 "),
183 ("int ", "i32 "),
184 ] {
185 out = out.replace(a, b);
186 }
187 out
188}
189
190fn replace_intrinsics(s: &str) -> String {
191 let mut out = s.to_string();
192 out = out.replace("lerp(", "mix(");
193 out = out.replace("rsqrt(", "inverseSqrt(");
194 out = out.replace("frac(", "fract(");
195 while let Some(pos) = out.find("saturate(") {
196 let start = pos + "saturate(".len();
197 let end = find_matching_paren(&out, start - 1).unwrap_or(out.len() - 1);
198 let inner = &out[start..end];
199 let repl = format!("clamp({}, 0.0, 1.0)", inner);
200 out.replace_range(pos..=end, &repl);
201 }
202 out
203}
204
205fn replace_sampler_calls(s: &str) -> String {
206 let mut out = String::new();
207 let bytes = s.as_bytes();
208 let mut i = 0usize;
209 while i < bytes.len() {
210 if s[i..].starts_with("tex2D(") {
211 let args_start = i + "tex2D(".len();
212 let end = find_matching_paren(s, args_start - 1).unwrap_or(s.len() - 1);
213 let args = &s[args_start..end];
214 if let Some((sampler, coord)) = split_top_level_once(args, ',') {
215 out.push_str(&format!("textureSample(tex_{}, samp_{}, {})", sampler.trim(), sampler.trim(), coord.trim()));
216 i = end + 1;
217 continue;
218 }
219 }
220 out.push(bytes[i] as char);
221 i += 1;
222 }
223 out
224}
225
226fn replace_registers(s: &str, const_names: &BTreeSet<String>, max_reg: &mut usize) -> String {
227 let mut out = String::new();
228 let chars = s.chars().collect::<Vec<_>>();
229 let mut i = 0usize;
230 while i < chars.len() {
231 if chars[i] == 'c' {
232 let start = i;
233 let mut j = i + 1;
234 while j < chars.len() && chars[j].is_ascii_digit() { j += 1; }
235 if j > i + 1 {
236 let name = chars[start..j].iter().collect::<String>();
237 if !const_names.contains(&name) {
238 if let Ok(idx) = name[1..].parse::<usize>() {
239 *max_reg = (*max_reg).max(idx + 1);
240 out.push_str(&format!("u.c[{}]", idx));
241 i = j;
242 continue;
243 }
244 }
245 }
246 }
247 out.push(chars[i]);
248 i += 1;
249 }
250 out
251}
252
253fn replace_ternary_expr(s: &str) -> Result<String, String> {
254 let q = s.find('?').ok_or_else(|| format!("bad ternary: {s}"))?;
255 let c = s[q + 1..].find(':').map(|x| q + 1 + x).ok_or_else(|| format!("bad ternary colon: {s}"))?;
256 let open = s[..q].rfind('(').ok_or_else(|| format!("bad ternary open: {s}"))?;
257 let close = s[c + 1..].rfind(')').map(|x| c + 1 + x).ok_or_else(|| format!("bad ternary close: {s}"))?;
258 let prefix = &s[..open];
259 let cond = s[open + 1..q].trim();
260 let true_expr = s[q + 1..c].trim();
261 let false_expr = s[c + 1..close].trim();
262 let suffix = &s[close + 1..];
263 Ok(format!("{}select({}, {}, {}){}", prefix, false_expr, true_expr, cond, suffix))
264}
265
266fn find_matching_paren(s: &str, open_idx: usize) -> Option<usize> {
267 let mut depth = 0i32;
268 for (i, ch) in s.char_indices().skip_while(|(i, _)| *i < open_idx) {
269 match ch {
270 '(' => depth += 1,
271 ')' => {
272 depth -= 1;
273 if depth == 0 { return Some(i); }
274 }
275 _ => {}
276 }
277 }
278 None
279}
280
281fn find_top_level_char(s: &str, target: char) -> Option<usize> {
282 let mut depth = 0i32;
283 for (i, ch) in s.char_indices() {
284 match ch {
285 '(' => depth += 1,
286 ')' => depth -= 1,
287 _ if ch == target && depth == 0 => return Some(i),
288 _ => {}
289 }
290 }
291 None
292}
293
294fn find_matching_colon(s: &str, q_index: usize) -> Option<usize> {
295 let mut depth = 0i32;
296 for (i, ch) in s.char_indices().skip_while(|(i, _)| *i <= q_index) {
297 match ch {
298 '(' => depth += 1,
299 ')' => depth -= 1,
300 ':' if depth == 0 => return Some(i),
301 _ => {}
302 }
303 }
304 None
305}
306
307fn split_top_level_once(s: &str, needle: char) -> Option<(&str, &str)> {
308 let mut depth = 0i32;
309 for (i, ch) in s.char_indices() {
310 match ch {
311 '(' => depth += 1,
312 ')' => depth -= 1,
313 _ if ch == needle && depth == 0 => return Some((&s[..i], &s[i + 1..])),
314 _ => {}
315 }
316 }
317 None
318}
319
320fn emit_wgsl(hlsl: &ParsedHlsl, kind: ShaderKind) -> String {
321 let mut out = String::new();
322 out.push_str(&format!("struct FloatRegs {{\n c: array<vec4<f32>, {}>,\n}};\n\n", hlsl.register_count));
323 out.push_str("@group(0) @binding(0) var<uniform> u: FloatRegs;\n\n");
324
325 let mut sampler_idx = BTreeMap::new();
326 for name in &hlsl.samplers {
327 let idx = name.trim_start_matches('s').parse::<u32>().unwrap_or(0);
328 sampler_idx.insert(name.clone(), idx);
329 }
330 for name in &hlsl.samplers {
331 let idx = sampler_idx[name];
332 out.push_str(&format!("@group(1) @binding({}) var tex_{}: texture_2d<f32>;\n", idx * 2, name));
333 out.push_str(&format!("@group(1) @binding({}) var samp_{}: sampler;\n", idx * 2 + 1, name));
334 }
335 if !hlsl.samplers.is_empty() {
336 out.push('\n');
337 }
338
339 for (name, ty, value) in &hlsl.consts {
340 out.push_str(&format!("const {}: {} = {};;\n", name, wgsl_type(ty), replace_types(value)).replace(";;", ";"));
341 }
342 if !hlsl.consts.is_empty() {
343 out.push('\n');
344 }
345
346 emit_struct(&mut out, &hlsl.input_name, &hlsl.input_fields, kind, true);
347 out.push('\n');
348 emit_struct(&mut out, &hlsl.output_name, &hlsl.output_fields, kind, false);
349 out.push('\n');
350
351 match kind {
352 ShaderKind::Vertex => out.push_str("@vertex\n"),
353 ShaderKind::Pixel => out.push_str("@fragment\n"),
354 }
355 out.push_str(&format!("fn main(input: {}) -> {} {{\n", hlsl.input_name, hlsl.output_name));
356 out.push_str(&format!(" var output: {};\n", hlsl.output_name));
357 for line in &hlsl.body_lines {
358 out.push_str(line);
359 out.push('\n');
360 }
361 out.push_str(" return output;\n}\n");
362 out
363}
364
365fn emit_struct(out: &mut String, name: &str, fields: &[StructField], kind: ShaderKind, is_input: bool) {
366 out.push_str(&format!("struct {} {{\n", name));
367 for field in fields {
368 let attr = semantic_to_wgsl_attribute(&field.semantic, kind, is_input);
369 out.push_str(&format!(" {}{}: {},\n", attr, field.name, wgsl_type(&field.ty)));
370 }
371 out.push_str("};\n");
372}
373
374fn wgsl_type(ty: &str) -> &'static str {
375 match ty {
376 "float4" | "vec4<f32>" => "vec4<f32>",
377 "float3" | "vec3<f32>" => "vec3<f32>",
378 "float2" | "vec2<f32>" => "vec2<f32>",
379 "float" | "f32" => "f32",
380 _ => "vec4<f32>",
381 }
382}
383
384fn semantic_to_wgsl_attribute(semantic: &str, kind: ShaderKind, is_input: bool) -> String {
385 if kind == ShaderKind::Pixel && !is_input {
386 if let Some(n) = semantic.strip_prefix("COLOR").and_then(|s| s.parse::<u32>().ok()) {
387 return format!("@location({}) ", n);
388 }
389 }
390 if semantic.starts_with("POSITION") {
391 if kind == ShaderKind::Vertex && !is_input {
392 return "@builtin(position) ".to_string();
393 }
394 return "@location(0) ".to_string();
395 }
396 if semantic == "NORMAL" {
397 return "@location(1) ".to_string();
398 }
399 if let Some(n) = semantic.strip_prefix("COLOR").and_then(|s| s.parse::<u32>().ok()) {
400 return format!("@location({}) ", 2 + n);
401 }
402 if let Some(n) = semantic.strip_prefix("TEXCOORD").and_then(|s| s.parse::<u32>().ok()) {
403 return format!("@location({}) ", 4 + n);
404 }
405 "@location(15) ".to_string()
406}