Skip to content

Commit b432046

Browse files
author
Ubuntu
committed
devide into separate files
1 parent 0d8f84e commit b432046

File tree

6 files changed

+271
-263
lines changed

6 files changed

+271
-263
lines changed

deeplake/integrations/mmlab/mmseg/basedataset.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,25 @@
1717
from mmengine.utils import is_abs
1818

1919

20-
2120
class BaseDataset(Dataset):
2221
r"""
2322
@brief A modified copy of OpenMMLab's BaseDataset.
2423
25-
This class is a direct copy of OpenMMLab's `BaseDataset`, with modifications
26-
to remove forced filesystem initialization (`force_init`) and customize the
27-
dataset length retrieval.
24+
This class is a direct copy of OpenMMLab's `BaseDataset`, with modifications
25+
to remove forced filesystem initialization (`force_init`) and customize the
26+
dataset length retrieval.
2827
2928
@note
30-
- We do not use the original `BaseDataset` because it enforces local filesystem
29+
- We do not use the original `BaseDataset` because it enforces local filesystem
3130
dataset initialization, which is incompatible with our cloud-based dataset.
32-
- Instead of relying on local file scans, this version retrieves dataset size
31+
- Instead of relying on local file scans, this version retrieves dataset size
3332
from a cloud storage backend.
34-
33+
3534
@modifications
3635
- Removed `force_init` to avoid mandatory filesystem checks.
3736
- Overridden `__len__` to use cloud metadata instead of local file counting.
38-
39-
This ensures that the dataset can be loaded dynamically from the cloud without
37+
38+
This ensures that the dataset can be loaded dynamically from the cloud without
4039
unnecessary local file system dependencies.
4140
4241
The annotation format is shown as follows.
@@ -193,8 +192,7 @@ def get_data_info(self, idx: int) -> dict:
193192
return data_info
194193

195194
def full_init(self):
196-
"""Load annotation file and set ``BaseDataset._fully_initialized`` to True.
197-
"""
195+
"""Load annotation file and set ``BaseDataset._fully_initialized`` to True."""
198196
if self._fully_initialized:
199197
return
200198

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import copy
3+
from typing import Any, Dict
4+
5+
from deeplake.integrations.mmlab.mmseg.registry import TRANSFORMS
6+
from deeplake.integrations.mmlab.mmseg.load_annotations import LoadAnnotations
7+
from mmengine.dataset import Compose
8+
9+
from deeplake.client.config import DEEPLAKE_AUTH_TOKEN
10+
11+
from deeplake.util.exceptions import (
12+
EmptyTokenException,
13+
EmptyDeeplakePathException,
14+
ConflictingDatasetParametersError,
15+
MissingTensorMappingError,
16+
)
17+
18+
import mmengine.registry
19+
20+
original_build_func = mmengine.registry.DATASETS.build
21+
22+
23+
def build_transform(steps):
24+
from mmengine.registry.build_functions import build_from_cfg
25+
26+
transforms = []
27+
steps_copy = copy.deepcopy(steps)
28+
29+
for step in steps_copy:
30+
if step["type"] == "LoadAnnotations":
31+
# Create LoadAnnotations instance and add to transforms list
32+
kwargs = step.copy()
33+
kwargs.pop("type")
34+
transform = LoadAnnotations(**kwargs)
35+
transforms.append(transform)
36+
elif step["type"] != "LoadImageFromFile":
37+
transform = build_from_cfg(step, TRANSFORMS, None)
38+
transforms.append(transform)
39+
40+
return Compose(transforms)
41+
42+
43+
def build_func_patch(
44+
cfg: Dict,
45+
*args,
46+
**kwargs,
47+
) -> Any:
48+
import deeplake as dp
49+
50+
creds = cfg.pop("deeplake_credentials", {})
51+
token = creds.pop("token", None)
52+
token = token or os.environ.get(DEEPLAKE_AUTH_TOKEN)
53+
if token is None:
54+
raise EmptyTokenException()
55+
56+
ds_path = cfg.pop("deeplake_path", None)
57+
if ds_path is None or not len(ds_path):
58+
raise EmptyDeeplakePathException()
59+
60+
deeplake_ds = dp.load(ds_path, token=token, read_only=True)[0:500:1]
61+
deeplake_commit = cfg.pop("deeplake_commit", None)
62+
deeplake_view_id = cfg.pop("deeplake_view_id", None)
63+
deeplake_query = cfg.pop("deeplake_query", None)
64+
65+
if deeplake_view_id and deeplake_query:
66+
raise ConflictingDatasetParametersError()
67+
68+
if deeplake_commit:
69+
deeplake_ds.checkout(deeplake_commit)
70+
71+
if deeplake_view_id:
72+
deeplake_ds = deeplake_ds.load_view(id=deeplake_view_id)
73+
74+
if deeplake_query:
75+
deeplake_ds = deeplake_ds.query(deeplake_query)
76+
77+
ds_train_tensors = cfg.pop("deeplake_tensors", {})
78+
79+
if "pipeline" in cfg:
80+
transform_pipeline = build_transform(cfg.get("pipeline"))
81+
else:
82+
transform_pipeline = None
83+
84+
if not ds_train_tensors and not {"img", "gt_semantic_seg"}.issubset(
85+
ds_train_tensors
86+
):
87+
raise MissingTensorMappingError()
88+
89+
cfg["lazy_init"] = False
90+
res = original_build_func(cfg, *args, **kwargs)
91+
res.deeplake_dataset = deeplake_ds
92+
res.images_tensor = ds_train_tensors.get("img")
93+
res.masks_tensor = ds_train_tensors.get("gt_semantic_seg")
94+
return res, transform_pipeline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import warnings
2+
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
3+
4+
from deeplake.integrations.mmlab.mmseg.registry import TRANSFORMS
5+
6+
7+
@TRANSFORMS.register_module()
8+
class LoadAnnotations(MMCV_LoadAnnotations):
9+
"""Load annotations for semantic segmentation provided by dataset.
10+
11+
The annotation format is as the following:
12+
13+
.. code-block:: python
14+
15+
{
16+
# Filename of semantic segmentation ground truth file.
17+
'seg_map_path': 'a/b/c'
18+
}
19+
20+
After this module, the annotation has been changed to the format below:
21+
22+
.. code-block:: python
23+
24+
{
25+
# in str
26+
'seg_fields': List
27+
# In uint8 type.
28+
'gt_seg_map': np.ndarray (H, W)
29+
}
30+
31+
Required Keys:
32+
33+
- seg_map_path (str): Path of semantic segmentation ground truth file.
34+
35+
Added Keys:
36+
37+
- seg_fields (List)
38+
- gt_seg_map (np.uint8)
39+
40+
Args:
41+
reduce_zero_label (bool, optional): Whether reduce all label value
42+
by 1. Usually used for datasets where 0 is background label.
43+
Defaults to None.
44+
imdecode_backend (str): The image decoding backend type. The backend
45+
argument for :func:``mmcv.imfrombytes``.
46+
See :fun:``mmcv.imfrombytes`` for details.
47+
Defaults to 'pillow'.
48+
backend_args (dict): Arguments to instantiate a file backend.
49+
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
50+
for details. Defaults to None.
51+
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
52+
"""
53+
54+
def __init__(
55+
self,
56+
reduce_zero_label=None,
57+
backend_args=None,
58+
imdecode_backend="pillow",
59+
) -> None:
60+
super().__init__(
61+
with_bbox=False,
62+
with_label=False,
63+
with_seg=True,
64+
with_keypoints=False,
65+
imdecode_backend=imdecode_backend,
66+
backend_args=backend_args,
67+
)
68+
self.reduce_zero_label = reduce_zero_label
69+
if self.reduce_zero_label is not None:
70+
warnings.warn(
71+
"`reduce_zero_label` will be deprecated, "
72+
"if you would like to ignore the zero label, please "
73+
"set `reduce_zero_label=True` when dataset "
74+
"initialized"
75+
)
76+
self.imdecode_backend = imdecode_backend
77+
78+
def _load_seg_map(self, results: dict) -> None:
79+
"""Private function to load semantic segmentation annotations.
80+
81+
Args:
82+
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
83+
84+
Returns:
85+
dict: The dict contains loaded semantic segmentation annotations.
86+
"""
87+
88+
gt_semantic_seg = results.pop("dp_seg_map", None)
89+
90+
# reduce zero_label
91+
if self.reduce_zero_label:
92+
# avoid using underflow conversion
93+
gt_semantic_seg[gt_semantic_seg == 0] = 255
94+
gt_semantic_seg = gt_semantic_seg - 1
95+
gt_semantic_seg[gt_semantic_seg == 254] = 255
96+
# modify if custom classes
97+
if results.get("label_map", None) is not None:
98+
# Add deep copy to solve bug of repeatedly
99+
# replace `gt_semantic_seg`, which is reported in
100+
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
101+
gt_semantic_seg_copy = gt_semantic_seg.copy()
102+
for old_id, new_id in results["label_map"].items():
103+
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
104+
results["gt_seg_map"] = gt_semantic_seg
105+
106+
def __repr__(self) -> str:
107+
repr_str = self.__class__.__name__
108+
repr_str += f"(reduce_zero_label={self.reduce_zero_label}, "
109+
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
110+
repr_str += f"backend_args={self.backend_args})"
111+
return repr_str

0 commit comments

Comments
 (0)