Skip to content

Commit

Permalink
修改--NLP一律使用
Browse files Browse the repository at this point in the history
修改--先使用DeepL API再使用Gemini API
修改--blockReason: SAFETY時直接調用DeepL結果
  • Loading branch information
Arcelibs committed Dec 27, 2023
1 parent 35a0104 commit 76fe99d
Showing 1 changed file with 39 additions and 51 deletions.
90 changes: 39 additions & 51 deletions Local-Windows/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# 初始化 Faster Whisper 模型
model = WhisperModel("large-v2", device="cuda", compute_type="int8") # 使用适合您硬件的模型大小和设备
model = WhisperModel("large-v2", device="cuda", compute_type="int8")

#從api_key.txt取得api
def get_api_key_from_file(file_path='api_key.txt'):
Expand Down Expand Up @@ -55,58 +55,49 @@ def call_deepl_api(input_text, deepL_auth_key, target_lang='ZH'):

# 呼叫Gemini的函數
def call_gemini_api(input_text):
# 從文件中取得API KEY
# 获取 DeepL API KEY
deepl_auth_key = get_deepl_api_key_from_file()
if not deepl_auth_key:
print("无法获取 DeepL API KEY。")
return None

# 先使用 DeepL API 翻译为简体中文
simplified_chinese_text = call_deepl_api(input_text, deepl_auth_key, target_lang='ZH')
if not simplified_chinese_text:
print("DeepL 翻译失败。")
return None

# 从文件中取得 Gemini API KEY
api_key = get_api_key_from_file()
if not api_key:
print("API KEY 無效")
print("Gemini API KEY 无效")
return None
# APIURL請求標頭

# Gemini API URL 请求头
url = "https://palm-proxy.arcelibs.com/v1beta/models/gemini-pro:generateContent?key={}".format(api_key)
headers = {'Content-Type': 'application/json'}

# 分段處理
segments = detect_and_split_text(input_text, max_length=500) # 假设 500 是分段长度的上限
all_translated_text = []
# 格式化输入文本,准备发送给 Gemini API
formatted_input = f"請你將下列語句翻譯成繁體中文,忽略任何語句與文法問題:\n{simplified_chinese_text}"
data = json.dumps({"contents": [{"parts": [{"text": formatted_input}]}]})

for segment in segments:
formatted_input = f"請你必須將下列語句翻譯成流暢的繁體中文: \n{segment}"
data = json.dumps({"contents": [{"parts": [{"text": formatted_input}]}]})
# 发送请求到 Gemini API
response = requests.post(url, headers=headers, data=data)
response_data = response.json()

response = requests.post(url, headers=headers, data=data)
response_data = response.json()
response_text = json.dumps(response_data, indent=4)

# 保存API回應內容
save_transcription("gemini_response", response_text, is_api_response=True)

if response.status_code == 200:
if 'blockReason' in response_data and response_data['blockReason'] == 'SAFETY':
# 获取 DeepL API KEY
deepl_auth_key = get_deepl_api_key_from_file()
if deepl_auth_key:
deepl_translated = call_deepl_api(segment, deepl_auth_key)
if deepl_translated:
# 再次尝试使用 Gemini API
gemini_second_try = call_gemini_api(deepl_translated)
all_translated_text.append(gemini_second_try)
else:
print("DeepL 翻译失败。")
all_translated_text.append(f"[翻译段落失败: {segment}]")
else:
print("无法获取 DeepL API KEY。")
all_translated_text.append(f"[翻译段落失败: {segment}]")
elif 'candidates' in response_data:
translated_text = response_data['candidates'][0]['content']['parts'][0]['text']
all_translated_text.append(translated_text)
else:
print("KeyError: 'candidates' not found in response.")
all_translated_text.append(f"[翻译段落失败: {segment}]")
# 检查响应并处理
if response.status_code == 200:
if 'blockReason' in response_data and response_data['blockReason'] == 'SAFETY':
# 如果 Gemini API 返回 blockReason 为 SAFETY,直接使用 DeepL 翻译结果
return simplified_chinese_text
elif 'candidates' in response_data:
return response_data['candidates'][0]['content']['parts'][0]['text']
else:
print(f"错误: {response.status_code}")
all_translated_text.append(f"[翻译段落失败: {segment}]")

return ' '.join(all_translated_text)
print("KeyError: 'candidates' not found in response.")
return None
else:
print(f"错误: {response.status_code}")
return None


# 使用 yt-dlp 獲取 YouTube 直播媒體位置
Expand Down Expand Up @@ -161,7 +152,7 @@ def save_transcription(file_path, text, is_api_response=False):
file.write(text)

# 新增語言檢測機制並加載模型
def detect_and_split_text(text, max_length):
def detect_and_split_text(text, max_length=500):
language = detect(text)
nlp = None
if language == "en":
Expand All @@ -171,7 +162,7 @@ def detect_and_split_text(text, max_length):
elif language in ["zh-cn", "zh-tw"]:
nlp = spacy.load("zh_core_web_sm")
else:
# 無法辨識?就用標點符號分段
# 如果无法识别语言,则根据标点符号分段
return split_text_on_punctuation(text, max_length)

return split_text_natural(nlp, text, max_length)
Expand All @@ -185,7 +176,7 @@ def split_text_natural(nlp, text, max_length):
if len(current_segment + sent.text) <= max_length:
current_segment += sent.text + " "
else:
if current_segment: # 空字串處理
if current_segment: # 非空字符串
segments.append(current_segment.strip())
current_segment = sent.text + " "
if current_segment:
Expand All @@ -194,10 +185,8 @@ def split_text_natural(nlp, text, max_length):

# 不是中英日文? 就用標點符號來分類
def split_text_on_punctuation(text, max_length):
# 定義分類的標點符號
punctuations = ".!?\n"
pattern = f"[{re.escape(punctuations)}]"

segments = []
current_segment = ""
for word in re.split(pattern, text):
Expand All @@ -209,7 +198,6 @@ def split_text_on_punctuation(text, max_length):
current_segment = word
if current_segment:
segments.append(current_segment.strip())

return segments

# 主流程
Expand Down Expand Up @@ -239,7 +227,7 @@ def main(segment_duration, total_duration):
time.sleep(segment_duration)

# 配置參數
SEGMENT_DURATION = 10 # 每一段錄製的長度,單位是秒
SEGMENT_DURATION = 15 # 每一段錄製的長度,單位是秒
TOTAL_DURATION = 6000 # 總錄製時間,單位是秒

if __name__ == "__main__":
Expand Down

0 comments on commit 76fe99d

Please sign in to comment.