|
| 1 | +import dspy |
| 2 | +import random |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from dspy.teleprompt import BootstrapFewShot |
| 6 | +from dspy.evaluate.evaluate import Evaluate |
| 7 | + |
| 8 | + |
| 9 | +class InferRules(BootstrapFewShot): |
| 10 | + def __init__(self, num_candidates=10, num_rules=10, num_threads=8, teacher_settings=None, **kwargs): |
| 11 | + super().__init__(teacher_settings=teacher_settings, **kwargs) |
| 12 | + |
| 13 | + self.num_candidates = num_candidates |
| 14 | + self.num_rules = num_rules |
| 15 | + self.num_threads = num_threads |
| 16 | + self.rules_induction_program = RulesInductionProgram(num_rules, teacher_settings=teacher_settings) |
| 17 | + self.metric = kwargs.get("metric") |
| 18 | + self.max_errors = kwargs.get("max_errors", 10) |
| 19 | + |
| 20 | + def compile(self, student, *, teacher=None, trainset, valset=None): |
| 21 | + if valset is None: |
| 22 | + train_size = int(0.5 * len(trainset)) |
| 23 | + trainset, valset = trainset[:train_size], trainset[train_size:] |
| 24 | + |
| 25 | + super().compile(student, teacher=teacher, trainset=trainset) |
| 26 | + |
| 27 | + original_program = self.student.deepcopy() |
| 28 | + all_predictors = [p for p in original_program.predictors() if hasattr(p, "signature")] |
| 29 | + instructions_list = [p.signature.instructions for p in all_predictors] |
| 30 | + |
| 31 | + best_score = -np.inf |
| 32 | + best_program = None |
| 33 | + |
| 34 | + for candidate_idx in range(self.num_candidates): |
| 35 | + candidate_program = original_program.deepcopy() |
| 36 | + candidate_predictors = [p for p in candidate_program.predictors() if hasattr(p, "signature")] |
| 37 | + |
| 38 | + for i, predictor in enumerate(candidate_predictors): |
| 39 | + predictor.signature.instructions = instructions_list[i] |
| 40 | + |
| 41 | + for i, predictor in enumerate(candidate_predictors): |
| 42 | + rules = self.induce_natural_language_rules(predictor, trainset) |
| 43 | + predictor.signature.instructions = instructions_list[i] |
| 44 | + self.update_program_instructions(predictor, rules) |
| 45 | + |
| 46 | + score = self.evaluate_program(candidate_program, valset) |
| 47 | + |
| 48 | + if score > best_score: |
| 49 | + best_score = score |
| 50 | + best_program = candidate_program |
| 51 | + |
| 52 | + print(f"Evaluated Candidate {candidate_idx+1} with score {score}. Current best score: {best_score}") |
| 53 | + |
| 54 | + print("Final best score:", best_score) |
| 55 | + |
| 56 | + return best_program |
| 57 | + |
| 58 | + def induce_natural_language_rules(self, predictor, trainset): |
| 59 | + demos = self.get_predictor_demos(trainset, predictor) |
| 60 | + signature = predictor.signature |
| 61 | + while True: |
| 62 | + examples_text = self.format_examples(demos, signature) |
| 63 | + try: |
| 64 | + return self.rules_induction_program(examples_text) |
| 65 | + except Exception as e: |
| 66 | + assert ( |
| 67 | + isinstance(e, ValueError) |
| 68 | + or e.__class__.__name__ == "BadRequestError" |
| 69 | + or "ContextWindowExceededError" in str(e) |
| 70 | + ) |
| 71 | + if len(demos) > 1: |
| 72 | + demos = demos[:-1] |
| 73 | + else: |
| 74 | + raise RuntimeError( |
| 75 | + "Failed to generate natural language rules since a single example couldn't fit in the model's context window." |
| 76 | + ) from e |
| 77 | + |
| 78 | + def update_program_instructions(self, predictor, natural_language_rules): |
| 79 | + predictor.signature.instructions = ( |
| 80 | + f"{predictor.signature.instructions}\n\n" |
| 81 | + f"Please adhere to the following rules when making your prediction:\n{natural_language_rules}" |
| 82 | + ) |
| 83 | + |
| 84 | + def format_examples(self, demos, signature): |
| 85 | + examples_text = "" |
| 86 | + for demo in demos: |
| 87 | + input_fields = {k: v for k, v in demo.items() if k in signature.input_fields} |
| 88 | + output_fields = {k: v for k, v in demo.items() if k in signature.output_fields} |
| 89 | + input_text = "\n".join(f"{k}: {v}" for k, v in input_fields.items()) |
| 90 | + output_text = "\n".join(f"{k}: {v}" for k, v in output_fields.items()) |
| 91 | + examples_text += f"Input Fields:\n{input_text}\n\n=========\nOutput Fields:\n{output_text}\n\n" |
| 92 | + return examples_text |
| 93 | + |
| 94 | + def get_predictor_demos(self, trainset, predictor): |
| 95 | + # TODO: Consider how this handled "incomplete" demos. |
| 96 | + signature = predictor.signature |
| 97 | + return [ |
| 98 | + { |
| 99 | + key: value |
| 100 | + for key, value in example.items() |
| 101 | + if key in signature.input_fields or key in signature.output_fields |
| 102 | + } |
| 103 | + for example in trainset |
| 104 | + ] |
| 105 | + |
| 106 | + def evaluate_program(self, program, dataset): |
| 107 | + evaluate = Evaluate( |
| 108 | + devset=dataset, |
| 109 | + metric=self.metric, |
| 110 | + num_threads=self.num_threads, |
| 111 | + max_errors=self.max_errors, |
| 112 | + display_table=False, |
| 113 | + display_progress=True, |
| 114 | + return_all_scores=True, |
| 115 | + ) |
| 116 | + score, _ = evaluate(program, metric=self.metric) |
| 117 | + return score |
| 118 | + |
| 119 | + |
| 120 | +class RulesInductionProgram(dspy.Module): |
| 121 | + def __init__(self, num_rules, teacher_settings=None): |
| 122 | + super().__init__() |
| 123 | + |
| 124 | + class CustomRulesInduction(dspy.Signature): |
| 125 | + __doc__ = f"""Given a set of examples, extract a list of {num_rules} concise and non-redundant natural language rules that provide clear guidance for performing the task. All rules should be actionable for a well-specified scope of examples of this general kind of task.""" |
| 126 | + examples_text = dspy.InputField(desc="Text containing examples") |
| 127 | + natural_language_rules = dspy.OutputField(desc="Induced natural language rules") |
| 128 | + |
| 129 | + self.rules_induction = dspy.ChainOfThought(CustomRulesInduction) |
| 130 | + self.teacher_settings = teacher_settings or {} |
| 131 | + self.rng = random.Random(0) |
| 132 | + |
| 133 | + def forward(self, examples_text): |
| 134 | + with dspy.settings.context(**self.teacher_settings): |
| 135 | + lm = dspy.settings.lm.copy(temperature=self.rng.uniform(0.9, 1.0)) |
| 136 | + with dspy.settings.context(lm=lm): |
| 137 | + rules = self.rules_induction(examples_text=examples_text).natural_language_rules |
| 138 | + |
| 139 | + return rules.strip() |
0 commit comments