Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Added Cholskey NotImplementedError function and test #885

Closed
wants to merge 12 commits into from
7 changes: 7 additions & 0 deletions tensornetwork/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def qr(self,
raise NotImplementedError("Backend '{}' has not implemented qr.".format(
self.name))

def chsky(self,
tensor: Tensor,
pivot_axis: int = -1,
non_negative_diagonal: bool = False) -> Tuple[Tensor, Tensor]:
"""Computes the Cholskey decomposition of a tensor."""
raise NotImplementedError("Backend '{}' has not implemented chsky.".format(self.name))

def rq(self,
tensor: Tensor,
pivot_axis: int = -1,
Expand Down
6 changes: 6 additions & 0 deletions tensornetwork/backends/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ def test_abstract_backend_qr_decompositon_not_implemented():
backend.qr(np.ones((2, 2)), 0)


def test_abstract_backend_chsky_decompositon_not_implemented():
backend = AbstractBackend()
with pytest.raises(NotImplementedError):
backend.chsky(np.ones((2, 2)), 0)


def test_abstract_backend_rq_decompositon_not_implemented():
backend = AbstractBackend()
with pytest.raises(NotImplementedError):
Expand Down
14 changes: 14 additions & 0 deletions tensornetwork/backends/jax/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,17 @@ def sign(self, tensor: Tensor) -> Tensor:

def item(self, tensor):
return tensor.item()

def power(self, a: Tensor, b: Union[Tensor, int]) -> Tensor:
"""
Returns the power of tensor a to the value of b.
In the case b is a tensor, then the power is by element
with a as the base and b as the exponent.
In the case b is a scalar, then the power of each value in a
is raised to the exponent of b.

Args:
a: The tensor that contains the base.
b: The tensor that contains the exponent or a single scalar.
"""
return jnp.power(a, b)
16 changes: 15 additions & 1 deletion tensornetwork/backends/jax/jax_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,6 @@ def matvec_jax(vector, matrix):
num_krylov_vecs=100,
tol=0.0001)


def test_sum():
np.random.seed(10)
backend = jax_backend.JaxBackend()
Expand Down Expand Up @@ -1240,3 +1239,18 @@ def test_item(dtype):
backend = jax_backend.JaxBackend()
tensor = backend.randn((1,), dtype=dtype, seed=10)
assert backend.item(tensor) == tensor.item()

@pytest.mark.parametrize("dtype", np_dtypes)
def test_power(dtype):
shape = (4, 3, 2)
backend = jax_backend.JaxBackend()
base_tensor = backend.randn(shape, dtype=dtype, seed=10)
power_tensor = backend.randn(shape, dtype=dtype, seed=10)
actual = backend.power(base_tensor, power_tensor)
expected = jax.numpy.power(base_tensor, power_tensor)
np.testing.assert_allclose(expected, actual)

power = np.random.rand(1)[0]
actual = backend.power(base_tensor, power)
expected = jax.numpy.power(base_tensor, power)
np.testing.assert_allclose(expected, actual)