from typing import Dict, Optional, Union
import torch
from collie.loss.metadata_utils import ideal_difference_from_metadata
[docs]def warp_loss(
positive_scores: torch.tensor,
many_negative_scores: torch.tensor,
num_items: int,
positive_items: Optional[torch.tensor] = None,
negative_items: Optional[torch.tensor] = None,
metadata: Optional[Dict[str, torch.tensor]] = dict(),
metadata_weights: Optional[Dict[str, float]] = dict(),
) -> torch.tensor:
"""
Modified WARP loss function [4]_.
See http://www.thespermwhale.com/jaseweston/papers/wsabie-ijcai.pdf for loss equation.
See ``ideal_difference_from_metadata`` docstring for more info on how metadata is used.
Parameters
----------
positive_scores: torch.tensor, 1-d
Tensor containing scores for known positive items of shape
``num_negative_samples x batch_size``
many_negative_scores: torch.tensor, 2-d
Iterable of tensors containing scores for many (n > 1) sampled negative items of shape
``num_negative_samples x batch_size``. More tensors increase the likelihood of finding
ranking-violating pairs, but risk overfitting
num_items: int
Total number of items in the dataset
positive_items: torch.tensor, 1-d
Tensor containing ids for known positive items of shape
``num_negative_samples x batch_size``. This is only needed if ``metadata`` is provided
negative_items: torch.tensor, 2-d
Tensor containing ids for sampled negative items of shape
``num_negative_samples x batch_size``. This is only needed if ``metadata`` is provided
metadata: dict
Keys should be strings identifying each metadata type that match keys in
``metadata_weights``. Values should be a ``torch.tensor`` of shape (num_items x 1). Each
tensor should contain categorical metadata information about items (e.g. a number
representing the genre of the item)
metadata_weights: dict
Keys should be strings identifying each metadata type that match keys in ``metadata``.
Values should be the amount of weight to place on a match of that type of metadata, with the
sum of all values ``<= 1``.
e.g. If ``metadata_weights = {'genre': .3, 'director': .2}``, then an item is:
* a 100% match if it's the same item,
* a 50% match if it's a different item with the same genre and same director,
* a 30% match if it's a different item with the same genre and different director,
* a 20% match if it's a different item with a different genre and same director,
* a 0% match if it's a different item with a different genre and different director,
which is equivalent to the loss without any partial credit
Returns
-------
loss: torch.tensor
References
----------
.. [4] Weston et al. WSABIE: Scaling Up To Large Vocabulary Image Annotation.
www.thespermwhale.com/jaseweston/papers/wsabie-ijcai.pdf.
"""
if negative_items is not None and positive_items is not None:
positive_items = positive_items.repeat([many_negative_scores.shape[0], 1])
if metadata is not None and len(metadata) > 0:
ideal_difference = ideal_difference_from_metadata(
positive_items=positive_items,
negative_items=negative_items,
metadata=metadata,
metadata_weights=metadata_weights,
).transpose(1, 0)
else:
ideal_difference = 1
# device to put new tensors on
device = positive_scores.device
# WARP loss requires a different structure for positive and negative samples
positive_scores = positive_scores.view(len(positive_scores), 1)
many_negative_scores = torch.transpose(many_negative_scores, 0, 1)
batch_size, max_trials = many_negative_scores.size(0), many_negative_scores.size(1)
flattened_new_row_indices = torch.arange(0, batch_size, 1).long().to(device) * (max_trials + 1)
tensor_of_ones = torch.ones(batch_size, 1).float().to(device)
# ``initial_loss`` is essentially just hinge loss for now
hinge_loss = ideal_difference - positive_scores + many_negative_scores
# Add column of ones to so we know when we have used all our attempts. This is used for indexing
# and computing ``should_count_loss`` if no real value is above 0.
initial_loss_with_ones = torch.cat([hinge_loss, tensor_of_ones], dim=1)
# this will be modified in ``_find_first_loss_violation``
initial_loss_with_ones_binary = torch.cat([hinge_loss, tensor_of_ones], dim=1)
number_of_tries = _find_first_loss_violation(initial_loss_with_ones_binary, device)
prediction_index_for_flattened_predictions = number_of_tries + flattened_new_row_indices
number_of_tries = (number_of_tries + 1).float()
# IMPORTANT CHANGE: normal WARP weighting has the numerator set to ``num_items - 1``, but we
# have found this does not penalize when the last item in a negative item sequence ranks above a
# positive item score. Adjusting the numerator as below penalizes this correctly. Additionally,
# adding a floor function to the numerator can also have the same negative effect of not
# not counting loss. See the original implementation as a comment below, and our modified,
# harsher calculation implemented below.
# loss_weights = torch.log(torch.floor((num_items - 1) / number_of_tries))
loss_weights = torch.log((num_items / number_of_tries))
# don't count loss if we used max number of attempts looking for a violation and didn't find one
should_we_count_loss = (number_of_tries <= max_trials).float()
loss = (
loss_weights
* (
initial_loss_with_ones.flatten()[prediction_index_for_flattened_predictions]
)
* should_we_count_loss
)
return (loss.sum() + loss.pow(2).sum()) / len(positive_scores)
def _find_first_loss_violation(losses: torch.tensor,
device: Union[str, torch.device, torch.cuda.device]) -> torch.tensor:
"""
Find the index of the first violation where ``1 - positive_score + negative_score`` is greater
than 0.
"""
# set all negative losses to 0 and all positive losses to 1
losses[losses < 0] = 0
losses[losses > 0] = 1
# after this, maximum value will be the first non-zero (bad) loss
reverse_indices = torch.arange(losses.shape[1], 0, -1).to(device)
min_index_of_good_loss = losses * reverse_indices
# report the first loss that is positive here (not 0-based indexed)
number_of_tries = torch.argmax(min_index_of_good_loss, 1, keepdim=True).flatten()
return number_of_tries