Source code for sktime_mcp.tools.list_estimators
"""
Discovery tools for sktime MCP.
Provides query_registry, list_estimators, and get_available_tags tools.
"""
import difflib
from typing import Any
from sktime_mcp.registry.interface import get_registry
[docs]
def query_registry_tool(
target: str = "estimators",
task: str | None = None,
tags: dict[str, Any] | None = None,
query: str | None = None,
limit: int = 50,
offset: int = 0,
) -> dict[str, Any]:
"""
Unified entry point to query the sktime registry for estimators, capability tags, or performance metrics.
Args:
target: What registry target to search. One of: "estimators", "tags", "metrics".
task: Filter by task type (e.g., "forecasting", "transformation", "classification",
"regression", "clustering", "splitting", "detection", "alignment",
"parameter_estimation", "network"). Applies to "estimators" and "metrics".
tags: Key-value pair of capability tag filters (e.g., {"capability:pred_int": True}).
Applies to "estimators".
query: Substring search over name, module, or docstring. Applies to "estimators".
limit: Maximum number of results to return (default: 50).
offset: Number of results to skip for pagination (default: 0).
Returns:
Dictionary with:
- success: bool
- results: List of matching components or tags
- count: Number of results in this page
- total: Total matching results
- offset: Current offset
- limit: Current limit
- has_more: True if more results exist
"""
registry = get_registry()
try:
# Validate target
valid_targets = ["estimators", "tags", "metrics"]
if target not in valid_targets:
return {
"success": False,
"error": f"Invalid target: '{target}'. Valid targets: {valid_targets}",
}
# Check pagination bounds
if offset < 0:
return {"success": False, "error": "offset must be a non-negative integer."}
if limit < 1:
return {"success": False, "error": "limit must be a positive integer."}
# Handle target: tags
if target == "tags":
all_tags = registry.get_available_tags()
# If query is provided, filter tags by name/description
if query:
q_lower = query.lower()
all_tags = [
t
for t in all_tags
if q_lower in t.get("tag", "").lower()
or q_lower in t.get("description", "").lower()
]
total = len(all_tags)
page = all_tags[offset : offset + limit]
return {
"success": True,
"results": page,
"count": len(page),
"total": total,
"offset": offset,
"limit": limit,
"has_more": (offset + limit) < total,
}
# Handle target: metrics or estimators
# Validate task if provided
if task is not None:
valid_tasks = registry.get_available_tasks()
if task not in valid_tasks:
suggestions = difflib.get_close_matches(task, valid_tasks, n=3, cutoff=0.6)
return {
"success": False,
"error": f"Invalid task: '{task}'. Valid options: {valid_tasks}."
+ (f" Did you mean: {suggestions}?" if suggestions else ""),
}
# Validate tag keys if provided
if tags is not None:
valid_tag_keys = {t["tag"] for t in registry.get_available_tags()}
invalid_keys = [k for k in tags if k not in valid_tag_keys]
if invalid_keys:
suggestions = {
k: difflib.get_close_matches(k, valid_tag_keys, n=1, cutoff=0.6)
for k in invalid_keys
}
return {
"success": False,
"error": f"Invalid tag key(s): {invalid_keys}. Use target='tags' to see valid keys.",
"suggestions": {k: v[0] if v else None for k, v in suggestions.items()},
}
# Fetch base list of components
if target == "metrics":
# Metrics are components with task == "metric"
components = registry.get_all_estimators(task="metric")
else:
# Estimators are everything except metrics
components = [e for e in registry.get_all_estimators() if e.task != "metric"]
# Apply task filter
if task:
components = [e for e in components if e.task == task]
# Apply tags filter
if tags:
components = registry._filter_by_tags(components, tags)
# Apply query search if provided
if query:
q_lower = query.lower()
filtered_components = []
for node in components:
name_lower = node.name.lower()
module_lower = node.module.lower()
doc_lower = node.docstring.lower() if node.docstring else ""
if q_lower in name_lower or q_lower in module_lower or q_lower in doc_lower:
filtered_components.append(node)
components = filtered_components
# Pagination
total = len(components)
page = components[offset : offset + limit]
results = [est.to_summary() for est in page]
return {
"success": True,
"results": results,
"count": len(results),
"total": total,
"offset": offset,
"limit": limit,
"has_more": (offset + limit) < total,
"target": target,
"task_filter": task,
"tag_filter": tags,
"query": query,
}
except Exception as e:
return {"success": False, "error": str(e)}
# --- Backward Compatibility Wrappers ---
[docs]
def list_estimators_tool(
task: str | None = None,
tags: dict[str, Any] | None = None,
query: str | None = None,
limit: int = 50,
offset: int = 0,
) -> dict[str, Any]:
"""Deprecated: Use query_registry_tool with target='estimators' instead."""
res = query_registry_tool(
target="estimators",
task=task,
tags=tags,
query=query,
limit=limit,
offset=offset,
)
if not res["success"]:
return res
return {
"success": True,
"estimators": res["results"],
"count": res["count"],
"total": res["total"],
"offset": res["offset"],
"limit": res["limit"],
"has_more": res["has_more"],
"task_filter": task,
"tag_filter": tags,
"query": query,
}
def get_available_tasks() -> dict[str, Any]:
"""Get list of available task types."""
registry = get_registry()
return {
"success": True,
"tasks": registry.get_available_tasks(),
}
[docs]
def get_available_tags() -> dict[str, Any]:
"""Deprecated: Use query_registry_tool with target='tags' instead."""
registry = get_registry()
return {
"success": True,
"tags": registry.get_available_tags(),
}