mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-14 06:07:17 +08:00
fix: update async stream tests
This commit is contained in:
parent
38f68e251d
commit
21aad8447c
@ -81,9 +81,11 @@ def test_client_chat_stream(httpserver: HTTPServer):
|
|||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
client = Client(httpserver.url_for('/'))
|
||||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||||
|
|
||||||
|
it = iter(['I ', "don't ", 'know.'])
|
||||||
for part in response:
|
for part in response:
|
||||||
assert part['message']['role'] in 'assistant'
|
assert part['message']['role'] in 'assistant'
|
||||||
assert part['message']['content'] in ['I ', "don't ", 'know.']
|
assert part['message']['content'] == next(it)
|
||||||
|
|
||||||
|
|
||||||
def test_client_chat_images(httpserver: HTTPServer):
|
def test_client_chat_images(httpserver: HTTPServer):
|
||||||
@ -187,9 +189,11 @@ def test_client_generate_stream(httpserver: HTTPServer):
|
|||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
client = Client(httpserver.url_for('/'))
|
||||||
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||||
|
|
||||||
|
it = iter(['Because ', 'it ', 'is.'])
|
||||||
for part in response:
|
for part in response:
|
||||||
assert part['model'] == 'dummy'
|
assert part['model'] == 'dummy'
|
||||||
assert part['response'] in ['Because ', 'it ', 'is.']
|
assert part['response'] == next(it)
|
||||||
|
|
||||||
|
|
||||||
def test_client_generate_images(httpserver: HTTPServer):
|
def test_client_generate_images(httpserver: HTTPServer):
|
||||||
@ -458,6 +462,24 @@ async def test_async_client_chat(httpserver: HTTPServer):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_client_chat_stream(httpserver: HTTPServer):
|
async def test_async_client_chat_stream(httpserver: HTTPServer):
|
||||||
|
def stream_handler(_: Request):
|
||||||
|
def generate():
|
||||||
|
for message in ['I ', "don't ", 'know.']:
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(generate())
|
||||||
|
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/chat',
|
'/api/chat',
|
||||||
method='POST',
|
method='POST',
|
||||||
@ -468,11 +490,15 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
|
|||||||
'format': '',
|
'format': '',
|
||||||
'options': {},
|
'options': {},
|
||||||
},
|
},
|
||||||
).respond_with_json({})
|
).respond_with_handler(stream_handler)
|
||||||
|
|
||||||
client = AsyncClient(httpserver.url_for('/'))
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||||
assert isinstance(response, types.AsyncGeneratorType)
|
|
||||||
|
it = iter(['I ', "don't ", 'know.'])
|
||||||
|
async for part in response:
|
||||||
|
assert part['message']['role'] == 'assistant'
|
||||||
|
assert part['message']['content'] == next(it)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -529,6 +555,21 @@ async def test_async_client_generate(httpserver: HTTPServer):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_client_generate_stream(httpserver: HTTPServer):
|
async def test_async_client_generate_stream(httpserver: HTTPServer):
|
||||||
|
def stream_handler(_: Request):
|
||||||
|
def generate():
|
||||||
|
for message in ['Because ', 'it ', 'is.']:
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(generate())
|
||||||
|
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/generate',
|
'/api/generate',
|
||||||
method='POST',
|
method='POST',
|
||||||
@ -544,11 +585,15 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
|
|||||||
'format': '',
|
'format': '',
|
||||||
'options': {},
|
'options': {},
|
||||||
},
|
},
|
||||||
).respond_with_json({})
|
).respond_with_handler(stream_handler)
|
||||||
|
|
||||||
client = AsyncClient(httpserver.url_for('/'))
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
|
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||||
assert isinstance(response, types.AsyncGeneratorType)
|
|
||||||
|
it = iter(['Because ', 'it ', 'is.'])
|
||||||
|
async for part in response:
|
||||||
|
assert part['model'] == 'dummy'
|
||||||
|
assert part['response'] == next(it)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user