Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def determine_split(index: int) -> str:
return "test"

guitarset = mirdata.initialize("guitarset")
guitarset.download(["index"])
track_ids = guitarset.track_ids
random.shuffle(track_ids)

Expand Down
2 changes: 2 additions & 0 deletions basic_pitch/data/datasets/ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from basic_pitch.data import commandline, pipeline


# Oct 2025: Ikala remote download is broken on mirdata side # TODO: Re-evaluate later
class IkalaInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
Expand Down Expand Up @@ -142,6 +143,7 @@ def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[
random.seed(seed)

ikala = mirdata.initialize("ikala")
ikala.download(["index"])
track_ids = ikala.track_ids
random.shuffle(track_ids)

Expand Down
21 changes: 5 additions & 16 deletions basic_pitch/data/datasets/maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import os
import sys
import tempfile
import time
from typing import Any, Dict, List, TextIO, Tuple

Expand Down Expand Up @@ -164,20 +163,10 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
return [batch]


def create_input_data(source: str) -> List[Tuple[str, str]]:
import apache_beam as beam

filesystem = beam.io.filesystems.FileSystems()

with tempfile.TemporaryDirectory() as tmpdir:
maestro = mirdata.initialize("maestro", data_home=tmpdir)
metadata_path = maestro._index["metadata"]["maestro-v2.0.0"][0]
with filesystem.open(
os.path.join(source, metadata_path),
) as s, open(os.path.join(tmpdir, metadata_path), "wb") as d:
d.write(s.read())

return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]
def create_input_data() -> List[Tuple[str, str]]:
maestro = mirdata.initialize("maestro")
maestro.download(["metadata"])
return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand All @@ -198,7 +187,7 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
input_data = create_input_data(known_args.source)
input_data = create_input_data()
pipeline.run(
pipeline_options,
pipeline_args,
Expand Down
1 change: 1 addition & 0 deletions basic_pitch/data/datasets/medleydb_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[
random.seed(seed)

medleydb_pitch = mirdata.initialize("medleydb_pitch")
medleydb_pitch.download(["index"])
track_ids = medleydb_pitch.track_ids
random.shuffle(track_ids)

Expand Down
1 change: 1 addition & 0 deletions basic_pitch/data/datasets/slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def process(self, element: List[str]) -> List[Any]:

def create_input_data() -> List[Tuple[str, str]]:
slakh = mirdata.initialize("slakh")
slakh.download(["index"])
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]


Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ bp-download = "basic_pitch.data.download:main"
data = [
"basic_pitch[tf,test]",
"apache_beam",
# TODO: mirdata 0.3.9 moves dataset indexes files which breaks our tests
# Adapt our codebase to release that constraint
"mirdata<=0.3.8",
"mirdata>=1.0.0",
"smart_open",
"sox",
"ffmpeg-python"
Expand Down
50 changes: 50 additions & 0 deletions tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import json
import pathlib
from unittest import mock

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
GUITAR_SET_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "guitarset" / "dummy_index.json"))
IKALA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "ikala" / "dummy_index.json"))
MAESTRO_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "dummy_index.json"))
METADATA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "maestro-v2.0.0.json"))
MEDLEYDB_PITCH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "medleydb_pitch" / "dummy_index.json"))
SLAKH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "slakh" / "dummy_index.json"))


@pytest.fixture # type: ignore[misc]
def mock_slakh_index() -> None: # type: ignore[misc]
with mock.patch("mirdata.datasets.slakh.Dataset.download"):
with mock.patch("mirdata.datasets.slakh.Dataset._index", new=SLAKH_TEST_INDEX):
yield


@pytest.fixture # type: ignore[misc]
def mock_medleydb_pitch_index() -> None: # type: ignore[misc]
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset.download"):
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset._index", new=MEDLEYDB_PITCH_TEST_INDEX):
yield


@pytest.fixture # type: ignore[misc]
def mock_maestro_index() -> None: # type: ignore[misc]
index_with_metadata = MAESTRO_TEST_INDEX
metadata = {mdata["midi_filename"].split(".")[0]: mdata for mdata in METADATA_TEST_INDEX}
with mock.patch("mirdata.datasets.maestro.Dataset.download"):
with mock.patch("mirdata.datasets.maestro.Dataset._metadata", new=metadata):
with mock.patch("mirdata.datasets.maestro.Dataset._index", new=index_with_metadata):
yield


@pytest.fixture # type: ignore[misc]
def mock_guitarset_index() -> None: # type: ignore[misc]
with mock.patch("mirdata.datasets.guitarset.Dataset.download"):
with mock.patch("mirdata.datasets.guitarset.Dataset._index", new=GUITAR_SET_TEST_INDEX):
yield


@pytest.fixture # type: ignore[misc]
def mock_ikala_index() -> None: # type: ignore[misc]
with mock.patch("mirdata.datasets.ikala.Dataset.download"):
with mock.patch("mirdata.datasets.ikala.Dataset._index", new=IKALA_TEST_INDEX):
yield
5 changes: 2 additions & 3 deletions tests/data/test_guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import os
import pathlib
import shutil

from apache_beam.testing.test_pipeline import TestPipeline
from typing import List

Expand All @@ -36,7 +35,7 @@
TRACK_ID = "00_BN1-129-Eb_comp"


def test_guitarset_to_tf_example(tmp_path: pathlib.Path) -> None:
def test_guitarset_to_tf_example(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
mock_guitarset_home = tmp_path / "guitarset"
mock_guitarset_audio = mock_guitarset_home / "audio_mono-mic"
mock_guitarset_annotations = mock_guitarset_home / "annotation"
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_guitarset_invalid_tracks(tmpdir: str) -> None:
assert fp.read().strip() == str(i)


def test_guitarset_create_input_data() -> None:
def test_guitarset_create_input_data(mock_guitarset_index: None) -> None:
data = create_input_data(train_percent=0.33, validation_percent=0.33)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.1
Expand Down
4 changes: 1 addition & 3 deletions tests/data/test_ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
import apache_beam as beam
import itertools
import os

from apache_beam.testing.test_pipeline import TestPipeline

from basic_pitch.data.datasets.ikala import (
IkalaInvalidTracks,
create_input_data,
)


# TODO: Create test_ikala_to_tf_example


Expand All @@ -51,7 +49,7 @@ def test_ikala_invalid_tracks(tmpdir: str) -> None:
assert fp.read().strip() == str(i)


def test_ikala_create_input_data() -> None:
def test_ikala_create_input_data(mock_ikala_index: None) -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.1
Expand Down
11 changes: 5 additions & 6 deletions tests/data/test_maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# limitations under the License.
import os
import pathlib

from typing import List

import apache_beam as beam
Expand All @@ -40,7 +39,7 @@
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"


def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
def test_maestro_to_tf_example(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -68,7 +67,7 @@ def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
assert len(data) != 0


def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
def test_maestro_invalid_tracks(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -98,7 +97,7 @@ def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
assert fp.read().strip() == track_id


def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
"""
The track id used here is a real track id in maestro, and it is part of the train split, but we mock the data so as
not to store a large file in git, hence the variable name.
Expand Down Expand Up @@ -131,13 +130,13 @@ def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
assert fp.read().strip() == ""


def test_maestro_create_input_data() -> None:
def test_maestro_create_input_data(mock_maestro_index: None) -> None:
"""
A commuted metadata file is included in the repo for testing. mirdata references the metadata file to
populate the tracklist with metadata. Since the file is commuted to only the filenames referenced here,
we only consider these when testing the metadata.
"""
data = create_input_data(str(MAESTRO_TEST_DATA_PATH))
data = create_input_data()
assert len(data)

test_fnames = {TRAIN_TRACK_ID, VALID_TRACK_ID, TEST_TRACK_ID, GT_15M_TRACK_ID}
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_medleydb_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None:
assert fp.read().strip() == str(i)


def test_medleydb_create_input_data() -> None:
def test_medleydb_create_input_data(mock_medleydb_pitch_index: None) -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.01
Expand Down
10 changes: 5 additions & 5 deletions tests/data/test_slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_mock_input_data(data_home: pathlib.Path, input_data: List[Tuple[str,
shutil.copy(SLAKH_PATH / split / track_num / "metadata.yaml", track_dir / "metadata.yaml")


def test_slakh_to_tf_example(tmp_path: pathlib.Path) -> None:
def test_slakh_to_tf_example(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

Expand All @@ -92,7 +92,7 @@ def test_slakh_to_tf_example(tmp_path: pathlib.Path) -> None:
assert len(data) != 0


def test_slakh_invalid_tracks(tmp_path: pathlib.Path) -> None:
def test_slakh_invalid_tracks(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

Expand All @@ -119,7 +119,7 @@ def test_slakh_invalid_tracks(tmp_path: pathlib.Path) -> None:
assert fp.read().strip() == track_id


def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path) -> None:
def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path) -> None:
assert fp.read().strip() == ""


def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path) -> None:
def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

Expand All @@ -175,7 +175,7 @@ def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path) -> None:
assert fp.read().strip() == ""


def test_create_input_data() -> None:
def test_create_input_data(mock_slakh_index: None) -> None:
data = create_input_data()
for _, group in itertools.groupby(data, lambda el: el[1]):
assert len(list(group))
7 changes: 4 additions & 3 deletions tests/data/test_tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def mock_and_process(split: str, track_id: str) -> None:
return output_home


def test_prepare_datasets(tmp_path: pathlib.Path) -> None:
def test_prepare_datasets(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
datasets_home = setup_test_resources(tmp_path)

ds_train, ds_valid = prepare_datasets(
Expand All @@ -102,7 +102,7 @@ def test_prepare_datasets(tmp_path: pathlib.Path) -> None:
assert ds_valid is not None and isinstance(ds_valid, tf.data.Dataset)


def test_prepare_visualization_dataset(tmp_path: pathlib.Path) -> None:
def test_prepare_visualization_dataset(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
datasets_home = setup_test_resources(tmp_path)

ds_train, ds_valid = prepare_visualization_datasets(
Expand All @@ -117,7 +117,7 @@ def test_prepare_visualization_dataset(tmp_path: pathlib.Path) -> None:
assert ds_valid is not None and isinstance(ds_train, tf.data.Dataset)


def test_sample_datasets(tmp_path: pathlib.Path) -> None:
def test_sample_datasets(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
"""touches the following methods:
- transcription_dataset
- parse_transcription_tfexample
Expand All @@ -126,6 +126,7 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None:
- reduce_transcription_inputs
- get_sample_weights
- _infer_time_size
- _infer_time_size
- get_transcription_chunks
- extract_random_window
- extract_window
Expand Down
Loading