Skip to content

Commit 0bf4c50

Browse files
Infer optimizer (#7261)
* add Infer optimizer * Update optimizer, and rename to InferRules --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent 78b50f1 commit 0bf4c50

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

dspy/teleprompt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# from .mipro_optimizer import MIPRO
1010
from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2
1111
from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch
12+
from dspy.teleprompt.infer_rules import InferRules
1213

1314
# from .signature_opt import SignatureOptimizer
1415
# from .signature_opt_bayesian import BayesianSignatureOptimizer

dspy/teleprompt/infer_rules.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import dspy
2+
import random
3+
import numpy as np
4+
5+
from dspy.teleprompt import BootstrapFewShot
6+
from dspy.evaluate.evaluate import Evaluate
7+
8+
9+
class InferRules(BootstrapFewShot):
10+
def __init__(self, num_candidates=10, num_rules=10, num_threads=8, teacher_settings=None, **kwargs):
11+
super().__init__(teacher_settings=teacher_settings, **kwargs)
12+
13+
self.num_candidates = num_candidates
14+
self.num_rules = num_rules
15+
self.num_threads = num_threads
16+
self.rules_induction_program = RulesInductionProgram(num_rules, teacher_settings=teacher_settings)
17+
self.metric = kwargs.get("metric")
18+
self.max_errors = kwargs.get("max_errors", 10)
19+
20+
def compile(self, student, *, teacher=None, trainset, valset=None):
21+
if valset is None:
22+
train_size = int(0.5 * len(trainset))
23+
trainset, valset = trainset[:train_size], trainset[train_size:]
24+
25+
super().compile(student, teacher=teacher, trainset=trainset)
26+
27+
original_program = self.student.deepcopy()
28+
all_predictors = [p for p in original_program.predictors() if hasattr(p, "signature")]
29+
instructions_list = [p.signature.instructions for p in all_predictors]
30+
31+
best_score = -np.inf
32+
best_program = None
33+
34+
for candidate_idx in range(self.num_candidates):
35+
candidate_program = original_program.deepcopy()
36+
candidate_predictors = [p for p in candidate_program.predictors() if hasattr(p, "signature")]
37+
38+
for i, predictor in enumerate(candidate_predictors):
39+
predictor.signature.instructions = instructions_list[i]
40+
41+
for i, predictor in enumerate(candidate_predictors):
42+
rules = self.induce_natural_language_rules(predictor, trainset)
43+
predictor.signature.instructions = instructions_list[i]
44+
self.update_program_instructions(predictor, rules)
45+
46+
score = self.evaluate_program(candidate_program, valset)
47+
48+
if score > best_score:
49+
best_score = score
50+
best_program = candidate_program
51+
52+
print(f"Evaluated Candidate {candidate_idx+1} with score {score}. Current best score: {best_score}")
53+
54+
print("Final best score:", best_score)
55+
56+
return best_program
57+
58+
def induce_natural_language_rules(self, predictor, trainset):
59+
demos = self.get_predictor_demos(trainset, predictor)
60+
signature = predictor.signature
61+
while True:
62+
examples_text = self.format_examples(demos, signature)
63+
try:
64+
return self.rules_induction_program(examples_text)
65+
except Exception as e:
66+
assert (
67+
isinstance(e, ValueError)
68+
or e.__class__.__name__ == "BadRequestError"
69+
or "ContextWindowExceededError" in str(e)
70+
)
71+
if len(demos) > 1:
72+
demos = demos[:-1]
73+
else:
74+
raise RuntimeError(
75+
"Failed to generate natural language rules since a single example couldn't fit in the model's context window."
76+
) from e
77+
78+
def update_program_instructions(self, predictor, natural_language_rules):
79+
predictor.signature.instructions = (
80+
f"{predictor.signature.instructions}\n\n"
81+
f"Please adhere to the following rules when making your prediction:\n{natural_language_rules}"
82+
)
83+
84+
def format_examples(self, demos, signature):
85+
examples_text = ""
86+
for demo in demos:
87+
input_fields = {k: v for k, v in demo.items() if k in signature.input_fields}
88+
output_fields = {k: v for k, v in demo.items() if k in signature.output_fields}
89+
input_text = "\n".join(f"{k}: {v}" for k, v in input_fields.items())
90+
output_text = "\n".join(f"{k}: {v}" for k, v in output_fields.items())
91+
examples_text += f"Input Fields:\n{input_text}\n\n=========\nOutput Fields:\n{output_text}\n\n"
92+
return examples_text
93+
94+
def get_predictor_demos(self, trainset, predictor):
95+
# TODO: Consider how this handled "incomplete" demos.
96+
signature = predictor.signature
97+
return [
98+
{
99+
key: value
100+
for key, value in example.items()
101+
if key in signature.input_fields or key in signature.output_fields
102+
}
103+
for example in trainset
104+
]
105+
106+
def evaluate_program(self, program, dataset):
107+
evaluate = Evaluate(
108+
devset=dataset,
109+
metric=self.metric,
110+
num_threads=self.num_threads,
111+
max_errors=self.max_errors,
112+
display_table=False,
113+
display_progress=True,
114+
return_all_scores=True,
115+
)
116+
score, _ = evaluate(program, metric=self.metric)
117+
return score
118+
119+
120+
class RulesInductionProgram(dspy.Module):
121+
def __init__(self, num_rules, teacher_settings=None):
122+
super().__init__()
123+
124+
class CustomRulesInduction(dspy.Signature):
125+
__doc__ = f"""Given a set of examples, extract a list of {num_rules} concise and non-redundant natural language rules that provide clear guidance for performing the task. All rules should be actionable for a well-specified scope of examples of this general kind of task."""
126+
examples_text = dspy.InputField(desc="Text containing examples")
127+
natural_language_rules = dspy.OutputField(desc="Induced natural language rules")
128+
129+
self.rules_induction = dspy.ChainOfThought(CustomRulesInduction)
130+
self.teacher_settings = teacher_settings or {}
131+
self.rng = random.Random(0)
132+
133+
def forward(self, examples_text):
134+
with dspy.settings.context(**self.teacher_settings):
135+
lm = dspy.settings.lm.copy(temperature=self.rng.uniform(0.9, 1.0))
136+
with dspy.settings.context(lm=lm):
137+
rules = self.rules_induction(examples_text=examples_text).natural_language_rules
138+
139+
return rules.strip()

0 commit comments

Comments
 (0)