|
| 1 | +import glob |
| 2 | +import os |
| 3 | +import SimpleITK as sitk |
| 4 | +import numpy as np |
| 5 | +import argparse |
| 6 | +from medpy.metric import binary |
| 7 | + |
| 8 | +def read_nii(path): |
| 9 | + return sitk.GetArrayFromImage(sitk.ReadImage(path)) |
| 10 | + |
| 11 | +def new_dice(pred,label): |
| 12 | + tp_hard = np.sum((pred == 1).astype(np.float) * (label == 1).astype(np.float)) |
| 13 | + fp_hard = np.sum((pred == 1).astype(np.float) * (label != 1).astype(np.float)) |
| 14 | + fn_hard = np.sum((pred != 1).astype(np.float) * (label == 1).astype(np.float)) |
| 15 | + return 2*tp_hard/(2*tp_hard+fp_hard+fn_hard) |
| 16 | + |
| 17 | +def dice(pred, label): |
| 18 | + if (pred.sum() + label.sum()) == 0: |
| 19 | + return 1 |
| 20 | + else: |
| 21 | + return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) |
| 22 | + |
| 23 | +def hd(pred,gt): |
| 24 | + if pred.sum() > 0 and gt.sum()>0: |
| 25 | + hd95 = binary.hd95(pred, gt) |
| 26 | + return hd95 |
| 27 | + else: |
| 28 | + return 0 |
| 29 | + |
| 30 | +def process_label(label): |
| 31 | + net = label == 2 |
| 32 | + ed = label == 1 |
| 33 | + et = label == 3 |
| 34 | + ET=et |
| 35 | + TC=net+et |
| 36 | + WT=net+et+ed |
| 37 | + ED= ed |
| 38 | + NET=net |
| 39 | + return ET,TC,WT,ED,NET |
| 40 | + |
| 41 | +def test(fold): |
| 42 | + #path='./' |
| 43 | + |
| 44 | + path = None # Replace None by the full path of : unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/" |
| 45 | + |
| 46 | + label_list=sorted(glob.glob(os.path.join(path,'labelsTs','*nii.gz'))) |
| 47 | + |
| 48 | + infer_path = None # Replace None by the full path of : unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/" |
| 49 | + |
| 50 | + infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz'))) |
| 51 | + print("loading success...") |
| 52 | + Dice_et=[] |
| 53 | + Dice_tc=[] |
| 54 | + Dice_wt=[] |
| 55 | + Dice_ed=[] |
| 56 | + Dice_net=[] |
| 57 | + |
| 58 | + HD_et=[] |
| 59 | + HD_tc=[] |
| 60 | + HD_wt=[] |
| 61 | + HD_ed=[] |
| 62 | + HD_net=[] |
| 63 | + file=infer_path + 'inferTs/'+fold |
| 64 | + if not os.path.exists(file): |
| 65 | + os.makedirs(file) |
| 66 | + fw = open(file+'/dice_five.txt', 'w') |
| 67 | + |
| 68 | + for label_path,infer_path in zip(label_list,infer_list): |
| 69 | + print(label_path.split('/')[-1]) |
| 70 | + print(infer_path.split('/')[-1]) |
| 71 | + label,infer = read_nii(label_path),read_nii(infer_path) |
| 72 | + label_et,label_tc,label_wt,label_ed,label_net=process_label(label) |
| 73 | + infer_et,infer_tc,infer_wt,infer_ed,infer_net=process_label(infer) |
| 74 | + Dice_et.append(dice(infer_et,label_et)) |
| 75 | + Dice_tc.append(dice(infer_tc,label_tc)) |
| 76 | + Dice_wt.append(dice(infer_wt,label_wt)) |
| 77 | + Dice_ed.append(dice(infer_ed,label_ed)) |
| 78 | + Dice_net.append(dice(infer_net,label_net)) |
| 79 | + |
| 80 | + HD_et.append(hd(infer_et,label_et)) |
| 81 | + HD_tc.append(hd(infer_tc,label_tc)) |
| 82 | + HD_wt.append(hd(infer_wt,label_wt)) |
| 83 | + HD_ed.append(hd(infer_ed,label_ed)) |
| 84 | + HD_net.append(hd(infer_net,label_net)) |
| 85 | + |
| 86 | + |
| 87 | + fw.write('*'*20+'\n',) |
| 88 | + fw.write(infer_path.split('/')[-1]+'\n') |
| 89 | + fw.write('hd_et: {:.4f}\n'.format(HD_et[-1])) |
| 90 | + fw.write('hd_tc: {:.4f}\n'.format(HD_tc[-1])) |
| 91 | + fw.write('hd_wt: {:.4f}\n'.format(HD_wt[-1])) |
| 92 | + fw.write('hd_ed: {:.4f}\n'.format(HD_ed[-1])) |
| 93 | + fw.write('hd_net: {:.4f}\n'.format(HD_net[-1])) |
| 94 | + fw.write('*'*20+'\n',) |
| 95 | + fw.write('Dice_et: {:.4f}\n'.format(Dice_et[-1])) |
| 96 | + fw.write('Dice_tc: {:.4f}\n'.format(Dice_tc[-1])) |
| 97 | + fw.write('Dice_wt: {:.4f}\n'.format(Dice_wt[-1])) |
| 98 | + fw.write('Dice_ed: {:.4f}\n'.format(Dice_ed[-1])) |
| 99 | + fw.write('Dice_net: {:.4f}\n'.format(Dice_net[-1])) |
| 100 | + |
| 101 | + #print('dice_et: {:.4f}'.format(np.mean(Dice_et))) |
| 102 | + #print('dice_tc: {:.4f}'.format(np.mean(Dice_tc))) |
| 103 | + #print('dice_wt: {:.4f}'.format(np.mean(Dice_wt))) |
| 104 | + dsc=[] |
| 105 | + avg_hd=[] |
| 106 | + dsc.append(np.mean(Dice_et)) |
| 107 | + dsc.append(np.mean(Dice_tc)) |
| 108 | + dsc.append(np.mean(Dice_wt)) |
| 109 | + dsc.append(np.mean(Dice_ed)) |
| 110 | + dsc.append(np.mean(Dice_net)) |
| 111 | + |
| 112 | + |
| 113 | + avg_hd.append(np.mean(HD_et)) |
| 114 | + avg_hd.append(np.mean(HD_tc)) |
| 115 | + avg_hd.append(np.mean(HD_wt)) |
| 116 | + avg_hd.append(np.mean(HD_ed)) |
| 117 | + avg_hd.append(np.mean(HD_net)) |
| 118 | + |
| 119 | + fw.write('Dice_et'+str(np.mean(Dice_et))+' '+'\n') |
| 120 | + fw.write('Dice_tc'+str(np.mean(Dice_tc))+' '+'\n') |
| 121 | + fw.write('Dice_wt'+str(np.mean(Dice_wt))+' '+'\n') |
| 122 | + fw.write('Dice_ed'+str(np.mean(Dice_ed))+' '+'\n') |
| 123 | + fw.write('Dice_net'+str(np.mean(Dice_net))+' '+'\n') |
| 124 | + |
| 125 | + fw.write('HD_et'+str(np.mean(HD_et))+' '+'\n') |
| 126 | + fw.write('HD_tc'+str(np.mean(HD_tc))+' '+'\n') |
| 127 | + fw.write('HD_wt'+str(np.mean(HD_wt))+' '+'\n') |
| 128 | + fw.write('HD_ed'+str(np.mean(HD_ed))+' '+'\n') |
| 129 | + fw.write('HD_net'+str(np.mean(HD_net))+' '+'\n') |
| 130 | + |
| 131 | + fw.write('Dice'+str(np.mean(dsc))+' '+'\n') |
| 132 | + fw.write('HD'+str(np.mean(avg_hd))+' '+'\n') |
| 133 | + #print('Dice'+str(np.mean(dsc))+' '+'\n') |
| 134 | + #print('HD'+str(np.mean(avg_hd))+' '+'\n') |
| 135 | + |
| 136 | + |
| 137 | + |
| 138 | +if __name__ == '__main__': |
| 139 | + parser = argparse.ArgumentParser() |
| 140 | + parser.add_argument("fold", help="fold name") |
| 141 | + args = parser.parse_args() |
| 142 | + fold=args.fold |
| 143 | + test(fold) |
0 commit comments