Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,13 @@ def _parse_schema_from_parameter(
),
func_name,
)
required_fields = [
field_name
for field_name, field_info in param.annotation.model_fields.items()
if field_info.is_required()
]
if required_fields:
schema.required = required_fields
_raise_if_schema_unsupported(variant, schema)
return schema
if inspect.isclass(param.annotation) and issubclass(
Expand Down
59 changes: 59 additions & 0 deletions tests/unittests/tools/test_build_function_declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,65 @@ def simple_function(input: CustomInput) -> str:
)


def test_basemodel_required_fields():
class SearchRequest(BaseModel):
query: str
max_results: int
filter: str = ''

def search(request: SearchRequest) -> list:
return []

function_decl = _automatic_function_calling_util.build_function_declaration(
func=search
)

inner = function_decl.parameters.properties['request']
assert set(inner.required) == {'query', 'max_results'}
assert 'filter' not in (inner.required or [])


def test_basemodel_all_optional_fields_no_required():
class Config(BaseModel):
timeout: int = 30
retries: int = 3

def run(config: Config) -> str:
return ''

function_decl = _automatic_function_calling_util.build_function_declaration(
func=run
)

inner = function_decl.parameters.properties['config']
assert not inner.required


def test_nested_basemodel_required_fields():
class Inner(BaseModel):
x: int
y: int = 0

class Outer(BaseModel):
inner: Inner
label: str = ''

def process(data: Outer) -> str:
return ''

function_decl = _automatic_function_calling_util.build_function_declaration(
func=process
)

outer = function_decl.parameters.properties['data']
assert set(outer.required) == {'inner'}
assert 'label' not in (outer.required or [])

inner = outer.properties['inner']
assert set(inner.required) == {'x'}
assert 'y' not in (inner.required or [])


def test_toolcontext_ignored():
def simple_function(input_str: str, tool_context: ToolContext) -> str:
return {'result': input_str}
Expand Down