test_songid_pgvector_path.py
810 Bytes
import unittest
from scripts.evaluate_songid_pgvector_path import aggregate_song_scores, compute_metrics
class SongIdPgvectorPathTests(unittest.TestCase):
def test_aggregate_song_scores_ranks_by_combined_score(self):
song_ids = ['a', 'a', 'b', 'c']
sims = [0.9, 0.85, 0.95, 0.2]
idxs = [0, 1, 2, 3]
ranked = aggregate_song_scores(song_ids, sims, idxs)
self.assertEqual(ranked[0][0], 'b')
self.assertEqual(ranked[1][0], 'a')
def test_compute_metrics(self):
metrics = compute_metrics([1, 2, 4], 5)
self.assertEqual(metrics['count'], 3)
self.assertEqual(metrics['top1'], 0.333333)
self.assertEqual(metrics['top3'], 0.666667)
self.assertEqual(metrics['top5'], 1.0)
if __name__ == '__main__':
unittest.main()