test_songid_pgvector_path.py 810 Bytes
import unittest

from scripts.evaluate_songid_pgvector_path import aggregate_song_scores, compute_metrics


class SongIdPgvectorPathTests(unittest.TestCase):
    def test_aggregate_song_scores_ranks_by_combined_score(self):
        song_ids = ['a', 'a', 'b', 'c']
        sims = [0.9, 0.85, 0.95, 0.2]
        idxs = [0, 1, 2, 3]
        ranked = aggregate_song_scores(song_ids, sims, idxs)
        self.assertEqual(ranked[0][0], 'b')
        self.assertEqual(ranked[1][0], 'a')

    def test_compute_metrics(self):
        metrics = compute_metrics([1, 2, 4], 5)
        self.assertEqual(metrics['count'], 3)
        self.assertEqual(metrics['top1'], 0.333333)
        self.assertEqual(metrics['top3'], 0.666667)
        self.assertEqual(metrics['top5'], 1.0)


if __name__ == '__main__':
    unittest.main()