55import tempfile
66import time
77from pathlib import Path
8- from typing import Any , Callable , List , Optional
8+ from typing import Any , Callable , List , Optional , Union
99
1010from ...config import MODEL_PATH
1111from ..utils .logger import setup_logger
@@ -25,16 +25,20 @@ class WhisperCppASR(BaseASR):
2525
2626 def __init__ (
2727 self ,
28- audio_path ,
28+ audio_input : Union [ str , bytes ] ,
2929 language = "en" ,
3030 whisper_cpp_path = None ,
3131 whisper_model = None ,
3232 use_cache : bool = False ,
3333 need_word_time_stamp : bool = False ,
3434 ):
35- super ().__init__ (audio_path , use_cache )
36- assert os .path .exists (audio_path ), f"Audio file not found: { audio_path } "
37- assert audio_path .endswith (".wav" ), f"Audio must be WAV format: { audio_path } "
35+ super ().__init__ (audio_input , use_cache )
36+
37+ if isinstance (audio_input , str ):
38+ assert os .path .exists (audio_input ), f"Audio file not found: { audio_input } "
39+ assert audio_input .endswith (
40+ ".wav"
41+ ), f"Audio must be WAV format: { audio_input } "
3842
3943 # Auto-detect whisper executable if not provided
4044 if whisper_cpp_path is None :
@@ -116,13 +120,13 @@ def _default_callback(_progress: int, _message: str) -> None:
116120
117121 with tempfile .TemporaryDirectory () as temp_path :
118122 temp_dir = Path (temp_path )
119- wav_path = temp_dir / "audio .wav"
123+ wav_path = temp_dir / "whisper_cpp_audio .wav"
120124 output_path = wav_path .with_suffix (".srt" )
121125
122126 try :
123127 # 复制音频文件
124- if isinstance (self .audio_path , str ):
125- shutil .copy2 (self .audio_path , wav_path )
128+ if isinstance (self .audio_input , str ):
129+ shutil .copy2 (self .audio_input , wav_path )
126130 else :
127131 if self .file_binary :
128132 wav_path .write_bytes (self .file_binary )
@@ -136,10 +140,7 @@ def _default_callback(_progress: int, _message: str) -> None:
136140 logger .info ("Whisper.cpp command: %s" , " " .join (whisper_params ))
137141
138142 # Get audio duration
139- if isinstance (self .audio_path , str ):
140- total_duration = self .get_audio_duration (self .audio_path )
141- else :
142- total_duration = 600
143+ total_duration = self .audio_duration
143144 logger .info ("Audio duration: %d seconds" , total_duration )
144145
145146 # Start process
@@ -272,7 +273,7 @@ def detect_whisper_executable() -> str:
272273if __name__ == "__main__" :
273274 # 简短示例
274275 asr = WhisperCppASR (
275- audio_path = "audio.mp3" ,
276+ audio_input = "audio.mp3" ,
276277 whisper_model = "tiny" ,
277278 whisper_cpp_path = "bin/whisper-cpp.exe" ,
278279 language = "en" ,
0 commit comments