Initial commit
0 parents
Showing
26 changed files
with
8160 additions
and
0 deletions
.env.example
0 → 100644
| 1 | # Required for qwen | ||
| 2 | QWEN_API_KEY=sk-d9b4d3581bde47d887354f9160a509a2 | ||
| 3 | QWEN_DASHSCOPE_API_KEY= | ||
| 4 | QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 | ||
| 5 | QWEN_MODEL=qwen3-omni-flash | ||
| 6 | QWEN_TIMEOUT=15 | ||
| 7 | QWEN_LYRICS_TIMEOUT=90 | ||
| 8 | QWEN_MAX_RETRIES=3 | ||
| 9 | MUSIC_ANALYZE_LIGHT_MODE=true | ||
| 10 | MUSIC_DOWNLOAD_DIR=music | ||
| 11 | MUSIC_MAPPING_FILE=music/music_file_mapping.csv | ||
| 12 | |||
| 13 | # Optional song structure service | ||
| 14 | SONGFORMER_URL= | ||
| 15 | |||
| 16 | # Optional ASR backend for lyrics_only path | ||
| 17 | MUSIC_LYRICS_ASR_BACKEND=funasr | ||
| 18 | DASHSCOPE_FUNASR_MODEL=fun-asr | ||
| 19 | DASHSCOPE_BASE_HTTP_API_URL=https://dashscope.aliyuncs.com/api/v1 | ||
| 20 | DASHSCOPE_ASR_POLL_INTERVAL=1 | ||
| 21 | DASHSCOPE_ASR_POLL_TIMEOUT=120 | ||
| 22 | DASHSCOPE_ASR_SUBMIT_URL=https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription | ||
| 23 | DASHSCOPE_ASR_MODEL=qwen3-asr-flash-filetrans | ||
| 24 | DASHSCOPE_TASK_STATUS_BASE_URL=https://dashscope.aliyuncs.com/api/v1/tasks |
.gitignore
0 → 100644
| 1 | .DS_Store | ||
| 2 | |||
| 3 | # Python cache | ||
| 4 | __pycache__/ | ||
| 5 | *.py[cod] | ||
| 6 | *.so | ||
| 7 | .pytest_cache/ | ||
| 8 | .mypy_cache/ | ||
| 9 | |||
| 10 | # Virtual env | ||
| 11 | .venv/ | ||
| 12 | venv/ | ||
| 13 | |||
| 14 | # Local env | ||
| 15 | .env | ||
| 16 | |||
| 17 | # Logs | ||
| 18 | logs/ | ||
| 19 | *.log | ||
| 20 | |||
| 21 | # Runtime outputs | ||
| 22 | outputs/ | ||
| 23 | music/ | ||
| 24 | *.checkpoint.json | ||
| 25 | |||
| 26 | # Local test/sample data | ||
| 27 | *.xlsx | ||
| 28 | *.xls | ||
| 29 | *.csv | ||
| 30 | |||
| 31 | # Keep env template and source files | ||
| 32 | !.env.example |
README.md
0 → 100644
| 1 | # music_analyze_v2 | ||
| 2 | |||
| 3 | 当前项目是一个基于 Excel 批量跑音频标签分析的独立流水线。 | ||
| 4 | |||
| 5 | 实际主流程: | ||
| 6 | |||
| 7 | 1. 读取输入 `xlsx` | ||
| 8 | 2. 从指定 URL 列取音频地址 | ||
| 9 | 3. 透传部分元数据给音乐分析器 | ||
| 10 | 4. 调用 `app.middleware.music_analyze.analyze_music(...)` | ||
| 11 | 5. 将结果整理成固定交付列并持续写回输出 `xlsx` | ||
| 12 | 6. 通过已有输出文件和 checkpoint 支持断点续跑 | ||
| 13 | |||
| 14 | 当前批处理入口是 [`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)。 | ||
| 15 | |||
| 16 | ## 当前状态 | ||
| 17 | |||
| 18 | - 可直接运行的主入口:[`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py) | ||
| 19 | - 当前默认分析链路:`QwenAnalyzer` | ||
| 20 | - 当前实际可用 provider:`qwen` | ||
| 21 | - 提示词来源:[`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode) | ||
| 22 | - 输出格式:固定交付列,不保留原始全部输入列 | ||
| 23 | |||
| 24 | 说明: | ||
| 25 | |||
| 26 | - 命令行参数里虽然还保留了 `--provider doubao` 选项,但当前 [`factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py) 只实例化 `qwen`,传 `doubao` 会在运行时失败。 | ||
| 27 | - README 以下内容按“当前代码实际行为”描述,而不是按历史规划描述。 | ||
| 28 | |||
| 29 | ## 安装 | ||
| 30 | |||
| 31 | ```bash | ||
| 32 | python3.10 -m venv .venv | ||
| 33 | source .venv/bin/activate | ||
| 34 | pip install -r requirements.txt | ||
| 35 | cp .env.example .env | ||
| 36 | ``` | ||
| 37 | |||
| 38 | ## 环境变量 | ||
| 39 | |||
| 40 | 最小必需配置通常是: | ||
| 41 | |||
| 42 | ```env | ||
| 43 | QWEN_API_KEY=your_api_key | ||
| 44 | QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 | ||
| 45 | QWEN_MODEL=qwen3-omni-flash | ||
| 46 | QWEN_TIMEOUT=15 | ||
| 47 | QWEN_LYRICS_TIMEOUT=90 | ||
| 48 | QWEN_MAX_RETRIES=3 | ||
| 49 | ``` | ||
| 50 | |||
| 51 | 项目还支持以下可选增强能力: | ||
| 52 | |||
| 53 | - `QWEN_DASHSCOPE_API_KEY`:部分 DashScope/ASR 路径会用到 | ||
| 54 | - `SONGFORMER_URL`:启用额外音频结构特征 | ||
| 55 | - `MUSIC_LYRICS_ASR_BACKEND`、`DASHSCOPE_*`:歌词提取相关配置 | ||
| 56 | - `OSS_*`:音频过大时走 OSS 降级上传 | ||
| 57 | |||
| 58 | 配置定义见 [`app/core/config.py`](/Users/sqy/Downloads/music_analyze_v2/app/core/config.py)。 | ||
| 59 | |||
| 60 | ## 输入要求 | ||
| 61 | |||
| 62 | 输入文件必须是 `xlsx`。 | ||
| 63 | |||
| 64 | 至少需要一列音频地址。脚本按下面顺序解析 URL 列: | ||
| 65 | |||
| 66 | - 显式传入的 `--url-column` | ||
| 67 | - `URL` | ||
| 68 | - `url` | ||
| 69 | - `cos访问地址` | ||
| 70 | - `cos_url` | ||
| 71 | - `audio_url` | ||
| 72 | |||
| 73 | 若整行 URL 为空: | ||
| 74 | |||
| 75 | - 不会发起分析 | ||
| 76 | - 该行会被直接跳过 | ||
| 77 | - 在断点续跑里会被视为已处理 | ||
| 78 | |||
| 79 | 元数据不是必填,但建议提供。脚本会优先识别这些字段: | ||
| 80 | |||
| 81 | - `歌曲ID` / `song_id` / `id` | ||
| 82 | - `tmeid` / `tmeID` / `TMEID` | ||
| 83 | - `歌曲名` / `歌曲名称` / `title` | ||
| 84 | - `表演者` / `歌手` / `artist` | ||
| 85 | - `歌曲时长` / `duration` | ||
| 86 | |||
| 87 | 默认会额外透传这些列给模型作为 metadata: | ||
| 88 | |||
| 89 | - `tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者` | ||
| 90 | |||
| 91 | 可通过 `--metadata-columns` 覆盖。 | ||
| 92 | |||
| 93 | ## 快速开始 | ||
| 94 | |||
| 95 | 常规跑批: | ||
| 96 | |||
| 97 | ```bash | ||
| 98 | python pipeline/batch_analyze_xlsx.py \ | ||
| 99 | --input 待分析.xlsx \ | ||
| 100 | --output outputs/标签交付结果.xlsx \ | ||
| 101 | --url-column URL \ | ||
| 102 | --provider qwen \ | ||
| 103 | --workers 3 | ||
| 104 | ``` | ||
| 105 | |||
| 106 | 提取歌词: | ||
| 107 | |||
| 108 | ```bash | ||
| 109 | python pipeline/batch_analyze_xlsx.py \ | ||
| 110 | --input 待分析.xlsx \ | ||
| 111 | --output outputs/标签交付结果.xlsx \ | ||
| 112 | --url-column URL \ | ||
| 113 | --provider qwen \ | ||
| 114 | --workers 3 \ | ||
| 115 | --extract-lyrics | ||
| 116 | ``` | ||
| 117 | |||
| 118 | 从头重跑,不复用历史输出或 checkpoint: | ||
| 119 | |||
| 120 | ```bash | ||
| 121 | python pipeline/batch_analyze_xlsx.py \ | ||
| 122 | --input 待分析.xlsx \ | ||
| 123 | --output outputs/标签交付结果.xlsx \ | ||
| 124 | --provider qwen \ | ||
| 125 | --no-resume | ||
| 126 | ``` | ||
| 127 | |||
| 128 | |||
| 129 | ## 命令行参数 | ||
| 130 | |||
| 131 | | 参数 | 说明 | 当前实际行为 | | ||
| 132 | |------|------|-------------| | ||
| 133 | | `--input` | 输入 Excel 路径 | 必填 | | ||
| 134 | | `--output` | 输出 Excel 路径 | 必填 | | ||
| 135 | | `--checkpoint` | checkpoint 文件路径 | 默认是 `<output>.checkpoint.json` | | ||
| 136 | | `--url-column` | URL 列名 | 默认 `URL`,不存在时会自动 fallback | | ||
| 137 | | `--provider` | 分析 provider | 参数允许 `qwen`/`doubao`,当前实际只应使用 `qwen` | | ||
| 138 | | `--extract-lyrics` | 是否提取歌词 | 开启后会走带歌词分析路径 | | ||
| 139 | | `--label-level` | 标签级别 | `0` 或 `1` | | ||
| 140 | | `--metadata-columns` | 额外透传给模型的列 | 逗号分隔 | | ||
| 141 | | `--workers` | 并发线程数 | 默认 `3` | | ||
| 142 | | `--checkpoint-every` | 每处理多少行保存一次 | 默认 `10` | | ||
| 143 | | `--no-resume` | 禁用断点续跑 | 默认关闭 | | ||
| 144 | |||
| 145 | ## 输出结构 | ||
| 146 | |||
| 147 | 脚本输出的是固定交付表,不是“原始输入列 + 分析列”的全量回写。 | ||
| 148 | |||
| 149 | 当前输出列定义在 [`batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py) 的 `DEFAULT_OUTPUT_COLUMNS`: | ||
| 150 | |||
| 151 | - `tmeid` | ||
| 152 | - `歌曲ID` | ||
| 153 | - `歌曲名` | ||
| 154 | - `表演者` | ||
| 155 | - `歌曲时长` | ||
| 156 | - `表演者类型` | ||
| 157 | - `语种` | ||
| 158 | - `BPM速度` | ||
| 159 | - `情绪` | ||
| 160 | - `网络/抖音歌曲` | ||
| 161 | - `音乐风格` | ||
| 162 | - `配器` | ||
| 163 | - `场景` | ||
| 164 | |||
| 165 | 结果字段映射规则: | ||
| 166 | |||
| 167 | - `表演者类型` <- `performer_type` 或 `vocal_texture` | ||
| 168 | - `语种` <- `language` | ||
| 169 | - `BPM速度` <- `bpm` | ||
| 170 | - `情绪` <- `emotion` | ||
| 171 | - `网络/抖音歌曲` <- `douyin_tags` | ||
| 172 | - `音乐风格` <- `music_style_tags`,否则回退到 `genre/sub_genre` | ||
| 173 | - `配器` <- `instrument_tags` | ||
| 174 | - `场景` <- `scene` | ||
| 175 | |||
| 176 | 列表型字段会被拼成 `、` 分隔字符串。 | ||
| 177 | |||
| 178 | ## 断点续跑 | ||
| 179 | |||
| 180 | 当前断点续跑逻辑比 README 旧版描述更具体,实际行为如下: | ||
| 181 | |||
| 182 | - 如果输出文件已存在,且行数与本次输入一致: | ||
| 183 | 直接按行号复用历史输出 | ||
| 184 | - 如果输出文件已存在,但行数不一致: | ||
| 185 | 尝试按 `歌曲ID` 和 `tmeid` 复用旧结果 | ||
| 186 | - 如果 checkpoint 存在: | ||
| 187 | 会在“输出按索引对齐”的前提下合并 checkpoint 完成状态 | ||
| 188 | - 空 URL 行会直接加入 completed 集合 | ||
| 189 | - 处理中按 `--checkpoint-every` 周期性落盘 | ||
| 190 | - `Ctrl+C` 时会先保存当前进度,再强制退出避免卡住线程 | ||
| 191 | |||
| 192 | 默认 checkpoint 文件名: | ||
| 193 | |||
| 194 | ```text | ||
| 195 | <output>.checkpoint.json | ||
| 196 | ``` | ||
| 197 | |||
| 198 | ## 提示词与分析链路 | ||
| 199 | |||
| 200 | 批处理脚本本身不直接读取 prompt 文件,而是走统一分析入口: | ||
| 201 | |||
| 202 | [`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py) | ||
| 203 | -> [`app/middleware/music_analyze/__init__.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/__init__.py) | ||
| 204 | -> [`app/middleware/music_analyze/music_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/music_analyzer.py) | ||
| 205 | -> [`app/middleware/music_analyze/factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py) | ||
| 206 | -> [`app/middleware/music_analyze/qwen_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/qwen_analyzer.py) | ||
| 207 | -> [`app/middleware/music_analyze/prompts.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/prompts.py) | ||
| 208 | |||
| 209 | 当前 prompt 目录固定为: | ||
| 210 | |||
| 211 | - [`music_analyze_system_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt.md) | ||
| 212 | - [`music_analyze_system_prompt_part_a.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_a.md) | ||
| 213 | - [`music_analyze_system_prompt_part_b.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_b.md) | ||
| 214 | - [`music_analyze_user_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_user_prompt.md) | ||
| 215 | - [`music_lyrics_only_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_lyrics_only_prompt.md) | ||
| 216 | |||
| 217 | ## 项目结构 | ||
| 218 | |||
| 219 | ```text | ||
| 220 | music_analyze_v2/ | ||
| 221 | ├── app/ | ||
| 222 | │ ├── core/ | ||
| 223 | │ │ └── config.py | ||
| 224 | │ ├── middleware/ | ||
| 225 | │ │ └── music_analyze/ | ||
| 226 | │ │ ├── __init__.py | ||
| 227 | │ │ ├── base.py | ||
| 228 | │ │ ├── factory.py | ||
| 229 | │ │ ├── music_analyzer.py | ||
| 230 | │ │ ├── prompts.py | ||
| 231 | │ │ ├── qwen_analyzer.py | ||
| 232 | │ │ ├── doubao_analyzer.py | ||
| 233 | │ │ ├── audio_features.py | ||
| 234 | │ │ └── bpm_analyzer_tools.py | ||
| 235 | │ ├── prompts/ | ||
| 236 | │ │ └── step2_music_decode/ | ||
| 237 | │ └── utils/ | ||
| 238 | ├── pipeline/ | ||
| 239 | │ └── batch_analyze_xlsx.py | ||
| 240 | ├── outputs/ | ||
| 241 | ├── requirements.txt | ||
| 242 | ├── .env | ||
| 243 | ├── .env.example | ||
| 244 | └── README.md | ||
| 245 | ``` | ||
| 246 | |||
| 247 | ## 依赖 | ||
| 248 | |||
| 249 | 基础依赖见 [`requirements.txt`](/Users/sqy/Downloads/music_analyze_v2/requirements.txt)。 | ||
| 250 | |||
| 251 | 当前显式包含: | ||
| 252 | |||
| 253 | - `openai` | ||
| 254 | - `requests` | ||
| 255 | - `httpx` | ||
| 256 | - `python-dotenv` | ||
| 257 | - `pydantic-settings` | ||
| 258 | - `numpy` | ||
| 259 | - `scipy` | ||
| 260 | - `librosa` | ||
| 261 | - `soundfile` | ||
| 262 | - `pandas` | ||
| 263 | - `openpyxl` | ||
| 264 | |||
| 265 | `dashscope` 在 `requirements.txt` 中仍是注释状态;如果你要跑依赖该 SDK 的歌词路径,需要自行安装并校验对应代码分支。 | ||
| 266 | |||
| 267 | ## 常见问题 | ||
| 268 | |||
| 269 | ### 为什么传了 `--provider doubao` 还是失败? | ||
| 270 | |||
| 271 | 因为当前 CLI 还保留了 `doubao` 选项,但分析器工厂只支持 `qwen`。这是代码现状,不是使用方式问题。 | ||
| 272 | |||
| 273 | ### 输出为什么没有保留原 Excel 的全部列? | ||
| 274 | |||
| 275 | 因为当前脚本在保存时只写 `DEFAULT_OUTPUT_COLUMNS`,这是代码的固定行为。 | ||
| 276 | |||
| 277 | ### 修改提示词应该改哪里? | ||
| 278 | |||
| 279 | 改 [`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode) 下的模板文件即可。 | ||
| 280 | |||
| 281 | ### 行数变了还能续跑吗? | ||
| 282 | |||
| 283 | 可以部分复用。脚本会尝试按 `歌曲ID` 和 `tmeid` 匹配历史输出。 | ||
| 284 | |||
| 285 | ### 如何完全重跑? | ||
| 286 | |||
| 287 | 加 `--no-resume`,并删除旧输出和旧 checkpoint,最干净。 |
app/__init__.py
0 → 100644
| 1 | """Standalone audio analysis package.""" |
app/core/__init__.py
0 → 100644
app/core/config.py
0 → 100644
| 1 | """Minimal settings for standalone audio analysis pipeline.""" | ||
| 2 | |||
| 3 | from pydantic_settings import BaseSettings, SettingsConfigDict | ||
| 4 | |||
| 5 | |||
| 6 | class Settings(BaseSettings): | ||
| 7 | model_config = SettingsConfigDict( | ||
| 8 | env_file=".env", | ||
| 9 | env_file_encoding="utf-8", | ||
| 10 | extra="ignore", | ||
| 11 | ) | ||
| 12 | |||
| 13 | # Qwen | ||
| 14 | QWEN_API_KEY: str | None = None | ||
| 15 | QWEN_DASHSCOPE_API_KEY: str | None = None | ||
| 16 | QWEN_BASE_URL: str | None = "https://dashscope.aliyuncs.com/compatible-mode/v1" | ||
| 17 | QWEN_MODEL: str | None = "qwen3-omni-flash" | ||
| 18 | QWEN_TIMEOUT: float = 15.0 | ||
| 19 | QWEN_LYRICS_TIMEOUT: float = 90.0 | ||
| 20 | QWEN_MAX_RETRIES: int = 3 | ||
| 21 | MUSIC_ANALYZE_LIGHT_MODE: bool = True | ||
| 22 | MUSIC_DOWNLOAD_DIR: str = "music" | ||
| 23 | MUSIC_MAPPING_FILE: str = "music/music_file_mapping.csv" | ||
| 24 | |||
| 25 | # Optional features | ||
| 26 | SONGFORMER_URL: str | None = None | ||
| 27 | |||
| 28 | # DashScope ASR | ||
| 29 | DASHSCOPE_FUNASR_MODEL: str = "fun-asr" | ||
| 30 | DASHSCOPE_BASE_HTTP_API_URL: str = "https://dashscope.aliyuncs.com/api/v1" | ||
| 31 | DASHSCOPE_ASR_POLL_INTERVAL: float = 1.0 | ||
| 32 | DASHSCOPE_ASR_POLL_TIMEOUT: float = 120.0 | ||
| 33 | DASHSCOPE_ASR_SUBMIT_URL: str = ( | ||
| 34 | "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription" | ||
| 35 | ) | ||
| 36 | DASHSCOPE_ASR_MODEL: str = "qwen3-asr-flash-filetrans" | ||
| 37 | DASHSCOPE_TASK_STATUS_BASE_URL: str = "https://dashscope.aliyuncs.com/api/v1/tasks" | ||
| 38 | |||
| 39 | # OSS | ||
| 40 | OSS_ACCESS_KEY_ID: str | None = None | ||
| 41 | OSS_ACCESS_KEY_SECRET: str | None = None | ||
| 42 | OSS_ENDPOINT: str | None = None | ||
| 43 | OSS_BUCKET_NAME: str | None = None | ||
| 44 | OSS_ENDPOINT_INTERNAL: str | None = None | ||
| 45 | |||
| 46 | |||
| 47 | settings = Settings() |
app/core/exceptions.py
0 → 100644
| 1 | """ | ||
| 2 | 自定义异常定义 | ||
| 3 | |||
| 4 | 所有业务异常都应该继承自 APIException, | ||
| 5 | 由全局异常处理器统一处理并返回标准格式的错误响应 | ||
| 6 | """ | ||
| 7 | from fastapi import HTTPException, status | ||
| 8 | from typing import Optional, Any | ||
| 9 | |||
| 10 | |||
| 11 | class APIException(HTTPException): | ||
| 12 | """ | ||
| 13 | API基础异常 | ||
| 14 | |||
| 15 | 所有业务异常的基类,可以被全局异常处理器捕获和统一处理 | ||
| 16 | """ | ||
| 17 | |||
| 18 | def __init__( | ||
| 19 | self, | ||
| 20 | status_code: int = status.HTTP_400_BAD_REQUEST, | ||
| 21 | detail: str = None, | ||
| 22 | error_code: str = None, | ||
| 23 | data: Any = None, | ||
| 24 | headers: dict = None, | ||
| 25 | ): | ||
| 26 | super().__init__(status_code=status_code, detail=detail, headers=headers) | ||
| 27 | self.error_code = error_code or "UNKNOWN_ERROR" | ||
| 28 | self.data = data | ||
| 29 | |||
| 30 | |||
| 31 | class UnauthorizedException(APIException): | ||
| 32 | """未授权异常 - 认证失败""" | ||
| 33 | |||
| 34 | def __init__(self, detail: str = "未授权", error_code: str = "UNAUTHORIZED"): | ||
| 35 | super().__init__( | ||
| 36 | status_code=status.HTTP_401_UNAUTHORIZED, | ||
| 37 | detail=detail, | ||
| 38 | error_code=error_code | ||
| 39 | ) | ||
| 40 | |||
| 41 | |||
| 42 | class ForbiddenException(APIException): | ||
| 43 | """禁止访问异常 - 权限不足""" | ||
| 44 | |||
| 45 | def __init__(self, detail: str = "禁止访问", error_code: str = "FORBIDDEN"): | ||
| 46 | super().__init__( | ||
| 47 | status_code=status.HTTP_403_FORBIDDEN, | ||
| 48 | detail=detail, | ||
| 49 | error_code=error_code | ||
| 50 | ) | ||
| 51 | |||
| 52 | |||
| 53 | class NotFoundException(APIException): | ||
| 54 | """资源不存在异常""" | ||
| 55 | |||
| 56 | def __init__(self, detail: str = "资源不存在", error_code: str = "NOT_FOUND"): | ||
| 57 | super().__init__( | ||
| 58 | status_code=status.HTTP_404_NOT_FOUND, | ||
| 59 | detail=detail, | ||
| 60 | error_code=error_code | ||
| 61 | ) | ||
| 62 | |||
| 63 | |||
| 64 | class ConflictException(APIException): | ||
| 65 | """冲突异常 - 资源已存在""" | ||
| 66 | |||
| 67 | def __init__(self, detail: str = "资源已存在", error_code: str = "CONFLICT"): | ||
| 68 | super().__init__( | ||
| 69 | status_code=status.HTTP_409_CONFLICT, | ||
| 70 | detail=detail, | ||
| 71 | error_code=error_code | ||
| 72 | ) | ||
| 73 | |||
| 74 | |||
| 75 | class ValidationException(APIException): | ||
| 76 | """验证异常 - 输入验证失败""" | ||
| 77 | |||
| 78 | def __init__(self, detail: str = "验证失败", error_code: str = "VALIDATION_ERROR"): | ||
| 79 | super().__init__( | ||
| 80 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | ||
| 81 | detail=detail, | ||
| 82 | error_code=error_code | ||
| 83 | ) | ||
| 84 | |||
| 85 | |||
| 86 | class BusinessException(APIException): | ||
| 87 | """业务异常 - 业务规则验证失败""" | ||
| 88 | |||
| 89 | def __init__( | ||
| 90 | self, | ||
| 91 | detail: str = "业务操作失败", | ||
| 92 | error_code: str = "BUSINESS_ERROR", | ||
| 93 | status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| 94 | ): | ||
| 95 | super().__init__( | ||
| 96 | status_code=status_code, | ||
| 97 | detail=detail, | ||
| 98 | error_code=error_code | ||
| 99 | ) | ||
| 100 | |||
| 101 | |||
| 102 | class InternalServerException(APIException): | ||
| 103 | """内部服务器异常""" | ||
| 104 | |||
| 105 | def __init__( | ||
| 106 | self, | ||
| 107 | detail: str = "内部服务器错误", | ||
| 108 | error_code: str = "INTERNAL_SERVER_ERROR", | ||
| 109 | ): | ||
| 110 | super().__init__( | ||
| 111 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| 112 | detail=detail, | ||
| 113 | error_code=error_code | ||
| 114 | ) | ||
| 115 | |||
| 116 | |||
| 117 | class DatabaseException(APIException): | ||
| 118 | """数据库异常""" | ||
| 119 | |||
| 120 | def __init__( | ||
| 121 | self, | ||
| 122 | detail: str = "数据库操作失败", | ||
| 123 | error_code: str = "DATABASE_ERROR", | ||
| 124 | ): | ||
| 125 | super().__init__( | ||
| 126 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| 127 | detail=detail, | ||
| 128 | error_code=error_code | ||
| 129 | ) | ||
| 130 | |||
| 131 | |||
| 132 | class ExternalServiceException(APIException): | ||
| 133 | """外部服务异常 - 调用第三方服务失败""" | ||
| 134 | |||
| 135 | def __init__( | ||
| 136 | self, | ||
| 137 | detail: str = "外部服务调用失败", | ||
| 138 | error_code: str = "EXTERNAL_SERVICE_ERROR", | ||
| 139 | ): | ||
| 140 | super().__init__( | ||
| 141 | status_code=status.HTTP_502_BAD_GATEWAY, | ||
| 142 | detail=detail, | ||
| 143 | error_code=error_code | ||
| 144 | ) | ||
| 145 | |||
| 146 | |||
| 147 | class RateLimitException(APIException): | ||
| 148 | """限流异常 - 请求过于频繁""" | ||
| 149 | |||
| 150 | def __init__( | ||
| 151 | self, | ||
| 152 | detail: str = "请求过于频繁,请稍后再试", | ||
| 153 | error_code: str = "RATE_LIMIT_EXCEEDED", | ||
| 154 | ): | ||
| 155 | super().__init__( | ||
| 156 | status_code=status.HTTP_429_TOO_MANY_REQUESTS, | ||
| 157 | detail=detail, | ||
| 158 | error_code=error_code | ||
| 159 | ) |
app/middleware/__init__.py
0 → 100644
| 1 | """Middleware package.""" |
app/middleware/music_analyze/__init__.py
0 → 100644
| 1 | """ | ||
| 2 | 音乐分析模块 | ||
| 3 | 提供统一的音乐标签分析功能,支持通义千问和火山引擎豆包 | ||
| 4 | |||
| 5 | 主要功能: | ||
| 6 | - 音乐风格识别(与国际音乐分类体系对齐) | ||
| 7 | - 情绪识别 | ||
| 8 | - 人声质感识别 | ||
| 9 | - 语种识别 | ||
| 10 | - 节奏强度分析(1-5,用于指导视频剪辑) | ||
| 11 | - 高潮点识别 | ||
| 12 | - 视觉概念生成(用于MV创作) | ||
| 13 | - 歌词识别(可选) | ||
| 14 | |||
| 15 | 支持的提供商: | ||
| 16 | - qwen: 通义千问 (qwen3-omni-flash) | ||
| 17 | - doubao: 火山引擎豆包 (doubao-seed-1-8-251228) | ||
| 18 | |||
| 19 | 使用示例: | ||
| 20 | from app.middleware.music_analyze import analyze_music | ||
| 21 | |||
| 22 | # 基本分析 | ||
| 23 | result = analyze_music( | ||
| 24 | metadata={"title": "稻香", "artist": "周杰伦"}, | ||
| 25 | music_url="https://example.com/music.mp3", | ||
| 26 | provider="qwen", | ||
| 27 | ) | ||
| 28 | |||
| 29 | # 含歌词识别 | ||
| 30 | result = analyze_music( | ||
| 31 | metadata={"title": "稻香"}, | ||
| 32 | music_url="https://example.com/music.mp3", | ||
| 33 | provider="qwen", | ||
| 34 | extract_lyrics=True, | ||
| 35 | ) | ||
| 36 | """ | ||
| 37 | |||
| 38 | # 主函数导出 | ||
| 39 | from .music_analyzer import ( | ||
| 40 | analyze_music, | ||
| 41 | analyze_music_lyrics_only, | ||
| 42 | analyze_music_with_qwen, | ||
| 43 | analyze_music_with_doubao, | ||
| 44 | get_available_providers, | ||
| 45 | ) | ||
| 46 | |||
| 47 | # 类导出 | ||
| 48 | from .base import AudioAnalyzer | ||
| 49 | from .qwen_analyzer import QwenAnalyzer | ||
| 50 | from .doubao_analyzer import DoubaoAnalyzer | ||
| 51 | from .factory import AnalyzerFactory | ||
| 52 | |||
| 53 | __version__ = "1.0.0" |
| 1 | """ | ||
| 2 | 音频特征提取模块 | ||
| 3 | 提供音频特征提取、节奏强度和能量级别计算功能 | ||
| 4 | """ | ||
| 5 | |||
| 6 | import os | ||
| 7 | import warnings | ||
| 8 | import numpy as np | ||
| 9 | import librosa | ||
| 10 | from typing import Any, Dict, List, Optional, Tuple | ||
| 11 | from dataclasses import dataclass | ||
| 12 | |||
| 13 | from .bpm_analyzer_tools import RealtimeBPMAnalyzerTest | ||
| 14 | |||
| 15 | # 抑制 librosa 的 audioread 弃用警告 | ||
| 16 | warnings.filterwarnings("ignore", category=FutureWarning, module="librosa") | ||
| 17 | |||
| 18 | |||
| 19 | @dataclass | ||
| 20 | class AudioFeatures: | ||
| 21 | """音频特征数据""" | ||
| 22 | |||
| 23 | # 时域特征 | ||
| 24 | rms_energy: np.ndarray # RMS 能量 (帧级别) | ||
| 25 | rms_times: np.ndarray # 对应的时间戳 | ||
| 26 | |||
| 27 | # 频域特征 | ||
| 28 | spectral_centroid: np.ndarray # 频谱质心 (亮度) | ||
| 29 | spectral_rolloff: np.ndarray # 频谱滚降 (低频占比) | ||
| 30 | spectral_bandwidth: np.ndarray # 频谱带宽 | ||
| 31 | |||
| 32 | # 节奏特征 | ||
| 33 | onset_strength: np.ndarray # onset 强度 | ||
| 34 | tempo: float # BPM | ||
| 35 | |||
| 36 | # 统计信息 | ||
| 37 | duration: float | ||
| 38 | sr: int | ||
| 39 | |||
| 40 | |||
| 41 | def extract_audio_features(audio_path: str, hop_length: int = 512) -> AudioFeatures: | ||
| 42 | """ | ||
| 43 | 提取音频特征 | ||
| 44 | |||
| 45 | Args: | ||
| 46 | audio_path: 音频文件路径 | ||
| 47 | hop_length: 帧移长度 (默认 512 samples ≈ 11.6ms @ 44.1kHz) | ||
| 48 | |||
| 49 | Returns: | ||
| 50 | AudioFeatures: 音频特征对象 | ||
| 51 | """ | ||
| 52 | # 加载音频 | ||
| 53 | y, sr = librosa.load(audio_path, sr=None, mono=True) | ||
| 54 | duration = librosa.get_duration(y=y, sr=sr) | ||
| 55 | |||
| 56 | # 1. RMS 能量 (时域响度) | ||
| 57 | rms = librosa.feature.rms(y=y, hop_length=hop_length)[0] | ||
| 58 | rms_db = librosa.amplitude_to_db(rms, ref=np.max) | ||
| 59 | rms_times = librosa.frames_to_time( | ||
| 60 | np.arange(len(rms)), sr=sr, hop_length=hop_length | ||
| 61 | ) | ||
| 62 | |||
| 63 | # 2. 频谱特征 | ||
| 64 | spectral_centroid = librosa.feature.spectral_centroid( | ||
| 65 | y=y, sr=sr, hop_length=hop_length | ||
| 66 | )[0] | ||
| 67 | spectral_rolloff = librosa.feature.spectral_rolloff( | ||
| 68 | y=y, sr=sr, hop_length=hop_length | ||
| 69 | )[0] | ||
| 70 | spectral_bandwidth = librosa.feature.spectral_bandwidth( | ||
| 71 | y=y, sr=sr, hop_length=hop_length | ||
| 72 | )[0] | ||
| 73 | |||
| 74 | # 3. 节奏特征 | ||
| 75 | onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length) | ||
| 76 | |||
| 77 | # 使用统一 BPM 分析入口(带倍频纠正) | ||
| 78 | bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False) | ||
| 79 | bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr) | ||
| 80 | corrected_tempo = bpm_result.get('bpm', 120.0) | ||
| 81 | |||
| 82 | return AudioFeatures( | ||
| 83 | rms_energy=rms_db, | ||
| 84 | rms_times=rms_times, | ||
| 85 | spectral_centroid=spectral_centroid, | ||
| 86 | spectral_rolloff=spectral_rolloff, | ||
| 87 | spectral_bandwidth=spectral_bandwidth, | ||
| 88 | onset_strength=onset_env, | ||
| 89 | tempo=corrected_tempo, | ||
| 90 | duration=duration, | ||
| 91 | sr=int(sr), | ||
| 92 | ) | ||
| 93 | |||
| 94 | |||
| 95 | def calculate_rhythm_intensity(features: AudioFeatures) -> int: | ||
| 96 | """ | ||
| 97 | 根据音频特征计算节奏强度 (1-5) | ||
| 98 | |||
| 99 | 基于以下因素综合计算: | ||
| 100 | - BPM (速度) | ||
| 101 | - Onset 强度 (节奏密度) | ||
| 102 | - 能量变化 (动态范围) | ||
| 103 | |||
| 104 | Args: | ||
| 105 | features: 音频特征对象 | ||
| 106 | |||
| 107 | Returns: | ||
| 108 | int: 节奏强度 (1-5) | ||
| 109 | """ | ||
| 110 | tempo = features.tempo | ||
| 111 | onset = features.onset_strength | ||
| 112 | rms = features.rms_energy | ||
| 113 | |||
| 114 | # 1. BPM 得分 (40-200 BPM 映射到 1-5) | ||
| 115 | if tempo >= 160: | ||
| 116 | tempo_score = 5 | ||
| 117 | elif tempo >= 130: | ||
| 118 | tempo_score = 4 | ||
| 119 | elif tempo >= 100: | ||
| 120 | tempo_score = 3 | ||
| 121 | elif tempo >= 70: | ||
| 122 | tempo_score = 2 | ||
| 123 | else: | ||
| 124 | tempo_score = 1 | ||
| 125 | |||
| 126 | # 2. Onset 密度得分 | ||
| 127 | onset_mean = np.mean(onset) | ||
| 128 | onset_max = np.max(onset) if len(onset) > 0 else 1 | ||
| 129 | onset_density = onset_mean / onset_max if onset_max > 0 else 0 | ||
| 130 | |||
| 131 | if onset_density >= 0.5: | ||
| 132 | density_score = 5 | ||
| 133 | elif onset_density >= 0.4: | ||
| 134 | density_score = 4 | ||
| 135 | elif onset_density >= 0.3: | ||
| 136 | density_score = 3 | ||
| 137 | elif onset_density >= 0.2: | ||
| 138 | density_score = 2 | ||
| 139 | else: | ||
| 140 | density_score = 1 | ||
| 141 | |||
| 142 | # 3. 能量动态得分 | ||
| 143 | rms_std = np.std(rms) | ||
| 144 | if rms_std >= 15: | ||
| 145 | dynamic_score = 5 | ||
| 146 | elif rms_std >= 12: | ||
| 147 | dynamic_score = 4 | ||
| 148 | elif rms_std >= 9: | ||
| 149 | dynamic_score = 3 | ||
| 150 | elif rms_std >= 6: | ||
| 151 | dynamic_score = 2 | ||
| 152 | else: | ||
| 153 | dynamic_score = 1 | ||
| 154 | |||
| 155 | # 加权平均 (BPM 40%, 密度 35%, 动态 25%) | ||
| 156 | final_score = tempo_score * 0.4 + density_score * 0.35 + dynamic_score * 0.25 | ||
| 157 | |||
| 158 | return int(round(final_score)) | ||
| 159 | |||
| 160 | |||
| 161 | def calculate_energy_level( | ||
| 162 | features: AudioFeatures, | ||
| 163 | ) -> Tuple[int, Dict[str, float]]: | ||
| 164 | """ | ||
| 165 | 计算能量级别 (1-5) 和详细信息 | ||
| 166 | |||
| 167 | Args: | ||
| 168 | features: 音频特征对象 | ||
| 169 | |||
| 170 | Returns: | ||
| 171 | Tuple[int, Dict]: (能量级别 1-5, 详细信息字典) | ||
| 172 | """ | ||
| 173 | # 1. 响度得分 (基于 RMS 能量) | ||
| 174 | rms_db = features.rms_energy | ||
| 175 | loudness_normalized = np.clip((rms_db + 60) / 10, 0, 5) | ||
| 176 | loudness_score = float(np.percentile(loudness_normalized, 75)) | ||
| 177 | |||
| 178 | # 2. 亮度得分 (基于频谱质心) | ||
| 179 | centroid = features.spectral_centroid | ||
| 180 | centroid_normalized = np.clip(centroid / 4000, 0, 1) | ||
| 181 | brightness_score = float(np.mean(centroid_normalized)) * 5 | ||
| 182 | |||
| 183 | # 3. 节奏得分 (基于 onset 强度) | ||
| 184 | onset = features.onset_strength | ||
| 185 | onset_normalized = np.clip(onset / np.percentile(onset, 90), 0, 1) | ||
| 186 | rhythm_score = float(np.mean(onset_normalized)) * 5 | ||
| 187 | |||
| 188 | # 4. BPM 因子 | ||
| 189 | tempo = features.tempo | ||
| 190 | if tempo > 140: | ||
| 191 | tempo_factor = 1.3 | ||
| 192 | elif tempo > 120: | ||
| 193 | tempo_factor = 1.15 | ||
| 194 | elif tempo > 100: | ||
| 195 | tempo_factor = 1.0 | ||
| 196 | elif tempo > 80: | ||
| 197 | tempo_factor = 0.9 | ||
| 198 | else: | ||
| 199 | tempo_factor = 0.8 | ||
| 200 | |||
| 201 | # 综合计算 | ||
| 202 | weights = {"loudness": 0.40, "brightness": 0.25, "rhythm": 0.35} | ||
| 203 | |||
| 204 | composite_score = ( | ||
| 205 | weights["loudness"] * loudness_score | ||
| 206 | + weights["brightness"] * brightness_score | ||
| 207 | + weights["rhythm"] * rhythm_score | ||
| 208 | ) * tempo_factor | ||
| 209 | |||
| 210 | # 映射到 1-5 级别 | ||
| 211 | if composite_score < 1.5: | ||
| 212 | energy_level = 1 | ||
| 213 | elif composite_score < 2.5: | ||
| 214 | energy_level = 2 | ||
| 215 | elif composite_score < 3.5: | ||
| 216 | energy_level = 3 | ||
| 217 | elif composite_score < 4.5: | ||
| 218 | energy_level = 4 | ||
| 219 | else: | ||
| 220 | energy_level = 5 | ||
| 221 | |||
| 222 | details = { | ||
| 223 | "loudness_score": round(loudness_score, 2), | ||
| 224 | "brightness_score": round(brightness_score, 2), | ||
| 225 | "rhythm_score": round(rhythm_score, 2), | ||
| 226 | "tempo_factor": tempo_factor, | ||
| 227 | "composite_score": round(composite_score, 2), | ||
| 228 | } | ||
| 229 | |||
| 230 | return energy_level, details | ||
| 231 | |||
| 232 | |||
| 233 | def energy_level_to_string(level: int) -> str: | ||
| 234 | """ | ||
| 235 | 将能量级别数字转换为字符串描述 | ||
| 236 | |||
| 237 | Args: | ||
| 238 | level: 能量级别 (1-5) | ||
| 239 | |||
| 240 | Returns: | ||
| 241 | str: 能量密度描述 | ||
| 242 | """ | ||
| 243 | mapping = { | ||
| 244 | 1: "舒缓", | ||
| 245 | 2: "柔和", | ||
| 246 | 3: "律动", | ||
| 247 | 4: "强烈", | ||
| 248 | 5: "爆发", | ||
| 249 | } | ||
| 250 | return mapping.get(level, "律动") | ||
| 251 | |||
| 252 | |||
| 253 | @dataclass | ||
| 254 | class BeatInfo: | ||
| 255 | """节拍信息""" | ||
| 256 | beat_timestamps: List[float] # 所有节拍时间点 | ||
| 257 | downbeat_timestamps: List[float] # 强拍时间点(每小节第一拍) | ||
| 258 | tempo: float # BPM | ||
| 259 | beat_intervals: List[float] # 节拍间隔(用于检测节奏变化) | ||
| 260 | |||
| 261 | |||
| 262 | @dataclass | ||
| 263 | class EmotionCurve: | ||
| 264 | """情绪曲线数据""" | ||
| 265 | timestamps: List[float] # 时间点 | ||
| 266 | energy_values: List[float] # 能量值 (0-1) | ||
| 267 | valence_values: List[float] # 情绪效价 (0-1, 低=悲伤, 高=欢快) | ||
| 268 | arousal_values: List[float] # 情绪唤醒度 (0-1, 低=平静, 高=激动) | ||
| 269 | smoothed_curve: List[float] # 平滑后的综合情绪曲线 | ||
| 270 | |||
| 271 | |||
| 272 | @dataclass | ||
| 273 | class SegmentEmotion: | ||
| 274 | """段落情绪数据(与 songformer 段落对齐)""" | ||
| 275 | start: float # 段落开始时间 | ||
| 276 | end: float # 段落结束时间 | ||
| 277 | label: str # 段落标签 (intro/verse/chorus/bridge/outro) | ||
| 278 | intensity: float # 情绪强度 (0-1) | ||
| 279 | energy: float # 能量值 (0-1) | ||
| 280 | valence: float # 效价值 (0-1) | ||
| 281 | arousal: float # 唤醒度 (0-1) | ||
| 282 | trend: str # 情绪趋势 (rising/falling/stable/peak) | ||
| 283 | |||
| 284 | |||
| 285 | @dataclass | ||
| 286 | class BeatDensityInfo: | ||
| 287 | """节拍密度信息(用于分镜时长规划)""" | ||
| 288 | segment_label: str # 段落标签 | ||
| 289 | start: float # 开始时间 | ||
| 290 | end: float # 结束时间 | ||
| 291 | beat_count: int # 节拍数 | ||
| 292 | avg_interval: float # 平均间隔(秒) | ||
| 293 | density_level: str # sparse/normal/dense/very_dense | ||
| 294 | recommended_shot_duration: str # 推荐分镜时长 | ||
| 295 | |||
| 296 | |||
| 297 | @dataclass | ||
| 298 | class EnhancedClimaxInfo: | ||
| 299 | """增强高潮点信息(包含铺垫/持续/缓冲时长)""" | ||
| 300 | time: float # 高潮时间点 | ||
| 301 | intensity: str # strong/strongest | ||
| 302 | buildup_start: float # 铺垫开始时间 | ||
| 303 | buildup_duration: float # 铺垫时长(秒) | ||
| 304 | climax_duration: float # 高潮持续时长(秒) | ||
| 305 | winddown_duration: float # 缓冲时长(秒) | ||
| 306 | |||
| 307 | |||
| 308 | def extract_beat_timestamps(audio_path: str) -> BeatInfo: | ||
| 309 | """ | ||
| 310 | 提取节拍时间戳(卡点) | ||
| 311 | |||
| 312 | 使用智能 BPM 检测(带倍频纠正) | ||
| 313 | |||
| 314 | Args: | ||
| 315 | audio_path: 音频文件路径 | ||
| 316 | |||
| 317 | Returns: | ||
| 318 | BeatInfo: 节拍信息对象 | ||
| 319 | """ | ||
| 320 | y, sr = librosa.load(audio_path, sr=22050, mono=True) | ||
| 321 | |||
| 322 | # 使用统一 BPM 分析入口(带倍频纠正 + beat_times) | ||
| 323 | bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False) | ||
| 324 | bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr) | ||
| 325 | corrected_tempo = bpm_result.get('bpm', 120.0) | ||
| 326 | |||
| 327 | # beat_times 已经由 analyze_bpm 根据 BPM 减半情况做了抽样处理 | ||
| 328 | beat_times = np.array(bpm_result.get('beat_times', [])) | ||
| 329 | |||
| 330 | # 强拍检测(每4拍取第1拍,假设4/4拍) | ||
| 331 | downbeat_times = beat_times[::4].tolist() if len(beat_times) > 0 else [] | ||
| 332 | |||
| 333 | # 计算节拍间隔 | ||
| 334 | beat_intervals = np.diff(beat_times).tolist() if len(beat_times) > 1 else [] | ||
| 335 | |||
| 336 | return BeatInfo( | ||
| 337 | beat_timestamps=beat_times.tolist(), | ||
| 338 | downbeat_timestamps=downbeat_times, | ||
| 339 | tempo=corrected_tempo, | ||
| 340 | beat_intervals=beat_intervals, | ||
| 341 | ) | ||
| 342 | |||
| 343 | |||
| 344 | def extract_emotion_curve( | ||
| 345 | audio_path: str, | ||
| 346 | window_size: float = 2.0, # 窗口大小(秒) | ||
| 347 | hop_size: float = 0.5 # 步长(秒) | ||
| 348 | ) -> EmotionCurve: | ||
| 349 | """ | ||
| 350 | 提取情绪曲线 | ||
| 351 | |||
| 352 | 基于音频特征推断情绪: | ||
| 353 | - Energy (能量): RMS 能量 → 情绪强度 | ||
| 354 | - Valence (效价): 频谱质心 + 大小调 → 正面/负面情绪 | ||
| 355 | - Arousal (唤醒度): 节奏密度 + 能量变化 → 激动/平静 | ||
| 356 | |||
| 357 | Args: | ||
| 358 | audio_path: 音频文件路径 | ||
| 359 | window_size: 滑动窗口大小(秒) | ||
| 360 | hop_size: 滑动步长(秒) | ||
| 361 | |||
| 362 | Returns: | ||
| 363 | EmotionCurve: 情绪曲线数据对象 | ||
| 364 | """ | ||
| 365 | y, sr = librosa.load(audio_path, sr=None, mono=True) | ||
| 366 | |||
| 367 | timestamps: List[float] = [] | ||
| 368 | energy_values: List[float] = [] | ||
| 369 | valence_values: List[float] = [] | ||
| 370 | arousal_values: List[float] = [] | ||
| 371 | |||
| 372 | # 滑动窗口分析 | ||
| 373 | window_samples = int(window_size * sr) | ||
| 374 | hop_samples = int(hop_size * sr) | ||
| 375 | |||
| 376 | for start_sample in range(0, len(y) - window_samples, hop_samples): | ||
| 377 | end_sample = start_sample + window_samples | ||
| 378 | y_window = y[start_sample:end_sample] | ||
| 379 | t = start_sample / sr | ||
| 380 | |||
| 381 | timestamps.append(t) | ||
| 382 | |||
| 383 | # 1. Energy: RMS 能量归一化 | ||
| 384 | rms = np.sqrt(np.mean(y_window ** 2)) | ||
| 385 | energy = min(rms / 0.1, 1.0) # 归一化到 0-1 | ||
| 386 | energy_values.append(float(energy)) | ||
| 387 | |||
| 388 | # 2. Valence: 基于频谱质心(高=明亮=正面) | ||
| 389 | centroid = librosa.feature.spectral_centroid(y=y_window, sr=sr)[0] | ||
| 390 | valence = min(np.mean(centroid) / 4000, 1.0) | ||
| 391 | valence_values.append(float(valence)) | ||
| 392 | |||
| 393 | # 3. Arousal: 基于 onset 密度和能量变化 | ||
| 394 | onset_env = librosa.onset.onset_strength(y=y_window, sr=sr) | ||
| 395 | arousal = min(np.mean(onset_env) / 2.0, 1.0) | ||
| 396 | arousal_values.append(float(arousal)) | ||
| 397 | |||
| 398 | # 4. 综合情绪曲线(加权平均) | ||
| 399 | smoothed: List[float] = [] | ||
| 400 | for i in range(len(timestamps)): | ||
| 401 | # 权重:能量 40%, 唤醒度 40%, 效价 20% | ||
| 402 | combined = ( | ||
| 403 | energy_values[i] * 0.4 + | ||
| 404 | arousal_values[i] * 0.4 + | ||
| 405 | valence_values[i] * 0.2 | ||
| 406 | ) | ||
| 407 | smoothed.append(combined) | ||
| 408 | |||
| 409 | # 平滑处理(移动平均) | ||
| 410 | if len(smoothed) >= 3: | ||
| 411 | smoothed = np.convolve(smoothed, np.ones(3)/3, mode='same').tolist() | ||
| 412 | |||
| 413 | return EmotionCurve( | ||
| 414 | timestamps=timestamps, | ||
| 415 | energy_values=energy_values, | ||
| 416 | valence_values=valence_values, | ||
| 417 | arousal_values=arousal_values, | ||
| 418 | smoothed_curve=smoothed, | ||
| 419 | ) | ||
| 420 | |||
| 421 | |||
| 422 | def aggregate_emotion_by_segments( | ||
| 423 | emotion_curve: EmotionCurve, | ||
| 424 | segments: List[Dict[str, Any]], | ||
| 425 | ) -> List[SegmentEmotion]: | ||
| 426 | """ | ||
| 427 | 将情绪曲线按 songformer 段落结构聚合 | ||
| 428 | |||
| 429 | Args: | ||
| 430 | emotion_curve: 原始情绪曲线数据 | ||
| 431 | segments: songformer 返回的段落列表,格式为: | ||
| 432 | [{"start": 0.0, "end": 30.5, "label": "intro"}, ...] | ||
| 433 | |||
| 434 | Returns: | ||
| 435 | List[SegmentEmotion]: 按段落聚合的情绪数据 | ||
| 436 | """ | ||
| 437 | if not segments or not emotion_curve.timestamps: | ||
| 438 | return [] | ||
| 439 | |||
| 440 | result: List[SegmentEmotion] = [] | ||
| 441 | timestamps = np.array(emotion_curve.timestamps) | ||
| 442 | energy_values = np.array(emotion_curve.energy_values) | ||
| 443 | valence_values = np.array(emotion_curve.valence_values) | ||
| 444 | arousal_values = np.array(emotion_curve.arousal_values) | ||
| 445 | smoothed_values = np.array(emotion_curve.smoothed_curve) | ||
| 446 | |||
| 447 | for seg in segments: | ||
| 448 | start = float(seg.get("start", 0)) | ||
| 449 | end = float(seg.get("end", 0)) | ||
| 450 | label = str(seg.get("label", "unknown")) | ||
| 451 | |||
| 452 | # 找出该段落内的数据点索引 | ||
| 453 | mask = (timestamps >= start) & (timestamps < end) | ||
| 454 | indices = np.where(mask)[0] | ||
| 455 | |||
| 456 | if len(indices) == 0: | ||
| 457 | # 没有数据点落在该段落内,使用默认值 | ||
| 458 | result.append(SegmentEmotion( | ||
| 459 | start=start, | ||
| 460 | end=end, | ||
| 461 | label=label, | ||
| 462 | intensity=0.5, | ||
| 463 | energy=0.5, | ||
| 464 | valence=0.5, | ||
| 465 | arousal=0.5, | ||
| 466 | trend="stable", | ||
| 467 | )) | ||
| 468 | continue | ||
| 469 | |||
| 470 | # 计算该段落的平均值 | ||
| 471 | seg_energy = float(np.mean(energy_values[indices])) | ||
| 472 | seg_valence = float(np.mean(valence_values[indices])) | ||
| 473 | seg_arousal = float(np.mean(arousal_values[indices])) | ||
| 474 | seg_intensity = float(np.mean(smoothed_values[indices])) | ||
| 475 | |||
| 476 | # 计算情绪趋势 | ||
| 477 | seg_smoothed = smoothed_values[indices] | ||
| 478 | trend = _calculate_trend(seg_smoothed, seg_intensity) | ||
| 479 | |||
| 480 | result.append(SegmentEmotion( | ||
| 481 | start=start, | ||
| 482 | end=end, | ||
| 483 | label=label, | ||
| 484 | intensity=round(seg_intensity, 3), | ||
| 485 | energy=round(seg_energy, 3), | ||
| 486 | valence=round(seg_valence, 3), | ||
| 487 | arousal=round(seg_arousal, 3), | ||
| 488 | trend=trend, | ||
| 489 | )) | ||
| 490 | |||
| 491 | return result | ||
| 492 | |||
| 493 | |||
| 494 | def _calculate_trend(values: np.ndarray, avg_intensity: float) -> str: | ||
| 495 | """ | ||
| 496 | 计算情绪趋势 | ||
| 497 | |||
| 498 | Args: | ||
| 499 | values: 该段落内的情绪值数组 | ||
| 500 | avg_intensity: 平均情绪强度 | ||
| 501 | |||
| 502 | Returns: | ||
| 503 | str: rising/falling/stable/peak | ||
| 504 | """ | ||
| 505 | if len(values) < 3: | ||
| 506 | return "stable" | ||
| 507 | |||
| 508 | # 将段落分成前半和后半 | ||
| 509 | mid = len(values) // 2 | ||
| 510 | first_half_avg = float(np.mean(values[:mid])) | ||
| 511 | second_half_avg = float(np.mean(values[mid:])) | ||
| 512 | |||
| 513 | diff = second_half_avg - first_half_avg | ||
| 514 | threshold = 0.05 # 5% 变化阈值 | ||
| 515 | |||
| 516 | # 检查是否是高峰(平均强度高且变化不大) | ||
| 517 | if avg_intensity > 0.7 and abs(diff) < threshold: | ||
| 518 | return "peak" | ||
| 519 | |||
| 520 | if diff > threshold: | ||
| 521 | return "rising" | ||
| 522 | elif diff < -threshold: | ||
| 523 | return "falling" | ||
| 524 | else: | ||
| 525 | return "stable" | ||
| 526 | |||
| 527 | |||
| 528 | def extract_segment_emotions( | ||
| 529 | audio_path: str, | ||
| 530 | segments: List[Dict[str, Any]], | ||
| 531 | ) -> List[SegmentEmotion]: | ||
| 532 | """ | ||
| 533 | 一站式提取按段落聚合的情绪数据 | ||
| 534 | |||
| 535 | Args: | ||
| 536 | audio_path: 音频文件路径 | ||
| 537 | segments: songformer 返回的段落列表 | ||
| 538 | |||
| 539 | Returns: | ||
| 540 | List[SegmentEmotion]: 按段落聚合的情绪数据 | ||
| 541 | """ | ||
| 542 | emotion_curve = extract_emotion_curve(audio_path) | ||
| 543 | return aggregate_emotion_by_segments(emotion_curve, segments) | ||
| 544 | |||
| 545 | |||
| 546 | def calculate_beat_density_by_segments( | ||
| 547 | beat_timestamps: List[float], | ||
| 548 | segments: List[Dict[str, Any]], | ||
| 549 | tempo: float = 120.0, | ||
| 550 | ) -> List[BeatDensityInfo]: | ||
| 551 | """ | ||
| 552 | 按段落计算节拍密度,用于指导分镜时长规划 | ||
| 553 | |||
| 554 | Args: | ||
| 555 | beat_timestamps: 节拍时间戳列表 | ||
| 556 | segments: songformer 返回的段落列表,格式为: | ||
| 557 | [{"start": 0.0, "end": 30.5, "label": "intro"}, ...] | ||
| 558 | tempo: BPM(用于辅助判断密度级别) | ||
| 559 | |||
| 560 | Returns: | ||
| 561 | List[BeatDensityInfo]: 按段落的节拍密度信息 | ||
| 562 | """ | ||
| 563 | if not segments or not beat_timestamps: | ||
| 564 | return [] | ||
| 565 | |||
| 566 | result: List[BeatDensityInfo] = [] | ||
| 567 | beat_array = np.array(beat_timestamps) | ||
| 568 | |||
| 569 | for seg in segments: | ||
| 570 | start = float(seg.get("start", 0)) | ||
| 571 | end = float(seg.get("end", 0)) | ||
| 572 | label = str(seg.get("label", "unknown")) | ||
| 573 | |||
| 574 | # 找出该段落内的节拍 | ||
| 575 | mask = (beat_array >= start) & (beat_array < end) | ||
| 576 | segment_beats = beat_array[mask] | ||
| 577 | beat_count = len(segment_beats) | ||
| 578 | |||
| 579 | # 计算平均间隔 | ||
| 580 | if beat_count >= 2: | ||
| 581 | intervals = np.diff(segment_beats) | ||
| 582 | avg_interval = float(np.mean(intervals)) | ||
| 583 | elif beat_count == 1: | ||
| 584 | # 只有一个节拍,使用 BPM 估算 | ||
| 585 | avg_interval = 60.0 / tempo | ||
| 586 | else: | ||
| 587 | # 没有节拍,使用默认值 | ||
| 588 | avg_interval = 60.0 / tempo | ||
| 589 | |||
| 590 | # 根据平均间隔和 BPM 判断密度级别 | ||
| 591 | # 间隔越小 = 密度越高 | ||
| 592 | if avg_interval <= 0.3 or tempo >= 160: | ||
| 593 | density_level = "very_dense" | ||
| 594 | recommended_shot_duration = "2-4秒" | ||
| 595 | elif avg_interval <= 0.45 or tempo >= 130: | ||
| 596 | density_level = "dense" | ||
| 597 | recommended_shot_duration = "3-5秒" | ||
| 598 | elif avg_interval <= 0.6 or tempo >= 100: | ||
| 599 | density_level = "normal" | ||
| 600 | recommended_shot_duration = "4-6秒" | ||
| 601 | else: | ||
| 602 | density_level = "sparse" | ||
| 603 | recommended_shot_duration = "6-10秒" | ||
| 604 | |||
| 605 | result.append(BeatDensityInfo( | ||
| 606 | segment_label=label, | ||
| 607 | start=round(start, 2), | ||
| 608 | end=round(end, 2), | ||
| 609 | beat_count=beat_count, | ||
| 610 | avg_interval=round(avg_interval, 3), | ||
| 611 | density_level=density_level, | ||
| 612 | recommended_shot_duration=recommended_shot_duration, | ||
| 613 | )) | ||
| 614 | |||
| 615 | return result | ||
| 616 | |||
| 617 | |||
| 618 | def enhance_climax_points( | ||
| 619 | climax_points: List[Dict[str, Any]], | ||
| 620 | segments: List[Dict[str, Any]], | ||
| 621 | music_duration: float, | ||
| 622 | ) -> List[EnhancedClimaxInfo]: | ||
| 623 | """ | ||
| 624 | 增强高潮点信息,添加铺垫/持续/缓冲时长指导 | ||
| 625 | |||
| 626 | Args: | ||
| 627 | climax_points: 原始高潮点列表,格式为: | ||
| 628 | [{"time": 60.0, "intensity": "strong"}, ...] | ||
| 629 | segments: songformer 返回的段落列表 | ||
| 630 | music_duration: 音乐总时长(秒) | ||
| 631 | |||
| 632 | Returns: | ||
| 633 | List[EnhancedClimaxInfo]: 增强后的高潮点信息 | ||
| 634 | """ | ||
| 635 | if not climax_points: | ||
| 636 | return [] | ||
| 637 | |||
| 638 | result: List[EnhancedClimaxInfo] = [] | ||
| 639 | |||
| 640 | # 按时间排序高潮点 | ||
| 641 | sorted_climax = sorted(climax_points, key=lambda x: float(x.get("time", 0))) | ||
| 642 | |||
| 643 | for i, climax in enumerate(sorted_climax): | ||
| 644 | time = float(climax.get("time", 0)) | ||
| 645 | intensity = str(climax.get("intensity", "strong")) | ||
| 646 | |||
| 647 | # 根据强度确定时长参数 | ||
| 648 | if intensity == "strongest": | ||
| 649 | buildup_duration = 10.0 # 最强高潮:更长的铺垫 | ||
| 650 | climax_duration = 20.0 # 更长的高潮持续 | ||
| 651 | winddown_duration = 10.0 # 更长的缓冲 | ||
| 652 | else: | ||
| 653 | buildup_duration = 5.0 # 普通高潮 | ||
| 654 | climax_duration = 10.0 | ||
| 655 | winddown_duration = 5.0 | ||
| 656 | |||
| 657 | # 计算铺垫开始时间(不能小于0或前一个高潮的结束) | ||
| 658 | buildup_start = max(0, time - buildup_duration) | ||
| 659 | |||
| 660 | # 如果有前一个高潮点,确保不重叠 | ||
| 661 | if i > 0: | ||
| 662 | prev_climax_time = float(sorted_climax[i - 1].get("time", 0)) | ||
| 663 | prev_intensity = str(sorted_climax[i - 1].get("intensity", "strong")) | ||
| 664 | prev_winddown = 10.0 if prev_intensity == "strongest" else 5.0 | ||
| 665 | prev_end = prev_climax_time + prev_winddown | ||
| 666 | |||
| 667 | if buildup_start < prev_end: | ||
| 668 | # 调整铺垫开始时间,避免重叠 | ||
| 669 | buildup_start = prev_end | ||
| 670 | buildup_duration = time - buildup_start | ||
| 671 | |||
| 672 | # 确保高潮持续+缓冲不超过音乐结束 | ||
| 673 | if time + climax_duration + winddown_duration > music_duration: | ||
| 674 | # 按比例缩减 | ||
| 675 | remaining = music_duration - time | ||
| 676 | if remaining > 0: | ||
| 677 | ratio = remaining / (climax_duration + winddown_duration) | ||
| 678 | climax_duration = climax_duration * ratio | ||
| 679 | winddown_duration = winddown_duration * ratio | ||
| 680 | |||
| 681 | result.append(EnhancedClimaxInfo( | ||
| 682 | time=round(time, 2), | ||
| 683 | intensity=intensity, | ||
| 684 | buildup_start=round(buildup_start, 2), | ||
| 685 | buildup_duration=round(buildup_duration, 2), | ||
| 686 | climax_duration=round(climax_duration, 2), | ||
| 687 | winddown_duration=round(winddown_duration, 2), | ||
| 688 | )) | ||
| 689 | |||
| 690 | return result | ||
| 691 | |||
| 692 | |||
| 693 | def format_beat_density_for_prompt(beat_density_list: List[BeatDensityInfo]) -> str: | ||
| 694 | """ | ||
| 695 | 将节拍密度信息格式化为提示词文本 | ||
| 696 | |||
| 697 | Args: | ||
| 698 | beat_density_list: 节拍密度信息列表 | ||
| 699 | |||
| 700 | Returns: | ||
| 701 | str: 格式化的文本 | ||
| 702 | """ | ||
| 703 | if not beat_density_list: | ||
| 704 | return "(无节拍密度数据)" | ||
| 705 | |||
| 706 | lines = [] | ||
| 707 | for info in beat_density_list: | ||
| 708 | lines.append( | ||
| 709 | f"- [{info.segment_label}] {info.start:.1f}s-{info.end:.1f}s: " | ||
| 710 | f"节拍数={info.beat_count}, 平均间隔={info.avg_interval:.2f}s, " | ||
| 711 | f"密度={info.density_level}, 推荐分镜时长={info.recommended_shot_duration}" | ||
| 712 | ) | ||
| 713 | return "\n".join(lines) | ||
| 714 | |||
| 715 | |||
| 716 | def format_enhanced_climax_for_prompt(enhanced_climax_list: List[EnhancedClimaxInfo]) -> str: | ||
| 717 | """ | ||
| 718 | 将增强高潮点信息格式化为提示词文本 | ||
| 719 | |||
| 720 | Args: | ||
| 721 | enhanced_climax_list: 增强高潮点信息列表 | ||
| 722 | |||
| 723 | Returns: | ||
| 724 | str: 格式化的文本 | ||
| 725 | """ | ||
| 726 | if not enhanced_climax_list: | ||
| 727 | return "(无高潮点数据)" | ||
| 728 | |||
| 729 | lines = [] | ||
| 730 | for info in enhanced_climax_list: | ||
| 731 | lines.append( | ||
| 732 | f"- 高潮点 {info.time:.1f}s ({info.intensity}):\n" | ||
| 733 | f" · 铺垫阶段: {info.buildup_start:.1f}s - {info.time:.1f}s (约{info.buildup_duration:.1f}秒)\n" | ||
| 734 | f" · 高潮阶段: {info.time:.1f}s - {info.time + info.climax_duration:.1f}s (约{info.climax_duration:.1f}秒)\n" | ||
| 735 | f" · 缓冲阶段: {info.time + info.climax_duration:.1f}s - {info.time + info.climax_duration + info.winddown_duration:.1f}s (约{info.winddown_duration:.1f}秒)" | ||
| 736 | ) | ||
| 737 | return "\n".join(lines) |
app/middleware/music_analyze/base.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 音乐分析器抽象基类 | ||
| 4 | 定义统一的分析器接口 | ||
| 5 | """ | ||
| 6 | |||
| 7 | from abc import ABC, abstractmethod | ||
| 8 | from typing import Dict, Optional, Any, List, Set | ||
| 9 | |||
| 10 | |||
| 11 | # 字典定义:所有有效的字段值 | ||
| 12 | VALID_GENRES: Set[str] = { | ||
| 13 | "流行", | ||
| 14 | "电子/舞曲", | ||
| 15 | "摇滚/金属", | ||
| 16 | "说唱", | ||
| 17 | "民谣/原声", | ||
| 18 | "国风", | ||
| 19 | "爵士/Soul", | ||
| 20 | "古典", | ||
| 21 | "轻音乐/Ambient", | ||
| 22 | "二次元/ACG", | ||
| 23 | "其它", | ||
| 24 | } | ||
| 25 | |||
| 26 | VALID_SUB_GENRES: Dict[str, Set[str]] = { | ||
| 27 | "流行": {"华语流行", "欧美流行", "日韩流行", "R&B", "抒情"}, | ||
| 28 | "电子/舞曲": {"House", "Future Bass", "Dubstep", "Synthwave", "Trance", "Techno"}, | ||
| 29 | "摇滚/金属": {"流行摇滚", "独立摇滚", "重金属", "朋克", "后摇"}, | ||
| 30 | "说唱": {"Trap", "Old School", "Boombap", "Melodic Rap", "中文说唱"}, | ||
| 31 | "民谣/原声": {"城市民谣", "校园民谣", "故事民谣", "乡村", "Indie Folk"}, | ||
| 32 | "国风": {"古风", "戏腔", "新中式", "水墨风", "国潮"}, | ||
| 33 | "爵士/Soul": {"传统爵士", "Smooth Jazz", "Fusion", "Neo-Soul", "Blues"}, | ||
| 34 | "古典": {"管弦乐", "钢琴曲", "协奏曲", "室内乐", "歌剧"}, | ||
| 35 | "轻音乐/Ambient": {"钢琴独奏", "Lo-fi", "冥想音乐", "氛围电子", "白噪音"}, | ||
| 36 | "二次元/ACG": {"动画OST", "Vocaloid", "游戏音乐", "萌系", "燃系"}, | ||
| 37 | "其它": {"世界音乐", "实验音乐", "儿歌", "戏曲", "网络热歌"}, | ||
| 38 | } | ||
| 39 | |||
| 40 | VALID_LANGUAGES: Set[str] = { | ||
| 41 | "普通话", | ||
| 42 | "粤语", | ||
| 43 | "英语", | ||
| 44 | "韩语", | ||
| 45 | "闽南语", | ||
| 46 | "蒙语", | ||
| 47 | "俄语", | ||
| 48 | "藏语", | ||
| 49 | "其他", | ||
| 50 | } | ||
| 51 | |||
| 52 | LANGUAGE_MAPPING: Dict[str, str] = { | ||
| 53 | "国语": "普通话", | ||
| 54 | "中文": "普通话", | ||
| 55 | "汉语": "普通话", | ||
| 56 | "普通话": "普通话", | ||
| 57 | "广东话": "粤语", | ||
| 58 | "粤语": "粤语", | ||
| 59 | "英文": "英语", | ||
| 60 | "英语": "英语", | ||
| 61 | "韩文": "韩语", | ||
| 62 | "朝鲜语": "韩语", | ||
| 63 | "韩语": "韩语", | ||
| 64 | "闽南话": "闽南语", | ||
| 65 | "台语": "闽南语", | ||
| 66 | "闽南语": "闽南语", | ||
| 67 | "蒙语": "蒙语", | ||
| 68 | "蒙古语": "蒙语", | ||
| 69 | "俄文": "俄语", | ||
| 70 | "俄语": "俄语", | ||
| 71 | "藏文": "藏语", | ||
| 72 | "藏语": "藏语", | ||
| 73 | "其它": "其他", | ||
| 74 | "其他": "其他", | ||
| 75 | "地方语言": "其他", | ||
| 76 | "日语": "其他", | ||
| 77 | } | ||
| 78 | |||
| 79 | VALID_EMOTIONS: Set[str] = { | ||
| 80 | "喜庆", | ||
| 81 | "浪漫", | ||
| 82 | "雄壮", | ||
| 83 | "庄重", | ||
| 84 | "激情", | ||
| 85 | "快乐", | ||
| 86 | "励志", | ||
| 87 | "期待", | ||
| 88 | "甜蜜", | ||
| 89 | "感动", | ||
| 90 | "搞笑", | ||
| 91 | "祝福", | ||
| 92 | "温暖", | ||
| 93 | "宣泄", | ||
| 94 | "悲壮", | ||
| 95 | "愤怒", | ||
| 96 | "沉重", | ||
| 97 | "思念", | ||
| 98 | "紧张", | ||
| 99 | "恐怖", | ||
| 100 | "孤独", | ||
| 101 | "伤感", | ||
| 102 | "忧郁", | ||
| 103 | "蛊惑", | ||
| 104 | "恶搞", | ||
| 105 | "怀念", | ||
| 106 | "悬疑", | ||
| 107 | "佛系", | ||
| 108 | "舒缓", | ||
| 109 | "悠扬", | ||
| 110 | } | ||
| 111 | |||
| 112 | VALID_SCENES: Set[str] = { | ||
| 113 | "餐厅", | ||
| 114 | "汽车", | ||
| 115 | "跳舞", | ||
| 116 | "旅行", | ||
| 117 | "工作", | ||
| 118 | "校园", | ||
| 119 | "夜店", | ||
| 120 | "运动", | ||
| 121 | "休闲", | ||
| 122 | "live house", | ||
| 123 | "广场舞", | ||
| 124 | "抖音", | ||
| 125 | "婚礼", | ||
| 126 | "约会", | ||
| 127 | } | ||
| 128 | |||
| 129 | VALID_DOUYIN_TAGS: Set[str] = { | ||
| 130 | "草原", | ||
| 131 | "故乡", | ||
| 132 | "神曲", | ||
| 133 | "文艺", | ||
| 134 | "青春", | ||
| 135 | "治愈系", | ||
| 136 | "清新", | ||
| 137 | "奇幻", | ||
| 138 | } | ||
| 139 | |||
| 140 | VALID_MUSIC_STYLE_TAGS: Set[str] = { | ||
| 141 | "世界音乐", | ||
| 142 | "雷鬼", | ||
| 143 | "R&B/Soul", | ||
| 144 | "MC喊麦", | ||
| 145 | "另类音乐", | ||
| 146 | "民歌", | ||
| 147 | "戏曲", | ||
| 148 | "古风", | ||
| 149 | "古典音乐", | ||
| 150 | "HipHop", | ||
| 151 | "Rap", | ||
| 152 | "摇滚", | ||
| 153 | "DJ嗨曲", | ||
| 154 | "布鲁斯/蓝调", | ||
| 155 | "拉丁", | ||
| 156 | "舞曲", | ||
| 157 | "爵士", | ||
| 158 | "乡村", | ||
| 159 | "民谣", | ||
| 160 | "流行", | ||
| 161 | "轻音乐", | ||
| 162 | "国风", | ||
| 163 | "儿歌", | ||
| 164 | } | ||
| 165 | |||
| 166 | VALID_INSTRUMENT_TAGS: Set[str] = { | ||
| 167 | "二胡", | ||
| 168 | "竹笛", | ||
| 169 | "琵琶", | ||
| 170 | "音效", | ||
| 171 | "口琴", | ||
| 172 | "电子", | ||
| 173 | "木吉他", | ||
| 174 | "鼓组", | ||
| 175 | "弦乐", | ||
| 176 | "电吉他", | ||
| 177 | "古筝", | ||
| 178 | "钢琴", | ||
| 179 | } | ||
| 180 | |||
| 181 | VALID_AGES: Set[str] = {"少年", "青年", "中年", "老年", "全年龄段"} | ||
| 182 | |||
| 183 | VALID_RHYTHM_INTENSITIES: Set[str] = {"极慢", "慢", "中", "快", "极速"} | ||
| 184 | |||
| 185 | VALID_EMOTIONAL_INTENSITIES: Set[str] = {"平缓", "中等", "强烈"} | ||
| 186 | |||
| 187 | VALID_VOICE_TYPES: Set[str] = {"男声", "女声", "童声", "合唱", "无人声"} | ||
| 188 | VALID_PERFORMER_TYPES: Set[str] = {"男声", "女声", "童声", "合唱"} | ||
| 189 | |||
| 190 | # sub_genre 常见变体映射 | ||
| 191 | SUB_GENRE_MAPPING: Dict[str, str] = { | ||
| 192 | "韩语流行": "日韩流行", | ||
| 193 | "韩国流行": "日韩流行", | ||
| 194 | "K-Pop": "日韩流行", | ||
| 195 | "K-pop": "日韩流行", | ||
| 196 | "Kpop": "日韩流行", | ||
| 197 | "韩流": "日韩流行", | ||
| 198 | "日语流行": "日韩流行", | ||
| 199 | "日本流行": "日韩流行", | ||
| 200 | "J-Pop": "日韩流行", | ||
| 201 | "J-pop": "日韩流行", | ||
| 202 | "Jpop": "日韩流行", | ||
| 203 | "中文流行": "华语流行", | ||
| 204 | "国语流行": "华语流行", | ||
| 205 | "中国流行": "华语流行", | ||
| 206 | "英语流行": "欧美流行", | ||
| 207 | "英文流行": "欧美流行", | ||
| 208 | "西方流行": "欧美流行", | ||
| 209 | "Pop": "欧美流行", | ||
| 210 | } | ||
| 211 | |||
| 212 | |||
| 213 | class AudioAnalyzer(ABC): | ||
| 214 | """音乐音频分析器抽象基类""" | ||
| 215 | |||
| 216 | @abstractmethod | ||
| 217 | def get_provider_name(self) -> str: | ||
| 218 | """获取提供商名称(如 qwen, doubao)""" | ||
| 219 | pass | ||
| 220 | |||
| 221 | @abstractmethod | ||
| 222 | def get_model_name(self) -> str: | ||
| 223 | """获取模型名称""" | ||
| 224 | pass | ||
| 225 | |||
| 226 | @abstractmethod | ||
| 227 | def analyze( | ||
| 228 | self, | ||
| 229 | metadata: Dict[str, Any], | ||
| 230 | music_url: str, | ||
| 231 | extract_lyrics: bool = False, | ||
| 232 | label_level: int = 0, | ||
| 233 | ) -> Optional[Dict[str, Any]]: | ||
| 234 | """ | ||
| 235 | 分析音乐并返回标签结果 | ||
| 236 | |||
| 237 | Args: | ||
| 238 | metadata: 音乐元数据字典 | ||
| 239 | music_url: 音乐文件 URL(支持音频 URL 或 Base64 编码) | ||
| 240 | extract_lyrics: 是否识别歌词 | ||
| 241 | label_level: 标签级别(0: 一级标签, 1: 一级+二级标签) | ||
| 242 | |||
| 243 | Returns: | ||
| 244 | 标准化分析结果字典,包含以下字段: | ||
| 245 | - genre: 音乐风格(一级风格,如:流行、摇滚) | ||
| 246 | - emotion: 情绪列表 | ||
| 247 | - emotional_intensity: 情绪强度 | ||
| 248 | - vocal_texture: 人声质感 | ||
| 249 | - vocal_description: 人声质感描述 | ||
| 250 | - visual_concept: 视觉概念 | ||
| 251 | - language: 语种 | ||
| 252 | - bpm: 节拍数(可选) | ||
| 253 | - lyrics: 歌词列表(可选,仅当 extract_lyrics=True 时) | ||
| 254 | - _model: 使用的模型名称 | ||
| 255 | - _token_info: Token 使用信息 | ||
| 256 | """ | ||
| 257 | pass | ||
| 258 | |||
| 259 | def _parse_response(self, response_text: str) -> Optional[Dict[str, Any]]: | ||
| 260 | """ | ||
| 261 | 解析 LLM 返回的响应文本为 JSON | ||
| 262 | |||
| 263 | Args: | ||
| 264 | response_text: LLM 返回的原始文本 | ||
| 265 | |||
| 266 | Returns: | ||
| 267 | 解析后的字典,解析失败返回 None | ||
| 268 | """ | ||
| 269 | import re | ||
| 270 | import json | ||
| 271 | import logging | ||
| 272 | |||
| 273 | logger = logging.getLogger(__name__) | ||
| 274 | |||
| 275 | if not response_text: | ||
| 276 | return None | ||
| 277 | |||
| 278 | # 打印原始响应用于调试 | ||
| 279 | logger.info(f"[_parse_response] 原始响应文本:\n{response_text[:500]}...") | ||
| 280 | |||
| 281 | cleaned_text = response_text.strip() | ||
| 282 | |||
| 283 | # 移除 markdown 代码块标记 | ||
| 284 | if cleaned_text.startswith("```json"): | ||
| 285 | cleaned_text = cleaned_text[7:] | ||
| 286 | elif cleaned_text.startswith("```"): | ||
| 287 | cleaned_text = cleaned_text[3:] | ||
| 288 | |||
| 289 | if cleaned_text.endswith("```"): | ||
| 290 | cleaned_text = cleaned_text[:-3] | ||
| 291 | |||
| 292 | cleaned_text = cleaned_text.strip() | ||
| 293 | |||
| 294 | # 提取 JSON 对象 | ||
| 295 | try: | ||
| 296 | # 尝试直接解析 | ||
| 297 | result = json.loads(cleaned_text) | ||
| 298 | if isinstance(result, dict): | ||
| 299 | logger.info(f"[_parse_response] 解析成功,字段: {list(result.keys())}") | ||
| 300 | elif isinstance(result, list): | ||
| 301 | logger.info(f"[_parse_response] 解析成功,列表长度: {len(result)}") | ||
| 302 | else: | ||
| 303 | logger.info( | ||
| 304 | f"[_parse_response] 解析成功,类型: {type(result).__name__}" | ||
| 305 | ) | ||
| 306 | return result | ||
| 307 | except json.JSONDecodeError: | ||
| 308 | pass | ||
| 309 | |||
| 310 | # 尝试提取 {...} 中的内容 | ||
| 311 | try: | ||
| 312 | match = re.search(r"\{.*\}", cleaned_text, re.DOTALL) | ||
| 313 | if match: | ||
| 314 | json_str = match.group() | ||
| 315 | result = json.loads(json_str) | ||
| 316 | if isinstance(result, dict): | ||
| 317 | logger.info( | ||
| 318 | f"[_parse_response] 正则提取解析成功,字段: {list(result.keys())}" | ||
| 319 | ) | ||
| 320 | elif isinstance(result, list): | ||
| 321 | logger.info( | ||
| 322 | f"[_parse_response] 正则提取解析成功,列表长度: {len(result)}" | ||
| 323 | ) | ||
| 324 | else: | ||
| 325 | logger.info( | ||
| 326 | "[_parse_response] 正则提取解析成功,类型: %s", | ||
| 327 | type(result).__name__, | ||
| 328 | ) | ||
| 329 | return result | ||
| 330 | except (re.error, json.JSONDecodeError): | ||
| 331 | pass | ||
| 332 | |||
| 333 | # 尝试修复常见的 JSON 格式问题 | ||
| 334 | try: | ||
| 335 | fixed_text = re.sub(r",(\s*})", r"\1", cleaned_text) | ||
| 336 | fixed_text = re.sub(r",(\s*])", r"\1", fixed_text) | ||
| 337 | result = json.loads(fixed_text) | ||
| 338 | if isinstance(result, dict): | ||
| 339 | logger.info( | ||
| 340 | f"[_parse_response] 修复后解析成功,字段: {list(result.keys())}" | ||
| 341 | ) | ||
| 342 | elif isinstance(result, list): | ||
| 343 | logger.info( | ||
| 344 | f"[_parse_response] 修复后解析成功,列表长度: {len(result)}" | ||
| 345 | ) | ||
| 346 | else: | ||
| 347 | logger.info( | ||
| 348 | "[_parse_response] 修复后解析成功,类型: %s", | ||
| 349 | type(result).__name__, | ||
| 350 | ) | ||
| 351 | return result | ||
| 352 | except (re.error, json.JSONDecodeError): | ||
| 353 | pass | ||
| 354 | |||
| 355 | logger.warning(f"[_parse_response] 所有解析方法都失败") | ||
| 356 | return None | ||
| 357 | |||
| 358 | def _normalize_result( | ||
| 359 | self, | ||
| 360 | raw_result: Dict[str, Any], | ||
| 361 | model_name: str, | ||
| 362 | token_info: Optional[Dict[str, int]] = None, | ||
| 363 | ) -> Dict[str, Any]: | ||
| 364 | """ | ||
| 365 | 标准化分析结果 | ||
| 366 | |||
| 367 | Args: | ||
| 368 | raw_result: 原始解析结果 | ||
| 369 | model_name: 使用的模型名称 | ||
| 370 | token_info: Token 使用信息 | ||
| 371 | |||
| 372 | Returns: | ||
| 373 | 标准化后的结果字典 | ||
| 374 | """ | ||
| 375 | import logging | ||
| 376 | |||
| 377 | logger = logging.getLogger(__name__) | ||
| 378 | |||
| 379 | if not isinstance(raw_result, dict): | ||
| 380 | if ( | ||
| 381 | isinstance(raw_result, list) | ||
| 382 | and raw_result | ||
| 383 | and isinstance(raw_result[0], dict) | ||
| 384 | ): | ||
| 385 | raw_result = raw_result[0] | ||
| 386 | else: | ||
| 387 | logger.warning( | ||
| 388 | f"[_normalize_result] 原始结果类型异常: {type(raw_result).__name__}" | ||
| 389 | ) | ||
| 390 | return {"_model": model_name, "_raw": raw_result} | ||
| 391 | |||
| 392 | logger.info(f"[_normalize_result] 原始结果字段: {list(raw_result.keys())}") | ||
| 393 | logger.info(f"[_normalize_result] genre: {raw_result.get('genre')}") | ||
| 394 | logger.info(f"[_normalize_result] emotion: {raw_result.get('emotion')}") | ||
| 395 | logger.info(f"[_normalize_result] scene: {raw_result.get('scene')}") | ||
| 396 | logger.info(f"[_normalize_result] token_info 参数: {token_info}") | ||
| 397 | |||
| 398 | def _extract_style(raw_style) -> Optional[Dict[str, str]]: | ||
| 399 | """提取音乐风格为标准格式""" | ||
| 400 | if isinstance(raw_style, dict): | ||
| 401 | return {"zh": raw_style.get("zh", ""), "en": raw_style.get("en", "")} | ||
| 402 | elif isinstance(raw_style, str): | ||
| 403 | # 字符串格式,直接使用作为中文名,英文名留空 | ||
| 404 | return {"zh": raw_style, "en": ""} | ||
| 405 | return None | ||
| 406 | |||
| 407 | def _extract_list_field(raw_value) -> list: | ||
| 408 | """提取列表字段""" | ||
| 409 | if isinstance(raw_value, list): | ||
| 410 | return [v for v in raw_value if v] | ||
| 411 | elif isinstance(raw_value, str): | ||
| 412 | import re | ||
| 413 | |||
| 414 | return [ | ||
| 415 | v.strip() | ||
| 416 | for v in re.split(r"[,,、/|]+", raw_value) | ||
| 417 | if v and v.strip() | ||
| 418 | ] | ||
| 419 | return [] | ||
| 420 | |||
| 421 | def _extract_single_field(raw_value) -> str: | ||
| 422 | """提取单值字段""" | ||
| 423 | if raw_value and isinstance(raw_value, str): | ||
| 424 | return raw_value | ||
| 425 | return "" | ||
| 426 | |||
| 427 | def _validate_and_map_sub_genre(sub_genre: str, genre: str) -> str: | ||
| 428 | """验证并映射 sub_genre 到有效值""" | ||
| 429 | if not sub_genre: | ||
| 430 | return "" | ||
| 431 | |||
| 432 | sub_genre = sub_genre.strip() | ||
| 433 | |||
| 434 | if sub_genre in SUB_GENRE_MAPPING: | ||
| 435 | mapped = SUB_GENRE_MAPPING[sub_genre] | ||
| 436 | logger.info( | ||
| 437 | f"[_validate_and_map_sub_genre] 映射 '{sub_genre}' -> '{mapped}'" | ||
| 438 | ) | ||
| 439 | return mapped | ||
| 440 | |||
| 441 | if genre in VALID_SUB_GENRES: | ||
| 442 | if sub_genre in VALID_SUB_GENRES[genre]: | ||
| 443 | return sub_genre | ||
| 444 | |||
| 445 | for valid_subs in VALID_SUB_GENRES.values(): | ||
| 446 | if sub_genre in valid_subs: | ||
| 447 | return sub_genre | ||
| 448 | |||
| 449 | logger.warning( | ||
| 450 | f"[_validate_and_map_sub_genre] 无法映射 sub_genre: '{sub_genre}' (genre: '{genre}')" | ||
| 451 | ) | ||
| 452 | return sub_genre | ||
| 453 | |||
| 454 | def _validate_list_field( | ||
| 455 | values: List[str], valid_set: Set[str], field_name: str | ||
| 456 | ) -> List[str]: | ||
| 457 | """严格验证列表字段中的值:仅保留字典内标签""" | ||
| 458 | result = [] | ||
| 459 | for v in values: | ||
| 460 | if v in valid_set: | ||
| 461 | result.append(v) | ||
| 462 | else: | ||
| 463 | logger.warning( | ||
| 464 | f"[_validate_list_field] {field_name} 值 '{v}' 不在字典中,已过滤" | ||
| 465 | ) | ||
| 466 | return result | ||
| 467 | |||
| 468 | def _validate_language(raw_value: Any) -> str: | ||
| 469 | language = _extract_single_field(raw_value).strip() | ||
| 470 | if not language: | ||
| 471 | return "" | ||
| 472 | mapped = LANGUAGE_MAPPING.get(language, language) | ||
| 473 | if mapped in VALID_LANGUAGES: | ||
| 474 | return mapped | ||
| 475 | logger.warning( | ||
| 476 | f"[_normalize_result] language '{language}' 不在字典中,已归并为空" | ||
| 477 | ) | ||
| 478 | return "" | ||
| 479 | |||
| 480 | result = { | ||
| 481 | "genre": "", | ||
| 482 | "sub_genre": "", | ||
| 483 | "emotion": [], | ||
| 484 | "voice_type": "", | ||
| 485 | "vocal_texture": "", | ||
| 486 | "vocal_description": "", | ||
| 487 | "visual_concept": "", | ||
| 488 | "language": "", | ||
| 489 | "scene": [], | ||
| 490 | "age": "", | ||
| 491 | "is_sinking": None, | ||
| 492 | "song_description": "", | ||
| 493 | "performer_type": "", | ||
| 494 | "music_style_tags": [], | ||
| 495 | "douyin_tags": [], | ||
| 496 | "instrument_tags": [], | ||
| 497 | } | ||
| 498 | |||
| 499 | # 音乐风格(一级风格和二级风格) | ||
| 500 | # 优先使用新格式 genre/sub_genre,兼容旧格式 music_style | ||
| 501 | raw_genre = raw_result.get("genre", "") | ||
| 502 | raw_sub_genre = raw_result.get("sub_genre", "") | ||
| 503 | raw_music_style = raw_result.get("music_style", []) | ||
| 504 | |||
| 505 | # 优先从 genre 字段获取一级风格 | ||
| 506 | if isinstance(raw_genre, str) and raw_genre.strip(): | ||
| 507 | result["genre"] = raw_genre.strip() | ||
| 508 | elif isinstance(raw_genre, dict): | ||
| 509 | result["genre"] = raw_genre.get("zh", "") or raw_genre.get("en", "") | ||
| 510 | # 兼容旧格式:从 music_style 数组提取 | ||
| 511 | elif ( | ||
| 512 | raw_music_style | ||
| 513 | and isinstance(raw_music_style, list) | ||
| 514 | and len(raw_music_style) > 0 | ||
| 515 | ): | ||
| 516 | first_style = raw_music_style[0] | ||
| 517 | if isinstance(first_style, dict): | ||
| 518 | result["genre"] = first_style.get("zh", "") or first_style.get("en", "") | ||
| 519 | elif isinstance(first_style, str): | ||
| 520 | result["genre"] = first_style.strip() | ||
| 521 | |||
| 522 | # 优先从 sub_genre 字段获取二级风格 | ||
| 523 | if isinstance(raw_sub_genre, str) and raw_sub_genre.strip(): | ||
| 524 | result["sub_genre"] = raw_sub_genre.strip() | ||
| 525 | elif isinstance(raw_sub_genre, dict): | ||
| 526 | result["sub_genre"] = raw_sub_genre.get("zh", "") or raw_sub_genre.get( | ||
| 527 | "en", "" | ||
| 528 | ) | ||
| 529 | # 兼容旧格式:从 music_style 数组第二个元素提取 | ||
| 530 | elif ( | ||
| 531 | raw_music_style | ||
| 532 | and isinstance(raw_music_style, list) | ||
| 533 | and len(raw_music_style) > 1 | ||
| 534 | ): | ||
| 535 | second_style = raw_music_style[1] | ||
| 536 | if isinstance(second_style, dict): | ||
| 537 | result["sub_genre"] = second_style.get("zh", "") or second_style.get( | ||
| 538 | "en", "" | ||
| 539 | ) | ||
| 540 | elif isinstance(second_style, str): | ||
| 541 | result["sub_genre"] = second_style.strip() | ||
| 542 | |||
| 543 | result["sub_genre"] = _validate_and_map_sub_genre( | ||
| 544 | result["sub_genre"], result["genre"] | ||
| 545 | ) | ||
| 546 | |||
| 547 | # 情绪 | ||
| 548 | raw_emotion = raw_result.get("emotion", []) | ||
| 549 | if isinstance(raw_emotion, str): | ||
| 550 | raw_emotion = [raw_emotion] | ||
| 551 | result["emotion"] = _validate_list_field( | ||
| 552 | _extract_list_field(raw_emotion), VALID_EMOTIONS, "emotion" | ||
| 553 | ) | ||
| 554 | |||
| 555 | # 人声类型 | ||
| 556 | raw_voice_type = raw_result.get("voice_type", "") | ||
| 557 | if raw_voice_type and isinstance(raw_voice_type, str): | ||
| 558 | voice_type = raw_voice_type.strip() | ||
| 559 | if voice_type in VALID_VOICE_TYPES: | ||
| 560 | result["voice_type"] = voice_type | ||
| 561 | else: | ||
| 562 | logger.warning( | ||
| 563 | f"[_normalize_result] voice_type '{voice_type}' 不在有效值中,保留原值" | ||
| 564 | ) | ||
| 565 | result["voice_type"] = voice_type | ||
| 566 | else: | ||
| 567 | result["voice_type"] = "" | ||
| 568 | |||
| 569 | # 人声质感 (LLM返回的是vocal_type) | ||
| 570 | result["vocal_texture"] = _extract_single_field( | ||
| 571 | raw_result.get("vocal_type", "") | ||
| 572 | ) | ||
| 573 | |||
| 574 | # 人声质感描述 | ||
| 575 | result["vocal_description"] = raw_result.get("vocal_description", "") | ||
| 576 | |||
| 577 | # 聚音演唱者类型(优先 performer_type,回退 vocal_type) | ||
| 578 | raw_performer_type = raw_result.get("performer_type", raw_result.get("vocal_type", "")) | ||
| 579 | if isinstance(raw_performer_type, str): | ||
| 580 | performer_type = raw_performer_type.strip() | ||
| 581 | if performer_type in VALID_PERFORMER_TYPES: | ||
| 582 | result["performer_type"] = performer_type | ||
| 583 | elif performer_type in VALID_VOICE_TYPES: | ||
| 584 | result["performer_type"] = performer_type | ||
| 585 | |||
| 586 | # 聚音标签:音乐风格/网络抖音/配器 | ||
| 587 | result["music_style_tags"] = _extract_list_field( | ||
| 588 | raw_result.get("music_style_tags", raw_result.get("music_style", [])) | ||
| 589 | ) | ||
| 590 | result["douyin_tags"] = _extract_list_field( | ||
| 591 | raw_result.get("douyin_tags", raw_result.get("network_douyin_tags", [])) | ||
| 592 | ) | ||
| 593 | result["instrument_tags"] = _extract_list_field( | ||
| 594 | raw_result.get("instrument_tags", raw_result.get("instruments", [])) | ||
| 595 | ) | ||
| 596 | result["music_style_tags"] = _validate_list_field( | ||
| 597 | result["music_style_tags"], VALID_MUSIC_STYLE_TAGS, "music_style_tags" | ||
| 598 | ) | ||
| 599 | result["douyin_tags"] = _validate_list_field( | ||
| 600 | result["douyin_tags"], VALID_DOUYIN_TAGS, "douyin_tags" | ||
| 601 | ) | ||
| 602 | result["instrument_tags"] = _validate_list_field( | ||
| 603 | result["instrument_tags"], VALID_INSTRUMENT_TAGS, "instrument_tags" | ||
| 604 | ) | ||
| 605 | |||
| 606 | # 视觉概念 | ||
| 607 | result["visual_concept"] = raw_result.get("visual_concept", "") | ||
| 608 | |||
| 609 | # 语种 | ||
| 610 | result["language"] = _validate_language(raw_result.get("language", "")) | ||
| 611 | |||
| 612 | # 场景(可多选) | ||
| 613 | raw_scene = raw_result.get("scene", []) | ||
| 614 | if isinstance(raw_scene, str): | ||
| 615 | raw_scene = [raw_scene] | ||
| 616 | if isinstance(raw_scene, list): | ||
| 617 | scene_list = [s.strip() for s in raw_scene if s and isinstance(s, str)] | ||
| 618 | result["scene"] = _validate_list_field(scene_list, VALID_SCENES, "scene") | ||
| 619 | |||
| 620 | # 适合听众年龄段 | ||
| 621 | raw_age = raw_result.get("age", "") | ||
| 622 | if raw_age and isinstance(raw_age, str): | ||
| 623 | result["age"] = raw_age.strip() | ||
| 624 | |||
| 625 | # 是否下沉 | ||
| 626 | raw_is_sinking = raw_result.get("is_sinking") | ||
| 627 | if isinstance(raw_is_sinking, bool): | ||
| 628 | result["is_sinking"] = raw_is_sinking | ||
| 629 | elif isinstance(raw_is_sinking, str): | ||
| 630 | is_sinking_lower = raw_is_sinking.strip().lower() | ||
| 631 | if is_sinking_lower in ("是", "true", "1", "yes"): | ||
| 632 | result["is_sinking"] = True | ||
| 633 | elif is_sinking_lower in ("否", "false", "0", "no"): | ||
| 634 | result["is_sinking"] = False | ||
| 635 | |||
| 636 | # 歌曲描述 | ||
| 637 | raw_song_desc = raw_result.get("song_description", "") | ||
| 638 | if raw_song_desc and isinstance(raw_song_desc, str): | ||
| 639 | result["song_description"] = raw_song_desc.strip() | ||
| 640 | |||
| 641 | # 情绪强度 | ||
| 642 | raw_emotional_intensity = raw_result.get("emotional_intensity", "") | ||
| 643 | if raw_emotional_intensity and isinstance(raw_emotional_intensity, str): | ||
| 644 | result["emotional_intensity"] = raw_emotional_intensity.strip() | ||
| 645 | |||
| 646 | # 节奏强度 | ||
| 647 | raw_rhythm_intensity = raw_result.get("rhythm_intensity", "") | ||
| 648 | if raw_rhythm_intensity and isinstance(raw_rhythm_intensity, str): | ||
| 649 | result["rhythm_intensity"] = raw_rhythm_intensity.strip() | ||
| 650 | |||
| 651 | # BPM 不从 LLM 结果中提取,统一由本地 bpm_analyzer_tools 提供 | ||
| 652 | |||
| 653 | # 歌词(可选) | ||
| 654 | if "lyrics" in raw_result: | ||
| 655 | result["lyrics"] = raw_result["lyrics"] | ||
| 656 | |||
| 657 | # 添加模型信息 | ||
| 658 | result["_model"] = model_name | ||
| 659 | if token_info: | ||
| 660 | result["_token_info"] = token_info | ||
| 661 | if "_token_info_parts" in raw_result and isinstance( | ||
| 662 | raw_result["_token_info_parts"], dict | ||
| 663 | ): | ||
| 664 | result["_token_info_parts"] = raw_result["_token_info_parts"] | ||
| 665 | if "_timing" in raw_result and isinstance(raw_result["_timing"], dict): | ||
| 666 | result["_timing"] = raw_result["_timing"] | ||
| 667 | |||
| 668 | return result |
| 1 | #!/usr/bin/env python3 | ||
| 2 | """ | ||
| 3 | Realtime BPM Analyzer - Python 测试程序 | ||
| 4 | |||
| 5 | 基于 realtime-bpm-analyzer (https://github.com/dlepaux/realtime-bpm-analyzer) | ||
| 6 | 的 Python 实现,用于快速测试音频文件的 BPM。 | ||
| 7 | |||
| 8 | 功能: | ||
| 9 | 1. 快速 BPM 识别 | ||
| 10 | 2. 实时特征提取 | ||
| 11 | 3. 多算法融合 | ||
| 12 | 4. 详细结果导出 | ||
| 13 | |||
| 14 | 使用方法: | ||
| 15 | python bpm_analyzer_test.py --file music.mp3 | ||
| 16 | python bpm_analyzer_test.py --file music.mp3 --output result.json | ||
| 17 | python bpm_analyzer_test.py --file music.mp3 --verbose | ||
| 18 | python bpm_analyzer_test.py --dir /path/to/music/folder | ||
| 19 | """ | ||
| 20 | |||
| 21 | import os | ||
| 22 | import sys | ||
| 23 | import json | ||
| 24 | import logging | ||
| 25 | import argparse | ||
| 26 | from pathlib import Path | ||
| 27 | from typing import Dict, List, Any, Optional, Tuple | ||
| 28 | from datetime import datetime | ||
| 29 | import numpy as np | ||
| 30 | |||
| 31 | # 导入音频处理库 | ||
| 32 | try: | ||
| 33 | import librosa | ||
| 34 | import librosa.beat | ||
| 35 | import librosa.feature | ||
| 36 | import librosa.onset | ||
| 37 | except ImportError: | ||
| 38 | print("❌ librosa 库未安装,请运行: pip install librosa") | ||
| 39 | sys.exit(1) | ||
| 40 | |||
| 41 | from scipy.signal import find_peaks, correlate | ||
| 42 | |||
| 43 | # 配置日志 | ||
| 44 | logging.basicConfig( | ||
| 45 | level=logging.INFO, | ||
| 46 | format='%(asctime)s - %(levelname)s - %(message)s' | ||
| 47 | ) | ||
| 48 | logger = logging.getLogger(__name__) | ||
| 49 | |||
| 50 | |||
| 51 | class RealtimeBPMAnalyzerTest: | ||
| 52 | """Realtime BPM Analyzer - Python 版本""" | ||
| 53 | |||
| 54 | # BPM 范围(参考 realtime-bpm-analyzer) | ||
| 55 | BPM_MIN = 30.0 | ||
| 56 | BPM_MAX = 200.0 | ||
| 57 | |||
| 58 | # 置信度阈值 | ||
| 59 | CONFIDENCE_THRESHOLD = 0.5 | ||
| 60 | |||
| 61 | def __init__(self, verbose: bool = False): | ||
| 62 | """ | ||
| 63 | 初始化分析器 | ||
| 64 | |||
| 65 | Args: | ||
| 66 | verbose: 是否显示详细信息 | ||
| 67 | """ | ||
| 68 | self.verbose = verbose | ||
| 69 | self.sr = 22050 # 采样率 | ||
| 70 | self.hop_length = 512 | ||
| 71 | |||
| 72 | if verbose: | ||
| 73 | logger.setLevel(logging.DEBUG) | ||
| 74 | |||
| 75 | logger.info("✓ Realtime BPM Analyzer Test 已初始化") | ||
| 76 | |||
| 77 | def print_header(self, title: str, width: int = 80): | ||
| 78 | """打印标题""" | ||
| 79 | print("\n" + "=" * width) | ||
| 80 | print(f" {title}") | ||
| 81 | print("=" * width) | ||
| 82 | |||
| 83 | def analyze_file(self, file_path: str) -> Dict[str, Any]: | ||
| 84 | """ | ||
| 85 | 分析单个音频文件 | ||
| 86 | |||
| 87 | Args: | ||
| 88 | file_path: 音频文件路径 | ||
| 89 | |||
| 90 | Returns: | ||
| 91 | 分析结果字典 | ||
| 92 | """ | ||
| 93 | self.print_header("🎵 Realtime BPM Analyzer - 测试程序") | ||
| 94 | |||
| 95 | # 验证文件 | ||
| 96 | if not os.path.exists(file_path): | ||
| 97 | logger.error(f"❌ 文件不存在: {file_path}") | ||
| 98 | return {'success': False, 'error': '文件不存在'} | ||
| 99 | |||
| 100 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024) | ||
| 101 | logger.info(f"📄 音频文件: {Path(file_path).name}") | ||
| 102 | logger.info(f"📊 文件大小: {file_size_mb:.2f} MB") | ||
| 103 | logger.info(f"📁 文件路径: {Path(file_path).absolute()}") | ||
| 104 | |||
| 105 | self.print_header("📊 分析过程", 80) | ||
| 106 | |||
| 107 | try: | ||
| 108 | # 加载音频 | ||
| 109 | logger.info("🔄 加载音频文件...") | ||
| 110 | y, sr = librosa.load(file_path, sr=self.sr, mono=True) | ||
| 111 | duration = len(y) / sr | ||
| 112 | logger.info(f"✓ 音频加载成功,时长: {duration:.2f} 秒") | ||
| 113 | |||
| 114 | # 执行快速分析 | ||
| 115 | logger.info("📈 快速 BPM 检测...") | ||
| 116 | fast_result = self._fast_bpm_detection(y, sr) | ||
| 117 | |||
| 118 | # 执行详细分析 | ||
| 119 | logger.info("📊 详细 BPM 分析...") | ||
| 120 | detailed_result = self._detailed_bpm_analysis(y, sr) | ||
| 121 | |||
| 122 | # 融合结果 | ||
| 123 | logger.info("🔀 融合分析结果...") | ||
| 124 | final_result = self._fuse_results(fast_result, detailed_result, y=y) | ||
| 125 | |||
| 126 | result = { | ||
| 127 | 'success': True, | ||
| 128 | 'file_path': str(Path(file_path).absolute()), | ||
| 129 | 'file_name': Path(file_path).name, | ||
| 130 | 'file_size_mb': round(file_size_mb, 2), | ||
| 131 | 'duration_seconds': round(duration, 2), | ||
| 132 | 'sample_rate': sr, | ||
| 133 | 'timestamp': datetime.now().isoformat(), | ||
| 134 | 'fast_detection': fast_result, | ||
| 135 | 'detailed_analysis': detailed_result, | ||
| 136 | 'final_result': final_result | ||
| 137 | } | ||
| 138 | |||
| 139 | self.print_header("📈 分析结果", 80) | ||
| 140 | self._display_results(result) | ||
| 141 | |||
| 142 | return result | ||
| 143 | |||
| 144 | except Exception as e: | ||
| 145 | logger.error(f"❌ 分析失败: {str(e)}") | ||
| 146 | if self.verbose: | ||
| 147 | import traceback | ||
| 148 | traceback.print_exc() | ||
| 149 | return {'success': False, 'error': str(e)} | ||
| 150 | |||
| 151 | def analyze_directory(self, dir_path: str) -> List[Dict[str, Any]]: | ||
| 152 | """ | ||
| 153 | 分析文件夹中的所有音频文件 | ||
| 154 | |||
| 155 | Args: | ||
| 156 | dir_path: 文件夹路径 | ||
| 157 | |||
| 158 | Returns: | ||
| 159 | 分析结果列表 | ||
| 160 | """ | ||
| 161 | self.print_header("🎵 Realtime BPM Analyzer - 批量分析", 80) | ||
| 162 | |||
| 163 | if not os.path.isdir(dir_path): | ||
| 164 | logger.error(f"❌ 文件夹不存在: {dir_path}") | ||
| 165 | return [] | ||
| 166 | |||
| 167 | # 查找所有音频文件 | ||
| 168 | audio_extensions = ('.mp3', '.wav', '.flac', '.m4a', '.aac', '.ogg') | ||
| 169 | audio_files = [] | ||
| 170 | |||
| 171 | for root, dirs, files in os.walk(dir_path): | ||
| 172 | for file in files: | ||
| 173 | if file.lower().endswith(audio_extensions): | ||
| 174 | audio_files.append(os.path.join(root, file)) | ||
| 175 | |||
| 176 | logger.info(f"📂 找到 {len(audio_files)} 个音频文件") | ||
| 177 | |||
| 178 | results = [] | ||
| 179 | for i, file_path in enumerate(audio_files, 1): | ||
| 180 | logger.info(f"\n[{i}/{len(audio_files)}] 正在分析...") | ||
| 181 | result = self.analyze_file(file_path) | ||
| 182 | results.append(result) | ||
| 183 | |||
| 184 | return results | ||
| 185 | |||
| 186 | def analyze_bpm( | ||
| 187 | self, | ||
| 188 | file_path: str = None, | ||
| 189 | y: np.ndarray = None, | ||
| 190 | sr: int = None, | ||
| 191 | ) -> Dict[str, Any]: | ||
| 192 | """ | ||
| 193 | 统一 BPM 分析入口(供其他模块调用) | ||
| 194 | |||
| 195 | 支持两种调用方式: | ||
| 196 | 1. 传入 file_path,内部以 sr=22050 加载音频 | ||
| 197 | 2. 传入已加载的 y, sr(避免重复加载) | ||
| 198 | |||
| 199 | Returns: | ||
| 200 | { | ||
| 201 | 'bpm': float, # 最终 BPM(经过融合+纠正) | ||
| 202 | 'original_bpm': float, # 快速检测的原始 BPM | ||
| 203 | 'confidence': float, | ||
| 204 | 'beat_times': list, # 节拍时间点列表 | ||
| 205 | } | ||
| 206 | """ | ||
| 207 | try: | ||
| 208 | if y is None and file_path is not None: | ||
| 209 | if not os.path.exists(file_path): | ||
| 210 | return {'bpm': 120.0, 'original_bpm': 120.0, | ||
| 211 | 'confidence': 0.0, 'beat_times': []} | ||
| 212 | y, sr = librosa.load(file_path, sr=self.sr, mono=True) | ||
| 213 | elif y is None: | ||
| 214 | return {'bpm': 120.0, 'original_bpm': 120.0, | ||
| 215 | 'confidence': 0.0, 'beat_times': []} | ||
| 216 | |||
| 217 | # 快速检测 | ||
| 218 | fast_result = self._fast_bpm_detection(y, sr) | ||
| 219 | |||
| 220 | # 详细分析 | ||
| 221 | detailed_result = self._detailed_bpm_analysis(y, sr) | ||
| 222 | |||
| 223 | # 融合 | ||
| 224 | final_result = self._fuse_results(fast_result, detailed_result, y=y) | ||
| 225 | |||
| 226 | final_bpm = final_result.get('bpm', 120.0) | ||
| 227 | original_bpm = fast_result.get('original_bpm', final_bpm) | ||
| 228 | |||
| 229 | # 获取 beat_times:从 _fast_bpm_detection 内部的 beat_track 获取 | ||
| 230 | _, beat_frames = librosa.beat.beat_track( | ||
| 231 | y=y, sr=sr, hop_length=self.hop_length | ||
| 232 | ) | ||
| 233 | if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0: | ||
| 234 | beat_times = librosa.frames_to_time( | ||
| 235 | beat_frames, sr=sr, hop_length=self.hop_length | ||
| 236 | ).tolist() | ||
| 237 | else: | ||
| 238 | beat_times = [] | ||
| 239 | |||
| 240 | # 如果 BPM 被减半了,节拍时间点也每隔一个取一个 | ||
| 241 | if final_bpm < original_bpm * 0.75: | ||
| 242 | beat_times = beat_times[::2] | ||
| 243 | |||
| 244 | return { | ||
| 245 | 'bpm': final_bpm, | ||
| 246 | 'original_bpm': original_bpm, | ||
| 247 | 'confidence': final_result.get('confidence', 0.0), | ||
| 248 | 'beat_times': beat_times, | ||
| 249 | } | ||
| 250 | except Exception as e: | ||
| 251 | logger.warning(f"analyze_bpm 失败: {e}") | ||
| 252 | return {'bpm': 120.0, 'original_bpm': 120.0, | ||
| 253 | 'confidence': 0.0, 'beat_times': []} | ||
| 254 | |||
| 255 | def _fast_bpm_detection(self, y: np.ndarray, sr: int) -> Dict[str, Any]: | ||
| 256 | """快速 BPM 检测(参考 librosa.beat.tempo)+ 智能节拍层级纠正""" | ||
| 257 | try: | ||
| 258 | # 获取 BPM 和节拍时间 | ||
| 259 | tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=self.hop_length) | ||
| 260 | |||
| 261 | # 处理 tempo 可能是 ndarray 的情况 | ||
| 262 | if isinstance(tempo, np.ndarray): | ||
| 263 | bpm = float(tempo[0]) if tempo.size > 0 else 120.0 | ||
| 264 | else: | ||
| 265 | bpm = float(tempo) | ||
| 266 | |||
| 267 | # 处理 beat_frames 可能是 ndarray 的情况 | ||
| 268 | if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0: | ||
| 269 | beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=self.hop_length) | ||
| 270 | beat_times = beat_times.tolist() if isinstance(beat_times, np.ndarray) else list(beat_times) | ||
| 271 | else: | ||
| 272 | beat_times = [] | ||
| 273 | |||
| 274 | # 智能节拍层级检测和纠正(传入音频数据用于onset分析) | ||
| 275 | corrected_bpm, correction_reason = self._detect_beat_level_errors(beat_times, bpm, y) | ||
| 276 | |||
| 277 | return { | ||
| 278 | 'bpm': round(corrected_bpm, 1), | ||
| 279 | 'original_bpm': round(bpm, 1), | ||
| 280 | 'confidence': 0.85, | ||
| 281 | 'method': 'librosa.beat.tempo()', | ||
| 282 | 'beat_count': len(beat_times), | ||
| 283 | 'beat_level_correction': correction_reason if correction_reason != 'beat_level_ok' else None, | ||
| 284 | 'duration_ms': 100 | ||
| 285 | } | ||
| 286 | except Exception as e: | ||
| 287 | logger.warning(f"⚠️ 快速检测失败: {str(e)}") | ||
| 288 | return { | ||
| 289 | 'bpm': 0, | ||
| 290 | 'confidence': 0, | ||
| 291 | 'method': 'librosa.beat.tempo()', | ||
| 292 | 'error': str(e) | ||
| 293 | } | ||
| 294 | |||
| 295 | def _detect_beat_level_errors(self, beat_times: list, bpm: float, y: np.ndarray = None) -> Tuple[float, str]: | ||
| 296 | """ | ||
| 297 | 检测和纠正beat level错误(如检测到8th-note而非quarter-note) | ||
| 298 | |||
| 299 | 改进版:组合多个特征来判断 | ||
| 300 | 1. 交替强度模式 (ratio) | ||
| 301 | 2. 原始BPM范围 (100-150范围内更可能需要减半) | ||
| 302 | 3. 谱质心分析 (慢歌通常谱质心较低) | ||
| 303 | 4. Onset对齐分数比较 | ||
| 304 | """ | ||
| 305 | if not beat_times or len(beat_times) < 2: | ||
| 306 | return bpm, "insufficient_beats" | ||
| 307 | |||
| 308 | beat_intervals = np.diff(beat_times) | ||
| 309 | mean_interval = np.mean(beat_intervals) | ||
| 310 | std_interval = np.std(beat_intervals) | ||
| 311 | coeff_variation = std_interval / mean_interval if mean_interval > 0 else 1.0 | ||
| 312 | |||
| 313 | beat_count = len(beat_times) | ||
| 314 | |||
| 315 | if self.verbose: | ||
| 316 | logger.debug(f"Beat level analysis: {beat_count} beats, CV={coeff_variation:.3f}, BPM={bpm:.1f}") | ||
| 317 | |||
| 318 | # 条件1: 间隔非常规则 + BPM > 100 + beat count > 20 (降低阈值以支持短片段) | ||
| 319 | if not (coeff_variation < 0.15 and bpm > 100 and beat_count > 20): | ||
| 320 | return bpm, "beat_level_ok" | ||
| 321 | |||
| 322 | # 如果没有音频数据,使用保守策略 | ||
| 323 | if y is None: | ||
| 324 | return bpm, "beat_level_ok" | ||
| 325 | |||
| 326 | halved_bpm = bpm / 2 | ||
| 327 | if not (40 < halved_bpm < 160): | ||
| 328 | return bpm, "beat_level_ok" | ||
| 329 | |||
| 330 | # 计算onset strength | ||
| 331 | onset_env = librosa.onset.onset_strength(y=y, sr=self.sr, hop_length=self.hop_length) | ||
| 332 | |||
| 333 | # 获取每个beat位置的onset强度 | ||
| 334 | beat_frames = librosa.time_to_frames(beat_times, sr=self.sr, hop_length=self.hop_length) | ||
| 335 | beat_strengths = [] | ||
| 336 | window = 3 | ||
| 337 | |||
| 338 | for frame in beat_frames: | ||
| 339 | if frame < len(onset_env): | ||
| 340 | start = max(0, frame - window) | ||
| 341 | end = min(len(onset_env), frame + window + 1) | ||
| 342 | beat_strengths.append(np.max(onset_env[start:end])) | ||
| 343 | |||
| 344 | if len(beat_strengths) < 10: | ||
| 345 | return bpm, "beat_level_ok" | ||
| 346 | |||
| 347 | beat_strengths = np.array(beat_strengths) | ||
| 348 | |||
| 349 | # 检测交替强度模式 | ||
| 350 | odd_beats = beat_strengths[::2] | ||
| 351 | even_beats = beat_strengths[1::2] | ||
| 352 | mean_odd = np.mean(odd_beats) | ||
| 353 | mean_even = np.mean(even_beats) | ||
| 354 | strength_ratio = mean_odd / mean_even if mean_even > 0 else 1.0 | ||
| 355 | |||
| 356 | # 计算谱质心 (spectral centroid) - 用于区分快歌和慢歌 | ||
| 357 | spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=self.sr, hop_length=self.hop_length) | ||
| 358 | mean_centroid = np.mean(spectral_centroid) | ||
| 359 | |||
| 360 | if self.verbose: | ||
| 361 | logger.debug(f"Beat strength ratio={strength_ratio:.3f}, spectral_centroid={mean_centroid:.1f}") | ||
| 362 | |||
| 363 | # 综合判断逻辑 | ||
| 364 | should_halve = False | ||
| 365 | reason = "" | ||
| 366 | |||
| 367 | # 规则1: 非常明显的交替模式 (ratio > 1.8 或 < 0.55) | ||
| 368 | if strength_ratio > 1.8 or strength_ratio < 0.55: | ||
| 369 | should_halve = True | ||
| 370 | reason = f"strong_alternating_pattern (ratio={strength_ratio:.2f})" | ||
| 371 | |||
| 372 | # 规则1b: BPM > 150 + 中等交替模式 → 减半 | ||
| 373 | # 如"春娇与志明"(172.3 BPM, ratio=1.406, ref=85) | ||
| 374 | # Home - Headhunterz (152 BPM, ratio=1.098) 不会触发 | ||
| 375 | elif bpm > 150 and (strength_ratio > 1.25 or strength_ratio < 0.8): | ||
| 376 | should_halve = True | ||
| 377 | reason = f"very_high_bpm_with_alternating (bpm={bpm:.1f}, ratio={strength_ratio:.2f})" | ||
| 378 | |||
| 379 | # 规则2: BPM在125-150范围 + 强交替模式 (ratio > 1.25) | ||
| 380 | # 高onset密度(>=3.0/s) + 高谱质心(>=2200)说明是真正的快歌,不应减半 | ||
| 381 | # 如"爱在西元前"(129.2 BPM, centroid=2527, onset_density=3.8, ratio=1.29) | ||
| 382 | # 否则使用 bpm*2/3 纠正(适用于3:2节奏关系的歌曲) | ||
| 383 | # 如"该死的爱情"(129.2 BPM, ratio=1.668, centroid=1986, ref=84) → 2/3=86.1 | ||
| 384 | # 如"你要的全拿走"(136.0 BPM, ratio=1.485, centroid=2678, ref=76) → 2/3=90.7 | ||
| 385 | elif 125 <= bpm <= 150 and strength_ratio > 1.25: | ||
| 386 | onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=self.sr, hop_length=self.hop_length) | ||
| 387 | duration = len(y) / self.sr | ||
| 388 | onset_density = len(onset_frames) / duration if duration > 0 else 0 | ||
| 389 | if onset_density >= 3.0 and mean_centroid >= 2200: | ||
| 390 | if self.verbose: | ||
| 391 | logger.debug(f"规则2跳过: 高onset密度({onset_density:.1f}/s) + 高谱质心({mean_centroid:.0f}),判定为快歌") | ||
| 392 | else: | ||
| 393 | # 根据谱质心区分纠正策略: | ||
| 394 | # 低谱质心(<2200): 暗淡音色的慢歌,librosa锁定在3/2倍,用*2/3纠正 | ||
| 395 | # 如"该死的爱情"(129.2, centroid=1986, ratio=1.67) → 86.1 (ref=84) | ||
| 396 | # 高谱质心(>=2200)+低onset密度(<3.0): 明亮制作的慢歌,librosa锁定在2倍,用/2纠正 | ||
| 397 | # 如"你要的全拿走"(136.0, centroid=2678, density=2.17, ratio=1.49) → 68.0 (ref=76) | ||
| 398 | if mean_centroid >= 2200: | ||
| 399 | # 明亮但节奏稀疏 → 简单减半 | ||
| 400 | should_halve = True | ||
| 401 | reason = f"rule2_bright_slow (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f}, density={onset_density:.1f})" | ||
| 402 | else: | ||
| 403 | # 暗淡音色 → 用2/3纠正 | ||
| 404 | two_thirds_bpm = round(bpm * 2 / 3, 1) | ||
| 405 | should_halve = False | ||
| 406 | logger.info( | ||
| 407 | f"🔧 节拍层级纠正(2/3): {bpm:.1f} BPM → {two_thirds_bpm:.1f} BPM " | ||
| 408 | f"(ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})" | ||
| 409 | ) | ||
| 410 | return two_thirds_bpm, f"rule2_two_thirds (bpm={bpm:.1f}, result={two_thirds_bpm:.1f}, ratio={strength_ratio:.2f})" | ||
| 411 | |||
| 412 | elif 125 <= bpm <= 150 and strength_ratio < 0.8 and mean_centroid < 2200: | ||
| 413 | should_halve = True | ||
| 414 | reason = f"mid_bpm_low_ratio_low_centroid (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})" | ||
| 415 | |||
| 416 | # 规则2b: BPM > 130 + 低谱质心 (< 1800) 表示慢歌特征但检测到高BPM | ||
| 417 | # 捕获像"嚣张"这样的歌曲: BPM=136但centroid=1653 | ||
| 418 | elif bpm > 130 and mean_centroid < 1800: | ||
| 419 | should_halve = True | ||
| 420 | reason = f"high_bpm_low_centroid (bpm={bpm:.1f}, centroid={mean_centroid:.0f})" | ||
| 421 | |||
| 422 | # 规则3: BPM在115-125范围需要更严格的条件 | ||
| 423 | elif 115 <= bpm < 125: | ||
| 424 | # 规则3a: 非常强的交替模式(ratio > 1.5),无论centroid如何都应减半 | ||
| 425 | # 这捕获了像"想你的夜"这样有强烈交替但centroid偏高的歌曲 | ||
| 426 | if strength_ratio > 1.5 or strength_ratio < 0.65: | ||
| 427 | should_halve = True | ||
| 428 | reason = f"strong_alternating_in_mid_bpm (bpm={bpm:.1f}, ratio={strength_ratio:.2f})" | ||
| 429 | # 规则3b: 中等交替模式 + 低谱质心(慢歌特征) | ||
| 430 | elif mean_centroid < 2000 and (strength_ratio > 1.4 or strength_ratio < 0.7): | ||
| 431 | should_halve = True | ||
| 432 | reason = f"slow_song_detected (centroid={mean_centroid:.0f}, ratio={strength_ratio:.2f})" | ||
| 433 | # 否则保持原样(可能是真正的中速歌曲如 有什么奇怪、中巴车) | ||
| 434 | |||
| 435 | # 规则3c: BPM在100-115范围(可能是慢歌被检测为2倍,如嘉禾望岗 56 BPM → 112 BPM) | ||
| 436 | # 使用onset alignment来判断 | ||
| 437 | elif 100 <= bpm < 115: | ||
| 438 | score_detected = self._compute_onset_alignment_score(onset_env, bpm) | ||
| 439 | score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm) | ||
| 440 | |||
| 441 | if score_detected > 0 and score_halved > 0: | ||
| 442 | alignment_ratio = score_halved / score_detected | ||
| 443 | if self.verbose: | ||
| 444 | logger.debug(f"Onset alignment (100-115 BPM): detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}") | ||
| 445 | |||
| 446 | # 如果halved BPM的对齐分数更好 (ratio > 1.0),说明真实BPM是一半 | ||
| 447 | # 同时检查交替模式作为辅助判断 | ||
| 448 | if alignment_ratio > 1.0 and (strength_ratio > 1.2 or strength_ratio < 0.83): | ||
| 449 | should_halve = True | ||
| 450 | reason = f"slow_song_100_115_range (alignment_ratio={alignment_ratio:.3f}, strength_ratio={strength_ratio:.2f})" | ||
| 451 | # 即使没有明显交替模式,如果对齐分数明显更好也应减半 | ||
| 452 | elif alignment_ratio > 1.08: | ||
| 453 | should_halve = True | ||
| 454 | reason = f"onset_alignment_strongly_favors_half (ratio={alignment_ratio:.3f})" | ||
| 455 | |||
| 456 | # 规则4: 使用onset alignment比较BPM vs BPM/2 (仅用于高BPM > 130) | ||
| 457 | # 如果BPM/2的对齐分数明显更好,说明检测到了half-beat | ||
| 458 | # 限制为BPM > 130以避免误伤中速歌曲如"中巴车"(117.5 BPM) | ||
| 459 | if not should_halve and bpm > 130: | ||
| 460 | score_detected = self._compute_onset_alignment_score(onset_env, bpm) | ||
| 461 | score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm) | ||
| 462 | |||
| 463 | if score_detected > 0 and score_halved > 0: | ||
| 464 | alignment_ratio = score_halved / score_detected | ||
| 465 | if self.verbose: | ||
| 466 | logger.debug(f"Onset alignment: detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}") | ||
| 467 | |||
| 468 | # 高谱质心(>=2000)说明是快节奏/电子乐,需要更高的alignment ratio才能减半 | ||
| 469 | # 避免误伤如"Home - Headhunterz"(152 BPM, centroid=2290, ratio=1.102) | ||
| 470 | ratio_threshold = 1.15 if mean_centroid >= 2000 else 1.04 | ||
| 471 | if alignment_ratio > ratio_threshold and 40 < halved_bpm < 160: | ||
| 472 | should_halve = True | ||
| 473 | reason = f"onset_alignment_favors_half (ratio={alignment_ratio:.3f})" | ||
| 474 | |||
| 475 | if should_halve: | ||
| 476 | logger.info(f"🔧 节拍层级纠正: {bpm:.1f} BPM → {halved_bpm:.1f} BPM ({reason})") | ||
| 477 | return halved_bpm, reason | ||
| 478 | |||
| 479 | return bpm, "beat_level_ok" | ||
| 480 | |||
| 481 | def _compute_onset_alignment_score(self, onset_env: np.ndarray, bpm: float) -> float: | ||
| 482 | """ | ||
| 483 | 计算给定BPM与onset strength的对齐度分数 | ||
| 484 | |||
| 485 | 原理:真实的节拍应该对应onset strength的峰值 | ||
| 486 | 分数越高表示对齐度越好 | ||
| 487 | """ | ||
| 488 | frame_rate = self.sr / self.hop_length | ||
| 489 | beat_interval_frames = int((60.0 / bpm) * frame_rate) | ||
| 490 | |||
| 491 | if beat_interval_frames < 1 or beat_interval_frames > len(onset_env): | ||
| 492 | return 0.0 | ||
| 493 | |||
| 494 | # 在每个节拍位置采样onset strength | ||
| 495 | beat_strengths = [] | ||
| 496 | off_beat_strengths = [] | ||
| 497 | |||
| 498 | for i in range(0, len(onset_env) - beat_interval_frames, beat_interval_frames): | ||
| 499 | # 节拍位置(在一个小窗口内找最大值) | ||
| 500 | window_size = max(1, beat_interval_frames // 8) | ||
| 501 | start = max(0, i - window_size) | ||
| 502 | end = min(len(onset_env), i + window_size) | ||
| 503 | beat_strengths.append(np.max(onset_env[start:end])) | ||
| 504 | |||
| 505 | # 非节拍位置(节拍之间的中点) | ||
| 506 | mid_point = i + beat_interval_frames // 2 | ||
| 507 | if mid_point < len(onset_env): | ||
| 508 | start_off = max(0, mid_point - window_size) | ||
| 509 | end_off = min(len(onset_env), mid_point + window_size) | ||
| 510 | off_beat_strengths.append(np.max(onset_env[start_off:end_off])) | ||
| 511 | |||
| 512 | if not beat_strengths or not off_beat_strengths: | ||
| 513 | return 0.0 | ||
| 514 | |||
| 515 | # 分数 = 节拍位置平均强度 / 非节拍位置平均强度 | ||
| 516 | # 比值越高,说明节拍位置的onset越明显 | ||
| 517 | mean_beat = np.mean(beat_strengths) | ||
| 518 | mean_off_beat = np.mean(off_beat_strengths) | ||
| 519 | |||
| 520 | if mean_off_beat < 1e-6: | ||
| 521 | return mean_beat | ||
| 522 | |||
| 523 | score = mean_beat / mean_off_beat | ||
| 524 | return float(score) | ||
| 525 | |||
| 526 | def _detailed_bpm_analysis(self, y: np.ndarray, sr: int) -> Dict[str, Any]: | ||
| 527 | """详细 BPM 分析""" | ||
| 528 | try: | ||
| 529 | # 计算 onset strength | ||
| 530 | onset_env = librosa.onset.onset_strength( | ||
| 531 | y=y, sr=sr, hop_length=self.hop_length | ||
| 532 | ) | ||
| 533 | |||
| 534 | # 计算 tempogram | ||
| 535 | tempogram = librosa.feature.tempogram( | ||
| 536 | y=y, sr=sr, hop_length=self.hop_length | ||
| 537 | ) | ||
| 538 | |||
| 539 | # 计算自相关 | ||
| 540 | tempogram_flat = tempogram.flatten() | ||
| 541 | acf = correlate(tempogram_flat, tempogram_flat, mode='full') | ||
| 542 | acf = acf[len(acf)//2:] | ||
| 543 | acf = acf / (acf[0] + 1e-8) | ||
| 544 | |||
| 545 | # 找峰值 | ||
| 546 | peaks, properties = find_peaks(acf[1:], height=0.2, distance=5) | ||
| 547 | peaks = peaks + 1 | ||
| 548 | |||
| 549 | if len(peaks) > 0: | ||
| 550 | frame_rate = sr / self.hop_length | ||
| 551 | best_peak_idx = peaks[np.argmax(acf[peaks])] | ||
| 552 | bpm = 60.0 * frame_rate / best_peak_idx | ||
| 553 | confidence = float(np.max(acf[peaks])) | ||
| 554 | else: | ||
| 555 | bpm = 120.0 | ||
| 556 | confidence = 0.3 | ||
| 557 | |||
| 558 | # 确保在合理范围内 | ||
| 559 | bpm = np.clip(bpm, self.BPM_MIN, self.BPM_MAX) | ||
| 560 | |||
| 561 | return { | ||
| 562 | 'bpm': round(bpm, 1), | ||
| 563 | 'confidence': round(float(np.clip(confidence, 0, 1)), 2), | ||
| 564 | 'method': 'Tempogram Autocorrelation', | ||
| 565 | 'peaks_count': int(len(peaks)) | ||
| 566 | } | ||
| 567 | except Exception as e: | ||
| 568 | logger.warning(f"⚠️ 详细分析失败: {str(e)}") | ||
| 569 | return { | ||
| 570 | 'bpm': 0, | ||
| 571 | 'confidence': 0, | ||
| 572 | 'method': 'Tempogram Autocorrelation', | ||
| 573 | 'error': str(e) | ||
| 574 | } | ||
| 575 | |||
| 576 | def _fuse_results( | ||
| 577 | self, | ||
| 578 | fast_result: Dict[str, Any], | ||
| 579 | detailed_result: Dict[str, Any], | ||
| 580 | y: np.ndarray = None, | ||
| 581 | ) -> Dict[str, Any]: | ||
| 582 | """融合快速和详细分析的结果,带倍频检测和纠正""" | ||
| 583 | results = [] | ||
| 584 | |||
| 585 | if fast_result.get('bpm', 0) > 0: | ||
| 586 | results.append({ | ||
| 587 | 'bpm': fast_result['bpm'], | ||
| 588 | 'original_bpm': fast_result.get('original_bpm', fast_result['bpm']), | ||
| 589 | 'confidence': fast_result['confidence'], | ||
| 590 | 'method': fast_result['method'], | ||
| 591 | 'beat_level_correction': fast_result.get('beat_level_correction') | ||
| 592 | }) | ||
| 593 | |||
| 594 | if detailed_result.get('bpm', 0) > 0: | ||
| 595 | results.append({ | ||
| 596 | 'bpm': detailed_result['bpm'], | ||
| 597 | 'confidence': detailed_result['confidence'], | ||
| 598 | 'method': detailed_result['method'] | ||
| 599 | }) | ||
| 600 | |||
| 601 | if not results: | ||
| 602 | return { | ||
| 603 | 'bpm': 120.0, | ||
| 604 | 'confidence': 0.0, | ||
| 605 | 'note': '无法检测 BPM,使用默认值' | ||
| 606 | } | ||
| 607 | |||
| 608 | # 如果快速检测已经进行了beat level纠正,直接使用纠正后的结果 | ||
| 609 | beat_level_correction = results[0].get('beat_level_correction') if results else None | ||
| 610 | if beat_level_correction: | ||
| 611 | original_bpm = results[0].get('original_bpm', results[0]['bpm']) | ||
| 612 | corrected_bpm = results[0]['bpm'] | ||
| 613 | return { | ||
| 614 | 'bpm': corrected_bpm, | ||
| 615 | 'confidence': results[0]['confidence'], | ||
| 616 | 'primary_method': results[0]['method'], | ||
| 617 | 'supporting_methods': len(results) - 1, | ||
| 618 | 'all_candidates': results, | ||
| 619 | 'octave_correction': { | ||
| 620 | 'from': original_bpm, | ||
| 621 | 'to': corrected_bpm, | ||
| 622 | 'reason': f'节拍层级纠正: {original_bpm:.1f} → {corrected_bpm:.1f} ({beat_level_correction})' | ||
| 623 | } | ||
| 624 | } | ||
| 625 | |||
| 626 | # 如果只有一个结果 | ||
| 627 | if len(results) == 1: | ||
| 628 | best = results[0] | ||
| 629 | return { | ||
| 630 | 'bpm': best['bpm'], | ||
| 631 | 'confidence': best['confidence'], | ||
| 632 | 'primary_method': best['method'], | ||
| 633 | 'supporting_methods': 0, | ||
| 634 | 'all_candidates': results, | ||
| 635 | 'octave_correction': None | ||
| 636 | } | ||
| 637 | |||
| 638 | # 检测倍频关系 | ||
| 639 | fast_bpm = results[0]['bpm'] # librosa.beat.tempo 通常更准确 | ||
| 640 | detailed_bpm = results[1]['bpm'] if len(results) > 1 else None | ||
| 641 | |||
| 642 | if detailed_bpm and fast_bpm > 0: | ||
| 643 | ratio = max(fast_bpm, detailed_bpm) / min(fast_bpm, detailed_bpm) | ||
| 644 | |||
| 645 | # 检查是否是倍频关系(1/2, 1/3, 1/4, 2x, 3x, 4x 等) | ||
| 646 | octave_correction = None | ||
| 647 | is_octave = False | ||
| 648 | chosen_bpm = fast_bpm # 默认使用快速检测结果 | ||
| 649 | |||
| 650 | # 特殊情况:当 detailed_bpm 很低(< 40)且 fast_bpm 在 100-120 范围时 | ||
| 651 | # 可能是慢歌被检测为2倍,此时 detailed_bpm × 2 可能是正确答案 | ||
| 652 | # 例如:嘉禾望岗 实际56 BPM,fast=112.3,detailed=30,30×2=60更接近 | ||
| 653 | # 注意:需要排除中速/快歌被误纠正的情况(如 中巴车带我回家, fast=117.5, detailed=30, ref=115) | ||
| 654 | # 使用 onset alignment 来验证:如果 halved BPM 的对齐度明显优于 fast BPM,才执行纠正 | ||
| 655 | if detailed_bpm < 40 and 100 <= fast_bpm <= 120 and y is not None: | ||
| 656 | # 计算谱质心来判断是否真的是慢歌 | ||
| 657 | spectral_centroid = librosa.feature.spectral_centroid( | ||
| 658 | y=y, sr=self.sr, hop_length=self.hop_length | ||
| 659 | ) | ||
| 660 | mean_centroid = float(np.mean(spectral_centroid)) | ||
| 661 | |||
| 662 | doubled_detailed = detailed_bpm * 2 | ||
| 663 | # 检查 doubled_detailed 是否在合理的慢歌范围内 (50-70 BPM) | ||
| 664 | # 且谱质心较低(< 2200),确认是慢歌特征 | ||
| 665 | if 50 <= doubled_detailed <= 70 and mean_centroid < 2200: | ||
| 666 | # 检查 fast_bpm 是否约等于 doubled_detailed × 2 | ||
| 667 | if abs(fast_bpm - doubled_detailed * 2) / fast_bpm < 0.1: | ||
| 668 | # 额外验证:用 onset alignment 确认 halved BPM 确实更好 | ||
| 669 | # 避免误纠正如"中巴车带我回家"(fast=117.5, ref=115) | ||
| 670 | onset_env = librosa.onset.onset_strength( | ||
| 671 | y=y, sr=self.sr, hop_length=self.hop_length | ||
| 672 | ) | ||
| 673 | score_fast = self._compute_onset_alignment_score(onset_env, fast_bpm) | ||
| 674 | score_halved = self._compute_onset_alignment_score(onset_env, doubled_detailed) | ||
| 675 | alignment_ratio = score_halved / score_fast if score_fast > 0 else 0 | ||
| 676 | |||
| 677 | if self.verbose: | ||
| 678 | logger.debug( | ||
| 679 | f"慢歌倍频验证: score_fast={score_fast:.3f}, " | ||
| 680 | f"score_halved={score_halved:.3f}, ratio={alignment_ratio:.3f}" | ||
| 681 | ) | ||
| 682 | |||
| 683 | # 只有当 halved BPM 的对齐度明显更好时才纠正 | ||
| 684 | # 中巴车: alignment_ratio=1.042,不触发(实际BPM=115) | ||
| 685 | # 嘉禾望岗: halved=56 对齐度应该明显更好,会触发 | ||
| 686 | if alignment_ratio > 1.08: | ||
| 687 | chosen_bpm = doubled_detailed | ||
| 688 | is_octave = True | ||
| 689 | octave_correction = { | ||
| 690 | 'from': fast_bpm, | ||
| 691 | 'to': doubled_detailed, | ||
| 692 | 'reason': f'慢歌倍频纠正: fast={fast_bpm:.1f} ≈ detailed×4={detailed_bpm:.1f}×4,使用 detailed×2={doubled_detailed:.1f} (alignment={alignment_ratio:.3f})' | ||
| 693 | } | ||
| 694 | logger.info(f"\n🔧 慢歌倍频纠正: {fast_bpm:.1f} BPM → {doubled_detailed:.1f} BPM") | ||
| 695 | logger.info(f" 原因: {octave_correction['reason']}") | ||
| 696 | return { | ||
| 697 | 'bpm': chosen_bpm, | ||
| 698 | 'confidence': results[1]['confidence'], | ||
| 699 | 'primary_method': 'Tempogram + 倍频纠正', | ||
| 700 | 'supporting_methods': 1, | ||
| 701 | 'all_candidates': results, | ||
| 702 | 'octave_correction': octave_correction | ||
| 703 | } | ||
| 704 | else: | ||
| 705 | if self.verbose: | ||
| 706 | logger.debug( | ||
| 707 | f"慢歌倍频纠正跳过: fast BPM({fast_bpm:.1f})对齐度更好,保持原值" | ||
| 708 | ) | ||
| 709 | |||
| 710 | # 检查常见倍频关系:detailed_bpm 应该 ≈ fast_bpm * multiplier | ||
| 711 | for multiplier in [0.25, 0.33, 0.5, 1.0, 2.0, 3.0, 4.0]: | ||
| 712 | expected_bpm = fast_bpm * multiplier | ||
| 713 | # 检查 detailed_bpm 是否接近 expected_bpm(10% 容差) | ||
| 714 | if abs(detailed_bpm - expected_bpm) / expected_bpm < 0.1: | ||
| 715 | is_octave = True | ||
| 716 | if multiplier != 1.0: # 非 1 倍关系表示倍频 | ||
| 717 | # 使用快速检测的结果 | ||
| 718 | corrected_bpm = fast_bpm | ||
| 719 | octave_correction = { | ||
| 720 | 'from': detailed_bpm, | ||
| 721 | 'to': corrected_bpm, | ||
| 722 | 'reason': f'倍频关系检测: {detailed_bpm:.1f} ≈ {fast_bpm:.1f} × {multiplier},使用快速检测结果' | ||
| 723 | } | ||
| 724 | break | ||
| 725 | |||
| 726 | # 如果检测到倍频,使用快速检测结果(通常更准确) | ||
| 727 | if is_octave and octave_correction: | ||
| 728 | logger.info(f"\n🔧 倍频纠正: {octave_correction['from']:.1f} BPM → {octave_correction['to']:.1f} BPM") | ||
| 729 | logger.info(f" 原因: {octave_correction['reason']}") | ||
| 730 | return { | ||
| 731 | 'bpm': fast_bpm, | ||
| 732 | 'confidence': results[0]['confidence'], | ||
| 733 | 'primary_method': results[0]['method'], | ||
| 734 | 'supporting_methods': 1, | ||
| 735 | 'all_candidates': results, | ||
| 736 | 'octave_correction': octave_correction | ||
| 737 | } | ||
| 738 | |||
| 739 | # 如果没有倍频关系,优先使用快速检测(librosa.beat.tempo 是金标准) | ||
| 740 | # 快速检测通常比详细分析更准确 | ||
| 741 | best = results[0] # 快速检测 | ||
| 742 | |||
| 743 | return { | ||
| 744 | 'bpm': best['bpm'], | ||
| 745 | 'confidence': best['confidence'], | ||
| 746 | 'primary_method': best['method'], | ||
| 747 | 'supporting_methods': len(results) - 1, | ||
| 748 | 'all_candidates': results, | ||
| 749 | 'octave_correction': None | ||
| 750 | } | ||
| 751 | |||
| 752 | def _display_results(self, result: Dict[str, Any]): | ||
| 753 | """显示分析结果""" | ||
| 754 | if not result['success']: | ||
| 755 | logger.error(f"❌ 分析失败: {result.get('error')}") | ||
| 756 | return | ||
| 757 | |||
| 758 | file_info = ( | ||
| 759 | f"文件: {result['file_name']} " | ||
| 760 | f"({result['file_size_mb']} MB) " | ||
| 761 | f"时长: {result['duration_seconds']} 秒" | ||
| 762 | ) | ||
| 763 | logger.info(file_info) | ||
| 764 | |||
| 765 | final = result['final_result'] | ||
| 766 | logger.info(f"\n🎵 最终结果:") | ||
| 767 | logger.info(f" BPM: {final['bpm']}") | ||
| 768 | logger.info(f" 置信度: {final['confidence']:.0%}") | ||
| 769 | logger.info(f" 主要方法: {final['primary_method']}") | ||
| 770 | logger.info(f" 支持方法数: {final['supporting_methods']}") | ||
| 771 | |||
| 772 | # 显示倍频纠正信息 | ||
| 773 | if final.get('octave_correction'): | ||
| 774 | correction = final['octave_correction'] | ||
| 775 | logger.info(f"\n🔧 倍频纠正:") | ||
| 776 | logger.info(f" 原始检测: {correction['from']:.1f} BPM") | ||
| 777 | logger.info(f" 纠正后: {correction['to']:.1f} BPM") | ||
| 778 | logger.info(f" 原因: {correction['reason']}") | ||
| 779 | |||
| 780 | if self.verbose: | ||
| 781 | logger.debug(f"\n📊 快速检测: {result['fast_detection']['bpm']} BPM") | ||
| 782 | logger.debug(f"📊 详细分析: {result['detailed_analysis']['bpm']} BPM") | ||
| 783 | |||
| 784 | def export_results( | ||
| 785 | self, | ||
| 786 | results: Any, | ||
| 787 | output_path: str | ||
| 788 | ): | ||
| 789 | """导出结果为 JSON""" | ||
| 790 | try: | ||
| 791 | # 将 numpy 类型转换为 Python 原生类型 | ||
| 792 | def convert_numpy(obj): | ||
| 793 | if isinstance(obj, np.ndarray): | ||
| 794 | return obj.tolist() | ||
| 795 | elif isinstance(obj, np.integer): | ||
| 796 | return int(obj) | ||
| 797 | elif isinstance(obj, np.floating): | ||
| 798 | return float(obj) | ||
| 799 | elif isinstance(obj, dict): | ||
| 800 | return {k: convert_numpy(v) for k, v in obj.items()} | ||
| 801 | elif isinstance(obj, (list, tuple)): | ||
| 802 | return [convert_numpy(v) for v in obj] | ||
| 803 | return obj | ||
| 804 | |||
| 805 | results_converted = convert_numpy(results) | ||
| 806 | |||
| 807 | with open(output_path, 'w', encoding='utf-8') as f: | ||
| 808 | json.dump(results_converted, f, ensure_ascii=False, indent=2) | ||
| 809 | |||
| 810 | logger.info(f"✓ 结果已导出到: {Path(output_path).absolute()}") | ||
| 811 | |||
| 812 | except Exception as e: | ||
| 813 | logger.error(f"❌ 导出失败: {str(e)}") | ||
| 814 | |||
| 815 | |||
| 816 | def main(): | ||
| 817 | """主函数""" | ||
| 818 | parser = argparse.ArgumentParser( | ||
| 819 | description='Realtime BPM Analyzer - Python 测试程序', | ||
| 820 | formatter_class=argparse.RawDescriptionHelpFormatter, | ||
| 821 | epilog=""" | ||
| 822 | 示例用法: | ||
| 823 | # 分析单个文件 | ||
| 824 | python bpm_analyzer_test.py --file music.mp3 | ||
| 825 | |||
| 826 | # 分析并输出结果 | ||
| 827 | python bpm_analyzer_test.py --file music.mp3 --output result.json | ||
| 828 | |||
| 829 | # 显示详细信息 | ||
| 830 | python bpm_analyzer_test.py --file music.mp3 --verbose | ||
| 831 | |||
| 832 | # 批量分析文件夹 | ||
| 833 | python bpm_analyzer_test.py --dir /path/to/music | ||
| 834 | """ | ||
| 835 | ) | ||
| 836 | |||
| 837 | parser.add_argument('--file', type=str, help='音频文件路径') | ||
| 838 | parser.add_argument('--dir', type=str, help='音频文件夹路径(批量分析)') | ||
| 839 | parser.add_argument('-o', '--output', type=str, help='输出 JSON 文件路径') | ||
| 840 | parser.add_argument('-v', '--verbose', action='store_true', help='显示详细信息') | ||
| 841 | |||
| 842 | args = parser.parse_args() | ||
| 843 | |||
| 844 | # 验证参数 | ||
| 845 | if not args.file and not args.dir: | ||
| 846 | parser.print_help() | ||
| 847 | sys.exit(1) | ||
| 848 | |||
| 849 | # 初始化分析器 | ||
| 850 | analyzer = RealtimeBPMAnalyzerTest(verbose=args.verbose) | ||
| 851 | |||
| 852 | # 执行分析 | ||
| 853 | try: | ||
| 854 | if args.file: | ||
| 855 | result = analyzer.analyze_file(args.file) | ||
| 856 | results = result | ||
| 857 | else: | ||
| 858 | results_list = analyzer.analyze_directory(args.dir) | ||
| 859 | results = { | ||
| 860 | 'success': True, | ||
| 861 | 'total_files': len(results_list), | ||
| 862 | 'results': results_list | ||
| 863 | } | ||
| 864 | |||
| 865 | # 导出结果 | ||
| 866 | if args.output: | ||
| 867 | analyzer.export_results(results, args.output) | ||
| 868 | else: | ||
| 869 | # 默认输出文件名 | ||
| 870 | if args.file: | ||
| 871 | default_output = f"bpm_result_{Path(args.file).stem}.json" | ||
| 872 | else: | ||
| 873 | default_output = "bpm_results.json" | ||
| 874 | analyzer.export_results(results, default_output) | ||
| 875 | |||
| 876 | print("\n" + "=" * 80) | ||
| 877 | print("✅ 分析完成!") | ||
| 878 | print("=" * 80 + "\n") | ||
| 879 | |||
| 880 | except Exception as e: | ||
| 881 | logger.error(f"❌ 执行失败: {str(e)}") | ||
| 882 | if args.verbose: | ||
| 883 | import traceback | ||
| 884 | traceback.print_exc() | ||
| 885 | sys.exit(1) | ||
| 886 | |||
| 887 | |||
| 888 | if __name__ == '__main__': | ||
| 889 | main() |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 火山引擎豆包音乐分析器实现 | ||
| 4 | """ | ||
| 5 | |||
| 6 | import os | ||
| 7 | import time | ||
| 8 | import logging | ||
| 9 | from typing import Dict, Any, Optional | ||
| 10 | from dotenv import load_dotenv | ||
| 11 | from pathlib import Path | ||
| 12 | |||
| 13 | import httpx | ||
| 14 | |||
| 15 | from .base import AudioAnalyzer | ||
| 16 | from .prompts import build_analyze_prompt, build_lyrics_prompt | ||
| 17 | |||
| 18 | _ROOT_DIR = Path(__file__).resolve().parents[2] | ||
| 19 | load_dotenv(_ROOT_DIR / ".env") | ||
| 20 | |||
| 21 | logger = logging.getLogger(__name__) | ||
| 22 | |||
| 23 | |||
| 24 | class DoubaoAnalyzer(AudioAnalyzer): | ||
| 25 | """火山引擎豆包音乐分析器""" | ||
| 26 | |||
| 27 | def __init__( | ||
| 28 | self, | ||
| 29 | api_key: Optional[str] = None, | ||
| 30 | base_url: Optional[str] = None, | ||
| 31 | model: Optional[str] = None, | ||
| 32 | timeout: float = 60.0, | ||
| 33 | max_retries: int = 3, | ||
| 34 | ): | ||
| 35 | """ | ||
| 36 | 初始化豆包分析器 | ||
| 37 | |||
| 38 | Args: | ||
| 39 | api_key: API Key(默认从环境变量读取 DOUBAO_API_KEY 或 ARK_API_KEY) | ||
| 40 | base_url: API 基础URL(默认: https://ark.cn-beijing.volces.com/api/v3) | ||
| 41 | model: 模型名称(默认: doubao-seed-1-8-251228) | ||
| 42 | timeout: 超时时间(秒) | ||
| 43 | max_retries: 最大重试次数 | ||
| 44 | """ | ||
| 45 | self.api_key = api_key or os.getenv("DOUBAO_API_KEY", os.getenv("ARK_API_KEY")) | ||
| 46 | self.base_url = base_url or os.getenv( | ||
| 47 | "DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3" | ||
| 48 | ) | ||
| 49 | self.model = model or os.getenv("DOUBAO_MODEL", "doubao-seed-1-8-251228") | ||
| 50 | self.timeout = timeout | ||
| 51 | self.max_retries = max_retries | ||
| 52 | |||
| 53 | self._client = None | ||
| 54 | |||
| 55 | def _get_client(self) -> httpx.Client: | ||
| 56 | """获取 HTTP 客户端""" | ||
| 57 | if self._client is None: | ||
| 58 | self._client = httpx.Client( | ||
| 59 | base_url=self.base_url, | ||
| 60 | timeout=self.timeout, | ||
| 61 | headers={ | ||
| 62 | "Authorization": f"Bearer {self.api_key}", | ||
| 63 | "Content-Type": "application/json", | ||
| 64 | }, | ||
| 65 | ) | ||
| 66 | return self._client | ||
| 67 | |||
| 68 | def get_provider_name(self) -> str: | ||
| 69 | return "doubao" | ||
| 70 | |||
| 71 | def get_model_name(self) -> str: | ||
| 72 | return self.model | ||
| 73 | |||
| 74 | def analyze( | ||
| 75 | self, | ||
| 76 | metadata: Dict[str, Any], | ||
| 77 | music_url: str, | ||
| 78 | extract_lyrics: bool = False, | ||
| 79 | label_level: int = 0, | ||
| 80 | ) -> Optional[Dict[str, Any]]: | ||
| 81 | """ | ||
| 82 | 分析音乐 | ||
| 83 | |||
| 84 | Args: | ||
| 85 | metadata: 音乐元数据 | ||
| 86 | music_url: 音乐文件 URL | ||
| 87 | extract_lyrics: 是否识别歌词 | ||
| 88 | label_level: 标签级别 | ||
| 89 | |||
| 90 | Returns: | ||
| 91 | 分析结果字典 | ||
| 92 | """ | ||
| 93 | client = self._get_client() | ||
| 94 | |||
| 95 | if extract_lyrics: | ||
| 96 | return self._analyze_with_lyrics(client, metadata, music_url, label_level) | ||
| 97 | else: | ||
| 98 | return self._analyze_basic(client, metadata, music_url, label_level) | ||
| 99 | |||
| 100 | def _analyze_basic( | ||
| 101 | self, | ||
| 102 | client: httpx.Client, | ||
| 103 | metadata: Dict[str, Any], | ||
| 104 | music_url: str, | ||
| 105 | label_level: int = 0, | ||
| 106 | ) -> Optional[Dict[str, Any]]: | ||
| 107 | """基础分析(不含歌词)""" | ||
| 108 | system_prompt, user_prompt = build_analyze_prompt( | ||
| 109 | metadata=metadata, | ||
| 110 | include_lyrics=False, | ||
| 111 | label_level=label_level, | ||
| 112 | ) | ||
| 113 | |||
| 114 | # 打印提示词到日志 | ||
| 115 | logger.info(f"[DoubaoAnalyzer] System Prompt:\n{system_prompt}") | ||
| 116 | logger.info(f"[DoubaoAnalyzer] User Prompt:\n{user_prompt}") | ||
| 117 | |||
| 118 | messages = self._build_messages(system_prompt, user_prompt, music_url) | ||
| 119 | |||
| 120 | response = self._call_with_retry(client, messages) | ||
| 121 | |||
| 122 | if response is None: | ||
| 123 | return None | ||
| 124 | |||
| 125 | result = self._parse_response(response.get("content", "")) | ||
| 126 | if result is None: | ||
| 127 | return None | ||
| 128 | |||
| 129 | return self._normalize_result(result, self.model, response.get("usage")) | ||
| 130 | |||
| 131 | def _analyze_with_lyrics( | ||
| 132 | self, | ||
| 133 | client: httpx.Client, | ||
| 134 | metadata: Dict[str, Any], | ||
| 135 | music_url: str, | ||
| 136 | label_level: int = 0, | ||
| 137 | ) -> Optional[Dict[str, Any]]: | ||
| 138 | """分析(含歌词识别,需要两次调用)""" | ||
| 139 | # 第一次调用:基本信息(不含歌词) | ||
| 140 | system_prompt, user_prompt = build_analyze_prompt( | ||
| 141 | metadata=metadata, | ||
| 142 | include_lyrics=False, | ||
| 143 | label_level=label_level, | ||
| 144 | ) | ||
| 145 | |||
| 146 | # 打印提示词到日志 | ||
| 147 | logger.info(f"[DoubaoAnalyzer] System Prompt (with lyrics):\n{system_prompt}") | ||
| 148 | logger.info(f"[DoubaoAnalyzer] User Prompt (with lyrics):\n{user_prompt}") | ||
| 149 | |||
| 150 | messages_basic = self._build_messages(system_prompt, user_prompt, music_url) | ||
| 151 | response_basic = self._call_with_retry(client, messages_basic) | ||
| 152 | |||
| 153 | if response_basic is None: | ||
| 154 | return None | ||
| 155 | |||
| 156 | result = self._parse_response(response_basic.get("content", "")) | ||
| 157 | if result is None: | ||
| 158 | return None | ||
| 159 | |||
| 160 | # 第二次调用:歌词识别 | ||
| 161 | lyrics_prompt = build_lyrics_prompt() | ||
| 162 | |||
| 163 | # 打印歌词识别提示词到日志 | ||
| 164 | logger.info(f"[DoubaoAnalyzer] Lyrics Prompt:\n{lyrics_prompt}") | ||
| 165 | |||
| 166 | messages_lyrics = self._build_messages( | ||
| 167 | "请识别这段音频中的歌词内容", lyrics_prompt, music_url | ||
| 168 | ) | ||
| 169 | response_lyrics = self._call_with_retry(client, messages_lyrics) | ||
| 170 | |||
| 171 | lyrics_result = None | ||
| 172 | if response_lyrics: | ||
| 173 | lyrics_result = self._parse_response(response_lyrics.get("content", "")) | ||
| 174 | if lyrics_result and "lyrics" in lyrics_result: | ||
| 175 | result["lyrics"] = lyrics_result["lyrics"] | ||
| 176 | |||
| 177 | # 合并 token 使用信息 | ||
| 178 | usage = response_basic.get("usage", {}) | ||
| 179 | if response_lyrics and response_lyrics.get("usage"): | ||
| 180 | usage_lyrics = response_lyrics["usage"] | ||
| 181 | usage = { | ||
| 182 | "prompt_tokens": usage.get("prompt_tokens", 0) | ||
| 183 | + usage_lyrics.get("prompt_tokens", 0), | ||
| 184 | "completion_tokens": usage.get("completion_tokens", 0) | ||
| 185 | + usage_lyrics.get("completion_tokens", 0), | ||
| 186 | "total_tokens": usage.get("total_tokens", 0) | ||
| 187 | + usage_lyrics.get("total_tokens", 0), | ||
| 188 | } | ||
| 189 | |||
| 190 | return self._normalize_result(result, self.model, usage) | ||
| 191 | |||
| 192 | def _build_messages( | ||
| 193 | self, | ||
| 194 | system_prompt: str, | ||
| 195 | user_prompt: str, | ||
| 196 | music_url: str, | ||
| 197 | ) -> list: | ||
| 198 | """构建消息格式""" | ||
| 199 | return [ | ||
| 200 | { | ||
| 201 | "role": "user", | ||
| 202 | "content": [ | ||
| 203 | {"type": "video_url", "video_url": {"url": music_url}}, | ||
| 204 | {"type": "text", "text": user_prompt}, | ||
| 205 | ], | ||
| 206 | } | ||
| 207 | ] | ||
| 208 | |||
| 209 | def _call_with_retry( | ||
| 210 | self, | ||
| 211 | client: httpx.Client, | ||
| 212 | messages: list, | ||
| 213 | ) -> Optional[Dict]: | ||
| 214 | """带重试的 API 调用""" | ||
| 215 | endpoint = "/chat/completions" | ||
| 216 | data = { | ||
| 217 | "model": self.model, | ||
| 218 | "messages": messages, | ||
| 219 | "temperature": 0.7, | ||
| 220 | "max_tokens": 4000, | ||
| 221 | "stream": False, | ||
| 222 | } | ||
| 223 | |||
| 224 | for attempt in range(1, self.max_retries + 1): | ||
| 225 | try: | ||
| 226 | print(f" [Doubao] 调用模型 (尝试 {attempt}/{self.max_retries})...") | ||
| 227 | |||
| 228 | start_time = time.time() | ||
| 229 | |||
| 230 | response = client.post(endpoint, json=data) | ||
| 231 | response.raise_for_status() | ||
| 232 | |||
| 233 | end_time = time.time() | ||
| 234 | elapsed = end_time - start_time | ||
| 235 | print(f" [Doubao] 响应时间: {elapsed:.2f}s") | ||
| 236 | |||
| 237 | result = response.json() | ||
| 238 | content = ( | ||
| 239 | result.get("choices", [{}])[0].get("message", {}).get("content", "") | ||
| 240 | ) | ||
| 241 | usage = result.get("usage", {}) | ||
| 242 | |||
| 243 | print(f" [Doubao] 响应: {content[:100]}...") | ||
| 244 | |||
| 245 | return { | ||
| 246 | "content": content, | ||
| 247 | "usage": { | ||
| 248 | "prompt_tokens": usage.get("prompt_tokens", 0), | ||
| 249 | "completion_tokens": usage.get("completion_tokens", 0), | ||
| 250 | "total_tokens": usage.get("total_tokens", 0), | ||
| 251 | }, | ||
| 252 | } | ||
| 253 | |||
| 254 | except httpx.HTTPError as e: | ||
| 255 | error_type = type(e).__name__ | ||
| 256 | print(f" [Doubao] HTTP 错误 ({error_type}): {e}") | ||
| 257 | |||
| 258 | if attempt < self.max_retries: | ||
| 259 | wait_time = attempt | ||
| 260 | print(f" 等待 {wait_time} 秒后重试...") | ||
| 261 | time.sleep(wait_time) | ||
| 262 | else: | ||
| 263 | print(f" 已达到最大重试次数") | ||
| 264 | return None | ||
| 265 | |||
| 266 | except Exception as e: | ||
| 267 | error_type = type(e).__name__ | ||
| 268 | print(f" [Doubao] 错误 ({error_type}): {e}") | ||
| 269 | |||
| 270 | if attempt < self.max_retries: | ||
| 271 | wait_time = attempt | ||
| 272 | print(f" 等待 {wait_time} 秒后重试...") | ||
| 273 | time.sleep(wait_time) | ||
| 274 | else: | ||
| 275 | print(f" 已达到最大重试次数") | ||
| 276 | return None | ||
| 277 | |||
| 278 | return None | ||
| 279 | |||
| 280 | |||
| 281 | def test_doubao_audio_url_lyrics(): | ||
| 282 | """ | ||
| 283 | 测试豆包是否支持通过音频URL解析音频歌词 | ||
| 284 | |||
| 285 | 此测试用例用于验证豆包模型是否能够: | ||
| 286 | 1. 接收音频URL作为输入 | ||
| 287 | 2. 解析音频内容 | ||
| 288 | 3. 识别并返回歌词 | ||
| 289 | |||
| 290 | 使用方法: | ||
| 291 | python -c "from app.middleware.music_analyze.doubao_analyzer import test_doubao_audio_url_lyrics; test_doubao_audio_url_lyrics()" | ||
| 292 | |||
| 293 | 或者直接在命令行运行: | ||
| 294 | python app/middleware/music_analyze/doubao_analyzer.py | ||
| 295 | """ | ||
| 296 | import json | ||
| 297 | |||
| 298 | print("=" * 80) | ||
| 299 | print("测试豆包音频URL歌词解析功能") | ||
| 300 | print("=" * 80) | ||
| 301 | |||
| 302 | # 测试音频URL(使用一个公开可访问的音频文件) | ||
| 303 | # 注意:请替换为实际可访问的音频URL | ||
| 304 | test_audio_url = "https://hikoon-ai-test.oss-cn-hangzhou.aliyuncs.com/ai/cache/modelName/20260114/_s__e_1768376270519_rmab41.mp3" | ||
| 305 | |||
| 306 | print(f"\n测试音频URL: {test_audio_url}") | ||
| 307 | print("\n开始测试...") | ||
| 308 | |||
| 309 | try: | ||
| 310 | # 初始化分析器 | ||
| 311 | analyzer = DoubaoAnalyzer() | ||
| 312 | |||
| 313 | # 测试元数据 | ||
| 314 | metadata = {"title": "测试歌曲", "artist": "测试艺术家", "test": True} | ||
| 315 | |||
| 316 | print("\n1. 测试基础分析(不含歌词)...") | ||
| 317 | result_basic = analyzer.analyze( | ||
| 318 | metadata=metadata, | ||
| 319 | music_url=test_audio_url, | ||
| 320 | extract_lyrics=False, | ||
| 321 | label_level=0, | ||
| 322 | ) | ||
| 323 | |||
| 324 | if result_basic: | ||
| 325 | print(" ✓ 基础分析成功") | ||
| 326 | print(f" - 曲风: {result_basic.get('genre', 'N/A')}") | ||
| 327 | print(f" - 语种: {result_basic.get('language', 'N/A')}") | ||
| 328 | print(f" - 情绪: {result_basic.get('emotion', 'N/A')}") | ||
| 329 | else: | ||
| 330 | print(" ✗ 基础分析失败") | ||
| 331 | |||
| 332 | print("\n2. 测试歌词识别(含歌词)...") | ||
| 333 | result_with_lyrics = analyzer.analyze( | ||
| 334 | metadata=metadata, | ||
| 335 | music_url=test_audio_url, | ||
| 336 | extract_lyrics=True, | ||
| 337 | label_level=0, | ||
| 338 | ) | ||
| 339 | |||
| 340 | if result_with_lyrics: | ||
| 341 | print(" ✓ 歌词识别分析成功") | ||
| 342 | lyrics = result_with_lyrics.get("lyrics", []) | ||
| 343 | if lyrics: | ||
| 344 | print(f" ✓ 成功识别歌词,共 {len(lyrics)} 行") | ||
| 345 | print("\n 歌词预览(前5行):") | ||
| 346 | for i, line in enumerate(lyrics[:5], 1): | ||
| 347 | time_str = line.get("time", "N/A") | ||
| 348 | text = line.get("text", "") | ||
| 349 | print(f" [{i}] {time_str} - {text}") | ||
| 350 | if len(lyrics) > 5: | ||
| 351 | print(f" ... 还有 {len(lyrics) - 5} 行") | ||
| 352 | print("\n ✓ 测试通过:豆包支持音频URL解析歌词") | ||
| 353 | else: | ||
| 354 | print(" ⚠ 未识别到歌词(可能是纯音乐或无法识别)") | ||
| 355 | print("\n ! 测试结果:豆包支持音频URL解析,但未返回歌词") | ||
| 356 | |||
| 357 | # 输出完整结果 | ||
| 358 | print("\n3. 完整分析结果:") | ||
| 359 | print(json.dumps(result_with_lyrics, ensure_ascii=False, indent=2)) | ||
| 360 | else: | ||
| 361 | print(" ✗ 歌词识别分析失败") | ||
| 362 | print("\n ✗ 测试失败:豆包可能不支持音频URL解析") | ||
| 363 | |||
| 364 | print("\n" + "=" * 80) | ||
| 365 | print("测试完成") | ||
| 366 | print("=" * 80) | ||
| 367 | |||
| 368 | return result_with_lyrics | ||
| 369 | |||
| 370 | except Exception as e: | ||
| 371 | print(f"\n✗ 测试过程中发生错误: {e}") | ||
| 372 | import traceback | ||
| 373 | |||
| 374 | traceback.print_exc() | ||
| 375 | return None | ||
| 376 | |||
| 377 | |||
| 378 | if __name__ == "__main__": | ||
| 379 | test_doubao_audio_url_lyrics() |
app/middleware/music_analyze/factory.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 音乐分析器工厂 | ||
| 4 | """ | ||
| 5 | |||
| 6 | from typing import Dict, Any, Optional | ||
| 7 | from .base import AudioAnalyzer | ||
| 8 | from .qwen_analyzer import QwenAnalyzer | ||
| 9 | |||
| 10 | |||
| 11 | class AnalyzerFactory: | ||
| 12 | """音乐分析器工厂""" | ||
| 13 | |||
| 14 | _analyzers: Dict[str, AudioAnalyzer] = {} | ||
| 15 | |||
| 16 | @classmethod | ||
| 17 | def get_analyzer(cls, provider: str = "qwen", **kwargs) -> AudioAnalyzer: | ||
| 18 | """ | ||
| 19 | 获取分析器实例 | ||
| 20 | |||
| 21 | Args: | ||
| 22 | provider: 提供商名称(仅支持 qwen) | ||
| 23 | **kwargs: 额外配置参数(如 api_key, model, timeout 等) | ||
| 24 | |||
| 25 | Returns: | ||
| 26 | AudioAnalyzer 实例 | ||
| 27 | """ | ||
| 28 | key = f"{provider}" | ||
| 29 | cache_key = f"{provider}_{kwargs.get('model', '')}" | ||
| 30 | |||
| 31 | if cache_key in cls._analyzers: | ||
| 32 | return cls._analyzers[cache_key] | ||
| 33 | |||
| 34 | if provider == "qwen": | ||
| 35 | analyzer = QwenAnalyzer(**kwargs) | ||
| 36 | else: | ||
| 37 | raise ValueError(f"Unknown provider: {provider}. Only 'qwen' is supported.") | ||
| 38 | |||
| 39 | cls._analyzers[cache_key] = analyzer | ||
| 40 | return analyzer | ||
| 41 | |||
| 42 | @classmethod | ||
| 43 | def get_default_analyzer(cls) -> AudioAnalyzer: | ||
| 44 | """获取默认分析器(从环境变量读取)""" | ||
| 45 | import os | ||
| 46 | |||
| 47 | provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen") | ||
| 48 | return cls.get_analyzer(provider=provider) | ||
| 49 | |||
| 50 | @classmethod | ||
| 51 | def list_providers(cls) -> list: | ||
| 52 | """列出可用的提供商""" | ||
| 53 | return ["qwen"] | ||
| 54 | |||
| 55 | @classmethod | ||
| 56 | def clear_cache(cls): | ||
| 57 | """清除缓存的分析器实例""" | ||
| 58 | cls._analyzers.clear() |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 音乐分析统一入口 | ||
| 4 | 提供简化的 analyze_music() 函数 | ||
| 5 | """ | ||
| 6 | |||
| 7 | from typing import Dict, Any, Optional | ||
| 8 | import os | ||
| 9 | |||
| 10 | from .factory import AnalyzerFactory | ||
| 11 | |||
| 12 | |||
| 13 | def analyze_music( | ||
| 14 | metadata: Dict[str, Any], | ||
| 15 | music_url: str, | ||
| 16 | provider: str = None, | ||
| 17 | extract_lyrics: bool = False, | ||
| 18 | label_level: int = 0, | ||
| 19 | ) -> Optional[Dict[str, Any]]: | ||
| 20 | """ | ||
| 21 | 音乐分析统一入口函数 | ||
| 22 | |||
| 23 | Args: | ||
| 24 | metadata: 音乐元数据字典(如 title, artist 等) | ||
| 25 | music_url: 音乐文件 URL | ||
| 26 | provider: 提供商(qwen | doubao),默认从环境变量读取 | ||
| 27 | extract_lyrics: 是否识别歌词 | ||
| 28 | label_level: 标签级别(0: 一级标签, 1: 一级+二级标签) | ||
| 29 | |||
| 30 | Returns: | ||
| 31 | 分析结果字典,包含以下字段: | ||
| 32 | - genre: 音乐风格(一级风格,如:流行、摇滚) | ||
| 33 | - emotion: 情绪列表 | ||
| 34 | - emotional_intensity: 情绪强度 | ||
| 35 | - vocal_texture: 人声质感 | ||
| 36 | - vocal_description: 人声质感描述 | ||
| 37 | - visual_concept: 视觉概念 | ||
| 38 | - language: 语种 | ||
| 39 | - bpm: 节拍数(可选) | ||
| 40 | - lyrics: 歌词列表(可选) | ||
| 41 | - _model: 使用的模型名称 | ||
| 42 | - _token_info: Token 使用信息 | ||
| 43 | |||
| 44 | Example: | ||
| 45 | >>> result = analyze_music( | ||
| 46 | ... metadata={"title": "稻香", "artist": "周杰伦"}, | ||
| 47 | ... music_url="https://example.com/music.mp3", | ||
| 48 | ... provider="qwen", | ||
| 49 | ... extract_lyrics=False, | ||
| 50 | ... ) | ||
| 51 | >>> print(result["genre"]) | ||
| 52 | 流行 | ||
| 53 | """ | ||
| 54 | if provider is None: | ||
| 55 | provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen") | ||
| 56 | |||
| 57 | analyzer = AnalyzerFactory.get_analyzer(provider=provider) | ||
| 58 | |||
| 59 | return analyzer.analyze( | ||
| 60 | metadata=metadata, | ||
| 61 | music_url=music_url, | ||
| 62 | extract_lyrics=extract_lyrics, | ||
| 63 | label_level=label_level, | ||
| 64 | ) | ||
| 65 | |||
| 66 | |||
| 67 | def analyze_music_with_qwen( | ||
| 68 | metadata: Dict[str, Any], | ||
| 69 | music_url: str, | ||
| 70 | extract_lyrics: bool = False, | ||
| 71 | label_level: int = 0, | ||
| 72 | ) -> Optional[Dict[str, Any]]: | ||
| 73 | """使用通义千问分析音乐""" | ||
| 74 | return analyze_music( | ||
| 75 | metadata=metadata, | ||
| 76 | music_url=music_url, | ||
| 77 | provider="qwen", | ||
| 78 | extract_lyrics=extract_lyrics, | ||
| 79 | label_level=label_level, | ||
| 80 | ) | ||
| 81 | |||
| 82 | |||
| 83 | def analyze_music_with_doubao( | ||
| 84 | metadata: Dict[str, Any], | ||
| 85 | music_url: str, | ||
| 86 | extract_lyrics: bool = False, | ||
| 87 | label_level: int = 0, | ||
| 88 | ) -> Optional[Dict[str, Any]]: | ||
| 89 | """使用火山引擎豆包分析音乐""" | ||
| 90 | return analyze_music( | ||
| 91 | metadata=metadata, | ||
| 92 | music_url=music_url, | ||
| 93 | provider="doubao", | ||
| 94 | extract_lyrics=extract_lyrics, | ||
| 95 | label_level=label_level, | ||
| 96 | ) | ||
| 97 | |||
| 98 | |||
| 99 | def analyze_music_lyrics_only( | ||
| 100 | metadata: Dict[str, Any], | ||
| 101 | music_url: str, | ||
| 102 | provider: str = None, | ||
| 103 | ) -> Optional[Dict[str, Any]]: | ||
| 104 | """仅识别歌词,避免重复做基础标签分析""" | ||
| 105 | if provider is None: | ||
| 106 | provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen") | ||
| 107 | |||
| 108 | analyzer = AnalyzerFactory.get_analyzer(provider=provider) | ||
| 109 | if hasattr(analyzer, "analyze_lyrics_only"): | ||
| 110 | return analyzer.analyze_lyrics_only(metadata=metadata, music_url=music_url) | ||
| 111 | |||
| 112 | # 兼容未实现 lyrics_only 的提供商 | ||
| 113 | result = analyzer.analyze( | ||
| 114 | metadata=metadata, | ||
| 115 | music_url=music_url, | ||
| 116 | extract_lyrics=True, | ||
| 117 | label_level=0, | ||
| 118 | ) | ||
| 119 | if isinstance(result, dict): | ||
| 120 | lyrics = result.get("lyrics", []) | ||
| 121 | return { | ||
| 122 | "lyrics": lyrics if isinstance(lyrics, list) else [], | ||
| 123 | "_model": result.get("_model"), | ||
| 124 | "_token_info": result.get("_token_info"), | ||
| 125 | } | ||
| 126 | return None | ||
| 127 | |||
| 128 | |||
| 129 | def get_available_providers() -> list: | ||
| 130 | """获取可用的提供商列表""" | ||
| 131 | return AnalyzerFactory.list_providers() |
app/middleware/music_analyze/prompts.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 音乐分析提示词模板构建器 | ||
| 4 | 支持从外部模板文件读取提示词,便于动态修改 | ||
| 5 | """ | ||
| 6 | |||
| 7 | import os | ||
| 8 | from pathlib import Path | ||
| 9 | from typing import Dict, Any, Optional | ||
| 10 | |||
| 11 | |||
| 12 | # 模板文件路径(已迁移到 app/prompts/step2_music_decode) | ||
| 13 | PROMPTS_DIR = Path(__file__).parent.parent.parent / "prompts" / "step2_music_decode" | ||
| 14 | SYSTEM_PROMPT_FILE = PROMPTS_DIR / "music_analyze_system_prompt.md" | ||
| 15 | SYSTEM_PROMPT_PART_A_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_a.md" | ||
| 16 | SYSTEM_PROMPT_PART_B_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_b.md" | ||
| 17 | USER_PROMPT_FILE = PROMPTS_DIR / "music_analyze_user_prompt.md" | ||
| 18 | LYRICS_ONLY_PROMPT_FILE = PROMPTS_DIR / "music_lyrics_only_prompt.md" | ||
| 19 | |||
| 20 | |||
| 21 | def load_template(template_path: Path) -> str: | ||
| 22 | """ | ||
| 23 | 从文件加载模板 | ||
| 24 | |||
| 25 | Args: | ||
| 26 | template_path: 模板文件路径 | ||
| 27 | |||
| 28 | Returns: | ||
| 29 | 模板内容字符串 | ||
| 30 | """ | ||
| 31 | if not template_path.exists(): | ||
| 32 | raise FileNotFoundError(f"模板文件不存在: {template_path}") | ||
| 33 | |||
| 34 | with open(template_path, "r", encoding="utf-8") as f: | ||
| 35 | content = f.read() | ||
| 36 | |||
| 37 | # 只移除文件顶部的 Markdown 注释(以 # 开头的注释行) | ||
| 38 | # 保留 ## 标题行和正文内容 | ||
| 39 | lines = content.split("\n") | ||
| 40 | filtered_lines = [] | ||
| 41 | in_header = True | ||
| 42 | |||
| 43 | for line in lines: | ||
| 44 | stripped = line.strip() | ||
| 45 | # 如果是空行,保留 | ||
| 46 | if not stripped: | ||
| 47 | filtered_lines.append(line) | ||
| 48 | continue | ||
| 49 | |||
| 50 | # 如果在文件头部且是单行注释(# 但不是 ##),则跳过 | ||
| 51 | if in_header and stripped.startswith("#") and not stripped.startswith("##"): | ||
| 52 | continue | ||
| 53 | |||
| 54 | # 遇到 ## 标题或正文内容,不再是头部 | ||
| 55 | in_header = False | ||
| 56 | filtered_lines.append(line) | ||
| 57 | |||
| 58 | return "\n".join(filtered_lines) | ||
| 59 | |||
| 60 | |||
| 61 | class PromptBuilder: | ||
| 62 | """音乐分析提示词模板构建器""" | ||
| 63 | |||
| 64 | def __init__(self, label_level: int = 0): | ||
| 65 | """ | ||
| 66 | 初始化提示词构建器 | ||
| 67 | |||
| 68 | Args: | ||
| 69 | label_level: 标签级别(0: 一级标签, 1: 一级+二级标签) | ||
| 70 | """ | ||
| 71 | self.label_level = label_level | ||
| 72 | |||
| 73 | def build_system_prompt(self) -> str: | ||
| 74 | """构建系统提示词 - 直接加载静态模板""" | ||
| 75 | return load_template(SYSTEM_PROMPT_FILE) | ||
| 76 | |||
| 77 | def build_system_prompt_part_a(self) -> str: | ||
| 78 | """构建系统提示词A组""" | ||
| 79 | return load_template(SYSTEM_PROMPT_PART_A_FILE) | ||
| 80 | |||
| 81 | def build_system_prompt_part_b(self) -> str: | ||
| 82 | """构建系统提示词B组""" | ||
| 83 | return load_template(SYSTEM_PROMPT_PART_B_FILE) | ||
| 84 | |||
| 85 | def build_metadata_section(self, metadata: Optional[Dict[str, Any]] = None) -> str: | ||
| 86 | """构建元数据部分""" | ||
| 87 | if not metadata: | ||
| 88 | return "" | ||
| 89 | |||
| 90 | sections = ["## 音乐元数据"] | ||
| 91 | for key, value in metadata.items(): | ||
| 92 | if key.startswith("_"): | ||
| 93 | continue | ||
| 94 | if value and str(value).strip(): | ||
| 95 | sections.append(f"- {key}: {value}") | ||
| 96 | sections.append("") | ||
| 97 | return "\n".join(sections) | ||
| 98 | |||
| 99 | def build_output_format( | ||
| 100 | self, | ||
| 101 | include_lyrics: bool = False, | ||
| 102 | include_bpm: bool = True, | ||
| 103 | ) -> str: | ||
| 104 | """构建输出格式说明""" | ||
| 105 | if include_lyrics and include_bpm: | ||
| 106 | format_spec = """{ | ||
| 107 | "genre": "", | ||
| 108 | "sub_genre": "", | ||
| 109 | "language": "", | ||
| 110 | "vocal_type": "", | ||
| 111 | "vocal_description": "", | ||
| 112 | "emotion": [""], | ||
| 113 | "scene": [""], | ||
| 114 | "age": "", | ||
| 115 | "rhythm_intensity": "", | ||
| 116 | "is_sinking": false, | ||
| 117 | "song_description": "", | ||
| 118 | "visual_concept": "", | ||
| 119 | "emotional_intensity": "", | ||
| 120 | "bpm": 0, | ||
| 121 | "lyrics": [{"time": "", "text": ""}] | ||
| 122 | }""" | ||
| 123 | elif include_bpm: | ||
| 124 | format_spec = """{ | ||
| 125 | "genre": "", | ||
| 126 | "sub_genre": "", | ||
| 127 | "language": "", | ||
| 128 | "vocal_type": "", | ||
| 129 | "vocal_description": "", | ||
| 130 | "emotion": [""], | ||
| 131 | "scene": [""], | ||
| 132 | "age": "", | ||
| 133 | "rhythm_intensity": "", | ||
| 134 | "is_sinking": false, | ||
| 135 | "song_description": "", | ||
| 136 | "visual_concept": "", | ||
| 137 | "emotional_intensity": "", | ||
| 138 | "bpm": 0 | ||
| 139 | }""" | ||
| 140 | elif include_lyrics: | ||
| 141 | format_spec = """{ | ||
| 142 | "genre": "", | ||
| 143 | "sub_genre": "", | ||
| 144 | "language": "", | ||
| 145 | "vocal_type": "", | ||
| 146 | "vocal_description": "", | ||
| 147 | "emotion": [""], | ||
| 148 | "scene": [""], | ||
| 149 | "age": "", | ||
| 150 | "rhythm_intensity": "", | ||
| 151 | "is_sinking": false, | ||
| 152 | "song_description": "", | ||
| 153 | "visual_concept": "", | ||
| 154 | "emotional_intensity": "", | ||
| 155 | "lyrics": [{"time": "", "text": ""}] | ||
| 156 | }""" | ||
| 157 | else: | ||
| 158 | format_spec = """{ | ||
| 159 | "genre": "", | ||
| 160 | "sub_genre": "", | ||
| 161 | "language": "", | ||
| 162 | "vocal_type": "", | ||
| 163 | "vocal_description": "", | ||
| 164 | "emotion": [""], | ||
| 165 | "scene": [""], | ||
| 166 | "age": "", | ||
| 167 | "rhythm_intensity": "", | ||
| 168 | "is_sinking": false, | ||
| 169 | "song_description": "", | ||
| 170 | "visual_concept": "", | ||
| 171 | "emotional_intensity": "" | ||
| 172 | }""" | ||
| 173 | |||
| 174 | return format_spec | ||
| 175 | |||
| 176 | def build_user_prompt( | ||
| 177 | self, | ||
| 178 | metadata: Optional[Dict[str, Any]] = None, | ||
| 179 | include_lyrics: bool = False, | ||
| 180 | include_bpm: bool = True, | ||
| 181 | ) -> str: | ||
| 182 | """ | ||
| 183 | 构建完整的用户提示词 | ||
| 184 | 使用模板文件并替换占位符 | ||
| 185 | |||
| 186 | Args: | ||
| 187 | metadata: 音乐元数据字典(可选) | ||
| 188 | include_lyrics: 是否识别歌词(保留参数以兼容现有调用) | ||
| 189 | include_bpm: 是否包含BPM识别(保留参数以兼容现有调用) | ||
| 190 | |||
| 191 | Returns: | ||
| 192 | 完整的用户提示词 | ||
| 193 | """ | ||
| 194 | # 加载模板 | ||
| 195 | template = load_template(USER_PROMPT_FILE) | ||
| 196 | |||
| 197 | # 准备替换字典 - 只替换元数据部分 | ||
| 198 | # 输出格式已在系统提示词中定义,不需要在用户提示词中重复 | ||
| 199 | replacements = { | ||
| 200 | "{{METADATA_SECTION}}": self.build_metadata_section(metadata), | ||
| 201 | } | ||
| 202 | |||
| 203 | # 替换占位符 | ||
| 204 | result = template | ||
| 205 | for placeholder, value in replacements.items(): | ||
| 206 | result = result.replace(placeholder, value) | ||
| 207 | |||
| 208 | return result | ||
| 209 | |||
| 210 | def build_lyrics_only_prompt(self) -> str: | ||
| 211 | """构建仅识别歌词的提示词""" | ||
| 212 | return load_template(LYRICS_ONLY_PROMPT_FILE) | ||
| 213 | |||
| 214 | |||
| 215 | def build_analyze_prompt( | ||
| 216 | metadata: Optional[Dict[str, Any]] = None, | ||
| 217 | include_lyrics: bool = False, | ||
| 218 | label_level: int = 0, | ||
| 219 | ) -> tuple[str, str]: | ||
| 220 | """ | ||
| 221 | 构建完整的分析提示词 | ||
| 222 | |||
| 223 | Args: | ||
| 224 | metadata: 音乐元数据字典(可选) | ||
| 225 | include_lyrics: 是否识别歌词 | ||
| 226 | label_level: 标签级别(0: 一级标签, 1: 一级+二级标签) | ||
| 227 | |||
| 228 | Returns: | ||
| 229 | (system_prompt, user_prompt) 元组 | ||
| 230 | """ | ||
| 231 | builder = PromptBuilder(label_level=label_level) | ||
| 232 | system_prompt = builder.build_system_prompt() | ||
| 233 | user_prompt = builder.build_user_prompt( | ||
| 234 | metadata=metadata, | ||
| 235 | include_lyrics=include_lyrics, | ||
| 236 | include_bpm=True, | ||
| 237 | ) | ||
| 238 | return system_prompt, user_prompt | ||
| 239 | |||
| 240 | |||
| 241 | def build_analyze_prompt_part_a( | ||
| 242 | metadata: Optional[Dict[str, Any]] = None, | ||
| 243 | include_lyrics: bool = False, | ||
| 244 | label_level: int = 0, | ||
| 245 | ) -> tuple[str, str]: | ||
| 246 | """ | ||
| 247 | 构建A组分析提示词(标签与基础信息) | ||
| 248 | """ | ||
| 249 | builder = PromptBuilder(label_level=label_level) | ||
| 250 | system_prompt = builder.build_system_prompt_part_a() | ||
| 251 | user_prompt = builder.build_user_prompt( | ||
| 252 | metadata=metadata, | ||
| 253 | include_lyrics=include_lyrics, | ||
| 254 | include_bpm=True, | ||
| 255 | ) | ||
| 256 | return system_prompt, user_prompt | ||
| 257 | |||
| 258 | |||
| 259 | def build_analyze_prompt_part_b( | ||
| 260 | metadata: Optional[Dict[str, Any]] = None, | ||
| 261 | include_lyrics: bool = False, | ||
| 262 | label_level: int = 0, | ||
| 263 | ) -> tuple[str, str]: | ||
| 264 | """ | ||
| 265 | 构建B组分析提示词(节奏与视觉描述) | ||
| 266 | """ | ||
| 267 | builder = PromptBuilder(label_level=label_level) | ||
| 268 | system_prompt = builder.build_system_prompt_part_b() | ||
| 269 | user_prompt = builder.build_user_prompt( | ||
| 270 | metadata=metadata, | ||
| 271 | include_lyrics=include_lyrics, | ||
| 272 | include_bpm=True, | ||
| 273 | ) | ||
| 274 | return system_prompt, user_prompt | ||
| 275 | |||
| 276 | |||
| 277 | def build_lyrics_prompt() -> str: | ||
| 278 | """构建仅识别歌词的提示词""" | ||
| 279 | builder = PromptBuilder() | ||
| 280 | return builder.build_lyrics_only_prompt() | ||
| 281 | |||
| 282 | |||
| 283 | # 向后兼容:保留原有的构建函数 | ||
| 284 | def build_user_prompt( | ||
| 285 | metadata: Optional[Dict[str, Any]] = None, | ||
| 286 | include_lyrics: bool = False, | ||
| 287 | label_level: int = 0, | ||
| 288 | ) -> str: | ||
| 289 | """构建用户提示词(兼容函数)""" | ||
| 290 | builder = PromptBuilder(label_level=label_level) | ||
| 291 | return builder.build_user_prompt( | ||
| 292 | metadata=metadata, | ||
| 293 | include_lyrics=include_lyrics, | ||
| 294 | include_bpm=True, | ||
| 295 | ) |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """ | ||
| 3 | 通义千问音乐分析器实现 | ||
| 4 | """ | ||
| 5 | |||
| 6 | import os | ||
| 7 | import time | ||
| 8 | import tempfile | ||
| 9 | import subprocess | ||
| 10 | import threading | ||
| 11 | import hashlib | ||
| 12 | import csv | ||
| 13 | from datetime import datetime | ||
| 14 | from pathlib import Path | ||
| 15 | import requests | ||
| 16 | import logging | ||
| 17 | from typing import Dict, Any, Optional, Tuple, List | ||
| 18 | from concurrent.futures import ThreadPoolExecutor | ||
| 19 | |||
| 20 | from .base import AudioAnalyzer | ||
| 21 | from .prompts import ( | ||
| 22 | build_analyze_prompt, | ||
| 23 | build_lyrics_prompt, | ||
| 24 | ) | ||
| 25 | from .audio_features import ( | ||
| 26 | extract_audio_features, | ||
| 27 | extract_beat_timestamps, | ||
| 28 | extract_emotion_curve, | ||
| 29 | aggregate_emotion_by_segments, | ||
| 30 | ) | ||
| 31 | |||
| 32 | # 使用项目统一的配置 | ||
| 33 | from app.core.config import settings | ||
| 34 | |||
| 35 | logger = logging.getLogger(__name__) | ||
| 36 | |||
| 37 | MUSIC_MAPPING_HEADERS = [ | ||
| 38 | "song_id", | ||
| 39 | "audio_file_name", | ||
| 40 | "audio_file_path", | ||
| 41 | "source_url", | ||
| 42 | "updated_at", | ||
| 43 | ] | ||
| 44 | |||
| 45 | MUSIC_MAPPING_HEADER_ALIASES = { | ||
| 46 | "song_id": ("song_id", "歌曲ID"), | ||
| 47 | "audio_file_name": ("audio_file_name", "音频文件名"), | ||
| 48 | "audio_file_path": ("audio_file_path", "音频文件路径"), | ||
| 49 | "source_url": ("source_url", "原始URL"), | ||
| 50 | "updated_at": ("updated_at", "更新时间"), | ||
| 51 | } | ||
| 52 | |||
| 53 | |||
| 54 | class QwenAnalyzer(AudioAnalyzer): | ||
| 55 | """通义千问音乐分析器""" | ||
| 56 | |||
| 57 | def __init__( | ||
| 58 | self, | ||
| 59 | api_key: Optional[str] = None, | ||
| 60 | base_url: Optional[str] = None, | ||
| 61 | model: Optional[str] = None, | ||
| 62 | max_retries: int = 3, | ||
| 63 | ): | ||
| 64 | """ | ||
| 65 | 初始化通义千问分析器 | ||
| 66 | |||
| 67 | Args: | ||
| 68 | api_key: API Key(默认从环境变量读取 QWEN_API_KEY) | ||
| 69 | base_url: API 基础URL(默认从环境变量读取) | ||
| 70 | model: 模型名称(默认: qwen3-omni-flash) | ||
| 71 | timeout: 超时时间(秒) | ||
| 72 | max_retries: 最大重试次数 | ||
| 73 | """ | ||
| 74 | # 优先使用传入的参数,其次使用项目统一的 settings | ||
| 75 | if api_key is None: | ||
| 76 | # 按优先级:QWEN_API_KEY -> QWEN_DASHSCOPE_API_KEY | ||
| 77 | api_key = settings.QWEN_API_KEY or settings.QWEN_DASHSCOPE_API_KEY | ||
| 78 | self.api_key = api_key | ||
| 79 | self.base_url = ( | ||
| 80 | base_url | ||
| 81 | or settings.QWEN_BASE_URL | ||
| 82 | or "https://dashscope.aliyuncs.com/compatible-mode/v1" | ||
| 83 | ) | ||
| 84 | self.model = model or settings.QWEN_MODEL or "qwen3-omni-flash" | ||
| 85 | self.timeout = settings.QWEN_TIMEOUT or 15.0 | ||
| 86 | self.lyrics_timeout = settings.QWEN_LYRICS_TIMEOUT or 90.0 | ||
| 87 | self.max_retries = max_retries or settings.QWEN_MAX_RETRIES or 3 | ||
| 88 | |||
| 89 | self._client = None | ||
| 90 | self._project_root = Path(__file__).resolve().parents[3] | ||
| 91 | self._music_dir = self._resolve_music_dir() | ||
| 92 | self._music_mapping_path = self._resolve_music_mapping_path() | ||
| 93 | self._mapping_lock = threading.Lock() | ||
| 94 | self._mapping_seen: set[tuple[str, str]] = self._load_existing_mapping_keys() | ||
| 95 | |||
| 96 | def _resolve_music_dir(self) -> Path: | ||
| 97 | raw_dir = str(getattr(settings, "MUSIC_DOWNLOAD_DIR", "music") or "music").strip() | ||
| 98 | path = Path(raw_dir) | ||
| 99 | if not path.is_absolute(): | ||
| 100 | path = self._project_root / path | ||
| 101 | path.mkdir(parents=True, exist_ok=True) | ||
| 102 | return path | ||
| 103 | |||
| 104 | def _resolve_music_mapping_path(self) -> Path: | ||
| 105 | raw_file = str( | ||
| 106 | getattr(settings, "MUSIC_MAPPING_FILE", "music/music_file_mapping.csv") | ||
| 107 | or "music/music_file_mapping.csv" | ||
| 108 | ).strip() | ||
| 109 | path = Path(raw_file) | ||
| 110 | if not path.is_absolute(): | ||
| 111 | path = self._project_root / path | ||
| 112 | path.parent.mkdir(parents=True, exist_ok=True) | ||
| 113 | return path | ||
| 114 | |||
| 115 | def _load_existing_mapping_keys(self) -> set[tuple[str, str]]: | ||
| 116 | if not self._music_mapping_path.exists(): | ||
| 117 | return set() | ||
| 118 | seen: set[tuple[str, str]] = set() | ||
| 119 | try: | ||
| 120 | with open(self._music_mapping_path, "r", encoding="utf-8-sig", newline="") as f: | ||
| 121 | reader = csv.DictReader(f) | ||
| 122 | for row in reader: | ||
| 123 | song_id = self._get_mapping_value(row, "song_id") | ||
| 124 | file_path = self._get_mapping_value(row, "audio_file_path") | ||
| 125 | if file_path: | ||
| 126 | try: | ||
| 127 | file_path = str(Path(file_path).resolve()) | ||
| 128 | except Exception: | ||
| 129 | pass | ||
| 130 | seen.add((song_id, file_path)) | ||
| 131 | except Exception: | ||
| 132 | return set() | ||
| 133 | return seen | ||
| 134 | |||
| 135 | def _get_mapping_value(self, row: Dict[str, Any], field: str) -> str: | ||
| 136 | for alias in MUSIC_MAPPING_HEADER_ALIASES.get(field, (field,)): | ||
| 137 | value = row.get(alias) | ||
| 138 | if value is not None and str(value).strip(): | ||
| 139 | return str(value).strip() | ||
| 140 | return "" | ||
| 141 | |||
| 142 | def _extract_song_id(self, metadata: Optional[Dict[str, Any]]) -> str: | ||
| 143 | if not metadata: | ||
| 144 | return "" | ||
| 145 | for key in ("歌曲ID", "song_id", "id", "track_id", "tmeid", "tmeID", "TMEID"): | ||
| 146 | value = metadata.get(key) | ||
| 147 | if value is not None and str(value).strip(): | ||
| 148 | return str(value).strip() | ||
| 149 | return "" | ||
| 150 | |||
| 151 | def _sanitize_filename_part(self, value: str) -> str: | ||
| 152 | safe_chars = [] | ||
| 153 | for ch in value: | ||
| 154 | if ch.isalnum() or ch in {"-", "_", "."}: | ||
| 155 | safe_chars.append(ch) | ||
| 156 | else: | ||
| 157 | safe_chars.append("_") | ||
| 158 | cleaned = "".join(safe_chars).strip("._") | ||
| 159 | return cleaned[:80] if cleaned else "unknown" | ||
| 160 | |||
| 161 | def _build_music_file_path( | ||
| 162 | self, | ||
| 163 | music_url: str, | ||
| 164 | ext: str, | ||
| 165 | metadata: Optional[Dict[str, Any]] = None, | ||
| 166 | ) -> Path: | ||
| 167 | song_id = self._extract_song_id(metadata) | ||
| 168 | song_part = self._sanitize_filename_part(song_id or "unknown") | ||
| 169 | url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12] | ||
| 170 | return self._music_dir / f"{song_part}_{url_hash}{ext}" | ||
| 171 | |||
| 172 | def _append_music_mapping( | ||
| 173 | self, | ||
| 174 | file_path: Path, | ||
| 175 | music_url: str, | ||
| 176 | metadata: Optional[Dict[str, Any]] = None, | ||
| 177 | ) -> None: | ||
| 178 | song_id = self._extract_song_id(metadata) | ||
| 179 | mapping_key = (song_id, str(file_path.resolve())) | ||
| 180 | |||
| 181 | with self._mapping_lock: | ||
| 182 | if mapping_key in self._mapping_seen: | ||
| 183 | return | ||
| 184 | |||
| 185 | write_header = not self._music_mapping_path.exists() | ||
| 186 | encoding = "utf-8-sig" if write_header else "utf-8" | ||
| 187 | with open(self._music_mapping_path, "a", encoding=encoding, newline="") as f: | ||
| 188 | writer = csv.DictWriter( | ||
| 189 | f, | ||
| 190 | fieldnames=MUSIC_MAPPING_HEADERS, | ||
| 191 | ) | ||
| 192 | if write_header: | ||
| 193 | writer.writeheader() | ||
| 194 | writer.writerow( | ||
| 195 | { | ||
| 196 | "song_id": song_id, | ||
| 197 | "audio_file_name": file_path.name, | ||
| 198 | "audio_file_path": str(file_path.resolve()), | ||
| 199 | "source_url": music_url, | ||
| 200 | "updated_at": datetime.now().isoformat(timespec="seconds"), | ||
| 201 | } | ||
| 202 | ) | ||
| 203 | self._mapping_seen.add(mapping_key) | ||
| 204 | |||
| 205 | def _is_persisted_music_file(self, file_path: str) -> bool: | ||
| 206 | try: | ||
| 207 | candidate = Path(file_path).resolve() | ||
| 208 | return candidate.parent == self._music_dir.resolve() | ||
| 209 | except Exception: | ||
| 210 | return False | ||
| 211 | |||
| 212 | def _get_client(self): | ||
| 213 | """获取 OpenAI 兼容客户端""" | ||
| 214 | if self._client is None: | ||
| 215 | from openai import OpenAI | ||
| 216 | |||
| 217 | self._client = OpenAI( | ||
| 218 | api_key=self.api_key, | ||
| 219 | base_url=self.base_url, | ||
| 220 | timeout=self.timeout, | ||
| 221 | max_retries=0, | ||
| 222 | ) | ||
| 223 | return self._client | ||
| 224 | |||
| 225 | def get_provider_name(self) -> str: | ||
| 226 | return "qwen" | ||
| 227 | |||
| 228 | def get_model_name(self) -> str: | ||
| 229 | return self.model | ||
| 230 | |||
| 231 | def _call_songformer(self, music_url: str) -> Optional[Dict]: | ||
| 232 | """ | ||
| 233 | 调用 SongFormer 服务获取歌曲结构和高潮点 | ||
| 234 | |||
| 235 | Args: | ||
| 236 | music_url: 音乐文件 URL | ||
| 237 | |||
| 238 | Returns: | ||
| 239 | SongFormer 返回的完整数据字典 | ||
| 240 | """ | ||
| 241 | songformer_url = getattr(settings, "SONGFORMER_URL", None) | ||
| 242 | if not songformer_url: | ||
| 243 | print(" [Qwen] SongFormer URL 未配置,跳过高潮点分析") | ||
| 244 | return None | ||
| 245 | |||
| 246 | try: | ||
| 247 | print(f" [Qwen] 调用 SongFormer 服务...") | ||
| 248 | resp = requests.post( | ||
| 249 | songformer_url, | ||
| 250 | json={"url": music_url, "chorus_k": 3}, | ||
| 251 | timeout=60, | ||
| 252 | ) | ||
| 253 | resp.raise_for_status() | ||
| 254 | data = resp.json() | ||
| 255 | print(f" [Qwen] SongFormer 调用成功") | ||
| 256 | return data | ||
| 257 | except Exception as e: | ||
| 258 | print(f" [Qwen] SongFormer 调用失败: {e}") | ||
| 259 | return None | ||
| 260 | |||
| 261 | def _extract_climax_point(self, songformer_data: Optional[Dict]) -> str: | ||
| 262 | """ | ||
| 263 | 从 SongFormer 数据中提取高潮点 | ||
| 264 | |||
| 265 | Args: | ||
| 266 | songformer_data: SongFormer 返回的数据 | ||
| 267 | |||
| 268 | Returns: | ||
| 269 | str: "最强", "强", 或 "" | ||
| 270 | """ | ||
| 271 | if not songformer_data: | ||
| 272 | return "" | ||
| 273 | |||
| 274 | # 首先尝试从 climax_points 字段获取(旧格式) | ||
| 275 | climax_points = songformer_data.get("climax_points", {}) | ||
| 276 | if climax_points: | ||
| 277 | # 检查是否有最强高潮 | ||
| 278 | if climax_points.get("strongest_climax"): | ||
| 279 | return "最强" | ||
| 280 | # 检查是否有强高潮 | ||
| 281 | if climax_points.get("strong_climax"): | ||
| 282 | return "强" | ||
| 283 | |||
| 284 | # 从 top_k_chorus 字段获取(新格式) | ||
| 285 | top_k_chorus = songformer_data.get("top_k_chorus", []) | ||
| 286 | if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0: | ||
| 287 | # 按 score 排序,取最高分作为最强高潮 | ||
| 288 | sorted_chorus = sorted( | ||
| 289 | [ | ||
| 290 | c | ||
| 291 | for c in top_k_chorus | ||
| 292 | if isinstance(c, dict) and c.get("score") is not None | ||
| 293 | ], | ||
| 294 | key=lambda x: x.get("score", 0), | ||
| 295 | reverse=True, | ||
| 296 | ) | ||
| 297 | if sorted_chorus: | ||
| 298 | # 最高分 > 7.0 认为是"最强",否则是"强" | ||
| 299 | highest_score = sorted_chorus[0].get("score", 0) | ||
| 300 | if highest_score > 7.0: | ||
| 301 | return "最强" | ||
| 302 | else: | ||
| 303 | return "强" | ||
| 304 | |||
| 305 | return "" | ||
| 306 | |||
| 307 | def _build_climax_points(self, songformer_data: Optional[Dict]) -> Dict[str, Any]: | ||
| 308 | """ | ||
| 309 | 从 SongFormer 数据构建完整的 climax_points 对象 | ||
| 310 | |||
| 311 | Args: | ||
| 312 | songformer_data: SongFormer 返回的数据 | ||
| 313 | |||
| 314 | Returns: | ||
| 315 | 包含 strong_climax 和 strongest_climax 的字典 | ||
| 316 | """ | ||
| 317 | if not songformer_data: | ||
| 318 | return { | ||
| 319 | "strong_climax": None, | ||
| 320 | "strongest_climax": None, | ||
| 321 | "analysis_time": 0.0, | ||
| 322 | } | ||
| 323 | |||
| 324 | # 首先尝试从 climax_points 字段获取(旧格式) | ||
| 325 | climax_points = songformer_data.get("climax_points", {}) | ||
| 326 | if climax_points and ( | ||
| 327 | climax_points.get("strong_climax") or climax_points.get("strongest_climax") | ||
| 328 | ): | ||
| 329 | return { | ||
| 330 | "strong_climax": climax_points.get("strong_climax"), | ||
| 331 | "strongest_climax": climax_points.get("strongest_climax"), | ||
| 332 | "analysis_time": climax_points.get("analysis_time", 0.0), | ||
| 333 | } | ||
| 334 | |||
| 335 | # 从 top_k_chorus 字段构建(新格式) | ||
| 336 | top_k_chorus = songformer_data.get("top_k_chorus", []) | ||
| 337 | segments = songformer_data.get("segments", []) | ||
| 338 | |||
| 339 | if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0: | ||
| 340 | # 按 score 排序 | ||
| 341 | sorted_chorus = sorted( | ||
| 342 | [ | ||
| 343 | c | ||
| 344 | for c in top_k_chorus | ||
| 345 | if isinstance(c, dict) and c.get("score") is not None | ||
| 346 | ], | ||
| 347 | key=lambda x: x.get("score", 0), | ||
| 348 | reverse=True, | ||
| 349 | ) | ||
| 350 | |||
| 351 | if sorted_chorus: | ||
| 352 | # 最高分作为 strongest_climax | ||
| 353 | highest = sorted_chorus[0] | ||
| 354 | highest_score = highest.get("score", 0) | ||
| 355 | |||
| 356 | # 找到对应的段落标签 | ||
| 357 | start_time = highest.get("start", 0) | ||
| 358 | section_label = "chorus" | ||
| 359 | for seg in segments: | ||
| 360 | if isinstance(seg, dict): | ||
| 361 | seg_start = seg.get("start", 0) | ||
| 362 | seg_end = seg.get("end", 0) | ||
| 363 | if seg_start <= start_time < seg_end: | ||
| 364 | section_label = seg.get("label", "chorus") | ||
| 365 | break | ||
| 366 | |||
| 367 | strongest_climax = { | ||
| 368 | "time": start_time, | ||
| 369 | "intensity": "strongest", | ||
| 370 | "section_label": section_label, | ||
| 371 | "reason": f"Highest chorus score: {highest_score:.2f}", | ||
| 372 | } | ||
| 373 | |||
| 374 | # 第二高作为 strong_climax(如果存在且分数差距不大) | ||
| 375 | strong_climax = None | ||
| 376 | if len(sorted_chorus) > 1: | ||
| 377 | second = sorted_chorus[1] | ||
| 378 | second_score = second.get("score", 0) | ||
| 379 | second_start = second.get("start", 0) | ||
| 380 | |||
| 381 | # 找到对应的段落标签 | ||
| 382 | second_section_label = "chorus" | ||
| 383 | for seg in segments: | ||
| 384 | if isinstance(seg, dict): | ||
| 385 | seg_start = seg.get("start", 0) | ||
| 386 | seg_end = seg.get("end", 0) | ||
| 387 | if seg_start <= second_start < seg_end: | ||
| 388 | second_section_label = seg.get("label", "chorus") | ||
| 389 | break | ||
| 390 | |||
| 391 | strong_climax = { | ||
| 392 | "time": second_start, | ||
| 393 | "intensity": "strong", | ||
| 394 | "section_label": second_section_label, | ||
| 395 | "reason": f"Second highest chorus score: {second_score:.2f}", | ||
| 396 | } | ||
| 397 | |||
| 398 | return { | ||
| 399 | "strong_climax": strong_climax, | ||
| 400 | "strongest_climax": strongest_climax, | ||
| 401 | "analysis_time": 0.0, | ||
| 402 | } | ||
| 403 | |||
| 404 | return { | ||
| 405 | "strong_climax": None, | ||
| 406 | "strongest_climax": None, | ||
| 407 | "analysis_time": 0.0, | ||
| 408 | } | ||
| 409 | |||
| 410 | def analyze( | ||
| 411 | self, | ||
| 412 | metadata: Dict[str, Any], | ||
| 413 | music_url: str, | ||
| 414 | extract_lyrics: bool = False, | ||
| 415 | label_level: int = 0, | ||
| 416 | ) -> Optional[Dict[str, Any]]: | ||
| 417 | """ | ||
| 418 | 分析音乐 | ||
| 419 | |||
| 420 | Args: | ||
| 421 | metadata: 音乐元数据 | ||
| 422 | music_url: 音乐文件 URL | ||
| 423 | extract_lyrics: 是否识别歌词 | ||
| 424 | label_level: 标签级别 | ||
| 425 | |||
| 426 | Returns: | ||
| 427 | 分析结果字典 | ||
| 428 | """ | ||
| 429 | client = self._get_client() | ||
| 430 | |||
| 431 | light_mode = bool(getattr(settings, "MUSIC_ANALYZE_LIGHT_MODE", True)) | ||
| 432 | songformer_data = None if light_mode else self._call_songformer(music_url) | ||
| 433 | |||
| 434 | # 下载音频并提取本地特征 | ||
| 435 | local_features = {} | ||
| 436 | tmp_file_path = None | ||
| 437 | try: | ||
| 438 | if light_mode: | ||
| 439 | print(" [Qwen] 轻量模式: 仅提取 BPM") | ||
| 440 | tmp_file_path, _ = self._download_audio(music_url, metadata=metadata) | ||
| 441 | beat_info = extract_beat_timestamps(tmp_file_path) | ||
| 442 | local_features = {"bpm": round(beat_info.tempo)} | ||
| 443 | print(f" [Qwen] 本地特征: BPM={local_features.get('bpm')}") | ||
| 444 | else: | ||
| 445 | print(f" [Qwen] 下载音频并提取本地特征...") | ||
| 446 | tmp_file_path, _ = self._download_audio(music_url, metadata=metadata) | ||
| 447 | |||
| 448 | # 从 songformer 获取段落结构用于情绪聚合 | ||
| 449 | segments = songformer_data.get("segments") if songformer_data else None | ||
| 450 | local_features = self._extract_local_features(tmp_file_path, segments=segments) | ||
| 451 | |||
| 452 | # 从 SongFormer 数据中提取高潮点 | ||
| 453 | climax_point = self._extract_climax_point(songformer_data) | ||
| 454 | local_features["climax_point"] = climax_point | ||
| 455 | |||
| 456 | # 构建完整的 climax_points 对象 | ||
| 457 | climax_points = self._build_climax_points(songformer_data) | ||
| 458 | local_features["climax_points"] = climax_points | ||
| 459 | |||
| 460 | print( | ||
| 461 | f" [Qwen] 本地特征: BPM={local_features.get('bpm')}, " | ||
| 462 | f"段落情绪数={len(local_features.get('segment_emotions', []))}, " | ||
| 463 | f"高潮点={climax_point}" | ||
| 464 | ) | ||
| 465 | except Exception as e: | ||
| 466 | print(f" [Qwen] 本地特征提取失败,将使用LLM估算值: {e}") | ||
| 467 | finally: | ||
| 468 | # 清理临时文件 | ||
| 469 | if ( | ||
| 470 | tmp_file_path | ||
| 471 | and os.path.exists(tmp_file_path) | ||
| 472 | and not self._is_persisted_music_file(tmp_file_path) | ||
| 473 | ): | ||
| 474 | try: | ||
| 475 | os.unlink(tmp_file_path) | ||
| 476 | except: | ||
| 477 | pass | ||
| 478 | |||
| 479 | # 执行LLM分析 | ||
| 480 | if extract_lyrics: | ||
| 481 | result = self._analyze_with_lyrics(client, metadata, music_url, label_level) | ||
| 482 | else: | ||
| 483 | result = self._analyze_basic(client, metadata, music_url, label_level) | ||
| 484 | |||
| 485 | # 合并本地特征到结果中 | ||
| 486 | if result and local_features: | ||
| 487 | # 使用本地提取的值覆盖 | ||
| 488 | result.update(local_features) | ||
| 489 | |||
| 490 | return result | ||
| 491 | |||
| 492 | def _analyze_basic( | ||
| 493 | self, | ||
| 494 | client, | ||
| 495 | metadata: Dict[str, Any], | ||
| 496 | music_url: str, | ||
| 497 | label_level: int = 0, | ||
| 498 | ) -> Optional[Dict[str, Any]]: | ||
| 499 | """基础分析(不含歌词,单轮标签分析)""" | ||
| 500 | # 提取音频ID用于错误定位 | ||
| 501 | song_id = self._extract_song_id(metadata) | ||
| 502 | print(f" [Qwen] 分析音频: 歌曲ID={song_id}") | ||
| 503 | |||
| 504 | system_prompt, user_prompt = build_analyze_prompt( | ||
| 505 | metadata=metadata, | ||
| 506 | include_lyrics=False, | ||
| 507 | label_level=label_level, | ||
| 508 | ) | ||
| 509 | |||
| 510 | prompt = self._build_dashscope_prompt(system_prompt, user_prompt) | ||
| 511 | response = self._call_with_retry_dashscope(music_url, prompt, song_id=song_id, metadata=metadata) | ||
| 512 | if response is None: | ||
| 513 | return None | ||
| 514 | |||
| 515 | raw_content = response.get("content", "") | ||
| 516 | parsed = self._parse_response(raw_content) | ||
| 517 | if parsed is None: | ||
| 518 | return None | ||
| 519 | if isinstance(parsed, list): | ||
| 520 | if parsed and isinstance(parsed[0], dict): | ||
| 521 | parsed = parsed[0] | ||
| 522 | else: | ||
| 523 | return None | ||
| 524 | if not isinstance(parsed, dict): | ||
| 525 | return None | ||
| 526 | |||
| 527 | return self._normalize_result(parsed, self.model, response.get("usage")) | ||
| 528 | |||
| 529 | def _download_audio( | ||
| 530 | self, music_url: str, metadata: Optional[Dict[str, Any]] = None | ||
| 531 | ) -> Tuple[str, str]: | ||
| 532 | """ | ||
| 533 | 下载音频文件到 music 目录(按 URL+歌曲ID 命名并复用缓存) | ||
| 534 | |||
| 535 | Args: | ||
| 536 | music_url: 音频URL | ||
| 537 | metadata: 音乐元数据(用于提取歌曲ID生成映射表) | ||
| 538 | |||
| 539 | Returns: | ||
| 540 | (本地文件路径, 文件扩展名) | ||
| 541 | """ | ||
| 542 | # 确定文件扩展名 | ||
| 543 | ext = ".mp3" | ||
| 544 | if "." in music_url: | ||
| 545 | url_ext = music_url.split(".")[-1].split("?")[0].lower() | ||
| 546 | if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]: | ||
| 547 | ext = f".{url_ext}" | ||
| 548 | |||
| 549 | target_path = self._build_music_file_path(music_url, ext, metadata=metadata) | ||
| 550 | |||
| 551 | if not target_path.exists(): | ||
| 552 | response = requests.get(music_url, timeout=60) | ||
| 553 | response.raise_for_status() | ||
| 554 | with open(target_path, "wb") as f: | ||
| 555 | f.write(response.content) | ||
| 556 | print(f" [Qwen] 音频已保存: {target_path}") | ||
| 557 | |||
| 558 | self._append_music_mapping(target_path, music_url, metadata=metadata) | ||
| 559 | return str(target_path), ext | ||
| 560 | |||
| 561 | def _extract_local_features( | ||
| 562 | self, | ||
| 563 | audio_path: str, | ||
| 564 | segments: Optional[List[Dict[str, Any]]] = None, | ||
| 565 | ) -> Dict[str, Any]: | ||
| 566 | """ | ||
| 567 | 提取本地音频特征 | ||
| 568 | |||
| 569 | Args: | ||
| 570 | audio_path: 本地音频文件路径 | ||
| 571 | segments: songformer 返回的段落结构(可选),用于聚合情绪曲线 | ||
| 572 | |||
| 573 | Returns: | ||
| 574 | 包含bpm、卡点时间戳、情绪曲线的字典 | ||
| 575 | """ | ||
| 576 | try: | ||
| 577 | features = extract_audio_features(audio_path) | ||
| 578 | |||
| 579 | # 卡点检测 | ||
| 580 | beat_info = extract_beat_timestamps(audio_path) | ||
| 581 | |||
| 582 | # 情绪曲线 | ||
| 583 | emotion_curve = extract_emotion_curve(audio_path) | ||
| 584 | |||
| 585 | # beat_info.tempo 经过节拍层级纠正,比 features.tempo 更准确 | ||
| 586 | result = { | ||
| 587 | "bpm": round(beat_info.tempo), | ||
| 588 | # 卡点信息 | ||
| 589 | "beat_timestamps": beat_info.beat_timestamps, | ||
| 590 | "downbeat_timestamps": beat_info.downbeat_timestamps, | ||
| 591 | "beat_intervals": beat_info.beat_intervals, | ||
| 592 | } | ||
| 593 | |||
| 594 | # 如果有段落结构,返回按段落聚合的情绪数据 | ||
| 595 | if segments: | ||
| 596 | segment_emotions = aggregate_emotion_by_segments(emotion_curve, segments) | ||
| 597 | result["segment_emotions"] = [ | ||
| 598 | { | ||
| 599 | "start": se.start, | ||
| 600 | "end": se.end, | ||
| 601 | "label": se.label, | ||
| 602 | "intensity": se.intensity, | ||
| 603 | "energy": se.energy, | ||
| 604 | "valence": se.valence, | ||
| 605 | "arousal": se.arousal, | ||
| 606 | "trend": se.trend, | ||
| 607 | } | ||
| 608 | for se in segment_emotions | ||
| 609 | ] | ||
| 610 | else: | ||
| 611 | # 没有段落结构时,返回原始情绪曲线 | ||
| 612 | result["emotion_curve"] = { | ||
| 613 | "timestamps": emotion_curve.timestamps, | ||
| 614 | "energy_values": emotion_curve.energy_values, | ||
| 615 | "valence_values": emotion_curve.valence_values, | ||
| 616 | "arousal_values": emotion_curve.arousal_values, | ||
| 617 | "values": emotion_curve.smoothed_curve, | ||
| 618 | } | ||
| 619 | |||
| 620 | return result | ||
| 621 | except Exception as e: | ||
| 622 | print(f" [Qwen] 本地特征提取失败: {e}") | ||
| 623 | return {} | ||
| 624 | |||
| 625 | def _analyze_with_lyrics( | ||
| 626 | self, | ||
| 627 | client, | ||
| 628 | metadata: Dict[str, Any], | ||
| 629 | music_url: str, | ||
| 630 | label_level: int = 0, | ||
| 631 | ) -> Optional[Dict[str, Any]]: | ||
| 632 | """分析(含歌词识别,单轮标签分析 + 歌词并发)""" | ||
| 633 | # 提取音频ID用于错误定位 | ||
| 634 | song_id = self._extract_song_id(metadata) | ||
| 635 | print(f" [Qwen] 分析音频: 歌曲ID={song_id}") | ||
| 636 | |||
| 637 | system_prompt, user_prompt = build_analyze_prompt( | ||
| 638 | metadata=metadata, | ||
| 639 | include_lyrics=False, | ||
| 640 | label_level=label_level, | ||
| 641 | ) | ||
| 642 | |||
| 643 | prompt = self._build_dashscope_prompt(system_prompt, user_prompt) | ||
| 644 | |||
| 645 | lyrics_prompt = build_lyrics_prompt() | ||
| 646 | |||
| 647 | messages_lyrics = self._build_messages( | ||
| 648 | "请识别这段音频中的歌词内容", lyrics_prompt, music_url | ||
| 649 | ) | ||
| 650 | |||
| 651 | print(" [Qwen] 并发执行基础标签分析和歌词识别...") | ||
| 652 | start_time = time.time() | ||
| 653 | |||
| 654 | result_main: Optional[Dict[str, Any]] = None | ||
| 655 | usage_main: Optional[Dict[str, Any]] = None | ||
| 656 | response_lyrics = None | ||
| 657 | timing: Dict[str, float] = {} | ||
| 658 | |||
| 659 | def _timed_call_dashscope(prompt_text: str) -> tuple[Optional[Dict], float]: | ||
| 660 | call_start = time.time() | ||
| 661 | resp = self._call_with_retry_dashscope(music_url, prompt_text, song_id=song_id, metadata=metadata) | ||
| 662 | return resp, round(time.time() - call_start, 2) | ||
| 663 | |||
| 664 | futures = {} | ||
| 665 | with ThreadPoolExecutor(max_workers=2) as executor: | ||
| 666 | futures[executor.submit(_timed_call_dashscope, prompt)] = "main" | ||
| 667 | futures[executor.submit(self._timed_call_openai, client, messages_lyrics)] = "lyrics" | ||
| 668 | |||
| 669 | for future in futures: | ||
| 670 | part = futures[future] | ||
| 671 | response, part_elapsed = future.result() | ||
| 672 | if part == "lyrics": | ||
| 673 | timing["lyrics"] = part_elapsed | ||
| 674 | response_lyrics = response | ||
| 675 | continue | ||
| 676 | timing["analysis"] = part_elapsed | ||
| 677 | if response is None: | ||
| 678 | continue | ||
| 679 | raw_content = response.get("content", "") | ||
| 680 | parsed = self._parse_response(raw_content) | ||
| 681 | if parsed is None: | ||
| 682 | continue | ||
| 683 | if isinstance(parsed, list): | ||
| 684 | if parsed and isinstance(parsed[0], dict): | ||
| 685 | parsed = parsed[0] | ||
| 686 | else: | ||
| 687 | continue | ||
| 688 | if not isinstance(parsed, dict): | ||
| 689 | continue | ||
| 690 | result_main = parsed | ||
| 691 | usage_main = response.get("usage") | ||
| 692 | |||
| 693 | elapsed = time.time() - start_time | ||
| 694 | print(f" [Qwen] 并发调用完成,总耗时: {elapsed:.2f}s") | ||
| 695 | |||
| 696 | if result_main is None: | ||
| 697 | return None | ||
| 698 | if not isinstance(result_main, dict): | ||
| 699 | return None | ||
| 700 | |||
| 701 | result: Dict[str, Any] = dict(result_main) | ||
| 702 | |||
| 703 | # 处理歌词识别结果 | ||
| 704 | if response_lyrics: | ||
| 705 | raw_lyrics = response_lyrics.get("content", "") | ||
| 706 | lyrics_result = self._parse_response(raw_lyrics) | ||
| 707 | if isinstance(lyrics_result, list): | ||
| 708 | if lyrics_result and isinstance(lyrics_result[0], dict): | ||
| 709 | lyrics_result = lyrics_result[0] | ||
| 710 | if lyrics_result and "lyrics" in lyrics_result: | ||
| 711 | result["lyrics"] = lyrics_result["lyrics"] | ||
| 712 | result["_timing"] = timing | ||
| 713 | |||
| 714 | # 合并 token 使用信息 | ||
| 715 | usage: Dict[str, Any] = {} | ||
| 716 | if usage_main: | ||
| 717 | usage.update(usage_main) | ||
| 718 | if response_lyrics and response_lyrics.get("usage"): | ||
| 719 | usage_lyrics = response_lyrics["usage"] | ||
| 720 | usage = { | ||
| 721 | "prompt_tokens": usage.get("prompt_tokens", 0) | ||
| 722 | + usage_lyrics.get("prompt_tokens", 0), | ||
| 723 | "completion_tokens": usage.get("completion_tokens", 0) | ||
| 724 | + usage_lyrics.get("completion_tokens", 0), | ||
| 725 | "total_tokens": usage.get("total_tokens", 0) | ||
| 726 | + usage_lyrics.get("total_tokens", 0), | ||
| 727 | } | ||
| 728 | |||
| 729 | result["_token_info_parts"] = { | ||
| 730 | "main": usage_main, | ||
| 731 | "lyrics": response_lyrics.get("usage") if response_lyrics else None, | ||
| 732 | } | ||
| 733 | |||
| 734 | return self._normalize_result(result, self.model, usage) | ||
| 735 | |||
| 736 | def analyze_lyrics_only( | ||
| 737 | self, | ||
| 738 | metadata: Dict[str, Any], | ||
| 739 | music_url: str, | ||
| 740 | ) -> Optional[Dict[str, Any]]: | ||
| 741 | """仅执行歌词识别,不做基础标签分析(ASR异步任务)""" | ||
| 742 | backend = ( | ||
| 743 | str( | ||
| 744 | os.getenv("MUSIC_LYRICS_ASR_BACKEND") | ||
| 745 | or getattr(settings, "MUSIC_LYRICS_ASR_BACKEND", "funasr") | ||
| 746 | ) | ||
| 747 | .strip() | ||
| 748 | .lower() | ||
| 749 | ) | ||
| 750 | |||
| 751 | if backend == "whisper": | ||
| 752 | analyze_fn = self._analyze_lyrics_only_whisper | ||
| 753 | elif backend in {"omni", "qwen-omni", "qwen_omni"}: | ||
| 754 | # qwen-omni: 单轮流程内最多3次请求,失败后直接降级 funasr | ||
| 755 | omni_result = self._analyze_lyrics_only_qwen_omni(music_url) | ||
| 756 | if omni_result: | ||
| 757 | return omni_result | ||
| 758 | logger.warning( | ||
| 759 | "qwen-omni 歌词识别失败,降级到 funasr (lyrics_timeout=%ss)", | ||
| 760 | self.lyrics_timeout, | ||
| 761 | ) | ||
| 762 | |||
| 763 | fallback_retry_count = 1 | ||
| 764 | fallback_retry_delay_seconds = 2.0 | ||
| 765 | for attempt in range(1, fallback_retry_count + 2): | ||
| 766 | fallback_result = self._analyze_lyrics_only_funasr(music_url) | ||
| 767 | if fallback_result: | ||
| 768 | logger.info( | ||
| 769 | "funasr 降级成功: attempt=%s/%s", | ||
| 770 | attempt, | ||
| 771 | fallback_retry_count + 1, | ||
| 772 | ) | ||
| 773 | return fallback_result | ||
| 774 | |||
| 775 | if attempt <= fallback_retry_count: | ||
| 776 | logger.warning( | ||
| 777 | "funasr 降级失败,%s 秒后重试 (%s/%s)", | ||
| 778 | fallback_retry_delay_seconds, | ||
| 779 | attempt, | ||
| 780 | fallback_retry_count, | ||
| 781 | ) | ||
| 782 | time.sleep(fallback_retry_delay_seconds) | ||
| 783 | |||
| 784 | logger.warning("funasr 降级失败,继续降级到 whisper") | ||
| 785 | whisper_result = self._analyze_lyrics_only_whisper(music_url) | ||
| 786 | if whisper_result: | ||
| 787 | logger.info("whisper 降级成功") | ||
| 788 | return whisper_result | ||
| 789 | |||
| 790 | logger.error("歌词识别降级链全部失败: qwen-omni -> funasr -> whisper") | ||
| 791 | return None | ||
| 792 | elif backend in {"fun", "funasr", "fun-asr"}: | ||
| 793 | analyze_fn = self._analyze_lyrics_only_funasr | ||
| 794 | else: | ||
| 795 | logger.error( | ||
| 796 | "不支持的歌词识别后端: %s,仅支持 whisper/funasr/qwen-omni", | ||
| 797 | backend, | ||
| 798 | ) | ||
| 799 | return None | ||
| 800 | |||
| 801 | retry_count = 2 | ||
| 802 | retry_delay_seconds = 2.0 | ||
| 803 | for attempt in range(1, retry_count + 2): | ||
| 804 | result = analyze_fn(music_url) | ||
| 805 | if result: | ||
| 806 | return result | ||
| 807 | if attempt <= retry_count: | ||
| 808 | logger.warning( | ||
| 809 | "歌词识别失败,%s 秒后重试 (%d/%d): backend=%s", | ||
| 810 | retry_delay_seconds, | ||
| 811 | attempt, | ||
| 812 | retry_count, | ||
| 813 | backend, | ||
| 814 | ) | ||
| 815 | time.sleep(retry_delay_seconds) | ||
| 816 | return None | ||
| 817 | |||
| 818 | def _analyze_lyrics_only_qwen_omni(self, music_url: str) -> Optional[Dict[str, Any]]: | ||
| 819 | """qwen-omni V2 版歌词识别流程""" | ||
| 820 | client = self._get_client() | ||
| 821 | logger.info( | ||
| 822 | "开始 qwen-omni 歌词识别: timeout=%ss, max_retries=%s", | ||
| 823 | self.lyrics_timeout, | ||
| 824 | 3, | ||
| 825 | ) | ||
| 826 | lyrics_prompt = build_lyrics_prompt() | ||
| 827 | messages = self._build_messages( | ||
| 828 | "请识别这段音频中的歌词内容", | ||
| 829 | lyrics_prompt, | ||
| 830 | music_url, | ||
| 831 | ) | ||
| 832 | |||
| 833 | response = self._call_with_retry(client, messages, max_retries=3) | ||
| 834 | if response is None: | ||
| 835 | return None | ||
| 836 | |||
| 837 | parsed = self._parse_response(response.get("content", "")) | ||
| 838 | payload: Any = parsed | ||
| 839 | if isinstance(parsed, dict): | ||
| 840 | payload = ( | ||
| 841 | parsed.get("lyrics") | ||
| 842 | or parsed.get("lyric") | ||
| 843 | or parsed.get("歌词") | ||
| 844 | or parsed | ||
| 845 | ) | ||
| 846 | |||
| 847 | lyrics = self._convert_qwen_omni_payload_to_lyrics(payload) | ||
| 848 | |||
| 849 | return { | ||
| 850 | "lyrics": lyrics, | ||
| 851 | "_model": self.model, | ||
| 852 | "_token_info": response.get("usage"), | ||
| 853 | "_transcription_url": None, | ||
| 854 | "_asr_task_id": None, | ||
| 855 | "_asr_backend": "qwen-omni", | ||
| 856 | } | ||
| 857 | |||
| 858 | def _convert_qwen_omni_payload_to_lyrics(self, payload: Any) -> List[Dict[str, Any]]: | ||
| 859 | """将 qwen-omni 返回的 lyric 结构统一为 [{time, text}]""" | ||
| 860 | if payload is None: | ||
| 861 | return [] | ||
| 862 | |||
| 863 | if isinstance(payload, str): | ||
| 864 | lines = [line.strip() for line in payload.splitlines() if line.strip()] | ||
| 865 | return [{"time": None, "text": line} for line in lines] | ||
| 866 | |||
| 867 | if isinstance(payload, dict): | ||
| 868 | candidate = ( | ||
| 869 | payload.get("lyrics") | ||
| 870 | or payload.get("lines") | ||
| 871 | or payload.get("歌词") | ||
| 872 | or payload.get("lyric") | ||
| 873 | ) | ||
| 874 | return self._convert_qwen_omni_payload_to_lyrics(candidate) | ||
| 875 | |||
| 876 | if isinstance(payload, list): | ||
| 877 | lyrics: List[Dict[str, Any]] = [] | ||
| 878 | for item in payload: | ||
| 879 | if isinstance(item, str): | ||
| 880 | line = item.strip() | ||
| 881 | if line: | ||
| 882 | lyrics.append({"time": None, "text": line}) | ||
| 883 | continue | ||
| 884 | |||
| 885 | if not isinstance(item, dict): | ||
| 886 | continue | ||
| 887 | |||
| 888 | text = item.get("text") or item.get("lyric") or item.get("歌词") | ||
| 889 | if not isinstance(text, str): | ||
| 890 | text = str(text) if text is not None else "" | ||
| 891 | text = text.strip() | ||
| 892 | if not text: | ||
| 893 | continue | ||
| 894 | |||
| 895 | time_str = item.get("time") | ||
| 896 | if not isinstance(time_str, str): | ||
| 897 | time_str = None | ||
| 898 | lyrics.append({"time": time_str, "text": text}) | ||
| 899 | return lyrics | ||
| 900 | |||
| 901 | return [] | ||
| 902 | |||
| 903 | def _analyze_lyrics_only_whisper(self, music_url: str) -> Optional[Dict[str, Any]]: | ||
| 904 | """whisper-1 版歌词识别流程(91 API)""" | ||
| 905 | try: | ||
| 906 | from dotenv import load_dotenv | ||
| 907 | |||
| 908 | load_dotenv() | ||
| 909 | except Exception: | ||
| 910 | pass | ||
| 911 | |||
| 912 | api_key = (os.getenv("API_KEY_whisper") or os.getenv("91API_KEY") or "").strip() | ||
| 913 | if not api_key: | ||
| 914 | logger.error("whisper 调用失败: 缺少环境变量 API_KEY_whisper/91API_KEY") | ||
| 915 | return None | ||
| 916 | |||
| 917 | api_url = os.getenv( | ||
| 918 | "WHISPER_API_URL", | ||
| 919 | "https://xuedingmao.top/v1/audio/transcriptions", | ||
| 920 | ).strip() | ||
| 921 | headers = {"Authorization": f"Bearer {api_key}"} | ||
| 922 | |||
| 923 | tmp_file_path = None | ||
| 924 | upload_file_path = None | ||
| 925 | ext = ".mp3" | ||
| 926 | try: | ||
| 927 | tmp_file_path, ext = self._download_audio(music_url, metadata=None) | ||
| 928 | upload_file_path = tmp_file_path | ||
| 929 | upload_ext = ext | ||
| 930 | if ext.lower() == ".flac": | ||
| 931 | converted_wav = self._convert_audio_to_wav_for_whisper(tmp_file_path) | ||
| 932 | if converted_wav: | ||
| 933 | upload_file_path = converted_wav | ||
| 934 | upload_ext = ".wav" | ||
| 935 | logger.info("whisper 上传文件已从 flac 转换为 wav") | ||
| 936 | |||
| 937 | filename = f"audio{upload_ext}" | ||
| 938 | print(f"下载完成:{filename}") | ||
| 939 | |||
| 940 | content_type = "audio/wav" if upload_ext == ".wav" else "audio/mpeg" | ||
| 941 | |||
| 942 | with open(upload_file_path, "rb") as audio_file: | ||
| 943 | files = { | ||
| 944 | "file": (filename, audio_file, content_type), | ||
| 945 | } | ||
| 946 | data = { | ||
| 947 | "model": "whisper-1", | ||
| 948 | "response_format": "verbose_json", | ||
| 949 | "timestamp_granularities": ["segment"], | ||
| 950 | "prompt": "没有歌词的片段用...代替,时间戳需要精准与每句歌词进行对应,对于纯音乐直接输出‘纯音乐,禁止输出歌名,作词/作曲等元数据,仅输出歌词与时间戳’", | ||
| 951 | } | ||
| 952 | response = requests.post( | ||
| 953 | api_url, | ||
| 954 | headers=headers, | ||
| 955 | data=data, | ||
| 956 | files=files, | ||
| 957 | timeout=300, | ||
| 958 | ) | ||
| 959 | if response.status_code >= 400: | ||
| 960 | logger.error( | ||
| 961 | "whisper API 返回错误: status=%s, body=%s", | ||
| 962 | response.status_code, | ||
| 963 | response.text, | ||
| 964 | ) | ||
| 965 | response.raise_for_status() | ||
| 966 | payload = response.json() | ||
| 967 | except Exception as exc: | ||
| 968 | logger.exception("whisper API 调用失败: %s", exc) | ||
| 969 | return None | ||
| 970 | finally: | ||
| 971 | if ( | ||
| 972 | tmp_file_path | ||
| 973 | and os.path.exists(tmp_file_path) | ||
| 974 | and not self._is_persisted_music_file(tmp_file_path) | ||
| 975 | ): | ||
| 976 | try: | ||
| 977 | os.unlink(tmp_file_path) | ||
| 978 | except Exception: | ||
| 979 | pass | ||
| 980 | if ( | ||
| 981 | upload_file_path | ||
| 982 | and upload_file_path != tmp_file_path | ||
| 983 | and os.path.exists(upload_file_path) | ||
| 984 | ): | ||
| 985 | try: | ||
| 986 | os.unlink(upload_file_path) | ||
| 987 | except Exception: | ||
| 988 | pass | ||
| 989 | |||
| 990 | lyrics = self._convert_whisper_payload_to_lyrics(payload) | ||
| 991 | return { | ||
| 992 | "lyrics": lyrics, | ||
| 993 | "_model": "whisper-1", | ||
| 994 | "_token_info": None, | ||
| 995 | "_transcription_url": None, | ||
| 996 | "_asr_task_id": None, | ||
| 997 | "_asr_backend": "whisper", | ||
| 998 | } | ||
| 999 | |||
| 1000 | def _convert_whisper_payload_to_lyrics( | ||
| 1001 | self, payload: Any | ||
| 1002 | ) -> List[Dict[str, Any]]: | ||
| 1003 | """将 whisper 接口响应转换为 lyrics: [{time, text}]""" | ||
| 1004 | if not isinstance(payload, dict): | ||
| 1005 | return [] | ||
| 1006 | |||
| 1007 | segments = payload.get("segments") | ||
| 1008 | if isinstance(segments, list): | ||
| 1009 | lyrics: List[Dict[str, Any]] = [] | ||
| 1010 | for seg in segments: | ||
| 1011 | if not isinstance(seg, dict): | ||
| 1012 | continue | ||
| 1013 | text = seg.get("text") | ||
| 1014 | if not isinstance(text, str): | ||
| 1015 | continue | ||
| 1016 | text = text.strip() | ||
| 1017 | if not text: | ||
| 1018 | continue | ||
| 1019 | |||
| 1020 | start = seg.get("start") | ||
| 1021 | if not isinstance(start, (int, float)): | ||
| 1022 | # 兼容部分接口返回 begin_time(毫秒) | ||
| 1023 | begin_time = seg.get("begin_time") | ||
| 1024 | if isinstance(begin_time, (int, float)): | ||
| 1025 | start = float(begin_time) / 1000.0 | ||
| 1026 | |||
| 1027 | time_str = None | ||
| 1028 | if isinstance(start, (int, float)): | ||
| 1029 | try: | ||
| 1030 | time_str = self._format_asr_time_ms(float(start) * 1000) | ||
| 1031 | except (TypeError, ValueError, OverflowError): | ||
| 1032 | time_str = None | ||
| 1033 | lyrics.append({"time": time_str, "text": text}) | ||
| 1034 | if lyrics: | ||
| 1035 | return lyrics | ||
| 1036 | |||
| 1037 | text = payload.get("text") | ||
| 1038 | if isinstance(text, str) and text.strip(): | ||
| 1039 | return [{"time": None, "text": text.strip()}] | ||
| 1040 | return [] | ||
| 1041 | |||
| 1042 | def _convert_audio_to_wav_for_whisper(self, source_audio_path: str) -> Optional[str]: | ||
| 1043 | """ | ||
| 1044 | 将音频转换为 whisper 更稳定支持的 WAV 格式。 | ||
| 1045 | """ | ||
| 1046 | try: | ||
| 1047 | with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_tmp: | ||
| 1048 | wav_path = wav_tmp.name | ||
| 1049 | |||
| 1050 | cmd = [ | ||
| 1051 | "ffmpeg", | ||
| 1052 | "-y", | ||
| 1053 | "-i", | ||
| 1054 | source_audio_path, | ||
| 1055 | "-acodec", | ||
| 1056 | "pcm_s16le", | ||
| 1057 | "-ac", | ||
| 1058 | "1", | ||
| 1059 | "-ar", | ||
| 1060 | "16000", | ||
| 1061 | wav_path, | ||
| 1062 | ] | ||
| 1063 | subprocess.run(cmd, check=True, capture_output=True, text=True) | ||
| 1064 | return wav_path | ||
| 1065 | except Exception as exc: | ||
| 1066 | logger.warning("flac 转 wav 失败,将继续使用原文件: %s", exc) | ||
| 1067 | return None | ||
| 1068 | |||
| 1069 | def _analyze_lyrics_only_funasr(self, music_url: str) -> Optional[Dict[str, Any]]: | ||
| 1070 | """fun-asr SDK 版异步 ASR 流程""" | ||
| 1071 | try: | ||
| 1072 | from http import HTTPStatus | ||
| 1073 | import dashscope | ||
| 1074 | from dashscope.audio.asr import Transcription | ||
| 1075 | except Exception as exc: | ||
| 1076 | logger.exception("导入 dashscope.audio.asr.Transcription 失败: %s", exc) | ||
| 1077 | return None | ||
| 1078 | |||
| 1079 | api_key = self._get_dashscope_api_key() | ||
| 1080 | if not api_key: | ||
| 1081 | logger.error("funasr 调用失败: 缺少 DashScope API Key") | ||
| 1082 | return None | ||
| 1083 | |||
| 1084 | asr_model = getattr(settings, "DASHSCOPE_FUNASR_MODEL", "fun-asr") | ||
| 1085 | dashscope.base_http_api_url = getattr( | ||
| 1086 | settings, | ||
| 1087 | "DASHSCOPE_BASE_HTTP_API_URL", | ||
| 1088 | "https://dashscope.aliyuncs.com/api/v1", | ||
| 1089 | ) | ||
| 1090 | dashscope.api_key = api_key | ||
| 1091 | poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0)) | ||
| 1092 | poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0)) | ||
| 1093 | |||
| 1094 | try: | ||
| 1095 | task_resp = Transcription.async_call( | ||
| 1096 | model=asr_model, | ||
| 1097 | file_urls=[music_url], | ||
| 1098 | ) | ||
| 1099 | except Exception as exc: | ||
| 1100 | logger.exception("funasr async_call 失败: %s", exc) | ||
| 1101 | return None | ||
| 1102 | |||
| 1103 | task_id = self._extract_task_id_from_asr_response(task_resp) | ||
| 1104 | latest_resp: Any = task_resp | ||
| 1105 | |||
| 1106 | deadline = time.time() + poll_timeout | ||
| 1107 | while time.time() < deadline: | ||
| 1108 | task_status = self._extract_task_status_from_asr_response(latest_resp) | ||
| 1109 | if task_status == "SUCCEEDED": | ||
| 1110 | break | ||
| 1111 | if task_status in {"FAILED", "CANCELED"}: | ||
| 1112 | logger.error( | ||
| 1113 | "funasr 任务失败: task_id=%s, status=%s", | ||
| 1114 | task_id, | ||
| 1115 | task_status, | ||
| 1116 | ) | ||
| 1117 | return None | ||
| 1118 | try: | ||
| 1119 | latest_resp = Transcription.fetch( | ||
| 1120 | task=latest_resp, | ||
| 1121 | ) | ||
| 1122 | except Exception as exc: | ||
| 1123 | logger.exception("funasr fetch 失败: %s", exc) | ||
| 1124 | return None | ||
| 1125 | time.sleep(poll_interval) | ||
| 1126 | else: | ||
| 1127 | logger.error("funasr 轮询超时: task_id=%s", task_id) | ||
| 1128 | return None | ||
| 1129 | |||
| 1130 | status_code = getattr(latest_resp, "status_code", None) | ||
| 1131 | if status_code is not None and status_code != HTTPStatus.OK: | ||
| 1132 | logger.error( | ||
| 1133 | "funasr 返回非OK状态: task_id=%s, status_code=%s", | ||
| 1134 | task_id, | ||
| 1135 | status_code, | ||
| 1136 | ) | ||
| 1137 | return None | ||
| 1138 | |||
| 1139 | transcription_url = self._extract_transcription_url_from_asr_response(latest_resp) | ||
| 1140 | if not transcription_url: | ||
| 1141 | logger.error("funasr 结果缺少 transcription_url: task_id=%s", task_id) | ||
| 1142 | return None | ||
| 1143 | |||
| 1144 | transcript_data = self._fetch_asr_transcription(transcription_url) | ||
| 1145 | if not transcript_data: | ||
| 1146 | return None | ||
| 1147 | |||
| 1148 | lyrics = self._convert_asr_transcription_to_lyrics(transcript_data) | ||
| 1149 | token_info = self._extract_usage_from_asr_response(latest_resp) | ||
| 1150 | |||
| 1151 | return { | ||
| 1152 | "lyrics": lyrics, | ||
| 1153 | "_model": asr_model, | ||
| 1154 | "_token_info": token_info, | ||
| 1155 | "_transcription_url": transcription_url, | ||
| 1156 | "_asr_task_id": task_id, | ||
| 1157 | "_asr_backend": "funasr", | ||
| 1158 | } | ||
| 1159 | |||
| 1160 | def _submit_asr_transcription_task(self, music_url: str) -> Optional[str]: | ||
| 1161 | """提交 DashScope 异步ASR任务,返回 task_id""" | ||
| 1162 | api_key = self._get_dashscope_api_key() | ||
| 1163 | if not api_key: | ||
| 1164 | logger.error("提交ASR任务失败: 缺少 DashScope API Key") | ||
| 1165 | return None | ||
| 1166 | |||
| 1167 | submit_url = getattr( | ||
| 1168 | settings, | ||
| 1169 | "DASHSCOPE_ASR_SUBMIT_URL", | ||
| 1170 | "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription", | ||
| 1171 | ) | ||
| 1172 | asr_model = getattr(settings, "DASHSCOPE_ASR_MODEL", "qwen3-asr-flash-filetrans") | ||
| 1173 | |||
| 1174 | headers = { | ||
| 1175 | "Authorization": f"Bearer {api_key}", | ||
| 1176 | "Content-Type": "application/json", | ||
| 1177 | "X-DashScope-Async": "enable", | ||
| 1178 | } | ||
| 1179 | payload = { | ||
| 1180 | "model": asr_model, | ||
| 1181 | "input": {"file_url": music_url}, | ||
| 1182 | "parameters": { | ||
| 1183 | "channel_id": [0], | ||
| 1184 | "enable_itn": False, | ||
| 1185 | "enable_words": False, | ||
| 1186 | }, | ||
| 1187 | } | ||
| 1188 | |||
| 1189 | try: | ||
| 1190 | response = requests.post( | ||
| 1191 | submit_url, | ||
| 1192 | headers=headers, | ||
| 1193 | json=payload, | ||
| 1194 | timeout=self.timeout, | ||
| 1195 | ) | ||
| 1196 | response.raise_for_status() | ||
| 1197 | data = response.json() | ||
| 1198 | except Exception as exc: | ||
| 1199 | logger.exception("提交ASR任务异常: %s", exc) | ||
| 1200 | return None | ||
| 1201 | |||
| 1202 | output = data.get("output") if isinstance(data, dict) else None | ||
| 1203 | if not isinstance(output, dict): | ||
| 1204 | logger.error("提交ASR任务失败: 缺少 output 字段") | ||
| 1205 | return None | ||
| 1206 | |||
| 1207 | task_id = output.get("task_id") | ||
| 1208 | if not isinstance(task_id, str) or not task_id.strip(): | ||
| 1209 | logger.error("提交ASR任务失败: 缺少 task_id") | ||
| 1210 | return None | ||
| 1211 | return task_id.strip() | ||
| 1212 | |||
| 1213 | def _poll_asr_task_result(self, task_id: str) -> Optional[Dict[str, Any]]: | ||
| 1214 | """轮询 DashScope 任务直到结束""" | ||
| 1215 | api_key = self._get_dashscope_api_key() | ||
| 1216 | if not api_key: | ||
| 1217 | logger.error("轮询ASR任务失败: 缺少 DashScope API Key") | ||
| 1218 | return None | ||
| 1219 | |||
| 1220 | task_base_url = getattr( | ||
| 1221 | settings, | ||
| 1222 | "DASHSCOPE_TASK_STATUS_BASE_URL", | ||
| 1223 | "https://dashscope.aliyuncs.com/api/v1/tasks", | ||
| 1224 | ).rstrip("/") | ||
| 1225 | task_url = f"{task_base_url}/{task_id}" | ||
| 1226 | |||
| 1227 | headers = { | ||
| 1228 | "Authorization": f"Bearer {api_key}", | ||
| 1229 | "X-DashScope-Async": "enable", | ||
| 1230 | "Content-Type": "application/json", | ||
| 1231 | } | ||
| 1232 | |||
| 1233 | poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0)) | ||
| 1234 | poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0)) | ||
| 1235 | deadline = time.time() + poll_timeout | ||
| 1236 | |||
| 1237 | while time.time() < deadline: | ||
| 1238 | try: | ||
| 1239 | response = requests.get(task_url, headers=headers, timeout=self.timeout) | ||
| 1240 | response.raise_for_status() | ||
| 1241 | data = response.json() | ||
| 1242 | except Exception as exc: | ||
| 1243 | logger.exception("轮询ASR任务异常: task_id=%s, error=%s", task_id, exc) | ||
| 1244 | return None | ||
| 1245 | |||
| 1246 | output = data.get("output") if isinstance(data, dict) else None | ||
| 1247 | task_status = output.get("task_status") if isinstance(output, dict) else None | ||
| 1248 | if task_status == "SUCCEEDED": | ||
| 1249 | return data | ||
| 1250 | if task_status in {"FAILED", "CANCELED"}: | ||
| 1251 | logger.error( | ||
| 1252 | "ASR任务失败: task_id=%s, status=%s, data=%s", | ||
| 1253 | task_id, | ||
| 1254 | task_status, | ||
| 1255 | data, | ||
| 1256 | ) | ||
| 1257 | return None | ||
| 1258 | |||
| 1259 | time.sleep(poll_interval) | ||
| 1260 | |||
| 1261 | logger.error("轮询ASR任务超时: task_id=%s", task_id) | ||
| 1262 | return None | ||
| 1263 | |||
| 1264 | def _fetch_asr_transcription(self, transcription_url: str) -> Optional[Dict[str, Any]]: | ||
| 1265 | """下载 transcription_url 对应的转写结果JSON""" | ||
| 1266 | try: | ||
| 1267 | response = requests.get(transcription_url, timeout=self.timeout) | ||
| 1268 | response.raise_for_status() | ||
| 1269 | data = response.json() | ||
| 1270 | return data if isinstance(data, dict) else None | ||
| 1271 | except Exception as exc: | ||
| 1272 | logger.exception("下载ASR转写结果失败: %s", exc) | ||
| 1273 | return None | ||
| 1274 | |||
| 1275 | def _convert_asr_transcription_to_lyrics( | ||
| 1276 | self, transcript_data: Dict[str, Any] | ||
| 1277 | ) -> List[Dict[str, Any]]: | ||
| 1278 | """将ASR结果转换为 lyrics: [{time, text}]""" | ||
| 1279 | transcripts = transcript_data.get("transcripts") | ||
| 1280 | if not isinstance(transcripts, list): | ||
| 1281 | return [] | ||
| 1282 | |||
| 1283 | lyrics: List[Dict[str, Any]] = [] | ||
| 1284 | for transcript in transcripts: | ||
| 1285 | if not isinstance(transcript, dict): | ||
| 1286 | continue | ||
| 1287 | |||
| 1288 | sentences = transcript.get("sentences") | ||
| 1289 | if not isinstance(sentences, list): | ||
| 1290 | continue | ||
| 1291 | |||
| 1292 | for sentence in sentences: | ||
| 1293 | if not isinstance(sentence, dict): | ||
| 1294 | continue | ||
| 1295 | |||
| 1296 | text = sentence.get("text") | ||
| 1297 | if not isinstance(text, str): | ||
| 1298 | continue | ||
| 1299 | text = text.strip() | ||
| 1300 | if not text: | ||
| 1301 | continue | ||
| 1302 | |||
| 1303 | begin_time = sentence.get("begin_time") | ||
| 1304 | time_str = ( | ||
| 1305 | self._format_asr_time_ms(begin_time) | ||
| 1306 | if isinstance(begin_time, (int, float)) | ||
| 1307 | else None | ||
| 1308 | ) | ||
| 1309 | lyrics.append( | ||
| 1310 | { | ||
| 1311 | "time": time_str, | ||
| 1312 | "text": text, | ||
| 1313 | } | ||
| 1314 | ) | ||
| 1315 | |||
| 1316 | return lyrics | ||
| 1317 | |||
| 1318 | @staticmethod | ||
| 1319 | def _format_asr_time_ms(ms_value: float) -> str: | ||
| 1320 | """毫秒转 mm:ss.xxx""" | ||
| 1321 | total_ms = int(max(0, ms_value)) | ||
| 1322 | minutes = total_ms // 60000 | ||
| 1323 | seconds = (total_ms % 60000) // 1000 | ||
| 1324 | milliseconds = total_ms % 1000 | ||
| 1325 | return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" | ||
| 1326 | |||
| 1327 | def _get_dashscope_api_key(self) -> Optional[str]: | ||
| 1328 | """获取 DashScope API Key(ASR专用)""" | ||
| 1329 | return ( | ||
| 1330 | self.api_key | ||
| 1331 | or settings.QWEN_DASHSCOPE_API_KEY | ||
| 1332 | or settings.QWEN_API_KEY | ||
| 1333 | or os.getenv("DASHSCOPE_API_KEY") | ||
| 1334 | or os.getenv("QWEN_DASHSCOPE_API_KEY") | ||
| 1335 | or os.getenv("QWEN_API_KEY") | ||
| 1336 | ) | ||
| 1337 | |||
| 1338 | @staticmethod | ||
| 1339 | def _as_dict(response_obj: Any) -> Dict[str, Any]: | ||
| 1340 | """尽可能将 SDK 响应对象转换为 dict""" | ||
| 1341 | if isinstance(response_obj, dict): | ||
| 1342 | return response_obj | ||
| 1343 | if response_obj is None: | ||
| 1344 | return {} | ||
| 1345 | |||
| 1346 | for attr in ("to_dict", "as_dict", "dict"): | ||
| 1347 | fn = getattr(response_obj, attr, None) | ||
| 1348 | if callable(fn): | ||
| 1349 | try: | ||
| 1350 | value = fn() | ||
| 1351 | if isinstance(value, dict): | ||
| 1352 | return value | ||
| 1353 | except Exception: | ||
| 1354 | pass | ||
| 1355 | |||
| 1356 | data: Dict[str, Any] = {} | ||
| 1357 | for key in ("request_id", "output", "usage"): | ||
| 1358 | val = getattr(response_obj, key, None) | ||
| 1359 | if val is not None: | ||
| 1360 | if key in ("output", "usage") and not isinstance(val, dict): | ||
| 1361 | nested = QwenAnalyzer._as_dict(val) | ||
| 1362 | data[key] = nested if nested else val | ||
| 1363 | else: | ||
| 1364 | data[key] = val | ||
| 1365 | return data | ||
| 1366 | |||
| 1367 | def _extract_task_id_from_asr_response(self, response_obj: Any) -> Optional[str]: | ||
| 1368 | data = self._as_dict(response_obj) | ||
| 1369 | output = data.get("output") | ||
| 1370 | if isinstance(output, dict): | ||
| 1371 | task_id = output.get("task_id") | ||
| 1372 | if isinstance(task_id, str) and task_id.strip(): | ||
| 1373 | return task_id.strip() | ||
| 1374 | return None | ||
| 1375 | |||
| 1376 | def _extract_task_status_from_asr_response(self, response_obj: Any) -> Optional[str]: | ||
| 1377 | data = self._as_dict(response_obj) | ||
| 1378 | output = data.get("output") | ||
| 1379 | if isinstance(output, dict): | ||
| 1380 | task_status = output.get("task_status") | ||
| 1381 | if isinstance(task_status, str): | ||
| 1382 | return task_status | ||
| 1383 | return None | ||
| 1384 | |||
| 1385 | def _extract_transcription_url_from_asr_response( | ||
| 1386 | self, response_obj: Any | ||
| 1387 | ) -> Optional[str]: | ||
| 1388 | data = self._as_dict(response_obj) | ||
| 1389 | output = data.get("output") | ||
| 1390 | if not isinstance(output, dict): | ||
| 1391 | return None | ||
| 1392 | |||
| 1393 | # 兼容 output.results: [{transcription_url: ...}] | ||
| 1394 | results = output.get("results") | ||
| 1395 | if isinstance(results, list) and results: | ||
| 1396 | first = results[0] | ||
| 1397 | if isinstance(first, dict): | ||
| 1398 | transcription_url = first.get("transcription_url") | ||
| 1399 | if isinstance(transcription_url, str) and transcription_url.strip(): | ||
| 1400 | return transcription_url.strip() | ||
| 1401 | |||
| 1402 | result = output.get("result") | ||
| 1403 | if not isinstance(result, dict): | ||
| 1404 | # 兜底兼容 output.transcription_url | ||
| 1405 | transcription_url = output.get("transcription_url") | ||
| 1406 | if isinstance(transcription_url, str) and transcription_url.strip(): | ||
| 1407 | return transcription_url.strip() | ||
| 1408 | return None | ||
| 1409 | transcription_url = result.get("transcription_url") | ||
| 1410 | if isinstance(transcription_url, str) and transcription_url.strip(): | ||
| 1411 | return transcription_url.strip() | ||
| 1412 | return None | ||
| 1413 | |||
| 1414 | def _extract_usage_from_asr_response( | ||
| 1415 | self, response_obj: Any | ||
| 1416 | ) -> Optional[Dict[str, Any]]: | ||
| 1417 | data = self._as_dict(response_obj) | ||
| 1418 | usage = data.get("usage") | ||
| 1419 | return usage if isinstance(usage, dict) else None | ||
| 1420 | |||
| 1421 | def _build_messages( | ||
| 1422 | self, | ||
| 1423 | system_prompt: str, | ||
| 1424 | user_prompt: str, | ||
| 1425 | music_url: str, | ||
| 1426 | ) -> list: | ||
| 1427 | """构建消息格式""" | ||
| 1428 | messages = [] | ||
| 1429 | |||
| 1430 | # 添加系统提示词 | ||
| 1431 | if system_prompt: | ||
| 1432 | messages.append( | ||
| 1433 | { | ||
| 1434 | "role": "system", | ||
| 1435 | "content": system_prompt, | ||
| 1436 | } | ||
| 1437 | ) | ||
| 1438 | |||
| 1439 | # 添加用户消息(包含音频和文本) | ||
| 1440 | messages.append( | ||
| 1441 | { | ||
| 1442 | "role": "user", | ||
| 1443 | "content": [ | ||
| 1444 | { | ||
| 1445 | "type": "input_audio", | ||
| 1446 | "input_audio": {"data": music_url, "format": "mp3"}, | ||
| 1447 | }, | ||
| 1448 | {"type": "text", "text": user_prompt}, | ||
| 1449 | ], | ||
| 1450 | } | ||
| 1451 | ) | ||
| 1452 | |||
| 1453 | return messages | ||
| 1454 | |||
| 1455 | def _build_dashscope_prompt(self, system_prompt: str, user_prompt: str) -> str: | ||
| 1456 | """构建 DashScope 调用的文本提示词""" | ||
| 1457 | if system_prompt and system_prompt.strip(): | ||
| 1458 | return f"{system_prompt.strip()}\n\n{user_prompt}".strip() | ||
| 1459 | return user_prompt.strip() | ||
| 1460 | |||
| 1461 | def _timed_call_openai( | ||
| 1462 | self, client, messages: list | ||
| 1463 | ) -> tuple[Optional[Dict], float]: | ||
| 1464 | """为 OpenAI 兼容调用提供耗时统计""" | ||
| 1465 | call_start = time.time() | ||
| 1466 | resp = self._call_with_retry(client, messages) | ||
| 1467 | return resp, round(time.time() - call_start, 2) | ||
| 1468 | |||
| 1469 | def _call_with_retry_dashscope( | ||
| 1470 | self, music_url: str, prompt: str, timeout: Optional[float] = None, song_id: str = "", metadata: Optional[Dict[str, Any]] = None | ||
| 1471 | ) -> Optional[Dict]: | ||
| 1472 | """使用 DashScope SDK 进行多模态调用(带重试,自动降级到 base64)""" | ||
| 1473 | import dashscope | ||
| 1474 | |||
| 1475 | dashscope_key = ( | ||
| 1476 | self.api_key | ||
| 1477 | or settings.QWEN_DASHSCOPE_API_KEY | ||
| 1478 | or os.getenv("QWEN_OMNI_API_KEY") | ||
| 1479 | or os.getenv("DASHSCOPE_API_KEY") | ||
| 1480 | ) | ||
| 1481 | if not dashscope_key: | ||
| 1482 | print(" ⚠ 未设置 DASHSCOPE_API_KEY 环境变量,请先配置") | ||
| 1483 | return None | ||
| 1484 | |||
| 1485 | messages = [ | ||
| 1486 | { | ||
| 1487 | "role": "user", | ||
| 1488 | "content": [ | ||
| 1489 | {"audio": music_url}, | ||
| 1490 | {"text": prompt}, | ||
| 1491 | ], | ||
| 1492 | } | ||
| 1493 | ] | ||
| 1494 | |||
| 1495 | timeout = timeout or self.timeout | ||
| 1496 | |||
| 1497 | for attempt in range(1, self.max_retries + 1): | ||
| 1498 | try: | ||
| 1499 | print( | ||
| 1500 | f" [{self.model}] 正在分析 (DashScope 尝试 {attempt}/{self.max_retries}, timeout={timeout}s)..." | ||
| 1501 | ) | ||
| 1502 | response = self._dashscope_call_with_hard_timeout( | ||
| 1503 | dashscope=dashscope, | ||
| 1504 | api_key=dashscope_key, | ||
| 1505 | model=self.model, | ||
| 1506 | messages=messages, | ||
| 1507 | timeout=timeout, | ||
| 1508 | ) | ||
| 1509 | |||
| 1510 | if response.status_code != 200: | ||
| 1511 | error_msg = getattr(response, "message", "") | ||
| 1512 | error_code = getattr(response, "code", "") | ||
| 1513 | error_output = getattr(response, "output", {}) | ||
| 1514 | print( | ||
| 1515 | f" ✗ [{self.model}] API 调用失败,状态码: {response.status_code}" | ||
| 1516 | ) | ||
| 1517 | if song_id: | ||
| 1518 | print(f" 歌曲ID: {song_id}") | ||
| 1519 | if error_code: | ||
| 1520 | print(f" 错误代码: {error_code}") | ||
| 1521 | if error_msg: | ||
| 1522 | print(f" 错误信息: {error_msg}") | ||
| 1523 | if error_output: | ||
| 1524 | print(f" 响应内容: {error_output}") | ||
| 1525 | |||
| 1526 | # 检测文件过大错误,自动降级到 OSS 方式 | ||
| 1527 | if "file size is too large" in str(error_msg).lower() or "file size is too large" in str(error_output).lower(): | ||
| 1528 | print(f" [Qwen] 检测到文件过大,自动降级到 OSS 方式...") | ||
| 1529 | try: | ||
| 1530 | temp_audio_path = self._download_audio_temp(music_url) | ||
| 1531 | if temp_audio_path: | ||
| 1532 | mono_path = self._convert_to_mono(temp_audio_path) | ||
| 1533 | oss_url = self._upload_audio_to_oss(mono_path) | ||
| 1534 | # 只删除转换后的单声道文件,保留原始下载文件 | ||
| 1535 | self._cleanup_temp_audio(mono_path) | ||
| 1536 | if oss_url: | ||
| 1537 | print(f" [Qwen] 使用 OSS URL 重新请求: {oss_url[:60]}...") | ||
| 1538 | return self._call_with_retry_dashscope(oss_url, prompt, timeout=timeout, song_id=song_id, metadata=metadata) | ||
| 1539 | except Exception as e: | ||
| 1540 | print(f" [Qwen] OSS 降级失败: {e}") | ||
| 1541 | return None | ||
| 1542 | |||
| 1543 | if attempt < self.max_retries: | ||
| 1544 | time.sleep(attempt) | ||
| 1545 | continue | ||
| 1546 | return None | ||
| 1547 | |||
| 1548 | content = response.output.choices[0].message.content | ||
| 1549 | if isinstance(content, list): | ||
| 1550 | if content and isinstance(content[0], dict) and "text" in content[0]: | ||
| 1551 | result_text = content[0]["text"] | ||
| 1552 | else: | ||
| 1553 | result_text = "" | ||
| 1554 | else: | ||
| 1555 | result_text = content | ||
| 1556 | |||
| 1557 | usage = None | ||
| 1558 | resp_usage = getattr(response, "usage", None) | ||
| 1559 | if isinstance(resp_usage, dict): | ||
| 1560 | input_tokens = resp_usage.get( | ||
| 1561 | "input_tokens", resp_usage.get("prompt_tokens", 0) | ||
| 1562 | ) | ||
| 1563 | output_tokens = resp_usage.get( | ||
| 1564 | "output_tokens", resp_usage.get("completion_tokens", 0) | ||
| 1565 | ) | ||
| 1566 | total_tokens = resp_usage.get("total_tokens") | ||
| 1567 | usage = { | ||
| 1568 | "prompt_tokens": input_tokens or 0, | ||
| 1569 | "completion_tokens": output_tokens or 0, | ||
| 1570 | "total_tokens": total_tokens | ||
| 1571 | if total_tokens is not None | ||
| 1572 | else (input_tokens or 0) + (output_tokens or 0), | ||
| 1573 | } | ||
| 1574 | elif resp_usage is not None: | ||
| 1575 | input_tokens = getattr(resp_usage, "input_tokens", None) | ||
| 1576 | output_tokens = getattr(resp_usage, "output_tokens", None) | ||
| 1577 | total_tokens = getattr(resp_usage, "total_tokens", None) | ||
| 1578 | usage = { | ||
| 1579 | "prompt_tokens": input_tokens or 0, | ||
| 1580 | "completion_tokens": output_tokens or 0, | ||
| 1581 | "total_tokens": total_tokens | ||
| 1582 | if total_tokens is not None | ||
| 1583 | else (input_tokens or 0) + (output_tokens or 0), | ||
| 1584 | } | ||
| 1585 | |||
| 1586 | return {"content": result_text, "usage": usage} | ||
| 1587 | |||
| 1588 | except TimeoutError: | ||
| 1589 | print(f" ✗ [{self.model}] API 调用超时 (尝试 {attempt}/{self.max_retries})") | ||
| 1590 | if attempt < self.max_retries: | ||
| 1591 | time.sleep(attempt) | ||
| 1592 | continue | ||
| 1593 | return None | ||
| 1594 | except Exception as e: | ||
| 1595 | print(f" ✗ [{self.model}] API 调用异常: {e}") | ||
| 1596 | if attempt < self.max_retries: | ||
| 1597 | time.sleep(attempt) | ||
| 1598 | continue | ||
| 1599 | return None | ||
| 1600 | |||
| 1601 | return None | ||
| 1602 | |||
| 1603 | def _download_audio_temp(self, music_url: str) -> Optional[str]: | ||
| 1604 | """ | ||
| 1605 | 临时下载音频文件到系统临时目录 | ||
| 1606 | |||
| 1607 | Args: | ||
| 1608 | music_url: 音频URL | ||
| 1609 | |||
| 1610 | Returns: | ||
| 1611 | 临时文件路径,如果下载失败返回 None | ||
| 1612 | """ | ||
| 1613 | try: | ||
| 1614 | # 确定文件扩展名 | ||
| 1615 | ext = ".mp3" | ||
| 1616 | if "." in music_url: | ||
| 1617 | url_ext = music_url.split(".")[-1].split("?")[0].lower() | ||
| 1618 | if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]: | ||
| 1619 | ext = f".{url_ext}" | ||
| 1620 | |||
| 1621 | # 下载到系统临时目录 | ||
| 1622 | temp_dir = tempfile.gettempdir() | ||
| 1623 | url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12] | ||
| 1624 | temp_path = os.path.join(temp_dir, f"qwen_audio_{url_hash}{ext}") | ||
| 1625 | |||
| 1626 | if not os.path.exists(temp_path): | ||
| 1627 | response = requests.get(music_url, timeout=60) | ||
| 1628 | response.raise_for_status() | ||
| 1629 | with open(temp_path, "wb") as f: | ||
| 1630 | f.write(response.content) | ||
| 1631 | print(f" [Qwen] 临时音频已下载: {temp_path}") | ||
| 1632 | else: | ||
| 1633 | print(f" [Qwen] 使用缓存的临时音频") | ||
| 1634 | |||
| 1635 | return temp_path | ||
| 1636 | except Exception as e: | ||
| 1637 | print(f" [Qwen] 临时音频下载失败: {e}") | ||
| 1638 | return None | ||
| 1639 | |||
| 1640 | def _convert_to_mono(self, audio_path: str) -> str: | ||
| 1641 | """ | ||
| 1642 | 将音频转换为单声道 | ||
| 1643 | |||
| 1644 | Args: | ||
| 1645 | audio_path: 原始音频文件路径 | ||
| 1646 | |||
| 1647 | Returns: | ||
| 1648 | 转换后的音频文件路径 | ||
| 1649 | """ | ||
| 1650 | import time | ||
| 1651 | timestamp = int(time.time() * 1000) | ||
| 1652 | base_name = os.path.basename(audio_path) | ||
| 1653 | name_parts = base_name.rsplit(".", 1) | ||
| 1654 | if len(name_parts) == 2: | ||
| 1655 | mono_path = os.path.join( | ||
| 1656 | os.path.dirname(audio_path), | ||
| 1657 | f"{name_parts[0]}_mono_{timestamp}.{name_parts[1]}" | ||
| 1658 | ) | ||
| 1659 | else: | ||
| 1660 | mono_path = f"{audio_path}_mono_{timestamp}" | ||
| 1661 | |||
| 1662 | try: | ||
| 1663 | cmd = [ | ||
| 1664 | "ffmpeg", | ||
| 1665 | "-i", audio_path, | ||
| 1666 | "-ac", "1", # 转为单声道 | ||
| 1667 | "-y", | ||
| 1668 | mono_path | ||
| 1669 | ] | ||
| 1670 | print(f" [Qwen] 转换为单声道: ffmpeg -i ... -ac 1") | ||
| 1671 | subprocess.run(cmd, capture_output=True, timeout=60, check=True) | ||
| 1672 | original_size = os.path.getsize(audio_path) | ||
| 1673 | mono_size = os.path.getsize(mono_path) | ||
| 1674 | ratio = (1 - mono_size / original_size) * 100 | ||
| 1675 | print(f" [Qwen] 音频已转换: {original_size/1024/1024:.1f}MB -> {mono_size/1024/1024:.1f}MB (压缩率: {ratio:.1f}%)") | ||
| 1676 | return mono_path | ||
| 1677 | except Exception as e: | ||
| 1678 | print(f" [Qwen] 音频转换失败: {e},将使用原文件") | ||
| 1679 | return audio_path | ||
| 1680 | |||
| 1681 | def _upload_audio_to_oss(self, audio_path: str) -> Optional[str]: | ||
| 1682 | """ | ||
| 1683 | 将音频文件上传到 OSS | ||
| 1684 | |||
| 1685 | Args: | ||
| 1686 | audio_path: 音频文件路径 | ||
| 1687 | |||
| 1688 | Returns: | ||
| 1689 | OSS 文件 URL,如果上传失败返回 None | ||
| 1690 | """ | ||
| 1691 | try: | ||
| 1692 | from app.utils.oss_uploader import oss_uploader | ||
| 1693 | |||
| 1694 | success, result = oss_uploader.upload_file(audio_path) | ||
| 1695 | if not success: | ||
| 1696 | print(f" [Qwen] 音频上传到 OSS 失败: {result}") | ||
| 1697 | return None | ||
| 1698 | |||
| 1699 | oss_url = result | ||
| 1700 | print(f" [Qwen] 音频已上传到 OSS: {oss_url}") | ||
| 1701 | return oss_url | ||
| 1702 | except Exception as e: | ||
| 1703 | print(f" [Qwen] 音频上传到 OSS 失败: {e}") | ||
| 1704 | return None | ||
| 1705 | |||
| 1706 | def _cleanup_temp_audio(self, temp_path: str) -> None: | ||
| 1707 | """清理临时音频文件""" | ||
| 1708 | if temp_path and os.path.exists(temp_path): | ||
| 1709 | try: | ||
| 1710 | os.unlink(temp_path) | ||
| 1711 | print(f" [Qwen] 已清理临时音频文件") | ||
| 1712 | except: | ||
| 1713 | pass | ||
| 1714 | |||
| 1715 | def _dashscope_call_with_hard_timeout( | ||
| 1716 | self, | ||
| 1717 | dashscope, | ||
| 1718 | api_key: str, | ||
| 1719 | model: str, | ||
| 1720 | messages: list, | ||
| 1721 | timeout: float, | ||
| 1722 | ): | ||
| 1723 | """ | ||
| 1724 | DashScope SDK 某些版本下 request_timeout 可能无法稳定生效。 | ||
| 1725 | 这里增加线程级硬超时,避免单次调用无限阻塞。 | ||
| 1726 | """ | ||
| 1727 | box: Dict[str, Any] = {} | ||
| 1728 | done = threading.Event() | ||
| 1729 | |||
| 1730 | def _target() -> None: | ||
| 1731 | try: | ||
| 1732 | box["response"] = dashscope.MultiModalConversation.call( | ||
| 1733 | api_key=api_key, | ||
| 1734 | model=model, | ||
| 1735 | messages=messages, | ||
| 1736 | request_timeout=timeout, | ||
| 1737 | ) | ||
| 1738 | except Exception as exc: | ||
| 1739 | box["error"] = exc | ||
| 1740 | finally: | ||
| 1741 | done.set() | ||
| 1742 | |||
| 1743 | worker = threading.Thread(target=_target, daemon=True) | ||
| 1744 | worker.start() | ||
| 1745 | hard_timeout = max(float(timeout), 1.0) + 3.0 | ||
| 1746 | if not done.wait(hard_timeout): | ||
| 1747 | raise TimeoutError(f"DashScope hard timeout after {hard_timeout:.1f}s") | ||
| 1748 | if "error" in box: | ||
| 1749 | raise box["error"] | ||
| 1750 | return box.get("response") | ||
| 1751 | |||
| 1752 | def _call_with_retry( | ||
| 1753 | self, | ||
| 1754 | client, | ||
| 1755 | messages: list, | ||
| 1756 | timeout: Optional[float] = None, | ||
| 1757 | max_retries: Optional[int] = None, | ||
| 1758 | ) -> Optional[Dict]: | ||
| 1759 | """带重试的 API 调用(非流式)""" | ||
| 1760 | timeout = timeout or self.lyrics_timeout | ||
| 1761 | retries = max_retries or self.max_retries | ||
| 1762 | |||
| 1763 | for attempt in range(1, retries + 1): | ||
| 1764 | try: | ||
| 1765 | print( | ||
| 1766 | f" [Qwen] 调用模型 (尝试 {attempt}/{retries}, timeout={timeout}s)..." | ||
| 1767 | ) | ||
| 1768 | |||
| 1769 | response = client.chat.completions.create( | ||
| 1770 | model=self.model, | ||
| 1771 | messages=messages, | ||
| 1772 | modalities=["text"], | ||
| 1773 | stream=False, | ||
| 1774 | timeout=timeout, | ||
| 1775 | extra_body={"enable_thinking": False}, | ||
| 1776 | ) | ||
| 1777 | |||
| 1778 | content = ( | ||
| 1779 | response.choices[0].message.content if response.choices else "" | ||
| 1780 | ) | ||
| 1781 | usage = { | ||
| 1782 | "prompt_tokens": response.usage.prompt_tokens | ||
| 1783 | if response.usage | ||
| 1784 | else 0, | ||
| 1785 | "completion_tokens": response.usage.completion_tokens | ||
| 1786 | if response.usage | ||
| 1787 | else 0, | ||
| 1788 | "total_tokens": response.usage.total_tokens | ||
| 1789 | if response.usage | ||
| 1790 | else 0, | ||
| 1791 | } | ||
| 1792 | |||
| 1793 | print(f" [Qwen] 响应: {content[:100]}...") | ||
| 1794 | |||
| 1795 | return {"content": content, "usage": usage} | ||
| 1796 | |||
| 1797 | except Exception as e: | ||
| 1798 | error_type = type(e).__name__ | ||
| 1799 | print(f" [Qwen] 错误 ({error_type}): {e}") | ||
| 1800 | |||
| 1801 | if attempt < retries: | ||
| 1802 | wait_time = attempt | ||
| 1803 | print(f" 等待 {wait_time} 秒后重试...") | ||
| 1804 | time.sleep(wait_time) | ||
| 1805 | else: | ||
| 1806 | print(f" 已达到最大重试次数") | ||
| 1807 | return None | ||
| 1808 | |||
| 1809 | return None |
| 1 | # 聚音标签识别助手 - 系统角色定义 | ||
| 2 | |||
| 3 | ## 角色定位 | ||
| 4 | |||
| 5 | 你是音乐内容标签标注助手。 | ||
| 6 | 你的任务是基于输入的歌曲信息(如歌词、标题、风格描述、音频特征等),严格按照「聚音标签字典」输出标准化标签字段。 | ||
| 7 | |||
| 8 | 只输出标签结果,不做解释,不做分析,不添加任何多余文本。 | ||
| 9 | |||
| 10 | ------ | ||
| 11 | |||
| 12 | ## 输出格式 | ||
| 13 | |||
| 14 | 仅输出 JSON 纯文本,结构如下: | ||
| 15 | |||
| 16 | { | ||
| 17 | "performer_type": "", | ||
| 18 | "language": "", | ||
| 19 | "emotion": [], | ||
| 20 | "douyin_tags": [], | ||
| 21 | "music_style_tags": [], | ||
| 22 | "instrument_tags": [], | ||
| 23 | "scene": [] | ||
| 24 | } | ||
| 25 | |||
| 26 | 禁止输出任何解释性文字、注释或额外字段。 | ||
| 27 | |||
| 28 | ------ | ||
| 29 | |||
| 30 | ## 全局约束规则 | ||
| 31 | |||
| 32 | 1. 所有标签必须严格从下方字典中选择,禁止自造词。 | ||
| 33 | 2. 不允许基于刻板印象猜测(如仅凭曲风推断情绪)。 | ||
| 34 | 3. 标签必须基于明确特征: | ||
| 35 | - 歌词内容 | ||
| 36 | - 音乐风格特征 | ||
| 37 | - 明确出现的配器 | ||
| 38 | - 明确使用场景 | ||
| 39 | 4. 多选字段仅选择高度确定且核心表达的标签,避免过度打标。 | ||
| 40 | 5. 注意!所有字段至少选择一个标签,不允许留空。 | ||
| 41 | |||
| 42 | ------ | ||
| 43 | |||
| 44 | # 字段判定标准说明 | ||
| 45 | |||
| 46 | ## 一、演唱者类型 performer_type(单选) | ||
| 47 | |||
| 48 | 用于标注主要人声类型,仅根据实际听感或明确描述判断: | ||
| 49 | |||
| 50 | - 男声:主要为男性声线 | ||
| 51 | - 女声:主要为女性声线 | ||
| 52 | - 童声:明显儿童声线 | ||
| 53 | - 合唱:多人群体演唱为主(非简单和声) | ||
| 54 | |||
| 55 | 不确定时输出 ""。 | ||
| 56 | |||
| 57 | ------ | ||
| 58 | |||
| 59 | ## 二、情绪 emotion(多选) | ||
| 60 | |||
| 61 | 必须基于歌曲整体情绪表达判断,而非个别词语。 | ||
| 62 | |||
| 63 | - 喜庆:节日、庆典氛围明显 | ||
| 64 | - 浪漫:爱情氛围浓厚 | ||
| 65 | - 雄壮:宏大、史诗、气势恢宏 | ||
| 66 | - 蛊惑:迷幻、魅惑、暧昧 | ||
| 67 | - 宣泄:情绪爆发、释放 | ||
| 68 | - 悲壮:悲情但具有力量感 | ||
| 69 | - 愤怒:强烈对抗或激烈表达 | ||
| 70 | - 庄重:正式、肃穆 | ||
| 71 | - 激情:热烈高昂 | ||
| 72 | - 沉重:压抑、厚重 | ||
| 73 | - 快乐:轻松开心 | ||
| 74 | - 励志:奋斗、成长、自我激励 | ||
| 75 | - 思念:想念某人或过往 | ||
| 76 | - 紧张:悬而未决、焦虑 | ||
| 77 | - 恐怖:惊悚氛围 | ||
| 78 | - 感动:温情催泪 | ||
| 79 | - 恶搞:刻意夸张调侃 | ||
| 80 | - 搞笑:明显幽默表达 | ||
| 81 | - 期待:盼望未来 | ||
| 82 | - 怀念:回忆过去 | ||
| 83 | - 甜蜜:恋爱甜感 | ||
| 84 | - 孤独:孤单、自我独白 | ||
| 85 | - 伤感:悲伤低落 | ||
| 86 | - 悬疑:神秘未知感 | ||
| 87 | - 祝福:祝愿表达 | ||
| 88 | - 佛系:平淡随性 | ||
| 89 | - 舒缓:节奏慢、平稳 | ||
| 90 | - 悠扬:旋律流畅优美 | ||
| 91 | - 温暖:柔和治愈 | ||
| 92 | - 忧郁:带有阴郁气质 | ||
| 93 | |||
| 94 | 避免: | ||
| 95 | |||
| 96 | - 同时选择强烈对立情绪(如 快乐 与 伤感) | ||
| 97 | - 同类标签堆叠(如 伤感 + 忧郁 + 孤独 需明确区分) | ||
| 98 | |||
| 99 | ------ | ||
| 100 | |||
| 101 | ## 三、语种 language(单选) | ||
| 102 | |||
| 103 | 仅从下列标签中选择一个最主要的演唱语种: | ||
| 104 | |||
| 105 | - 普通话 | ||
| 106 | - 粤语 | ||
| 107 | - 藏语 | ||
| 108 | - 英语 | ||
| 109 | - 韩语 | ||
| 110 | - 闽南语 | ||
| 111 | - 蒙语 | ||
| 112 | - 俄语 | ||
| 113 | - 其他 | ||
| 114 | |||
| 115 | 规则: | ||
| 116 | |||
| 117 | - 只输出一个语种标签 | ||
| 118 | - 依据实际演唱语言判断,不根据歌手国籍或曲风猜测 | ||
| 119 | - 纯音乐或无法判断时输出 "" | ||
| 120 | |||
| 121 | ------ | ||
| 122 | |||
| 123 | ## 四、网络/抖音歌曲 douyin_tags(可多选) | ||
| 124 | |||
| 125 | 仅当歌曲具备明显网络传播特征或主题风格时选择: | ||
| 126 | |||
| 127 | - 草原:草原文化、民族草原元素 | ||
| 128 | - 故乡:思乡主题 | ||
| 129 | - 神曲:洗脑旋律、强节奏重复 | ||
| 130 | - 文艺:小众表达、诗性表达 | ||
| 131 | - 青春:校园或成长主题 | ||
| 132 | - 治愈系:温暖疗愈风格 | ||
| 133 | - 清新:轻快自然风格 | ||
| 134 | - 奇幻:幻想、魔幻元素 | ||
| 135 | |||
| 136 | 非明显网络属性不要强行标注。 | ||
| 137 | |||
| 138 | ------ | ||
| 139 | |||
| 140 | ## 五、音乐风格 music_style_tags(多选) | ||
| 141 | |||
| 142 | 必须根据音乐结构与风格特征判断,不根据歌词主题判断。 | ||
| 143 | |||
| 144 | - 世界音乐 | ||
| 145 | - 雷鬼 | ||
| 146 | - R&B/Soul | ||
| 147 | - MC喊麦 | ||
| 148 | - 另类音乐 | ||
| 149 | - 民歌 | ||
| 150 | - 戏曲 | ||
| 151 | - 古风 | ||
| 152 | - 古典音乐 | ||
| 153 | - HipHop | ||
| 154 | - Rap | ||
| 155 | - 摇滚 | ||
| 156 | - DJ嗨曲 | ||
| 157 | - 布鲁斯/蓝调 | ||
| 158 | - 拉丁 | ||
| 159 | - 舞曲 | ||
| 160 | - 爵士 | ||
| 161 | - 乡村 | ||
| 162 | - 民谣 | ||
| 163 | - 流行 | ||
| 164 | - 轻音乐 | ||
| 165 | - 国风 | ||
| 166 | - 儿歌 | ||
| 167 | |||
| 168 | 规则: | ||
| 169 | |||
| 170 | - 只选核心风格,不叠加相似风格 | ||
| 171 | - 不因使用某个乐器就推断整体风格 | ||
| 172 | - 无明显风格时可只选“流行” | ||
| 173 | |||
| 174 | ------ | ||
| 175 | |||
| 176 | ## 六、配器 instrument_tags(多选) | ||
| 177 | |||
| 178 | 仅在明确可识别时选择: | ||
| 179 | |||
| 180 | - 二胡 | ||
| 181 | - 竹笛 | ||
| 182 | - 琵琶 | ||
| 183 | - 音效 | ||
| 184 | - 口琴 | ||
| 185 | - 电子 | ||
| 186 | - 木吉他 | ||
| 187 | - 鼓组 | ||
| 188 | - 弦乐 | ||
| 189 | - 电吉他 | ||
| 190 | - 古筝 | ||
| 191 | - 钢琴 | ||
| 192 | |||
| 193 | 规则: | ||
| 194 | |||
| 195 | - 必须为明显主导或突出配器 | ||
| 196 | - 不因常规伴奏默认存在而标注 | ||
| 197 | - 不确定不要猜 | ||
| 198 | |||
| 199 | ------ | ||
| 200 | |||
| 201 | ## 七、场景 scene(多选) | ||
| 202 | |||
| 203 | 根据歌曲使用场景或明显氛围判断: | ||
| 204 | |||
| 205 | - 餐厅 | ||
| 206 | - 汽车 | ||
| 207 | - 跳舞 | ||
| 208 | - 旅行 | ||
| 209 | - 工作 | ||
| 210 | - 校园 | ||
| 211 | - 夜店 | ||
| 212 | - 运动 | ||
| 213 | - 休闲 | ||
| 214 | - live house | ||
| 215 | - 广场舞 | ||
| 216 | - 抖音 | ||
| 217 | - 婚礼 | ||
| 218 | - 约会 | ||
| 219 | |||
| 220 | 规则: | ||
| 221 | |||
| 222 | - 仅当歌曲明显适配该场景时标注 | ||
| 223 | - 避免泛化场景(如所有慢歌都标“休闲”) | ||
| 224 | |||
| 225 | ------ | ||
| 226 | |||
| 227 | ## 最终执行要求 | ||
| 228 | |||
| 229 | - 只输出 JSON | ||
| 230 | - 不解释 | ||
| 231 | - 不补充说明 | ||
| 232 | - 不输出字典内容 | ||
| 233 | - 不输出“分析如下”之类文字 | ||
| 234 | - 不添加未定义字段 | ||
| 235 | |||
| 236 | 严格遵守字段范围与空值规则。 | ||
| 237 | |||
| 238 | ## 输出格式 | ||
| 239 | 必须严格输出以下 JSON 结构,字段名不能改: | ||
| 240 | |||
| 241 | ```json | ||
| 242 | { | ||
| 243 | "performer_type": "", | ||
| 244 | "language": "", | ||
| 245 | "emotion": [], | ||
| 246 | "douyin_tags": [], | ||
| 247 | "music_style_tags": [], | ||
| 248 | "instrument_tags": [], | ||
| 249 | "scene": [] | ||
| 250 | } | ||
| 251 | ``` |
| 1 | # 歌词识别提示词模板 | ||
| 2 | # 仅识别歌词内容,不包含其他音乐分析 | ||
| 3 | 请识别并转录音频中的完整歌词。 | ||
| 4 | |||
| 5 | ## 核心任务 | ||
| 6 | 1. **逐句识别**:按时间顺序输出每一句歌词,每句通过换行进行分隔。 | ||
| 7 | 2. **字段要求**:每条记录必须包含 `time` (格式 "mm:ss.xxx",无法确定则为 null) 和 `text` (歌词内容)。 | ||
| 8 | 3. **无语义音节压缩**:对于“啊/呜/哦/嗯/啦”等辅助音节,禁止逐字展示,统一使用 `...` 缩略(例:把“啊啊啊啊”识别为“啊...”)。 | ||
| 9 | 4. **完整性**:必须转录包括重复段落在内的全曲内容。 | ||
| 10 | 5. **静默与纯音乐**:若为纯音乐或无歌词,仅返回空数组 `[]`。 | ||
| 11 | 6. 完整识别歌曲所有段落的完整歌词,包括不同段落之间重复了的歌词 | ||
| 12 | |||
| 13 | ## 输出格式规范 | ||
| 14 | - 严格输出 JSON,不得包含任何 Markdown 转义符(如 ```json)或解释性文字。 | ||
| 15 | - 字段统一为: {"lyrics": [{"time": "00:00.000", "text": "内容"}]} | ||
| 16 | |||
| 17 | ## 质量控制 | ||
| 18 | - 遇到合唱/重叠时,以主旋律为主。 | ||
| 19 | - 严禁自行脑补不存在的歌词。 | ||
| 20 | - 不要返回任何其他无关内容 |
app/utils/oss_service.py
0 → 100644
| 1 | """ | ||
| 2 | 阿里云OSS服务 | ||
| 3 | 提供文件上传、下载、删除等功能 | ||
| 4 | """ | ||
| 5 | import oss2 | ||
| 6 | from typing import Optional, Dict, Any, BinaryIO | ||
| 7 | from datetime import datetime, timedelta | ||
| 8 | from pathlib import Path | ||
| 9 | import hashlib | ||
| 10 | import mimetypes | ||
| 11 | import logging | ||
| 12 | import time | ||
| 13 | import json | ||
| 14 | from functools import wraps | ||
| 15 | |||
| 16 | from aliyunsdkcore.client import AcsClient | ||
| 17 | from aliyunsdksts import AssumeRoleRequest | ||
| 18 | |||
| 19 | from app.core.config import settings | ||
| 20 | from app.core.exceptions import ( | ||
| 21 | BusinessException, | ||
| 22 | ValidationException, | ||
| 23 | ExternalServiceException, | ||
| 24 | NotFoundException | ||
| 25 | ) | ||
| 26 | |||
| 27 | logger = logging.getLogger(__name__) | ||
| 28 | |||
| 29 | |||
| 30 | def retry_on_connection_error(max_retries: int = 3, delay: float = 1.0): | ||
| 31 | """ | ||
| 32 | 重试装饰器 - 在连接错误时重试 | ||
| 33 | |||
| 34 | Args: | ||
| 35 | max_retries: 最大重试次数 | ||
| 36 | delay: 重试延迟(秒) | ||
| 37 | """ | ||
| 38 | def decorator(func): | ||
| 39 | @wraps(func) | ||
| 40 | def wrapper(*args, **kwargs): | ||
| 41 | last_exception = None | ||
| 42 | for attempt in range(max_retries): | ||
| 43 | try: | ||
| 44 | return func(*args, **kwargs) | ||
| 45 | except (oss2.exceptions.RequestError, | ||
| 46 | ConnectionError, | ||
| 47 | TimeoutError, | ||
| 48 | Exception) as e: | ||
| 49 | # 只重试连接相关的错误 | ||
| 50 | error_str = str(e) | ||
| 51 | if ('Connection' in error_str or | ||
| 52 | 'timeout' in error_str.lower() or | ||
| 53 | 'closed' in error_str.lower() or | ||
| 54 | 'RemoteDisconnected' in error_str): | ||
| 55 | last_exception = e | ||
| 56 | if attempt < max_retries - 1: | ||
| 57 | wait_time = delay * (2 ** attempt) # 指数退避 | ||
| 58 | logger.warning( | ||
| 59 | f"[{func.__name__}] 连接错误,{wait_time}秒后重试 " | ||
| 60 | f"({attempt + 1}/{max_retries}): {error_str}" | ||
| 61 | ) | ||
| 62 | time.sleep(wait_time) | ||
| 63 | continue | ||
| 64 | else: | ||
| 65 | logger.error(f"[{func.__name__}] 重试次数已用尽: {error_str}") | ||
| 66 | raise | ||
| 67 | else: | ||
| 68 | # 非连接错误直接抛出 | ||
| 69 | raise | ||
| 70 | |||
| 71 | if last_exception: | ||
| 72 | raise last_exception | ||
| 73 | return wrapper | ||
| 74 | return decorator | ||
| 75 | |||
| 76 | |||
| 77 | class OSSService: | ||
| 78 | """阿里云OSS服务类""" | ||
| 79 | |||
| 80 | def __init__(self): | ||
| 81 | """初始化OSS客户端""" | ||
| 82 | if not all([ | ||
| 83 | settings.OSS_ACCESS_KEY_ID, | ||
| 84 | settings.OSS_ACCESS_KEY_SECRET, | ||
| 85 | settings.OSS_ENDPOINT, | ||
| 86 | settings.OSS_BUCKET_NAME | ||
| 87 | ]): | ||
| 88 | raise BusinessException("OSS配置不完整,请检查环境变量") | ||
| 89 | |||
| 90 | # 创建认证对象 | ||
| 91 | auth = oss2.Auth( | ||
| 92 | settings.OSS_ACCESS_KEY_ID, | ||
| 93 | settings.OSS_ACCESS_KEY_SECRET | ||
| 94 | ) | ||
| 95 | |||
| 96 | # 确定Endpoint | ||
| 97 | endpoint = settings.OSS_ENDPOINT | ||
| 98 | if settings.OSS_INTERNAL_ENDPOINT and settings.OSS_REGION: | ||
| 99 | endpoint = f"oss-{settings.OSS_REGION}-internal.aliyuncs.com" | ||
| 100 | |||
| 101 | # 创建Bucket对象 | ||
| 102 | self.bucket = oss2.Bucket( | ||
| 103 | auth, | ||
| 104 | endpoint, | ||
| 105 | settings.OSS_BUCKET_NAME | ||
| 106 | ) | ||
| 107 | |||
| 108 | # 设置超时参数(在 bucket 对象上设置) | ||
| 109 | self.bucket.timeout = ( | ||
| 110 | settings.OSS_CONNECTION_TIMEOUT, | ||
| 111 | settings.OSS_READ_TIMEOUT | ||
| 112 | ) | ||
| 113 | |||
| 114 | # 配置重试参数 | ||
| 115 | self.max_retries = settings.OSS_REQUEST_RETRY_TIMES | ||
| 116 | self.retry_delay = settings.OSS_RETRY_DELAY | ||
| 117 | |||
| 118 | logger.info( | ||
| 119 | f"OSS服务已初始化: endpoint={endpoint}, " | ||
| 120 | f"timeout={settings.OSS_CONNECTION_TIMEOUT}/{settings.OSS_READ_TIMEOUT}s, " | ||
| 121 | f"retries={self.max_retries}" | ||
| 122 | ) | ||
| 123 | |||
| 124 | @staticmethod | ||
| 125 | def _validate_file_extension(filename: str) -> bool: | ||
| 126 | """验证文件扩展名""" | ||
| 127 | ext = Path(filename).suffix.lower() | ||
| 128 | if not ext: | ||
| 129 | return False | ||
| 130 | return ext in [e.lower() for e in settings.OSS_ALLOWED_EXTENSIONS] | ||
| 131 | |||
| 132 | @staticmethod | ||
| 133 | def _validate_file_size(file_size: int) -> bool: | ||
| 134 | """验证文件大小""" | ||
| 135 | return 0 < file_size <= settings.OSS_MAX_FILE_SIZE | ||
| 136 | |||
| 137 | @staticmethod | ||
| 138 | def _generate_object_key( | ||
| 139 | filename: str, | ||
| 140 | user_id: int, | ||
| 141 | prefix: Optional[str] = None | ||
| 142 | ) -> str: | ||
| 143 | """ | ||
| 144 | 生成OSS对象键(路径) | ||
| 145 | 格式: {prefix}/{user_id}/{date}/{hash}_{filename} | ||
| 146 | """ | ||
| 147 | # 获取文件扩展名 | ||
| 148 | ext = Path(filename).suffix.lower() | ||
| 149 | name = Path(filename).stem | ||
| 150 | |||
| 151 | # 生成时间戳目录 | ||
| 152 | now = datetime.now() | ||
| 153 | date_path = now.strftime("%Y%m%d") | ||
| 154 | |||
| 155 | # 生成唯一标识(时间戳 + 4位随机数) | ||
| 156 | timestamp = now.strftime("%H%M%S%f") | ||
| 157 | unique_id = hashlib.md5(f"{user_id}{timestamp}{name}".encode()).hexdigest()[:8] | ||
| 158 | |||
| 159 | # 组合文件名 | ||
| 160 | new_filename = f"{unique_id}_{name}{ext}" | ||
| 161 | |||
| 162 | # 确定前缀 | ||
| 163 | base_prefix = prefix or settings.OSS_UPLOAD_PATH_PREFIX | ||
| 164 | |||
| 165 | # 组合完整路径(加上全局前缀) | ||
| 166 | if settings.OSS_GLOBAL_PREFIX: | ||
| 167 | return f"{settings.OSS_GLOBAL_PREFIX}/{base_prefix}/{user_id}/{date_path}/{new_filename}" | ||
| 168 | return f"{base_prefix}/{user_id}/{date_path}/{new_filename}" | ||
| 169 | |||
| 170 | @staticmethod | ||
| 171 | def _generate_object_key_simple( | ||
| 172 | filename: str, | ||
| 173 | entity_type: str | ||
| 174 | ) -> str: | ||
| 175 | """ | ||
| 176 | 生成OSS对象键(路径)- 简化版本,用于角色/场景图片转存 | ||
| 177 | 格式: {entity_type}/{date}/{filename} | ||
| 178 | """ | ||
| 179 | # 获取文件扩展名 | ||
| 180 | ext = Path(filename).suffix.lower() | ||
| 181 | name = Path(filename).stem | ||
| 182 | |||
| 183 | # 生成时间戳目录 | ||
| 184 | now = datetime.now() | ||
| 185 | date_path = now.strftime("%Y%m%d") | ||
| 186 | |||
| 187 | # 生成唯一标识(时间戳 + 随机数) | ||
| 188 | timestamp = now.strftime("%H%M%S%f") | ||
| 189 | unique_id = hashlib.md5(f"{entity_type}{timestamp}{name}".encode()).hexdigest()[:8] | ||
| 190 | |||
| 191 | # 组合文件名 | ||
| 192 | new_filename = f"{unique_id}{ext}" | ||
| 193 | |||
| 194 | # 组合完整路径:类型/日期/文件名.后缀 | ||
| 195 | if settings.OSS_GLOBAL_PREFIX: | ||
| 196 | return f"{settings.OSS_GLOBAL_PREFIX}/{entity_type}/{date_path}/{new_filename}" | ||
| 197 | return f"{entity_type}/{date_path}/{new_filename}" | ||
| 198 | |||
| 199 | @staticmethod | ||
| 200 | def _get_content_type(filename: str) -> str: | ||
| 201 | """获取文件MIME类型""" | ||
| 202 | content_type, _ = mimetypes.guess_type(filename) | ||
| 203 | return content_type or 'application/octet-stream' | ||
| 204 | |||
| 205 | @staticmethod | ||
| 206 | def _get_extension_from_mime(mime_type: str) -> str: | ||
| 207 | """ | ||
| 208 | 根据MIME类型获取文件扩展名 | ||
| 209 | |||
| 210 | 参数: | ||
| 211 | mime_type: MIME类型(如 image/jpeg, video/mp4) | ||
| 212 | |||
| 213 | 返回: | ||
| 214 | 文件扩展名(包含点号,如 .jpg, .mp4) | ||
| 215 | """ | ||
| 216 | # MIME类型到扩展名的映射 | ||
| 217 | mime_to_ext = { | ||
| 218 | # 图片 | ||
| 219 | 'image/jpeg': '.jpg', | ||
| 220 | 'image/jpg': '.jpg', | ||
| 221 | 'image/png': '.png', | ||
| 222 | 'image/gif': '.gif', | ||
| 223 | 'image/webp': '.webp', | ||
| 224 | 'image/svg+xml': '.svg', | ||
| 225 | 'image/bmp': '.bmp', | ||
| 226 | 'image/x-icon': '.ico', | ||
| 227 | # 视频 | ||
| 228 | 'video/mp4': '.mp4', | ||
| 229 | 'video/mpeg': '.mpeg', | ||
| 230 | 'video/webm': '.webm', | ||
| 231 | 'video/quicktime': '.mov', | ||
| 232 | 'video/x-msvideo': '.avi', | ||
| 233 | 'video/x-matroska': '.mkv', | ||
| 234 | # 音频 | ||
| 235 | 'audio/mpeg': '.mp3', | ||
| 236 | 'audio/wav': '.wav', | ||
| 237 | 'audio/ogg': '.ogg', | ||
| 238 | 'audio/webm': '.weba', | ||
| 239 | # 其他 | ||
| 240 | 'application/pdf': '.pdf', | ||
| 241 | 'text/plain': '.txt', | ||
| 242 | 'application/json': '.json', | ||
| 243 | } | ||
| 244 | |||
| 245 | # 处理带参数的 MIME 类型(如 image/jpeg; charset=utf-8) | ||
| 246 | base_mime = mime_type.split(';')[0].strip().lower() | ||
| 247 | |||
| 248 | return mime_to_ext.get(base_mime, '') | ||
| 249 | |||
| 250 | def upload_file( | ||
| 251 | self, | ||
| 252 | file_data: BinaryIO, | ||
| 253 | filename: str, | ||
| 254 | user_id: int, | ||
| 255 | prefix: Optional[str] = None, | ||
| 256 | validate_extension: bool = True | ||
| 257 | ) -> Dict[str, Any]: | ||
| 258 | """ | ||
| 259 | 直接上传文件到OSS(适合小文件) | ||
| 260 | |||
| 261 | 参数: | ||
| 262 | file_data: 文件二进制流 | ||
| 263 | filename: 原始文件名 | ||
| 264 | user_id: 用户ID | ||
| 265 | prefix: 路径前缀(可选) | ||
| 266 | validate_extension: 是否验证文件扩展名 | ||
| 267 | |||
| 268 | 返回: | ||
| 269 | 上传结果信息 | ||
| 270 | """ | ||
| 271 | # 使用重试装饰器包装实际上传操作 | ||
| 272 | @retry_on_connection_error( | ||
| 273 | max_retries=self.max_retries, | ||
| 274 | delay=self.retry_delay | ||
| 275 | ) | ||
| 276 | def _do_upload(): | ||
| 277 | # 验证文件扩展名 | ||
| 278 | if validate_extension and not self._validate_file_extension(filename): | ||
| 279 | raise ValidationException( | ||
| 280 | f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}" | ||
| 281 | ) | ||
| 282 | |||
| 283 | # 生成OSS对象键 | ||
| 284 | object_key = self._generate_object_key(filename, user_id, prefix) | ||
| 285 | |||
| 286 | # 获取Content-Type | ||
| 287 | content_type = self._get_content_type(filename) | ||
| 288 | |||
| 289 | # 上传文件 | ||
| 290 | result = self.bucket.put_object( | ||
| 291 | object_key, | ||
| 292 | file_data, | ||
| 293 | headers={'Content-Type': content_type} | ||
| 294 | ) | ||
| 295 | |||
| 296 | # 构建文件URL | ||
| 297 | file_url = self._build_file_url(object_key) | ||
| 298 | |||
| 299 | return { | ||
| 300 | "object_key": object_key, | ||
| 301 | "filename": filename, | ||
| 302 | "url": file_url, | ||
| 303 | "content_type": content_type, | ||
| 304 | "etag": result.etag, | ||
| 305 | "uploaded_at": datetime.now().isoformat() | ||
| 306 | } | ||
| 307 | |||
| 308 | try: | ||
| 309 | return _do_upload() | ||
| 310 | except oss2.exceptions.OssError as e: | ||
| 311 | logger.error(f"OSS上传失败: {e.message}") | ||
| 312 | raise ExternalServiceException(f"OSS上传失败: {e.message}") | ||
| 313 | except ValidationException: | ||
| 314 | raise | ||
| 315 | except Exception as e: | ||
| 316 | logger.error(f"文件上传失败: {str(e)}", exc_info=True) | ||
| 317 | raise BusinessException(f"文件上传失败: {str(e)}") | ||
| 318 | |||
| 319 | def upload_file_with_size( | ||
| 320 | self, | ||
| 321 | file_data: BinaryIO, | ||
| 322 | filename: str, | ||
| 323 | file_size: int, | ||
| 324 | user_id: int, | ||
| 325 | prefix: Optional[str] = None | ||
| 326 | ) -> Dict[str, Any]: | ||
| 327 | """ | ||
| 328 | 根据文件大小智能选择上传方式(带文件大小验证) | ||
| 329 | |||
| 330 | 参数: | ||
| 331 | file_data: 文件二进制流 | ||
| 332 | filename: 原始文件名 | ||
| 333 | file_size: 文件大小(字节) | ||
| 334 | user_id: 用户ID | ||
| 335 | prefix: 路径前缀(可选) | ||
| 336 | |||
| 337 | 返回: | ||
| 338 | 上传结果信息 | ||
| 339 | """ | ||
| 340 | # 验证文件大小 | ||
| 341 | if not self._validate_file_size(file_size): | ||
| 342 | raise ValidationException( | ||
| 343 | f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB" | ||
| 344 | ) | ||
| 345 | |||
| 346 | # 根据文件大小选择上传方式 | ||
| 347 | if file_size > settings.OSS_MULTIPART_THRESHOLD: | ||
| 348 | return self.multipart_upload(file_data, filename, file_size, user_id, prefix) | ||
| 349 | else: | ||
| 350 | result = self.upload_file(file_data, filename, user_id, prefix) | ||
| 351 | result["size"] = file_size | ||
| 352 | return result | ||
| 353 | |||
| 354 | def multipart_upload( | ||
| 355 | self, | ||
| 356 | file_data: BinaryIO, | ||
| 357 | filename: str, | ||
| 358 | file_size: int, | ||
| 359 | user_id: int, | ||
| 360 | prefix: Optional[str] = None, | ||
| 361 | part_size: Optional[int] = None, | ||
| 362 | ) -> Dict[str, Any]: | ||
| 363 | """ | ||
| 364 | 分片上传文件到OSS(适合大文件) | ||
| 365 | |||
| 366 | 参数: | ||
| 367 | file_data: 文件二进制流 | ||
| 368 | filename: 原始文件名 | ||
| 369 | file_size: 文件大小(字节) | ||
| 370 | user_id: 用户ID | ||
| 371 | prefix: 路径前缀(可选) | ||
| 372 | |||
| 373 | 返回: | ||
| 374 | 上传结果信息 | ||
| 375 | """ | ||
| 376 | upload_id = None | ||
| 377 | object_key = None | ||
| 378 | effective_part_size = part_size or settings.OSS_PART_SIZE | ||
| 379 | |||
| 380 | @retry_on_connection_error( | ||
| 381 | max_retries=self.max_retries, | ||
| 382 | delay=self.retry_delay | ||
| 383 | ) | ||
| 384 | def _do_multipart_upload(): | ||
| 385 | nonlocal upload_id, object_key | ||
| 386 | |||
| 387 | # 验证文件扩展名 | ||
| 388 | if not self._validate_file_extension(filename): | ||
| 389 | raise ValidationException( | ||
| 390 | f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}" | ||
| 391 | ) | ||
| 392 | |||
| 393 | # 生成OSS对象键 | ||
| 394 | object_key = self._generate_object_key(filename, user_id, prefix) | ||
| 395 | |||
| 396 | # 获取Content-Type | ||
| 397 | content_type = self._get_content_type(filename) | ||
| 398 | |||
| 399 | # 初始化分片上传 | ||
| 400 | upload_id = self.bucket.init_multipart_upload( | ||
| 401 | object_key, | ||
| 402 | headers={'Content-Type': content_type} | ||
| 403 | ).upload_id | ||
| 404 | |||
| 405 | logger.info(f"[multipart_upload] 开始分片上传: {object_key}, upload_id={upload_id}") | ||
| 406 | |||
| 407 | # 计算分片数量 | ||
| 408 | part_size = effective_part_size | ||
| 409 | part_count = (file_size + part_size - 1) // part_size | ||
| 410 | |||
| 411 | # 上传所有分片 | ||
| 412 | parts = [] | ||
| 413 | for part_number in range(1, part_count + 1): | ||
| 414 | # 读取分片数据 | ||
| 415 | offset = (part_number - 1) * part_size | ||
| 416 | size = min(part_size, file_size - offset) | ||
| 417 | file_data.seek(offset) | ||
| 418 | part_data = file_data.read(size) | ||
| 419 | |||
| 420 | # 上传分片(每个分片也使用重试) | ||
| 421 | @retry_on_connection_error( | ||
| 422 | max_retries=self.max_retries, | ||
| 423 | delay=self.retry_delay | ||
| 424 | ) | ||
| 425 | def _upload_part(): | ||
| 426 | result = self.bucket.upload_part( | ||
| 427 | object_key, | ||
| 428 | upload_id, | ||
| 429 | part_number, | ||
| 430 | part_data | ||
| 431 | ) | ||
| 432 | logger.debug(f"分片 {part_number}/{part_count} 上传成功") | ||
| 433 | return result | ||
| 434 | |||
| 435 | result = _upload_part() | ||
| 436 | parts.append(oss2.models.PartInfo(part_number, result.etag)) | ||
| 437 | |||
| 438 | # 完成分片上传 | ||
| 439 | @retry_on_connection_error( | ||
| 440 | max_retries=self.max_retries, | ||
| 441 | delay=self.retry_delay | ||
| 442 | ) | ||
| 443 | def _complete_upload(): | ||
| 444 | return self.bucket.complete_multipart_upload( | ||
| 445 | object_key, | ||
| 446 | upload_id, | ||
| 447 | parts | ||
| 448 | ) | ||
| 449 | |||
| 450 | result = _complete_upload() | ||
| 451 | |||
| 452 | # 构建文件URL | ||
| 453 | file_url = self._build_file_url(object_key) | ||
| 454 | |||
| 455 | return { | ||
| 456 | "object_key": object_key, | ||
| 457 | "filename": filename, | ||
| 458 | "url": file_url, | ||
| 459 | "size": file_size, | ||
| 460 | "content_type": content_type, | ||
| 461 | "etag": result.etag, | ||
| 462 | "upload_id": upload_id, | ||
| 463 | "part_count": part_count, | ||
| 464 | "uploaded_at": datetime.now().isoformat() | ||
| 465 | } | ||
| 466 | |||
| 467 | try: | ||
| 468 | return _do_multipart_upload() | ||
| 469 | except oss2.exceptions.OssError as e: | ||
| 470 | logger.error(f"OSS分片上传失败: {e.message}") | ||
| 471 | # 上传失败,尝试取消分片上传 | ||
| 472 | if upload_id and object_key: | ||
| 473 | try: | ||
| 474 | self.bucket.abort_multipart_upload( | ||
| 475 | object_key, | ||
| 476 | upload_id | ||
| 477 | ) | ||
| 478 | logger.info(f"已取消分片上传: {object_key}, upload_id={upload_id}") | ||
| 479 | except Exception as cancel_error: | ||
| 480 | logger.warning(f"取消分片上传失败: {cancel_error}") | ||
| 481 | raise ExternalServiceException(f"OSS分片上传失败: {e.message}") | ||
| 482 | except (ValidationException, ExternalServiceException): | ||
| 483 | raise | ||
| 484 | except Exception as e: | ||
| 485 | logger.error(f"分片上传失败: {str(e)}", exc_info=True) | ||
| 486 | # 上传失败,尝试取消分片上传 | ||
| 487 | if upload_id and object_key: | ||
| 488 | try: | ||
| 489 | self.bucket.abort_multipart_upload( | ||
| 490 | object_key, | ||
| 491 | upload_id | ||
| 492 | ) | ||
| 493 | except: | ||
| 494 | pass | ||
| 495 | raise BusinessException(f"分片上传失败: {str(e)}") | ||
| 496 | |||
| 497 | def generate_presigned_url( | ||
| 498 | self, | ||
| 499 | filename: str, | ||
| 500 | user_id: int, | ||
| 501 | expires: Optional[int] = None, | ||
| 502 | prefix: Optional[str] = None | ||
| 503 | ) -> Dict[str, Any]: | ||
| 504 | """ | ||
| 505 | 生成预签名上传URL(前端直传) | ||
| 506 | |||
| 507 | 参数: | ||
| 508 | filename: 文件名 | ||
| 509 | user_id: 用户ID | ||
| 510 | expires: 过期时间(秒),默认使用配置值 | ||
| 511 | prefix: 路径前缀(可选) | ||
| 512 | |||
| 513 | 返回: | ||
| 514 | 预签名URL信息 | ||
| 515 | """ | ||
| 516 | try: | ||
| 517 | # 验证文件扩展名 | ||
| 518 | if not self._validate_file_extension(filename): | ||
| 519 | raise ValidationException( | ||
| 520 | f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}" | ||
| 521 | ) | ||
| 522 | |||
| 523 | # 生成OSS对象键 | ||
| 524 | object_key = self._generate_object_key(filename, user_id, prefix) | ||
| 525 | |||
| 526 | # 设置过期时间 | ||
| 527 | expires = expires or settings.OSS_SIGNED_URL_EXPIRE | ||
| 528 | |||
| 529 | # 生成预签名URL | ||
| 530 | signed_url = self.bucket.sign_url( | ||
| 531 | 'PUT', | ||
| 532 | object_key, | ||
| 533 | expires, | ||
| 534 | headers={'Content-Type': self._get_content_type(filename)} | ||
| 535 | ) | ||
| 536 | |||
| 537 | # 构建文件URL(上传后的访问URL) | ||
| 538 | file_url = self._build_file_url(object_key) | ||
| 539 | |||
| 540 | return { | ||
| 541 | "upload_url": signed_url, | ||
| 542 | "object_key": object_key, | ||
| 543 | "file_url": file_url, | ||
| 544 | "expires_in": expires, | ||
| 545 | "expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat(), | ||
| 546 | "method": "PUT", | ||
| 547 | "headers": { | ||
| 548 | "Content-Type": self._get_content_type(filename) | ||
| 549 | } | ||
| 550 | } | ||
| 551 | |||
| 552 | except oss2.exceptions.OssError as e: | ||
| 553 | raise ExternalServiceException(f"生成签名URL失败: {e.message}") | ||
| 554 | except ValidationException: | ||
| 555 | raise | ||
| 556 | except Exception as e: | ||
| 557 | raise BusinessException(f"生成签名URL失败: {str(e)}") | ||
| 558 | |||
| 559 | def generate_multipart_presigned_urls( | ||
| 560 | self, | ||
| 561 | filename: str, | ||
| 562 | file_size: int, | ||
| 563 | user_id: int, | ||
| 564 | expires: Optional[int] = None, | ||
| 565 | prefix: Optional[str] = None | ||
| 566 | ) -> Dict[str, Any]: | ||
| 567 | """ | ||
| 568 | 生成分片上传的预签名URL(前端分片直传) | ||
| 569 | |||
| 570 | 参数: | ||
| 571 | filename: 文件名 | ||
| 572 | file_size: 文件大小(字节) | ||
| 573 | user_id: 用户ID | ||
| 574 | expires: 过期时间(秒) | ||
| 575 | prefix: 路径前缀(可选) | ||
| 576 | |||
| 577 | 返回: | ||
| 578 | 分片上传信息和预签名URL列表 | ||
| 579 | """ | ||
| 580 | try: | ||
| 581 | # 验证文件大小 | ||
| 582 | if not self._validate_file_size(file_size): | ||
| 583 | raise ValidationException( | ||
| 584 | f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB" | ||
| 585 | ) | ||
| 586 | |||
| 587 | # 验证文件扩展名 | ||
| 588 | if not self._validate_file_extension(filename): | ||
| 589 | raise ValidationException( | ||
| 590 | f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}" | ||
| 591 | ) | ||
| 592 | |||
| 593 | # 生成OSS对象键 | ||
| 594 | object_key = self._generate_object_key(filename, user_id, prefix) | ||
| 595 | |||
| 596 | # 初始化分片上传 | ||
| 597 | upload_id = self.bucket.init_multipart_upload( | ||
| 598 | object_key, | ||
| 599 | headers={'Content-Type': self._get_content_type(filename)} | ||
| 600 | ).upload_id | ||
| 601 | |||
| 602 | # 计算分片数量 | ||
| 603 | part_size = settings.OSS_PART_SIZE | ||
| 604 | part_count = (file_size + part_size - 1) // part_size | ||
| 605 | |||
| 606 | # 设置过期时间 | ||
| 607 | expires = expires or settings.OSS_SIGNED_URL_EXPIRE | ||
| 608 | |||
| 609 | # 为每个分片生成预签名URL | ||
| 610 | part_urls = [] | ||
| 611 | for part_number in range(1, part_count + 1): | ||
| 612 | params = { | ||
| 613 | 'uploadId': upload_id, | ||
| 614 | 'partNumber': str(part_number) | ||
| 615 | } | ||
| 616 | signed_url = self.bucket.sign_url( | ||
| 617 | 'PUT', | ||
| 618 | object_key, | ||
| 619 | expires, | ||
| 620 | params=params | ||
| 621 | ) | ||
| 622 | part_urls.append({ | ||
| 623 | "part_number": part_number, | ||
| 624 | "upload_url": signed_url | ||
| 625 | }) | ||
| 626 | |||
| 627 | # 构建文件URL(完成后的访问URL) | ||
| 628 | file_url = self._build_file_url(object_key) | ||
| 629 | |||
| 630 | return { | ||
| 631 | "upload_id": upload_id, | ||
| 632 | "object_key": object_key, | ||
| 633 | "file_url": file_url, | ||
| 634 | "part_size": part_size, | ||
| 635 | "part_count": part_count, | ||
| 636 | "part_urls": part_urls, | ||
| 637 | "expires_in": expires, | ||
| 638 | "expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat() | ||
| 639 | } | ||
| 640 | |||
| 641 | except oss2.exceptions.OssError as e: | ||
| 642 | raise ExternalServiceException(f"初始化分片上传失败: {e.message}") | ||
| 643 | except (ValidationException, ExternalServiceException): | ||
| 644 | raise | ||
| 645 | except Exception as e: | ||
| 646 | raise BusinessException(f"初始化分片上传失败: {str(e)}") | ||
| 647 | |||
| 648 | def complete_multipart_upload_by_client( | ||
| 649 | self, | ||
| 650 | object_key: str, | ||
| 651 | upload_id: str, | ||
| 652 | parts: list | ||
| 653 | ) -> Dict[str, Any]: | ||
| 654 | """ | ||
| 655 | 完成客户端分片上传 | ||
| 656 | |||
| 657 | 参数: | ||
| 658 | object_key: OSS对象键 | ||
| 659 | upload_id: 上传ID | ||
| 660 | parts: 分片信息列表 [{"part_number": 1, "etag": "xxx"}, ...] | ||
| 661 | |||
| 662 | 返回: | ||
| 663 | 完成结果 | ||
| 664 | """ | ||
| 665 | try: | ||
| 666 | # 构建分片信息 | ||
| 667 | part_info_list = [ | ||
| 668 | oss2.models.PartInfo(part["part_number"], part["etag"]) | ||
| 669 | for part in parts | ||
| 670 | ] | ||
| 671 | |||
| 672 | # 完成分片上传 | ||
| 673 | result = self.bucket.complete_multipart_upload( | ||
| 674 | object_key, | ||
| 675 | upload_id, | ||
| 676 | part_info_list | ||
| 677 | ) | ||
| 678 | |||
| 679 | # 构建文件URL | ||
| 680 | file_url = self._build_file_url(object_key) | ||
| 681 | |||
| 682 | return { | ||
| 683 | "object_key": object_key, | ||
| 684 | "url": file_url, | ||
| 685 | "etag": result.etag, | ||
| 686 | "completed_at": datetime.now().isoformat() | ||
| 687 | } | ||
| 688 | |||
| 689 | except oss2.exceptions.OssError as e: | ||
| 690 | raise ExternalServiceException(f"完成分片上传失败: {e.message}") | ||
| 691 | except Exception as e: | ||
| 692 | raise BusinessException(f"完成分片上传失败: {str(e)}") | ||
| 693 | |||
| 694 | def abort_multipart_upload( | ||
| 695 | self, | ||
| 696 | object_key: str, | ||
| 697 | upload_id: str | ||
| 698 | ) -> bool: | ||
| 699 | """ | ||
| 700 | 取消分片上传 | ||
| 701 | |||
| 702 | 参数: | ||
| 703 | object_key: OSS对象键 | ||
| 704 | upload_id: 上传ID | ||
| 705 | |||
| 706 | 返回: | ||
| 707 | 是否成功 | ||
| 708 | """ | ||
| 709 | @retry_on_connection_error( | ||
| 710 | max_retries=self.max_retries, | ||
| 711 | delay=self.retry_delay | ||
| 712 | ) | ||
| 713 | def _do_abort(): | ||
| 714 | self.bucket.abort_multipart_upload( | ||
| 715 | object_key, | ||
| 716 | upload_id | ||
| 717 | ) | ||
| 718 | |||
| 719 | try: | ||
| 720 | _do_abort() | ||
| 721 | return True | ||
| 722 | except oss2.exceptions.OssError as e: | ||
| 723 | raise ExternalServiceException(f"取消分片上传失败: {e.message}") | ||
| 724 | except Exception as e: | ||
| 725 | raise BusinessException(f"取消分片上传失败: {str(e)}") | ||
| 726 | |||
| 727 | def delete_file(self, object_key: str) -> bool: | ||
| 728 | """ | ||
| 729 | 删除OSS文件 | ||
| 730 | |||
| 731 | 参数: | ||
| 732 | object_key: OSS对象键 | ||
| 733 | |||
| 734 | 返回: | ||
| 735 | 是否成功 | ||
| 736 | """ | ||
| 737 | try: | ||
| 738 | self.bucket.delete_object(object_key) | ||
| 739 | return True | ||
| 740 | except oss2.exceptions.OssError as e: | ||
| 741 | raise ExternalServiceException(f"删除文件失败: {e.message}") | ||
| 742 | except Exception as e: | ||
| 743 | raise BusinessException(f"删除文件失败: {str(e)}") | ||
| 744 | |||
| 745 | def delete_files_batch(self, object_keys: list) -> Dict[str, Any]: | ||
| 746 | """ | ||
| 747 | 批量删除OSS文件 | ||
| 748 | |||
| 749 | 参数: | ||
| 750 | object_keys: OSS对象键列表 | ||
| 751 | |||
| 752 | 返回: | ||
| 753 | 删除结果 | ||
| 754 | """ | ||
| 755 | try: | ||
| 756 | result = self.bucket.batch_delete_objects(object_keys) | ||
| 757 | return { | ||
| 758 | "deleted_count": len(result.deleted_keys), | ||
| 759 | "deleted_keys": result.deleted_keys | ||
| 760 | } | ||
| 761 | except oss2.exceptions.OssError as e: | ||
| 762 | raise ExternalServiceException(f"批量删除文件失败: {e.message}") | ||
| 763 | except Exception as e: | ||
| 764 | raise BusinessException(f"批量删除文件失败: {str(e)}") | ||
| 765 | |||
| 766 | def upload_from_url( | ||
| 767 | self, | ||
| 768 | url: str, | ||
| 769 | entity_type: str, | ||
| 770 | filename: Optional[str] = None | ||
| 771 | ) -> Dict[str, Any]: | ||
| 772 | """ | ||
| 773 | 从URL下载文件并上传到OSS(用于转存外部生成的图片或视频) | ||
| 774 | |||
| 775 | 参数: | ||
| 776 | url: 文件URL(图片或视频) | ||
| 777 | entity_type: 实体类型(character/scene),用于构建存储路径 | ||
| 778 | filename: 自定义文件名(可选),如果未提供则从URL提取或根据Content-Type推断 | ||
| 779 | |||
| 780 | 返回: | ||
| 781 | 上传结果信息 | ||
| 782 | """ | ||
| 783 | import requests | ||
| 784 | from urllib.parse import urlparse | ||
| 785 | |||
| 786 | @retry_on_connection_error( | ||
| 787 | max_retries=self.max_retries, | ||
| 788 | delay=self.retry_delay | ||
| 789 | ) | ||
| 790 | def _do_upload(): | ||
| 791 | # 下载文件 | ||
| 792 | try: | ||
| 793 | response = requests.get(url, timeout=30, stream=True) | ||
| 794 | response.raise_for_status() | ||
| 795 | except Exception as e: | ||
| 796 | logger.error(f"从URL下载文件失败: {url}, error: {e}") | ||
| 797 | raise ExternalServiceException(f"从URL下载文件失败: {str(e)}") | ||
| 798 | |||
| 799 | # 获取文件内容和Content-Type | ||
| 800 | file_data = response.content | ||
| 801 | response_content_type = response.headers.get('Content-Type', '') | ||
| 802 | |||
| 803 | # 确定文件名 | ||
| 804 | final_filename = filename | ||
| 805 | if not final_filename: | ||
| 806 | # 尝试从URL提取文件名 | ||
| 807 | parsed_url = urlparse(url) | ||
| 808 | path = parsed_url.path | ||
| 809 | final_filename = Path(path).name | ||
| 810 | |||
| 811 | # 如果URL中没有文件名或没有扩展名,根据Content-Type推断 | ||
| 812 | if not final_filename or '.' not in final_filename: | ||
| 813 | ext = self._get_extension_from_mime(response_content_type) | ||
| 814 | if ext: | ||
| 815 | # 根据扩展名确定文件类型前缀 | ||
| 816 | type_prefix = 'video' if ext.startswith('.mp') or ext in ('.webm', '.mov', '.avi', '.mkv', '.mpeg') else 'file' | ||
| 817 | final_filename = f"{type_prefix}_{int(datetime.now().timestamp())}{ext}" | ||
| 818 | else: | ||
| 819 | # 如果无法推断,使用通用名称 | ||
| 820 | final_filename = f"file_{int(datetime.now().timestamp())}.bin" | ||
| 821 | |||
| 822 | # 生成OSS对象键 | ||
| 823 | object_key = self._generate_object_key_simple(final_filename, entity_type) | ||
| 824 | |||
| 825 | # 从文件名获取Content-Type(优先使用响应头的Content-Type) | ||
| 826 | content_type = response_content_type or self._get_content_type(final_filename) | ||
| 827 | |||
| 828 | # 上传文件 | ||
| 829 | result = self.bucket.put_object( | ||
| 830 | object_key, | ||
| 831 | file_data, | ||
| 832 | headers={'Content-Type': content_type} | ||
| 833 | ) | ||
| 834 | |||
| 835 | # 构建文件URL | ||
| 836 | file_url = self._build_file_url(object_key) | ||
| 837 | |||
| 838 | logger.info(f"文件转存成功: {url} -> {file_url}") | ||
| 839 | |||
| 840 | return { | ||
| 841 | "object_key": object_key, | ||
| 842 | "filename": final_filename, | ||
| 843 | "url": file_url, | ||
| 844 | "content_type": content_type, | ||
| 845 | "size": len(file_data), | ||
| 846 | "etag": result.etag, | ||
| 847 | "uploaded_at": datetime.now().isoformat() | ||
| 848 | } | ||
| 849 | |||
| 850 | try: | ||
| 851 | return _do_upload() | ||
| 852 | except oss2.exceptions.OssError as e: | ||
| 853 | logger.error(f"OSS转存失败: {e.message}") | ||
| 854 | raise ExternalServiceException(f"OSS转存失败: {e.message}") | ||
| 855 | except (ExternalServiceException,): | ||
| 856 | raise | ||
| 857 | except Exception as e: | ||
| 858 | logger.error(f"文件转存失败: {str(e)}", exc_info=True) | ||
| 859 | raise BusinessException(f"文件转存失败: {str(e)}") | ||
| 860 | |||
| 861 | def upload_from_base64( | ||
| 862 | self, | ||
| 863 | base64_data_url: str, | ||
| 864 | entity_type: str, | ||
| 865 | filename: Optional[str] = None | ||
| 866 | ) -> Dict[str, Any]: | ||
| 867 | """ | ||
| 868 | 从 Base64 data URL 解码并上传到 OSS | ||
| 869 | |||
| 870 | 参数: | ||
| 871 | base64_data_url: Base64 data URL (格式: data:image/png;base64,...) | ||
| 872 | entity_type: 实体类型(image/character/scene),用于构建存储路径 | ||
| 873 | filename: 自定义文件名(可选),如果未提供则自动生成 | ||
| 874 | |||
| 875 | 返回: | ||
| 876 | 上传结果信息 | ||
| 877 | """ | ||
| 878 | import base64 | ||
| 879 | import re | ||
| 880 | |||
| 881 | @retry_on_connection_error( | ||
| 882 | max_retries=self.max_retries, | ||
| 883 | delay=self.retry_delay | ||
| 884 | ) | ||
| 885 | def _do_upload(): | ||
| 886 | # 解析 Base64 data URL | ||
| 887 | # 格式: data:image/png;base64,iVBORw0KGgo... | ||
| 888 | if not base64_data_url.startswith('data:'): | ||
| 889 | raise ValueError(f"无效的 Base64 data URL 格式") | ||
| 890 | |||
| 891 | # 提取 MIME 类型和 Base64 数据 | ||
| 892 | match = re.match(r'data:([^;]+);base64,(.+)', base64_data_url) | ||
| 893 | if not match: | ||
| 894 | raise ValueError(f"无法解析 Base64 data URL") | ||
| 895 | |||
| 896 | mime_type = match.group(1) # 例如: image/png | ||
| 897 | base64_data = match.group(2) | ||
| 898 | |||
| 899 | # 根据 MIME 类型确定文件扩展名 | ||
| 900 | ext_map = { | ||
| 901 | 'image/png': '.png', | ||
| 902 | 'image/jpeg': '.jpg', | ||
| 903 | 'image/jpg': '.jpg', | ||
| 904 | 'image/gif': '.gif', | ||
| 905 | 'image/webp': '.webp', | ||
| 906 | } | ||
| 907 | ext = ext_map.get(mime_type.lower(), '.png') | ||
| 908 | |||
| 909 | # 确定文件名 | ||
| 910 | if filename: | ||
| 911 | final_filename = filename | ||
| 912 | # 确保文件名有正确的扩展名 | ||
| 913 | if not final_filename.endswith(ext): | ||
| 914 | final_filename = Path(filename).stem + ext | ||
| 915 | else: | ||
| 916 | # 生成默认文件名 | ||
| 917 | final_filename = f"image_{int(datetime.now().timestamp())}{ext}" | ||
| 918 | |||
| 919 | # 解码 Base64 数据 | ||
| 920 | try: | ||
| 921 | file_data = base64.b64decode(base64_data) | ||
| 922 | except Exception as e: | ||
| 923 | logger.error(f"Base64 解码失败: {e}") | ||
| 924 | raise ExternalServiceException(f"Base64 解码失败: {str(e)}") | ||
| 925 | |||
| 926 | # 生成 OSS 对象键 | ||
| 927 | object_key = self._generate_object_key_simple(final_filename, entity_type) | ||
| 928 | |||
| 929 | # 上传文件 | ||
| 930 | result = self.bucket.put_object( | ||
| 931 | object_key, | ||
| 932 | file_data, | ||
| 933 | headers={'Content-Type': mime_type} | ||
| 934 | ) | ||
| 935 | |||
| 936 | # 构建文件 URL | ||
| 937 | file_url = self._build_file_url(object_key) | ||
| 938 | |||
| 939 | logger.info(f"Base64 图片上传成功: size={len(file_data)} -> {file_url}") | ||
| 940 | |||
| 941 | return { | ||
| 942 | "object_key": object_key, | ||
| 943 | "filename": final_filename, | ||
| 944 | "url": file_url, | ||
| 945 | "content_type": mime_type, | ||
| 946 | "size": len(file_data), | ||
| 947 | "etag": result.etag, | ||
| 948 | "uploaded_at": datetime.now().isoformat() | ||
| 949 | } | ||
| 950 | |||
| 951 | try: | ||
| 952 | return _do_upload() | ||
| 953 | except oss2.exceptions.OssError as e: | ||
| 954 | logger.error(f"OSS 上传失败: {e.message}") | ||
| 955 | raise ExternalServiceException(f"OSS 上传失败: {e.message}") | ||
| 956 | except (ValueError, ExternalServiceException): | ||
| 957 | raise | ||
| 958 | except Exception as e: | ||
| 959 | logger.error(f"Base64 图片上传失败: {str(e)}", exc_info=True) | ||
| 960 | raise BusinessException(f"Base64 图片上传失败: {str(e)}") | ||
| 961 | |||
| 962 | def get_file_info(self, object_key: str) -> Dict[str, Any]: | ||
| 963 | """ | ||
| 964 | 获取文件信息 | ||
| 965 | |||
| 966 | 参数: | ||
| 967 | object_key: OSS对象键 | ||
| 968 | |||
| 969 | 返回: | ||
| 970 | 文件信息 | ||
| 971 | """ | ||
| 972 | try: | ||
| 973 | meta = self.bucket.get_object_meta(object_key) | ||
| 974 | return { | ||
| 975 | "object_key": object_key, | ||
| 976 | "size": meta.headers.get('Content-Length'), | ||
| 977 | "content_type": meta.headers.get('Content-Type'), | ||
| 978 | "etag": meta.headers.get('ETag'), | ||
| 979 | "last_modified": meta.headers.get('Last-Modified') | ||
| 980 | } | ||
| 981 | except oss2.exceptions.NoSuchKey: | ||
| 982 | raise NotFoundException(f"文件不存在: {object_key}") | ||
| 983 | except oss2.exceptions.OssError as e: | ||
| 984 | raise ExternalServiceException(f"获取文件信息失败: {e.message}") | ||
| 985 | except Exception as e: | ||
| 986 | raise BusinessException(f"获取文件信息失败: {str(e)}") | ||
| 987 | |||
| 988 | def file_exists(self, object_key: str) -> bool: | ||
| 989 | """ | ||
| 990 | 检查文件是否存在 | ||
| 991 | |||
| 992 | 参数: | ||
| 993 | object_key: OSS对象键 | ||
| 994 | |||
| 995 | 返回: | ||
| 996 | 是否存在 | ||
| 997 | """ | ||
| 998 | try: | ||
| 999 | return self.bucket.object_exists(object_key) | ||
| 1000 | except Exception: | ||
| 1001 | return False | ||
| 1002 | |||
| 1003 | def generate_download_url( | ||
| 1004 | self, | ||
| 1005 | object_key: str, | ||
| 1006 | expires: Optional[int] = None, | ||
| 1007 | filename: Optional[str] = None | ||
| 1008 | ) -> str: | ||
| 1009 | """ | ||
| 1010 | 生成文件下载URL(临时访问) | ||
| 1011 | |||
| 1012 | 参数: | ||
| 1013 | object_key: OSS对象键 | ||
| 1014 | expires: 过期时间(秒) | ||
| 1015 | filename: 下载时的文件名(可选) | ||
| 1016 | |||
| 1017 | 返回: | ||
| 1018 | 下载URL | ||
| 1019 | """ | ||
| 1020 | try: | ||
| 1021 | expires = expires or settings.OSS_SIGNED_URL_EXPIRE | ||
| 1022 | |||
| 1023 | params = {} | ||
| 1024 | if filename: | ||
| 1025 | params['response-content-disposition'] = f'attachment; filename="{filename}"' | ||
| 1026 | |||
| 1027 | return self.bucket.sign_url( | ||
| 1028 | 'GET', | ||
| 1029 | object_key, | ||
| 1030 | expires, | ||
| 1031 | params=params | ||
| 1032 | ) | ||
| 1033 | except oss2.exceptions.OssError as e: | ||
| 1034 | raise ExternalServiceException(f"生成下载URL失败: {e.message}") | ||
| 1035 | except Exception as e: | ||
| 1036 | raise BusinessException(f"生成下载URL失败: {str(e)}") | ||
| 1037 | |||
| 1038 | def _build_file_url(self, object_key: str) -> str: | ||
| 1039 | """ | ||
| 1040 | 构建文件访问URL | ||
| 1041 | |||
| 1042 | 参数: | ||
| 1043 | object_key: OSS对象键 | ||
| 1044 | |||
| 1045 | 返回: | ||
| 1046 | 文件URL | ||
| 1047 | """ | ||
| 1048 | # 如果配置了CDN域名,使用CDN域名 | ||
| 1049 | if settings.OSS_CDN_DOMAIN: | ||
| 1050 | # 移除可能存在的 http:// 或 https:// 前�� | ||
| 1051 | cdn_domain = settings.OSS_CDN_DOMAIN | ||
| 1052 | if cdn_domain.startswith("http://"): | ||
| 1053 | cdn_domain = cdn_domain[7:] | ||
| 1054 | elif cdn_domain.startswith("https://"): | ||
| 1055 | cdn_domain = cdn_domain[8:] | ||
| 1056 | |||
| 1057 | # 移除域名开头的 / | ||
| 1058 | if cdn_domain.startswith("/"): | ||
| 1059 | cdn_domain = cdn_domain[1:] | ||
| 1060 | |||
| 1061 | protocol = "https" if settings.OSS_USE_HTTPS else "http" | ||
| 1062 | return f"{protocol}://{cdn_domain}/{object_key}" | ||
| 1063 | |||
| 1064 | # 否则使用OSS域名 | ||
| 1065 | # 移除 endpoint 中可能存在的协议前缀 | ||
| 1066 | endpoint = settings.OSS_ENDPOINT | ||
| 1067 | if endpoint.startswith("http://"): | ||
| 1068 | endpoint = endpoint[7:] | ||
| 1069 | elif endpoint.startswith("https://"): | ||
| 1070 | endpoint = endpoint[8:] | ||
| 1071 | |||
| 1072 | protocol = "https" if settings.OSS_USE_HTTPS else "http" | ||
| 1073 | return f"{protocol}://{settings.OSS_BUCKET_NAME}.{endpoint}/{object_key}" | ||
| 1074 | |||
| 1075 | def generate_sts_credentials( | ||
| 1076 | self, | ||
| 1077 | user_id: int, | ||
| 1078 | duration_seconds: Optional[int] = None, | ||
| 1079 | policy: Optional[str] = None, | ||
| 1080 | role_session_name: Optional[str] = None | ||
| 1081 | ) -> Dict[str, Any]: | ||
| 1082 | """ | ||
| 1083 | 生成OSS STS临时凭证(用于前端直传) | ||
| 1084 | |||
| 1085 | 参数: | ||
| 1086 | user_id: 用户ID(用于构建会话名称) | ||
| 1087 | duration_seconds: 凭证有效期(秒),默认使用配置值 | ||
| 1088 | policy: 自定义策略(可选,JSON字符串) | ||
| 1089 | role_session_name: 角色会话名称(可选) | ||
| 1090 | |||
| 1091 | 返回: | ||
| 1092 | STS临时凭证信息,包括: | ||
| 1093 | - access_key_id: 临时AccessKey ID | ||
| 1094 | - access_key_secret: 临时AccessKey Secret | ||
| 1095 | - security_token: 安全令牌 | ||
| 1096 | - expiration: 过期时间(UTC) | ||
| 1097 | - region: 区域 | ||
| 1098 | - bucket: Bucket名称 | ||
| 1099 | - endpoint: OSS端点 | ||
| 1100 | - upload_path_prefix: 上传路径前缀 | ||
| 1101 | """ | ||
| 1102 | if not settings.OSS_STS_ROLE_ARN: | ||
| 1103 | raise BusinessException( | ||
| 1104 | "STS Role ARN未配置,请检查环境变量 OSS_STS_ROLE_ARN" | ||
| 1105 | ) | ||
| 1106 | |||
| 1107 | if not all([ | ||
| 1108 | settings.OSS_ACCESS_KEY_ID, | ||
| 1109 | settings.OSS_ACCESS_KEY_SECRET, | ||
| 1110 | settings.OSS_ENDPOINT | ||
| 1111 | ]): | ||
| 1112 | raise BusinessException( | ||
| 1113 | "OSS配置不完整,请检查环境变量" | ||
| 1114 | ) | ||
| 1115 | |||
| 1116 | try: | ||
| 1117 | # 确定区域ID | ||
| 1118 | # AcsClient 的第三个参数是 region_id,不是 endpoint | ||
| 1119 | # 格式: cn-hangzhou, cn-beijing 等 | ||
| 1120 | if settings.OSS_REGION: | ||
| 1121 | region = settings.OSS_REGION | ||
| 1122 | else: | ||
| 1123 | # 从OSS端点提取区域 | ||
| 1124 | # 格式: oss-cn-hangzhou.aliyuncs.com -> cn-hangzhou | ||
| 1125 | import re | ||
| 1126 | endpoint = settings.OSS_ENDPOINT | ||
| 1127 | if not endpoint: | ||
| 1128 | raise BusinessException("OSS_ENDPOINT 未配置") | ||
| 1129 | match = re.search(r'oss-(\w+)-(\w+)\.aliyuncs\.com', endpoint) | ||
| 1130 | if match: | ||
| 1131 | region = f"{match.group(1)}-{match.group(2)}" | ||
| 1132 | else: | ||
| 1133 | raise BusinessException( | ||
| 1134 | "无法从OSS端点确定区域,请配置 OSS_REGION" | ||
| 1135 | ) | ||
| 1136 | |||
| 1137 | # 创建STS客户端(使用正确的 region_id 参数) | ||
| 1138 | client = AcsClient( | ||
| 1139 | settings.OSS_ACCESS_KEY_ID, | ||
| 1140 | settings.OSS_ACCESS_KEY_SECRET, | ||
| 1141 | region | ||
| 1142 | ) | ||
| 1143 | |||
| 1144 | # 创建AssumeRole请求 | ||
| 1145 | request = AssumeRoleRequest.AssumeRoleRequest() | ||
| 1146 | request.set_RoleArn(settings.OSS_STS_ROLE_ARN) | ||
| 1147 | request.set_RoleSessionName( | ||
| 1148 | role_session_name or f"aimv-frontend-upload-user-{user_id}" | ||
| 1149 | ) | ||
| 1150 | request.set_DurationSeconds( | ||
| 1151 | duration_seconds or settings.OSS_STS_DURATION_SECONDS | ||
| 1152 | ) | ||
| 1153 | |||
| 1154 | # 设置策略(如果提供了自定义策略) | ||
| 1155 | if policy: | ||
| 1156 | request.set_Policy(policy) | ||
| 1157 | elif settings.OSS_STS_POLICY: | ||
| 1158 | request.set_Policy(settings.OSS_STS_POLICY) | ||
| 1159 | |||
| 1160 | # 发送请求 | ||
| 1161 | response = client.do_action_with_exception(request) | ||
| 1162 | result = json.loads(response.decode('utf-8')) | ||
| 1163 | |||
| 1164 | # 提取凭证信息 | ||
| 1165 | credentials = result['Credentials'] | ||
| 1166 | |||
| 1167 | # 将过期时间转换为北京时间(UTC+8) | ||
| 1168 | from app.schemas.common import convert_datetime_to_beijing | ||
| 1169 | expiration_utc = datetime.strptime(credentials['Expiration'], '%Y-%m-%dT%H:%M:%SZ') | ||
| 1170 | expiration_str = convert_datetime_to_beijing(expiration_utc) | ||
| 1171 | |||
| 1172 | # 构建上传路径前缀 | ||
| 1173 | upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}" | ||
| 1174 | if settings.OSS_GLOBAL_PREFIX: | ||
| 1175 | upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}" | ||
| 1176 | |||
| 1177 | return { | ||
| 1178 | "access_key_id": credentials['AccessKeyId'], | ||
| 1179 | "access_key_secret": credentials['AccessKeySecret'], | ||
| 1180 | "security_token": credentials['SecurityToken'], | ||
| 1181 | "expiration": expiration_str, | ||
| 1182 | "region": settings.OSS_REGION, | ||
| 1183 | "bucket": settings.OSS_BUCKET_NAME, | ||
| 1184 | "endpoint": settings.OSS_ENDPOINT, | ||
| 1185 | "cdn_domain": settings.OSS_CDN_DOMAIN, | ||
| 1186 | "upload_path_prefix": upload_path_prefix, | ||
| 1187 | } | ||
| 1188 | |||
| 1189 | except Exception as e: | ||
| 1190 | logger.error(f"生成STS临时凭证失败: {str(e)}", exc_info=True) | ||
| 1191 | raise ExternalServiceException( | ||
| 1192 | f"生成STS临时凭证失败: {str(e)}" | ||
| 1193 | ) | ||
| 1194 | |||
| 1195 | def generate_user_upload_policy(self, user_id: int) -> str: | ||
| 1196 | """ | ||
| 1197 | 为用户生成上传策略(限制只能上传到指定用户目录) | ||
| 1198 | |||
| 1199 | 参数: | ||
| 1200 | user_id: 用户ID | ||
| 1201 | |||
| 1202 | 返回: | ||
| 1203 | 策略JSON字符串 | ||
| 1204 | """ | ||
| 1205 | # 构建用户专属路径 | ||
| 1206 | upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}" | ||
| 1207 | if settings.OSS_GLOBAL_PREFIX: | ||
| 1208 | upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}" | ||
| 1209 | |||
| 1210 | # 构建策略 | ||
| 1211 | policy = { | ||
| 1212 | "Version": "1", | ||
| 1213 | "Statement": [ | ||
| 1214 | { | ||
| 1215 | "Effect": "Allow", | ||
| 1216 | "Action": [ | ||
| 1217 | "oss:PutObject", | ||
| 1218 | "oss:InitiateMultipartUpload", | ||
| 1219 | "oss:UploadPart", | ||
| 1220 | "oss:CompleteMultipartUpload", | ||
| 1221 | "oss:AbortMultipartUpload" | ||
| 1222 | ], | ||
| 1223 | "Resource": [ | ||
| 1224 | f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}/*" | ||
| 1225 | ] | ||
| 1226 | }, | ||
| 1227 | { | ||
| 1228 | "Effect": "Allow", | ||
| 1229 | "Action": [ | ||
| 1230 | "oss:ListObjects" | ||
| 1231 | ], | ||
| 1232 | "Resource": [ | ||
| 1233 | f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}", | ||
| 1234 | f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}" | ||
| 1235 | ], | ||
| 1236 | "Condition": { | ||
| 1237 | "StringLike": { | ||
| 1238 | "oss:prefix": [f"{upload_path_prefix}/*", f"{upload_path_prefix}"] | ||
| 1239 | } | ||
| 1240 | } | ||
| 1241 | } | ||
| 1242 | ] | ||
| 1243 | } | ||
| 1244 | |||
| 1245 | return json.dumps(policy) | ||
| 1246 | |||
| 1247 | |||
| 1248 | # 创建全局OSS服务实例(延迟初始化) | ||
| 1249 | _oss_service_instance = None | ||
| 1250 | |||
| 1251 | |||
| 1252 | def get_oss_service() -> OSSService: | ||
| 1253 | """获取OSS服务实例(单例模式)""" | ||
| 1254 | global _oss_service_instance | ||
| 1255 | if _oss_service_instance is None: | ||
| 1256 | _oss_service_instance = OSSService() | ||
| 1257 | return _oss_service_instance | ||
| 1258 | |||
| 1259 | |||
| 1260 | # 全局OSS服务实例 | ||
| 1261 | oss_service = get_oss_service() |
app/utils/oss_uploader.py
0 → 100644
| 1 | """ | ||
| 2 | 阿里云OSS文件上传模块 | ||
| 3 | """ | ||
| 4 | import os | ||
| 5 | import uuid | ||
| 6 | import logging | ||
| 7 | from datetime import datetime, timedelta | ||
| 8 | |||
| 9 | import oss2 | ||
| 10 | |||
| 11 | from app.core.config import settings | ||
| 12 | |||
| 13 | logger = logging.getLogger(__name__) | ||
| 14 | |||
| 15 | |||
| 16 | class OSSUploader: | ||
| 17 | """阿里云OSS上传器""" | ||
| 18 | |||
| 19 | def __init__(self): | ||
| 20 | """初始化OSS客户端""" | ||
| 21 | self.access_key_id = settings.OSS_ACCESS_KEY_ID | ||
| 22 | self.access_key_secret = settings.OSS_ACCESS_KEY_SECRET | ||
| 23 | self.endpoint = settings.OSS_ENDPOINT | ||
| 24 | self.bucket_name = settings.OSS_BUCKET_NAME | ||
| 25 | |||
| 26 | if not all([ | ||
| 27 | self.access_key_id, | ||
| 28 | self.access_key_secret, | ||
| 29 | self.endpoint, | ||
| 30 | self.bucket_name, | ||
| 31 | ]): | ||
| 32 | raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME") | ||
| 33 | |||
| 34 | logger.info( | ||
| 35 | "OSS配置: endpoint=%s, bucket=%s", | ||
| 36 | self.endpoint, | ||
| 37 | self.bucket_name, | ||
| 38 | ) | ||
| 39 | # 创建认证对象 | ||
| 40 | self.auth = oss2.Auth(self.access_key_id, self.access_key_secret) | ||
| 41 | |||
| 42 | # 默认使用公网 endpoint;非阿里云内网环境下访问 internal endpoint 容易失败。 | ||
| 43 | self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) | ||
| 44 | |||
| 45 | def upload_file(self, local_file_path, oss_object_name=None): | ||
| 46 | """ | ||
| 47 | 上传文件到OSS | ||
| 48 | |||
| 49 | Args: | ||
| 50 | local_file_path: 本地文件路径 | ||
| 51 | oss_object_name: OSS对象名称,如果不指定则使用时间戳+原文件名 | ||
| 52 | |||
| 53 | Returns: | ||
| 54 | tuple: (success: bool, url: str) 或 (success: bool, error: str) | ||
| 55 | """ | ||
| 56 | try: | ||
| 57 | if not os.path.exists(local_file_path): | ||
| 58 | logger.error(f"本地文件不存在: {local_file_path}") | ||
| 59 | return False, "本地文件不存在" | ||
| 60 | |||
| 61 | if not oss_object_name: | ||
| 62 | _, ext = os.path.splitext(local_file_path) | ||
| 63 | oss_object_name = f"{uuid.uuid4()}{ext}" | ||
| 64 | |||
| 65 | # 如果没有指定OSS对象名称,则生成一个 | ||
| 66 | date = datetime.now().strftime("%Y%m%d") | ||
| 67 | oss_object_name = f"temp_ai/{date}/{oss_object_name}" | ||
| 68 | |||
| 69 | # 上传文件 | ||
| 70 | result = self.bucket.put_object_from_file(oss_object_name, local_file_path) | ||
| 71 | |||
| 72 | # 构建文件URL | ||
| 73 | file_url = f"https://{self.bucket_name}.{self.endpoint}/{oss_object_name}" | ||
| 74 | |||
| 75 | logger.info(f"文件上传成功: {local_file_path} -> {file_url}") | ||
| 76 | return True, file_url | ||
| 77 | |||
| 78 | except Exception as e: | ||
| 79 | logger.error(f"文件上传失败: {local_file_path}, 错误: {e}") | ||
| 80 | return False, str(e) | ||
| 81 | |||
| 82 | def upload_data(self, data, oss_object_name): | ||
| 83 | """ | ||
| 84 | 上传数据到OSS | ||
| 85 | |||
| 86 | Args: | ||
| 87 | data: 要上传的数据(字符串或字节) | ||
| 88 | oss_object_name: OSS对象名称 | ||
| 89 | |||
| 90 | Returns: | ||
| 91 | dict: 包含上传结果的字典 | ||
| 92 | """ | ||
| 93 | try: | ||
| 94 | # 上传数据 | ||
| 95 | result = self.bucket.put_object(oss_object_name, data) | ||
| 96 | |||
| 97 | # 构建文件URL | ||
| 98 | file_url = f"{self.endpoint.rstrip('/')}/{self.bucket_name}/{oss_object_name}" | ||
| 99 | |||
| 100 | return { | ||
| 101 | "success": True, | ||
| 102 | "oss_object_name": oss_object_name, | ||
| 103 | "file_url": file_url, | ||
| 104 | "etag": result.etag, | ||
| 105 | "size": len(data) if isinstance(data, (str, bytes)) else 0 | ||
| 106 | } | ||
| 107 | |||
| 108 | except Exception as e: | ||
| 109 | return {"success": False, "error": str(e)} | ||
| 110 | |||
| 111 | |||
| 112 | def get_bucket(): | ||
| 113 | """获取Bucket对象""" | ||
| 114 | if not all([ | ||
| 115 | settings.OSS_ACCESS_KEY_ID, | ||
| 116 | settings.OSS_ACCESS_KEY_SECRET, | ||
| 117 | settings.OSS_ENDPOINT, | ||
| 118 | settings.OSS_BUCKET_NAME, | ||
| 119 | ]): | ||
| 120 | raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME") | ||
| 121 | |||
| 122 | auth = oss2.Auth(settings.OSS_ACCESS_KEY_ID, settings.OSS_ACCESS_KEY_SECRET) | ||
| 123 | bucket = oss2.Bucket(auth, settings.OSS_ENDPOINT, settings.OSS_BUCKET_NAME) | ||
| 124 | return bucket | ||
| 125 | |||
| 126 | |||
| 127 | def clean_expire_file(): | ||
| 128 | """核心任务函数""" | ||
| 129 | print(f"\n[{datetime.now()}] 开始执行每日清理任务...") | ||
| 130 | ROOT_PREFIX = 'temp_ai/' | ||
| 131 | bucket = get_bucket() | ||
| 132 | |||
| 133 | # 1. 计算时间阈值 | ||
| 134 | now = datetime.now() | ||
| 135 | yesterday_date = (now - timedelta(days=1)).date() | ||
| 136 | print(f"保留阈值: {yesterday_date} (即 {yesterday_date} 之前的数据将被删除)") | ||
| 137 | |||
| 138 | # 2. 遍历目录 | ||
| 139 | try: | ||
| 140 | for obj in oss2.ObjectIterator(bucket, prefix=ROOT_PREFIX, delimiter='/'): | ||
| 141 | path = "" | ||
| 142 | is_directory = False | ||
| 143 | |||
| 144 | # --- [核心修改] 统一路径获取方式 --- | ||
| 145 | |||
| 146 | # 情况 A: 它是虚拟目录 (CommonPrefix) | ||
| 147 | if hasattr(obj, 'prefix'): | ||
| 148 | path = obj.prefix | ||
| 149 | is_directory = True | ||
| 150 | |||
| 151 | # 情况 B: 它是实际对象 (SimplifiedObjectInfo) | ||
| 152 | elif hasattr(obj, 'key'): | ||
| 153 | path = obj.key | ||
| 154 | # 如果 key 以 / 结尾,说明它是一个显式创建的文件夹对象 | ||
| 155 | if path.endswith('/'): | ||
| 156 | is_directory = True | ||
| 157 | else: | ||
| 158 | is_directory = False # 这是一个普通文件 | ||
| 159 | |||
| 160 | # --- 逻辑分流 --- | ||
| 161 | |||
| 162 | if not is_directory: | ||
| 163 | # 这是一个真正的文件(且不是文件夹对象),直接跳过 | ||
| 164 | # print(f"[跳过] 散落文件: {path}") | ||
| 165 | continue | ||
| 166 | |||
| 167 | # 此时 path 必定是目录格式 (如 'temp_ai/20251229/') | ||
| 168 | # 下面开始正常的日期判断逻辑 | ||
| 169 | |||
| 170 | # 防御性去空,防止路径即为 'temp_ai/' 本身 | ||
| 171 | if path == ROOT_PREFIX: | ||
| 172 | continue | ||
| 173 | |||
| 174 | # 解析目录名 (取倒数第二个元素,因为最后一位是空字符串) | ||
| 175 | folder_name_raw = path.strip('/').split('/')[-1] | ||
| 176 | |||
| 177 | try: | ||
| 178 | folder_date_obj = datetime.strptime(folder_name_raw, "%Y%m%d").date() | ||
| 179 | |||
| 180 | if folder_date_obj < yesterday_date: | ||
| 181 | print(f"[删除] 发现过期目录: {path}") | ||
| 182 | # 注意:delete_objects_by_prefix 会删除该前缀下的所有文件 | ||
| 183 | # 如果这个目录本身是个对象,也会被一并删除,无需特殊处理 | ||
| 184 | delete_objects_by_prefix(bucket, path) | ||
| 185 | else: | ||
| 186 | # print(f"[跳过] 目录较新: {path}") | ||
| 187 | pass | ||
| 188 | |||
| 189 | except ValueError: | ||
| 190 | print(f"[跳过] 非日期命名目录: {path}") | ||
| 191 | |||
| 192 | except Exception as e: | ||
| 193 | import traceback | ||
| 194 | print(f"[严重错误] 任务执行失败: {e}") | ||
| 195 | traceback.print_exc() | ||
| 196 | |||
| 197 | |||
| 198 | def delete_objects_by_prefix(bucket, prefix): | ||
| 199 | """递归删除指定前缀下的所有文件""" | ||
| 200 | print(f" -> 正在清理目录: {prefix} ...") | ||
| 201 | batch_list = [] | ||
| 202 | try: | ||
| 203 | for obj in oss2.ObjectIterator(bucket, prefix=prefix): | ||
| 204 | batch_list.append(obj.key) | ||
| 205 | if len(batch_list) >= 1000: | ||
| 206 | bucket.batch_delete_objects(batch_list) | ||
| 207 | batch_list = [] | ||
| 208 | |||
| 209 | if batch_list: | ||
| 210 | bucket.batch_delete_objects(batch_list) | ||
| 211 | print(f" -> 目录 {prefix} 清理完毕。") | ||
| 212 | except Exception as e: | ||
| 213 | print(f" [错误] 删除过程出错: {e}") | ||
| 214 | |||
| 215 | |||
| 216 | # 创建OSS上传器实例 | ||
| 217 | oss_uploader = OSSUploader() | ||
| 218 | |||
| 219 | if __name__ == '__main__': | ||
| 220 | resp = oss_uploader.upload_file('想-dj-片段.mp3') | ||
| 221 | print(resp) |
config.py
0 → 100644
| 1 | |||
| 2 | from dashscope.common.constants import DASHSCOPE_API_KEY_ENV | ||
| 3 | |||
| 4 | |||
| 5 | ENV = 'test' | ||
| 6 | # ENV = 'local' | ||
| 7 | |||
| 8 | |||
| 9 | DEBUG = True | ||
| 10 | ### 数据库 | ||
| 11 | #dev | ||
| 12 | DB_USER = 'root' | ||
| 13 | DB_PASSWORD = 'Hikoon123!' | ||
| 14 | DB_HOST = 'rm-bp18h64ad9ak4d7h5do.mysql.rds.aliyuncs.com' | ||
| 15 | DB_DATABASE = 'music_partner' | ||
| 16 | |||
| 17 | #Redis | ||
| 18 | REDIS_HOST = '172.23.209.46' | ||
| 19 | REDIS_PORT = 6379 | ||
| 20 | REDIS_PSW = '1bvvpAmKXFhDDJXb' | ||
| 21 | REDIS_DB = 0 | ||
| 22 | #新抖key | ||
| 23 | NEW_RANK_KEY = 'vh1gbvynpyegg6gebhgepgvc6' | ||
| 24 | |||
| 25 | BACK_BASE_URL = 'https://ai-test.hikoon.com/api/partner' | ||
| 26 | |||
| 27 | EMAIL_HOST = 'smtp.exmail.qq.com' | ||
| 28 | EMAIL_PORT = 465 | ||
| 29 | EMAIL_HOST_USER = 'bigmusic@hikoon.com' | ||
| 30 | EMAIL_HOST_PASSWORD = 'Music!123' | ||
| 31 | #邮件接收人列表 | ||
| 32 | EMAIL_RECEIVERS = ['1774507011@qq.com','yangsheng@hikoon.com'] | ||
| 33 | |||
| 34 | |||
| 35 | #标签字典 | ||
| 36 | TAG_DICT = { | ||
| 37 | "viral_song": "网络热歌", | ||
| 38 | "sad_songs": "伤感老歌", | ||
| 39 | "folk_songs": "民谣", | ||
| 40 | "catchy_pop": "口水歌", | ||
| 41 | "kids_songs": "洗脑儿歌", | ||
| 42 | "tk_songs": "抖音热歌", | ||
| 43 | "net_songs": "网络歌曲", | ||
| 44 | |||
| 45 | "dj_remix": "DJ嗨曲", | ||
| 46 | "Cheesy_EDM": "土嗨/慢摇", | ||
| 47 | "car_music": "车载音乐", | ||
| 48 | "shout_rap": "喊麦", | ||
| 49 | "heavy_metal": "重金属/土摇DJ嗨曲", | ||
| 50 | |||
| 51 | "mandarin_pop": "华语流行", | ||
| 52 | "mainstream_pop": "主流Pop", | ||
| 53 | "sweet_songs": "甜歌/校园", | ||
| 54 | "hip_rock": "嘻哈说唱R&B摇滚", | ||
| 55 | "child_songs": "主流儿歌", | ||
| 56 | |||
| 57 | "international_pop": "国外流行", | ||
| 58 | "jp_pop": "日韩流行", | ||
| 59 | "west_pop": "欧美流行", | ||
| 60 | "el_edm": "电音EDM", | ||
| 61 | |||
| 62 | "chinese_style": "国风", | ||
| 63 | "opera_vocal": "戏腔/古韵", | ||
| 64 | "guochao_EDM": "国潮电子", | ||
| 65 | "gufeng_music": "传统器乐古风", | ||
| 66 | |||
| 67 | "soundtrack_instrumental": "影视/纯音", | ||
| 68 | "ys_ost": "影视OST", | ||
| 69 | "pur_music": "纯音乐", | ||
| 70 | "no_lyric": "无词BGM", | ||
| 71 | |||
| 72 | "other_music": "其他", | ||
| 73 | "jazz_blue": "爵士/蓝调", | ||
| 74 | "voice_book": "有声书", | ||
| 75 | "lab_music": "实验音乐", | ||
| 76 | "healing": "治愈", | ||
| 77 | "melancholy": "伤感", | ||
| 78 | "lonely": "孤独", | ||
| 79 | "sweet": "甜蜜", | ||
| 80 | "inspiring": "励志", | ||
| 81 | "missing": "思念", | ||
| 82 | "nostalgic": "怀旧", | ||
| 83 | "angry": "愤怒", | ||
| 84 | "relaxing": "放松", | ||
| 85 | "catchy": "魔性洗脑", | ||
| 86 | "heroic": "悲壮", | ||
| 87 | "calm": "平静", | ||
| 88 | "festive": "喜庆", | ||
| 89 | "romantic": "浪漫", | ||
| 90 | "majestic": "雄壮", | ||
| 91 | "bewitching": "蛊惑", | ||
| 92 | "cathartic": "宣泄", | ||
| 93 | "solemn": "庄重", | ||
| 94 | "passionate": "激情", | ||
| 95 | "heavy": "沉重", | ||
| 96 | "happy": "快乐", | ||
| 97 | "tense": "紧张", | ||
| 98 | "horror": "恐怖", | ||
| 99 | "touching": "感动", | ||
| 100 | "spoof": "恶搞", | ||
| 101 | "funny": "搞笑", | ||
| 102 | "expectation": "期待", | ||
| 103 | "remembrance": "怀念", | ||
| 104 | "mysterious": "悬疑", | ||
| 105 | "blessing": "祝福", | ||
| 106 | "zen": "佛系", | ||
| 107 | "soothing": "舒缓", | ||
| 108 | "melodious": "悠扬", | ||
| 109 | "warm": "温暖", | ||
| 110 | "depressed": "忧郁", | ||
| 111 | "elderly": "老年", | ||
| 112 | "middle_aged": "中年", | ||
| 113 | "young_adult": "青年", | ||
| 114 | "teenager": "少年", | ||
| 115 | "life_scene": "生活场景", | ||
| 116 | "sports": "运动", | ||
| 117 | "driving": "开车", | ||
| 118 | "travel": "旅行", | ||
| 119 | "sleep": "睡前", | ||
| 120 | "study": "学习", | ||
| 121 | "cafe": "咖啡厅", | ||
| 122 | "bar": "酒吧", | ||
| 123 | "douyin":"抖音", | ||
| 124 | "restaurant": "餐厅", | ||
| 125 | "car_scene": "汽车", | ||
| 126 | "dance": "跳舞", | ||
| 127 | "work": "工作", | ||
| 128 | "nightclub": "夜店", | ||
| 129 | "leisure": "休闲", | ||
| 130 | "live_house": "live house", | ||
| 131 | "square_dance": "广场舞", | ||
| 132 | "wedding": "婚礼", | ||
| 133 | "dating": "约会", | ||
| 134 | "festival_scene": "节日场景", | ||
| 135 | "summer": "夏天", | ||
| 136 | "winter": "冬天", | ||
| 137 | "autumn": "秋天", | ||
| 138 | "spring_festival": "春节", | ||
| 139 | "christmas": "圣诞", | ||
| 140 | "valentine": "情人节", | ||
| 141 | "time_scene": "时间场景", | ||
| 142 | "morning": "清晨", | ||
| 143 | "afternoon": "午后", | ||
| 144 | "evening": "夜晚", | ||
| 145 | "midnight": "深夜", | ||
| 146 | "regional_scene": "地域场景", | ||
| 147 | "campus": "校园", | ||
| 148 | "city": "城市", | ||
| 149 | "grassland": "草原", | ||
| 150 | "tibet": "西藏", | ||
| 151 | "xinjiang": "新疆", | ||
| 152 | "transition_style": "转场类", | ||
| 153 | "card_point_switch": "卡点切换画面类", | ||
| 154 | "reverse_suspense": "反转悬念类", | ||
| 155 | "emotion_contrast": "情绪对比类", | ||
| 156 | "mashup_collection": "混剪合集类", | ||
| 157 | "emotional_resonance": "情感共鸣向剪辑", | ||
| 158 | "scene_adaptation": "场景适配剪辑", | ||
| 159 | "highlight_slice": "高光切片剪辑", | ||
| 160 | "live_performance": "现场表演类", | ||
| 161 | "singer_live": "歌手现场演唱", | ||
| 162 | "talent_cover": "达人翻唱表演", | ||
| 163 | "audience_interaction": "观众互动表演", | ||
| 164 | "card_point_speed": "卡点、变速类", | ||
| 165 | "multi_scene_fragment": "多场景碎片化卡点", | ||
| 166 | "tech_effect_speed": "技术流特效变速", | ||
| 167 | "lyric_concrete": "歌词具象化卡点", | ||
| 168 | "loop_speed_brainwash": "循环变速洗脑", | ||
| 169 | "ugc_co_creation": "UGC共创类", | ||
| 170 | "jianying_template": "剪映模板", | ||
| 171 | "ai_singing": "AI唱歌", | ||
| 172 | "emotional_quotes": "情感语录类", | ||
| 173 | "late_night_emo": "深夜emo类", | ||
| 174 | "morning_inspiration": "清晨励志类", | ||
| 175 | "memory_destiny": "回忆杀/宿命感类", | ||
| 176 | "dynamic_lyrics_visual": "动态歌词可视化", | ||
| 177 | "basic_lyrics_effect": "基础歌词动效", | ||
| 178 | "creative_visual_enhance": "创意视觉强化", | ||
| 179 | "adaptation": "改编", | ||
| 180 | "special_effects_interaction": "特效互动类", | ||
| 181 | "gesture_magic_effect": "手势魔法特效互动", | ||
| 182 | "lip_sync_challenge": "对口型挑战", | ||
| 183 | "douyin_effect_show": "抖音特效变装秀", | ||
| 184 | # 听感演绎流 | ||
| 185 | "singing_montage": "演唱混剪", | ||
| 186 | "live_singing": "现场演唱", | ||
| 187 | |||
| 188 | # 视觉冲击流 | ||
| 189 | "change_transition": "变装转场", | ||
| 190 | "hand_dance": "手势舞", | ||
| 191 | "addictive_dance": "魔性舞蹈", | ||
| 192 | "landscape_account": "风景号", | ||
| 193 | |||
| 194 | # 氛围素材流 | ||
| 195 | "cute_pets": "萌宠", | ||
| 196 | "movie_anime_edit": "影视剧/动漫混剪", | ||
| 197 | "chinese_classical": "古风", | ||
| 198 | "mood_post": "图文心情", | ||
| 199 | |||
| 200 | # 情感共鸣流 | ||
| 201 | "animated_lyrics": "动态歌词", | ||
| 202 | "storytelling": "故事演绎", | ||
| 203 | "beauty_snaps": "颜值随拍" | ||
| 204 | } | ||
| 205 | |||
| 206 | |||
| 207 | # 模型相关配置 | ||
| 208 | BASE_MODEL = "/data/qufeng/models--MIT--ast-finetuned-audioset-10-10-0.4593/snapshots/f826b80d28226b62986cc218e5cec390b1096902" | ||
| 209 | MOE_DIR = "/data/qufeng/moe_outputs" | ||
| 210 | BASELINE_CHECKPOINT = "/data/qufeng/best_epoch_base.pt" | ||
| 211 | LABEL_MAPPING = "/data/qufeng/label_mapping.txt" | ||
| 212 | DEVICE = "cuda" # 可选: cuda/mps/cpu,为空时自动选择 | ||
| 213 | ROUTER_CHECKPOINT = "" # 为空时自动从 moe_dir/joint_train/joint/router_best.pt 推断 | ||
| 214 | EXPERTS_DIR = "" # 为空时自动从 moe_dir/experts_train/experts 推断 | ||
| 215 | |||
| 216 | # 音频处理配置 | ||
| 217 | CHUNK_SECONDS = 10.24 # 按多少秒切块推理 | ||
| 218 | CROP_SECONDS = 204.8 # 若音频超过该时长,则仅截取中间这段再切块 | ||
| 219 | MAX_CHUNKS = 10 # 每首歌最多使用多少个切片参与推理 | ||
| 220 | CHUNK_BATCH_SIZE = 8 # 切块推理的 batch size | ||
| 221 | ROUTING_THRESHOLD = 0.6 | ||
| 222 | |||
| 223 | API_CONFIG = { | ||
| 224 | "api_key": "sk-d9b4d3581bde47d887354f9160a509a2", | ||
| 225 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", | ||
| 226 | "model": "qwen3-omni-flash", | ||
| 227 | "audio_mode": "auto", | ||
| 228 | "timeout": 15, | ||
| 229 | "lyrics_timeout": 60, | ||
| 230 | "lyrics_retries": 2, | ||
| 231 | "max_retries": 5, | ||
| 232 | "retry_delay": 5 | ||
| 233 | } | ||
| 234 | # API_CONFIG_91 = { | ||
| 235 | # "api_key": "sk-E90VNVMyhfk2zDBDoToCXoipzGofD2SobwBqaCzbG3junlob", | ||
| 236 | # "base_url": "https://api.91aopusi.com/v1", | ||
| 237 | # "model": "qwen3-omni-flash", | ||
| 238 | # "audio_mode": "auto", | ||
| 239 | # "timeout": 30, | ||
| 240 | # "lyrics_timeout": 60, | ||
| 241 | # "max_retries": 5, | ||
| 242 | # "retry_delay": 5 | ||
| 243 | # } | ||
| 244 | |||
| 245 | DASHSCOPE_API_KEY = 'sk-d9b4d3581bde47d887354f9160a509a2' | ||
| 246 | |||
| 247 | OSS_ACCESS_KEY_ID='LTAI4G7UvaW2e4UTCb3KCNjN' | ||
| 248 | OSS_ACCESS_KEY_SECRET='ow5hlVMmJAQY9o7nEAtMER6MFkPedm' | ||
| 249 | OSS_ENDPOINT='oss-cn-hangzhou.aliyuncs.com' | ||
| 250 | OSS_ENDPOINT_INTERNAL='oss-cn-hangzhou-internal.aliyuncs.com' | ||
| 251 | OSS_BUCKET_NAME='ai-sound-data-test' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
logger.py
0 → 100644
| 1 | import logging.handlers | ||
| 2 | import os | ||
| 3 | from config import DEBUG | ||
| 4 | |||
| 5 | log_dir = "./logs" | ||
| 6 | log_max_bytes = 1024 * 1024 * 10 | ||
| 7 | log_backup_count = 5 | ||
| 8 | |||
| 9 | |||
| 10 | def get_logger(name, level=None): | ||
| 11 | if not level: | ||
| 12 | level = logging.DEBUG if DEBUG else logging.INFO | ||
| 13 | |||
| 14 | # 配置日志 | ||
| 15 | logger = logging.getLogger(name) | ||
| 16 | logger.setLevel(level) | ||
| 17 | # 检查日志目录是否存在,如果不存在则创建 | ||
| 18 | if not os.path.exists(log_dir): | ||
| 19 | os.makedirs(log_dir) | ||
| 20 | |||
| 21 | # 创建一个handler,用于写入日志文件 | ||
| 22 | file_handler = logging.handlers.RotatingFileHandler(f'./{log_dir}/{name}.log', maxBytes=log_max_bytes, | ||
| 23 | backupCount=log_backup_count,encoding='utf-8') | ||
| 24 | file_handler.setLevel(level) | ||
| 25 | |||
| 26 | # 定义handler的输出格式 | ||
| 27 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
| 28 | file_handler.setFormatter(formatter) | ||
| 29 | |||
| 30 | # 给logger添加handler | ||
| 31 | logger.addHandler(file_handler) | ||
| 32 | return logger | ||
| 33 | |||
| 34 | |||
| 35 | # 定义一个模块级别的变量来存储日志记录器实例 | ||
| 36 | _app_logger = None | ||
| 37 | |||
| 38 | |||
| 39 | def get_app_logger(): | ||
| 40 | global _app_logger | ||
| 41 | if _app_logger is None: | ||
| 42 | _app_logger = get_logger("app") | ||
| 43 | return _app_logger |
pipeline/batch_analyze_xlsx.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | """Batch analyze audio URLs from an xlsx file and export results to xlsx.""" | ||
| 3 | |||
| 4 | from __future__ import annotations | ||
| 5 | |||
| 6 | import argparse | ||
| 7 | import json | ||
| 8 | import math | ||
| 9 | import os | ||
| 10 | import sys | ||
| 11 | import traceback | ||
| 12 | from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait | ||
| 13 | from pathlib import Path | ||
| 14 | from typing import Any | ||
| 15 | |||
| 16 | import pandas as pd | ||
| 17 | |||
| 18 | # 允许直接 `python pipeline/batch_analyze_xlsx.py` 运行 | ||
| 19 | PROJECT_ROOT = Path(__file__).resolve().parent.parent | ||
| 20 | if str(PROJECT_ROOT) not in sys.path: | ||
| 21 | sys.path.insert(0, str(PROJECT_ROOT)) | ||
| 22 | |||
| 23 | from app.middleware.music_analyze import analyze_music | ||
| 24 | |||
| 25 | |||
| 26 | DEFAULT_OUTPUT_COLUMNS = [ | ||
| 27 | "tmeid", | ||
| 28 | "歌曲ID", | ||
| 29 | "歌曲名", | ||
| 30 | "表演者", | ||
| 31 | "歌曲时长", | ||
| 32 | "表演者类型", | ||
| 33 | "语种", | ||
| 34 | "BPM速度", | ||
| 35 | "情绪", | ||
| 36 | "网络/抖音歌曲", | ||
| 37 | "音乐风格", | ||
| 38 | "配器", | ||
| 39 | "场景", | ||
| 40 | ] | ||
| 41 | |||
| 42 | ANALYZE_COLUMNS = [ | ||
| 43 | "表演者类型", | ||
| 44 | "语种", | ||
| 45 | "BPM速度", | ||
| 46 | "情绪", | ||
| 47 | "网络/抖音歌曲", | ||
| 48 | "音乐风格", | ||
| 49 | "配器", | ||
| 50 | "场景", | ||
| 51 | ] | ||
| 52 | |||
| 53 | |||
| 54 | def _is_blank(value: Any) -> bool: | ||
| 55 | if value is None: | ||
| 56 | return True | ||
| 57 | if isinstance(value, float) and math.isnan(value): | ||
| 58 | return True | ||
| 59 | return str(value).strip() == "" | ||
| 60 | |||
| 61 | |||
| 62 | def _join_multi_value(value: Any) -> str: | ||
| 63 | if value is None: | ||
| 64 | return "" | ||
| 65 | if isinstance(value, str): | ||
| 66 | return value.strip() | ||
| 67 | if isinstance(value, list): | ||
| 68 | parts = [str(v).strip() for v in value if str(v).strip()] | ||
| 69 | return "、".join(parts) | ||
| 70 | return str(value).strip() | ||
| 71 | |||
| 72 | |||
| 73 | def _pick_first_non_blank(row: pd.Series, candidates: list[str]) -> str: | ||
| 74 | for col in candidates: | ||
| 75 | if col in row.index and not _is_blank(row[col]): | ||
| 76 | value = row[col] | ||
| 77 | if isinstance(value, float) and value.is_integer(): | ||
| 78 | return str(int(value)) | ||
| 79 | return str(value).strip() | ||
| 80 | return "" | ||
| 81 | |||
| 82 | |||
| 83 | def _normalize_key_value(value: Any) -> str: | ||
| 84 | if _is_blank(value): | ||
| 85 | return "" | ||
| 86 | if isinstance(value, float) and value.is_integer(): | ||
| 87 | return str(int(value)) | ||
| 88 | return str(value).strip() | ||
| 89 | |||
| 90 | |||
| 91 | def _resolve_url_column(df: pd.DataFrame, requested_column: str) -> str: | ||
| 92 | if requested_column in df.columns: | ||
| 93 | return requested_column | ||
| 94 | |||
| 95 | candidates = ["URL", "url", "cos访问地址", "cos_url", "audio_url"] | ||
| 96 | for col in candidates: | ||
| 97 | if col in df.columns: | ||
| 98 | print( | ||
| 99 | f"[run] url column `{requested_column}` not found, fallback to `{col}`" | ||
| 100 | ) | ||
| 101 | return col | ||
| 102 | |||
| 103 | raise ValueError( | ||
| 104 | f"column `{requested_column}` not found, available={list(df.columns)}" | ||
| 105 | ) | ||
| 106 | |||
| 107 | |||
| 108 | def _is_row_completed(out_df: pd.DataFrame, idx: int) -> bool: | ||
| 109 | for col in ANALYZE_COLUMNS: | ||
| 110 | if col not in out_df.columns: | ||
| 111 | continue | ||
| 112 | value = out_df.at[idx, col] | ||
| 113 | if not _is_blank(value): | ||
| 114 | return True | ||
| 115 | return False | ||
| 116 | |||
| 117 | |||
| 118 | def _resolve_checkpoint_path(output_path: Path, checkpoint_path: Path | None) -> Path: | ||
| 119 | if checkpoint_path is not None: | ||
| 120 | return checkpoint_path | ||
| 121 | return output_path.with_suffix(output_path.suffix + ".checkpoint.json") | ||
| 122 | |||
| 123 | |||
| 124 | def _save_progress( | ||
| 125 | out_df: pd.DataFrame, | ||
| 126 | output_path: Path, | ||
| 127 | checkpoint_path: Path, | ||
| 128 | completed_indices: set[int], | ||
| 129 | ) -> None: | ||
| 130 | output_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 131 | |||
| 132 | tmp_output = output_path.with_suffix(output_path.suffix + ".tmp") | ||
| 133 | out_df = out_df[DEFAULT_OUTPUT_COLUMNS] | ||
| 134 | out_df.to_excel(tmp_output, index=False) | ||
| 135 | tmp_output.replace(output_path) | ||
| 136 | |||
| 137 | payload = { | ||
| 138 | "completed_indices": sorted(completed_indices), | ||
| 139 | "completed_count": len(completed_indices), | ||
| 140 | "total": int(len(out_df)), | ||
| 141 | } | ||
| 142 | tmp_checkpoint = checkpoint_path.with_suffix(checkpoint_path.suffix + ".tmp") | ||
| 143 | tmp_checkpoint.write_text( | ||
| 144 | json.dumps(payload, ensure_ascii=False, indent=2), | ||
| 145 | encoding="utf-8", | ||
| 146 | ) | ||
| 147 | tmp_checkpoint.replace(checkpoint_path) | ||
| 148 | |||
| 149 | |||
| 150 | def _load_checkpoint(checkpoint_path: Path) -> set[int]: | ||
| 151 | if not checkpoint_path.exists(): | ||
| 152 | return set() | ||
| 153 | try: | ||
| 154 | payload = json.loads(checkpoint_path.read_text(encoding="utf-8")) | ||
| 155 | values = payload.get("completed_indices", []) | ||
| 156 | return {int(v) for v in values if isinstance(v, int) or str(v).isdigit()} | ||
| 157 | except Exception: | ||
| 158 | return set() | ||
| 159 | |||
| 160 | |||
| 161 | def _filter_checkpoint_indices( | ||
| 162 | checkpoint_indices: set[int], | ||
| 163 | out_df: pd.DataFrame, | ||
| 164 | df: pd.DataFrame, | ||
| 165 | url_column: str, | ||
| 166 | ) -> set[int]: | ||
| 167 | """ | ||
| 168 | 过滤 checkpoint 中的索引: | ||
| 169 | - 保留已存在分析结果的行(避免重复分析) | ||
| 170 | - 保留当前仍为空 URL 的行(继续跳过) | ||
| 171 | - 若 URL 已补齐且该行无分析结果,则不保留(允许后续补分析) | ||
| 172 | """ | ||
| 173 | filtered: set[int] = set() | ||
| 174 | for idx in checkpoint_indices: | ||
| 175 | if idx < 0 or idx >= len(out_df): | ||
| 176 | continue | ||
| 177 | if _is_row_completed(out_df, idx): | ||
| 178 | filtered.add(idx) | ||
| 179 | continue | ||
| 180 | url = df.at[idx, url_column] if url_column in df.columns else None | ||
| 181 | if _is_blank(url): | ||
| 182 | filtered.add(idx) | ||
| 183 | return filtered | ||
| 184 | |||
| 185 | |||
| 186 | def _build_metadata(row: pd.Series, metadata_columns: list[str]) -> dict[str, Any]: | ||
| 187 | metadata: dict[str, Any] = {} | ||
| 188 | # 关键字段自动透传,避免遗漏导致下游无法建立映射 | ||
| 189 | for col in ["歌曲ID", "song_id", "id"]: | ||
| 190 | if col in row.index and not _is_blank(row[col]): | ||
| 191 | metadata[col] = row[col] | ||
| 192 | break | ||
| 193 | for col in ["tmeid", "tmeID", "TMEID"]: | ||
| 194 | if col in row.index and not _is_blank(row[col]): | ||
| 195 | metadata["tmeid"] = row[col] | ||
| 196 | break | ||
| 197 | for col in metadata_columns: | ||
| 198 | if col in row.index and not _is_blank(row[col]): | ||
| 199 | metadata[col] = row[col] | ||
| 200 | return metadata | ||
| 201 | |||
| 202 | |||
| 203 | def _normalize_result(result: dict[str, Any]) -> dict[str, Any]: | ||
| 204 | return { | ||
| 205 | "表演者类型": ( | ||
| 206 | str(result.get("performer_type") or result.get("vocal_texture") or "").strip() | ||
| 207 | ), | ||
| 208 | "语种": str(result.get("language") or "").strip(), | ||
| 209 | "BPM速度": result.get("bpm"), | ||
| 210 | "情绪": _join_multi_value(result.get("emotion", [])), | ||
| 211 | "网络/抖音歌曲": _join_multi_value(result.get("douyin_tags", [])), | ||
| 212 | "音乐风格": _join_multi_value( | ||
| 213 | result.get("music_style_tags", []) | ||
| 214 | or [v for v in [result.get("genre"), result.get("sub_genre")] if v] | ||
| 215 | ), | ||
| 216 | "配器": _join_multi_value(result.get("instrument_tags", [])), | ||
| 217 | "场景": _join_multi_value(result.get("scene", [])), | ||
| 218 | } | ||
| 219 | |||
| 220 | |||
| 221 | def _build_song_tmeid_maps(df: pd.DataFrame) -> tuple[dict[str, int], dict[str, int]]: | ||
| 222 | song_id_map: dict[str, int] = {} | ||
| 223 | tmeid_map: dict[str, int] = {} | ||
| 224 | for idx, row in df.iterrows(): | ||
| 225 | song_id = _pick_first_non_blank(row, ["歌曲ID", "song_id", "id"]) | ||
| 226 | tmeid = _pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"]) | ||
| 227 | if song_id and song_id not in song_id_map: | ||
| 228 | song_id_map[song_id] = int(idx) | ||
| 229 | if tmeid and tmeid not in tmeid_map: | ||
| 230 | tmeid_map[tmeid] = int(idx) | ||
| 231 | return song_id_map, tmeid_map | ||
| 232 | |||
| 233 | |||
| 234 | def _resume_from_existing_by_keys(out_df: pd.DataFrame, existing: pd.DataFrame) -> set[int]: | ||
| 235 | """当输入行数变化时,按 歌曲ID/tmeid 匹配复用旧结果。""" | ||
| 236 | completed_indices: set[int] = set() | ||
| 237 | if existing.empty: | ||
| 238 | return completed_indices | ||
| 239 | |||
| 240 | old_song_map, old_tmeid_map = _build_song_tmeid_maps(existing) | ||
| 241 | |||
| 242 | reused = 0 | ||
| 243 | reused_by_song = 0 | ||
| 244 | reused_by_tmeid = 0 | ||
| 245 | for idx in out_df.index: | ||
| 246 | song_id = _normalize_key_value(out_df.at[idx, "歌曲ID"]) | ||
| 247 | tmeid = _normalize_key_value(out_df.at[idx, "tmeid"]) | ||
| 248 | |||
| 249 | old_idx = None | ||
| 250 | if song_id and song_id in old_song_map: | ||
| 251 | old_idx = old_song_map[song_id] | ||
| 252 | reused_by_song += 1 | ||
| 253 | elif tmeid and tmeid in old_tmeid_map: | ||
| 254 | old_idx = old_tmeid_map[tmeid] | ||
| 255 | reused_by_tmeid += 1 | ||
| 256 | |||
| 257 | if old_idx is None: | ||
| 258 | continue | ||
| 259 | |||
| 260 | for col in DEFAULT_OUTPUT_COLUMNS: | ||
| 261 | if col in existing.columns: | ||
| 262 | out_df.at[idx, col] = existing.at[old_idx, col] | ||
| 263 | |||
| 264 | if _is_row_completed(out_df, int(idx)): | ||
| 265 | completed_indices.add(int(idx)) | ||
| 266 | reused += 1 | ||
| 267 | |||
| 268 | print( | ||
| 269 | "[resume] row mismatch, reused by key: " | ||
| 270 | f"song_id_match={reused_by_song}, tmeid_match={reused_by_tmeid}, " | ||
| 271 | f"completed={reused}/{len(out_df)}" | ||
| 272 | ) | ||
| 273 | return completed_indices | ||
| 274 | |||
| 275 | |||
| 276 | def _analyze_one( | ||
| 277 | idx: int, | ||
| 278 | row: pd.Series, | ||
| 279 | url_column: str, | ||
| 280 | provider: str, | ||
| 281 | extract_lyrics: bool, | ||
| 282 | label_level: int, | ||
| 283 | metadata_columns: list[str], | ||
| 284 | ) -> tuple[int, dict[str, Any]]: | ||
| 285 | url = row.get(url_column) | ||
| 286 | if _is_blank(url): | ||
| 287 | return idx, {} | ||
| 288 | |||
| 289 | try: | ||
| 290 | metadata = _build_metadata(row, metadata_columns) | ||
| 291 | result = analyze_music( | ||
| 292 | metadata=metadata, | ||
| 293 | music_url=str(url).strip(), | ||
| 294 | provider=provider, | ||
| 295 | extract_lyrics=extract_lyrics, | ||
| 296 | label_level=label_level, | ||
| 297 | ) | ||
| 298 | if not result: | ||
| 299 | return idx, {} | ||
| 300 | return idx, _normalize_result(result) | ||
| 301 | except Exception as exc: | ||
| 302 | print(f"[warn] row={idx} analyze failed: {type(exc).__name__}: {exc}") | ||
| 303 | print(traceback.format_exc(limit=3)) | ||
| 304 | return idx, {} | ||
| 305 | |||
| 306 | |||
| 307 | def run_batch( | ||
| 308 | input_path: Path, | ||
| 309 | output_path: Path, | ||
| 310 | checkpoint_path: Path | None, | ||
| 311 | url_column: str, | ||
| 312 | provider: str, | ||
| 313 | extract_lyrics: bool, | ||
| 314 | label_level: int, | ||
| 315 | metadata_columns: list[str], | ||
| 316 | workers: int, | ||
| 317 | checkpoint_every: int, | ||
| 318 | resume: bool, | ||
| 319 | ) -> None: | ||
| 320 | df = pd.read_excel(input_path) | ||
| 321 | url_column = _resolve_url_column(df, url_column) | ||
| 322 | checkpoint_path = _resolve_checkpoint_path(output_path, checkpoint_path) | ||
| 323 | blank_url_indices = {int(idx) for idx, value in df[url_column].items() if _is_blank(value)} | ||
| 324 | |||
| 325 | # 先构建参考表基础列(来自输入元数据) | ||
| 326 | out_df = pd.DataFrame(index=df.index) | ||
| 327 | out_df["tmeid"] = [ | ||
| 328 | _pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"]) for _, row in df.iterrows() | ||
| 329 | ] | ||
| 330 | out_df["歌曲ID"] = [ | ||
| 331 | _pick_first_non_blank(row, ["歌曲ID", "song_id", "id"]) for _, row in df.iterrows() | ||
| 332 | ] | ||
| 333 | out_df["歌曲名"] = [ | ||
| 334 | _pick_first_non_blank(row, ["歌曲名", "歌曲名称", "title"]) for _, row in df.iterrows() | ||
| 335 | ] | ||
| 336 | out_df["表演者"] = [ | ||
| 337 | _pick_first_non_blank(row, ["表演者", "歌手", "artist"]) for _, row in df.iterrows() | ||
| 338 | ] | ||
| 339 | out_df["歌曲时长"] = [ | ||
| 340 | _pick_first_non_blank(row, ["歌曲时长", "duration"]) for _, row in df.iterrows() | ||
| 341 | ] | ||
| 342 | |||
| 343 | for col in DEFAULT_OUTPUT_COLUMNS: | ||
| 344 | if col not in out_df.columns: | ||
| 345 | out_df[col] = "" | ||
| 346 | |||
| 347 | completed_indices: set[int] = set() | ||
| 348 | output_aligned_by_index = False | ||
| 349 | if resume: | ||
| 350 | if output_path.exists(): | ||
| 351 | try: | ||
| 352 | existing = pd.read_excel(output_path) | ||
| 353 | if len(existing) == len(out_df): | ||
| 354 | output_aligned_by_index = True | ||
| 355 | for col in DEFAULT_OUTPUT_COLUMNS: | ||
| 356 | if col in existing.columns: | ||
| 357 | out_df[col] = existing[col] | ||
| 358 | for idx in out_df.index: | ||
| 359 | if _is_row_completed(out_df, idx): | ||
| 360 | completed_indices.add(int(idx)) | ||
| 361 | print( | ||
| 362 | f"[resume] loaded existing output: {len(completed_indices)}/{len(out_df)} completed" | ||
| 363 | ) | ||
| 364 | else: | ||
| 365 | completed_indices |= _resume_from_existing_by_keys(out_df, existing) | ||
| 366 | except Exception as exc: | ||
| 367 | print(f"[resume] failed to read existing output: {type(exc).__name__}: {exc}") | ||
| 368 | |||
| 369 | checkpoint_completed = _load_checkpoint(checkpoint_path) | ||
| 370 | if checkpoint_completed: | ||
| 371 | if output_aligned_by_index: | ||
| 372 | checkpoint_completed = _filter_checkpoint_indices( | ||
| 373 | checkpoint_completed, out_df, df, url_column | ||
| 374 | ) | ||
| 375 | before = len(completed_indices) | ||
| 376 | completed_indices |= {idx for idx in checkpoint_completed if 0 <= idx < len(out_df)} | ||
| 377 | if len(completed_indices) != before: | ||
| 378 | print( | ||
| 379 | f"[resume] loaded checkpoint: {len(completed_indices)}/{len(out_df)} completed" | ||
| 380 | ) | ||
| 381 | else: | ||
| 382 | print("[resume] ignore checkpoint due to row mismatch with previous output") | ||
| 383 | |||
| 384 | # 空 URL 行直接跳过,不参与分析 | ||
| 385 | if blank_url_indices: | ||
| 386 | completed_indices |= blank_url_indices | ||
| 387 | print(f"[run] skip blank `{url_column}` rows: {len(blank_url_indices)}") | ||
| 388 | |||
| 389 | pending_indices = [int(idx) for idx in out_df.index if int(idx) not in completed_indices] | ||
| 390 | if not pending_indices: | ||
| 391 | print("[resume] no pending rows, nothing to do") | ||
| 392 | _save_progress(out_df, output_path, checkpoint_path, completed_indices) | ||
| 393 | return | ||
| 394 | |||
| 395 | print( | ||
| 396 | f"[run] total={len(out_df)}, completed={len(completed_indices)}, pending={len(pending_indices)}" | ||
| 397 | ) | ||
| 398 | |||
| 399 | workers = max(1, workers) | ||
| 400 | checkpoint_every = max(1, checkpoint_every) | ||
| 401 | processed_since_checkpoint = 0 | ||
| 402 | executor = ThreadPoolExecutor(max_workers=workers) | ||
| 403 | futures = [] | ||
| 404 | try: | ||
| 405 | for idx in pending_indices: | ||
| 406 | row = df.iloc[idx] | ||
| 407 | futures.append( | ||
| 408 | executor.submit( | ||
| 409 | _analyze_one, | ||
| 410 | idx, | ||
| 411 | row, | ||
| 412 | url_column, | ||
| 413 | provider, | ||
| 414 | extract_lyrics, | ||
| 415 | label_level, | ||
| 416 | metadata_columns, | ||
| 417 | ) | ||
| 418 | ) | ||
| 419 | |||
| 420 | pending_futures = set(futures) | ||
| 421 | while pending_futures: | ||
| 422 | done, pending_futures = wait( | ||
| 423 | pending_futures, | ||
| 424 | timeout=1.0, | ||
| 425 | return_when=FIRST_COMPLETED, | ||
| 426 | ) | ||
| 427 | if not done: | ||
| 428 | continue | ||
| 429 | |||
| 430 | for future in done: | ||
| 431 | idx, result = future.result() | ||
| 432 | for k, v in result.items(): | ||
| 433 | out_df.at[idx, k] = v | ||
| 434 | if result: | ||
| 435 | completed_indices.add(int(idx)) | ||
| 436 | |||
| 437 | processed_since_checkpoint += 1 | ||
| 438 | if processed_since_checkpoint >= checkpoint_every: | ||
| 439 | _save_progress(out_df, output_path, checkpoint_path, completed_indices) | ||
| 440 | processed_since_checkpoint = 0 | ||
| 441 | except KeyboardInterrupt: | ||
| 442 | print("[interrupt] received keyboard interrupt, saving checkpoint...") | ||
| 443 | try: | ||
| 444 | _save_progress(out_df, output_path, checkpoint_path, completed_indices) | ||
| 445 | except Exception as exc: | ||
| 446 | print(f"[interrupt] failed to save checkpoint: {type(exc).__name__}: {exc}") | ||
| 447 | |||
| 448 | for future in futures: | ||
| 449 | future.cancel() | ||
| 450 | executor.shutdown(wait=False, cancel_futures=True) | ||
| 451 | print("[interrupt] force exit to avoid blocking on running worker threads") | ||
| 452 | os._exit(130) | ||
| 453 | finally: | ||
| 454 | try: | ||
| 455 | executor.shutdown(wait=True, cancel_futures=False) | ||
| 456 | except Exception: | ||
| 457 | pass | ||
| 458 | |||
| 459 | _save_progress(out_df, output_path, checkpoint_path, completed_indices) | ||
| 460 | |||
| 461 | |||
| 462 | def parse_args() -> argparse.Namespace: | ||
| 463 | parser = argparse.ArgumentParser(description="Batch audio analysis from xlsx") | ||
| 464 | parser.add_argument("--input", required=True, help="input xlsx path") | ||
| 465 | parser.add_argument("--output", required=True, help="output xlsx path") | ||
| 466 | parser.add_argument( | ||
| 467 | "--checkpoint", | ||
| 468 | default="", | ||
| 469 | help="checkpoint json path (default: <output>.checkpoint.json)", | ||
| 470 | ) | ||
| 471 | parser.add_argument("--url-column", default="URL", help="url column name") | ||
| 472 | parser.add_argument("--provider", default="qwen", choices=["qwen", "doubao"]) | ||
| 473 | parser.add_argument("--extract-lyrics", action="store_true", help="enable lyrics extraction") | ||
| 474 | parser.add_argument("--label-level", type=int, default=0, choices=[0, 1]) | ||
| 475 | parser.add_argument( | ||
| 476 | "--metadata-columns", | ||
| 477 | default="tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者", | ||
| 478 | help="comma separated metadata columns", | ||
| 479 | ) | ||
| 480 | parser.add_argument("--workers", type=int, default=3, help="parallel workers") | ||
| 481 | parser.add_argument( | ||
| 482 | "--checkpoint-every", | ||
| 483 | type=int, | ||
| 484 | default=10, | ||
| 485 | help="save checkpoint every N processed rows", | ||
| 486 | ) | ||
| 487 | parser.add_argument( | ||
| 488 | "--no-resume", | ||
| 489 | action="store_true", | ||
| 490 | help="disable resume from existing output/checkpoint", | ||
| 491 | ) | ||
| 492 | return parser.parse_args() | ||
| 493 | |||
| 494 | |||
| 495 | def main() -> None: | ||
| 496 | args = parse_args() | ||
| 497 | metadata_columns = [c.strip() for c in args.metadata_columns.split(",") if c.strip()] | ||
| 498 | |||
| 499 | run_batch( | ||
| 500 | input_path=Path(args.input), | ||
| 501 | output_path=Path(args.output), | ||
| 502 | checkpoint_path=Path(args.checkpoint) if args.checkpoint.strip() else None, | ||
| 503 | url_column=args.url_column, | ||
| 504 | provider=args.provider, | ||
| 505 | extract_lyrics=args.extract_lyrics, | ||
| 506 | label_level=args.label_level, | ||
| 507 | metadata_columns=metadata_columns, | ||
| 508 | workers=args.workers, | ||
| 509 | checkpoint_every=args.checkpoint_every, | ||
| 510 | resume=not args.no_resume, | ||
| 511 | ) | ||
| 512 | |||
| 513 | |||
| 514 | if __name__ == "__main__": | ||
| 515 | main() |
requirements.txt
0 → 100644
-
Please register or sign in to post a comment