Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion basic_pitch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
SEMITONES_PER_OCTAVE = 12 # for frequency bin calculations

FFT_HOP = 256
N_FFT = 8 * FFT_HOP

NOTES_BINS_PER_SEMITONE = 1
CONTOURS_BINS_PER_SEMITONE = 3
Expand Down
14 changes: 11 additions & 3 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
AUDIO_N_SAMPLES,
ANNOTATIONS_FPS,
FFT_HOP,
AUDIO_WINDOW_LENGTH,
)
from basic_pitch.commandline_printing import (
generating_file_message,
Expand Down Expand Up @@ -247,13 +248,15 @@ def unwrap_output(
output: npt.NDArray[np.float32],
audio_original_length: int,
n_overlapping_frames: int,
hop_size: int,
) -> np.array:
"""Unwrap batched model predictions to a single matrix.

Args:
output: array (n_batches, n_times_short, n_freqs)
audio_original_length: length of original audio signal (in samples)
n_overlapping_frames: number of overlapping frames in the output
hop_size: size of the hop used when scanning the input audio

Returns:
array (n_times, n_freqs)
Expand All @@ -266,10 +269,14 @@ def unwrap_output(
# remove half of the overlapping frames from beginning and end
output = output[:, n_olap:-n_olap, :]

# Concatenate the frames outputs (overlapping frames removed) into a single dimension
output_shape = output.shape
n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
unwrapped_output = output.reshape(output_shape[0] * output_shape[1], output_shape[2])
return unwrapped_output[:n_output_frames_original, :] # trim to original audio length

# trim to number of expected windows in output
n_expected_windows = audio_original_length / hop_size
n_frames_per_window = (AUDIO_WINDOW_LENGTH * ANNOTATIONS_FPS) - n_overlapping_frames
return unwrapped_output[: int(n_expected_windows * n_frames_per_window), :]


def run_inference(
Expand Down Expand Up @@ -303,7 +310,8 @@ def run_inference(
output[k].append(v)

unwrapped_output = {
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames, hop_size)
for k in output
}

if debug_file:
Expand Down
Binary file modified tests/resources/vocadito_10/model_output.npz
Binary file not shown.
10 changes: 9 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
import numpy as np
import numpy.typing as npt

from basic_pitch import ICASSP_2022_MODEL_PATH, inference
from basic_pitch import ICASSP_2022_MODEL_PATH, inference, note_creation
from basic_pitch.constants import (
AUDIO_SAMPLE_RATE,
AUDIO_N_SAMPLES,
ANNOTATIONS_N_SEMITONES,
FFT_HOP,
ANNOTATION_HOP,
)

RESOURCES_PATH = pathlib.Path(__file__).parent / "resources"
Expand All @@ -55,6 +56,13 @@ def test_predict() -> None:
assert all(note_pitch_max)
assert isinstance(note_events, list)

# Check that model output has the expected length according to the last frame second computed downstream
# (via model_frames_to_time) with to a few frames of tolerance
audio_length_s = librosa.get_duration(filename=test_audio_path)
n_model_output_frames = model_output["note"].shape[0]
last_frame_s = note_creation.model_frames_to_time(n_model_output_frames)[-1]
np.testing.assert_allclose(last_frame_s, audio_length_s, atol=2 * ANNOTATION_HOP)

expected_model_output = np.load(RESOURCES_PATH / "vocadito_10" / "model_output.npz", allow_pickle=True)[
"arr_0"
].item()
Expand Down