Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict check #179

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ v0.6.0 (Unreleased)
This will be set to the main clock when storing the dataset.
- Changed default ``fill_value`` in the zarr stores to maximum dtype value
for integer dtypes and ``np.nan`` for floating-point variables.
- Added custom dependencies as option at model creation e.g.
``xs.Model({"a":A,"b":B},custom_dependencies={"a":"b"})

v0.5.0 (26 January 2021)
------------------------
Expand Down
244 changes: 222 additions & 22 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def get_processes_to_validate(self):

return {k: list(v) for k, v in processes_to_validate.items()}

def get_process_dependencies(self):
def get_process_dependencies(self, custom_dependencies={}):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of dependent processes (or empty
lists for processes that have no dependencies).
Expand All @@ -423,6 +423,10 @@ def get_process_dependencies(self):
]
)

# actually add custom dependencies
for p_name, deps in custom_dependencies.items():
self._dep_processes[p_name].update(deps)

for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, intent=VarIntent.OUT).values():
if var.metadata["var_type"] == VarType.ON_DEMAND:
Expand Down Expand Up @@ -455,6 +459,7 @@ def _sort_processes(self):

"""
ordered = []
self._deps_dict = {p: set() for p in self._dep_processes}

# Nodes whose descendents have been completely explored.
# These nodes are guaranteed to not be part of a cycle.
Expand Down Expand Up @@ -484,18 +489,19 @@ def _sort_processes(self):
# Add direct descendants of cur to nodes stack
next_nodes = []
for nxt in self._dep_processes[cur]:
if nxt not in completed:
if nxt in seen:
# Cycle detected!
cycle = [nxt]
while nodes[-1] != nxt:
cycle.append(nodes.pop())
if nxt in seen:
# Cycle detected!
cycle = [nxt]
while nodes[-1] != nxt:
cycle.append(nodes.pop())
cycle.reverse()
cycle = "->".join(cycle)
raise RuntimeError(
f"Cycle detected in process graph: {cycle}"
)
cycle.append(nodes.pop())
cycle.reverse()
cycle = "->".join(cycle)
raise RuntimeError(f"Cycle detected in process graph: {cycle}")
if nxt in completed:
self._deps_dict[cur].add(nxt)
self._deps_dict[cur].update(self._deps_dict[nxt])
else:
next_nodes.append(nxt)

if next_nodes:
Expand All @@ -507,8 +513,142 @@ def _sort_processes(self):
completed.add(cur)
seen.remove(cur)
nodes.pop()

return ordered

def _strict_order_check(self):
"""
IMPORTANT: _sort_processes should be run first
checks if all inout variables and corresponding in variables are explicitly set in the dependencies
Out variables always come first, since the get_process_dependencies checks for that.
A well-behaved graph looks like: ``in0->inout1->in1->inout2->in2``
"""
# create dictionaries with all inout variables and input variables
inout_dict = {} # dict of {key:{p1_name,p2_name}} for inout variables
# TODO: improve this: the aim is to create a {key:{p1,p2,p3}} dict,
# where p1,p2,p3 are process names that have the key var as inout, resp. in vars
# some problems are that we can have on_demand and state varibles,
# that key can return a tuple or list,
for p_name, p_obj in self._processes_obj.items():
# create {key:{p1_name,p2_name}} dicts for in and inout vars.
for var in filter_variables(p_obj, intent=VarIntent.INOUT).values():
target_keys = self._get_var_key(p_name, var)
inout_dict.setdefault(target_keys, set()).add(p_name)

in_dict = {key: set() for key in inout_dict}
for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, intent=VarIntent.IN).values():
target_keys = self._get_var_key(p_name, var)
if target_keys in in_dict:
in_dict[target_keys].add(p_name)

# filter out variables that do not need to be checked (without inputs):
# inout_dict = {k: v for k, v in inout_dict.items() if k in in_dict}

for key, inout_ps in inout_dict.items():
in_ps = in_dict[key]

verified_ios = []
# now we only have to search and verify all inout variables
for io_p in inout_ps:
io_stack = [io_p]
while io_stack:
cur = io_stack[-1]
if cur in verified_ios:
io_stack.pop()
continue

child_ios = self._deps_dict[cur].intersection(inout_ps - {cur})
if child_ios:
if child_ios == set(verified_ios):
child_ins = in_ps.intersection(self._deps_dict[cur])
# verify that all children have the previous io as
# dependency
problem_children = {}
for child_in in child_ins:
# we want to list all processes that should
# depend on the previous
# io-io
# /
# in
if not verified_ios[-1] in self._deps_dict[child_in]:
problem_children[child_in] = [
p
for p in verified_ios
if p not in self._deps_dict[child_in]
]
if problem_children:
raise RuntimeError(
f"While checking {key}, {cur} updates it"
f" and depends on some processes that use"
f" it, but they do not depend on {verified_ios[-1]}"
f". Place them somewhere between or before "
f"their values: {problem_children}"
)
# we can now safely remove these in nodes
in_ps -= child_ins
verified_ios.append(cur)
io_stack.pop()
elif child_ios - set(verified_ios):
io_stack.extend(child_ios)
else:
# the problem here is that
# io-..-io
# \
# io
problem_ios = [
p for p in verified_ios if p not in child_ios
]
raise RuntimeError(
f"while checking {key}, order of inout process "
f"{cur} compared to {problem_ios} could not be "
f"established. Place it before {problem_ios[0]}"
f" in between, or after {problem_ios[-1]}"
)
else:
# we are at the bottom inout process: remove in
# variables from the set
# this can only happen if we are the first process at
# the bottom
if verified_ios:
# the problem here is
# io->..->io
# /
# io
problem_ios = [
p for p in verified_ios if cur not in self._deps_dict[p]
]
raise RuntimeError(
f"While checking {key}, inout process "
f"{verified_ios[-1]} has two branch dependencies."
f" Place {cur} before, after or somewhere between "
f"{verified_ios[:-1]}"
)
in_ps -= self._deps_dict[cur]
verified_ios.append(cur)
io_stack.pop()

# we finished all inout, and inputs that are descendants of inout
# vars, so all remaining input vars should depend on the last inout
# var
problem_ins = {}
for p in in_ps:
if not verified_ios[-1] in self._deps_dict[p]:
problem_ins[p] = [
prob for prob in verified_ios if prob not in self._deps_dict[p]
]

if problem_ins:
#
# io->io->io->io
# \ \
# in in in
raise RuntimeError(
f"while checking {key}, some input processes do not depend "
f"on {verified_ios[-1]}, with all inout processes {verified_ios}"
f" place them somewhere in between, before or after their values: {problem_ins}"
)

def get_sorted_processes(self):
self._sorted_processes = OrderedDict(
[(p_name, self._processes_obj[p_name]) for p_name in self._sort_processes()]
Expand All @@ -523,8 +663,9 @@ class Model(AttrMapping):
This collection is ordered such that the computational flow is
consistent with process inter-dependencies.

Ordering doesn't need to be explicitly provided ; it is dynamically
computed using the processes interfaces.
Ordering doesn't always need to be explicitly provided ; it is dynamically
computed using the processes interfaces. For other cases, custom
dependencies can be supplied.

Processes interfaces are also used for automatically retrieving
the model inputs, i.e., all the variables that require setting a
Expand All @@ -534,17 +675,25 @@ class Model(AttrMapping):

active = []

def __init__(self, processes):
def __init__(self, processes, custom_dependencies={}, strict_order_check=False):
"""
Parameters
----------
processes : dict
Dictionnary with process names as keys and classes (decorated with
Dictionary with process names as keys and classes (decorated with
:func:`process`) as values.
custom_dependencies : dict
Dictionary of custom dependencies.
keys are process names and values iterable of process names that it
depends on.
strict_order_check : bool
if True, aggresively check for correct ordering. (default: False)
For a variable with processes for which it is an inout variable, it
should look like: ``ins0->inout1->ins1->inout2->ins2``

Raises
------
:exc:`NoteAProcessClassError`
:exc:`NotAProcessClassError`
If values in ``processes`` are not classes decorated with
:func:`process`.

Expand Down Expand Up @@ -572,9 +721,21 @@ def __init__(self, processes):

self._processes_to_validate = builder.get_processes_to_validate()

self._dep_processes = builder.get_process_dependencies()
# clean custom dependencies
self._custom_dependencies = {}
for p_name, c_deps in custom_dependencies.items():
c_deps = {c_deps} if isinstance(c_deps, str) else set(c_deps)
self._custom_dependencies[p_name] = c_deps

self._dep_processes = builder.get_process_dependencies(
self._custom_dependencies
)
self._processes = builder.get_sorted_processes()

self._strict_order_check = strict_order_check
if self._strict_order_check:
builder._strict_order_check()

super(Model, self).__init__(self._processes)
self._initialized = True

Expand Down Expand Up @@ -1065,7 +1226,7 @@ def drop_processes(self, keys):

Parameters
----------
keys : str or list of str
keys : str or iterable of str
Name(s) of the processes to drop.

Returns
Expand All @@ -1074,13 +1235,52 @@ def drop_processes(self, keys):
New Model instance with dropped processes.

"""
if isinstance(keys, str):
keys = [keys]
keys = {keys} if isinstance(keys, str) else set(keys)

processes_cls = {
k: type(obj) for k, obj in self._processes.items() if k not in keys
}
return type(self)(processes_cls)

# we also should check for chains of deps e.g.
# a->b->c->d->e where {b,c,d} are removed
# then we have a->e left over.
# perform a depth-first search on custom dependencies
# and let the custom deps propagate forward
completed = set()
for key in self._custom_dependencies:
if key in completed:
continue
key_stack = [key]
while key_stack:
cur = key_stack[-1]
if cur in completed:
key_stack.pop()
continue

# if we have custom dependencies that are removed
# and are fully traversed, add their deps to the current
child_keys = keys.intersection(self._custom_dependencies[cur])
if child_keys.issubset(completed):
# all children are added, so we are safe
self._custom_dependencies[cur].update(
*[
self._custom_dependencies[child_key]
for child_key in child_keys
]
)
self._custom_dependencies[cur] -= child_keys
completed.add(cur)
key_stack.pop()
else: # if child_keys - completed:
# we need to search deeper: add to the stack.
key_stack.extend([k for k in child_keys - completed])

# now also remove keys from custom deps
for key in keys:
if key in self._custom_dependencies:
del self._custom_dependencies[key]

return type(self)(processes_cls, self._custom_dependencies)

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand Down
Loading