Source code for sorbetto.tile.tile

# Copyright (c) 2025-2025, Sebastien Pierard et al.
# SPDX-License-Identifier: Apache-2.0

import io
import logging
from typing import Iterator, SupportsIndex, cast

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable

from sorbetto.annotation.abstract_annotation import AbstractAnnotation
from sorbetto.core.types import Extent
from sorbetto.flavor.abstract_flavor import AbstractFlavor
from sorbetto.parameterization.abstract_parameterization import AbstractParameterization


[docs] class Tile: """ This is the base class for all Tiles. A Tile is a graphical representation (of what ????) with: - a parameterization; - a flavor; # TDOO: not always because of EmptyTile - and some annotations. Tiles with the default parameterization are studied in detail in :cite:t:`Pierard2024TheTile-arxiv`. Various flavors of Tiles are described in :cite:t:`Halin2024AHitchhikers-arxiv` and :cite:t:`Pierard2025AMethodology`. """ def __init__( self, parameterization: AbstractParameterization, flavor: AbstractFlavor | None = None, name: str = "Tile", resolution: int = 1001, disable_colorbar: bool = True, ): """ Args: parameterization (AbstractParameterization): The parameterization to be used for the tile. flavor (AbstractFlavor | None, optional): The flavor to use. Defaults to None. resolution (int, optional): Resolution of the tile. Defaults to 1001. name (str | None, optional): Name of the tile. Defaults to None. Raises: TypeError: If the types of the arguments are incorrect. """ if not isinstance(parameterization, AbstractParameterization): raise TypeError( f"parameterization must be an instance of AbstractParameterization, got {type(parameterization)}" ) self._parameterization = parameterization if flavor is not None: if not isinstance(flavor, AbstractFlavor): raise TypeError( f"flavor must be an instance of AbstractFlavor, got {type(flavor)}" ) self._flavor = flavor self._name = name if (not isinstance(resolution, int)) or resolution <= 0: raise TypeError( f"resolution must be a strictly positive integer, got {resolution!r}" ) self._resolution = resolution self._zoom = self._parameterization.getExtent() self._mat_value: np.ndarray | None = None self._update_grid() self._annotations: list[AbstractAnnotation] = list() self._disable_colorbar = disable_colorbar @property def resolution(self) -> int: return self._resolution @resolution.setter def resolution(self, resolution: int): if (not isinstance(resolution, int)) or resolution <= 0: raise TypeError( f"resolution must be a strictly positive integer, got {resolution!r}" ) self._resolution = resolution self._update_grid() @property def zoom(self) -> Extent: return self._zoom @zoom.setter def zoom(self, zoom: Extent): def intersection(extent_1: Extent, extent_2: Extent) -> Extent: min_x_1, max_x_1, min_y_1, max_y_1 = extent_1 assert min_x_1 < max_x_1 assert min_y_1 < max_y_1 min_x_2, max_x_2, min_y_2, max_y_2 = extent_2 assert min_x_2 < max_x_2 assert min_y_2 < max_y_2 min_x = max(min_x_1, min_x_2) max_x = min(max_x_1, max_x_2) min_y = max(min_y_1, min_y_2) max_y = min(max_y_1, max_y_2) assert min_x < max_x assert min_y < max_y return (min_x, max_x, min_y, max_y) assert isinstance(zoom, tuple) assert len(zoom) == 4 assert all(isinstance(v, float) for v in zoom) extent = self._parameterization.getExtent() self._zoom = intersection(zoom, extent) @property def name(self) -> str: return self._name @name.setter def name(self, value): if value is None: value = "Unnamed Tile" elif not isinstance(value, str): value = str(value) self._name = value @property def parameterization(self) -> AbstractParameterization: return self._parameterization @property def importances(self): return self.parameterization.getCanonicalImportanceVectorized( self._mat_x, self._mat_y ) @property def flavor(self) -> AbstractFlavor | None: return self._flavor @property def disable_colorbar(self) -> bool: return self._disable_colorbar @disable_colorbar.setter def disable_colorbar(self, value: bool): if not isinstance(value, bool): raise TypeError(f"disable_colorbar must be a bool, got {type(value)}") self._disable_colorbar = value def _update_grid(self): x_min, x_max, y_min, y_max = self._zoom assert x_min < x_max assert y_min < y_max vec_x = np.linspace(x_min, x_max, self.resolution) self._vec_x = vec_x vec_y = np.linspace(y_min, y_max, self.resolution) self._vec_y = vec_y self._mat_x, self._mat_y = np.meshgrid(vec_x, vec_y, indexing="xy") self._mat_value = None
[docs] def getExplanation(self) -> str: return self.__str__()
@property def mat_value(self) -> np.ndarray: if self._flavor is None: tmp = np.empty([self.resolution, self.resolution]) tmp[:] = np.nan return tmp if self._mat_value is None: self._mat_value = self._compute_mat_value(self._mat_x, self._mat_y) return cast(np.ndarray, self._mat_value) def _compute_mat_value( self, param1: list[float] | np.ndarray | None = None, # TODO: or float ? param2: list[float] | np.ndarray | None = None, # TODO: or float ? ): # uses `flavor ( importances )`. if not isinstance(param1, (np.ndarray)): param1 = np.array(param1) if not isinstance(param2, (np.ndarray)): param2 = np.array(param2) importance = self.parameterization.getCanonicalImportanceVectorized( param1, param2 ) if self.flavor is None: return np.zeros_like(importance.shape[:-1]) return self.flavor(importance=importance) def __call__(self, param1: np.ndarray, param2: np.ndarray) -> np.ndarray: return self._compute_mat_value(param1, param2) # @abstractmethod # def getColormap(self) -> np.ndarray: ... # TODO
[docs] def genAnnotations(self) -> Iterator[AbstractAnnotation]: # Generator for annotation in self._annotations: yield annotation
[docs] def appendAnnotation(self, annotation): assert isinstance(annotation, AbstractAnnotation) self._annotations.append(annotation)
[docs] def removeAnnotation(self, annotation): assert isinstance(annotation, AbstractAnnotation) self._annotations.remove(annotation)
[docs] def popAnnotation(self, index: SupportsIndex = -1) -> AbstractAnnotation: """ Remove and return an annotation at index (default last). Args: index (SupportsIndex, optional): index of the annotation to pop. Defaults to -1. Returns: AbstractAnnotation: the annotation at the specified index. Raises: IndexError: if the index is out of range. """ return self._annotations.pop(index)
[docs] def clearAnnotations(self): self._annotations.clear()
[docs] def draw( self, fig: Figure | None = None, ax: Axes | None = None ) -> tuple[Figure, Axes]: """Draws the Tile in the given figure and axes. Args: fig (Figure | None, optional): The figure to draw in. If None, a new figure is created. Defaults to None. ax (Axes | None, optional): The axes to draw in. If None, a new axis is created. Defaults to None. Note that this argument is ignored if fig is None. Returns: The figure and axes used for drawing. """ if fig is None: fig = plt.figure() ax = fig.gca() elif ax is None: ax = fig.gca() # Draw all annotations for annotation in self.genAnnotations(): assert isinstance(annotation, AbstractAnnotation) tile = self try: annotation.draw(tile, fig, ax) except Exception as e: message = ( "Something went wrong while drawing annotation {!r}, got {}".format( annotation.name, e ) ) logging.warning(message) # Configure the limits, axes labels, and title. x_min, x_max, y_min, y_max = self._zoom assert x_min < x_max assert y_min < y_max ax.set_xlim((x_min, x_max)) ax.set_ylim((y_min, y_max)) ax.set_aspect("equal") parameterization = self.parameterization ax.set_xlabel(parameterization.getNameParameter1()) ax.set_ylabel(parameterization.getNameParameter2()) ax.set_title(self.name) if not self.disable_colorbar: # Create a subdivision of the axis to add a colorbar of same height divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad="5%") fig.colorbar(ax.images[0], cax) return fig, ax
def __str__(self) -> str: buffer = io.StringIO() buffer.write('This Tile is named "{}".'.format(self.name)) buffer.write("\nIt uses the flavor: {}.".format(self.flavor)) buffer.write( "\nIt uses the parameterization: {}.".format(self.parameterization) ) if len(self._annotations) > 0: buffer.write("\nIt shows the following annotations:\n") for annotation in self._annotations: buffer.write("- {}\n".format(annotation.name)) return buffer.getvalue()