-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_base.py
84 lines (62 loc) · 2.69 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Any, Dict, Tuple, List, Optional
from pathlib import Path
from proxystore.connectors.file import FileConnector
from proxystore.store import Store
from proxystore.store import register_store
from proxystore.store import unregister_store
from pytest import fixture
from colmena.models import Result, ExecutableTask, SerializationMethod
from colmena.task_server.base import run_and_record_timing
# TODO (wardlt): Figure how to import this from test_models
class EchoTask(ExecutableTask):
def __init__(self):
super().__init__(executable=['echo'])
def preprocess(self, run_dir: Path, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[List[str], Optional[str]]:
return list(map(str, args)), None
def postprocess(self, run_dir: Path) -> Any:
return (run_dir / 'colmena.stdout').read_text()
class FakeMPITask(ExecutableTask):
def __init__(self):
super().__init__(executable=['echo', '-n'],
mpi=True,
mpi_command_string='echo -N {total_ranks} -n {cpu_processes} --cc depth')
def preprocess(self, run_dir: Path, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[List[str], Optional[str]]:
return list(map(str, args)), None
def postprocess(self, run_dir: Path) -> Any:
return (run_dir / 'colmena.stdout').read_text()
def test_run_with_executable():
result = Result(inputs=((1,), {}))
func = EchoTask()
run_and_record_timing(func, result)
result.deserialize()
assert result.value == '1\n'
@fixture
def store(tmpdir):
with Store('store', FileConnector(tmpdir), metrics=True) as store:
register_store(store)
yield store
unregister_store(store)
def test_run_function(store):
"""Make sure the run function behaves as expected:
- Records runtimes
- Tracks proxy statistics
"""
# Make the result and configure it to use the store
result = Result(inputs=(('a' * 1024,), {}))
result.proxystore_name = store.name
result.proxystore_threshold = 128
result.proxystore_config = store.config()
# Serialize it
result.serialization_method = SerializationMethod.PICKLE
result.serialize()
# Run the function
run_and_record_timing(lambda x: x.upper(), result)
# Make sure the timings are all set
assert result.time.running > 0
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started
# Make sure we have stats for both proxies
assert len(result.time.proxy) == 2
assert all('store.proxy' in v['times'] for v in result.time.proxy.values())