Better docstring parsing and some fixes

This commit is contained in:
ParthSareen
2024-11-19 17:53:25 -08:00
parent 6d9c156080
commit c5c61a3b04
2 changed files with 47 additions and 9 deletions
+16 -8
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
from collections import defaultdict
import inspect
from typing import Callable, Union
import re
import pydantic
from ollama._types import Tool
@@ -14,7 +15,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
key = hash(doc_string)
for line in doc_string.splitlines():
lowered_line = line.lower()
lowered_line = line.lower().strip()
if lowered_line.startswith('args:'):
key = 'args'
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
@@ -27,14 +28,21 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if ':' in line and not line.lower().startswith('args:'):
# Split on first occurrence of '(' or ':' to separate arg name from description
split_char = '(' if '(' in line else ':'
arg_name, rest = line.split(split_char, 1)
if ':' in line:
# Split the line on either:
# 1. A parenthetical expression like (integer) - captured in group 1
# 2. A colon :
# Followed by optional whitespace. Only split on first occurrence.
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)
arg_name = parts[0].strip()
last_key = arg_name
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
arg_description = parts[-1].strip()
if len(parts) > 2 and parts[1]: # Has parenthetical content
arg_description = parts[-1].split(':', 1)[-1].strip()
last_key = arg_name.strip()
# Get description after the colon
arg_description = rest.split(':', 1)[1].strip() if split_char == '(' else rest.strip()
parsed_docstring[last_key] = arg_description
elif last_key and line:
+31 -1
View File
@@ -147,7 +147,7 @@ def test_skewed_docstring_parsing():
"""
Add two numbers together.
Args:
x (integer):: The first number
x (integer): : The first number
@@ -238,3 +238,33 @@ def test_function_with_no_types():
tool = convert_function_to_tool(no_types).model_dump()
assert tool['function']['parameters']['properties']['a']['type'] == 'string'
assert tool['function']['parameters']['properties']['b']['type'] == 'string'
def test_function_with_parentheses():
def func_with_parentheses(a: int, b: int) -> int:
"""
A function with parentheses.
Args:
a: First (:thing) number to add
b: Second number to add
Returns:
int: The sum of a and b
"""
pass
def func_with_parentheses_and_args(a: int, b: int):
"""
A function with parentheses and args.
Args:
a(integer) : First (:thing) number to add
b(integer) :Second number to add
"""
pass
tool = convert_function_to_tool(func_with_parentheses).model_dump()
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'
tool = convert_function_to_tool(func_with_parentheses_and_args).model_dump()
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'