Source code for sorbetto.flavor.ranking_flavor

from typing import Any

import matplotlib.pyplot as plt
import numpy as np

from sorbetto.core.entity import Entity
from sorbetto.core.importance import Importance
from sorbetto.flavor.abstract_numeric_flavor import AbstractNumericFlavor
from sorbetto.performance.finite_set_of_two_class_classification_performances import (
    FiniteSetOfTwoClassClassificationPerformances,
)
from sorbetto.ranking.ranking_score import RankingScore


[docs] class RankingFlavor(AbstractNumericFlavor): """ For a given rank :math:`r`, the *Entity Flavor* is the mathematical function that gives, to any Importance :math:`I` (that is, some application-specific preferences), the entity ranked :math:`r`-th according to the ordering of performances induced by the Ranking Score :math:`R_I` corresponding to the importance :math:`I`. """ def __init__( self, entity: Entity, entity_list: list[Entity], name: str = "Unnamed Ranking Flavor", colormap: Any = None, ): super().__init__(name=name, colormap=colormap) self._entity = entity self._entity_list = entity_list self._nb_entities = len(entity_list) self._performances = FiniteSetOfTwoClassClassificationPerformances( [e.performance for e in entity_list] ) try: self._id_entity = self._entity_list.index(entity) except ValueError as exc: raise ValueError( "The given entity was not found in the given entity list." ) from exc @property def entity(self) -> Entity: return self._entity @property def entity_list(self) -> list[Entity]: return self._entity_list @property def nb_entities(self) -> int: return self._nb_entities @property def performances(self) -> FiniteSetOfTwoClassClassificationPerformances: return self._performances @property def id_entity(self) -> int: return self._id_entity def __call__( self, importance: Importance | np.ndarray, ) -> float | np.ndarray: values = RankingScore._compute( importance=importance, performance=self._performances ) # TODO check behaviour of argsort(argsort()) with multiple identical values # and allow user to choose between 'min', 'max', 'mean', ... ranks = np.argsort(-values, axis=0) rank_entities = np.argsort(ranks, axis=0) + 1 return rank_entities[self._id_entity]
[docs] def getDefaultColormap(self): return plt.get_cmap("rainbow", self.nb_entities)
[docs] def getLowerBound(self): return 1.0
[docs] def getUpperBound(self): return self._nb_entities