Skip to content

Error: Error in gradient for op BatchMatMul. The gradient of input 'b' has shape 'b,s,h', which does not match the shape of the input 's,h' #8548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
warfrogsdf opened this issue Apr 17, 2025 · 3 comments

Comments

@warfrogsdf
Copy link

warfrogsdf commented Apr 17, 2025

Describe the current behavior
I implemented a definition layer and model using TensorFlowJS, and encountered a problem during training. The code is as follows.When the code runs, it will report the following error.

throw new Error("Error in gradient for op ".concat(node.kernelName, ". The gradient of input ") +
^

Error: Error in gradient for op BatchMatMul. The gradient of input 'b' has shape '4,8,8', which does not match the shape of the input '8,8'

Describe the expected behavior

no error

my code

import * as tf from '@tensorflow/tfjs-node';

class MyLayer extends tf.layers.Layer {
  constructor(units) {
    super({});
    this.units = units;
  }

  build(inputShape) {
    this.w1 = this.addWeight(
      'w1',
      [this.units, this.units],
      'float32',
      tf.initializers.glorotNormal({}),
      undefined,
      true
    );
    super.build(inputShape);
  }

  call(inputs) {
    const input = Array.isArray(inputs) ? inputs[0] : inputs;
    return  tf.matMul(input, this.w1.read());
  }

  computeOutputShape(inputShape) {
    return [null, inputShape[inputShape.length - 2], this.dModel];
  }

  static get className() {
    return 'MyLayer';
  }
}
tf.serialization.registerClass(MyLayer);

const input = tf.input({shape: [4, 8]});
const layer1 = new MyLayer(8, 2)
const output = layer1.apply(input)

const model = tf.model({inputs: input, outputs: output});
model.compile({
  optimizer: 'adam',
  loss: tf.losses.softmaxCrossEntropy,
});
const _input = tf.ones([40000, 4, 8])
const _output = tf.ones([40000, 4, 8])

model.fit(_input, _output, {batchSize: 4}).then(()=>{
  let x = tf.ones([1, 4, 8])
  const y = model.predict(x)
});

@warfrogsdf warfrogsdf added the type:bug Something isn't working label Apr 17, 2025
@gaikwadrahul8 gaikwadrahul8 self-assigned this Apr 18, 2025
@shreyvegad
Copy link

This happens during training, particularly when calculating gradients of the matMul operation in your custom layer. Let's break it down.

In your custom layer, you define:
return tf.matMul(input, this.w1.read());

And the input shape passed to the model is [batchSize, 4, 8] (i.e., 3D tensor), and this.w1.read() is [8, 8] (i.e., 2D tensor). So you're trying to do:
matMul([batchSize, 4, 8], [8, 8])

This works in the forward pass because tf.matMul supports broadcasting over the batch dimension when one operand is 2D. But the gradient computation fails because it's trying to compute the gradient with respect to b (the 2D matrix), and it expects a broadcasted version of b with matching shape [batchSize, 8, 8].

### Solution:

You need to explicitly broadcast your weight tensor this.w1 to match the batch dimensions during training.
Modify the call() method like this:
`call(inputs) {
const input = Array.isArray(inputs) ? inputs[0] : inputs;
const batchSize = input.shape[0];

// Expand weights to shape [batchSize, 8, 8]
const wExpanded = this.w1.read().tile([batchSize, 1, 1]);

return tf.matMul(input, wExpanded);
}`

@shmishra99
Copy link
Contributor

Hi @warfrogsdf ,

As @shreyvegad suggested, the tf.matMul(input, this.w1.read()); operation in your custom layer is likely failing due to tensors with differing dimensions. This is because tf.matMul requires tensors with compatible inner dimensions for multiplication.

You can resolve this is by reshaping your input or weight ( this.w1.read()) tensors. To reshape your input tensor, you can use the tf.reshape function. This can reshape the input dimensions to [batch * feature, output] ( [batch * 4, 8] in your case).

tf.reshape(input, [-1, inputShape[inputShape.length - 1]])

Let me know if it helps. Thank You!!

Copy link

github-actions bot commented May 1, 2025

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label May 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants