diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index e47602ac942..278e33038e1 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -159,6 +159,17 @@ def _decompose_with_context_( self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: control_qubits = list(qubits[: self.num_controls()]) + controlled_sub_gate = self.sub_gate.controlled( + self.num_controls(), self.control_values, self.control_qid_shape + ) + # Prefer the subgate controlled version if available + if self != controlled_sub_gate: + # Prevent 2-cycle from appearing in the recursive decomposition + # TODO: Remove after #7241 is resolved + if not isinstance(controlled_sub_gate, ControlledGate) or not isinstance( + controlled_sub_gate.sub_gate, common_gates.CZPowGate + ): + return controlled_sub_gate.on(*qubits) if ( protocols.has_unitary(self.sub_gate) and protocols.num_qubits(self.sub_gate) == 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..ac591bbf2d8 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -494,6 +494,36 @@ def _test_controlled_gate_is_consistent( np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13) +@pytest.mark.parametrize( + 'sub_gate, expected_decomposition', + [ + (cirq.X, [cirq.CX]), + (cirq.CX, [cirq.CCX]), + (cirq.XPowGate(), [cirq.CXPowGate()]), + (cirq.CXPowGate(), [cirq.CCXPowGate()]), + (cirq.Z, [cirq.CZ]), + (cirq.CZ, [cirq.CCZ]), + (cirq.ZPowGate(), [cirq.CZPowGate()]), + (cirq.CZPowGate(), [cirq.CCZPowGate()]), + ], +) +def test_controlled_gate_decomposition_uses_canonical_version(sub_gate, expected_decomposition): + cgate = cirq.ControlledGate(sub_gate, num_controls=1) + qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) + dec = cirq.decompose_once(cgate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition + + +@pytest.mark.parametrize( + 'sub_gate, expected_decomposition', [(cirq.Z, [cirq.CZ]), (cirq.ZPowGate(), [cirq.CZPowGate()])] +) +def test_controlled_gate_full_decomposition(sub_gate, expected_decomposition): + cgate = cirq.ControlledGate(sub_gate, num_controls=1) + qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) + dec = cirq.decompose(cgate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition + + def test_pow_inverse(): assert cirq.inverse(CRestricted, None) is None assert cirq.pow(CRestricted, 1.5, None) is None