import io
import logging
from typing import Iterator, SupportsIndex, cast
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sorbetto.annotation.abstract_annotation import AbstractAnnotation
from sorbetto.core.types import Extent
from sorbetto.flavor.abstract_flavor import AbstractFlavor
from sorbetto.parameterization.abstract_parameterization import AbstractParameterization
[docs]
class Tile:
"""
This is the base class for all Tiles. A Tile is a graphical representation (of what ????) with:
- a parameterization;
- a flavor; # TDOO: not always because of EmptyTile
- and some annotations.
Tiles with the default parameterization are studied in detail in :cite:t:`Pierard2024TheTile-arxiv`.
Various flavors of Tiles are described in :cite:t:`Halin2024AHitchhikers-arxiv` and :cite:t:`Pierard2025AMethodology`.
"""
def __init__(
self,
parameterization: AbstractParameterization,
flavor: AbstractFlavor | None = None,
name: str = "Tile",
resolution: int = 1001,
disable_colorbar: bool = True,
):
"""
Args:
parameterization (AbstractParameterization): The parameterization to be used
for the tile.
flavor (AbstractFlavor | None, optional): The flavor to use. Defaults to None.
resolution (int, optional): Resolution of the tile. Defaults to 1001.
name (str | None, optional): Name of the tile. Defaults to None.
Raises:
TypeError: If the types of the arguments are incorrect.
"""
if not isinstance(parameterization, AbstractParameterization):
raise TypeError(
f"parameterization must be an instance of AbstractParameterization, got {type(parameterization)}"
)
self._parameterization = parameterization
if flavor is not None:
if not isinstance(flavor, AbstractFlavor):
raise TypeError(
f"flavor must be an instance of AbstractFlavor, got {type(flavor)}"
)
self._flavor = flavor
self._name = name
if (not isinstance(resolution, int)) or resolution <= 0:
raise TypeError(
f"resolution must be a strictly positive integer, got {resolution!r}"
)
self._resolution = resolution
self._zoom = self._parameterization.getExtent()
self._mat_value: np.ndarray | None = None
self._update_grid()
self._annotations: list[AbstractAnnotation] = list()
self._disable_colorbar = disable_colorbar
@property
def resolution(self) -> int:
return self._resolution
@resolution.setter
def resolution(self, resolution: int):
if (not isinstance(resolution, int)) or resolution <= 0:
raise TypeError(
f"resolution must be a strictly positive integer, got {resolution!r}"
)
self._resolution = resolution
self._update_grid()
@property
def zoom(self) -> Extent:
return self._zoom
@zoom.setter
def zoom(self, zoom: Extent):
def intersection(extent_1: Extent, extent_2: Extent) -> Extent:
min_x_1, max_x_1, min_y_1, max_y_1 = extent_1
assert min_x_1 < max_x_1
assert min_y_1 < max_y_1
min_x_2, max_x_2, min_y_2, max_y_2 = extent_2
assert min_x_2 < max_x_2
assert min_y_2 < max_y_2
min_x = max(min_x_1, min_x_2)
max_x = min(max_x_1, max_x_2)
min_y = max(min_y_1, min_y_2)
max_y = min(max_y_1, max_y_2)
assert min_x < max_x
assert min_y < max_y
return (min_x, max_x, min_y, max_y)
assert isinstance(zoom, tuple)
assert len(zoom) == 4
assert all(isinstance(v, float) for v in zoom)
extent = self._parameterization.getExtent()
self._zoom = intersection(zoom, extent)
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value):
if value is None:
value = "Unnamed Tile"
elif not isinstance(value, str):
value = str(value)
self._name = value
@property
def parameterization(self) -> AbstractParameterization:
return self._parameterization
@property
def importances(self):
return self.parameterization.getCanonicalImportanceVectorized(
self._mat_x, self._mat_y
)
@property
def flavor(self) -> AbstractFlavor | None:
return self._flavor
@property
def disable_colorbar(self) -> bool:
return self._disable_colorbar
@disable_colorbar.setter
def disable_colorbar(self, value: bool):
if not isinstance(value, bool):
raise TypeError(f"disable_colorbar must be a bool, got {type(value)}")
self._disable_colorbar = value
def _update_grid(self):
x_min, x_max, y_min, y_max = self._zoom
assert x_min < x_max
assert y_min < y_max
vec_x = np.linspace(x_min, x_max, self.resolution)
self._vec_x = vec_x
vec_y = np.linspace(y_min, y_max, self.resolution)
self._vec_y = vec_y
self._mat_x, self._mat_y = np.meshgrid(vec_x, vec_y, indexing="xy")
self._mat_value = None
[docs]
def getExplanation(self) -> str:
return self.__str__()
@property
def mat_value(self) -> np.ndarray:
if self._flavor is None:
tmp = np.empty([self.resolution, self.resolution])
tmp[:] = np.nan
return tmp
if self._mat_value is None:
self._mat_value = self._compute_mat_value(self._mat_x, self._mat_y)
return cast(np.ndarray, self._mat_value)
def _compute_mat_value(
self,
param1: list[float] | np.ndarray | None = None, # TODO: or float ?
param2: list[float] | np.ndarray | None = None, # TODO: or float ?
): # uses `flavor ( importances )`.
if not isinstance(param1, (np.ndarray)):
param1 = np.array(param1)
if not isinstance(param2, (np.ndarray)):
param2 = np.array(param2)
importance = self.parameterization.getCanonicalImportanceVectorized(
param1, param2
)
if self.flavor is None:
return np.zeros_like(importance.shape[:-1])
return self.flavor(importance=importance)
def __call__(self, param1: np.ndarray, param2: np.ndarray) -> np.ndarray:
return self._compute_mat_value(param1, param2)
# @abstractmethod
# def getColormap(self) -> np.ndarray: ... # TODO
[docs]
def genAnnotations(self) -> Iterator[AbstractAnnotation]: # Generator
for annotation in self._annotations:
yield annotation
[docs]
def appendAnnotation(self, annotation):
assert isinstance(annotation, AbstractAnnotation)
self._annotations.append(annotation)
[docs]
def removeAnnotation(self, annotation):
assert isinstance(annotation, AbstractAnnotation)
self._annotations.remove(annotation)
[docs]
def popAnnotation(self, index: SupportsIndex = -1) -> AbstractAnnotation:
"""
Remove and return an annotation at index (default last).
Args:
index (SupportsIndex, optional): index of the annotation to pop. Defaults to -1.
Returns:
AbstractAnnotation: the annotation at the specified index.
Raises:
IndexError: if the index is out of range.
"""
return self._annotations.pop(index)
[docs]
def clearAnnotations(self):
self._annotations.clear()
[docs]
def draw(
self, fig: Figure | None = None, ax: Axes | None = None
) -> tuple[Figure, Axes]:
"""Draws the Tile in the given figure and axes.
Args:
fig (Figure | None, optional): The figure to draw in. If None, a new
figure is created. Defaults to None.
ax (Axes | None, optional): The axes to draw in. If None, a new axis is
created. Defaults to None. Note that this argument is ignored if fig
is None.
Returns:
The figure and axes used for drawing.
"""
if fig is None:
fig = plt.figure()
ax = fig.gca()
elif ax is None:
ax = fig.gca()
# Draw all annotations
for annotation in self.genAnnotations():
assert isinstance(annotation, AbstractAnnotation)
tile = self
try:
annotation.draw(tile, fig, ax)
except Exception as e:
message = (
"Something went wrong while drawing annotation {!r}, got {}".format(
annotation.name, e
)
)
logging.warning(message)
# Configure the limits, axes labels, and title.
x_min, x_max, y_min, y_max = self._zoom
assert x_min < x_max
assert y_min < y_max
ax.set_xlim((x_min, x_max))
ax.set_ylim((y_min, y_max))
ax.set_aspect("equal")
parameterization = self.parameterization
ax.set_xlabel(parameterization.getNameParameter1())
ax.set_ylabel(parameterization.getNameParameter2())
ax.set_title(self.name)
if not self.disable_colorbar:
# Create a subdivision of the axis to add a colorbar of same height
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad="5%")
fig.colorbar(ax.images[0], cax)
return fig, ax
def __str__(self) -> str:
buffer = io.StringIO()
buffer.write('This Tile is named "{}".'.format(self.name))
buffer.write("\nIt uses the flavor: {}.".format(self.flavor))
buffer.write(
"\nIt uses the parameterization: {}.".format(self.parameterization)
)
if len(self._annotations) > 0:
buffer.write("\nIt shows the following annotations:\n")
for annotation in self._annotations:
buffer.write("- {}\n".format(annotation.name))
return buffer.getvalue()