Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some utility decorators for Model/States methods #247

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import collections.abc
import contextlib
import os
import tempfile
import typing

Expand Down Expand Up @@ -41,14 +42,14 @@ class _Graph:
@classmethod
def from_file(
cls: typing.Type[_GraphSubclass],
file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str],
file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO],
*,
check_header: bool = True,
) -> _GraphSubclass: ...

def into_file(
self,
file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str],
file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO],
*,
max_num_states: int = 0,
only_decision: bool = False,
Expand Down
27 changes: 4 additions & 23 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode
from dwave.optimization.states cimport States
from dwave.optimization.states import StateView
from dwave.optimization.symbols cimport symbol_from_ptr
from dwave.optimization.utilities import _file_object_arg, _lock

__all__ = []

Expand Down Expand Up @@ -139,6 +140,7 @@ cdef class _Graph:
return sum(sym.state_size() for sym in self.iter_decisions())

@classmethod
@_file_object_arg("rb") # translate str/bytes file inputs into file objects
def from_file(cls, file, *,
check_header = True,
):
Expand All @@ -157,10 +159,6 @@ cdef class _Graph:
"""
import dwave.optimization.symbols as symbols

if isinstance(file, str):
with open(file, "rb") as f:
return cls.from_file(f)

version, header_data = _Graph._from_file_header(file)

cdef _Graph model = cls()
Expand Down Expand Up @@ -282,6 +280,8 @@ cdef class _Graph:
num_states=num_states,
)

@_file_object_arg("wb") # translate str/bytes file inputs into file objects
@_lock
def into_file(self, file, *,
Py_ssize_t max_num_states = 0,
bool only_decision = False,
Expand Down Expand Up @@ -309,25 +309,6 @@ cdef class _Graph:

TODO: describe the format
"""
if not self.is_locked():
# lock for the duration of the method
with self.lock():
return self.into_file(
file,
max_num_states=max_num_states,
only_decision=only_decision,
version=version,
)

if isinstance(file, str):
with open(file, "wb") as f:
return self.into_file(
f,
max_num_states=max_num_states,
only_decision=only_decision,
version=version,
)

version, model_info = self._into_file_header(
file,
version=version,
Expand Down
5 changes: 3 additions & 2 deletions dwave/optimization/states.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections.abc
import os
import typing

from dwave.optimization.model import Model
Expand All @@ -25,7 +26,7 @@ class States:

def from_file(
self,
file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str],
file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO],
*,
replace: bool = True,
check_header: bool = True,
Expand All @@ -36,7 +37,7 @@ class States:

def into_file(
self,
file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str],
file: typing.Union[bytes, os.PathLike, str, typing.BinaryIO],
version: typing.Optional[tuple[int, int]] = None,
): ...

Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/states.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ from libcpp.utility cimport move
from dwave.optimization.libcpp.array cimport Array as cppArray
from dwave.optimization.model cimport ArraySymbol, _Graph
from dwave.optimization.model import Model
from dwave.optimization.utilities import _file_object_arg

__all__ = ["States"]

Expand Down Expand Up @@ -137,6 +138,7 @@ cdef class States:
self._states.swap(states)
return move(states)

@_file_object_arg("rb") # translate str/bytes file inputs into file objects
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are redundant for now, but will be used in #241

def from_file(self, file, *, replace = True, check_header = True):
"""Construct states from the given file.

Expand Down Expand Up @@ -200,6 +202,7 @@ cdef class States:
self._states[i].resize(model.num_nodes())
model._graph.initialize_state(self._states[i])

@_file_object_arg("wb") # translate str/bytes file inputs into file objects
def into_file(self, file, *,
version = None,
):
Expand Down
50 changes: 50 additions & 0 deletions dwave/optimization/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import functools

__all__ = []


def _file_object_arg(mode: str):
"""In several methods we want to accept a file name or a file-like object.

The ``mode`` argument is the same as for ``open()``.

This method assumes that the file argument is the first one. We could
generalize if we need to.
"""
def decorator(method):
@functools.wraps(method)
def _method(cls_or_self, file, *args, **kwargs):
if isinstance(file, (str, bytes, os.PathLike)):
with open(os.fspath(file), mode) as fobj:
return method(cls_or_self, fobj, *args, **kwargs)
else:
return method(cls_or_self, file, *args, **kwargs)
return _method
return decorator


def _lock(method):
"""Decorator for Model methods that lock the model for the duration."""
@functools.wraps(method)
def _method(obj, *args, **kwargs):
if not obj.is_locked():
with obj.lock():
return method(obj, *args, **kwargs)
else:
return method(obj, *args, **kwargs)
return _method
58 changes: 47 additions & 11 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
import operator
import os.path
import pathlib
import tempfile
import unittest

Expand Down Expand Up @@ -454,21 +455,56 @@ def test(self):
np.testing.assert_array_equal(a.state(2), range(5))

def test_by_filename(self):
model = Model()
c = model.constant([0, 1, 2, 3, 4])
x = model.list(5)
model.minimize(c[x].sum())
with self.subTest("bytes"):
model = Model()
c = model.constant([0, 1, 2, 3, 4])
x = model.list(5)
model.minimize(c[x].sum())

with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, "temp.nl")

model.into_file(fname.encode("ascii"))

new = Model.from_file(fname)

# todo: use a model equality test function once we have it
for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()):
self.assertIs(type(n0), type(n1))

with self.subTest("path-like"):
model = Model()
c = model.constant([0, 1, 2, 3, 4])
x = model.list(5)
model.minimize(c[x].sum())

with tempfile.TemporaryDirectory() as dirname:
fname = pathlib.PurePath(os.path.join(dirname, "temp.nl"))

model.into_file(fname)

new = Model.from_file(fname)

# todo: use a model equality test function once we have it
for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()):
self.assertIs(type(n0), type(n1))

with self.subTest("str"):
model = Model()
c = model.constant([0, 1, 2, 3, 4])
x = model.list(5)
model.minimize(c[x].sum())

with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, "temp.nl")
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, "temp.nl")

model.into_file(fname)
model.into_file(fname)

new = Model.from_file(fname)
new = Model.from_file(fname)

# todo: use a model equality test function once we have it
for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()):
self.assertIs(type(n0), type(n1))
# todo: use a model equality test function once we have it
for n0, n1 in zip(model.iter_symbols(), new.iter_symbols()):
self.assertIs(type(n0), type(n1))

def test_invalid_version_from_file(self):
from dwave.optimization._model import DEFAULT_SERIALIZATION_VERSION
Expand Down