Skip to content

[naga hlsl-out, glsl-out] Support atomicCompareExchangeWeak #7658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Bottom level categories:
#### Naga

- When emitting GLSL, Uniform and Storage Buffer memory layouts are now emitted even if no explicit binding is given. By @cloone8 in [#7579](https://github.com/gfx-rs/wgpu/pull/7579).
- Add support for `atomicCompareExchangeWeak` in HLSL and GLSL backends. By @cryvosh in [#7658](https://github.com/gfx-rs/wgpu/pull/7658)

### Bug Fixes

Expand Down
105 changes: 77 additions & 28 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ pub struct Writer<'a, W> {
multiview: Option<core::num::NonZeroU32>,
/// Mapping of varying variables to their location. Needed for reflections.
varying: crate::FastHashMap<String, VaryingLocation>,

/// Set of special type names whose definitions have already been written. To prevent duplicates.
written_special_struct_names: crate::FastHashSet<String>,
}

impl<'a, W: Write> Writer<'a, W> {
Expand Down Expand Up @@ -688,6 +691,7 @@ impl<'a, W: Write> Writer<'a, W> {
need_bake_expressions: Default::default(),
continue_ctx: back::continue_forward::ContinueCtx::default(),
varying: Default::default(),
written_special_struct_names: Default::default(),
};

// Find all features required to print this module
Expand Down Expand Up @@ -787,23 +791,36 @@ impl<'a, W: Write> Writer<'a, W> {
// you can't make a struct without adding all of its members first.
for (handle, ty) in self.module.types.iter() {
if let TypeInner::Struct { ref members, .. } = ty.inner {
// Skip special atomic compare exchange result structs (generated in next loop)
let struct_name = &self.names[&NameKey::Type(handle)];
if struct_name.starts_with("_atomic_compare_exchange_result") {
continue;
}

// Structures ending with runtime-sized arrays can only be
// rendered as shader storage blocks in GLSL, not stand-alone
// struct types.
if !self.module.types[members.last().unwrap().ty]
.inner
.is_dynamically_sized(&self.module.types)
{
let name = &self.names[&NameKey::Type(handle)];
write!(self.out, "struct {name} ")?;
write!(self.out, "struct {struct_name} ")?;
self.write_struct_body(handle, members)?;
writeln!(self.out, ";")?;
}
}
}

// Write functions to create special types.
// Write functions and struct definitions for special types.
for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
let struct_name = &self.names[&NameKey::Type(*struct_ty)];
if !self
.written_special_struct_names
.insert(struct_name.clone())
{
continue;
}

match type_key {
&crate::PredeclaredType::ModfResult { size, scalar }
| &crate::PredeclaredType::FrexpResult { size, scalar } => {
Expand Down Expand Up @@ -835,8 +852,6 @@ impl<'a, W: Write> Writer<'a, W> {
(FREXP_FUNCTION, "frexp", other_type_name)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
if !self.options.version.supports_frexp_function()
&& matches!(type_key, &crate::PredeclaredType::FrexpResult { .. })
Expand All @@ -860,7 +875,14 @@ impl<'a, W: Write> Writer<'a, W> {
)?;
}
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let scalar_str = glsl_scalar(scalar)?.full;
writeln!(
self.out,
"struct {} {{\n {} old_value;\n bool exchanged;\n}};",
struct_name, scalar_str
)?;
}
}
}

Expand Down Expand Up @@ -1118,6 +1140,17 @@ impl<'a, W: Write> Writer<'a, W> {
/// # Notes
/// Adds no trailing or leading whitespace
fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult {
for (key, &handle) in self.module.special_types.predeclared_types.iter() {
if handle == ty {
if let crate::PredeclaredType::AtomicCompareExchangeWeakResult(_) = *key {
let name = &self.names[&NameKey::Type(ty)];
write!(self.out, "{name}")?;
return Ok(());
}
break;
}
}

match self.module.types[ty].inner {
// glsl has no pointer types so just write types as normal and loads are skipped
TypeInner::Pointer { base, .. } => self.write_type(base),
Expand Down Expand Up @@ -2572,33 +2605,49 @@ impl<'a, W: Write> Writer<'a, W> {
result,
} => {
write!(self.out, "{level}")?;
if let Some(result) = result {
let res_name = Baked(result).to_string();
let res_ty = ctx.resolve_type(result, &self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
}

let fun_str = fun.to_glsl();
write!(self.out, "atomic{fun_str}(")?;
self.write_expr(pointer, ctx)?;
write!(self.out, ", ")?;
// handle the special cases
match *fun {
crate::AtomicFunction::Subtract => {
// we just wrote `InterlockedAdd`, so negate the argument
write!(self.out, "-")?;
crate::AtomicFunction::Exchange {
compare: Some(compare_expr),
} => {
let result_handle = result.expect("CompareExchange must have a result");
let res_name = Baked(result_handle).to_string();
self.write_type(ctx.info[result_handle].ty.handle().unwrap())?;
write!(self.out, " {res_name};")?;
write!(self.out, " {res_name}.old_value = atomicCompSwap(")?;
self.write_expr(pointer, ctx)?;
write!(self.out, ", ")?;
self.write_expr(compare_expr, ctx)?;
write!(self.out, ", ")?;
self.write_expr(value, ctx)?;
writeln!(self.out, ");")?;

write!(
self.out,
"{level}{res_name}.exchanged = ({res_name}.old_value == "
)?;
self.write_expr(compare_expr, ctx)?;
writeln!(self.out, ");")?;
self.named_expressions.insert(result_handle, res_name);
}
crate::AtomicFunction::Exchange { compare: Some(_) } => {
return Err(Error::Custom(
"atomic CompareExchange is not implemented".to_string(),
));
_ => {
if let Some(result) = result {
let res_name = Baked(result).to_string();
self.write_type(ctx.info[result].ty.handle().unwrap())?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
}
let fun_str = fun.to_glsl();
write!(self.out, "atomic{fun_str}(")?;
self.write_expr(pointer, ctx)?;
write!(self.out, ", ")?;
if let crate::AtomicFunction::Subtract = *fun {
write!(self.out, "-")?;
}
self.write_expr(value, ctx)?;
writeln!(self.out, ");")?;
}
_ => {}
}
self.write_expr(value, ctx)?;
writeln!(self.out, ");")?;
}
// Stores a value into an image.
Statement::ImageAtomic {
Expand Down
2 changes: 1 addition & 1 deletion naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl crate::AtomicFunction {
Self::Min => "Min",
Self::Max => "Max",
Self::Exchange { compare: None } => "Exchange",
Self::Exchange { .. } => "", //TODO
Self::Exchange { .. } => "CompareExchange",
}
}
}
117 changes: 74 additions & 43 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2358,79 +2358,78 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
result,
} => {
write!(self.out, "{level}")?;
let res_name = match result {
None => None,
Some(result) => {
let name = Baked(result).to_string();
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => {
self.write_type(module, handle)?
}
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name}; ")?;
Some((result, name))
}
let res_var_info = if let Some(res_handle) = result {
let name = Baked(res_handle).to_string();
match func_ctx.info[res_handle].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name}; ")?;
Some((res_handle, name))
} else {
None
};

// Validation ensures that `pointer` has a `Pointer` type.
let pointer_space = func_ctx
.resolve_type(pointer, &module.types)
.pointer_space()
.unwrap();

let fun_str = fun.to_hlsl_suffix();
let compare_expr = match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
_ => None,
};
match pointer_space {
crate::AddressSpace::WorkGroup => {
write!(self.out, "Interlocked{fun_str}(")?;
self.write_expr(module, pointer, func_ctx)?;
self.emit_hlsl_atomic_tail(
module,
func_ctx,
fun,
compare_expr,
value,
&res_var_info,
)?;
}
crate::AddressSpace::Storage { .. } => {
let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
// The call to `self.write_storage_address` wants
// mutable access to all of `self`, so temporarily take
// ownership of our reusable access chain buffer.
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let width = match func_ctx.resolve_type(value, &module.types) {
&TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
_ => "",
};
write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
let chain = mem::take(&mut self.temp_access_chain);
self.write_storage_address(module, &chain, func_ctx)?;
self.temp_access_chain = chain;
self.emit_hlsl_atomic_tail(
module,
func_ctx,
fun,
compare_expr,
value,
&res_var_info,
)?;
}
ref other => {
return Err(Error::Custom(format!(
"invalid address space {other:?} for atomic statement"
)))
}
}
write!(self.out, ", ")?;
// handle the special cases
match *fun {
crate::AtomicFunction::Subtract => {
// we just wrote `InterlockedAdd`, so negate the argument
write!(self.out, "-")?;
}
crate::AtomicFunction::Exchange { compare: Some(_) } => {
return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
if let Some(cmp) = compare_expr {
if let Some(&(res_handle, ref res_name)) = res_var_info.as_ref() {
write!(
self.out,
"{level}{res_name}.exchanged = ({res_name}.old_value == "
)?;
self.write_expr(module, cmp, func_ctx)?;
writeln!(self.out, ");")?;
self.named_expressions.insert(res_handle, res_name.clone());
}
_ => {}
}
self.write_expr(module, value, func_ctx)?;

// The `original_value` out parameter is optional for all the
// `Interlocked` functions we generate other than
// `InterlockedExchange`.
if let Some((result, name)) = res_name {
write!(self.out, ", {name}")?;
self.named_expressions.insert(result, name);
}

writeln!(self.out, ");")?;
}
Statement::ImageAtomic {
image,
Expand Down Expand Up @@ -4287,6 +4286,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
Ok(())
}

/// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
fn emit_hlsl_atomic_tail(
&mut self,
module: &Module,
func_ctx: &back::FunctionCtx<'_>,
fun: &crate::AtomicFunction,
compare_expr: Option<Handle<crate::Expression>>,
value: Handle<crate::Expression>,
res_var_info: &Option<(Handle<crate::Expression>, String)>,
) -> BackendResult {
if let Some(cmp) = compare_expr {
write!(self.out, ", ")?;
self.write_expr(module, cmp, func_ctx)?;
}
write!(self.out, ", ")?;
if let crate::AtomicFunction::Subtract = *fun {
write!(self.out, "-")?;
}
self.write_expr(module, value, func_ctx)?;
if let Some(&(res_handle, ref res_name)) = res_var_info.as_ref() {
write!(self.out, ", ")?;
if compare_expr.is_some() {
write!(self.out, "{res_name}.old_value")?;
} else {
write!(self.out, "{res_name}")?;
}
self.named_expressions.insert(res_handle, res_name.clone());
}
writeln!(self.out, ");")?;
Ok(())
}
}

pub(super) struct MatrixType {
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/in/wgsl/atomicCompareExchange-int64.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
god_mode = true
targets = "SPIRV | WGSL"
targets = "SPIRV | HLSL | WGSL"

[hlsl]
shader_model = "V6_6"
fake_missing_bindings = true
push_constants_target = { register = 0, space = 0 }
restrict_indexing = true
Expand Down
2 changes: 1 addition & 1 deletion naga/tests/in/wgsl/atomicCompareExchange.toml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
targets = "SPIRV | METAL | WGSL"
targets = "SPIRV | METAL | GLSL | HLSL | WGSL"
17 changes: 8 additions & 9 deletions naga/tests/in/wgsl/atomicOps-int64.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
atomicExchange(&workgroup_struct.atomic_scalar, 1lu);
atomicExchange(&workgroup_struct.atomic_arr[1], 1li);

// // TODO: https://github.com/gpuweb/gpuweb/issues/2021
// atomicCompareExchangeWeak(&storage_atomic_scalar, 1lu);
// atomicCompareExchangeWeak(&storage_atomic_arr[1], 1li);
// atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1lu);
// atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1li);
// atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1lu);
// atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1li);
// atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1lu);
// atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1li);
let cas_res_0 = atomicCompareExchangeWeak(&storage_atomic_scalar, 1lu, 2lu);
let cas_res_1 = atomicCompareExchangeWeak(&storage_atomic_arr[1], 1li, 2li);
let cas_res_2 = atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1lu, 2lu);
let cas_res_3 = atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1li, 2li);
let cas_res_4 = atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1lu, 2lu);
let cas_res_5 = atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1li, 2li);
let cas_res_6 = atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1lu, 2lu);
let cas_res_7 = atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1li, 2li);
}
Loading