diff --git a/precise/functions.py b/precise/functions.py index 44daa8b..2655353 100644 --- a/precise/functions.py +++ b/precise/functions.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import * +LOSS_BIAS = 0.9 # [0..1] where 1 is inf bias + def weighted_log_loss(yt, yp) -> Any: """ @@ -21,23 +23,21 @@ def weighted_log_loss(yt, yp) -> Any: yp: Prediction """ from keras import backend as K - weight = 0.7 # [0..1] where 1 is inf bias pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon()) neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon()) - return weight * K.mean(neg_loss) + (1. - weight) * K.mean(pos_loss) + return LOSS_BIAS * K.mean(neg_loss) + (1. - LOSS_BIAS) * K.mean(pos_loss) def weighted_mse_loss(yt, yp) -> Any: from keras import backend as K - weight = 0.9 # [0..1] where 1 is inf bias total = K.sum(K.ones_like(yt)) neg_loss = total * K.sum(K.square(yp * (1 - yt))) / K.sum(1 - yt) pos_loss = total * K.sum(K.square(1. - (yp * yt))) / K.sum(yt) - return weight * neg_loss + (1. - weight) * pos_loss + return LOSS_BIAS * neg_loss + (1. - LOSS_BIAS) * pos_loss def false_pos(yt, yp) -> Any: