Skip to content

Commit b937782

Browse files
authored
Merge pull request #53 from meta-llama/pdo
Add Prompt Duel Optimizer (PDO) - Label-Free Prompt Optimization
2 parents 670efa1 + 72eb96c commit b937782

File tree

14 files changed

+3800
-9
lines changed

14 files changed

+3800
-9
lines changed

README.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
<h1 align="center"> Prompt Ops </h1>
22

3+
### 🎉 New: Prompt Duel Optimizer (PDO) Published!
4+
5+
We've published a new paper on **PDO (Prompt Duel Optimizer)** - an efficient label-free prompt optimization method using dueling bandits and Thompson sampling. PDO achieves state-of-the-art results on BIG-bench Hard and MS MARCO benchmarks.
6+
7+
📄 **Read the paper:** [LLM Prompt Duel Optimizer: Efficient Label-Free Prompt Optimization](https://www.arxiv.org/abs/2510.13907) (arXiv:2510.13907)
8+
9+
🧪 **Try it yourself:** Check out the [Web of Lies use case](use-cases/web-of-lies-pdo/) demonstrating PDO on logical reasoning tasks
10+
11+
**Star this repo** and follow along - we'll be publishing a detailed tutorial notebook soon!
12+
13+
---
14+
315
## What is prompt-ops?
416
<p align="center">
517
<a href="https://pypi.org/project/prompt-ops/"><img src="https://img.shields.io/pypi/v/prompt-ops.svg" /></a>
@@ -85,19 +97,21 @@ These results were measured on the [HotpotQA multi-hop reasoning benchmark](http
8597

8698
### Step 1: Installation
8799

100+
> **Note:** We recommend installing from source as we are currently transitioning package names on PyPI. This ensures you get the latest stable version without any naming conflicts.
101+
88102
```bash
89103
# Create a virtual environment
90104
conda create -n prompt-ops python=3.10
91105
conda activate prompt-ops
92106

93-
# Install from PyPI
94-
pip install prompt-ops
95-
96-
# OR install from source
107+
# Recommended: Install from source
97108
git clone https://github.com/meta-llama/prompt-ops.git
98109
cd prompt-ops
99110
pip install -e .
100111

112+
# Alternative: Install from PyPI (may have naming transition issues, still on version 0.0.7)
113+
# pip install llama-prompt-ops
114+
101115
```
102116

103117
### Step 2: Create a sample project

src/prompt_ops/core/model.py

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
except ImportError:
3232
TEXTGRAD_AVAILABLE = False
3333

34+
try:
35+
import litellm
36+
37+
LITELLM_AVAILABLE = True
38+
except ImportError:
39+
LITELLM_AVAILABLE = False
40+
3441

3542
class ModelAdapter(ABC):
3643
"""
@@ -81,6 +88,36 @@ def generate_with_chat_format(
8188
"""
8289
pass
8390

91+
def generate_batch(
92+
self, prompts: List[str], max_threads: int = 1, **kwargs
93+
) -> List[str]:
94+
"""
95+
Generate responses for multiple prompts, optionally in parallel.
96+
97+
This method is useful for optimizers that need to evaluate multiple
98+
candidates simultaneously (e.g., PDO duels, batch evaluation).
99+
100+
Args:
101+
prompts: List of input prompts
102+
max_threads: Maximum number of threads for parallel execution
103+
**kwargs: Generation parameters (temperature, max_tokens, etc.)
104+
105+
Returns:
106+
List of generated responses in same order as input prompts
107+
"""
108+
if max_threads <= 1:
109+
# Sequential execution
110+
return [self.generate(prompt, **kwargs) for prompt in prompts]
111+
112+
# Parallel execution
113+
import concurrent.futures
114+
115+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor:
116+
futures = [
117+
executor.submit(self.generate, prompt, **kwargs) for prompt in prompts
118+
]
119+
return [future.result() for future in futures]
120+
84121

85122
class DSPyModelAdapter(ModelAdapter):
86123
"""
@@ -334,6 +371,186 @@ def generate_with_chat_format(
334371
return response.text
335372

336373

374+
class LiteLLMModelAdapter(ModelAdapter):
375+
"""
376+
Lightweight adapter using LiteLLM for simple text generation.
377+
378+
Provides a clean "prompt in, string out" interface without
379+
framework overhead. Ideal for optimization strategies that
380+
don't need DSPy's advanced features.
381+
"""
382+
383+
def __init__(
384+
self,
385+
model_name: str = None,
386+
api_base: str = None,
387+
api_key: str = None,
388+
max_tokens: int = 4096,
389+
temperature: float = 0.0,
390+
**kwargs,
391+
):
392+
"""
393+
Initialize the LiteLLM model adapter with configuration parameters.
394+
395+
Args:
396+
model_name: The model identifier (e.g., "openrouter/meta-llama/llama-3.3-70b-instruct")
397+
api_base: The API base URL
398+
api_key: The API key
399+
max_tokens: Maximum number of tokens to generate
400+
temperature: Sampling temperature
401+
**kwargs: Additional arguments to pass to litellm.completion
402+
"""
403+
if not LITELLM_AVAILABLE:
404+
raise ImportError(
405+
"LiteLLM is not installed. Install it with `pip install litellm`"
406+
)
407+
408+
# Store configuration
409+
self.model_name = model_name
410+
self.api_base = api_base
411+
self.api_key = api_key
412+
self.max_tokens = max_tokens
413+
self.temperature = temperature
414+
self.kwargs = kwargs
415+
416+
# Set up environment variables for LiteLLM if needed
417+
if api_key:
418+
self._setup_api_key(model_name, api_key)
419+
if api_base:
420+
self._setup_api_base(model_name, api_base)
421+
422+
def _setup_api_key(self, model_name: str, api_key: str):
423+
"""Set appropriate environment variable based on model provider."""
424+
model_lower = model_name.lower() if model_name else ""
425+
426+
if "openai" in model_lower and "openrouter" not in model_lower:
427+
os.environ["OPENAI_API_KEY"] = api_key
428+
elif "anthropic" in model_lower or "claude" in model_lower:
429+
os.environ["ANTHROPIC_API_KEY"] = api_key
430+
elif "openrouter" in model_lower:
431+
os.environ["OPENROUTER_API_KEY"] = api_key
432+
elif "together" in model_lower:
433+
os.environ["TOGETHER_API_KEY"] = api_key
434+
# Add more providers as needed
435+
436+
def _setup_api_base(self, model_name: str, api_base: str):
437+
"""Set appropriate base URL environment variable."""
438+
model_lower = model_name.lower() if model_name else ""
439+
440+
if "openai" in model_lower and "openrouter" not in model_lower:
441+
os.environ["OPENAI_API_BASE"] = api_base
442+
elif "openrouter" in model_lower:
443+
os.environ["OPENROUTER_API_BASE"] = api_base
444+
# Add more as needed
445+
446+
def generate(
447+
self, prompt: str, temperature: float = None, max_tokens: int = None, **kwargs
448+
) -> str:
449+
"""
450+
Generate text from a prompt using LiteLLM.
451+
452+
Args:
453+
prompt: The input prompt text
454+
temperature: Override the default temperature
455+
max_tokens: Override the default max tokens
456+
**kwargs: Additional generation parameters
457+
458+
Returns:
459+
The generated text response
460+
"""
461+
# Use override values or defaults
462+
temp = temperature if temperature is not None else self.temperature
463+
tokens = max_tokens if max_tokens is not None else self.max_tokens
464+
465+
# Prepare LiteLLM call
466+
messages = [{"role": "user", "content": prompt}]
467+
468+
# Filter out DSPy-specific parameters that LiteLLM doesn't understand
469+
filtered_kwargs = {
470+
k: v
471+
for k, v in self.kwargs.items()
472+
if k not in ["cache", "model"] # Remove DSPy-specific params
473+
}
474+
475+
# Prepare kwargs for litellm
476+
litellm_kwargs = {
477+
"model": self.model_name,
478+
"messages": messages,
479+
"temperature": temp,
480+
"max_tokens": tokens,
481+
**filtered_kwargs,
482+
**kwargs,
483+
}
484+
485+
# Add API base if specified
486+
if self.api_base:
487+
litellm_kwargs["api_base"] = self.api_base
488+
489+
try:
490+
response = litellm.completion(**litellm_kwargs)
491+
492+
# Extract text from response
493+
return response.choices[0].message.content
494+
495+
except Exception as e:
496+
# Convert to our standard error types if needed
497+
raise e
498+
499+
def generate_with_chat_format(
500+
self,
501+
messages: List[Dict[str, str]],
502+
temperature: float = None,
503+
max_tokens: int = None,
504+
**kwargs,
505+
) -> str:
506+
"""
507+
Generate text using a chat format with multiple messages.
508+
509+
Args:
510+
messages: List of message dictionaries with 'role' and 'content' keys
511+
temperature: Override the default temperature
512+
max_tokens: Override the default max tokens
513+
**kwargs: Additional generation parameters
514+
515+
Returns:
516+
The generated text response
517+
"""
518+
# Use override values or defaults
519+
temp = temperature if temperature is not None else self.temperature
520+
tokens = max_tokens if max_tokens is not None else self.max_tokens
521+
522+
# Filter out DSPy-specific parameters that LiteLLM doesn't understand
523+
filtered_kwargs = {
524+
k: v
525+
for k, v in self.kwargs.items()
526+
if k not in ["cache", "model"] # Remove DSPy-specific params
527+
}
528+
529+
# Prepare kwargs for litellm
530+
litellm_kwargs = {
531+
"model": self.model_name,
532+
"messages": messages,
533+
"temperature": temp,
534+
"max_tokens": tokens,
535+
**filtered_kwargs,
536+
**kwargs,
537+
}
538+
539+
# Add API base if specified
540+
if self.api_base:
541+
litellm_kwargs["api_base"] = self.api_base
542+
543+
try:
544+
response = litellm.completion(**litellm_kwargs)
545+
546+
# Extract text from response
547+
return response.choices[0].message.content
548+
549+
except Exception as e:
550+
# Convert to our standard error types if needed
551+
raise e
552+
553+
337554
def setup_model(model_name=None, adapter_type="dspy", **kwargs):
338555
"""
339556
Set up a model adapter using the specified adapter type.
@@ -343,7 +560,7 @@ def setup_model(model_name=None, adapter_type="dspy", **kwargs):
343560
344561
Args:
345562
model_name: The model identifier (e.g., "openai/gpt-4o-mini", "anthropic/claude-3-opus-20240229")
346-
adapter_type: The adapter type to use ("dspy" or "textgrad")
563+
adapter_type: The adapter type to use ("dspy", "textgrad", or "litellm")
347564
**kwargs: Additional adapter-specific configuration options
348565
349566
Returns:
@@ -394,6 +611,14 @@ def setup_model(model_name=None, adapter_type="dspy", **kwargs):
394611
logger.progress(
395612
f" Using model with TextGrad: {kwargs.get('model_name', 'custom configuration')}"
396613
)
614+
elif adapter_type.lower() == "litellm":
615+
# For LiteLLM, use model_name directly
616+
if model_name and "model_name" not in kwargs:
617+
kwargs["model_name"] = model_name
618+
adapter = LiteLLMModelAdapter(**kwargs)
619+
logger.progress(
620+
f" Using model with LiteLLM: {kwargs.get('model_name', 'custom configuration')}"
621+
)
397622
else:
398623
raise ValueError(f"Unsupported adapter type: {adapter_type}")
399624

@@ -405,7 +630,7 @@ def get_model_adapter(adapter_type, **kwargs):
405630
Get a model adapter instance by type.
406631
407632
Args:
408-
adapter_type: The adapter type ("dspy" or "textgrad")
633+
adapter_type: The adapter type ("dspy", "textgrad", or "litellm")
409634
**kwargs: Configuration parameters for the adapter
410635
411636
Returns:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
"""
7+
PDO (Prompt Duel Optimizer) module.
8+
9+
This module provides a dueling bandit optimization approach for prompt optimization
10+
using Thompson sampling, multiple ranking systems, and reflective prompt evolution.
11+
"""
12+
13+
from .optimization_engine import PDOEngine
14+
from .ranking_systems import (
15+
TrueSkillFromCounts,
16+
avg_winrate_ranking,
17+
borda_ranking,
18+
copeland_ranking,
19+
elo_ranking,
20+
)
21+
from .thompson_sampling import sample_duel_pair
22+
23+
# Legacy imports removed; mutation now handled inside optimization engine.
24+
25+
__all__ = [
26+
"PDOEngine",
27+
"sample_duel_pair",
28+
"copeland_ranking",
29+
"borda_ranking",
30+
"avg_winrate_ranking",
31+
"elo_ranking",
32+
"TrueSkillFromCounts",
33+
]

0 commit comments

Comments
 (0)