factory.py 1.51 KB
# -*- coding: utf-8 -*-
"""
音乐分析器工厂
"""

from typing import Dict, Any, Optional
from .base import AudioAnalyzer
from .qwen_analyzer import QwenAnalyzer


class AnalyzerFactory:
    """音乐分析器工厂"""

    _analyzers: Dict[str, AudioAnalyzer] = {}

    @classmethod
    def get_analyzer(cls, provider: str = "qwen", **kwargs) -> AudioAnalyzer:
        """
        获取分析器实例

        Args:
            provider: 提供商名称(仅支持 qwen)
            **kwargs: 额外配置参数(如 api_key, model, timeout 等)

        Returns:
            AudioAnalyzer 实例
        """
        key = f"{provider}"
        cache_key = f"{provider}_{kwargs.get('model', '')}"

        if cache_key in cls._analyzers:
            return cls._analyzers[cache_key]

        if provider == "qwen":
            analyzer = QwenAnalyzer(**kwargs)
        else:
            raise ValueError(f"Unknown provider: {provider}. Only 'qwen' is supported.")

        cls._analyzers[cache_key] = analyzer
        return analyzer

    @classmethod
    def get_default_analyzer(cls) -> AudioAnalyzer:
        """获取默认分析器(从环境变量读取)"""
        import os

        provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
        return cls.get_analyzer(provider=provider)

    @classmethod
    def list_providers(cls) -> list:
        """列出可用的提供商"""
        return ["qwen"]

    @classmethod
    def clear_cache(cls):
        """清除缓存的分析器实例"""
        cls._analyzers.clear()