Skip to content

Commit 33d27c2

Browse files
committedSep 21, 2020
Multi-Input-Output
1 parent 72a01ff commit 33d27c2

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed
 

Diff for: ‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ add_header.py
1414
*fair_cluster*
1515
fastmri_dirs.yaml
1616
dataset_cache.pkl
17+
devnotes.md

Diff for: ‎experimental/varnet/train_varnet_demo.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def build_args():
4848
# ------------------------
4949
path_config = pathlib.Path.cwd() / ".." / ".." / "fastmri_dirs.yaml"
5050
knee_path = fetch_dir("knee_path", path_config)
51+
brain_path = fetch_dir("brain_path", path_config)
5152
logdir = fetch_dir("log_path", path_config) / "varnet" / "varnet_demo"
5253

5354
parent_parser = ArgumentParser(add_help=False)
@@ -73,7 +74,7 @@ def build_args():
7374
lr_step_size=40,
7475
lr_gamma=0.1,
7576
weight_decay=0.0,
76-
data_path=knee_path,
77+
data_path=brain_path,
7778
challenge="multicoil",
7879
exp_dir=logdir,
7980
exp_name="varnet_demo",
@@ -90,6 +91,7 @@ def build_args():
9091
distributed_backend=backend,
9192
seed=42,
9293
deterministic=True,
94+
overfit_batches=1
9395
)
9496

9597
parser.add_argument("--mode", default="train", type=str)

Diff for: ‎fastmri/mri_module.py

+18
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,27 @@ def validation_step_end(self, val_logs):
183183
val_logs = {key: value.cpu() for key, value in val_logs.items()}
184184
val_logs["device"] = device
185185

186+
print("End of Validation Step has been reached!...Saving this val log now!")
187+
188+
for k, v in val_logs.items():
189+
if k == "device":
190+
continue
191+
print(k, v.ndim)
192+
186193
return val_logs
187194

188195
def validation_epoch_end(self, val_logs):
196+
197+
for k, v in val_logs[0].items():
198+
if k == "device":
199+
continue
200+
print(k, val_logs[0][f"{k}"].ndim)
201+
202+
203+
for output in val_logs[0]["output"]:
204+
print("outputs: ", len(output))
205+
206+
189207
assert val_logs[0]["output"].ndim == 3
190208
device = val_logs[0]["device"]
191209

0 commit comments

Comments
 (0)
Please sign in to comment.