Skip to content

Commit abcfbf0

Browse files
fix(naga): properly impl. auto. type conv. for select
1 parent 196ff98 commit abcfbf0

File tree

6 files changed

+119
-48
lines changed

6 files changed

+119
-48
lines changed

naga/src/front/wgsl/lower/mod.rs

+45-6
Original file line numberDiff line numberDiff line change
@@ -2496,13 +2496,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24962496
} else {
24972497
match function.name {
24982498
"select" => {
2499-
let mut args = ctx.prepare_args(arguments, 3, span);
2500-
2501-
let reject = self.expression(args.next()?, ctx)?;
2502-
let accept = self.expression(args.next()?, ctx)?;
2503-
let condition = self.expression(args.next()?, ctx)?;
2499+
const NUM_ARGS: usize = 3;
2500+
let mut args_iter = ctx.prepare_args(arguments, NUM_ARGS as u32, span);
2501+
let args: [_; NUM_ARGS] =
2502+
[args_iter.next()?, args_iter.next()?, args_iter.next()?];
2503+
args_iter.finish()?;
25042504

2505-
args.finish()?;
2505+
let [reject, accept, condition] = self.function_helper_array(
2506+
span,
2507+
proc::select::WgslSymbol,
2508+
|_args, _ctx| proc::select::overloads(),
2509+
args,
2510+
ctx,
2511+
)?;
25062512

25072513
ir::Expression::Select {
25082514
reject,
@@ -2999,6 +3005,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29993005
Ok(lowered_arguments)
30003006
}
30013007

3008+
/// A convenience around [`Self::function_helper`] that marshals to and from fixed-size arrays.
3009+
///
3010+
/// Useful for resolving function calls with a fixed number of arguments at parse time, like
3011+
/// [`ir::Expression::Select`].
3012+
fn function_helper_array<const NUM_ARGS: usize, F, O, R>(
3013+
&mut self,
3014+
span: Span,
3015+
fun: F,
3016+
resolve_overloads: R,
3017+
ast_arguments: [Handle<ast::Expression<'source>>; NUM_ARGS],
3018+
ctx: &mut ExpressionContext<'source, '_, '_>,
3019+
) -> Result<'source, [Handle<ir::Expression>; NUM_ARGS]>
3020+
where
3021+
F: TryToWgsl + core::fmt::Debug + Copy,
3022+
O: proc::OverloadSet,
3023+
R: FnOnce(
3024+
&[Handle<ir::Expression>; NUM_ARGS],
3025+
&mut ExpressionContext<'source, '_, '_>,
3026+
) -> O,
3027+
{
3028+
self.function_helper(
3029+
span,
3030+
fun,
3031+
|args, ctx| {
3032+
let args = args.try_into().unwrap();
3033+
resolve_overloads(args, ctx)
3034+
},
3035+
ast_arguments.into(),
3036+
ctx,
3037+
)
3038+
.map(|arr| arr.into_inner().unwrap())
3039+
}
3040+
30023041
/// Choose the right overload for a function call.
30033042
///
30043043
/// Return a [`Rule`] representing the most preferred overload in

naga/src/proc/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use emitter::Emitter;
1919
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
2020
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
2121
pub use namer::{EntryPointIndex, NameKey, Namer};
22-
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
22+
pub use overloads::{select, Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;

naga/src/proc/overloads/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod list;
2121
mod mathfunction;
2222
mod one_bits_iter;
2323
mod rule;
24+
pub mod select;
2425
mod utils;
2526

2627
pub use rule::{Conclusion, MissingSpecialType, Rule};

naga/src/proc/overloads/select.rs

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use crate::common::wgsl::{ToWgsl, TryToWgsl};
2+
use crate::ir;
3+
use crate::proc::overloads::utils::{list, rule, scalar_or_vecn, scalars};
4+
use crate::proc::overloads::OverloadSet;
5+
6+
pub fn overloads() -> impl OverloadSet {
7+
list(scalars().flat_map(|scalar| {
8+
scalar_or_vecn(scalar).map(|input| {
9+
let bool_arg = match input.clone() {
10+
ir::TypeInner::Scalar(_) => ir::TypeInner::Scalar(ir::Scalar::BOOL),
11+
ir::TypeInner::Vector { size, scalar: _ } => ir::TypeInner::Vector {
12+
size,
13+
scalar: ir::Scalar::BOOL,
14+
},
15+
_ => unreachable!(),
16+
};
17+
rule([input.clone(), input.clone(), bool_arg], input)
18+
})
19+
}))
20+
}
21+
22+
#[derive(Clone, Copy)]
23+
pub struct WgslSymbol;
24+
25+
impl ToWgsl for WgslSymbol {
26+
fn to_wgsl(self) -> &'static str {
27+
"select"
28+
}
29+
}
30+
31+
impl TryToWgsl for WgslSymbol {
32+
fn try_to_wgsl(self) -> Option<&'static str> {
33+
Some(self.to_wgsl())
34+
}
35+
36+
const DESCRIPTION: &'static str = "`select` built-in";
37+
}
38+
39+
impl core::fmt::Debug for WgslSymbol {
40+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41+
f.write_str(Self::DESCRIPTION)
42+
}
43+
}

naga/src/proc/overloads/utils.rs

+19
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,25 @@ pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
3434
.into_iter()
3535
}
3636

37+
/// Produce all [`ir::Scalar`]s.
38+
///
39+
/// Note that `*32` and `F16` must appear before other sizes; this is how we
40+
/// represent conversion rank.
41+
pub fn scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
42+
[
43+
ir::Scalar::ABSTRACT_INT,
44+
ir::Scalar::ABSTRACT_FLOAT,
45+
ir::Scalar::I32,
46+
ir::Scalar::U32,
47+
ir::Scalar::F32,
48+
ir::Scalar::F16,
49+
ir::Scalar::I64,
50+
ir::Scalar::U64,
51+
ir::Scalar::F64,
52+
]
53+
.into_iter()
54+
}
55+
3756
/// Produce all the floating-point [`ir::Scalar`]s, but omit
3857
/// abstract types, for #7405.
3958
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {

naga/src/valid/expression.rs

+10-41
Original file line numberDiff line numberDiff line change
@@ -984,47 +984,16 @@ impl super::Validator {
984984
accept,
985985
reject,
986986
} => {
987-
let accept_inner = &resolver[accept];
988-
let reject_inner = &resolver[reject];
989-
let condition_ty = &resolver[condition];
990-
let condition_good = match *condition_ty {
991-
Ti::Scalar(Sc {
992-
kind: Sk::Bool,
993-
width: _,
994-
}) => {
995-
// When `condition` is a single boolean, `accept` and
996-
// `reject` can be vectors or scalars.
997-
match *accept_inner {
998-
Ti::Scalar { .. } | Ti::Vector { .. } => true,
999-
_ => false,
1000-
}
1001-
}
1002-
Ti::Vector {
1003-
size,
1004-
scalar:
1005-
Sc {
1006-
kind: Sk::Bool,
1007-
width: _,
1008-
},
1009-
} => match *accept_inner {
1010-
Ti::Vector {
1011-
size: other_size, ..
1012-
} => size == other_size,
1013-
_ => false,
1014-
},
1015-
_ => false,
1016-
};
1017-
if accept_inner != reject_inner {
1018-
return Err(ExpressionError::SelectValuesTypeMismatch {
1019-
accept: accept_inner.clone(),
1020-
reject: reject_inner.clone(),
1021-
});
1022-
}
1023-
if !condition_good {
1024-
return Err(ExpressionError::SelectConditionNotABool {
1025-
actual: condition_ty.clone(),
1026-
});
1027-
}
987+
self.validate_func_call_with_overloads(
988+
module,
989+
proc::select::WgslSymbol,
990+
proc::select::overloads(),
991+
[reject, accept, condition]
992+
.iter()
993+
.copied()
994+
.map(|arg| (arg, &resolver[arg])),
995+
)?;
996+
1028997
ShaderStages::all()
1029998
}
1030999
E::Derivative { expr, .. } => {

0 commit comments

Comments
 (0)