Skip to content

Fix typing for select #7572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
- Apply necessary automatic conversions to the `value` argument of `textureStore`. By @jimblandy in [#7567](https://github.com/gfx-rs/wgpu/pull/7567).
- Properly apply WGSL's automatic conversions to the arguments to texture sampling functions. By @jimblandy in [#7548](https://github.com/gfx-rs/wgpu/pull/7548).
- Properly evaluate `abs(most negative abstract int)`. By @jimblandy in [#7507](https://github.com/gfx-rs/wgpu/pull/7507).
- Fix typing for `select`, which had issues particularly with a lack of automatic type conversion. By @ErichDonGubler in [#7572](https://github.com/gfx-rs/wgpu/pull/7572).

#### DX12

Expand Down
28 changes: 28 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ pub(crate) enum Error<'a> {
on_what: DiagnosticAttributeNotSupportedPosition,
spans: Vec<Span>,
},
SelectUnexpectedArgumentType {
arg_span: Span,
arg_type: String,
},
SelectRejectAndAcceptHaveNoCommonType {
reject_span: Span,
reject_type: String,
accept_span: Span,
accept_type: String,
},
}

impl From<ConflictingDiagnosticRuleError> for Error<'_> {
Expand Down Expand Up @@ -1340,6 +1350,24 @@ impl<'a> Error<'a> {
],
}
}
Error::SelectUnexpectedArgumentType { arg_span, ref arg_type } => ParseError {
message: "unexpected argument type for `select` call".into(),
labels: vec![(arg_span, format!("this value of type {arg_type}").into())],
notes: vec!["expected a scalar or a `vecN` of scalars".into()],
},
Error::SelectRejectAndAcceptHaveNoCommonType {
reject_span,
ref reject_type,
accept_span,
ref accept_type,
} => ParseError {
message: "type mismatch for reject and accept values in `select` call".into(),
labels: vec![
(reject_span, format!("reject value of type {reject_type}").into()),
(accept_span, format!("accept value of type {accept_type}").into()),
],
notes: vec![],
},
}
}
}
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ impl crate::Scalar {
self.automatic_conversion_combine(goal) == Some(goal)
}

const fn concretize(self) -> Self {
pub(in crate::front::wgsl) const fn concretize(self) -> Self {
use crate::ScalarKind as Sk;
match self.kind {
Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self,
Expand Down
63 changes: 61 additions & 2 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloc::{
borrow::ToOwned,
boxed::Box,
format,
string::{String, ToString},
vec::Vec,
};
Expand Down Expand Up @@ -2483,12 +2484,70 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
"select" => {
let mut args = ctx.prepare_args(arguments, 3, span);

let reject = self.expression(args.next()?, ctx)?;
let accept = self.expression(args.next()?, ctx)?;
let reject_orig = args.next()?;
let accept_orig = args.next()?;
let mut values = [
self.expression_for_abstract(reject_orig, ctx)?,
self.expression_for_abstract(accept_orig, ctx)?,
];
let condition = self.expression(args.next()?, ctx)?;

args.finish()?;

fn expr_ty<'a>(
ctx: &'a ExpressionContext<'_, '_, '_>,
expr: Handle<ir::Expression>,
) -> &'a ir::TypeInner {
ctx.typifier()[expr].inner_with(&ctx.module.types)
}
let diag_deets =
|module: &ir::Module, ty: &ir::TypeInner, orig_expr| {
(
ctx.ast_expressions.get_span(orig_expr),
format!("`{:?}`", ty.for_debug(&module.types)),
)
};
for (&value, orig_value) in
values.iter().zip([reject_orig, accept_orig])
{
ctx.grow_types(value)?;

let value_type = expr_ty(ctx, value);
if value_type.vector_size_and_scalar().is_none() {
let (arg_span, arg_type) =
diag_deets(ctx.module, value_type, orig_value);
return Err(Box::new(Error::SelectUnexpectedArgumentType {
arg_span,
arg_type,
}));
}
}
let mut consensus_scalar = ctx
.automatic_conversion_consensus(&values)
.map_err(|_idx| {
let [reject, accept] = values;
let [(reject_span, reject_type), (accept_span, accept_type)] =
[(reject_orig, reject), (accept_orig, accept)].map(
|(orig_expr, expr)| {
let ty = expr_ty(ctx, expr);
diag_deets(ctx.module, ty, orig_expr)
},
);
Error::SelectRejectAndAcceptHaveNoCommonType {
reject_span,
reject_type,
accept_span,
accept_type,
}
})?;
if !ctx.is_const(condition) {
consensus_scalar = consensus_scalar.concretize();
}

ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?;

let [reject, accept] = values;

ir::Expression::Select {
reject,
accept,
Expand Down
147 changes: 144 additions & 3 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,27 @@ pub enum ConstantEvaluatorError {
RuntimeExpr,
#[error("Unexpected override-expression")]
OverrideExpr,
#[error("Expected boolean expression for condition argument of `select`, got something else")]
SelectScalarConditionNotABool,
#[error(
"Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
reject,
accept
)]
SelectVecRejectAcceptSizeMismatch {
reject: crate::VectorSize,
accept: crate::VectorSize,
},
#[error("Expected boolean vector for condition arg., got something else")]
SelectConditionNotAVecBool,
#[error(
"Expected same number of vector components between condition, accept, and reject args., got something else",
)]
SelectConditionVecSizeMismatch,
#[error(
"Expected reject and accept args. to be scalars of vectors of the same type, got something else",
)]
SelectAcceptRejectTypeMismatch,
}

impl<'a> ConstantEvaluator<'a> {
Expand Down Expand Up @@ -904,9 +925,19 @@ impl<'a> ConstantEvaluator<'a> {
)),
}
}
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
"select built-in function".into(),
)),
Expression::Select {
reject,
accept,
condition,
} => {
let mut arg = |expr| self.check_and_get(expr);

let reject = arg(reject)?;
let accept = arg(accept)?;
let condition = arg(condition)?;

self.select(reject, accept, condition, span)
}
Expression::Relational { fun, argument } => {
let argument = self.check_and_get(argument)?;
self.relational(fun, argument, span)
Expand Down Expand Up @@ -2497,6 +2528,116 @@ impl<'a> ConstantEvaluator<'a> {

Ok(resolution)
}

fn select(
&mut self,
reject: Handle<Expression>,
accept: Handle<Expression>,
condition: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);

let reject = arg(reject)?;
let accept = arg(accept)?;
let condition = arg(condition)?;

let select_single_component =
|this: &mut Self, reject_scalar, reject, accept, condition| {
let accept = this.cast(accept, reject_scalar, span)?;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated from #7602 (comment): @jimblandy was wondering if we really needed cast here. So far as I can tell...yes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, another issue: We apparently coerce 0 and non-zero numeric values to bool with cast. That makes select accept strictly more things than it should in constant evaluation; is there a reason we should allow such a cast?

if condition {
Ok(accept)
} else {
Ok(reject)
}
};

match (&self.expressions[reject], &self.expressions[accept]) {
(&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
let reject_scalar = reject_lit.scalar();
let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
else {
return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
};
select_single_component(self, reject_scalar, reject, accept, condition)
}
(
&Expression::Compose {
ty: reject_ty,
components: ref reject_components,
},
&Expression::Compose {
ty: accept_ty,
components: ref accept_components,
},
) => {
let ty_deets = |ty| {
let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
(size.unwrap(), scalar)
};

let expected_vec_size = {
let [(reject_vec_size, _), (accept_vec_size, _)] =
[reject_ty, accept_ty].map(ty_deets);

if reject_vec_size != accept_vec_size {
return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
reject: reject_vec_size,
accept: accept_vec_size,
});
}
reject_vec_size
};

let condition_components = match self.expressions[condition] {
Expression::Literal(Literal::Bool(condition)) => {
vec![condition; (expected_vec_size as u8).into()]
}
Expression::Compose {
ty: condition_ty,
components: ref condition_components,
} => {
let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
if condition_scalar.kind != ScalarKind::Bool {
return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
}
if condition_vec_size != expected_vec_size {
return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
}
condition_components
.iter()
.copied()
.map(|component| match &self.expressions[component] {
&Expression::Literal(Literal::Bool(condition)) => condition,
_ => unreachable!(),
})
.collect()
}

_ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
};

let evaluated = Expression::Compose {
ty: reject_ty,
components: reject_components
.clone()
.into_iter()
.zip(accept_components.clone().into_iter())
.zip(condition_components.into_iter())
.map(|((reject, accept), condition)| {
let reject_scalar = match &self.expressions[reject] {
&Expression::Literal(lit) => lit.scalar(),
_ => unreachable!(),
};
select_single_component(self, reject_scalar, reject, accept, condition)
})
.collect::<Result<_, _>>()?,
};
self.register_evaluated_expr(evaluated, span)
}
_ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
}
}
}

fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
Expand Down
32 changes: 32 additions & 0 deletions naga/tests/in/wgsl/select.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
const_assert select(0xdeadbeef, 42f, false) == 0xdeadbeef;
const_assert select(0xdeadbeefu, 42, false) == 0xdeadbeefu;
const_assert select(0xdeadi, 42, false) == 0xdeadi;

const_assert select(42f, 9001, true) == 9001;
const_assert select(42f, 9001, true) == 9001f;
const_assert select(42, 9001i, true) == 9001;
const_assert select(42, 9001u, true) == 9001;

const_assert !select(false, true, false);
const_assert select(false, true, true);
const_assert select(true, false, false);
const_assert !select(true, false, true);

const_assert all(select(vec2(2f), vec2(), true) == vec2(0));
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), vec2(false, false)) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
const_assert all(select(vec2(1), vec2(2f), vec2(true, false)) == vec2(2, 1));

const_assert all(select(vec3(1), vec3(2f), vec3(true)) == vec3(2));
const_assert all(select(vec4(1), vec4(2f), vec4(true)) == vec4(2));

@compute @workgroup_size(1, 1)
fn main() {
_ = select(1, 2f, false);

var x0 = vec2(1, 2);
var i1: vec2<f32> = select(vec2<f32>(1., 0.), vec2<f32>(0., 1.), (x0.x < x0.y));
}
Loading