Skip to content

Commit 5e02637

Browse files
author
Marian Rassat
committed
fix: better handling of disparity grids, test fixes, disparity range inversion
1 parent 9cadb1d commit 5e02637

File tree

9 files changed

+173
-58
lines changed

9 files changed

+173
-58
lines changed

src/pandora/matching_cost/cpp/includes/matching_cost.hpp

+12
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,16 @@ py::array_t<float> reverse_cost_volume(
4141
int min_disp
4242
);
4343

44+
/**
45+
* @brief Create the right disp ranges from the left disp ranges
46+
*
47+
* @param left_min: the 2D left disp min array, with dimensions row, col
48+
* @param left_max: the 2D left disp min array, with dimensions row, col
49+
* @return: The min and max disp ranges for the right image
50+
*/
51+
std::tuple<py::array_t<float>, py::array_t<float>> reverse_disp_range(
52+
py::array_t<float> left_min,
53+
py::array_t<float> left_max
54+
);
55+
4456
#endif // MATCHING_COST_HPP

src/pandora/matching_cost/cpp/matching_cost_cpp.pyi

+12
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,15 @@ def reverse_cost_volume(left_cv, disp_min):
4949
:rtype: 3D np.ndarray of type float32
5050
"""
5151
...
52+
53+
def reverse_disp_range(left_min, left_max):
54+
"""
55+
Create the right disp ranges from the left disp ranges
56+
:param left_min: the 2D left disp min array, with dimensions row, col
57+
:type left_min: np.ndarray(dtype=float32)
58+
:param left_max: the 2D left disp max array, with dimensions row, col
59+
:type left_max: np.ndarray(dtype=float32)
60+
:return: The min and max disp ranges for the right image
61+
:rtype: Tuple[np.ndarray(dtype=float32), np.ndarray(dtype=float32)]
62+
"""
63+
return None, None

src/pandora/matching_cost/cpp/src/bindings.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,10 @@ PYBIND11_MODULE(matching_cost_cpp, m) {
3838
"Computes right cost volume from left cost volume."
3939
);
4040

41+
m.def(
42+
"reverse_disp_range",
43+
&reverse_disp_range,
44+
"Computes the right disp range from the left one."
45+
);
46+
4147
}

src/pandora/matching_cost/cpp/src/matching_cost.cpp

+78
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919

2020
#include "matching_cost.hpp"
21+
#include <cmath>
22+
#include <limits>
2123

2224
namespace py = pybind11;
2325

@@ -51,4 +53,80 @@ py::array_t<float> reverse_cost_volume(
5153
}
5254

5355
return right_cv;
56+
}
57+
58+
59+
std::tuple<py::array_t<float>, py::array_t<float>> reverse_disp_range(
60+
py::array_t<float> left_min,
61+
py::array_t<float> left_max
62+
) {
63+
auto r_left_min = left_min.unchecked<2>();
64+
auto r_left_max = left_max.unchecked<2>();
65+
66+
size_t n_row = r_left_min.shape(0);
67+
size_t n_col = r_left_min.shape(1);
68+
69+
py::array_t<float> right_min = py::array_t<float>({n_row, n_col});
70+
py::array_t<float> right_max = py::array_t<float>({n_row, n_col});
71+
auto rw_right_min = right_min.mutable_unchecked<2>();
72+
auto rw_right_max = right_max.mutable_unchecked<2>();
73+
74+
// init the min and max values at inf
75+
for (size_t row = 0; row < n_row; ++row) {
76+
for (size_t col = 0; col < n_col; ++col) {
77+
rw_right_min(row, col) = std::numeric_limits<float>::infinity();
78+
rw_right_max(row, col) = -std::numeric_limits<float>::infinity();
79+
}
80+
}
81+
82+
for (size_t row = 0; row < n_row; ++row) {
83+
for (size_t col = 0; col < n_col; ++col) {
84+
85+
float d_min_raw = r_left_min(row, col);
86+
float d_max_raw = r_left_max(row, col);
87+
88+
// skip nans
89+
if (std::isnan(d_min_raw))
90+
continue;
91+
if (std::isnan(d_max_raw))
92+
continue;
93+
94+
int d_min = static_cast<int>(d_min_raw);
95+
int d_max = static_cast<int>(d_max_raw);
96+
97+
for (int d = d_min; d <= d_max; d++) {
98+
99+
int right_col = static_cast<int>(col) + d;
100+
101+
// increment d when right_col is too low, break when too high
102+
if (right_col < 0)
103+
continue;
104+
if (right_col >= static_cast<int>(n_col))
105+
break;
106+
107+
// update mins and maxs with -d to reach left_img(row, col) from
108+
// right_img(row, right_col)
109+
rw_right_min(row, right_col) = std::min(
110+
rw_right_min(row, right_col), static_cast<float>(-d)
111+
);
112+
rw_right_max(row, right_col) = std::max(
113+
rw_right_max(row, right_col), static_cast<float>(-d)
114+
);
115+
116+
}
117+
118+
}
119+
}
120+
121+
// set the disp ranges that have not been filled to nan
122+
for (size_t row = 0; row < n_row; ++row) {
123+
for (size_t col = 0; col < n_col; ++col) {
124+
if ( std::isinf(rw_right_min(row, col)) ) {
125+
rw_right_min(row, col) = std::nanf("");
126+
rw_right_max(row, col) = std::nanf("");
127+
}
128+
}
129+
}
130+
131+
return {right_min, right_max};
54132
}

src/pandora/matching_cost/matching_cost.py

+14
Original file line numberDiff line numberDiff line change
@@ -915,3 +915,17 @@ def reverse_cost_volume(left_cv: np.ndarray, disp_min: int) -> np.ndarray:
915915
:rtype: 3D np.ndarray of type float32
916916
"""
917917
return matching_cost_cpp.reverse_cost_volume(left_cv, disp_min)
918+
919+
@staticmethod
920+
def reverse_disp_range(left_min: np.ndarray, left_max: np.ndarray) -> np.ndarray:
921+
"""
922+
Create the right disp ranges from the left disp ranges
923+
924+
:param left_min: the 2D left disp min array, with dimensions row, col
925+
:type left_min: np.ndarray(dtype=float32)
926+
:param left_max: the 2D left disp max array, with dimensions row, col
927+
:type left_max: np.ndarray(dtype=float32)
928+
:return: The min and max disp ranges for the right image
929+
:rtype: Tuple[np.ndarray(dtype=float32), np.ndarray(dtype=float32)]
930+
"""
931+
return matching_cost_cpp.reverse_disp_range(left_min, left_max)

src/pandora/state_machine.py

+21-30
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
semantic_segmentation,
5858
)
5959
from pandora.margins import GlobalMargins
60-
from pandora.img_tools import rasterio_open
6160

6261
from pandora.criteria import validity_mask
6362

@@ -315,9 +314,19 @@ def matching_cost_prepare(self, cfg: Dict[str, dict], input_step: str) -> None:
315314
self.right_disp_min = self.right_disp_min * self.scale_factor
316315
self.right_disp_max = self.right_disp_max * self.scale_factor
317316

318-
self.right_cv = self.matching_cost_.allocate_cost_volume(
319-
self.right_img, (self.right_disp_min, self.right_disp_max), cfg
320-
)
317+
if self.right_disp_map == "cross_checking_accurate":
318+
# allocate the cost volume the standard way
319+
self.right_cv = self.matching_cost_.allocate_cost_volume(
320+
self.right_img, (self.right_disp_min, self.right_disp_max), cfg
321+
)
322+
323+
elif self.right_disp_map == "cross_checking_fast":
324+
# the right cv may have a different size from left cv if created from its disp range
325+
# (ex: with the test disparity grids)
326+
# create it from the left disps instead, to match the left cv's size
327+
self.right_cv = self.matching_cost_.allocate_cost_volume(
328+
self.right_img, (-self.disp_max, -self.disp_min), cfg
329+
)
321330

322331
# Compute validity mask to identify invalid points in cost volume
323332
self.right_cv = validity_mask(self.right_img, self.left_img, self.right_cv)
@@ -426,7 +435,7 @@ def disparity_run(self, cfg: Dict[str, dict], input_step: str) -> None:
426435
# Fast cross checking is used, compute right cv at wta time
427436
# Compute right cost volume and mask it
428437
self.right_cv["cost_volume"].data = matching_cost.AbstractMatchingCost.reverse_cost_volume(
429-
self.left_cv["cost_volume"].data, self.right_disp_min.min()
438+
self.left_cv["cost_volume"].data, np.nanmin(self.right_disp_min)
430439
)
431440

432441
self.right_cv.attrs["type_measure"] = self.left_cv.attrs["type_measure"]
@@ -638,8 +647,10 @@ def run_prepare(
638647
self.right_disp_min = right_img["disparity"].sel(band_disp="min").data
639648
self.right_disp_max = right_img["disparity"].sel(band_disp="max").data
640649
else:
641-
self.right_disp_min = -left_img["disparity"].sel(band_disp="max").data
642-
self.right_disp_max = -left_img["disparity"].sel(band_disp="min").data
650+
# Right disparities : always infered from left disparities
651+
self.right_disp_min, self.right_disp_max = matching_cost.AbstractMatchingCost.reverse_disp_range(
652+
self.disp_min, self.disp_max
653+
)
643654

644655
# Initiate output disparity datasets
645656
self.left_disparity = xr.Dataset()
@@ -875,31 +886,11 @@ def validation_check_conf(self, cfg: Dict[str, dict], input_step: str) -> None:
875886
# If both disp bounds are lists, check that they add up
876887
if isinstance(ds_left, list) and isinstance(ds_right, list):
877888
if ds_left[0] != -ds_right[1] or ds_left[1] != -ds_right[0]:
878-
raise AttributeError(
879-
"The cross-checking step can't be processed if disp_min, disp_max, disp_right_min, disp_right_max "
880-
"are all ints and disp_min != -disp_right_max or disp_max != -disp_right_min"
881-
)
889+
raise AttributeError("disp_min != -disp_right_max or disp_max != -disp_right_min")
882890

883-
# If both disp bounds are strs, check that their global min/max adds up
891+
# If both disp bounds are strs, warn that the right disp will be ignored
884892
elif isinstance(ds_left, str) and isinstance(ds_right, str):
885-
left_img = rasterio_open(ds_left).read()
886-
right_img = rasterio_open(ds_right).read()
887-
888-
left_min = left_img.min()
889-
left_max = left_img.max()
890-
right_min = right_img.min()
891-
right_max = right_img.max()
892-
893-
if left_min != -right_max or left_max != -right_min:
894-
raise AttributeError(
895-
"The cross-checking step can't be processed if "
896-
"disp_min != -disp_right_max or disp_max != -disp_right_min"
897-
)
898-
899-
elif type(ds_left) != type(ds_right) and ds_left is not None and ds_right is not None:
900-
raise AttributeError(
901-
"The cross-checking step does not support left and right disparities of different kinds at this time."
902-
)
893+
logging.warning("The right disp will be ignored, and instead computed from the left disp.")
903894

904895
def multiscale_check_conf(self, cfg: Dict[str, dict], input_step: str) -> None:
905896
"""

tests/functional_tests/test_validation.py

-26
Original file line numberDiff line numberDiff line change
@@ -158,29 +158,3 @@ def test_validation_fast(self, tmp_path, user_cfg):
158158

159159
# Check they are *strictly* equal
160160
assert error(result_fast, result_accurate, threshold=0) == 0
161-
162-
def test_validation_method_is_mandatory(self):
163-
"""
164-
Test that there's a crash when not giving validation_method
165-
"""
166-
with pytest.raises(KeyError):
167-
validation_ = validation.AbstractValidation(**{"something's wrong": "b"}) # type: ignore
168-
169-
def test_fails_with_invalid_method(self):
170-
"""
171-
Test that there's a crash when giving a validation method that doesn't exist
172-
"""
173-
with pytest.raises(KeyError):
174-
validation_ = validation.AbstractValidation(**{"validation_method": "hello"}) # type: ignore
175-
176-
def test_right_instance_created(self):
177-
"""
178-
Test that the right instance is created when creating specific validation method objects
179-
"""
180-
validation_ = validation.AbstractValidation(**{"validation_method": "cross_checking_fast"}) # type: ignore
181-
assert isinstance(validation_, validation.AbstractValidation)
182-
assert isinstance(validation_, CrossCheckingAccurate) # the instance is the same for both methods
183-
184-
validation_ = validation.AbstractValidation(**{"validation_method": "cross_checking_accurate"}) # type: ignore
185-
assert isinstance(validation_, validation.AbstractValidation)
186-
assert isinstance(validation_, CrossCheckingAccurate)

tests/test_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ def test_check_conf(self):
522522
"pipeline": copy.deepcopy(common.validation_pipeline_cfg),
523523
}
524524

525-
# When left disparities are grids and right are none, cross checking method cannot be used : the program exits
526-
self.assertRaises(MachineError, check_configuration.check_conf, cfg, pandora_machine)
525+
# When left disparities are grids and right are none, cross checking should succeed
526+
check_configuration.check_conf(cfg, pandora_machine)
527527

528528
# Check the configuration returned with left and right disparity grids and cross checking method
529529
cfg = {

tests/test_validation.py

+28
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
"""
2626

2727
import unittest
28+
import pytest
2829

2930
import numpy as np
3031
import xarray as xr
3132

3233
from tests import common
3334
import pandora.constants as cst
3435
from pandora import validation
36+
from pandora.validation.validation import CrossCheckingAccurate
3537

3638

3739
class TestValidation(unittest.TestCase):
@@ -73,6 +75,32 @@ def setUp(self):
7375
)
7476
self.right.attrs["offset_row_col"] = 0
7577

78+
def test_validation_method_is_mandatory(self):
79+
"""
80+
Test that there's a crash when not giving validation_method
81+
"""
82+
with pytest.raises(KeyError):
83+
validation.AbstractValidation(**{"something's wrong": "b"}) # type: ignore
84+
85+
def test_fails_with_invalid_method(self):
86+
"""
87+
Test that there's a crash when giving a validation method that doesn't exist
88+
"""
89+
with pytest.raises(KeyError):
90+
validation.AbstractValidation(**{"validation_method": "hello"}) # type: ignore
91+
92+
def test_right_instance_created(self):
93+
"""
94+
Test that the right instance is created when creating specific validation method objects
95+
"""
96+
validation_ = validation.AbstractValidation(**{"validation_method": "cross_checking_fast"}) # type: ignore
97+
assert isinstance(validation_, validation.AbstractValidation)
98+
assert isinstance(validation_, CrossCheckingAccurate) # the instance is the same for both methods
99+
100+
validation_ = validation.AbstractValidation(**{"validation_method": "cross_checking_accurate"}) # type: ignore
101+
assert isinstance(validation_, validation.AbstractValidation)
102+
assert isinstance(validation_, CrossCheckingAccurate)
103+
76104
def test_cross_checking(self):
77105
"""
78106
Test the confidence measure and the validity_mask for the cross checking method,

0 commit comments

Comments
 (0)