Skip to content

Accelerated MAISI and make it compatible with previous DDPM #1953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Mar 19, 2025
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
eb631d1
Update diff_model_train and make it compartitble with previous DDPM. …
Can-Zhao Mar 11, 2025
24c8573
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
3240045
rm redundant code
Can-Zhao Mar 11, 2025
b1385ba
test and update notebook
Can-Zhao Mar 11, 2025
f1210a9
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
Can-Zhao Mar 11, 2025
10e4bf7
rm old json
Can-Zhao Mar 11, 2025
9be20ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
9e6a715
reformat
Can-Zhao Mar 11, 2025
1dff30c
reformat
Can-Zhao Mar 11, 2025
4187565
reformat
Can-Zhao Mar 11, 2025
f448e9e
reformat
Can-Zhao Mar 11, 2025
81ea271
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
a90cbd0
reformat
Can-Zhao Mar 11, 2025
a4c8583
reformat
Can-Zhao Mar 11, 2025
198399c
reformat
Can-Zhao Mar 11, 2025
e3eb585
add controlnet notebook
Can-Zhao Mar 11, 2025
a1832b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
1c42a83
add inference notebook
Can-Zhao Mar 11, 2025
d7850db
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
Can-Zhao Mar 11, 2025
e11dabc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
de69a55
update code
Can-Zhao Mar 11, 2025
9dcaa42
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
Can-Zhao Mar 11, 2025
f93d2fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2025
adc5292
update readme
Can-Zhao Mar 12, 2025
57fb610
update readme
Can-Zhao Mar 12, 2025
2dc8039
update environment
Can-Zhao Mar 12, 2025
151177d
add modality as input, make inference notebook excutable for ddpm and…
Can-Zhao Mar 13, 2025
cfd5636
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
f085a28
add modality as input to config
Can-Zhao Mar 13, 2025
0f7debb
add modality as input to config
Can-Zhao Mar 13, 2025
5831f98
add modality as input to config
Can-Zhao Mar 13, 2025
6f1caf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
e2a4874
update readme
Can-Zhao Mar 13, 2025
d07af26
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
Can-Zhao Mar 13, 2025
56da1e4
bechmark gpu memory
Can-Zhao Mar 13, 2025
76d85e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
866d75f
pass check
Can-Zhao Mar 13, 2025
c8e730c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
8680ae8
readme
Can-Zhao Mar 13, 2025
719d050
train for rflow tested
Can-Zhao Mar 14, 2025
ddcc96f
readme
Can-Zhao Mar 14, 2025
5cff0a4
readme
Can-Zhao Mar 14, 2025
b265a91
train rflow
Can-Zhao Mar 14, 2025
babaf90
train rflow
Can-Zhao Mar 14, 2025
32048e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2025
7cd0d5a
update fig
Can-Zhao Mar 14, 2025
9415053
update fig
Can-Zhao Mar 14, 2025
a9ce6e6
readme
Can-Zhao Mar 14, 2025
875c78b
readme
Can-Zhao Mar 14, 2025
ce53d1d
readme
Can-Zhao Mar 14, 2025
5b2c4d7
readme
Can-Zhao Mar 14, 2025
7779f0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2025
f145874
readme
Can-Zhao Mar 14, 2025
437e325
Merge branch 'main' into maisi
KumoLiu Mar 17, 2025
12dfa47
rm MONAI_DATA_DIRECTORY
Can-Zhao Mar 18, 2025
8d33601
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
Can-Zhao Mar 18, 2025
6b7c6fe
Merge branch 'main' into maisi
Can-Zhao Mar 18, 2025
225e8b4
add expnanation on tumor
Can-Zhao Mar 18, 2025
4e136d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2025
58a265a
add back monai data dir
Can-Zhao Mar 19, 2025
e01a532
update model
Can-Zhao Mar 19, 2025
9ab0b21
update model
Can-Zhao Mar 19, 2025
210a063
update test cases
Can-Zhao Mar 19, 2025
4b24deb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
4c8efcf
reformat
Can-Zhao Mar 19, 2025
c4c80bb
OOM
Can-Zhao Mar 19, 2025
039e6fd
OOM
Can-Zhao Mar 19, 2025
44ecd9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
20a4df6
Merge branch 'main' into maisi
KumoLiu Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Mar 11, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 24c8573eaa76903400adbccce99b705b7535f874
24 changes: 16 additions & 8 deletions generation/maisi/scripts/diff_model_infer.py
Original file line number Diff line number Diff line change
@@ -149,7 +149,7 @@ def run_inference(
if isinstance(noise_scheduler, RFlowScheduler):
noise_scheduler.set_timesteps(
num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
input_img_size_numel=torch.prod(torch.tensor(noise.shape[-3:]))
input_img_size_numel=torch.prod(torch.tensor(noise.shape[-3:])),
)
else:
noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
@@ -161,9 +161,9 @@ def run_inference(
all_timesteps = noise_scheduler.timesteps
all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
progress_bar = tqdm(
zip(all_timesteps, all_next_timesteps),
total=min(len(all_timesteps), len(all_next_timesteps)),
)
zip(all_timesteps, all_next_timesteps),
total=min(len(all_timesteps), len(all_next_timesteps)),
)
with torch.amp.autocast("cuda", enabled=True):
for t, next_t in progress_bar:
model_output = unet(
@@ -178,7 +178,6 @@ def run_inference(
else:
image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore


inferer = SlidingWindowInferer(
roi_size=(
min(output_size[0] // divisor // 4 * 3, 96),
@@ -228,7 +227,9 @@ def save_image(


@torch.inference_mode()
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, include_body_region: bool = False ) -> None:
def diff_model_infer(
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, include_body_region: bool = False
) -> None:
"""
Main function to run the diffusion model inference.
@@ -335,7 +336,14 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
default=1,
help="Number of GPUs to use for distributed inference",
)
parser.add_argument("--include_body_region", dest="include_body_region", action="store_true", help="Whether to include body region in data")
parser.add_argument(
"--include_body_region",
dest="include_body_region",
action="store_true",
help="Whether to include body region in data",
)

args = parser.parse_args()
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus, include_body_region=args.include_body_region)
diff_model_infer(
args.env_config, args.model_config, args.model_def, args.num_gpus, include_body_region=args.include_body_region
)
57 changes: 36 additions & 21 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
@@ -51,12 +51,12 @@ def load_filenames(data_list_path: str) -> list:


def prepare_data(
train_files: list,
device: torch.device,
cache_rate: float,
num_workers: int = 2,
batch_size: int = 1,
include_body_region: bool = False
train_files: list,
device: torch.device,
cache_rate: float,
num_workers: int = 2,
batch_size: int = 1,
include_body_region: bool = False,
) -> DataLoader:
"""
Prepare training data.
@@ -78,11 +78,11 @@ def _load_data_from_file(file_path, key):
return torch.FloatTensor(json.load(f)[key])

train_transforms_list = [
monai.transforms.LoadImaged(keys=["image"]),
monai.transforms.EnsureChannelFirstd(keys=["image"]),
monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
]
monai.transforms.LoadImaged(keys=["image"]),
monai.transforms.EnsureChannelFirstd(keys=["image"]),
monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
]
if include_body_region:
train_transforms_list += [
monai.transforms.Lambdad(
@@ -202,7 +202,7 @@ def train_one_epoch(
logger: logging.Logger,
local_rank: int,
amp: bool = True,
include_body_region: bool = False
include_body_region: bool = False,
) -> torch.Tensor:
"""
Train the model for one epoch.
@@ -284,9 +284,10 @@ def train_one_epoch(
# predict velocity
loss = loss_pt(model_output.float(), (images - noise).float())
else:
raise ValueError("noise scheduler prediction type has to be chosen from ",
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]"
)
raise ValueError(
"noise scheduler prediction type has to be chosen from ",
f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
)

if amp:
scaler.scale(loss).backward()
@@ -349,7 +350,12 @@ def save_checkpoint(


def diff_model_train(
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True, include_body_region: bool = False
env_config_path: str,
model_config_path: str,
model_def_path: str,
num_gpus: int,
amp: bool = True,
include_body_region: bool = False,
) -> None:
"""
Main function to train a diffusion model.
@@ -400,9 +406,11 @@ def diff_model_train(
)[local_rank]

train_loader = prepare_data(
train_files, device, args.diffusion_unet_train["cache_rate"],
train_files,
device,
args.diffusion_unet_train["cache_rate"],
batch_size=args.diffusion_unet_train["batch_size"],
include_body_region = include_body_region
include_body_region=include_body_region,
)

unet = load_unet(args, device, logger)
@@ -438,7 +446,7 @@ def diff_model_train(
logger,
local_rank,
amp=amp,
include_body_region=include_body_region
include_body_region=include_body_region,
)

loss_torch = loss_torch.tolist()
@@ -479,7 +487,14 @@ def diff_model_train(
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
parser.add_argument("--include_body_region", dest="include_body_region", action="store_true", help="Whether to include body region in data")
parser.add_argument(
"--include_body_region",
dest="include_body_region",
action="store_true",
help="Whether to include body region in data",
)

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp, args.include_body_region)
diff_model_train(
args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp, args.include_body_region
)
6 changes: 3 additions & 3 deletions generation/maisi/scripts/utils.py
Original file line number Diff line number Diff line change
@@ -712,14 +712,14 @@ def dynamic_infer(inferer, model, images):
# Extract the spatial dimensions from the images tensor (H, W, D)
spatial_dims = images.shape[2:]
orig_roi = inferer.roi_size

# Check that roi has the same number of dimensions as spatial_dims
if len(orig_roi) != len(spatial_dims):
raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).")

# Iterate and adjust each ROI dimension
adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)]
inferer.roi_size = adjusted_roi
output = inferer(network=model, inputs=images)
inferer.roi_size = orig_roi
return output
return output