Skip to content

add ciou and diou #914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
13 changes: 13 additions & 0 deletions tensorflow_addons/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ py_library(
"connected_components.py",
"resampler_ops.py",
"compose_ops.py",
"iou_ops.py",
]),
data = [
":sparse_image_warp_test_data",
Expand Down Expand Up @@ -177,3 +178,15 @@ py_test(
":image",
],
)

py_test(
name = "iou_ops_test",
size = "medium",
srcs = [
"iou_ops_test.py",
],
main = "iou_ops_test.py",
deps = [
":image",
],
)
4 changes: 4 additions & 0 deletions tensorflow_addons/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@
from tensorflow_addons.image.translate_ops import translate
from tensorflow_addons.image.translate_ops import translate_xy
from tensorflow_addons.image.compose_ops import blend
from tensorflow_addons.image.iou_ops import iou
from tensorflow_addons.image.iou_ops import ciou
from tensorflow_addons.image.iou_ops import diou
from tensorflow_addons.image.iou_ops import giou
183 changes: 183 additions & 0 deletions tensorflow_addons/image/iou_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements IoUs."""

import tensorflow as tf
import math
import numpy as np
from typing import Union

CompatibleFloatTensorLike = Union[tf.Tensor, float, np.float32, np.float64]


def _get_v(
b1_height: CompatibleFloatTensorLike,
b1_width: CompatibleFloatTensorLike,
b2_height: CompatibleFloatTensorLike,
b2_width: CompatibleFloatTensorLike,
) -> tf.Tensor:
@tf.custom_gradient
def _get_grad_v(height, width):
arctan = tf.atan(tf.math.divide_no_nan(b1_width, b1_height)) - tf.atan(
tf.math.divide_no_nan(width, height)
)
v = 4 * ((arctan / math.pi) ** 2)

def _grad_v(dv):
gdw = dv * 8 * arctan * height / (math.pi ** 2)
gdh = -dv * 8 * arctan * width / (math.pi ** 2)
return [gdh, gdw]

return v, _grad_v

return _get_grad_v(b2_height, b2_width)


def _common_iou(
b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike, mode: str = "iou"
) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].
mode: one of ['iou', 'ciou', 'diou', 'giou'], decided to calculate IoU or CIoU or DIoU or GIoU.

Returns:
IoU loss float `Tensor`.
"""
b1 = tf.convert_to_tensor(b1)
if not b1.dtype.is_floating:
b1 = tf.cast(b1, tf.float32)
b2 = tf.cast(b2, b1.dtype)

def _inner():
zero = tf.convert_to_tensor(0.0, b1.dtype)
b1_ymin, b1_xmin, b1_ymax, b1_xmax = tf.unstack(b1, 4, axis=-1)
b2_ymin, b2_xmin, b2_ymax, b2_xmax = tf.unstack(b2, 4, axis=-1)
b1_width = tf.maximum(zero, b1_xmax - b1_xmin)
b1_height = tf.maximum(zero, b1_ymax - b1_ymin)
b2_width = tf.maximum(zero, b2_xmax - b2_xmin)
b2_height = tf.maximum(zero, b2_ymax - b2_ymin)
b1_area = b1_width * b1_height
b2_area = b2_width * b2_height

intersect_ymin = tf.maximum(b1_ymin, b2_ymin)
intersect_xmin = tf.maximum(b1_xmin, b2_xmin)
intersect_ymax = tf.minimum(b1_ymax, b2_ymax)
intersect_xmax = tf.minimum(b1_xmax, b2_xmax)
intersect_width = tf.maximum(zero, intersect_xmax - intersect_xmin)
intersect_height = tf.maximum(zero, intersect_ymax - intersect_ymin)
intersect_area = intersect_width * intersect_height

union_area = b1_area + b2_area - intersect_area
iou = tf.math.divide_no_nan(intersect_area, union_area)
if mode == "iou":
return iou

elif mode in ["ciou", "diou"]:
enclose_ymin = tf.minimum(b1_ymin, b2_ymin)
enclose_xmin = tf.minimum(b1_xmin, b2_xmin)
enclose_ymax = tf.maximum(b1_ymax, b2_ymax)
enclose_xmax = tf.maximum(b1_xmax, b2_xmax)

b1_center = tf.stack([(b1_ymin + b1_ymax) / 2, (b1_xmin + b1_xmax) / 2],axis=-1)
b2_center = tf.stack([(b2_ymin + b2_ymax) / 2, (b2_xmin + b2_xmax) / 2],axis=-1)
euclidean = tf.linalg.norm(b2_center - b1_center,axis=-1)
diag_length = tf.linalg.norm(
tf.stack([enclose_ymax - enclose_ymin, enclose_xmax - enclose_xmin],axis=-1),axis=-1
)
diou = iou - (euclidean ** 2) / (diag_length ** 2)
if mode == "ciou":
v = _get_v(b1_height, b1_width, b2_height, b2_width)
alpha = tf.math.divide_no_nan(v, ((1 - iou) + v))
return diou - alpha * v

return diou
elif mode == "giou":
enclose_ymin = tf.minimum(b1_ymin, b2_ymin)
enclose_xmin = tf.minimum(b1_xmin, b2_xmin)
enclose_ymax = tf.maximum(b1_ymax, b2_ymax)
enclose_xmax = tf.maximum(b1_xmax, b2_xmax)
enclose_width = tf.maximum(zero, enclose_xmax - enclose_xmin)
enclose_height = tf.maximum(zero, enclose_ymax - enclose_ymin)
enclose_area = enclose_width * enclose_height
giou = iou - tf.math.divide_no_nan(
(enclose_area - union_area), enclose_area
)
return giou
else:
raise ValueError(
"Value of mode should be one of ['iou','giou','ciou','diou']"
)

return tf.squeeze(_inner())


def iou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
IoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "iou")


def ciou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
CIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "ciou")


def diou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
DIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "diou")


def giou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
GIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "giou")
87 changes: 87 additions & 0 deletions tensorflow_addons/image/iou_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for IoU losses."""

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.utils import test_utils
from tensorflow_addons.image import iou, ciou, diou, giou


@test_utils.run_all_in_graph_and_eager_modes
class IoUTest(tf.test.TestCase, parameterized.TestCase):
"""IoU test class."""

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_ious_loss(self, dtype):
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], dtype=dtype)
boxes2 = tf.constant(
[[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], dtype=dtype
)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [
[0.125, 0.0],
[-0.4088933645154844, -0.5487535732151345],
[-0.4065315315315314, -0.5315315315315314],
[-0.07500000298023224, -0.9333333373069763],
]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(boxes1, boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_different_shapes(self, dtype):
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], dtype=dtype)
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype)
expand_boxes1 = tf.expand_dims(boxes1, -2)
expand_boxes2 = tf.expand_dims(boxes2, 0)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [
[0.125, 0.0625],
[-0.0117957952481038, -0.1123530805529542],
[-0.0094339622641511, -0.0719339622641511],
[-0.075, -0.3660714285714286],
]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(expand_boxes1, expand_boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_one_bbox(self, dtype):
boxes1 = tf.constant([4.0, 3.0, 7.0, 5.0], dtype=dtype)
boxes2 = tf.constant([3.0, 4.0, 6.0, 8.0], dtype=dtype)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [0.125, 0.000686947503852, 0.0030487804878, -0.075]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(boxes1, boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)


if __name__ == "__main__":
tf.test.main()
9 changes: 5 additions & 4 deletions tensorflow_addons/losses/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ py_library(
"__init__.py",
"contrastive.py",
"focal_loss.py",
"giou_loss.py",
"iou_loss.py",
"lifted.py",
"metric_learning.py",
"npairs.py",
Expand All @@ -18,6 +18,7 @@ py_library(
],
deps = [
"//tensorflow_addons/activations",
"//tensorflow_addons/image",
"//tensorflow_addons/utils",
],
)
Expand Down Expand Up @@ -47,12 +48,12 @@ py_test(
)

py_test(
name = "giou_loss_test",
name = "iou_loss_test",
size = "small",
srcs = [
"giou_loss_test.py",
"iou_loss_test.py",
],
main = "giou_loss_test.py",
main = "iou_loss_test.py",
deps = [
":losses",
],
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
sigmoid_focal_crossentropy,
SigmoidFocalCrossEntropy,
)
from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss
from tensorflow_addons.losses.iou_loss import iou_loss, IoULoss
from tensorflow_addons.losses.iou_loss import ciou_loss, CIoULoss
from tensorflow_addons.losses.iou_loss import diou_loss, DIoULoss
from tensorflow_addons.losses.iou_loss import giou_loss, GIoULoss
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import (
Expand Down
Loading