177 lines
6.8 KiB
Python
177 lines
6.8 KiB
Python
import json
|
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
from sqlalchemy import desc, or_, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import Settings, get_settings
|
|
from app.db import get_db
|
|
from app.models import CrawlRun, Employee
|
|
from app.security import mcp_protected_resource_metadata, require_mcp_auth
|
|
|
|
router = APIRouter(prefix="/mcp")
|
|
metadata_router = APIRouter()
|
|
|
|
|
|
TOOLS = [
|
|
{
|
|
"name": "search_employees",
|
|
"description": "Search MIEM employees by name or profile URL.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"status": {"type": "string", "enum": ["active", "dismissed"]},
|
|
"limit": {"type": "integer", "default": 20},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
{
|
|
"name": "get_employee",
|
|
"description": "Get one employee by profile id, profile key, or canonical URL.",
|
|
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
|
|
},
|
|
{
|
|
"name": "list_employee_publications",
|
|
"description": "List publications parsed from an employee profile.",
|
|
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
|
|
},
|
|
{
|
|
"name": "list_employee_courses",
|
|
"description": "List teaching courses parsed from an employee profile.",
|
|
"inputSchema": {"type": "object", "properties": {"profile_id_or_url": {"type": "string"}}, "required": ["profile_id_or_url"]},
|
|
},
|
|
{
|
|
"name": "get_crawl_status",
|
|
"description": "Return the latest crawl run status.",
|
|
"inputSchema": {"type": "object", "properties": {}},
|
|
},
|
|
]
|
|
|
|
|
|
@router.post("")
|
|
async def mcp_http(
|
|
request: Request,
|
|
db: Session = Depends(get_db),
|
|
settings: Settings = Depends(get_settings),
|
|
) -> dict:
|
|
require_mcp_auth(request, settings)
|
|
payload = await request.json()
|
|
method = payload.get("method")
|
|
request_id = payload.get("id")
|
|
params = payload.get("params") or {}
|
|
|
|
try:
|
|
if method == "initialize":
|
|
result = {
|
|
"protocolVersion": "2024-11-05",
|
|
"serverInfo": {"name": "miem-employees", "version": "0.1.0"},
|
|
"capabilities": {"tools": {}},
|
|
}
|
|
elif method == "tools/list":
|
|
result = {"tools": TOOLS}
|
|
elif method == "tools/call":
|
|
result = _call_tool(db, params.get("name"), params.get("arguments") or {})
|
|
else:
|
|
return {"jsonrpc": "2.0", "id": request_id, "error": {"code": -32601, "message": "Method not found"}}
|
|
return {"jsonrpc": "2.0", "id": request_id, "result": result}
|
|
except Exception as exc:
|
|
return {"jsonrpc": "2.0", "id": request_id, "error": {"code": -32000, "message": str(exc)}}
|
|
|
|
|
|
def _call_tool(db: Session, name: str, arguments: dict) -> dict:
|
|
if name == "search_employees":
|
|
return _tool_response(_search_employees(db, arguments))
|
|
if name == "get_employee":
|
|
employee = _find_employee(db, arguments["profile_id_or_url"])
|
|
return _tool_response(_employee_payload(employee) if employee else {"error": "not_found"})
|
|
if name == "list_employee_publications":
|
|
employee = _find_employee(db, arguments["profile_id_or_url"])
|
|
return _tool_response(_collect_section_items(employee, "publications"))
|
|
if name == "list_employee_courses":
|
|
employee = _find_employee(db, arguments["profile_id_or_url"])
|
|
return _tool_response(_collect_section_items(employee, "courses_by_year"))
|
|
if name == "get_crawl_status":
|
|
run = db.scalar(select(CrawlRun).order_by(desc(CrawlRun.started_at)).limit(1))
|
|
return _tool_response(_run_payload(run) if run else {"status": "never_run"})
|
|
raise ValueError(f"Unknown tool: {name}")
|
|
|
|
|
|
def _search_employees(db: Session, arguments: dict) -> list[dict]:
|
|
query = arguments.get("query", "")
|
|
limit = min(int(arguments.get("limit") or 20), 100)
|
|
stmt = select(Employee)
|
|
if arguments.get("status"):
|
|
stmt = stmt.where(Employee.status == arguments["status"])
|
|
if query:
|
|
pattern = f"%{query}%"
|
|
stmt = stmt.where(or_(Employee.full_name.ilike(pattern), Employee.canonical_url.ilike(pattern)))
|
|
employees = db.scalars(stmt.order_by(Employee.full_name).limit(limit)).all()
|
|
return [_employee_payload(employee, include_data=False) for employee in employees]
|
|
|
|
|
|
def _find_employee(db: Session, value: str) -> Employee | None:
|
|
pattern = value.strip()
|
|
stmt = select(Employee).where(
|
|
or_(
|
|
Employee.profile_key == pattern,
|
|
Employee.profile_id == pattern,
|
|
Employee.canonical_url == pattern,
|
|
Employee.canonical_url.ilike(f"%{pattern}%"),
|
|
)
|
|
)
|
|
return db.scalar(stmt.limit(1))
|
|
|
|
|
|
def _collect_section_items(employee: Employee | None, section_type: str) -> dict:
|
|
if not employee or not employee.current_data:
|
|
return {"items": []}
|
|
items = []
|
|
for section in employee.current_data.get("sections") or []:
|
|
if section.get("type") != section_type:
|
|
continue
|
|
if section_type == "publications":
|
|
items.extend(section.get("publications") or [])
|
|
elif section_type == "courses_by_year":
|
|
items.extend(section.get("courses") or [])
|
|
return {"employee": _employee_payload(employee, include_data=False), "items": items}
|
|
|
|
|
|
def _employee_payload(employee: Employee, include_data: bool = True) -> dict:
|
|
payload = {
|
|
"profile_key": employee.profile_key,
|
|
"profile_id": employee.profile_id,
|
|
"full_name": employee.full_name,
|
|
"status": employee.status,
|
|
"canonical_url": employee.canonical_url,
|
|
"last_seen_at": employee.last_seen_at.isoformat() if employee.last_seen_at else None,
|
|
"dismissed_at": employee.dismissed_at.isoformat() if employee.dismissed_at else None,
|
|
}
|
|
if include_data:
|
|
payload["data"] = employee.current_data
|
|
return payload
|
|
|
|
|
|
def _run_payload(run: CrawlRun) -> dict:
|
|
return {
|
|
"id": run.id,
|
|
"status": run.status,
|
|
"source_url": run.source_url,
|
|
"started_at": run.started_at.isoformat() if run.started_at else None,
|
|
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
|
|
"found_count": run.found_count,
|
|
"parsed_count": run.parsed_count,
|
|
"error_count": run.error_count,
|
|
"dismissed_count": run.dismissed_count,
|
|
}
|
|
|
|
|
|
def _tool_response(data: object) -> dict:
|
|
return {"content": [{"type": "text", "text": json.dumps(data, ensure_ascii=False, default=str)}]}
|
|
|
|
|
|
@metadata_router.get("/.well-known/oauth-protected-resource")
|
|
def oauth_protected_resource(settings: Settings = Depends(get_settings)) -> dict:
|
|
return mcp_protected_resource_metadata(settings)
|