|
19 | 19 | "source": [ |
20 | 20 | "from typing import Any, Dict\n", |
21 | 21 | "import json\n", |
| 22 | + "import re\n", |
22 | 23 | "\n", |
23 | 24 | "def parse_json(input_string: str):\n", |
24 | 25 | " \"\"\"\n", |
|
88 | 89 | }, |
89 | 90 | { |
90 | 91 | "cell_type": "code", |
91 | | - "execution_count": null, |
| 92 | + "execution_count": 4, |
92 | 93 | "id": "6a447f4a-5aac-4f85-8393-6f9bade1ce30", |
93 | 94 | "metadata": {}, |
94 | 95 | "outputs": [], |
95 | 96 | "source": [ |
96 | 97 | "import yaml\n", |
97 | 98 | "\n", |
98 | | - "with open('facility_v2_train.json') as stream:\n", |
| 99 | + "with open('dataset.json') as stream:\n", |
99 | 100 | " dataset = json.load(stream)\n", |
100 | 101 | "\n", |
101 | 102 | "with open('facility_prompt.yaml') as stream:\n", |
|
104 | 105 | }, |
105 | 106 | { |
106 | 107 | "cell_type": "code", |
107 | | - "execution_count": null, |
| 108 | + "execution_count": 5, |
108 | 109 | "id": "ba35209a-c778-4d9b-a575-a27bb9078caf", |
109 | 110 | "metadata": {}, |
110 | | - "outputs": [], |
| 111 | + "outputs": [ |
| 112 | + { |
| 113 | + "data": { |
| 114 | + "text/plain": [ |
| 115 | + "60" |
| 116 | + ] |
| 117 | + }, |
| 118 | + "execution_count": 5, |
| 119 | + "metadata": {}, |
| 120 | + "output_type": "execute_result" |
| 121 | + } |
| 122 | + ], |
111 | 123 | "source": [ |
112 | 124 | "dataset_test = dataset[int(len(dataset)*0.7):]\n", |
113 | 125 | "len(dataset_test)" |
114 | 126 | ] |
115 | 127 | }, |
116 | | - { |
117 | | - "cell_type": "code", |
118 | | - "execution_count": null, |
119 | | - "id": "07f83eb5-a957-4f77-b6c9-fdd943f58cc2", |
120 | | - "metadata": {}, |
121 | | - "outputs": [], |
122 | | - "source": [ |
123 | | - "# from openai import OpenAI\n", |
124 | | - "from gen_ai_hub.proxy.native.openai import OpenAI" |
125 | | - ] |
126 | | - }, |
127 | | - { |
128 | | - "cell_type": "code", |
129 | | - "execution_count": null, |
130 | | - "id": "3f5e7904-2ca8-4784-9937-6fa824b3d109", |
131 | | - "metadata": {}, |
132 | | - "outputs": [], |
133 | | - "source": [ |
134 | | - "client = OpenAI()" |
135 | | - ] |
136 | | - }, |
137 | 128 | { |
138 | 129 | "cell_type": "code", |
139 | 130 | "execution_count": null, |
140 | 131 | "id": "e68c6fd3-b191-47e4-9d18-9fa1596ecb50", |
141 | 132 | "metadata": {}, |
142 | | - "outputs": [], |
| 133 | + "outputs": [ |
| 134 | + { |
| 135 | + "name": "stderr", |
| 136 | + "output_type": "stream", |
| 137 | + "text": [ |
| 138 | + "/Users/justinai/anaconda3/envs/sap-prompt-opt/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |
| 139 | + " from .autonotebook import tqdm as notebook_tqdm\n", |
| 140 | + "Processing batches: 92%|█████████▏| 11/12 [00:58<00:05, 5.46s/it]" |
| 141 | + ] |
| 142 | + } |
| 143 | + ], |
143 | 144 | "source": [ |
| 145 | + "import asyncio\n", |
144 | 146 | "from tqdm.auto import tqdm\n", |
| 147 | + "from openai import AsyncOpenAI\n", |
| 148 | + "import os\n", |
145 | 149 | "\n", |
146 | | - "result = []\n", |
| 150 | + "# Configure OpenRouter client\n", |
| 151 | + "client = AsyncOpenAI(\n", |
| 152 | + " base_url=\"https://openrouter.ai/api/v1\",\n", |
| 153 | + " api_key=os.getenv(\"OPENROUTER_API_KEY\"), # Make sure to set this environment variable\n", |
| 154 | + ")\n", |
147 | 155 | "\n", |
148 | | - "for entry in tqdm(dataset_test):\n", |
149 | | - " output = client.chat.completions.create(\n", |
150 | | - " model=\"gpt-4o\",\n", |
151 | | - " messages=[\n", |
152 | | - " {\"role\": \"system\", \"content\": prompt[\"system\"]},\n", |
153 | | - " {\"role\": \"user\", \"content\": prompt[\"user\"].format(**entry[\"fields\"])},\n", |
154 | | - " ],\n", |
155 | | - " temperature=0.\n", |
156 | | - " )\n", |
157 | | - " prediction = output.choices[0].message.content\n", |
158 | | - " result.append(evaluate(entry[\"answer\"], prediction))\n", |
| 156 | + "async def process_entry(entry):\n", |
| 157 | + " \"\"\"Process a single entry with OpenRouter\"\"\"\n", |
| 158 | + " try:\n", |
| 159 | + " output = await client.chat.completions.create(\n", |
| 160 | + " model=\"meta-llama/llama-3.3-70b-instruct\",\n", |
| 161 | + " messages=[\n", |
| 162 | + " {\"role\": \"system\", \"content\": prompt[\"system\"]},\n", |
| 163 | + " {\"role\": \"user\", \"content\": prompt[\"user\"].format(**entry[\"fields\"])},\n", |
| 164 | + " ],\n", |
| 165 | + " temperature=0.\n", |
| 166 | + " )\n", |
| 167 | + " prediction = output.choices[0].message.content\n", |
| 168 | + " return evaluate(entry[\"answer\"], prediction)\n", |
| 169 | + " except Exception as e:\n", |
| 170 | + " print(f\"Error processing entry: {e}\")\n", |
| 171 | + " return {\"error\": str(e)}\n", |
| 172 | + "\n", |
| 173 | + "async def process_batch(entries, batch_size=10):\n", |
| 174 | + " \"\"\"Process entries in batches to avoid rate limits\"\"\"\n", |
| 175 | + " results = []\n", |
| 176 | + " \n", |
| 177 | + " for i in tqdm(range(0, len(entries), batch_size), desc=\"Processing batches\"):\n", |
| 178 | + " batch = entries[i:i + batch_size]\n", |
| 179 | + " batch_results = await asyncio.gather(*[process_entry(entry) for entry in batch])\n", |
| 180 | + " results.extend(batch_results)\n", |
| 181 | + " \n", |
| 182 | + " # Optional: Add a small delay between batches to be respectful to the API\n", |
| 183 | + " if i + batch_size < len(entries):\n", |
| 184 | + " await asyncio.sleep(0.1)\n", |
| 185 | + " \n", |
| 186 | + " return results\n", |
| 187 | + "\n", |
| 188 | + "# Run the batch processing\n", |
| 189 | + "result = await process_batch(dataset_test, batch_size=24) # Adjust batch_size as needed\n", |
159 | 190 | "\n", |
160 | 191 | " " |
161 | 192 | ] |
|
174 | 205 | { |
175 | 206 | "cell_type": "code", |
176 | 207 | "execution_count": null, |
177 | | - "id": "baf28a90-19a6-44c7-af1d-125b01cf21fa", |
| 208 | + "id": "f751aad4-c534-4f9c-a3c2-fffc06a1c485", |
178 | 209 | "metadata": {}, |
179 | 210 | "outputs": [], |
180 | | - "source": [ |
181 | | - "# gpt-4o -> {'is_valid_json': 0.967, 'correct_categories': 0.895, 'correct_sentiment': 0.517, 'correct_urgency': 0.767, 'total': 0.726}" |
182 | | - ] |
| 211 | + "source": [] |
183 | 212 | }, |
184 | 213 | { |
185 | 214 | "cell_type": "code", |
186 | 215 | "execution_count": null, |
187 | | - "id": "f751aad4-c534-4f9c-a3c2-fffc06a1c485", |
| 216 | + "id": "3336a411", |
188 | 217 | "metadata": {}, |
189 | 218 | "outputs": [], |
190 | 219 | "source": [] |
191 | 220 | } |
192 | 221 | ], |
193 | 222 | "metadata": { |
194 | 223 | "kernelspec": { |
195 | | - "display_name": "Python 3 (ipykernel)", |
| 224 | + "display_name": "sap-prompt-opt", |
196 | 225 | "language": "python", |
197 | 226 | "name": "python3" |
198 | 227 | }, |
|
206 | 235 | "name": "python", |
207 | 236 | "nbconvert_exporter": "python", |
208 | 237 | "pygments_lexer": "ipython3", |
209 | | - "version": "3.11.11" |
| 238 | + "version": "3.9.21" |
210 | 239 | } |
211 | 240 | }, |
212 | 241 | "nbformat": 4, |
|
0 commit comments