Source code for bnl.core

"""Core data structures for monotonic boundary casting."""

from __future__ import annotations

from collections import Counter
from numbers import Number

__all__ = [
    "Boundary",
    "RatedBoundary",
    "LeveledBoundary",
    "TimeSpan",
    "Segment",
    "MultiSegment",
    "BoundaryContour",
    "BoundaryHierarchy",
]

from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field, replace
from functools import cached_property
from typing import Any

import jams
import numpy as np
import plotly.graph_objects as go

# region: Boundary Objects


[docs] @dataclass(frozen=True, order=True) class Boundary: """ A divider in the flow of time in seconds, quantized to 1e-5 seconds. """ time: float def __post_init__(self) -> None: rounded_time = round(self.time, 5) object.__setattr__(self, "time", rounded_time) def __repr__(self) -> str: return f"B({self.time:.1f})"
[docs] @dataclass(frozen=True, order=True) class RatedBoundary(Boundary): """ A boundary with a continuous measure of importance or salience. """ salience: float def __repr__(self) -> str: return f"RB({self.time:.1f}, {self.salience:.2f})"
[docs] @dataclass(frozen=True, order=True, init=False) class LeveledBoundary(RatedBoundary): """A boundary that exists in a monotonic hierarchy. This object represents a boundary that has been assigned a discrete hierarchical level. The `salience` attribute is automatically set to be equal to the `level`. """ #: The discrete hierarchical level of the boundary. level: int def __init__(self, time: float, level: int): """ Args: time (float): The time of the boundary in seconds. level (int): The discrete hierarchical level of the boundary. Must be a positive integer. Raises: ValueError: If `level` is not a positive integer. """ if not isinstance(level, int) or level <= 0: raise ValueError( f"`level` must be a positive integer, got {level}, with type {type(level)}." ) # Manually set the attributes for this frozen instance. object.__setattr__(self, "time", time) object.__setattr__(self, "level", level) object.__setattr__(self, "salience", float(level)) # Explicitly call the Boundary object's post-init for time validation. super().__post_init__() def __repr__(self) -> str: return f"LB({self.time:.1f}, {self.level})"
# endregion # region: TimeSpan and Segment
[docs] @dataclass(frozen=True) class TimeSpan: """An abstract base class for objects that represent a span of time. This class defines the interface for all time-spanned objects, ensuring they have `start`, `end`, `duration`, and `name` properties. """ start: Boundary end: Boundary name: str | None = field(default=None, kw_only=True) def __post_init__(self): self._validate_timespan() if self.name is None: object.__setattr__(self, "name", self._interval_str()) @property def duration(self) -> float: """Duration in seconds.""" return self.end.time - self.start.time def _validate_timespan(self): """Helper method to validate that the timespan is valid.""" if self.end.time <= self.start.time: raise ValueError("TimeSpan must have a non-zero, positive duration.") def _interval_str(self) -> str: """Helper method to generate a default name from the boundaries.""" return f"[{self.start.time:.2f}-{self.end.time:.2f}]" def __repr__(self) -> str: return f"TS({self._interval_str()}, {self.name})" def __str__(self) -> str: return self.name
[docs] @dataclass(frozen=True) class Segment(TimeSpan): """An ordered sequence of boundaries that partition a span into labeled sections. Represents one layer of annotation. While it inherits from `TimeSpan`, its `start` and `end` attributes are automatically derived from the provided `boundaries`. """ bs: Sequence[Boundary] labels: Sequence[str] = field(default_factory=list) start: Boundary = field(init=False) end: Boundary = field(init=False) def __post_init__(self) -> None: """Validates the core assumptions of the Segment.""" if not self.bs or len(self.bs) < 2: raise ValueError("A Segment requires at least two boundaries.") if len(self.labels) == 0: object.__setattr__(self, "labels", [None] * (len(self.bs) - 1)) if len(self.labels) != len(self.bs) - 1: raise ValueError( f"Number of labels ({len(self.labels)}) must be one less than " f"the number of boundaries ({len(self.bs)})" ) if any(self.bs[i] > self.bs[i + 1] for i in range(len(self.bs) - 1)): raise ValueError(f"Boundaries must be sorted. {self.bs}") # Use object.__setattr__ to assign to the init=False fields. object.__setattr__(self, "start", self.bs[0]) object.__setattr__(self, "end", self.bs[-1]) super().__post_init__() @cached_property def sections(self) -> Sequence[TimeSpan]: """A list of all the labeled time spans that compose the segment.""" return [ TimeSpan(name=label, start=Boundary(itvl[0]), end=Boundary(itvl[1])) for itvl, label in zip(self.itvls, self.labels) ] @cached_property def itvls(self) -> np.ndarray: itvls = [[b.time, e.time] for b, e in zip(self.bs[:-1], self.bs[1:])] return np.array(itvls) @property def lam(self) -> np.ndarray: """Label Agreement Matrix Returns: np.ndarray: The label agreement matrix. """ return np.equal.outer(self.labels, self.labels) def __len__(self) -> int: return len(self.sections) def __getitem__(self, key: int) -> TimeSpan: return self.sections[key] def __iter__(self) -> Iterator[TimeSpan]: return iter(self.sections) def __repr__(self) -> str: return f"S({self._interval_str()}, {self.name})" def __str__(self) -> str: return self.name
[docs] @classmethod def from_jams(cls, segment_annotation: jams.Annotation, name: str | None = None) -> Segment: """ Data Ingestion from jams format. """ itvls, labels = segment_annotation.to_interval_values() return cls.from_itvls(itvls, labels, name=name)
[docs] @classmethod def from_itvls( cls, itvls: Sequence[Sequence[float]], labels: Sequence[str], name: str | None = None, ) -> Segment: """Data Ingestion from `mir_eval` format of boundaries and labels.""" # assume intervals have no overlap or gaps bs = [Boundary(itvl[0]) for itvl in itvls] # tag on the end time of the last interval bs.append(Boundary(itvls[-1][1])) return cls(bs=bs, labels=labels, name=name)
[docs] @classmethod def from_bs( cls, bs: Sequence[Boundary | Number], labels: Sequence[str] | None = None, name: str | None = None, ) -> Segment: """Creates a Segment from a sequence of boundaries and labels.""" bs = [Boundary(b) if isinstance(b, Number) else b for b in bs] if labels is None: labels = [] return cls(bs=bs, labels=labels, name=name)
[docs] def plot( self, colorscale: str | list[str] = "D3", hatch: bool = True, ) -> go.Figure: """Plots the segment on a plotly figure by warpping it in a MultiSegment.""" ms = MultiSegment(raw_layers=[self], name=str(self)) fig = ms.plot(colorscale=colorscale, hatch=hatch) fig.update_layout(yaxis_visible=False) return fig
[docs] def scrub_labels(self, replace_with: str | None = "") -> Segment: """Scrubs the labels of the Segment by replacing them with empty strings.""" return replace(self, labels=[replace_with] * len(self.labels))
[docs] def align(self, span: TimeSpan) -> Segment: """Align with a TimeSpan object.""" if len(self.bs) == 2: return replace(self, bs=[span.start, span.end]) inner_bs = self.bs[1:-1] if span.start.time >= inner_bs[0].time or span.end.time <= inner_bs[-1].time: raise ValueError(f"New span {span} does not contain the inner boundaries.") new_bs = [span.start] + list(inner_bs) + [span.end] return replace(self, bs=new_bs)
# endregion: Segment # region: MultiSegment
[docs] @dataclass(frozen=True) class MultiSegment(TimeSpan): """The primary input object for analysis, containing multiple Segment layers.""" raw_layers: Sequence[Segment] = field(default_factory=list) """A sequence of `Segment` objects representing different layers of annotation.""" start: Boundary = field(init=False) end: Boundary = field(init=False) def __post_init__(self): """Validates the core assumptions of the MultiSegment.""" if not self.raw_layers: raise ValueError("MultiSegment must contain at least one Segment layer.") # Calculate the unified span and set the start/end boundaries. unified_span = self.find_span(self.raw_layers, mode="union") object.__setattr__(self, "start", unified_span.start) object.__setattr__(self, "end", unified_span.end) super().__post_init__() @cached_property def layers(self) -> Sequence[Segment]: """Returns the layers aligned to a unified time span.""" # put all raw_layers on a unified time span unified_span = TimeSpan(self.start, self.end) aligned_layers = [layer.align(unified_span) for layer in self.raw_layers] # make sure all layer's name are distinct, if not, add suffix based on occurrence count. seen_names_count = Counter() processed_layers = [] for layer in aligned_layers: count = seen_names_count[layer.name] if count: layer = replace(layer, name=f"{layer.name}_{count}") seen_names_count[layer.name] += 1 processed_layers.append(layer) return processed_layers def __len__(self) -> int: return len(self.raw_layers) def __getitem__(self, key: int) -> Segment: return self.layers[key] def __iter__(self) -> Iterator[Segment]: return iter(self.layers) @property def itvls(self) -> Sequence[np.ndarray]: """Returns a list of all the intervals for each layer in the MultiSegment.""" return [layer.itvls for layer in self] @property def labels(self) -> Sequence[Sequence[str]]: """Returns a list of all the labels for each layer in the MultiSegment.""" return [layer.labels for layer in self]
[docs] @classmethod def from_json(cls, json_data: list, name: str | None = None) -> MultiSegment: """Data Ingestion from adobe json format. Args: json_data (list): A list of layers, where each layer is a tuple of (intervals, labels). `intervals` is a list of [start, end] times, and `labels` is a list of strings. name (str, optional): Name for the created MultiSegment. """ layers = [] for i, layer in enumerate(json_data, start=1): itvls, labels = layer layers.append(Segment.from_itvls(itvls, labels, name=f"L{i:02d}")) return cls(raw_layers=layers, name=name)
@classmethod def from_itvls( cls, itvls: Sequence[Sequence[float]], labels: Sequence[str], name: str | None = None ) -> MultiSegment: layers = [] for i in range(len(itvls)): layers.append(Segment.from_itvls(itvls[i], labels[i], name=f"L{i + 1:02d}")) return cls(raw_layers=layers, name=name)
[docs] def plot(self, colorscale: str | list[str] = "D3", hatch: bool = True) -> go.Figure: """Plots the MultiSegment on a Plotly figure. Args: colorscale (str | list[str], optional): Plotly colorscale to use. Can be a qualitative scale name (e.g., "Set3", "Pastel") or a list of colors. hatch (bool, optional): Whether to use hatch patterns for different labels. Defaults to True. """ from . import viz return viz.plot_multisegment(ms=self, colorscale=colorscale, hatch=hatch)
[docs] def contour(self, strategy: str = "depth", **kwargs: Any) -> BoundaryContour: """Calculates boundary salience and converts to a BoundaryContour.""" from . import ops if strategy not in ops.SalienceStrategy._registry: raise ValueError(f"Unknown salience strategy: {strategy}") strategy_class = ops.SalienceStrategy._registry[strategy] contour_strategy = strategy_class(**kwargs) return contour_strategy(self)
[docs] def scrub_labels(self) -> MultiSegment: """Scrubs the labels of the MultiSegment by replacing them with empty strings.""" return MultiSegment(raw_layers=[layer.scrub_labels() for layer in self], name=self.name)
[docs] @staticmethod def find_span( layers: Sequence[Segment], mode: str = "common", ) -> TimeSpan: """Finds the span of a list of Segment layers. Args: mode (str, optional): The alignment mode. Can be "union" or "common". Defaults to "common". """ if mode == "union": inc_start_time = min(layer.start.time for layer in layers) inc_end_time = max(layer.end.time for layer in layers) elif mode == "common": inc_start_time = max(layer.start.time for layer in layers) inc_end_time = min(layer.end.time for layer in layers) else: raise ValueError(f"Unknown alignment mode: {mode}. Must be 'union' or 'common'.") return TimeSpan( start=Boundary(inc_start_time), end=Boundary(inc_end_time), name=f"{mode} span", )
[docs] def align(self, span: TimeSpan) -> MultiSegment: """Align with a TimeSpan object.""" return MultiSegment(raw_layers=[layer.align(span) for layer in self], name=self.name)
[docs] def prune_layers(self, relabel: bool = True) -> MultiSegment: """Prunes identical layers from the MultiSegment. This also gets rid of layers with no inner boundaries. """ pruned_layers = [] for layer in self: if len(layer) <= 1: continue # skip layers with no inner boundaries # The first valid layer is always added. if not pruned_layers: pruned_layers.append(layer) continue # Subsequent layers are added only if they differ from the previous one. same_boundaries = np.array_equal(layer.bs, pruned_layers[-1].bs) same_labeling = np.array_equal(layer.lam, pruned_layers[-1].lam) if not (same_boundaries and same_labeling): pruned_layers.append(layer) final_layers = pruned_layers if relabel: final_layers = [ replace(layer, name=f"L{i:02d}") for i, layer in enumerate(pruned_layers, start=1) ] return replace(self, raw_layers=final_layers)
[docs] def squeeze_layers(self, times: int = 1, relabel: bool = True) -> MultiSegment: """Remove the least informative layer from the MultiSegment according to vmeasure. Returns a new MultiSegment with the most redundant layer removed. """ from mir_eval.segment import vmeasure if times <= 0 or len(self) <= 1: return self elif times == 1: # base case: # get rid of the level that adds the least information # look at vmeasure between all consecutive levels, # get the one with the highest vmeasure with the next level v_f1 = [ vmeasure(lv1.itvls, lv1.labels, lv2.itvls, lv2.labels)[2] for lv1, lv2 in zip(self, self[1:]) ] idx_to_pop = np.argmax(v_f1) new_layers = [layer for i, layer in enumerate(self) if i != idx_to_pop] if relabel: new_layers = [ replace(layer, name=f"L{i:02d}") for i, layer in enumerate(new_layers, start=1) ] return replace(self, raw_layers=new_layers) else: # Recurse return self.squeeze_layers(times - 1, relabel=relabel).squeeze_layers( times=1, relabel=relabel )
# endregion: MultiSegment # region: Monotonic Boundary
[docs] @dataclass(frozen=True) class BoundaryContour(TimeSpan): """ An intermediate, purely structural representation of boundary salience over time. """ bs: Sequence[RatedBoundary] start: Boundary = field(init=False) end: Boundary = field(init=False) def __post_init__(self): if not self.bs or len(self.bs) < 2: raise ValueError("A BoundaryContour requires at least two boundaries.") if any(self.bs[i] > self.bs[i + 1] for i in range(len(self.bs) - 1)): raise ValueError(f"Boundaries must be sorted. {self.bs}") # Use object.__setattr__ to assign to the init=False fields. object.__setattr__(self, "start", self.bs[0]) object.__setattr__(self, "end", self.bs[-1]) super().__post_init__() def __len__(self) -> int: return len(self.bs) - 2 def __getitem__(self, key: int) -> RatedBoundary: return self.bs[1:-1][key] def __iter__(self) -> Iterator[RatedBoundary]: return iter(self.bs[1:-1])
[docs] def plot(self, **kwargs: Any) -> go.Figure: """Plots the BoundaryContour on a Plotly figure. Args: fig: Optional Plotly Figure to add to. **kwargs: Additional keyword arguments to pass to the plotting function. Returns: A Plotly Figure object with the boundary contour visualization. """ from . import viz return viz.plot_boundary_contour(self, **kwargs)
[docs] def clean(self, strategy: str = "absorb", **kwargs: Any) -> BoundaryContour: """Cleans up the boundary contour using a specified strategy. This is a convenience wrapper around `bnl.ops.clean_boundaries`. Args: strategy (str): The cleaning strategy to use. See `bnl.ops.clean_boundaries` for details. Defaults to 'absorb'. **kwargs: Additional keyword arguments to pass to the strategy (e.g., `window`). Returns: BoundaryContour: A new, cleaned BoundaryContour. """ from . import ops if strategy not in ops.CleanStrategy._registry: raise ValueError(f"Unknown boundary cleaning strategy: {strategy}") # Retrieve the class from the registry and instantiate it with the provided arguments. strategy_class = ops.CleanStrategy._registry[strategy] clean_strategy = strategy_class(**kwargs) return clean_strategy(self)
[docs] def level(self, strategy: str = "unique", **kwargs: Any) -> BoundaryHierarchy: """ Converts the BoundaryContour to a BoundaryHierarchy by quantizing salience. """ from . import ops if strategy not in ops.LevelStrategy._registry: raise ValueError(f"Unknown boundary level strategy: {strategy}") strategy_class = ops.LevelStrategy._registry[strategy] level_strategy = strategy_class(**kwargs) return level_strategy(self)
[docs] @dataclass(frozen=True) class BoundaryHierarchy(BoundaryContour): """ The structural output of the monotonic casting process. """ bs: Sequence[LeveledBoundary] start: Boundary = field(init=False) end: Boundary = field(init=False) def __post_init__(self): for boundary in self.bs: if not isinstance(boundary, LeveledBoundary): raise TypeError("All boundaries must be LeveledBoundary instances") # Use object.__setattr__ to assign to the init=False fields. object.__setattr__(self, "start", self.bs[0]) object.__setattr__(self, "end", self.bs[-1]) super().__post_init__()
[docs] def to_ms(self) -> MultiSegment: """Convert the BoundaryHierarchy to a MultiSegment. The MultiSegment will have layers from coarsest (highest level) to finest (lowest level), with empty strings for all labels. Returns: MultiSegment: The resulting MultiSegment object. """ layers = [] max_level = max(b.level for b in self.bs) for level in range(max_level, 0, -1): level_boundaries = [Boundary(b.time) for b in self.bs if b.level >= level] labels = [""] * (len(level_boundaries) - 1) layers.append( Segment( bs=level_boundaries, labels=labels, name=f"L{max_level - level + 1:02d}", ) ) return MultiSegment(raw_layers=layers, name=f"{self.name} Monotonic MS")
# endregion