|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
1 | 12 | import os
|
2 | 13 | import glob
|
3 | 14 | import logging
|
|
8 | 19 | from monai.engines import SupervisedEvaluator
|
9 | 20 | from monai.transforms import (
|
10 | 21 | LoadImaged,
|
| 22 | + EnsureChannelFirstd, |
11 | 23 | Lambdad,
|
| 24 | + AsDiscreted, |
12 | 25 | Activationsd,
|
13 | 26 | Compose,
|
14 | 27 | CastToTyped,
|
@@ -54,6 +67,7 @@ def run(cfg):
|
54 | 67 | val_transforms = Compose(
|
55 | 68 | [
|
56 | 69 | LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True),
|
| 70 | + EnsureChannelFirstd(keys=["image", "label_inst", "label_type"], channel_dim=-1), |
57 | 71 | Lambdad(keys="label_inst", func=lambda x: measure.label(x)),
|
58 | 72 | CastToTyped(keys=["image", "label_inst"], dtype=torch.int),
|
59 | 73 | CenterSpatialCropd(
|
@@ -92,14 +106,14 @@ def run(cfg):
|
92 | 106 | post_process_np = Compose(
|
93 | 107 | [
|
94 | 108 | Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
|
95 |
| - Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1:2, ...] > 0.5), |
| 109 | + AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True), |
96 | 110 | ]
|
97 | 111 | )
|
98 | 112 | post_process = Lambdad(keys="pred", func=post_process_np)
|
99 | 113 |
|
100 | 114 | # Evaluator
|
101 | 115 | val_handlers = [
|
102 |
| - CheckpointLoader(load_path=cfg["ckpt"], load_dict={"net": model}), |
| 116 | + CheckpointLoader(load_path=cfg["ckpt"], load_dict={"model": model}), |
103 | 117 | StatsHandler(output_transform=lambda x: None),
|
104 | 118 | ]
|
105 | 119 | evaluator = SupervisedEvaluator(
|
|
0 commit comments