factory.py
1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""
音乐分析器工厂
"""
from typing import Dict, Any, Optional
from .base import AudioAnalyzer
from .qwen_analyzer import QwenAnalyzer
class AnalyzerFactory:
"""音乐分析器工厂"""
_analyzers: Dict[str, AudioAnalyzer] = {}
@classmethod
def get_analyzer(cls, provider: str = "qwen", **kwargs) -> AudioAnalyzer:
"""
获取分析器实例
Args:
provider: 提供商名称(仅支持 qwen)
**kwargs: 额外配置参数(如 api_key, model, timeout 等)
Returns:
AudioAnalyzer 实例
"""
key = f"{provider}"
cache_key = f"{provider}_{kwargs.get('model', '')}"
if cache_key in cls._analyzers:
return cls._analyzers[cache_key]
if provider == "qwen":
analyzer = QwenAnalyzer(**kwargs)
else:
raise ValueError(f"Unknown provider: {provider}. Only 'qwen' is supported.")
cls._analyzers[cache_key] = analyzer
return analyzer
@classmethod
def get_default_analyzer(cls) -> AudioAnalyzer:
"""获取默认分析器(从环境变量读取)"""
import os
provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
return cls.get_analyzer(provider=provider)
@classmethod
def list_providers(cls) -> list:
"""列出可用的提供商"""
return ["qwen"]
@classmethod
def clear_cache(cls):
"""清除缓存的分析器实例"""
cls._analyzers.clear()