3131except ImportError :
3232 TEXTGRAD_AVAILABLE = False
3333
34+ try :
35+ import litellm
36+
37+ LITELLM_AVAILABLE = True
38+ except ImportError :
39+ LITELLM_AVAILABLE = False
40+
3441
3542class ModelAdapter (ABC ):
3643 """
@@ -81,6 +88,36 @@ def generate_with_chat_format(
8188 """
8289 pass
8390
91+ def generate_batch (
92+ self , prompts : List [str ], max_threads : int = 1 , ** kwargs
93+ ) -> List [str ]:
94+ """
95+ Generate responses for multiple prompts, optionally in parallel.
96+
97+ This method is useful for optimizers that need to evaluate multiple
98+ candidates simultaneously (e.g., PDO duels, batch evaluation).
99+
100+ Args:
101+ prompts: List of input prompts
102+ max_threads: Maximum number of threads for parallel execution
103+ **kwargs: Generation parameters (temperature, max_tokens, etc.)
104+
105+ Returns:
106+ List of generated responses in same order as input prompts
107+ """
108+ if max_threads <= 1 :
109+ # Sequential execution
110+ return [self .generate (prompt , ** kwargs ) for prompt in prompts ]
111+
112+ # Parallel execution
113+ import concurrent .futures
114+
115+ with concurrent .futures .ThreadPoolExecutor (max_workers = max_threads ) as executor :
116+ futures = [
117+ executor .submit (self .generate , prompt , ** kwargs ) for prompt in prompts
118+ ]
119+ return [future .result () for future in futures ]
120+
84121
85122class DSPyModelAdapter (ModelAdapter ):
86123 """
@@ -334,6 +371,186 @@ def generate_with_chat_format(
334371 return response .text
335372
336373
374+ class LiteLLMModelAdapter (ModelAdapter ):
375+ """
376+ Lightweight adapter using LiteLLM for simple text generation.
377+
378+ Provides a clean "prompt in, string out" interface without
379+ framework overhead. Ideal for optimization strategies that
380+ don't need DSPy's advanced features.
381+ """
382+
383+ def __init__ (
384+ self ,
385+ model_name : str = None ,
386+ api_base : str = None ,
387+ api_key : str = None ,
388+ max_tokens : int = 4096 ,
389+ temperature : float = 0.0 ,
390+ ** kwargs ,
391+ ):
392+ """
393+ Initialize the LiteLLM model adapter with configuration parameters.
394+
395+ Args:
396+ model_name: The model identifier (e.g., "openrouter/meta-llama/llama-3.3-70b-instruct")
397+ api_base: The API base URL
398+ api_key: The API key
399+ max_tokens: Maximum number of tokens to generate
400+ temperature: Sampling temperature
401+ **kwargs: Additional arguments to pass to litellm.completion
402+ """
403+ if not LITELLM_AVAILABLE :
404+ raise ImportError (
405+ "LiteLLM is not installed. Install it with `pip install litellm`"
406+ )
407+
408+ # Store configuration
409+ self .model_name = model_name
410+ self .api_base = api_base
411+ self .api_key = api_key
412+ self .max_tokens = max_tokens
413+ self .temperature = temperature
414+ self .kwargs = kwargs
415+
416+ # Set up environment variables for LiteLLM if needed
417+ if api_key :
418+ self ._setup_api_key (model_name , api_key )
419+ if api_base :
420+ self ._setup_api_base (model_name , api_base )
421+
422+ def _setup_api_key (self , model_name : str , api_key : str ):
423+ """Set appropriate environment variable based on model provider."""
424+ model_lower = model_name .lower () if model_name else ""
425+
426+ if "openai" in model_lower and "openrouter" not in model_lower :
427+ os .environ ["OPENAI_API_KEY" ] = api_key
428+ elif "anthropic" in model_lower or "claude" in model_lower :
429+ os .environ ["ANTHROPIC_API_KEY" ] = api_key
430+ elif "openrouter" in model_lower :
431+ os .environ ["OPENROUTER_API_KEY" ] = api_key
432+ elif "together" in model_lower :
433+ os .environ ["TOGETHER_API_KEY" ] = api_key
434+ # Add more providers as needed
435+
436+ def _setup_api_base (self , model_name : str , api_base : str ):
437+ """Set appropriate base URL environment variable."""
438+ model_lower = model_name .lower () if model_name else ""
439+
440+ if "openai" in model_lower and "openrouter" not in model_lower :
441+ os .environ ["OPENAI_API_BASE" ] = api_base
442+ elif "openrouter" in model_lower :
443+ os .environ ["OPENROUTER_API_BASE" ] = api_base
444+ # Add more as needed
445+
446+ def generate (
447+ self , prompt : str , temperature : float = None , max_tokens : int = None , ** kwargs
448+ ) -> str :
449+ """
450+ Generate text from a prompt using LiteLLM.
451+
452+ Args:
453+ prompt: The input prompt text
454+ temperature: Override the default temperature
455+ max_tokens: Override the default max tokens
456+ **kwargs: Additional generation parameters
457+
458+ Returns:
459+ The generated text response
460+ """
461+ # Use override values or defaults
462+ temp = temperature if temperature is not None else self .temperature
463+ tokens = max_tokens if max_tokens is not None else self .max_tokens
464+
465+ # Prepare LiteLLM call
466+ messages = [{"role" : "user" , "content" : prompt }]
467+
468+ # Filter out DSPy-specific parameters that LiteLLM doesn't understand
469+ filtered_kwargs = {
470+ k : v
471+ for k , v in self .kwargs .items ()
472+ if k not in ["cache" , "model" ] # Remove DSPy-specific params
473+ }
474+
475+ # Prepare kwargs for litellm
476+ litellm_kwargs = {
477+ "model" : self .model_name ,
478+ "messages" : messages ,
479+ "temperature" : temp ,
480+ "max_tokens" : tokens ,
481+ ** filtered_kwargs ,
482+ ** kwargs ,
483+ }
484+
485+ # Add API base if specified
486+ if self .api_base :
487+ litellm_kwargs ["api_base" ] = self .api_base
488+
489+ try :
490+ response = litellm .completion (** litellm_kwargs )
491+
492+ # Extract text from response
493+ return response .choices [0 ].message .content
494+
495+ except Exception as e :
496+ # Convert to our standard error types if needed
497+ raise e
498+
499+ def generate_with_chat_format (
500+ self ,
501+ messages : List [Dict [str , str ]],
502+ temperature : float = None ,
503+ max_tokens : int = None ,
504+ ** kwargs ,
505+ ) -> str :
506+ """
507+ Generate text using a chat format with multiple messages.
508+
509+ Args:
510+ messages: List of message dictionaries with 'role' and 'content' keys
511+ temperature: Override the default temperature
512+ max_tokens: Override the default max tokens
513+ **kwargs: Additional generation parameters
514+
515+ Returns:
516+ The generated text response
517+ """
518+ # Use override values or defaults
519+ temp = temperature if temperature is not None else self .temperature
520+ tokens = max_tokens if max_tokens is not None else self .max_tokens
521+
522+ # Filter out DSPy-specific parameters that LiteLLM doesn't understand
523+ filtered_kwargs = {
524+ k : v
525+ for k , v in self .kwargs .items ()
526+ if k not in ["cache" , "model" ] # Remove DSPy-specific params
527+ }
528+
529+ # Prepare kwargs for litellm
530+ litellm_kwargs = {
531+ "model" : self .model_name ,
532+ "messages" : messages ,
533+ "temperature" : temp ,
534+ "max_tokens" : tokens ,
535+ ** filtered_kwargs ,
536+ ** kwargs ,
537+ }
538+
539+ # Add API base if specified
540+ if self .api_base :
541+ litellm_kwargs ["api_base" ] = self .api_base
542+
543+ try :
544+ response = litellm .completion (** litellm_kwargs )
545+
546+ # Extract text from response
547+ return response .choices [0 ].message .content
548+
549+ except Exception as e :
550+ # Convert to our standard error types if needed
551+ raise e
552+
553+
337554def setup_model (model_name = None , adapter_type = "dspy" , ** kwargs ):
338555 """
339556 Set up a model adapter using the specified adapter type.
@@ -343,7 +560,7 @@ def setup_model(model_name=None, adapter_type="dspy", **kwargs):
343560
344561 Args:
345562 model_name: The model identifier (e.g., "openai/gpt-4o-mini", "anthropic/claude-3-opus-20240229")
346- adapter_type: The adapter type to use ("dspy" or "textgrad ")
563+ adapter_type: The adapter type to use ("dspy", "textgrad", or "litellm ")
347564 **kwargs: Additional adapter-specific configuration options
348565
349566 Returns:
@@ -394,6 +611,14 @@ def setup_model(model_name=None, adapter_type="dspy", **kwargs):
394611 logger .progress (
395612 f" Using model with TextGrad: { kwargs .get ('model_name' , 'custom configuration' )} "
396613 )
614+ elif adapter_type .lower () == "litellm" :
615+ # For LiteLLM, use model_name directly
616+ if model_name and "model_name" not in kwargs :
617+ kwargs ["model_name" ] = model_name
618+ adapter = LiteLLMModelAdapter (** kwargs )
619+ logger .progress (
620+ f" Using model with LiteLLM: { kwargs .get ('model_name' , 'custom configuration' )} "
621+ )
397622 else :
398623 raise ValueError (f"Unsupported adapter type: { adapter_type } " )
399624
@@ -405,7 +630,7 @@ def get_model_adapter(adapter_type, **kwargs):
405630 Get a model adapter instance by type.
406631
407632 Args:
408- adapter_type: The adapter type ("dspy" or "textgrad ")
633+ adapter_type: The adapter type ("dspy", "textgrad", or "litellm ")
409634 **kwargs: Configuration parameters for the adapter
410635
411636 Returns:
0 commit comments