Skip to content

Commit 0008226

Browse files
authored
client/types: add logprobs support (#601)
1 parent 9ddd5f0 commit 0008226

File tree

7 files changed

+193
-2
lines changed

7 files changed

+193
-2
lines changed

examples/chat-logprobs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Iterable
2+
3+
import ollama
4+
5+
6+
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
7+
print(f'\n{label}:')
8+
for entry in logprobs:
9+
token = entry.get('token', '')
10+
logprob = entry.get('logprob')
11+
print(f' token={token!r:<12} logprob={logprob:.3f}')
12+
for alt in entry.get('top_logprobs', []):
13+
if alt['token'] != token:
14+
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
15+
16+
17+
messages = [
18+
{
19+
'role': 'user',
20+
'content': 'hi! be concise.',
21+
},
22+
]
23+
24+
response = ollama.chat(
25+
model='gemma3',
26+
messages=messages,
27+
logprobs=True,
28+
top_logprobs=3,
29+
)
30+
print('Chat response:', response['message']['content'])
31+
print_logprobs(response.get('logprobs', []), 'chat logprobs')

examples/chat-with-history.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
},
1616
{
1717
'role': 'assistant',
18-
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
18+
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
19+
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
1920
},
2021
]
2122

examples/generate-logprobs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Iterable
2+
3+
import ollama
4+
5+
6+
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
7+
print(f'\n{label}:')
8+
for entry in logprobs:
9+
token = entry.get('token', '')
10+
logprob = entry.get('logprob')
11+
print(f' token={token!r:<12} logprob={logprob:.3f}')
12+
for alt in entry.get('top_logprobs', []):
13+
if alt['token'] != token:
14+
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
15+
16+
17+
response = ollama.generate(
18+
model='gemma3',
19+
prompt='hi! be concise.',
20+
logprobs=True,
21+
top_logprobs=3,
22+
)
23+
print('Generate response:', response['response'])
24+
print_logprobs(response.get('logprobs', []), 'generate logprobs')

ollama/_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def generate(
200200
context: Optional[Sequence[int]] = None,
201201
stream: Literal[False] = False,
202202
think: Optional[bool] = None,
203+
logprobs: Optional[bool] = None,
204+
top_logprobs: Optional[int] = None,
203205
raw: bool = False,
204206
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
205207
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -219,6 +221,8 @@ def generate(
219221
context: Optional[Sequence[int]] = None,
220222
stream: Literal[True] = True,
221223
think: Optional[bool] = None,
224+
logprobs: Optional[bool] = None,
225+
top_logprobs: Optional[int] = None,
222226
raw: bool = False,
223227
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
224228
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -237,6 +241,8 @@ def generate(
237241
context: Optional[Sequence[int]] = None,
238242
stream: bool = False,
239243
think: Optional[bool] = None,
244+
logprobs: Optional[bool] = None,
245+
top_logprobs: Optional[int] = None,
240246
raw: Optional[bool] = None,
241247
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
242248
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -266,6 +272,8 @@ def generate(
266272
context=context,
267273
stream=stream,
268274
think=think,
275+
logprobs=logprobs,
276+
top_logprobs=top_logprobs,
269277
raw=raw,
270278
format=format,
271279
images=list(_copy_images(images)) if images else None,
@@ -284,6 +292,8 @@ def chat(
284292
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
285293
stream: Literal[False] = False,
286294
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
295+
logprobs: Optional[bool] = None,
296+
top_logprobs: Optional[int] = None,
287297
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
288298
options: Optional[Union[Mapping[str, Any], Options]] = None,
289299
keep_alive: Optional[Union[float, str]] = None,
@@ -298,6 +308,8 @@ def chat(
298308
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
299309
stream: Literal[True] = True,
300310
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
311+
logprobs: Optional[bool] = None,
312+
top_logprobs: Optional[int] = None,
301313
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
302314
options: Optional[Union[Mapping[str, Any], Options]] = None,
303315
keep_alive: Optional[Union[float, str]] = None,
@@ -311,6 +323,8 @@ def chat(
311323
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
312324
stream: bool = False,
313325
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
326+
logprobs: Optional[bool] = None,
327+
top_logprobs: Optional[int] = None,
314328
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
315329
options: Optional[Union[Mapping[str, Any], Options]] = None,
316330
keep_alive: Optional[Union[float, str]] = None,
@@ -358,6 +372,8 @@ def add_two_numbers(a: int, b: int) -> int:
358372
tools=list(_copy_tools(tools)),
359373
stream=stream,
360374
think=think,
375+
logprobs=logprobs,
376+
top_logprobs=top_logprobs,
361377
format=format,
362378
options=options,
363379
keep_alive=keep_alive,
@@ -802,6 +818,8 @@ async def generate(
802818
context: Optional[Sequence[int]] = None,
803819
stream: Literal[False] = False,
804820
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
821+
logprobs: Optional[bool] = None,
822+
top_logprobs: Optional[int] = None,
805823
raw: bool = False,
806824
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
807825
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -821,6 +839,8 @@ async def generate(
821839
context: Optional[Sequence[int]] = None,
822840
stream: Literal[True] = True,
823841
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
842+
logprobs: Optional[bool] = None,
843+
top_logprobs: Optional[int] = None,
824844
raw: bool = False,
825845
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
826846
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -839,6 +859,8 @@ async def generate(
839859
context: Optional[Sequence[int]] = None,
840860
stream: bool = False,
841861
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
862+
logprobs: Optional[bool] = None,
863+
top_logprobs: Optional[int] = None,
842864
raw: Optional[bool] = None,
843865
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
844866
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
@@ -867,6 +889,8 @@ async def generate(
867889
context=context,
868890
stream=stream,
869891
think=think,
892+
logprobs=logprobs,
893+
top_logprobs=top_logprobs,
870894
raw=raw,
871895
format=format,
872896
images=list(_copy_images(images)) if images else None,
@@ -885,6 +909,8 @@ async def chat(
885909
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
886910
stream: Literal[False] = False,
887911
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
912+
logprobs: Optional[bool] = None,
913+
top_logprobs: Optional[int] = None,
888914
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
889915
options: Optional[Union[Mapping[str, Any], Options]] = None,
890916
keep_alive: Optional[Union[float, str]] = None,
@@ -899,6 +925,8 @@ async def chat(
899925
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
900926
stream: Literal[True] = True,
901927
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
928+
logprobs: Optional[bool] = None,
929+
top_logprobs: Optional[int] = None,
902930
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
903931
options: Optional[Union[Mapping[str, Any], Options]] = None,
904932
keep_alive: Optional[Union[float, str]] = None,
@@ -912,6 +940,8 @@ async def chat(
912940
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
913941
stream: bool = False,
914942
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
943+
logprobs: Optional[bool] = None,
944+
top_logprobs: Optional[int] = None,
915945
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
916946
options: Optional[Union[Mapping[str, Any], Options]] = None,
917947
keep_alive: Optional[Union[float, str]] = None,
@@ -960,6 +990,8 @@ def add_two_numbers(a: int, b: int) -> int:
960990
tools=list(_copy_tools(tools)),
961991
stream=stream,
962992
think=think,
993+
logprobs=logprobs,
994+
top_logprobs=top_logprobs,
963995
format=format,
964996
options=options,
965997
keep_alive=keep_alive,

ollama/_types.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest):
210210
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
211211
'Enable thinking mode (for thinking models).'
212212

213+
logprobs: Optional[bool] = None
214+
'Return log probabilities for generated tokens.'
215+
216+
top_logprobs: Optional[int] = None
217+
'Number of alternative tokens and log probabilities to include per position (0-20).'
218+
213219

214220
class BaseGenerateResponse(SubscriptableBaseModel):
215221
model: Optional[str] = None
@@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel):
243249
'Duration of evaluating inference in nanoseconds.'
244250

245251

252+
class TokenLogprob(SubscriptableBaseModel):
253+
token: str
254+
'Token text.'
255+
256+
logprob: float
257+
'Log probability for the token.'
258+
259+
260+
class Logprob(TokenLogprob):
261+
top_logprobs: Optional[Sequence[TokenLogprob]] = None
262+
'Most likely tokens and their log probabilities.'
263+
264+
246265
class GenerateResponse(BaseGenerateResponse):
247266
"""
248267
Response returned by generate requests.
@@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse):
257276
context: Optional[Sequence[int]] = None
258277
'Tokenized history up to the point of the response.'
259278

279+
logprobs: Optional[Sequence[Logprob]] = None
280+
'Log probabilities for generated tokens.'
281+
260282

261283
class Message(SubscriptableBaseModel):
262284
"""
@@ -360,6 +382,12 @@ def serialize_model(self, nxt):
360382
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
361383
'Enable thinking mode (for thinking models).'
362384

385+
logprobs: Optional[bool] = None
386+
'Return log probabilities for generated tokens.'
387+
388+
top_logprobs: Optional[int] = None
389+
'Number of alternative tokens and log probabilities to include per position (0-20).'
390+
363391

364392
class ChatResponse(BaseGenerateResponse):
365393
"""
@@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse):
369397
message: Message
370398
'Response message.'
371399

400+
logprobs: Optional[Sequence[Logprob]] = None
401+
'Log probabilities for generated tokens if requested.'
402+
372403

373404
class EmbedRequest(BaseRequest):
374405
input: Union[str, Sequence[str]]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
3737
config-path = 'none'
3838

3939
[tool.ruff]
40-
line-length = 999
40+
line-length = 320
4141
indent-width = 2
4242

4343
[tool.ruff.format]

tests/test_client.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
6161
assert response['message']['content'] == "I don't know."
6262

6363

64+
def test_client_chat_with_logprobs(httpserver: HTTPServer):
65+
httpserver.expect_ordered_request(
66+
'/api/chat',
67+
method='POST',
68+
json={
69+
'model': 'dummy',
70+
'messages': [{'role': 'user', 'content': 'Hi'}],
71+
'tools': [],
72+
'stream': False,
73+
'logprobs': True,
74+
'top_logprobs': 3,
75+
},
76+
).respond_with_json(
77+
{
78+
'model': 'dummy',
79+
'message': {
80+
'role': 'assistant',
81+
'content': 'Hello',
82+
},
83+
'logprobs': [
84+
{
85+
'token': 'Hello',
86+
'logprob': -0.1,
87+
'top_logprobs': [
88+
{'token': 'Hello', 'logprob': -0.1},
89+
{'token': 'Hi', 'logprob': -1.0},
90+
],
91+
}
92+
],
93+
}
94+
)
95+
96+
client = Client(httpserver.url_for('/'))
97+
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
98+
assert response['logprobs'][0]['token'] == 'Hello'
99+
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
100+
101+
64102
def test_client_chat_stream(httpserver: HTTPServer):
65103
def stream_handler(_: Request):
66104
def generate():
@@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
294332
assert response['response'] == 'Because it is.'
295333

296334

335+
def test_client_generate_with_logprobs(httpserver: HTTPServer):
336+
httpserver.expect_ordered_request(
337+
'/api/generate',
338+
method='POST',
339+
json={
340+
'model': 'dummy',
341+
'prompt': 'Why',
342+
'stream': False,
343+
'logprobs': True,
344+
'top_logprobs': 2,
345+
},
346+
).respond_with_json(
347+
{
348+
'model': 'dummy',
349+
'response': 'Hello',
350+
'logprobs': [
351+
{
352+
'token': 'Hello',
353+
'logprob': -0.2,
354+
'top_logprobs': [
355+
{'token': 'Hello', 'logprob': -0.2},
356+
{'token': 'Hi', 'logprob': -1.5},
357+
],
358+
}
359+
],
360+
}
361+
)
362+
363+
client = Client(httpserver.url_for('/'))
364+
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
365+
assert response['logprobs'][0]['token'] == 'Hello'
366+
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
367+
368+
297369
def test_client_generate_with_image_type(httpserver: HTTPServer):
298370
httpserver.expect_ordered_request(
299371
'/api/generate',

0 commit comments

Comments
 (0)