Skip to content

Commit 49cf5a0

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Loading fg probability from the alpha channel of image_rgb
Summary: It is often easier to store the mask together with RGB, especially for renders. The logic in this diff: * if load_mask and mask_path provided, take the mask from mask_path, * otherwise, check if the image has the alpha channel and take it as a mask. Reviewed By: antoinetlc Differential Revision: D68160212 fbshipit-source-id: d9b6779f90027a4987ba96800983f441edff9c74
1 parent 89b851e commit 49cf5a0

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

pytorch3d/implicitron/dataset/frame_data.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,8 @@ def build(
589589
),
590590
)
591591

592-
fg_mask_np: Optional[np.ndarray] = None
592+
fg_mask_np: np.ndarray | None = None
593+
bbox_xywh: tuple[float, float, float, float] | None = None
593594
mask_annotation = frame_annotation.mask
594595
if mask_annotation is not None:
595596
if load_blobs and self.load_masks:
@@ -598,10 +599,6 @@ def build(
598599
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
599600

600601
bbox_xywh = mask_annotation.bounding_box_xywh
601-
if bbox_xywh is None and fg_mask_np is not None:
602-
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
603-
604-
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
605602

606603
if frame_annotation.image is not None:
607604
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@@ -618,11 +615,27 @@ def build(
618615
if image_path is None:
619616
raise ValueError("Image path is required to load images.")
620617

621-
image_np = load_image(self._local_path(image_path))
618+
no_mask = fg_mask_np is None # didn’t read the mask file
619+
image_np = load_image(
620+
self._local_path(image_path), try_read_alpha=no_mask
621+
)
622+
if image_np.shape[0] == 4: # RGBA image
623+
if no_mask:
624+
fg_mask_np = image_np[3:]
625+
frame_data.fg_probability = safe_as_tensor(
626+
fg_mask_np, torch.float
627+
)
628+
629+
image_np = image_np[:3]
630+
622631
frame_data.image_rgb = self._postprocess_image(
623632
image_np, frame_annotation.image.size, frame_data.fg_probability
624633
)
625634

635+
if bbox_xywh is None and fg_mask_np is not None:
636+
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
637+
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
638+
626639
depth_annotation = frame_annotation.depth
627640
if (
628641
load_blobs

pytorch3d/implicitron/dataset/utils.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def is_train_frame(
8787
def get_bbox_from_mask(
8888
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
8989
) -> Tuple[int, int, int, int]:
90+
# these corner cases need to be handled in order to avoid an infinite loop
91+
if mask.size == 0:
92+
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
93+
return 0, 0, 1, 1
94+
95+
if not mask.min() >= 0.0:
96+
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
97+
mask = mask.clip(min=0.0)
98+
9099
# bbox in xywh
91100
masks_for_box = np.zeros_like(mask)
92101
while masks_for_box.sum() <= 1.0:
@@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
229238
return im.astype(np.float32) / 255.0
230239

231240

232-
def load_image(path: str) -> np.ndarray:
241+
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
242+
"""
243+
Load an image from a path and return it as a numpy array.
244+
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
245+
returned as the fourth channel.
246+
Otherwise, the image is read as RGB and a three-channel image is returned.
247+
"""
248+
233249
with Image.open(path) as pil_im:
234-
im = np.array(pil_im.convert("RGB"))
250+
# Check if the image has an alpha channel
251+
if try_read_alpha and pil_im.mode == "RGBA":
252+
im = np.array(pil_im)
253+
else:
254+
im = np.array(pil_im.convert("RGB"))
235255

236256
return transpose_normalize_image(im)
237257

0 commit comments

Comments
 (0)