diff --git a/game/compute/phase5_extractor/phase5_draw.js b/game/compute/phase5_extractor/phase5_draw.js index e859c64..465cc33 100644 --- a/game/compute/phase5_extractor/phase5_draw.js +++ b/game/compute/phase5_extractor/phase5_draw.js @@ -23,6 +23,14 @@ export class Phase5Draw { usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC, }); + // ISO threshold uniform buffer (1 f32 = 4 bytes) + this.isoThresholdBuffer = device.createBuffer({ + size: 4, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + // Default ISO threshold = 0.5 + device.queue.writeBuffer(this.isoThresholdBuffer, 0, new Float32Array([0.5])); + // Indirect draw args: [indexCount, instanceCount, firstIndex, baseVertex, firstInstance] this.drawArgsBuffer = device.createBuffer({ size: 20, @@ -59,7 +67,11 @@ export class Phase5Draw { }); } - _createTopologyBG(qefBuffer, hermiteBuffer, isoThreshold) { + _createTopologyBG(qefBuffer, hermiteBuffer, isoThresholdValue) { + // Update ISO threshold buffer if a value is provided + if (isoThresholdValue !== undefined) { + this.device.queue.writeBuffer(this.isoThresholdBuffer, 0, new Float32Array([isoThresholdValue])); + } return this.device.createBindGroup({ layout: this.pipelines.topology.getBindGroupLayout(0), entries: [ @@ -67,7 +79,7 @@ export class Phase5Draw { { binding: 1, resource: { buffer: hermiteBuffer } }, { binding: 2, resource: { buffer: this.indexBuffer } }, { binding: 3, resource: { buffer: this.indexCountBuffer } }, - { binding: 4, resource: { buffer: isoThreshold } }, + { binding: 4, resource: { buffer: this.isoThresholdBuffer } }, ], }); } @@ -84,24 +96,28 @@ export class Phase5Draw { /** * Full mesh build: topology generation + LOD stitching + draw args. + * + * Topology: generates faces for all 255³ cells. + * LOD stitching: corrects seams between LOD levels. */ - async buildMesh(qefBuffer, hermiteBuffer, lodBuffer, isoThresholdValue) { + async buildMesh(qefBuffer, hermiteBuffer, lodBuffer, isoThresholdValue = 0.5) { const device = this.device; const encoder = device.createCommandEncoder(); // Reset index count device.queue.writeBuffer(this.indexCountBuffer, 0, new Uint32Array([0])); - // --- Pass 1: Topology generation --- + // --- Pass 1: Topology generation (full 3D: 255³ cells) --- { const pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines.topology); pass.setBindGroup(0, this._createTopologyBG(qefBuffer, hermiteBuffer, isoThresholdValue)); - pass.dispatchWorkgroups(32, 32, 1); // 255 × 255 cells in XY + // Dispatch: 32×32×32 workgroups = 1024³ threads covering 255³ cells + pass.dispatchWorkgroups(32, 32, 32); pass.end(); } - // --- Pass 2: LOD stitching --- + // --- Pass 2: LOD stitching (full 3D: 255³ cells) --- { const pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines.stitch); @@ -113,8 +129,7 @@ export class Phase5Draw { // --- Update indirect draw args --- // copy indexCount → drawArgs[0] (indexCount) encoder.copyBufferToBuffer(this.indexCountBuffer, 0, this.drawArgsBuffer, 0, 4); - // Set instanceCount = 1 (at offset 4) - // In a separate writeBuffer: + // Set instanceCount = 1, firstIndex = 0, baseVertex = 0, firstInstance = 0 device.queue.writeBuffer(this.drawArgsBuffer, 4, new Uint32Array([1, 0, 0, 0])); device.queue.submit([encoder.finish()]); @@ -149,7 +164,7 @@ export class Phase5Draw { } destroy() { - const bufs = ['indexBuffer', 'indexCountBuffer', 'drawArgsBuffer', 'readbackBuffer']; + const bufs = ['indexBuffer', 'indexCountBuffer', 'isoThresholdBuffer', 'drawArgsBuffer', 'readbackBuffer']; for (const key of bufs) { if (this[key]) this[key].destroy(); } diff --git a/game/compute/phase5_extractor/phase5_host.js b/game/compute/phase5_extractor/phase5_host.js index c2af8e0..cf88799 100644 --- a/game/compute/phase5_extractor/phase5_host.js +++ b/game/compute/phase5_extractor/phase5_host.js @@ -6,12 +6,22 @@ * Writes GPU-visible dual vertex buffer for rendering. */ +const CHANNELS = { + DENSITY: 0, + SEDIMENT: 1, + PERM_X: 2, + PERM_Y: 3, + PERM_Z: 4, + COHESION: 5, +}; + export class Phase5Extractor { constructor(device, gridSize = 256) { this.device = device; this.gridSize = gridSize; this.cellCount = (gridSize - 1) ** 3; // 255³ this.vertexCount = (gridSize + 1) ** 3; // 257³ + this.CHANNELS = CHANNELS; this._createBuffers(); this.pipelines = {}; @@ -91,6 +101,13 @@ export class Phase5Extractor { } _createHermiteBG(metaBuffer, densityBuf, cohesionBuf, permXBuf) { + // Validate buffers exist and are GPU buffers + if (!densityBuf || !cohesionBuf || !permXBuf || !metaBuffer) { + throw new Error('Phase5Extractor: hermite bind group missing required buffers (density, cohesion, permX, meta)'); + } + if (!densityBuf.size || !cohesionBuf.size || !permXBuf.size || !metaBuffer.size) { + throw new Error('Phase5Extractor: hermite buffers have zero size'); + } return this.device.createBindGroup({ layout: this.pipelines.hermite.getBindGroupLayout(0), entries: [ @@ -139,12 +156,13 @@ export class Phase5Extractor { }); } - async fullExtract(metaBuffer, channelBuffers, brickMetaBuffer) { + async fullExtract(metaBuffer, channelBuffers, brickMetaBuffer, qefParams = { tolerance: 0.3, weightThreshold: 0.01 }) { const device = this.device; const encoder = device.createCommandEncoder(); - // Set QEF params - device.queue.writeBuffer(this.qefParamsBuffer, 0, new Float32Array([0.3, 0.01])); + // Set QEF params (configurable per-call) + const paramData = new Float32Array([qefParams.tolerance, qefParams.weightThreshold]); + device.queue.writeBuffer(this.qefParamsBuffer, 0, paramData); // Pass 1: Hermite data generation (32×32×32 workgroups = 256³ vertices) { @@ -181,6 +199,7 @@ export class Phase5Extractor { const copyEncoder = device.createCommandEncoder(); copyEncoder.copyBufferToBuffer(this.vertexBuffer, 0, this.prevVertexBuffer, 0, this.cellCount * 12); device.queue.submit([copyEncoder.finish()]); + await device.queue.onSubmittedWorkDone(); } async incrementalExtract(brickMetaBuffer) { @@ -207,6 +226,7 @@ export class Phase5Extractor { const readEncoder = device.createCommandEncoder(); readEncoder.copyBufferToBuffer(this.deltaCountBuffer, 0, readback, 0, 4); device.queue.submit([readEncoder.finish()]); + await device.queue.onSubmittedWorkDone(); await readback.mapAsync(GPUMapMode.READ); const count = new Uint32Array(readback.getMappedRange())[0]; @@ -217,6 +237,7 @@ export class Phase5Extractor { const copyEncoder = device.createCommandEncoder(); copyEncoder.copyBufferToBuffer(this.vertexBuffer, 0, this.prevVertexBuffer, 0, this.cellCount * 12); device.queue.submit([copyEncoder.finish()]); + await device.queue.onSubmittedWorkDone(); return count; } diff --git a/game/compute/phase5_extractor/tiled_extractor.js b/game/compute/phase5_extractor/tiled_extractor.js index fda5d91..45a1a71 100644 --- a/game/compute/phase5_extractor/tiled_extractor.js +++ b/game/compute/phase5_extractor/tiled_extractor.js @@ -120,6 +120,12 @@ export class TiledExtractor { size: 8, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, }); + + // Tile offset uniform (vec3 = 12 bytes, padded to 16) + this.tileOffsetBuffer = this.device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); } async init(wgslSources) { @@ -163,6 +169,10 @@ export class TiledExtractor { for (let tz = 0; tz < this.tilesPerDim; tz++) { for (let ty = 0; ty < this.tilesPerDim; ty++) { for (let tx = 0; tx < this.tilesPerDim; tx++) { + // Tile loop offset for global buffer writes + const tileOffsetData = new Uint32Array([vx0, vy0, vz0, 0]); // padding + queue.writeBuffer(this.tileOffsetBuffer, 0, tileOffsetData); + const encoder = device.createCommandEncoder(); // Compute tile bounds in vertex space @@ -242,13 +252,14 @@ export class TiledExtractor { } _createQEFBG_Tiled(tileX, tileY, tileZ) { - // TODO: Add tile_offset uniform binding once WGSL accepts it + // Now includes tile_offset uniform binding return this.device.createBindGroup({ layout: this.pipelines.qef.getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.tileHermiteBuffer } }, { binding: 1, resource: { buffer: this.vertexBuffer } }, { binding: 2, resource: { buffer: this.qefParamsBuffer } }, + { binding: 3, resource: { buffer: this.tileOffsetBuffer } }, ], }); } @@ -261,16 +272,25 @@ export class TiledExtractor { { binding: 1, resource: { buffer: permXBuf } }, { binding: 2, resource: { buffer: this.lodBuffer } }, { binding: 3, resource: { buffer: metaBuffer } }, + { binding: 4, resource: { buffer: this.tileOffsetBuffer } }, ], }); } getVertexBuffer() { return this.vertexBuffer; } getLODBuffer() { return this.lodBuffer; } + + destroy() { + const bufs = ['vertexBuffer', 'prevVertexBuffer', 'lodBuffer', 'tileHermiteBuffer', + 'deltaBuffer', 'deltaCountBuffer', 'qefParamsBuffer', 'tileOffsetBuffer']; + for (const key of bufs) { + if (this[key]) this[key].destroy(); + } + } } /* - * WGSL SHADER MODIFICATION REQUIRED + * WGSL SHADER MODIFICATIONS REQUIRED * * The hermite and qef shaders currently compute global vertex indices * from global_invocation_id. For tiled extraction, they need an @@ -283,11 +303,15 @@ export class TiledExtractor { * let vx = gid.x + 1u + tile_offset.x; * let vy = gid.y + 1u + tile_offset.y; * let vz = gid.z + 1u + tile_offset.z; + * let vertex_idx = vx + vy * GRID_SIZE + vz * GRID_SIZE * GRID_SIZE; * - * Without this uniform, the shaders must be modified or the tiled - * approach requires one bind group per tile position (impractical). + * // In qef_solve, similarly adjust cell indices: + * let cx = gid.x + tile_offset.x; + * let cy = gid.y + tile_offset.y; + * let cz = gid.z + tile_offset.z; * - * For a minimal first pass: use the original full-hermite allocation - * and only tile the QEF+LOD passes (which are smaller). That drops - * peak from 942MB to ~667MB with less shader modification. + * Binding locations: + * hermite: binding(5) = tile_offset + * qef: binding(3) = tile_offset + * lod: binding(4) = tile_offset */ diff --git a/game/compute/phase6_edit/phase6_edit.wgsl b/game/compute/phase6_edit/phase6_edit.wgsl index 50b9ccf..feab3c9 100644 --- a/game/compute/phase6_edit/phase6_edit.wgsl +++ b/game/compute/phase6_edit/phase6_edit.wgsl @@ -46,6 +46,13 @@ const VOXELS_PER_BRICK: u32 = 4096u; // ── Helper: Encode with Range Expansion Detection ───────────── fn encode_with_expansion(ch: u32, brick_idx: u32, local_idx: u32, val: f32, dst: ptr>) -> bool { let meta = brick_meta[brick_idx * 6u + ch]; + + // Guard against division by zero + if (meta.y < 1e-6) { + (*dst)[brick_idx * VOXELS_PER_BRICK + local_idx] = f16_encode(clamp(val, 0.0, 1.0)); + return false; // No expansion needed if range is degenerate + } + let norm = (val - meta.x) / meta.y; let needs_expand_min = (val < meta.x); @@ -108,10 +115,12 @@ fn inject_pass(@builtin(global_invocation_id) gid: vec3) { if (cmd.material_type == 0u) { // Carve → Air let t = smoothstep(cmd.radius, cmd.radius * (1.0 - cmd.falloff), d); - new_density = mix(0.0, 1.0, t); + // t=1 at outer edge (should stay solid), t=0 at center (should be air) + // So density goes from 1.0 (outer) to 0.0 (center) as we carve inward + new_density = mix(0.0, 1.0, t); // Correct: 1.0 * t + 0.0 * (1-t) new_cohesion = mix(0.0, 1.0, t); new_perm = 1.0; - clear_water = (t > 0.5); // Clear water in fully carved voxels + clear_water = (t < 0.5); // Clear water in carved voxels (center) } else if (cmd.material_type == 1u) { // Inject Ore let t = smoothstep(cmd.radius * 0.5, cmd.radius, d); new_density = mix(0.95, 0.5, t); @@ -124,17 +133,19 @@ fn inject_pass(@builtin(global_invocation_id) gid: vec3) { } // Encode with range expansion detection + // Note: atomicMin/atomicMax expect u32 pointers, so ensure the buffers are u32 if (encode_with_expansion(0u, brick_idx, local_idx, new_density, &density_u16)) { - atomicMin(&edit_min_buffer[brick_idx * 6u + 0u], F16(new_density)); - atomicMax(&edit_max_buffer[brick_idx * 6u + 0u], F16(new_density)); + // Store expanded bounds as u32 bit patterns representing F16 values + atomicMin(&edit_min_buffer[brick_idx * 6u + 0u], f16_bits(new_density)); + atomicMax(&edit_max_buffer[brick_idx * 6u + 0u], f16_bits(new_density)); } if (encode_with_expansion(1u, brick_idx, local_idx, new_cohesion, &cohesion_u16)) { - atomicMin(&edit_min_buffer[brick_idx * 6u + 1u], F16(new_cohesion)); - atomicMax(&edit_max_buffer[brick_idx * 6u + 1u], F16(new_cohesion)); + atomicMin(&edit_min_buffer[brick_idx * 6u + 1u], f16_bits(new_cohesion)); + atomicMax(&edit_max_buffer[brick_idx * 6u + 1u], f16_bits(new_cohesion)); } if (encode_with_expansion(2u, brick_idx, local_idx, new_perm, &perm_x_u16)) { - atomicMin(&edit_min_buffer[brick_idx * 6u + 2u], F16(new_perm)); - atomicMax(&edit_max_buffer[brick_idx * 6u + 2u], F16(new_perm)); + atomicMin(&edit_min_buffer[brick_idx * 6u + 2u], f16_bits(new_perm)); + atomicMax(&edit_max_buffer[brick_idx * 6u + 2u], f16_bits(new_perm)); } // Clear water in carved voids (prevent physics explosion) diff --git a/game/compute/phase6_edit/phase6_host.js b/game/compute/phase6_edit/phase6_host.js index 6f54b13..b9ea9f2 100644 --- a/game/compute/phase6_edit/phase6_host.js +++ b/game/compute/phase6_edit/phase6_host.js @@ -13,8 +13,8 @@ export class EditManager { this.nextSlot = 0; this.pendingCount = 0; - // Command descriptor - this.commandByteSize = 32; // sizeof(EditCommand) padded + // Command descriptor: center(12B) + radius(4B) + materialType(4B) + falloff(4B) = 24 bytes + this.commandByteSize = 24; // Ring buffer for edit commands (GPU-visible) this.editBuffer = device.createBuffer({ @@ -163,7 +163,17 @@ export class EditManager { pass.dispatchWorkgroups(wgX, wgY, 1); pass.end(); - // Reset counter for next frame + // NOTE: Do NOT reset editCountBuffer here. It must be reset AFTER + // this encoder finishes on the GPU. Use resetEditCount() as a separate + // GPU pass or call it after await device.queue.onSubmittedWorkDone(). + } + + /** + * Reset edit counter. Must be called AFTER applyEdits() GPU work is done. + * Can be called on a separate command encoder or via writeBuffer after + * GPU completion. + */ + resetEditCount() { this.device.queue.writeBuffer(this.editCountBuffer, 0, new Uint32Array([0])); this.pendingCount = 0; } diff --git a/qef_extraction/density_field.wgsl b/qef_extraction/density_field.wgsl new file mode 100644 index 0000000..01b329c --- /dev/null +++ b/qef_extraction/density_field.wgsl @@ -0,0 +1,44 @@ +// density_field.wgsl +// Kernel 1: Convert particle positions to 64³ density field. +// Reads spatial_hash.wgsl grid, writes density scalar per cell. +// +// Dispatch: (64, 64, 1) workgroups of (1, 1, 64) — one thread per z-slice. + +struct DensityParams { + grid_dim: u32, // 64 + max_particles_per_cell: u32, + particle_radius: f32, // for density falloff +}; + +@group(0) @binding(0) var grid_heads: array>; +@group(0) @binding(1) var grid_next: array>; +@group(0) @binding(2) var particle_positions: array>; +@group(0) @binding(3) var density_field: array; // 64³ = 262,144 +@group(0) @binding(4) var params: DensityParams; + +@compute @workgroup_size(8, 8, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let cx = gid.x; + let cy = gid.y; + let cz = gid.z; + + if (cx >= params.grid_dim || cy >= params.grid_dim || cz >= params.grid_dim) { + return; + } + + let cell_idx = cx + cy * params.grid_dim + cz * params.grid_dim * params.grid_dim; + + // Count particles in this cell by walking linked list + var count = 0u; + var curr = atomicLoad(&grid_heads[cell_idx]); + + while (curr >= 0 && count < params.max_particles_per_cell) { + count++; + let n_idx = u32(curr); + curr = atomicLoad(&grid_next[n_idx]); + } + + // Normalize: density = count / max (capped at 1.0) + let density = f32(count) / f32(params.max_particles_per_cell); + density_field[cell_idx] = density; +} diff --git a/qef_extraction/marching_tets.wgsl b/qef_extraction/marching_tets.wgsl new file mode 100644 index 0000000..43e19fd --- /dev/null +++ b/qef_extraction/marching_tets.wgsl @@ -0,0 +1,152 @@ +// marching_tets.wgsl +// Kernel 2: Marching Tetrahedra — classify edges, emit crossings. +// Each grid cell decomposed into 6 tetrahedra. +// Output: per-cell crossing mask + edge intersection points. +// +// Dispatch: (64, 64, 1) workgroups of (1, 1, 64) + +struct MTParams { + grid_dim: u32, // 64 + cell_size: f32, + isosurface: f32, // threshold, typically 0.1 +}; + +// Cube vertex indices in [0,7] for the 8 corners. +// Corner ordering: (x,y,z) where bit 0=x, bit 1=y, bit 2=z +// +// z=1: 4---5 z=0: 0---1 +// |\ |\ |\ |\ +// 7---6 3---2 +// +// 6 tetrahedra decomposing a cube (vertex indices into cube corners [0..7]): +// Each tet shares the cube diagonal (0,7) for consistency. +const TET_DECOMP: array, 6> = array, 6>( + vec4(0u, 1u, 3u, 7u), // tet 0 + vec4(0u, 1u, 5u, 7u), // tet 1 + vec4(0u, 4u, 5u, 7u), // tet 2 + vec4(0u, 4u, 6u, 7u), // tet 3 + vec4(0u, 2u, 3u, 7u), // tet 4 + vec4(0u, 2u, 6u, 7u), // tet 5 +); + +// Tetrahedron edges: pairs of vertex indices (6 edges per tet). +// Indices are into the tet's 4 vertices [0..3]. +const TET_EDGES: array, 6> = array, 6>( + vec2(0u, 1u), vec2(0u, 2u), vec2(0u, 3u), + vec2(1u, 2u), vec2(1u, 3u), vec2(2u, 3u), +); + +// Cube corner positions (unit cube [0,1]³) +fn cube_corner(idx: u32) -> vec3 { + return vec3( + f32((idx >> 0u) & 1u), + f32((idx >> 1u) & 1u), + f32((idx >> 2u) & 1u), + ); +} + +@group(1) @binding(0) var density_field: array; +@group(1) @binding(1) var crossing_count: atomic; // total crossings +@group(1) @binding(2) var crossings: array; // packed: cell_idx | edge_data +@group(1) @binding(3) var mt_params: MTParams; + +// Read density at a grid corner, with bounds check +fn density_at(cx: u32, cy: u32, cz: u32) -> f32 { + if (cx >= mt_params.grid_dim || cy >= mt_params.grid_dim || cz >= mt_params.grid_dim) { + return 0.0; + } + let idx = cx + cy * mt_params.grid_dim + cz * mt_params.grid_dim * mt_params.grid_dim; + return density_field[idx]; +} + +// Read density at cube corner (0..7) given cell origin +fn cube_corner_density(cx: u32, cy: u32, cz: u32, corner: u32) -> f32 { + let dx = (corner >> 0u) & 1u; + let dy = (corner >> 1u) & 1u; + let dz = (corner >> 2u) & 1u; + return density_at(cx + dx, cy + dy, cz + dz); +} + +// Linear interpolation along edge to find crossing point +fn edge_interp(d0: f32, d1: f32, p0: vec3, p1: vec3) -> vec3 { + let t = (mt_params.isosurface - d0) / (d1 - d0); + return mix(p0, p1, clamp(t, 0.0, 1.0)); +} + +@compute @workgroup_size(8, 8, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let cx = gid.x; + let cy = gid.y; + let cz = gid.z; + + if (cx >= mt_params.grid_dim - 1u || cy >= mt_params.grid_dim - 1u || cz >= mt_params.grid_dim - 1u) { + return; + } + + let cell_base = cx + cy * mt_params.grid_dim + cz * mt_params.grid_dim * mt_params.grid_dim; + + // Read corner densities + var corner_d: array; + for (var i = 0u; i < 8u; i++) { + corner_d[i] = cube_corner_density(cx, cy, cz, i); + } + + var cell_crossings = 0u; + + // Process 6 tetrahedra + for (var t = 0u; t < 6u; t++) { + let tet = TET_DECOMP[t]; + var tet_d: array; + var tet_p: array, 4>; + + for (var v = 0u; v < 4u; v++) { + let cv = tet[v]; + tet_d[v] = corner_d[cv]; + tet_p[v] = cube_corner(cv); + } + + // Classify vertices: above or below isosurface + var mask = 0u; + for (var v = 0u; v < 4u; v++) { + if (tet_d[v] >= mt_params.isosurface) { + mask |= (1u << v); + } + } + + // Degenerate cases: all above or all below → no crossing + if (mask == 0u || mask == 15u) { continue; } + + // Check each of 6 edges for crossing + for (var e = 0u; e < 6u; e++) { + let e0 = TET_EDGES[e].x; + let e1 = TET_EDGES[e].y; + + let above0 = (mask >> e0) & 1u; + let above1 = (mask >> e1) & 1u; + + if (above0 == above1) { continue; } // no sign change + + // Crossing found — compute intersection point + let d0 = tet_d[e0]; + let d1 = tet_d[e1]; + let p0 = tet_p[e0]; + let p1 = tet_p[e1]; + + let isect = edge_interp(d0, d1, p0, p1); + + // Pack: cell_base (18 bits) | t (3 bits) | e (3 bits) — fits in u32 + let edge_data = cell_base; + let packed = (edge_data << 6u) | (t << 3u) | e; + + // Atomic append to crossings buffer + let slot = atomicAdd(&crossing_count, 1u); + if (slot < arrayLength(&crossings)) { + crossings[slot] = packed; + // Store intersection point in parallel array + // (handled in separate pass or interleaved buffer) + } + + cell_crossings++; + } + } +} diff --git a/qef_extraction/mesh_assembly.wgsl b/qef_extraction/mesh_assembly.wgsl new file mode 100644 index 0000000..1080fe7 --- /dev/null +++ b/qef_extraction/mesh_assembly.wgsl @@ -0,0 +1,132 @@ +// mesh_assembly.wgsl +// Kernel 4: Compact mesh assembly — deduplicate vertices, build index buffer. +// Takes raw QEF vertex output, hashes positions to merge near-duplicates. +// Output: compact Vertex[] + Index[] buffers ready for Filament. +// +// Dispatch: indirect, vertex_count / 64 for dedup, then per-tri for indexing. + +struct MeshParams { + grid_dim: u32, + cell_size: f32, + merge_threshold: f32, // vertices closer than this are merged (default: cell_size * 0.01) + max_vertices: u32, + max_indices: u32, +}; + +struct Vertex { + position: vec3, + normal: vec3, + material_tensor: vec4, // 6-channel tensor packed into 2×vec4 (first 4 channels) +}; + +@group(3) @binding(0) var raw_vertices: array>; +@group(3) @binding(1) var vertex_count_in: atomic; +@group(3) @binding(2) var mesh_vertices: array; +@group(3) @binding(3) var mesh_indices: array; +@group(3) @binding(4) var mesh_vertex_count: atomic; +@group(3) @binding(5) var mesh_index_count: atomic; +@group(3) @binding(6) var density_field: array; +@group(3) @binding(7) var mesh_params: MeshParams; + +// Spatial hash for vertex deduplication +fn vertex_hash(pos: vec3) -> u32 { + let inv_thresh = 1.0 / mesh_params.merge_threshold; + let ix = u32(pos.x * inv_thresh + 100000.0); + let iy = u32(pos.y * inv_thresh + 100000.0); + let iz = u32(pos.z * inv_thresh + 100000.0); + // Simple hash: Morton-like interleave + return (ix & 0x3FFu) | ((iy & 0x3FFu) << 10u) | ((iz & 0x3FFu) << 20u); +} + +// Compute vertex normal from density gradient +fn compute_normal(pos: vec3) -> vec3 { + let d = mesh_params.grid_dim; + let h = mesh_params.cell_size; + + // Convert world pos to grid coords + let gx = u32(clamp(pos.x / h, 0.0, f32(d - 1))); + let gy = u32(clamp(pos.y / h, 0.0, f32(d - 1))); + let gz = u32(clamp(pos.z / h, 0.0, f32(d - 1))); + + let dx = if (gx > 0u && gx < d - 1u) { + density_field[(gx+1u) + gy*d + gz*d*d] - density_field[(gx-1u) + gy*d + gz*d*d] + } else { 0.0 }; + + let dy = if (gy > 0u && gy < d - 1u) { + density_field[gx + (gy+1u)*d + gz*d*d] - density_field[gx + (gy-1u)*d + gz*d*d] + } else { 0.0 }; + + let dz = if (gz > 0u && gz < d - 1u) { + density_field[gx + gy*d + (gz+1u)*d*d] - density_field[gx + gy*d + (gz-1u)*d*d] + } else { 0.0 }; + + return normalize(vec3(dx, dy, dz)); +} + +// Phase 1: Deduplicate vertices by spatial hash +@compute @workgroup_size(64) +fn deduplicate_vertices(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = atomicLoad(&vertex_count_in); + if (idx >= total) { return; } + + let pos = raw_vertices[idx]; + let norm = compute_normal(pos); + + // Simple dedup: just emit. Full dedup via hash table is phase 2. + // (Per-cell independent QEF naturally minimizes duplicates) + let slot = atomicAdd(&mesh_vertex_count, 1u); + if (slot < mesh_params.max_vertices) { + mesh_vertices[slot] = Vertex( + pos, + norm, + vec4(0.0, 0.0, 0.0, 0.0), // material tensor filled by later pass + ); + } +} + +// Phase 2: Build triangle indices via Delaunay-like triangulation +// For MVP: connect vertices within each cell using a simple fan +@compute @workgroup_size(64) +fn build_indices(@builtin(global_invocation_id) gid: vec3) { + let cell_idx = gid.x; + let d = mesh_params.grid_dim; + let total_cells = d * d * d; + + if (cell_idx >= total_cells) { return; } + + let cz = cell_idx / (d * d); + let cy = (cell_idx % (d * d)) / d; + let cx = cell_idx % d; + + // Skip boundary cells (no complete neighborhood) + if (cx >= d - 1u || cy >= d - 1u || cz >= d - 1u) { return; } + + // For MVP: each cell with density crossing emits 2 triangles + // forming a quad connecting cell center to neighbors + // This is a placeholder — proper triangulation uses the MT edge table + // to connect crossing points into faces. + + let ci = cell_idx; + let density = density_field[ci]; + + if (density < 0.1 || density > 0.9) { return; } + + // Emit a placeholder quad (2 triangles) connecting this cell to neighbors + // In production, this reads the MT crossing table to build proper faces. + // For MVP, we accept the simplification. + + let vc = (cx + 1u) + (cy + 1u) * d + (cz + 1u) * d * d; + let slot = atomicAdd(&mesh_index_count, 6u); + + if (slot + 6u <= mesh_params.max_indices) { + // Triangle 1 + mesh_indices[slot + 0u] = ci; + mesh_indices[slot + 1u] = ci + 1u; + mesh_indices[slot + 2u] = ci + d; + // Triangle 2 + mesh_indices[slot + 3u] = ci + 1u; + mesh_indices[slot + 4u] = ci + 1u + d; + mesh_indices[slot + 5u] = ci + d; + } +} diff --git a/qef_extraction/qef_pipeline.rs b/qef_extraction/qef_pipeline.rs new file mode 100644 index 0000000..f4643da --- /dev/null +++ b/qef_extraction/qef_pipeline.rs @@ -0,0 +1,585 @@ +// qef_pipeline.rs +// Rust dispatch orchestrator for the QEF mesh extraction pipeline. +// Ties together: spatial_hash → density → marching_tets → QEF → mesh → Filament. +// +// Usage: +// let mut pipeline = QefPipeline::new(&device, &queue, 64); +// pipeline.extract_mesh(&particle_positions, &spatial_hash_grid); +// let renderable = pipeline.to_filament(&engine); + +use std::sync::Arc; +use wgpu::{util::DeviceExt, *}; +use bytemuck::{Pod, Zeroable}; + +// ─── GPU-side structs (must match WGSL layouts) ────────────────────────── + +#[repr(C)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] +struct DensityParams { + grid_dim: u32, + max_particles_per_cell: u32, + particle_radius: f32, + _pad: u32, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] +struct MTParams { + grid_dim: u32, + cell_size: f32, + isosurface: f32, + _pad: u32, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] +struct QEFParams { + grid_dim: u32, + cell_size: f32, + regularization: f32, + _pad: u32, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] +struct MeshParams { + grid_dim: u32, + cell_size: f32, + merge_threshold: f32, + max_vertices: u32, + max_indices: u32, + _pad: [u32; 3], +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] +struct QefVertex { + position: [f32; 3], + _pad0: f32, + normal: [f32; 3], + _pad1: f32, + material_tensor: [f32; 4], // first 4 of 6 channels +} + +// ─── Pipeline ──────────────────────────────────────────────────────────── + +pub struct QefPipeline { + grid_dim: u32, + cell_size: f32, + + // Kernels + density_pipeline: ComputePipeline, + mt_pipeline: ComputePipeline, + qef_pipeline: ComputePipeline, + dedup_pipeline: ComputePipeline, + index_pipeline: ComputePipeline, + + // Bind group layouts + density_bgl: BindGroupLayout, + mt_bgl: BindGroupLayout, + qef_bgl: BindGroupLayout, + mesh_bgl: BindGroupLayout, + + // Buffers (allocated once, reused per frame) + density_field: Buffer, + crossings: Buffer, + crossing_count: Buffer, + raw_vertices: Buffer, + vertex_count: Buffer, + mesh_vertices: Buffer, + mesh_indices: Buffer, + mesh_vcount: Buffer, + mesh_icount: Buffer, + + // Uniforms + density_uniform: Buffer, + mt_uniform: Buffer, + qef_uniform: Buffer, + mesh_uniform: Buffer, + + max_particles: u32, + max_crossings: u32, + max_vertices: u32, + max_indices: u32, +} + +impl QefPipeline { + pub fn new(device: &Device, grid_dim: u32, cell_size: f32, max_particles: u32) -> Self { + let grid_cells = grid_dim * grid_dim * grid_dim; + let max_crossings = grid_cells * 36; // 6 tets × 6 edges max per cell + let max_vertices = grid_cells * 3; // avg 3 vertices per crossing cell + let max_indices = grid_cells * 12; // avg 4 tris per crossing cell + + // Load WGSL + let density_src = include_str!("density_field.wgsl"); + let mt_src = include_str!("marching_tets.wgsl"); + let qef_src = include_str!("qef_solve.wgsl"); + let mesh_src = include_str!("mesh_assembly.wgsl"); + + // Shader modules + let density_module = device.create_shader_module(ShaderModuleDescriptor { + label: Some("density_field"), + source: ShaderSource::Wgsl(density_src.into()), + }); + let mt_module = device.create_shader_module(ShaderModuleDescriptor { + label: Some("marching_tets"), + source: ShaderSource::Wgsl(mt_src.into()), + }); + let qef_module = device.create_shader_module(ShaderModuleDescriptor { + label: Some("qef_solve"), + source: ShaderSource::Wgsl(qef_src.into()), + }); + let mesh_module = device.create_shader_module(ShaderModuleDescriptor { + label: Some("mesh_assembly"), + source: ShaderSource::Wgsl(mesh_src.into()), + }); + + // Bind group layouts (matching WGSL group indices) + let density_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("density_bgl"), + entries: &[ + storage_read(0, false), // grid_heads (provided externally) + storage_read(1, false), // grid_next + storage_read(2, false), // particle_positions + storage_rw(3, false), // density_field + uniform(4, false), // params + ], + }); + + let mt_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("mt_bgl"), + entries: &[ + storage_read(0, false), // density_field + storage_rw(1, false), // crossing_count + storage_rw(2, false), // crossings + uniform(3, false), // params + ], + }); + + let qef_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("qef_bgl"), + entries: &[ + storage_read(0, false), // density_field + storage_read(1, false), // crossings + storage_read(2, false), // crossing_count + storage_rw(3, false), // vertices + storage_rw(4, false), // vertex_count + uniform(5, false), // params + ], + }); + + let mesh_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("mesh_bgl"), + entries: &[ + storage_read(0, false), // raw_vertices + storage_read(1, false), // vertex_count_in + storage_rw(2, false), // mesh_vertices + storage_rw(3, false), // mesh_indices + storage_rw(4, false), // mesh_vertex_count + storage_rw(5, false), // mesh_index_count + storage_read(6, false), // density_field + uniform(7, false), // params + ], + }); + + // Pipeline layouts + let density_pl = device.create_pipeline_layout(&PipelineLayoutDescriptor { + label: Some("density_pl"), + bind_group_layouts: &[&density_bgl], + push_constant_ranges: &[], + }); + let mt_pl = device.create_pipeline_layout(&PipelineLayoutDescriptor { + label: Some("mt_pl"), + bind_group_layouts: &[&mt_bgl], + push_constant_ranges: &[], + }); + let qef_pl = device.create_pipeline_layout(&PipelineLayoutDescriptor { + label: Some("qef_pl"), + bind_group_layouts: &[&qef_bgl], + push_constant_ranges: &[], + }); + let mesh_pl = device.create_pipeline_layout(&PipelineLayoutDescriptor { + label: Some("mesh_pl"), + bind_group_layouts: &[&mesh_bgl], + push_constant_ranges: &[], + }); + + // Compute pipelines + let density_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some("density"), + layout: Some(&density_pl), + module: &density_module, + entry_point: Some("main"), + compilation_options: Default::default(), + }); + let mt_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some("marching_tets"), + layout: Some(&mt_pl), + module: &mt_module, + entry_point: Some("main"), + compilation_options: Default::default(), + }); + let qef_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some("qef_solve"), + layout: Some(&qef_pl), + module: &qef_module, + entry_point: Some("main"), + compilation_options: Default::default(), + }); + let dedup_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some("dedup"), + layout: Some(&mesh_pl), + module: &mesh_module, + entry_point: Some("deduplicate_vertices"), + compilation_options: Default::default(), + }); + let index_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some("index"), + layout: Some(&mesh_pl), + module: &mesh_module, + entry_point: Some("build_indices"), + compilation_options: Default::default(), + }); + + // Buffers + let density_field = device.create_buffer(&BufferDescriptor { + label: Some("density_field"), + size: (grid_cells as u64) * 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let crossings = device.create_buffer(&BufferDescriptor { + label: Some("crossings"), + size: (max_crossings as u64) * 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let crossing_count = device.create_buffer_init(&BufferInitDescriptor { + label: Some("crossing_count"), + contents: &0u32.to_le_bytes(), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + }); + let raw_vertices = device.create_buffer(&BufferDescriptor { + label: Some("raw_vertices"), + size: (max_vertices as u64) * 12, + usage: BufferUsages::STORAGE, + mapped_at_creation: false, + }); + let vertex_count = device.create_buffer_init(&BufferInitDescriptor { + label: Some("vertex_count"), + contents: &0u32.to_le_bytes(), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + }); + let mesh_vertices = device.create_buffer(&BufferDescriptor { + label: Some("mesh_vertices"), + size: (max_vertices as u64) * std::mem::size_of::() as u64, + usage: BufferUsages::STORAGE | BufferUsages::VERTEX, + mapped_at_creation: false, + }); + let mesh_indices = device.create_buffer(&BufferDescriptor { + label: Some("mesh_indices"), + size: (max_indices as u64) * 4, + usage: BufferUsages::STORAGE | BufferUsages::INDEX, + mapped_at_creation: false, + }); + let mesh_vcount = device.create_buffer_init(&BufferInitDescriptor { + label: Some("mesh_vcount"), + contents: &0u32.to_le_bytes(), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + }); + let mesh_icount = device.create_buffer_init(&BufferInitDescriptor { + label: Some("mesh_icount"), + contents: &0u32.to_le_bytes(), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + }); + + // Uniforms + let density_uniform = device.create_buffer_init(&BufferInitDescriptor { + label: Some("density_params"), + contents: bytemuck::bytes_of(&DensityParams { + grid_dim, + max_particles_per_cell: 64, + particle_radius: cell_size * 0.5, + _pad: 0, + }), + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + }); + let mt_uniform = device.create_buffer_init(&BufferInitDescriptor { + label: Some("mt_params"), + contents: bytemuck::bytes_of(&MTParams { + grid_dim, + cell_size, + isosurface: 0.1, + _pad: 0, + }), + usage: BufferUsages::UNIFORM, + }); + let qef_uniform = device.create_buffer_init(&BufferInitDescriptor { + label: Some("qef_params"), + contents: bytemuck::bytes_of(&QEFParams { + grid_dim, + cell_size, + regularization: 0.001, + _pad: 0, + }), + usage: BufferUsages::UNIFORM, + }); + let mesh_uniform = device.create_buffer_init(&BufferInitDescriptor { + label: Some("mesh_params"), + contents: bytemuck::bytes_of(&MeshParams { + grid_dim, + cell_size, + merge_threshold: cell_size * 0.01, + max_vertices, + max_indices, + _pad: [0; 3], + }), + usage: BufferUsages::UNIFORM, + }); + + Self { + grid_dim, cell_size, + density_pipeline, mt_pipeline, qef_pipeline, dedup_pipeline, index_pipeline, + density_bgl, mt_bgl, qef_bgl, mesh_bgl, + density_field, crossings, crossing_count, raw_vertices, vertex_count, + mesh_vertices, mesh_indices, mesh_vcount, mesh_icount, + density_uniform, mt_uniform, qef_uniform, mesh_uniform, + max_particles, max_crossings, max_vertices, max_indices, + } + } + + /// Run the full QEF extraction pipeline. + /// `grid_heads`, `grid_next`, `particle_positions` are buffers from spatial_hash.wgsl. + pub fn extract_mesh( + &self, + encoder: &mut CommandEncoder, + grid_heads: &Buffer, + grid_next: &Buffer, + particle_positions: &Buffer, + grid_head_count: u32, + ) { + // ── Pass 1: Density Field ────────────────────────────────────── + { + let density_bg = encoder.device().create_bind_group(&BindGroupDescriptor { + label: Some("density_bg"), + layout: &self.density_bgl, + entries: &[ + BindGroupEntry { binding: 0, resource: grid_heads.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: grid_next.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: particle_positions.as_entire_binding() }, + BindGroupEntry { binding: 3, resource: self.density_field.as_entire_binding() }, + BindGroupEntry { binding: 4, resource: self.density_uniform.as_entire_binding() }, + ], + }); + + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("density_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&self.density_pipeline); + cpass.set_bind_group(0, &density_bg, &[]); + let wg = (self.grid_dim + 7) / 8; + cpass.dispatch_workgroups(wg, wg, self.grid_dim); + } + + // Reset counters + encoder.clear_buffer(&self.crossing_count, 0, 4); + + // ── Pass 2: Marching Tetrahedra ──────────────────────────────── + { + let mt_bg = encoder.device().create_bind_group(&BindGroupDescriptor { + label: Some("mt_bg"), + layout: &self.mt_bgl, + entries: &[ + BindGroupEntry { binding: 0, resource: self.density_field.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: self.crossing_count.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: self.crossings.as_entire_binding() }, + BindGroupEntry { binding: 3, resource: self.mt_uniform.as_entire_binding() }, + ], + }); + + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("mt_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&self.mt_pipeline); + cpass.set_bind_group(0, &mt_bg, &[]); + let wg = (self.grid_dim + 7) / 8; + cpass.dispatch_workgroups(wg, wg, self.grid_dim); + } + + // Reset vertex counter + encoder.clear_buffer(&self.vertex_count, 0, 4); + + // ── Pass 3: QEF Solve ────────────────────────────────────────── + { + let qef_bg = encoder.device().create_bind_group(&BindGroupDescriptor { + label: Some("qef_bg"), + layout: &self.qef_bgl, + entries: &[ + BindGroupEntry { binding: 0, resource: self.density_field.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: self.crossings.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: self.crossing_count.as_entire_binding() }, + BindGroupEntry { binding: 3, resource: self.raw_vertices.as_entire_binding() }, + BindGroupEntry { binding: 4, resource: self.vertex_count.as_entire_binding() }, + BindGroupEntry { binding: 5, resource: self.qef_uniform.as_entire_binding() }, + ], + }); + + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("qef_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&self.qef_pipeline); + cpass.set_bind_group(0, &qef_bg, &[]); + // Indirect dispatch: wait for MT pass crossing count, then dispatch + // crossing_count / 64 workgroups + let wg = (self.max_crossings + 63) / 64; + cpass.dispatch_workgroups(wg, 1, 1); + } + + // Reset mesh counters + encoder.clear_buffer(&self.mesh_vcount, 0, 4); + encoder.clear_buffer(&self.mesh_icount, 0, 4); + + // ── Pass 4: Mesh Assembly ────────────────────────────────────── + { + let mesh_bg = encoder.device().create_bind_group(&BindGroupDescriptor { + label: Some("mesh_bg"), + layout: &self.mesh_bgl, + entries: &[ + BindGroupEntry { binding: 0, resource: self.raw_vertices.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: self.vertex_count.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: self.mesh_vertices.as_entire_binding() }, + BindGroupEntry { binding: 3, resource: self.mesh_indices.as_entire_binding() }, + BindGroupEntry { binding: 4, resource: self.mesh_vcount.as_entire_binding() }, + BindGroupEntry { binding: 5, resource: self.mesh_icount.as_entire_binding() }, + BindGroupEntry { binding: 6, resource: self.density_field.as_entire_binding() }, + BindGroupEntry { binding: 7, resource: self.mesh_uniform.as_entire_binding() }, + ], + }); + + // Dedup + { + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("dedup_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&self.dedup_pipeline); + cpass.set_bind_group(0, &mesh_bg, &[]); + let wg = (self.max_vertices + 63) / 64; + cpass.dispatch_workgroups(wg, 1, 1); + } + + // Index + { + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("index_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&self.index_pipeline); + cpass.set_bind_group(0, &mesh_bg, &[]); + let total_cells = self.grid_dim * self.grid_dim * self.grid_dim; + let wg = (total_cells + 63) / 64; + cpass.dispatch_workgroups(wg, 1, 1); + } + } + } + + /// Get output buffers ready for rendering. + pub fn vertex_buffer(&self) -> &Buffer { &self.mesh_vertices } + pub fn index_buffer(&self) -> &Buffer { &self.mesh_indices } + pub fn vertex_count_buffer(&self) -> &Buffer { &self.mesh_vcount } + pub fn index_count_buffer(&self) -> &Buffer { &self.mesh_icount } +} + +// ─── Helpers ───────────────────────────────────────────────────────────── + +fn storage_read(binding: u32, has_dynamic_offset: bool) -> BindGroupLayoutEntry { + BindGroupLayoutEntry { + binding, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset, + min_binding_size: None, + }, + count: None, + } +} + +fn storage_rw(binding: u32, has_dynamic_offset: bool) -> BindGroupLayoutEntry { + BindGroupLayoutEntry { + binding, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset, + min_binding_size: None, + }, + count: None, + } +} + +fn uniform(binding: u32, has_dynamic_offset: bool) -> BindGroupLayoutEntry { + BindGroupLayoutEntry { + binding, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Uniform, + has_dynamic_offset, + min_binding_size: None, + }, + count: None, + } +} + +impl QefPipeline { + /// Integrate with existing Filament render setup. + /// Call this after extract_mesh() to build Filament VertexBuffer + IndexBuffer. + #[cfg(feature = "filament")] + pub fn to_filament( + &self, + engine: &mut filament::Engine, + device: &Device, + queue: &Queue, + ) -> (filament::VertexBuffer, filament::IndexBuffer) { + use filament::{VertexBuffer, IndexBuffer, VertexAttribute, VertexBufferType}; + + // Download vertex/index counts from GPU + let mut staging_vc = device.create_buffer(&BufferDescriptor { + label: Some("staging_vc"), + size: 4, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + let mut staging_ic = device.create_buffer(&BufferDescriptor { + label: Some("staging_ic"), + size: 4, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { label: None }); + encoder.copy_buffer_to_buffer(&self.mesh_vcount, 0, &staging_vc, 0, 4); + encoder.copy_buffer_to_buffer(&self.mesh_icount, 0, &staging_ic, 0, 4); + queue.submit(Some(encoder.finish())); + + // Build Filament buffers (MVP: assume data is ready — production needs fence) + // In practice, use a ring buffer or double-buffer for async readback. + let vb = VertexBuffer::new(engine) + .vertex_count(self.max_vertices) + .buffer_count(1) + .attribute(VertexAttribute::POSITION, 0, VertexBufferType::FLOAT3, 0, 0) + .attribute(VertexAttribute::CUSTOM0, 0, VertexBufferType::FLOAT3, 12, 0) // normal + .attribute(VertexAttribute::CUSTOM1, 0, VertexBufferType::FLOAT4, 24, 0) // material tensor (first 4) + .build(engine); + + let ib = IndexBuffer::new(engine) + .index_count(self.max_indices) + .buffer_type(filament::IndexBufferType::UINT) + .build(engine); + + (vb, ib) + } +} diff --git a/qef_extraction/qef_solve.wgsl b/qef_extraction/qef_solve.wgsl new file mode 100644 index 0000000..c838f57 --- /dev/null +++ b/qef_extraction/qef_solve.wgsl @@ -0,0 +1,237 @@ +// qef_solve.wgsl +// Kernel 3: QEF Vertex Placement — solve 3×3 system per crossing cell. +// Minimize Σ (n_i · (v - p_i))² where n_i is gradient, p_i is crossing point. +// Each cell gets ONE vertex placed at the minimizer of its quadric error. +// +// Dispatch: indirect, crossing_count / 64 workgroups + +struct QEFParams { + grid_dim: u32, + cell_size: f32, + regularization: f32, // small epsilon for singular matrices (default 0.001) +}; + +@group(2) @binding(0) var density_field: array; +@group(2) @binding(1) var crossings: array; // packed edge data from MT pass +@group(2) @binding(2) var crossing_count: atomic; +@group(2) @binding(3) var vertices: array>; // output: one vertex per crossing +@group(2) @binding(4) var vertex_count: atomic; +@group(2) @binding(5) var qef_params: QEFParams; + +// --- 3×3 SVD (closed form, no iteration) --- +// Solves A^T A x = A^T b for x = vertex position +// A is m×3 (normals), b is m×1 (n·p) +// Using normal equations: (A^T A) x = A^T b + +struct Mat3x3 { + m: array, // column-major: m[col*3 + row] +} + +fn mat3_zero() -> Mat3x3 { + return Mat3x3(array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)); +} + +fn mat3_add_outer(a: vec3, weight: f32) -> Mat3x3 { + // Returns weight * (a ⊗ a) as a matrix to accumulate into A^T A + var m = mat3_zero(); + m.m[0] = a.x * a.x * weight; // col 0 + m.m[1] = a.x * a.y * weight; + m.m[2] = a.x * a.z * weight; + m.m[3] = a.y * a.x * weight; // col 1 + m.m[4] = a.y * a.y * weight; + m.m[5] = a.y * a.z * weight; + m.m[6] = a.z * a.x * weight; // col 2 + m.m[7] = a.z * a.y * weight; + m.m[8] = a.z * a.z * weight; + return m; +} + +// Cramer's rule for 3×3 system (robust enough for QEF with regularization) +fn solve_3x3(ata: Mat3x3, atb: vec3, reg: f32) -> vec3 { + // Add regularization to diagonal + var a00 = ata.m[0] + reg; + var a01 = ata.m[1]; + var a02 = ata.m[2]; + var a10 = ata.m[3]; + var a11 = ata.m[4] + reg; + var a12 = ata.m[5]; + var a20 = ata.m[6]; + var a21 = ata.m[7]; + var a22 = ata.m[8] + reg; + + // Determinant + let det = a00 * (a11 * a22 - a12 * a21) + - a01 * (a10 * a22 - a12 * a20) + + a02 * (a10 * a21 - a11 * a20); + + if (abs(det) < 1e-12) { + // Singular — fall back to centroid + return vec3(0.0); + } + + let inv_det = 1.0 / det; + + // Cramer's rule + let x = (atb.x * (a11 * a22 - a12 * a21) + - a01 * (atb.y * a22 - a12 * atb.z) + + a02 * (atb.y * a21 - a11 * atb.z)) * inv_det; + + let y = (a00 * (atb.y * a22 - a12 * atb.z) + - atb.x * (a10 * a22 - a12 * a20) + + a02 * (a10 * atb.z - atb.y * a20)) * inv_det; + + let z = (a00 * (a11 * atb.z - atb.y * a21) + - a01 * (a10 * atb.z - atb.y * a20) + + atb.x * (a10 * a21 - a11 * a20)) * inv_det; + + return vec3(x, y, z); +} + +// --- Density gradient via central differences --- +fn density_gradient(cx: u32, cy: u32, cz: u32) -> vec3 { + let d = qef_params.grid_dim; + let h = qef_params.cell_size; + + let dx = if (cx > 0u && cx < d - 1u) { + let r = density_field[(cx+1u) + cy*d + cz*d*d]; + let l = density_field[(cx-1u) + cy*d + cz*d*d]; + (r - l) / (2.0 * h) + } else { 0.0 }; + + let dy = if (cy > 0u && cy < d - 1u) { + let r = density_field[cx + (cy+1u)*d + cz*d*d]; + let l = density_field[cx + (cy-1u)*d + cz*d*d]; + (r - l) / (2.0 * h) + } else { 0.0 }; + + let dz = if (cz > 0u && cz < d - 1u) { + let r = density_field[cx + cy*d + (cz+1u)*d*d]; + let l = density_field[cx + cy*d + (cz-1u)*d*d]; + (r - l) / (2.0 * h) + } else { 0.0 }; + + return normalize(vec3(dx, dy, dz)); +} + +// Unpack edge data from MT pass +fn unpack_crossing(packed: u32) -> vec3 { + let edge_bits = packed & 0x3Fu; // lower 6 bits: tet + edge id + let cell_idx = packed >> 6u; // upper bits: cell index + return vec3(cell_idx, edge_bits >> 3u, edge_bits & 0x7u); +} + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = atomicLoad(&crossing_count); + + if (idx >= total) { return; } + + let packed = crossings[idx]; + let unpacked = unpack_crossing(packed); + let cell_idx = unpacked.x; + + // Recover cell coordinates + let d = qef_params.grid_dim; + let cz = cell_idx / (d * d); + let cy = (cell_idx % (d * d)) / d; + let cx = cell_idx % d; + + // --- Build QEF for this cell --- + // For each of the 8 corners that cross the isosurface, add: + // n_i · (v - p_i) term to the quadric + // Where n_i = gradient at crossing point, p_i = crossing position + + var ata = mat3_zero(); + var atb = vec3(0.0); + var centroid = vec3(0.0); + var point_count = 0u; + + // Gather crossing points from all 6 tetrahedra edges in this cell + // We re-derive crossings from the density field for correctness + // (avoids storing intersection points in the MT pass) + + // Cube corner densities + var corner_d: array; + for (var i = 0u; i < 8u; i++) { + let dx = (i >> 0u) & 1u; + let dy = (i >> 1u) & 1u; + let dz = (i >> 2u) & 1u; + let ci = (cx + dx) + (cy + dy) * d + (cz + dz) * d * d; + corner_d[i] = density_field[ci]; + } + + let cell_origin = vec3(f32(cx), f32(cy), f32(cz)) * qef_params.cell_size; + + // Re-derive crossings from all 6 tets + for (var t = 0u; t < 6u; t++) { + let tet = TET_DECOMP[t]; + var tet_d: array; + var tet_p: array, 4>; + + for (var v = 0u; v < 4u; v++) { + tet_d[v] = corner_d[tet[v]]; + } + + var mask = 0u; + for (var v = 0u; v < 4u; v++) { + if (tet_d[v] >= 0.1) { mask |= (1u << v); } + } + if (mask == 0u || mask == 15u) { continue; } + + for (var e = 0u; e < 6u; e++) { + let e0 = TET_EDGES[e].x; + let e1 = TET_EDGES[e].y; + let above0 = (mask >> e0) & 1u; + let above1 = (mask >> e1) & 1u; + if (above0 == above1) { continue; } + + // Interpolate crossing position + let d0 = tet_d[e0]; + let d1 = tet_d[e1]; + let t_val = (0.1 - d0) / (d1 - d0); + + // Compute world-space position + let cv0 = tet[e0]; + let cv1 = tet[e1]; + let p0 = cell_origin + vec3(f32((cv0>>0u)&1u), f32((cv0>>1u)&1u), f32((cv0>>2u)&1u)) * qef_params.cell_size; + let p1 = cell_origin + vec3(f32((cv1>>0u)&1u), f32((cv1>>1u)&1u), f32((cv1>>2u)&1u)) * qef_params.cell_size; + let p = mix(p0, p1, clamp(t_val, 0.0, 1.0)); + + // Gradient at crossing point + let n = density_gradient(cx, cy, cz); + + // Accumulate: A^T A += n ⊗ n, A^T b += (n·p) * n + let accum = mat3_add_outer(n, 1.0); + ata.m[0] += accum.m[0]; ata.m[1] += accum.m[1]; ata.m[2] += accum.m[2]; + ata.m[3] += accum.m[3]; ata.m[4] += accum.m[4]; ata.m[5] += accum.m[5]; + ata.m[6] += accum.m[6]; ata.m[7] += accum.m[7]; ata.m[8] += accum.m[8]; + + let dot_np = dot(n, p); + atb += n * dot_np; + centroid += p; + point_count++; + } + } + + // Fallback: if no crossings (shouldn't happen), use cell center + if (point_count == 0u) { + let v_out = cell_origin + vec3(0.5) * qef_params.cell_size; + let slot = atomicAdd(&vertex_count, 1u); + vertices[slot] = v_out; + return; + } + + centroid /= f32(point_count); + + // Solve QEF + let v_qef = solve_3x3(ata, atb, qef_params.regularization); + + // Clamp vertex to cell bounds + let v_clamped = clamp(v_qef, cell_origin, cell_origin + vec3(qef_params.cell_size)); + + let slot = atomicAdd(&vertex_count, 1u); + if (slot < arrayLength(&vertices)) { + vertices[slot] = v_clamped; + } +}