You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi everyone, I'm training UNETR from scratch with 25 train and 5 val micro-CT datasets and segmentations of flowers. It's a binary problem with only 1 label per image (flower) and the background is not labelled. I've checked the BTCV and Spleen tutorials for parameter setup, but I'm still not convinced I chose all the right settings. My loss function can be volatile and very negative (ranging from 1.5 to -30) for certain images during training. The volatility of my loss function makes me unsure if I can trust the DICE scores I'm getting out, which have been okay so far (~0.70 or so after a few hundred or thousand epochs, depending on parameters). When I run novel images through the trained models, the predictions also look okay, but I want to make sure this isn't purely coincidence.
See minimal example of my code below. I suspect the strange outputs for loss may be an issue with one of the following functions: loss_function, post_label, post_pred, dice_metric.
#@title Get dataset filepaths for train and validation
data_root = os.path.join(root_dir, "input")
train_images = sorted(glob.glob(os.path.join(data_root, "train/images", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_root, "train/labels", "*.nii.gz")))
val_images = sorted(glob.glob(os.path.join(data_root, "val/images", "*.nii.gz")))
val_labels = sorted(glob.glob(os.path.join(data_root, "val/labels", "*.nii.gz")))
train_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]
num_train = len(train_files)
num_val = len(val_files)
print("Training model with {} train images and {} validation images\n".format(num_train, num_val))
print(" Train images loaded from {} \n Train labels loaded from {} \n Validation images loaded from {} \n Validation labels loaded from {}".format(train_images, train_labels, val_images, val_labels))
#@title Dataset Transforms
# use clip with ScaleIntensityRangePercentilesd?
clip = True # @param {"type":"boolean"}
# Weights of pos vs neg labels for RandCropByPosNegLabeld
pos = 1 #@param
neg = 1 #@param
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(.05, .05, .05),
mode=("bilinear", "nearest"),
),
ScaleIntensityRangePercentilesd(
keys="image",
lower=10, upper=90, b_min=0, b_max=200,
clip=clip,
relative=False,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=pos, # 1 to 1 ratio pos:neg applies equal weight; since binary problem, apply more weight to pos
neg=neg,
num_samples=4,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(.05, .05, .05),
mode=("bilinear", "nearest"),
),
ScaleIntensityRangePercentilesd(
keys="image",
lower=10, upper=90, b_min=0, b_max=200,
clip=clip,
relative=False,),
CropForegroundd(keys=["image", "label"], source_key="image"),
]
)
#@title Transform datasets and read into cache for training
# Train Dataset
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_num=num_train, cache_rate=1.0, num_workers=12)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=6, pin_memory=True)
# Validation Dataset
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=num_val, cache_rate=1.0, num_workers=8)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
print("\n\nSample Validation image shape after transform: {}, label shape: {}".format(first(val_loader)['image'].shape, first(val_loader)['label'].shape))
print("\n\nSample Train image shape after transform: {}, label shape: {}".format(first(train_loader)['image'].shape, first(train_loader)['label'].shape))
#@title Set model parameters
# Use cuda GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Parameters
out_channels = 1 #@param
img_size = 96 #@param
feature_size = 32 # @param ["16","32","48","64","70","96"] {"type":"raw"}
hidden_size = 768 #@param
mlp_dim = 3072 #@param
num_heads = 12 #@param
norm_name = "instance" # @param ["instance","batch"]
res_block = True # @param {"type":"boolean"}
conv_block = True # @param {"type":"boolean"}
dropout_rate = 0.0 #@param
pos_embed = "perceptron" # @param ["conv","perceptron"]
# Define the model
model = UNETR(
in_channels=1,
out_channels=out_channels,
img_size=(img_size, img_size, img_size),
feature_size=feature_size,
hidden_size=hidden_size,
mlp_dim=mlp_dim,
num_heads=num_heads,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
dropout_rate=dropout_rate,
proj_type=pos_embed,
).to(device)
# Define the loss function
if out_channels > 1:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
else:
loss_function = DiceCELoss(include_background=True, sigmoid=True)
# Algorithmically optimize GPU training and speed things up, could be variable for val images since size varies
torch.backends.cudnn.benchmark = True #@param
# Set optimizer parameters
optimizer = "AdamW" # @param ["AdamW","Novograd"] {"allow-input":true}
lr = 1e-3 # @param ["1.5e-3","1e-3"] {"type":"raw","allow-input":true}
weight_decay = 1e-4 # @param ["1e-2","0","1e-3","1e-4","1e-1"] {"type":"raw"}
amsgrad = False # @param {"type":"boolean"}
# look into monai.optimizers.WarmupCosineSchedule to schedule lr across epochs after tune other params
if "Adam" in optimizer:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=amsgrad)
else:
optimizer = Novograd(model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=amsgrad)
#@title Train pipeline
# Modified from : Copyright (c) MONAI Consortium and [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0)
# Save each new training attempt results by name (TRAIN_SES_NUM)
saved_models_dir = "saved_models" # @param ["saved_models"] {"allow-input":true}
saved_model_path, TRAIN_SESS_NUM = set_saved_model_path(out_dir, saved_models_dir)
def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
#print("\nValidation input shape: ", val_inputs.shape)
#print("\nValidation ground truth shape: ", val_labels.shape)
val_outputs = sliding_window_inference(val_inputs, roi_size=(img_size, img_size, img_size), sw_batch_size=4, predictor=model)
#print("\nValidation output shape: ", val_outputs.shape)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
#print("Val output convert: ", val_output_convert)
#print("Val labels convert: ", val_labels_convert)
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0)) # noqa: B038
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val
def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].to(device), batch["label"].to(device))
#print("\nModel input shape: ", x.shape)
logit_map = model(x)
#print("\nModel output shape: ", logit_map.shape)
#print("\nGround truth shape: ", y.shape)
loss = loss_function(logit_map, y)
loss.backward()
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
epoch_iterator.set_description( # noqa: B038
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
)
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), saved_model_path)
print(
"\033[32m Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}\033[0m".format(dice_val_best, dice_val)
)
else:
print(
"\033[33mModel Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}\033[0m".format(
dice_val_best, dice_val
)
)
global_step += 1
return global_step, dice_val_best, global_step_best
max_iterations = 500 #@param
eval_num = 10 #@param
post_label = Compose([AsDiscrete(n_classes=out_channels)])
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_metric = DiceMetric(include_background=True, reduction="mean")
global_step = 0
dice_val_best = -0.1
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(saved_models_dir, saved_model_path)))
print(f"\n\n\n train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi everyone, I'm training UNETR from scratch with 25 train and 5 val micro-CT datasets and segmentations of flowers. It's a binary problem with only 1 label per image (flower) and the background is not labelled. I've checked the BTCV and Spleen tutorials for parameter setup, but I'm still not convinced I chose all the right settings. My loss function can be volatile and very negative (ranging from 1.5 to -30) for certain images during training. The volatility of my loss function makes me unsure if I can trust the DICE scores I'm getting out, which have been okay so far (~0.70 or so after a few hundred or thousand epochs, depending on parameters). When I run novel images through the trained models, the predictions also look okay, but I want to make sure this isn't purely coincidence.
See minimal example of my code below. I suspect the strange outputs for loss may be an issue with one of the following functions: loss_function, post_label, post_pred, dice_metric.
Thanks in advance for the feedback!!
Beta Was this translation helpful? Give feedback.
All reactions