"""Core data loading classes for manifest-based datasets."""
from __future__ import annotations
__all__ = [
"Track",
"Dataset",
]
import io
import json
import os
from collections.abc import Iterator
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any, Literal, cast
import jams
import pandas as pd
import requests
from .core import MultiSegment, Segment
[docs]
@dataclass
class Track:
"""A single track and its associated data assets."""
track_id: str
manifest_row: pd.Series
dataset: Dataset
def __repr__(self) -> str:
has_columns = self.manifest_row.filter(like="has_").astype(bool)
num_assets = int(has_columns.values.sum()) if len(has_columns) > 0 else 0
return (
f"Track(track_id='{self.track_id}', num_assets={num_assets}, "
f"source='{self.dataset.data_location}')"
)
@cached_property
def info(self) -> dict[str, Any]:
"""Essential track information (cached)."""
info: dict[str, Any] = {"track_id": self.track_id}
# Reconstruct paths for all available assets
for col_name, has_asset in self.manifest_row.items():
if str(col_name).startswith("has_") and has_asset:
parts = col_name.replace("has_", "").split("_", 1)
asset_type = parts[0]
asset_subtype = parts[1] if len(parts) > 1 else None
if asset_type and asset_subtype:
path_or_url = self.dataset._reconstruct_path(
self.track_id, asset_type, asset_subtype
)
info[f"{asset_type}_{asset_subtype}_path"] = path_or_url
return info
@cached_property
def refs(self) -> dict[str, MultiSegment]:
"""Returns available reference annotations."""
# Get the jams reference file and find all the annotators
# Add JAMS metadata if reference annotation exists
if self.jam is not None:
annotators = [
ann.annotation_metadata.annotator.name
for ann in self.jam.search(namespace="segment_salami_function")
]
annotators = list(set(annotators))
else:
annotators = []
return {a_id: self.load_annotation("reference", a_id) for a_id in annotators}
@cached_property
def ests(self) -> dict[str, MultiSegment]:
"""Returns available estimated annotations."""
# Find all available estimated annotations from info
est_keys = [key for key in self.info if key.startswith("annotation_adobe")]
est_ids = [key.replace("annotation_adobe-", "").replace("_path", "") for key in est_keys]
return {est_id: self.load_annotation(f"adobe-{est_id}") for est_id in est_ids}
@cached_property
def jam(self) -> jams.JAMS | None:
"""Returns the reference JAMS object for this track."""
jam_path = self.info.get("annotation_reference_path")
if jam_path is not None:
return jams.load(self._fetch_content(jam_path))
return None
[docs]
def load_annotation(self, annotation_type: str, annotator: str | None = None) -> MultiSegment:
"""Loads a specific annotation as a MultiSegment.
Parameters:
annotation_type (str): One of:
- 'reference' to load the reference JAMS. Optionally pass
`annotator` to select a specific annotator by name.
- 'adobe-<id>' to load an Adobe JSON (e.g., 'adobe-mu1gamma1').
annotator (str | None): Name of the annotator in the JAMS file to select.
Raises:
ValueError: If the requested annotation is unavailable for this track.
NotImplementedError: If the file type is unsupported (.jams and .json are supported).
"""
annotation_key = f"annotation_{annotation_type}_path"
if annotation_key not in self.info:
raise ValueError(f"Annotation type '{annotation_type}' not available for this track.")
annotation_path = self.info[annotation_key]
if str(annotation_path).lower().endswith(".jams"):
return self._load_jams_anno(annotation_path, name=annotator)
elif str(annotation_path).lower().endswith(".json"):
return self._load_json(annotation_path, name=annotation_type)
else:
raise NotImplementedError(f"Unsupported file type: {annotation_path}")
def _load_jams_anno(self, path: str | Path, name: str | None = None) -> MultiSegment:
"""
Find the annotator with name `name` in the JAMS file, and load it as a `MultiSegment`.
If `name` is None, find the first annotator in the JAMS file.
Each MultiSegment contains two layers:
- coarse (`segment_salami_function`)
- fine (`segment_salami_lower`)
"""
jam = jams.load(self._fetch_content(path))
search_name = name if name is not None else ""
uppers = jam.search(namespace="segment_salami_function").search(name=search_name)
lowers = jam.search(namespace="segment_salami_lower").search(name=search_name)
if len(uppers) == 0 or len(lowers) == 0:
raise ValueError(f"No annotator found for {name}")
return MultiSegment(
raw_layers=[
Segment.from_jams(uppers[0], name="coarse"),
Segment.from_jams(lowers[0], name="fine"),
],
name=f"annotator-{uppers[0].annotation_metadata.annotator.name}",
)
def _load_json(self, path: str | Path, name: str | None = None) -> MultiSegment:
"""Loads a JSON annotation as a MultiSegment."""
json_data = json.load(self._fetch_content(path))
ms_name = "JSON Annotation" if name is None else name
return MultiSegment.from_json(json_data, name=ms_name)
@staticmethod
def _fetch_content(path: str | Path) -> io.StringIO:
"""Fetches file content into a memory buffer. works for local files and urls."""
if isinstance(path, str) and path.startswith("http"):
response = requests.get(
str(path),
timeout=float(os.getenv("BNL_HTTP_TIMEOUT", "10")),
headers={"User-Agent": "bnl"},
)
response.raise_for_status()
return io.StringIO(response.text)
elif Path(path).exists():
with open(path, encoding="utf-8") as f:
return io.StringIO(f.read())
else:
raise FileNotFoundError(f"File not found: {path}")
[docs]
class Dataset:
"""A manifest-based dataset."""
track_ids: list[str]
manifest: pd.DataFrame
# Allow overriding the public bucket via environment for easy configuration in
# local development without code changes.
R2_BUCKET_PUBLIC_URL: str = os.getenv(
"BNL_R2_BUCKET_PUBLIC_URL",
"https://pub-05e404c031184ec4bbf69b0c2321b98e.r2.dev",
)
data_location: Literal["local", "cloud"] = field(init=False)
def __init__(self, manifest_path: Path | str | None = None):
if manifest_path is None:
manifest_path = f"{self.R2_BUCKET_PUBLIC_URL}/manifest_cloud_boolean.csv"
self.manifest_path = manifest_path
self.data_location = (
"cloud"
if isinstance(manifest_path, str) and manifest_path.startswith("http")
else "local"
)
if self.data_location == "local":
expanded_path = Path(manifest_path).expanduser()
self.dataset_root: Path | str = expanded_path.parent
self.base_url: str | None = None
else:
self.base_url = str(manifest_path).rsplit("/", 1)[0]
self.dataset_root = self.base_url
# Load manifest
try:
load_path = (
Path(manifest_path).expanduser() if self.data_location == "local" else manifest_path
)
if self.data_location == "cloud":
# Minimal robustness: use a short timeout and a simple User-Agent.
response = requests.get(
str(manifest_path),
timeout=float(os.getenv("BNL_HTTP_TIMEOUT", "10")),
headers={"User-Agent": "bnl"},
)
response.raise_for_status()
self.manifest = pd.read_csv(io.StringIO(response.text))
else:
self.manifest = pd.read_csv(load_path)
except FileNotFoundError as e:
raise FileNotFoundError(f"Manifest not found: {manifest_path}") from e
self.manifest["track_id"] = self.manifest["track_id"].astype(str)
self.manifest.set_index("track_id", inplace=True, drop=False)
# Only include tracks that have the reference annotation
self.manifest = self.manifest[
self.manifest.filter(like="has_annotation_reference").astype(bool).values.any(axis=1)
]
try:
self.track_ids = sorted(self.manifest["track_id"].unique(), key=int)
except ValueError:
self.track_ids = sorted(self.manifest["track_id"].unique())
def __getitem__(self, track_id: str) -> Track:
"""Load a specific track by its ID."""
track_id = str(track_id)
if track_id not in self.track_ids:
raise ValueError(f"Track ID '{track_id}' not found in manifest.")
return Track(track_id, self.manifest.loc[track_id], self)
def __len__(self) -> int:
return len(self.track_ids)
def __iter__(self) -> Iterator[Track]:
for track_id in self.track_ids:
yield self[track_id]
@staticmethod
def _format_adobe_params(asset_subtype: str) -> str:
"""Convert adobe asset subtype to formatted parameters."""
mu_gamma = asset_subtype.split("-")[1]
if mu_gamma == "mu1gamma1":
return "mu_0.1_gamma_0.1"
elif mu_gamma == "mu5gamma5":
return "mu_0.5_gamma_0.5"
elif mu_gamma == "mu1gamma9":
return "mu_0.1_gamma_0.9"
else:
return mu_gamma
def _reconstruct_local_path(self, track_id: str, asset_type: str, asset_subtype: str) -> Path:
"""Reconstruct local file path for an asset."""
root = cast(Path, self.dataset_root)
if asset_type == "audio":
return root / "audio" / track_id / f"audio.{asset_subtype}"
elif asset_type == "annotation":
if asset_subtype.startswith("ref_") or asset_subtype == "reference":
return root / "jams" / f"{track_id}.jams"
elif "adobe" in asset_subtype:
# Adobe annotations have a specific subfolder structure.
subfolder = f"adobe/def_{self._format_adobe_params(asset_subtype)}"
return root / subfolder / f"{track_id}.mp3.msdclasscsnmagic.json"
raise ValueError(f"Unknown local asset: {asset_type}/{asset_subtype}")
def _reconstruct_cloud_url(self, track_id: str, asset_type: str, asset_subtype: str) -> str:
"""Reconstruct cloud URL for an asset."""
base = cast(str, self.base_url)
if asset_type == "audio":
return f"{base}/slm-dataset/{track_id}/audio.{asset_subtype}"
elif asset_type == "annotation" and (
asset_subtype.startswith("ref_") or asset_subtype == "reference"
):
return f"{base}/ref-jams/{track_id}.jams"
elif asset_type == "annotation" and "adobe" in asset_subtype:
subfolder = f"adobe21-est/def_{self._format_adobe_params(asset_subtype)}"
return f"{base}/{subfolder}/{track_id}.mp3.msdclasscsnmagic.json"
raise ValueError(f"Unknown cloud asset: {asset_type}/{asset_subtype}")
def _reconstruct_path(self, track_id: str, asset_type: str, asset_subtype: str) -> Path | str:
"""
Reconstruct the full path or URL for an asset based on the dataset's
data_location.
Dispatches to _reconstruct_local_path for local assets or
_reconstruct_cloud_url for cloud assets.
"""
if self.data_location == "local":
return self._reconstruct_local_path(track_id, asset_type, asset_subtype)
elif self.data_location == "cloud":
return self._reconstruct_cloud_url(track_id, asset_type, asset_subtype)
else:
raise ValueError(f"Unknown data location: {self.data_location}")