Source code for llm_batch_helper.providers

import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime
import warnings

import httpx
import openai
import google.generativeai as genai
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential, before_sleep_log
from tqdm.asyncio import tqdm_asyncio

from .cache import LLMCache
from .config import LLMConfig
from .input_handlers import get_prompts


def _run_async_function(async_func, *args, **kwargs):
    """
    Run an async function in a way that works in both regular Python and Jupyter notebooks.
    
    This handles the event loop management properly for different environments.
    """
    try:
        # Try to get the current event loop
        loop = asyncio.get_running_loop()
        # If we're in a running loop (like Jupyter), we need to use nest_asyncio
        try:
            import nest_asyncio
            nest_asyncio.apply()
            return asyncio.run(async_func(*args, **kwargs))
        except ImportError:
            # If nest_asyncio is not available, try to run in the current loop
            # This is a fallback that might work in some cases
            import concurrent.futures
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(asyncio.run, async_func(*args, **kwargs))
                return future.result()
    except RuntimeError:
        # No event loop running, we can use asyncio.run directly
        return asyncio.run(async_func(*args, **kwargs))


[docs] def log_retry_attempt(retry_state): """Custom logging function for retry attempts.""" attempt_number = retry_state.attempt_number exception = retry_state.outcome.exception() wait_time = retry_state.next_action.sleep if retry_state.next_action else 0 error_type = type(exception).__name__ error_msg = str(exception) # Extract status code if available status_code = "unknown" if hasattr(exception, 'status_code'): status_code = exception.status_code elif hasattr(exception, 'response') and hasattr(exception.response, 'status_code'): status_code = exception.response.status_code print(f"🔄 [{datetime.now().strftime('%H:%M:%S')}] Retry attempt {attempt_number}/5:") print(f" Error: {error_type} (status: {status_code})") print(f" Message: {error_msg[:100]}{'...' if len(error_msg) > 100 else ''}") print(f" Waiting {wait_time:.1f}s before next attempt...") print()
@retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( ( ConnectionError, TimeoutError, openai.APITimeoutError, openai.APIConnectionError, openai.RateLimitError, openai.APIError, ) ), before_sleep=log_retry_attempt, reraise=True, ) async def _get_openai_response_direct( prompt: str, config: LLMConfig ) -> Dict[str, Union[str, Dict]]: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable not set") async with httpx.AsyncClient(timeout=1000.0) as client: aclient = openai.AsyncOpenAI(api_key=api_key, http_client=client) messages = [ {"role": "system", "content": config.system_instruction}, {"role": "user", "content": prompt}, ] response = await aclient.chat.completions.create( model=config.model_name, messages=messages, temperature=config.temperature, max_completion_tokens=config.max_completion_tokens, **config.kwargs, ) usage_details = { "prompt_token_count": response.usage.prompt_tokens, "completion_token_count": response.usage.completion_tokens, "total_token_count": response.usage.total_tokens, } return { "response_text": response.choices[0].message.content, "usage_details": usage_details, } @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( ( ConnectionError, TimeoutError, httpx.HTTPStatusError, httpx.RequestError, ) ), reraise=True, ) async def _get_together_response_direct( prompt: str, config: LLMConfig ) -> Dict[str, Union[str, Dict]]: api_key = os.environ.get("TOGETHER_API_KEY") if not api_key: raise ValueError("TOGETHER_API_KEY environment variable not set") async with httpx.AsyncClient(timeout=1000.0) as client: messages = [ {"role": "system", "content": config.system_instruction}, {"role": "user", "content": prompt}, ] headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "model": config.model_name, "messages": messages, "temperature": config.temperature, "max_tokens": config.max_completion_tokens, **config.kwargs, } response = await client.post( "https://api.together.xyz/chat/completions", json=payload, headers=headers, ) response.raise_for_status() response_data = response.json() usage = response_data.get("usage", {}) usage_details = { "prompt_token_count": usage.get("prompt_tokens", 0), "completion_token_count": usage.get("completion_tokens", 0), "total_token_count": usage.get("total_tokens", 0), } return { "response_text": response_data["choices"][0]["message"]["content"], "usage_details": usage_details, } @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( ( ConnectionError, TimeoutError, httpx.HTTPStatusError, httpx.RequestError, ) ), before_sleep=log_retry_attempt, reraise=True, ) async def _get_openrouter_response_direct( prompt: str, config: LLMConfig ) -> Dict[str, Union[str, Dict]]: api_key = os.environ.get("OPENROUTER_API_KEY") if not api_key: raise ValueError("OPENROUTER_API_KEY environment variable not set") async with httpx.AsyncClient(timeout=1000.0) as client: messages = [ {"role": "system", "content": config.system_instruction}, {"role": "user", "content": prompt}, ] headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "model": config.model_name, "messages": messages, "temperature": config.temperature, "max_tokens": config.max_completion_tokens, **config.kwargs, } response = await client.post( "https://openrouter.ai/api/v1/chat/completions", json=payload, headers=headers, ) response.raise_for_status() response_data = response.json() usage = response_data.get("usage", {}) usage_details = { "prompt_token_count": usage.get("prompt_tokens", 0), "completion_token_count": usage.get("completion_tokens", 0), "total_token_count": usage.get("total_tokens", 0), } return { "response_text": response_data["choices"][0]["message"]["content"], "usage_details": usage_details, } @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( ( ConnectionError, TimeoutError, Exception, # Gemini SDK may raise various exceptions ) ), before_sleep=log_retry_attempt, reraise=True, ) async def _get_gemini_response_direct( prompt: str, config: LLMConfig ) -> Dict[str, Union[str, Dict]]: api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY or GOOGLE_API_KEY environment variable not set") # Configure the Gemini client genai.configure(api_key=api_key) # Create the model model = genai.GenerativeModel(config.model_name) # Prepare the prompt with system instruction if provided full_prompt = prompt if config.system_instruction and config.system_instruction.strip(): full_prompt = f"{config.system_instruction}\n\n{prompt}" try: # Generate content asynchronously response = await asyncio.to_thread( model.generate_content, full_prompt, generation_config=genai.types.GenerationConfig( temperature=config.temperature, max_output_tokens=config.max_completion_tokens, **{k: v for k, v in config.kwargs.items() if k in ['top_p', 'top_k', 'candidate_count']} ) ) # Extract usage information if available usage_details = { "prompt_token_count": getattr(response.usage_metadata, 'prompt_token_count', 0) if hasattr(response, 'usage_metadata') else 0, "completion_token_count": getattr(response.usage_metadata, 'candidates_token_count', 0) if hasattr(response, 'usage_metadata') else 0, "total_token_count": getattr(response.usage_metadata, 'total_token_count', 0) if hasattr(response, 'usage_metadata') else 0, } return { "response_text": response.text, "usage_details": usage_details, } except Exception as e: # Handle potential safety blocks or other Gemini-specific errors if hasattr(e, 'message') and 'block' in str(e).lower(): return { "response_text": "[Content blocked by safety filters]", "usage_details": {"prompt_token_count": 0, "completion_token_count": 0, "total_token_count": 0}, "blocked": True } raise e
[docs] async def get_llm_response_with_internal_retry( prompt_id: str, prompt: str, config: LLMConfig, provider: str, cache: Optional[LLMCache] = None, force: bool = False, ) -> Dict[str, Union[str, Dict]]: # Check cache first if available and not forcing regeneration if cache and not force: cached_response = cache.get_cached_response(prompt_id) if cached_response: return cached_response["llm_response"] try: if provider.lower() == "openai": response = await _get_openai_response_direct(prompt, config) elif provider.lower() == "together": response = await _get_together_response_direct(prompt, config) elif provider.lower() == "openrouter": response = await _get_openrouter_response_direct(prompt, config) elif provider.lower() == "gemini": response = await _get_gemini_response_direct(prompt, config) else: raise ValueError(f"Unsupported provider: {provider}") # Cache the response if cache is available if cache and "error" not in response: cache.save_response(prompt_id, prompt, response) return response except Exception as e: return { "error": f"LLM API call failed after internal retries: {e!s}", "provider": provider, }
[docs] async def process_prompts_batch_async( prompts: Optional[List[Union[str, Tuple[str, str], Dict[str, Any]]]] = None, input_dir: Optional[str] = None, config: LLMConfig = None, provider: str = "openai", desc: str = "Processing prompts", cache_dir: Optional[str] = None, force: bool = False, ) -> Dict[str, Dict[str, Union[str, Dict]]]: """Process a batch of prompts through the LLM. Args: prompts: Optional list of prompts in any supported format (string, tuple, or dict) input_dir: Optional path to directory containing prompt files config: LLM configuration provider: LLM provider to use ("openai", "together", "openrouter", or "gemini") desc: Description for progress bar cache_dir: Optional directory for caching responses force: If True, force regeneration even if cached response exists Returns: Dict mapping prompt IDs to their responses, ordered by input sequence Note: Either prompts or input_dir must be provided, but not both. Results are returned in the same order as the input prompts. """ if prompts is None and input_dir is None: raise ValueError("Either prompts or input_dir must be provided") if prompts is not None and input_dir is not None: raise ValueError("Cannot specify both prompts and input_dir") # Get prompts from either source if input_dir is not None: prompts = get_prompts(input_dir) else: prompts = get_prompts(prompts) # Create semaphore for concurrent requests semaphore = asyncio.Semaphore(config.max_concurrent_requests) # Process prompts results = {} # Keep track of original order for sorting results prompt_order = {prompt_id: idx for idx, (prompt_id, _) in enumerate(prompts)} tasks = [ _process_single_prompt_attempt_with_verification( prompt_id, prompt_text, config, provider, semaphore, cache_dir, force ) for prompt_id, prompt_text in prompts ] for future in tqdm_asyncio(asyncio.as_completed(tasks), total=len(tasks), desc=desc): prompt_id, response_data = await future results[prompt_id] = response_data # Sort results by original input order to maintain input sequence # Note: Python 3.7+ guarantees dict insertion order, we explicitly sort # to ensure results match the original prompt order regardless of completion order ordered_results = {} for prompt_id in sorted(results.keys(), key=lambda pid: prompt_order[pid]): ordered_results[prompt_id] = results[prompt_id] return ordered_results
[docs] def process_prompts_batch( prompts: Optional[List[Union[str, Tuple[str, str], Dict[str, Any]]]] = None, input_dir: Optional[str] = None, config: LLMConfig = None, provider: str = "openai", desc: str = "Processing prompts", cache_dir: Optional[str] = None, force: bool = False, ) -> Dict[str, Dict[str, Union[str, Dict]]]: """ Process a batch of prompts through the LLM (synchronous version). This is the main user-facing function that works in both regular Python scripts and Jupyter notebooks without requiring async/await syntax. Args: prompts: Optional list of prompts in any supported format (string, tuple, or dict) input_dir: Optional path to directory containing prompt files config: LLM configuration provider: LLM provider to use ("openai", "together", "openrouter", or "gemini") desc: Description for progress bar cache_dir: Optional directory for caching responses force: If True, force regeneration even if cached response exists Returns: Dict mapping prompt IDs to their responses, ordered by input sequence Note: Either prompts or input_dir must be provided, but not both. Results are returned in the same order as the input prompts. Example: >>> from llm_batch_helper import LLMConfig, process_prompts_batch >>> config = LLMConfig(model_name="gpt-4o-mini") >>> results = process_prompts_batch( ... prompts=["What is 2+2?", "What is the capital of France?"], ... config=config, ... provider="openai" ... ) >>> # Results will be in the same order as input prompts """ return _run_async_function( process_prompts_batch_async, prompts=prompts, input_dir=input_dir, config=config, provider=provider, desc=desc, cache_dir=cache_dir, force=force, )
async def _process_single_prompt_attempt_with_verification( prompt_id: str, prompt_text: str, config: LLMConfig, provider: str, semaphore: asyncio.Semaphore, cache_dir: Optional[str] = None, force: bool = False, ): """Process a single prompt with verification and caching.""" async with semaphore: # Check cache first if cache_dir is provided if cache_dir and not force: cache = LLMCache(cache_dir) cached_response = cache.get_cached_response(prompt_id) if cached_response is not None: cached_response_data = cached_response["llm_response"] # If no verification callback, use cached response directly if config.verification_callback is None: return prompt_id, {**cached_response_data, "from_cache": True} # Verify response if callback provided verified = await asyncio.to_thread( config.verification_callback, prompt_id, cached_response_data, prompt_text, **config.verification_callback_args, ) if verified: return prompt_id, {**cached_response_data, "from_cache": True} # Process the prompt last_exception_details = None for attempt in range(config.max_retries): if attempt > 0: print(f"🔁 [{datetime.now().strftime('%H:%M:%S')}] Application-level retry {attempt+1}/{config.max_retries} for prompt: {prompt_id}") try: # Get LLM response llm_response_data = await get_llm_response_with_internal_retry( prompt_id, prompt_text, config, provider ) if "error" in llm_response_data: print(f"❌ [{datetime.now().strftime('%H:%M:%S')}] API call failed on attempt {attempt+1}: {llm_response_data.get('error', 'Unknown error')}") last_exception_details = llm_response_data if attempt < config.max_retries - 1: wait_time = min(2 * 2**attempt, 30) print(f" Waiting {wait_time}s before next application retry...") await asyncio.sleep(wait_time) continue # Verify response if callback provided if config.verification_callback: verified = await asyncio.to_thread( config.verification_callback, prompt_id, llm_response_data, prompt_text, **config.verification_callback_args, ) if not verified: last_exception_details = { "error": f"Verification failed on attempt {attempt + 1}", "prompt_id": prompt_id, "llm_response_data": llm_response_data, } if attempt == config.max_retries - 1: return prompt_id, last_exception_details continue # Save to cache if cache_dir provided if cache_dir: cache = LLMCache(cache_dir) cache.save_response(prompt_id, prompt_text, llm_response_data) return prompt_id, llm_response_data except Exception as e: last_exception_details = { "error": f"Unexpected error: {e!s}", "prompt_id": prompt_id, } if attempt == config.max_retries - 1: return prompt_id, last_exception_details # Sleep is now handled above with logging continue return prompt_id, last_exception_details or { "error": f"Exhausted all {config.max_retries} retries for {prompt_id}" }