-
Notifications
You must be signed in to change notification settings - Fork 389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Centralize logic of postprocess memory to util #886
Centralize logic of postprocess memory to util #886
Conversation
WalkthroughThis change introduces a centralized function for postprocessing memories, enhancing code maintainability and readability. It includes operations like memory status checking, audio processing, file uploading, speech recognition handling, and result storing. The dependencies are updated with Changes
Uplevel your code reviews with Entelligence.AI ProEntelligence.AI ProIf you like this project, please support us by purchasing the Pro version. The Pro version has advanced context, superior noise reduction and several proprietary improvements compared to the open source version. Moreover, Entelligence.AI Pro is free for open source projects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Files selected (4)
- backend/requirements.txt (1)
- backend/routers/postprocessing.py (2)
- backend/routers/transcribe.py (2)
- backend/utils/memories/postprocess_memory.py (1)
Files not summarized due to errors (1)
- backend/routers/postprocessing.py (diff tokens exceeds limit)
Files skipped from review due to trivial changes (1)
- backend/requirements.txt
Review comments generated (9)
- Review: 8
- LGTM: 1
Tips
Chat with Entelligence.AI Bot (@Entelligence.AI
)
- Reply on review comments left by this bot to ask follow-up questions. A review comment is a comment on a diff or a file.
- Invite the bot into a review comment chain by tagging
@Entelligence.AI
in a reply.
Code suggestions
- The bot may make code suggestions, but please review them carefully before committing since the line number ranges may be misaligned.
- You can edit the comment made by the bot and manually tweak the suggestion if it is slightly off.
Pausing incremental reviews
- Add
@Entelligence.AI: ignore
anywhere in the PR description to pause further reviews from the bot.
@@ -15,7 +15,7 @@ | |||
from models.memory import Memory, TranscriptSegment | |||
from models.message_event import NewMemoryCreated, MessageEvent, NewProcessingMemoryCreated | |||
from models.processing_memory import ProcessingMemory | |||
from routers.postprocessing import postprocess_memory_util | |||
from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import statement has been updated to reflect the new location of the postprocess_memory
function. This change seems fine as long as the function postprocess_memory
is indeed moved to utils.memories.postprocess_memory
. Please ensure that this file and function exist at the specified location.
- from routers.postprocessing import postprocess_memory_util
+ from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util
def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str): | ||
memory_data = _get_memory_by_id(uid, memory_id) | ||
if not memory_data: | ||
return (404, "Memory not found") | ||
|
||
memory = Memory(**memory_data) | ||
if memory.discarded: | ||
print('postprocess_memory: Memory is discarded') | ||
return (400, "Memory is discarded") | ||
|
||
if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started: | ||
print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}') | ||
return (400, "Memory can't be post-processed again") | ||
|
||
aseg = AudioSegment.from_wav(file_path) | ||
if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 | ||
# TODO: fix app, sometimes audio uploaded is wrong, is too short. | ||
print('postprocess_memory: Audio duration is too short, seems wrong.') | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) | ||
return (500, "Audio duration is too short, seems wrong.") | ||
|
||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) | ||
|
||
try: | ||
# Calling VAD to avoid processing empty parts and getting hallucinations from whisper. | ||
vad_segments = vad_is_empty(file_path, return_segments=True) | ||
if vad_segments: | ||
start = vad_segments[0]['start'] | ||
end = vad_segments[-1]['end'] | ||
aseg = AudioSegment.from_wav(file_path) | ||
aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] | ||
aseg.export(file_path, format="wav") | ||
except Exception as e: | ||
print(e) | ||
|
||
try: | ||
aseg = AudioSegment.from_wav(file_path) | ||
signed_url = upload_postprocessing_audio(file_path) | ||
threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() | ||
|
||
if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): | ||
upload_memory_recording(file_path, uid, memory_id) | ||
|
||
speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) | ||
words = fal_whisperx(signed_url, speakers_count) | ||
fal_segments = fal_postprocessing(words, aseg.duration_seconds) | ||
|
||
# if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL | ||
count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) | ||
new_count = len(''.join([segment.text.strip() for segment in fal_segments])) | ||
print('Prev characters count:', count, 'New characters count:', new_count) | ||
|
||
fal_failed = not fal_segments or new_count < (count * 0.85) | ||
|
||
if fal_failed: | ||
_handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) | ||
else: | ||
_handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) | ||
|
||
# Store both models results. | ||
memories_db.store_model_segments_result(uid, memory.id, streaming_model, memory.transcript_segments) | ||
memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) | ||
|
||
if not fal_failed: | ||
memory.transcript_segments = fal_segments | ||
|
||
memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later | ||
if fal_failed: | ||
# TODO: FAL fails too much and is fucking expensive. Remove it. | ||
fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) | ||
memory.postprocessing = MemoryPostProcessing( | ||
status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx) | ||
# TODO: consider doing process_memory, if any segment still matched to user or people | ||
return (200, memory) | ||
|
||
# Reprocess memory with improved transcription | ||
result: Memory = process_memory(uid, memory.language, memory, force_process=True) | ||
|
||
# Process users emotion, async | ||
if emotional_feedback: | ||
asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) | ||
except Exception as e: | ||
print(e) | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) | ||
return (500, str(e)) | ||
|
||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) | ||
result.postprocessing = MemoryPostProcessing( | ||
status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) | ||
|
||
return (200, result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The postprocess_memory
function is quite long and does a lot of different things, which makes it hard to understand and maintain. It would be better to split this function into smaller ones, each with a single responsibility. For example, you could have separate functions for audio processing, speech recognition, and result storing.
- def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str):
+ def postprocess_memory(memory_id: str, file_path: str, uid: str, emotional_feedback: bool, streaming_model: str):
+ memory = get_memory(memory_id, uid)
+ if not validate_memory(memory):
+ return (400, "Invalid memory")
+
+ audio_segment = process_audio(file_path)
+ if not audio_segment:
+ return (500, "Audio processing failed")
+
+ transcript_segments = transcribe_audio(audio_segment, uid, memory_id)
+ if not transcript_segments:
+ return (500, "Speech recognition failed")
+
+ store_results(uid, memory_id, streaming_model, transcript_segments)
+ return (200, "Success")
This way, each function can be tested and maintained independently, making the code more modular and easier to understand.
print(e) | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error handling in this function is done by printing the exception and returning an error status. This might not be sufficient for debugging purposes as the stack trace is lost. Consider logging the full exception using a logging library, which can provide more context when errors occur.
- print(e)
+ import logging
+ logging.exception("An error occurred during postprocessing.")
time.sleep(300) # 5 min | ||
delete_postprocessing_audio(file_path) | ||
os.remove(file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _delete_postprocessing_audio
function uses time.sleep(300)
to delay the deletion of the audio file. This could potentially block the execution of other tasks in the same thread. Consider using a non-blocking way to schedule this task, such as using asyncio.sleep
in an asynchronous function or scheduling the deletion in a separate thread.
- time.sleep(300) # 5 min
- delete_postprocessing_audio(file_path)
- os.remove(file_path)
+ async def _delete_postprocessing_audio(file_path):
+ await asyncio.sleep(300) # 5 min
+ delete_postprocessing_audio(file_path)
+ os.remove(file_path)
async def _process_user_emotion(uid: str, language_code: str, memory: Memory, urls: [str]): | ||
if not any(segment.is_user for segment in memory.transcript_segments): | ||
print(f"_process_user_emotion skipped for {memory.id}") | ||
return | ||
|
||
process_user_emotion(uid, language_code, memory, urls) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _process_user_emotion
function is marked as async
but it doesn't use any await
expressions inside. If the process_user_emotion
function is not asynchronous, there's no need to mark _process_user_emotion
as async
. If process_user_emotion
is indeed asynchronous and requires waiting, you should use await
before calling it.
- process_user_emotion(uid, language_code, memory, urls)
+ await process_user_emotion(uid, language_code, memory, urls)
def _handle_segment_embedding_matching(uid: str, file_path: str, segments: List[TranscriptSegment], aseg: AudioSegment): | ||
if aseg.frame_rate == 16000: | ||
matches = get_speech_profile_matching_predictions(uid, file_path, [s.dict() for s in segments]) | ||
for i, segment in enumerate(segments): | ||
segment.is_user = matches[i]['is_user'] | ||
segment.person_id = matches[i].get('person_id') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the _handle_segment_embedding_matching
function, you're directly modifying the segments
list that is passed as an argument. This can lead to unexpected side effects if the caller doesn't expect the list to be modified. Consider creating a new list with the updated segments instead of modifying the original one.
- for i, segment in enumerate(segments):
- segment.is_user = matches[i]['is_user']
- segment.person_id = matches[i].get('person_id')
+ updated_segments = []
+ for i, segment in enumerate(segments):
+ updated_segment = segment.copy()
+ updated_segment.is_user = matches[i]['is_user']
+ updated_segment.person_id = matches[i].get('person_id')
+ updated_segments.append(updated_segment)
+ return updated_segments
|
||
from fastapi import APIRouter, Depends, HTTPException, UploadFile | ||
from pydub import AudioSegment | ||
|
||
import database.memories as memories_db | ||
from database.users import get_user_store_recording_permission | ||
from models.memory import * | ||
from routers.memories import _get_memory_by_id | ||
from utils.memories.process_memory import process_memory, process_user_emotion | ||
from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util | ||
from utils.other import endpoints as auth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import statements have been significantly reduced, which suggests that the logic has been moved to the postprocess_memory_util
function. This is a good step towards modularity and maintainability. However, it's important to ensure that all necessary dependencies are imported in the new location of the code.
- import asyncio
- import os
- import threading
- import time
- from pydub import AudioSegment
- import database.memories as memories_db
- from database.users import get_user_store_recording_permission
- from routers.memories import _get_memory_by_id
- from utils.memories.process_memory import process_memory, process_user_emotion
- from utils.other.storage import upload_postprocessing_audio, \
- delete_postprocessing_audio, upload_memory_recording
- from utils.stt.pre_recorded import fal_whisperx, fal_postprocessing
- from utils.stt.speech_profile import get_speech_profile_matching_predictions
- from utils.stt.vad import vad_is_empty
+ from utils.memories.postprocess_memory import postprocess_memory as postprocess_memory_util
|
||
# Save file | ||
file_path = f"_temp/{memory_id}_{file.filename}" | ||
with open(file_path, 'wb') as f: | ||
f.write(file.file.read()) | ||
|
||
aseg = AudioSegment.from_wav(file_path) | ||
if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10 | ||
# TODO: fix app, sometimes audio uploaded is wrong, is too short. | ||
print('postprocess_memory: Audio duration is too short, seems wrong.') | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled) | ||
raise HTTPException(status_code=500, detail="Audio duration is too short, seems wrong.") | ||
|
||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress) | ||
|
||
try: | ||
# Calling VAD to avoid processing empty parts and getting hallucinations from whisper. | ||
vad_segments = vad_is_empty(file_path, return_segments=True) | ||
if vad_segments: | ||
start = vad_segments[0]['start'] | ||
end = vad_segments[-1]['end'] | ||
aseg = AudioSegment.from_wav(file_path) | ||
aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)] | ||
aseg.export(file_path, format="wav") | ||
except Exception as e: | ||
print(e) | ||
|
||
try: | ||
aseg = AudioSegment.from_wav(file_path) | ||
signed_url = upload_postprocessing_audio(file_path) | ||
threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start() | ||
|
||
if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid): | ||
upload_memory_recording(file_path, uid, memory_id) | ||
|
||
speakers_count = len(set([segment.speaker for segment in memory.transcript_segments])) | ||
words = fal_whisperx(signed_url, speakers_count) | ||
fal_segments = fal_postprocessing(words, aseg.duration_seconds) | ||
|
||
# if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL | ||
count = len(''.join([segment.text.strip() for segment in memory.transcript_segments])) | ||
new_count = len(''.join([segment.text.strip() for segment in fal_segments])) | ||
print('Prev characters count:', count, 'New characters count:', new_count) | ||
|
||
fal_failed = not fal_segments or new_count < (count * 0.85) | ||
|
||
if fal_failed: | ||
_handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg) | ||
else: | ||
_handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) | ||
|
||
# Store both models results. | ||
memories_db.store_model_segments_result(uid, memory.id, 'deepgram_streaming', memory.transcript_segments) | ||
memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments) | ||
|
||
if not fal_failed: | ||
memory.transcript_segments = fal_segments | ||
|
||
memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later | ||
if fal_failed: | ||
# TODO: FAL fails too much and is fucking expensive. Remove it. | ||
fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})' | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason) | ||
memory.postprocessing = MemoryPostProcessing( | ||
status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx, fail_reason=fail_reason, | ||
) | ||
# TODO: consider doing process_memory, if any segment still matched to user or people | ||
return memory | ||
|
||
# Reprocess memory with improved transcription | ||
result: Memory = process_memory(uid, memory.language, memory, force_process=True) | ||
|
||
# Process users emotion, async | ||
if emotional_feedback: | ||
asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url])) | ||
except Exception as e: | ||
print(e) | ||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e)) | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
# Process | ||
status_code, result = postprocess_memory_util(memory_id=memory_id, uid=uid, file_path=file_path, emotional_feedback=emotional_feedback, streaming_model="deepgram_streaming") | ||
if status_code != 200: | ||
raise HTTPException(status_code=status_code, detail=result) | ||
|
||
memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed) | ||
result.postprocessing = MemoryPostProcessing( | ||
status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes here seem to simplify the function by moving most of the logic to postprocess_memory_util
. This makes the function easier to read and understand. However, it's crucial to ensure that the postprocess_memory_util
function handles all the edge cases and error conditions that were previously handled here.
- memory_data = _get_memory_by_id(uid, memory_id)
- memory = Memory(**memory_data)
- if memory.discarded:
- print('postprocess_memory: Memory is discarded')
- raise HTTPException(status_code=400, detail="Memory is discarded")
-
- if memory.postprocessing is not None and memory.postprocessing.status != PostProcessingStatus.not_started:
- print(f'postprocess_memory: Memory can\'t be post-processed again {memory.postprocessing.status}')
- raise HTTPException(status_code=400, detail="Memory can't be post-processed again")
-
- file_path = f"_temp/{memory_id}_{file.filename}"
- with open(file_path, 'wb') as f:
- f.write(file.file.read())
-
- aseg = AudioSegment.from_wav(file_path)
- if aseg.duration_seconds < 10: # TODO: validate duration more accurately, segment.last.end - segment.first.start - 10
- # TODO: fix app, sometimes audio uploaded is wrong, is too short.
- print('postprocess_memory: Audio duration is too short, seems wrong.')
- memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.canceled)
- raise HTTPException(status_code=500, detail="Audio duration is too short, seems wrong.")
-
- memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.in_progress)
-
- try:
- # Calling VAD to avoid processing empty parts and getting hallucinations from whisper.
- vad_segments = vad_is_empty(file_path, return_segments=True)
- if vad_segments:
- start = vad_segments[0]['start']
- end = vad_segments[-1]['end']
- aseg = AudioSegment.from_wav(file_path)
- aseg = aseg[max(0, (start - 1) * 1000):min((end + 1) * 1000, aseg.duration_seconds * 1000)]
- aseg.export(file_path, format="wav")
- except Exception as e:
- print(e)
-
- try:
- aseg = AudioSegment.from_wav(file_path)
- signed_url = upload_postprocessing_audio(file_path)
- threading.Thread(target=_delete_postprocessing_audio, args=(file_path,)).start()
-
- if aseg.frame_rate == 16000 and get_user_store_recording_permission(uid):
- upload_memory_recording(file_path, uid, memory_id)
-
- speakers_count = len(set([segment.speaker for segment in memory.transcript_segments]))
- words = fal_whisperx(signed_url, speakers_count)
- fal_segments = fal_postprocessing(words, aseg.duration_seconds)
-
- # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL
- count = len(''.join([segment.text.strip() for segment in memory.transcript_segments]))
- new_count = len(''.join([segment.text.strip() for segment in fal_segments]))
- print('Prev characters count:', count, 'New characters count:', new_count)
-
- fal_failed = not fal_segments or new_count < (count * 0.85)
-
- if fal_failed:
- _handle_segment_embedding_matching(uid, file_path, memory.transcript_segments, aseg)
- else:
- _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg)
-
- # Store both models results.
- memories_db.store_model_segments_result(uid, memory.id, 'deepgram_streaming', memory.transcript_segments)
- memories_db.store_model_segments_result(uid, memory.id, 'fal_whisperx', fal_segments)
-
- if not fal_failed:
- memory.transcript_segments = fal_segments
-
- memories_db.upsert_memory(uid, memory.dict()) # Store transcript segments at least if smth fails later
- if fal_failed:
- # TODO: FAL fails too much and is fucking expensive. Remove it.
- fail_reason = 'FAL empty segments' if not fal_segments else f'FAL transcript too short ({new_count} vs {count})'
- memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=fail_reason)
- memory.postprocessing = MemoryPostProcessing(
- status=PostProcessingStatus.failed, model=PostProcessingModel.fal_whisperx, fail_reason=fail_reason,
- )
- # TODO: consider doing process_memory, if any segment still matched to user or people
- return memory
-
- # Reprocess memory with improved transcription
- result: Memory = process_memory(uid, memory.language, memory, force_process=True)
-
- # Process users emotion, async
- if emotional_feedback:
- asyncio.run(_process_user_emotion(uid, memory.language, memory, [signed_url]))
- except Exception as e:
- print(e)
- memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.failed, fail_reason=str(e))
- raise HTTPException(status_code=500, detail=str(e))
-
- memories_db.set_postprocessing_status(uid, memory.id, PostProcessingStatus.completed)
- result.postprocessing = MemoryPostProcessing(
- status=PostProcessingStatus.completed, model=PostProcessingModel.fal_whisperx)
+ # Save file
+ file_path = f"_temp/{memory_id}_{file.filename}"
+ with open(file_path, 'wb') as f:
+ f.write(file.file.read())
+
+ # Process
+ status_code, result = postprocess_memory_util(memory_id=memory_id, uid=uid, file_path=file_path, emotional_feedback=emotional_feedback, streaming_model="deepgram_streaming")
+ if status_code != 200:
+ raise HTTPException(status_code=status_code, detail=result)
+
+ return result
Summary by Entelligence.AI
postprocess_memory
function to streamline the memory postprocessing workflow, including memory status checks, audio processing, file uploads, speech recognition, and result storage.postprocess_memory.py
, improving code maintainability.opuslib==3.0.1
to the project dependencies to support audio processing tasks._get_memory_by_id
,_delete_postprocessing_audio
, and_handle_segment_embedding_matching
to work seamlessly with the new postprocessing flow.