1use std::collections::{BTreeMap, BTreeSet};
2
3use crate::ctab::{ConstantInfo, ConstantTable, RegisterSet, ValueType};
4use crate::disasm::{
5 mask_len, DeclUsage, Instruction, Opcode, RegisterKey, RegisterType, ResultModifier,
6 SamplerTextureType, ShaderKind, SourceModifier,
7};
8use crate::disasm::{parse_shader, ShaderModel};
9
10#[derive(Debug, Clone)]
11struct DeclInfo {
12 reg: RegisterKey,
13 semantic: String,
14}
15
16#[derive(Debug, Clone)]
17struct DefFloat {
18 values: [f32; 4],
19}
20
21#[derive(Debug, Clone)]
22struct DefInt {
23 values: [i32; 4],
24}
25
26#[derive(Debug, Clone)]
27struct DefBool {
28 value: bool,
29}
30
31#[derive(Debug, Clone)]
32struct Context<'a> {
33 shader: &'a ShaderModel,
34 ctab: Option<&'a ConstantTable>,
35 decls: BTreeMap<RegisterKey, DeclInfo>,
36 sampler_decls: BTreeMap<u16, SamplerTextureType>,
37 def_float: BTreeMap<u16, DefFloat>,
38 def_int: BTreeMap<u16, DefInt>,
39 def_bool: BTreeMap<u16, DefBool>,
40 used_inputs: BTreeSet<RegisterKey>,
41 used_outputs: BTreeSet<RegisterKey>,
42 used_temps: BTreeSet<RegisterKey>,
43 used_samplers: BTreeSet<u16>,
44 used_consts: BTreeSet<u16>,
45 used_int_consts: BTreeSet<u16>,
46 used_bool_consts: BTreeSet<u16>,
47 uses_lit: bool,
48 uses_dst: bool,
49}
50
51pub fn decompile_wgsl(data: &[u8], ctab: Option<&ConstantTable>) -> String {
52 match parse_shader(data) {
53 Ok(shader) => {
54 let ctx = analyze(&shader, ctab);
55 emit_wgsl(&ctx)
56 }
57 Err(_) => String::new(),
58 }
59}
60
61fn emit_wgsl(ctx: &Context<'_>) -> String {
62 let mut out = String::new();
63 emit_register_bindings(&mut out, ctx);
64 emit_texture_bindings(&mut out, ctx);
65 emit_def_constants(&mut out, ctx);
66 emit_helpers(&mut out, ctx);
67 emit_structs(&mut out, ctx);
68 emit_main(&mut out, ctx);
69 out
70}
71
72fn emit_register_bindings(out: &mut String, ctx: &Context<'_>) {
73 if !ctx.used_consts.is_empty() {
74 out.push_str("struct FloatRegs {\n");
75 out.push_str(" c: array<vec4<f32>, 256>,\n");
76 out.push_str("};\n");
77 out.push_str("@group(0) @binding(0) var<uniform> float_regs: FloatRegs;\n\n");
78 }
79 if !ctx.used_int_consts.is_empty() {
80 out.push_str("struct IntRegs {\n");
81 out.push_str(" i: array<vec4<i32>, 16>,\n");
82 out.push_str("};\n");
83 out.push_str("@group(0) @binding(1) var<uniform> int_regs: IntRegs;\n\n");
84 }
85 if !ctx.used_bool_consts.is_empty() {
86 out.push_str("struct BoolRegs {\n");
87 out.push_str(" b: array<u32, 16>,\n");
88 out.push_str("};\n");
89 out.push_str("@group(0) @binding(2) var<uniform> bool_regs: BoolRegs;\n\n");
90 }
91}
92
93fn emit_texture_bindings(out: &mut String, ctx: &Context<'_>) {
94 for sampler in &ctx.used_samplers {
95 let tex_binding = (*sampler as u32) * 2;
96 let samp_binding = tex_binding + 1;
97 let tex_ty = wgsl_texture_type(sampler_texture_type(ctx, *sampler));
98 out.push_str(&format!("@group(1) @binding({}) var tex_s{}: {};\n", tex_binding, sampler, tex_ty));
99 out.push_str(&format!("@group(1) @binding({}) var samp_s{}: sampler;\n", samp_binding, sampler));
100 }
101 if !ctx.used_samplers.is_empty() {
102 out.push('\n');
103 }
104}
105
106fn emit_def_constants(out: &mut String, ctx: &Context<'_>) {
107 let mut any = false;
108 for (idx, def) in &ctx.def_float {
109 out.push_str(&format!(
110 "const c{}: vec4<f32> = vec4<f32>({}, {}, {}, {});\n",
111 idx,
112 fmt_f32(def.values[0]),
113 fmt_f32(def.values[1]),
114 fmt_f32(def.values[2]),
115 fmt_f32(def.values[3])
116 ));
117 any = true;
118 }
119 for (idx, def) in &ctx.def_int {
120 out.push_str(&format!(
121 "const i{}: vec4<i32> = vec4<i32>({}, {}, {}, {});\n",
122 idx, def.values[0], def.values[1], def.values[2], def.values[3]
123 ));
124 any = true;
125 }
126 for (idx, def) in &ctx.def_bool {
127 out.push_str(&format!("const b{}: bool = {};\n", idx, if def.value { "true" } else { "false" }));
128 any = true;
129 }
130 if any {
131 out.push('\n');
132 }
133}
134
135fn emit_helpers(out: &mut String, ctx: &Context<'_>) {
136 if ctx.uses_lit {
137 out.push_str("fn sm2_lit(v: vec4<f32>) -> vec4<f32> {\n");
138 out.push_str(" let y = max(v.x, 0.0);\n");
139 out.push_str(" let z = select(0.0, pow(max(v.y, 0.0), v.w), v.x > 0.0 && v.y > 0.0);\n");
140 out.push_str(" return vec4<f32>(1.0, y, z, 1.0);\n");
141 out.push_str("}\n\n");
142 }
143 if ctx.uses_dst {
144 out.push_str("fn sm2_dst(a: vec4<f32>, b: vec4<f32>) -> vec4<f32> {\n");
145 out.push_str(" return vec4<f32>(1.0, a.y * b.y, a.z, b.w);\n");
146 out.push_str("}\n\n");
147 }
148}
149
150fn emit_structs(out: &mut String, ctx: &Context<'_>) {
151 let input_name = input_struct_name(ctx.shader.kind);
152 let output_name = output_struct_name(ctx.shader.kind);
153
154 out.push_str(&format!("struct {} {{\n", input_name));
155 for decl in ctx.decls.values() {
156 let field = input_field_name(decl.reg);
157 let ty = wgsl_input_field_type(ctx, decl.reg);
158 let attr = input_attr(ctx.shader.kind, &decl.semantic);
159 out.push_str(&format!(" {} {}: {},\n", attr, field, ty));
160 }
161 out.push_str("};\n\n");
162
163 out.push_str(&format!("struct {} {{\n", output_name));
164 for reg in &ctx.used_outputs {
165 let field = output_field_name(*reg);
166 let sem = output_semantic(ctx, *reg);
167 let ty = wgsl_output_field_type(*reg);
168 let attr = output_attr(ctx.shader.kind, &sem, *reg);
169 out.push_str(&format!(" {} {}: {},\n", attr, field, ty));
170 }
171 out.push_str("};\n\n");
172}
173
174fn emit_main(out: &mut String, ctx: &Context<'_>) {
175 let stage = match ctx.shader.kind {
176 ShaderKind::Vertex => "vertex",
177 ShaderKind::Pixel => "fragment",
178 };
179 let input_name = input_struct_name(ctx.shader.kind);
180 let output_name = output_struct_name(ctx.shader.kind);
181 out.push_str(&format!("@{}\n", stage));
182 out.push_str(&format!("fn main(input: {}) -> {} {{\n", input_name, output_name));
183 out.push_str(&format!(" var output: {};\n", output_name));
184
185 for reg in &ctx.used_outputs {
186 out.push_str(&format!(" output.{} = {};\n", output_field_name(*reg), zero_value(wgsl_output_field_type(*reg))));
187 }
188 for reg in &ctx.used_temps {
189 out.push_str(&format!(" var {}: {} = {};\n", temp_name(*reg, ctx.shader.kind), temp_type(*reg), zero_value(temp_type(*reg))));
190 }
191 if !ctx.used_outputs.is_empty() || !ctx.used_temps.is_empty() {
192 out.push('\n');
193 }
194
195 let mut indent = 1usize;
196 for inst in &ctx.shader.instructions {
197 emit_instruction(out, ctx, inst, &mut indent);
198 }
199
200 out.push_str(" return output;\n");
201 out.push_str("}\n");
202}
203
204fn emit_instruction(out: &mut String, ctx: &Context<'_>, inst: &Instruction, indent: &mut usize) {
205 match inst.opcode {
206 Opcode::Comment | Opcode::Dcl | Opcode::Def | Opcode::DefI | Opcode::DefB | Opcode::Nop | Opcode::End => {}
207 Opcode::TexKill => {
208 if let Some(reg) = inst.dest_register() {
209 let expr = format!("{}{}", register_base(ctx, reg), wgsl_mask_suffix(inst.dest_write_mask()));
210 let n = mask_len(inst.dest_write_mask());
211 if n == 1 {
212 line(out, *indent, &format!("if ({} < 0.0) {{ discard; }}", expr));
213 } else {
214 line(out, *indent, &format!("if (any({} < {})) {{ discard; }}", expr, zero_vector(n)));
215 }
216 }
217 }
218 Opcode::If => {
219 let cond = source_expr(ctx, inst, 0, 1);
220 line(out, *indent, &format!("if ({}) {{", scalar_bool_expr(cond)));
221 *indent += 1;
222 }
223 Opcode::IfC => {
224 let a = source_expr(ctx, inst, 0, 4);
225 let b = source_expr(ctx, inst, 1, 4);
226 line(out, *indent, &format!("if ({}) {{", compare_all_expr(a, cmp_op(inst.comparison()), b, 4)));
227 *indent += 1;
228 }
229 Opcode::Else => {
230 if *indent > 0 { *indent -= 1; }
231 line(out, *indent, "} else {");
232 *indent += 1;
233 }
234 Opcode::EndIf => {
235 if *indent > 0 { *indent -= 1; }
236 line(out, *indent, "}");
237 }
238 Opcode::Break => line(out, *indent, "break;"),
239 Opcode::BreakC => {
240 let a = source_expr(ctx, inst, 0, 4);
241 let b = source_expr(ctx, inst, 1, 4);
242 line(out, *indent, &format!("if ({}) {{ break; }}", compare_all_expr(a, cmp_op(inst.comparison()), b, 4)));
243 }
244 Opcode::Rep => {
245 let n = source_expr(ctx, inst, 0, 1);
246 line(out, *indent, &format!("for (var _rep{}: i32 = 0; _rep{} < i32({}); _rep{} = _rep{} + 1) {{", inst.offset, inst.offset, n, inst.offset, inst.offset));
247 *indent += 1;
248 }
249 Opcode::Loop => {
250 line(out, *indent, &format!("for (var _loop{}: i32 = 0; ; _loop{} = _loop{} + 1) {{", inst.offset, inst.offset, inst.offset));
251 *indent += 1;
252 }
253 Opcode::EndRep | Opcode::EndLoop => {
254 if *indent > 0 { *indent -= 1; }
255 line(out, *indent, "}");
256 }
257 Opcode::Ret => line(out, *indent, "return output;"),
258 _ if inst.opcode.has_destination() => {
259 if let Some((dst, rhs)) = assignment_expr(ctx, inst) {
260 line(out, *indent, &format!("{} = {};", dst, rhs));
261 }
262 }
263 _ => {}
264 }
265}
266
267fn assignment_expr(ctx: &Context<'_>, inst: &Instruction) -> Option<(String, String)> {
268 let reg = inst.dest_register()?;
269 let mask = inst.dest_write_mask();
270 let dst_count = mask_len(mask);
271 let dst = format!("{}{}", register_base(ctx, reg), wgsl_mask_suffix(mask));
272
273 let (raw_rhs, raw_width) = match inst.opcode {
274 Opcode::Mov | Opcode::MovA => (source_expr(ctx, inst, 1, dst_count), dst_count),
275 Opcode::Add => (bin(ctx, inst, dst_count, "+"), dst_count),
276 Opcode::Sub => (bin(ctx, inst, dst_count, "-"), dst_count),
277 Opcode::Mul => (bin(ctx, inst, dst_count, "*"), dst_count),
278 Opcode::Mad => (format!("({} * {} + {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count), source_expr(ctx, inst, 3, dst_count)), dst_count),
279 Opcode::Rcp => (format!("(1.0 / {})", source_expr(ctx, inst, 1, 1)), 1),
280 Opcode::Rsq => (format!("inverseSqrt({})", source_expr(ctx, inst, 1, 1)), 1),
281 Opcode::Dp3 => (format!("dot({}, {})", source_expr(ctx, inst, 1, 3), source_expr(ctx, inst, 2, 3)), 1),
282 Opcode::Dp4 => (format!("dot({}, {})", source_expr(ctx, inst, 1, 4), source_expr(ctx, inst, 2, 4)), 1),
283 Opcode::Min => (format!("min({}, {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count)), dst_count),
284 Opcode::Max => (format!("max({}, {})", source_expr(ctx, inst, 1, dst_count), source_expr(ctx, inst, 2, dst_count)), dst_count),
285 Opcode::Slt => (select_float(ctx, inst, dst_count, "<"), dst_count),
286 Opcode::Sge => (select_float(ctx, inst, dst_count, ">="), dst_count),
287 Opcode::Exp | Opcode::ExpP => (format!("exp2({})", source_expr(ctx, inst, 1, 1)), 1),
288 Opcode::Log | Opcode::LogP => (format!("log2({})", source_expr(ctx, inst, 1, 1)), 1),
289 Opcode::Lit => (format!("sm2_lit({})", source_expr(ctx, inst, 1, 4)), 4),
290 Opcode::Dst => (format!("sm2_dst({}, {})", source_expr(ctx, inst, 1, 4), source_expr(ctx, inst, 2, 4)), 4),
291 Opcode::Lrp => (format!("mix({}, {}, {})", source_expr(ctx, inst, 3, dst_count), source_expr(ctx, inst, 2, dst_count), source_expr(ctx, inst, 1, dst_count)), dst_count),
292 Opcode::Frc => (format!("fract({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
293 Opcode::Pow => (format!("pow({}, {})", source_expr(ctx, inst, 1, 1), source_expr(ctx, inst, 2, 1)), 1),
294 Opcode::Crs => (format!("cross({}, {})", source_expr(ctx, inst, 1, 3), source_expr(ctx, inst, 2, 3)), 3),
295 Opcode::Sgn => (format!("sign({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
296 Opcode::Abs => (format!("abs({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
297 Opcode::Nrm => (format!("normalize({})", source_expr(ctx, inst, 1, 3)), 3),
298 Opcode::SinCos => {
299 let width = if dst_count <= 1 { 1 } else { 2 };
300 (sincos_expr(ctx, inst, dst_count), width)
301 }
302 Opcode::Cmp => (select_expr(ctx, inst, dst_count, ">="), dst_count),
303 Opcode::Cnd => (select_expr(ctx, inst, dst_count, ">"), dst_count),
304 Opcode::Dp2Add => (format!("(dot({}, {}) + {})", source_expr(ctx, inst, 1, 2), source_expr(ctx, inst, 2, 2), source_expr(ctx, inst, 3, 1)), 1),
305 Opcode::M4x4 => (matrix_mul_expr(ctx, inst, 4, 4), 4),
306 Opcode::M4x3 => (matrix_mul_expr(ctx, inst, 4, 3), 3),
307 Opcode::M3x4 => (matrix_mul_expr(ctx, inst, 3, 4), 4),
308 Opcode::M3x3 => (matrix_mul_expr(ctx, inst, 3, 3), 3),
309 Opcode::M3x2 => (matrix_mul_expr(ctx, inst, 3, 2), 2),
310 Opcode::Tex | Opcode::TexLdl | Opcode::TexLdd => (texture_expr(ctx, inst), 4),
311 Opcode::TexCoord => (source_expr(ctx, inst, 1, dst_count), dst_count),
312 Opcode::Dsx => (format!("dpdx({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
313 Opcode::Dsy => (format!("dpdy({})", source_expr(ctx, inst, 1, dst_count)), dst_count),
314 _ => return None,
315 };
316
317 let mut rhs = coerce_expr_width(raw_rhs, raw_width, dst_count);
318 rhs = apply_result_modifier(rhs, inst.dest_modifier(), dst_count);
319 Some((dst, rhs))
320}
321fn bin(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
322 format!("({} {} {})", source_expr(ctx, inst, 1, n), op, source_expr(ctx, inst, 2, n))
323}
324
325fn select_float(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
326 let a = source_expr(ctx, inst, 1, n);
327 let b = source_expr(ctx, inst, 2, n);
328 let zero = zero_vector(n);
329 let one = one_vector(n);
330 format!("select({}, {}, {} {} {})", zero, one, a, op, b)
331}
332
333fn select_expr(ctx: &Context<'_>, inst: &Instruction, n: usize, op: &str) -> String {
334 let a = source_expr(ctx, inst, 1, n);
335 let yes = source_expr(ctx, inst, 2, n);
336 let no = source_expr(ctx, inst, 3, n);
337 let pivot = if op == ">" { half_vector(n) } else { zero_vector(n) };
338 format!("select({}, {}, {} {} {})", no, yes, a, op, pivot)
339}
340
341fn sincos_expr(ctx: &Context<'_>, inst: &Instruction, n: usize) -> String {
342 if n <= 1 {
343 format!("cos({})", source_expr(ctx, inst, 1, 1))
344 } else {
345 format!("vec2<f32>(cos({}), sin({}))", source_expr(ctx, inst, 1, 1), source_expr(ctx, inst, 1, 1))
346 }
347}
348
349fn matrix_mul_expr(ctx: &Context<'_>, inst: &Instruction, vec_len: usize, out_len: usize) -> String {
350 let v = source_expr(ctx, inst, 1, vec_len);
351 let mut rows = Vec::new();
352 if let Some(reg) = inst.source_register(2) {
353 if reg.ty == RegisterType::Const {
354 for i in 0..out_len {
355 rows.push(format!("dot({}, {})", v, const_row_expr(ctx, reg.number + i as u16)));
356 }
357 } else {
358 for _ in 0..out_len {
359 rows.push(format!("dot({}, {})", v, source_expr(ctx, inst, 2, vec_len)));
360 }
361 }
362 }
363 vector_constructor(&rows)
364}
365
366fn texture_expr(ctx: &Context<'_>, inst: &Instruction) -> String {
367 let sampler = inst.source_register(2).map(|r| r.number).unwrap_or(0);
368 let sampler_ty = sampler_texture_type(ctx, sampler);
369 let dim = sampler_ty.hlsl_dim();
370 let tex = format!("tex_s{}", sampler);
371 let samp = format!("samp_s{}", sampler);
372 let controls = inst.texld_controls();
373 match inst.opcode {
374 Opcode::TexLdl => {
375 let coord = source_expr(ctx, inst, 1, dim);
376 let lod = source_component_expr(ctx, inst, 1, 3);
377 format!("textureSampleLevel({}, {}, {}, {})", tex, samp, coord, lod)
378 }
379 Opcode::TexLdd => {
380 let coord = source_expr(ctx, inst, 1, dim);
381 let ddx = source_expr(ctx, inst, 3, dim);
382 let ddy = source_expr(ctx, inst, 4, dim);
383 format!("textureSampleGrad({}, {}, {}, {}, {})", tex, samp, coord, ddx, ddy)
384 }
385 Opcode::Tex if controls == 1 => {
386 let coord = source_expr(ctx, inst, 1, dim);
387 let w = source_component_expr(ctx, inst, 1, 3);
388 format!("textureSample({}, {}, ({} / {}))", tex, samp, coord, w)
389 }
390 Opcode::Tex if controls == 2 => {
391 let coord = source_expr(ctx, inst, 1, dim);
392 let bias = source_component_expr(ctx, inst, 1, 3);
393 format!("textureSampleBias({}, {}, {}, {})", tex, samp, coord, bias)
394 }
395 _ => {
396 let coord = source_expr(ctx, inst, 1, dim);
397 format!("textureSample({}, {}, {})", tex, samp, coord)
398 }
399 }
400}
401fn source_expr(ctx: &Context<'_>, inst: &Instruction, param_index: usize, count: usize) -> String {
402 let Some(reg) = inst.source_register(param_index) else { return zero_vector(count); };
403 let count = count.clamp(1, 4);
404 let mut expr = register_base(ctx, reg);
405 if !register_is_scalar_source(ctx, reg) {
406 let swz = inst.source_swizzle(param_index);
407 expr.push_str(&wgsl_source_swizzle_suffix(swz, count));
408 }
409 apply_source_modifier(expr, inst.source_modifier(param_index), count)
410}
411
412fn source_component_expr(ctx: &Context<'_>, inst: &Instruction, param_index: usize, component: usize) -> String {
413 let Some(reg) = inst.source_register(param_index) else { return "0.0".to_string(); };
414 let mut expr = register_base(ctx, reg);
415 if !register_is_scalar_source(ctx, reg) {
416 let swz = inst.source_swizzle(param_index);
417 let comp = swz[component.clamp(0, 3)];
418 expr.push_str(component_suffix(comp));
419 }
420 apply_source_modifier(expr, inst.source_modifier(param_index), 1)
421}
422
423fn register_is_scalar_source(ctx: &Context<'_>, reg: RegisterKey) -> bool {
424 match reg.ty {
425 RegisterType::ConstBool | RegisterType::Loop | RegisterType::Label => true,
426 RegisterType::MiscType if ctx.shader.kind == ShaderKind::Pixel && reg.number == 1 => true,
427 _ => false,
428 }
429}
430fn register_base(ctx: &Context<'_>, reg: RegisterKey) -> String {
431 match reg.ty {
432 RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Predicate => temp_name(reg, ctx.shader.kind),
433 RegisterType::Texture if ctx.shader.kind == ShaderKind::Vertex => temp_name(reg, ctx.shader.kind),
434 RegisterType::Texture | RegisterType::Input | RegisterType::MiscType => format!("input.{}", input_field_name(reg)),
435 RegisterType::Const => const_row_expr(ctx, reg.number),
436 RegisterType::ConstInt => int_const_expr(ctx, reg.number),
437 RegisterType::ConstBool => bool_const_expr(ctx, reg.number),
438 RegisterType::Sampler => format!("samp_s{}", reg.number),
439 RegisterType::ColorOut | RegisterType::DepthOut | RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output => format!("output.{}", output_field_name(reg)),
440 RegisterType::Loop => "_loop".to_string(),
441 RegisterType::Label => format!("label{}", reg.number),
442 _ => format!("u{}", reg.number),
443 }
444}
445
446fn const_row_expr(ctx: &Context<'_>, index: u16) -> String {
447 if ctx.def_float.contains_key(&index) {
448 return format!("c{}", index);
449 }
450 format!("float_regs.c[{}]", index)
451}
452
453fn int_const_expr(ctx: &Context<'_>, index: u16) -> String {
454 if ctx.def_int.contains_key(&index) {
455 return format!("i{}", index);
456 }
457 format!("int_regs.i[{}]", index)
458}
459
460fn bool_const_expr(ctx: &Context<'_>, index: u16) -> String {
461 if ctx.def_bool.contains_key(&index) {
462 return format!("b{}", index);
463 }
464 format!("(bool_regs.b[{}] != 0u)", index)
465}
466
467fn analyze<'a>(shader: &'a ShaderModel, ctab: Option<&'a ConstantTable>) -> Context<'a> {
468 let mut ctx = Context {
469 shader,
470 ctab,
471 decls: BTreeMap::new(),
472 sampler_decls: BTreeMap::new(),
473 def_float: BTreeMap::new(),
474 def_int: BTreeMap::new(),
475 def_bool: BTreeMap::new(),
476 used_inputs: BTreeSet::new(),
477 used_outputs: BTreeSet::new(),
478 used_temps: BTreeSet::new(),
479 used_samplers: BTreeSet::new(),
480 used_consts: BTreeSet::new(),
481 used_int_consts: BTreeSet::new(),
482 used_bool_consts: BTreeSet::new(),
483 uses_lit: false,
484 uses_dst: false,
485 };
486
487 for inst in &shader.instructions {
488 match inst.opcode {
489 Opcode::Dcl => {
490 if let Some(reg) = inst.dest_register() {
491 if reg.ty == RegisterType::Sampler {
492 ctx.sampler_decls.insert(reg.number, inst.decl_sampler_type());
493 ctx.used_samplers.insert(reg.number);
494 } else {
495 let semantic = semantic_from_decl(shader.kind, inst);
496 ctx.decls.insert(reg, DeclInfo { reg, semantic });
497 }
498 }
499 }
500 Opcode::Def => {
501 if let Some(reg) = inst.dest_register() {
502 if reg.ty == RegisterType::Const {
503 ctx.def_float.insert(reg.number, DefFloat {
504 values: [inst.get_float_param(1), inst.get_float_param(2), inst.get_float_param(3), inst.get_float_param(4)],
505 });
506 }
507 }
508 }
509 Opcode::DefI => {
510 if let Some(reg) = inst.dest_register() {
511 if reg.ty == RegisterType::ConstInt {
512 ctx.def_int.insert(reg.number, DefInt {
513 values: [inst.get_int_param(1), inst.get_int_param(2), inst.get_int_param(3), inst.get_int_param(4)],
514 });
515 }
516 }
517 }
518 Opcode::DefB => {
519 if let Some(reg) = inst.dest_register() {
520 if reg.ty == RegisterType::ConstBool {
521 ctx.def_bool.insert(reg.number, DefBool { value: inst.get_int_param(1) != 0 });
522 }
523 }
524 }
525 Opcode::Lit => ctx.uses_lit = true,
526 Opcode::Dst => ctx.uses_dst = true,
527 _ => {}
528 }
529
530 if inst.opcode == Opcode::End || inst.opcode == Opcode::Comment {
531 continue;
532 }
533
534 if inst.opcode == Opcode::TexKill {
535 if let Some(reg) = inst.dest_register() {
536 classify_source_register(&mut ctx, reg);
537 }
538 continue;
539 }
540
541 if inst.opcode.has_destination() {
542 if let Some(reg) = inst.dest_register() {
543 classify_dest_register(&mut ctx, reg);
544 }
545 }
546
547 let first_src = if inst.opcode.has_destination() { 1 } else { 0 };
548 for pi in first_src..inst.params.len() {
549 if matches!(inst.opcode, Opcode::Dcl | Opcode::Def | Opcode::DefI | Opcode::DefB) {
550 continue;
551 }
552 if let Some(reg) = inst.source_register(pi) {
553 classify_source_register(&mut ctx, reg);
554 }
555 }
556 }
557
558 infer_missing_decls(&mut ctx);
559 ctx
560}
561
562fn classify_dest_register(ctx: &mut Context<'_>, reg: RegisterKey) {
563 match reg.ty {
564 RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Texture | RegisterType::Predicate => {
565 if !(ctx.shader.kind == ShaderKind::Pixel && reg.ty == RegisterType::Texture) {
566 ctx.used_temps.insert(reg);
567 }
568 }
569 RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output | RegisterType::ColorOut | RegisterType::DepthOut => {
570 ctx.used_outputs.insert(reg);
571 }
572 _ => {}
573 }
574}
575
576fn classify_source_register(ctx: &mut Context<'_>, reg: RegisterKey) {
577 match reg.ty {
578 RegisterType::Input | RegisterType::MiscType => {
579 ctx.used_inputs.insert(reg);
580 }
581 RegisterType::Texture => {
582 if ctx.shader.kind == ShaderKind::Pixel {
583 ctx.used_inputs.insert(reg);
584 } else {
585 ctx.used_temps.insert(reg);
586 }
587 }
588 RegisterType::Temp | RegisterType::TempFloat16 | RegisterType::Predicate => {
589 ctx.used_temps.insert(reg);
590 }
591 RegisterType::Const => {
592 ctx.used_consts.insert(reg.number);
593 }
594 RegisterType::ConstInt => {
595 ctx.used_int_consts.insert(reg.number);
596 }
597 RegisterType::ConstBool => {
598 ctx.used_bool_consts.insert(reg.number);
599 }
600 RegisterType::Sampler => {
601 ctx.used_samplers.insert(reg.number);
602 }
603 RegisterType::RastOut | RegisterType::AttrOut | RegisterType::Output | RegisterType::ColorOut | RegisterType::DepthOut => {
604 ctx.used_outputs.insert(reg);
605 }
606 _ => {}
607 }
608}
609
610fn infer_missing_decls(ctx: &mut Context<'_>) {
611 for reg in ctx.used_inputs.clone() {
612 if ctx.decls.contains_key(®) {
613 continue;
614 }
615 let semantic = inferred_input_semantic(ctx.shader.kind, reg);
616 ctx.decls.insert(reg, DeclInfo { reg, semantic });
617 }
618 for sampler in ctx.used_samplers.clone() {
619 ctx.sampler_decls.entry(sampler).or_insert(SamplerTextureType::TwoD);
620 }
621 if ctx.shader.kind == ShaderKind::Pixel && ctx.used_outputs.is_empty() {
622 ctx.used_outputs.insert(RegisterKey { ty: RegisterType::ColorOut, number: 0 });
623 }
624}
625
626fn semantic_from_decl(kind: ShaderKind, inst: &Instruction) -> String {
627 if let Some(reg) = inst.dest_register() {
628 if reg.ty == RegisterType::MiscType {
629 return match reg.number {
630 0 => "VPOS".to_string(),
631 1 => "VFACE".to_string(),
632 _ => format!("TEXCOORD{}", reg.number),
633 };
634 }
635 if kind == ShaderKind::Pixel {
636 match reg.ty {
637 RegisterType::Texture => return format!("TEXCOORD{}", reg.number),
638 RegisterType::Input => return format!("COLOR{}", reg.number),
639 _ => {}
640 }
641 }
642 }
643 let usage = inst.decl_usage();
644 let index = inst.decl_index();
645 let prefix = usage.semantic_prefix();
646 if index == 0 && !matches!(usage, DeclUsage::TexCoord | DeclUsage::Color | DeclUsage::Position) {
647 prefix.to_string()
648 } else {
649 format!("{}{}", prefix, index)
650 }
651}
652
653fn inferred_input_semantic(kind: ShaderKind, reg: RegisterKey) -> String {
654 match (kind, reg.ty) {
655 (ShaderKind::Pixel, RegisterType::Texture) => format!("TEXCOORD{}", reg.number),
656 (ShaderKind::Pixel, RegisterType::Input) => format!("COLOR{}", reg.number),
657 (_, RegisterType::Input) => match reg.number {
658 0 => "POSITION0".to_string(),
659 1 => "NORMAL0".to_string(),
660 n => format!("TEXCOORD{}", n.saturating_sub(2)),
661 },
662 (_, RegisterType::MiscType) => match reg.number {
663 0 => "VPOS".to_string(),
664 1 => "VFACE".to_string(),
665 _ => format!("TEXCOORD{}", reg.number),
666 },
667 _ => format!("TEXCOORD{}", reg.number),
668 }
669}
670
671fn output_semantic(ctx: &Context<'_>, reg: RegisterKey) -> String {
672 if let Some(decl) = ctx.decls.get(®) {
673 return decl.semantic.clone();
674 }
675 match reg.ty {
676 RegisterType::ColorOut => format!("COLOR{}", reg.number),
677 RegisterType::DepthOut => "DEPTH".to_string(),
678 RegisterType::RastOut => match reg.number {
679 0 => "POSITION0".to_string(),
680 1 => "FOG".to_string(),
681 2 => "PSIZE".to_string(),
682 n => format!("TEXCOORD{}", n),
683 },
684 RegisterType::AttrOut => format!("COLOR{}", reg.number),
685 RegisterType::Output => format!("TEXCOORD{}", reg.number),
686 _ => format!("TEXCOORD{}", reg.number),
687 }
688}
689
690fn input_attr(kind: ShaderKind, semantic: &str) -> String {
691 let upper = semantic.to_ascii_uppercase();
692 if kind == ShaderKind::Pixel && (upper == "VPOS" || upper == "POSITION" || upper == "POSITION0") {
693 return "@builtin(position)".to_string();
694 }
695 if kind == ShaderKind::Pixel && upper == "VFACE" {
696 return "@builtin(front_facing)".to_string();
697 }
698 format!("@location({})", semantic_location(semantic))
699}
700fn output_attr(kind: ShaderKind, semantic: &str, reg: RegisterKey) -> String {
701 if kind == ShaderKind::Vertex && is_position_semantic(semantic) {
702 "@builtin(position)".to_string()
703 } else if reg.ty == RegisterType::DepthOut || semantic == "DEPTH" {
704 "@builtin(frag_depth)".to_string()
705 } else {
706 format!("@location({})", semantic_location(semantic))
707 }
708}
709
710fn semantic_location(semantic: &str) -> u32 {
711 let upper = semantic.to_ascii_uppercase();
712 if let Some(n) = parse_semantic_index(&upper, "POSITION") {
713 return n;
714 }
715 if let Some(n) = parse_semantic_index(&upper, "NORMAL") {
716 return 1 + n;
717 }
718 if let Some(n) = parse_semantic_index(&upper, "COLOR") {
719 return 2 + n;
720 }
721 if let Some(n) = parse_semantic_index(&upper, "TEXCOORD") {
722 return 4 + n;
723 }
724 if let Some(n) = parse_semantic_index(&upper, "BLENDWEIGHT") {
725 return 12 + n;
726 }
727 if let Some(n) = parse_semantic_index(&upper, "BLENDINDICES") {
728 return 14 + n;
729 }
730 if let Some(n) = parse_semantic_index(&upper, "TANGENT") {
731 return 16 + n;
732 }
733 if let Some(n) = parse_semantic_index(&upper, "BINORMAL") {
734 return 18 + n;
735 }
736 if upper == "FOG" {
737 return 20;
738 }
739 if upper == "PSIZE" {
740 return 21;
741 }
742 31
743}
744
745fn parse_semantic_index(s: &str, prefix: &str) -> Option<u32> {
746 if !s.starts_with(prefix) {
747 return None;
748 }
749 let rest = &s[prefix.len()..];
750 if rest.is_empty() {
751 Some(0)
752 } else {
753 rest.parse::<u32>().ok()
754 }
755}
756
757fn is_position_semantic(s: &str) -> bool {
758 let upper = s.to_ascii_uppercase();
759 upper == "POSITION" || upper == "POSITION0" || upper == "POSITIONT" || upper == "POSITIONT0"
760}
761
762fn wgsl_input_field_type(ctx: &Context<'_>, reg: RegisterKey) -> &'static str {
763 match reg.ty {
764 RegisterType::MiscType if reg.number == 1 => "bool",
765 _ => {
766 if let Some(decl) = ctx.decls.get(®) {
767 if decl.semantic.starts_with("TEXCOORD") || decl.semantic.starts_with("COLOR") {
768 return "vec4<f32>";
769 }
770 }
771 "vec4<f32>"
772 }
773 }
774}
775
776fn wgsl_output_field_type(reg: RegisterKey) -> &'static str {
777 match reg.ty {
778 RegisterType::DepthOut => "f32",
779 RegisterType::RastOut if reg.number == 1 || reg.number == 2 => "f32",
780 _ => "vec4<f32>",
781 }
782}
783
784fn input_struct_name(kind: ShaderKind) -> &'static str {
785 match kind {
786 ShaderKind::Vertex => "VSInput",
787 ShaderKind::Pixel => "PSInput",
788 }
789}
790
791fn output_struct_name(kind: ShaderKind) -> &'static str {
792 match kind {
793 ShaderKind::Vertex => "VSOutput",
794 ShaderKind::Pixel => "PSOutput",
795 }
796}
797
798fn input_field_name(reg: RegisterKey) -> String {
799 match reg.ty {
800 RegisterType::Input => format!("v{}", reg.number),
801 RegisterType::Texture => format!("t{}", reg.number),
802 RegisterType::MiscType => match reg.number {
803 0 => "vPos".to_string(),
804 1 => "vFace".to_string(),
805 _ => format!("vMisc{}", reg.number),
806 },
807 _ => format!("in{}", reg.number),
808 }
809}
810
811fn output_field_name(reg: RegisterKey) -> String {
812 match reg.ty {
813 RegisterType::ColorOut => format!("oC{}", reg.number),
814 RegisterType::DepthOut => "oDepth".to_string(),
815 RegisterType::RastOut => match reg.number {
816 0 => "oPos".to_string(),
817 1 => "oFog".to_string(),
818 2 => "oPts".to_string(),
819 _ => format!("o{}", reg.number),
820 },
821 RegisterType::AttrOut => format!("oD{}", reg.number),
822 RegisterType::Output => format!("o{}", reg.number),
823 _ => format!("out{}", reg.number),
824 }
825}
826
827fn temp_name(reg: RegisterKey, kind: ShaderKind) -> String {
828 match reg.ty {
829 RegisterType::Texture if kind == ShaderKind::Vertex => format!("a{}", reg.number),
830 RegisterType::Predicate => format!("p{}", reg.number),
831 RegisterType::TempFloat16 => format!("h{}", reg.number),
832 _ => format!("r{}", reg.number),
833 }
834}
835
836fn temp_type(reg: RegisterKey) -> &'static str {
837 match reg.ty {
838 RegisterType::Predicate => "vec4<bool>",
839 RegisterType::Texture => "vec4<f32>",
840 RegisterType::TempFloat16 => "vec4<f32>",
841 _ => "vec4<f32>",
842 }
843}
844
845fn zero_value(ty: &str) -> &'static str {
846 match ty {
847 "f32" => "0.0",
848 "bool" => "false",
849 "vec2<f32>" => "vec2<f32>(0.0)",
850 "vec3<f32>" => "vec3<f32>(0.0)",
851 "vec4<bool>" => "vec4<bool>(false)",
852 _ => "vec4<f32>(0.0)",
853 }
854}
855
856fn zero_vector(n: usize) -> String {
857 match n {
858 1 => "0.0".to_string(),
859 2 => "vec2<f32>(0.0)".to_string(),
860 3 => "vec3<f32>(0.0)".to_string(),
861 _ => "vec4<f32>(0.0)".to_string(),
862 }
863}
864
865fn one_vector(n: usize) -> String {
866 match n {
867 1 => "1.0".to_string(),
868 2 => "vec2<f32>(1.0)".to_string(),
869 3 => "vec3<f32>(1.0)".to_string(),
870 _ => "vec4<f32>(1.0)".to_string(),
871 }
872}
873
874fn half_vector(n: usize) -> String {
875 match n {
876 1 => "0.5".to_string(),
877 2 => "vec2<f32>(0.5)".to_string(),
878 3 => "vec3<f32>(0.5)".to_string(),
879 _ => "vec4<f32>(0.5)".to_string(),
880 }
881}
882
883fn vector_constructor(values: &[String]) -> String {
884 match values.len() {
885 0 => "0.0".to_string(),
886 1 => values[0].clone(),
887 2 => format!("vec2<f32>({}, {})", values[0], values[1]),
888 3 => format!("vec3<f32>({}, {}, {})", values[0], values[1], values[2]),
889 _ => format!("vec4<f32>({}, {}, {}, {})", values[0], values[1], values[2], values[3]),
890 }
891}
892
893fn wgsl_mask_suffix(mask: u8) -> String {
894 let m = if mask == 0 { 0xf } else { mask };
895 if m == 0xf {
896 String::new()
897 } else {
898 let mut s = String::from(".");
899 for (i, c) in ['x', 'y', 'z', 'w'].iter().enumerate() {
900 if (m & (1 << i)) != 0 {
901 s.push(*c);
902 }
903 }
904 s
905 }
906}
907
908fn wgsl_source_swizzle_suffix(swizzle: [usize; 4], count: usize) -> String {
909 let count = count.clamp(1, 4);
910 let identity = [0usize, 1, 2, 3];
911 if count == 4 && swizzle == identity {
912 return String::new();
913 }
914 let names = ['x', 'y', 'z', 'w'];
915 let mut s = String::from(".");
916 for i in 0..count {
917 s.push(names[swizzle[i]]);
918 }
919 s
920}
921
922fn component_suffix(component: usize) -> &'static str {
923 match component {
924 0 => ".x",
925 1 => ".y",
926 2 => ".z",
927 3 => ".w",
928 _ => ".x",
929 }
930}
931
932fn coerce_expr_width(expr: String, src_width: usize, dst_width: usize) -> String {
933 let src_width = src_width.clamp(1, 4);
934 let dst_width = dst_width.clamp(1, 4);
935 if src_width == dst_width {
936 return expr;
937 }
938 if src_width == 1 {
939 return match dst_width {
940 1 => expr,
941 2 => format!("vec2<f32>({})", expr),
942 3 => format!("vec3<f32>({})", expr),
943 _ => format!("vec4<f32>({})", expr),
944 };
945 }
946 if src_width > dst_width {
947 let suffix = match dst_width {
948 1 => ".x",
949 2 => ".xy",
950 3 => ".xyz",
951 _ => "",
952 };
953 if suffix.is_empty() { expr } else { format!("({}){}", expr, suffix) }
954 } else {
955 match dst_width {
956 2 => format!("vec2<f32>({}, 0.0)", expr),
957 3 => format!("vec3<f32>({}, 0.0)", expr),
958 _ => format!("vec4<f32>({}, 0.0)", expr),
959 }
960 }
961}
962fn apply_source_modifier(expr: String, modifier: SourceModifier, count: usize) -> String {
963 match modifier {
964 SourceModifier::None => expr,
965 SourceModifier::Negate => format!("-({})", expr),
966 SourceModifier::Bias => format!("(({}) - {})", expr, half_vector(count)),
967 SourceModifier::BiasAndNegate => format!("-(({}) - {})", expr, half_vector(count)),
968 SourceModifier::Sign => format!("((({}) - {}) * {})", expr, half_vector(count), one_vector(count)),
969 SourceModifier::SignAndNegate => format!("-((({}) - {}) * {})", expr, half_vector(count), one_vector(count)),
970 SourceModifier::Complement => format!("({} - ({}))", one_vector(count), expr),
971 SourceModifier::X2 => format!("(({}) * 2.0)", expr),
972 SourceModifier::X2AndNegate => format!("-(({}) * 2.0)", expr),
973 SourceModifier::DivideByZ => format!("(({}) / ({}).z)", expr, expr),
974 SourceModifier::DivideByW => format!("(({}) / ({}).w)", expr, expr),
975 SourceModifier::Abs => format!("abs({})", expr),
976 SourceModifier::AbsAndNegate => format!("-abs({})", expr),
977 SourceModifier::Not => format!("!({})", expr),
978 SourceModifier::Unknown(_) => expr,
979 }
980}
981
982fn apply_result_modifier(expr: String, modifier: ResultModifier, count: usize) -> String {
983 if modifier.saturate {
984 format!("clamp({}, {}, {})", expr, zero_vector(count), one_vector(count))
985 } else {
986 expr
987 }
988}
989
990fn scalar_bool_expr(expr: String) -> String {
991 format!("({} != 0.0)", expr)
992}
993
994fn line(out: &mut String, indent: usize, s: &str) {
995 for _ in 0..indent {
996 out.push_str(" ");
997 }
998 out.push_str(s);
999 out.push('\n');
1000}
1001
1002fn compare_all_expr(a: String, op: &str, b: String, width: usize) -> String {
1003 if width <= 1 {
1004 format!("{} {} {}", a, op, b)
1005 } else {
1006 format!("all({} {} {})", a, op, b)
1007 }
1008}
1009
1010fn cmp_op(code: u8) -> &'static str {
1011 match code {
1012 1 => ">",
1013 2 => "==",
1014 3 => ">=",
1015 4 => "<",
1016 5 => "!=",
1017 6 => "<=",
1018 _ => "!=",
1019 }
1020}
1021
1022fn sampler_texture_type(ctx: &Context<'_>, sampler: u16) -> SamplerTextureType {
1023 if let Some(c) = sampler_constant(ctx, sampler) {
1024 if let Some(t) = &c.type_info {
1025 return match t.value_type {
1026 ValueType::SamplerCube => SamplerTextureType::Cube,
1027 ValueType::Sampler3D => SamplerTextureType::Volume,
1028 _ => SamplerTextureType::TwoD,
1029 };
1030 }
1031 }
1032 ctx.sampler_decls.get(&sampler).copied().unwrap_or(SamplerTextureType::TwoD)
1033}
1034
1035fn wgsl_texture_type(ty: SamplerTextureType) -> &'static str {
1036 match ty {
1037 SamplerTextureType::Cube => "texture_cube<f32>",
1038 SamplerTextureType::Volume => "texture_3d<f32>",
1039 SamplerTextureType::TwoD | SamplerTextureType::Unknown => "texture_2d<f32>",
1040 }
1041}
1042
1043fn sampler_constant<'a>(ctx: &'a Context<'_>, sampler: u16) -> Option<&'a ConstantInfo> {
1044 ctx.ctab?.constants.iter().find(|c| c.register_set == RegisterSet::Sampler && c.register_index == sampler)
1045}
1046
1047fn fmt_f32(v: f32) -> String {
1048 if v.is_nan() {
1049 "(0.0 / 0.0)".to_string()
1050 } else if v.is_infinite() {
1051 if v.is_sign_positive() { "(1.0 / 0.0)".to_string() } else { "(-1.0 / 0.0)".to_string() }
1052 } else {
1053 let mut s = format!("{:.9}", v);
1054 while s.contains('.') && s.ends_with('0') {
1055 s.pop();
1056 }
1057 if s.ends_with('.') {
1058 s.push('0');
1059 }
1060 if !s.contains('.') && !s.contains('e') && !s.contains('E') {
1061 s.push_str(".0");
1062 }
1063 s
1064 }
1065}