Skip to content

Commit 94f9594

Browse files
authored
Add {Frozen}Circuit.from_moments to construct circuit by moments. (#5805)
In our internal code almost all circuits are constructed with explicit moments rather than using an insertion strategy because we want to control the alignment of ops. Here we add a `from_moments` classmethod on `Circuit` and `FrozenCircuit` which takes any number of `OP_TREE` args and converts each one to a `Moment` in the resulting circuit. This gives a convenient way to construct circuits without cluttering the code with multiple calls to `Moment`. Also adds special cases in both the `Circuit` and `FrozenCircuit` constructors to handle constructing from a single circuit or a sequence of moments.
1 parent 3d61112 commit 94f9594

File tree

4 files changed

+91
-31
lines changed

4 files changed

+91
-31
lines changed

cirq-core/cirq/circuits/circuit.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,30 @@ class AbstractCircuit(abc.ABC):
143143
* get_independent_qubit_sets
144144
"""
145145

146+
@classmethod
147+
def from_moments(cls: Type[CIRCUIT_TYPE], *moments: 'cirq.OP_TREE') -> CIRCUIT_TYPE:
148+
"""Create a circuit from moment op trees.
149+
150+
Args:
151+
*moments: Op tree for each moment.
152+
"""
153+
return cls._from_moments(
154+
moment if isinstance(moment, Moment) else Moment(moment) for moment in moments
155+
)
156+
157+
@classmethod
158+
@abc.abstractmethod
159+
def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE:
160+
"""Create a circuit from moments.
161+
162+
This must be implemented by subclasses. It provides a more efficient way
163+
to construct a circuit instance since we already have the moments and so
164+
can skip the analysis required to implement various insert strategies.
165+
166+
Args:
167+
moments: Moments of the circuit.
168+
"""
169+
146170
@property
147171
@abc.abstractmethod
148172
def moments(self) -> Sequence['cirq.Moment']:
@@ -225,8 +249,7 @@ def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) ->
225249

226250
def __getitem__(self, key):
227251
if isinstance(key, slice):
228-
sliced_moments = self.moments[key]
229-
return self._with_sliced_moments(sliced_moments)
252+
return self._from_moments(self.moments[key])
230253
if hasattr(key, '__index__'):
231254
return self.moments[key]
232255
if isinstance(key, tuple):
@@ -239,17 +262,12 @@ def __getitem__(self, key):
239262
return selected_moments[qubit_idx]
240263
if isinstance(qubit_idx, ops.Qid):
241264
qubit_idx = [qubit_idx]
242-
sliced_moments = [moment[qubit_idx] for moment in selected_moments]
243-
return self._with_sliced_moments(sliced_moments)
265+
return self._from_moments(moment[qubit_idx] for moment in selected_moments)
244266

245267
raise TypeError('__getitem__ called with key not of type slice, int, or tuple.')
246268

247269
# pylint: enable=function-redefined
248270

249-
@abc.abstractmethod
250-
def _with_sliced_moments(self: CIRCUIT_TYPE, moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE:
251-
"""Helper method for constructing circuits from __getitem__."""
252-
253271
def __str__(self) -> str:
254272
return self.to_text_diagram()
255273

@@ -909,7 +927,7 @@ def map_moment(moment: 'cirq.Moment') -> 'cirq.Circuit':
909927
"""Apply func to expand each op into a circuit, then zip up the circuits."""
910928
return Circuit.zip(*[Circuit(func(op)) for op in moment])
911929

912-
return self._with_sliced_moments(m for moment in self for m in map_moment(moment))
930+
return self._from_moments(m for moment in self for m in map_moment(moment))
913931

914932
def qid_shape(
915933
self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
@@ -949,18 +967,16 @@ def _measurement_key_names_(self) -> FrozenSet[str]:
949967
return self.all_measurement_key_names()
950968

951969
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
952-
return self._with_sliced_moments(
953-
[protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments]
970+
return self._from_moments(
971+
protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments
954972
)
955973

956974
def _with_key_path_(self, path: Tuple[str, ...]):
957-
return self._with_sliced_moments(
958-
[protocols.with_key_path(moment, path) for moment in self.moments]
959-
)
975+
return self._from_moments(protocols.with_key_path(moment, path) for moment in self.moments)
960976

961977
def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
962-
return self._with_sliced_moments(
963-
[protocols.with_key_path_prefix(moment, prefix) for moment in self.moments]
978+
return self._from_moments(
979+
protocols.with_key_path_prefix(moment, prefix) for moment in self.moments
964980
)
965981

966982
def _with_rescoped_keys_(
@@ -971,7 +987,7 @@ def _with_rescoped_keys_(
971987
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
972988
moments.append(new_moment)
973989
bindable_keys |= protocols.measurement_key_objs(new_moment)
974-
return self._with_sliced_moments(moments)
990+
return self._from_moments(moments)
975991

976992
def _qid_shape_(self) -> Tuple[int, ...]:
977993
return self.qid_shape()
@@ -1552,9 +1568,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
15521568
# the qubits from one factor belong to a specific independent qubit set.
15531569
# This makes it possible to create independent circuits based on these
15541570
# moments.
1555-
return (
1556-
self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors
1557-
)
1571+
return (self._from_moments(m[qubits] for m in self.moments) for qubits in qubit_factors)
15581572

15591573
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
15601574
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
@@ -1719,6 +1733,12 @@ def __init__(
17191733
else:
17201734
self.append(contents, strategy=strategy)
17211735

1736+
@classmethod
1737+
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
1738+
new_circuit = Circuit()
1739+
new_circuit._moments[:] = moments
1740+
return new_circuit
1741+
17221742
def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
17231743
"""Optimized algorithm to load contents quickly.
17241744
@@ -1813,11 +1833,6 @@ def copy(self) -> 'Circuit':
18131833
copied_circuit._moments = self._moments[:]
18141834
return copied_circuit
18151835

1816-
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit':
1817-
new_circuit = Circuit()
1818-
new_circuit._moments = list(moments)
1819-
return new_circuit
1820-
18211836
# pylint: disable=function-redefined
18221837
@overload
18231838
def __setitem__(self, key: int, value: 'cirq.Moment'):

cirq-core/cirq/circuits/circuit_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,23 @@ def validate_moment(self, moment):
7070
moment_and_op_type_validating_device = _MomentAndOpTypeValidatingDeviceType()
7171

7272

73+
def test_from_moments():
74+
a, b, c, d = cirq.LineQubit.range(4)
75+
assert cirq.Circuit.from_moments(
76+
[cirq.X(a), cirq.Y(b)],
77+
[cirq.X(c)],
78+
[],
79+
cirq.Z(d),
80+
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
81+
) == cirq.Circuit(
82+
cirq.Moment(cirq.X(a), cirq.Y(b)),
83+
cirq.Moment(cirq.X(c)),
84+
cirq.Moment(),
85+
cirq.Moment(cirq.Z(d)),
86+
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
87+
)
88+
89+
7390
def test_alignment():
7491
assert repr(cirq.Alignment.LEFT) == 'cirq.Alignment.LEFT'
7592
assert repr(cirq.Alignment.RIGHT) == 'cirq.Alignment.RIGHT'
@@ -269,6 +286,16 @@ def test_append_control_key_subcircuit():
269286
assert len(c) == 1
270287

271288

289+
def test_measurement_key_paths():
290+
a = cirq.LineQubit(0)
291+
circuit1 = cirq.Circuit(cirq.measure(a, key='A'))
292+
assert cirq.measurement_key_names(circuit1) == {'A'}
293+
circuit2 = cirq.with_key_path(circuit1, ('B',))
294+
assert cirq.measurement_key_names(circuit2) == {'B:A'}
295+
circuit3 = cirq.with_key_path_prefix(circuit2, ('C',))
296+
assert cirq.measurement_key_names(circuit3) == {'C:B:A'}
297+
298+
272299
def test_append_moments():
273300
a = cirq.NamedQubit('a')
274301
b = cirq.NamedQubit('b')

cirq-core/cirq/circuits/frozen_circuit.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""An immutable version of the Circuit data structure."""
15-
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union
15+
from typing import FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union
1616

1717
import numpy as np
1818

@@ -51,6 +51,12 @@ def __init__(
5151
base = Circuit(contents, strategy=strategy)
5252
self._moments = tuple(base.moments)
5353

54+
@classmethod
55+
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
56+
new_circuit = FrozenCircuit()
57+
new_circuit._moments = tuple(moments)
58+
return new_circuit
59+
5460
@property
5561
def moments(self) -> Sequence['cirq.Moment']:
5662
return self._moments
@@ -143,11 +149,6 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
143149
except:
144150
return NotImplemented
145151

146-
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
147-
new_circuit = FrozenCircuit()
148-
new_circuit._moments = tuple(moments)
149-
return new_circuit
150-
151152
def _resolve_parameters_(
152153
self, resolver: 'cirq.ParamResolver', recursive: bool
153154
) -> 'cirq.FrozenCircuit':

cirq-core/cirq/circuits/frozen_circuit_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@
2121
import cirq
2222

2323

24+
def test_from_moments():
25+
a, b, c, d = cirq.LineQubit.range(4)
26+
assert cirq.FrozenCircuit.from_moments(
27+
[cirq.X(a), cirq.Y(b)],
28+
[cirq.X(c)],
29+
[],
30+
cirq.Z(d),
31+
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
32+
) == cirq.FrozenCircuit(
33+
cirq.Moment(cirq.X(a), cirq.Y(b)),
34+
cirq.Moment(cirq.X(c)),
35+
cirq.Moment(),
36+
cirq.Moment(cirq.Z(d)),
37+
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
38+
)
39+
40+
2441
def test_freeze_and_unfreeze():
2542
a, b = cirq.LineQubit.range(2)
2643
c = cirq.Circuit(cirq.X(a), cirq.H(b))

0 commit comments

Comments
 (0)