doubao_analyzer.py 12.4 KB
# -*- coding: utf-8 -*-
"""
火山引擎豆包音乐分析器实现
"""

import os
import time
import logging
from typing import Dict, Any, Optional
from dotenv import load_dotenv
from pathlib import Path

import httpx

from .base import AudioAnalyzer
from .prompts import build_analyze_prompt, build_lyrics_prompt

_ROOT_DIR = Path(__file__).resolve().parents[2]
load_dotenv(_ROOT_DIR / ".env")

logger = logging.getLogger(__name__)


class DoubaoAnalyzer(AudioAnalyzer):
    """火山引擎豆包音乐分析器"""

    def __init__(
        self,
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
        model: Optional[str] = None,
        timeout: float = 60.0,
        max_retries: int = 3,
    ):
        """
        初始化豆包分析器

        Args:
            api_key: API Key(默认从环境变量读取 DOUBAO_API_KEY 或 ARK_API_KEY)
            base_url: API 基础URL(默认: https://ark.cn-beijing.volces.com/api/v3)
            model: 模型名称(默认: doubao-seed-1-8-251228)
            timeout: 超时时间(秒)
            max_retries: 最大重试次数
        """
        self.api_key = api_key or os.getenv("DOUBAO_API_KEY", os.getenv("ARK_API_KEY"))
        self.base_url = base_url or os.getenv(
            "DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"
        )
        self.model = model or os.getenv("DOUBAO_MODEL", "doubao-seed-1-8-251228")
        self.timeout = timeout
        self.max_retries = max_retries

        self._client = None

    def _get_client(self) -> httpx.Client:
        """获取 HTTP 客户端"""
        if self._client is None:
            self._client = httpx.Client(
                base_url=self.base_url,
                timeout=self.timeout,
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json",
                },
            )
        return self._client

    def get_provider_name(self) -> str:
        return "doubao"

    def get_model_name(self) -> str:
        return self.model

    def analyze(
        self,
        metadata: Dict[str, Any],
        music_url: str,
        extract_lyrics: bool = False,
        label_level: int = 0,
    ) -> Optional[Dict[str, Any]]:
        """
        分析音乐

        Args:
            metadata: 音乐元数据
            music_url: 音乐文件 URL
            extract_lyrics: 是否识别歌词
            label_level: 标签级别

        Returns:
            分析结果字典
        """
        client = self._get_client()

        if extract_lyrics:
            return self._analyze_with_lyrics(client, metadata, music_url, label_level)
        else:
            return self._analyze_basic(client, metadata, music_url, label_level)

    def _analyze_basic(
        self,
        client: httpx.Client,
        metadata: Dict[str, Any],
        music_url: str,
        label_level: int = 0,
    ) -> Optional[Dict[str, Any]]:
        """基础分析(不含歌词)"""
        system_prompt, user_prompt = build_analyze_prompt(
            metadata=metadata,
            include_lyrics=False,
            label_level=label_level,
        )

        # 打印提示词到日志
        logger.info(f"[DoubaoAnalyzer] System Prompt:\n{system_prompt}")
        logger.info(f"[DoubaoAnalyzer] User Prompt:\n{user_prompt}")

        messages = self._build_messages(system_prompt, user_prompt, music_url)

        response = self._call_with_retry(client, messages)

        if response is None:
            return None

        result = self._parse_response(response.get("content", ""))
        if result is None:
            return None

        return self._normalize_result(result, self.model, response.get("usage"))

    def _analyze_with_lyrics(
        self,
        client: httpx.Client,
        metadata: Dict[str, Any],
        music_url: str,
        label_level: int = 0,
    ) -> Optional[Dict[str, Any]]:
        """分析(含歌词识别,需要两次调用)"""
        # 第一次调用:基本信息(不含歌词)
        system_prompt, user_prompt = build_analyze_prompt(
            metadata=metadata,
            include_lyrics=False,
            label_level=label_level,
        )

        # 打印提示词到日志
        logger.info(f"[DoubaoAnalyzer] System Prompt (with lyrics):\n{system_prompt}")
        logger.info(f"[DoubaoAnalyzer] User Prompt (with lyrics):\n{user_prompt}")

        messages_basic = self._build_messages(system_prompt, user_prompt, music_url)
        response_basic = self._call_with_retry(client, messages_basic)

        if response_basic is None:
            return None

        result = self._parse_response(response_basic.get("content", ""))
        if result is None:
            return None

        # 第二次调用:歌词识别
        lyrics_prompt = build_lyrics_prompt()

        # 打印歌词识别提示词到日志
        logger.info(f"[DoubaoAnalyzer] Lyrics Prompt:\n{lyrics_prompt}")

        messages_lyrics = self._build_messages(
            "请识别这段音频中的歌词内容", lyrics_prompt, music_url
        )
        response_lyrics = self._call_with_retry(client, messages_lyrics)

        lyrics_result = None
        if response_lyrics:
            lyrics_result = self._parse_response(response_lyrics.get("content", ""))
            if lyrics_result and "lyrics" in lyrics_result:
                result["lyrics"] = lyrics_result["lyrics"]

        # 合并 token 使用信息
        usage = response_basic.get("usage", {})
        if response_lyrics and response_lyrics.get("usage"):
            usage_lyrics = response_lyrics["usage"]
            usage = {
                "prompt_tokens": usage.get("prompt_tokens", 0)
                + usage_lyrics.get("prompt_tokens", 0),
                "completion_tokens": usage.get("completion_tokens", 0)
                + usage_lyrics.get("completion_tokens", 0),
                "total_tokens": usage.get("total_tokens", 0)
                + usage_lyrics.get("total_tokens", 0),
            }

        return self._normalize_result(result, self.model, usage)

    def _build_messages(
        self,
        system_prompt: str,
        user_prompt: str,
        music_url: str,
    ) -> list:
        """构建消息格式"""
        return [
            {
                "role": "user",
                "content": [
                    {"type": "video_url", "video_url": {"url": music_url}},
                    {"type": "text", "text": user_prompt},
                ],
            }
        ]

    def _call_with_retry(
        self,
        client: httpx.Client,
        messages: list,
    ) -> Optional[Dict]:
        """带重试的 API 调用"""
        endpoint = "/chat/completions"
        data = {
            "model": self.model,
            "messages": messages,
            "temperature": 0.7,
            "max_tokens": 4000,
            "stream": False,
        }

        for attempt in range(1, self.max_retries + 1):
            try:
                print(f"  [Doubao] 调用模型 (尝试 {attempt}/{self.max_retries})...")

                start_time = time.time()

                response = client.post(endpoint, json=data)
                response.raise_for_status()

                end_time = time.time()
                elapsed = end_time - start_time
                print(f"  [Doubao] 响应时间: {elapsed:.2f}s")

                result = response.json()
                content = (
                    result.get("choices", [{}])[0].get("message", {}).get("content", "")
                )
                usage = result.get("usage", {})

                print(f"  [Doubao] 响应: {content[:100]}...")

                return {
                    "content": content,
                    "usage": {
                        "prompt_tokens": usage.get("prompt_tokens", 0),
                        "completion_tokens": usage.get("completion_tokens", 0),
                        "total_tokens": usage.get("total_tokens", 0),
                    },
                }

            except httpx.HTTPError as e:
                error_type = type(e).__name__
                print(f"  [Doubao] HTTP 错误 ({error_type}): {e}")

                if attempt < self.max_retries:
                    wait_time = attempt
                    print(f"    等待 {wait_time} 秒后重试...")
                    time.sleep(wait_time)
                else:
                    print(f"    已达到最大重试次数")
                    return None

            except Exception as e:
                error_type = type(e).__name__
                print(f"  [Doubao] 错误 ({error_type}): {e}")

                if attempt < self.max_retries:
                    wait_time = attempt
                    print(f"    等待 {wait_time} 秒后重试...")
                    time.sleep(wait_time)
                else:
                    print(f"    已达到最大重试次数")
                    return None

        return None


def test_doubao_audio_url_lyrics():
    """
    测试豆包是否支持通过音频URL解析音频歌词

    此测试用例用于验证豆包模型是否能够:
    1. 接收音频URL作为输入
    2. 解析音频内容
    3. 识别并返回歌词

    使用方法:
        python -c "from app.middleware.music_analyze.doubao_analyzer import test_doubao_audio_url_lyrics; test_doubao_audio_url_lyrics()"

    或者直接在命令行运行:
        python app/middleware/music_analyze/doubao_analyzer.py
    """
    import json

    print("=" * 80)
    print("测试豆包音频URL歌词解析功能")
    print("=" * 80)

    # 测试音频URL(使用一个公开可访问的音频文件)
    # 注意:请替换为实际可访问的音频URL
    test_audio_url = "https://hikoon-ai-test.oss-cn-hangzhou.aliyuncs.com/ai/cache/modelName/20260114/_s__e_1768376270519_rmab41.mp3"

    print(f"\n测试音频URL: {test_audio_url}")
    print("\n开始测试...")

    try:
        # 初始化分析器
        analyzer = DoubaoAnalyzer()

        # 测试元数据
        metadata = {"title": "测试歌曲", "artist": "测试艺术家", "test": True}

        print("\n1. 测试基础分析(不含歌词)...")
        result_basic = analyzer.analyze(
            metadata=metadata,
            music_url=test_audio_url,
            extract_lyrics=False,
            label_level=0,
        )

        if result_basic:
            print("   ✓ 基础分析成功")
            print(f"   - 曲风: {result_basic.get('genre', 'N/A')}")
            print(f"   - 语种: {result_basic.get('language', 'N/A')}")
            print(f"   - 情绪: {result_basic.get('emotion', 'N/A')}")
        else:
            print("   ✗ 基础分析失败")

        print("\n2. 测试歌词识别(含歌词)...")
        result_with_lyrics = analyzer.analyze(
            metadata=metadata,
            music_url=test_audio_url,
            extract_lyrics=True,
            label_level=0,
        )

        if result_with_lyrics:
            print("   ✓ 歌词识别分析成功")
            lyrics = result_with_lyrics.get("lyrics", [])
            if lyrics:
                print(f"   ✓ 成功识别歌词,共 {len(lyrics)} 行")
                print("\n   歌词预览(前5行):")
                for i, line in enumerate(lyrics[:5], 1):
                    time_str = line.get("time", "N/A")
                    text = line.get("text", "")
                    print(f"      [{i}] {time_str} - {text}")
                if len(lyrics) > 5:
                    print(f"      ... 还有 {len(lyrics) - 5} 行")
                print("\n   ✓ 测试通过:豆包支持音频URL解析歌词")
            else:
                print("   ⚠ 未识别到歌词(可能是纯音乐或无法识别)")
                print("\n   ! 测试结果:豆包支持音频URL解析,但未返回歌词")

            # 输出完整结果
            print("\n3. 完整分析结果:")
            print(json.dumps(result_with_lyrics, ensure_ascii=False, indent=2))
        else:
            print("   ✗ 歌词识别分析失败")
            print("\n   ✗ 测试失败:豆包可能不支持音频URL解析")

        print("\n" + "=" * 80)
        print("测试完成")
        print("=" * 80)

        return result_with_lyrics

    except Exception as e:
        print(f"\n✗ 测试过程中发生错误: {e}")
        import traceback

        traceback.print_exc()
        return None


if __name__ == "__main__":
    test_doubao_audio_url_lyrics()