Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions apps/common/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
with _lock:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
else:
if model_instance.is_cache_model():
ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code appears to be intended to optimize the retrieval of models from a cache while providing fallback functionality if the model is not found or invalid. However, there are a few potential improvements:

  1. Avoid Unnecessary Checks: The second if statement checks both whether the model instance should be considered and its validity. It’s usually more efficient to first retrieve and verify the existence before touching the TTL.

  2. Lock Management: Ensure that _lock is properly initialized and managed throughout the function to avoid concurrency issues.

  3. Cache Key Naming Consistency: Use consistent naming conventions for caching keys.

Here's an optimized version of the code:

from concurrent.futures import ThreadPoolExecutor

class ModelManage:
    # Assuming this is where lock initialization might happen
    _lock = threading.Lock()

def get_model(_id, get_model):
    with ThreadPoolExecutor() as executor:
        result = executor.submit(get_model, _id)

        try:
            response = result.result(timeout=60)  # Timeout after 60 seconds
        except futures.TimeoutError:
            raise Exception(f"Model retrieval took longer than expected for {_id}")

    model_instance = ModelManage.cache.get(_id)
    
    if model_instance is None:
        model_instance = response
        ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
        
        if not model_instance.should_be_cached():
            print("Warning: Retrieved model does not seem cacheable.")
            
    elif model_instance.is_valid():
        ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
        print(f"Model {model_id} renewed.")

    return model_instance

Key Improvements:

  • ThreadPoolExecutor: Used for async calls to get_model, which can improve performance by allowing other operations to proceed without waiting.
  • Timeout Handling: Added a timeout to prevent hanging indefinitely on result.
  • Improved Condition Check Logic: First tries to update the cache only when necessary, reducing overhead and ensuring consistency.

Note:

  • This assumes that _lock is correctly defined elsewhere in your implementation.
  • Consider adding error handling for actual exceptions thrown during model retrieval.
  • Adjust should_be_cached() method according to how your validation logic determines eligibility for caching.

Expand Down