diff --git a/precise/functions.py b/precise/functions.py index 12fb6de..3bfeec5 100644 --- a/precise/functions.py +++ b/precise/functions.py @@ -25,7 +25,8 @@ def weighted_log_loss(yt, yp) -> Any: pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon()) neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon()) - return weight * K.sum(neg_loss) + (1. - weight) * K.sum(pos_loss) + + return weight * K.mean(neg_loss) + (1. - weight) * K.mean(pos_loss) def weighted_mse_loss(yt, yp) -> Any: