@@ -78,6 +78,36 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type:
78
78
}
79
79
} ;
80
80
81
+ const calcDataOffsetSnippet = ( dataRank : number , parallel : boolean ) =>
82
+ `${
83
+ dataRank === 1
84
+ ? `
85
+ let element_count_dim = uniforms.output_strides;
86
+ let dim_value = uniforms.output_shape;`
87
+ : `
88
+ let element_count_dim = uniforms.output_strides[${ parallel ? 'i - indices_start' : 'i' } ];
89
+ let dim_value = uniforms.output_shape[${ parallel ? 'i - indices_start' : 'i' } + uniforms.last_index_dimension];`
90
+ }
91
+
92
+ if (index >= 0) {
93
+ if (index >= i32(dim_value)) {
94
+ index = i32(dim_value - 1);
95
+ }
96
+ } else {
97
+ if (index < -i32(dim_value)) {
98
+ index = 0;
99
+ } else {
100
+ index += i32(dim_value);
101
+ }
102
+ }
103
+ data_offset += u32((u32(index) * element_count_dim));` ;
104
+
105
+ const updateElementsSnippet = ( attributes : ScatterNDAttributes , outputTypeValue : ReductionType , parallel : boolean ) =>
106
+ `for (var i = 0u; i < uniforms.num_updates_elements; i++) {
107
+ let value = updates[uniforms.num_updates_elements * ${ parallel ? 'global_idx' : 'idx' } + i];
108
+ ${ atomicReductionSnippet ( attributes . reduction , 'output[data_offset + i]' , 'value' , outputTypeValue ) }
109
+ }` ;
110
+
81
111
const createScatterNDProgramInfo = ( inputs : readonly TensorView [ ] , attributes : ScatterNDAttributes ) : ProgramInfo => {
82
112
const inputShape = inputs [ 0 ] . dims ;
83
113
const indicesShape = inputs [ 1 ] . dims ;
@@ -87,6 +117,7 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
87
117
const outputSize = Math . ceil ( ShapeUtil . size ( indicesShape ) / components ) ;
88
118
const lastIndexDimension = indicesShape [ indicesShape . length - 1 ] ;
89
119
const numUpdatesElements = ShapeUtil . sizeFromDimension ( inputShape , lastIndexDimension ) ;
120
+ const numIndicesElements = ShapeUtil . sizeFromDimension ( indicesShape , 0 ) / lastIndexDimension ;
90
121
91
122
const programUniforms : ProgramUniform [ ] = [
92
123
{ type : DataType . uint32 , data : outputSize } ,
@@ -113,9 +144,8 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
113
144
${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
114
145
var hasDuplicates = false;
115
146
if (${ attributes . reduction === 'none' } ) {
116
- let n = ${ ShapeUtil . size ( indicesShape ) } ;
117
- for (var i = 0; i < n; i = i + 1) {
118
- for (var j = i + 1; j < n; j = j + 1) {
147
+ for (var i = 0; i < ${ numIndicesElements } ; i = i + 1) {
148
+ for (var j = i + 1; j < ${ numIndicesElements } ; j = j + 1) {
119
149
var index_i = i32(indices[i].x);
120
150
var index_j = i32(indices[j].x);
121
151
if (index_i == index_j) {
@@ -129,51 +159,31 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
129
159
}
130
160
}
131
161
132
- var data_offset = 0u;
133
- var indices_start = uniforms.last_index_dimension * global_idx;
134
162
if (${ attributes . reduction === 'none' } && hasDuplicates) {
135
163
if (global_idx != 0u) {
136
164
return;
137
165
}
138
- indices_start = 0u;
139
- }
140
- let indices_end = indices_start + uniforms.last_index_dimension;
141
- for (var i = indices_start; i < indices_end; i++) {
142
- var index = i32(indices[i].x);
143
- ${
144
- inputs [ 0 ] . dims . length === 1
145
- ? `
146
- let element_count_dim = uniforms.output_strides;
147
- let dim_value = uniforms.output_shape;`
148
- : `
149
- let element_count_dim = uniforms.output_strides[i - indices_start];
150
- let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
151
- }
152
- if (index >= 0) {
153
- if (index >= i32(dim_value)) {
154
- index = i32(dim_value - 1);
155
- }
156
- } else {
157
- if (index < -i32(dim_value)) {
158
- index = 0;
159
- } else {
160
- index += i32(dim_value);
166
+ // Process each index-update pair individually when duplicates exist
167
+ for (var idx = 0u; idx < ${ numIndicesElements } u; idx++) {
168
+ var data_offset = 0u;
169
+ for (var i = 0u; i < uniforms.last_index_dimension; i++) {
170
+ var index = i32(indices[idx * uniforms.last_index_dimension + i].x);
171
+ ${ calcDataOffsetSnippet ( inputShape . length , false ) }
161
172
}
173
+ ${ updateElementsSnippet ( attributes , output . type . value as ReductionType , false ) }
162
174
}
163
- data_offset += u32((u32(index) * element_count_dim)) ;
175
+ return ;
164
176
}
165
177
166
- for (var i = 0u; i < uniforms.num_updates_elements; i++) {
167
- let value = updates[uniforms.num_updates_elements * global_idx + i];
168
- ${ atomicReductionSnippet (
169
- attributes . reduction ,
170
- 'output[data_offset + i]' ,
171
- 'value' ,
172
- output . type . value as ReductionType ,
173
- ) }
178
+ var data_offset = 0u;
179
+ var indices_start = uniforms.last_index_dimension * global_idx;
180
+ var indices_end = indices_start + uniforms.last_index_dimension;
181
+ for (var i = indices_start; i < indices_end; i++) {
182
+ var index = i32(indices[i].x);
183
+ ${ calcDataOffsetSnippet ( inputShape . length , true ) }
174
184
}
175
-
176
- }` ;
185
+ ${ updateElementsSnippet ( attributes , output . type . value as ReductionType , true ) }
186
+ }` ;
177
187
} ;
178
188
return {
179
189
name : 'ScatterND' ,
0 commit comments