Commit 7bf71620 7bf71620f01eb8ff3bc8ab5cdd8d9832a9780575 by 沈秋雨

Initial commit

0 parents
# Required for qwen
QWEN_API_KEY=sk-d9b4d3581bde47d887354f9160a509a2
QWEN_DASHSCOPE_API_KEY=
QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
QWEN_MODEL=qwen3-omni-flash
QWEN_TIMEOUT=15
QWEN_LYRICS_TIMEOUT=90
QWEN_MAX_RETRIES=3
MUSIC_ANALYZE_LIGHT_MODE=true
MUSIC_DOWNLOAD_DIR=music
MUSIC_MAPPING_FILE=music/music_file_mapping.csv
# Optional song structure service
SONGFORMER_URL=
# Optional ASR backend for lyrics_only path
MUSIC_LYRICS_ASR_BACKEND=funasr
DASHSCOPE_FUNASR_MODEL=fun-asr
DASHSCOPE_BASE_HTTP_API_URL=https://dashscope.aliyuncs.com/api/v1
DASHSCOPE_ASR_POLL_INTERVAL=1
DASHSCOPE_ASR_POLL_TIMEOUT=120
DASHSCOPE_ASR_SUBMIT_URL=https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription
DASHSCOPE_ASR_MODEL=qwen3-asr-flash-filetrans
DASHSCOPE_TASK_STATUS_BASE_URL=https://dashscope.aliyuncs.com/api/v1/tasks
.DS_Store
# Python cache
__pycache__/
*.py[cod]
*.so
.pytest_cache/
.mypy_cache/
# Virtual env
.venv/
venv/
# Local env
.env
# Logs
logs/
*.log
# Runtime outputs
outputs/
music/
*.checkpoint.json
# Local test/sample data
*.xlsx
*.xls
*.csv
# Keep env template and source files
!.env.example
# music_analyze_v2
当前项目是一个基于 Excel 批量跑音频标签分析的独立流水线。
实际主流程:
1. 读取输入 `xlsx`
2. 从指定 URL 列取音频地址
3. 透传部分元数据给音乐分析器
4. 调用 `app.middleware.music_analyze.analyze_music(...)`
5. 将结果整理成固定交付列并持续写回输出 `xlsx`
6. 通过已有输出文件和 checkpoint 支持断点续跑
当前批处理入口是 [`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
## 当前状态
- 可直接运行的主入口:[`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
- 当前默认分析链路:`QwenAnalyzer`
- 当前实际可用 provider:`qwen`
- 提示词来源:[`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode)
- 输出格式:固定交付列,不保留原始全部输入列
说明:
- 命令行参数里虽然还保留了 `--provider doubao` 选项,但当前 [`factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py) 只实例化 `qwen`,传 `doubao` 会在运行时失败。
- README 以下内容按“当前代码实际行为”描述,而不是按历史规划描述。
## 安装
```bash
python3.10 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
cp .env.example .env
```
## 环境变量
最小必需配置通常是:
```env
QWEN_API_KEY=your_api_key
QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
QWEN_MODEL=qwen3-omni-flash
QWEN_TIMEOUT=15
QWEN_LYRICS_TIMEOUT=90
QWEN_MAX_RETRIES=3
```
项目还支持以下可选增强能力:
- `QWEN_DASHSCOPE_API_KEY`:部分 DashScope/ASR 路径会用到
- `SONGFORMER_URL`:启用额外音频结构特征
- `MUSIC_LYRICS_ASR_BACKEND``DASHSCOPE_*`:歌词提取相关配置
- `OSS_*`:音频过大时走 OSS 降级上传
配置定义见 [`app/core/config.py`](/Users/sqy/Downloads/music_analyze_v2/app/core/config.py)
## 输入要求
输入文件必须是 `xlsx`
至少需要一列音频地址。脚本按下面顺序解析 URL 列:
- 显式传入的 `--url-column`
- `URL`
- `url`
- `cos访问地址`
- `cos_url`
- `audio_url`
若整行 URL 为空:
- 不会发起分析
- 该行会被直接跳过
- 在断点续跑里会被视为已处理
元数据不是必填,但建议提供。脚本会优先识别这些字段:
- `歌曲ID` / `song_id` / `id`
- `tmeid` / `tmeID` / `TMEID`
- `歌曲名` / `歌曲名称` / `title`
- `表演者` / `歌手` / `artist`
- `歌曲时长` / `duration`
默认会额外透传这些列给模型作为 metadata:
- `tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者`
可通过 `--metadata-columns` 覆盖。
## 快速开始
常规跑批:
```bash
python pipeline/batch_analyze_xlsx.py \
--input 待分析.xlsx \
--output outputs/标签交付结果.xlsx \
--url-column URL \
--provider qwen \
--workers 3
```
提取歌词:
```bash
python pipeline/batch_analyze_xlsx.py \
--input 待分析.xlsx \
--output outputs/标签交付结果.xlsx \
--url-column URL \
--provider qwen \
--workers 3 \
--extract-lyrics
```
从头重跑,不复用历史输出或 checkpoint:
```bash
python pipeline/batch_analyze_xlsx.py \
--input 待分析.xlsx \
--output outputs/标签交付结果.xlsx \
--provider qwen \
--no-resume
```
## 命令行参数
| 参数 | 说明 | 当前实际行为 |
|------|------|-------------|
| `--input` | 输入 Excel 路径 | 必填 |
| `--output` | 输出 Excel 路径 | 必填 |
| `--checkpoint` | checkpoint 文件路径 | 默认是 `<output>.checkpoint.json` |
| `--url-column` | URL 列名 | 默认 `URL`,不存在时会自动 fallback |
| `--provider` | 分析 provider | 参数允许 `qwen`/`doubao`,当前实际只应使用 `qwen` |
| `--extract-lyrics` | 是否提取歌词 | 开启后会走带歌词分析路径 |
| `--label-level` | 标签级别 | `0``1` |
| `--metadata-columns` | 额外透传给模型的列 | 逗号分隔 |
| `--workers` | 并发线程数 | 默认 `3` |
| `--checkpoint-every` | 每处理多少行保存一次 | 默认 `10` |
| `--no-resume` | 禁用断点续跑 | 默认关闭 |
## 输出结构
脚本输出的是固定交付表,不是“原始输入列 + 分析列”的全量回写。
当前输出列定义在 [`batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)`DEFAULT_OUTPUT_COLUMNS`
- `tmeid`
- `歌曲ID`
- `歌曲名`
- `表演者`
- `歌曲时长`
- `表演者类型`
- `语种`
- `BPM速度`
- `情绪`
- `网络/抖音歌曲`
- `音乐风格`
- `配器`
- `场景`
结果字段映射规则:
- `表演者类型` <- `performer_type``vocal_texture`
- `语种` <- `language`
- `BPM速度` <- `bpm`
- `情绪` <- `emotion`
- `网络/抖音歌曲` <- `douyin_tags`
- `音乐风格` <- `music_style_tags`,否则回退到 `genre/sub_genre`
- `配器` <- `instrument_tags`
- `场景` <- `scene`
列表型字段会被拼成 `、` 分隔字符串。
## 断点续跑
当前断点续跑逻辑比 README 旧版描述更具体,实际行为如下:
- 如果输出文件已存在,且行数与本次输入一致:
直接按行号复用历史输出
- 如果输出文件已存在,但行数不一致:
尝试按 `歌曲ID``tmeid` 复用旧结果
- 如果 checkpoint 存在:
会在“输出按索引对齐”的前提下合并 checkpoint 完成状态
- 空 URL 行会直接加入 completed 集合
- 处理中按 `--checkpoint-every` 周期性落盘
- `Ctrl+C` 时会先保存当前进度,再强制退出避免卡住线程
默认 checkpoint 文件名:
```text
<output>.checkpoint.json
```
## 提示词与分析链路
批处理脚本本身不直接读取 prompt 文件,而是走统一分析入口:
[`pipeline/batch_analyze_xlsx.py`](/Users/sqy/Downloads/music_analyze_v2/pipeline/batch_analyze_xlsx.py)
-> [`app/middleware/music_analyze/__init__.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/__init__.py)
-> [`app/middleware/music_analyze/music_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/music_analyzer.py)
-> [`app/middleware/music_analyze/factory.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/factory.py)
-> [`app/middleware/music_analyze/qwen_analyzer.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/qwen_analyzer.py)
-> [`app/middleware/music_analyze/prompts.py`](/Users/sqy/Downloads/music_analyze_v2/app/middleware/music_analyze/prompts.py)
当前 prompt 目录固定为:
- [`music_analyze_system_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt.md)
- [`music_analyze_system_prompt_part_a.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_a.md)
- [`music_analyze_system_prompt_part_b.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_system_prompt_part_b.md)
- [`music_analyze_user_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_analyze_user_prompt.md)
- [`music_lyrics_only_prompt.md`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode/music_lyrics_only_prompt.md)
## 项目结构
```text
music_analyze_v2/
├── app/
│ ├── core/
│ │ └── config.py
│ ├── middleware/
│ │ └── music_analyze/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── factory.py
│ │ ├── music_analyzer.py
│ │ ├── prompts.py
│ │ ├── qwen_analyzer.py
│ │ ├── doubao_analyzer.py
│ │ ├── audio_features.py
│ │ └── bpm_analyzer_tools.py
│ ├── prompts/
│ │ └── step2_music_decode/
│ └── utils/
├── pipeline/
│ └── batch_analyze_xlsx.py
├── outputs/
├── requirements.txt
├── .env
├── .env.example
└── README.md
```
## 依赖
基础依赖见 [`requirements.txt`](/Users/sqy/Downloads/music_analyze_v2/requirements.txt)
当前显式包含:
- `openai`
- `requests`
- `httpx`
- `python-dotenv`
- `pydantic-settings`
- `numpy`
- `scipy`
- `librosa`
- `soundfile`
- `pandas`
- `openpyxl`
`dashscope``requirements.txt` 中仍是注释状态;如果你要跑依赖该 SDK 的歌词路径,需要自行安装并校验对应代码分支。
## 常见问题
### 为什么传了 `--provider doubao` 还是失败?
因为当前 CLI 还保留了 `doubao` 选项,但分析器工厂只支持 `qwen`。这是代码现状,不是使用方式问题。
### 输出为什么没有保留原 Excel 的全部列?
因为当前脚本在保存时只写 `DEFAULT_OUTPUT_COLUMNS`,这是代码的固定行为。
### 修改提示词应该改哪里?
[`app/prompts/step2_music_decode`](/Users/sqy/Downloads/music_analyze_v2/app/prompts/step2_music_decode) 下的模板文件即可。
### 行数变了还能续跑吗?
可以部分复用。脚本会尝试按 `歌曲ID``tmeid` 匹配历史输出。
### 如何完全重跑?
`--no-resume`,并删除旧输出和旧 checkpoint,最干净。
"""Standalone audio analysis package."""
from .config import settings
__all__ = ["settings"]
"""Minimal settings for standalone audio analysis pipeline."""
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Qwen
QWEN_API_KEY: str | None = None
QWEN_DASHSCOPE_API_KEY: str | None = None
QWEN_BASE_URL: str | None = "https://dashscope.aliyuncs.com/compatible-mode/v1"
QWEN_MODEL: str | None = "qwen3-omni-flash"
QWEN_TIMEOUT: float = 15.0
QWEN_LYRICS_TIMEOUT: float = 90.0
QWEN_MAX_RETRIES: int = 3
MUSIC_ANALYZE_LIGHT_MODE: bool = True
MUSIC_DOWNLOAD_DIR: str = "music"
MUSIC_MAPPING_FILE: str = "music/music_file_mapping.csv"
# Optional features
SONGFORMER_URL: str | None = None
# DashScope ASR
DASHSCOPE_FUNASR_MODEL: str = "fun-asr"
DASHSCOPE_BASE_HTTP_API_URL: str = "https://dashscope.aliyuncs.com/api/v1"
DASHSCOPE_ASR_POLL_INTERVAL: float = 1.0
DASHSCOPE_ASR_POLL_TIMEOUT: float = 120.0
DASHSCOPE_ASR_SUBMIT_URL: str = (
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
)
DASHSCOPE_ASR_MODEL: str = "qwen3-asr-flash-filetrans"
DASHSCOPE_TASK_STATUS_BASE_URL: str = "https://dashscope.aliyuncs.com/api/v1/tasks"
# OSS
OSS_ACCESS_KEY_ID: str | None = None
OSS_ACCESS_KEY_SECRET: str | None = None
OSS_ENDPOINT: str | None = None
OSS_BUCKET_NAME: str | None = None
OSS_ENDPOINT_INTERNAL: str | None = None
settings = Settings()
"""
自定义异常定义
所有业务异常都应该继承自 APIException,
由全局异常处理器统一处理并返回标准格式的错误响应
"""
from fastapi import HTTPException, status
from typing import Optional, Any
class APIException(HTTPException):
"""
API基础异常
所有业务异常的基类,可以被全局异常处理器捕获和统一处理
"""
def __init__(
self,
status_code: int = status.HTTP_400_BAD_REQUEST,
detail: str = None,
error_code: str = None,
data: Any = None,
headers: dict = None,
):
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.error_code = error_code or "UNKNOWN_ERROR"
self.data = data
class UnauthorizedException(APIException):
"""未授权异常 - 认证失败"""
def __init__(self, detail: str = "未授权", error_code: str = "UNAUTHORIZED"):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
error_code=error_code
)
class ForbiddenException(APIException):
"""禁止访问异常 - 权限不足"""
def __init__(self, detail: str = "禁止访问", error_code: str = "FORBIDDEN"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
error_code=error_code
)
class NotFoundException(APIException):
"""资源不存在异常"""
def __init__(self, detail: str = "资源不存在", error_code: str = "NOT_FOUND"):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
error_code=error_code
)
class ConflictException(APIException):
"""冲突异常 - 资源已存在"""
def __init__(self, detail: str = "资源已存在", error_code: str = "CONFLICT"):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
error_code=error_code
)
class ValidationException(APIException):
"""验证异常 - 输入验证失败"""
def __init__(self, detail: str = "验证失败", error_code: str = "VALIDATION_ERROR"):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=detail,
error_code=error_code
)
class BusinessException(APIException):
"""业务异常 - 业务规则验证失败"""
def __init__(
self,
detail: str = "业务操作失败",
error_code: str = "BUSINESS_ERROR",
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(
status_code=status_code,
detail=detail,
error_code=error_code
)
class InternalServerException(APIException):
"""内部服务器异常"""
def __init__(
self,
detail: str = "内部服务器错误",
error_code: str = "INTERNAL_SERVER_ERROR",
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
error_code=error_code
)
class DatabaseException(APIException):
"""数据库异常"""
def __init__(
self,
detail: str = "数据库操作失败",
error_code: str = "DATABASE_ERROR",
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
error_code=error_code
)
class ExternalServiceException(APIException):
"""外部服务异常 - 调用第三方服务失败"""
def __init__(
self,
detail: str = "外部服务调用失败",
error_code: str = "EXTERNAL_SERVICE_ERROR",
):
super().__init__(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=detail,
error_code=error_code
)
class RateLimitException(APIException):
"""限流异常 - 请求过于频繁"""
def __init__(
self,
detail: str = "请求过于频繁,请稍后再试",
error_code: str = "RATE_LIMIT_EXCEEDED",
):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=detail,
error_code=error_code
)
"""Middleware package."""
"""
音乐分析模块
提供统一的音乐标签分析功能,支持通义千问和火山引擎豆包
主要功能:
- 音乐风格识别(与国际音乐分类体系对齐)
- 情绪识别
- 人声质感识别
- 语种识别
- 节奏强度分析(1-5,用于指导视频剪辑)
- 高潮点识别
- 视觉概念生成(用于MV创作)
- 歌词识别(可选)
支持的提供商:
- qwen: 通义千问 (qwen3-omni-flash)
- doubao: 火山引擎豆包 (doubao-seed-1-8-251228)
使用示例:
from app.middleware.music_analyze import analyze_music
# 基本分析
result = analyze_music(
metadata={"title": "稻香", "artist": "周杰伦"},
music_url="https://example.com/music.mp3",
provider="qwen",
)
# 含歌词识别
result = analyze_music(
metadata={"title": "稻香"},
music_url="https://example.com/music.mp3",
provider="qwen",
extract_lyrics=True,
)
"""
# 主函数导出
from .music_analyzer import (
analyze_music,
analyze_music_lyrics_only,
analyze_music_with_qwen,
analyze_music_with_doubao,
get_available_providers,
)
# 类导出
from .base import AudioAnalyzer
from .qwen_analyzer import QwenAnalyzer
from .doubao_analyzer import DoubaoAnalyzer
from .factory import AnalyzerFactory
__version__ = "1.0.0"
"""
音频特征提取模块
提供音频特征提取、节奏强度和能量级别计算功能
"""
import os
import warnings
import numpy as np
import librosa
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from .bpm_analyzer_tools import RealtimeBPMAnalyzerTest
# 抑制 librosa 的 audioread 弃用警告
warnings.filterwarnings("ignore", category=FutureWarning, module="librosa")
@dataclass
class AudioFeatures:
"""音频特征数据"""
# 时域特征
rms_energy: np.ndarray # RMS 能量 (帧级别)
rms_times: np.ndarray # 对应的时间戳
# 频域特征
spectral_centroid: np.ndarray # 频谱质心 (亮度)
spectral_rolloff: np.ndarray # 频谱滚降 (低频占比)
spectral_bandwidth: np.ndarray # 频谱带宽
# 节奏特征
onset_strength: np.ndarray # onset 强度
tempo: float # BPM
# 统计信息
duration: float
sr: int
def extract_audio_features(audio_path: str, hop_length: int = 512) -> AudioFeatures:
"""
提取音频特征
Args:
audio_path: 音频文件路径
hop_length: 帧移长度 (默认 512 samples ≈ 11.6ms @ 44.1kHz)
Returns:
AudioFeatures: 音频特征对象
"""
# 加载音频
y, sr = librosa.load(audio_path, sr=None, mono=True)
duration = librosa.get_duration(y=y, sr=sr)
# 1. RMS 能量 (时域响度)
rms = librosa.feature.rms(y=y, hop_length=hop_length)[0]
rms_db = librosa.amplitude_to_db(rms, ref=np.max)
rms_times = librosa.frames_to_time(
np.arange(len(rms)), sr=sr, hop_length=hop_length
)
# 2. 频谱特征
spectral_centroid = librosa.feature.spectral_centroid(
y=y, sr=sr, hop_length=hop_length
)[0]
spectral_rolloff = librosa.feature.spectral_rolloff(
y=y, sr=sr, hop_length=hop_length
)[0]
spectral_bandwidth = librosa.feature.spectral_bandwidth(
y=y, sr=sr, hop_length=hop_length
)[0]
# 3. 节奏特征
onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
# 使用统一 BPM 分析入口(带倍频纠正)
bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False)
bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr)
corrected_tempo = bpm_result.get('bpm', 120.0)
return AudioFeatures(
rms_energy=rms_db,
rms_times=rms_times,
spectral_centroid=spectral_centroid,
spectral_rolloff=spectral_rolloff,
spectral_bandwidth=spectral_bandwidth,
onset_strength=onset_env,
tempo=corrected_tempo,
duration=duration,
sr=int(sr),
)
def calculate_rhythm_intensity(features: AudioFeatures) -> int:
"""
根据音频特征计算节奏强度 (1-5)
基于以下因素综合计算:
- BPM (速度)
- Onset 强度 (节奏密度)
- 能量变化 (动态范围)
Args:
features: 音频特征对象
Returns:
int: 节奏强度 (1-5)
"""
tempo = features.tempo
onset = features.onset_strength
rms = features.rms_energy
# 1. BPM 得分 (40-200 BPM 映射到 1-5)
if tempo >= 160:
tempo_score = 5
elif tempo >= 130:
tempo_score = 4
elif tempo >= 100:
tempo_score = 3
elif tempo >= 70:
tempo_score = 2
else:
tempo_score = 1
# 2. Onset 密度得分
onset_mean = np.mean(onset)
onset_max = np.max(onset) if len(onset) > 0 else 1
onset_density = onset_mean / onset_max if onset_max > 0 else 0
if onset_density >= 0.5:
density_score = 5
elif onset_density >= 0.4:
density_score = 4
elif onset_density >= 0.3:
density_score = 3
elif onset_density >= 0.2:
density_score = 2
else:
density_score = 1
# 3. 能量动态得分
rms_std = np.std(rms)
if rms_std >= 15:
dynamic_score = 5
elif rms_std >= 12:
dynamic_score = 4
elif rms_std >= 9:
dynamic_score = 3
elif rms_std >= 6:
dynamic_score = 2
else:
dynamic_score = 1
# 加权平均 (BPM 40%, 密度 35%, 动态 25%)
final_score = tempo_score * 0.4 + density_score * 0.35 + dynamic_score * 0.25
return int(round(final_score))
def calculate_energy_level(
features: AudioFeatures,
) -> Tuple[int, Dict[str, float]]:
"""
计算能量级别 (1-5) 和详细信息
Args:
features: 音频特征对象
Returns:
Tuple[int, Dict]: (能量级别 1-5, 详细信息字典)
"""
# 1. 响度得分 (基于 RMS 能量)
rms_db = features.rms_energy
loudness_normalized = np.clip((rms_db + 60) / 10, 0, 5)
loudness_score = float(np.percentile(loudness_normalized, 75))
# 2. 亮度得分 (基于频谱质心)
centroid = features.spectral_centroid
centroid_normalized = np.clip(centroid / 4000, 0, 1)
brightness_score = float(np.mean(centroid_normalized)) * 5
# 3. 节奏得分 (基于 onset 强度)
onset = features.onset_strength
onset_normalized = np.clip(onset / np.percentile(onset, 90), 0, 1)
rhythm_score = float(np.mean(onset_normalized)) * 5
# 4. BPM 因子
tempo = features.tempo
if tempo > 140:
tempo_factor = 1.3
elif tempo > 120:
tempo_factor = 1.15
elif tempo > 100:
tempo_factor = 1.0
elif tempo > 80:
tempo_factor = 0.9
else:
tempo_factor = 0.8
# 综合计算
weights = {"loudness": 0.40, "brightness": 0.25, "rhythm": 0.35}
composite_score = (
weights["loudness"] * loudness_score
+ weights["brightness"] * brightness_score
+ weights["rhythm"] * rhythm_score
) * tempo_factor
# 映射到 1-5 级别
if composite_score < 1.5:
energy_level = 1
elif composite_score < 2.5:
energy_level = 2
elif composite_score < 3.5:
energy_level = 3
elif composite_score < 4.5:
energy_level = 4
else:
energy_level = 5
details = {
"loudness_score": round(loudness_score, 2),
"brightness_score": round(brightness_score, 2),
"rhythm_score": round(rhythm_score, 2),
"tempo_factor": tempo_factor,
"composite_score": round(composite_score, 2),
}
return energy_level, details
def energy_level_to_string(level: int) -> str:
"""
将能量级别数字转换为字符串描述
Args:
level: 能量级别 (1-5)
Returns:
str: 能量密度描述
"""
mapping = {
1: "舒缓",
2: "柔和",
3: "律动",
4: "强烈",
5: "爆发",
}
return mapping.get(level, "律动")
@dataclass
class BeatInfo:
"""节拍信息"""
beat_timestamps: List[float] # 所有节拍时间点
downbeat_timestamps: List[float] # 强拍时间点(每小节第一拍)
tempo: float # BPM
beat_intervals: List[float] # 节拍间隔(用于检测节奏变化)
@dataclass
class EmotionCurve:
"""情绪曲线数据"""
timestamps: List[float] # 时间点
energy_values: List[float] # 能量值 (0-1)
valence_values: List[float] # 情绪效价 (0-1, 低=悲伤, 高=欢快)
arousal_values: List[float] # 情绪唤醒度 (0-1, 低=平静, 高=激动)
smoothed_curve: List[float] # 平滑后的综合情绪曲线
@dataclass
class SegmentEmotion:
"""段落情绪数据(与 songformer 段落对齐)"""
start: float # 段落开始时间
end: float # 段落结束时间
label: str # 段落标签 (intro/verse/chorus/bridge/outro)
intensity: float # 情绪强度 (0-1)
energy: float # 能量值 (0-1)
valence: float # 效价值 (0-1)
arousal: float # 唤醒度 (0-1)
trend: str # 情绪趋势 (rising/falling/stable/peak)
@dataclass
class BeatDensityInfo:
"""节拍密度信息(用于分镜时长规划)"""
segment_label: str # 段落标签
start: float # 开始时间
end: float # 结束时间
beat_count: int # 节拍数
avg_interval: float # 平均间隔(秒)
density_level: str # sparse/normal/dense/very_dense
recommended_shot_duration: str # 推荐分镜时长
@dataclass
class EnhancedClimaxInfo:
"""增强高潮点信息(包含铺垫/持续/缓冲时长)"""
time: float # 高潮时间点
intensity: str # strong/strongest
buildup_start: float # 铺垫开始时间
buildup_duration: float # 铺垫时长(秒)
climax_duration: float # 高潮持续时长(秒)
winddown_duration: float # 缓冲时长(秒)
def extract_beat_timestamps(audio_path: str) -> BeatInfo:
"""
提取节拍时间戳(卡点)
使用智能 BPM 检测(带倍频纠正)
Args:
audio_path: 音频文件路径
Returns:
BeatInfo: 节拍信息对象
"""
y, sr = librosa.load(audio_path, sr=22050, mono=True)
# 使用统一 BPM 分析入口(带倍频纠正 + beat_times)
bpm_analyzer = RealtimeBPMAnalyzerTest(verbose=False)
bpm_result = bpm_analyzer.analyze_bpm(y=y, sr=sr)
corrected_tempo = bpm_result.get('bpm', 120.0)
# beat_times 已经由 analyze_bpm 根据 BPM 减半情况做了抽样处理
beat_times = np.array(bpm_result.get('beat_times', []))
# 强拍检测(每4拍取第1拍,假设4/4拍)
downbeat_times = beat_times[::4].tolist() if len(beat_times) > 0 else []
# 计算节拍间隔
beat_intervals = np.diff(beat_times).tolist() if len(beat_times) > 1 else []
return BeatInfo(
beat_timestamps=beat_times.tolist(),
downbeat_timestamps=downbeat_times,
tempo=corrected_tempo,
beat_intervals=beat_intervals,
)
def extract_emotion_curve(
audio_path: str,
window_size: float = 2.0, # 窗口大小(秒)
hop_size: float = 0.5 # 步长(秒)
) -> EmotionCurve:
"""
提取情绪曲线
基于音频特征推断情绪:
- Energy (能量): RMS 能量 → 情绪强度
- Valence (效价): 频谱质心 + 大小调 → 正面/负面情绪
- Arousal (唤醒度): 节奏密度 + 能量变化 → 激动/平静
Args:
audio_path: 音频文件路径
window_size: 滑动窗口大小(秒)
hop_size: 滑动步长(秒)
Returns:
EmotionCurve: 情绪曲线数据对象
"""
y, sr = librosa.load(audio_path, sr=None, mono=True)
timestamps: List[float] = []
energy_values: List[float] = []
valence_values: List[float] = []
arousal_values: List[float] = []
# 滑动窗口分析
window_samples = int(window_size * sr)
hop_samples = int(hop_size * sr)
for start_sample in range(0, len(y) - window_samples, hop_samples):
end_sample = start_sample + window_samples
y_window = y[start_sample:end_sample]
t = start_sample / sr
timestamps.append(t)
# 1. Energy: RMS 能量归一化
rms = np.sqrt(np.mean(y_window ** 2))
energy = min(rms / 0.1, 1.0) # 归一化到 0-1
energy_values.append(float(energy))
# 2. Valence: 基于频谱质心(高=明亮=正面)
centroid = librosa.feature.spectral_centroid(y=y_window, sr=sr)[0]
valence = min(np.mean(centroid) / 4000, 1.0)
valence_values.append(float(valence))
# 3. Arousal: 基于 onset 密度和能量变化
onset_env = librosa.onset.onset_strength(y=y_window, sr=sr)
arousal = min(np.mean(onset_env) / 2.0, 1.0)
arousal_values.append(float(arousal))
# 4. 综合情绪曲线(加权平均)
smoothed: List[float] = []
for i in range(len(timestamps)):
# 权重:能量 40%, 唤醒度 40%, 效价 20%
combined = (
energy_values[i] * 0.4 +
arousal_values[i] * 0.4 +
valence_values[i] * 0.2
)
smoothed.append(combined)
# 平滑处理(移动平均)
if len(smoothed) >= 3:
smoothed = np.convolve(smoothed, np.ones(3)/3, mode='same').tolist()
return EmotionCurve(
timestamps=timestamps,
energy_values=energy_values,
valence_values=valence_values,
arousal_values=arousal_values,
smoothed_curve=smoothed,
)
def aggregate_emotion_by_segments(
emotion_curve: EmotionCurve,
segments: List[Dict[str, Any]],
) -> List[SegmentEmotion]:
"""
将情绪曲线按 songformer 段落结构聚合
Args:
emotion_curve: 原始情绪曲线数据
segments: songformer 返回的段落列表,格式为:
[{"start": 0.0, "end": 30.5, "label": "intro"}, ...]
Returns:
List[SegmentEmotion]: 按段落聚合的情绪数据
"""
if not segments or not emotion_curve.timestamps:
return []
result: List[SegmentEmotion] = []
timestamps = np.array(emotion_curve.timestamps)
energy_values = np.array(emotion_curve.energy_values)
valence_values = np.array(emotion_curve.valence_values)
arousal_values = np.array(emotion_curve.arousal_values)
smoothed_values = np.array(emotion_curve.smoothed_curve)
for seg in segments:
start = float(seg.get("start", 0))
end = float(seg.get("end", 0))
label = str(seg.get("label", "unknown"))
# 找出该段落内的数据点索引
mask = (timestamps >= start) & (timestamps < end)
indices = np.where(mask)[0]
if len(indices) == 0:
# 没有数据点落在该段落内,使用默认值
result.append(SegmentEmotion(
start=start,
end=end,
label=label,
intensity=0.5,
energy=0.5,
valence=0.5,
arousal=0.5,
trend="stable",
))
continue
# 计算该段落的平均值
seg_energy = float(np.mean(energy_values[indices]))
seg_valence = float(np.mean(valence_values[indices]))
seg_arousal = float(np.mean(arousal_values[indices]))
seg_intensity = float(np.mean(smoothed_values[indices]))
# 计算情绪趋势
seg_smoothed = smoothed_values[indices]
trend = _calculate_trend(seg_smoothed, seg_intensity)
result.append(SegmentEmotion(
start=start,
end=end,
label=label,
intensity=round(seg_intensity, 3),
energy=round(seg_energy, 3),
valence=round(seg_valence, 3),
arousal=round(seg_arousal, 3),
trend=trend,
))
return result
def _calculate_trend(values: np.ndarray, avg_intensity: float) -> str:
"""
计算情绪趋势
Args:
values: 该段落内的情绪值数组
avg_intensity: 平均情绪强度
Returns:
str: rising/falling/stable/peak
"""
if len(values) < 3:
return "stable"
# 将段落分成前半和后半
mid = len(values) // 2
first_half_avg = float(np.mean(values[:mid]))
second_half_avg = float(np.mean(values[mid:]))
diff = second_half_avg - first_half_avg
threshold = 0.05 # 5% 变化阈值
# 检查是否是高峰(平均强度高且变化不大)
if avg_intensity > 0.7 and abs(diff) < threshold:
return "peak"
if diff > threshold:
return "rising"
elif diff < -threshold:
return "falling"
else:
return "stable"
def extract_segment_emotions(
audio_path: str,
segments: List[Dict[str, Any]],
) -> List[SegmentEmotion]:
"""
一站式提取按段落聚合的情绪数据
Args:
audio_path: 音频文件路径
segments: songformer 返回的段落列表
Returns:
List[SegmentEmotion]: 按段落聚合的情绪数据
"""
emotion_curve = extract_emotion_curve(audio_path)
return aggregate_emotion_by_segments(emotion_curve, segments)
def calculate_beat_density_by_segments(
beat_timestamps: List[float],
segments: List[Dict[str, Any]],
tempo: float = 120.0,
) -> List[BeatDensityInfo]:
"""
按段落计算节拍密度,用于指导分镜时长规划
Args:
beat_timestamps: 节拍时间戳列表
segments: songformer 返回的段落列表,格式为:
[{"start": 0.0, "end": 30.5, "label": "intro"}, ...]
tempo: BPM(用于辅助判断密度级别)
Returns:
List[BeatDensityInfo]: 按段落的节拍密度信息
"""
if not segments or not beat_timestamps:
return []
result: List[BeatDensityInfo] = []
beat_array = np.array(beat_timestamps)
for seg in segments:
start = float(seg.get("start", 0))
end = float(seg.get("end", 0))
label = str(seg.get("label", "unknown"))
# 找出该段落内的节拍
mask = (beat_array >= start) & (beat_array < end)
segment_beats = beat_array[mask]
beat_count = len(segment_beats)
# 计算平均间隔
if beat_count >= 2:
intervals = np.diff(segment_beats)
avg_interval = float(np.mean(intervals))
elif beat_count == 1:
# 只有一个节拍,使用 BPM 估算
avg_interval = 60.0 / tempo
else:
# 没有节拍,使用默认值
avg_interval = 60.0 / tempo
# 根据平均间隔和 BPM 判断密度级别
# 间隔越小 = 密度越高
if avg_interval <= 0.3 or tempo >= 160:
density_level = "very_dense"
recommended_shot_duration = "2-4秒"
elif avg_interval <= 0.45 or tempo >= 130:
density_level = "dense"
recommended_shot_duration = "3-5秒"
elif avg_interval <= 0.6 or tempo >= 100:
density_level = "normal"
recommended_shot_duration = "4-6秒"
else:
density_level = "sparse"
recommended_shot_duration = "6-10秒"
result.append(BeatDensityInfo(
segment_label=label,
start=round(start, 2),
end=round(end, 2),
beat_count=beat_count,
avg_interval=round(avg_interval, 3),
density_level=density_level,
recommended_shot_duration=recommended_shot_duration,
))
return result
def enhance_climax_points(
climax_points: List[Dict[str, Any]],
segments: List[Dict[str, Any]],
music_duration: float,
) -> List[EnhancedClimaxInfo]:
"""
增强高潮点信息,添加铺垫/持续/缓冲时长指导
Args:
climax_points: 原始高潮点列表,格式为:
[{"time": 60.0, "intensity": "strong"}, ...]
segments: songformer 返回的段落列表
music_duration: 音乐总时长(秒)
Returns:
List[EnhancedClimaxInfo]: 增强后的高潮点信息
"""
if not climax_points:
return []
result: List[EnhancedClimaxInfo] = []
# 按时间排序高潮点
sorted_climax = sorted(climax_points, key=lambda x: float(x.get("time", 0)))
for i, climax in enumerate(sorted_climax):
time = float(climax.get("time", 0))
intensity = str(climax.get("intensity", "strong"))
# 根据强度确定时长参数
if intensity == "strongest":
buildup_duration = 10.0 # 最强高潮:更长的铺垫
climax_duration = 20.0 # 更长的高潮持续
winddown_duration = 10.0 # 更长的缓冲
else:
buildup_duration = 5.0 # 普通高潮
climax_duration = 10.0
winddown_duration = 5.0
# 计算铺垫开始时间(不能小于0或前一个高潮的结束)
buildup_start = max(0, time - buildup_duration)
# 如果有前一个高潮点,确保不重叠
if i > 0:
prev_climax_time = float(sorted_climax[i - 1].get("time", 0))
prev_intensity = str(sorted_climax[i - 1].get("intensity", "strong"))
prev_winddown = 10.0 if prev_intensity == "strongest" else 5.0
prev_end = prev_climax_time + prev_winddown
if buildup_start < prev_end:
# 调整铺垫开始时间,避免重叠
buildup_start = prev_end
buildup_duration = time - buildup_start
# 确保高潮持续+缓冲不超过音乐结束
if time + climax_duration + winddown_duration > music_duration:
# 按比例缩减
remaining = music_duration - time
if remaining > 0:
ratio = remaining / (climax_duration + winddown_duration)
climax_duration = climax_duration * ratio
winddown_duration = winddown_duration * ratio
result.append(EnhancedClimaxInfo(
time=round(time, 2),
intensity=intensity,
buildup_start=round(buildup_start, 2),
buildup_duration=round(buildup_duration, 2),
climax_duration=round(climax_duration, 2),
winddown_duration=round(winddown_duration, 2),
))
return result
def format_beat_density_for_prompt(beat_density_list: List[BeatDensityInfo]) -> str:
"""
将节拍密度信息格式化为提示词文本
Args:
beat_density_list: 节拍密度信息列表
Returns:
str: 格式化的文本
"""
if not beat_density_list:
return "(无节拍密度数据)"
lines = []
for info in beat_density_list:
lines.append(
f"- [{info.segment_label}] {info.start:.1f}s-{info.end:.1f}s: "
f"节拍数={info.beat_count}, 平均间隔={info.avg_interval:.2f}s, "
f"密度={info.density_level}, 推荐分镜时长={info.recommended_shot_duration}"
)
return "\n".join(lines)
def format_enhanced_climax_for_prompt(enhanced_climax_list: List[EnhancedClimaxInfo]) -> str:
"""
将增强高潮点信息格式化为提示词文本
Args:
enhanced_climax_list: 增强高潮点信息列表
Returns:
str: 格式化的文本
"""
if not enhanced_climax_list:
return "(无高潮点数据)"
lines = []
for info in enhanced_climax_list:
lines.append(
f"- 高潮点 {info.time:.1f}s ({info.intensity}):\n"
f" · 铺垫阶段: {info.buildup_start:.1f}s - {info.time:.1f}s (约{info.buildup_duration:.1f}秒)\n"
f" · 高潮阶段: {info.time:.1f}s - {info.time + info.climax_duration:.1f}s (约{info.climax_duration:.1f}秒)\n"
f" · 缓冲阶段: {info.time + info.climax_duration:.1f}s - {info.time + info.climax_duration + info.winddown_duration:.1f}s (约{info.winddown_duration:.1f}秒)"
)
return "\n".join(lines)
# -*- coding: utf-8 -*-
"""
音乐分析器抽象基类
定义统一的分析器接口
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any, List, Set
# 字典定义:所有有效的字段值
VALID_GENRES: Set[str] = {
"流行",
"电子/舞曲",
"摇滚/金属",
"说唱",
"民谣/原声",
"国风",
"爵士/Soul",
"古典",
"轻音乐/Ambient",
"二次元/ACG",
"其它",
}
VALID_SUB_GENRES: Dict[str, Set[str]] = {
"流行": {"华语流行", "欧美流行", "日韩流行", "R&B", "抒情"},
"电子/舞曲": {"House", "Future Bass", "Dubstep", "Synthwave", "Trance", "Techno"},
"摇滚/金属": {"流行摇滚", "独立摇滚", "重金属", "朋克", "后摇"},
"说唱": {"Trap", "Old School", "Boombap", "Melodic Rap", "中文说唱"},
"民谣/原声": {"城市民谣", "校园民谣", "故事民谣", "乡村", "Indie Folk"},
"国风": {"古风", "戏腔", "新中式", "水墨风", "国潮"},
"爵士/Soul": {"传统爵士", "Smooth Jazz", "Fusion", "Neo-Soul", "Blues"},
"古典": {"管弦乐", "钢琴曲", "协奏曲", "室内乐", "歌剧"},
"轻音乐/Ambient": {"钢琴独奏", "Lo-fi", "冥想音乐", "氛围电子", "白噪音"},
"二次元/ACG": {"动画OST", "Vocaloid", "游戏音乐", "萌系", "燃系"},
"其它": {"世界音乐", "实验音乐", "儿歌", "戏曲", "网络热歌"},
}
VALID_LANGUAGES: Set[str] = {
"普通话",
"粤语",
"英语",
"韩语",
"闽南语",
"蒙语",
"俄语",
"藏语",
"其他",
}
LANGUAGE_MAPPING: Dict[str, str] = {
"国语": "普通话",
"中文": "普通话",
"汉语": "普通话",
"普通话": "普通话",
"广东话": "粤语",
"粤语": "粤语",
"英文": "英语",
"英语": "英语",
"韩文": "韩语",
"朝鲜语": "韩语",
"韩语": "韩语",
"闽南话": "闽南语",
"台语": "闽南语",
"闽南语": "闽南语",
"蒙语": "蒙语",
"蒙古语": "蒙语",
"俄文": "俄语",
"俄语": "俄语",
"藏文": "藏语",
"藏语": "藏语",
"其它": "其他",
"其他": "其他",
"地方语言": "其他",
"日语": "其他",
}
VALID_EMOTIONS: Set[str] = {
"喜庆",
"浪漫",
"雄壮",
"庄重",
"激情",
"快乐",
"励志",
"期待",
"甜蜜",
"感动",
"搞笑",
"祝福",
"温暖",
"宣泄",
"悲壮",
"愤怒",
"沉重",
"思念",
"紧张",
"恐怖",
"孤独",
"伤感",
"忧郁",
"蛊惑",
"恶搞",
"怀念",
"悬疑",
"佛系",
"舒缓",
"悠扬",
}
VALID_SCENES: Set[str] = {
"餐厅",
"汽车",
"跳舞",
"旅行",
"工作",
"校园",
"夜店",
"运动",
"休闲",
"live house",
"广场舞",
"抖音",
"婚礼",
"约会",
}
VALID_DOUYIN_TAGS: Set[str] = {
"草原",
"故乡",
"神曲",
"文艺",
"青春",
"治愈系",
"清新",
"奇幻",
}
VALID_MUSIC_STYLE_TAGS: Set[str] = {
"世界音乐",
"雷鬼",
"R&B/Soul",
"MC喊麦",
"另类音乐",
"民歌",
"戏曲",
"古风",
"古典音乐",
"HipHop",
"Rap",
"摇滚",
"DJ嗨曲",
"布鲁斯/蓝调",
"拉丁",
"舞曲",
"爵士",
"乡村",
"民谣",
"流行",
"轻音乐",
"国风",
"儿歌",
}
VALID_INSTRUMENT_TAGS: Set[str] = {
"二胡",
"竹笛",
"琵琶",
"音效",
"口琴",
"电子",
"木吉他",
"鼓组",
"弦乐",
"电吉他",
"古筝",
"钢琴",
}
VALID_AGES: Set[str] = {"少年", "青年", "中年", "老年", "全年龄段"}
VALID_RHYTHM_INTENSITIES: Set[str] = {"极慢", "慢", "中", "快", "极速"}
VALID_EMOTIONAL_INTENSITIES: Set[str] = {"平缓", "中等", "强烈"}
VALID_VOICE_TYPES: Set[str] = {"男声", "女声", "童声", "合唱", "无人声"}
VALID_PERFORMER_TYPES: Set[str] = {"男声", "女声", "童声", "合唱"}
# sub_genre 常见变体映射
SUB_GENRE_MAPPING: Dict[str, str] = {
"韩语流行": "日韩流行",
"韩国流行": "日韩流行",
"K-Pop": "日韩流行",
"K-pop": "日韩流行",
"Kpop": "日韩流行",
"韩流": "日韩流行",
"日语流行": "日韩流行",
"日本流行": "日韩流行",
"J-Pop": "日韩流行",
"J-pop": "日韩流行",
"Jpop": "日韩流行",
"中文流行": "华语流行",
"国语流行": "华语流行",
"中国流行": "华语流行",
"英语流行": "欧美流行",
"英文流行": "欧美流行",
"西方流行": "欧美流行",
"Pop": "欧美流行",
}
class AudioAnalyzer(ABC):
"""音乐音频分析器抽象基类"""
@abstractmethod
def get_provider_name(self) -> str:
"""获取提供商名称(如 qwen, doubao)"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""获取模型名称"""
pass
@abstractmethod
def analyze(
self,
metadata: Dict[str, Any],
music_url: str,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""
分析音乐并返回标签结果
Args:
metadata: 音乐元数据字典
music_url: 音乐文件 URL(支持音频 URL 或 Base64 编码)
extract_lyrics: 是否识别歌词
label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
Returns:
标准化分析结果字典,包含以下字段:
- genre: 音乐风格(一级风格,如:流行、摇滚)
- emotion: 情绪列表
- emotional_intensity: 情绪强度
- vocal_texture: 人声质感
- vocal_description: 人声质感描述
- visual_concept: 视觉概念
- language: 语种
- bpm: 节拍数(可选)
- lyrics: 歌词列表(可选,仅当 extract_lyrics=True 时)
- _model: 使用的模型名称
- _token_info: Token 使用信息
"""
pass
def _parse_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""
解析 LLM 返回的响应文本为 JSON
Args:
response_text: LLM 返回的原始文本
Returns:
解析后的字典,解析失败返回 None
"""
import re
import json
import logging
logger = logging.getLogger(__name__)
if not response_text:
return None
# 打印原始响应用于调试
logger.info(f"[_parse_response] 原始响应文本:\n{response_text[:500]}...")
cleaned_text = response_text.strip()
# 移除 markdown 代码块标记
if cleaned_text.startswith("```json"):
cleaned_text = cleaned_text[7:]
elif cleaned_text.startswith("```"):
cleaned_text = cleaned_text[3:]
if cleaned_text.endswith("```"):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text.strip()
# 提取 JSON 对象
try:
# 尝试直接解析
result = json.loads(cleaned_text)
if isinstance(result, dict):
logger.info(f"[_parse_response] 解析成功,字段: {list(result.keys())}")
elif isinstance(result, list):
logger.info(f"[_parse_response] 解析成功,列表长度: {len(result)}")
else:
logger.info(
f"[_parse_response] 解析成功,类型: {type(result).__name__}"
)
return result
except json.JSONDecodeError:
pass
# 尝试提取 {...} 中的内容
try:
match = re.search(r"\{.*\}", cleaned_text, re.DOTALL)
if match:
json_str = match.group()
result = json.loads(json_str)
if isinstance(result, dict):
logger.info(
f"[_parse_response] 正则提取解析成功,字段: {list(result.keys())}"
)
elif isinstance(result, list):
logger.info(
f"[_parse_response] 正则提取解析成功,列表长度: {len(result)}"
)
else:
logger.info(
"[_parse_response] 正则提取解析成功,类型: %s",
type(result).__name__,
)
return result
except (re.error, json.JSONDecodeError):
pass
# 尝试修复常见的 JSON 格式问题
try:
fixed_text = re.sub(r",(\s*})", r"\1", cleaned_text)
fixed_text = re.sub(r",(\s*])", r"\1", fixed_text)
result = json.loads(fixed_text)
if isinstance(result, dict):
logger.info(
f"[_parse_response] 修复后解析成功,字段: {list(result.keys())}"
)
elif isinstance(result, list):
logger.info(
f"[_parse_response] 修复后解析成功,列表长度: {len(result)}"
)
else:
logger.info(
"[_parse_response] 修复后解析成功,类型: %s",
type(result).__name__,
)
return result
except (re.error, json.JSONDecodeError):
pass
logger.warning(f"[_parse_response] 所有解析方法都失败")
return None
def _normalize_result(
self,
raw_result: Dict[str, Any],
model_name: str,
token_info: Optional[Dict[str, int]] = None,
) -> Dict[str, Any]:
"""
标准化分析结果
Args:
raw_result: 原始解析结果
model_name: 使用的模型名称
token_info: Token 使用信息
Returns:
标准化后的结果字典
"""
import logging
logger = logging.getLogger(__name__)
if not isinstance(raw_result, dict):
if (
isinstance(raw_result, list)
and raw_result
and isinstance(raw_result[0], dict)
):
raw_result = raw_result[0]
else:
logger.warning(
f"[_normalize_result] 原始结果类型异常: {type(raw_result).__name__}"
)
return {"_model": model_name, "_raw": raw_result}
logger.info(f"[_normalize_result] 原始结果字段: {list(raw_result.keys())}")
logger.info(f"[_normalize_result] genre: {raw_result.get('genre')}")
logger.info(f"[_normalize_result] emotion: {raw_result.get('emotion')}")
logger.info(f"[_normalize_result] scene: {raw_result.get('scene')}")
logger.info(f"[_normalize_result] token_info 参数: {token_info}")
def _extract_style(raw_style) -> Optional[Dict[str, str]]:
"""提取音乐风格为标准格式"""
if isinstance(raw_style, dict):
return {"zh": raw_style.get("zh", ""), "en": raw_style.get("en", "")}
elif isinstance(raw_style, str):
# 字符串格式,直接使用作为中文名,英文名留空
return {"zh": raw_style, "en": ""}
return None
def _extract_list_field(raw_value) -> list:
"""提取列表字段"""
if isinstance(raw_value, list):
return [v for v in raw_value if v]
elif isinstance(raw_value, str):
import re
return [
v.strip()
for v in re.split(r"[,,、/|]+", raw_value)
if v and v.strip()
]
return []
def _extract_single_field(raw_value) -> str:
"""提取单值字段"""
if raw_value and isinstance(raw_value, str):
return raw_value
return ""
def _validate_and_map_sub_genre(sub_genre: str, genre: str) -> str:
"""验证并映射 sub_genre 到有效值"""
if not sub_genre:
return ""
sub_genre = sub_genre.strip()
if sub_genre in SUB_GENRE_MAPPING:
mapped = SUB_GENRE_MAPPING[sub_genre]
logger.info(
f"[_validate_and_map_sub_genre] 映射 '{sub_genre}' -> '{mapped}'"
)
return mapped
if genre in VALID_SUB_GENRES:
if sub_genre in VALID_SUB_GENRES[genre]:
return sub_genre
for valid_subs in VALID_SUB_GENRES.values():
if sub_genre in valid_subs:
return sub_genre
logger.warning(
f"[_validate_and_map_sub_genre] 无法映射 sub_genre: '{sub_genre}' (genre: '{genre}')"
)
return sub_genre
def _validate_list_field(
values: List[str], valid_set: Set[str], field_name: str
) -> List[str]:
"""严格验证列表字段中的值:仅保留字典内标签"""
result = []
for v in values:
if v in valid_set:
result.append(v)
else:
logger.warning(
f"[_validate_list_field] {field_name} 值 '{v}' 不在字典中,已过滤"
)
return result
def _validate_language(raw_value: Any) -> str:
language = _extract_single_field(raw_value).strip()
if not language:
return ""
mapped = LANGUAGE_MAPPING.get(language, language)
if mapped in VALID_LANGUAGES:
return mapped
logger.warning(
f"[_normalize_result] language '{language}' 不在字典中,已归并为空"
)
return ""
result = {
"genre": "",
"sub_genre": "",
"emotion": [],
"voice_type": "",
"vocal_texture": "",
"vocal_description": "",
"visual_concept": "",
"language": "",
"scene": [],
"age": "",
"is_sinking": None,
"song_description": "",
"performer_type": "",
"music_style_tags": [],
"douyin_tags": [],
"instrument_tags": [],
}
# 音乐风格(一级风格和二级风格)
# 优先使用新格式 genre/sub_genre,兼容旧格式 music_style
raw_genre = raw_result.get("genre", "")
raw_sub_genre = raw_result.get("sub_genre", "")
raw_music_style = raw_result.get("music_style", [])
# 优先从 genre 字段获取一级风格
if isinstance(raw_genre, str) and raw_genre.strip():
result["genre"] = raw_genre.strip()
elif isinstance(raw_genre, dict):
result["genre"] = raw_genre.get("zh", "") or raw_genre.get("en", "")
# 兼容旧格式:从 music_style 数组提取
elif (
raw_music_style
and isinstance(raw_music_style, list)
and len(raw_music_style) > 0
):
first_style = raw_music_style[0]
if isinstance(first_style, dict):
result["genre"] = first_style.get("zh", "") or first_style.get("en", "")
elif isinstance(first_style, str):
result["genre"] = first_style.strip()
# 优先从 sub_genre 字段获取二级风格
if isinstance(raw_sub_genre, str) and raw_sub_genre.strip():
result["sub_genre"] = raw_sub_genre.strip()
elif isinstance(raw_sub_genre, dict):
result["sub_genre"] = raw_sub_genre.get("zh", "") or raw_sub_genre.get(
"en", ""
)
# 兼容旧格式:从 music_style 数组第二个元素提取
elif (
raw_music_style
and isinstance(raw_music_style, list)
and len(raw_music_style) > 1
):
second_style = raw_music_style[1]
if isinstance(second_style, dict):
result["sub_genre"] = second_style.get("zh", "") or second_style.get(
"en", ""
)
elif isinstance(second_style, str):
result["sub_genre"] = second_style.strip()
result["sub_genre"] = _validate_and_map_sub_genre(
result["sub_genre"], result["genre"]
)
# 情绪
raw_emotion = raw_result.get("emotion", [])
if isinstance(raw_emotion, str):
raw_emotion = [raw_emotion]
result["emotion"] = _validate_list_field(
_extract_list_field(raw_emotion), VALID_EMOTIONS, "emotion"
)
# 人声类型
raw_voice_type = raw_result.get("voice_type", "")
if raw_voice_type and isinstance(raw_voice_type, str):
voice_type = raw_voice_type.strip()
if voice_type in VALID_VOICE_TYPES:
result["voice_type"] = voice_type
else:
logger.warning(
f"[_normalize_result] voice_type '{voice_type}' 不在有效值中,保留原值"
)
result["voice_type"] = voice_type
else:
result["voice_type"] = ""
# 人声质感 (LLM返回的是vocal_type)
result["vocal_texture"] = _extract_single_field(
raw_result.get("vocal_type", "")
)
# 人声质感描述
result["vocal_description"] = raw_result.get("vocal_description", "")
# 聚音演唱者类型(优先 performer_type,回退 vocal_type)
raw_performer_type = raw_result.get("performer_type", raw_result.get("vocal_type", ""))
if isinstance(raw_performer_type, str):
performer_type = raw_performer_type.strip()
if performer_type in VALID_PERFORMER_TYPES:
result["performer_type"] = performer_type
elif performer_type in VALID_VOICE_TYPES:
result["performer_type"] = performer_type
# 聚音标签:音乐风格/网络抖音/配器
result["music_style_tags"] = _extract_list_field(
raw_result.get("music_style_tags", raw_result.get("music_style", []))
)
result["douyin_tags"] = _extract_list_field(
raw_result.get("douyin_tags", raw_result.get("network_douyin_tags", []))
)
result["instrument_tags"] = _extract_list_field(
raw_result.get("instrument_tags", raw_result.get("instruments", []))
)
result["music_style_tags"] = _validate_list_field(
result["music_style_tags"], VALID_MUSIC_STYLE_TAGS, "music_style_tags"
)
result["douyin_tags"] = _validate_list_field(
result["douyin_tags"], VALID_DOUYIN_TAGS, "douyin_tags"
)
result["instrument_tags"] = _validate_list_field(
result["instrument_tags"], VALID_INSTRUMENT_TAGS, "instrument_tags"
)
# 视觉概念
result["visual_concept"] = raw_result.get("visual_concept", "")
# 语种
result["language"] = _validate_language(raw_result.get("language", ""))
# 场景(可多选)
raw_scene = raw_result.get("scene", [])
if isinstance(raw_scene, str):
raw_scene = [raw_scene]
if isinstance(raw_scene, list):
scene_list = [s.strip() for s in raw_scene if s and isinstance(s, str)]
result["scene"] = _validate_list_field(scene_list, VALID_SCENES, "scene")
# 适合听众年龄段
raw_age = raw_result.get("age", "")
if raw_age and isinstance(raw_age, str):
result["age"] = raw_age.strip()
# 是否下沉
raw_is_sinking = raw_result.get("is_sinking")
if isinstance(raw_is_sinking, bool):
result["is_sinking"] = raw_is_sinking
elif isinstance(raw_is_sinking, str):
is_sinking_lower = raw_is_sinking.strip().lower()
if is_sinking_lower in ("是", "true", "1", "yes"):
result["is_sinking"] = True
elif is_sinking_lower in ("否", "false", "0", "no"):
result["is_sinking"] = False
# 歌曲描述
raw_song_desc = raw_result.get("song_description", "")
if raw_song_desc and isinstance(raw_song_desc, str):
result["song_description"] = raw_song_desc.strip()
# 情绪强度
raw_emotional_intensity = raw_result.get("emotional_intensity", "")
if raw_emotional_intensity and isinstance(raw_emotional_intensity, str):
result["emotional_intensity"] = raw_emotional_intensity.strip()
# 节奏强度
raw_rhythm_intensity = raw_result.get("rhythm_intensity", "")
if raw_rhythm_intensity and isinstance(raw_rhythm_intensity, str):
result["rhythm_intensity"] = raw_rhythm_intensity.strip()
# BPM 不从 LLM 结果中提取,统一由本地 bpm_analyzer_tools 提供
# 歌词(可选)
if "lyrics" in raw_result:
result["lyrics"] = raw_result["lyrics"]
# 添加模型信息
result["_model"] = model_name
if token_info:
result["_token_info"] = token_info
if "_token_info_parts" in raw_result and isinstance(
raw_result["_token_info_parts"], dict
):
result["_token_info_parts"] = raw_result["_token_info_parts"]
if "_timing" in raw_result and isinstance(raw_result["_timing"], dict):
result["_timing"] = raw_result["_timing"]
return result
#!/usr/bin/env python3
"""
Realtime BPM Analyzer - Python 测试程序
基于 realtime-bpm-analyzer (https://github.com/dlepaux/realtime-bpm-analyzer)
的 Python 实现,用于快速测试音频文件的 BPM。
功能:
1. 快速 BPM 识别
2. 实时特征提取
3. 多算法融合
4. 详细结果导出
使用方法:
python bpm_analyzer_test.py --file music.mp3
python bpm_analyzer_test.py --file music.mp3 --output result.json
python bpm_analyzer_test.py --file music.mp3 --verbose
python bpm_analyzer_test.py --dir /path/to/music/folder
"""
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
import numpy as np
# 导入音频处理库
try:
import librosa
import librosa.beat
import librosa.feature
import librosa.onset
except ImportError:
print("❌ librosa 库未安装,请运行: pip install librosa")
sys.exit(1)
from scipy.signal import find_peaks, correlate
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class RealtimeBPMAnalyzerTest:
"""Realtime BPM Analyzer - Python 版本"""
# BPM 范围(参考 realtime-bpm-analyzer)
BPM_MIN = 30.0
BPM_MAX = 200.0
# 置信度阈值
CONFIDENCE_THRESHOLD = 0.5
def __init__(self, verbose: bool = False):
"""
初始化分析器
Args:
verbose: 是否显示详细信息
"""
self.verbose = verbose
self.sr = 22050 # 采样率
self.hop_length = 512
if verbose:
logger.setLevel(logging.DEBUG)
logger.info("✓ Realtime BPM Analyzer Test 已初始化")
def print_header(self, title: str, width: int = 80):
"""打印标题"""
print("\n" + "=" * width)
print(f" {title}")
print("=" * width)
def analyze_file(self, file_path: str) -> Dict[str, Any]:
"""
分析单个音频文件
Args:
file_path: 音频文件路径
Returns:
分析结果字典
"""
self.print_header("🎵 Realtime BPM Analyzer - 测试程序")
# 验证文件
if not os.path.exists(file_path):
logger.error(f"❌ 文件不存在: {file_path}")
return {'success': False, 'error': '文件不存在'}
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f"📄 音频文件: {Path(file_path).name}")
logger.info(f"📊 文件大小: {file_size_mb:.2f} MB")
logger.info(f"📁 文件路径: {Path(file_path).absolute()}")
self.print_header("📊 分析过程", 80)
try:
# 加载音频
logger.info("🔄 加载音频文件...")
y, sr = librosa.load(file_path, sr=self.sr, mono=True)
duration = len(y) / sr
logger.info(f"✓ 音频加载成功,时长: {duration:.2f} 秒")
# 执行快速分析
logger.info("📈 快速 BPM 检测...")
fast_result = self._fast_bpm_detection(y, sr)
# 执行详细分析
logger.info("📊 详细 BPM 分析...")
detailed_result = self._detailed_bpm_analysis(y, sr)
# 融合结果
logger.info("🔀 融合分析结果...")
final_result = self._fuse_results(fast_result, detailed_result, y=y)
result = {
'success': True,
'file_path': str(Path(file_path).absolute()),
'file_name': Path(file_path).name,
'file_size_mb': round(file_size_mb, 2),
'duration_seconds': round(duration, 2),
'sample_rate': sr,
'timestamp': datetime.now().isoformat(),
'fast_detection': fast_result,
'detailed_analysis': detailed_result,
'final_result': final_result
}
self.print_header("📈 分析结果", 80)
self._display_results(result)
return result
except Exception as e:
logger.error(f"❌ 分析失败: {str(e)}")
if self.verbose:
import traceback
traceback.print_exc()
return {'success': False, 'error': str(e)}
def analyze_directory(self, dir_path: str) -> List[Dict[str, Any]]:
"""
分析文件夹中的所有音频文件
Args:
dir_path: 文件夹路径
Returns:
分析结果列表
"""
self.print_header("🎵 Realtime BPM Analyzer - 批量分析", 80)
if not os.path.isdir(dir_path):
logger.error(f"❌ 文件夹不存在: {dir_path}")
return []
# 查找所有音频文件
audio_extensions = ('.mp3', '.wav', '.flac', '.m4a', '.aac', '.ogg')
audio_files = []
for root, dirs, files in os.walk(dir_path):
for file in files:
if file.lower().endswith(audio_extensions):
audio_files.append(os.path.join(root, file))
logger.info(f"📂 找到 {len(audio_files)} 个音频文件")
results = []
for i, file_path in enumerate(audio_files, 1):
logger.info(f"\n[{i}/{len(audio_files)}] 正在分析...")
result = self.analyze_file(file_path)
results.append(result)
return results
def analyze_bpm(
self,
file_path: str = None,
y: np.ndarray = None,
sr: int = None,
) -> Dict[str, Any]:
"""
统一 BPM 分析入口(供其他模块调用)
支持两种调用方式:
1. 传入 file_path,内部以 sr=22050 加载音频
2. 传入已加载的 y, sr(避免重复加载)
Returns:
{
'bpm': float, # 最终 BPM(经过融合+纠正)
'original_bpm': float, # 快速检测的原始 BPM
'confidence': float,
'beat_times': list, # 节拍时间点列表
}
"""
try:
if y is None and file_path is not None:
if not os.path.exists(file_path):
return {'bpm': 120.0, 'original_bpm': 120.0,
'confidence': 0.0, 'beat_times': []}
y, sr = librosa.load(file_path, sr=self.sr, mono=True)
elif y is None:
return {'bpm': 120.0, 'original_bpm': 120.0,
'confidence': 0.0, 'beat_times': []}
# 快速检测
fast_result = self._fast_bpm_detection(y, sr)
# 详细分析
detailed_result = self._detailed_bpm_analysis(y, sr)
# 融合
final_result = self._fuse_results(fast_result, detailed_result, y=y)
final_bpm = final_result.get('bpm', 120.0)
original_bpm = fast_result.get('original_bpm', final_bpm)
# 获取 beat_times:从 _fast_bpm_detection 内部的 beat_track 获取
_, beat_frames = librosa.beat.beat_track(
y=y, sr=sr, hop_length=self.hop_length
)
if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0:
beat_times = librosa.frames_to_time(
beat_frames, sr=sr, hop_length=self.hop_length
).tolist()
else:
beat_times = []
# 如果 BPM 被减半了,节拍时间点也每隔一个取一个
if final_bpm < original_bpm * 0.75:
beat_times = beat_times[::2]
return {
'bpm': final_bpm,
'original_bpm': original_bpm,
'confidence': final_result.get('confidence', 0.0),
'beat_times': beat_times,
}
except Exception as e:
logger.warning(f"analyze_bpm 失败: {e}")
return {'bpm': 120.0, 'original_bpm': 120.0,
'confidence': 0.0, 'beat_times': []}
def _fast_bpm_detection(self, y: np.ndarray, sr: int) -> Dict[str, Any]:
"""快速 BPM 检测(参考 librosa.beat.tempo)+ 智能节拍层级纠正"""
try:
# 获取 BPM 和节拍时间
tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=self.hop_length)
# 处理 tempo 可能是 ndarray 的情况
if isinstance(tempo, np.ndarray):
bpm = float(tempo[0]) if tempo.size > 0 else 120.0
else:
bpm = float(tempo)
# 处理 beat_frames 可能是 ndarray 的情况
if isinstance(beat_frames, np.ndarray) and beat_frames.size > 0:
beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=self.hop_length)
beat_times = beat_times.tolist() if isinstance(beat_times, np.ndarray) else list(beat_times)
else:
beat_times = []
# 智能节拍层级检测和纠正(传入音频数据用于onset分析)
corrected_bpm, correction_reason = self._detect_beat_level_errors(beat_times, bpm, y)
return {
'bpm': round(corrected_bpm, 1),
'original_bpm': round(bpm, 1),
'confidence': 0.85,
'method': 'librosa.beat.tempo()',
'beat_count': len(beat_times),
'beat_level_correction': correction_reason if correction_reason != 'beat_level_ok' else None,
'duration_ms': 100
}
except Exception as e:
logger.warning(f"⚠️ 快速检测失败: {str(e)}")
return {
'bpm': 0,
'confidence': 0,
'method': 'librosa.beat.tempo()',
'error': str(e)
}
def _detect_beat_level_errors(self, beat_times: list, bpm: float, y: np.ndarray = None) -> Tuple[float, str]:
"""
检测和纠正beat level错误(如检测到8th-note而非quarter-note)
改进版:组合多个特征来判断
1. 交替强度模式 (ratio)
2. 原始BPM范围 (100-150范围内更可能需要减半)
3. 谱质心分析 (慢歌通常谱质心较低)
4. Onset对齐分数比较
"""
if not beat_times or len(beat_times) < 2:
return bpm, "insufficient_beats"
beat_intervals = np.diff(beat_times)
mean_interval = np.mean(beat_intervals)
std_interval = np.std(beat_intervals)
coeff_variation = std_interval / mean_interval if mean_interval > 0 else 1.0
beat_count = len(beat_times)
if self.verbose:
logger.debug(f"Beat level analysis: {beat_count} beats, CV={coeff_variation:.3f}, BPM={bpm:.1f}")
# 条件1: 间隔非常规则 + BPM > 100 + beat count > 20 (降低阈值以支持短片段)
if not (coeff_variation < 0.15 and bpm > 100 and beat_count > 20):
return bpm, "beat_level_ok"
# 如果没有音频数据,使用保守策略
if y is None:
return bpm, "beat_level_ok"
halved_bpm = bpm / 2
if not (40 < halved_bpm < 160):
return bpm, "beat_level_ok"
# 计算onset strength
onset_env = librosa.onset.onset_strength(y=y, sr=self.sr, hop_length=self.hop_length)
# 获取每个beat位置的onset强度
beat_frames = librosa.time_to_frames(beat_times, sr=self.sr, hop_length=self.hop_length)
beat_strengths = []
window = 3
for frame in beat_frames:
if frame < len(onset_env):
start = max(0, frame - window)
end = min(len(onset_env), frame + window + 1)
beat_strengths.append(np.max(onset_env[start:end]))
if len(beat_strengths) < 10:
return bpm, "beat_level_ok"
beat_strengths = np.array(beat_strengths)
# 检测交替强度模式
odd_beats = beat_strengths[::2]
even_beats = beat_strengths[1::2]
mean_odd = np.mean(odd_beats)
mean_even = np.mean(even_beats)
strength_ratio = mean_odd / mean_even if mean_even > 0 else 1.0
# 计算谱质心 (spectral centroid) - 用于区分快歌和慢歌
spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=self.sr, hop_length=self.hop_length)
mean_centroid = np.mean(spectral_centroid)
if self.verbose:
logger.debug(f"Beat strength ratio={strength_ratio:.3f}, spectral_centroid={mean_centroid:.1f}")
# 综合判断逻辑
should_halve = False
reason = ""
# 规则1: 非常明显的交替模式 (ratio > 1.8 或 < 0.55)
if strength_ratio > 1.8 or strength_ratio < 0.55:
should_halve = True
reason = f"strong_alternating_pattern (ratio={strength_ratio:.2f})"
# 规则1b: BPM > 150 + 中等交替模式 → 减半
# 如"春娇与志明"(172.3 BPM, ratio=1.406, ref=85)
# Home - Headhunterz (152 BPM, ratio=1.098) 不会触发
elif bpm > 150 and (strength_ratio > 1.25 or strength_ratio < 0.8):
should_halve = True
reason = f"very_high_bpm_with_alternating (bpm={bpm:.1f}, ratio={strength_ratio:.2f})"
# 规则2: BPM在125-150范围 + 强交替模式 (ratio > 1.25)
# 高onset密度(>=3.0/s) + 高谱质心(>=2200)说明是真正的快歌,不应减半
# 如"爱在西元前"(129.2 BPM, centroid=2527, onset_density=3.8, ratio=1.29)
# 否则使用 bpm*2/3 纠正(适用于3:2节奏关系的歌曲)
# 如"该死的爱情"(129.2 BPM, ratio=1.668, centroid=1986, ref=84) → 2/3=86.1
# 如"你要的全拿走"(136.0 BPM, ratio=1.485, centroid=2678, ref=76) → 2/3=90.7
elif 125 <= bpm <= 150 and strength_ratio > 1.25:
onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=self.sr, hop_length=self.hop_length)
duration = len(y) / self.sr
onset_density = len(onset_frames) / duration if duration > 0 else 0
if onset_density >= 3.0 and mean_centroid >= 2200:
if self.verbose:
logger.debug(f"规则2跳过: 高onset密度({onset_density:.1f}/s) + 高谱质心({mean_centroid:.0f}),判定为快歌")
else:
# 根据谱质心区分纠正策略:
# 低谱质心(<2200): 暗淡音色的慢歌,librosa锁定在3/2倍,用*2/3纠正
# 如"该死的爱情"(129.2, centroid=1986, ratio=1.67) → 86.1 (ref=84)
# 高谱质心(>=2200)+低onset密度(<3.0): 明亮制作的慢歌,librosa锁定在2倍,用/2纠正
# 如"你要的全拿走"(136.0, centroid=2678, density=2.17, ratio=1.49) → 68.0 (ref=76)
if mean_centroid >= 2200:
# 明亮但节奏稀疏 → 简单减半
should_halve = True
reason = f"rule2_bright_slow (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f}, density={onset_density:.1f})"
else:
# 暗淡音色 → 用2/3纠正
two_thirds_bpm = round(bpm * 2 / 3, 1)
should_halve = False
logger.info(
f"🔧 节拍层级纠正(2/3): {bpm:.1f} BPM → {two_thirds_bpm:.1f} BPM "
f"(ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})"
)
return two_thirds_bpm, f"rule2_two_thirds (bpm={bpm:.1f}, result={two_thirds_bpm:.1f}, ratio={strength_ratio:.2f})"
elif 125 <= bpm <= 150 and strength_ratio < 0.8 and mean_centroid < 2200:
should_halve = True
reason = f"mid_bpm_low_ratio_low_centroid (bpm={bpm:.1f}, ratio={strength_ratio:.2f}, centroid={mean_centroid:.0f})"
# 规则2b: BPM > 130 + 低谱质心 (< 1800) 表示慢歌特征但检测到高BPM
# 捕获像"嚣张"这样的歌曲: BPM=136但centroid=1653
elif bpm > 130 and mean_centroid < 1800:
should_halve = True
reason = f"high_bpm_low_centroid (bpm={bpm:.1f}, centroid={mean_centroid:.0f})"
# 规则3: BPM在115-125范围需要更严格的条件
elif 115 <= bpm < 125:
# 规则3a: 非常强的交替模式(ratio > 1.5),无论centroid如何都应减半
# 这捕获了像"想你的夜"这样有强烈交替但centroid偏高的歌曲
if strength_ratio > 1.5 or strength_ratio < 0.65:
should_halve = True
reason = f"strong_alternating_in_mid_bpm (bpm={bpm:.1f}, ratio={strength_ratio:.2f})"
# 规则3b: 中等交替模式 + 低谱质心(慢歌特征)
elif mean_centroid < 2000 and (strength_ratio > 1.4 or strength_ratio < 0.7):
should_halve = True
reason = f"slow_song_detected (centroid={mean_centroid:.0f}, ratio={strength_ratio:.2f})"
# 否则保持原样(可能是真正的中速歌曲如 有什么奇怪、中巴车)
# 规则3c: BPM在100-115范围(可能是慢歌被检测为2倍,如嘉禾望岗 56 BPM → 112 BPM)
# 使用onset alignment来判断
elif 100 <= bpm < 115:
score_detected = self._compute_onset_alignment_score(onset_env, bpm)
score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm)
if score_detected > 0 and score_halved > 0:
alignment_ratio = score_halved / score_detected
if self.verbose:
logger.debug(f"Onset alignment (100-115 BPM): detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}")
# 如果halved BPM的对齐分数更好 (ratio > 1.0),说明真实BPM是一半
# 同时检查交替模式作为辅助判断
if alignment_ratio > 1.0 and (strength_ratio > 1.2 or strength_ratio < 0.83):
should_halve = True
reason = f"slow_song_100_115_range (alignment_ratio={alignment_ratio:.3f}, strength_ratio={strength_ratio:.2f})"
# 即使没有明显交替模式,如果对齐分数明显更好也应减半
elif alignment_ratio > 1.08:
should_halve = True
reason = f"onset_alignment_strongly_favors_half (ratio={alignment_ratio:.3f})"
# 规则4: 使用onset alignment比较BPM vs BPM/2 (仅用于高BPM > 130)
# 如果BPM/2的对齐分数明显更好,说明检测到了half-beat
# 限制为BPM > 130以避免误伤中速歌曲如"中巴车"(117.5 BPM)
if not should_halve and bpm > 130:
score_detected = self._compute_onset_alignment_score(onset_env, bpm)
score_halved = self._compute_onset_alignment_score(onset_env, halved_bpm)
if score_detected > 0 and score_halved > 0:
alignment_ratio = score_halved / score_detected
if self.verbose:
logger.debug(f"Onset alignment: detected={score_detected:.3f}, halved={score_halved:.3f}, ratio={alignment_ratio:.3f}")
# 高谱质心(>=2000)说明是快节奏/电子乐,需要更高的alignment ratio才能减半
# 避免误伤如"Home - Headhunterz"(152 BPM, centroid=2290, ratio=1.102)
ratio_threshold = 1.15 if mean_centroid >= 2000 else 1.04
if alignment_ratio > ratio_threshold and 40 < halved_bpm < 160:
should_halve = True
reason = f"onset_alignment_favors_half (ratio={alignment_ratio:.3f})"
if should_halve:
logger.info(f"🔧 节拍层级纠正: {bpm:.1f} BPM → {halved_bpm:.1f} BPM ({reason})")
return halved_bpm, reason
return bpm, "beat_level_ok"
def _compute_onset_alignment_score(self, onset_env: np.ndarray, bpm: float) -> float:
"""
计算给定BPM与onset strength的对齐度分数
原理:真实的节拍应该对应onset strength的峰值
分数越高表示对齐度越好
"""
frame_rate = self.sr / self.hop_length
beat_interval_frames = int((60.0 / bpm) * frame_rate)
if beat_interval_frames < 1 or beat_interval_frames > len(onset_env):
return 0.0
# 在每个节拍位置采样onset strength
beat_strengths = []
off_beat_strengths = []
for i in range(0, len(onset_env) - beat_interval_frames, beat_interval_frames):
# 节拍位置(在一个小窗口内找最大值)
window_size = max(1, beat_interval_frames // 8)
start = max(0, i - window_size)
end = min(len(onset_env), i + window_size)
beat_strengths.append(np.max(onset_env[start:end]))
# 非节拍位置(节拍之间的中点)
mid_point = i + beat_interval_frames // 2
if mid_point < len(onset_env):
start_off = max(0, mid_point - window_size)
end_off = min(len(onset_env), mid_point + window_size)
off_beat_strengths.append(np.max(onset_env[start_off:end_off]))
if not beat_strengths or not off_beat_strengths:
return 0.0
# 分数 = 节拍位置平均强度 / 非节拍位置平均强度
# 比值越高,说明节拍位置的onset越明显
mean_beat = np.mean(beat_strengths)
mean_off_beat = np.mean(off_beat_strengths)
if mean_off_beat < 1e-6:
return mean_beat
score = mean_beat / mean_off_beat
return float(score)
def _detailed_bpm_analysis(self, y: np.ndarray, sr: int) -> Dict[str, Any]:
"""详细 BPM 分析"""
try:
# 计算 onset strength
onset_env = librosa.onset.onset_strength(
y=y, sr=sr, hop_length=self.hop_length
)
# 计算 tempogram
tempogram = librosa.feature.tempogram(
y=y, sr=sr, hop_length=self.hop_length
)
# 计算自相关
tempogram_flat = tempogram.flatten()
acf = correlate(tempogram_flat, tempogram_flat, mode='full')
acf = acf[len(acf)//2:]
acf = acf / (acf[0] + 1e-8)
# 找峰值
peaks, properties = find_peaks(acf[1:], height=0.2, distance=5)
peaks = peaks + 1
if len(peaks) > 0:
frame_rate = sr / self.hop_length
best_peak_idx = peaks[np.argmax(acf[peaks])]
bpm = 60.0 * frame_rate / best_peak_idx
confidence = float(np.max(acf[peaks]))
else:
bpm = 120.0
confidence = 0.3
# 确保在合理范围内
bpm = np.clip(bpm, self.BPM_MIN, self.BPM_MAX)
return {
'bpm': round(bpm, 1),
'confidence': round(float(np.clip(confidence, 0, 1)), 2),
'method': 'Tempogram Autocorrelation',
'peaks_count': int(len(peaks))
}
except Exception as e:
logger.warning(f"⚠️ 详细分析失败: {str(e)}")
return {
'bpm': 0,
'confidence': 0,
'method': 'Tempogram Autocorrelation',
'error': str(e)
}
def _fuse_results(
self,
fast_result: Dict[str, Any],
detailed_result: Dict[str, Any],
y: np.ndarray = None,
) -> Dict[str, Any]:
"""融合快速和详细分析的结果,带倍频检测和纠正"""
results = []
if fast_result.get('bpm', 0) > 0:
results.append({
'bpm': fast_result['bpm'],
'original_bpm': fast_result.get('original_bpm', fast_result['bpm']),
'confidence': fast_result['confidence'],
'method': fast_result['method'],
'beat_level_correction': fast_result.get('beat_level_correction')
})
if detailed_result.get('bpm', 0) > 0:
results.append({
'bpm': detailed_result['bpm'],
'confidence': detailed_result['confidence'],
'method': detailed_result['method']
})
if not results:
return {
'bpm': 120.0,
'confidence': 0.0,
'note': '无法检测 BPM,使用默认值'
}
# 如果快速检测已经进行了beat level纠正,直接使用纠正后的结果
beat_level_correction = results[0].get('beat_level_correction') if results else None
if beat_level_correction:
original_bpm = results[0].get('original_bpm', results[0]['bpm'])
corrected_bpm = results[0]['bpm']
return {
'bpm': corrected_bpm,
'confidence': results[0]['confidence'],
'primary_method': results[0]['method'],
'supporting_methods': len(results) - 1,
'all_candidates': results,
'octave_correction': {
'from': original_bpm,
'to': corrected_bpm,
'reason': f'节拍层级纠正: {original_bpm:.1f} → {corrected_bpm:.1f} ({beat_level_correction})'
}
}
# 如果只有一个结果
if len(results) == 1:
best = results[0]
return {
'bpm': best['bpm'],
'confidence': best['confidence'],
'primary_method': best['method'],
'supporting_methods': 0,
'all_candidates': results,
'octave_correction': None
}
# 检测倍频关系
fast_bpm = results[0]['bpm'] # librosa.beat.tempo 通常更准确
detailed_bpm = results[1]['bpm'] if len(results) > 1 else None
if detailed_bpm and fast_bpm > 0:
ratio = max(fast_bpm, detailed_bpm) / min(fast_bpm, detailed_bpm)
# 检查是否是倍频关系(1/2, 1/3, 1/4, 2x, 3x, 4x 等)
octave_correction = None
is_octave = False
chosen_bpm = fast_bpm # 默认使用快速检测结果
# 特殊情况:当 detailed_bpm 很低(< 40)且 fast_bpm 在 100-120 范围时
# 可能是慢歌被检测为2倍,此时 detailed_bpm × 2 可能是正确答案
# 例如:嘉禾望岗 实际56 BPM,fast=112.3,detailed=30,30×2=60更接近
# 注意:需要排除中速/快歌被误纠正的情况(如 中巴车带我回家, fast=117.5, detailed=30, ref=115)
# 使用 onset alignment 来验证:如果 halved BPM 的对齐度明显优于 fast BPM,才执行纠正
if detailed_bpm < 40 and 100 <= fast_bpm <= 120 and y is not None:
# 计算谱质心来判断是否真的是慢歌
spectral_centroid = librosa.feature.spectral_centroid(
y=y, sr=self.sr, hop_length=self.hop_length
)
mean_centroid = float(np.mean(spectral_centroid))
doubled_detailed = detailed_bpm * 2
# 检查 doubled_detailed 是否在合理的慢歌范围内 (50-70 BPM)
# 且谱质心较低(< 2200),确认是慢歌特征
if 50 <= doubled_detailed <= 70 and mean_centroid < 2200:
# 检查 fast_bpm 是否约等于 doubled_detailed × 2
if abs(fast_bpm - doubled_detailed * 2) / fast_bpm < 0.1:
# 额外验证:用 onset alignment 确认 halved BPM 确实更好
# 避免误纠正如"中巴车带我回家"(fast=117.5, ref=115)
onset_env = librosa.onset.onset_strength(
y=y, sr=self.sr, hop_length=self.hop_length
)
score_fast = self._compute_onset_alignment_score(onset_env, fast_bpm)
score_halved = self._compute_onset_alignment_score(onset_env, doubled_detailed)
alignment_ratio = score_halved / score_fast if score_fast > 0 else 0
if self.verbose:
logger.debug(
f"慢歌倍频验证: score_fast={score_fast:.3f}, "
f"score_halved={score_halved:.3f}, ratio={alignment_ratio:.3f}"
)
# 只有当 halved BPM 的对齐度明显更好时才纠正
# 中巴车: alignment_ratio=1.042,不触发(实际BPM=115)
# 嘉禾望岗: halved=56 对齐度应该明显更好,会触发
if alignment_ratio > 1.08:
chosen_bpm = doubled_detailed
is_octave = True
octave_correction = {
'from': fast_bpm,
'to': doubled_detailed,
'reason': f'慢歌倍频纠正: fast={fast_bpm:.1f} ≈ detailed×4={detailed_bpm:.1f}×4,使用 detailed×2={doubled_detailed:.1f} (alignment={alignment_ratio:.3f})'
}
logger.info(f"\n🔧 慢歌倍频纠正: {fast_bpm:.1f} BPM → {doubled_detailed:.1f} BPM")
logger.info(f" 原因: {octave_correction['reason']}")
return {
'bpm': chosen_bpm,
'confidence': results[1]['confidence'],
'primary_method': 'Tempogram + 倍频纠正',
'supporting_methods': 1,
'all_candidates': results,
'octave_correction': octave_correction
}
else:
if self.verbose:
logger.debug(
f"慢歌倍频纠正跳过: fast BPM({fast_bpm:.1f})对齐度更好,保持原值"
)
# 检查常见倍频关系:detailed_bpm 应该 ≈ fast_bpm * multiplier
for multiplier in [0.25, 0.33, 0.5, 1.0, 2.0, 3.0, 4.0]:
expected_bpm = fast_bpm * multiplier
# 检查 detailed_bpm 是否接近 expected_bpm(10% 容差)
if abs(detailed_bpm - expected_bpm) / expected_bpm < 0.1:
is_octave = True
if multiplier != 1.0: # 非 1 倍关系表示倍频
# 使用快速检测的结果
corrected_bpm = fast_bpm
octave_correction = {
'from': detailed_bpm,
'to': corrected_bpm,
'reason': f'倍频关系检测: {detailed_bpm:.1f} ≈ {fast_bpm:.1f} × {multiplier},使用快速检测结果'
}
break
# 如果检测到倍频,使用快速检测结果(通常更准确)
if is_octave and octave_correction:
logger.info(f"\n🔧 倍频纠正: {octave_correction['from']:.1f} BPM → {octave_correction['to']:.1f} BPM")
logger.info(f" 原因: {octave_correction['reason']}")
return {
'bpm': fast_bpm,
'confidence': results[0]['confidence'],
'primary_method': results[0]['method'],
'supporting_methods': 1,
'all_candidates': results,
'octave_correction': octave_correction
}
# 如果没有倍频关系,优先使用快速检测(librosa.beat.tempo 是金标准)
# 快速检测通常比详细分析更准确
best = results[0] # 快速检测
return {
'bpm': best['bpm'],
'confidence': best['confidence'],
'primary_method': best['method'],
'supporting_methods': len(results) - 1,
'all_candidates': results,
'octave_correction': None
}
def _display_results(self, result: Dict[str, Any]):
"""显示分析结果"""
if not result['success']:
logger.error(f"❌ 分析失败: {result.get('error')}")
return
file_info = (
f"文件: {result['file_name']} "
f"({result['file_size_mb']} MB) "
f"时长: {result['duration_seconds']} 秒"
)
logger.info(file_info)
final = result['final_result']
logger.info(f"\n🎵 最终结果:")
logger.info(f" BPM: {final['bpm']}")
logger.info(f" 置信度: {final['confidence']:.0%}")
logger.info(f" 主要方法: {final['primary_method']}")
logger.info(f" 支持方法数: {final['supporting_methods']}")
# 显示倍频纠正信息
if final.get('octave_correction'):
correction = final['octave_correction']
logger.info(f"\n🔧 倍频纠正:")
logger.info(f" 原始检测: {correction['from']:.1f} BPM")
logger.info(f" 纠正后: {correction['to']:.1f} BPM")
logger.info(f" 原因: {correction['reason']}")
if self.verbose:
logger.debug(f"\n📊 快速检测: {result['fast_detection']['bpm']} BPM")
logger.debug(f"📊 详细分析: {result['detailed_analysis']['bpm']} BPM")
def export_results(
self,
results: Any,
output_path: str
):
"""导出结果为 JSON"""
try:
# 将 numpy 类型转换为 Python 原生类型
def convert_numpy(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, dict):
return {k: convert_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [convert_numpy(v) for v in obj]
return obj
results_converted = convert_numpy(results)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results_converted, f, ensure_ascii=False, indent=2)
logger.info(f"✓ 结果已导出到: {Path(output_path).absolute()}")
except Exception as e:
logger.error(f"❌ 导出失败: {str(e)}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description='Realtime BPM Analyzer - Python 测试程序',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 分析单个文件
python bpm_analyzer_test.py --file music.mp3
# 分析并输出结果
python bpm_analyzer_test.py --file music.mp3 --output result.json
# 显示详细信息
python bpm_analyzer_test.py --file music.mp3 --verbose
# 批量分析文件夹
python bpm_analyzer_test.py --dir /path/to/music
"""
)
parser.add_argument('--file', type=str, help='音频文件路径')
parser.add_argument('--dir', type=str, help='音频文件夹路径(批量分析)')
parser.add_argument('-o', '--output', type=str, help='输出 JSON 文件路径')
parser.add_argument('-v', '--verbose', action='store_true', help='显示详细信息')
args = parser.parse_args()
# 验证参数
if not args.file and not args.dir:
parser.print_help()
sys.exit(1)
# 初始化分析器
analyzer = RealtimeBPMAnalyzerTest(verbose=args.verbose)
# 执行分析
try:
if args.file:
result = analyzer.analyze_file(args.file)
results = result
else:
results_list = analyzer.analyze_directory(args.dir)
results = {
'success': True,
'total_files': len(results_list),
'results': results_list
}
# 导出结果
if args.output:
analyzer.export_results(results, args.output)
else:
# 默认输出文件名
if args.file:
default_output = f"bpm_result_{Path(args.file).stem}.json"
else:
default_output = "bpm_results.json"
analyzer.export_results(results, default_output)
print("\n" + "=" * 80)
print("✅ 分析完成!")
print("=" * 80 + "\n")
except Exception as e:
logger.error(f"❌ 执行失败: {str(e)}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
"""
火山引擎豆包音乐分析器实现
"""
import os
import time
import logging
from typing import Dict, Any, Optional
from dotenv import load_dotenv
from pathlib import Path
import httpx
from .base import AudioAnalyzer
from .prompts import build_analyze_prompt, build_lyrics_prompt
_ROOT_DIR = Path(__file__).resolve().parents[2]
load_dotenv(_ROOT_DIR / ".env")
logger = logging.getLogger(__name__)
class DoubaoAnalyzer(AudioAnalyzer):
"""火山引擎豆包音乐分析器"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
timeout: float = 60.0,
max_retries: int = 3,
):
"""
初始化豆包分析器
Args:
api_key: API Key(默认从环境变量读取 DOUBAO_API_KEY 或 ARK_API_KEY)
base_url: API 基础URL(默认: https://ark.cn-beijing.volces.com/api/v3)
model: 模型名称(默认: doubao-seed-1-8-251228)
timeout: 超时时间(秒)
max_retries: 最大重试次数
"""
self.api_key = api_key or os.getenv("DOUBAO_API_KEY", os.getenv("ARK_API_KEY"))
self.base_url = base_url or os.getenv(
"DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"
)
self.model = model or os.getenv("DOUBAO_MODEL", "doubao-seed-1-8-251228")
self.timeout = timeout
self.max_retries = max_retries
self._client = None
def _get_client(self) -> httpx.Client:
"""获取 HTTP 客户端"""
if self._client is None:
self._client = httpx.Client(
base_url=self.base_url,
timeout=self.timeout,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
return self._client
def get_provider_name(self) -> str:
return "doubao"
def get_model_name(self) -> str:
return self.model
def analyze(
self,
metadata: Dict[str, Any],
music_url: str,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""
分析音乐
Args:
metadata: 音乐元数据
music_url: 音乐文件 URL
extract_lyrics: 是否识别歌词
label_level: 标签级别
Returns:
分析结果字典
"""
client = self._get_client()
if extract_lyrics:
return self._analyze_with_lyrics(client, metadata, music_url, label_level)
else:
return self._analyze_basic(client, metadata, music_url, label_level)
def _analyze_basic(
self,
client: httpx.Client,
metadata: Dict[str, Any],
music_url: str,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""基础分析(不含歌词)"""
system_prompt, user_prompt = build_analyze_prompt(
metadata=metadata,
include_lyrics=False,
label_level=label_level,
)
# 打印提示词到日志
logger.info(f"[DoubaoAnalyzer] System Prompt:\n{system_prompt}")
logger.info(f"[DoubaoAnalyzer] User Prompt:\n{user_prompt}")
messages = self._build_messages(system_prompt, user_prompt, music_url)
response = self._call_with_retry(client, messages)
if response is None:
return None
result = self._parse_response(response.get("content", ""))
if result is None:
return None
return self._normalize_result(result, self.model, response.get("usage"))
def _analyze_with_lyrics(
self,
client: httpx.Client,
metadata: Dict[str, Any],
music_url: str,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""分析(含歌词识别,需要两次调用)"""
# 第一次调用:基本信息(不含歌词)
system_prompt, user_prompt = build_analyze_prompt(
metadata=metadata,
include_lyrics=False,
label_level=label_level,
)
# 打印提示词到日志
logger.info(f"[DoubaoAnalyzer] System Prompt (with lyrics):\n{system_prompt}")
logger.info(f"[DoubaoAnalyzer] User Prompt (with lyrics):\n{user_prompt}")
messages_basic = self._build_messages(system_prompt, user_prompt, music_url)
response_basic = self._call_with_retry(client, messages_basic)
if response_basic is None:
return None
result = self._parse_response(response_basic.get("content", ""))
if result is None:
return None
# 第二次调用:歌词识别
lyrics_prompt = build_lyrics_prompt()
# 打印歌词识别提示词到日志
logger.info(f"[DoubaoAnalyzer] Lyrics Prompt:\n{lyrics_prompt}")
messages_lyrics = self._build_messages(
"请识别这段音频中的歌词内容", lyrics_prompt, music_url
)
response_lyrics = self._call_with_retry(client, messages_lyrics)
lyrics_result = None
if response_lyrics:
lyrics_result = self._parse_response(response_lyrics.get("content", ""))
if lyrics_result and "lyrics" in lyrics_result:
result["lyrics"] = lyrics_result["lyrics"]
# 合并 token 使用信息
usage = response_basic.get("usage", {})
if response_lyrics and response_lyrics.get("usage"):
usage_lyrics = response_lyrics["usage"]
usage = {
"prompt_tokens": usage.get("prompt_tokens", 0)
+ usage_lyrics.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0)
+ usage_lyrics.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0)
+ usage_lyrics.get("total_tokens", 0),
}
return self._normalize_result(result, self.model, usage)
def _build_messages(
self,
system_prompt: str,
user_prompt: str,
music_url: str,
) -> list:
"""构建消息格式"""
return [
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": music_url}},
{"type": "text", "text": user_prompt},
],
}
]
def _call_with_retry(
self,
client: httpx.Client,
messages: list,
) -> Optional[Dict]:
"""带重试的 API 调用"""
endpoint = "/chat/completions"
data = {
"model": self.model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 4000,
"stream": False,
}
for attempt in range(1, self.max_retries + 1):
try:
print(f" [Doubao] 调用模型 (尝试 {attempt}/{self.max_retries})...")
start_time = time.time()
response = client.post(endpoint, json=data)
response.raise_for_status()
end_time = time.time()
elapsed = end_time - start_time
print(f" [Doubao] 响应时间: {elapsed:.2f}s")
result = response.json()
content = (
result.get("choices", [{}])[0].get("message", {}).get("content", "")
)
usage = result.get("usage", {})
print(f" [Doubao] 响应: {content[:100]}...")
return {
"content": content,
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}
except httpx.HTTPError as e:
error_type = type(e).__name__
print(f" [Doubao] HTTP 错误 ({error_type}): {e}")
if attempt < self.max_retries:
wait_time = attempt
print(f" 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f" 已达到最大重试次数")
return None
except Exception as e:
error_type = type(e).__name__
print(f" [Doubao] 错误 ({error_type}): {e}")
if attempt < self.max_retries:
wait_time = attempt
print(f" 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f" 已达到最大重试次数")
return None
return None
def test_doubao_audio_url_lyrics():
"""
测试豆包是否支持通过音频URL解析音频歌词
此测试用例用于验证豆包模型是否能够:
1. 接收音频URL作为输入
2. 解析音频内容
3. 识别并返回歌词
使用方法:
python -c "from app.middleware.music_analyze.doubao_analyzer import test_doubao_audio_url_lyrics; test_doubao_audio_url_lyrics()"
或者直接在命令行运行:
python app/middleware/music_analyze/doubao_analyzer.py
"""
import json
print("=" * 80)
print("测试豆包音频URL歌词解析功能")
print("=" * 80)
# 测试音频URL(使用一个公开可访问的音频文件)
# 注意:请替换为实际可访问的音频URL
test_audio_url = "https://hikoon-ai-test.oss-cn-hangzhou.aliyuncs.com/ai/cache/modelName/20260114/_s__e_1768376270519_rmab41.mp3"
print(f"\n测试音频URL: {test_audio_url}")
print("\n开始测试...")
try:
# 初始化分析器
analyzer = DoubaoAnalyzer()
# 测试元数据
metadata = {"title": "测试歌曲", "artist": "测试艺术家", "test": True}
print("\n1. 测试基础分析(不含歌词)...")
result_basic = analyzer.analyze(
metadata=metadata,
music_url=test_audio_url,
extract_lyrics=False,
label_level=0,
)
if result_basic:
print(" ✓ 基础分析成功")
print(f" - 曲风: {result_basic.get('genre', 'N/A')}")
print(f" - 语种: {result_basic.get('language', 'N/A')}")
print(f" - 情绪: {result_basic.get('emotion', 'N/A')}")
else:
print(" ✗ 基础分析失败")
print("\n2. 测试歌词识别(含歌词)...")
result_with_lyrics = analyzer.analyze(
metadata=metadata,
music_url=test_audio_url,
extract_lyrics=True,
label_level=0,
)
if result_with_lyrics:
print(" ✓ 歌词识别分析成功")
lyrics = result_with_lyrics.get("lyrics", [])
if lyrics:
print(f" ✓ 成功识别歌词,共 {len(lyrics)} 行")
print("\n 歌词预览(前5行):")
for i, line in enumerate(lyrics[:5], 1):
time_str = line.get("time", "N/A")
text = line.get("text", "")
print(f" [{i}] {time_str} - {text}")
if len(lyrics) > 5:
print(f" ... 还有 {len(lyrics) - 5} 行")
print("\n ✓ 测试通过:豆包支持音频URL解析歌词")
else:
print(" ⚠ 未识别到歌词(可能是纯音乐或无法识别)")
print("\n ! 测试结果:豆包支持音频URL解析,但未返回歌词")
# 输出完整结果
print("\n3. 完整分析结果:")
print(json.dumps(result_with_lyrics, ensure_ascii=False, indent=2))
else:
print(" ✗ 歌词识别分析失败")
print("\n ✗ 测试失败:豆包可能不支持音频URL解析")
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
return result_with_lyrics
except Exception as e:
print(f"\n✗ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
test_doubao_audio_url_lyrics()
# -*- coding: utf-8 -*-
"""
音乐分析器工厂
"""
from typing import Dict, Any, Optional
from .base import AudioAnalyzer
from .qwen_analyzer import QwenAnalyzer
class AnalyzerFactory:
"""音乐分析器工厂"""
_analyzers: Dict[str, AudioAnalyzer] = {}
@classmethod
def get_analyzer(cls, provider: str = "qwen", **kwargs) -> AudioAnalyzer:
"""
获取分析器实例
Args:
provider: 提供商名称(仅支持 qwen)
**kwargs: 额外配置参数(如 api_key, model, timeout 等)
Returns:
AudioAnalyzer 实例
"""
key = f"{provider}"
cache_key = f"{provider}_{kwargs.get('model', '')}"
if cache_key in cls._analyzers:
return cls._analyzers[cache_key]
if provider == "qwen":
analyzer = QwenAnalyzer(**kwargs)
else:
raise ValueError(f"Unknown provider: {provider}. Only 'qwen' is supported.")
cls._analyzers[cache_key] = analyzer
return analyzer
@classmethod
def get_default_analyzer(cls) -> AudioAnalyzer:
"""获取默认分析器(从环境变量读取)"""
import os
provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
return cls.get_analyzer(provider=provider)
@classmethod
def list_providers(cls) -> list:
"""列出可用的提供商"""
return ["qwen"]
@classmethod
def clear_cache(cls):
"""清除缓存的分析器实例"""
cls._analyzers.clear()
# -*- coding: utf-8 -*-
"""
音乐分析统一入口
提供简化的 analyze_music() 函数
"""
from typing import Dict, Any, Optional
import os
from .factory import AnalyzerFactory
def analyze_music(
metadata: Dict[str, Any],
music_url: str,
provider: str = None,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""
音乐分析统一入口函数
Args:
metadata: 音乐元数据字典(如 title, artist 等)
music_url: 音乐文件 URL
provider: 提供商(qwen | doubao),默认从环境变量读取
extract_lyrics: 是否识别歌词
label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
Returns:
分析结果字典,包含以下字段:
- genre: 音乐风格(一级风格,如:流行、摇滚)
- emotion: 情绪列表
- emotional_intensity: 情绪强度
- vocal_texture: 人声质感
- vocal_description: 人声质感描述
- visual_concept: 视觉概念
- language: 语种
- bpm: 节拍数(可选)
- lyrics: 歌词列表(可选)
- _model: 使用的模型名称
- _token_info: Token 使用信息
Example:
>>> result = analyze_music(
... metadata={"title": "稻香", "artist": "周杰伦"},
... music_url="https://example.com/music.mp3",
... provider="qwen",
... extract_lyrics=False,
... )
>>> print(result["genre"])
流行
"""
if provider is None:
provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
analyzer = AnalyzerFactory.get_analyzer(provider=provider)
return analyzer.analyze(
metadata=metadata,
music_url=music_url,
extract_lyrics=extract_lyrics,
label_level=label_level,
)
def analyze_music_with_qwen(
metadata: Dict[str, Any],
music_url: str,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""使用通义千问分析音乐"""
return analyze_music(
metadata=metadata,
music_url=music_url,
provider="qwen",
extract_lyrics=extract_lyrics,
label_level=label_level,
)
def analyze_music_with_doubao(
metadata: Dict[str, Any],
music_url: str,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""使用火山引擎豆包分析音乐"""
return analyze_music(
metadata=metadata,
music_url=music_url,
provider="doubao",
extract_lyrics=extract_lyrics,
label_level=label_level,
)
def analyze_music_lyrics_only(
metadata: Dict[str, Any],
music_url: str,
provider: str = None,
) -> Optional[Dict[str, Any]]:
"""仅识别歌词,避免重复做基础标签分析"""
if provider is None:
provider = os.getenv("DEFAULT_MUSIC_ANALYZER", "qwen")
analyzer = AnalyzerFactory.get_analyzer(provider=provider)
if hasattr(analyzer, "analyze_lyrics_only"):
return analyzer.analyze_lyrics_only(metadata=metadata, music_url=music_url)
# 兼容未实现 lyrics_only 的提供商
result = analyzer.analyze(
metadata=metadata,
music_url=music_url,
extract_lyrics=True,
label_level=0,
)
if isinstance(result, dict):
lyrics = result.get("lyrics", [])
return {
"lyrics": lyrics if isinstance(lyrics, list) else [],
"_model": result.get("_model"),
"_token_info": result.get("_token_info"),
}
return None
def get_available_providers() -> list:
"""获取可用的提供商列表"""
return AnalyzerFactory.list_providers()
# -*- coding: utf-8 -*-
"""
音乐分析提示词模板构建器
支持从外部模板文件读取提示词,便于动态修改
"""
import os
from pathlib import Path
from typing import Dict, Any, Optional
# 模板文件路径(已迁移到 app/prompts/step2_music_decode)
PROMPTS_DIR = Path(__file__).parent.parent.parent / "prompts" / "step2_music_decode"
SYSTEM_PROMPT_FILE = PROMPTS_DIR / "music_analyze_system_prompt.md"
SYSTEM_PROMPT_PART_A_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_a.md"
SYSTEM_PROMPT_PART_B_FILE = PROMPTS_DIR / "music_analyze_system_prompt_part_b.md"
USER_PROMPT_FILE = PROMPTS_DIR / "music_analyze_user_prompt.md"
LYRICS_ONLY_PROMPT_FILE = PROMPTS_DIR / "music_lyrics_only_prompt.md"
def load_template(template_path: Path) -> str:
"""
从文件加载模板
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
"""
if not template_path.exists():
raise FileNotFoundError(f"模板文件不存在: {template_path}")
with open(template_path, "r", encoding="utf-8") as f:
content = f.read()
# 只移除文件顶部的 Markdown 注释(以 # 开头的注释行)
# 保留 ## 标题行和正文内容
lines = content.split("\n")
filtered_lines = []
in_header = True
for line in lines:
stripped = line.strip()
# 如果是空行,保留
if not stripped:
filtered_lines.append(line)
continue
# 如果在文件头部且是单行注释(# 但不是 ##),则跳过
if in_header and stripped.startswith("#") and not stripped.startswith("##"):
continue
# 遇到 ## 标题或正文内容,不再是头部
in_header = False
filtered_lines.append(line)
return "\n".join(filtered_lines)
class PromptBuilder:
"""音乐分析提示词模板构建器"""
def __init__(self, label_level: int = 0):
"""
初始化提示词构建器
Args:
label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
"""
self.label_level = label_level
def build_system_prompt(self) -> str:
"""构建系统提示词 - 直接加载静态模板"""
return load_template(SYSTEM_PROMPT_FILE)
def build_system_prompt_part_a(self) -> str:
"""构建系统提示词A组"""
return load_template(SYSTEM_PROMPT_PART_A_FILE)
def build_system_prompt_part_b(self) -> str:
"""构建系统提示词B组"""
return load_template(SYSTEM_PROMPT_PART_B_FILE)
def build_metadata_section(self, metadata: Optional[Dict[str, Any]] = None) -> str:
"""构建元数据部分"""
if not metadata:
return ""
sections = ["## 音乐元数据"]
for key, value in metadata.items():
if key.startswith("_"):
continue
if value and str(value).strip():
sections.append(f"- {key}: {value}")
sections.append("")
return "\n".join(sections)
def build_output_format(
self,
include_lyrics: bool = False,
include_bpm: bool = True,
) -> str:
"""构建输出格式说明"""
if include_lyrics and include_bpm:
format_spec = """{
"genre": "",
"sub_genre": "",
"language": "",
"vocal_type": "",
"vocal_description": "",
"emotion": [""],
"scene": [""],
"age": "",
"rhythm_intensity": "",
"is_sinking": false,
"song_description": "",
"visual_concept": "",
"emotional_intensity": "",
"bpm": 0,
"lyrics": [{"time": "", "text": ""}]
}"""
elif include_bpm:
format_spec = """{
"genre": "",
"sub_genre": "",
"language": "",
"vocal_type": "",
"vocal_description": "",
"emotion": [""],
"scene": [""],
"age": "",
"rhythm_intensity": "",
"is_sinking": false,
"song_description": "",
"visual_concept": "",
"emotional_intensity": "",
"bpm": 0
}"""
elif include_lyrics:
format_spec = """{
"genre": "",
"sub_genre": "",
"language": "",
"vocal_type": "",
"vocal_description": "",
"emotion": [""],
"scene": [""],
"age": "",
"rhythm_intensity": "",
"is_sinking": false,
"song_description": "",
"visual_concept": "",
"emotional_intensity": "",
"lyrics": [{"time": "", "text": ""}]
}"""
else:
format_spec = """{
"genre": "",
"sub_genre": "",
"language": "",
"vocal_type": "",
"vocal_description": "",
"emotion": [""],
"scene": [""],
"age": "",
"rhythm_intensity": "",
"is_sinking": false,
"song_description": "",
"visual_concept": "",
"emotional_intensity": ""
}"""
return format_spec
def build_user_prompt(
self,
metadata: Optional[Dict[str, Any]] = None,
include_lyrics: bool = False,
include_bpm: bool = True,
) -> str:
"""
构建完整的用户提示词
使用模板文件并替换占位符
Args:
metadata: 音乐元数据字典(可选)
include_lyrics: 是否识别歌词(保留参数以兼容现有调用)
include_bpm: 是否包含BPM识别(保留参数以兼容现有调用)
Returns:
完整的用户提示词
"""
# 加载模板
template = load_template(USER_PROMPT_FILE)
# 准备替换字典 - 只替换元数据部分
# 输出格式已在系统提示词中定义,不需要在用户提示词中重复
replacements = {
"{{METADATA_SECTION}}": self.build_metadata_section(metadata),
}
# 替换占位符
result = template
for placeholder, value in replacements.items():
result = result.replace(placeholder, value)
return result
def build_lyrics_only_prompt(self) -> str:
"""构建仅识别歌词的提示词"""
return load_template(LYRICS_ONLY_PROMPT_FILE)
def build_analyze_prompt(
metadata: Optional[Dict[str, Any]] = None,
include_lyrics: bool = False,
label_level: int = 0,
) -> tuple[str, str]:
"""
构建完整的分析提示词
Args:
metadata: 音乐元数据字典(可选)
include_lyrics: 是否识别歌词
label_level: 标签级别(0: 一级标签, 1: 一级+二级标签)
Returns:
(system_prompt, user_prompt) 元组
"""
builder = PromptBuilder(label_level=label_level)
system_prompt = builder.build_system_prompt()
user_prompt = builder.build_user_prompt(
metadata=metadata,
include_lyrics=include_lyrics,
include_bpm=True,
)
return system_prompt, user_prompt
def build_analyze_prompt_part_a(
metadata: Optional[Dict[str, Any]] = None,
include_lyrics: bool = False,
label_level: int = 0,
) -> tuple[str, str]:
"""
构建A组分析提示词(标签与基础信息)
"""
builder = PromptBuilder(label_level=label_level)
system_prompt = builder.build_system_prompt_part_a()
user_prompt = builder.build_user_prompt(
metadata=metadata,
include_lyrics=include_lyrics,
include_bpm=True,
)
return system_prompt, user_prompt
def build_analyze_prompt_part_b(
metadata: Optional[Dict[str, Any]] = None,
include_lyrics: bool = False,
label_level: int = 0,
) -> tuple[str, str]:
"""
构建B组分析提示词(节奏与视觉描述)
"""
builder = PromptBuilder(label_level=label_level)
system_prompt = builder.build_system_prompt_part_b()
user_prompt = builder.build_user_prompt(
metadata=metadata,
include_lyrics=include_lyrics,
include_bpm=True,
)
return system_prompt, user_prompt
def build_lyrics_prompt() -> str:
"""构建仅识别歌词的提示词"""
builder = PromptBuilder()
return builder.build_lyrics_only_prompt()
# 向后兼容:保留原有的构建函数
def build_user_prompt(
metadata: Optional[Dict[str, Any]] = None,
include_lyrics: bool = False,
label_level: int = 0,
) -> str:
"""构建用户提示词(兼容函数)"""
builder = PromptBuilder(label_level=label_level)
return builder.build_user_prompt(
metadata=metadata,
include_lyrics=include_lyrics,
include_bpm=True,
)
# -*- coding: utf-8 -*-
"""
通义千问音乐分析器实现
"""
import os
import time
import tempfile
import subprocess
import threading
import hashlib
import csv
from datetime import datetime
from pathlib import Path
import requests
import logging
from typing import Dict, Any, Optional, Tuple, List
from concurrent.futures import ThreadPoolExecutor
from .base import AudioAnalyzer
from .prompts import (
build_analyze_prompt,
build_lyrics_prompt,
)
from .audio_features import (
extract_audio_features,
extract_beat_timestamps,
extract_emotion_curve,
aggregate_emotion_by_segments,
)
# 使用项目统一的配置
from app.core.config import settings
logger = logging.getLogger(__name__)
MUSIC_MAPPING_HEADERS = [
"song_id",
"audio_file_name",
"audio_file_path",
"source_url",
"updated_at",
]
MUSIC_MAPPING_HEADER_ALIASES = {
"song_id": ("song_id", "歌曲ID"),
"audio_file_name": ("audio_file_name", "音频文件名"),
"audio_file_path": ("audio_file_path", "音频文件路径"),
"source_url": ("source_url", "原始URL"),
"updated_at": ("updated_at", "更新时间"),
}
class QwenAnalyzer(AudioAnalyzer):
"""通义千问音乐分析器"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
max_retries: int = 3,
):
"""
初始化通义千问分析器
Args:
api_key: API Key(默认从环境变量读取 QWEN_API_KEY)
base_url: API 基础URL(默认从环境变量读取)
model: 模型名称(默认: qwen3-omni-flash)
timeout: 超时时间(秒)
max_retries: 最大重试次数
"""
# 优先使用传入的参数,其次使用项目统一的 settings
if api_key is None:
# 按优先级:QWEN_API_KEY -> QWEN_DASHSCOPE_API_KEY
api_key = settings.QWEN_API_KEY or settings.QWEN_DASHSCOPE_API_KEY
self.api_key = api_key
self.base_url = (
base_url
or settings.QWEN_BASE_URL
or "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
self.model = model or settings.QWEN_MODEL or "qwen3-omni-flash"
self.timeout = settings.QWEN_TIMEOUT or 15.0
self.lyrics_timeout = settings.QWEN_LYRICS_TIMEOUT or 90.0
self.max_retries = max_retries or settings.QWEN_MAX_RETRIES or 3
self._client = None
self._project_root = Path(__file__).resolve().parents[3]
self._music_dir = self._resolve_music_dir()
self._music_mapping_path = self._resolve_music_mapping_path()
self._mapping_lock = threading.Lock()
self._mapping_seen: set[tuple[str, str]] = self._load_existing_mapping_keys()
def _resolve_music_dir(self) -> Path:
raw_dir = str(getattr(settings, "MUSIC_DOWNLOAD_DIR", "music") or "music").strip()
path = Path(raw_dir)
if not path.is_absolute():
path = self._project_root / path
path.mkdir(parents=True, exist_ok=True)
return path
def _resolve_music_mapping_path(self) -> Path:
raw_file = str(
getattr(settings, "MUSIC_MAPPING_FILE", "music/music_file_mapping.csv")
or "music/music_file_mapping.csv"
).strip()
path = Path(raw_file)
if not path.is_absolute():
path = self._project_root / path
path.parent.mkdir(parents=True, exist_ok=True)
return path
def _load_existing_mapping_keys(self) -> set[tuple[str, str]]:
if not self._music_mapping_path.exists():
return set()
seen: set[tuple[str, str]] = set()
try:
with open(self._music_mapping_path, "r", encoding="utf-8-sig", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
song_id = self._get_mapping_value(row, "song_id")
file_path = self._get_mapping_value(row, "audio_file_path")
if file_path:
try:
file_path = str(Path(file_path).resolve())
except Exception:
pass
seen.add((song_id, file_path))
except Exception:
return set()
return seen
def _get_mapping_value(self, row: Dict[str, Any], field: str) -> str:
for alias in MUSIC_MAPPING_HEADER_ALIASES.get(field, (field,)):
value = row.get(alias)
if value is not None and str(value).strip():
return str(value).strip()
return ""
def _extract_song_id(self, metadata: Optional[Dict[str, Any]]) -> str:
if not metadata:
return ""
for key in ("歌曲ID", "song_id", "id", "track_id", "tmeid", "tmeID", "TMEID"):
value = metadata.get(key)
if value is not None and str(value).strip():
return str(value).strip()
return ""
def _sanitize_filename_part(self, value: str) -> str:
safe_chars = []
for ch in value:
if ch.isalnum() or ch in {"-", "_", "."}:
safe_chars.append(ch)
else:
safe_chars.append("_")
cleaned = "".join(safe_chars).strip("._")
return cleaned[:80] if cleaned else "unknown"
def _build_music_file_path(
self,
music_url: str,
ext: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Path:
song_id = self._extract_song_id(metadata)
song_part = self._sanitize_filename_part(song_id or "unknown")
url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12]
return self._music_dir / f"{song_part}_{url_hash}{ext}"
def _append_music_mapping(
self,
file_path: Path,
music_url: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
song_id = self._extract_song_id(metadata)
mapping_key = (song_id, str(file_path.resolve()))
with self._mapping_lock:
if mapping_key in self._mapping_seen:
return
write_header = not self._music_mapping_path.exists()
encoding = "utf-8-sig" if write_header else "utf-8"
with open(self._music_mapping_path, "a", encoding=encoding, newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=MUSIC_MAPPING_HEADERS,
)
if write_header:
writer.writeheader()
writer.writerow(
{
"song_id": song_id,
"audio_file_name": file_path.name,
"audio_file_path": str(file_path.resolve()),
"source_url": music_url,
"updated_at": datetime.now().isoformat(timespec="seconds"),
}
)
self._mapping_seen.add(mapping_key)
def _is_persisted_music_file(self, file_path: str) -> bool:
try:
candidate = Path(file_path).resolve()
return candidate.parent == self._music_dir.resolve()
except Exception:
return False
def _get_client(self):
"""获取 OpenAI 兼容客户端"""
if self._client is None:
from openai import OpenAI
self._client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=self.timeout,
max_retries=0,
)
return self._client
def get_provider_name(self) -> str:
return "qwen"
def get_model_name(self) -> str:
return self.model
def _call_songformer(self, music_url: str) -> Optional[Dict]:
"""
调用 SongFormer 服务获取歌曲结构和高潮点
Args:
music_url: 音乐文件 URL
Returns:
SongFormer 返回的完整数据字典
"""
songformer_url = getattr(settings, "SONGFORMER_URL", None)
if not songformer_url:
print(" [Qwen] SongFormer URL 未配置,跳过高潮点分析")
return None
try:
print(f" [Qwen] 调用 SongFormer 服务...")
resp = requests.post(
songformer_url,
json={"url": music_url, "chorus_k": 3},
timeout=60,
)
resp.raise_for_status()
data = resp.json()
print(f" [Qwen] SongFormer 调用成功")
return data
except Exception as e:
print(f" [Qwen] SongFormer 调用失败: {e}")
return None
def _extract_climax_point(self, songformer_data: Optional[Dict]) -> str:
"""
从 SongFormer 数据中提取高潮点
Args:
songformer_data: SongFormer 返回的数据
Returns:
str: "最强", "强", 或 ""
"""
if not songformer_data:
return ""
# 首先尝试从 climax_points 字段获取(旧格式)
climax_points = songformer_data.get("climax_points", {})
if climax_points:
# 检查是否有最强高潮
if climax_points.get("strongest_climax"):
return "最强"
# 检查是否有强高潮
if climax_points.get("strong_climax"):
return "强"
# 从 top_k_chorus 字段获取(新格式)
top_k_chorus = songformer_data.get("top_k_chorus", [])
if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0:
# 按 score 排序,取最高分作为最强高潮
sorted_chorus = sorted(
[
c
for c in top_k_chorus
if isinstance(c, dict) and c.get("score") is not None
],
key=lambda x: x.get("score", 0),
reverse=True,
)
if sorted_chorus:
# 最高分 > 7.0 认为是"最强",否则是"强"
highest_score = sorted_chorus[0].get("score", 0)
if highest_score > 7.0:
return "最强"
else:
return "强"
return ""
def _build_climax_points(self, songformer_data: Optional[Dict]) -> Dict[str, Any]:
"""
从 SongFormer 数据构建完整的 climax_points 对象
Args:
songformer_data: SongFormer 返回的数据
Returns:
包含 strong_climax 和 strongest_climax 的字典
"""
if not songformer_data:
return {
"strong_climax": None,
"strongest_climax": None,
"analysis_time": 0.0,
}
# 首先尝试从 climax_points 字段获取(旧格式)
climax_points = songformer_data.get("climax_points", {})
if climax_points and (
climax_points.get("strong_climax") or climax_points.get("strongest_climax")
):
return {
"strong_climax": climax_points.get("strong_climax"),
"strongest_climax": climax_points.get("strongest_climax"),
"analysis_time": climax_points.get("analysis_time", 0.0),
}
# 从 top_k_chorus 字段构建(新格式)
top_k_chorus = songformer_data.get("top_k_chorus", [])
segments = songformer_data.get("segments", [])
if isinstance(top_k_chorus, list) and len(top_k_chorus) > 0:
# 按 score 排序
sorted_chorus = sorted(
[
c
for c in top_k_chorus
if isinstance(c, dict) and c.get("score") is not None
],
key=lambda x: x.get("score", 0),
reverse=True,
)
if sorted_chorus:
# 最高分作为 strongest_climax
highest = sorted_chorus[0]
highest_score = highest.get("score", 0)
# 找到对应的段落标签
start_time = highest.get("start", 0)
section_label = "chorus"
for seg in segments:
if isinstance(seg, dict):
seg_start = seg.get("start", 0)
seg_end = seg.get("end", 0)
if seg_start <= start_time < seg_end:
section_label = seg.get("label", "chorus")
break
strongest_climax = {
"time": start_time,
"intensity": "strongest",
"section_label": section_label,
"reason": f"Highest chorus score: {highest_score:.2f}",
}
# 第二高作为 strong_climax(如果存在且分数差距不大)
strong_climax = None
if len(sorted_chorus) > 1:
second = sorted_chorus[1]
second_score = second.get("score", 0)
second_start = second.get("start", 0)
# 找到对应的段落标签
second_section_label = "chorus"
for seg in segments:
if isinstance(seg, dict):
seg_start = seg.get("start", 0)
seg_end = seg.get("end", 0)
if seg_start <= second_start < seg_end:
second_section_label = seg.get("label", "chorus")
break
strong_climax = {
"time": second_start,
"intensity": "strong",
"section_label": second_section_label,
"reason": f"Second highest chorus score: {second_score:.2f}",
}
return {
"strong_climax": strong_climax,
"strongest_climax": strongest_climax,
"analysis_time": 0.0,
}
return {
"strong_climax": None,
"strongest_climax": None,
"analysis_time": 0.0,
}
def analyze(
self,
metadata: Dict[str, Any],
music_url: str,
extract_lyrics: bool = False,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""
分析音乐
Args:
metadata: 音乐元数据
music_url: 音乐文件 URL
extract_lyrics: 是否识别歌词
label_level: 标签级别
Returns:
分析结果字典
"""
client = self._get_client()
light_mode = bool(getattr(settings, "MUSIC_ANALYZE_LIGHT_MODE", True))
songformer_data = None if light_mode else self._call_songformer(music_url)
# 下载音频并提取本地特征
local_features = {}
tmp_file_path = None
try:
if light_mode:
print(" [Qwen] 轻量模式: 仅提取 BPM")
tmp_file_path, _ = self._download_audio(music_url, metadata=metadata)
beat_info = extract_beat_timestamps(tmp_file_path)
local_features = {"bpm": round(beat_info.tempo)}
print(f" [Qwen] 本地特征: BPM={local_features.get('bpm')}")
else:
print(f" [Qwen] 下载音频并提取本地特征...")
tmp_file_path, _ = self._download_audio(music_url, metadata=metadata)
# 从 songformer 获取段落结构用于情绪聚合
segments = songformer_data.get("segments") if songformer_data else None
local_features = self._extract_local_features(tmp_file_path, segments=segments)
# 从 SongFormer 数据中提取高潮点
climax_point = self._extract_climax_point(songformer_data)
local_features["climax_point"] = climax_point
# 构建完整的 climax_points 对象
climax_points = self._build_climax_points(songformer_data)
local_features["climax_points"] = climax_points
print(
f" [Qwen] 本地特征: BPM={local_features.get('bpm')}, "
f"段落情绪数={len(local_features.get('segment_emotions', []))}, "
f"高潮点={climax_point}"
)
except Exception as e:
print(f" [Qwen] 本地特征提取失败,将使用LLM估算值: {e}")
finally:
# 清理临时文件
if (
tmp_file_path
and os.path.exists(tmp_file_path)
and not self._is_persisted_music_file(tmp_file_path)
):
try:
os.unlink(tmp_file_path)
except:
pass
# 执行LLM分析
if extract_lyrics:
result = self._analyze_with_lyrics(client, metadata, music_url, label_level)
else:
result = self._analyze_basic(client, metadata, music_url, label_level)
# 合并本地特征到结果中
if result and local_features:
# 使用本地提取的值覆盖
result.update(local_features)
return result
def _analyze_basic(
self,
client,
metadata: Dict[str, Any],
music_url: str,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""基础分析(不含歌词,单轮标签分析)"""
# 提取音频ID用于错误定位
song_id = self._extract_song_id(metadata)
print(f" [Qwen] 分析音频: 歌曲ID={song_id}")
system_prompt, user_prompt = build_analyze_prompt(
metadata=metadata,
include_lyrics=False,
label_level=label_level,
)
prompt = self._build_dashscope_prompt(system_prompt, user_prompt)
response = self._call_with_retry_dashscope(music_url, prompt, song_id=song_id, metadata=metadata)
if response is None:
return None
raw_content = response.get("content", "")
parsed = self._parse_response(raw_content)
if parsed is None:
return None
if isinstance(parsed, list):
if parsed and isinstance(parsed[0], dict):
parsed = parsed[0]
else:
return None
if not isinstance(parsed, dict):
return None
return self._normalize_result(parsed, self.model, response.get("usage"))
def _download_audio(
self, music_url: str, metadata: Optional[Dict[str, Any]] = None
) -> Tuple[str, str]:
"""
下载音频文件到 music 目录(按 URL+歌曲ID 命名并复用缓存)
Args:
music_url: 音频URL
metadata: 音乐元数据(用于提取歌曲ID生成映射表)
Returns:
(本地文件路径, 文件扩展名)
"""
# 确定文件扩展名
ext = ".mp3"
if "." in music_url:
url_ext = music_url.split(".")[-1].split("?")[0].lower()
if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]:
ext = f".{url_ext}"
target_path = self._build_music_file_path(music_url, ext, metadata=metadata)
if not target_path.exists():
response = requests.get(music_url, timeout=60)
response.raise_for_status()
with open(target_path, "wb") as f:
f.write(response.content)
print(f" [Qwen] 音频已保存: {target_path}")
self._append_music_mapping(target_path, music_url, metadata=metadata)
return str(target_path), ext
def _extract_local_features(
self,
audio_path: str,
segments: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
提取本地音频特征
Args:
audio_path: 本地音频文件路径
segments: songformer 返回的段落结构(可选),用于聚合情绪曲线
Returns:
包含bpm、卡点时间戳、情绪曲线的字典
"""
try:
features = extract_audio_features(audio_path)
# 卡点检测
beat_info = extract_beat_timestamps(audio_path)
# 情绪曲线
emotion_curve = extract_emotion_curve(audio_path)
# beat_info.tempo 经过节拍层级纠正,比 features.tempo 更准确
result = {
"bpm": round(beat_info.tempo),
# 卡点信息
"beat_timestamps": beat_info.beat_timestamps,
"downbeat_timestamps": beat_info.downbeat_timestamps,
"beat_intervals": beat_info.beat_intervals,
}
# 如果有段落结构,返回按段落聚合的情绪数据
if segments:
segment_emotions = aggregate_emotion_by_segments(emotion_curve, segments)
result["segment_emotions"] = [
{
"start": se.start,
"end": se.end,
"label": se.label,
"intensity": se.intensity,
"energy": se.energy,
"valence": se.valence,
"arousal": se.arousal,
"trend": se.trend,
}
for se in segment_emotions
]
else:
# 没有段落结构时,返回原始情绪曲线
result["emotion_curve"] = {
"timestamps": emotion_curve.timestamps,
"energy_values": emotion_curve.energy_values,
"valence_values": emotion_curve.valence_values,
"arousal_values": emotion_curve.arousal_values,
"values": emotion_curve.smoothed_curve,
}
return result
except Exception as e:
print(f" [Qwen] 本地特征提取失败: {e}")
return {}
def _analyze_with_lyrics(
self,
client,
metadata: Dict[str, Any],
music_url: str,
label_level: int = 0,
) -> Optional[Dict[str, Any]]:
"""分析(含歌词识别,单轮标签分析 + 歌词并发)"""
# 提取音频ID用于错误定位
song_id = self._extract_song_id(metadata)
print(f" [Qwen] 分析音频: 歌曲ID={song_id}")
system_prompt, user_prompt = build_analyze_prompt(
metadata=metadata,
include_lyrics=False,
label_level=label_level,
)
prompt = self._build_dashscope_prompt(system_prompt, user_prompt)
lyrics_prompt = build_lyrics_prompt()
messages_lyrics = self._build_messages(
"请识别这段音频中的歌词内容", lyrics_prompt, music_url
)
print(" [Qwen] 并发执行基础标签分析和歌词识别...")
start_time = time.time()
result_main: Optional[Dict[str, Any]] = None
usage_main: Optional[Dict[str, Any]] = None
response_lyrics = None
timing: Dict[str, float] = {}
def _timed_call_dashscope(prompt_text: str) -> tuple[Optional[Dict], float]:
call_start = time.time()
resp = self._call_with_retry_dashscope(music_url, prompt_text, song_id=song_id, metadata=metadata)
return resp, round(time.time() - call_start, 2)
futures = {}
with ThreadPoolExecutor(max_workers=2) as executor:
futures[executor.submit(_timed_call_dashscope, prompt)] = "main"
futures[executor.submit(self._timed_call_openai, client, messages_lyrics)] = "lyrics"
for future in futures:
part = futures[future]
response, part_elapsed = future.result()
if part == "lyrics":
timing["lyrics"] = part_elapsed
response_lyrics = response
continue
timing["analysis"] = part_elapsed
if response is None:
continue
raw_content = response.get("content", "")
parsed = self._parse_response(raw_content)
if parsed is None:
continue
if isinstance(parsed, list):
if parsed and isinstance(parsed[0], dict):
parsed = parsed[0]
else:
continue
if not isinstance(parsed, dict):
continue
result_main = parsed
usage_main = response.get("usage")
elapsed = time.time() - start_time
print(f" [Qwen] 并发调用完成,总耗时: {elapsed:.2f}s")
if result_main is None:
return None
if not isinstance(result_main, dict):
return None
result: Dict[str, Any] = dict(result_main)
# 处理歌词识别结果
if response_lyrics:
raw_lyrics = response_lyrics.get("content", "")
lyrics_result = self._parse_response(raw_lyrics)
if isinstance(lyrics_result, list):
if lyrics_result and isinstance(lyrics_result[0], dict):
lyrics_result = lyrics_result[0]
if lyrics_result and "lyrics" in lyrics_result:
result["lyrics"] = lyrics_result["lyrics"]
result["_timing"] = timing
# 合并 token 使用信息
usage: Dict[str, Any] = {}
if usage_main:
usage.update(usage_main)
if response_lyrics and response_lyrics.get("usage"):
usage_lyrics = response_lyrics["usage"]
usage = {
"prompt_tokens": usage.get("prompt_tokens", 0)
+ usage_lyrics.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0)
+ usage_lyrics.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0)
+ usage_lyrics.get("total_tokens", 0),
}
result["_token_info_parts"] = {
"main": usage_main,
"lyrics": response_lyrics.get("usage") if response_lyrics else None,
}
return self._normalize_result(result, self.model, usage)
def analyze_lyrics_only(
self,
metadata: Dict[str, Any],
music_url: str,
) -> Optional[Dict[str, Any]]:
"""仅执行歌词识别,不做基础标签分析(ASR异步任务)"""
backend = (
str(
os.getenv("MUSIC_LYRICS_ASR_BACKEND")
or getattr(settings, "MUSIC_LYRICS_ASR_BACKEND", "funasr")
)
.strip()
.lower()
)
if backend == "whisper":
analyze_fn = self._analyze_lyrics_only_whisper
elif backend in {"omni", "qwen-omni", "qwen_omni"}:
# qwen-omni: 单轮流程内最多3次请求,失败后直接降级 funasr
omni_result = self._analyze_lyrics_only_qwen_omni(music_url)
if omni_result:
return omni_result
logger.warning(
"qwen-omni 歌词识别失败,降级到 funasr (lyrics_timeout=%ss)",
self.lyrics_timeout,
)
fallback_retry_count = 1
fallback_retry_delay_seconds = 2.0
for attempt in range(1, fallback_retry_count + 2):
fallback_result = self._analyze_lyrics_only_funasr(music_url)
if fallback_result:
logger.info(
"funasr 降级成功: attempt=%s/%s",
attempt,
fallback_retry_count + 1,
)
return fallback_result
if attempt <= fallback_retry_count:
logger.warning(
"funasr 降级失败,%s 秒后重试 (%s/%s)",
fallback_retry_delay_seconds,
attempt,
fallback_retry_count,
)
time.sleep(fallback_retry_delay_seconds)
logger.warning("funasr 降级失败,继续降级到 whisper")
whisper_result = self._analyze_lyrics_only_whisper(music_url)
if whisper_result:
logger.info("whisper 降级成功")
return whisper_result
logger.error("歌词识别降级链全部失败: qwen-omni -> funasr -> whisper")
return None
elif backend in {"fun", "funasr", "fun-asr"}:
analyze_fn = self._analyze_lyrics_only_funasr
else:
logger.error(
"不支持的歌词识别后端: %s,仅支持 whisper/funasr/qwen-omni",
backend,
)
return None
retry_count = 2
retry_delay_seconds = 2.0
for attempt in range(1, retry_count + 2):
result = analyze_fn(music_url)
if result:
return result
if attempt <= retry_count:
logger.warning(
"歌词识别失败,%s 秒后重试 (%d/%d): backend=%s",
retry_delay_seconds,
attempt,
retry_count,
backend,
)
time.sleep(retry_delay_seconds)
return None
def _analyze_lyrics_only_qwen_omni(self, music_url: str) -> Optional[Dict[str, Any]]:
"""qwen-omni V2 版歌词识别流程"""
client = self._get_client()
logger.info(
"开始 qwen-omni 歌词识别: timeout=%ss, max_retries=%s",
self.lyrics_timeout,
3,
)
lyrics_prompt = build_lyrics_prompt()
messages = self._build_messages(
"请识别这段音频中的歌词内容",
lyrics_prompt,
music_url,
)
response = self._call_with_retry(client, messages, max_retries=3)
if response is None:
return None
parsed = self._parse_response(response.get("content", ""))
payload: Any = parsed
if isinstance(parsed, dict):
payload = (
parsed.get("lyrics")
or parsed.get("lyric")
or parsed.get("歌词")
or parsed
)
lyrics = self._convert_qwen_omni_payload_to_lyrics(payload)
return {
"lyrics": lyrics,
"_model": self.model,
"_token_info": response.get("usage"),
"_transcription_url": None,
"_asr_task_id": None,
"_asr_backend": "qwen-omni",
}
def _convert_qwen_omni_payload_to_lyrics(self, payload: Any) -> List[Dict[str, Any]]:
"""将 qwen-omni 返回的 lyric 结构统一为 [{time, text}]"""
if payload is None:
return []
if isinstance(payload, str):
lines = [line.strip() for line in payload.splitlines() if line.strip()]
return [{"time": None, "text": line} for line in lines]
if isinstance(payload, dict):
candidate = (
payload.get("lyrics")
or payload.get("lines")
or payload.get("歌词")
or payload.get("lyric")
)
return self._convert_qwen_omni_payload_to_lyrics(candidate)
if isinstance(payload, list):
lyrics: List[Dict[str, Any]] = []
for item in payload:
if isinstance(item, str):
line = item.strip()
if line:
lyrics.append({"time": None, "text": line})
continue
if not isinstance(item, dict):
continue
text = item.get("text") or item.get("lyric") or item.get("歌词")
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
continue
time_str = item.get("time")
if not isinstance(time_str, str):
time_str = None
lyrics.append({"time": time_str, "text": text})
return lyrics
return []
def _analyze_lyrics_only_whisper(self, music_url: str) -> Optional[Dict[str, Any]]:
"""whisper-1 版歌词识别流程(91 API)"""
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
api_key = (os.getenv("API_KEY_whisper") or os.getenv("91API_KEY") or "").strip()
if not api_key:
logger.error("whisper 调用失败: 缺少环境变量 API_KEY_whisper/91API_KEY")
return None
api_url = os.getenv(
"WHISPER_API_URL",
"https://xuedingmao.top/v1/audio/transcriptions",
).strip()
headers = {"Authorization": f"Bearer {api_key}"}
tmp_file_path = None
upload_file_path = None
ext = ".mp3"
try:
tmp_file_path, ext = self._download_audio(music_url, metadata=None)
upload_file_path = tmp_file_path
upload_ext = ext
if ext.lower() == ".flac":
converted_wav = self._convert_audio_to_wav_for_whisper(tmp_file_path)
if converted_wav:
upload_file_path = converted_wav
upload_ext = ".wav"
logger.info("whisper 上传文件已从 flac 转换为 wav")
filename = f"audio{upload_ext}"
print(f"下载完成:{filename}")
content_type = "audio/wav" if upload_ext == ".wav" else "audio/mpeg"
with open(upload_file_path, "rb") as audio_file:
files = {
"file": (filename, audio_file, content_type),
}
data = {
"model": "whisper-1",
"response_format": "verbose_json",
"timestamp_granularities": ["segment"],
"prompt": "没有歌词的片段用...代替,时间戳需要精准与每句歌词进行对应,对于纯音乐直接输出‘纯音乐,禁止输出歌名,作词/作曲等元数据,仅输出歌词与时间戳’",
}
response = requests.post(
api_url,
headers=headers,
data=data,
files=files,
timeout=300,
)
if response.status_code >= 400:
logger.error(
"whisper API 返回错误: status=%s, body=%s",
response.status_code,
response.text,
)
response.raise_for_status()
payload = response.json()
except Exception as exc:
logger.exception("whisper API 调用失败: %s", exc)
return None
finally:
if (
tmp_file_path
and os.path.exists(tmp_file_path)
and not self._is_persisted_music_file(tmp_file_path)
):
try:
os.unlink(tmp_file_path)
except Exception:
pass
if (
upload_file_path
and upload_file_path != tmp_file_path
and os.path.exists(upload_file_path)
):
try:
os.unlink(upload_file_path)
except Exception:
pass
lyrics = self._convert_whisper_payload_to_lyrics(payload)
return {
"lyrics": lyrics,
"_model": "whisper-1",
"_token_info": None,
"_transcription_url": None,
"_asr_task_id": None,
"_asr_backend": "whisper",
}
def _convert_whisper_payload_to_lyrics(
self, payload: Any
) -> List[Dict[str, Any]]:
"""将 whisper 接口响应转换为 lyrics: [{time, text}]"""
if not isinstance(payload, dict):
return []
segments = payload.get("segments")
if isinstance(segments, list):
lyrics: List[Dict[str, Any]] = []
for seg in segments:
if not isinstance(seg, dict):
continue
text = seg.get("text")
if not isinstance(text, str):
continue
text = text.strip()
if not text:
continue
start = seg.get("start")
if not isinstance(start, (int, float)):
# 兼容部分接口返回 begin_time(毫秒)
begin_time = seg.get("begin_time")
if isinstance(begin_time, (int, float)):
start = float(begin_time) / 1000.0
time_str = None
if isinstance(start, (int, float)):
try:
time_str = self._format_asr_time_ms(float(start) * 1000)
except (TypeError, ValueError, OverflowError):
time_str = None
lyrics.append({"time": time_str, "text": text})
if lyrics:
return lyrics
text = payload.get("text")
if isinstance(text, str) and text.strip():
return [{"time": None, "text": text.strip()}]
return []
def _convert_audio_to_wav_for_whisper(self, source_audio_path: str) -> Optional[str]:
"""
将音频转换为 whisper 更稳定支持的 WAV 格式。
"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_tmp:
wav_path = wav_tmp.name
cmd = [
"ffmpeg",
"-y",
"-i",
source_audio_path,
"-acodec",
"pcm_s16le",
"-ac",
"1",
"-ar",
"16000",
wav_path,
]
subprocess.run(cmd, check=True, capture_output=True, text=True)
return wav_path
except Exception as exc:
logger.warning("flac 转 wav 失败,将继续使用原文件: %s", exc)
return None
def _analyze_lyrics_only_funasr(self, music_url: str) -> Optional[Dict[str, Any]]:
"""fun-asr SDK 版异步 ASR 流程"""
try:
from http import HTTPStatus
import dashscope
from dashscope.audio.asr import Transcription
except Exception as exc:
logger.exception("导入 dashscope.audio.asr.Transcription 失败: %s", exc)
return None
api_key = self._get_dashscope_api_key()
if not api_key:
logger.error("funasr 调用失败: 缺少 DashScope API Key")
return None
asr_model = getattr(settings, "DASHSCOPE_FUNASR_MODEL", "fun-asr")
dashscope.base_http_api_url = getattr(
settings,
"DASHSCOPE_BASE_HTTP_API_URL",
"https://dashscope.aliyuncs.com/api/v1",
)
dashscope.api_key = api_key
poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0))
poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0))
try:
task_resp = Transcription.async_call(
model=asr_model,
file_urls=[music_url],
)
except Exception as exc:
logger.exception("funasr async_call 失败: %s", exc)
return None
task_id = self._extract_task_id_from_asr_response(task_resp)
latest_resp: Any = task_resp
deadline = time.time() + poll_timeout
while time.time() < deadline:
task_status = self._extract_task_status_from_asr_response(latest_resp)
if task_status == "SUCCEEDED":
break
if task_status in {"FAILED", "CANCELED"}:
logger.error(
"funasr 任务失败: task_id=%s, status=%s",
task_id,
task_status,
)
return None
try:
latest_resp = Transcription.fetch(
task=latest_resp,
)
except Exception as exc:
logger.exception("funasr fetch 失败: %s", exc)
return None
time.sleep(poll_interval)
else:
logger.error("funasr 轮询超时: task_id=%s", task_id)
return None
status_code = getattr(latest_resp, "status_code", None)
if status_code is not None and status_code != HTTPStatus.OK:
logger.error(
"funasr 返回非OK状态: task_id=%s, status_code=%s",
task_id,
status_code,
)
return None
transcription_url = self._extract_transcription_url_from_asr_response(latest_resp)
if not transcription_url:
logger.error("funasr 结果缺少 transcription_url: task_id=%s", task_id)
return None
transcript_data = self._fetch_asr_transcription(transcription_url)
if not transcript_data:
return None
lyrics = self._convert_asr_transcription_to_lyrics(transcript_data)
token_info = self._extract_usage_from_asr_response(latest_resp)
return {
"lyrics": lyrics,
"_model": asr_model,
"_token_info": token_info,
"_transcription_url": transcription_url,
"_asr_task_id": task_id,
"_asr_backend": "funasr",
}
def _submit_asr_transcription_task(self, music_url: str) -> Optional[str]:
"""提交 DashScope 异步ASR任务,返回 task_id"""
api_key = self._get_dashscope_api_key()
if not api_key:
logger.error("提交ASR任务失败: 缺少 DashScope API Key")
return None
submit_url = getattr(
settings,
"DASHSCOPE_ASR_SUBMIT_URL",
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
)
asr_model = getattr(settings, "DASHSCOPE_ASR_MODEL", "qwen3-asr-flash-filetrans")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
payload = {
"model": asr_model,
"input": {"file_url": music_url},
"parameters": {
"channel_id": [0],
"enable_itn": False,
"enable_words": False,
},
}
try:
response = requests.post(
submit_url,
headers=headers,
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
except Exception as exc:
logger.exception("提交ASR任务异常: %s", exc)
return None
output = data.get("output") if isinstance(data, dict) else None
if not isinstance(output, dict):
logger.error("提交ASR任务失败: 缺少 output 字段")
return None
task_id = output.get("task_id")
if not isinstance(task_id, str) or not task_id.strip():
logger.error("提交ASR任务失败: 缺少 task_id")
return None
return task_id.strip()
def _poll_asr_task_result(self, task_id: str) -> Optional[Dict[str, Any]]:
"""轮询 DashScope 任务直到结束"""
api_key = self._get_dashscope_api_key()
if not api_key:
logger.error("轮询ASR任务失败: 缺少 DashScope API Key")
return None
task_base_url = getattr(
settings,
"DASHSCOPE_TASK_STATUS_BASE_URL",
"https://dashscope.aliyuncs.com/api/v1/tasks",
).rstrip("/")
task_url = f"{task_base_url}/{task_id}"
headers = {
"Authorization": f"Bearer {api_key}",
"X-DashScope-Async": "enable",
"Content-Type": "application/json",
}
poll_interval = float(getattr(settings, "DASHSCOPE_ASR_POLL_INTERVAL", 1.0))
poll_timeout = float(getattr(settings, "DASHSCOPE_ASR_POLL_TIMEOUT", 120.0))
deadline = time.time() + poll_timeout
while time.time() < deadline:
try:
response = requests.get(task_url, headers=headers, timeout=self.timeout)
response.raise_for_status()
data = response.json()
except Exception as exc:
logger.exception("轮询ASR任务异常: task_id=%s, error=%s", task_id, exc)
return None
output = data.get("output") if isinstance(data, dict) else None
task_status = output.get("task_status") if isinstance(output, dict) else None
if task_status == "SUCCEEDED":
return data
if task_status in {"FAILED", "CANCELED"}:
logger.error(
"ASR任务失败: task_id=%s, status=%s, data=%s",
task_id,
task_status,
data,
)
return None
time.sleep(poll_interval)
logger.error("轮询ASR任务超时: task_id=%s", task_id)
return None
def _fetch_asr_transcription(self, transcription_url: str) -> Optional[Dict[str, Any]]:
"""下载 transcription_url 对应的转写结果JSON"""
try:
response = requests.get(transcription_url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
return data if isinstance(data, dict) else None
except Exception as exc:
logger.exception("下载ASR转写结果失败: %s", exc)
return None
def _convert_asr_transcription_to_lyrics(
self, transcript_data: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""将ASR结果转换为 lyrics: [{time, text}]"""
transcripts = transcript_data.get("transcripts")
if not isinstance(transcripts, list):
return []
lyrics: List[Dict[str, Any]] = []
for transcript in transcripts:
if not isinstance(transcript, dict):
continue
sentences = transcript.get("sentences")
if not isinstance(sentences, list):
continue
for sentence in sentences:
if not isinstance(sentence, dict):
continue
text = sentence.get("text")
if not isinstance(text, str):
continue
text = text.strip()
if not text:
continue
begin_time = sentence.get("begin_time")
time_str = (
self._format_asr_time_ms(begin_time)
if isinstance(begin_time, (int, float))
else None
)
lyrics.append(
{
"time": time_str,
"text": text,
}
)
return lyrics
@staticmethod
def _format_asr_time_ms(ms_value: float) -> str:
"""毫秒转 mm:ss.xxx"""
total_ms = int(max(0, ms_value))
minutes = total_ms // 60000
seconds = (total_ms % 60000) // 1000
milliseconds = total_ms % 1000
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def _get_dashscope_api_key(self) -> Optional[str]:
"""获取 DashScope API Key(ASR专用)"""
return (
self.api_key
or settings.QWEN_DASHSCOPE_API_KEY
or settings.QWEN_API_KEY
or os.getenv("DASHSCOPE_API_KEY")
or os.getenv("QWEN_DASHSCOPE_API_KEY")
or os.getenv("QWEN_API_KEY")
)
@staticmethod
def _as_dict(response_obj: Any) -> Dict[str, Any]:
"""尽可能将 SDK 响应对象转换为 dict"""
if isinstance(response_obj, dict):
return response_obj
if response_obj is None:
return {}
for attr in ("to_dict", "as_dict", "dict"):
fn = getattr(response_obj, attr, None)
if callable(fn):
try:
value = fn()
if isinstance(value, dict):
return value
except Exception:
pass
data: Dict[str, Any] = {}
for key in ("request_id", "output", "usage"):
val = getattr(response_obj, key, None)
if val is not None:
if key in ("output", "usage") and not isinstance(val, dict):
nested = QwenAnalyzer._as_dict(val)
data[key] = nested if nested else val
else:
data[key] = val
return data
def _extract_task_id_from_asr_response(self, response_obj: Any) -> Optional[str]:
data = self._as_dict(response_obj)
output = data.get("output")
if isinstance(output, dict):
task_id = output.get("task_id")
if isinstance(task_id, str) and task_id.strip():
return task_id.strip()
return None
def _extract_task_status_from_asr_response(self, response_obj: Any) -> Optional[str]:
data = self._as_dict(response_obj)
output = data.get("output")
if isinstance(output, dict):
task_status = output.get("task_status")
if isinstance(task_status, str):
return task_status
return None
def _extract_transcription_url_from_asr_response(
self, response_obj: Any
) -> Optional[str]:
data = self._as_dict(response_obj)
output = data.get("output")
if not isinstance(output, dict):
return None
# 兼容 output.results: [{transcription_url: ...}]
results = output.get("results")
if isinstance(results, list) and results:
first = results[0]
if isinstance(first, dict):
transcription_url = first.get("transcription_url")
if isinstance(transcription_url, str) and transcription_url.strip():
return transcription_url.strip()
result = output.get("result")
if not isinstance(result, dict):
# 兜底兼容 output.transcription_url
transcription_url = output.get("transcription_url")
if isinstance(transcription_url, str) and transcription_url.strip():
return transcription_url.strip()
return None
transcription_url = result.get("transcription_url")
if isinstance(transcription_url, str) and transcription_url.strip():
return transcription_url.strip()
return None
def _extract_usage_from_asr_response(
self, response_obj: Any
) -> Optional[Dict[str, Any]]:
data = self._as_dict(response_obj)
usage = data.get("usage")
return usage if isinstance(usage, dict) else None
def _build_messages(
self,
system_prompt: str,
user_prompt: str,
music_url: str,
) -> list:
"""构建消息格式"""
messages = []
# 添加系统提示词
if system_prompt:
messages.append(
{
"role": "system",
"content": system_prompt,
}
)
# 添加用户消息(包含音频和文本)
messages.append(
{
"role": "user",
"content": [
{
"type": "input_audio",
"input_audio": {"data": music_url, "format": "mp3"},
},
{"type": "text", "text": user_prompt},
],
}
)
return messages
def _build_dashscope_prompt(self, system_prompt: str, user_prompt: str) -> str:
"""构建 DashScope 调用的文本提示词"""
if system_prompt and system_prompt.strip():
return f"{system_prompt.strip()}\n\n{user_prompt}".strip()
return user_prompt.strip()
def _timed_call_openai(
self, client, messages: list
) -> tuple[Optional[Dict], float]:
"""为 OpenAI 兼容调用提供耗时统计"""
call_start = time.time()
resp = self._call_with_retry(client, messages)
return resp, round(time.time() - call_start, 2)
def _call_with_retry_dashscope(
self, music_url: str, prompt: str, timeout: Optional[float] = None, song_id: str = "", metadata: Optional[Dict[str, Any]] = None
) -> Optional[Dict]:
"""使用 DashScope SDK 进行多模态调用(带重试,自动降级到 base64)"""
import dashscope
dashscope_key = (
self.api_key
or settings.QWEN_DASHSCOPE_API_KEY
or os.getenv("QWEN_OMNI_API_KEY")
or os.getenv("DASHSCOPE_API_KEY")
)
if not dashscope_key:
print(" ⚠ 未设置 DASHSCOPE_API_KEY 环境变量,请先配置")
return None
messages = [
{
"role": "user",
"content": [
{"audio": music_url},
{"text": prompt},
],
}
]
timeout = timeout or self.timeout
for attempt in range(1, self.max_retries + 1):
try:
print(
f" [{self.model}] 正在分析 (DashScope 尝试 {attempt}/{self.max_retries}, timeout={timeout}s)..."
)
response = self._dashscope_call_with_hard_timeout(
dashscope=dashscope,
api_key=dashscope_key,
model=self.model,
messages=messages,
timeout=timeout,
)
if response.status_code != 200:
error_msg = getattr(response, "message", "")
error_code = getattr(response, "code", "")
error_output = getattr(response, "output", {})
print(
f" ✗ [{self.model}] API 调用失败,状态码: {response.status_code}"
)
if song_id:
print(f" 歌曲ID: {song_id}")
if error_code:
print(f" 错误代码: {error_code}")
if error_msg:
print(f" 错误信息: {error_msg}")
if error_output:
print(f" 响应内容: {error_output}")
# 检测文件过大错误,自动降级到 OSS 方式
if "file size is too large" in str(error_msg).lower() or "file size is too large" in str(error_output).lower():
print(f" [Qwen] 检测到文件过大,自动降级到 OSS 方式...")
try:
temp_audio_path = self._download_audio_temp(music_url)
if temp_audio_path:
mono_path = self._convert_to_mono(temp_audio_path)
oss_url = self._upload_audio_to_oss(mono_path)
# 只删除转换后的单声道文件,保留原始下载文件
self._cleanup_temp_audio(mono_path)
if oss_url:
print(f" [Qwen] 使用 OSS URL 重新请求: {oss_url[:60]}...")
return self._call_with_retry_dashscope(oss_url, prompt, timeout=timeout, song_id=song_id, metadata=metadata)
except Exception as e:
print(f" [Qwen] OSS 降级失败: {e}")
return None
if attempt < self.max_retries:
time.sleep(attempt)
continue
return None
content = response.output.choices[0].message.content
if isinstance(content, list):
if content and isinstance(content[0], dict) and "text" in content[0]:
result_text = content[0]["text"]
else:
result_text = ""
else:
result_text = content
usage = None
resp_usage = getattr(response, "usage", None)
if isinstance(resp_usage, dict):
input_tokens = resp_usage.get(
"input_tokens", resp_usage.get("prompt_tokens", 0)
)
output_tokens = resp_usage.get(
"output_tokens", resp_usage.get("completion_tokens", 0)
)
total_tokens = resp_usage.get("total_tokens")
usage = {
"prompt_tokens": input_tokens or 0,
"completion_tokens": output_tokens or 0,
"total_tokens": total_tokens
if total_tokens is not None
else (input_tokens or 0) + (output_tokens or 0),
}
elif resp_usage is not None:
input_tokens = getattr(resp_usage, "input_tokens", None)
output_tokens = getattr(resp_usage, "output_tokens", None)
total_tokens = getattr(resp_usage, "total_tokens", None)
usage = {
"prompt_tokens": input_tokens or 0,
"completion_tokens": output_tokens or 0,
"total_tokens": total_tokens
if total_tokens is not None
else (input_tokens or 0) + (output_tokens or 0),
}
return {"content": result_text, "usage": usage}
except TimeoutError:
print(f" ✗ [{self.model}] API 调用超时 (尝试 {attempt}/{self.max_retries})")
if attempt < self.max_retries:
time.sleep(attempt)
continue
return None
except Exception as e:
print(f" ✗ [{self.model}] API 调用异常: {e}")
if attempt < self.max_retries:
time.sleep(attempt)
continue
return None
return None
def _download_audio_temp(self, music_url: str) -> Optional[str]:
"""
临时下载音频文件到系统临时目录
Args:
music_url: 音频URL
Returns:
临时文件路径,如果下载失败返回 None
"""
try:
# 确定文件扩展名
ext = ".mp3"
if "." in music_url:
url_ext = music_url.split(".")[-1].split("?")[0].lower()
if url_ext in ["mp3", "wav", "flac", "aac", "m4a", "ogg"]:
ext = f".{url_ext}"
# 下载到系统临时目录
temp_dir = tempfile.gettempdir()
url_hash = hashlib.md5(music_url.encode("utf-8")).hexdigest()[:12]
temp_path = os.path.join(temp_dir, f"qwen_audio_{url_hash}{ext}")
if not os.path.exists(temp_path):
response = requests.get(music_url, timeout=60)
response.raise_for_status()
with open(temp_path, "wb") as f:
f.write(response.content)
print(f" [Qwen] 临时音频已下载: {temp_path}")
else:
print(f" [Qwen] 使用缓存的临时音频")
return temp_path
except Exception as e:
print(f" [Qwen] 临时音频下载失败: {e}")
return None
def _convert_to_mono(self, audio_path: str) -> str:
"""
将音频转换为单声道
Args:
audio_path: 原始音频文件路径
Returns:
转换后的音频文件路径
"""
import time
timestamp = int(time.time() * 1000)
base_name = os.path.basename(audio_path)
name_parts = base_name.rsplit(".", 1)
if len(name_parts) == 2:
mono_path = os.path.join(
os.path.dirname(audio_path),
f"{name_parts[0]}_mono_{timestamp}.{name_parts[1]}"
)
else:
mono_path = f"{audio_path}_mono_{timestamp}"
try:
cmd = [
"ffmpeg",
"-i", audio_path,
"-ac", "1", # 转为单声道
"-y",
mono_path
]
print(f" [Qwen] 转换为单声道: ffmpeg -i ... -ac 1")
subprocess.run(cmd, capture_output=True, timeout=60, check=True)
original_size = os.path.getsize(audio_path)
mono_size = os.path.getsize(mono_path)
ratio = (1 - mono_size / original_size) * 100
print(f" [Qwen] 音频已转换: {original_size/1024/1024:.1f}MB -> {mono_size/1024/1024:.1f}MB (压缩率: {ratio:.1f}%)")
return mono_path
except Exception as e:
print(f" [Qwen] 音频转换失败: {e},将使用原文件")
return audio_path
def _upload_audio_to_oss(self, audio_path: str) -> Optional[str]:
"""
将音频文件上传到 OSS
Args:
audio_path: 音频文件路径
Returns:
OSS 文件 URL,如果上传失败返回 None
"""
try:
from app.utils.oss_uploader import oss_uploader
success, result = oss_uploader.upload_file(audio_path)
if not success:
print(f" [Qwen] 音频上传到 OSS 失败: {result}")
return None
oss_url = result
print(f" [Qwen] 音频已上传到 OSS: {oss_url}")
return oss_url
except Exception as e:
print(f" [Qwen] 音频上传到 OSS 失败: {e}")
return None
def _cleanup_temp_audio(self, temp_path: str) -> None:
"""清理临时音频文件"""
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
print(f" [Qwen] 已清理临时音频文件")
except:
pass
def _dashscope_call_with_hard_timeout(
self,
dashscope,
api_key: str,
model: str,
messages: list,
timeout: float,
):
"""
DashScope SDK 某些版本下 request_timeout 可能无法稳定生效。
这里增加线程级硬超时,避免单次调用无限阻塞。
"""
box: Dict[str, Any] = {}
done = threading.Event()
def _target() -> None:
try:
box["response"] = dashscope.MultiModalConversation.call(
api_key=api_key,
model=model,
messages=messages,
request_timeout=timeout,
)
except Exception as exc:
box["error"] = exc
finally:
done.set()
worker = threading.Thread(target=_target, daemon=True)
worker.start()
hard_timeout = max(float(timeout), 1.0) + 3.0
if not done.wait(hard_timeout):
raise TimeoutError(f"DashScope hard timeout after {hard_timeout:.1f}s")
if "error" in box:
raise box["error"]
return box.get("response")
def _call_with_retry(
self,
client,
messages: list,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
) -> Optional[Dict]:
"""带重试的 API 调用(非流式)"""
timeout = timeout or self.lyrics_timeout
retries = max_retries or self.max_retries
for attempt in range(1, retries + 1):
try:
print(
f" [Qwen] 调用模型 (尝试 {attempt}/{retries}, timeout={timeout}s)..."
)
response = client.chat.completions.create(
model=self.model,
messages=messages,
modalities=["text"],
stream=False,
timeout=timeout,
extra_body={"enable_thinking": False},
)
content = (
response.choices[0].message.content if response.choices else ""
)
usage = {
"prompt_tokens": response.usage.prompt_tokens
if response.usage
else 0,
"completion_tokens": response.usage.completion_tokens
if response.usage
else 0,
"total_tokens": response.usage.total_tokens
if response.usage
else 0,
}
print(f" [Qwen] 响应: {content[:100]}...")
return {"content": content, "usage": usage}
except Exception as e:
error_type = type(e).__name__
print(f" [Qwen] 错误 ({error_type}): {e}")
if attempt < retries:
wait_time = attempt
print(f" 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f" 已达到最大重试次数")
return None
return None
# 聚音标签识别助手 - 系统角色定义
## 角色定位
你是音乐内容标签标注助手。
你的任务是基于输入的歌曲信息(如歌词、标题、风格描述、音频特征等),严格按照「聚音标签字典」输出标准化标签字段。
只输出标签结果,不做解释,不做分析,不添加任何多余文本。
------
## 输出格式
仅输出 JSON 纯文本,结构如下:
{
"performer_type": "",
"language": "",
"emotion": [],
"douyin_tags": [],
"music_style_tags": [],
"instrument_tags": [],
"scene": []
}
禁止输出任何解释性文字、注释或额外字段。
------
## 全局约束规则
1. 所有标签必须严格从下方字典中选择,禁止自造词。
2. 不允许基于刻板印象猜测(如仅凭曲风推断情绪)。
3. 标签必须基于明确特征:
- 歌词内容
- 音乐风格特征
- 明确出现的配器
- 明确使用场景
4. 多选字段仅选择高度确定且核心表达的标签,避免过度打标。
5. 注意!所有字段至少选择一个标签,不允许留空。
------
# 字段判定标准说明
## 一、演唱者类型 performer_type(单选)
用于标注主要人声类型,仅根据实际听感或明确描述判断:
- 男声:主要为男性声线
- 女声:主要为女性声线
- 童声:明显儿童声线
- 合唱:多人群体演唱为主(非简单和声)
不确定时输出 ""。
------
## 二、情绪 emotion(多选)
必须基于歌曲整体情绪表达判断,而非个别词语。
- 喜庆:节日、庆典氛围明显
- 浪漫:爱情氛围浓厚
- 雄壮:宏大、史诗、气势恢宏
- 蛊惑:迷幻、魅惑、暧昧
- 宣泄:情绪爆发、释放
- 悲壮:悲情但具有力量感
- 愤怒:强烈对抗或激烈表达
- 庄重:正式、肃穆
- 激情:热烈高昂
- 沉重:压抑、厚重
- 快乐:轻松开心
- 励志:奋斗、成长、自我激励
- 思念:想念某人或过往
- 紧张:悬而未决、焦虑
- 恐怖:惊悚氛围
- 感动:温情催泪
- 恶搞:刻意夸张调侃
- 搞笑:明显幽默表达
- 期待:盼望未来
- 怀念:回忆过去
- 甜蜜:恋爱甜感
- 孤独:孤单、自我独白
- 伤感:悲伤低落
- 悬疑:神秘未知感
- 祝福:祝愿表达
- 佛系:平淡随性
- 舒缓:节奏慢、平稳
- 悠扬:旋律流畅优美
- 温暖:柔和治愈
- 忧郁:带有阴郁气质
避免:
- 同时选择强烈对立情绪(如 快乐 与 伤感)
- 同类标签堆叠(如 伤感 + 忧郁 + 孤独 需明确区分)
------
## 三、语种 language(单选)
仅从下列标签中选择一个最主要的演唱语种:
- 普通话
- 粤语
- 藏语
- 英语
- 韩语
- 闽南语
- 蒙语
- 俄语
- 其他
规则:
- 只输出一个语种标签
- 依据实际演唱语言判断,不根据歌手国籍或曲风猜测
- 纯音乐或无法判断时输出 ""
------
## 四、网络/抖音歌曲 douyin_tags(可多选)
仅当歌曲具备明显网络传播特征或主题风格时选择:
- 草原:草原文化、民族草原元素
- 故乡:思乡主题
- 神曲:洗脑旋律、强节奏重复
- 文艺:小众表达、诗性表达
- 青春:校园或成长主题
- 治愈系:温暖疗愈风格
- 清新:轻快自然风格
- 奇幻:幻想、魔幻元素
非明显网络属性不要强行标注。
------
## 五、音乐风格 music_style_tags(多选)
必须根据音乐结构与风格特征判断,不根据歌词主题判断。
- 世界音乐
- 雷鬼
- R&B/Soul
- MC喊麦
- 另类音乐
- 民歌
- 戏曲
- 古风
- 古典音乐
- HipHop
- Rap
- 摇滚
- DJ嗨曲
- 布鲁斯/蓝调
- 拉丁
- 舞曲
- 爵士
- 乡村
- 民谣
- 流行
- 轻音乐
- 国风
- 儿歌
规则:
- 只选核心风格,不叠加相似风格
- 不因使用某个乐器就推断整体风格
- 无明显风格时可只选“流行”
------
## 六、配器 instrument_tags(多选)
仅在明确可识别时选择:
- 二胡
- 竹笛
- 琵琶
- 音效
- 口琴
- 电子
- 木吉他
- 鼓组
- 弦乐
- 电吉他
- 古筝
- 钢琴
规则:
- 必须为明显主导或突出配器
- 不因常规伴奏默认存在而标注
- 不确定不要猜
------
## 七、场景 scene(多选)
根据歌曲使用场景或明显氛围判断:
- 餐厅
- 汽车
- 跳舞
- 旅行
- 工作
- 校园
- 夜店
- 运动
- 休闲
- live house
- 广场舞
- 抖音
- 婚礼
- 约会
规则:
- 仅当歌曲明显适配该场景时标注
- 避免泛化场景(如所有慢歌都标“休闲”)
------
## 最终执行要求
- 只输出 JSON
- 不解释
- 不补充说明
- 不输出字典内容
- 不输出“分析如下”之类文字
- 不添加未定义字段
严格遵守字段范围与空值规则。
## 输出格式
必须严格输出以下 JSON 结构,字段名不能改:
```json
{
"performer_type": "",
"language": "",
"emotion": [],
"douyin_tags": [],
"music_style_tags": [],
"instrument_tags": [],
"scene": []
}
```
# 待分析元数据
{{METADATA_SECTION}}
# 任务目标
请基于音频内容完成聚音标签识别,仅输出系统要求的标签字段。
# 约束提醒
- 必须基于实际听到的特征,无法确认的标签输出空值。
- 严格执行 JSON 纯文本输出,禁止任何 Markdown 格式。
# 歌词识别提示词模板
# 仅识别歌词内容,不包含其他音乐分析
请识别并转录音频中的完整歌词。
## 核心任务
1. **逐句识别**:按时间顺序输出每一句歌词,每句通过换行进行分隔。
2. **字段要求**:每条记录必须包含 `time` (格式 "mm:ss.xxx",无法确定则为 null) 和 `text` (歌词内容)。
3. **无语义音节压缩**:对于“啊/呜/哦/嗯/啦”等辅助音节,禁止逐字展示,统一使用 `...` 缩略(例:把“啊啊啊啊”识别为“啊...”)。
4. **完整性**:必须转录包括重复段落在内的全曲内容。
5. **静默与纯音乐**:若为纯音乐或无歌词,仅返回空数组 `[]`
6. 完整识别歌曲所有段落的完整歌词,包括不同段落之间重复了的歌词
## 输出格式规范
- 严格输出 JSON,不得包含任何 Markdown 转义符(如 ```json)或解释性文字。
- 字段统一为: {"lyrics": [{"time": "00:00.000", "text": "内容"}]}
## 质量控制
- 遇到合唱/重叠时,以主旋律为主。
- 严禁自行脑补不存在的歌词。
- 不要返回任何其他无关内容
"""
阿里云OSS服务
提供文件上传、下载、删除等功能
"""
import oss2
from typing import Optional, Dict, Any, BinaryIO
from datetime import datetime, timedelta
from pathlib import Path
import hashlib
import mimetypes
import logging
import time
import json
from functools import wraps
from aliyunsdkcore.client import AcsClient
from aliyunsdksts import AssumeRoleRequest
from app.core.config import settings
from app.core.exceptions import (
BusinessException,
ValidationException,
ExternalServiceException,
NotFoundException
)
logger = logging.getLogger(__name__)
def retry_on_connection_error(max_retries: int = 3, delay: float = 1.0):
"""
重试装饰器 - 在连接错误时重试
Args:
max_retries: 最大重试次数
delay: 重试延迟(秒)
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (oss2.exceptions.RequestError,
ConnectionError,
TimeoutError,
Exception) as e:
# 只重试连接相关的错误
error_str = str(e)
if ('Connection' in error_str or
'timeout' in error_str.lower() or
'closed' in error_str.lower() or
'RemoteDisconnected' in error_str):
last_exception = e
if attempt < max_retries - 1:
wait_time = delay * (2 ** attempt) # 指数退避
logger.warning(
f"[{func.__name__}] 连接错误,{wait_time}秒后重试 "
f"({attempt + 1}/{max_retries}): {error_str}"
)
time.sleep(wait_time)
continue
else:
logger.error(f"[{func.__name__}] 重试次数已用尽: {error_str}")
raise
else:
# 非连接错误直接抛出
raise
if last_exception:
raise last_exception
return wrapper
return decorator
class OSSService:
"""阿里云OSS服务类"""
def __init__(self):
"""初始化OSS客户端"""
if not all([
settings.OSS_ACCESS_KEY_ID,
settings.OSS_ACCESS_KEY_SECRET,
settings.OSS_ENDPOINT,
settings.OSS_BUCKET_NAME
]):
raise BusinessException("OSS配置不完整,请检查环境变量")
# 创建认证对象
auth = oss2.Auth(
settings.OSS_ACCESS_KEY_ID,
settings.OSS_ACCESS_KEY_SECRET
)
# 确定Endpoint
endpoint = settings.OSS_ENDPOINT
if settings.OSS_INTERNAL_ENDPOINT and settings.OSS_REGION:
endpoint = f"oss-{settings.OSS_REGION}-internal.aliyuncs.com"
# 创建Bucket对象
self.bucket = oss2.Bucket(
auth,
endpoint,
settings.OSS_BUCKET_NAME
)
# 设置超时参数(在 bucket 对象上设置)
self.bucket.timeout = (
settings.OSS_CONNECTION_TIMEOUT,
settings.OSS_READ_TIMEOUT
)
# 配置重试参数
self.max_retries = settings.OSS_REQUEST_RETRY_TIMES
self.retry_delay = settings.OSS_RETRY_DELAY
logger.info(
f"OSS服务已初始化: endpoint={endpoint}, "
f"timeout={settings.OSS_CONNECTION_TIMEOUT}/{settings.OSS_READ_TIMEOUT}s, "
f"retries={self.max_retries}"
)
@staticmethod
def _validate_file_extension(filename: str) -> bool:
"""验证文件扩展名"""
ext = Path(filename).suffix.lower()
if not ext:
return False
return ext in [e.lower() for e in settings.OSS_ALLOWED_EXTENSIONS]
@staticmethod
def _validate_file_size(file_size: int) -> bool:
"""验证文件大小"""
return 0 < file_size <= settings.OSS_MAX_FILE_SIZE
@staticmethod
def _generate_object_key(
filename: str,
user_id: int,
prefix: Optional[str] = None
) -> str:
"""
生成OSS对象键(路径)
格式: {prefix}/{user_id}/{date}/{hash}_{filename}
"""
# 获取文件扩展名
ext = Path(filename).suffix.lower()
name = Path(filename).stem
# 生成时间戳目录
now = datetime.now()
date_path = now.strftime("%Y%m%d")
# 生成唯一标识(时间戳 + 4位随机数)
timestamp = now.strftime("%H%M%S%f")
unique_id = hashlib.md5(f"{user_id}{timestamp}{name}".encode()).hexdigest()[:8]
# 组合文件名
new_filename = f"{unique_id}_{name}{ext}"
# 确定前缀
base_prefix = prefix or settings.OSS_UPLOAD_PATH_PREFIX
# 组合完整路径(加上全局前缀)
if settings.OSS_GLOBAL_PREFIX:
return f"{settings.OSS_GLOBAL_PREFIX}/{base_prefix}/{user_id}/{date_path}/{new_filename}"
return f"{base_prefix}/{user_id}/{date_path}/{new_filename}"
@staticmethod
def _generate_object_key_simple(
filename: str,
entity_type: str
) -> str:
"""
生成OSS对象键(路径)- 简化版本,用于角色/场景图片转存
格式: {entity_type}/{date}/{filename}
"""
# 获取文件扩展名
ext = Path(filename).suffix.lower()
name = Path(filename).stem
# 生成时间戳目录
now = datetime.now()
date_path = now.strftime("%Y%m%d")
# 生成唯一标识(时间戳 + 随机数)
timestamp = now.strftime("%H%M%S%f")
unique_id = hashlib.md5(f"{entity_type}{timestamp}{name}".encode()).hexdigest()[:8]
# 组合文件名
new_filename = f"{unique_id}{ext}"
# 组合完整路径:类型/日期/文件名.后缀
if settings.OSS_GLOBAL_PREFIX:
return f"{settings.OSS_GLOBAL_PREFIX}/{entity_type}/{date_path}/{new_filename}"
return f"{entity_type}/{date_path}/{new_filename}"
@staticmethod
def _get_content_type(filename: str) -> str:
"""获取文件MIME类型"""
content_type, _ = mimetypes.guess_type(filename)
return content_type or 'application/octet-stream'
@staticmethod
def _get_extension_from_mime(mime_type: str) -> str:
"""
根据MIME类型获取文件扩展名
参数:
mime_type: MIME类型(如 image/jpeg, video/mp4)
返回:
文件扩展名(包含点号,如 .jpg, .mp4)
"""
# MIME类型到扩展名的映射
mime_to_ext = {
# 图片
'image/jpeg': '.jpg',
'image/jpg': '.jpg',
'image/png': '.png',
'image/gif': '.gif',
'image/webp': '.webp',
'image/svg+xml': '.svg',
'image/bmp': '.bmp',
'image/x-icon': '.ico',
# 视频
'video/mp4': '.mp4',
'video/mpeg': '.mpeg',
'video/webm': '.webm',
'video/quicktime': '.mov',
'video/x-msvideo': '.avi',
'video/x-matroska': '.mkv',
# 音频
'audio/mpeg': '.mp3',
'audio/wav': '.wav',
'audio/ogg': '.ogg',
'audio/webm': '.weba',
# 其他
'application/pdf': '.pdf',
'text/plain': '.txt',
'application/json': '.json',
}
# 处理带参数的 MIME 类型(如 image/jpeg; charset=utf-8)
base_mime = mime_type.split(';')[0].strip().lower()
return mime_to_ext.get(base_mime, '')
def upload_file(
self,
file_data: BinaryIO,
filename: str,
user_id: int,
prefix: Optional[str] = None,
validate_extension: bool = True
) -> Dict[str, Any]:
"""
直接上传文件到OSS(适合小文件)
参数:
file_data: 文件二进制流
filename: 原始文件名
user_id: 用户ID
prefix: 路径前缀(可选)
validate_extension: 是否验证文件扩展名
返回:
上传结果信息
"""
# 使用重试装饰器包装实际上传操作
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _do_upload():
# 验证文件扩展名
if validate_extension and not self._validate_file_extension(filename):
raise ValidationException(
f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
)
# 生成OSS对象键
object_key = self._generate_object_key(filename, user_id, prefix)
# 获取Content-Type
content_type = self._get_content_type(filename)
# 上传文件
result = self.bucket.put_object(
object_key,
file_data,
headers={'Content-Type': content_type}
)
# 构建文件URL
file_url = self._build_file_url(object_key)
return {
"object_key": object_key,
"filename": filename,
"url": file_url,
"content_type": content_type,
"etag": result.etag,
"uploaded_at": datetime.now().isoformat()
}
try:
return _do_upload()
except oss2.exceptions.OssError as e:
logger.error(f"OSS上传失败: {e.message}")
raise ExternalServiceException(f"OSS上传失败: {e.message}")
except ValidationException:
raise
except Exception as e:
logger.error(f"文件上传失败: {str(e)}", exc_info=True)
raise BusinessException(f"文件上传失败: {str(e)}")
def upload_file_with_size(
self,
file_data: BinaryIO,
filename: str,
file_size: int,
user_id: int,
prefix: Optional[str] = None
) -> Dict[str, Any]:
"""
根据文件大小智能选择上传方式(带文件大小验证)
参数:
file_data: 文件二进制流
filename: 原始文件名
file_size: 文件大小(字节)
user_id: 用户ID
prefix: 路径前缀(可选)
返回:
上传结果信息
"""
# 验证文件大小
if not self._validate_file_size(file_size):
raise ValidationException(
f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB"
)
# 根据文件大小选择上传方式
if file_size > settings.OSS_MULTIPART_THRESHOLD:
return self.multipart_upload(file_data, filename, file_size, user_id, prefix)
else:
result = self.upload_file(file_data, filename, user_id, prefix)
result["size"] = file_size
return result
def multipart_upload(
self,
file_data: BinaryIO,
filename: str,
file_size: int,
user_id: int,
prefix: Optional[str] = None,
part_size: Optional[int] = None,
) -> Dict[str, Any]:
"""
分片上传文件到OSS(适合大文件)
参数:
file_data: 文件二进制流
filename: 原始文件名
file_size: 文件大小(字节)
user_id: 用户ID
prefix: 路径前缀(可选)
返回:
上传结果信息
"""
upload_id = None
object_key = None
effective_part_size = part_size or settings.OSS_PART_SIZE
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _do_multipart_upload():
nonlocal upload_id, object_key
# 验证文件扩展名
if not self._validate_file_extension(filename):
raise ValidationException(
f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
)
# 生成OSS对象键
object_key = self._generate_object_key(filename, user_id, prefix)
# 获取Content-Type
content_type = self._get_content_type(filename)
# 初始化分片上传
upload_id = self.bucket.init_multipart_upload(
object_key,
headers={'Content-Type': content_type}
).upload_id
logger.info(f"[multipart_upload] 开始分片上传: {object_key}, upload_id={upload_id}")
# 计算分片数量
part_size = effective_part_size
part_count = (file_size + part_size - 1) // part_size
# 上传所有分片
parts = []
for part_number in range(1, part_count + 1):
# 读取分片数据
offset = (part_number - 1) * part_size
size = min(part_size, file_size - offset)
file_data.seek(offset)
part_data = file_data.read(size)
# 上传分片(每个分片也使用重试)
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _upload_part():
result = self.bucket.upload_part(
object_key,
upload_id,
part_number,
part_data
)
logger.debug(f"分片 {part_number}/{part_count} 上传成功")
return result
result = _upload_part()
parts.append(oss2.models.PartInfo(part_number, result.etag))
# 完成分片上传
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _complete_upload():
return self.bucket.complete_multipart_upload(
object_key,
upload_id,
parts
)
result = _complete_upload()
# 构建文件URL
file_url = self._build_file_url(object_key)
return {
"object_key": object_key,
"filename": filename,
"url": file_url,
"size": file_size,
"content_type": content_type,
"etag": result.etag,
"upload_id": upload_id,
"part_count": part_count,
"uploaded_at": datetime.now().isoformat()
}
try:
return _do_multipart_upload()
except oss2.exceptions.OssError as e:
logger.error(f"OSS分片上传失败: {e.message}")
# 上传失败,尝试取消分片上传
if upload_id and object_key:
try:
self.bucket.abort_multipart_upload(
object_key,
upload_id
)
logger.info(f"已取消分片上传: {object_key}, upload_id={upload_id}")
except Exception as cancel_error:
logger.warning(f"取消分片上传失败: {cancel_error}")
raise ExternalServiceException(f"OSS分片上传失败: {e.message}")
except (ValidationException, ExternalServiceException):
raise
except Exception as e:
logger.error(f"分片上传失败: {str(e)}", exc_info=True)
# 上传失败,尝试取消分片上传
if upload_id and object_key:
try:
self.bucket.abort_multipart_upload(
object_key,
upload_id
)
except:
pass
raise BusinessException(f"分片上传失败: {str(e)}")
def generate_presigned_url(
self,
filename: str,
user_id: int,
expires: Optional[int] = None,
prefix: Optional[str] = None
) -> Dict[str, Any]:
"""
生成预签名上传URL(前端直传)
参数:
filename: 文件名
user_id: 用户ID
expires: 过期时间(秒),默认使用配置值
prefix: 路径前缀(可选)
返回:
预签名URL信息
"""
try:
# 验证文件扩展名
if not self._validate_file_extension(filename):
raise ValidationException(
f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
)
# 生成OSS对象键
object_key = self._generate_object_key(filename, user_id, prefix)
# 设置过期时间
expires = expires or settings.OSS_SIGNED_URL_EXPIRE
# 生成预签名URL
signed_url = self.bucket.sign_url(
'PUT',
object_key,
expires,
headers={'Content-Type': self._get_content_type(filename)}
)
# 构建文件URL(上传后的访问URL)
file_url = self._build_file_url(object_key)
return {
"upload_url": signed_url,
"object_key": object_key,
"file_url": file_url,
"expires_in": expires,
"expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat(),
"method": "PUT",
"headers": {
"Content-Type": self._get_content_type(filename)
}
}
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"生成签名URL失败: {e.message}")
except ValidationException:
raise
except Exception as e:
raise BusinessException(f"生成签名URL失败: {str(e)}")
def generate_multipart_presigned_urls(
self,
filename: str,
file_size: int,
user_id: int,
expires: Optional[int] = None,
prefix: Optional[str] = None
) -> Dict[str, Any]:
"""
生成分片上传的预签名URL(前端分片直传)
参数:
filename: 文件名
file_size: 文件大小(字节)
user_id: 用户ID
expires: 过期时间(秒)
prefix: 路径前缀(可选)
返回:
分片上传信息和预签名URL列表
"""
try:
# 验证文件大小
if not self._validate_file_size(file_size):
raise ValidationException(
f"文件大小超出限制,最大允许 {settings.OSS_MAX_FILE_SIZE / 1024 / 1024:.2f}MB"
)
# 验证文件扩展名
if not self._validate_file_extension(filename):
raise ValidationException(
f"不支持的文件类型,允许的类型: {', '.join(settings.OSS_ALLOWED_EXTENSIONS)}"
)
# 生成OSS对象键
object_key = self._generate_object_key(filename, user_id, prefix)
# 初始化分片上传
upload_id = self.bucket.init_multipart_upload(
object_key,
headers={'Content-Type': self._get_content_type(filename)}
).upload_id
# 计算分片数量
part_size = settings.OSS_PART_SIZE
part_count = (file_size + part_size - 1) // part_size
# 设置过期时间
expires = expires or settings.OSS_SIGNED_URL_EXPIRE
# 为每个分片生成预签名URL
part_urls = []
for part_number in range(1, part_count + 1):
params = {
'uploadId': upload_id,
'partNumber': str(part_number)
}
signed_url = self.bucket.sign_url(
'PUT',
object_key,
expires,
params=params
)
part_urls.append({
"part_number": part_number,
"upload_url": signed_url
})
# 构建文件URL(完成后的访问URL)
file_url = self._build_file_url(object_key)
return {
"upload_id": upload_id,
"object_key": object_key,
"file_url": file_url,
"part_size": part_size,
"part_count": part_count,
"part_urls": part_urls,
"expires_in": expires,
"expires_at": (datetime.now() + timedelta(seconds=expires)).isoformat()
}
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"初始化分片上传失败: {e.message}")
except (ValidationException, ExternalServiceException):
raise
except Exception as e:
raise BusinessException(f"初始化分片上传失败: {str(e)}")
def complete_multipart_upload_by_client(
self,
object_key: str,
upload_id: str,
parts: list
) -> Dict[str, Any]:
"""
完成客户端分片上传
参数:
object_key: OSS对象键
upload_id: 上传ID
parts: 分片信息列表 [{"part_number": 1, "etag": "xxx"}, ...]
返回:
完成结果
"""
try:
# 构建分片信息
part_info_list = [
oss2.models.PartInfo(part["part_number"], part["etag"])
for part in parts
]
# 完成分片上传
result = self.bucket.complete_multipart_upload(
object_key,
upload_id,
part_info_list
)
# 构建文件URL
file_url = self._build_file_url(object_key)
return {
"object_key": object_key,
"url": file_url,
"etag": result.etag,
"completed_at": datetime.now().isoformat()
}
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"完成分片上传失败: {e.message}")
except Exception as e:
raise BusinessException(f"完成分片上传失败: {str(e)}")
def abort_multipart_upload(
self,
object_key: str,
upload_id: str
) -> bool:
"""
取消分片上传
参数:
object_key: OSS对象键
upload_id: 上传ID
返回:
是否成功
"""
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _do_abort():
self.bucket.abort_multipart_upload(
object_key,
upload_id
)
try:
_do_abort()
return True
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"取消分片上传失败: {e.message}")
except Exception as e:
raise BusinessException(f"取消分片上传失败: {str(e)}")
def delete_file(self, object_key: str) -> bool:
"""
删除OSS文件
参数:
object_key: OSS对象键
返回:
是否成功
"""
try:
self.bucket.delete_object(object_key)
return True
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"删除文件失败: {e.message}")
except Exception as e:
raise BusinessException(f"删除文件失败: {str(e)}")
def delete_files_batch(self, object_keys: list) -> Dict[str, Any]:
"""
批量删除OSS文件
参数:
object_keys: OSS对象键列表
返回:
删除结果
"""
try:
result = self.bucket.batch_delete_objects(object_keys)
return {
"deleted_count": len(result.deleted_keys),
"deleted_keys": result.deleted_keys
}
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"批量删除文件失败: {e.message}")
except Exception as e:
raise BusinessException(f"批量删除文件失败: {str(e)}")
def upload_from_url(
self,
url: str,
entity_type: str,
filename: Optional[str] = None
) -> Dict[str, Any]:
"""
从URL下载文件并上传到OSS(用于转存外部生成的图片或视频)
参数:
url: 文件URL(图片或视频)
entity_type: 实体类型(character/scene),用于构建存储路径
filename: 自定义文件名(可选),如果未提供则从URL提取或根据Content-Type推断
返回:
上传结果信息
"""
import requests
from urllib.parse import urlparse
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _do_upload():
# 下载文件
try:
response = requests.get(url, timeout=30, stream=True)
response.raise_for_status()
except Exception as e:
logger.error(f"从URL下载文件失败: {url}, error: {e}")
raise ExternalServiceException(f"从URL下载文件失败: {str(e)}")
# 获取文件内容和Content-Type
file_data = response.content
response_content_type = response.headers.get('Content-Type', '')
# 确定文件名
final_filename = filename
if not final_filename:
# 尝试从URL提取文件名
parsed_url = urlparse(url)
path = parsed_url.path
final_filename = Path(path).name
# 如果URL中没有文件名或没有扩展名,根据Content-Type推断
if not final_filename or '.' not in final_filename:
ext = self._get_extension_from_mime(response_content_type)
if ext:
# 根据扩展名确定文件类型前缀
type_prefix = 'video' if ext.startswith('.mp') or ext in ('.webm', '.mov', '.avi', '.mkv', '.mpeg') else 'file'
final_filename = f"{type_prefix}_{int(datetime.now().timestamp())}{ext}"
else:
# 如果无法推断,使用通用名称
final_filename = f"file_{int(datetime.now().timestamp())}.bin"
# 生成OSS对象键
object_key = self._generate_object_key_simple(final_filename, entity_type)
# 从文件名获取Content-Type(优先使用响应头的Content-Type)
content_type = response_content_type or self._get_content_type(final_filename)
# 上传文件
result = self.bucket.put_object(
object_key,
file_data,
headers={'Content-Type': content_type}
)
# 构建文件URL
file_url = self._build_file_url(object_key)
logger.info(f"文件转存成功: {url} -> {file_url}")
return {
"object_key": object_key,
"filename": final_filename,
"url": file_url,
"content_type": content_type,
"size": len(file_data),
"etag": result.etag,
"uploaded_at": datetime.now().isoformat()
}
try:
return _do_upload()
except oss2.exceptions.OssError as e:
logger.error(f"OSS转存失败: {e.message}")
raise ExternalServiceException(f"OSS转存失败: {e.message}")
except (ExternalServiceException,):
raise
except Exception as e:
logger.error(f"文件转存失败: {str(e)}", exc_info=True)
raise BusinessException(f"文件转存失败: {str(e)}")
def upload_from_base64(
self,
base64_data_url: str,
entity_type: str,
filename: Optional[str] = None
) -> Dict[str, Any]:
"""
从 Base64 data URL 解码并上传到 OSS
参数:
base64_data_url: Base64 data URL (格式: data:image/png;base64,...)
entity_type: 实体类型(image/character/scene),用于构建存储路径
filename: 自定义文件名(可选),如果未提供则自动生成
返回:
上传结果信息
"""
import base64
import re
@retry_on_connection_error(
max_retries=self.max_retries,
delay=self.retry_delay
)
def _do_upload():
# 解析 Base64 data URL
# 格式: data:image/png;base64,iVBORw0KGgo...
if not base64_data_url.startswith('data:'):
raise ValueError(f"无效的 Base64 data URL 格式")
# 提取 MIME 类型和 Base64 数据
match = re.match(r'data:([^;]+);base64,(.+)', base64_data_url)
if not match:
raise ValueError(f"无法解析 Base64 data URL")
mime_type = match.group(1) # 例如: image/png
base64_data = match.group(2)
# 根据 MIME 类型确定文件扩展名
ext_map = {
'image/png': '.png',
'image/jpeg': '.jpg',
'image/jpg': '.jpg',
'image/gif': '.gif',
'image/webp': '.webp',
}
ext = ext_map.get(mime_type.lower(), '.png')
# 确定文件名
if filename:
final_filename = filename
# 确保文件名有正确的扩展名
if not final_filename.endswith(ext):
final_filename = Path(filename).stem + ext
else:
# 生成默认文件名
final_filename = f"image_{int(datetime.now().timestamp())}{ext}"
# 解码 Base64 数据
try:
file_data = base64.b64decode(base64_data)
except Exception as e:
logger.error(f"Base64 解码失败: {e}")
raise ExternalServiceException(f"Base64 解码失败: {str(e)}")
# 生成 OSS 对象键
object_key = self._generate_object_key_simple(final_filename, entity_type)
# 上传文件
result = self.bucket.put_object(
object_key,
file_data,
headers={'Content-Type': mime_type}
)
# 构建文件 URL
file_url = self._build_file_url(object_key)
logger.info(f"Base64 图片上传成功: size={len(file_data)} -> {file_url}")
return {
"object_key": object_key,
"filename": final_filename,
"url": file_url,
"content_type": mime_type,
"size": len(file_data),
"etag": result.etag,
"uploaded_at": datetime.now().isoformat()
}
try:
return _do_upload()
except oss2.exceptions.OssError as e:
logger.error(f"OSS 上传失败: {e.message}")
raise ExternalServiceException(f"OSS 上传失败: {e.message}")
except (ValueError, ExternalServiceException):
raise
except Exception as e:
logger.error(f"Base64 图片上传失败: {str(e)}", exc_info=True)
raise BusinessException(f"Base64 图片上传失败: {str(e)}")
def get_file_info(self, object_key: str) -> Dict[str, Any]:
"""
获取文件信息
参数:
object_key: OSS对象键
返回:
文件信息
"""
try:
meta = self.bucket.get_object_meta(object_key)
return {
"object_key": object_key,
"size": meta.headers.get('Content-Length'),
"content_type": meta.headers.get('Content-Type'),
"etag": meta.headers.get('ETag'),
"last_modified": meta.headers.get('Last-Modified')
}
except oss2.exceptions.NoSuchKey:
raise NotFoundException(f"文件不存在: {object_key}")
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"获取文件信息失败: {e.message}")
except Exception as e:
raise BusinessException(f"获取文件信息失败: {str(e)}")
def file_exists(self, object_key: str) -> bool:
"""
检查文件是否存在
参数:
object_key: OSS对象键
返回:
是否存在
"""
try:
return self.bucket.object_exists(object_key)
except Exception:
return False
def generate_download_url(
self,
object_key: str,
expires: Optional[int] = None,
filename: Optional[str] = None
) -> str:
"""
生成文件下载URL(临时访问)
参数:
object_key: OSS对象键
expires: 过期时间(秒)
filename: 下载时的文件名(可选)
返回:
下载URL
"""
try:
expires = expires or settings.OSS_SIGNED_URL_EXPIRE
params = {}
if filename:
params['response-content-disposition'] = f'attachment; filename="{filename}"'
return self.bucket.sign_url(
'GET',
object_key,
expires,
params=params
)
except oss2.exceptions.OssError as e:
raise ExternalServiceException(f"生成下载URL失败: {e.message}")
except Exception as e:
raise BusinessException(f"生成下载URL失败: {str(e)}")
def _build_file_url(self, object_key: str) -> str:
"""
构建文件访问URL
参数:
object_key: OSS对象键
返回:
文件URL
"""
# 如果配置了CDN域名,使用CDN域名
if settings.OSS_CDN_DOMAIN:
# 移除可能存在的 http:// 或 https:// 前��
cdn_domain = settings.OSS_CDN_DOMAIN
if cdn_domain.startswith("http://"):
cdn_domain = cdn_domain[7:]
elif cdn_domain.startswith("https://"):
cdn_domain = cdn_domain[8:]
# 移除域名开头的 /
if cdn_domain.startswith("/"):
cdn_domain = cdn_domain[1:]
protocol = "https" if settings.OSS_USE_HTTPS else "http"
return f"{protocol}://{cdn_domain}/{object_key}"
# 否则使用OSS域名
# 移除 endpoint 中可能存在的协议前缀
endpoint = settings.OSS_ENDPOINT
if endpoint.startswith("http://"):
endpoint = endpoint[7:]
elif endpoint.startswith("https://"):
endpoint = endpoint[8:]
protocol = "https" if settings.OSS_USE_HTTPS else "http"
return f"{protocol}://{settings.OSS_BUCKET_NAME}.{endpoint}/{object_key}"
def generate_sts_credentials(
self,
user_id: int,
duration_seconds: Optional[int] = None,
policy: Optional[str] = None,
role_session_name: Optional[str] = None
) -> Dict[str, Any]:
"""
生成OSS STS临时凭证(用于前端直传)
参数:
user_id: 用户ID(用于构建会话名称)
duration_seconds: 凭证有效期(秒),默认使用配置值
policy: 自定义策略(可选,JSON字符串)
role_session_name: 角色会话名称(可选)
返回:
STS临时凭证信息,包括:
- access_key_id: 临时AccessKey ID
- access_key_secret: 临时AccessKey Secret
- security_token: 安全令牌
- expiration: 过期时间(UTC)
- region: 区域
- bucket: Bucket名称
- endpoint: OSS端点
- upload_path_prefix: 上传路径前缀
"""
if not settings.OSS_STS_ROLE_ARN:
raise BusinessException(
"STS Role ARN未配置,请检查环境变量 OSS_STS_ROLE_ARN"
)
if not all([
settings.OSS_ACCESS_KEY_ID,
settings.OSS_ACCESS_KEY_SECRET,
settings.OSS_ENDPOINT
]):
raise BusinessException(
"OSS配置不完整,请检查环境变量"
)
try:
# 确定区域ID
# AcsClient 的第三个参数是 region_id,不是 endpoint
# 格式: cn-hangzhou, cn-beijing 等
if settings.OSS_REGION:
region = settings.OSS_REGION
else:
# 从OSS端点提取区域
# 格式: oss-cn-hangzhou.aliyuncs.com -> cn-hangzhou
import re
endpoint = settings.OSS_ENDPOINT
if not endpoint:
raise BusinessException("OSS_ENDPOINT 未配置")
match = re.search(r'oss-(\w+)-(\w+)\.aliyuncs\.com', endpoint)
if match:
region = f"{match.group(1)}-{match.group(2)}"
else:
raise BusinessException(
"无法从OSS端点确定区域,请配置 OSS_REGION"
)
# 创建STS客户端(使用正确的 region_id 参数)
client = AcsClient(
settings.OSS_ACCESS_KEY_ID,
settings.OSS_ACCESS_KEY_SECRET,
region
)
# 创建AssumeRole请求
request = AssumeRoleRequest.AssumeRoleRequest()
request.set_RoleArn(settings.OSS_STS_ROLE_ARN)
request.set_RoleSessionName(
role_session_name or f"aimv-frontend-upload-user-{user_id}"
)
request.set_DurationSeconds(
duration_seconds or settings.OSS_STS_DURATION_SECONDS
)
# 设置策略(如果提供了自定义策略)
if policy:
request.set_Policy(policy)
elif settings.OSS_STS_POLICY:
request.set_Policy(settings.OSS_STS_POLICY)
# 发送请求
response = client.do_action_with_exception(request)
result = json.loads(response.decode('utf-8'))
# 提取凭证信息
credentials = result['Credentials']
# 将过期时间转换为北京时间(UTC+8)
from app.schemas.common import convert_datetime_to_beijing
expiration_utc = datetime.strptime(credentials['Expiration'], '%Y-%m-%dT%H:%M:%SZ')
expiration_str = convert_datetime_to_beijing(expiration_utc)
# 构建上传路径前缀
upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}"
if settings.OSS_GLOBAL_PREFIX:
upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}"
return {
"access_key_id": credentials['AccessKeyId'],
"access_key_secret": credentials['AccessKeySecret'],
"security_token": credentials['SecurityToken'],
"expiration": expiration_str,
"region": settings.OSS_REGION,
"bucket": settings.OSS_BUCKET_NAME,
"endpoint": settings.OSS_ENDPOINT,
"cdn_domain": settings.OSS_CDN_DOMAIN,
"upload_path_prefix": upload_path_prefix,
}
except Exception as e:
logger.error(f"生成STS临时凭证失败: {str(e)}", exc_info=True)
raise ExternalServiceException(
f"生成STS临时凭证失败: {str(e)}"
)
def generate_user_upload_policy(self, user_id: int) -> str:
"""
为用户生成上传策略(限制只能上传到指定用户目录)
参数:
user_id: 用户ID
返回:
策略JSON字符串
"""
# 构建用户专属路径
upload_path_prefix = f"{settings.OSS_UPLOAD_PATH_PREFIX}/{user_id}"
if settings.OSS_GLOBAL_PREFIX:
upload_path_prefix = f"{settings.OSS_GLOBAL_PREFIX}/{upload_path_prefix}"
# 构建策略
policy = {
"Version": "1",
"Statement": [
{
"Effect": "Allow",
"Action": [
"oss:PutObject",
"oss:InitiateMultipartUpload",
"oss:UploadPart",
"oss:CompleteMultipartUpload",
"oss:AbortMultipartUpload"
],
"Resource": [
f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}/*"
]
},
{
"Effect": "Allow",
"Action": [
"oss:ListObjects"
],
"Resource": [
f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}",
f"acs:oss:*:*:{settings.OSS_BUCKET_NAME}/{upload_path_prefix}"
],
"Condition": {
"StringLike": {
"oss:prefix": [f"{upload_path_prefix}/*", f"{upload_path_prefix}"]
}
}
}
]
}
return json.dumps(policy)
# 创建全局OSS服务实例(延迟初始化)
_oss_service_instance = None
def get_oss_service() -> OSSService:
"""获取OSS服务实例(单例模式)"""
global _oss_service_instance
if _oss_service_instance is None:
_oss_service_instance = OSSService()
return _oss_service_instance
# 全局OSS服务实例
oss_service = get_oss_service()
"""
阿里云OSS文件上传模块
"""
import os
import uuid
import logging
from datetime import datetime, timedelta
import oss2
from app.core.config import settings
logger = logging.getLogger(__name__)
class OSSUploader:
"""阿里云OSS上传器"""
def __init__(self):
"""初始化OSS客户端"""
self.access_key_id = settings.OSS_ACCESS_KEY_ID
self.access_key_secret = settings.OSS_ACCESS_KEY_SECRET
self.endpoint = settings.OSS_ENDPOINT
self.bucket_name = settings.OSS_BUCKET_NAME
if not all([
self.access_key_id,
self.access_key_secret,
self.endpoint,
self.bucket_name,
]):
raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME")
logger.info(
"OSS配置: endpoint=%s, bucket=%s",
self.endpoint,
self.bucket_name,
)
# 创建认证对象
self.auth = oss2.Auth(self.access_key_id, self.access_key_secret)
# 默认使用公网 endpoint;非阿里云内网环境下访问 internal endpoint 容易失败。
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
def upload_file(self, local_file_path, oss_object_name=None):
"""
上传文件到OSS
Args:
local_file_path: 本地文件路径
oss_object_name: OSS对象名称,如果不指定则使用时间戳+原文件名
Returns:
tuple: (success: bool, url: str) 或 (success: bool, error: str)
"""
try:
if not os.path.exists(local_file_path):
logger.error(f"本地文件不存在: {local_file_path}")
return False, "本地文件不存在"
if not oss_object_name:
_, ext = os.path.splitext(local_file_path)
oss_object_name = f"{uuid.uuid4()}{ext}"
# 如果没有指定OSS对象名称,则生成一个
date = datetime.now().strftime("%Y%m%d")
oss_object_name = f"temp_ai/{date}/{oss_object_name}"
# 上传文件
result = self.bucket.put_object_from_file(oss_object_name, local_file_path)
# 构建文件URL
file_url = f"https://{self.bucket_name}.{self.endpoint}/{oss_object_name}"
logger.info(f"文件上传成功: {local_file_path} -> {file_url}")
return True, file_url
except Exception as e:
logger.error(f"文件上传失败: {local_file_path}, 错误: {e}")
return False, str(e)
def upload_data(self, data, oss_object_name):
"""
上传数据到OSS
Args:
data: 要上传的数据(字符串或字节)
oss_object_name: OSS对象名称
Returns:
dict: 包含上传结果的字典
"""
try:
# 上传数据
result = self.bucket.put_object(oss_object_name, data)
# 构建文件URL
file_url = f"{self.endpoint.rstrip('/')}/{self.bucket_name}/{oss_object_name}"
return {
"success": True,
"oss_object_name": oss_object_name,
"file_url": file_url,
"etag": result.etag,
"size": len(data) if isinstance(data, (str, bytes)) else 0
}
except Exception as e:
return {"success": False, "error": str(e)}
def get_bucket():
"""获取Bucket对象"""
if not all([
settings.OSS_ACCESS_KEY_ID,
settings.OSS_ACCESS_KEY_SECRET,
settings.OSS_ENDPOINT,
settings.OSS_BUCKET_NAME,
]):
raise ValueError("OSS配置不完整,请检查 .env 中的 OSS_ACCESS_KEY_ID/OSS_ACCESS_KEY_SECRET/OSS_ENDPOINT/OSS_BUCKET_NAME")
auth = oss2.Auth(settings.OSS_ACCESS_KEY_ID, settings.OSS_ACCESS_KEY_SECRET)
bucket = oss2.Bucket(auth, settings.OSS_ENDPOINT, settings.OSS_BUCKET_NAME)
return bucket
def clean_expire_file():
"""核心任务函数"""
print(f"\n[{datetime.now()}] 开始执行每日清理任务...")
ROOT_PREFIX = 'temp_ai/'
bucket = get_bucket()
# 1. 计算时间阈值
now = datetime.now()
yesterday_date = (now - timedelta(days=1)).date()
print(f"保留阈值: {yesterday_date} (即 {yesterday_date} 之前的数据将被删除)")
# 2. 遍历目录
try:
for obj in oss2.ObjectIterator(bucket, prefix=ROOT_PREFIX, delimiter='/'):
path = ""
is_directory = False
# --- [核心修改] 统一路径获取方式 ---
# 情况 A: 它是虚拟目录 (CommonPrefix)
if hasattr(obj, 'prefix'):
path = obj.prefix
is_directory = True
# 情况 B: 它是实际对象 (SimplifiedObjectInfo)
elif hasattr(obj, 'key'):
path = obj.key
# 如果 key 以 / 结尾,说明它是一个显式创建的文件夹对象
if path.endswith('/'):
is_directory = True
else:
is_directory = False # 这是一个普通文件
# --- 逻辑分流 ---
if not is_directory:
# 这是一个真正的文件(且不是文件夹对象),直接跳过
# print(f"[跳过] 散落文件: {path}")
continue
# 此时 path 必定是目录格式 (如 'temp_ai/20251229/')
# 下面开始正常的日期判断逻辑
# 防御性去空,防止路径即为 'temp_ai/' 本身
if path == ROOT_PREFIX:
continue
# 解析目录名 (取倒数第二个元素,因为最后一位是空字符串)
folder_name_raw = path.strip('/').split('/')[-1]
try:
folder_date_obj = datetime.strptime(folder_name_raw, "%Y%m%d").date()
if folder_date_obj < yesterday_date:
print(f"[删除] 发现过期目录: {path}")
# 注意:delete_objects_by_prefix 会删除该前缀下的所有文件
# 如果这个目录本身是个对象,也会被一并删除,无需特殊处理
delete_objects_by_prefix(bucket, path)
else:
# print(f"[跳过] 目录较新: {path}")
pass
except ValueError:
print(f"[跳过] 非日期命名目录: {path}")
except Exception as e:
import traceback
print(f"[严重错误] 任务执行失败: {e}")
traceback.print_exc()
def delete_objects_by_prefix(bucket, prefix):
"""递归删除指定前缀下的所有文件"""
print(f" -> 正在清理目录: {prefix} ...")
batch_list = []
try:
for obj in oss2.ObjectIterator(bucket, prefix=prefix):
batch_list.append(obj.key)
if len(batch_list) >= 1000:
bucket.batch_delete_objects(batch_list)
batch_list = []
if batch_list:
bucket.batch_delete_objects(batch_list)
print(f" -> 目录 {prefix} 清理完毕。")
except Exception as e:
print(f" [错误] 删除过程出错: {e}")
# 创建OSS上传器实例
oss_uploader = OSSUploader()
if __name__ == '__main__':
resp = oss_uploader.upload_file('想-dj-片段.mp3')
print(resp)
from dashscope.common.constants import DASHSCOPE_API_KEY_ENV
ENV = 'test'
# ENV = 'local'
DEBUG = True
### 数据库
#dev
DB_USER = 'root'
DB_PASSWORD = 'Hikoon123!'
DB_HOST = 'rm-bp18h64ad9ak4d7h5do.mysql.rds.aliyuncs.com'
DB_DATABASE = 'music_partner'
#Redis
REDIS_HOST = '172.23.209.46'
REDIS_PORT = 6379
REDIS_PSW = '1bvvpAmKXFhDDJXb'
REDIS_DB = 0
#新抖key
NEW_RANK_KEY = 'vh1gbvynpyegg6gebhgepgvc6'
BACK_BASE_URL = 'https://ai-test.hikoon.com/api/partner'
EMAIL_HOST = 'smtp.exmail.qq.com'
EMAIL_PORT = 465
EMAIL_HOST_USER = 'bigmusic@hikoon.com'
EMAIL_HOST_PASSWORD = 'Music!123'
#邮件接收人列表
EMAIL_RECEIVERS = ['1774507011@qq.com','yangsheng@hikoon.com']
#标签字典
TAG_DICT = {
"viral_song": "网络热歌",
"sad_songs": "伤感老歌",
"folk_songs": "民谣",
"catchy_pop": "口水歌",
"kids_songs": "洗脑儿歌",
"tk_songs": "抖音热歌",
"net_songs": "网络歌曲",
"dj_remix": "DJ嗨曲",
"Cheesy_EDM": "土嗨/慢摇",
"car_music": "车载音乐",
"shout_rap": "喊麦",
"heavy_metal": "重金属/土摇DJ嗨曲",
"mandarin_pop": "华语流行",
"mainstream_pop": "主流Pop",
"sweet_songs": "甜歌/校园",
"hip_rock": "嘻哈说唱R&B摇滚",
"child_songs": "主流儿歌",
"international_pop": "国外流行",
"jp_pop": "日韩流行",
"west_pop": "欧美流行",
"el_edm": "电音EDM",
"chinese_style": "国风",
"opera_vocal": "戏腔/古韵",
"guochao_EDM": "国潮电子",
"gufeng_music": "传统器乐古风",
"soundtrack_instrumental": "影视/纯音",
"ys_ost": "影视OST",
"pur_music": "纯音乐",
"no_lyric": "无词BGM",
"other_music": "其他",
"jazz_blue": "爵士/蓝调",
"voice_book": "有声书",
"lab_music": "实验音乐",
"healing": "治愈",
"melancholy": "伤感",
"lonely": "孤独",
"sweet": "甜蜜",
"inspiring": "励志",
"missing": "思念",
"nostalgic": "怀旧",
"angry": "愤怒",
"relaxing": "放松",
"catchy": "魔性洗脑",
"heroic": "悲壮",
"calm": "平静",
"festive": "喜庆",
"romantic": "浪漫",
"majestic": "雄壮",
"bewitching": "蛊惑",
"cathartic": "宣泄",
"solemn": "庄重",
"passionate": "激情",
"heavy": "沉重",
"happy": "快乐",
"tense": "紧张",
"horror": "恐怖",
"touching": "感动",
"spoof": "恶搞",
"funny": "搞笑",
"expectation": "期待",
"remembrance": "怀念",
"mysterious": "悬疑",
"blessing": "祝福",
"zen": "佛系",
"soothing": "舒缓",
"melodious": "悠扬",
"warm": "温暖",
"depressed": "忧郁",
"elderly": "老年",
"middle_aged": "中年",
"young_adult": "青年",
"teenager": "少年",
"life_scene": "生活场景",
"sports": "运动",
"driving": "开车",
"travel": "旅行",
"sleep": "睡前",
"study": "学习",
"cafe": "咖啡厅",
"bar": "酒吧",
"douyin":"抖音",
"restaurant": "餐厅",
"car_scene": "汽车",
"dance": "跳舞",
"work": "工作",
"nightclub": "夜店",
"leisure": "休闲",
"live_house": "live house",
"square_dance": "广场舞",
"wedding": "婚礼",
"dating": "约会",
"festival_scene": "节日场景",
"summer": "夏天",
"winter": "冬天",
"autumn": "秋天",
"spring_festival": "春节",
"christmas": "圣诞",
"valentine": "情人节",
"time_scene": "时间场景",
"morning": "清晨",
"afternoon": "午后",
"evening": "夜晚",
"midnight": "深夜",
"regional_scene": "地域场景",
"campus": "校园",
"city": "城市",
"grassland": "草原",
"tibet": "西藏",
"xinjiang": "新疆",
"transition_style": "转场类",
"card_point_switch": "卡点切换画面类",
"reverse_suspense": "反转悬念类",
"emotion_contrast": "情绪对比类",
"mashup_collection": "混剪合集类",
"emotional_resonance": "情感共鸣向剪辑",
"scene_adaptation": "场景适配剪辑",
"highlight_slice": "高光切片剪辑",
"live_performance": "现场表演类",
"singer_live": "歌手现场演唱",
"talent_cover": "达人翻唱表演",
"audience_interaction": "观众互动表演",
"card_point_speed": "卡点、变速类",
"multi_scene_fragment": "多场景碎片化卡点",
"tech_effect_speed": "技术流特效变速",
"lyric_concrete": "歌词具象化卡点",
"loop_speed_brainwash": "循环变速洗脑",
"ugc_co_creation": "UGC共创类",
"jianying_template": "剪映模板",
"ai_singing": "AI唱歌",
"emotional_quotes": "情感语录类",
"late_night_emo": "深夜emo类",
"morning_inspiration": "清晨励志类",
"memory_destiny": "回忆杀/宿命感类",
"dynamic_lyrics_visual": "动态歌词可视化",
"basic_lyrics_effect": "基础歌词动效",
"creative_visual_enhance": "创意视觉强化",
"adaptation": "改编",
"special_effects_interaction": "特效互动类",
"gesture_magic_effect": "手势魔法特效互动",
"lip_sync_challenge": "对口型挑战",
"douyin_effect_show": "抖音特效变装秀",
# 听感演绎流
"singing_montage": "演唱混剪",
"live_singing": "现场演唱",
# 视觉冲击流
"change_transition": "变装转场",
"hand_dance": "手势舞",
"addictive_dance": "魔性舞蹈",
"landscape_account": "风景号",
# 氛围素材流
"cute_pets": "萌宠",
"movie_anime_edit": "影视剧/动漫混剪",
"chinese_classical": "古风",
"mood_post": "图文心情",
# 情感共鸣流
"animated_lyrics": "动态歌词",
"storytelling": "故事演绎",
"beauty_snaps": "颜值随拍"
}
# 模型相关配置
BASE_MODEL = "/data/qufeng/models--MIT--ast-finetuned-audioset-10-10-0.4593/snapshots/f826b80d28226b62986cc218e5cec390b1096902"
MOE_DIR = "/data/qufeng/moe_outputs"
BASELINE_CHECKPOINT = "/data/qufeng/best_epoch_base.pt"
LABEL_MAPPING = "/data/qufeng/label_mapping.txt"
DEVICE = "cuda" # 可选: cuda/mps/cpu,为空时自动选择
ROUTER_CHECKPOINT = "" # 为空时自动从 moe_dir/joint_train/joint/router_best.pt 推断
EXPERTS_DIR = "" # 为空时自动从 moe_dir/experts_train/experts 推断
# 音频处理配置
CHUNK_SECONDS = 10.24 # 按多少秒切块推理
CROP_SECONDS = 204.8 # 若音频超过该时长,则仅截取中间这段再切块
MAX_CHUNKS = 10 # 每首歌最多使用多少个切片参与推理
CHUNK_BATCH_SIZE = 8 # 切块推理的 batch size
ROUTING_THRESHOLD = 0.6
API_CONFIG = {
"api_key": "sk-d9b4d3581bde47d887354f9160a509a2",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen3-omni-flash",
"audio_mode": "auto",
"timeout": 15,
"lyrics_timeout": 60,
"lyrics_retries": 2,
"max_retries": 5,
"retry_delay": 5
}
# API_CONFIG_91 = {
# "api_key": "sk-E90VNVMyhfk2zDBDoToCXoipzGofD2SobwBqaCzbG3junlob",
# "base_url": "https://api.91aopusi.com/v1",
# "model": "qwen3-omni-flash",
# "audio_mode": "auto",
# "timeout": 30,
# "lyrics_timeout": 60,
# "max_retries": 5,
# "retry_delay": 5
# }
DASHSCOPE_API_KEY = 'sk-d9b4d3581bde47d887354f9160a509a2'
OSS_ACCESS_KEY_ID='LTAI4G7UvaW2e4UTCb3KCNjN'
OSS_ACCESS_KEY_SECRET='ow5hlVMmJAQY9o7nEAtMER6MFkPedm'
OSS_ENDPOINT='oss-cn-hangzhou.aliyuncs.com'
OSS_ENDPOINT_INTERNAL='oss-cn-hangzhou-internal.aliyuncs.com'
OSS_BUCKET_NAME='ai-sound-data-test'
\ No newline at end of file
import logging.handlers
import os
from config import DEBUG
log_dir = "./logs"
log_max_bytes = 1024 * 1024 * 10
log_backup_count = 5
def get_logger(name, level=None):
if not level:
level = logging.DEBUG if DEBUG else logging.INFO
# 配置日志
logger = logging.getLogger(name)
logger.setLevel(level)
# 检查日志目录是否存在,如果不存在则创建
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# 创建一个handler,用于写入日志文件
file_handler = logging.handlers.RotatingFileHandler(f'./{log_dir}/{name}.log', maxBytes=log_max_bytes,
backupCount=log_backup_count,encoding='utf-8')
file_handler.setLevel(level)
# 定义handler的输出格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# 给logger添加handler
logger.addHandler(file_handler)
return logger
# 定义一个模块级别的变量来存储日志记录器实例
_app_logger = None
def get_app_logger():
global _app_logger
if _app_logger is None:
_app_logger = get_logger("app")
return _app_logger
# -*- coding: utf-8 -*-
"""Batch analyze audio URLs from an xlsx file and export results to xlsx."""
from __future__ import annotations
import argparse
import json
import math
import os
import sys
import traceback
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from pathlib import Path
from typing import Any
import pandas as pd
# 允许直接 `python pipeline/batch_analyze_xlsx.py` 运行
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.middleware.music_analyze import analyze_music
DEFAULT_OUTPUT_COLUMNS = [
"tmeid",
"歌曲ID",
"歌曲名",
"表演者",
"歌曲时长",
"表演者类型",
"语种",
"BPM速度",
"情绪",
"网络/抖音歌曲",
"音乐风格",
"配器",
"场景",
]
ANALYZE_COLUMNS = [
"表演者类型",
"语种",
"BPM速度",
"情绪",
"网络/抖音歌曲",
"音乐风格",
"配器",
"场景",
]
def _is_blank(value: Any) -> bool:
if value is None:
return True
if isinstance(value, float) and math.isnan(value):
return True
return str(value).strip() == ""
def _join_multi_value(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, list):
parts = [str(v).strip() for v in value if str(v).strip()]
return "、".join(parts)
return str(value).strip()
def _pick_first_non_blank(row: pd.Series, candidates: list[str]) -> str:
for col in candidates:
if col in row.index and not _is_blank(row[col]):
value = row[col]
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value).strip()
return ""
def _normalize_key_value(value: Any) -> str:
if _is_blank(value):
return ""
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value).strip()
def _resolve_url_column(df: pd.DataFrame, requested_column: str) -> str:
if requested_column in df.columns:
return requested_column
candidates = ["URL", "url", "cos访问地址", "cos_url", "audio_url"]
for col in candidates:
if col in df.columns:
print(
f"[run] url column `{requested_column}` not found, fallback to `{col}`"
)
return col
raise ValueError(
f"column `{requested_column}` not found, available={list(df.columns)}"
)
def _is_row_completed(out_df: pd.DataFrame, idx: int) -> bool:
for col in ANALYZE_COLUMNS:
if col not in out_df.columns:
continue
value = out_df.at[idx, col]
if not _is_blank(value):
return True
return False
def _resolve_checkpoint_path(output_path: Path, checkpoint_path: Path | None) -> Path:
if checkpoint_path is not None:
return checkpoint_path
return output_path.with_suffix(output_path.suffix + ".checkpoint.json")
def _save_progress(
out_df: pd.DataFrame,
output_path: Path,
checkpoint_path: Path,
completed_indices: set[int],
) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
tmp_output = output_path.with_suffix(output_path.suffix + ".tmp")
out_df = out_df[DEFAULT_OUTPUT_COLUMNS]
out_df.to_excel(tmp_output, index=False)
tmp_output.replace(output_path)
payload = {
"completed_indices": sorted(completed_indices),
"completed_count": len(completed_indices),
"total": int(len(out_df)),
}
tmp_checkpoint = checkpoint_path.with_suffix(checkpoint_path.suffix + ".tmp")
tmp_checkpoint.write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
tmp_checkpoint.replace(checkpoint_path)
def _load_checkpoint(checkpoint_path: Path) -> set[int]:
if not checkpoint_path.exists():
return set()
try:
payload = json.loads(checkpoint_path.read_text(encoding="utf-8"))
values = payload.get("completed_indices", [])
return {int(v) for v in values if isinstance(v, int) or str(v).isdigit()}
except Exception:
return set()
def _filter_checkpoint_indices(
checkpoint_indices: set[int],
out_df: pd.DataFrame,
df: pd.DataFrame,
url_column: str,
) -> set[int]:
"""
过滤 checkpoint 中的索引:
- 保留已存在分析结果的行(避免重复分析)
- 保留当前仍为空 URL 的行(继续跳过)
- 若 URL 已补齐且该行无分析结果,则不保留(允许后续补分析)
"""
filtered: set[int] = set()
for idx in checkpoint_indices:
if idx < 0 or idx >= len(out_df):
continue
if _is_row_completed(out_df, idx):
filtered.add(idx)
continue
url = df.at[idx, url_column] if url_column in df.columns else None
if _is_blank(url):
filtered.add(idx)
return filtered
def _build_metadata(row: pd.Series, metadata_columns: list[str]) -> dict[str, Any]:
metadata: dict[str, Any] = {}
# 关键字段自动透传,避免遗漏导致下游无法建立映射
for col in ["歌曲ID", "song_id", "id"]:
if col in row.index and not _is_blank(row[col]):
metadata[col] = row[col]
break
for col in ["tmeid", "tmeID", "TMEID"]:
if col in row.index and not _is_blank(row[col]):
metadata["tmeid"] = row[col]
break
for col in metadata_columns:
if col in row.index and not _is_blank(row[col]):
metadata[col] = row[col]
return metadata
def _normalize_result(result: dict[str, Any]) -> dict[str, Any]:
return {
"表演者类型": (
str(result.get("performer_type") or result.get("vocal_texture") or "").strip()
),
"语种": str(result.get("language") or "").strip(),
"BPM速度": result.get("bpm"),
"情绪": _join_multi_value(result.get("emotion", [])),
"网络/抖音歌曲": _join_multi_value(result.get("douyin_tags", [])),
"音乐风格": _join_multi_value(
result.get("music_style_tags", [])
or [v for v in [result.get("genre"), result.get("sub_genre")] if v]
),
"配器": _join_multi_value(result.get("instrument_tags", [])),
"场景": _join_multi_value(result.get("scene", [])),
}
def _build_song_tmeid_maps(df: pd.DataFrame) -> tuple[dict[str, int], dict[str, int]]:
song_id_map: dict[str, int] = {}
tmeid_map: dict[str, int] = {}
for idx, row in df.iterrows():
song_id = _pick_first_non_blank(row, ["歌曲ID", "song_id", "id"])
tmeid = _pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"])
if song_id and song_id not in song_id_map:
song_id_map[song_id] = int(idx)
if tmeid and tmeid not in tmeid_map:
tmeid_map[tmeid] = int(idx)
return song_id_map, tmeid_map
def _resume_from_existing_by_keys(out_df: pd.DataFrame, existing: pd.DataFrame) -> set[int]:
"""当输入行数变化时,按 歌曲ID/tmeid 匹配复用旧结果。"""
completed_indices: set[int] = set()
if existing.empty:
return completed_indices
old_song_map, old_tmeid_map = _build_song_tmeid_maps(existing)
reused = 0
reused_by_song = 0
reused_by_tmeid = 0
for idx in out_df.index:
song_id = _normalize_key_value(out_df.at[idx, "歌曲ID"])
tmeid = _normalize_key_value(out_df.at[idx, "tmeid"])
old_idx = None
if song_id and song_id in old_song_map:
old_idx = old_song_map[song_id]
reused_by_song += 1
elif tmeid and tmeid in old_tmeid_map:
old_idx = old_tmeid_map[tmeid]
reused_by_tmeid += 1
if old_idx is None:
continue
for col in DEFAULT_OUTPUT_COLUMNS:
if col in existing.columns:
out_df.at[idx, col] = existing.at[old_idx, col]
if _is_row_completed(out_df, int(idx)):
completed_indices.add(int(idx))
reused += 1
print(
"[resume] row mismatch, reused by key: "
f"song_id_match={reused_by_song}, tmeid_match={reused_by_tmeid}, "
f"completed={reused}/{len(out_df)}"
)
return completed_indices
def _analyze_one(
idx: int,
row: pd.Series,
url_column: str,
provider: str,
extract_lyrics: bool,
label_level: int,
metadata_columns: list[str],
) -> tuple[int, dict[str, Any]]:
url = row.get(url_column)
if _is_blank(url):
return idx, {}
try:
metadata = _build_metadata(row, metadata_columns)
result = analyze_music(
metadata=metadata,
music_url=str(url).strip(),
provider=provider,
extract_lyrics=extract_lyrics,
label_level=label_level,
)
if not result:
return idx, {}
return idx, _normalize_result(result)
except Exception as exc:
print(f"[warn] row={idx} analyze failed: {type(exc).__name__}: {exc}")
print(traceback.format_exc(limit=3))
return idx, {}
def run_batch(
input_path: Path,
output_path: Path,
checkpoint_path: Path | None,
url_column: str,
provider: str,
extract_lyrics: bool,
label_level: int,
metadata_columns: list[str],
workers: int,
checkpoint_every: int,
resume: bool,
) -> None:
df = pd.read_excel(input_path)
url_column = _resolve_url_column(df, url_column)
checkpoint_path = _resolve_checkpoint_path(output_path, checkpoint_path)
blank_url_indices = {int(idx) for idx, value in df[url_column].items() if _is_blank(value)}
# 先构建参考表基础列(来自输入元数据)
out_df = pd.DataFrame(index=df.index)
out_df["tmeid"] = [
_pick_first_non_blank(row, ["tmeid", "tmeID", "TMEID"]) for _, row in df.iterrows()
]
out_df["歌曲ID"] = [
_pick_first_non_blank(row, ["歌曲ID", "song_id", "id"]) for _, row in df.iterrows()
]
out_df["歌曲名"] = [
_pick_first_non_blank(row, ["歌曲名", "歌曲名称", "title"]) for _, row in df.iterrows()
]
out_df["表演者"] = [
_pick_first_non_blank(row, ["表演者", "歌手", "artist"]) for _, row in df.iterrows()
]
out_df["歌曲时长"] = [
_pick_first_non_blank(row, ["歌曲时长", "duration"]) for _, row in df.iterrows()
]
for col in DEFAULT_OUTPUT_COLUMNS:
if col not in out_df.columns:
out_df[col] = ""
completed_indices: set[int] = set()
output_aligned_by_index = False
if resume:
if output_path.exists():
try:
existing = pd.read_excel(output_path)
if len(existing) == len(out_df):
output_aligned_by_index = True
for col in DEFAULT_OUTPUT_COLUMNS:
if col in existing.columns:
out_df[col] = existing[col]
for idx in out_df.index:
if _is_row_completed(out_df, idx):
completed_indices.add(int(idx))
print(
f"[resume] loaded existing output: {len(completed_indices)}/{len(out_df)} completed"
)
else:
completed_indices |= _resume_from_existing_by_keys(out_df, existing)
except Exception as exc:
print(f"[resume] failed to read existing output: {type(exc).__name__}: {exc}")
checkpoint_completed = _load_checkpoint(checkpoint_path)
if checkpoint_completed:
if output_aligned_by_index:
checkpoint_completed = _filter_checkpoint_indices(
checkpoint_completed, out_df, df, url_column
)
before = len(completed_indices)
completed_indices |= {idx for idx in checkpoint_completed if 0 <= idx < len(out_df)}
if len(completed_indices) != before:
print(
f"[resume] loaded checkpoint: {len(completed_indices)}/{len(out_df)} completed"
)
else:
print("[resume] ignore checkpoint due to row mismatch with previous output")
# 空 URL 行直接跳过,不参与分析
if blank_url_indices:
completed_indices |= blank_url_indices
print(f"[run] skip blank `{url_column}` rows: {len(blank_url_indices)}")
pending_indices = [int(idx) for idx in out_df.index if int(idx) not in completed_indices]
if not pending_indices:
print("[resume] no pending rows, nothing to do")
_save_progress(out_df, output_path, checkpoint_path, completed_indices)
return
print(
f"[run] total={len(out_df)}, completed={len(completed_indices)}, pending={len(pending_indices)}"
)
workers = max(1, workers)
checkpoint_every = max(1, checkpoint_every)
processed_since_checkpoint = 0
executor = ThreadPoolExecutor(max_workers=workers)
futures = []
try:
for idx in pending_indices:
row = df.iloc[idx]
futures.append(
executor.submit(
_analyze_one,
idx,
row,
url_column,
provider,
extract_lyrics,
label_level,
metadata_columns,
)
)
pending_futures = set(futures)
while pending_futures:
done, pending_futures = wait(
pending_futures,
timeout=1.0,
return_when=FIRST_COMPLETED,
)
if not done:
continue
for future in done:
idx, result = future.result()
for k, v in result.items():
out_df.at[idx, k] = v
if result:
completed_indices.add(int(idx))
processed_since_checkpoint += 1
if processed_since_checkpoint >= checkpoint_every:
_save_progress(out_df, output_path, checkpoint_path, completed_indices)
processed_since_checkpoint = 0
except KeyboardInterrupt:
print("[interrupt] received keyboard interrupt, saving checkpoint...")
try:
_save_progress(out_df, output_path, checkpoint_path, completed_indices)
except Exception as exc:
print(f"[interrupt] failed to save checkpoint: {type(exc).__name__}: {exc}")
for future in futures:
future.cancel()
executor.shutdown(wait=False, cancel_futures=True)
print("[interrupt] force exit to avoid blocking on running worker threads")
os._exit(130)
finally:
try:
executor.shutdown(wait=True, cancel_futures=False)
except Exception:
pass
_save_progress(out_df, output_path, checkpoint_path, completed_indices)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Batch audio analysis from xlsx")
parser.add_argument("--input", required=True, help="input xlsx path")
parser.add_argument("--output", required=True, help="output xlsx path")
parser.add_argument(
"--checkpoint",
default="",
help="checkpoint json path (default: <output>.checkpoint.json)",
)
parser.add_argument("--url-column", default="URL", help="url column name")
parser.add_argument("--provider", default="qwen", choices=["qwen", "doubao"])
parser.add_argument("--extract-lyrics", action="store_true", help="enable lyrics extraction")
parser.add_argument("--label-level", type=int, default=0, choices=[0, 1])
parser.add_argument(
"--metadata-columns",
default="tmeID,歌曲名称,歌曲名,歌手,表演者,版本,词作者,曲作者",
help="comma separated metadata columns",
)
parser.add_argument("--workers", type=int, default=3, help="parallel workers")
parser.add_argument(
"--checkpoint-every",
type=int,
default=10,
help="save checkpoint every N processed rows",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="disable resume from existing output/checkpoint",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
metadata_columns = [c.strip() for c in args.metadata_columns.split(",") if c.strip()]
run_batch(
input_path=Path(args.input),
output_path=Path(args.output),
checkpoint_path=Path(args.checkpoint) if args.checkpoint.strip() else None,
url_column=args.url_column,
provider=args.provider,
extract_lyrics=args.extract_lyrics,
label_level=args.label_level,
metadata_columns=metadata_columns,
workers=args.workers,
checkpoint_every=args.checkpoint_every,
resume=not args.no_resume,
)
if __name__ == "__main__":
main()
openai>=1.58.1
requests>=2.31.0
httpx>=0.28.1
python-dotenv>=1.0.1
pydantic-settings>=2.6.1
numpy>=1.24.0
scipy>=1.10.0
librosa>=0.10.2
soundfile>=0.12.1
pandas>=2.2.0
openpyxl>=3.1.2
# Optional: enable funasr backend in qwen_analyzer
# dashscope>=1.20.0