Skip to content

Commit a2ced54

Browse files
dkurtthewoz
authored andcommitted
Merge pull request opencv#23691 from dkurt:pycv_float16_fixes
Import and export np.float16 in Python opencv#23691 ### Pull Request Readiness Checklist * Also, fixes `cv::norm` with `NORM_INF` and `CV_16F` resolves opencv#23687 See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
1 parent aa536e7 commit a2ced54

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

modules/core/src/norm.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ double norm( InputArray _src, int normType, InputArray _mask )
753753
{
754754
int bsz = std::min(total - j, blockSize);
755755
hal::cvt16f32f((const float16_t*)ptrs[0], data0, bsz * cn);
756-
func((uchar*)data0, ptrs[1], (uchar*)&result.d, bsz, cn);
756+
func((uchar*)data0, ptrs[1], (uchar*)&result.f, bsz, cn);
757757
ptrs[0] += bsz*esz;
758758
if (ptrs[1])
759759
ptrs[1] += bsz;
@@ -771,9 +771,9 @@ double norm( InputArray _src, int normType, InputArray _mask )
771771

772772
if( normType == NORM_INF )
773773
{
774-
if(depth == CV_64F || depth == CV_16F)
774+
if(depth == CV_64F)
775775
return result.d;
776-
else if (depth == CV_32F)
776+
else if (depth == CV_32F || depth == CV_16F)
777777
return result.f;
778778
else
779779
return result.i;
@@ -1224,7 +1224,7 @@ double norm( InputArray _src1, InputArray _src2, int normType, InputArray _mask
12241224
int bsz = std::min(total - j, blockSize);
12251225
hal::cvt16f32f((const float16_t*)ptrs[0], data0, bsz * cn);
12261226
hal::cvt16f32f((const float16_t*)ptrs[1], data1, bsz * cn);
1227-
func((uchar*)data0, (uchar*)data1, ptrs[2], (uchar*)&result.d, bsz, cn);
1227+
func((uchar*)data0, (uchar*)data1, ptrs[2], (uchar*)&result.f, bsz, cn);
12281228
ptrs[0] += bsz*esz;
12291229
ptrs[1] += bsz*esz;
12301230
if (ptrs[2])
@@ -1243,9 +1243,9 @@ double norm( InputArray _src1, InputArray _src2, int normType, InputArray _mask
12431243

12441244
if( normType == NORM_INF )
12451245
{
1246-
if (depth == CV_64F || depth == CV_16F)
1246+
if (depth == CV_64F)
12471247
return result.d;
1248-
else if (depth == CV_32F)
1248+
else if (depth == CV_32F || depth == CV_16F)
12491249
return result.f;
12501250
else
12511251
return result.u;

modules/python/src2/cv2_convert.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ bool pyopencv_to(PyObject* o, Mat& m, const ArgInfo& info)
8888
typenum == NPY_SHORT ? CV_16S :
8989
typenum == NPY_INT ? CV_32S :
9090
typenum == NPY_INT32 ? CV_32S :
91+
typenum == NPY_HALF ? CV_16F :
9192
typenum == NPY_FLOAT ? CV_32F :
9293
typenum == NPY_DOUBLE ? CV_64F : -1;
9394

modules/python/src2/cv2_numpy.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ UMatData* NumpyAllocator::allocate(int dims0, const int* sizes, int type, void*
3939
int typenum = depth == CV_8U ? NPY_UBYTE : depth == CV_8S ? NPY_BYTE :
4040
depth == CV_16U ? NPY_USHORT : depth == CV_16S ? NPY_SHORT :
4141
depth == CV_32S ? NPY_INT : depth == CV_32F ? NPY_FLOAT :
42-
depth == CV_64F ? NPY_DOUBLE : f*NPY_ULONGLONG + (f^1)*NPY_UINT;
42+
depth == CV_64F ? NPY_DOUBLE : depth == CV_16F ? NPY_HALF : f*NPY_ULONGLONG + (f^1)*NPY_UINT;
4343
int i, dims = dims0;
4444
cv::AutoBuffer<npy_intp> _sizes(dims + 1);
4545
for( i = 0; i < dims; i++ )

modules/python/test/test_norm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_element_types(norm_type):
8888
return (np.uint8,)
8989
else:
9090
return (np.uint8, np.int8, np.uint16, np.int16, np.int32, np.float32,
91-
np.float64)
91+
np.float64, np.float16)
9292

9393

9494
def generate_vector(shape, dtype):

0 commit comments

Comments
 (0)