Files
miem_workers/app/mcp.py

181 lines
7.0 KiB
Python

import json
from fastapi import APIRouter, Depends, Request
from sqlalchemy import desc, or_, select
from sqlalchemy.orm import Session
from app.db import get_db
from app.models import CrawlRun, Employee
from app.services.admin_data import run_detail_payload
from app.version import BACKEND_VERSION
router = APIRouter(prefix="/mcp")
TOOLS = [
{
"name": "search_employees",
"description": "Search MIEM employees by name or profile URL.",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string"},
"status": {"type": "string", "enum": ["active", "dismissed"]},
"limit": {"type": "integer", "default": 20},
},
"required": ["query"],
},
},
{
"name": "get_employee",
"description": "Get one employee by profile id, profile key, or canonical URL.",
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
},
{
"name": "list_employee_publications",
"description": "List publications parsed from an employee profile.",
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
},
{
"name": "list_employee_courses",
"description": "List teaching courses parsed from an employee profile.",
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
},
{
"name": "get_crawl_status",
"description": "Return the latest crawl run status.",
"inputSchema": {"type": "object", "properties": {}},
},
{
"name": "get_crawl_run_details",
"description": "Return detailed employee changes and errors for one crawl run.",
"inputSchema": {
"type": "object",
"properties": {"run_id": {"type": "integer"}},
"required": ["run_id"],
},
},
]
@router.post("")
async def mcp_http(
request: Request,
db: Session = Depends(get_db),
) -> dict:
payload = await request.json()
method = payload.get("method")
request_id = payload.get("id")
params = payload.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "miem-employees", "version": BACKEND_VERSION},
"capabilities": {"tools": {}},
}
elif method == "tools/list":
result = {"tools": TOOLS}
elif method == "tools/call":
result = _call_tool(db, params.get("name"), params.get("arguments") or {})
else:
return {"jsonrpc": "2.0", "id": request_id, "error": {"code": -32601, "message": "Method not found"}}
return {"jsonrpc": "2.0", "id": request_id, "result": result}
except Exception as exc:
return {"jsonrpc": "2.0", "id": request_id, "error": {"code": -32000, "message": str(exc)}}
def _call_tool(db: Session, name: str, arguments: dict) -> dict:
if name == "search_employees":
return _tool_response(_search_employees(db, arguments))
if name == "get_employee":
employee = _find_employee(db, arguments["profile_id_or_url"])
return _tool_response(_employee_payload(employee) if employee else {"error": "not_found"})
if name == "list_employee_publications":
employee = _find_employee(db, arguments["profile_id_or_url"])
return _tool_response(_collect_section_items(employee, "publications"))
if name == "list_employee_courses":
employee = _find_employee(db, arguments["profile_id_or_url"])
return _tool_response(_collect_section_items(employee, "courses_by_year"))
if name == "get_crawl_status":
run = db.scalar(select(CrawlRun).order_by(desc(CrawlRun.started_at)).limit(1))
return _tool_response(_run_payload(run) if run else {"status": "never_run"})
if name == "get_crawl_run_details":
run = db.get(CrawlRun, int(arguments["run_id"]))
return _tool_response(run_detail_payload(db, run) if run else {"error": "not_found"})
raise ValueError(f"Unknown tool: {name}")
def _search_employees(db: Session, arguments: dict) -> list[dict]:
query = arguments.get("query", "")
limit = min(int(arguments.get("limit") or 20), 100)
stmt = select(Employee)
if arguments.get("status"):
stmt = stmt.where(Employee.status == arguments["status"])
if query:
pattern = f"%{query}%"
stmt = stmt.where(or_(Employee.full_name.ilike(pattern), Employee.canonical_url.ilike(pattern)))
employees = db.scalars(stmt.order_by(Employee.full_name).limit(limit)).all()
return [_employee_payload(employee, include_data=False) for employee in employees]
def _find_employee(db: Session, value: str) -> Employee | None:
pattern = value.strip()
stmt = select(Employee).where(
or_(
Employee.profile_key == pattern,
Employee.profile_id == pattern,
Employee.canonical_url == pattern,
Employee.canonical_url.ilike(f"%{pattern}%"),
)
)
return db.scalar(stmt.limit(1))
def _collect_section_items(employee: Employee | None, section_type: str) -> dict:
if not employee or not employee.current_data:
return {"items": []}
items = []
for section in employee.current_data.get("sections") or []:
if section.get("type") != section_type:
continue
if section_type == "publications":
items.extend(section.get("publications") or [])
elif section_type == "courses_by_year":
items.extend(section.get("courses") or [])
return {"employee": _employee_payload(employee, include_data=False), "items": items}
def _employee_payload(employee: Employee, include_data: bool = True) -> dict:
payload = {
"profile_key": employee.profile_key,
"profile_id": employee.profile_id,
"full_name": employee.full_name,
"status": employee.status,
"canonical_url": employee.canonical_url,
"last_seen_at": employee.last_seen_at.isoformat() if employee.last_seen_at else None,
"dismissed_at": employee.dismissed_at.isoformat() if employee.dismissed_at else None,
}
if include_data:
payload["data"] = employee.current_data
return payload
def _run_payload(run: CrawlRun) -> dict:
return {
"id": run.id,
"status": run.status,
"source_url": run.source_url,
"started_at": run.started_at.isoformat() if run.started_at else None,
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
"found_count": run.found_count,
"parsed_count": run.parsed_count,
"error_count": run.error_count,
"dismissed_count": run.dismissed_count,
}
def _tool_response(data: object) -> dict:
return {"content": [{"type": "text", "text": json.dumps(data, ensure_ascii=False, default=str)}]}