Skip to content

Commit dd40727

Browse files
fix(naga): properly impl. auto. type conv. for select
1 parent bb554c9 commit dd40727

18 files changed

+1013
-619
lines changed

naga/src/front/wgsl/error.rs

+28
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ pub(crate) enum Error<'a> {
399399
on_what: DiagnosticAttributeNotSupportedPosition,
400400
spans: Vec<Span>,
401401
},
402+
SelectUnexpectedArgumentType {
403+
arg_span: Span,
404+
arg_type: String,
405+
},
406+
SelectRejectAndAcceptHaveNoCommonType {
407+
reject_span: Span,
408+
reject_type: String,
409+
accept_span: Span,
410+
accept_type: String,
411+
},
402412
}
403413

404414
impl From<ConflictingDiagnosticRuleError> for Error<'_> {
@@ -1340,6 +1350,24 @@ impl<'a> Error<'a> {
13401350
],
13411351
}
13421352
}
1353+
Error::SelectUnexpectedArgumentType { arg_span, ref arg_type } => ParseError {
1354+
message: "unexpected argument type for `select` call".into(),
1355+
labels: vec![(arg_span, format!("this value of type {arg_type}").into())],
1356+
notes: vec!["expected a scalar or a `vecN` of scalars".into()],
1357+
},
1358+
Error::SelectRejectAndAcceptHaveNoCommonType {
1359+
reject_span,
1360+
ref reject_type,
1361+
accept_span,
1362+
ref accept_type,
1363+
} => ParseError {
1364+
message: "type mismatch for reject and accept values in `select` call".into(),
1365+
labels: vec![
1366+
(reject_span, format!("reject value of type {reject_type}").into()),
1367+
(accept_span, format!("accept value of type {accept_type}").into()),
1368+
],
1369+
notes: vec![],
1370+
},
13431371
}
13441372
}
13451373
}

naga/src/front/wgsl/lower/mod.rs

+61-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use alloc::{
22
borrow::ToOwned,
33
boxed::Box,
4+
format,
45
string::{String, ToString},
56
vec::Vec,
67
};
@@ -2483,12 +2484,70 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24832484
"select" => {
24842485
let mut args = ctx.prepare_args(arguments, 3, span);
24852486

2486-
let reject = self.expression(args.next()?, ctx)?;
2487-
let accept = self.expression(args.next()?, ctx)?;
2487+
let reject_orig = args.next()?;
2488+
let accept_orig = args.next()?;
2489+
let mut values = [
2490+
self.expression_for_abstract(reject_orig, ctx)?,
2491+
self.expression_for_abstract(accept_orig, ctx)?,
2492+
];
24882493
let condition = self.expression(args.next()?, ctx)?;
24892494

24902495
args.finish()?;
24912496

2497+
fn expr_ty<'a>(
2498+
ctx: &'a ExpressionContext<'_, '_, '_>,
2499+
expr: Handle<ir::Expression>,
2500+
) -> &'a ir::TypeInner {
2501+
ctx.typifier()[expr].inner_with(&ctx.module.types)
2502+
}
2503+
let diag_deets =
2504+
|module: &ir::Module, ty: &ir::TypeInner, orig_expr| {
2505+
(
2506+
ctx.ast_expressions.get_span(orig_expr),
2507+
format!("`{:?}`", ty.for_debug(&module.types)),
2508+
)
2509+
};
2510+
for (&value, orig_value) in
2511+
values.iter().zip([reject_orig, accept_orig])
2512+
{
2513+
ctx.grow_types(value)?;
2514+
2515+
let value_type = expr_ty(ctx, value);
2516+
if value_type.vector_size_and_scalar().is_none() {
2517+
let (arg_span, arg_type) =
2518+
diag_deets(ctx.module, value_type, orig_value);
2519+
return Err(Box::new(Error::SelectUnexpectedArgumentType {
2520+
arg_span,
2521+
arg_type,
2522+
}));
2523+
}
2524+
}
2525+
let mut consensus_scalar = ctx
2526+
.automatic_conversion_consensus(&values)
2527+
.map_err(|_idx| {
2528+
let [reject, accept] = values;
2529+
let [(reject_span, reject_type), (accept_span, accept_type)] =
2530+
[(reject_orig, reject), (accept_orig, accept)].map(
2531+
|(orig_expr, expr)| {
2532+
let ty = expr_ty(ctx, expr);
2533+
diag_deets(ctx.module, ty, orig_expr)
2534+
},
2535+
);
2536+
Error::SelectRejectAndAcceptHaveNoCommonType {
2537+
reject_span,
2538+
reject_type,
2539+
accept_span,
2540+
accept_type,
2541+
}
2542+
})?;
2543+
if !ctx.is_const(condition) {
2544+
consensus_scalar = consensus_scalar.concretize();
2545+
}
2546+
2547+
ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?;
2548+
2549+
let [reject, accept] = values;
2550+
24922551
ir::Expression::Select {
24932552
reject,
24942553
accept,

naga/src/proc/constant_evaluator.rs

+144-3
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,27 @@ pub enum ConstantEvaluatorError {
563563
RuntimeExpr,
564564
#[error("Unexpected override-expression")]
565565
OverrideExpr,
566+
#[error("Expected boolean expression for condition argument of `select`, got something else")]
567+
SelectScalarConditionNotABool,
568+
#[error(
569+
"Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
570+
reject,
571+
accept
572+
)]
573+
SelectVecRejectAcceptSizeMismatch {
574+
reject: crate::VectorSize,
575+
accept: crate::VectorSize,
576+
},
577+
#[error("Expected boolean vector for condition arg., got something else")]
578+
SelectConditionNotAVecBool,
579+
#[error(
580+
"Expected same number of vector components between condition, accept, and reject args., got something else",
581+
)]
582+
SelectConditionVecSizeMismatch,
583+
#[error(
584+
"Expected reject and accept args. to be scalars of vectors of the same type, got something else",
585+
)]
586+
SelectAcceptRejectTypeMismatch,
566587
}
567588

568589
impl<'a> ConstantEvaluator<'a> {
@@ -904,9 +925,19 @@ impl<'a> ConstantEvaluator<'a> {
904925
)),
905926
}
906927
}
907-
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
908-
"select built-in function".into(),
909-
)),
928+
Expression::Select {
929+
reject,
930+
accept,
931+
condition,
932+
} => {
933+
let mut arg = |expr| self.check_and_get(expr);
934+
935+
let reject = arg(reject)?;
936+
let accept = arg(accept)?;
937+
let condition = arg(condition)?;
938+
939+
self.select(reject, accept, condition, span)
940+
}
910941
Expression::Relational { fun, argument } => {
911942
let argument = self.check_and_get(argument)?;
912943
self.relational(fun, argument, span)
@@ -2497,6 +2528,116 @@ impl<'a> ConstantEvaluator<'a> {
24972528

24982529
Ok(resolution)
24992530
}
2531+
2532+
fn select(
2533+
&mut self,
2534+
reject: Handle<Expression>,
2535+
accept: Handle<Expression>,
2536+
condition: Handle<Expression>,
2537+
span: Span,
2538+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2539+
let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
2540+
2541+
let reject = arg(reject)?;
2542+
let accept = arg(accept)?;
2543+
let condition = arg(condition)?;
2544+
2545+
let select_single_component =
2546+
|this: &mut Self, reject_scalar, reject, accept, condition| {
2547+
let accept = this.cast(accept, reject_scalar, span)?;
2548+
if condition {
2549+
Ok(accept)
2550+
} else {
2551+
Ok(reject)
2552+
}
2553+
};
2554+
2555+
match (&self.expressions[reject], &self.expressions[accept]) {
2556+
(&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
2557+
let reject_scalar = reject_lit.scalar();
2558+
let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
2559+
else {
2560+
return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
2561+
};
2562+
select_single_component(self, reject_scalar, reject, accept, condition)
2563+
}
2564+
(
2565+
&Expression::Compose {
2566+
ty: reject_ty,
2567+
components: ref reject_components,
2568+
},
2569+
&Expression::Compose {
2570+
ty: accept_ty,
2571+
components: ref accept_components,
2572+
},
2573+
) => {
2574+
let ty_deets = |ty| {
2575+
let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
2576+
(size.unwrap(), scalar)
2577+
};
2578+
2579+
let expected_vec_size = {
2580+
let [(reject_vec_size, _), (accept_vec_size, _)] =
2581+
[reject_ty, accept_ty].map(ty_deets);
2582+
2583+
if reject_vec_size != accept_vec_size {
2584+
return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
2585+
reject: reject_vec_size,
2586+
accept: accept_vec_size,
2587+
});
2588+
}
2589+
reject_vec_size
2590+
};
2591+
2592+
let condition_components = match self.expressions[condition] {
2593+
Expression::Literal(Literal::Bool(condition)) => {
2594+
vec![condition; (expected_vec_size as u8).into()]
2595+
}
2596+
Expression::Compose {
2597+
ty: condition_ty,
2598+
components: ref condition_components,
2599+
} => {
2600+
let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
2601+
if condition_scalar.kind != ScalarKind::Bool {
2602+
return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
2603+
}
2604+
if condition_vec_size != expected_vec_size {
2605+
return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
2606+
}
2607+
condition_components
2608+
.iter()
2609+
.copied()
2610+
.map(|component| match &self.expressions[component] {
2611+
&Expression::Literal(Literal::Bool(condition)) => condition,
2612+
_ => unreachable!(),
2613+
})
2614+
.collect()
2615+
}
2616+
2617+
_ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
2618+
};
2619+
2620+
let evaluated = Expression::Compose {
2621+
ty: reject_ty,
2622+
components: reject_components
2623+
.clone()
2624+
.into_iter()
2625+
.zip(accept_components.clone().into_iter())
2626+
.zip(condition_components.into_iter())
2627+
.map(|((reject, accept), condition)| {
2628+
let reject_scalar = match &self.expressions[reject] {
2629+
&Expression::Literal(lit) => lit.scalar(),
2630+
_ => unreachable!(),
2631+
};
2632+
select_single_component(self, reject_scalar, reject, accept, condition)
2633+
})
2634+
.collect::<Result<_, _>>()?,
2635+
};
2636+
self.register_evaluated_expr(evaluated, span)
2637+
}
2638+
_ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
2639+
}
2640+
}
25002641
}
25012642

25022643
fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {

naga/tests/in/wgsl/select.toml

Whitespace-only changes.

naga/tests/in/wgsl/select.wgsl

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
const_assert select(0xdeadbeef, 42f, false) == 0xdeadbeef;
2+
const_assert select(0xdeadbeefu, 42, false) == 0xdeadbeefu;
3+
const_assert select(0xdeadi, 42, false) == 0xdeadi;
4+
5+
const_assert select(42f, 9001, true) == 9001;
6+
const_assert select(42f, 9001, true) == 9001f;
7+
const_assert select(42, 9001i, true) == 9001;
8+
const_assert select(42, 9001u, true) == 9001;
9+
10+
const_assert !select(false, true, false);
11+
const_assert select(false, true, true);
12+
const_assert select(true, false, false);
13+
const_assert !select(true, false, true);
14+
15+
const_assert all(select(vec2(2f), vec2(), true) == vec2(0));
16+
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
17+
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
18+
const_assert all(select(vec2(1), vec2(2f), vec2(false, false)) == vec2(1));
19+
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
20+
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
21+
const_assert all(select(vec2(1), vec2(2f), vec2(true, false)) == vec2(2, 1));
22+
23+
const_assert all(select(vec3(1), vec3(2f), vec3(true)) == vec3(2));
24+
const_assert all(select(vec4(1), vec4(2f), vec4(true)) == vec4(2));
25+
26+
@compute @workgroup_size(1, 1)
27+
fn main() {
28+
_ = select(1, 2f, false);
29+
30+
var x0 = vec2(1, 2);
31+
var i1: vec2<f32> = select(vec2<f32>(1., 0.), vec2<f32>(0., 1.), (x0.x < x0.y));
32+
}

naga/tests/naga/validation.rs

+4-46
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,8 @@ fn binding_arrays_cannot_hold_scalars() {
648648
#[cfg(feature = "wgsl-in")]
649649
#[test]
650650
fn validation_error_messages() {
651-
let cases = [
652-
(
653-
r#"@group(0) @binding(0) var my_sampler: sampler;
651+
let cases = [(
652+
r#"@group(0) @binding(0) var my_sampler: sampler;
654653
655654
fn foo(tex: texture_2d<f32>) -> vec4<f32> {
656655
return textureSampleLevel(tex, my_sampler, vec2f(0, 0), 0.0);
@@ -660,7 +659,7 @@ fn validation_error_messages() {
660659
foo();
661660
}
662661
"#,
663-
"\
662+
"\
664663
error: Function [1] 'main' is invalid
665664
┌─ wgsl:7:17
666665
\n7 │ ╭ fn main() {
@@ -671,48 +670,7 @@ error: Function [1] 'main' is invalid
671670
= Requires 1 arguments, but 0 are provided
672671
673672
",
674-
),
675-
(
676-
"\
677-
@compute @workgroup_size(1, 1)
678-
fn main() {
679-
// Bad: `9001` isn't a `bool`.
680-
_ = select(1, 2, 9001);
681-
}
682-
",
683-
"\
684-
error: Entry point main at Compute is invalid
685-
┌─ wgsl:4:9
686-
687-
4 │ _ = select(1, 2, 9001);
688-
│ ^^^^^^ naga::ir::Expression [3]
689-
690-
= Expression [3] is invalid
691-
= Expected selection condition to be a boolean value, got Scalar(Scalar { kind: Sint, width: 4 })
692-
693-
",
694-
),
695-
(
696-
"\
697-
@compute @workgroup_size(1, 1)
698-
fn main() {
699-
// Bad: `bool` and abstract int args. don't match.
700-
_ = select(true, 1, false);
701-
}
702-
",
703-
"\
704-
error: Entry point main at Compute is invalid
705-
┌─ wgsl:4:9
706-
707-
4 │ _ = select(true, 1, false);
708-
│ ^^^^^^ naga::ir::Expression [3]
709-
710-
= Expression [3] is invalid
711-
= Expected selection argument types to match, but reject value of type Scalar(Scalar { kind: Bool, width: 1 }) does not match accept value of value Scalar(Scalar { kind: Sint, width: 4 })
712-
713-
",
714-
),
715-
];
673+
)];
716674

717675
for (source, expected_err) in cases {
718676
let module = naga::front::wgsl::parse_str(source).unwrap();

0 commit comments

Comments
 (0)