Skip to content

Commit da6afc4

Browse files
authored
Merge pull request #60 from PlasmaControl/main
Update the test structure in dev branch.
2 parents 979681f + 4534751 commit da6afc4

File tree

9 files changed

+87
-45
lines changed

9 files changed

+87
-45
lines changed

.github/workflows/python-test-push.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
runs-on: ubuntu-latest
99
strategy:
1010
matrix:
11-
python-version: [3.7, 3.8, 3.9] # 3
11+
python-version: [3.9, 3.11, 3.12] # 3
1212
steps:
1313
- uses: actions/checkout@v2
1414
- name: Set up Python ${{ matrix.python-version }}
@@ -18,8 +18,8 @@ jobs:
1818
- name: Install dependencies
1919
run: |
2020
python -m pip install --upgrade pip
21-
pip install pytest==5.4.1 flake8==4.0.1 pytest-flake8 mypy pytest-mypy pytest-cov \
22-
pytest-pep257 types-setuptools
21+
pip install pytest flake8 pytest-flake8-v2 mypy pytest-mypy pytest-cov \
22+
types-setuptools
2323
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
2424
- name: Lint with flake8
2525
run: |

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ Please see the [References section](#references) for more information. For code
3939

4040
### Prerequisites
4141

42-
PyRCN is developed using Python 3.7 or newer. It depends on the following packages:
42+
PyRCN is developed using Python 3.9 or newer. It depends on the following packages:
4343

4444
- [numpy>=1.18.1](https://numpy.org/)
45-
- [scikit-learn>=1.0](https://scikit-learn.org/stable/)
45+
- [scikit-learn>=1.4](https://scikit-learn.org/stable/)
4646
- [joblib>=0.13.2](https://joblib.readthedocs.io)
4747
- [pandas>=1.0.0](https://pandas.pydata.org/)
4848
- [matplotlib](https://matplotlib.org/)

pytest.ini

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# pytest.ini
22
[pytest]
3-
minversion = 5.0
3+
minversion = 6.5.0
44
addopts =
55
-ra -q -v
66
--doctest-modules
7-
# --junitxml=junit/test-results.xml
8-
# --cov=src/pyrcn
9-
# --cov-branch
10-
# --cov-report=xml
11-
# --cov-report=html
12-
# --cov-report=term-missing
7+
--junitxml=junit/test-results.xml
8+
--cov=src/pyrcn
9+
--cov-branch
10+
--cov-report=xml
11+
--cov-report=html
12+
--cov-report=term-missing
13+
# --pep257
1314
# --flake8
14-
# --mypy
15+
--mypy
1516

1617
testpaths =
1718
src/pyrcn

requirements.txt

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
setuptools>=46.4.0
2+
scikit-learn>=1.4
3+
numpy>=1.18.1
4+
scipy>=1.4.0
5+
joblib>=0.13.2
6+
pandas>=1.0.0
7+
typing-extensions
8+
requests[matplotlib]
9+
requests[seaborn]
10+
requests[ipywidgets]
11+
requests[ipympl]
12+
requests[tqdm]

setup.cfg

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[metadata]
2+
name = PyRCN
3+
version = 0.0.18
4+
author = Peter Steiner
5+
author_email = [email protected]
6+
description = A scikit-learn-compatible framework for Reservoir Computing in Python
7+
long_description = file: README.md
8+
long_description_content_type = text/markdown
9+
url = https://github.com/TUD-STKS/PyRCN
10+
project_urls =
11+
Documentation = https://pyrcn.readthedocs.io/
12+
Funding = https://pyrcn.net/
13+
Source = https://github.com/TUD-STKS/PyRCN/
14+
Bug Tracker = https://github.com/TUD-STKS/PyRCN/issues
15+
classifiers =
16+
Programming Language :: Python :: 3
17+
Development Status :: 2 - Pre-Alpha
18+
License :: OSI Approved :: BSD License
19+
Operating System :: OS Independent
20+
Intended Audience :: Science/Research
21+
22+
[options]
23+
package_dir =
24+
= src
25+
packages = find:
26+
python_requires = >=3.9
27+
28+
[options.packages.find]
29+
where = src
30+
31+
[tool:pytest]
32+
testpaths = tests

src/pyrcn/base/_base.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,14 @@ def _make_sparse(k_in: int, dense_weights: np.ndarray,
9191
The sparse layer weights
9292
"""
9393
n_inputs, n_outputs = dense_weights.shape
94-
nr_entries = int(n_inputs * k_in)
95-
indices = np.zeros(shape=nr_entries, dtype=int)
96-
indptr = np.arange(start=0, stop=(n_inputs + 1) * k_in, step=k_in)
97-
dense_weights = dense_weights.flatten()[:nr_entries]
98-
99-
for en in range(0, n_inputs * k_in, k_in):
100-
indices[en: en + k_in] = \
101-
random_state.permutation(n_outputs)[:k_in].astype(int)
102-
return scipy.sparse.csr_matrix(
103-
(dense_weights, indices, indptr), shape=(n_inputs, n_outputs),
104-
dtype='float64')
94+
95+
for neuron in range(n_outputs):
96+
all_indices = np.arange(n_inputs)
97+
keep_indices = np.random.choice(n_inputs, k_in, replace=False)
98+
zero_indices = np.setdiff1d(all_indices, keep_indices)
99+
dense_weights[zero_indices, neuron] = 0
100+
101+
return scipy.sparse.csr_matrix(dense_weights, dtype='float64')
105102

106103

107104
def _normal_random_weights(

src/pyrcn/model_selection/_search.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sklearn.base import BaseEstimator, is_classifier, clone
1010
from sklearn.model_selection._search import BaseSearchCV
11-
from sklearn.utils.validation import indexable, _check_fit_params
11+
from sklearn.utils.validation import indexable, _check_method_params
1212
from sklearn.model_selection._split import check_cv
1313

1414
import numpy as np
@@ -510,7 +510,7 @@ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None, *,
510510
func = self.func
511511

512512
X, y, groups = indexable(X, y, groups)
513-
fit_params = _check_fit_params(X, fit_params)
513+
fit_params = _check_method_params(X, fit_params)
514514

515515
cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator))
516516
n_splits = cv_orig.get_n_splits(X, y, groups)

tests/test_esn.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -127,38 +127,38 @@ def test_esn_output_unchanged() -> None:
127127
esn = ESNClassifier(hidden_layer_size=50).fit(X, y)
128128
print(esn)
129129
shape2 = y[0].shape
130-
assert(shape1 == shape2)
130+
assert (shape1 == shape2)
131131

132132

133133
def test_esn_classifier_sequence_to_value() -> None:
134134
X, y = load_digits(return_X_y=True, as_sequence=True)
135135
esn = ESNClassifier(hidden_layer_size=50).fit(X, y)
136136
y_pred = esn.predict(X)
137-
assert(len(y) == len(y_pred))
138-
assert(len(y_pred[0]) == 1)
139-
assert(esn.sequence_to_value is True)
140-
assert(esn.decision_strategy == "winner_takes_all")
137+
assert (len(y) == len(y_pred))
138+
assert (len(y_pred[0]) == 1)
139+
assert (esn.sequence_to_value is True)
140+
assert (esn.decision_strategy == "winner_takes_all")
141141
y_pred = esn.predict_proba(X)
142-
assert(y_pred[0].ndim == 1)
142+
assert (y_pred[0].ndim == 1)
143143
y_pred = esn.predict_log_proba(X)
144-
assert(y_pred[0].ndim == 1)
144+
assert (y_pred[0].ndim == 1)
145145
esn.sequence_to_value = False
146146
y_pred = esn.predict(X)
147-
assert(len(y_pred[0]) == 8)
147+
assert (len(y_pred[0]) == 8)
148148
y_pred = esn.predict_proba(X)
149-
assert(y_pred[0].ndim == 2)
149+
assert (y_pred[0].ndim == 2)
150150
y_pred = esn.predict_log_proba(X)
151-
assert(y_pred[0].ndim == 2)
151+
assert (y_pred[0].ndim == 2)
152152

153153

154154
def test_esn_classifier_instance_fit() -> None:
155155
X, y = load_digits(return_X_y=True, as_sequence=True)
156156
esn = ESNClassifier(hidden_layer_size=50).fit(X[0], np.repeat(y[0], 8))
157-
assert(esn.sequence_to_value is False)
157+
assert (esn.sequence_to_value is False)
158158
y_pred = esn.predict_proba(X[0])
159-
assert(y_pred.ndim == 2)
159+
assert (y_pred.ndim == 2)
160160
y_pred = esn.predict_log_proba(X[0])
161-
assert(y_pred.ndim == 2)
161+
assert (y_pred.ndim == 2)
162162

163163

164164
def test_esn_classifier_partial_fit() -> None:

tests/test_model_selection.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def test_sequentialSearchCV_equivalence() -> None:
3131
('gs3', GridSearchCV, param_grid1)]).fit(X, y)
3232
assert gs1.best_params_ == ss.all_best_params_['gs1']
3333
assert gs2.best_params_ == ss.all_best_params_['gs2']
34-
assert(isinstance(ss.cv_results_, dict))
35-
assert(ss.best_estimator_ is not None)
36-
assert(isinstance(ss.best_score_, float))
34+
assert (isinstance(ss.cv_results_, dict))
35+
assert (ss.best_estimator_ is not None)
36+
assert (isinstance(ss.best_score_, float))
3737
print(ss.best_index_)
38-
assert(isinstance(ss.n_splits_, int))
39-
assert(isinstance(ss.refit_time_, float))
40-
assert(isinstance(ss.multimetric, bool))
38+
assert (isinstance(ss.n_splits_, int))
39+
assert (isinstance(ss.refit_time_, float))
40+
assert (isinstance(ss.multimetric, bool))
4141

4242

4343
@pytest.mark.skip(reason="no way of currently testing this")

0 commit comments

Comments
 (0)