Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meshlet fill cluster buffers rewritten #15955

Merged
merged 38 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
063a8b1
WIP: Proper error projection
JMS55 Oct 9, 2024
d8fd980
WIP: LOD selection (not working), some cleanup
JMS55 Oct 10, 2024
718e663
Fix: Forgot to set parent_group_error for children after building par…
JMS55 Oct 10, 2024
1ddf173
Misc
JMS55 Oct 10, 2024
8b71b24
Use zeux's error projection code
JMS55 Oct 11, 2024
4f96881
Misc
JMS55 Oct 11, 2024
cac11bc
Pack group errors 2x16float
JMS55 Oct 11, 2024
b5d8a0c
Clippy
JMS55 Oct 11, 2024
634eaa2
Try to fix ortho, use 0.5px threshold
JMS55 Oct 12, 2024
b093698
Fixes for error projection
JMS55 Oct 12, 2024
31dd6f6
Go back to previous codegen, slightly faster
JMS55 Oct 12, 2024
3a48456
Update bunny URL
JMS55 Oct 12, 2024
27c03f8
Add missing doc comment
JMS55 Oct 12, 2024
d9f5b75
Remove perspective comment, I don't actually like how it behaves
JMS55 Oct 12, 2024
2fa9731
Larger meshlet group size, constant seed for metis, smallvec optimiza…
JMS55 Oct 12, 2024
a88a5a7
Factor out constant
JMS55 Oct 12, 2024
8db943e
Misc refactor
JMS55 Oct 12, 2024
0abbb47
Small rename
JMS55 Oct 12, 2024
a1c2ee3
Use position-only vertices instead of triangle edges to determine mes…
JMS55 Oct 12, 2024
62d3605
Misc
JMS55 Oct 12, 2024
99722b0
Use position-only vertex remap count to save a bit of memory when all…
JMS55 Oct 12, 2024
86f0b70
Fix segfault when generating vertex remap
JMS55 Oct 12, 2024
24187a4
Lock group borders manually for better simplification
JMS55 Oct 13, 2024
bad03cf
Small doc fix
JMS55 Oct 13, 2024
6eda2aa
Comment missing a word
JMS55 Oct 13, 2024
aff12ab
Retry stuck meshlets when building the DAG
JMS55 Oct 13, 2024
d65879c
Update bunny asset url
JMS55 Oct 13, 2024
1ca60e7
CI lints
JMS55 Oct 13, 2024
38644da
WIP: Faster fill_cluster_buffers, but glitchy
JMS55 Oct 14, 2024
acc0baa
Require PreviousGlobalTransform to avoid archetype moves
JMS55 Oct 14, 2024
e41adad
Add required Default derive to PreviousGlobalTransform
JMS55 Oct 16, 2024
895d2ba
Check cluster_buffer_slots constraint
JMS55 Oct 16, 2024
efa2b67
Fix compile issue
JMS55 Oct 16, 2024
0c0165d
Remap 1d->2d dispatch only if needed
JMS55 Oct 19, 2024
1c71836
Fix borrow
JMS55 Oct 19, 2024
8aed635
Remove early submit that is not worth the time spent
JMS55 Oct 23, 2024
cd84144
Merge commit '6d42830c7f2a5583df7a1bffadbfb1cef6eef3c3' into meshlet-…
JMS55 Oct 23, 2024
3305a20
Fix comment merge
JMS55 Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/bevy_pbr/src/meshlet/cull_clusters.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
meshlet_software_raster_indirect_args,
meshlet_hardware_raster_indirect_args,
meshlet_raster_clusters,
meshlet_raster_cluster_rightmost_slot,
constants,
MeshletBoundingSphere,
}
#import bevy_render::maths::affine3_to_square
Expand All @@ -32,7 +32,7 @@ fn cull_clusters(
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= arrayLength(&meshlet_cluster_meshlet_ids) { return; }
if cluster_id >= constants.scene_cluster_count { return; }

#ifdef MESHLET_SECOND_CULLING_PASS
if !cluster_is_second_pass_candidate(cluster_id) { return; }
Expand Down Expand Up @@ -138,7 +138,7 @@ fn cull_clusters(
} else {
// Append this cluster to the list for hardware rasterization
buffer_slot = atomicAdd(&meshlet_hardware_raster_indirect_args.instance_count, 1u);
buffer_slot = meshlet_raster_cluster_rightmost_slot - buffer_slot;
buffer_slot = constants.meshlet_raster_cluster_rightmost_slot - buffer_slot;
}
meshlet_raster_clusters[buffer_slot] = cluster_id;
}
Expand Down
56 changes: 31 additions & 25 deletions crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl
Original file line number Diff line number Diff line change
@@ -1,44 +1,50 @@
#import bevy_pbr::meshlet_bindings::{
cluster_count,
meshlet_instance_meshlet_counts_prefix_sum,
scene_instance_count,
meshlet_global_cluster_count,
meshlet_instance_meshlet_counts,
meshlet_instance_meshlet_slice_starts,
meshlet_cluster_instance_ids,
meshlet_cluster_meshlet_ids,
}

/// Writes out instance_id and meshlet_id to the global buffers for each cluster in the scene.

var<workgroup> cluster_slice_start_workgroup: u32;

@compute
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread
@workgroup_size(1024, 1, 1) // 1024 threads per workgroup, 1 instance per workgroup
fn fill_cluster_buffers(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(local_invocation_index) local_invocation_index: u32,
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= cluster_count { return; } // TODO: Could be an arrayLength?

// Binary search to find the instance this cluster belongs to
var left = 0u;
var right = arrayLength(&meshlet_instance_meshlet_counts_prefix_sum) - 1u;
while left <= right {
let mid = (left + right) / 2u;
if meshlet_instance_meshlet_counts_prefix_sum[mid] <= cluster_id {
left = mid + 1u;
} else {
right = mid - 1u;
}
// Calculate the instance ID for this workgroup
var instance_id = workgroup_id.x + (workgroup_id.y * num_workgroups.x);
if instance_id >= scene_instance_count { return; }

let instance_meshlet_count = meshlet_instance_meshlet_counts[instance_id];
let instance_meshlet_slice_start = meshlet_instance_meshlet_slice_starts[instance_id];

// Reserve cluster slots for the instance and broadcast to the workgroup
if local_invocation_index == 0u {
cluster_slice_start_workgroup = atomicAdd(&meshlet_global_cluster_count, instance_meshlet_count);
}
let instance_id = right;
let cluster_slice_start = workgroupUniformLoad(&cluster_slice_start_workgroup);

// Find the meshlet ID for this cluster within the instance's MeshletMesh
let meshlet_id_local = cluster_id - meshlet_instance_meshlet_counts_prefix_sum[instance_id];
// Loop enough times to write out all the meshlets for the instance given that each thread writes 1 meshlet in each iteration
for (var clusters_written = 0u; clusters_written < instance_meshlet_count; clusters_written += 1024u) {
// Calculate meshlet ID within this instance's MeshletMesh to process for this thread
let meshlet_id_local = clusters_written + local_invocation_index;
if meshlet_id_local >= instance_meshlet_count { return; }

// Find the overall meshlet ID in the global meshlet buffer
let meshlet_id = meshlet_id_local + meshlet_instance_meshlet_slice_starts[instance_id];
// Find the overall cluster ID in the global cluster buffer
let cluster_id = cluster_slice_start + meshlet_id_local;

// Write results to buffers
meshlet_cluster_instance_ids[cluster_id] = instance_id;
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
// Find the overall meshlet ID in the global meshlet buffer
let meshlet_id = instance_meshlet_slice_start + meshlet_id_local;

// Write results to buffers
meshlet_cluster_instance_ids[cluster_id] = instance_id;
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
}
}
48 changes: 27 additions & 21 deletions crates/bevy_pbr/src/meshlet/instance_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,46 @@ use bevy_ecs::{
query::Has,
system::{Local, Query, Res, ResMut, Resource, SystemState},
};
use bevy_render::sync_world::MainEntity;
use bevy_render::{render_resource::StorageBuffer, view::RenderLayers, MainWorld};
use bevy_render::{
render_resource::StorageBuffer, sync_world::MainEntity, view::RenderLayers, MainWorld,
};
use bevy_transform::components::GlobalTransform;
use bevy_utils::{HashMap, HashSet};
use core::ops::{DerefMut, Range};

/// Manages data for each entity with a [`MeshletMesh`].
#[derive(Resource)]
pub struct InstanceManager {
/// Amount of clusters in the scene (sum of all meshlet counts across all instances)
/// Amount of instances in the scene.
pub scene_instance_count: u32,
/// Amount of clusters in the scene.
pub scene_cluster_count: u32,

/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`]
/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`].
pub instances: Vec<(MainEntity, RenderLayers, bool)>,
/// Per-instance [`MeshUniform`]
/// Per-instance [`MeshUniform`].
pub instance_uniforms: StorageBuffer<Vec<MeshUniform>>,
/// Per-instance material ID
/// Per-instance material ID.
pub instance_material_ids: StorageBuffer<Vec<u32>>,
/// Prefix-sum of meshlet counts per instance
pub instance_meshlet_counts_prefix_sum: StorageBuffer<Vec<u32>>,
/// Per-instance index to the start of the instance's slice of the meshlets buffer
/// Per-instance count of meshlets in the instance's [`MeshletMesh`].
pub instance_meshlet_counts: StorageBuffer<Vec<u32>>,
/// Per-instance index to the start of the instance's slice of the meshlets buffer.
pub instance_meshlet_slice_starts: StorageBuffer<Vec<u32>>,
/// Per-view per-instance visibility bit. Used for [`RenderLayers`] and [`NotShadowCaster`] support.
pub view_instance_visibility: EntityHashMap<StorageBuffer<Vec<u32>>>,

/// Next material ID available for a [`Material`]
/// Next material ID available for a [`Material`].
next_material_id: u32,
/// Map of [`Material`] to material ID
/// Map of [`Material`] to material ID.
material_id_lookup: HashMap<UntypedAssetId, u32>,
/// Set of material IDs used in the scene
/// Set of material IDs used in the scene.
material_ids_present_in_scene: HashSet<u32>,
}

impl InstanceManager {
pub fn new() -> Self {
Self {
scene_instance_count: 0,
scene_cluster_count: 0,

instances: Vec::new(),
Expand All @@ -59,9 +63,9 @@ impl InstanceManager {
buffer.set_label(Some("meshlet_instance_material_ids"));
buffer
},
instance_meshlet_counts_prefix_sum: {
instance_meshlet_counts: {
let mut buffer = StorageBuffer::default();
buffer.set_label(Some("meshlet_instance_meshlet_counts_prefix_sum"));
buffer.set_label(Some("meshlet_instance_meshlet_counts"));
buffer
},
instance_meshlet_slice_starts: {
Expand All @@ -80,7 +84,7 @@ impl InstanceManager {
#[allow(clippy::too_many_arguments)]
pub fn add_instance(
&mut self,
instance: Entity,
instance: MainEntity,
meshlets_slice: Range<u32>,
transform: &GlobalTransform,
previous_transform: Option<&PreviousGlobalTransform>,
Expand Down Expand Up @@ -108,20 +112,21 @@ impl InstanceManager {

// Append instance data
self.instances.push((
instance.into(),
instance,
render_layers.cloned().unwrap_or(RenderLayers::default()),
not_shadow_caster,
));
self.instance_uniforms.get_mut().push(mesh_uniform);
self.instance_material_ids.get_mut().push(0);
self.instance_meshlet_counts_prefix_sum
self.instance_meshlet_counts
.get_mut()
.push(self.scene_cluster_count);
.push(meshlets_slice.len() as u32);
self.instance_meshlet_slice_starts
.get_mut()
.push(meshlets_slice.start);

self.scene_cluster_count += meshlets_slice.end - meshlets_slice.start;
self.scene_instance_count += 1;
self.scene_cluster_count += meshlets_slice.len() as u32;
}

/// Get the material ID for a [`crate::Material`].
Expand All @@ -140,12 +145,13 @@ impl InstanceManager {
}

pub fn reset(&mut self, entities: &Entities) {
self.scene_instance_count = 0;
self.scene_cluster_count = 0;

self.instances.clear();
self.instance_uniforms.get_mut().clear();
self.instance_material_ids.get_mut().clear();
self.instance_meshlet_counts_prefix_sum.get_mut().clear();
self.instance_meshlet_counts.get_mut().clear();
self.instance_meshlet_slice_starts.get_mut().clear();
self.view_instance_visibility
.retain(|view_entity, _| entities.contains(*view_entity));
Expand Down Expand Up @@ -227,7 +233,7 @@ pub fn extract_meshlet_mesh_entities(

// Add the instance's data to the instance manager
instance_manager.add_instance(
instance,
instance.into(),
meshlets_slice,
transform,
previous_transform,
Expand Down
16 changes: 9 additions & 7 deletions crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,27 @@ struct DrawIndirectArgs {
const CENTIMETERS_PER_METER = 100.0;

#ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS
var<push_constant> cluster_count: u32;
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts_prefix_sum: array<u32>; // Per entity instance
var<push_constant> scene_instance_count: u32;
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts: array<u32>; // Per entity instance
@group(0) @binding(1) var<storage, read> meshlet_instance_meshlet_slice_starts: array<u32>; // Per entity instance
@group(0) @binding(2) var<storage, read_write> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(3) var<storage, read_write> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(4) var<storage, read_write> meshlet_global_cluster_count: atomic<u32>; // Single object shared between all workgroups
#endif

#ifdef MESHLET_CULLING_PASS
var<push_constant> meshlet_raster_cluster_rightmost_slot: u32;
struct Constants { scene_cluster_count: u32, meshlet_raster_cluster_rightmost_slot: u32 }
var<push_constant> constants: Constants;
@group(0) @binding(0) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(1) var<storage, read> meshlet_bounding_spheres: array<MeshletBoundingSpheres>; // Per meshlet
@group(0) @binding(2) var<storage, read> meshlet_simplification_errors: array<u32>; // Per meshlet
@group(0) @binding(3) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(4) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(5) var<storage, read> meshlet_view_instance_visibility: array<u32>; // 1 bit per entity instance, packed as a bitmask
@group(0) @binding(6) var<storage, read_write> meshlet_second_pass_candidates: array<atomic<u32>>; // 1 bit per cluster , packed as a bitmask
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
@group(0) @binding(10) var depth_pyramid: texture_2d<f32>; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass
@group(0) @binding(11) var<uniform> view: View;
@group(0) @binding(12) var<uniform> previous_view: PreviousViewUniforms;
Expand All @@ -95,7 +97,7 @@ fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool {
@group(0) @binding(3) var<storage, read> meshlet_vertex_positions: array<u32>; // Many per meshlet
@group(0) @binding(4) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(5) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
@group(0) @binding(7) var<storage, read> meshlet_software_raster_cluster_count: u32;
#ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS_OUTPUT
@group(0) @binding(8) var<storage, read_write> meshlet_visibility_buffer: array<atomic<u64>>; // Per pixel
Expand Down
11 changes: 9 additions & 2 deletions crates/bevy_pbr/src/meshlet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use self::{
},
visibility_buffer_raster_node::MeshletVisibilityBufferRasterPassNode,
};
use crate::{graph::NodePbr, Material, MeshMaterial3d};
use crate::{graph::NodePbr, Material, MeshMaterial3d, PreviousGlobalTransform};
use bevy_app::{App, Plugin, PostUpdate};
use bevy_asset::{load_internal_asset, AssetApp, AssetId, Handle};
use bevy_core_pipeline::{
Expand Down Expand Up @@ -129,6 +129,8 @@ pub struct MeshletPlugin {
/// If this number is too low, you'll see rendering artifacts like missing or blinking meshes.
///
/// Each cluster slot costs 4 bytes of VRAM.
///
/// Must not be greater than 2^25.
pub cluster_buffer_slots: u32,
}

Expand All @@ -147,6 +149,11 @@ impl Plugin for MeshletPlugin {
#[cfg(target_endian = "big")]
compile_error!("MeshletPlugin is only supported on little-endian processors.");

if self.cluster_buffer_slots > 2_u32.pow(25) {
error!("MeshletPlugin::cluster_buffer_slots must not be greater than 2^25.");
std::process::exit(1);
}

load_internal_asset!(
app,
MESHLET_BINDINGS_SHADER_HANDLE,
Expand Down Expand Up @@ -293,7 +300,7 @@ impl Plugin for MeshletPlugin {
/// The meshlet mesh equivalent of [`bevy_render::mesh::Mesh3d`].
#[derive(Component, Clone, Debug, Default, Deref, DerefMut, Reflect, PartialEq, Eq, From)]
#[reflect(Component, Default)]
#[require(Transform, Visibility)]
#[require(Transform, PreviousGlobalTransform, Visibility)]
pub struct MeshletMesh3d(pub Handle<MeshletMesh>);

impl From<MeshletMesh3d> for AssetId<MeshletMesh> {
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_pbr/src/meshlet/pipelines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl FromWorld for MeshletPipelines {
layout: vec![cull_layout.clone()],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
range: 0..8,
}],
shader: MESHLET_CULLING_SHADER_HANDLE,
shader_defs: vec![
Expand All @@ -99,7 +99,7 @@ impl FromWorld for MeshletPipelines {
layout: vec![cull_layout],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
range: 0..8,
}],
shader: MESHLET_CULLING_SHADER_HANDLE,
shader_defs: vec![
Expand Down Expand Up @@ -441,7 +441,10 @@ impl FromWorld for MeshletPipelines {
pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("meshlet_remap_1d_to_2d_dispatch_pipeline".into()),
layout: vec![layout],
push_constant_ranges: vec![],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
}],
shader: MESHLET_REMAP_1D_TO_2D_DISPATCH_SHADER_HANDLE,
shader_defs: vec![],
entry_point: "remap_dispatch".into(),
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_pbr/src/meshlet/remap_1d_to_2d_dispatch.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ struct DispatchIndirectArgs {

@group(0) @binding(0) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs;
@group(0) @binding(1) var<storage, read_write> meshlet_software_raster_cluster_count: u32;
var<push_constant> max_compute_workgroups_per_dimension: u32;

@compute
@workgroup_size(1, 1, 1)
fn remap_dispatch() {
meshlet_software_raster_cluster_count = meshlet_software_raster_indirect_args.x;

let n = u32(ceil(sqrt(f32(meshlet_software_raster_indirect_args.x))));
meshlet_software_raster_indirect_args.x = n;
meshlet_software_raster_indirect_args.y = n;
if meshlet_software_raster_cluster_count > max_compute_workgroups_per_dimension {
let n = u32(ceil(sqrt(f32(meshlet_software_raster_cluster_count))));
meshlet_software_raster_indirect_args.x = n;
meshlet_software_raster_indirect_args.y = n;
}
}
Loading