Skip to content

Commit 1d4c83f

Browse files
committed
Add scripts for deep learning reconstruction of prostate scans using Varnet
1 parent 00894e8 commit 1d4c83f

23 files changed

+2931
-0
lines changed

.gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,10 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# log directory
132+
logs/
133+
varnet/
134+
135+
# bash files
136+
*.sh

DL_reconstruction/README.md

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Prostate Diffusion MRI Reconstruction
2+
3+
This repository contains code for reconstruction accelerated prostate diffusion MRI, derived from the [fastMRI](https://github.com/facebookresearch/fastMRI) repository and modified specifically for the fastmri prostate diffusion data. Our model supports reconstruction for b50 and b1000 diffusion images, outputting b50 or b1000 trace images accordingly.
4+
5+
## Features
6+
7+
- **Data Compatibility**: Designed for reconstruction of prostate diffusion MRI data, specifically compatible with the fastMRI prostate dataset.
8+
- **Model Output**: Generates b50 or b1000 trace images
9+
- **Based on fastMRI**: Leverages the framework and model from the fastMRI repository.
10+
11+
## Usage
12+
13+
To run the model, you can use the following commands for training and testing:
14+
15+
### Training
16+
17+
```bash
18+
python train_varnet_demo.py --mode train --data_path [path_to_data] --bvalue b50 --test_path [path_to_test_data] --state_dict_file [path_to_checkpoint_for_testing] --batch_size 1 --num_workers 4
19+
```
20+
21+
Replace `[script_name]` with the name of your script, `[path_to_data]` with the path to your training data, `[path_to_test_data]` with the path to your testing data, and `[path_to_checkpoint_for_testing]` with the path to a model checkpoint file if you have one for testing.
22+
23+
### Testing
24+
25+
```bash
26+
python train_varnet_demo.py --mode test --data_path [path_to_data] --bvalue b1000 --test_path [path_to_test_data] --state_dict_file [path_to_checkpoint] --batch_size 1 --num_workers 4
27+
```
28+
29+
Ensure you specify the `--state_dict_file` argument during testing to provide the path to your model checkpoint.
30+
31+
## Configuration
32+
33+
The script accepts various command-line arguments to customize the data paths, model parameters, and training settings. Refer to the source code for a complete list of available options.
34+
35+
## Contributing
36+
37+
Contributions to improve this project are welcome. Please consider submitting a pull request or opening an issue for any bugs or feature requests.
38+
39+
## License
40+
41+
This project is licensed under the MIT License - see the LICENSE file for details.
42+

DL_reconstruction/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Copyright (c) Facebook, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from .losses import SSIMLoss
9+
from .utils import save_reconstructions
10+
from .coil_combine import rss, rss_complex
11+
from .fftc import fft2c_new as fft2c
12+
from .fftc import fftshift
13+
from .fftc import ifft2c_new as ifft2c
14+
from .fftc import ifftshift, roll
15+
from .math_fn import (
16+
complex_abs,
17+
complex_abs_sq,
18+
complex_conj,
19+
complex_mul,
20+
tensor_to_complex_np,
21+
)

DL_reconstruction/coil_combine.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Copyright (c) Facebook, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
import torch
9+
10+
from .math_fn import complex_abs_sq
11+
12+
def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
13+
"""
14+
Compute the Root Sum of Squares (RSS).
15+
16+
RSS is computed assuming that dim is the coil dimension.
17+
18+
Args:
19+
data: The input tensor
20+
dim: The dimensions along which to apply the RSS transform
21+
22+
Returns:
23+
The RSS value.
24+
"""
25+
return torch.sqrt((data ** 2).sum(dim))
26+
27+
28+
def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
29+
"""
30+
Compute the Root Sum of Squares (RSS) for complex inputs.
31+
32+
RSS is computed assuming that dim is the coil dimension.
33+
34+
Args:
35+
data: The input tensor
36+
dim: The dimensions along which to apply the RSS transform
37+
38+
Returns:
39+
The RSS value.
40+
"""
41+
return torch.sqrt(complex_abs_sq(data).sum(dim))

DL_reconstruction/data/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
Copyright (c) Facebook, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from .mri_data import SliceDataset
9+
from .volume_sampler import VolumeSampler

0 commit comments

Comments
 (0)