Skip to content

Commit cfd29df

Browse files
WIP: fix(naga): properly impl. auto. type conv. for select
1 parent ca4e365 commit cfd29df

File tree

5 files changed

+174
-45
lines changed

5 files changed

+174
-45
lines changed

naga/src/front/wgsl/lower/mod.rs

+105-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use alloc::{
44
string::{String, ToString},
55
vec::Vec,
66
};
7+
use arrayvec::ArrayVec;
78
use core::num::NonZeroU32;
89

910
use crate::common::wgsl::{TryToWgsl, TypeContext};
@@ -2481,13 +2482,46 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24812482
} else {
24822483
match function.name {
24832484
"select" => {
2484-
let mut args = ctx.prepare_args(arguments, 3, span);
2485+
const NUM_ARGS: usize = 3;
2486+
2487+
// TODO: dedupe with `math_function_helper`
24852488

2486-
let reject = self.expression(args.next()?, ctx)?;
2487-
let accept = self.expression(args.next()?, ctx)?;
2488-
let condition = self.expression(args.next()?, ctx)?;
2489+
let mut lowered_arguments = ArrayVec::<_, NUM_ARGS>::new();
2490+
let mut args = ctx.prepare_args(arguments, NUM_ARGS as u32, span);
2491+
2492+
for _ in 0..lowered_arguments.capacity() {
2493+
let lowered = self.expression_for_abstract(args.next()?, ctx)?;
2494+
ctx.grow_types(lowered)?;
2495+
lowered_arguments.push(lowered);
2496+
}
24892497

24902498
args.finish()?;
2499+
let mut lowered_arguments = lowered_arguments.into_inner().unwrap();
2500+
2501+
let fun_overloads = proc::select::overloads();
2502+
2503+
let rule = self.resolve_overloads(
2504+
span,
2505+
proc::select::WgslSymbol,
2506+
fun_overloads,
2507+
&lowered_arguments,
2508+
ctx,
2509+
)?;
2510+
2511+
self.apply_automatic_conversions_for_call(
2512+
&rule,
2513+
&mut lowered_arguments,
2514+
ctx,
2515+
)?;
2516+
2517+
// If this function returns a predeclared type, register it
2518+
// in `Module::special_types`. The typifier will expect to
2519+
// be able to find it there.
2520+
if let proc::Conclusion::Predeclared(predeclared) = rule.conclusion {
2521+
ctx.module.generate_predeclared_type(predeclared);
2522+
}
2523+
2524+
let [reject, accept, condition] = lowered_arguments;
24912525

24922526
ir::Expression::Select {
24932527
reject,
@@ -2988,6 +3022,70 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29883022
})
29893023
}
29903024

3025+
fn resolve_arr_things<const NUM_ARGS: usize, F, O, R>(
3026+
&mut self,
3027+
span: Span,
3028+
fun: F,
3029+
resolve_overloads: R,
3030+
ast_arguments: [Handle<ast::Expression<'source>>; NUM_ARGS],
3031+
ctx: &mut ExpressionContext<'source, '_, '_>,
3032+
) -> Result<'source, [Handle<ir::Expression>; NUM_ARGS]>
3033+
where
3034+
F: TryToWgsl + core::fmt::Debug + Copy,
3035+
O: proc::OverloadSet,
3036+
R: FnOnce(
3037+
&[Handle<ir::Expression>; NUM_ARGS],
3038+
&mut ExpressionContext<'source, '_, '_>,
3039+
) -> O,
3040+
{
3041+
self.resolve_const_things(
3042+
span,
3043+
fun,
3044+
|args, ctx| {
3045+
let args = args.try_into().unwrap();
3046+
resolve_overloads(args, ctx)
3047+
},
3048+
ast_arguments.into(),
3049+
ctx,
3050+
)
3051+
.map(|arr| arr.into_inner().unwrap())
3052+
}
3053+
3054+
fn resolve_const_things<const NUM_ARGS: usize, F, O, R>(
3055+
&mut self,
3056+
span: Span,
3057+
fun: F,
3058+
resolve_overloads: R,
3059+
ast_arguments: ArrayVec<Handle<ast::Expression<'source>>, { NUM_ARGS }>,
3060+
ctx: &mut ExpressionContext<'source, '_, '_>,
3061+
) -> Result<'source, ArrayVec<Handle<ir::Expression>, { NUM_ARGS }>>
3062+
where
3063+
F: TryToWgsl + core::fmt::Debug + Copy,
3064+
O: proc::OverloadSet,
3065+
R: FnOnce(&[Handle<ir::Expression>], &mut ExpressionContext<'source, '_, '_>) -> O,
3066+
{
3067+
let mut lowered_arguments = ArrayVec::<_, { NUM_ARGS }>::new();
3068+
3069+
for &arg in ast_arguments.iter() {
3070+
let lowered = self.expression_for_abstract(arg, ctx)?;
3071+
ctx.grow_types(lowered)?;
3072+
lowered_arguments.push(lowered);
3073+
}
3074+
3075+
let fun_overloads = resolve_overloads(&lowered_arguments, ctx);
3076+
let rule = self.resolve_overloads(span, fun, fun_overloads, &lowered_arguments, ctx)?;
3077+
self.apply_automatic_conversions_for_call(&rule, &mut lowered_arguments, ctx)?;
3078+
3079+
// If this function returns a predeclared type, register it
3080+
// in `Module::special_types`. The typifier will expect to
3081+
// be able to find it there.
3082+
if let proc::Conclusion::Predeclared(predeclared) = rule.conclusion {
3083+
ctx.module.generate_predeclared_type(predeclared);
3084+
}
3085+
3086+
Ok(lowered_arguments)
3087+
}
3088+
29913089
/// Choose the right overload for a function call.
29923090
///
29933091
/// Return a [`Rule`] representing the most preferred overload in
@@ -3826,6 +3924,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
38263924
}
38273925
}
38283926

3927+
// TODO: Surely this already exists! Find it.
3928+
const MAX_BUILTIN_ARGS: usize = 4;
3929+
38293930
impl ir::AtomicFunction {
38303931
pub fn map(word: &str) -> Option<Self> {
38313932
Some(match word {

naga/src/proc/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use emitter::Emitter;
1919
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
2020
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
2121
pub use namer::{EntryPointIndex, NameKey, Namer};
22-
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
22+
pub use overloads::{select, Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;

naga/src/proc/overloads/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod list;
2121
mod mathfunction;
2222
mod one_bits_iter;
2323
mod rule;
24+
pub mod select;
2425
mod utils;
2526

2627
pub use rule::{Conclusion, MissingSpecialType, Rule};

naga/src/proc/overloads/select.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use crate::common::wgsl::{ToWgsl, TryToWgsl};
2+
use crate::ir;
3+
use crate::proc::overloads::utils::{list, rule, scalar_or_vecn, scalars};
4+
use crate::proc::overloads::OverloadSet;
5+
6+
pub fn overloads() -> impl OverloadSet {
7+
list(scalars().flat_map(|scalar| {
8+
scalar_or_vecn(scalar).map(|input| {
9+
let bool_arg = match input.clone() {
10+
ir::TypeInner::Scalar(_) => ir::TypeInner::Scalar(ir::Scalar::BOOL),
11+
ir::TypeInner::Vector { size, scalar: _ } => ir::TypeInner::Vector {
12+
size,
13+
scalar: ir::Scalar::BOOL,
14+
},
15+
_ => unreachable!(),
16+
};
17+
rule([input.clone(), input.clone(), bool_arg], input)
18+
})
19+
}))
20+
}
21+
22+
#[derive(Debug, Clone, Copy)]
23+
pub struct WgslSymbol;
24+
25+
impl ToWgsl for WgslSymbol {
26+
fn to_wgsl(self) -> &'static str {
27+
"select"
28+
}
29+
}
30+
31+
impl TryToWgsl for WgslSymbol {
32+
fn try_to_wgsl(self) -> Option<&'static str> {
33+
Some(self.to_wgsl())
34+
}
35+
36+
const DESCRIPTION: &'static str = "`select` built-in";
37+
}

naga/src/valid/expression.rs

+30-40
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use alloc::{format, string::String};
1+
use alloc::{borrow::ToOwned, format, string::String};
22

33
use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
44
use crate::arena::UniqueArena;
5+
use crate::common::wgsl::TryToWgsl as _;
56
use crate::{
67
arena::Handle,
78
proc,
@@ -929,47 +930,36 @@ impl super::Validator {
929930
accept,
930931
reject,
931932
} => {
932-
let accept_inner = &resolver[accept];
933-
let reject_inner = &resolver[reject];
934-
let condition_ty = &resolver[condition];
935-
let condition_good = match *condition_ty {
936-
Ti::Scalar(Sc {
937-
kind: Sk::Bool,
938-
width: _,
939-
}) => {
940-
// When `condition` is a single boolean, `accept` and
941-
// `reject` can be vectors or scalars.
942-
match *accept_inner {
943-
Ti::Scalar { .. } | Ti::Vector { .. } => true,
944-
_ => false,
945-
}
933+
// TODO: dedupe with math functions
934+
935+
let mut overloads = proc::select::overloads();
936+
log::debug!(
937+
"initial overloads for `select`: {:#?}",
938+
overloads.for_debug(&module.types)
939+
);
940+
941+
for (i, (expr, ty)) in [reject, accept, condition]
942+
.iter()
943+
.copied()
944+
.map(|arg| (arg, &resolver[arg]))
945+
.enumerate()
946+
{
947+
overloads = overloads.arg(i, ty, &module.types);
948+
log::debug!(
949+
"overloads after arg {i}: {:#?}",
950+
overloads.for_debug(&module.types)
951+
);
952+
953+
if overloads.is_empty() {
954+
log::debug!("all overloads eliminated");
955+
return Err(ExpressionError::InvalidArgumentType(
956+
proc::select::WgslSymbol::DESCRIPTION.to_owned(),
957+
i as u32,
958+
expr,
959+
));
946960
}
947-
Ti::Vector {
948-
size,
949-
scalar:
950-
Sc {
951-
kind: Sk::Bool,
952-
width: _,
953-
},
954-
} => match *accept_inner {
955-
Ti::Vector {
956-
size: other_size, ..
957-
} => size == other_size,
958-
_ => false,
959-
},
960-
_ => false,
961-
};
962-
if accept_inner != reject_inner {
963-
return Err(ExpressionError::SelectValuesTypeMismatch {
964-
accept: accept_inner.clone(),
965-
reject: reject_inner.clone(),
966-
});
967-
}
968-
if !condition_good {
969-
return Err(ExpressionError::SelectConditionNotABool {
970-
actual: condition_ty.clone(),
971-
});
972961
}
962+
973963
ShaderStages::all()
974964
}
975965
E::Derivative { expr, .. } => {

0 commit comments

Comments
 (0)