Skip to content

Commit 24ece47

Browse files
authored
[JSEP] adjust edge case logic for scatternd (#24172)
Fixes #24070 by explicitly restricting single-threaded, sequential execution in the case where `reduction=none && hasDuplicates`.
1 parent 5d805c2 commit 24ece47

File tree

1 file changed

+49
-39
lines changed

1 file changed

+49
-39
lines changed

js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts

+49-39
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,36 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type:
7878
}
7979
};
8080

81+
const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) =>
82+
`${
83+
dataRank === 1
84+
? `
85+
let element_count_dim = uniforms.output_strides;
86+
let dim_value = uniforms.output_shape;`
87+
: `
88+
let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}];
89+
let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];`
90+
}
91+
92+
if (index >= 0) {
93+
if (index >= i32(dim_value)) {
94+
index = i32(dim_value - 1);
95+
}
96+
} else {
97+
if (index < -i32(dim_value)) {
98+
index = 0;
99+
} else {
100+
index += i32(dim_value);
101+
}
102+
}
103+
data_offset += u32((u32(index) * element_count_dim));`;
104+
105+
const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) =>
106+
`for (var i = 0u; i < uniforms.num_updates_elements; i++) {
107+
let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i];
108+
${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)}
109+
}`;
110+
81111
const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => {
82112
const inputShape = inputs[0].dims;
83113
const indicesShape = inputs[1].dims;
@@ -87,6 +117,7 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
87117
const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components);
88118
const lastIndexDimension = indicesShape[indicesShape.length - 1];
89119
const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension);
120+
const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension;
90121

91122
const programUniforms: ProgramUniform[] = [
92123
{ type: DataType.uint32, data: outputSize },
@@ -113,9 +144,8 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
113144
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
114145
var hasDuplicates = false;
115146
if (${attributes.reduction === 'none'}) {
116-
let n = ${ShapeUtil.size(indicesShape)};
117-
for (var i = 0; i < n; i = i + 1) {
118-
for (var j = i + 1; j < n; j = j + 1) {
147+
for (var i = 0; i < ${numIndicesElements}; i = i + 1) {
148+
for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) {
119149
var index_i = i32(indices[i].x);
120150
var index_j = i32(indices[j].x);
121151
if (index_i == index_j) {
@@ -129,51 +159,31 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
129159
}
130160
}
131161
132-
var data_offset = 0u;
133-
var indices_start = uniforms.last_index_dimension * global_idx;
134162
if (${attributes.reduction === 'none'} && hasDuplicates) {
135163
if (global_idx != 0u) {
136164
return;
137165
}
138-
indices_start = 0u;
139-
}
140-
let indices_end = indices_start + uniforms.last_index_dimension;
141-
for (var i = indices_start; i < indices_end; i++) {
142-
var index = i32(indices[i].x);
143-
${
144-
inputs[0].dims.length === 1
145-
? `
146-
let element_count_dim = uniforms.output_strides;
147-
let dim_value = uniforms.output_shape;`
148-
: `
149-
let element_count_dim = uniforms.output_strides[i - indices_start];
150-
let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
151-
}
152-
if (index >= 0) {
153-
if (index >= i32(dim_value)) {
154-
index = i32(dim_value - 1);
155-
}
156-
} else {
157-
if (index < -i32(dim_value)) {
158-
index = 0;
159-
} else {
160-
index += i32(dim_value);
166+
// Process each index-update pair individually when duplicates exist
167+
for (var idx = 0u; idx < ${numIndicesElements}u; idx++) {
168+
var data_offset = 0u;
169+
for (var i = 0u; i < uniforms.last_index_dimension; i++) {
170+
var index = i32(indices[idx * uniforms.last_index_dimension + i].x);
171+
${calcDataOffsetSnippet(inputShape.length, false)}
161172
}
173+
${updateElementsSnippet(attributes, output.type.value as ReductionType, false)}
162174
}
163-
data_offset += u32((u32(index) * element_count_dim));
175+
return;
164176
}
165177
166-
for (var i = 0u; i < uniforms.num_updates_elements; i++) {
167-
let value = updates[uniforms.num_updates_elements * global_idx + i];
168-
${atomicReductionSnippet(
169-
attributes.reduction,
170-
'output[data_offset + i]',
171-
'value',
172-
output.type.value as ReductionType,
173-
)}
178+
var data_offset = 0u;
179+
var indices_start = uniforms.last_index_dimension * global_idx;
180+
var indices_end = indices_start + uniforms.last_index_dimension;
181+
for (var i = indices_start; i < indices_end; i++) {
182+
var index = i32(indices[i].x);
183+
${calcDataOffsetSnippet(inputShape.length, true)}
174184
}
175-
176-
}`;
185+
${updateElementsSnippet(attributes, output.type.value as ReductionType, true)}
186+
}`;
177187
};
178188
return {
179189
name: 'ScatterND',

0 commit comments

Comments
 (0)