Commit 7bf71620 7bf71620f01eb8ff3bc8ab5cdd8d9832a9780575 by 沈秋雨

Initial commit

0 parents
1 # Required for qwen
2 QWEN_API_KEY=sk-d9b4d3581bde47d887354f9160a509a2
3 QWEN_DASHSCOPE_API_KEY=
4 QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
5 QWEN_MODEL=qwen3-omni-flash
6 QWEN_TIMEOUT=15
7 QWEN_LYRICS_TIMEOUT=90
8 QWEN_MAX_RETRIES=3
9 MUSIC_ANALYZE_LIGHT_MODE=true
10 MUSIC_DOWNLOAD_DIR=music
11 MUSIC_MAPPING_FILE=music/music_file_mapping.csv
12
13 # Optional song structure service
14 SONGFORMER_URL=
15
16 # Optional ASR backend for lyrics_only path
17 MUSIC_LYRICS_ASR_BACKEND=funasr
18 DASHSCOPE_FUNASR_MODEL=fun-asr
19 DASHSCOPE_BASE_HTTP_API_URL=https://dashscope.aliyuncs.com/api/v1
20 DASHSCOPE_ASR_POLL_INTERVAL=1
21 DASHSCOPE_ASR_POLL_TIMEOUT=120
22 DASHSCOPE_ASR_SUBMIT_URL=https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription
23 DASHSCOPE_ASR_MODEL=qwen3-asr-flash-filetrans
24 DASHSCOPE_TASK_STATUS_BASE_URL=https://dashscope.aliyuncs.com/api/v1/tasks
1 .DS_Store
2
3 # Python cache
4 __pycache__/
5 *.py[cod]
6 *.so
7 .pytest_cache/
8 .mypy_cache/
9
10 # Virtual env
11 .venv/
12 venv/
13
14 # Local env
15 .env
16
17 # Logs
18 logs/
19 *.log
20
21 # Runtime outputs
22 outputs/
23 music/
24 *.checkpoint.json
25
26 # Local test/sample data
27 *.xlsx
28 *.xls
29 *.csv
30
31 # Keep env template and source files
32 !.env.example
1 # music_analyze_v2
2
3 当前项目是一个基于 Excel 批量跑音频标签分析的独立流水线。
4
5 实际主流程:
6
7 1. 读取输入 `xlsx`
8 2. 从指定 URL 列取音频地址
9 3. 透传部分元数据给音乐分析器
10 4. 调用 `app.middleware.music_analyze.analyze_music(...)`
11 5. 将结果整理成固定交付列并持续写回输出 `xlsx`
12 6. 通过已有输出文件和 checkpoint 支持断点续跑
13
14 当前批处理入口是 [`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
15
16 ## 当前状态
17
18 - 可直接运行的主入口:[`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
19 - 当前默认分析链路:`QwenAnalyzer`
20 - 当前实际可用 provider:`qwen`
21 - 提示词来源:[`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode)
22 - 输出格式:固定交付列,不保留原始全部输入列
23
24 说明:
25
26 - 命令行参数里虽然还保留了 `--provider doubao` 选项,但当前 [`factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py) 只实例化 `qwen`,传 `doubao` 会在运行时失败。
27 - README 以下内容按“当前代码实际行为”描述,而不是按历史规划描述。
28
29 ## 安装
30
31 ```bash
32 python3.10 -m venv .venv
33 source .venv/bin/activate
34 pip install -r requirements.txt
35 cp .env.example .env
36 ```
37
38 ## 环境变量
39
40 最小必需配置通常是:
41
42 ```env
43 QWEN_API_KEY=your_api_key
44 QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
45 QWEN_MODEL=qwen3-omni-flash
46 QWEN_TIMEOUT=15
47 QWEN_LYRICS_TIMEOUT=90
48 QWEN_MAX_RETRIES=3
49 ```
50
51 项目还支持以下可选增强能力:
52
53 - `QWEN_DASHSCOPE_API_KEY`:部分 DashScope/ASR 路径会用到
54 - `SONGFORMER_URL`:启用额外音频结构特征
55 - `MUSIC_LYRICS_ASR_BACKEND``DASHSCOPE_*`:歌词提取相关配置
56 - `OSS_*`:音频过大时走 OSS 降级上传
57
58 配置定义见 [`app/core/config.py`](/Users/sqy/Downloads/music_analyze_v2/app/core/config.py)
59
60 ## 输入要求
61
62 输入文件必须是 `xlsx`
63
64 至少需要一列音频地址。脚本按下面顺序解析 URL 列:
65
66 - 显式传入的 `--url-column`
67 - `URL`
68 - `url`
69 - `cos访问地址`
70 - `cos_url`
71 - `audio_url`
72
73 若整行 URL 为空:
74
75 - 不会发起分析
76 - 该行会被直接跳过
77 - 在断点续跑里会被视为已处理
78
79 元数据不是必填,但建议提供。脚本会优先识别这些字段:
80
81 - `歌曲ID` / `song_id` / `id`
82 - `tmeid` / `tmeID` / `TMEID`
83 - `歌曲名` / `歌曲名称` / `title`
84 - `表演者` / `歌手` / `artist`
85 - `歌曲时长` / `duration`
86
87 默认会额外透传这些列给模型作为 metadata:
88
89 - `tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者`
90
91 可通过 `--metadata-columns` 覆盖。
92
93 ## 快速开始
94
95 常规跑批:
96
97 ```bash
98 python pipeline/batch_analyze_xlsx.py \
99 --input 待分析.xlsx \
100 --output outputs/标签交付结果.xlsx \
101 --url-column URL \
102 --provider qwen \
103 --workers 3
104 ```
105
106 提取歌词:
107
108 ```bash
109 python pipeline/batch_analyze_xlsx.py \
110 --input 待分析.xlsx \
111 --output outputs/标签交付结果.xlsx \
112 --url-column URL \
113 --provider qwen \
114 --workers 3 \
115 --extract-lyrics
116 ```
117
118 从头重跑,不复用历史输出或 checkpoint:
119
120 ```bash
121 python pipeline/batch_analyze_xlsx.py \
122 --input 待分析.xlsx \
123 --output outputs/标签交付结果.xlsx \
124 --provider qwen \
125 --no-resume
126 ```
127
128
129 ## 命令行参数
130
131 | 参数 | 说明 | 当前实际行为 |
132 |------|------|-------------|
133 | `--input` | 输入 Excel 路径 | 必填 |
134 | `--output` | 输出 Excel 路径 | 必填 |
135 | `--checkpoint` | checkpoint 文件路径 | 默认是 `<output>.checkpoint.json` |
136 | `--url-column` | URL 列名 | 默认 `URL`,不存在时会自动 fallback |
137 | `--provider` | 分析 provider | 参数允许 `qwen`/`doubao`,当前实际只应使用 `qwen` |
138 | `--extract-lyrics` | 是否提取歌词 | 开启后会走带歌词分析路径 |
139 | `--label-level` | 标签级别 | `0``1` |
140 | `--metadata-columns` | 额外透传给模型的列 | 逗号分隔 |
141 | `--workers` | 并发线程数 | 默认 `3` |
142 | `--checkpoint-every` | 每处理多少行保存一次 | 默认 `10` |
143 | `--no-resume` | 禁用断点续跑 | 默认关闭 |
144
145 ## 输出结构
146
147 脚本输出的是固定交付表,不是“原始输入列 + 分析列”的全量回写。
148
149 当前输出列定义在 [`batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)`DEFAULT_OUTPUT_COLUMNS`
150
151 - `tmeid`
152 - `歌曲ID`
153 - `歌曲名`
154 - `表演者`
155 - `歌曲时长`
156 - `表演者类型`
157 - `语种`
158 - `BPM速度`
159 - `情绪`
160 - `网络/抖音歌曲`
161 - `音乐风格`
162 - `配器`
163 - `场景`
164
165 结果字段映射规则:
166
167 - `表演者类型` <- `performer_type``vocal_texture`
168 - `语种` <- `language`
169 - `BPM速度` <- `bpm`
170 - `情绪` <- `emotion`
171 - `网络/抖音歌曲` <- `douyin_tags`
172 - `音乐风格` <- `music_style_tags`,否则回退到 `genre/sub_genre`
173 - `配器` <- `instrument_tags`
174 - `场景` <- `scene`
175
176 列表型字段会被拼成 `、` 分隔字符串。
177
178 ## 断点续跑
179
180 当前断点续跑逻辑比 README 旧版描述更具体,实际行为如下:
181
182 - 如果输出文件已存在,且行数与本次输入一致:
183 直接按行号复用历史输出
184 - 如果输出文件已存在,但行数不一致:
185 尝试按 `歌曲ID``tmeid` 复用旧结果
186 - 如果 checkpoint 存在:
187 会在“输出按索引对齐”的前提下合并 checkpoint 完成状态
188 - 空 URL 行会直接加入 completed 集合
189 - 处理中按 `--checkpoint-every` 周期性落盘
190 - `Ctrl+C` 时会先保存当前进度,再强制退出避免卡住线程
191
192 默认 checkpoint 文件名:
193
194 ```text
195 <output>.checkpoint.json
196 ```
197
198 ## 提示词与分析链路
199
200 批处理脚本本身不直接读取 prompt 文件,而是走统一分析入口:
201
202 [`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
203 -> [`app/middleware/music_analyze/__init__.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/__init__.py)
204 -> [`app/middleware/music_analyze/music_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/music_analyzer.py)
205 -> [`app/middleware/music_analyze/factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py)
206 -> [`app/middleware/music_analyze/qwen_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/qwen_analyzer.py)
207 -> [`app/middleware/music_analyze/prompts.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/prompts.py)
208
209 当前 prompt 目录固定为:
210
211 - [`music_analyze_system_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt.md)
212 - [`music_analyze_system_prompt_part_a.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_a.md)
213 - [`music_analyze_system_prompt_part_b.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_b.md)
214 - [`music_analyze_user_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_user_prompt.md)
215 - [`music_lyrics_only_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_lyrics_only_prompt.md)
216
217 ## 项目结构
218
219 ```text
220 music_analyze_v2/
221 ├── app/
222 │ ├── core/
223 │ │ └── config.py
224 │ ├── middleware/
225 │ │ └── music_analyze/
226 │ │ ├── __init__.py
227 │ │ ├── base.py
228 │ │ ├── factory.py
229 │ │ ├── music_analyzer.py
230 │ │ ├── prompts.py
231 │ │ ├── qwen_analyzer.py
232 │ │ ├── doubao_analyzer.py
233 │ │ ├── audio_features.py
234 │ │ └── bpm_analyzer_tools.py
235 │ ├── prompts/
236 │ │ └── step2_music_decode/
237 │ └── utils/
238 ├── pipeline/
239 │ └── batch_analyze_xlsx.py
240 ├── outputs/
241 ├── requirements.txt
242 ├── .env
243 ├── .env.example
244 └── README.md
245 ```
246
247 ## 依赖
248
249 基础依赖见 [`requirements.txt`](/Users/sqy/Downloads/music_analyze_v2/requirements.txt)
250
251 当前显式包含:
252
253 - `openai`
254 - `requests`
255 - `httpx`
256 - `python-dotenv`
257 - `pydantic-settings`
258 - `numpy`
259 - `scipy`
260 - `librosa`
261 - `soundfile`
262 - `pandas`
263 - `openpyxl`
264
265 `dashscope``requirements.txt` 中仍是注释状态;如果你要跑依赖该 SDK 的歌词路径,需要自行安装并校验对应代码分支。
266
267 ## 常见问题
268
269 ### 为什么传了 `--provider doubao` 还是失败?
270
271 因为当前 CLI 还保留了 `doubao` 选项,但分析器工厂只支持 `qwen`。这是代码现状,不是使用方式问题。
272
273 ### 输出为什么没有保留原 Excel 的全部列?
274
275 因为当前脚本在保存时只写 `DEFAULT_OUTPUT_COLUMNS`,这是代码的固定行为。
276
277 ### 修改提示词应该改哪里?
278
279 [`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode) 下的模板文件即可。
280
281 ### 行数变了还能续跑吗?
282
283 可以部分复用。脚本会尝试按 `歌曲ID``tmeid` 匹配历史输出。
284
285 ### 如何完全重跑?
286
287 `--no-resume`,并删除旧输出和旧 checkpoint,最干净。
1 """Standalone audio analysis package."""
1 from .config import settings
2
3 __all__ = ["settings"]
1 """Minimal settings for standalone audio analysis pipeline."""
2
3 from pydantic_settings import BaseSettings, SettingsConfigDict
4
5
6 class Settings(BaseSettings):
7 model_config = SettingsConfigDict(
8 env_file=".env",
9 env_file_encoding="utf-8",
10 extra="ignore",
11 )
12
13 # Qwen
14 QWEN_API_KEY: str | None = None
15 QWEN_DASHSCOPE_API_KEY: str | None = None
16 QWEN_BASE_URL: str | None = "https://dashscope.aliyuncs.com/compatible-mode/v1"
17 QWEN_MODEL: str | None = "qwen3-omni-flash"
18 QWEN_TIMEOUT: float = 15.0
19 QWEN_LYRICS_TIMEOUT: float = 90.0
20 QWEN_MAX_RETRIES: int = 3
21 MUSIC_ANALYZE_LIGHT_MODE: bool = True
22 MUSIC_DOWNLOAD_DIR: str = "music"
23 MUSIC_MAPPING_FILE: str = "music/music_file_mapping.csv"
24
25 # Optional features
26 SONGFORMER_URL: str | None = None
27
28 # DashScope ASR
29 DASHSCOPE_FUNASR_MODEL: str = "fun-asr"
30 DASHSCOPE_BASE_HTTP_API_URL: str = "https://dashscope.aliyuncs.com/api/v1"
31 DASHSCOPE_ASR_POLL_INTERVAL: float = 1.0
32 DASHSCOPE_ASR_POLL_TIMEOUT: float = 120.0
33 DASHSCOPE_ASR_SUBMIT_URL: str = (
34 "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
35 )
36 DASHSCOPE_ASR_MODEL: str = "qwen3-asr-flash-filetrans"
37 DASHSCOPE_TASK_STATUS_BASE_URL: str = "https://dashscope.aliyuncs.com/api/v1/tasks"
38
39 # OSS
40 OSS_ACCESS_KEY_ID: str | None = None
41 OSS_ACCESS_KEY_SECRET: str | None = None
42 OSS_ENDPOINT: str | None = None
43 OSS_BUCKET_NAME: str | None = None
44 OSS_ENDPOINT_INTERNAL: str | None = None
45
46
47 settings = Settings()
1 """
2 自定义异常定义
3
4 所有业务异常都应该继承自 APIException,
5 由全局异常处理器统一处理并返回标准格式的错误响应
6 """
7 from fastapi import HTTPException, status
8 from typing import Optional, Any
9
10
11 class APIException(HTTPException):
12 """
13 API基础异常
14
15 所有业务异常的基类,可以被全局异常处理器捕获和统一处理
16 """
17
18 def __init__(
19 self,
20 status_code: int = status.HTTP_400_BAD_REQUEST,
21 detail: str = None,
22 error_code: str = None,
23 data: Any = None,
24 headers: dict = None,
25 ):
26 super().__init__(status_code=status_code, detail=detail, headers=headers)
27 self.error_code = error_code or "UNKNOWN_ERROR"
28 self.data = data
29
30
31 class UnauthorizedException(APIException):
32 """未授权异常 - 认证失败"""
33
34 def __init__(self, detail: str = "未授权", error_code: str = "UNAUTHORIZED"):
35 super().__init__(
36 status_code=status.HTTP_401_UNAUTHORIZED,
37 detail=detail,
38 error_code=error_code
39 )
40
41
42 class ForbiddenException(APIException):
43 """禁止访问异常 - 权限不足"""
44
45 def __init__(self, detail: str = "禁止访问", error_code: str = "FORBIDDEN"):
46 super().__init__(
47 status_code=status.HTTP_403_FORBIDDEN,
48 detail=detail,
49 error_code=error_code
50 )
51
52
53 class NotFoundException(APIException):
54 """资源不存在异常"""
55
56 def __init__(self, detail: str = "资源不存在", error_code: str = "NOT_FOUND"):
57 super().__init__(
58 status_code=status.HTTP_404_NOT_FOUND,
59 detail=detail,
60 error_code=error_code
61 )
62
63
64 class ConflictException(APIException):
65 """冲突异常 - 资源已存在"""
66
67 def __init__(self, detail: str = "资源已存在", error_code: str = "CONFLICT"):
68 super().__init__(
69 status_code=status.HTTP_409_CONFLICT,
70 detail=detail,
71 error_code=error_code
72 )
73
74
75 class ValidationException(APIException):
76 """验证异常 - 输入验证失败"""
77
78 def __init__(self, detail: str = "验证失败", error_code: str = "VALIDATION_ERROR"):
79 super().__init__(
80 status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
81 detail=detail,
82 error_code=error_code
83 )
84
85
86 class BusinessException(APIException):
87 """业务异常 - 业务规则验证失败"""
88
89 def __init__(
90 self,
91 detail: str = "业务操作失败",
92 error_code: str = "BUSINESS_ERROR",
93 status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
94 ):
95 super().__init__(
96 status_code=status_code,
97 detail=detail,
98 error_code=error_code
99 )
100
101
102 class InternalServerException(APIException):
103 """内部服务器异常"""
104
105 def __init__(
106 self,
107 detail: str = "内部服务器错误",
108 error_code: str = "INTERNAL_SERVER_ERROR",
109 ):
110 super().__init__(
111 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
112 detail=detail,
113 error_code=error_code
114 )
115
116
117 class DatabaseException(APIException):
118 """数据库异常"""
119
120 def __init__(
121 self,
122 detail: str = "数据库操作失败",
123 error_code: str = "DATABASE_ERROR",
124 ):
125 super().__init__(
126 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
127 detail=detail,
128 error_code=error_code
129 )
130
131
132 class ExternalServiceException(APIException):
133 """外部服务异常 - 调用第三方服务失败"""
134
135 def __init__(
136 self,
137 detail: str = "外部服务调用失败",
138 error_code: str = "EXTERNAL_SERVICE_ERROR",
139 ):
140 super().__init__(
141 status_code=status.HTTP_502_BAD_GATEWAY,
142 detail=detail,
143 error_code=error_code
144 )
145
146
147 class RateLimitException(APIException):
148 """限流异常 - 请求过于频繁"""
149
150 def __init__(
151 self,
152 detail: str = "请求过于频繁,请稍后再试",
153 error_code: str = "RATE_LIMIT_EXCEEDED",
154 ):
155 super().__init__(
156 status_code=status.HTTP_429_TOO_MANY_REQUESTS,
157 detail=detail,
158 error_code=error_code
159 )
1 """Middleware package."""
1 """
2 音乐分析模块
3 提供统一的音乐标签分析功能,支持通义千问和火山引擎豆包
4
5 主要功能:
6 - 音乐风格识别(与国际音乐分类体系对齐)
7 - 情绪识别
8 - 人声质感识别
9 - 语种识别
10 - 节奏强度分析(1-5,用于指导视频剪辑)
11 - 高潮点识别
12 - 视觉概念生成(用于MV创作)
13 - 歌词识别(可选)
14
15 支持的提供商:
16 - qwen: 通义千问 (qwen3-omni-flash)
17 - doubao: 火山引擎豆包 (doubao-seed-1-8-251228)
18
19 使用示例:
20 from app.middleware.music_analyze import analyze_music
21
22 # 基本分析
23 result = analyze_music(
24 metadata={"title": "稻香", "artist": "周杰伦"},
25 music_url="https://example.com/music.mp3",
26 provider="qwen",
27 )
28
29 # 含歌词识别
30 result = analyze_music(
31 metadata={"title": "稻香"},
32 music_url="https://example.com/music.mp3",
33 provider="qwen",
34 extract_lyrics=True,
35 )
36 """
37
38 # 主函数导出
39 from .music_analyzer import (
40 analyze_music,
41 analyze_music_lyrics_only,
42 analyze_music_with_qwen,
43 analyze_music_with_doubao,
44 get_available_providers,
45 )
46
47 # 类导出
48 from .base import AudioAnalyzer
49 from .qwen_analyzer import QwenAnalyzer
50 from .doubao_analyzer import DoubaoAnalyzer
51 from .factory import AnalyzerFactory
52
53 __version__ = "1.0.0"
1 """
2 音频特征提取模块
3 提供音频特征提取、节奏强度和能量级别计算功能
4 """
5
6 import os
7 import warnings
8 import numpy as np
9 import librosa
10 from typing import Any, Dict, List, Optional, Tuple
11 from dataclasses import dataclass
12
13 from .bpm_analyzer_tools import RealtimeBPMAnalyzerTest
14
15 # 抑制 librosa 的 audioread 弃用警告
16 warnings.filterwarnings("ignore", category=FutureWarning, module="librosa")
17
18
19 @dataclass
20 class AudioFeatures:
21 """音频特征数据"""
22
23 # 时域特征
24 rms_energy: np.ndarray # RMS 能量 (帧级别)
25 rms_times: np.ndarray # 对应的时间戳
26
27 # 频域特征
28 spectral_centroid: np.ndarray # 频谱质心 (亮度)
29 spectral_rolloff: np.ndarray # 频谱滚降 (低频占比)
30 spectral_bandwidth: np.ndarray # 频谱带宽
31
32 # 节奏特征
33 onset_strength: np.ndarray # onset 强度
34 tempo: float # BPM
35
36 # 统计信息
37 duration: float
38 sr: int
39
40
41 def extract_audio_features(audio_path: str, hop_length: int = 512) -> AudioFeatures:
42 """
43 提取音频特征
44
45 Args:
46 audio_path: 音频文件路径
47 hop_length: 帧移长度 (默认 512 samples ≈ 11.6ms @ 44.1kHz)
48
49 Returns:
50 AudioFeatures: 音频特征对象
51 """
52 # 加载音频
53 y, sr = librosa.load(audio_path, sr=None, mono=True)
54 duration = librosa.get_duration(y=y, sr=sr)
55
56 # 1. RMS 能量 (时域响度)
57 rms = librosa.feature.rms(y=y, hop_length=hop_length)[0]
58 rms_db = librosa.amplitude_to_db(rms, ref=np.max)
59 rms_times = librosa.frames_to_time(
60 np.arange(len(rms)), sr=sr, hop_length=hop_length
61 )
62
63 # 2. 频谱特征
64 spectral_centroid = librosa.feature.spectral_centroid(
65 y=y, sr=sr, hop_length=hop_length
66 )[0]
67 spectral_rolloff = librosa.feature.spectral_rolloff(
68 y=y, sr=sr, hop_length=hop_length
69 )[0]
70 spectral_bandwidth = librosa.feature.spectral_bandwidth(
71 y=y, sr=sr, hop_length=hop_length
72 )[0]
73
74 # 3. 节奏特征
75 onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
76
77 # 使用统一 BPM 分析入口(带倍频纠正)
78 bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False)
79 bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr)
80 corrected_tempo = bpm_result.get('bpm', 120.0)
81
82 return AudioFeatures(
83 rms_energy=rms_db,
84 rms_times=rms_times,
85 spectral_centroid=spectral_centroid,
86 spectral_rolloff=spectral_rolloff,
87 spectral_bandwidth=spectral_bandwidth,
88 onset_strength=onset_env,
89 tempo=corrected_tempo,
90 duration=duration,
91 sr=int(sr),
92 )
93
94
95 def calculate_rhythm_intensity(features: AudioFeatures) -> int:
96 """
97 根据音频特征计算节奏强度 (1-5)
98
99 基于以下因素综合计算:
100 - BPM (速度)
101 - Onset 强度 (节奏密度)
102 - 能量变化 (动态范围)
103
104 Args:
105 features: 音频特征对象
106
107 Returns:
108 int: 节奏强度 (1-5)
109 """
110 tempo = features.tempo
111 onset = features.onset_strength
112 rms = features.rms_energy
113
114 # 1. BPM 得分 (40-200 BPM 映射到 1-5)
115 if tempo >= 160:
116 tempo_score = 5
117 elif tempo >= 130:
118 tempo_score = 4
119 elif tempo >= 100:
120 tempo_score = 3
121 elif tempo >= 70:
122 tempo_score = 2
123 else:
124 tempo_score = 1
125
126 # 2. Onset 密度得分
127 onset_mean = np.mean(onset)
128 onset_max = np.max(onset) if len(onset) > 0 else 1
129 onset_density = onset_mean / onset_max if onset_max > 0 else 0
130
131 if onset_density >= 0.5:
132 density_score = 5
133 elif onset_density >= 0.4:
134 density_score = 4
135 elif onset_density >= 0.3:
136 density_score = 3
137 elif onset_density >= 0.2:
138 density_score = 2
139 else:
140 density_score = 1
141
142 # 3. 能量动态得分
143 rms_std = np.std(rms)
144 if rms_std >= 15:
145 dynamic_score = 5
146 elif rms_std >= 12:
147 dynamic_score = 4
148 elif rms_std >= 9:
149 dynamic_score = 3
150 elif rms_std >= 6:
151 dynamic_score = 2
152 else:
153 dynamic_score = 1
154
155 # 加权平均 (BPM 40%, 密度 35%, 动态 25%)
156 final_score = tempo_score * 0.4 + density_score * 0.35 + dynamic_score * 0.25
157
158 return int(round(final_score))
159
160
161 def calculate_energy_level(
162 features: AudioFeatures,
163 ) -> Tuple[int, Dict[str, float]]:
164 """
165 计算能量级别 (1-5) 和详细信息
166
167 Args:
168 features: 音频特征对象
169
170 Returns:
171 Tuple[int, Dict]: (能量级别 1-5, 详细信息字典)
172 """
173 # 1. 响度得分 (基于 RMS 能量)
174 rms_db = features.rms_energy
175 loudness_normalized = np.clip((rms_db + 60) / 10, 0, 5)
176 loudness_score = float(np.percentile(loudness_normalized, 75))
177
178 # 2. 亮度得分 (基于频谱质心)
179 centroid = features.spectral_centroid
180 centroid_normalized = np.clip(centroid / 4000, 0, 1)
181 brightness_score = float(np.mean(centroid_normalized)) * 5
182
183 # 3. 节奏得分 (基于 onset 强度)
184 onset = features.onset_strength
185 onset_normalized = np.clip(onset / np.percentile(onset, 90), 0, 1)
186 rhythm_score = float(np.mean(onset_normalized)) * 5
187
188 # 4. BPM 因子
189 tempo = features.tempo
190 if tempo > 140:
191 tempo_factor = 1.3
192 elif tempo > 120:
193 tempo_factor = 1.15
194 elif tempo > 100:
195 tempo_factor = 1.0
196 elif tempo > 80:
197 tempo_factor = 0.9
198 else:
199 tempo_factor = 0.8
200
201 # 综合计算
202 weights = {"loudness": 0.40, "brightness": 0.25, "rhythm": 0.35}
203
204 composite_score = (
205 weights["loudness"] * loudness_score
206 + weights["brightness"] * brightness_score
207 + weights["rhythm"] * rhythm_score
208 ) * tempo_factor
209
210 # 映射到 1-5 级别
211 if composite_score < 1.5:
212 energy_level = 1
213 elif composite_score < 2.5:
214 energy_level = 2
215 elif composite_score < 3.5:
216 energy_level = 3
217 elif composite_score < 4.5:
218 energy_level = 4
219 else:
220 energy_level = 5
221
222 details = {
223 "loudness_score": round(loudness_score, 2),
224 "brightness_score": round(brightness_score, 2),
225 "rhythm_score": round(rhythm_score, 2),
226 "tempo_factor": tempo_factor,
227 "composite_score": round(composite_score, 2),
228 }
229
230 return energy_level, details
231
232
233 def energy_level_to_string(level: int) -> str:
234 """
235 将能量级别数字转换为字符串描述
236
237 Args:
238 level: 能量级别 (1-5)
239
240 Returns:
241 str: 能量密度描述
242 """
243 mapping = {
244 1: "舒缓",
245 2: "柔和",
246 3: "律动",
247 4: "强烈",
248 5: "爆发",
249 }
250 return mapping.get(level, "律动")
251
252
253 @dataclass
254 class BeatInfo:
255 """节拍信息"""
256 beat_timestamps: List[float] # 所有节拍时间点
257 downbeat_timestamps: List[float] # 强拍时间点(每小节第一拍)
258 tempo: float # BPM
259 beat_intervals: List[float] # 节拍间隔(用于检测节奏变化)
260
261
262 @dataclass
263 class EmotionCurve:
264 """情绪曲线数据"""
265 timestamps: List[float] # 时间点
266 energy_values: List[float] # 能量值 (0-1)
267 valence_values: List[float] # 情绪效价 (0-1, 低=悲伤, 高=欢快)
268 arousal_values: List[float] # 情绪唤醒度 (0-1, 低=平静, 高=激动)
269 smoothed_curve: List[float] # 平滑后的综合情绪曲线
270
271
272 @dataclass
273 class SegmentEmotion:
274 """段落情绪数据(与 songformer 段落对齐)"""
275 start: float # 段落开始时间
276 end: float # 段落结束时间
277 label: str # 段落标签 (intro/verse/chorus/bridge/outro)
278 intensity: float # 情绪强度 (0-1)
279 energy: float # 能量值 (0-1)
280 valence: float # 效价值 (0-1)
281 arousal: float # 唤醒度 (0-1)
282 trend: str # 情绪趋势 (rising/falling/stable/peak)
283
284
285 @dataclass
286 class BeatDensityInfo:
287 """节拍密度信息(用于分镜时长规划)"""
288 segment_label: str # 段落标签
289 start: float # 开始时间
290 end: float # 结束时间
291 beat_count: int # 节拍数
292 avg_interval: float # 平均间隔(秒)
293 density_level: str # sparse/normal/dense/very_dense
294 recommended_shot_duration: str # 推荐分镜时长
295
296
297 @dataclass
298 class EnhancedClimaxInfo:
299 """增强高潮点信息(包含铺垫/持续/缓冲时长)"""
300 time: float # 高潮时间点
301 intensity: str # strong/strongest
302 buildup_start: float # 铺垫开始时间
303 buildup_duration: float # 铺垫时长(秒)
304 climax_duration: float # 高潮持续时长(秒)
305 winddown_duration: float # 缓冲时长(秒)
306
307
308 def extract_beat_timestamps(audio_path: str) -> BeatInfo:
309 """
310 提取节拍时间戳(卡点)
311
312 使用智能 BPM 检测(带倍频纠正)
313
314 Args:
315 audio_path: 音频文件路径
316
317 Returns:
318 BeatInfo: 节拍信息对象
319 """
320 y, sr = librosa.load(audio_path, sr=22050, mono=True)
321
322 # 使用统一 BPM 分析入口(带倍频纠正 + beat_times)
323 bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False)
324 bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr)
325 corrected_tempo = bpm_result.get('bpm', 120.0)
326
327 # beat_times 已经由 analyze_bpm 根据 BPM 减半情况做了抽样处理
328 beat_times = np.array(bpm_result.get('beat_times', []))
329
330 # 强拍检测(每4拍取第1拍,假设4/4拍)
331 downbeat_times = beat_times[::4].tolist() if len(beat_times) > 0 else []
332
333 # 计算节拍间隔
334 beat_intervals = np.diff(beat_times).tolist() if len(beat_times) > 1 else []
335
336 return BeatInfo(
337 beat_timestamps=beat_times.tolist(),
338 downbeat_timestamps=downbeat_times,
339 tempo=corrected_tempo,
340 beat_intervals=beat_intervals,
341 )
342
343
344 def extract_emotion_curve(
345 audio_path: str,
346 window_size: float = 2.0, # 窗口大小(秒)
347 hop_size: float = 0.5 # 步长(秒)
348 ) -> EmotionCurve:
349 """
350 提取情绪曲线
351
352 基于音频特征推断情绪:
353 - Energy (能量): RMS 能量 → 情绪强度
354 - Valence (效价): 频谱质心 + 大小调 → 正面/负面情绪
355 - Arousal (唤醒度): 节奏密度 + 能量变化 → 激动/平静
356
357 Args:
358 audio_path: 音频文件路径
359 window_size: 滑动窗口大小(秒)
360 hop_size: 滑动步长(秒)
361
362 Returns:
363 EmotionCurve: 情绪曲线数据对象
364 """
365 y, sr = librosa.load(audio_path, sr=None, mono=True)
366
367 timestamps: List[float] = []
368 energy_values: List[float] = []
369 valence_values: List[float] = []
370 arousal_values: List[float] = []
371
372 # 滑动窗口分析
373 window_samples = int(window_size * sr)
374 hop_samples = int(hop_size * sr)
375
376 for start_sample in range(0, len(y) - window_samples, hop_samples):
377 end_sample = start_sample + window_samples
378 y_window = y[start_sample:end_sample]
379 t = start_sample / sr
380
381 timestamps.append(t)
382
383 # 1. Energy: RMS 能量归一化
384 rms = np.sqrt(np.mean(y_window ** 2))
385 energy = min(rms / 0.1, 1.0) # 归一化到 0-1
386 energy_values.append(float(energy))
387
388 # 2. Valence: 基于频谱质心(高=明亮=正面)
389 centroid = librosa.feature.spectral_centroid(y=y_window, sr=sr)[0]
390 valence = min(np.mean(centroid) / 4000, 1.0)
391 valence_values.append(float(valence))
392
393 # 3. Arousal: 基于 onset 密度和能量变化
394 onset_env = librosa.onset.onset_strength(y=y_window, sr=sr)
395 arousal = min(np.mean(onset_env) / 2.0, 1.0)
396 arousal_values.append(float(arousal))
397
398 # 4. 综合情绪曲线(加权平均)
399 smoothed: List[float] = []
400 for i in range(len(timestamps)):
401 # 权重:能量 40%, 唤醒度 40%, 效价 20%
402 combined = (
403 energy_values[i] * 0.4 +
404 arousal_values[i] * 0.4 +
405 valence_values[i] * 0.2
406 )
407 smoothed.append(combined)
408
409 # 平滑处理(移动平均)
410 if len(smoothed) >= 3:
411 smoothed = np.convolve(smoothed, np.ones(3)/3, mode='same').tolist()
412
413 return EmotionCurve(
414 timestamps=timestamps,
415 energy_values=energy_values,
416 valence_values=valence_values,
417 arousal_values=arousal_values,
418 smoothed_curve=smoothed,
419 )
420
421
422 def aggregate_emotion_by_segments(
423 emotion_curve: EmotionCurve,
424 segments: List[Dict[str, Any]],
425 ) -> List[SegmentEmotion]:
426 """
427 将情绪曲线按 songformer 段落结构聚合
428
429 Args:
430 emotion_curve: 原始情绪曲线数据
431 segments: songformer 返回的段落列表,格式为:
432 [{"start": 0.0, "end": 30.5, "label": "intro"}, ...]
433
434 Returns:
435 List[SegmentEmotion]: 按段落聚合的情绪数据
436 """
437 if not segments or not emotion_curve.timestamps:
438 return []
439
440 result: List[SegmentEmotion] = []
441 timestamps = np.array(emotion_curve.timestamps)
442 energy_values = np.array(emotion_curve.energy_values)
443 valence_values = np.array(emotion_curve.valence_values)
444 arousal_values = np.array(emotion_curve.arousal_values)
445 smoothed_values = np.array(emotion_curve.smoothed_curve)
446
447 for seg in segments:
448 start = float(seg.get("start", 0))
449 end = float(seg.get("end", 0))
450 label = str(seg.get("label", "unknown"))
451
452 # 找出该段落内的数据点索引
453 mask = (timestamps >= start) & (timestamps < end)
454 indices = np.where(mask)[0]
455
456 if len(indices) == 0:
457 # 没有数据点落在该段落内,使用默认值
458 result.append(SegmentEmotion(
459 start=start,
460 end=end,
461 label=label,
462 intensity=0.5,
463 energy=0.5,
464 valence=0.5,
465 arousal=0.5,
466 trend="stable",
467 ))
468 continue
469
470 # 计算该段落的平均值
471 seg_energy = float(np.mean(energy_values[indices]))
472 seg_valence = float(np.mean(valence_values[indices]))
473 seg_arousal = float(np.mean(arousal_values[indices]))
474 seg_intensity = float(np.mean(smoothed_values[indices]))
475
476 # 计算情绪趋势
477 seg_smoothed = smoothed_values[indices]
478 trend = _calculate_trend(seg_smoothed, seg_intensity)
479
480 result.append(SegmentEmotion(
481 start=start,
482 end=end,
483 label=label,
484 intensity=round(seg_intensity, 3),
485 energy=round(seg_energy, 3),
486 valence=round(seg_valence, 3),
487 arousal=round(seg_arousal, 3),
488 trend=trend,
489 ))
490
491 return result
492
493
494 def _calculate_trend(values: np.ndarray, avg_intensity: float) -> str:
495 """
496 计算情绪趋势
497
498 Args:
499 values: 该段落内的情绪值数组
500 avg_intensity: 平均情绪强度
501
502 Returns:
503 str: rising/falling/stable/peak
504 """
505 if len(values) < 3:
506 return "stable"
507
508 # 将段落分成前半和后半
509 mid = len(values) // 2
510 first_half_avg = float(np.mean(values[:mid]))
511 second_half_avg = float(np.mean(values[mid:]))
512
513 diff = second_half_avg - first_half_avg
514 threshold = 0.05 # 5% 变化阈值
515
516 # 检查是否是高峰(平均强度高且变化不大)
517 if avg_intensity > 0.7 and abs(diff) < threshold:
518 return "peak"
519
520 if diff > threshold:
521 return "rising"
522 elif diff < -threshold:
523 return "falling"
524 else:
525 return "stable"
526
527
528 def extract_segment_emotions(
529 audio_path: str,
530 segments: List[Dict[str, Any]],
531 ) -> List[SegmentEmotion]:
532 """
533 一站式提取按段落聚合的情绪数据
534
535 Args:
536 audio_path: 音频文件路径
537 segments: songformer 返回的段落列表
538
539 Returns:
540 List[SegmentEmotion]: 按段落聚合的情绪数据
541 """
542 emotion_curve = extract_emotion_curve(audio_path)
543 return aggregate_emotion_by_segments(emotion_curve, segments)
544
545
546 def calculate_beat_density_by_segments(
547 beat_timestamps: List[float],
548 segments: List[Dict[str, Any]],
549 tempo: float = 120.0,
550 ) -> List[BeatDensityInfo]:
551 """
552 按段落计算节拍密度,用于指导分镜时长规划
553
554 Args:
555 beat_timestamps: 节拍时间戳列表
556 segments: songformer 返回的段落列表,格式为:
557 [{"start": 0.0, "end": 30.5, "label": "intro"}, ...]
558 tempo: BPM(用于辅助判断密度级别)
559
560 Returns:
561 List[BeatDensityInfo]: 按段落的节拍密度信息
562 """
563 if not segments or not beat_timestamps:
564 return []
565
566 result: List[BeatDensityInfo] = []
567 beat_array = np.array(beat_timestamps)
568
569 for seg in segments:
570 start = float(seg.get("start", 0))
571 end = float(seg.get("end", 0))
572 label = str(seg.get("label", "unknown"))
573
574 # 找出该段落内的节拍
575 mask = (beat_array >= start) & (beat_array < end)
576 segment_beats = beat_array[mask]
577 beat_count = len(segment_beats)
578
579 # 计算平均间隔
580 if beat_count >= 2:
581 intervals = np.diff(segment_beats)
582 avg_interval = float(np.mean(intervals))
583 elif beat_count == 1:
584 # 只有一个节拍,使用 BPM 估算
585 avg_interval = 60.0 / tempo
586 else:
587 # 没有节拍,使用默认值
588 avg_interval = 60.0 / tempo
589
590 # 根据平均间隔和 BPM 判断密度级别
591 # 间隔越小 = 密度越高
592 if avg_interval <= 0.3 or tempo >= 160:
593 density_level = "very_dense"
594 recommended_shot_duration = "2-4秒"
595 elif avg_interval <= 0.45 or tempo >= 130:
596 density_level = "dense"
597 recommended_shot_duration = "3-5秒"
598 elif avg_interval <= 0.6 or tempo >= 100:
599 density_level = "normal"
600 recommended_shot_duration = "4-6秒"
601 else:
602 density_level = "sparse"
603 recommended_shot_duration = "6-10秒"
604
605 result.append(BeatDensityInfo(
606 segment_label=label,
607 start=round(start, 2),
608 end=round(end, 2),
609 beat_count=beat_count,
610 avg_interval=round(avg_interval, 3),
611 density_level=density_level,
612 recommended_shot_duration=recommended_shot_duration,
613 ))
614
615 return result
616
617
618 def enhance_climax_points(
619 climax_points: List[Dict[str, Any]],
620 segments: List[Dict[str, Any]],
621 music_duration: float,
622 ) -> List[EnhancedClimaxInfo]:
623 """
624 增强高潮点信息,添加铺垫/持续/缓冲时长指导
625
626 Args:
627 climax_points: 原始高潮点列表,格式为:
628 [{"time": 60.0, "intensity": "strong"}, ...]
629 segments: songformer 返回的段落列表
630 music_duration: 音乐总时长(秒)
631
632 Returns:
633 List[EnhancedClimaxInfo]: 增强后的高潮点信息
634 """
635 if not climax_points:
636 return []
637
638 result: List[EnhancedClimaxInfo] = []
639
640 # 按时间排序高潮点
641 sorted_climax = sorted(climax_points, key=lambda x: float(x.get("time", 0)))
642
643 for i, climax in enumerate(sorted_climax):
644 time = float(climax.get("time", 0))
645 intensity = str(climax.get("intensity", "strong"))
646
647 # 根据强度确定时长参数
648 if intensity == "strongest":
649 buildup_duration = 10.0 # 最强高潮:更长的铺垫
650 climax_duration = 20.0 # 更长的高潮持续
651 winddown_duration = 10.0 # 更长的缓冲
652 else:
653 buildup_duration = 5.0 # 普通高潮
654 climax_duration = 10.0
655 winddown_duration = 5.0
656
657 # 计算铺垫开始时间(不能小于0或前一个高潮的结束)
658 buildup_start = max(0, time - buildup_duration)
659
660 # 如果有前一个高潮点,确保不重叠
661 if i > 0:
662 prev_climax_time = float(sorted_climax[i - 1].get("time", 0))
663 prev_intensity = str(sorted_climax[i - 1].get("intensity", "strong"))
664 prev_winddown = 10.0 if prev_intensity == "strongest" else 5.0
665 prev_end = prev_climax_time + prev_winddown
666
667 if buildup_start < prev_end:
668 # 调整铺垫开始时间,避免重叠
669 buildup_start = prev_end
670 buildup_duration = time - buildup_start
671
672 # 确保高潮持续+缓冲不超过音乐结束
673 if time + climax_duration + winddown_duration > music_duration:
674 # 按比例缩减
675 remaining = music_duration - time
676 if remaining > 0:
677 ratio = remaining / (climax_duration + winddown_duration)
678 climax_duration = climax_duration * ratio
679 winddown_duration = winddown_duration * ratio
680
681 result.append(EnhancedClimaxInfo(
682 time=round(time, 2),
683 intensity=intensity,
684 buildup_start=round(buildup_start, 2),
685 buildup_duration=round(buildup_duration, 2),
686 climax_duration=round(climax_duration, 2),
687 winddown_duration=round(winddown_duration, 2),
688 ))
689
690 return result
691
692
693 def format_beat_density_for_prompt(beat_density_list: List[BeatDensityInfo]) -> str:
694 """
695 将节拍密度信息格式化为提示词文本
696
697 Args:
698 beat_density_list: 节拍密度信息列表
699
700 Returns:
701 str: 格式化的文本
702 """
703 if not beat_density_list:
704 return "(无节拍密度数据)"
705
706 lines = []
707 for info in beat_density_list:
708 lines.append(
709 f"- [{info.segment_label}] {info.start:.1f}s-{info.end:.1f}s: "
710 f"节拍数={info.beat_count}, 平均间隔={info.avg_interval:.2f}s, "
711 f"密度={info.density_level}, 推荐分镜时长={info.recommended_shot_duration}"
712 )
713 return "\n".join(lines)
714
715
716 def format_enhanced_climax_for_prompt(enhanced_climax_list: List[EnhancedClimaxInfo]) -> str:
717 """
718 将增强高潮点信息格式化为提示词文本
719
720 Args:
721 enhanced_climax_list: 增强高潮点信息列表
722
723 Returns:
724 str: 格式化的文本
725 """
726 if not enhanced_climax_list:
727 return "(无高潮点数据)"
728
729 lines = []
730 for info in enhanced_climax_list:
731 lines.append(
732 f"- 高潮点 {info.time:.1f}s ({info.intensity}):\n"
733 f" · 铺垫阶段: {info.buildup_start:.1f}s - {info.time:.1f}s (约{info.buildup_duration:.1f}秒)\n"
734 f" · 高潮阶段: {info.time:.1f}s - {info.time + info.climax_duration:.1f}s (约{info.climax_duration:.1f}秒)\n"
735 f" · 缓冲阶段: {info.time + info.climax_duration:.1f}s - {info.time + info.climax_duration + info.winddown_duration:.1f}s (约{info.winddown_duration:.1f}秒)"
736 )
737 return "\n".join(lines)
1 # -*- coding: utf-8 -*-
2 """
3 音乐分析器抽象基类
4 定义统一的分析器接口
5 """
6
7 from abc import ABC, abstractmethod
8 from typing import Dict, Optional, Any, List, Set
9
10
11 # 字典定义:所有有效的字段值
12 VALID_GENRES: Set[str] = {
13 "流行",
14 "电子/舞曲",
15 "摇滚/金属",
16 "说唱",
17 "民谣/原声",
18 "国风",
19 "爵士/Soul",
20 "古典",
21 "轻音乐/Ambient",
22 "二次元/ACG",
23 "其它",
24 }
25
26 VALID_SUB_GENRES: Dict[str, Set[str]] = {
27 "流行": {"华语流行", "欧美流行", "日韩流行", "R&B", "抒情"},
28 "电子/舞曲": {"House", "Future Bass", "Dubstep", "Synthwave", "Trance", "Techno"},
29 "摇滚/金属": {"流行摇滚", "独立摇滚", "重金属", "朋克", "后摇"},
30 "说唱": {"Trap", "Old School", "Boombap", "Melodic Rap", "中文说唱"},
31 "民谣/原声": {"城市民谣", "校园民谣", "故事民谣", "乡村", "Indie Folk"},
32 "国风": {"古风", "戏腔", "新中式", "水墨风", "国潮"},
33 "爵士/Soul": {"传统爵士", "Smooth Jazz", "Fusion", "Neo-Soul", "Blues"},
34 "古典": {"管弦乐", "钢琴曲", "协奏曲", "室内乐", "歌剧"},
35 "轻音乐/Ambient": {"钢琴独奏", "Lo-fi", "冥想音乐", "氛围电子", "白噪音"},
36 "二次元/ACG": {"动画OST", "Vocaloid", "游戏音乐", "萌系", "燃系"},
37 "其它": {"世界音乐", "实验音乐", "儿歌", "戏曲", "网络热歌"},
38 }
39
40 VALID_LANGUAGES: Set[str] = {
41 "普通话",
42 "粤语",
43 "英语",
44 "韩语",
45 "闽南语",
46 "蒙语",
47 "俄语",
48 "藏语",
49 "其他",
50 }
51
52 LANGUAGE_MAPPING: Dict[str, str] = {
53 "国语": "普通话",
54 "中文": "普通话",
55 "汉语": "普通话",
56 "普通话": "普通话",
57 "广东话": "粤语",
58 "粤语": "粤语",
59 "英文": "英语",
60 "英语": "英语",
61 "韩文": "韩语",
62 "朝鲜语": "韩语",
63 "韩语": "韩语",
64 "闽南话": "闽南语",
65 "台语": "闽南语",
66 "闽南语": "闽南语",
67 "蒙语": "蒙语",
68 "蒙古语": "蒙语",
69 "俄文": "俄语",
70 "俄语": "俄语",
71 "藏文": "藏语",
72 "藏语": "藏语",
73 "其它": "其他",
74 "其他": "其他",
75 "地方语言": "其他",
76 "日语": "其他",
77 }
78
79 VALID_EMOTIONS: Set[str] = {
80 "喜庆",
81 "浪漫",
82 "雄壮",
83 "庄重",
84 "激情",
85 "快乐",
86 "励志",
87 "期待",
88 "甜蜜",
89 "感动",
90 "搞笑",
91 "祝福",
92 "温暖",
93 "宣泄",
94 "悲壮",
95 "愤怒",
96 "沉重",
97 "思念",
98 "紧张",
99 "恐怖",
100 "孤独",
101 "伤感",
102 "忧郁",
103 "蛊惑",
104 "恶搞",
105 "怀念",
106 "悬疑",
107 "佛系",
108 "舒缓",
109 "悠扬",
110 }
111
112 VALID_SCENES: Set[str] = {
113 "餐厅",
114 "汽车",
115 "跳舞",
116 "旅行",
117 "工作",
118 "校园",
119 "夜店",
120 "运动",
121 "休闲",
122 "live house",
123 "广场舞",
124 "抖音",
125 "婚礼",
126 "约会",
127 }
128
129 VALID_DOUYIN_TAGS: Set[str] = {
130 "草原",
131 "故乡",
132 "神曲",
133 "文艺",
134 "青春",
135 "治愈系",
136 "清新",
137 "奇幻",
138 }
139
140 VALID_MUSIC_STYLE_TAGS: Set[str] = {
141 "世界音乐",
142 "雷鬼",
143 "R&B/Soul",
144 "MC喊麦",
145 "另类音乐",
146 "民歌",
147 "戏曲",
148 "古风",
149 "古典音乐",
150 "HipHop",
151 "Rap",
152 "摇滚",
153 "DJ嗨曲",
154 "布鲁斯/蓝调",
155 "拉丁",
156 "舞曲",
157 "爵士",
158 "乡村",
159 "民谣",
160 "流行",
161 "轻音乐",
162 "国风",
163 "儿歌",
164 }
165
166 VALID_INSTRUMENT_TAGS: Set[str] = {
167 "二胡",
168 "竹笛",
169 "琵琶",
170 "音效",
171 "口琴",
172 "电子",
173 "木吉他",
174 "鼓组",
175 "弦乐",
176 "电吉他",
177 "古筝",
178 "钢琴",
179 }
180
181 VALID_AGES: Set[str] = {"少年", "青年", "中年", "老年", "全年龄段"}
182
183 VALID_RHYTHM_INTENSITIES: Set[str] = {"极慢", "慢", "中", "快", "极速"}
184
185 VALID_EMOTIONAL_INTENSITIES: Set[str] = {"平缓", "中等", "强烈"}
186
187 VALID_VOICE_TYPES: Set[str] = {"男声", "女声", "童声", "合唱", "无人声"}
188 VALID_PERFORMER_TYPES: Set[str] = {"男声", "女声", "童声", "合唱"}
189
190 # sub_genre 常见变体映射
191 SUB_GENRE_MAPPING: Dict[str, str] = {
192 "韩语流行": "日韩流行",
193 "韩国流行": "日韩流行",
194 "K-Pop": "日韩流行",
195 "K-pop": "日韩流行",
196 "Kpop": "日韩流行",
197 "韩流": "日韩流行",
198 "日语流行": "日韩流行",
199 "日本流行": "日韩流行",
200 "J-Pop": "日韩流行",
201 "J-pop": "日韩流行",
202 "Jpop": "日韩流行",
203 "中文流行": "华语流行",
204 "国语流行": "华语流行",
205 "中国流行": "华语流行",
206 "英语流行": "欧美流行",
207 "英文流行": "欧美流行",
208 "西方流行": "欧美流行",
209 "Pop": "欧美流行",
210 }
211
212
213 class AudioAnalyzer(ABC):
214 """音乐音频分析器抽象基类"""
215
216 @abstractmethod
217 def get_provider_name(self) -> str:
218 """获取提供商名称(如 qwen, doubao)"""
219 pass
220
221 @abstractmethod
222 def get_model_name(self) -> str:
223 """获取模型名称"""
224 pass
225
226 @abstractmethod
227 def analyze(
228 self,
229 metadata: Dict[str, Any],
230 music_url: str,
231 extract_lyrics: bool = False,
232 label_level: int = 0,
233 ) -> Optional[Dict[str, Any]]:
234 """
235 分析音乐并返回标签结果
236
237 Args:
238 metadata: 音乐元数据字典
239 music_url: 音乐文件 URL(支持音频 URL 或 Base64 编码)
240 extract_lyrics: 是否识别歌词
241 label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
242
243 Returns:
244 标准化分析结果字典,包含以下字段:
245 - genre: 音乐风格(一级风格,如:流行、摇滚)
246 - emotion: 情绪列表
247 - emotional_intensity: 情绪强度
248 - vocal_texture: 人声质感
249 - vocal_description: 人声质感描述
250 - visual_concept: 视觉概念
251 - language: 语种
252 - bpm: 节拍数(可选)
253 - lyrics: 歌词列表(可选,仅当 extract_lyrics=True 时)
254 - _model: 使用的模型名称
255 - _token_info: Token 使用信息
256 """
257 pass
258
259 def _parse_response(self, response_text: str) -> Optional[Dict[str, Any]]:
260 """
261 解析 LLM 返回的响应文本为 JSON
262
263 Args:
264 response_text: LLM 返回的原始文本
265
266 Returns:
267 解析后的字典,解析失败返回 None
268 """
269 import re
270 import json
271 import logging
272
273 logger = logging.getLogger(__name__)
274
275 if not response_text:
276 return None
277
278 # 打印原始响应用于调试
279 logger.info(f"[_parse_response] 原始响应文本:\n{response_text[:500]}...")
280
281 cleaned_text = response_text.strip()
282
283 # 移除 markdown 代码块标记
284 if cleaned_text.startswith("```json"):
285 cleaned_text = cleaned_text[7:]
286 elif cleaned_text.startswith("```"):
287 cleaned_text = cleaned_text[3:]
288
289 if cleaned_text.endswith("```"):
290 cleaned_text = cleaned_text[:-3]
291
292 cleaned_text = cleaned_text.strip()
293
294 # 提取 JSON 对象
295 try:
296 # 尝试直接解析
297 result = json.loads(cleaned_text)
298 if isinstance(result, dict):
299 logger.info(f"[_parse_response] 解析成功,字段: {list(result.keys())}")
300 elif isinstance(result, list):
301 logger.info(f"[_parse_response] 解析成功,列表长度: {len(result)}")
302 else:
303 logger.info(
304 f"[_parse_response] 解析成功,类型: {type(result).__name__}"
305 )
306 return result
307 except json.JSONDecodeError:
308 pass
309
310 # 尝试提取 {...} 中的内容
311 try:
312 match = re.search(r"\{.*\}", cleaned_text, re.DOTALL)
313 if match:
314 json_str = match.group()
315 result = json.loads(json_str)
316 if isinstance(result, dict):
317 logger.info(
318 f"[_parse_response] 正则提取解析成功,字段: {list(result.keys())}"
319 )
320 elif isinstance(result, list):
321 logger.info(
322 f"[_parse_response] 正则提取解析成功,列表长度: {len(result)}"
323 )
324 else:
325 logger.info(
326 "[_parse_response] 正则提取解析成功,类型: %s",
327 type(result).__name__,
328 )
329 return result
330 except (re.error, json.JSONDecodeError):
331 pass
332
333 # 尝试修复常见的 JSON 格式问题
334 try:
335 fixed_text = re.sub(r",(\s*})", r"\1", cleaned_text)
336 fixed_text = re.sub(r",(\s*])", r"\1", fixed_text)
337 result = json.loads(fixed_text)
338 if isinstance(result, dict):
339 logger.info(
340 f"[_parse_response] 修复后解析成功,字段: {list(result.keys())}"
341 )
342 elif isinstance(result, list):
343 logger.info(
344 f"[_parse_response] 修复后解析成功,列表长度: {len(result)}"
345 )
346 else:
347 logger.info(
348 "[_parse_response] 修复后解析成功,类型: %s",
349 type(result).__name__,
350 )
351 return result
352 except (re.error, json.JSONDecodeError):
353 pass
354
355 logger.warning(f"[_parse_response] 所有解析方法都失败")
356 return None
357
358 def _normalize_result(
359 self,
360 raw_result: Dict[str, Any],
361 model_name: str,
362 token_info: Optional[Dict[str, int]] = None,
363 ) -> Dict[str, Any]:
364 """
365 标准化分析结果
366
367 Args:
368 raw_result: 原始解析结果
369 model_name: 使用的模型名称
370 token_info: Token 使用信息
371
372 Returns:
373 标准化后的结果字典
374 """
375 import logging
376
377 logger = logging.getLogger(__name__)
378
379 if not isinstance(raw_result, dict):
380 if (
381 isinstance(raw_result, list)
382 and raw_result
383 and isinstance(raw_result[0], dict)
384 ):
385 raw_result = raw_result[0]
386 else:
387 logger.warning(
388 f"[_normalize_result] 原始结果类型异常: {type(raw_result).__name__}"
389 )
390 return {"_model": model_name, "_raw": raw_result}
391
392 logger.info(f"[_normalize_result] 原始结果字段: {list(raw_result.keys())}")
393 logger.info(f"[_normalize_result] genre: {raw_result.get('genre')}")
394 logger.info(f"[_normalize_result] emotion: {raw_result.get('emotion')}")
395 logger.info(f"[_normalize_result] scene: {raw_result.get('scene')}")
396 logger.info(f"[_normalize_result] token_info 参数: {token_info}")
397
398 def _extract_style(raw_style) -> Optional[Dict[str, str]]:
399 """提取音乐风格为标准格式"""
400 if isinstance(raw_style, dict):
401 return {"zh": raw_style.get("zh", ""), "en": raw_style.get("en", "")}
402 elif isinstance(raw_style, str):
403 # 字符串格式,直接使用作为中文名,英文名留空
404 return {"zh": raw_style, "en": ""}
405 return None
406
407 def _extract_list_field(raw_value) -> list:
408 """提取列表字段"""
409 if isinstance(raw_value, list):
410 return [v for v in raw_value if v]
411 elif isinstance(raw_value, str):
412 import re
413
414 return [
415 v.strip()
416 for v in re.split(r"[,,、/|]+", raw_value)
417 if v and v.strip()
418 ]
419 return []
420
421 def _extract_single_field(raw_value) -> str:
422 """提取单值字段"""
423 if raw_value and isinstance(raw_value, str):
424 return raw_value
425 return ""
426
427 def _validate_and_map_sub_genre(sub_genre: str, genre: str) -> str:
428 """验证并映射 sub_genre 到有效值"""
429 if not sub_genre:
430 return ""
431
432 sub_genre = sub_genre.strip()
433
434 if sub_genre in SUB_GENRE_MAPPING:
435 mapped = SUB_GENRE_MAPPING[sub_genre]
436 logger.info(
437 f"[_validate_and_map_sub_genre] 映射 '{sub_genre}' -> '{mapped}'"
438 )
439 return mapped
440
441 if genre in VALID_SUB_GENRES:
442 if sub_genre in VALID_SUB_GENRES[genre]:
443 return sub_genre
444
445 for valid_subs in VALID_SUB_GENRES.values():
446 if sub_genre in valid_subs:
447 return sub_genre
448
449 logger.warning(
450 f"[_validate_and_map_sub_genre] 无法映射 sub_genre: '{sub_genre}' (genre: '{genre}')"
451 )
452 return sub_genre
453
454 def _validate_list_field(
455 values: List[str], valid_set: Set[str], field_name: str
456 ) -> List[str]:
457 """严格验证列表字段中的值:仅保留字典内标签"""
458 result = []
459 for v in values:
460 if v in valid_set:
461 result.append(v)
462 else:
463 logger.warning(
464 f"[_validate_list_field] {field_name} 值 '{v}' 不在字典中,已过滤"
465 )
466 return result
467
468 def _validate_language(raw_value: Any) -> str:
469 language = _extract_single_field(raw_value).strip()
470 if not language:
471 return ""
472 mapped = LANGUAGE_MAPPING.get(language, language)
473 if mapped in VALID_LANGUAGES:
474 return mapped
475 logger.warning(
476 f"[_normalize_result] language '{language}' 不在字典中,已归并为空"
477 )
478 return ""
479
480 result = {
481 "genre": "",
482 "sub_genre": "",
483 "emotion": [],
484 "voice_type": "",
485 "vocal_texture": "",
486 "vocal_description": "",
487 "visual_concept": "",
488 "language": "",
489 "scene": [],
490 "age": "",
491 "is_sinking": None,
492 "song_description": "",
493 "performer_type": "",
494 "music_style_tags": [],
495 "douyin_tags": [],
496 "instrument_tags": [],
497 }
498
499 # 音乐风格(一级风格和二级风格)
500 # 优先使用新格式 genre/sub_genre,兼容旧格式 music_style
501 raw_genre = raw_result.get("genre", "")
502 raw_sub_genre = raw_result.get("sub_genre", "")
503 raw_music_style = raw_result.get("music_style", [])
504
505 # 优先从 genre 字段获取一级风格
506 if isinstance(raw_genre, str) and raw_genre.strip():
507 result["genre"] = raw_genre.strip()
508 elif isinstance(raw_genre, dict):
509 result["genre"] = raw_genre.get("zh", "") or raw_genre.get("en", "")
510 # 兼容旧格式:从 music_style 数组提取
511 elif (
512 raw_music_style
513 and isinstance(raw_music_style, list)
514 and len(raw_music_style) > 0
515 ):
516 first_style = raw_music_style[0]
517 if isinstance(first_style, dict):
518 result["genre"] = first_style.get("zh", "") or first_style.get("en", "")
519 elif isinstance(first_style, str):
520 result["genre"] = first_style.strip()
521
522 # 优先从 sub_genre 字段获取二级风格
523 if isinstance(raw_sub_genre, str) and raw_sub_genre.strip():
524 result["sub_genre"] = raw_sub_genre.strip()
525 elif isinstance(raw_sub_genre, dict):
526 result["sub_genre"] = raw_sub_genre.get("zh", "") or raw_sub_genre.get(
527 "en", ""
528 )
529 # 兼容旧格式:从 music_style 数组第二个元素提取
530 elif (
531 raw_music_style
532 and isinstance(raw_music_style, list)
533 and len(raw_music_style) > 1
534 ):
535 second_style = raw_music_style[1]
536 if isinstance(second_style, dict):
537 result["sub_genre"] = second_style.get("zh", "") or second_style.get(
538 "en", ""
539 )
540 elif isinstance(second_style, str):
541 result["sub_genre"] = second_style.strip()
542
543 result["sub_genre"] = _validate_and_map_sub_genre(
544 result["sub_genre"], result["genre"]
545 )
546
547 # 情绪
548 raw_emotion = raw_result.get("emotion", [])
549 if isinstance(raw_emotion, str):
550 raw_emotion = [raw_emotion]
551 result["emotion"] = _validate_list_field(
552 _extract_list_field(raw_emotion), VALID_EMOTIONS, "emotion"
553 )
554
555 # 人声类型
556 raw_voice_type = raw_result.get("voice_type", "")
557 if raw_voice_type and isinstance(raw_voice_type, str):
558 voice_type = raw_voice_type.strip()
559 if voice_type in VALID_VOICE_TYPES:
560 result["voice_type"] = voice_type
561 else:
562 logger.warning(
563 f"[_normalize_result] voice_type '{voice_type}' 不在有效值中,保留原值"
564 )
565 result["voice_type"] = voice_type
566 else:
567 result["voice_type"] = ""
568
569 # 人声质感 (LLM返回的是vocal_type)
570 result["vocal_texture"] = _extract_single_field(
571 raw_result.get("vocal_type", "")
572 )
573
574 # 人声质感描述
575 result["vocal_description"] = raw_result.get("vocal_description", "")
576
577 # 聚音演唱者类型(优先 performer_type,回退 vocal_type)
578 raw_performer_type = raw_result.get("performer_type", raw_result.get("vocal_type", ""))
579 if isinstance(raw_performer_type, str):
580 performer_type = raw_performer_type.strip()
581 if performer_type in VALID_PERFORMER_TYPES:
582 result["performer_type"] = performer_type
583 elif performer_type in VALID_VOICE_TYPES:
584 result["performer_type"] = performer_type
585
586 # 聚音标签:音乐风格/网络抖音/配器
587 result["music_style_tags"] = _extract_list_field(
588 raw_result.get("music_style_tags", raw_result.get("music_style", []))
589 )
590 result["douyin_tags"] = _extract_list_field(
591 raw_result.get("douyin_tags", raw_result.get("network_douyin_tags", []))
592 )
593 result["instrument_tags"] = _extract_list_field(
594 raw_result.get("instrument_tags", raw_result.get("instruments", []))
595 )
596 result["music_style_tags"] = _validate_list_field(
597 result["music_style_tags"], VALID_MUSIC_STYLE_TAGS, "music_style_tags"
598 )
599 result["douyin_tags"] = _validate_list_field(
600 result["douyin_tags"], VALID_DOUYIN_TAGS, "douyin_tags"
601 )
602 result["instrument_tags"] = _validate_list_field(
603 result["instrument_tags"], VALID_INSTRUMENT_TAGS, "instrument_tags"
604 )
605
606 # 视觉概念
607 result["visual_concept"] = raw_result.get("visual_concept", "")
608
609 # 语种
610 result["language"] = _validate_language(raw_result.get("language", ""))
611
612 # 场景(可多选)
613 raw_scene = raw_result.get("scene", [])
614 if isinstance(raw_scene, str):
615 raw_scene = [raw_scene]
616 if isinstance(raw_scene, list):
617 scene_list = [s.strip() for s in raw_scene if s and isinstance(s, str)]
618 result["scene"] = _validate_list_field(scene_list, VALID_SCENES, "scene")
619
620 # 适合听众年龄段
621 raw_age = raw_result.get("age", "")
622 if raw_age and isinstance(raw_age, str):
623 result["age"] = raw_age.strip()
624
625 # 是否下沉
626 raw_is_sinking = raw_result.get("is_sinking")
627 if isinstance(raw_is_sinking, bool):
628 result["is_sinking"] = raw_is_sinking
629 elif isinstance(raw_is_sinking, str):
630 is_sinking_lower = raw_is_sinking.strip().lower()
631 if is_sinking_lower in ("是", "true", "1", "yes"):
632 result["is_sinking"] = True
633 elif is_sinking_lower in ("否", "false", "0", "no"):
634 result["is_sinking"] = False
635
636 # 歌曲描述
637 raw_song_desc = raw_result.get("song_description", "")
638 if raw_song_desc and isinstance(raw_song_desc, str):
639 result["song_description"] = raw_song_desc.strip()
640
641 # 情绪强度
642 raw_emotional_intensity = raw_result.get("emotional_intensity", "")
643 if raw_emotional_intensity and isinstance(raw_emotional_intensity, str):
644 result["emotional_intensity"] = raw_emotional_intensity.strip()
645
646 # 节奏强度
647 raw_rhythm_intensity = raw_result.get("rhythm_intensity", "")
648 if raw_rhythm_intensity and isinstance(raw_rhythm_intensity, str):
649 result["rhythm_intensity"] = raw_rhythm_intensity.strip()
650
651 # BPM 不从 LLM 结果中提取,统一由本地 bpm_analyzer_tools 提供
652
653 # 歌词(可选)
654 if "lyrics" in raw_result:
655 result["lyrics"] = raw_result["lyrics"]
656
657 # 添加模型信息
658 result["_model"] = model_name
659 if token_info:
660 result["_token_info"] = token_info
661 if "_token_info_parts" in raw_result and isinstance(
662 raw_result["_token_info_parts"], dict
663 ):
664 result["_token_info_parts"] = raw_result["_token_info_parts"]
665 if "_timing" in raw_result and isinstance(raw_result["_timing"], dict):
666 result["_timing"] = raw_result["_timing"]
667
668 return result
1 #!/usr/bin/env python3
2 """
3 Realtime BPM Analyzer - Python 测试程序
4
5 基于 realtime-bpm-analyzer (https://github.com/dlepaux/realtime-bpm-analyzer)
6 的 Python 实现,用于快速测试音频文件的 BPM。
7
8 功能:
9 1. 快速 BPM 识别
10 2. 实时特征提取
11 3. 多算法融合
12 4. 详细结果导出
13
14 使用方法:
15 python bpm_analyzer_test.py --file music.mp3
16 python bpm_analyzer_test.py --file music.mp3 --output result.json
17 python bpm_analyzer_test.py --file music.mp3 --verbose
18 python bpm_analyzer_test.py --dir /path/to/music/folder
19 """
20
21 import os
22 import sys
23 import json
24 import logging
25 import argparse
26 from pathlib import Path
27 from typing import Dict, List, Any, Optional, Tuple
28 from datetime import datetime
29 import numpy as np
30
31 # 导入音频处理库
32 try:
33 import librosa
34 import librosa.beat
35 import librosa.feature
36 import librosa.onset
37 except ImportError:
38 print("❌ librosa 库未安装,请运行: pip install librosa")
39 sys.exit(1)
40
41 from scipy.signal import find_peaks, correlate
42
43 # 配置日志
44 logging.basicConfig(
45 level=logging.INFO,
46 format='%(asctime)s - %(levelname)s - %(message)s'
47 )
48 logger = logging.getLogger(__name__)
49
50
51 class RealtimeBPMAnalyzerTest:
52 """Realtime BPM Analyzer - Python 版本"""
53
54 # BPM 范围(参考 realtime-bpm-analyzer)
55 BPM_MIN = 30.0
56 BPM_MAX = 200.0
57
58 # 置信度阈值
59 CONFIDENCE_THRESHOLD = 0.5
60
61 def __init__(self, verbose: bool = False):
62 """
63 初始化分析器
64
65 Args:
66 verbose: 是否显示详细信息
67 """
68 self.verbose = verbose
69 self.sr = 22050 # 采样率
70 self.hop_length = 512
71
72 if verbose:
73 logger.setLevel(logging.DEBUG)
74
75 logger.info("✓ Realtime BPM Analyzer Test 已初始化")
76
77 def print_header(self, title: str, width: int = 80):
78 """打印标题"""
79 print("\n" + "=" * width)
80 print(f" {title}")
81 print("=" * width)
82
83 def analyze_file(self, file_path: str) -> Dict[str, Any]:
84 """
85 分析单个音频文件
86
87 Args:
88 file_path: 音频文件路径
89
90 Returns:
91 分析结果字典
92 """
93 self.print_header("🎵 Realtime BPM Analyzer - 测试程序")
94
95 # 验证文件
96 if not os.path.exists(file_path):
97 logger.error(f"❌ 文件不存在: {file_path}")
98 return {'success': False, 'error': '文件不存在'}
99
100 file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
101 logger.info(f"📄 音频文件: {Path(file_path).name}")
102 logger.info(f"📊 文件大小: {file_size_mb:.2f} MB")
103 logger.info(f"📁 文件路径: {Path(file_path).absolute()}")
104
105 self.print_header("📊 分析过程", 80)
106
107 try:
108 # 加载音频
109 logger.info("🔄 加载音频文件...")
110 y, sr = librosa.load(file_path, sr=self.sr, mono=True)
111 duration = len(y) / sr
112 logger.info(f"✓ 音频加载成功,时长: {duration:.2f} 秒")
113
114 # 执行快速分析
115 logger.info("📈 快速 BPM 检测...")
116 fast_result = self._fast_bpm_detection(y, sr)
117
118 # 执行详细分析
119 logger.info("📊 详细 BPM 分析...")
120 detailed_result = self._detailed_bpm_analysis(y, sr)
121
122 # 融合结果
123 logger.info("🔀 融合分析结果...")
124 final_result = self._fuse_results(fast_result, detailed_result, y=y)
125
126 result = {
127 'success': True,
128 'file_path': str(Path(file_path).absolute()),
129 'file_name': Path(file_path).name,
130 'file_size_mb': round(file_size_mb, 2),
131 'duration_seconds': round(duration, 2),
132 'sample_rate': sr,
133 'timestamp': datetime.now().isoformat(),
134 'fast_detection': fast_result,
135 'detailed_analysis': detailed_result,
136 'final_result': final_result
137 }
138
139 self.print_header("📈 分析结果", 80)
140 self._display_results(result)
141
142 return result
143
144 except Exception as e:
145 logger.error(f"❌ 分析失败: {str(e)}")
146 if self.verbose:
147 import traceback
148 traceback.print_exc()
149 return {'success': False, 'error': str(e)}
150
151 def analyze_directory(self, dir_path: str) -> List[Dict[str, Any]]:
152 """
153 分析文件夹中的所有音频文件
154
155 Args:
156 dir_path: 文件夹路径
157
158 Returns:
159 分析结果列表
160 """
161 self.print_header("🎵 Realtime BPM Analyzer - 批量分析", 80)
162
163 if not os.path.isdir(dir_path):
164 logger.error(f"❌ 文件夹不存在: {dir_path}")
165 return []
166
167 # 查找所有音频文件
168 audio_extensions = ('.mp3', '.wav', '.flac', '.m4a', '.aac', '.ogg')
169 audio_files = []
170
171 for root, dirs, files in os.walk(dir_path):
172 for file in files:
173 if file.lower().endswith(audio_extensions):
174 audio_files.append(os.path.join(root, file))
175
176 logger.info(f"📂 找到 {len(audio_files)} 个音频文件")
177
178 results = []
179 for i, file_path in enumerate(audio_files, 1):
180 logger.info(f"\n[{i}/{len(audio_files)}] 正在分析...")
181 result = self.analyze_file(file_path)
182 results.append(result)
183
184 return results
185
186 def analyze_bpm(
187 self,
188 file_path: str = None,
189 y: np.ndarray = None,
190 sr: int = None,
191 ) -> Dict[str, Any]:
192 """
193 统一 BPM 分析入口(供其他模块调用)
194
195 支持两种调用方式:
196 1. 传入 file_path,内部以 sr=22050 加载音频
197 2. 传入已加载的 y, sr(避免重复加载)
198
199 Returns:
200 {
201 'bpm': float, # 最终 BPM(经过融合+纠正)
202 'original_bpm': float, # 快速检测的原始 BPM
203 'confidence': float,
204 'beat_times': list, # 节拍时间点列表
205 }
206 """
207 try:
208 if y is None and file_path is not None:
209 if not os.path.exists(file_path):
210 return {'bpm': 120.0, 'original_bpm': 120.0,
211 'confidence': 0.0, 'beat_times': []}
212 y, sr = librosa.load(file_path, sr=self.sr, mono=True)
213 elif y is None:
214 return {'bpm': 120.0, 'original_bpm': 120.0,
215 'confidence': 0.0, 'beat_times': []}
216
217 # 快速检测
218 fast_result = self._fast_bpm_detection(y, sr)
219
220 # 详细分析
221 detailed_result = self._detailed_bpm_analysis(y, sr)
222
223 # 融合
224 final_result = self._fuse_results(fast_result, detailed_result, y=y)
225
226 final_bpm = final_result.get('bpm', 120.0)
227 original_bpm = fast_result.get('original_bpm', final_bpm)
228
229 # 获取 beat_times:从 _fast_bpm_detection 内部的 beat_track 获取
230 _, beat_frames = librosa.beat.beat_track(
231 y=y, sr=sr, hop_length=self.hop_length
232 )
233 if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0:
234 beat_times = librosa.frames_to_time(
235 beat_frames, sr=sr, hop_length=self.hop_length
236 ).tolist()
237 else:
238 beat_times = []
239
240 # 如果 BPM 被减半了,节拍时间点也每隔一个取一个
241 if final_bpm < original_bpm * 0.75:
242 beat_times = beat_times[::2]
243
244 return {
245 'bpm': final_bpm,
246 'original_bpm': original_bpm,
247 'confidence': final_result.get('confidence', 0.0),
248 'beat_times': beat_times,
249 }
250 except Exception as e:
251 logger.warning(f"analyze_bpm 失败: {e}")
252 return {'bpm': 120.0, 'original_bpm': 120.0,
253 'confidence': 0.0, 'beat_times': []}
254
255 def _fast_bpm_detection(self, y: np.ndarray, sr: int) -> Dict[str, Any]:
256 """快速 BPM 检测(参考 librosa.beat.tempo)+ 智能节拍层级纠正"""
257 try:
258 # 获取 BPM 和节拍时间
259 tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=self.hop_length)
260
261 # 处理 tempo 可能是 ndarray 的情况
262 if isinstance(tempo, np.ndarray):
263 bpm = float(tempo[0]) if tempo.size > 0 else 120.0
264 else:
265 bpm = float(tempo)
266
267 # 处理 beat_frames 可能是 ndarray 的情况
268 if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0:
269 beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=self.hop_length)
270 beat_times = beat_times.tolist() if isinstance(beat_times, np.ndarray) else list(beat_times)
271 else:
272 beat_times = []
273
274 # 智能节拍层级检测和纠正(传入音频数据用于onset分析)
275 corrected_bpm, correction_reason = self._detect_beat_level_errors(beat_times, bpm, y)
276
277 return {
278 'bpm': round(corrected_bpm, 1),
279 'original_bpm': round(bpm, 1),
280 'confidence': 0.85,
281 'method': 'librosa.beat.tempo()',
282 'beat_count': len(beat_times),
283 'beat_level_correction': correction_reason if correction_reason != 'beat_level_ok' else None,
284 'duration_ms': 100
285 }
286 except Exception as e:
287 logger.warning(f"⚠️ 快速检测失败: {str(e)}")
288 return {
289 'bpm': 0,
290 'confidence': 0,
291 'method': 'librosa.beat.tempo()',
292 'error': str(e)
293 }
294
295 def _detect_beat_level_errors(self, beat_times: list, bpm: float, y: np.ndarray = None) -> Tuple[float, str]:
296 """
297 检测和纠正beat level错误(如检测到8th-note而非quarter-note)
298
299 改进版:组合多个特征来判断
300 1. 交替强度模式 (ratio)
301 2. 原始BPM范围 (100-150范围内更可能需要减半)
302 3. 谱质心分析 (慢歌通常谱质心较低)
303 4. Onset对齐分数比较
304 """
305 if not beat_times or len(beat_times) < 2:
306 return bpm, "insufficient_beats"
307
308 beat_intervals = np.diff(beat_times)
309 mean_interval = np.mean(beat_intervals)
310 std_interval = np.std(beat_intervals)
311 coeff_variation = std_interval / mean_interval if mean_interval > 0 else 1.0
312
313 beat_count = len(beat_times)
314
315 if self.verbose:
316 logger.debug(f"Beat level analysis: {beat_count} beats, CV={coeff_variation:.3f}, BPM={bpm:.1f}")
317
318 # 条件1: 间隔非常规则 + BPM > 100 + beat count > 20 (降低阈值以支持短片段)
319 if not (coeff_variation < 0.15 and bpm > 100 and beat_count > 20):
320 return bpm, "beat_level_ok"
321
322 # 如果没有音频数据,使用保守策略
323 if y is None:
324 return bpm, "beat_level_ok"
325
326 halved_bpm = bpm / 2
327 if not (40 < halved_bpm < 160):
328 return bpm, "beat_level_ok"
329
330 # 计算onset strength
331 onset_env = librosa.onset.onset_strength(y=y, sr=self.sr, hop_length=self.hop_length)
332
333 # 获取每个beat位置的onset强度
334 beat_frames = librosa.time_to_frames(beat_times, sr=self.sr, hop_length=self.hop_length)
335 beat_strengths = []
336 window = 3
337
338 for frame in beat_frames:
339 if frame < len(onset_env):
340 start = max(0, frame - window)
341 end = min(len(onset_env), frame + window + 1)
342 beat_strengths.append(np.max(onset_env[start:end]))
343
344 if len(beat_strengths) < 10:
345 return bpm, "beat_level_ok"
346
347 beat_strengths = np.array(beat_strengths)
348
349 # 检测交替强度模式
350 odd_beats = beat_strengths[::2]
351 even_beats = beat_strengths[1::2]
352 mean_odd = np.mean(odd_beats)
353 mean_even = np.mean(even_beats)
354 strength_ratio = mean_odd / mean_even if mean_even > 0 else 1.0
355
356 # 计算谱质心 (spectral centroid) - 用于区分快歌和慢歌
357 spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=self.sr, hop_length=self.hop_length)
358 mean_centroid = np.mean(spectral_centroid)
359
360 if self.verbose:
361 logger.debug(f"Beat strength ratio={strength_ratio:.3f}, spectral_centroid={mean_centroid:.1f}")
362
363 # 综合判断逻辑
364 should_halve = False
365 reason = ""
366
367 # 规则1: 非常明显的交替模式 (ratio > 1.8 或 < 0.55)
368 if strength_ratio > 1.8 or strength_ratio < 0.55:
369 should_halve = True
370 reason = f"strong_alternating_pattern (ratio={strength_ratio:.2f})"
371
372 # 规则1b: BPM > 150 + 中等交替模式 → 减半
373 # 如"春娇与志明"(172.3 BPM, ratio=1.406, ref=85)
374 # Home - Headhunterz (152 BPM, ratio=1.098) 不会触发
375 elif bpm > 150 and (strength_ratio > 1.25 or strength_ratio < 0.8):
376 should_halve = True
377 reason = f"very_high_bpm_with_alternating (bpm={bpm:.1f}, ratio={strength_ratio:.2f})"
378
379 # 规则2: BPM在125-150范围 + 强交替模式 (ratio > 1.25)
380 # 高onset密度(>=3.0/s) + 高谱质心(>=2200)说明是真正的快歌,不应减半
381 # 如"爱在西元前"(129.2 BPM, centroid=2527, onset_density=3.8, ratio=1.29)
382 # 否则使用 bpm*2/3 纠正(适用于3:2节奏关系的歌曲)
383 # 如"该死的爱情"(129.2 BPM, ratio=1.668, centroid=1986, ref=84) → 2/3=86.1
384 # 如"你要的全拿走"(136.0 BPM, ratio=1.485, centroid=2678, ref=76) → 2/3=90.7
385 elif 125 <= bpm <= 150 and strength_ratio > 1.25:
386 onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=self.sr, hop_length=self.hop_length)
387 duration = len(y) / self.sr
388 onset_density = len(onset_frames) / duration if duration > 0 else 0
389 if onset_density >= 3.0 and mean_centroid >= 2200:
390 if self.verbose:
391 logger.debug(f"规则2跳过: 高onset密度({onset_density:.1f}/s) + 高谱质心({mean_centroid:.0f}),判定为快歌")
392 else:
393 # 根据谱质心区分纠正策略:
394 # 低谱质心(<2200): 暗淡音色的慢歌,librosa锁定在3/2倍,用*2/3纠正
395 # 如"该死的爱情"(129.2, centroid=1986, ratio=1.67) → 86.1 (ref=84)
396 # 高谱质心(>=2200)+低onset密度(<3.0): 明亮制作的慢歌,librosa锁定在2倍,用/2纠正
397 # 如"你要的全拿走"(136.0, centroid=2678, density=2.17, ratio=1.49) → 68.0 (ref=76)
398 if mean_centroid >= 2200:
399 # 明亮但节奏稀疏 → 简单减半
400 should_halve = True
401 reason = f"rule2_bright_slow (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f}, density={onset_density:.1f})"
402 else:
403 # 暗淡音色 → 用2/3纠正
404 two_thirds_bpm = round(bpm * 2 / 3, 1)
405 should_halve = False
406 logger.info(
407 f"🔧 节拍层级纠正(2/3): {bpm:.1f} BPM → {two_thirds_bpm:.1f} BPM "
408 f"(ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})"
409 )
410 return two_thirds_bpm, f"rule2_two_thirds (bpm={bpm:.1f}, result={two_thirds_bpm:.1f}, ratio={strength_ratio:.2f})"
411
412 elif 125 <= bpm <= 150 and strength_ratio < 0.8 and mean_centroid < 2200:
413 should_halve = True
414 reason = f"mid_bpm_low_ratio_low_centroid (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})"
415
416 # 规则2b: BPM > 130 + 低谱质心 (< 1800) 表示慢歌特征但检测到高BPM
417 # 捕获像"嚣张"这样的歌曲: BPM=136但centroid=1653
418 elif bpm > 130 and mean_centroid < 1800:
419 should_halve = True
420 reason = f"high_bpm_low_centroid (bpm={bpm:.1f}, centroid={mean_centroid:.0f})"
421
422 # 规则3: BPM在115-125范围需要更严格的条件
423 elif 115 <= bpm < 125:
424 # 规则3a: 非常强的交替模式(ratio > 1.5),无论centroid如何都应减半
425 # 这捕获了像"想你的夜"这样有强烈交替但centroid偏高的歌曲
426 if strength_ratio > 1.5 or strength_ratio < 0.65:
427 should_halve = True
428 reason = f"strong_alternating_in_mid_bpm (bpm={bpm:.1f}, ratio={strength_ratio:.2f})"
429 # 规则3b: 中等交替模式 + 低谱质心(慢歌特征)
430 elif mean_centroid < 2000 and (strength_ratio > 1.4 or strength_ratio < 0.7):
431 should_halve = True
432 reason = f"slow_song_detected (centroid={mean_centroid:.0f}, ratio={strength_ratio:.2f})"
433 # 否则保持原样(可能是真正的中速歌曲如 有什么奇怪、中巴车)
434
435 # 规则3c: BPM在100-115范围(可能是慢歌被检测为2倍,如嘉禾望岗 56 BPM → 112 BPM)
436 # 使用onset alignment来判断
437 elif 100 <= bpm < 115:
438 score_detected = self._compute_onset_alignment_score(onset_env, bpm)
439 score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm)
440
441 if score_detected > 0 and score_halved > 0:
442 alignment_ratio = score_halved / score_detected
443 if self.verbose:
444 logger.debug(f"Onset alignment (100-115 BPM): detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}")
445
446 # 如果halved BPM的对齐分数更好 (ratio > 1.0),说明真实BPM是一半
447 # 同时检查交替模式作为辅助判断
448 if alignment_ratio > 1.0 and (strength_ratio > 1.2 or strength_ratio < 0.83):
449 should_halve = True
450 reason = f"slow_song_100_115_range (alignment_ratio={alignment_ratio:.3f}, strength_ratio={strength_ratio:.2f})"
451 # 即使没有明显交替模式,如果对齐分数明显更好也应减半
452 elif alignment_ratio > 1.08:
453 should_halve = True
454 reason = f"onset_alignment_strongly_favors_half (ratio={alignment_ratio:.3f})"
455
456 # 规则4: 使用onset alignment比较BPM vs BPM/2 (仅用于高BPM > 130)
457 # 如果BPM/2的对齐分数明显更好,说明检测到了half-beat
458 # 限制为BPM > 130以避免误伤中速歌曲如"中巴车"(117.5 BPM)
459 if not should_halve and bpm > 130:
460 score_detected = self._compute_onset_alignment_score(onset_env, bpm)
461 score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm)
462
463 if score_detected > 0 and score_halved > 0:
464 alignment_ratio = score_halved / score_detected
465 if self.verbose:
466 logger.debug(f"Onset alignment: detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}")
467
468 # 高谱质心(>=2000)说明是快节奏/电子乐,需要更高的alignment ratio才能减半
469 # 避免误伤如"Home - Headhunterz"(152 BPM, centroid=2290, ratio=1.102)
470 ratio_threshold = 1.15 if mean_centroid >= 2000 else 1.04
471 if alignment_ratio > ratio_threshold and 40 < halved_bpm < 160:
472 should_halve = True
473 reason = f"onset_alignment_favors_half (ratio={alignment_ratio:.3f})"
474
475 if should_halve:
476 logger.info(f"🔧 节拍层级纠正: {bpm:.1f} BPM → {halved_bpm:.1f} BPM ({reason})")
477 return halved_bpm, reason
478
479 return bpm, "beat_level_ok"
480
481 def _compute_onset_alignment_score(self, onset_env: np.ndarray, bpm: float) -> float:
482 """
483 计算给定BPM与onset strength的对齐度分数
484
485 原理:真实的节拍应该对应onset strength的峰值
486 分数越高表示对齐度越好
487 """
488 frame_rate = self.sr / self.hop_length
489 beat_interval_frames = int((60.0 / bpm) * frame_rate)
490
491 if beat_interval_frames < 1 or beat_interval_frames > len(onset_env):
492 return 0.0
493
494 # 在每个节拍位置采样onset strength
495 beat_strengths = []
496 off_beat_strengths = []
497
498 for i in range(0, len(onset_env) - beat_interval_frames, beat_interval_frames):
499 # 节拍位置(在一个小窗口内找最大值)
500 window_size = max(1, beat_interval_frames // 8)
501 start = max(0, i - window_size)
502 end = min(len(onset_env), i + window_size)
503 beat_strengths.append(np.max(onset_env[start:end]))
504
505 # 非节拍位置(节拍之间的中点)
506 mid_point = i + beat_interval_frames // 2
507 if mid_point < len(onset_env):
508 start_off = max(0, mid_point - window_size)
509 end_off = min(len(onset_env), mid_point + window_size)
510 off_beat_strengths.append(np.max(onset_env[start_off:end_off]))
511
512 if not beat_strengths or not off_beat_strengths:
513 return 0.0
514
515 # 分数 = 节拍位置平均强度 / 非节拍位置平均强度
516 # 比值越高,说明节拍位置的onset越明显
517 mean_beat = np.mean(beat_strengths)
518 mean_off_beat = np.mean(off_beat_strengths)
519
520 if mean_off_beat < 1e-6:
521 return mean_beat
522
523 score = mean_beat / mean_off_beat
524 return float(score)
525
526 def _detailed_bpm_analysis(self, y: np.ndarray, sr: int) -> Dict[str, Any]:
527 """详细 BPM 分析"""
528 try:
529 # 计算 onset strength
530 onset_env = librosa.onset.onset_strength(
531 y=y, sr=sr, hop_length=self.hop_length
532 )
533
534 # 计算 tempogram
535 tempogram = librosa.feature.tempogram(
536 y=y, sr=sr, hop_length=self.hop_length
537 )
538
539 # 计算自相关
540 tempogram_flat = tempogram.flatten()
541 acf = correlate(tempogram_flat, tempogram_flat, mode='full')
542 acf = acf[len(acf)//2:]
543 acf = acf / (acf[0] + 1e-8)
544
545 # 找峰值
546 peaks, properties = find_peaks(acf[1:], height=0.2, distance=5)
547 peaks = peaks + 1
548
549 if len(peaks) > 0:
550 frame_rate = sr / self.hop_length
551 best_peak_idx = peaks[np.argmax(acf[peaks])]
552 bpm = 60.0 * frame_rate / best_peak_idx
553 confidence = float(np.max(acf[peaks]))
554 else:
555 bpm = 120.0
556 confidence = 0.3
557
558 # 确保在合理范围内
559 bpm = np.clip(bpm, self.BPM_MIN, self.BPM_MAX)
560
561 return {
562 'bpm': round(bpm, 1),
563 'confidence': round(float(np.clip(confidence, 0, 1)), 2),
564 'method': 'Tempogram Autocorrelation',
565 'peaks_count': int(len(peaks))
566 }
567 except Exception as e:
568 logger.warning(f"⚠️ 详细分析失败: {str(e)}")
569 return {
570 'bpm': 0,
571 'confidence': 0,
572 'method': 'Tempogram Autocorrelation',
573 'error': str(e)
574 }
575
576 def _fuse_results(
577 self,
578 fast_result: Dict[str, Any],
579 detailed_result: Dict[str, Any],
580 y: np.ndarray = None,
581 ) -> Dict[str, Any]:
582 """融合快速和详细分析的结果,带倍频检测和纠正"""
583 results = []
584
585 if fast_result.get('bpm', 0) > 0:
586 results.append({
587 'bpm': fast_result['bpm'],
588 'original_bpm': fast_result.get('original_bpm', fast_result['bpm']),
589 'confidence': fast_result['confidence'],
590 'method': fast_result['method'],
591 'beat_level_correction': fast_result.get('beat_level_correction')
592 })
593
594 if detailed_result.get('bpm', 0) > 0:
595 results.append({
596 'bpm': detailed_result['bpm'],
597 'confidence': detailed_result['confidence'],
598 'method': detailed_result['method']
599 })
600
601 if not results:
602 return {
603 'bpm': 120.0,
604 'confidence': 0.0,
605 'note': '无法检测 BPM,使用默认值'
606 }
607
608 # 如果快速检测已经进行了beat level纠正,直接使用纠正后的结果
609 beat_level_correction = results[0].get('beat_level_correction') if results else None
610 if beat_level_correction:
611 original_bpm = results[0].get('original_bpm', results[0]['bpm'])
612 corrected_bpm = results[0]['bpm']
613 return {
614 'bpm': corrected_bpm,
615 'confidence': results[0]['confidence'],
616 'primary_method': results[0]['method'],
617 'supporting_methods': len(results) - 1,
618 'all_candidates': results,
619 'octave_correction': {
620 'from': original_bpm,
621 'to': corrected_bpm,
622 'reason': f'节拍层级纠正: {original_bpm:.1f} → {corrected_bpm:.1f} ({beat_level_correction})'
623 }
624 }
625
626 # 如果只有一个结果
627 if len(results) == 1:
628 best = results[0]
629 return {
630 'bpm': best['bpm'],
631 'confidence': best['confidence'],
632 'primary_method': best['method'],
633 'supporting_methods': 0,
634 'all_candidates': results,
635 'octave_correction': None
636 }
637
638 # 检测倍频关系
639 fast_bpm = results[0]['bpm'] # librosa.beat.tempo 通常更准确
640 detailed_bpm = results[1]['bpm'] if len(results) > 1 else None
641
642 if detailed_bpm and fast_bpm > 0:
643 ratio = max(fast_bpm, detailed_bpm) / min(fast_bpm, detailed_bpm)
644
645 # 检查是否是倍频关系(1/2, 1/3, 1/4, 2x, 3x, 4x 等)
646 octave_correction = None
647 is_octave = False
648 chosen_bpm = fast_bpm # 默认使用快速检测结果
649
650 # 特殊情况:当 detailed_bpm 很低(< 40)且 fast_bpm 在 100-120 范围时
651 # 可能是慢歌被检测为2倍,此时 detailed_bpm × 2 可能是正确答案
652 # 例如:嘉禾望岗 实际56 BPM,fast=112.3,detailed=30,30×2=60更接近
653 # 注意:需要排除中速/快歌被误纠正的情况(如 中巴车带我回家, fast=117.5, detailed=30, ref=115)
654 # 使用 onset alignment 来验证:如果 halved BPM 的对齐度明显优于 fast BPM,才执行纠正
655 if detailed_bpm < 40 and 100 <= fast_bpm <= 120 and y is not None:
656 # 计算谱质心来判断是否真的是慢歌
657 spectral_centroid = librosa.feature.spectral_centroid(
658 y=y, sr=self.sr, hop_length=self.hop_length
659 )
660 mean_centroid = float(np.mean(spectral_centroid))
661
662 doubled_detailed = detailed_bpm * 2
663 # 检查 doubled_detailed 是否在合理的慢歌范围内 (50-70 BPM)
664 # 且谱质心较低(< 2200),确认是慢歌特征
665 if 50 <= doubled_detailed <= 70 and mean_centroid < 2200:
666 # 检查 fast_bpm 是否约等于 doubled_detailed × 2
667 if abs(fast_bpm - doubled_detailed * 2) / fast_bpm < 0.1:
668 # 额外验证:用 onset alignment 确认 halved BPM 确实更好
669 # 避免误纠正如"中巴车带我回家"(fast=117.5, ref=115)
670 onset_env = librosa.onset.onset_strength(
671 y=y, sr=self.sr, hop_length=self.hop_length
672 )
673 score_fast = self._compute_onset_alignment_score(onset_env, fast_bpm)
674 score_halved = self._compute_onset_alignment_score(onset_env, doubled_detailed)
675 alignment_ratio = score_halved / score_fast if score_fast > 0 else 0
676
677 if self.verbose:
678 logger.debug(
679 f"慢歌倍频验证: score_fast={score_fast:.3f}, "
680 f"score_halved={score_halved:.3f}, ratio={alignment_ratio:.3f}"
681 )
682
683 # 只有当 halved BPM 的对齐度明显更好时才纠正
684 # 中巴车: alignment_ratio=1.042,不触发(实际BPM=115)
685 # 嘉禾望岗: halved=56 对齐度应该明显更好,会触发
686 if alignment_ratio > 1.08:
687 chosen_bpm = doubled_detailed
688 is_octave = True
689 octave_correction = {
690 'from': fast_bpm,
691 'to': doubled_detailed,
692 'reason': f'慢歌倍频纠正: fast={fast_bpm:.1f} ≈ detailed×4={detailed_bpm:.1f}×4,使用 detailed×2={doubled_detailed:.1f} (alignment={alignment_ratio:.3f})'
693 }
694 logger.info(f"\n🔧 慢歌倍频纠正: {fast_bpm:.1f} BPM → {doubled_detailed:.1f} BPM")
695 logger.info(f" 原因: {octave_correction['reason']}")
696 return {
697 'bpm': chosen_bpm,
698 'confidence': results[1]['confidence'],
699 'primary_method': 'Tempogram + 倍频纠正',
700 'supporting_methods': 1,
701 'all_candidates': results,
702 'octave_correction': octave_correction
703 }
704 else:
705 if self.verbose:
706 logger.debug(
707 f"慢歌倍频纠正跳过: fast BPM({fast_bpm:.1f})对齐度更好,保持原值"
708 )
709
710 # 检查常见倍频关系:detailed_bpm 应该 ≈ fast_bpm * multiplier
711 for multiplier in [0.25, 0.33, 0.5, 1.0, 2.0, 3.0, 4.0]:
712 expected_bpm = fast_bpm * multiplier
713 # 检查 detailed_bpm 是否接近 expected_bpm(10% 容差)
714 if abs(detailed_bpm - expected_bpm) / expected_bpm < 0.1:
715 is_octave = True
716 if multiplier != 1.0: # 非 1 倍关系表示倍频
717 # 使用快速检测的结果
718 corrected_bpm = fast_bpm
719 octave_correction = {
720 'from': detailed_bpm,
721 'to': corrected_bpm,
722 'reason': f'倍频关系检测: {detailed_bpm:.1f} ≈ {fast_bpm:.1f} × {multiplier},使用快速检测结果'
723 }
724 break
725
726 # 如果检测到倍频,使用快速检测结果(通常更准确)
727 if is_octave and octave_correction:
728 logger.info(f"\n🔧 倍频纠正: {octave_correction['from']:.1f} BPM → {octave_correction['to']:.1f} BPM")
729 logger.info(f" 原因: {octave_correction['reason']}")
730 return {
731 'bpm': fast_bpm,
732 'confidence': results[0]['confidence'],
733 'primary_method': results[0]['method'],
734 'supporting_methods': 1,
735 'all_candidates': results,
736 'octave_correction': octave_correction
737 }
738
739 # 如果没有倍频关系,优先使用快速检测(librosa.beat.tempo 是金标准)
740 # 快速检测通常比详细分析更准确
741 best = results[0] # 快速检测
742
743 return {
744 'bpm': best['bpm'],
745 'confidence': best['confidence'],
746 'primary_method': best['method'],
747 'supporting_methods': len(results) - 1,
748 'all_candidates': results,
749 'octave_correction': None
750 }
751
752 def _display_results(self, result: Dict[str, Any]):
753 """显示分析结果"""
754 if not result['success']:
755 logger.error(f"❌ 分析失败: {result.get('error')}")
756 return
757
758 file_info = (
759 f"文件: {result['file_name']} "
760 f"({result['file_size_mb']} MB) "
761 f"时长: {result['duration_seconds']} 秒"
762 )
763 logger.info(file_info)
764
765 final = result['final_result']
766 logger.info(f"\n🎵 最终结果:")
767 logger.info(f" BPM: {final['bpm']}")
768 logger.info(f" 置信度: {final['confidence']:.0%}")
769 logger.info(f" 主要方法: {final['primary_method']}")
770 logger.info(f" 支持方法数: {final['supporting_methods']}")
771
772 # 显示倍频纠正信息
773 if final.get('octave_correction'):
774 correction = final['octave_correction']
775 logger.info(f"\n🔧 倍频纠正:")
776 logger.info(f" 原始检测: {correction['from']:.1f} BPM")
777 logger.info(f" 纠正后: {correction['to']:.1f} BPM")
778 logger.info(f" 原因: {correction['reason']}")
779
780 if self.verbose:
781 logger.debug(f"\n📊 快速检测: {result['fast_detection']['bpm']} BPM")
782 logger.debug(f"📊 详细分析: {result['detailed_analysis']['bpm']} BPM")
783
784 def export_results(
785 self,
786 results: Any,
787 output_path: str
788 ):
789 """导出结果为 JSON"""
790 try:
791 # 将 numpy 类型转换为 Python 原生类型
792 def convert_numpy(obj):
793 if isinstance(obj, np.ndarray):
794 return obj.tolist()
795 elif isinstance(obj, np.integer):
796 return int(obj)
797 elif isinstance(obj, np.floating):
798 return float(obj)
799 elif isinstance(obj, dict):
800 return {k: convert_numpy(v) for k, v in obj.items()}
801 elif isinstance(obj, (list, tuple)):
802 return [convert_numpy(v) for v in obj]
803 return obj
804
805 results_converted = convert_numpy(results)
806
807 with open(output_path, 'w', encoding='utf-8') as f:
808 json.dump(results_converted, f, ensure_ascii=False, indent=2)
809
810 logger.info(f"✓ 结果已导出到: {Path(output_path).absolute()}")
811
812 except Exception as e:
813 logger.error(f"❌ 导出失败: {str(e)}")
814
815
816 def main():
817 """主函数"""
818 parser = argparse.ArgumentParser(
819 description='Realtime BPM Analyzer - Python 测试程序',
820 formatter_class=argparse.RawDescriptionHelpFormatter,
821 epilog="""
822 示例用法:
823 # 分析单个文件
824 python bpm_analyzer_test.py --file music.mp3
825
826 # 分析并输出结果
827 python bpm_analyzer_test.py --file music.mp3 --output result.json
828
829 # 显示详细信息
830 python bpm_analyzer_test.py --file music.mp3 --verbose
831
832 # 批量分析文件夹
833 python bpm_analyzer_test.py --dir /path/to/music
834 """
835 )
836
837 parser.add_argument('--file', type=str, help='音频文件路径')
838 parser.add_argument('--dir', type=str, help='音频文件夹路径(批量分析)')
839 parser.add_argument('-o', '--output', type=str, help='输出 JSON 文件路径')
840 parser.add_argument('-v', '--verbose', action='store_true', help='显示详细信息')
841
842 args = parser.parse_args()
843
844 # 验证参数
845 if not args.file and not args.dir:
846 parser.print_help()
847 sys.exit(1)
848
849 # 初始化分析器
850 analyzer = RealtimeBPMAnalyzerTest(verbose=args.verbose)
851
852 # 执行分析
853 try:
854 if args.file:
855 result = analyzer.analyze_file(args.file)
856 results = result
857 else:
858 results_list = analyzer.analyze_directory(args.dir)
859 results = {
860 'success': True,
861 'total_files': len(results_list),
862 'results': results_list
863 }
864
865 # 导出结果
866 if args.output:
867 analyzer.export_results(results, args.output)
868 else:
869 # 默认输出文件名
870 if args.file:
871 default_output = f"bpm_result_{Path(args.file).stem}.json"
872 else:
873 default_output = "bpm_results.json"
874 analyzer.export_results(results, default_output)
875
876 print("\n" + "=" * 80)
877 print("✅ 分析完成!")
878 print("=" * 80 + "\n")
879
880 except Exception as e:
881 logger.error(f"❌ 执行失败: {str(e)}")
882 if args.verbose:
883 import traceback
884 traceback.print_exc()
885 sys.exit(1)
886
887
888 if __name__ == '__main__':
889 main()
1 # -*- coding: utf-8 -*-
2 """
3 火山引擎豆包音乐分析器实现
4 """
5
6 import os
7 import time
8 import logging
9 from typing import Dict, Any, Optional
10 from dotenv import load_dotenv
11 from pathlib import Path
12
13 import httpx
14
15 from .base import AudioAnalyzer
16 from .prompts import build_analyze_prompt, build_lyrics_prompt
17
18 _ROOT_DIR = Path(__file__).resolve().parents[2]
19 load_dotenv(_ROOT_DIR / ".env")
20
21 logger = logging.getLogger(__name__)
22
23
24 class DoubaoAnalyzer(AudioAnalyzer):
25 """火山引擎豆包音乐分析器"""
26
27 def __init__(
28 self,
29 api_key: Optional[str] = None,
30 base_url: Optional[str] = None,
31 model: Optional[str] = None,
32 timeout: float = 60.0,
33 max_retries: int = 3,
34 ):
35 """
36 初始化豆包分析器
37
38 Args:
39 api_key: API Key(默认从环境变量读取 DOUBAO_API_KEY 或 ARK_API_KEY)
40 base_url: API 基础URL(默认: https://ark.cn-beijing.volces.com/api/v3)
41 model: 模型名称(默认: doubao-seed-1-8-251228)
42 timeout: 超时时间(秒)
43 max_retries: 最大重试次数
44 """
45 self.api_key = api_key or os.getenv("DOUBAO_API_KEY", os.getenv("ARK_API_KEY"))
46 self.base_url = base_url or os.getenv(
47 "DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"
48 )
49 self.model = model or os.getenv("DOUBAO_MODEL", "doubao-seed-1-8-251228")
50 self.timeout = timeout
51 self.max_retries = max_retries
52
53 self._client = None
54
55 def _get_client(self) -> httpx.Client:
56 """获取 HTTP 客户端"""
57 if self._client is None:
58 self._client = httpx.Client(
59 base_url=self.base_url,
60 timeout=self.timeout,
61 headers={
62 "Authorization": f"Bearer {self.api_key}",
63 "Content-Type": "application/json",
64 },
65 )
66 return self._client
67
68 def get_provider_name(self) -> str:
69 return "doubao"
70
71 def get_model_name(self) -> str:
72 return self.model
73
74 def analyze(
75 self,
76 metadata: Dict[str, Any],
77 music_url: str,
78 extract_lyrics: bool = False,
79 label_level: int = 0,
80 ) -> Optional[Dict[str, Any]]:
81 """
82 分析音乐
83
84 Args:
85 metadata: 音乐元数据
86 music_url: 音乐文件 URL
87 extract_lyrics: 是否识别歌词
88 label_level: 标签级别
89
90 Returns:
91 分析结果字典
92 """
93 client = self._get_client()
94
95 if extract_lyrics:
96 return self._analyze_with_lyrics(client, metadata, music_url, label_level)
97 else:
98 return self._analyze_basic(client, metadata, music_url, label_level)
99
100 def _analyze_basic(
101 self,
102 client: httpx.Client,
103 metadata: Dict[str, Any],
104 music_url: str,
105 label_level: int = 0,
106 ) -> Optional[Dict[str, Any]]:
107 """基础分析(不含歌词)"""
108 system_prompt, user_prompt = build_analyze_prompt(
109 metadata=metadata,
110 include_lyrics=False,
111 label_level=label_level,
112 )
113
114 # 打印提示词到日志
115 logger.info(f"[DoubaoAnalyzer] System Prompt:\n{system_prompt}")
116 logger.info(f"[DoubaoAnalyzer] User Prompt:\n{user_prompt}")
117
118 messages = self._build_messages(system_prompt, user_prompt, music_url)
119
120 response = self._call_with_retry(client, messages)
121
122 if response is None:
123 return None
124
125 result = self._parse_response(response.get("content", ""))
126 if result is None:
127 return None
128
129 return self._normalize_result(result, self.model, response.get("usage"))
130
131 def _analyze_with_lyrics(
132 self,
133 client: httpx.Client,
134 metadata: Dict[str, Any],
135 music_url: str,
136 label_level: int = 0,
137 ) -> Optional[Dict[str, Any]]:
138 """分析(含歌词识别,需要两次调用)"""
139 # 第一次调用:基本信息(不含歌词)
140 system_prompt, user_prompt = build_analyze_prompt(
141 metadata=metadata,
142 include_lyrics=False,
143 label_level=label_level,
144 )
145
146 # 打印提示词到日志
147 logger.info(f"[DoubaoAnalyzer] System Prompt (with lyrics):\n{system_prompt}")
148 logger.info(f"[DoubaoAnalyzer] User Prompt (with lyrics):\n{user_prompt}")
149
150 messages_basic = self._build_messages(system_prompt, user_prompt, music_url)
151 response_basic = self._call_with_retry(client, messages_basic)
152
153 if response_basic is None:
154 return None
155
156 result = self._parse_response(response_basic.get("content", ""))
157 if result is None:
158 return None
159
160 # 第二次调用:歌词识别
161 lyrics_prompt = build_lyrics_prompt()
162
163 # 打印歌词识别提示词到日志
164 logger.info(f"[DoubaoAnalyzer] Lyrics Prompt:\n{lyrics_prompt}")
165
166 messages_lyrics = self._build_messages(
167 "请识别这段音频中的歌词内容", lyrics_prompt, music_url
168 )
169 response_lyrics = self._call_with_retry(client, messages_lyrics)
170
171 lyrics_result = None
172 if response_lyrics:
173 lyrics_result = self._parse_response(response_lyrics.get("content", ""))
174 if lyrics_result and "lyrics" in lyrics_result:
175 result["lyrics"] = lyrics_result["lyrics"]
176
177 # 合并 token 使用信息
178 usage = response_basic.get("usage", {})
179 if response_lyrics and response_lyrics.get("usage"):
180 usage_lyrics = response_lyrics["usage"]
181 usage = {
182 "prompt_tokens": usage.get("prompt_tokens", 0)
183 + usage_lyrics.get("prompt_tokens", 0),
184 "completion_tokens": usage.get("completion_tokens", 0)
185 + usage_lyrics.get("completion_tokens", 0),
186 "total_tokens": usage.get("total_tokens", 0)
187 + usage_lyrics.get("total_tokens", 0),
188 }
189
190 return self._normalize_result(result, self.model, usage)
191
192 def _build_messages(
193 self,
194 system_prompt: str,
195 user_prompt: str,
196 music_url: str,
197 ) -> list:
198 """构建消息格式"""
199 return [
200 {
201 "role": "user",
202 "content": [
203 {"type": "video_url", "video_url": {"url": music_url}},
204 {"type": "text", "text": user_prompt},
205 ],
206 }
207 ]
208
209 def _call_with_retry(
210 self,
211 client: httpx.Client,
212 messages: list,
213 ) -> Optional[Dict]:
214 """带重试的 API 调用"""
215 endpoint = "/chat/completions"
216 data = {
217 "model": self.model,
218 "messages": messages,
219 "temperature": 0.7,
220 "max_tokens": 4000,
221 "stream": False,
222 }
223
224 for attempt in range(1, self.max_retries + 1):
225 try:
226 print(f" [Doubao] 调用模型 (尝试 {attempt}/{self.max_retries})...")
227
228 start_time = time.time()
229
230 response = client.post(endpoint, json=data)
231 response.raise_for_status()
232
233 end_time = time.time()
234 elapsed = end_time - start_time
235 print(f" [Doubao] 响应时间: {elapsed:.2f}s")
236
237 result = response.json()
238 content = (
239 result.get("choices", [{}])[0].get("message", {}).get("content", "")
240 )
241 usage = result.get("usage", {})
242
243 print(f" [Doubao] 响应: {content[:100]}...")
244
245 return {
246 "content": content,
247 "usage": {
248 "prompt_tokens": usage.get("prompt_tokens", 0),
249 "completion_tokens": usage.get("completion_tokens", 0),
250 "total_tokens": usage.get("total_tokens", 0),
251 },
252 }
253
254 except httpx.HTTPError as e:
255 error_type = type(e).__name__
256 print(f" [Doubao] HTTP 错误 ({error_type}): {e}")
257
258 if attempt < self.max_retries:
259 wait_time = attempt
260 print(f" 等待 {wait_time} 秒后重试...")
261 time.sleep(wait_time)
262 else:
263 print(f" 已达到最大重试次数")
264 return None
265
266 except Exception as e:
267 error_type = type(e).__name__
268 print(f" [Doubao] 错误 ({error_type}): {e}")
269
270 if attempt < self.max_retries:
271 wait_time = attempt
272 print(f" 等待 {wait_time} 秒后重试...")
273 time.sleep(wait_time)
274 else:
275 print(f" 已达到最大重试次数")
276 return None
277
278 return None
279
280
281 def test_doubao_audio_url_lyrics():
282 """
283 测试豆包是否支持通过音频URL解析音频歌词
284
285 此测试用例用于验证豆包模型是否能够:
286 1. 接收音频URL作为输入
287 2. 解析音频内容
288 3. 识别并返回歌词
289
290 使用方法:
291 python -c "from app.middleware.music_analyze.doubao_analyzer import test_doubao_audio_url_lyrics; test_doubao_audio_url_lyrics()"
292
293 或者直接在命令行运行:
294 python app/middleware/music_analyze/doubao_analyzer.py
295 """
296 import json
297
298 print("=" * 80)
299 print("测试豆包音频URL歌词解析功能")
300 print("=" * 80)
301
302 # 测试音频URL(使用一个公开可访问的音频文件)
303 # 注意:请替换为实际可访问的音频URL
304 test_audio_url = "https://hikoon-ai-test.oss-cn-hangzhou.aliyuncs.com/ai/cache/modelName/20260114/_s__e_1768376270519_rmab41.mp3"
305
306 print(f"\n测试音频URL: {test_audio_url}")
307 print("\n开始测试...")
308
309 try:
310 # 初始化分析器
311 analyzer = DoubaoAnalyzer()
312
313 # 测试元数据
314 metadata = {"title": "测试歌曲", "artist": "测试艺术家", "test": True}
315
316 print("\n1. 测试基础分析(不含歌词)...")
317 result_basic = analyzer.analyze(
318 metadata=metadata,
319 music_url=test_audio_url,
320 extract_lyrics=False,
321 label_level=0,
322 )
323
324 if result_basic:
325 print(" ✓ 基础分析成功")
326 print(f" - 曲风: {result_basic.get('genre', 'N/A')}")
327 print(f" - 语种: {result_basic.get('language', 'N/A')}")
328 print(f" - 情绪: {result_basic.get('emotion', 'N/A')}")
329 else:
330 print(" ✗ 基础分析失败")
331
332 print("\n2. 测试歌词识别(含歌词)...")
333 result_with_lyrics = analyzer.analyze(
334 metadata=metadata,
335 music_url=test_audio_url,
336 extract_lyrics=True,
337 label_level=0,
338 )
339
340 if result_with_lyrics:
341 print(" ✓ 歌词识别分析成功")
342 lyrics = result_with_lyrics.get("lyrics", [])
343 if lyrics:
344 print(f" ✓ 成功识别歌词,共 {len(lyrics)} 行")
345 print("\n 歌词预览(前5行):")
346 for i, line in enumerate(lyrics[:5], 1):
347 time_str = line.get("time", "N/A")
348 text = line.get("text", "")
349 print(f" [{i}] {time_str} - {text}")
350 if len(lyrics) > 5:
351 print(f" ... 还有 {len(lyrics) - 5} 行")
352 print("\n ✓ 测试通过:豆包支持音频URL解析歌词")
353 else:
354 print(" ⚠ 未识别到歌词(可能是纯音乐或无法识别)")
355 print("\n ! 测试结果:豆包支持音频URL解析,但未返回歌词")
356
357 # 输出完整结果
358 print("\n3. 完整分析结果:")
359 print(json.dumps(result_with_lyrics, ensure_ascii=False, indent=2))
360 else:
361 print(" ✗ 歌词识别分析失败")
362 print("\n ✗ 测试失败:豆包可能不支持音频URL解析")
363
364 print("\n" + "=" * 80)
365 print("测试完成")
366 print("=" * 80)
367
368 return result_with_lyrics
369
370 except Exception as e:
371 print(f"\n✗ 测试过程中发生错误: {e}")
372 import traceback
373
374 traceback.print_exc()
375 return None
376
377
378 if __name__ == "__main__":
379 test_doubao_audio_url_lyrics()
1 # -*- coding: utf-8 -*-
2 """
3 音乐分析器工厂
4 """
5
6 from typing import Dict, Any, Optional
7 from .base import AudioAnalyzer
8 from .qwen_analyzer import QwenAnalyzer
9
10
11 class AnalyzerFactory:
12 """音乐分析器工厂"""
13
14 _analyzers: Dict[str, AudioAnalyzer] = {}
15
16 @classmethod
17 def get_analyzer(cls, provider: str = "qwen", **kwargs) -> AudioAnalyzer:
18 """
19 获取分析器实例
20
21 Args:
22 provider: 提供商名称(仅支持 qwen)
23 **kwargs: 额外配置参数(如 api_key, model, timeout 等)
24
25 Returns:
26 AudioAnalyzer 实例
27 """
28 key = f"{provider}"
29 cache_key = f"{provider}_{kwargs.get('model', '')}"
30
31 if cache_key in cls._analyzers:
32 return cls._analyzers[cache_key]
33
34 if provider == "qwen":
35 analyzer = QwenAnalyzer(**kwargs)
36 else:
37 raise ValueError(f"Unknown provider: {provider}. Only 'qwen' is supported.")
38
39 cls._analyzers[cache_key] = analyzer
40 return analyzer
41
42 @classmethod
43 def get_default_analyzer(cls) -> AudioAnalyzer:
44 """获取默认分析器(从环境变量读取)"""
45 import os
46
47 provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
48 return cls.get_analyzer(provider=provider)
49
50 @classmethod
51 def list_providers(cls) -> list:
52 """列出可用的提供商"""
53 return ["qwen"]
54
55 @classmethod
56 def clear_cache(cls):
57 """清除缓存的分析器实例"""
58 cls._analyzers.clear()
1 # -*- coding: utf-8 -*-
2 """
3 音乐分析统一入口
4 提供简化的 analyze_music() 函数
5 """
6
7 from typing import Dict, Any, Optional
8 import os
9
10 from .factory import AnalyzerFactory
11
12
13 def analyze_music(
14 metadata: Dict[str, Any],
15 music_url: str,
16 provider: str = None,
17 extract_lyrics: bool = False,
18 label_level: int = 0,
19 ) -> Optional[Dict[str, Any]]:
20 """
21 音乐分析统一入口函数
22
23 Args:
24 metadata: 音乐元数据字典(如 title, artist 等)
25 music_url: 音乐文件 URL
26 provider: 提供商(qwen | doubao),默认从环境变量读取
27 extract_lyrics: 是否识别歌词
28 label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
29
30 Returns:
31 分析结果字典,包含以下字段:
32 - genre: 音乐风格(一级风格,如:流行、摇滚)
33 - emotion: 情绪列表
34 - emotional_intensity: 情绪强度
35 - vocal_texture: 人声质感
36 - vocal_description: 人声质感描述
37 - visual_concept: 视觉概念
38 - language: 语种
39 - bpm: 节拍数(可选)
40 - lyrics: 歌词列表(可选)
41 - _model: 使用的模型名称
42 - _token_info: Token 使用信息
43
44 Example:
45 >>> result = analyze_music(
46 ... metadata={"title": "稻香", "artist": "周杰伦"},
47 ... music_url="https://example.com/music.mp3",
48 ... provider="qwen",
49 ... extract_lyrics=False,
50 ... )
51 >>> print(result["genre"])
52 流行
53 """
54 if provider is None:
55 provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
56
57 analyzer = AnalyzerFactory.get_analyzer(provider=provider)
58
59 return analyzer.analyze(
60 metadata=metadata,
61 music_url=music_url,
62 extract_lyrics=extract_lyrics,
63 label_level=label_level,
64 )
65
66
67 def analyze_music_with_qwen(
68 metadata: Dict[str, Any],
69 music_url: str,
70 extract_lyrics: bool = False,
71 label_level: int = 0,
72 ) -> Optional[Dict[str, Any]]:
73 """使用通义千问分析音乐"""
74 return analyze_music(
75 metadata=metadata,
76 music_url=music_url,
77 provider="qwen",
78 extract_lyrics=extract_lyrics,
79 label_level=label_level,
80 )
81
82
83 def analyze_music_with_doubao(
84 metadata: Dict[str, Any],
85 music_url: str,
86 extract_lyrics: bool = False,
87 label_level: int = 0,
88 ) -> Optional[Dict[str, Any]]:
89 """使用火山引擎豆包分析音乐"""
90 return analyze_music(
91 metadata=metadata,
92 music_url=music_url,
93 provider="doubao",
94 extract_lyrics=extract_lyrics,
95 label_level=label_level,
96 )
97
98
99 def analyze_music_lyrics_only(
100 metadata: Dict[str, Any],
101 music_url: str,
102 provider: str = None,
103 ) -> Optional[Dict[str, Any]]:
104 """仅识别歌词,避免重复做基础标签分析"""
105 if provider is None:
106 provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
107
108 analyzer = AnalyzerFactory.get_analyzer(provider=provider)
109 if hasattr(analyzer, "analyze_lyrics_only"):
110 return analyzer.analyze_lyrics_only(metadata=metadata, music_url=music_url)
111
112 # 兼容未实现 lyrics_only 的提供商
113 result = analyzer.analyze(
114 metadata=metadata,
115 music_url=music_url,
116 extract_lyrics=True,
117 label_level=0,
118 )
119 if isinstance(result, dict):
120 lyrics = result.get("lyrics", [])
121 return {
122 "lyrics": lyrics if isinstance(lyrics, list) else [],
123 "_model": result.get("_model"),
124 "_token_info": result.get("_token_info"),
125 }
126 return None
127
128
129 def get_available_providers() -> list:
130 """获取可用的提供商列表"""
131 return AnalyzerFactory.list_providers()
1 # -*- coding: utf-8 -*-
2 """
3 音乐分析提示词模板构建器
4 支持从外部模板文件读取提示词,便于动态修改
5 """
6
7 import os
8 from pathlib import Path
9 from typing import Dict, Any, Optional
10
11
12 # 模板文件路径(已迁移到 app/prompts/step2_music_decode)
13 PROMPTS_DIR = Path(__file__).parent.parent.parent / "prompts" / "step2_music_decode"
14 SYSTEM_PROMPT_FILE = PROMPTS_DIR / "music_analyze_system_prompt.md"
15 SYSTEM_PROMPT_PART_A_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_a.md"
16 SYSTEM_PROMPT_PART_B_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_b.md"
17 USER_PROMPT_FILE = PROMPTS_DIR / "music_analyze_user_prompt.md"
18 LYRICS_ONLY_PROMPT_FILE = PROMPTS_DIR / "music_lyrics_only_prompt.md"
19
20
21 def load_template(template_path: Path) -> str:
22 """
23 从文件加载模板
24
25 Args:
26 template_path: 模板文件路径
27
28 Returns:
29 模板内容字符串
30 """
31 if not template_path.exists():
32 raise FileNotFoundError(f"模板文件不存在: {template_path}")
33
34 with open(template_path, "r", encoding="utf-8") as f:
35 content = f.read()
36
37 # 只移除文件顶部的 Markdown 注释(以 # 开头的注释行)
38 # 保留 ## 标题行和正文内容
39 lines = content.split("\n")
40 filtered_lines = []
41 in_header = True
42
43 for line in lines:
44 stripped = line.strip()
45 # 如果是空行,保留
46 if not stripped:
47 filtered_lines.append(line)
48 continue
49
50 # 如果在文件头部且是单行注释(# 但不是 ##),则跳过
51 if in_header and stripped.startswith("#") and not stripped.startswith("##"):
52 continue
53
54 # 遇到 ## 标题或正文内容,不再是头部
55 in_header = False
56 filtered_lines.append(line)
57
58 return "\n".join(filtered_lines)
59
60
61 class PromptBuilder:
62 """音乐分析提示词模板构建器"""
63
64 def __init__(self, label_level: int = 0):
65 """
66 初始化提示词构建器
67
68 Args:
69 label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
70 """
71 self.label_level = label_level
72
73 def build_system_prompt(self) -> str:
74 """构建系统提示词 - 直接加载静态模板"""
75 return load_template(SYSTEM_PROMPT_FILE)
76
77 def build_system_prompt_part_a(self) -> str:
78 """构建系统提示词A组"""
79 return load_template(SYSTEM_PROMPT_PART_A_FILE)
80
81 def build_system_prompt_part_b(self) -> str:
82 """构建系统提示词B组"""
83 return load_template(SYSTEM_PROMPT_PART_B_FILE)
84
85 def build_metadata_section(self, metadata: Optional[Dict[str, Any]] = None) -> str:
86 """构建元数据部分"""
87 if not metadata:
88 return ""
89
90 sections = ["## 音乐元数据"]
91 for key, value in metadata.items():
92 if key.startswith("_"):
93 continue
94 if value and str(value).strip():
95 sections.append(f"- {key}: {value}")
96 sections.append("")
97 return "\n".join(sections)
98
99 def build_output_format(
100 self,
101 include_lyrics: bool = False,
102 include_bpm: bool = True,
103 ) -> str:
104 """构建输出格式说明"""
105 if include_lyrics and include_bpm:
106 format_spec = """{
107 "genre": "",
108 "sub_genre": "",
109 "language": "",
110 "vocal_type": "",
111 "vocal_description": "",
112 "emotion": [""],
113 "scene": [""],
114 "age": "",
115 "rhythm_intensity": "",
116 "is_sinking": false,
117 "song_description": "",
118 "visual_concept": "",
119 "emotional_intensity": "",
120 "bpm": 0,
121 "lyrics": [{"time": "", "text": ""}]
122 }"""
123 elif include_bpm:
124 format_spec = """{
125 "genre": "",
126 "sub_genre": "",
127 "language": "",
128 "vocal_type": "",
129 "vocal_description": "",
130 "emotion": [""],
131 "scene": [""],
132 "age": "",
133 "rhythm_intensity": "",
134 "is_sinking": false,
135 "song_description": "",
136 "visual_concept": "",
137 "emotional_intensity": "",
138 "bpm": 0
139 }"""
140 elif include_lyrics:
141 format_spec = """{
142 "genre": "",
143 "sub_genre": "",
144 "language": "",
145 "vocal_type": "",
146 "vocal_description": "",
147 "emotion": [""],
148 "scene": [""],
149 "age": "",
150 "rhythm_intensity": "",
151 "is_sinking": false,
152 "song_description": "",
153 "visual_concept": "",
154 "emotional_intensity": "",
155 "lyrics": [{"time": "", "text": ""}]
156 }"""
157 else:
158 format_spec = """{
159 "genre": "",
160 "sub_genre": "",
161 "language": "",
162 "vocal_type": "",
163 "vocal_description": "",
164 "emotion": [""],
165 "scene": [""],
166 "age": "",
167 "rhythm_intensity": "",
168 "is_sinking": false,
169 "song_description": "",
170 "visual_concept": "",
171 "emotional_intensity": ""
172 }"""
173
174 return format_spec
175
176 def build_user_prompt(
177 self,
178 metadata: Optional[Dict[str, Any]] = None,
179 include_lyrics: bool = False,
180 include_bpm: bool = True,
181 ) -> str:
182 """
183 构建完整的用户提示词
184 使用模板文件并替换占位符
185
186 Args:
187 metadata: 音乐元数据字典(可选)
188 include_lyrics: 是否识别歌词(保留参数以兼容现有调用)
189 include_bpm: 是否包含BPM识别(保留参数以兼容现有调用)
190
191 Returns:
192 完整的用户提示词
193 """
194 # 加载模板
195 template = load_template(USER_PROMPT_FILE)
196
197 # 准备替换字典 - 只替换元数据部分
198 # 输出格式已在系统提示词中定义,不需要在用户提示词中重复
199 replacements = {
200 "{{METADATA_SECTION}}": self.build_metadata_section(metadata),
201 }
202
203 # 替换占位符
204 result = template
205 for placeholder, value in replacements.items():
206 result = result.replace(placeholder, value)
207
208 return result
209
210 def build_lyrics_only_prompt(self) -> str:
211 """构建仅识别歌词的提示词"""
212 return load_template(LYRICS_ONLY_PROMPT_FILE)
213
214
215 def build_analyze_prompt(
216 metadata: Optional[Dict[str, Any]] = None,
217 include_lyrics: bool = False,
218 label_level: int = 0,
219 ) -> tuple[str, str]:
220 """
221 构建完整的分析提示词
222
223 Args:
224 metadata: 音乐元数据字典(可选)
225 include_lyrics: 是否识别歌词
226 label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
227
228 Returns:
229 (system_prompt, user_prompt) 元组
230 """
231 builder = PromptBuilder(label_level=label_level)
232 system_prompt = builder.build_system_prompt()
233 user_prompt = builder.build_user_prompt(
234 metadata=metadata,
235 include_lyrics=include_lyrics,
236 include_bpm=True,
237 )
238 return system_prompt, user_prompt
239
240
241 def build_analyze_prompt_part_a(
242 metadata: Optional[Dict[str, Any]] = None,
243 include_lyrics: bool = False,
244 label_level: int = 0,
245 ) -> tuple[str, str]:
246 """
247 构建A组分析提示词(标签与基础信息)
248 """
249 builder = PromptBuilder(label_level=label_level)
250 system_prompt = builder.build_system_prompt_part_a()
251 user_prompt = builder.build_user_prompt(
252 metadata=metadata,
253 include_lyrics=include_lyrics,
254 include_bpm=True,
255 )
256 return system_prompt, user_prompt
257
258
259 def build_analyze_prompt_part_b(
260 metadata: Optional[Dict[str, Any]] = None,
261 include_lyrics: bool = False,
262 label_level: int = 0,
263 ) -> tuple[str, str]:
264 """
265 构建B组分析提示词(节奏与视觉描述)
266 """
267 builder = PromptBuilder(label_level=label_level)
268 system_prompt = builder.build_system_prompt_part_b()
269 user_prompt = builder.build_user_prompt(
270 metadata=metadata,
271 include_lyrics=include_lyrics,
272 include_bpm=True,
273 )
274 return system_prompt, user_prompt
275
276
277 def build_lyrics_prompt() -> str:
278 """构建仅识别歌词的提示词"""
279 builder = PromptBuilder()
280 return builder.build_lyrics_only_prompt()
281
282
283 # 向后兼容:保留原有的构建函数
284 def build_user_prompt(
285 metadata: Optional[Dict[str, Any]] = None,
286 include_lyrics: bool = False,
287 label_level: int = 0,
288 ) -> str:
289 """构建用户提示词(兼容函数)"""
290 builder = PromptBuilder(label_level=label_level)
291 return builder.build_user_prompt(
292 metadata=metadata,
293 include_lyrics=include_lyrics,
294 include_bpm=True,
295 )
1 # -*- coding: utf-8 -*-
2 """
3 通义千问音乐分析器实现
4 """
5
6 import os
7 import time
8 import tempfile
9 import subprocess
10 import threading
11 import hashlib
12 import csv
13 from datetime import datetime
14 from pathlib import Path
15 import requests
16 import logging
17 from typing import Dict, Any, Optional, Tuple, List
18 from concurrent.futures import ThreadPoolExecutor
19
20 from .base import AudioAnalyzer
21 from .prompts import (
22 build_analyze_prompt,
23 build_lyrics_prompt,
24 )
25 from .audio_features import (
26 extract_audio_features,
27 extract_beat_timestamps,
28 extract_emotion_curve,
29 aggregate_emotion_by_segments,
30 )
31
32 # 使用项目统一的配置
33 from app.core.config import settings
34
35 logger = logging.getLogger(__name__)
36
37 MUSIC_MAPPING_HEADERS = [
38 "song_id",
39 "audio_file_name",
40 "audio_file_path",
41 "source_url",
42 "updated_at",
43 ]
44
45 MUSIC_MAPPING_HEADER_ALIASES = {
46 "song_id": ("song_id", "歌曲ID"),
47 "audio_file_name": ("audio_file_name", "音频文件名"),
48 "audio_file_path": ("audio_file_path", "音频文件路径"),
49 "source_url": ("source_url", "原始URL"),
50 "updated_at": ("updated_at", "更新时间"),
51 }
52
53
54 class QwenAnalyzer(AudioAnalyzer):
55 """通义千问音乐分析器"""
56
57 def __init__(
58 self,
59 api_key: Optional[str] = None,
60 base_url: Optional[str] = None,
61 model: Optional[str] = None,
62 max_retries: int = 3,
63 ):
64 """
65 初始化通义千问分析器
66
67 Args:
68 api_key: API Key(默认从环境变量读取 QWEN_API_KEY)
69 base_url: API 基础URL(默认从环境变量读取)
70 model: 模型名称(默认: qwen3-omni-flash)
71 timeout: 超时时间(秒)
72 max_retries: 最大重试次数
73 """
74 # 优先使用传入的参数,其次使用项目统一的 settings
75 if api_key is None:
76 # 按优先级:QWEN_API_KEY -> QWEN_DASHSCOPE_API_KEY
77 api_key = settings.QWEN_API_KEY or settings.QWEN_DASHSCOPE_API_KEY
78 self.api_key = api_key
79 self.base_url = (
80 base_url
81 or settings.QWEN_BASE_URL
82 or "https://dashscope.aliyuncs.com/compatible-mode/v1"
83 )
84 self.model = model or settings.QWEN_MODEL or "qwen3-omni-flash"
85 self.timeout = settings.QWEN_TIMEOUT or 15.0
86 self.lyrics_timeout = settings.QWEN_LYRICS_TIMEOUT or 90.0
87 self.max_retries = max_retries or settings.QWEN_MAX_RETRIES or 3
88
89 self._client = None
90 self._project_root = Path(__file__).resolve().parents[3]
91 self._music_dir = self._resolve_music_dir()
92 self._music_mapping_path = self._resolve_music_mapping_path()
93 self._mapping_lock = threading.Lock()
94 self._mapping_seen: set[tuple[str, str]] = self._load_existing_mapping_keys()
95
96 def _resolve_music_dir(self) -> Path:
97 raw_dir = str(getattr(settings, "MUSIC_DOWNLOAD_DIR", "music") or "music").strip()
98 path = Path(raw_dir)
99 if not path.is_absolute():
100 path = self._project_root / path
101 path.mkdir(parents=True, exist_ok=True)
102 return path
103
104 def _resolve_music_mapping_path(self) -> Path:
105 raw_file = str(
106 getattr(settings, "MUSIC_MAPPING_FILE", "music/music_file_mapping.csv")
107 or "music/music_file_mapping.csv"
108 ).strip()
109 path = Path(raw_file)
110 if not path.is_absolute():
111 path = self._project_root / path
112 path.parent.mkdir(parents=True, exist_ok=True)
113 return path
114
115 def _load_existing_mapping_keys(self) -> set[tuple[str, str]]:
116 if not self._music_mapping_path.exists():
117 return set()
118 seen: set[tuple[str, str]] = set()
119 try:
120 with open(self._music_mapping_path, "r", encoding="utf-8-sig", newline="") as f:
121 reader = csv.DictReader(f)
122 for row in reader:
123 song_id = self._get_mapping_value(row, "song_id")
124 file_path = self._get_mapping_value(row, "audio_file_path")
125 if file_path:
126 try:
127 file_path = str(Path(file_path).resolve())
128 except Exception:
129 pass
130 seen.add((song_id, file_path))
131 except Exception:
132 return set()
133 return seen
134
135 def _get_mapping_value(self, row: Dict[str, Any], field: str) -> str:
136 for alias in MUSIC_MAPPING_HEADER_ALIASES.get(field, (field,)):
137 value = row.get(alias)
138 if value is not None and str(value).strip():
139 return str(value).strip()
140 return ""
141
142 def _extract_song_id(self, metadata: Optional[Dict[str, Any]]) -> str:
143 if not metadata:
144 return ""
145 for key in ("歌曲ID", "song_id", "id", "track_id", "tmeid", "tmeID", "TMEID"):
146 value = metadata.get(key)
147 if value is not None and str(value).strip():
148 return str(value).strip()
149 return ""
150
151 def _sanitize_filename_part(self, value: str) -> str:
152 safe_chars = []
153 for ch in value:
154 if ch.isalnum() or ch in {"-", "_", "."}:
155 safe_chars.append(ch)
156 else:
157 safe_chars.append("_")
158 cleaned = "".join(safe_chars).strip("._")
159 return cleaned[:80] if cleaned else "unknown"
160
161 def _build_music_file_path(
162 self,
163 music_url: str,
164 ext: str,
165 metadata: Optional[Dict[str, Any]] = None,
166 ) -> Path:
167 song_id = self._extract_song_id(metadata)
168 song_part = self._sanitize_filename_part(song_id or "unknown")
169 url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12]
170 return self._music_dir / f"{song_part}_{url_hash}{ext}"
171
172 def _append_music_mapping(
173 self,
174 file_path: Path,
175 music_url: str,
176 metadata: Optional[Dict[str, Any]] = None,
177 ) -> None:
178 song_id = self._extract_song_id(metadata)
179 mapping_key = (song_id, str(file_path.resolve()))
180
181 with self._mapping_lock:
182 if mapping_key in self._mapping_seen:
183 return
184
185 write_header = not self._music_mapping_path.exists()
186 encoding = "utf-8-sig" if write_header else "utf-8"
187 with open(self._music_mapping_path, "a", encoding=encoding, newline="") as f:
188 writer = csv.DictWriter(
189 f,
190 fieldnames=MUSIC_MAPPING_HEADERS,
191 )
192 if write_header:
193 writer.writeheader()
194 writer.writerow(
195 {
196 "song_id": song_id,
197 "audio_file_name": file_path.name,
198 "audio_file_path": str(file_path.resolve()),
199 "source_url": music_url,
200 "updated_at": datetime.now().isoformat(timespec="seconds"),
201 }
202 )
203 self._mapping_seen.add(mapping_key)
204
205 def _is_persisted_music_file(self, file_path: str) -> bool:
206 try:
207 candidate = Path(file_path).resolve()
208 return candidate.parent == self._music_dir.resolve()
209 except Exception:
210 return False
211
212 def _get_client(self):
213 """获取 OpenAI 兼容客户端"""
214 if self._client is None:
215 from openai import OpenAI
216
217 self._client = OpenAI(
218 api_key=self.api_key,
219 base_url=self.base_url,
220 timeout=self.timeout,
221 max_retries=0,
222 )
223 return self._client
224
225 def get_provider_name(self) -> str:
226 return "qwen"
227
228 def get_model_name(self) -> str:
229 return self.model
230
231 def _call_songformer(self, music_url: str) -> Optional[Dict]:
232 """
233 调用 SongFormer 服务获取歌曲结构和高潮点
234
235 Args:
236 music_url: 音乐文件 URL
237
238 Returns:
239 SongFormer 返回的完整数据字典
240 """
241 songformer_url = getattr(settings, "SONGFORMER_URL", None)
242 if not songformer_url:
243 print(" [Qwen] SongFormer URL 未配置,跳过高潮点分析")
244 return None
245
246 try:
247 print(f" [Qwen] 调用 SongFormer 服务...")
248 resp = requests.post(
249 songformer_url,
250 json={"url": music_url, "chorus_k": 3},
251 timeout=60,
252 )
253 resp.raise_for_status()
254 data = resp.json()
255 print(f" [Qwen] SongFormer 调用成功")
256 return data
257 except Exception as e:
258 print(f" [Qwen] SongFormer 调用失败: {e}")
259 return None
260
261 def _extract_climax_point(self, songformer_data: Optional[Dict]) -> str:
262 """
263 从 SongFormer 数据中提取高潮点
264
265 Args:
266 songformer_data: SongFormer 返回的数据
267
268 Returns:
269 str: "最强", "强", 或 ""
270 """
271 if not songformer_data:
272 return ""
273
274 # 首先尝试从 climax_points 字段获取(旧格式)
275 climax_points = songformer_data.get("climax_points", {})
276 if climax_points:
277 # 检查是否有最强高潮
278 if climax_points.get("strongest_climax"):
279 return "最强"
280 # 检查是否有强高潮
281 if climax_points.get("strong_climax"):
282 return "强"
283
284 # 从 top_k_chorus 字段获取(新格式)
285 top_k_chorus = songformer_data.get("top_k_chorus", [])
286 if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0:
287 # 按 score 排序,取最高分作为最强高潮
288 sorted_chorus = sorted(
289 [
290 c
291 for c in top_k_chorus
292 if isinstance(c, dict) and c.get("score") is not None
293 ],
294 key=lambda x: x.get("score", 0),
295 reverse=True,
296 )
297 if sorted_chorus:
298 # 最高分 > 7.0 认为是"最强",否则是"强"
299 highest_score = sorted_chorus[0].get("score", 0)
300 if highest_score > 7.0:
301 return "最强"
302 else:
303 return "强"
304
305 return ""
306
307 def _build_climax_points(self, songformer_data: Optional[Dict]) -> Dict[str, Any]:
308 """
309 从 SongFormer 数据构建完整的 climax_points 对象
310
311 Args:
312 songformer_data: SongFormer 返回的数据
313
314 Returns:
315 包含 strong_climax 和 strongest_climax 的字典
316 """
317 if not songformer_data:
318 return {
319 "strong_climax": None,
320 "strongest_climax": None,
321 "analysis_time": 0.0,
322 }
323
324 # 首先尝试从 climax_points 字段获取(旧格式)
325 climax_points = songformer_data.get("climax_points", {})
326 if climax_points and (
327 climax_points.get("strong_climax") or climax_points.get("strongest_climax")
328 ):
329 return {
330 "strong_climax": climax_points.get("strong_climax"),
331 "strongest_climax": climax_points.get("strongest_climax"),
332 "analysis_time": climax_points.get("analysis_time", 0.0),
333 }
334
335 # 从 top_k_chorus 字段构建(新格式)
336 top_k_chorus = songformer_data.get("top_k_chorus", [])
337 segments = songformer_data.get("segments", [])
338
339 if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0:
340 # 按 score 排序
341 sorted_chorus = sorted(
342 [
343 c
344 for c in top_k_chorus
345 if isinstance(c, dict) and c.get("score") is not None
346 ],
347 key=lambda x: x.get("score", 0),
348 reverse=True,
349 )
350
351 if sorted_chorus:
352 # 最高分作为 strongest_climax
353 highest = sorted_chorus[0]
354 highest_score = highest.get("score", 0)
355
356 # 找到对应的段落标签
357 start_time = highest.get("start", 0)
358 section_label = "chorus"
359 for seg in segments:
360 if isinstance(seg, dict):
361 seg_start = seg.get("start", 0)
362 seg_end = seg.get("end", 0)
363 if seg_start <= start_time < seg_end:
364 section_label = seg.get("label", "chorus")
365 break
366
367 strongest_climax = {
368 "time": start_time,
369 "intensity": "strongest",
370 "section_label": section_label,
371 "reason": f"Highest chorus score: {highest_score:.2f}",
372 }
373
374 # 第二高作为 strong_climax(如果存在且分数差距不大)
375 strong_climax = None
376 if len(sorted_chorus) > 1:
377 second = sorted_chorus[1]
378 second_score = second.get("score", 0)
379 second_start = second.get("start", 0)
380
381 # 找到对应的段落标签
382 second_section_label = "chorus"
383 for seg in segments:
384 if isinstance(seg, dict):
385 seg_start = seg.get("start", 0)
386 seg_end = seg.get("end", 0)
387 if seg_start <= second_start < seg_end:
388 second_section_label = seg.get("label", "chorus")
389 break
390
391 strong_climax = {
392 "time": second_start,
393 "intensity": "strong",
394 "section_label": second_section_label,
395 "reason": f"Second highest chorus score: {second_score:.2f}",
396 }
397
398 return {
399 "strong_climax": strong_climax,
400 "strongest_climax": strongest_climax,
401 "analysis_time": 0.0,
402 }
403
404 return {
405 "strong_climax": None,
406 "strongest_climax": None,
407 "analysis_time": 0.0,
408 }
409
410 def analyze(
411 self,
412 metadata: Dict[str, Any],
413 music_url: str,
414 extract_lyrics: bool = False,
415 label_level: int = 0,
416 ) -> Optional[Dict[str, Any]]:
417 """
418 分析音乐
419
420 Args:
421 metadata: 音乐元数据
422 music_url: 音乐文件 URL
423 extract_lyrics: 是否识别歌词
424 label_level: 标签级别
425
426 Returns:
427 分析结果字典
428 """
429 client = self._get_client()
430
431 light_mode = bool(getattr(settings, "MUSIC_ANALYZE_LIGHT_MODE", True))
432 songformer_data = None if light_mode else self._call_songformer(music_url)
433
434 # 下载音频并提取本地特征
435 local_features = {}
436 tmp_file_path = None
437 try:
438 if light_mode:
439 print(" [Qwen] 轻量模式: 仅提取 BPM")
440 tmp_file_path, _ = self._download_audio(music_url, metadata=metadata)
441 beat_info = extract_beat_timestamps(tmp_file_path)
442 local_features = {"bpm": round(beat_info.tempo)}
443 print(f" [Qwen] 本地特征: BPM={local_features.get('bpm')}")
444 else:
445 print(f" [Qwen] 下载音频并提取本地特征...")
446 tmp_file_path, _ = self._download_audio(music_url, metadata=metadata)
447
448 # 从 songformer 获取段落结构用于情绪聚合
449 segments = songformer_data.get("segments") if songformer_data else None
450 local_features = self._extract_local_features(tmp_file_path, segments=segments)
451
452 # 从 SongFormer 数据中提取高潮点
453 climax_point = self._extract_climax_point(songformer_data)
454 local_features["climax_point"] = climax_point
455
456 # 构建完整的 climax_points 对象
457 climax_points = self._build_climax_points(songformer_data)
458 local_features["climax_points"] = climax_points
459
460 print(
461 f" [Qwen] 本地特征: BPM={local_features.get('bpm')}, "
462 f"段落情绪数={len(local_features.get('segment_emotions', []))}, "
463 f"高潮点={climax_point}"
464 )
465 except Exception as e:
466 print(f" [Qwen] 本地特征提取失败,将使用LLM估算值: {e}")
467 finally:
468 # 清理临时文件
469 if (
470 tmp_file_path
471 and os.path.exists(tmp_file_path)
472 and not self._is_persisted_music_file(tmp_file_path)
473 ):
474 try:
475 os.unlink(tmp_file_path)
476 except:
477 pass
478
479 # 执行LLM分析
480 if extract_lyrics:
481 result = self._analyze_with_lyrics(client, metadata, music_url, label_level)
482 else:
483 result = self._analyze_basic(client, metadata, music_url, label_level)
484
485 # 合并本地特征到结果中
486 if result and local_features:
487 # 使用本地提取的值覆盖
488 result.update(local_features)
489
490 return result
491
492 def _analyze_basic(
493 self,
494 client,
495 metadata: Dict[str, Any],
496 music_url: str,
497 label_level: int = 0,
498 ) -> Optional[Dict[str, Any]]:
499 """基础分析(不含歌词,单轮标签分析)"""
500 # 提取音频ID用于错误定位
501 song_id = self._extract_song_id(metadata)
502 print(f" [Qwen] 分析音频: 歌曲ID={song_id}")
503
504 system_prompt, user_prompt = build_analyze_prompt(
505 metadata=metadata,
506 include_lyrics=False,
507 label_level=label_level,
508 )
509
510 prompt = self._build_dashscope_prompt(system_prompt, user_prompt)
511 response = self._call_with_retry_dashscope(music_url, prompt, song_id=song_id, metadata=metadata)
512 if response is None:
513 return None
514
515 raw_content = response.get("content", "")
516 parsed = self._parse_response(raw_content)
517 if parsed is None:
518 return None
519 if isinstance(parsed, list):
520 if parsed and isinstance(parsed[0], dict):
521 parsed = parsed[0]
522 else:
523 return None
524 if not isinstance(parsed, dict):
525 return None
526
527 return self._normalize_result(parsed, self.model, response.get("usage"))
528
529 def _download_audio(
530 self, music_url: str, metadata: Optional[Dict[str, Any]] = None
531 ) -> Tuple[str, str]:
532 """
533 下载音频文件到 music 目录(按 URL+歌曲ID 命名并复用缓存)
534
535 Args:
536 music_url: 音频URL
537 metadata: 音乐元数据(用于提取歌曲ID生成映射表)
538
539 Returns:
540 (本地文件路径, 文件扩展名)
541 """
542 # 确定文件扩展名
543 ext = ".mp3"
544 if "." in music_url:
545 url_ext = music_url.split(".")[-1].split("?")[0].lower()
546 if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]:
547 ext = f".{url_ext}"
548
549 target_path = self._build_music_file_path(music_url, ext, metadata=metadata)
550
551 if not target_path.exists():
552 response = requests.get(music_url, timeout=60)
553 response.raise_for_status()
554 with open(target_path, "wb") as f:
555 f.write(response.content)
556 print(f" [Qwen] 音频已保存: {target_path}")
557
558 self._append_music_mapping(target_path, music_url, metadata=metadata)
559 return str(target_path), ext
560
561 def _extract_local_features(
562 self,
563 audio_path: str,
564 segments: Optional[List[Dict[str, Any]]] = None,
565 ) -> Dict[str, Any]:
566 """
567 提取本地音频特征
568
569 Args:
570 audio_path: 本地音频文件路径
571 segments: songformer 返回的段落结构(可选),用于聚合情绪曲线
572
573 Returns:
574 包含bpm、卡点时间戳、情绪曲线的字典
575 """
576 try:
577 features = extract_audio_features(audio_path)
578
579 # 卡点检测
580 beat_info = extract_beat_timestamps(audio_path)
581
582 # 情绪曲线
583 emotion_curve = extract_emotion_curve(audio_path)
584
585 # beat_info.tempo 经过节拍层级纠正,比 features.tempo 更准确
586 result = {
587 "bpm": round(beat_info.tempo),
588 # 卡点信息
589 "beat_timestamps": beat_info.beat_timestamps,
590 "downbeat_timestamps": beat_info.downbeat_timestamps,
591 "beat_intervals": beat_info.beat_intervals,
592 }
593
594 # 如果有段落结构,返回按段落聚合的情绪数据
595 if segments:
596 segment_emotions = aggregate_emotion_by_segments(emotion_curve, segments)
597 result["segment_emotions"] = [
598 {
599 "start": se.start,
600 "end": se.end,
601 "label": se.label,
602 "intensity": se.intensity,
603 "energy": se.energy,
604 "valence": se.valence,
605 "arousal": se.arousal,
606 "trend": se.trend,
607 }
608 for se in segment_emotions
609 ]
610 else:
611 # 没有段落结构时,返回原始情绪曲线
612 result["emotion_curve"] = {
613 "timestamps": emotion_curve.timestamps,
614 "energy_values": emotion_curve.energy_values,
615 "valence_values": emotion_curve.valence_values,
616 "arousal_values": emotion_curve.arousal_values,
617 "values": emotion_curve.smoothed_curve,
618 }
619
620 return result
621 except Exception as e:
622 print(f" [Qwen] 本地特征提取失败: {e}")
623 return {}
624
625 def _analyze_with_lyrics(
626 self,
627 client,
628 metadata: Dict[str, Any],
629 music_url: str,
630 label_level: int = 0,
631 ) -> Optional[Dict[str, Any]]:
632 """分析(含歌词识别,单轮标签分析 + 歌词并发)"""
633 # 提取音频ID用于错误定位
634 song_id = self._extract_song_id(metadata)
635 print(f" [Qwen] 分析音频: 歌曲ID={song_id}")
636
637 system_prompt, user_prompt = build_analyze_prompt(
638 metadata=metadata,
639 include_lyrics=False,
640 label_level=label_level,
641 )
642
643 prompt = self._build_dashscope_prompt(system_prompt, user_prompt)
644
645 lyrics_prompt = build_lyrics_prompt()
646
647 messages_lyrics = self._build_messages(
648 "请识别这段音频中的歌词内容", lyrics_prompt, music_url
649 )
650
651 print(" [Qwen] 并发执行基础标签分析和歌词识别...")
652 start_time = time.time()
653
654 result_main: Optional[Dict[str, Any]] = None
655 usage_main: Optional[Dict[str, Any]] = None
656 response_lyrics = None
657 timing: Dict[str, float] = {}
658
659 def _timed_call_dashscope(prompt_text: str) -> tuple[Optional[Dict], float]:
660 call_start = time.time()
661 resp = self._call_with_retry_dashscope(music_url, prompt_text, song_id=song_id, metadata=metadata)
662 return resp, round(time.time() - call_start, 2)
663
664 futures = {}
665 with ThreadPoolExecutor(max_workers=2) as executor:
666 futures[executor.submit(_timed_call_dashscope, prompt)] = "main"
667 futures[executor.submit(self._timed_call_openai, client, messages_lyrics)] = "lyrics"
668
669 for future in futures:
670 part = futures[future]
671 response, part_elapsed = future.result()
672 if part == "lyrics":
673 timing["lyrics"] = part_elapsed
674 response_lyrics = response
675 continue
676 timing["analysis"] = part_elapsed
677 if response is None:
678 continue
679 raw_content = response.get("content", "")
680 parsed = self._parse_response(raw_content)
681 if parsed is None:
682 continue
683 if isinstance(parsed, list):
684 if parsed and isinstance(parsed[0], dict):
685 parsed = parsed[0]
686 else:
687 continue
688 if not isinstance(parsed, dict):
689 continue
690 result_main = parsed
691 usage_main = response.get("usage")
692
693 elapsed = time.time() - start_time
694 print(f" [Qwen] 并发调用完成,总耗时: {elapsed:.2f}s")
695
696 if result_main is None:
697 return None
698 if not isinstance(result_main, dict):
699 return None
700
701 result: Dict[str, Any] = dict(result_main)
702
703 # 处理歌词识别结果
704 if response_lyrics:
705 raw_lyrics = response_lyrics.get("content", "")
706 lyrics_result = self._parse_response(raw_lyrics)
707 if isinstance(lyrics_result, list):
708 if lyrics_result and isinstance(lyrics_result[0], dict):
709 lyrics_result = lyrics_result[0]
710 if lyrics_result and "lyrics" in lyrics_result:
711 result["lyrics"] = lyrics_result["lyrics"]
712 result["_timing"] = timing
713
714 # 合并 token 使用信息
715 usage: Dict[str, Any] = {}
716 if usage_main:
717 usage.update(usage_main)
718 if response_lyrics and response_lyrics.get("usage"):
719 usage_lyrics = response_lyrics["usage"]
720 usage = {
721 "prompt_tokens": usage.get("prompt_tokens", 0)
722 + usage_lyrics.get("prompt_tokens", 0),
723 "completion_tokens": usage.get("completion_tokens", 0)
724 + usage_lyrics.get("completion_tokens", 0),
725 "total_tokens": usage.get("total_tokens", 0)
726 + usage_lyrics.get("total_tokens", 0),
727 }
728
729 result["_token_info_parts"] = {
730 "main": usage_main,
731 "lyrics": response_lyrics.get("usage") if response_lyrics else None,
732 }
733
734 return self._normalize_result(result, self.model, usage)
735
736 def analyze_lyrics_only(
737 self,
738 metadata: Dict[str, Any],
739 music_url: str,
740 ) -> Optional[Dict[str, Any]]:
741 """仅执行歌词识别,不做基础标签分析(ASR异步任务)"""
742 backend = (
743 str(
744 os.getenv("MUSIC_LYRICS_ASR_BACKEND")
745 or getattr(settings, "MUSIC_LYRICS_ASR_BACKEND", "funasr")
746 )
747 .strip()
748 .lower()
749 )
750
751 if backend == "whisper":
752 analyze_fn = self._analyze_lyrics_only_whisper
753 elif backend in {"omni", "qwen-omni", "qwen_omni"}:
754 # qwen-omni: 单轮流程内最多3次请求,失败后直接降级 funasr
755 omni_result = self._analyze_lyrics_only_qwen_omni(music_url)
756 if omni_result:
757 return omni_result
758 logger.warning(
759 "qwen-omni 歌词识别失败,降级到 funasr (lyrics_timeout=%ss)",
760 self.lyrics_timeout,
761 )
762
763 fallback_retry_count = 1
764 fallback_retry_delay_seconds = 2.0
765 for attempt in range(1, fallback_retry_count + 2):
766 fallback_result = self._analyze_lyrics_only_funasr(music_url)
767 if fallback_result:
768 logger.info(
769 "funasr 降级成功: attempt=%s/%s",
770 attempt,
771 fallback_retry_count + 1,
772 )
773 return fallback_result
774
775 if attempt <= fallback_retry_count:
776 logger.warning(
777 "funasr 降级失败,%s 秒后重试 (%s/%s)",
778 fallback_retry_delay_seconds,
779 attempt,
780 fallback_retry_count,
781 )
782 time.sleep(fallback_retry_delay_seconds)
783
784 logger.warning("funasr 降级失败,继续降级到 whisper")
785 whisper_result = self._analyze_lyrics_only_whisper(music_url)
786 if whisper_result:
787 logger.info("whisper 降级成功")
788 return whisper_result
789
790 logger.error("歌词识别降级链全部失败: qwen-omni -> funasr -> whisper")
791 return None
792 elif backend in {"fun", "funasr", "fun-asr"}:
793 analyze_fn = self._analyze_lyrics_only_funasr
794 else:
795 logger.error(
796 "不支持的歌词识别后端: %s,仅支持 whisper/funasr/qwen-omni",
797 backend,
798 )
799 return None
800
801 retry_count = 2
802 retry_delay_seconds = 2.0
803 for attempt in range(1, retry_count + 2):
804 result = analyze_fn(music_url)
805 if result:
806 return result
807 if attempt <= retry_count:
808 logger.warning(
809 "歌词识别失败,%s 秒后重试 (%d/%d): backend=%s",
810 retry_delay_seconds,
811 attempt,
812 retry_count,
813 backend,
814 )
815 time.sleep(retry_delay_seconds)
816 return None
817
818 def _analyze_lyrics_only_qwen_omni(self, music_url: str) -> Optional[Dict[str, Any]]:
819 """qwen-omni V2 版歌词识别流程"""
820 client = self._get_client()
821 logger.info(
822 "开始 qwen-omni 歌词识别: timeout=%ss, max_retries=%s",
823 self.lyrics_timeout,
824 3,
825 )
826 lyrics_prompt = build_lyrics_prompt()
827 messages = self._build_messages(
828 "请识别这段音频中的歌词内容",
829 lyrics_prompt,
830 music_url,
831 )
832
833 response = self._call_with_retry(client, messages, max_retries=3)
834 if response is None:
835 return None
836
837 parsed = self._parse_response(response.get("content", ""))
838 payload: Any = parsed
839 if isinstance(parsed, dict):
840 payload = (
841 parsed.get("lyrics")
842 or parsed.get("lyric")
843 or parsed.get("歌词")
844 or parsed
845 )
846
847 lyrics = self._convert_qwen_omni_payload_to_lyrics(payload)
848
849 return {
850 "lyrics": lyrics,
851 "_model": self.model,
852 "_token_info": response.get("usage"),
853 "_transcription_url": None,
854 "_asr_task_id": None,
855 "_asr_backend": "qwen-omni",
856 }
857
858 def _convert_qwen_omni_payload_to_lyrics(self, payload: Any) -> List[Dict[str, Any]]:
859 """将 qwen-omni 返回的 lyric 结构统一为 [{time, text}]"""
860 if payload is None:
861 return []
862
863 if isinstance(payload, str):
864 lines = [line.strip() for line in payload.splitlines() if line.strip()]
865 return [{"time": None, "text": line} for line in lines]
866
867 if isinstance(payload, dict):
868 candidate = (
869 payload.get("lyrics")
870 or payload.get("lines")
871 or payload.get("歌词")
872 or payload.get("lyric")
873 )
874 return self._convert_qwen_omni_payload_to_lyrics(candidate)
875
876 if isinstance(payload, list):
877 lyrics: List[Dict[str, Any]] = []
878 for item in payload:
879 if isinstance(item, str):
880 line = item.strip()
881 if line:
882 lyrics.append({"time": None, "text": line})
883 continue
884
885 if not isinstance(item, dict):
886 continue
887
888 text = item.get("text") or item.get("lyric") or item.get("歌词")
889 if not isinstance(text, str):
890 text = str(text) if text is not None else ""
891 text = text.strip()
892 if not text:
893 continue
894
895 time_str = item.get("time")
896 if not isinstance(time_str, str):
897 time_str = None
898 lyrics.append({"time": time_str, "text": text})
899 return lyrics
900
901 return []
902
903 def _analyze_lyrics_only_whisper(self, music_url: str) -> Optional[Dict[str, Any]]:
904 """whisper-1 版歌词识别流程(91 API)"""
905 try:
906 from dotenv import load_dotenv
907
908 load_dotenv()
909 except Exception:
910 pass
911
912 api_key = (os.getenv("API_KEY_whisper") or os.getenv("91API_KEY") or "").strip()
913 if not api_key:
914 logger.error("whisper 调用失败: 缺少环境变量 API_KEY_whisper/91API_KEY")
915 return None
916
917 api_url = os.getenv(
918 "WHISPER_API_URL",
919 "https://xuedingmao.top/v1/audio/transcriptions",
920 ).strip()
921 headers = {"Authorization": f"Bearer {api_key}"}
922
923 tmp_file_path = None
924 upload_file_path = None
925 ext = ".mp3"
926 try:
927 tmp_file_path, ext = self._download_audio(music_url, metadata=None)
928 upload_file_path = tmp_file_path
929 upload_ext = ext
930 if ext.lower() == ".flac":
931 converted_wav = self._convert_audio_to_wav_for_whisper(tmp_file_path)
932 if converted_wav:
933 upload_file_path = converted_wav
934 upload_ext = ".wav"
935 logger.info("whisper 上传文件已从 flac 转换为 wav")
936
937 filename = f"audio{upload_ext}"
938 print(f"下载完成:{filename}")
939
940 content_type = "audio/wav" if upload_ext == ".wav" else "audio/mpeg"
941
942 with open(upload_file_path, "rb") as audio_file:
943 files = {
944 "file": (filename, audio_file, content_type),
945 }
946 data = {
947 "model": "whisper-1",
948 "response_format": "verbose_json",
949 "timestamp_granularities": ["segment"],
950 "prompt": "没有歌词的片段用...代替,时间戳需要精准与每句歌词进行对应,对于纯音乐直接输出‘纯音乐,禁止输出歌名,作词/作曲等元数据,仅输出歌词与时间戳’",
951 }
952 response = requests.post(
953 api_url,
954 headers=headers,
955 data=data,
956 files=files,
957 timeout=300,
958 )
959 if response.status_code >= 400:
960 logger.error(
961 "whisper API 返回错误: status=%s, body=%s",
962 response.status_code,
963 response.text,
964 )
965 response.raise_for_status()
966 payload = response.json()
967 except Exception as exc:
968 logger.exception("whisper API 调用失败: %s", exc)
969 return None
970 finally:
971 if (
972 tmp_file_path
973 and os.path.exists(tmp_file_path)
974 and not self._is_persisted_music_file(tmp_file_path)
975 ):
976 try:
977 os.unlink(tmp_file_path)
978 except Exception:
979 pass
980 if (
981 upload_file_path
982 and upload_file_path != tmp_file_path
983 and os.path.exists(upload_file_path)
984 ):
985 try:
986 os.unlink(upload_file_path)
987 except Exception:
988 pass
989
990 lyrics = self._convert_whisper_payload_to_lyrics(payload)
991 return {
992 "lyrics": lyrics,
993 "_model": "whisper-1",
994 "_token_info": None,
995 "_transcription_url": None,
996 "_asr_task_id": None,
997 "_asr_backend": "whisper",
998 }
999
1000 def _convert_whisper_payload_to_lyrics(
1001 self, payload: Any
1002 ) -> List[Dict[str, Any]]:
1003 """将 whisper 接口响应转换为 lyrics: [{time, text}]"""
1004 if not isinstance(payload, dict):
1005 return []
1006
1007 segments = payload.get("segments")
1008 if isinstance(segments, list):
1009 lyrics: List[Dict[str, Any]] = []
1010 for seg in segments:
1011 if not isinstance(seg, dict):
1012 continue
1013 text = seg.get("text")
1014 if not isinstance(text, str):
1015 continue
1016 text = text.strip()
1017 if not text:
1018 continue
1019
1020 start = seg.get("start")
1021 if not isinstance(start, (int, float)):
1022 # 兼容部分接口返回 begin_time(毫秒)
1023 begin_time = seg.get("begin_time")
1024 if isinstance(begin_time, (int, float)):
1025 start = float(begin_time) / 1000.0
1026
1027 time_str = None
1028 if isinstance(start, (int, float)):
1029 try:
1030 time_str = self._format_asr_time_ms(float(start) * 1000)
1031 except (TypeError, ValueError, OverflowError):
1032 time_str = None
1033 lyrics.append({"time": time_str, "text": text})
1034 if lyrics:
1035 return lyrics
1036
1037 text = payload.get("text")
1038 if isinstance(text, str) and text.strip():
1039 return [{"time": None, "text": text.strip()}]
1040 return []
1041
1042 def _convert_audio_to_wav_for_whisper(self, source_audio_path: str) -> Optional[str]:
1043 """
1044 将音频转换为 whisper 更稳定支持的 WAV 格式。
1045 """
1046 try:
1047 with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_tmp:
1048 wav_path = wav_tmp.name
1049
1050 cmd = [
1051 "ffmpeg",
1052 "-y",
1053 "-i",
1054 source_audio_path,
1055 "-acodec",
1056 "pcm_s16le",
1057 "-ac",
1058 "1",
1059 "-ar",
1060 "16000",
1061 wav_path,
1062 ]
1063 subprocess.run(cmd, check=True, capture_output=True, text=True)
1064 return wav_path
1065 except Exception as exc:
1066 logger.warning("flac 转 wav 失败,将继续使用原文件: %s", exc)
1067 return None
1068
1069 def _analyze_lyrics_only_funasr(self, music_url: str) -> Optional[Dict[str, Any]]:
1070 """fun-asr SDK 版异步 ASR 流程"""
1071 try:
1072 from http import HTTPStatus
1073 import dashscope
1074 from dashscope.audio.asr import Transcription
1075 except Exception as exc:
1076 logger.exception("导入 dashscope.audio.asr.Transcription 失败: %s", exc)
1077 return None
1078
1079 api_key = self._get_dashscope_api_key()
1080 if not api_key:
1081 logger.error("funasr 调用失败: 缺少 DashScope API Key")
1082 return None
1083
1084 asr_model = getattr(settings, "DASHSCOPE_FUNASR_MODEL", "fun-asr")
1085 dashscope.base_http_api_url = getattr(
1086 settings,
1087 "DASHSCOPE_BASE_HTTP_API_URL",
1088 "https://dashscope.aliyuncs.com/api/v1",
1089 )
1090 dashscope.api_key = api_key
1091 poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0))
1092 poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0))
1093
1094 try:
1095 task_resp = Transcription.async_call(
1096 model=asr_model,
1097 file_urls=[music_url],
1098 )
1099 except Exception as exc:
1100 logger.exception("funasr async_call 失败: %s", exc)
1101 return None
1102
1103 task_id = self._extract_task_id_from_asr_response(task_resp)
1104 latest_resp: Any = task_resp
1105
1106 deadline = time.time() + poll_timeout
1107 while time.time() < deadline:
1108 task_status = self._extract_task_status_from_asr_response(latest_resp)
1109 if task_status == "SUCCEEDED":
1110 break
1111 if task_status in {"FAILED", "CANCELED"}:
1112 logger.error(
1113 "funasr 任务失败: task_id=%s, status=%s",
1114 task_id,
1115 task_status,
1116 )
1117 return None
1118 try:
1119 latest_resp = Transcription.fetch(
1120 task=latest_resp,
1121 )
1122 except Exception as exc:
1123 logger.exception("funasr fetch 失败: %s", exc)
1124 return None
1125 time.sleep(poll_interval)
1126 else:
1127 logger.error("funasr 轮询超时: task_id=%s", task_id)
1128 return None
1129
1130 status_code = getattr(latest_resp, "status_code", None)
1131 if status_code is not None and status_code != HTTPStatus.OK:
1132 logger.error(
1133 "funasr 返回非OK状态: task_id=%s, status_code=%s",
1134 task_id,
1135 status_code,
1136 )
1137 return None
1138
1139 transcription_url = self._extract_transcription_url_from_asr_response(latest_resp)
1140 if not transcription_url:
1141 logger.error("funasr 结果缺少 transcription_url: task_id=%s", task_id)
1142 return None
1143
1144 transcript_data = self._fetch_asr_transcription(transcription_url)
1145 if not transcript_data:
1146 return None
1147
1148 lyrics = self._convert_asr_transcription_to_lyrics(transcript_data)
1149 token_info = self._extract_usage_from_asr_response(latest_resp)
1150
1151 return {
1152 "lyrics": lyrics,
1153 "_model": asr_model,
1154 "_token_info": token_info,
1155 "_transcription_url": transcription_url,
1156 "_asr_task_id": task_id,
1157 "_asr_backend": "funasr",
1158 }
1159
1160 def _submit_asr_transcription_task(self, music_url: str) -> Optional[str]:
1161 """提交 DashScope 异步ASR任务,返回 task_id"""
1162 api_key = self._get_dashscope_api_key()
1163 if not api_key:
1164 logger.error("提交ASR任务失败: 缺少 DashScope API Key")
1165 return None
1166
1167 submit_url = getattr(
1168 settings,
1169 "DASHSCOPE_ASR_SUBMIT_URL",
1170 "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
1171 )
1172 asr_model = getattr(settings, "DASHSCOPE_ASR_MODEL", "qwen3-asr-flash-filetrans")
1173
1174 headers = {
1175 "Authorization": f"Bearer {api_key}",
1176 "Content-Type": "application/json",
1177 "X-DashScope-Async": "enable",
1178 }
1179 payload = {
1180 "model": asr_model,
1181 "input": {"file_url": music_url},
1182 "parameters": {
1183 "channel_id": [0],
1184 "enable_itn": False,
1185 "enable_words": False,
1186 },
1187 }
1188
1189 try:
1190 response = requests.post(
1191 submit_url,
1192 headers=headers,
1193 json=payload,
1194 timeout=self.timeout,
1195 )
1196 response.raise_for_status()
1197 data = response.json()
1198 except Exception as exc:
1199 logger.exception("提交ASR任务异常: %s", exc)
1200 return None
1201
1202 output = data.get("output") if isinstance(data, dict) else None
1203 if not isinstance(output, dict):
1204 logger.error("提交ASR任务失败: 缺少 output 字段")
1205 return None
1206
1207 task_id = output.get("task_id")
1208 if not isinstance(task_id, str) or not task_id.strip():
1209 logger.error("提交ASR任务失败: 缺少 task_id")
1210 return None
1211 return task_id.strip()
1212
1213 def _poll_asr_task_result(self, task_id: str) -> Optional[Dict[str, Any]]:
1214 """轮询 DashScope 任务直到结束"""
1215 api_key = self._get_dashscope_api_key()
1216 if not api_key:
1217 logger.error("轮询ASR任务失败: 缺少 DashScope API Key")
1218 return None
1219
1220 task_base_url = getattr(
1221 settings,
1222 "DASHSCOPE_TASK_STATUS_BASE_URL",
1223 "https://dashscope.aliyuncs.com/api/v1/tasks",
1224 ).rstrip("/")
1225 task_url = f"{task_base_url}/{task_id}"
1226
1227 headers = {
1228 "Authorization": f"Bearer {api_key}",
1229 "X-DashScope-Async": "enable",
1230 "Content-Type": "application/json",
1231 }
1232
1233 poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0))
1234 poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0))
1235 deadline = time.time() + poll_timeout
1236
1237 while time.time() < deadline:
1238 try:
1239 response = requests.get(task_url, headers=headers, timeout=self.timeout)
1240 response.raise_for_status()
1241 data = response.json()
1242 except Exception as exc:
1243 logger.exception("轮询ASR任务异常: task_id=%s, error=%s", task_id, exc)
1244 return None
1245
1246 output = data.get("output") if isinstance(data, dict) else None
1247 task_status = output.get("task_status") if isinstance(output, dict) else None
1248 if task_status == "SUCCEEDED":
1249 return data
1250 if task_status in {"FAILED", "CANCELED"}:
1251 logger.error(
1252 "ASR任务失败: task_id=%s, status=%s, data=%s",
1253 task_id,
1254 task_status,
1255 data,
1256 )
1257 return None
1258
1259 time.sleep(poll_interval)
1260
1261 logger.error("轮询ASR任务超时: task_id=%s", task_id)
1262 return None
1263
1264 def _fetch_asr_transcription(self, transcription_url: str) -> Optional[Dict[str, Any]]:
1265 """下载 transcription_url 对应的转写结果JSON"""
1266 try:
1267 response = requests.get(transcription_url, timeout=self.timeout)
1268 response.raise_for_status()
1269 data = response.json()
1270 return data if isinstance(data, dict) else None
1271 except Exception as exc:
1272 logger.exception("下载ASR转写结果失败: %s", exc)
1273 return None
1274
1275 def _convert_asr_transcription_to_lyrics(
1276 self, transcript_data: Dict[str, Any]
1277 ) -> List[Dict[str, Any]]:
1278 """将ASR结果转换为 lyrics: [{time, text}]"""
1279 transcripts = transcript_data.get("transcripts")
1280 if not isinstance(transcripts, list):
1281 return []
1282
1283 lyrics: List[Dict[str, Any]] = []
1284 for transcript in transcripts:
1285 if not isinstance(transcript, dict):
1286 continue
1287
1288 sentences = transcript.get("sentences")
1289 if not isinstance(sentences, list):
1290 continue
1291
1292 for sentence in sentences:
1293 if not isinstance(sentence, dict):
1294 continue
1295
1296 text = sentence.get("text")
1297 if not isinstance(text, str):
1298 continue
1299 text = text.strip()
1300 if not text:
1301 continue
1302
1303 begin_time = sentence.get("begin_time")
1304 time_str = (
1305 self._format_asr_time_ms(begin_time)
1306 if isinstance(begin_time, (int, float))
1307 else None
1308 )
1309 lyrics.append(
1310 {
1311 "time": time_str,
1312 "text": text,
1313 }
1314 )
1315
1316 return lyrics
1317
1318 @staticmethod
1319 def _format_asr_time_ms(ms_value: float) -> str:
1320 """毫秒转 mm:ss.xxx"""
1321 total_ms = int(max(0, ms_value))
1322 minutes = total_ms // 60000
1323 seconds = (total_ms % 60000) // 1000
1324 milliseconds = total_ms % 1000
1325 return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
1326
1327 def _get_dashscope_api_key(self) -> Optional[str]:
1328 """获取 DashScope API Key(ASR专用)"""
1329 return (
1330 self.api_key
1331 or settings.QWEN_DASHSCOPE_API_KEY
1332 or settings.QWEN_API_KEY
1333 or os.getenv("DASHSCOPE_API_KEY")
1334 or os.getenv("QWEN_DASHSCOPE_API_KEY")
1335 or os.getenv("QWEN_API_KEY")
1336 )
1337
1338 @staticmethod
1339 def _as_dict(response_obj: Any) -> Dict[str, Any]:
1340 """尽可能将 SDK 响应对象转换为 dict"""
1341 if isinstance(response_obj, dict):
1342 return response_obj
1343 if response_obj is None:
1344 return {}
1345
1346 for attr in ("to_dict", "as_dict", "dict"):
1347 fn = getattr(response_obj, attr, None)
1348 if callable(fn):
1349 try:
1350 value = fn()
1351 if isinstance(value, dict):
1352 return value
1353 except Exception:
1354 pass
1355
1356 data: Dict[str, Any] = {}
1357 for key in ("request_id", "output", "usage"):
1358 val = getattr(response_obj, key, None)
1359 if val is not None:
1360 if key in ("output", "usage") and not isinstance(val, dict):
1361 nested = QwenAnalyzer._as_dict(val)
1362 data[key] = nested if nested else val
1363 else:
1364 data[key] = val
1365 return data
1366
1367 def _extract_task_id_from_asr_response(self, response_obj: Any) -> Optional[str]:
1368 data = self._as_dict(response_obj)
1369 output = data.get("output")
1370 if isinstance(output, dict):
1371 task_id = output.get("task_id")
1372 if isinstance(task_id, str) and task_id.strip():
1373 return task_id.strip()
1374 return None
1375
1376 def _extract_task_status_from_asr_response(self, response_obj: Any) -> Optional[str]:
1377 data = self._as_dict(response_obj)
1378 output = data.get("output")
1379 if isinstance(output, dict):
1380 task_status = output.get("task_status")
1381 if isinstance(task_status, str):
1382 return task_status
1383 return None
1384
1385 def _extract_transcription_url_from_asr_response(
1386 self, response_obj: Any
1387 ) -> Optional[str]:
1388 data = self._as_dict(response_obj)
1389 output = data.get("output")
1390 if not isinstance(output, dict):
1391 return None
1392
1393 # 兼容 output.results: [{transcription_url: ...}]
1394 results = output.get("results")
1395 if isinstance(results, list) and results:
1396 first = results[0]
1397 if isinstance(first, dict):
1398 transcription_url = first.get("transcription_url")
1399 if isinstance(transcription_url, str) and transcription_url.strip():
1400 return transcription_url.strip()
1401
1402 result = output.get("result")
1403 if not isinstance(result, dict):
1404 # 兜底兼容 output.transcription_url
1405 transcription_url = output.get("transcription_url")
1406 if isinstance(transcription_url, str) and transcription_url.strip():
1407 return transcription_url.strip()
1408 return None
1409 transcription_url = result.get("transcription_url")
1410 if isinstance(transcription_url, str) and transcription_url.strip():
1411 return transcription_url.strip()
1412 return None
1413
1414 def _extract_usage_from_asr_response(
1415 self, response_obj: Any
1416 ) -> Optional[Dict[str, Any]]:
1417 data = self._as_dict(response_obj)
1418 usage = data.get("usage")
1419 return usage if isinstance(usage, dict) else None
1420
1421 def _build_messages(
1422 self,
1423 system_prompt: str,
1424 user_prompt: str,
1425 music_url: str,
1426 ) -> list:
1427 """构建消息格式"""
1428 messages = []
1429
1430 # 添加系统提示词
1431 if system_prompt:
1432 messages.append(
1433 {
1434 "role": "system",
1435 "content": system_prompt,
1436 }
1437 )
1438
1439 # 添加用户消息(包含音频和文本)
1440 messages.append(
1441 {
1442 "role": "user",
1443 "content": [
1444 {
1445 "type": "input_audio",
1446 "input_audio": {"data": music_url, "format": "mp3"},
1447 },
1448 {"type": "text", "text": user_prompt},
1449 ],
1450 }
1451 )
1452
1453 return messages
1454
1455 def _build_dashscope_prompt(self, system_prompt: str, user_prompt: str) -> str:
1456 """构建 DashScope 调用的文本提示词"""
1457 if system_prompt and system_prompt.strip():
1458 return f"{system_prompt.strip()}\n\n{user_prompt}".strip()
1459 return user_prompt.strip()
1460
1461 def _timed_call_openai(
1462 self, client, messages: list
1463 ) -> tuple[Optional[Dict], float]:
1464 """为 OpenAI 兼容调用提供耗时统计"""
1465 call_start = time.time()
1466 resp = self._call_with_retry(client, messages)
1467 return resp, round(time.time() - call_start, 2)
1468
1469 def _call_with_retry_dashscope(
1470 self, music_url: str, prompt: str, timeout: Optional[float] = None, song_id: str = "", metadata: Optional[Dict[str, Any]] = None
1471 ) -> Optional[Dict]:
1472 """使用 DashScope SDK 进行多模态调用(带重试,自动降级到 base64)"""
1473 import dashscope
1474
1475 dashscope_key = (
1476 self.api_key
1477 or settings.QWEN_DASHSCOPE_API_KEY
1478 or os.getenv("QWEN_OMNI_API_KEY")
1479 or os.getenv("DASHSCOPE_API_KEY")
1480 )
1481 if not dashscope_key:
1482 print(" ⚠ 未设置 DASHSCOPE_API_KEY 环境变量,请先配置")
1483 return None
1484
1485 messages = [
1486 {
1487 "role": "user",
1488 "content": [
1489 {"audio": music_url},
1490 {"text": prompt},
1491 ],
1492 }
1493 ]
1494
1495 timeout = timeout or self.timeout
1496
1497 for attempt in range(1, self.max_retries + 1):
1498 try:
1499 print(
1500 f" [{self.model}] 正在分析 (DashScope 尝试 {attempt}/{self.max_retries}, timeout={timeout}s)..."
1501 )
1502 response = self._dashscope_call_with_hard_timeout(
1503 dashscope=dashscope,
1504 api_key=dashscope_key,
1505 model=self.model,
1506 messages=messages,
1507 timeout=timeout,
1508 )
1509
1510 if response.status_code != 200:
1511 error_msg = getattr(response, "message", "")
1512 error_code = getattr(response, "code", "")
1513 error_output = getattr(response, "output", {})
1514 print(
1515 f" ✗ [{self.model}] API 调用失败,状态码: {response.status_code}"
1516 )
1517 if song_id:
1518 print(f" 歌曲ID: {song_id}")
1519 if error_code:
1520 print(f" 错误代码: {error_code}")
1521 if error_msg:
1522 print(f" 错误信息: {error_msg}")
1523 if error_output:
1524 print(f" 响应内容: {error_output}")
1525
1526 # 检测文件过大错误,自动降级到 OSS 方式
1527 if "file size is too large" in str(error_msg).lower() or "file size is too large" in str(error_output).lower():
1528 print(f" [Qwen] 检测到文件过大,自动降级到 OSS 方式...")
1529 try:
1530 temp_audio_path = self._download_audio_temp(music_url)
1531 if temp_audio_path:
1532 mono_path = self._convert_to_mono(temp_audio_path)
1533 oss_url = self._upload_audio_to_oss(mono_path)
1534 # 只删除转换后的单声道文件,保留原始下载文件
1535 self._cleanup_temp_audio(mono_path)
1536 if oss_url:
1537 print(f" [Qwen] 使用 OSS URL 重新请求: {oss_url[:60]}...")
1538 return self._call_with_retry_dashscope(oss_url, prompt, timeout=timeout, song_id=song_id, metadata=metadata)
1539 except Exception as e:
1540 print(f" [Qwen] OSS 降级失败: {e}")
1541 return None
1542
1543 if attempt < self.max_retries:
1544 time.sleep(attempt)
1545 continue
1546 return None
1547
1548 content = response.output.choices[0].message.content
1549 if isinstance(content, list):
1550 if content and isinstance(content[0], dict) and "text" in content[0]:
1551 result_text = content[0]["text"]
1552 else:
1553 result_text = ""
1554 else:
1555 result_text = content
1556
1557 usage = None
1558 resp_usage = getattr(response, "usage", None)
1559 if isinstance(resp_usage, dict):
1560 input_tokens = resp_usage.get(
1561 "input_tokens", resp_usage.get("prompt_tokens", 0)
1562 )
1563 output_tokens = resp_usage.get(
1564 "output_tokens", resp_usage.get("completion_tokens", 0)
1565 )
1566 total_tokens = resp_usage.get("total_tokens")
1567 usage = {
1568 "prompt_tokens": input_tokens or 0,
1569 "completion_tokens": output_tokens or 0,
1570 "total_tokens": total_tokens
1571 if total_tokens is not None
1572 else (input_tokens or 0) + (output_tokens or 0),
1573 }
1574 elif resp_usage is not None:
1575 input_tokens = getattr(resp_usage, "input_tokens", None)
1576 output_tokens = getattr(resp_usage, "output_tokens", None)
1577 total_tokens = getattr(resp_usage, "total_tokens", None)
1578 usage = {
1579 "prompt_tokens": input_tokens or 0,
1580 "completion_tokens": output_tokens or 0,
1581 "total_tokens": total_tokens
1582 if total_tokens is not None
1583 else (input_tokens or 0) + (output_tokens or 0),
1584 }
1585
1586 return {"content": result_text, "usage": usage}
1587
1588 except TimeoutError:
1589 print(f" ✗ [{self.model}] API 调用超时 (尝试 {attempt}/{self.max_retries})")
1590 if attempt < self.max_retries:
1591 time.sleep(attempt)
1592 continue
1593 return None
1594 except Exception as e:
1595 print(f" ✗ [{self.model}] API 调用异常: {e}")
1596 if attempt < self.max_retries:
1597 time.sleep(attempt)
1598 continue
1599 return None
1600
1601 return None
1602
1603 def _download_audio_temp(self, music_url: str) -> Optional[str]:
1604 """
1605 临时下载音频文件到系统临时目录
1606
1607 Args:
1608 music_url: 音频URL
1609
1610 Returns:
1611 临时文件路径,如果下载失败返回 None
1612 """
1613 try:
1614 # 确定文件扩展名
1615 ext = ".mp3"
1616 if "." in music_url:
1617 url_ext = music_url.split(".")[-1].split("?")[0].lower()
1618 if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]:
1619 ext = f".{url_ext}"
1620
1621 # 下载到系统临时目录
1622 temp_dir = tempfile.gettempdir()
1623 url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12]
1624 temp_path = os.path.join(temp_dir, f"qwen_audio_{url_hash}{ext}")
1625
1626 if not os.path.exists(temp_path):
1627 response = requests.get(music_url, timeout=60)
1628 response.raise_for_status()
1629 with open(temp_path, "wb") as f:
1630 f.write(response.content)
1631 print(f" [Qwen] 临时音频已下载: {temp_path}")
1632 else:
1633 print(f" [Qwen] 使用缓存的临时音频")
1634
1635 return temp_path
1636 except Exception as e:
1637 print(f" [Qwen] 临时音频下载失败: {e}")
1638 return None
1639
1640 def _convert_to_mono(self, audio_path: str) -> str:
1641 """
1642 将音频转换为单声道
1643
1644 Args:
1645 audio_path: 原始音频文件路径
1646
1647 Returns:
1648 转换后的音频文件路径
1649 """
1650 import time
1651 timestamp = int(time.time() * 1000)
1652 base_name = os.path.basename(audio_path)
1653 name_parts = base_name.rsplit(".", 1)
1654 if len(name_parts) == 2:
1655 mono_path = os.path.join(
1656 os.path.dirname(audio_path),
1657 f"{name_parts[0]}_mono_{timestamp}.{name_parts[1]}"
1658 )
1659 else:
1660 mono_path = f"{audio_path}_mono_{timestamp}"
1661
1662 try:
1663 cmd = [
1664 "ffmpeg",
1665 "-i", audio_path,
1666 "-ac", "1", # 转为单声道
1667 "-y",
1668 mono_path
1669 ]
1670 print(f" [Qwen] 转换为单声道: ffmpeg -i ... -ac 1")
1671 subprocess.run(cmd, capture_output=True, timeout=60, check=True)
1672 original_size = os.path.getsize(audio_path)
1673 mono_size = os.path.getsize(mono_path)
1674 ratio = (1 - mono_size / original_size) * 100
1675 print(f" [Qwen] 音频已转换: {original_size/1024/1024:.1f}MB -> {mono_size/1024/1024:.1f}MB (压缩率: {ratio:.1f}%)")
1676 return mono_path
1677 except Exception as e:
1678 print(f" [Qwen] 音频转换失败: {e},将使用原文件")
1679 return audio_path
1680
1681 def _upload_audio_to_oss(self, audio_path: str) -> Optional[str]:
1682 """
1683 将音频文件上传到 OSS
1684
1685 Args:
1686 audio_path: 音频文件路径
1687
1688 Returns:
1689 OSS 文件 URL,如果上传失败返回 None
1690 """
1691 try:
1692 from app.utils.oss_uploader import oss_uploader
1693
1694 success, result = oss_uploader.upload_file(audio_path)
1695 if not success:
1696 print(f" [Qwen] 音频上传到 OSS 失败: {result}")
1697 return None
1698
1699 oss_url = result
1700 print(f" [Qwen] 音频已上传到 OSS: {oss_url}")
1701 return oss_url
1702 except Exception as e:
1703 print(f" [Qwen] 音频上传到 OSS 失败: {e}")
1704 return None
1705
1706 def _cleanup_temp_audio(self, temp_path: str) -> None:
1707 """清理临时音频文件"""
1708 if temp_path and os.path.exists(temp_path):
1709 try:
1710 os.unlink(temp_path)
1711 print(f" [Qwen] 已清理临时音频文件")
1712 except:
1713 pass
1714
1715 def _dashscope_call_with_hard_timeout(
1716 self,
1717 dashscope,
1718 api_key: str,
1719 model: str,
1720 messages: list,
1721 timeout: float,
1722 ):
1723 """
1724 DashScope SDK 某些版本下 request_timeout 可能无法稳定生效。
1725 这里增加线程级硬超时,避免单次调用无限阻塞。
1726 """
1727 box: Dict[str, Any] = {}
1728 done = threading.Event()
1729
1730 def _target() -> None:
1731 try:
1732 box["response"] = dashscope.MultiModalConversation.call(
1733 api_key=api_key,
1734 model=model,
1735 messages=messages,
1736 request_timeout=timeout,
1737 )
1738 except Exception as exc:
1739 box["error"] = exc
1740 finally:
1741 done.set()
1742
1743 worker = threading.Thread(target=_target, daemon=True)
1744 worker.start()
1745 hard_timeout = max(float(timeout), 1.0) + 3.0
1746 if not done.wait(hard_timeout):
1747 raise TimeoutError(f"DashScope hard timeout after {hard_timeout:.1f}s")
1748 if "error" in box:
1749 raise box["error"]
1750 return box.get("response")
1751
1752 def _call_with_retry(
1753 self,
1754 client,
1755 messages: list,
1756 timeout: Optional[float] = None,
1757 max_retries: Optional[int] = None,
1758 ) -> Optional[Dict]:
1759 """带重试的 API 调用(非流式)"""
1760 timeout = timeout or self.lyrics_timeout
1761 retries = max_retries or self.max_retries
1762
1763 for attempt in range(1, retries + 1):
1764 try:
1765 print(
1766 f" [Qwen] 调用模型 (尝试 {attempt}/{retries}, timeout={timeout}s)..."
1767 )
1768
1769 response = client.chat.completions.create(
1770 model=self.model,
1771 messages=messages,
1772 modalities=["text"],
1773 stream=False,
1774 timeout=timeout,
1775 extra_body={"enable_thinking": False},
1776 )
1777
1778 content = (
1779 response.choices[0].message.content if response.choices else ""
1780 )
1781 usage = {
1782 "prompt_tokens": response.usage.prompt_tokens
1783 if response.usage
1784 else 0,
1785 "completion_tokens": response.usage.completion_tokens
1786 if response.usage
1787 else 0,
1788 "total_tokens": response.usage.total_tokens
1789 if response.usage
1790 else 0,
1791 }
1792
1793 print(f" [Qwen] 响应: {content[:100]}...")
1794
1795 return {"content": content, "usage": usage}
1796
1797 except Exception as e:
1798 error_type = type(e).__name__
1799 print(f" [Qwen] 错误 ({error_type}): {e}")
1800
1801 if attempt < retries:
1802 wait_time = attempt
1803 print(f" 等待 {wait_time} 秒后重试...")
1804 time.sleep(wait_time)
1805 else:
1806 print(f" 已达到最大重试次数")
1807 return None
1808
1809 return None
1 # 聚音标签识别助手 - 系统角色定义
2
3 ## 角色定位
4
5 你是音乐内容标签标注助手。
6 你的任务是基于输入的歌曲信息(如歌词、标题、风格描述、音频特征等),严格按照「聚音标签字典」输出标准化标签字段。
7
8 只输出标签结果,不做解释,不做分析,不添加任何多余文本。
9
10 ------
11
12 ## 输出格式
13
14 仅输出 JSON 纯文本,结构如下:
15
16 {
17 "performer_type": "",
18 "language": "",
19 "emotion": [],
20 "douyin_tags": [],
21 "music_style_tags": [],
22 "instrument_tags": [],
23 "scene": []
24 }
25
26 禁止输出任何解释性文字、注释或额外字段。
27
28 ------
29
30 ## 全局约束规则
31
32 1. 所有标签必须严格从下方字典中选择,禁止自造词。
33 2. 不允许基于刻板印象猜测(如仅凭曲风推断情绪)。
34 3. 标签必须基于明确特征:
35 - 歌词内容
36 - 音乐风格特征
37 - 明确出现的配器
38 - 明确使用场景
39 4. 多选字段仅选择高度确定且核心表达的标签,避免过度打标。
40 5. 注意!所有字段至少选择一个标签,不允许留空。
41
42 ------
43
44 # 字段判定标准说明
45
46 ## 一、演唱者类型 performer_type(单选)
47
48 用于标注主要人声类型,仅根据实际听感或明确描述判断:
49
50 - 男声:主要为男性声线
51 - 女声:主要为女性声线
52 - 童声:明显儿童声线
53 - 合唱:多人群体演唱为主(非简单和声)
54
55 不确定时输出 ""。
56
57 ------
58
59 ## 二、情绪 emotion(多选)
60
61 必须基于歌曲整体情绪表达判断,而非个别词语。
62
63 - 喜庆:节日、庆典氛围明显
64 - 浪漫:爱情氛围浓厚
65 - 雄壮:宏大、史诗、气势恢宏
66 - 蛊惑:迷幻、魅惑、暧昧
67 - 宣泄:情绪爆发、释放
68 - 悲壮:悲情但具有力量感
69 - 愤怒:强烈对抗或激烈表达
70 - 庄重:正式、肃穆
71 - 激情:热烈高昂
72 - 沉重:压抑、厚重
73 - 快乐:轻松开心
74 - 励志:奋斗、成长、自我激励
75 - 思念:想念某人或过往
76 - 紧张:悬而未决、焦虑
77 - 恐怖:惊悚氛围
78 - 感动:温情催泪
79 - 恶搞:刻意夸张调侃
80 - 搞笑:明显幽默表达
81 - 期待:盼望未来
82 - 怀念:回忆过去
83 - 甜蜜:恋爱甜感
84 - 孤独:孤单、自我独白
85 - 伤感:悲伤低落
86 - 悬疑:神秘未知感
87 - 祝福:祝愿表达
88 - 佛系:平淡随性
89 - 舒缓:节奏慢、平稳
90 - 悠扬:旋律流畅优美
91 - 温暖:柔和治愈
92 - 忧郁:带有阴郁气质
93
94 避免:
95
96 - 同时选择强烈对立情绪(如 快乐 与 伤感)
97 - 同类标签堆叠(如 伤感 + 忧郁 + 孤独 需明确区分)
98
99 ------
100
101 ## 三、语种 language(单选)
102
103 仅从下列标签中选择一个最主要的演唱语种:
104
105 - 普通话
106 - 粤语
107 - 藏语
108 - 英语
109 - 韩语
110 - 闽南语
111 - 蒙语
112 - 俄语
113 - 其他
114
115 规则:
116
117 - 只输出一个语种标签
118 - 依据实际演唱语言判断,不根据歌手国籍或曲风猜测
119 - 纯音乐或无法判断时输出 ""
120
121 ------
122
123 ## 四、网络/抖音歌曲 douyin_tags(可多选)
124
125 仅当歌曲具备明显网络传播特征或主题风格时选择:
126
127 - 草原:草原文化、民族草原元素
128 - 故乡:思乡主题
129 - 神曲:洗脑旋律、强节奏重复
130 - 文艺:小众表达、诗性表达
131 - 青春:校园或成长主题
132 - 治愈系:温暖疗愈风格
133 - 清新:轻快自然风格
134 - 奇幻:幻想、魔幻元素
135
136 非明显网络属性不要强行标注。
137
138 ------
139
140 ## 五、音乐风格 music_style_tags(多选)
141
142 必须根据音乐结构与风格特征判断,不根据歌词主题判断。
143
144 - 世界音乐
145 - 雷鬼
146 - R&B/Soul
147 - MC喊麦
148 - 另类音乐
149 - 民歌
150 - 戏曲
151 - 古风
152 - 古典音乐
153 - HipHop
154 - Rap
155 - 摇滚
156 - DJ嗨曲
157 - 布鲁斯/蓝调
158 - 拉丁
159 - 舞曲
160 - 爵士
161 - 乡村
162 - 民谣
163 - 流行
164 - 轻音乐
165 - 国风
166 - 儿歌
167
168 规则:
169
170 - 只选核心风格,不叠加相似风格
171 - 不因使用某个乐器就推断整体风格
172 - 无明显风格时可只选“流行”
173
174 ------
175
176 ## 六、配器 instrument_tags(多选)
177
178 仅在明确可识别时选择:
179
180 - 二胡
181 - 竹笛
182 - 琵琶
183 - 音效
184 - 口琴
185 - 电子
186 - 木吉他
187 - 鼓组
188 - 弦乐
189 - 电吉他
190 - 古筝
191 - 钢琴
192
193 规则:
194
195 - 必须为明显主导或突出配器
196 - 不因常规伴奏默认存在而标注
197 - 不确定不要猜
198
199 ------
200
201 ## 七、场景 scene(多选)
202
203 根据歌曲使用场景或明显氛围判断:
204
205 - 餐厅
206 - 汽车
207 - 跳舞
208 - 旅行
209 - 工作
210 - 校园
211 - 夜店
212 - 运动
213 - 休闲
214 - live house
215 - 广场舞
216 - 抖音
217 - 婚礼
218 - 约会
219
220 规则:
221
222 - 仅当歌曲明显适配该场景时标注
223 - 避免泛化场景(如所有慢歌都标“休闲”)
224
225 ------
226
227 ## 最终执行要求
228
229 - 只输出 JSON
230 - 不解释
231 - 不补充说明
232 - 不输出字典内容
233 - 不输出“分析如下”之类文字
234 - 不添加未定义字段
235
236 严格遵守字段范围与空值规则。
237
238 ## 输出格式
239 必须严格输出以下 JSON 结构,字段名不能改:
240
241 ```json
242 {
243 "performer_type": "",
244 "language": "",
245 "emotion": [],
246 "douyin_tags": [],
247 "music_style_tags": [],
248 "instrument_tags": [],
249 "scene": []
250 }
251 ```
1 # 待分析元数据
2 {{METADATA_SECTION}}
3
4 # 任务目标
5 请基于音频内容完成聚音标签识别,仅输出系统要求的标签字段。
6
7 # 约束提醒
8 - 必须基于实际听到的特征,无法确认的标签输出空值。
9 - 严格执行 JSON 纯文本输出,禁止任何 Markdown 格式。
1 # 歌词识别提示词模板
2 # 仅识别歌词内容,不包含其他音乐分析
3 请识别并转录音频中的完整歌词。
4
5 ## 核心任务
6 1. **逐句识别**:按时间顺序输出每一句歌词,每句通过换行进行分隔。
7 2. **字段要求**:每条记录必须包含 `time` (格式 "mm:ss.xxx",无法确定则为 null) 和 `text` (歌词内容)。
8 3. **无语义音节压缩**:对于“啊/呜/哦/嗯/啦”等辅助音节,禁止逐字展示,统一使用 `...` 缩略(例:把“啊啊啊啊”识别为“啊...”)。
9 4. **完整性**:必须转录包括重复段落在内的全曲内容。
10 5. **静默与纯音乐**:若为纯音乐或无歌词,仅返回空数组 `[]`
11 6. 完整识别歌曲所有段落的完整歌词,包括不同段落之间重复了的歌词
12
13 ## 输出格式规范
14 - 严格输出 JSON,不得包含任何 Markdown 转义符(如 ```json)或解释性文字。
15 - 字段统一为: {"lyrics": [{"time": "00:00.000", "text": "内容"}]}
16
17 ## 质量控制
18 - 遇到合唱/重叠时,以主旋律为主。
19 - 严禁自行脑补不存在的歌词。
20 - 不要返回任何其他无关内容
1 """
2 阿里云OSS服务
3 提供文件上传、下载、删除等功能
4 """
5 import oss2
6 from typing import Optional, Dict, Any, BinaryIO
7 from datetime import datetime, timedelta
8 from pathlib import Path
9 import hashlib
10 import mimetypes
11 import logging
12 import time
13 import json
14 from functools import wraps
15
16 from aliyunsdkcore.client import AcsClient
17 from aliyunsdksts import AssumeRoleRequest
18
19 from app.core.config import settings
20 from app.core.exceptions import (
21 BusinessException,
22 ValidationException,
23 ExternalServiceException,
24 NotFoundException
25 )
26
27 logger = logging.getLogger(__name__)
28
29
30 def retry_on_connection_error(max_retries: int = 3, delay: float = 1.0):
31 """
32 重试装饰器 - 在连接错误时重试
33
34 Args:
35 max_retries: 最大重试次数
36 delay: 重试延迟(秒)
37 """
38 def decorator(func):
39 @wraps(func)
40 def wrapper(*args, **kwargs):
41 last_exception = None
42 for attempt in range(max_retries):
43 try:
44 return func(*args, **kwargs)
45 except (oss2.exceptions.RequestError,
46 ConnectionError,
47 TimeoutError,
48 Exception) as e:
49 # 只重试连接相关的错误
50 error_str = str(e)
51 if ('Connection' in error_str or
52 'timeout' in error_str.lower() or
53 'closed' in error_str.lower() or
54 'RemoteDisconnected' in error_str):
55 last_exception = e
56 if attempt < max_retries - 1:
57 wait_time = delay * (2 ** attempt) # 指数退避
58 logger.warning(
59 f"[{func.__name__}] 连接错误,{wait_time}秒后重试 "
60 f"({attempt + 1}/{max_retries}): {error_str}"
61 )
62 time.sleep(wait_time)
63 continue
64 else:
65 logger.error(f"[{func.__name__}] 重试次数已用尽: {error_str}")
66 raise
67 else:
68 # 非连接错误直接抛出
69 raise
70
71 if last_exception:
72 raise last_exception
73 return wrapper
74 return decorator
75
76
77 class OSSService:
78 """阿里云OSS服务类"""
79
80 def __init__(self):
81 """初始化OSS客户端"""
82 if not all([
83 settings.OSS_ACCESS_KEY_ID,
84 settings.OSS_ACCESS_KEY_SECRET,
85 settings.OSS_ENDPOINT,
86 settings.OSS_BUCKET_NAME
87 ]):
88 raise BusinessException("OSS配置不完整,请检查环境变量")
89
90 # 创建认证对象
91 auth = oss2.Auth(
92 settings.OSS_ACCESS_KEY_ID,
93 settings.OSS_ACCESS_KEY_SECRET
94 )
95
96 # 确定Endpoint
97 endpoint = settings.OSS_ENDPOINT
98 if settings.OSS_INTERNAL_ENDPOINT and settings.OSS_REGION:
99 endpoint = f"oss-{settings.OSS_REGION}-internal.aliyuncs.com"
100
101 # 创建Bucket对象
102 self.bucket = oss2.Bucket(
103 auth,
104 endpoint,
105 settings.OSS_BUCKET_NAME
106 )
107
108 # 设置超时参数(在 bucket 对象上设置)
109 self.bucket.timeout = (
110 settings.OSS_CONNECTION_TIMEOUT,
111 settings.OSS_READ_TIMEOUT
112 )
113
114 # 配置重试参数
115 self.max_retries = settings.OSS_REQUEST_RETRY_TIMES
116 self.retry_delay = settings.OSS_RETRY_DELAY
117
118 logger.info(
119 f"OSS服务已初始化: endpoint={endpoint}, "
120 f"timeout={settings.OSS_CONNECTION_TIMEOUT}/{settings.OSS_READ_TIMEOUT}s, "
121 f"retries={self.max_retries}"
122 )
123
124 @staticmethod
125 def _validate_file_extension(filename: str) -> bool:
126 """验证文件扩展名"""
127 ext = Path(filename).suffix.lower()
128 if not ext:
129 return False
130 return ext in [e.lower() for e in settings.OSS_ALLOWED_EXTENSIONS]
131
132 @staticmethod
133 def _validate_file_size(file_size: int) -> bool:
134 """验证文件大小"""
135 return 0 < file_size <= settings.OSS_MAX_FILE_SIZE
136
137 @staticmethod
138 def _generate_object_key(
139 filename: str,
140 user_id: int,
141 prefix: Optional[str] = None
142 ) -> str:
143 """
144 生成OSS对象键(路径)
145 格式: {prefix}/{user_id}/{date}/{hash}_{filename}
146 """
147 # 获取文件扩展名
148 ext = Path(filename).suffix.lower()
149 name = Path(filename).stem
150
151 # 生成时间戳目录
152 now = datetime.now()
153 date_path = now.strftime("%Y%m%d")
154
155 # 生成唯一标识(时间戳 + 4位随机数)
156 timestamp = now.strftime("%H%M%S%f")
157 unique_id = hashlib.md5(f"{user_id}{timestamp}{name}".encode()).hexdigest()[:8]
158
159 # 组合文件名
160 new_filename = f"{unique_id}_{name}{ext}"
161
162 # 确定前缀
163 base_prefix = prefix or settings.OSS_UPLOAD_PATH_PREFIX
164
165 # 组合完整路径(加上全局前缀)
166 if settings.OSS_GLOBAL_PREFIX:
167 return f"{settings.OSS_GLOBAL_PREFIX}/{base_prefix}/{user_id}/{date_path}/{new_filename}"
168 return f"{base_prefix}/{user_id}/{date_path}/{new_filename}"
169
170 @staticmethod
171 def _generate_object_key_simple(
172 filename: str,
173 entity_type: str
174 ) -> str:
175 """
176 生成OSS对象键(路径)- 简化版本,用于角色/场景图片转存
177 格式: {entity_type}/{date}/{filename}
178 """
179 # 获取文件扩展名
180 ext = Path(filename).suffix.lower()
181 name = Path(filename).stem
182
183 # 生成时间戳目录
184 now = datetime.now()
185 date_path = now.strftime("%Y%m%d")
186
187 # 生成唯一标识(时间戳 + 随机数)
188 timestamp = now.strftime("%H%M%S%f")
189 unique_id = hashlib.md5(f"{entity_type}{timestamp}{name}".encode()).hexdigest()[:8]
190
191 # 组合文件名
192 new_filename = f"{unique_id}{ext}"
193
194 # 组合完整路径:类型/日期/文件名.后缀
195 if settings.OSS_GLOBAL_PREFIX:
196 return f"{settings.OSS_GLOBAL_PREFIX}/{entity_type}/{date_path}/{new_filename}"
197 return f"{entity_type}/{date_path}/{new_filename}"
198
199 @staticmethod
200 def _get_content_type(filename: str) -> str:
201 """获取文件MIME类型"""
202 content_type, _ = mimetypes.guess_type(filename)
203 return content_type or 'application/octet-stream'
204
205 @staticmethod
206 def _get_extension_from_mime(mime_type: str) -> str:
207 """
208 根据MIME类型获取文件扩展名
209
210 参数:
211 mime_type: MIME类型(如 image/jpeg, video/mp4)
212
213 返回:
214 文件扩展名(包含点号,如 .jpg, .mp4)
215 """
216 # MIME类型到扩展名的映射
217 mime_to_ext = {
218 # 图片
219 'image/jpeg': '.jpg',
220 'image/jpg': '.jpg',
221 'image/png': '.png',
222 'image/gif': '.gif',
223 'image/webp': '.webp',
224 'image/svg+xml': '.svg',
225 'image/bmp': '.bmp',
226 'image/x-icon': '.ico',
227 # 视频
228 'video/mp4': '.mp4',
229 'video/mpeg': '.mpeg',
230 'video/webm': '.webm',
231 'video/quicktime': '.mov',
232 'video/x-msvideo': '.avi',
233 'video/x-matroska': '.mkv',
234 # 音频
235 'audio/mpeg': '.mp3',
236 'audio/wav': '.wav',
237 'audio/ogg': '.ogg',
238 'audio/webm': '.weba',
239 # 其他
240 'application/pdf': '.pdf',
241 'text/plain': '.txt',
242 'application/json': '.json',
243 }
244
245 # 处理带参数的 MIME 类型(如 image/jpeg; charset=utf-8)
246 base_mime = mime_type.split(';')[0].strip().lower()
247
248 return mime_to_ext.get(base_mime, '')
249
250 def upload_file(
251 self,
252 file_data: BinaryIO,
253 filename: str,
254 user_id: int,
255 prefix: Optional[str] = None,
256 validate_extension: bool = True
257 ) -> Dict[str, Any]:
258 """
259 直接上传文件到OSS(适合小文件)
260
261 参数:
262 file_data: 文件二进制流
263 filename: 原始文件名
264 user_id: 用户ID
265 prefix: 路径前缀(可选)
266 validate_extension: 是否验证文件扩展名
267
268 返回:
269 上传结果信息
270 """
271 # 使用重试装饰器包装实际上传操作
272 @retry_on_connection_error(
273 max_retries=self.max_retries,
274 delay=self.retry_delay
275 )
276 def _do_upload():
277 # 验证文件扩展名
278 if validate_extension and not self._validate_file_extension(filename):
279 raise ValidationException(
280 f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
281 )
282
283 # 生成OSS对象键
284 object_key = self._generate_object_key(filename, user_id, prefix)
285
286 # 获取Content-Type
287 content_type = self._get_content_type(filename)
288
289 # 上传文件
290 result = self.bucket.put_object(
291 object_key,
292 file_data,
293 headers={'Content-Type': content_type}
294 )
295
296 # 构建文件URL
297 file_url = self._build_file_url(object_key)
298
299 return {
300 "object_key": object_key,
301 "filename": filename,
302 "url": file_url,
303 "content_type": content_type,
304 "etag": result.etag,
305 "uploaded_at": datetime.now().isoformat()
306 }
307
308 try:
309 return _do_upload()
310 except oss2.exceptions.OssError as e:
311 logger.error(f"OSS上传失败: {e.message}")
312 raise ExternalServiceException(f"OSS上传失败: {e.message}")
313 except ValidationException:
314 raise
315 except Exception as e:
316 logger.error(f"文件上传失败: {str(e)}", exc_info=True)
317 raise BusinessException(f"文件上传失败: {str(e)}")
318
319 def upload_file_with_size(
320 self,
321 file_data: BinaryIO,
322 filename: str,
323 file_size: int,
324 user_id: int,
325 prefix: Optional[str] = None
326 ) -> Dict[str, Any]:
327 """
328 根据文件大小智能选择上传方式(带文件大小验证)
329
330 参数:
331 file_data: 文件二进制流
332 filename: 原始文件名
333 file_size: 文件大小(字节)
334 user_id: 用户ID
335 prefix: 路径前缀(可选)
336
337 返回:
338 上传结果信息
339 """
340 # 验证文件大小
341 if not self._validate_file_size(file_size):
342 raise ValidationException(
343 f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB"
344 )
345
346 # 根据文件大小选择上传方式
347 if file_size > settings.OSS_MULTIPART_THRESHOLD:
348 return self.multipart_upload(file_data, filename, file_size, user_id, prefix)
349 else:
350 result = self.upload_file(file_data, filename, user_id, prefix)
351 result["size"] = file_size
352 return result
353
354 def multipart_upload(
355 self,
356 file_data: BinaryIO,
357 filename: str,
358 file_size: int,
359 user_id: int,
360 prefix: Optional[str] = None,
361 part_size: Optional[int] = None,
362 ) -> Dict[str, Any]:
363 """
364 分片上传文件到OSS(适合大文件)
365
366 参数:
367 file_data: 文件二进制流
368 filename: 原始文件名
369 file_size: 文件大小(字节)
370 user_id: 用户ID
371 prefix: 路径前缀(可选)
372
373 返回:
374 上传结果信息
375 """
376 upload_id = None
377 object_key = None
378 effective_part_size = part_size or settings.OSS_PART_SIZE
379
380 @retry_on_connection_error(
381 max_retries=self.max_retries,
382 delay=self.retry_delay
383 )
384 def _do_multipart_upload():
385 nonlocal upload_id, object_key
386
387 # 验证文件扩展名
388 if not self._validate_file_extension(filename):
389 raise ValidationException(
390 f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
391 )
392
393 # 生成OSS对象键
394 object_key = self._generate_object_key(filename, user_id, prefix)
395
396 # 获取Content-Type
397 content_type = self._get_content_type(filename)
398
399 # 初始化分片上传
400 upload_id = self.bucket.init_multipart_upload(
401 object_key,
402 headers={'Content-Type': content_type}
403 ).upload_id
404
405 logger.info(f"[multipart_upload] 开始分片上传: {object_key}, upload_id={upload_id}")
406
407 # 计算分片数量
408 part_size = effective_part_size
409 part_count = (file_size + part_size - 1) // part_size
410
411 # 上传所有分片
412 parts = []
413 for part_number in range(1, part_count + 1):
414 # 读取分片数据
415 offset = (part_number - 1) * part_size
416 size = min(part_size, file_size - offset)
417 file_data.seek(offset)
418 part_data = file_data.read(size)
419
420 # 上传分片(每个分片也使用重试)
421 @retry_on_connection_error(
422 max_retries=self.max_retries,
423 delay=self.retry_delay
424 )
425 def _upload_part():
426 result = self.bucket.upload_part(
427 object_key,
428 upload_id,
429 part_number,
430 part_data
431 )
432 logger.debug(f"分片 {part_number}/{part_count} 上传成功")
433 return result
434
435 result = _upload_part()
436 parts.append(oss2.models.PartInfo(part_number, result.etag))
437
438 # 完成分片上传
439 @retry_on_connection_error(
440 max_retries=self.max_retries,
441 delay=self.retry_delay
442 )
443 def _complete_upload():
444 return self.bucket.complete_multipart_upload(
445 object_key,
446 upload_id,
447 parts
448 )
449
450 result = _complete_upload()
451
452 # 构建文件URL
453 file_url = self._build_file_url(object_key)
454
455 return {
456 "object_key": object_key,
457 "filename": filename,
458 "url": file_url,
459 "size": file_size,
460 "content_type": content_type,
461 "etag": result.etag,
462 "upload_id": upload_id,
463 "part_count": part_count,
464 "uploaded_at": datetime.now().isoformat()
465 }
466
467 try:
468 return _do_multipart_upload()
469 except oss2.exceptions.OssError as e:
470 logger.error(f"OSS分片上传失败: {e.message}")
471 # 上传失败,尝试取消分片上传
472 if upload_id and object_key:
473 try:
474 self.bucket.abort_multipart_upload(
475 object_key,
476 upload_id
477 )
478 logger.info(f"已取消分片上传: {object_key}, upload_id={upload_id}")
479 except Exception as cancel_error:
480 logger.warning(f"取消分片上传失败: {cancel_error}")
481 raise ExternalServiceException(f"OSS分片上传失败: {e.message}")
482 except (ValidationException, ExternalServiceException):
483 raise
484 except Exception as e:
485 logger.error(f"分片上传失败: {str(e)}", exc_info=True)
486 # 上传失败,尝试取消分片上传
487 if upload_id and object_key:
488 try:
489 self.bucket.abort_multipart_upload(
490 object_key,
491 upload_id
492 )
493 except:
494 pass
495 raise BusinessException(f"分片上传失败: {str(e)}")
496
497 def generate_presigned_url(
498 self,
499 filename: str,
500 user_id: int,
501 expires: Optional[int] = None,
502 prefix: Optional[str] = None
503 ) -> Dict[str, Any]:
504 """
505 生成预签名上传URL(前端直传)
506
507 参数:
508 filename: 文件名
509 user_id: 用户ID
510 expires: 过期时间(秒),默认使用配置值
511 prefix: 路径前缀(可选)
512
513 返回:
514 预签名URL信息
515 """
516 try:
517 # 验证文件扩展名
518 if not self._validate_file_extension(filename):
519 raise ValidationException(
520 f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
521 )
522
523 # 生成OSS对象键
524 object_key = self._generate_object_key(filename, user_id, prefix)
525
526 # 设置过期时间
527 expires = expires or settings.OSS_SIGNED_URL_EXPIRE
528
529 # 生成预签名URL
530 signed_url = self.bucket.sign_url(
531 'PUT',
532 object_key,
533 expires,
534 headers={'Content-Type': self._get_content_type(filename)}
535 )
536
537 # 构建文件URL(上传后的访问URL)
538 file_url = self._build_file_url(object_key)
539
540 return {
541 "upload_url": signed_url,
542 "object_key": object_key,
543 "file_url": file_url,
544 "expires_in": expires,
545 "expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat(),
546 "method": "PUT",
547 "headers": {
548 "Content-Type": self._get_content_type(filename)
549 }
550 }
551
552 except oss2.exceptions.OssError as e:
553 raise ExternalServiceException(f"生成签名URL失败: {e.message}")
554 except ValidationException:
555 raise
556 except Exception as e:
557 raise BusinessException(f"生成签名URL失败: {str(e)}")
558
559 def generate_multipart_presigned_urls(
560 self,
561 filename: str,
562 file_size: int,
563 user_id: int,
564 expires: Optional[int] = None,
565 prefix: Optional[str] = None
566 ) -> Dict[str, Any]:
567 """
568 生成分片上传的预签名URL(前端分片直传)
569
570 参数:
571 filename: 文件名
572 file_size: 文件大小(字节)
573 user_id: 用户ID
574 expires: 过期时间(秒)
575 prefix: 路径前缀(可选)
576
577 返回:
578 分片上传信息和预签名URL列表
579 """
580 try:
581 # 验证文件大小
582 if not self._validate_file_size(file_size):
583 raise ValidationException(
584 f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB"
585 )
586
587 # 验证文件扩展名
588 if not self._validate_file_extension(filename):
589 raise ValidationException(
590 f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
591 )
592
593 # 生成OSS对象键
594 object_key = self._generate_object_key(filename, user_id, prefix)
595
596 # 初始化分片上传
597 upload_id = self.bucket.init_multipart_upload(
598 object_key,
599 headers={'Content-Type': self._get_content_type(filename)}
600 ).upload_id
601
602 # 计算分片数量
603 part_size = settings.OSS_PART_SIZE
604 part_count = (file_size + part_size - 1) // part_size
605
606 # 设置过期时间
607 expires = expires or settings.OSS_SIGNED_URL_EXPIRE
608
609 # 为每个分片生成预签名URL
610 part_urls = []
611 for part_number in range(1, part_count + 1):
612 params = {
613 'uploadId': upload_id,
614 'partNumber': str(part_number)
615 }
616 signed_url = self.bucket.sign_url(
617 'PUT',
618 object_key,
619 expires,
620 params=params
621 )
622 part_urls.append({
623 "part_number": part_number,
624 "upload_url": signed_url
625 })
626
627 # 构建文件URL(完成后的访问URL)
628 file_url = self._build_file_url(object_key)
629
630 return {
631 "upload_id": upload_id,
632 "object_key": object_key,
633 "file_url": file_url,
634 "part_size": part_size,
635 "part_count": part_count,
636 "part_urls": part_urls,
637 "expires_in": expires,
638 "expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat()
639 }
640
641 except oss2.exceptions.OssError as e:
642 raise ExternalServiceException(f"初始化分片上传失败: {e.message}")
643 except (ValidationException, ExternalServiceException):
644 raise
645 except Exception as e:
646 raise BusinessException(f"初始化分片上传失败: {str(e)}")
647
648 def complete_multipart_upload_by_client(
649 self,
650 object_key: str,
651 upload_id: str,
652 parts: list
653 ) -> Dict[str, Any]:
654 """
655 完成客户端分片上传
656
657 参数:
658 object_key: OSS对象键
659 upload_id: 上传ID
660 parts: 分片信息列表 [{"part_number": 1, "etag": "xxx"}, ...]
661
662 返回:
663 完成结果
664 """
665 try:
666 # 构建分片信息
667 part_info_list = [
668 oss2.models.PartInfo(part["part_number"], part["etag"])
669 for part in parts
670 ]
671
672 # 完成分片上传
673 result = self.bucket.complete_multipart_upload(
674 object_key,
675 upload_id,
676 part_info_list
677 )
678
679 # 构建文件URL
680 file_url = self._build_file_url(object_key)
681
682 return {
683 "object_key": object_key,
684 "url": file_url,
685 "etag": result.etag,
686 "completed_at": datetime.now().isoformat()
687 }
688
689 except oss2.exceptions.OssError as e:
690 raise ExternalServiceException(f"完成分片上传失败: {e.message}")
691 except Exception as e:
692 raise BusinessException(f"完成分片上传失败: {str(e)}")
693
694 def abort_multipart_upload(
695 self,
696 object_key: str,
697 upload_id: str
698 ) -> bool:
699 """
700 取消分片上传
701
702 参数:
703 object_key: OSS对象键
704 upload_id: 上传ID
705
706 返回:
707 是否成功
708 """
709 @retry_on_connection_error(
710 max_retries=self.max_retries,
711 delay=self.retry_delay
712 )
713 def _do_abort():
714 self.bucket.abort_multipart_upload(
715 object_key,
716 upload_id
717 )
718
719 try:
720 _do_abort()
721 return True
722 except oss2.exceptions.OssError as e:
723 raise ExternalServiceException(f"取消分片上传失败: {e.message}")
724 except Exception as e:
725 raise BusinessException(f"取消分片上传失败: {str(e)}")
726
727 def delete_file(self, object_key: str) -> bool:
728 """
729 删除OSS文件
730
731 参数:
732 object_key: OSS对象键
733
734 返回:
735 是否成功
736 """
737 try:
738 self.bucket.delete_object(object_key)
739 return True
740 except oss2.exceptions.OssError as e:
741 raise ExternalServiceException(f"删除文件失败: {e.message}")
742 except Exception as e:
743 raise BusinessException(f"删除文件失败: {str(e)}")
744
745 def delete_files_batch(self, object_keys: list) -> Dict[str, Any]:
746 """
747 批量删除OSS文件
748
749 参数:
750 object_keys: OSS对象键列表
751
752 返回:
753 删除结果
754 """
755 try:
756 result = self.bucket.batch_delete_objects(object_keys)
757 return {
758 "deleted_count": len(result.deleted_keys),
759 "deleted_keys": result.deleted_keys
760 }
761 except oss2.exceptions.OssError as e:
762 raise ExternalServiceException(f"批量删除文件失败: {e.message}")
763 except Exception as e:
764 raise BusinessException(f"批量删除文件失败: {str(e)}")
765
766 def upload_from_url(
767 self,
768 url: str,
769 entity_type: str,
770 filename: Optional[str] = None
771 ) -> Dict[str, Any]:
772 """
773 从URL下载文件并上传到OSS(用于转存外部生成的图片或视频)
774
775 参数:
776 url: 文件URL(图片或视频)
777 entity_type: 实体类型(character/scene),用于构建存储路径
778 filename: 自定义文件名(可选),如果未提供则从URL提取或根据Content-Type推断
779
780 返回:
781 上传结果信息
782 """
783 import requests
784 from urllib.parse import urlparse
785
786 @retry_on_connection_error(
787 max_retries=self.max_retries,
788 delay=self.retry_delay
789 )
790 def _do_upload():
791 # 下载文件
792 try:
793 response = requests.get(url, timeout=30, stream=True)
794 response.raise_for_status()
795 except Exception as e:
796 logger.error(f"从URL下载文件失败: {url}, error: {e}")
797 raise ExternalServiceException(f"从URL下载文件失败: {str(e)}")
798
799 # 获取文件内容和Content-Type
800 file_data = response.content
801 response_content_type = response.headers.get('Content-Type', '')
802
803 # 确定文件名
804 final_filename = filename
805 if not final_filename:
806 # 尝试从URL提取文件名
807 parsed_url = urlparse(url)
808 path = parsed_url.path
809 final_filename = Path(path).name
810
811 # 如果URL中没有文件名或没有扩展名,根据Content-Type推断
812 if not final_filename or '.' not in final_filename:
813 ext = self._get_extension_from_mime(response_content_type)
814 if ext:
815 # 根据扩展名确定文件类型前缀
816 type_prefix = 'video' if ext.startswith('.mp') or ext in ('.webm', '.mov', '.avi', '.mkv', '.mpeg') else 'file'
817 final_filename = f"{type_prefix}_{int(datetime.now().timestamp())}{ext}"
818 else:
819 # 如果无法推断,使用通用名称
820 final_filename = f"file_{int(datetime.now().timestamp())}.bin"
821
822 # 生成OSS对象键
823 object_key = self._generate_object_key_simple(final_filename, entity_type)
824
825 # 从文件名获取Content-Type(优先使用响应头的Content-Type)
826 content_type = response_content_type or self._get_content_type(final_filename)
827
828 # 上传文件
829 result = self.bucket.put_object(
830 object_key,
831 file_data,
832 headers={'Content-Type': content_type}
833 )
834
835 # 构建文件URL
836 file_url = self._build_file_url(object_key)
837
838 logger.info(f"文件转存成功: {url} -> {file_url}")
839
840 return {
841 "object_key": object_key,
842 "filename": final_filename,
843 "url": file_url,
844 "content_type": content_type,
845 "size": len(file_data),
846 "etag": result.etag,
847 "uploaded_at": datetime.now().isoformat()
848 }
849
850 try:
851 return _do_upload()
852 except oss2.exceptions.OssError as e:
853 logger.error(f"OSS转存失败: {e.message}")
854 raise ExternalServiceException(f"OSS转存失败: {e.message}")
855 except (ExternalServiceException,):
856 raise
857 except Exception as e:
858 logger.error(f"文件转存失败: {str(e)}", exc_info=True)
859 raise BusinessException(f"文件转存失败: {str(e)}")
860
861 def upload_from_base64(
862 self,
863 base64_data_url: str,
864 entity_type: str,
865 filename: Optional[str] = None
866 ) -> Dict[str, Any]:
867 """
868 从 Base64 data URL 解码并上传到 OSS
869
870 参数:
871 base64_data_url: Base64 data URL (格式: data:image/png;base64,...)
872 entity_type: 实体类型(image/character/scene),用于构建存储路径
873 filename: 自定义文件名(可选),如果未提供则自动生成
874
875 返回:
876 上传结果信息
877 """
878 import base64
879 import re
880
881 @retry_on_connection_error(
882 max_retries=self.max_retries,
883 delay=self.retry_delay
884 )
885 def _do_upload():
886 # 解析 Base64 data URL
887 # 格式: data:image/png;base64,iVBORw0KGgo...
888 if not base64_data_url.startswith('data:'):
889 raise ValueError(f"无效的 Base64 data URL 格式")
890
891 # 提取 MIME 类型和 Base64 数据
892 match = re.match(r'data:([^;]+);base64,(.+)', base64_data_url)
893 if not match:
894 raise ValueError(f"无法解析 Base64 data URL")
895
896 mime_type = match.group(1) # 例如: image/png
897 base64_data = match.group(2)
898
899 # 根据 MIME 类型确定文件扩展名
900 ext_map = {
901 'image/png': '.png',
902 'image/jpeg': '.jpg',
903 'image/jpg': '.jpg',
904 'image/gif': '.gif',
905 'image/webp': '.webp',
906 }
907 ext = ext_map.get(mime_type.lower(), '.png')
908
909 # 确定文件名
910 if filename:
911 final_filename = filename
912 # 确保文件名有正确的扩展名
913 if not final_filename.endswith(ext):
914 final_filename = Path(filename).stem + ext
915 else:
916 # 生成默认文件名
917 final_filename = f"image_{int(datetime.now().timestamp())}{ext}"
918
919 # 解码 Base64 数据
920 try:
921 file_data = base64.b64decode(base64_data)
922 except Exception as e:
923 logger.error(f"Base64 解码失败: {e}")
924 raise ExternalServiceException(f"Base64 解码失败: {str(e)}")
925
926 # 生成 OSS 对象键
927 object_key = self._generate_object_key_simple(final_filename, entity_type)
928
929 # 上传文件
930 result = self.bucket.put_object(
931 object_key,
932 file_data,
933 headers={'Content-Type': mime_type}
934 )
935
936 # 构建文件 URL
937 file_url = self._build_file_url(object_key)
938
939 logger.info(f"Base64 图片上传成功: size={len(file_data)} -> {file_url}")
940
941 return {
942 "object_key": object_key,
943 "filename": final_filename,
944 "url": file_url,
945 "content_type": mime_type,
946 "size": len(file_data),
947 "etag": result.etag,
948 "uploaded_at": datetime.now().isoformat()
949 }
950
951 try:
952 return _do_upload()
953 except oss2.exceptions.OssError as e:
954 logger.error(f"OSS 上传失败: {e.message}")
955 raise ExternalServiceException(f"OSS 上传失败: {e.message}")
956 except (ValueError, ExternalServiceException):
957 raise
958 except Exception as e:
959 logger.error(f"Base64 图片上传失败: {str(e)}", exc_info=True)
960 raise BusinessException(f"Base64 图片上传失败: {str(e)}")
961
962 def get_file_info(self, object_key: str) -> Dict[str, Any]:
963 """
964 获取文件信息
965
966 参数:
967 object_key: OSS对象键
968
969 返回:
970 文件信息
971 """
972 try:
973 meta = self.bucket.get_object_meta(object_key)
974 return {
975 "object_key": object_key,
976 "size": meta.headers.get('Content-Length'),
977 "content_type": meta.headers.get('Content-Type'),
978 "etag": meta.headers.get('ETag'),
979 "last_modified": meta.headers.get('Last-Modified')
980 }
981 except oss2.exceptions.NoSuchKey:
982 raise NotFoundException(f"文件不存在: {object_key}")
983 except oss2.exceptions.OssError as e:
984 raise ExternalServiceException(f"获取文件信息失败: {e.message}")
985 except Exception as e:
986 raise BusinessException(f"获取文件信息失败: {str(e)}")
987
988 def file_exists(self, object_key: str) -> bool:
989 """
990 检查文件是否存在
991
992 参数:
993 object_key: OSS对象键
994
995 返回:
996 是否存在
997 """
998 try:
999 return self.bucket.object_exists(object_key)
1000 except Exception:
1001 return False
1002
1003 def generate_download_url(
1004 self,
1005 object_key: str,
1006 expires: Optional[int] = None,
1007 filename: Optional[str] = None
1008 ) -> str:
1009 """
1010 生成文件下载URL(临时访问)
1011
1012 参数:
1013 object_key: OSS对象键
1014 expires: 过期时间(秒)
1015 filename: 下载时的文件名(可选)
1016
1017 返回:
1018 下载URL
1019 """
1020 try:
1021 expires = expires or settings.OSS_SIGNED_URL_EXPIRE
1022
1023 params = {}
1024 if filename:
1025 params['response-content-disposition'] = f'attachment; filename="{filename}"'
1026
1027 return self.bucket.sign_url(
1028 'GET',
1029 object_key,
1030 expires,
1031 params=params
1032 )
1033 except oss2.exceptions.OssError as e:
1034 raise ExternalServiceException(f"生成下载URL失败: {e.message}")
1035 except Exception as e:
1036 raise BusinessException(f"生成下载URL失败: {str(e)}")
1037
1038 def _build_file_url(self, object_key: str) -> str:
1039 """
1040 构建文件访问URL
1041
1042 参数:
1043 object_key: OSS对象键
1044
1045 返回:
1046 文件URL
1047 """
1048 # 如果配置了CDN域名,使用CDN域名
1049 if settings.OSS_CDN_DOMAIN:
1050 # 移除可能存在的 http:// 或 https:// 前��
1051 cdn_domain = settings.OSS_CDN_DOMAIN
1052 if cdn_domain.startswith("http://"):
1053 cdn_domain = cdn_domain[7:]
1054 elif cdn_domain.startswith("https://"):
1055 cdn_domain = cdn_domain[8:]
1056
1057 # 移除域名开头的 /
1058 if cdn_domain.startswith("/"):
1059 cdn_domain = cdn_domain[1:]
1060
1061 protocol = "https" if settings.OSS_USE_HTTPS else "http"
1062 return f"{protocol}://{cdn_domain}/{object_key}"
1063
1064 # 否则使用OSS域名
1065 # 移除 endpoint 中可能存在的协议前缀
1066 endpoint = settings.OSS_ENDPOINT
1067 if endpoint.startswith("http://"):
1068 endpoint = endpoint[7:]
1069 elif endpoint.startswith("https://"):
1070 endpoint = endpoint[8:]
1071
1072 protocol = "https" if settings.OSS_USE_HTTPS else "http"
1073 return f"{protocol}://{settings.OSS_BUCKET_NAME}.{endpoint}/{object_key}"
1074
1075 def generate_sts_credentials(
1076 self,
1077 user_id: int,
1078 duration_seconds: Optional[int] = None,
1079 policy: Optional[str] = None,
1080 role_session_name: Optional[str] = None
1081 ) -> Dict[str, Any]:
1082 """
1083 生成OSS STS临时凭证(用于前端直传)
1084
1085 参数:
1086 user_id: 用户ID(用于构建会话名称)
1087 duration_seconds: 凭证有效期(秒),默认使用配置值
1088 policy: 自定义策略(可选,JSON字符串)
1089 role_session_name: 角色会话名称(可选)
1090
1091 返回:
1092 STS临时凭证信息,包括:
1093 - access_key_id: 临时AccessKey ID
1094 - access_key_secret: 临时AccessKey Secret
1095 - security_token: 安全令牌
1096 - expiration: 过期时间(UTC)
1097 - region: 区域
1098 - bucket: Bucket名称
1099 - endpoint: OSS端点
1100 - upload_path_prefix: 上传路径前缀
1101 """
1102 if not settings.OSS_STS_ROLE_ARN:
1103 raise BusinessException(
1104 "STS Role ARN未配置,请检查环境变量 OSS_STS_ROLE_ARN"
1105 )
1106
1107 if not all([
1108 settings.OSS_ACCESS_KEY_ID,
1109 settings.OSS_ACCESS_KEY_SECRET,
1110 settings.OSS_ENDPOINT
1111 ]):
1112 raise BusinessException(
1113 "OSS配置不完整,请检查环境变量"
1114 )
1115
1116 try:
1117 # 确定区域ID
1118 # AcsClient 的第三个参数是 region_id,不是 endpoint
1119 # 格式: cn-hangzhou, cn-beijing 等
1120 if settings.OSS_REGION:
1121 region = settings.OSS_REGION
1122 else:
1123 # 从OSS端点提取区域
1124 # 格式: oss-cn-hangzhou.aliyuncs.com -> cn-hangzhou
1125 import re
1126 endpoint = settings.OSS_ENDPOINT
1127 if not endpoint:
1128 raise BusinessException("OSS_ENDPOINT 未配置")
1129 match = re.search(r'oss-(\w+)-(\w+)\.aliyuncs\.com', endpoint)
1130 if match:
1131 region = f"{match.group(1)}-{match.group(2)}"
1132 else:
1133 raise BusinessException(
1134 "无法从OSS端点确定区域,请配置 OSS_REGION"
1135 )
1136
1137 # 创建STS客户端(使用正确的 region_id 参数)
1138 client = AcsClient(
1139 settings.OSS_ACCESS_KEY_ID,
1140 settings.OSS_ACCESS_KEY_SECRET,
1141 region
1142 )
1143
1144 # 创建AssumeRole请求
1145 request = AssumeRoleRequest.AssumeRoleRequest()
1146 request.set_RoleArn(settings.OSS_STS_ROLE_ARN)
1147 request.set_RoleSessionName(
1148 role_session_name or f"aimv-frontend-upload-user-{user_id}"
1149 )
1150 request.set_DurationSeconds(
1151 duration_seconds or settings.OSS_STS_DURATION_SECONDS
1152 )
1153
1154 # 设置策略(如果提供了自定义策略)
1155 if policy:
1156 request.set_Policy(policy)
1157 elif settings.OSS_STS_POLICY:
1158 request.set_Policy(settings.OSS_STS_POLICY)
1159
1160 # 发送请求
1161 response = client.do_action_with_exception(request)
1162 result = json.loads(response.decode('utf-8'))
1163
1164 # 提取凭证信息
1165 credentials = result['Credentials']
1166
1167 # 将过期时间转换为北京时间(UTC+8)
1168 from app.schemas.common import convert_datetime_to_beijing
1169 expiration_utc = datetime.strptime(credentials['Expiration'], '%Y-%m-%dT%H:%M:%SZ')
1170 expiration_str = convert_datetime_to_beijing(expiration_utc)
1171
1172 # 构建上传路径前缀
1173 upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}"
1174 if settings.OSS_GLOBAL_PREFIX:
1175 upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}"
1176
1177 return {
1178 "access_key_id": credentials['AccessKeyId'],
1179 "access_key_secret": credentials['AccessKeySecret'],
1180 "security_token": credentials['SecurityToken'],
1181 "expiration": expiration_str,
1182 "region": settings.OSS_REGION,
1183 "bucket": settings.OSS_BUCKET_NAME,
1184 "endpoint": settings.OSS_ENDPOINT,
1185 "cdn_domain": settings.OSS_CDN_DOMAIN,
1186 "upload_path_prefix": upload_path_prefix,
1187 }
1188
1189 except Exception as e:
1190 logger.error(f"生成STS临时凭证失败: {str(e)}", exc_info=True)
1191 raise ExternalServiceException(
1192 f"生成STS临时凭证失败: {str(e)}"
1193 )
1194
1195 def generate_user_upload_policy(self, user_id: int) -> str:
1196 """
1197 为用户生成上传策略(限制只能上传到指定用户目录)
1198
1199 参数:
1200 user_id: 用户ID
1201
1202 返回:
1203 策略JSON字符串
1204 """
1205 # 构建用户专属路径
1206 upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}"
1207 if settings.OSS_GLOBAL_PREFIX:
1208 upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}"
1209
1210 # 构建策略
1211 policy = {
1212 "Version": "1",
1213 "Statement": [
1214 {
1215 "Effect": "Allow",
1216 "Action": [
1217 "oss:PutObject",
1218 "oss:InitiateMultipartUpload",
1219 "oss:UploadPart",
1220 "oss:CompleteMultipartUpload",
1221 "oss:AbortMultipartUpload"
1222 ],
1223 "Resource": [
1224 f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}/*"
1225 ]
1226 },
1227 {
1228 "Effect": "Allow",
1229 "Action": [
1230 "oss:ListObjects"
1231 ],
1232 "Resource": [
1233 f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}",
1234 f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}"
1235 ],
1236 "Condition": {
1237 "StringLike": {
1238 "oss:prefix": [f"{upload_path_prefix}/*", f"{upload_path_prefix}"]
1239 }
1240 }
1241 }
1242 ]
1243 }
1244
1245 return json.dumps(policy)
1246
1247
1248 # 创建全局OSS服务实例(延迟初始化)
1249 _oss_service_instance = None
1250
1251
1252 def get_oss_service() -> OSSService:
1253 """获取OSS服务实例(单例模式)"""
1254 global _oss_service_instance
1255 if _oss_service_instance is None:
1256 _oss_service_instance = OSSService()
1257 return _oss_service_instance
1258
1259
1260 # 全局OSS服务实例
1261 oss_service = get_oss_service()
1 """
2 阿里云OSS文件上传模块
3 """
4 import os
5 import uuid
6 import logging
7 from datetime import datetime, timedelta
8
9 import oss2
10
11 from app.core.config import settings
12
13 logger = logging.getLogger(__name__)
14
15
16 class OSSUploader:
17 """阿里云OSS上传器"""
18
19 def __init__(self):
20 """初始化OSS客户端"""
21 self.access_key_id = settings.OSS_ACCESS_KEY_ID
22 self.access_key_secret = settings.OSS_ACCESS_KEY_SECRET
23 self.endpoint = settings.OSS_ENDPOINT
24 self.bucket_name = settings.OSS_BUCKET_NAME
25
26 if not all([
27 self.access_key_id,
28 self.access_key_secret,
29 self.endpoint,
30 self.bucket_name,
31 ]):
32 raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME")
33
34 logger.info(
35 "OSS配置: endpoint=%s, bucket=%s",
36 self.endpoint,
37 self.bucket_name,
38 )
39 # 创建认证对象
40 self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
41
42 # 默认使用公网 endpoint;非阿里云内网环境下访问 internal endpoint 容易失败。
43 self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
44
45 def upload_file(self, local_file_path, oss_object_name=None):
46 """
47 上传文件到OSS
48
49 Args:
50 local_file_path: 本地文件路径
51 oss_object_name: OSS对象名称,如果不指定则使用时间戳+原文件名
52
53 Returns:
54 tuple: (success: bool, url: str) 或 (success: bool, error: str)
55 """
56 try:
57 if not os.path.exists(local_file_path):
58 logger.error(f"本地文件不存在: {local_file_path}")
59 return False, "本地文件不存在"
60
61 if not oss_object_name:
62 _, ext = os.path.splitext(local_file_path)
63 oss_object_name = f"{uuid.uuid4()}{ext}"
64
65 # 如果没有指定OSS对象名称,则生成一个
66 date = datetime.now().strftime("%Y%m%d")
67 oss_object_name = f"temp_ai/{date}/{oss_object_name}"
68
69 # 上传文件
70 result = self.bucket.put_object_from_file(oss_object_name, local_file_path)
71
72 # 构建文件URL
73 file_url = f"https://{self.bucket_name}.{self.endpoint}/{oss_object_name}"
74
75 logger.info(f"文件上传成功: {local_file_path} -> {file_url}")
76 return True, file_url
77
78 except Exception as e:
79 logger.error(f"文件上传失败: {local_file_path}, 错误: {e}")
80 return False, str(e)
81
82 def upload_data(self, data, oss_object_name):
83 """
84 上传数据到OSS
85
86 Args:
87 data: 要上传的数据(字符串或字节)
88 oss_object_name: OSS对象名称
89
90 Returns:
91 dict: 包含上传结果的字典
92 """
93 try:
94 # 上传数据
95 result = self.bucket.put_object(oss_object_name, data)
96
97 # 构建文件URL
98 file_url = f"{self.endpoint.rstrip('/')}/{self.bucket_name}/{oss_object_name}"
99
100 return {
101 "success": True,
102 "oss_object_name": oss_object_name,
103 "file_url": file_url,
104 "etag": result.etag,
105 "size": len(data) if isinstance(data, (str, bytes)) else 0
106 }
107
108 except Exception as e:
109 return {"success": False, "error": str(e)}
110
111
112 def get_bucket():
113 """获取Bucket对象"""
114 if not all([
115 settings.OSS_ACCESS_KEY_ID,
116 settings.OSS_ACCESS_KEY_SECRET,
117 settings.OSS_ENDPOINT,
118 settings.OSS_BUCKET_NAME,
119 ]):
120 raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME")
121
122 auth = oss2.Auth(settings.OSS_ACCESS_KEY_ID, settings.OSS_ACCESS_KEY_SECRET)
123 bucket = oss2.Bucket(auth, settings.OSS_ENDPOINT, settings.OSS_BUCKET_NAME)
124 return bucket
125
126
127 def clean_expire_file():
128 """核心任务函数"""
129 print(f"\n[{datetime.now()}] 开始执行每日清理任务...")
130 ROOT_PREFIX = 'temp_ai/'
131 bucket = get_bucket()
132
133 # 1. 计算时间阈值
134 now = datetime.now()
135 yesterday_date = (now - timedelta(days=1)).date()
136 print(f"保留阈值: {yesterday_date} (即 {yesterday_date} 之前的数据将被删除)")
137
138 # 2. 遍历目录
139 try:
140 for obj in oss2.ObjectIterator(bucket, prefix=ROOT_PREFIX, delimiter='/'):
141 path = ""
142 is_directory = False
143
144 # --- [核心修改] 统一路径获取方式 ---
145
146 # 情况 A: 它是虚拟目录 (CommonPrefix)
147 if hasattr(obj, 'prefix'):
148 path = obj.prefix
149 is_directory = True
150
151 # 情况 B: 它是实际对象 (SimplifiedObjectInfo)
152 elif hasattr(obj, 'key'):
153 path = obj.key
154 # 如果 key 以 / 结尾,说明它是一个显式创建的文件夹对象
155 if path.endswith('/'):
156 is_directory = True
157 else:
158 is_directory = False # 这是一个普通文件
159
160 # --- 逻辑分流 ---
161
162 if not is_directory:
163 # 这是一个真正的文件(且不是文件夹对象),直接跳过
164 # print(f"[跳过] 散落文件: {path}")
165 continue
166
167 # 此时 path 必定是目录格式 (如 'temp_ai/20251229/')
168 # 下面开始正常的日期判断逻辑
169
170 # 防御性去空,防止路径即为 'temp_ai/' 本身
171 if path == ROOT_PREFIX:
172 continue
173
174 # 解析目录名 (取倒数第二个元素,因为最后一位是空字符串)
175 folder_name_raw = path.strip('/').split('/')[-1]
176
177 try:
178 folder_date_obj = datetime.strptime(folder_name_raw, "%Y%m%d").date()
179
180 if folder_date_obj < yesterday_date:
181 print(f"[删除] 发现过期目录: {path}")
182 # 注意:delete_objects_by_prefix 会删除该前缀下的所有文件
183 # 如果这个目录本身是个对象,也会被一并删除,无需特殊处理
184 delete_objects_by_prefix(bucket, path)
185 else:
186 # print(f"[跳过] 目录较新: {path}")
187 pass
188
189 except ValueError:
190 print(f"[跳过] 非日期命名目录: {path}")
191
192 except Exception as e:
193 import traceback
194 print(f"[严重错误] 任务执行失败: {e}")
195 traceback.print_exc()
196
197
198 def delete_objects_by_prefix(bucket, prefix):
199 """递归删除指定前缀下的所有文件"""
200 print(f" -> 正在清理目录: {prefix} ...")
201 batch_list = []
202 try:
203 for obj in oss2.ObjectIterator(bucket, prefix=prefix):
204 batch_list.append(obj.key)
205 if len(batch_list) >= 1000:
206 bucket.batch_delete_objects(batch_list)
207 batch_list = []
208
209 if batch_list:
210 bucket.batch_delete_objects(batch_list)
211 print(f" -> 目录 {prefix} 清理完毕。")
212 except Exception as e:
213 print(f" [错误] 删除过程出错: {e}")
214
215
216 # 创建OSS上传器实例
217 oss_uploader = OSSUploader()
218
219 if __name__ == '__main__':
220 resp = oss_uploader.upload_file('想-dj-片段.mp3')
221 print(resp)
1
2 from dashscope.common.constants import DASHSCOPE_API_KEY_ENV
3
4
5 ENV = 'test'
6 # ENV = 'local'
7
8
9 DEBUG = True
10 ### 数据库
11 #dev
12 DB_USER = 'root'
13 DB_PASSWORD = 'Hikoon123!'
14 DB_HOST = 'rm-bp18h64ad9ak4d7h5do.mysql.rds.aliyuncs.com'
15 DB_DATABASE = 'music_partner'
16
17 #Redis
18 REDIS_HOST = '172.23.209.46'
19 REDIS_PORT = 6379
20 REDIS_PSW = '1bvvpAmKXFhDDJXb'
21 REDIS_DB = 0
22 #新抖key
23 NEW_RANK_KEY = 'vh1gbvynpyegg6gebhgepgvc6'
24
25 BACK_BASE_URL = 'https://ai-test.hikoon.com/api/partner'
26
27 EMAIL_HOST = 'smtp.exmail.qq.com'
28 EMAIL_PORT = 465
29 EMAIL_HOST_USER = 'bigmusic@hikoon.com'
30 EMAIL_HOST_PASSWORD = 'Music!123'
31 #邮件接收人列表
32 EMAIL_RECEIVERS = ['1774507011@qq.com','yangsheng@hikoon.com']
33
34
35 #标签字典
36 TAG_DICT = {
37 "viral_song": "网络热歌",
38 "sad_songs": "伤感老歌",
39 "folk_songs": "民谣",
40 "catchy_pop": "口水歌",
41 "kids_songs": "洗脑儿歌",
42 "tk_songs": "抖音热歌",
43 "net_songs": "网络歌曲",
44
45 "dj_remix": "DJ嗨曲",
46 "Cheesy_EDM": "土嗨/慢摇",
47 "car_music": "车载音乐",
48 "shout_rap": "喊麦",
49 "heavy_metal": "重金属/土摇DJ嗨曲",
50
51 "mandarin_pop": "华语流行",
52 "mainstream_pop": "主流Pop",
53 "sweet_songs": "甜歌/校园",
54 "hip_rock": "嘻哈说唱R&B摇滚",
55 "child_songs": "主流儿歌",
56
57 "international_pop": "国外流行",
58 "jp_pop": "日韩流行",
59 "west_pop": "欧美流行",
60 "el_edm": "电音EDM",
61
62 "chinese_style": "国风",
63 "opera_vocal": "戏腔/古韵",
64 "guochao_EDM": "国潮电子",
65 "gufeng_music": "传统器乐古风",
66
67 "soundtrack_instrumental": "影视/纯音",
68 "ys_ost": "影视OST",
69 "pur_music": "纯音乐",
70 "no_lyric": "无词BGM",
71
72 "other_music": "其他",
73 "jazz_blue": "爵士/蓝调",
74 "voice_book": "有声书",
75 "lab_music": "实验音乐",
76 "healing": "治愈",
77 "melancholy": "伤感",
78 "lonely": "孤独",
79 "sweet": "甜蜜",
80 "inspiring": "励志",
81 "missing": "思念",
82 "nostalgic": "怀旧",
83 "angry": "愤怒",
84 "relaxing": "放松",
85 "catchy": "魔性洗脑",
86 "heroic": "悲壮",
87 "calm": "平静",
88 "festive": "喜庆",
89 "romantic": "浪漫",
90 "majestic": "雄壮",
91 "bewitching": "蛊惑",
92 "cathartic": "宣泄",
93 "solemn": "庄重",
94 "passionate": "激情",
95 "heavy": "沉重",
96 "happy": "快乐",
97 "tense": "紧张",
98 "horror": "恐怖",
99 "touching": "感动",
100 "spoof": "恶搞",
101 "funny": "搞笑",
102 "expectation": "期待",
103 "remembrance": "怀念",
104 "mysterious": "悬疑",
105 "blessing": "祝福",
106 "zen": "佛系",
107 "soothing": "舒缓",
108 "melodious": "悠扬",
109 "warm": "温暖",
110 "depressed": "忧郁",
111 "elderly": "老年",
112 "middle_aged": "中年",
113 "young_adult": "青年",
114 "teenager": "少年",
115 "life_scene": "生活场景",
116 "sports": "运动",
117 "driving": "开车",
118 "travel": "旅行",
119 "sleep": "睡前",
120 "study": "学习",
121 "cafe": "咖啡厅",
122 "bar": "酒吧",
123 "douyin":"抖音",
124 "restaurant": "餐厅",
125 "car_scene": "汽车",
126 "dance": "跳舞",
127 "work": "工作",
128 "nightclub": "夜店",
129 "leisure": "休闲",
130 "live_house": "live house",
131 "square_dance": "广场舞",
132 "wedding": "婚礼",
133 "dating": "约会",
134 "festival_scene": "节日场景",
135 "summer": "夏天",
136 "winter": "冬天",
137 "autumn": "秋天",
138 "spring_festival": "春节",
139 "christmas": "圣诞",
140 "valentine": "情人节",
141 "time_scene": "时间场景",
142 "morning": "清晨",
143 "afternoon": "午后",
144 "evening": "夜晚",
145 "midnight": "深夜",
146 "regional_scene": "地域场景",
147 "campus": "校园",
148 "city": "城市",
149 "grassland": "草原",
150 "tibet": "西藏",
151 "xinjiang": "新疆",
152 "transition_style": "转场类",
153 "card_point_switch": "卡点切换画面类",
154 "reverse_suspense": "反转悬念类",
155 "emotion_contrast": "情绪对比类",
156 "mashup_collection": "混剪合集类",
157 "emotional_resonance": "情感共鸣向剪辑",
158 "scene_adaptation": "场景适配剪辑",
159 "highlight_slice": "高光切片剪辑",
160 "live_performance": "现场表演类",
161 "singer_live": "歌手现场演唱",
162 "talent_cover": "达人翻唱表演",
163 "audience_interaction": "观众互动表演",
164 "card_point_speed": "卡点、变速类",
165 "multi_scene_fragment": "多场景碎片化卡点",
166 "tech_effect_speed": "技术流特效变速",
167 "lyric_concrete": "歌词具象化卡点",
168 "loop_speed_brainwash": "循环变速洗脑",
169 "ugc_co_creation": "UGC共创类",
170 "jianying_template": "剪映模板",
171 "ai_singing": "AI唱歌",
172 "emotional_quotes": "情感语录类",
173 "late_night_emo": "深夜emo类",
174 "morning_inspiration": "清晨励志类",
175 "memory_destiny": "回忆杀/宿命感类",
176 "dynamic_lyrics_visual": "动态歌词可视化",
177 "basic_lyrics_effect": "基础歌词动效",
178 "creative_visual_enhance": "创意视觉强化",
179 "adaptation": "改编",
180 "special_effects_interaction": "特效互动类",
181 "gesture_magic_effect": "手势魔法特效互动",
182 "lip_sync_challenge": "对口型挑战",
183 "douyin_effect_show": "抖音特效变装秀",
184 # 听感演绎流
185 "singing_montage": "演唱混剪",
186 "live_singing": "现场演唱",
187
188 # 视觉冲击流
189 "change_transition": "变装转场",
190 "hand_dance": "手势舞",
191 "addictive_dance": "魔性舞蹈",
192 "landscape_account": "风景号",
193
194 # 氛围素材流
195 "cute_pets": "萌宠",
196 "movie_anime_edit": "影视剧/动漫混剪",
197 "chinese_classical": "古风",
198 "mood_post": "图文心情",
199
200 # 情感共鸣流
201 "animated_lyrics": "动态歌词",
202 "storytelling": "故事演绎",
203 "beauty_snaps": "颜值随拍"
204 }
205
206
207 # 模型相关配置
208 BASE_MODEL = "/data/qufeng/models--MIT--ast-finetuned-audioset-10-10-0.4593/snapshots/f826b80d28226b62986cc218e5cec390b1096902"
209 MOE_DIR = "/data/qufeng/moe_outputs"
210 BASELINE_CHECKPOINT = "/data/qufeng/best_epoch_base.pt"
211 LABEL_MAPPING = "/data/qufeng/label_mapping.txt"
212 DEVICE = "cuda" # 可选: cuda/mps/cpu,为空时自动选择
213 ROUTER_CHECKPOINT = "" # 为空时自动从 moe_dir/joint_train/joint/router_best.pt 推断
214 EXPERTS_DIR = "" # 为空时自动从 moe_dir/experts_train/experts 推断
215
216 # 音频处理配置
217 CHUNK_SECONDS = 10.24 # 按多少秒切块推理
218 CROP_SECONDS = 204.8 # 若音频超过该时长,则仅截取中间这段再切块
219 MAX_CHUNKS = 10 # 每首歌最多使用多少个切片参与推理
220 CHUNK_BATCH_SIZE = 8 # 切块推理的 batch size
221 ROUTING_THRESHOLD = 0.6
222
223 API_CONFIG = {
224 "api_key": "sk-d9b4d3581bde47d887354f9160a509a2",
225 "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
226 "model": "qwen3-omni-flash",
227 "audio_mode": "auto",
228 "timeout": 15,
229 "lyrics_timeout": 60,
230 "lyrics_retries": 2,
231 "max_retries": 5,
232 "retry_delay": 5
233 }
234 # API_CONFIG_91 = {
235 # "api_key": "sk-E90VNVMyhfk2zDBDoToCXoipzGofD2SobwBqaCzbG3junlob",
236 # "base_url": "https://api.91aopusi.com/v1",
237 # "model": "qwen3-omni-flash",
238 # "audio_mode": "auto",
239 # "timeout": 30,
240 # "lyrics_timeout": 60,
241 # "max_retries": 5,
242 # "retry_delay": 5
243 # }
244
245 DASHSCOPE_API_KEY = 'sk-d9b4d3581bde47d887354f9160a509a2'
246
247 OSS_ACCESS_KEY_ID='LTAI4G7UvaW2e4UTCb3KCNjN'
248 OSS_ACCESS_KEY_SECRET='ow5hlVMmJAQY9o7nEAtMER6MFkPedm'
249 OSS_ENDPOINT='oss-cn-hangzhou.aliyuncs.com'
250 OSS_ENDPOINT_INTERNAL='oss-cn-hangzhou-internal.aliyuncs.com'
251 OSS_BUCKET_NAME='ai-sound-data-test'
...\ No newline at end of file ...\ No newline at end of file
1 import logging.handlers
2 import os
3 from config import DEBUG
4
5 log_dir = "./logs"
6 log_max_bytes = 1024 * 1024 * 10
7 log_backup_count = 5
8
9
10 def get_logger(name, level=None):
11 if not level:
12 level = logging.DEBUG if DEBUG else logging.INFO
13
14 # 配置日志
15 logger = logging.getLogger(name)
16 logger.setLevel(level)
17 # 检查日志目录是否存在,如果不存在则创建
18 if not os.path.exists(log_dir):
19 os.makedirs(log_dir)
20
21 # 创建一个handler,用于写入日志文件
22 file_handler = logging.handlers.RotatingFileHandler(f'./{log_dir}/{name}.log', maxBytes=log_max_bytes,
23 backupCount=log_backup_count,encoding='utf-8')
24 file_handler.setLevel(level)
25
26 # 定义handler的输出格式
27 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
28 file_handler.setFormatter(formatter)
29
30 # 给logger添加handler
31 logger.addHandler(file_handler)
32 return logger
33
34
35 # 定义一个模块级别的变量来存储日志记录器实例
36 _app_logger = None
37
38
39 def get_app_logger():
40 global _app_logger
41 if _app_logger is None:
42 _app_logger = get_logger("app")
43 return _app_logger
1 # -*- coding: utf-8 -*-
2 """Batch analyze audio URLs from an xlsx file and export results to xlsx."""
3
4 from __future__ import annotations
5
6 import argparse
7 import json
8 import math
9 import os
10 import sys
11 import traceback
12 from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
13 from pathlib import Path
14 from typing import Any
15
16 import pandas as pd
17
18 # 允许直接 `python pipeline/batch_analyze_xlsx.py` 运行
19 PROJECT_ROOT = Path(__file__).resolve().parent.parent
20 if str(PROJECT_ROOT) not in sys.path:
21 sys.path.insert(0, str(PROJECT_ROOT))
22
23 from app.middleware.music_analyze import analyze_music
24
25
26 DEFAULT_OUTPUT_COLUMNS = [
27 "tmeid",
28 "歌曲ID",
29 "歌曲名",
30 "表演者",
31 "歌曲时长",
32 "表演者类型",
33 "语种",
34 "BPM速度",
35 "情绪",
36 "网络/抖音歌曲",
37 "音乐风格",
38 "配器",
39 "场景",
40 ]
41
42 ANALYZE_COLUMNS = [
43 "表演者类型",
44 "语种",
45 "BPM速度",
46 "情绪",
47 "网络/抖音歌曲",
48 "音乐风格",
49 "配器",
50 "场景",
51 ]
52
53
54 def _is_blank(value: Any) -> bool:
55 if value is None:
56 return True
57 if isinstance(value, float) and math.isnan(value):
58 return True
59 return str(value).strip() == ""
60
61
62 def _join_multi_value(value: Any) -> str:
63 if value is None:
64 return ""
65 if isinstance(value, str):
66 return value.strip()
67 if isinstance(value, list):
68 parts = [str(v).strip() for v in value if str(v).strip()]
69 return "、".join(parts)
70 return str(value).strip()
71
72
73 def _pick_first_non_blank(row: pd.Series, candidates: list[str]) -> str:
74 for col in candidates:
75 if col in row.index and not _is_blank(row[col]):
76 value = row[col]
77 if isinstance(value, float) and value.is_integer():
78 return str(int(value))
79 return str(value).strip()
80 return ""
81
82
83 def _normalize_key_value(value: Any) -> str:
84 if _is_blank(value):
85 return ""
86 if isinstance(value, float) and value.is_integer():
87 return str(int(value))
88 return str(value).strip()
89
90
91 def _resolve_url_column(df: pd.DataFrame, requested_column: str) -> str:
92 if requested_column in df.columns:
93 return requested_column
94
95 candidates = ["URL", "url", "cos访问地址", "cos_url", "audio_url"]
96 for col in candidates:
97 if col in df.columns:
98 print(
99 f"[run] url column `{requested_column}` not found, fallback to `{col}`"
100 )
101 return col
102
103 raise ValueError(
104 f"column `{requested_column}` not found, available={list(df.columns)}"
105 )
106
107
108 def _is_row_completed(out_df: pd.DataFrame, idx: int) -> bool:
109 for col in ANALYZE_COLUMNS:
110 if col not in out_df.columns:
111 continue
112 value = out_df.at[idx, col]
113 if not _is_blank(value):
114 return True
115 return False
116
117
118 def _resolve_checkpoint_path(output_path: Path, checkpoint_path: Path | None) -> Path:
119 if checkpoint_path is not None:
120 return checkpoint_path
121 return output_path.with_suffix(output_path.suffix + ".checkpoint.json")
122
123
124 def _save_progress(
125 out_df: pd.DataFrame,
126 output_path: Path,
127 checkpoint_path: Path,
128 completed_indices: set[int],
129 ) -> None:
130 output_path.parent.mkdir(parents=True, exist_ok=True)
131
132 tmp_output = output_path.with_suffix(output_path.suffix + ".tmp")
133 out_df = out_df[DEFAULT_OUTPUT_COLUMNS]
134 out_df.to_excel(tmp_output, index=False)
135 tmp_output.replace(output_path)
136
137 payload = {
138 "completed_indices": sorted(completed_indices),
139 "completed_count": len(completed_indices),
140 "total": int(len(out_df)),
141 }
142 tmp_checkpoint = checkpoint_path.with_suffix(checkpoint_path.suffix + ".tmp")
143 tmp_checkpoint.write_text(
144 json.dumps(payload, ensure_ascii=False, indent=2),
145 encoding="utf-8",
146 )
147 tmp_checkpoint.replace(checkpoint_path)
148
149
150 def _load_checkpoint(checkpoint_path: Path) -> set[int]:
151 if not checkpoint_path.exists():
152 return set()
153 try:
154 payload = json.loads(checkpoint_path.read_text(encoding="utf-8"))
155 values = payload.get("completed_indices", [])
156 return {int(v) for v in values if isinstance(v, int) or str(v).isdigit()}
157 except Exception:
158 return set()
159
160
161 def _filter_checkpoint_indices(
162 checkpoint_indices: set[int],
163 out_df: pd.DataFrame,
164 df: pd.DataFrame,
165 url_column: str,
166 ) -> set[int]:
167 """
168 过滤 checkpoint 中的索引:
169 - 保留已存在分析结果的行(避免重复分析)
170 - 保留当前仍为空 URL 的行(继续跳过)
171 - 若 URL 已补齐且该行无分析结果,则不保留(允许后续补分析)
172 """
173 filtered: set[int] = set()
174 for idx in checkpoint_indices:
175 if idx < 0 or idx >= len(out_df):
176 continue
177 if _is_row_completed(out_df, idx):
178 filtered.add(idx)
179 continue
180 url = df.at[idx, url_column] if url_column in df.columns else None
181 if _is_blank(url):
182 filtered.add(idx)
183 return filtered
184
185
186 def _build_metadata(row: pd.Series, metadata_columns: list[str]) -> dict[str, Any]:
187 metadata: dict[str, Any] = {}
188 # 关键字段自动透传,避免遗漏导致下游无法建立映射
189 for col in ["歌曲ID", "song_id", "id"]:
190 if col in row.index and not _is_blank(row[col]):
191 metadata[col] = row[col]
192 break
193 for col in ["tmeid", "tmeID", "TMEID"]:
194 if col in row.index and not _is_blank(row[col]):
195 metadata["tmeid"] = row[col]
196 break
197 for col in metadata_columns:
198 if col in row.index and not _is_blank(row[col]):
199 metadata[col] = row[col]
200 return metadata
201
202
203 def _normalize_result(result: dict[str, Any]) -> dict[str, Any]:
204 return {
205 "表演者类型": (
206 str(result.get("performer_type") or result.get("vocal_texture") or "").strip()
207 ),
208 "语种": str(result.get("language") or "").strip(),
209 "BPM速度": result.get("bpm"),
210 "情绪": _join_multi_value(result.get("emotion", [])),
211 "网络/抖音歌曲": _join_multi_value(result.get("douyin_tags", [])),
212 "音乐风格": _join_multi_value(
213 result.get("music_style_tags", [])
214 or [v for v in [result.get("genre"), result.get("sub_genre")] if v]
215 ),
216 "配器": _join_multi_value(result.get("instrument_tags", [])),
217 "场景": _join_multi_value(result.get("scene", [])),
218 }
219
220
221 def _build_song_tmeid_maps(df: pd.DataFrame) -> tuple[dict[str, int], dict[str, int]]:
222 song_id_map: dict[str, int] = {}
223 tmeid_map: dict[str, int] = {}
224 for idx, row in df.iterrows():
225 song_id = _pick_first_non_blank(row, ["歌曲ID", "song_id", "id"])
226 tmeid = _pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"])
227 if song_id and song_id not in song_id_map:
228 song_id_map[song_id] = int(idx)
229 if tmeid and tmeid not in tmeid_map:
230 tmeid_map[tmeid] = int(idx)
231 return song_id_map, tmeid_map
232
233
234 def _resume_from_existing_by_keys(out_df: pd.DataFrame, existing: pd.DataFrame) -> set[int]:
235 """当输入行数变化时,按 歌曲ID/tmeid 匹配复用旧结果。"""
236 completed_indices: set[int] = set()
237 if existing.empty:
238 return completed_indices
239
240 old_song_map, old_tmeid_map = _build_song_tmeid_maps(existing)
241
242 reused = 0
243 reused_by_song = 0
244 reused_by_tmeid = 0
245 for idx in out_df.index:
246 song_id = _normalize_key_value(out_df.at[idx, "歌曲ID"])
247 tmeid = _normalize_key_value(out_df.at[idx, "tmeid"])
248
249 old_idx = None
250 if song_id and song_id in old_song_map:
251 old_idx = old_song_map[song_id]
252 reused_by_song += 1
253 elif tmeid and tmeid in old_tmeid_map:
254 old_idx = old_tmeid_map[tmeid]
255 reused_by_tmeid += 1
256
257 if old_idx is None:
258 continue
259
260 for col in DEFAULT_OUTPUT_COLUMNS:
261 if col in existing.columns:
262 out_df.at[idx, col] = existing.at[old_idx, col]
263
264 if _is_row_completed(out_df, int(idx)):
265 completed_indices.add(int(idx))
266 reused += 1
267
268 print(
269 "[resume] row mismatch, reused by key: "
270 f"song_id_match={reused_by_song}, tmeid_match={reused_by_tmeid}, "
271 f"completed={reused}/{len(out_df)}"
272 )
273 return completed_indices
274
275
276 def _analyze_one(
277 idx: int,
278 row: pd.Series,
279 url_column: str,
280 provider: str,
281 extract_lyrics: bool,
282 label_level: int,
283 metadata_columns: list[str],
284 ) -> tuple[int, dict[str, Any]]:
285 url = row.get(url_column)
286 if _is_blank(url):
287 return idx, {}
288
289 try:
290 metadata = _build_metadata(row, metadata_columns)
291 result = analyze_music(
292 metadata=metadata,
293 music_url=str(url).strip(),
294 provider=provider,
295 extract_lyrics=extract_lyrics,
296 label_level=label_level,
297 )
298 if not result:
299 return idx, {}
300 return idx, _normalize_result(result)
301 except Exception as exc:
302 print(f"[warn] row={idx} analyze failed: {type(exc).__name__}: {exc}")
303 print(traceback.format_exc(limit=3))
304 return idx, {}
305
306
307 def run_batch(
308 input_path: Path,
309 output_path: Path,
310 checkpoint_path: Path | None,
311 url_column: str,
312 provider: str,
313 extract_lyrics: bool,
314 label_level: int,
315 metadata_columns: list[str],
316 workers: int,
317 checkpoint_every: int,
318 resume: bool,
319 ) -> None:
320 df = pd.read_excel(input_path)
321 url_column = _resolve_url_column(df, url_column)
322 checkpoint_path = _resolve_checkpoint_path(output_path, checkpoint_path)
323 blank_url_indices = {int(idx) for idx, value in df[url_column].items() if _is_blank(value)}
324
325 # 先构建参考表基础列(来自输入元数据)
326 out_df = pd.DataFrame(index=df.index)
327 out_df["tmeid"] = [
328 _pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"]) for _, row in df.iterrows()
329 ]
330 out_df["歌曲ID"] = [
331 _pick_first_non_blank(row, ["歌曲ID", "song_id", "id"]) for _, row in df.iterrows()
332 ]
333 out_df["歌曲名"] = [
334 _pick_first_non_blank(row, ["歌曲名", "歌曲名称", "title"]) for _, row in df.iterrows()
335 ]
336 out_df["表演者"] = [
337 _pick_first_non_blank(row, ["表演者", "歌手", "artist"]) for _, row in df.iterrows()
338 ]
339 out_df["歌曲时长"] = [
340 _pick_first_non_blank(row, ["歌曲时长", "duration"]) for _, row in df.iterrows()
341 ]
342
343 for col in DEFAULT_OUTPUT_COLUMNS:
344 if col not in out_df.columns:
345 out_df[col] = ""
346
347 completed_indices: set[int] = set()
348 output_aligned_by_index = False
349 if resume:
350 if output_path.exists():
351 try:
352 existing = pd.read_excel(output_path)
353 if len(existing) == len(out_df):
354 output_aligned_by_index = True
355 for col in DEFAULT_OUTPUT_COLUMNS:
356 if col in existing.columns:
357 out_df[col] = existing[col]
358 for idx in out_df.index:
359 if _is_row_completed(out_df, idx):
360 completed_indices.add(int(idx))
361 print(
362 f"[resume] loaded existing output: {len(completed_indices)}/{len(out_df)} completed"
363 )
364 else:
365 completed_indices |= _resume_from_existing_by_keys(out_df, existing)
366 except Exception as exc:
367 print(f"[resume] failed to read existing output: {type(exc).__name__}: {exc}")
368
369 checkpoint_completed = _load_checkpoint(checkpoint_path)
370 if checkpoint_completed:
371 if output_aligned_by_index:
372 checkpoint_completed = _filter_checkpoint_indices(
373 checkpoint_completed, out_df, df, url_column
374 )
375 before = len(completed_indices)
376 completed_indices |= {idx for idx in checkpoint_completed if 0 <= idx < len(out_df)}
377 if len(completed_indices) != before:
378 print(
379 f"[resume] loaded checkpoint: {len(completed_indices)}/{len(out_df)} completed"
380 )
381 else:
382 print("[resume] ignore checkpoint due to row mismatch with previous output")
383
384 # 空 URL 行直接跳过,不参与分析
385 if blank_url_indices:
386 completed_indices |= blank_url_indices
387 print(f"[run] skip blank `{url_column}` rows: {len(blank_url_indices)}")
388
389 pending_indices = [int(idx) for idx in out_df.index if int(idx) not in completed_indices]
390 if not pending_indices:
391 print("[resume] no pending rows, nothing to do")
392 _save_progress(out_df, output_path, checkpoint_path, completed_indices)
393 return
394
395 print(
396 f"[run] total={len(out_df)}, completed={len(completed_indices)}, pending={len(pending_indices)}"
397 )
398
399 workers = max(1, workers)
400 checkpoint_every = max(1, checkpoint_every)
401 processed_since_checkpoint = 0
402 executor = ThreadPoolExecutor(max_workers=workers)
403 futures = []
404 try:
405 for idx in pending_indices:
406 row = df.iloc[idx]
407 futures.append(
408 executor.submit(
409 _analyze_one,
410 idx,
411 row,
412 url_column,
413 provider,
414 extract_lyrics,
415 label_level,
416 metadata_columns,
417 )
418 )
419
420 pending_futures = set(futures)
421 while pending_futures:
422 done, pending_futures = wait(
423 pending_futures,
424 timeout=1.0,
425 return_when=FIRST_COMPLETED,
426 )
427 if not done:
428 continue
429
430 for future in done:
431 idx, result = future.result()
432 for k, v in result.items():
433 out_df.at[idx, k] = v
434 if result:
435 completed_indices.add(int(idx))
436
437 processed_since_checkpoint += 1
438 if processed_since_checkpoint >= checkpoint_every:
439 _save_progress(out_df, output_path, checkpoint_path, completed_indices)
440 processed_since_checkpoint = 0
441 except KeyboardInterrupt:
442 print("[interrupt] received keyboard interrupt, saving checkpoint...")
443 try:
444 _save_progress(out_df, output_path, checkpoint_path, completed_indices)
445 except Exception as exc:
446 print(f"[interrupt] failed to save checkpoint: {type(exc).__name__}: {exc}")
447
448 for future in futures:
449 future.cancel()
450 executor.shutdown(wait=False, cancel_futures=True)
451 print("[interrupt] force exit to avoid blocking on running worker threads")
452 os._exit(130)
453 finally:
454 try:
455 executor.shutdown(wait=True, cancel_futures=False)
456 except Exception:
457 pass
458
459 _save_progress(out_df, output_path, checkpoint_path, completed_indices)
460
461
462 def parse_args() -> argparse.Namespace:
463 parser = argparse.ArgumentParser(description="Batch audio analysis from xlsx")
464 parser.add_argument("--input", required=True, help="input xlsx path")
465 parser.add_argument("--output", required=True, help="output xlsx path")
466 parser.add_argument(
467 "--checkpoint",
468 default="",
469 help="checkpoint json path (default: <output>.checkpoint.json)",
470 )
471 parser.add_argument("--url-column", default="URL", help="url column name")
472 parser.add_argument("--provider", default="qwen", choices=["qwen", "doubao"])
473 parser.add_argument("--extract-lyrics", action="store_true", help="enable lyrics extraction")
474 parser.add_argument("--label-level", type=int, default=0, choices=[0, 1])
475 parser.add_argument(
476 "--metadata-columns",
477 default="tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者",
478 help="comma separated metadata columns",
479 )
480 parser.add_argument("--workers", type=int, default=3, help="parallel workers")
481 parser.add_argument(
482 "--checkpoint-every",
483 type=int,
484 default=10,
485 help="save checkpoint every N processed rows",
486 )
487 parser.add_argument(
488 "--no-resume",
489 action="store_true",
490 help="disable resume from existing output/checkpoint",
491 )
492 return parser.parse_args()
493
494
495 def main() -> None:
496 args = parse_args()
497 metadata_columns = [c.strip() for c in args.metadata_columns.split(",") if c.strip()]
498
499 run_batch(
500 input_path=Path(args.input),
501 output_path=Path(args.output),
502 checkpoint_path=Path(args.checkpoint) if args.checkpoint.strip() else None,
503 url_column=args.url_column,
504 provider=args.provider,
505 extract_lyrics=args.extract_lyrics,
506 label_level=args.label_level,
507 metadata_columns=metadata_columns,
508 workers=args.workers,
509 checkpoint_every=args.checkpoint_every,
510 resume=not args.no_resume,
511 )
512
513
514 if __name__ == "__main__":
515 main()
1 openai>=1.58.1
2 requests>=2.31.0
3 httpx>=0.28.1
4 python-dotenv>=1.0.1
5 pydantic-settings>=2.6.1
6
7 numpy>=1.24.0
8 scipy>=1.10.0
9 librosa>=0.10.2
10 soundfile>=0.12.1
11
12 pandas>=2.2.0
13 openpyxl>=3.1.2
14
15 # Optional: enable funasr backend in qwen_analyzer
16 # dashscope>=1.20.0