diff --git a/dwave/optimization/_model.pyi b/dwave/optimization/_model.pyi index 9199cf1e..d7a97d4e 100644 --- a/dwave/optimization/_model.pyi +++ b/dwave/optimization/_model.pyi @@ -14,6 +14,7 @@ import collections.abc import contextlib +import os import tempfile import typing @@ -41,14 +42,14 @@ class _Graph: @classmethod def from_file( cls: typing.Type[_GraphSubclass], - file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str], + file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO], *, check_header: bool = True, ) -> _GraphSubclass: ... def into_file( self, - file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str], + file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO], *, max_num_states: int = 0, only_decision: bool = False, diff --git a/dwave/optimization/_model.pyx b/dwave/optimization/_model.pyx index 08468640..04f78bea 100644 --- a/dwave/optimization/_model.pyx +++ b/dwave/optimization/_model.pyx @@ -35,6 +35,7 @@ from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode from dwave.optimization.states cimport States from dwave.optimization.states import StateView from dwave.optimization.symbols cimport symbol_from_ptr +from dwave.optimization.utilities import _file_object_arg, _lock __all__ = [] @@ -139,6 +140,7 @@ cdef class _Graph: return sum(sym.state_size() for sym in self.iter_decisions()) @classmethod + @_file_object_arg("rb") # translate str/bytes file inputs into file objects def from_file(cls, file, *, check_header = True, ): @@ -157,10 +159,6 @@ cdef class _Graph: """ import dwave.optimization.symbols as symbols - if isinstance(file, str): - with open(file, "rb") as f: - return cls.from_file(f) - version, header_data = _Graph._from_file_header(file) cdef _Graph model = cls() @@ -282,6 +280,8 @@ cdef class _Graph: num_states=num_states, ) + @_file_object_arg("wb") # translate str/bytes file inputs into file objects + @_lock def into_file(self, file, *, Py_ssize_t max_num_states = 0, bool only_decision = False, @@ -309,25 +309,6 @@ cdef class _Graph: TODO: describe the format """ - if not self.is_locked(): - # lock for the duration of the method - with self.lock(): - return self.into_file( - file, - max_num_states=max_num_states, - only_decision=only_decision, - version=version, - ) - - if isinstance(file, str): - with open(file, "wb") as f: - return self.into_file( - f, - max_num_states=max_num_states, - only_decision=only_decision, - version=version, - ) - version, model_info = self._into_file_header( file, version=version, diff --git a/dwave/optimization/states.pyi b/dwave/optimization/states.pyi index debdb486..6a342c3f 100644 --- a/dwave/optimization/states.pyi +++ b/dwave/optimization/states.pyi @@ -13,6 +13,7 @@ # limitations under the License. import collections.abc +import os import typing from dwave.optimization.model import Model @@ -25,7 +26,7 @@ class States: def from_file( self, - file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str], + file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO], *, replace: bool = True, check_header: bool = True, @@ -36,7 +37,7 @@ class States: def into_file( self, - file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str], + file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO], version: typing.Optional[tuple[int, int]] = None, ): ... diff --git a/dwave/optimization/states.pyx b/dwave/optimization/states.pyx index c9eaccb9..c7ac4df3 100644 --- a/dwave/optimization/states.pyx +++ b/dwave/optimization/states.pyx @@ -19,6 +19,7 @@ from libcpp.utility cimport move from dwave.optimization.libcpp.array cimport Array as cppArray from dwave.optimization.model cimport ArraySymbol, _Graph from dwave.optimization.model import Model +from dwave.optimization.utilities import _file_object_arg __all__ = ["States"] @@ -137,6 +138,7 @@ cdef class States: self._states.swap(states) return move(states) + @_file_object_arg("rb") # translate str/bytes file inputs into file objects def from_file(self, file, *, replace = True, check_header = True): """Construct states from the given file. @@ -200,6 +202,7 @@ cdef class States: self._states[i].resize(model.num_nodes()) model._graph.initialize_state(self._states[i]) + @_file_object_arg("wb") # translate str/bytes file inputs into file objects def into_file(self, file, *, version = None, ): diff --git a/dwave/optimization/utilities.py b/dwave/optimization/utilities.py new file mode 100644 index 00000000..077d4f79 --- /dev/null +++ b/dwave/optimization/utilities.py @@ -0,0 +1,50 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import functools + +__all__ = [] + + +def _file_object_arg(mode: str): + """In several methods we want to accept a file name or a file-like object. + + The ``mode`` argument is the same as for ``open()``. + + This method assumes that the file argument is the first one. We could + generalize if we need to. + """ + def decorator(method): + @functools.wraps(method) + def _method(cls_or_self, file, *args, **kwargs): + if isinstance(file, (str, bytes, os.PathLike)): + with open(os.fspath(file), mode) as fobj: + return method(cls_or_self, fobj, *args, **kwargs) + else: + return method(cls_or_self, file, *args, **kwargs) + return _method + return decorator + + +def _lock(method): + """Decorator for Model methods that lock the model for the duration.""" + @functools.wraps(method) + def _method(obj, *args, **kwargs): + if not obj.is_locked(): + with obj.lock(): + return method(obj, *args, **kwargs) + else: + return method(obj, *args, **kwargs) + return _method diff --git a/tests/test_model.py b/tests/test_model.py index 0592701f..4d51505f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -15,6 +15,7 @@ import io import operator import os.path +import pathlib import tempfile import unittest @@ -454,21 +455,56 @@ def test(self): np.testing.assert_array_equal(a.state(2), range(5)) def test_by_filename(self): - model = Model() - c = model.constant([0, 1, 2, 3, 4]) - x = model.list(5) - model.minimize(c[x].sum()) + with self.subTest("bytes"): + model = Model() + c = model.constant([0, 1, 2, 3, 4]) + x = model.list(5) + model.minimize(c[x].sum()) + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, "temp.nl") + + model.into_file(fname.encode("ascii")) + + new = Model.from_file(fname) + + # todo: use a model equality test function once we have it + for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()): + self.assertIs(type(n0), type(n1)) + + with self.subTest("path-like"): + model = Model() + c = model.constant([0, 1, 2, 3, 4]) + x = model.list(5) + model.minimize(c[x].sum()) + + with tempfile.TemporaryDirectory() as dirname: + fname = pathlib.PurePath(os.path.join(dirname, "temp.nl")) + + model.into_file(fname) + + new = Model.from_file(fname) + + # todo: use a model equality test function once we have it + for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()): + self.assertIs(type(n0), type(n1)) + + with self.subTest("str"): + model = Model() + c = model.constant([0, 1, 2, 3, 4]) + x = model.list(5) + model.minimize(c[x].sum()) - with tempfile.TemporaryDirectory() as dirname: - fname = os.path.join(dirname, "temp.nl") + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, "temp.nl") - model.into_file(fname) + model.into_file(fname) - new = Model.from_file(fname) + new = Model.from_file(fname) - # todo: use a model equality test function once we have it - for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()): - self.assertIs(type(n0), type(n1)) + # todo: use a model equality test function once we have it + for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()): + self.assertIs(type(n0), type(n1)) def test_invalid_version_from_file(self): from dwave.optimization._model import DEFAULT_SERIALIZATION_VERSION