From 29451ccee137a6fe0bcde6b004e690950ddf93d1 Mon Sep 17 00:00:00 2001 From: Anton Date: Thu, 14 May 2026 11:00:46 +0300 Subject: [PATCH] feat: add dataset checkpoint sync for MCP --- app/mcp.py | 41 ++++- app/models.py | 40 +++++ app/services/crawler.py | 3 + app/services/dataset_versions.py | 227 ++++++++++++++++++++++++++++ app/version.py | 6 +- migrations/004_dataset_versions.sql | 29 ++++ pyproject.toml | 2 +- tests/test_api_mcp.py | 130 +++++++++++++++- tests/test_dataset_versions.py | 88 +++++++++++ 9 files changed, 558 insertions(+), 8 deletions(-) create mode 100644 app/services/dataset_versions.py create mode 100644 migrations/004_dataset_versions.sql create mode 100644 tests/test_dataset_versions.py diff --git a/app/mcp.py b/app/mcp.py index 5924b1f..b9ebaa3 100644 --- a/app/mcp.py +++ b/app/mcp.py @@ -7,12 +7,31 @@ from sqlalchemy.orm import Session from app.db import get_db from app.models import CrawlRun, Employee from app.services.admin_data import run_detail_payload +from app.services.dataset_versions import service_info_payload, sync_employees_payload from app.version import BACKEND_VERSION router = APIRouter(prefix="/mcp") +PROTOCOL_VERSION = "2024-11-05" +SERVICE_NAME = "miem-employees" TOOLS = [ + { + "name": "get_service_info", + "description": "Return service metadata, supported tools, and current dataset version.", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "sync_employees", + "description": "Synchronize employees by dataset hash. Returns a full snapshot or a delta from client_hash.", + "inputSchema": { + "type": "object", + "properties": { + "client_hash": {"type": "string"}, + "include_data": {"type": "boolean", "default": True}, + }, + }, + }, { "name": "search_employees", "description": "Search MIEM employees by name or profile URL.", @@ -71,8 +90,8 @@ async def mcp_http( try: if method == "initialize": result = { - "protocolVersion": "2024-11-05", - "serverInfo": {"name": "miem-employees", "version": BACKEND_VERSION}, + "protocolVersion": PROTOCOL_VERSION, + "serverInfo": {"name": SERVICE_NAME, "version": BACKEND_VERSION}, "capabilities": {"tools": {}}, } elif method == "tools/list": @@ -87,6 +106,24 @@ async def mcp_http( def _call_tool(db: Session, name: str, arguments: dict) -> dict: + if name == "get_service_info": + return _tool_response( + service_info_payload( + db, + tools=TOOLS, + service_name=SERVICE_NAME, + backend_version=BACKEND_VERSION, + protocol_version=PROTOCOL_VERSION, + ) + ) + if name == "sync_employees": + return _tool_response( + sync_employees_payload( + db, + client_hash=arguments.get("client_hash"), + include_data=bool(arguments.get("include_data", True)), + ) + ) if name == "search_employees": return _tool_response(_search_employees(db, arguments)) if name == "get_employee": diff --git a/app/models.py b/app/models.py index eeba774..a84b48b 100644 --- a/app/models.py +++ b/app/models.py @@ -76,6 +76,7 @@ class CrawlRun(Base): message: Mapped[str | None] = mapped_column(Text) employee_changes: Mapped[list["CrawlRunEmployeeChange"]] = relationship(back_populates="crawl_run") + dataset_versions: Mapped[list["DatasetVersion"]] = relationship(back_populates="crawl_run") class CrawlRunEmployeeChange(Base): @@ -134,3 +135,42 @@ class ParserSource(Base): source_url: Mapped[str] = mapped_column(Text, nullable=False) enabled: Mapped[bool] = mapped_column(default=True, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, nullable=False) + + +class DatasetVersion(Base): + __tablename__ = "dataset_versions" + __table_args__ = ( + UniqueConstraint("hash", name="uq_dataset_versions_hash"), + Index("ix_dataset_versions_created_at", "created_at"), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + hash: Mapped[str] = mapped_column(String(64), nullable=False) + previous_hash: Mapped[str | None] = mapped_column(String(64)) + crawl_run_id: Mapped[int | None] = mapped_column(ForeignKey("crawl_runs.id")) + employee_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + active_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + dismissed_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, nullable=False) + + crawl_run: Mapped[CrawlRun | None] = relationship(back_populates="dataset_versions") + items: Mapped[list["DatasetVersionItem"]] = relationship(back_populates="dataset_version", cascade="all, delete-orphan") + + +class DatasetVersionItem(Base): + __tablename__ = "dataset_version_items" + __table_args__ = ( + UniqueConstraint("dataset_version_id", "profile_key", name="uq_dataset_version_items_version_profile"), + Index("ix_dataset_version_items_hash", "dataset_version_id"), + Index("ix_dataset_version_items_profile_key", "profile_key"), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + dataset_version_id: Mapped[int] = mapped_column(ForeignKey("dataset_versions.id"), nullable=False) + profile_key: Mapped[str] = mapped_column(String(255), nullable=False) + employee_id: Mapped[int | None] = mapped_column(ForeignKey("employees.id")) + status: Mapped[str] = mapped_column(String(32), nullable=False) + checksum: Mapped[str] = mapped_column(String(64), nullable=False) + + dataset_version: Mapped[DatasetVersion] = relationship(back_populates="items") + employee: Mapped[Employee | None] = relationship() diff --git a/app/services/crawler.py b/app/services/crawler.py index 5048ab7..c081a3d 100644 --- a/app/services/crawler.py +++ b/app/services/crawler.py @@ -13,6 +13,7 @@ from app.models import CrawlError, CrawlRun, CrawlRunEmployeeChange, Employee, E from app.parser.collector import collect_profile_links from app.parser.profile import parse_person_profile from app.parser.profile_url import profile_key +from app.services.dataset_versions import get_or_create_current_version HEADERS = { "User-Agent": "Mozilla/5.0 (compatible; MIEMEmployeesBot/0.1.0; +https://miem.hse.ru/)" @@ -70,6 +71,7 @@ def run_crawl(db: Session, settings: Settings) -> CrawlRun: run.dismissed_count = _mark_dismissed(db, run, found_keys, session, settings.request_timeout) run.status = "completed" + get_or_create_current_version(db, crawl_run_id=run.id) except Exception as exc: run.status = "failed" run.message = str(exc) @@ -103,6 +105,7 @@ def refresh_employee(db: Session, employee: Employee, settings: Settings) -> Cra _upsert_employee(db, run, parsed) run.parsed_count = 1 run.status = "completed" + get_or_create_current_version(db, crawl_run_id=run.id) except Exception as exc: run.status = "failed" run.error_count = 1 diff --git a/app/services/dataset_versions.py b/app/services/dataset_versions.py new file mode 100644 index 0000000..74bc59f --- /dev/null +++ b/app/services/dataset_versions.py @@ -0,0 +1,227 @@ +import hashlib +import json +from dataclasses import dataclass + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session + +from app.models import DatasetVersion, DatasetVersionItem, Employee + + +@dataclass(frozen=True) +class EmployeeMarker: + profile_key: str + employee_id: int | None + status: str + checksum: str + + +def get_or_create_current_version(db: Session, *, crawl_run_id: int | None = None) -> DatasetVersion: + employees = db.scalars(select(Employee).order_by(Employee.profile_key)).all() + markers = [_employee_marker(employee) for employee in employees] + dataset_hash = _dataset_hash(markers) + latest = get_latest_version(db) + if latest and latest.hash == dataset_hash: + return latest + + active_count = sum(1 for marker in markers if marker.status == "active") + dismissed_count = sum(1 for marker in markers if marker.status == "dismissed") + version = DatasetVersion( + hash=dataset_hash, + previous_hash=latest.hash if latest else None, + crawl_run_id=crawl_run_id, + employee_count=len(markers), + active_count=active_count, + dismissed_count=dismissed_count, + ) + db.add(version) + db.flush() + for marker in markers: + db.add( + DatasetVersionItem( + dataset_version_id=version.id, + profile_key=marker.profile_key, + employee_id=marker.employee_id, + status=marker.status, + checksum=marker.checksum, + ) + ) + db.flush() + return version + + +def get_latest_version(db: Session) -> DatasetVersion | None: + return db.scalar(select(DatasetVersion).order_by(desc(DatasetVersion.created_at), desc(DatasetVersion.id)).limit(1)) + + +def get_version_by_hash(db: Session, dataset_hash: str | None) -> DatasetVersion | None: + if not dataset_hash: + return None + return db.scalar(select(DatasetVersion).where(DatasetVersion.hash == dataset_hash).limit(1)) + + +def service_info_payload(db: Session, *, tools: list[dict], service_name: str, backend_version: str, protocol_version: str) -> dict: + version = get_or_create_current_version(db) + db.commit() + return { + "service_name": service_name, + "backend_version": backend_version, + "protocolVersion": protocol_version, + "tools": tools, + "dataset": _version_payload(version), + } + + +def sync_employees_payload(db: Session, *, client_hash: str | None = None, include_data: bool = True) -> dict: + current = get_or_create_current_version(db) + db.commit() + if not client_hash: + return _full_sync_payload(db, current, include_data=include_data, reason=None) + if client_hash == current.hash: + return { + "mode": "delta", + "from_hash": client_hash, + "to_hash": current.hash, + "dataset": _version_payload(current), + "changes": {"added": [], "updated": [], "dismissed": [], "removed": []}, + } + + previous = get_version_by_hash(db, client_hash) + if not previous: + return _full_sync_payload(db, current, include_data=include_data, reason="unknown_client_hash", from_hash=client_hash) + + return _delta_sync_payload(db, previous, current, include_data=include_data) + + +def _full_sync_payload( + db: Session, + current: DatasetVersion, + *, + include_data: bool, + reason: str | None, + from_hash: str | None = None, +) -> dict: + employees = db.scalars(select(Employee).order_by(Employee.profile_key)).all() + payload = { + "mode": "full", + "from_hash": from_hash, + "to_hash": current.hash, + "dataset": _version_payload(current), + "items": [_employee_payload(employee, include_data=include_data) for employee in employees], + } + if reason: + payload["reason"] = reason + return payload + + +def _delta_sync_payload(db: Session, previous: DatasetVersion, current: DatasetVersion, *, include_data: bool) -> dict: + previous_items = _items_by_profile_key(previous) + current_items = _items_by_profile_key(current) + employees = {employee.profile_key: employee for employee in db.scalars(select(Employee)).all()} + added = [] + updated = [] + dismissed = [] + removed = [] + + for profile_key, current_item in sorted(current_items.items()): + previous_item = previous_items.get(profile_key) + employee = employees.get(profile_key) + if not previous_item: + if employee: + added.append(_employee_payload(employee, include_data=include_data)) + continue + if previous_item.checksum == current_item.checksum and previous_item.status == current_item.status: + continue + if current_item.status == "dismissed": + dismissed.append(_tombstone(profile_key, current_item.status, employee)) + elif employee: + updated.append(_employee_payload(employee, include_data=include_data)) + + for profile_key, previous_item in sorted(previous_items.items()): + if profile_key not in current_items: + removed.append(_tombstone(profile_key, "removed", employees.get(profile_key), checksum=previous_item.checksum)) + + return { + "mode": "delta", + "from_hash": previous.hash, + "to_hash": current.hash, + "dataset": _version_payload(current), + "changes": { + "added": added, + "updated": updated, + "dismissed": dismissed, + "removed": removed, + }, + } + + +def _items_by_profile_key(version: DatasetVersion) -> dict[str, DatasetVersionItem]: + return {item.profile_key: item for item in version.items} + + +def _version_payload(version: DatasetVersion) -> dict: + return { + "hash": version.hash, + "previous_hash": version.previous_hash, + "created_at": version.created_at.isoformat() if version.created_at else None, + "crawl_run_id": version.crawl_run_id, + "employee_count": version.employee_count, + "active_count": version.active_count, + "dismissed_count": version.dismissed_count, + } + + +def _employee_marker(employee: Employee) -> EmployeeMarker: + return EmployeeMarker( + profile_key=employee.profile_key, + employee_id=employee.id, + status=employee.status, + checksum=employee.current_checksum or _payload_hash(employee.current_data or {}), + ) + + +def _dataset_hash(markers: list[EmployeeMarker]) -> str: + payload = [ + {"profile_key": marker.profile_key, "status": marker.status, "checksum": marker.checksum} + for marker in sorted(markers, key=lambda item: item.profile_key) + ] + return _payload_hash(payload) + + +def _payload_hash(value: object) -> str: + payload = json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _employee_payload(employee: Employee, *, include_data: bool) -> dict: + payload = { + "profile_key": employee.profile_key, + "profile_id": employee.profile_id, + "full_name": employee.full_name, + "status": employee.status, + "canonical_url": employee.canonical_url, + "last_seen_at": employee.last_seen_at.isoformat() if employee.last_seen_at else None, + "dismissed_at": employee.dismissed_at.isoformat() if employee.dismissed_at else None, + "checksum": employee.current_checksum or _payload_hash(employee.current_data or {}), + } + if include_data: + payload["data"] = employee.current_data + return payload + + +def _tombstone(profile_key: str, status: str, employee: Employee | None, *, checksum: str | None = None) -> dict: + payload = { + "profile_key": profile_key, + "status": status, + "checksum": checksum or (employee.current_checksum if employee else None), + } + if employee: + payload.update( + { + "profile_id": employee.profile_id, + "full_name": employee.full_name, + "canonical_url": employee.canonical_url, + "dismissed_at": employee.dismissed_at.isoformat() if employee.dismissed_at else None, + } + ) + return payload diff --git a/app/version.py b/app/version.py index 0c07749..82de748 100644 --- a/app/version.py +++ b/app/version.py @@ -1,3 +1,3 @@ -APP_VERSION = "0.4.7" -FRONTEND_VERSION = "0.4.7" -BACKEND_VERSION = "0.4.7" +APP_VERSION = "0.5.0" +FRONTEND_VERSION = "0.5.0" +BACKEND_VERSION = "0.5.0" diff --git a/migrations/004_dataset_versions.sql b/migrations/004_dataset_versions.sql new file mode 100644 index 0000000..aa6969a --- /dev/null +++ b/migrations/004_dataset_versions.sql @@ -0,0 +1,29 @@ +CREATE TABLE IF NOT EXISTS dataset_versions ( + id SERIAL PRIMARY KEY, + hash VARCHAR(64) NOT NULL UNIQUE, + previous_hash VARCHAR(64), + crawl_run_id INTEGER REFERENCES crawl_runs(id), + employee_count INTEGER NOT NULL DEFAULT 0, + active_count INTEGER NOT NULL DEFAULT 0, + dismissed_count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS ix_dataset_versions_created_at + ON dataset_versions (created_at); + +CREATE TABLE IF NOT EXISTS dataset_version_items ( + id SERIAL PRIMARY KEY, + dataset_version_id INTEGER NOT NULL REFERENCES dataset_versions(id), + profile_key VARCHAR(255) NOT NULL, + employee_id INTEGER REFERENCES employees(id), + status VARCHAR(32) NOT NULL, + checksum VARCHAR(64) NOT NULL, + CONSTRAINT uq_dataset_version_items_version_profile UNIQUE (dataset_version_id, profile_key) +); + +CREATE INDEX IF NOT EXISTS ix_dataset_version_items_hash + ON dataset_version_items (dataset_version_id); + +CREATE INDEX IF NOT EXISTS ix_dataset_version_items_profile_key + ON dataset_version_items (profile_key); diff --git a/pyproject.toml b/pyproject.toml index 19e849e..9fb2306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "miem-workers" -version = "0.4.7" +version = "0.5.0" description = "MIEM employees parser, admin API, and MCP server" requires-python = ">=3.11" dependencies = [ diff --git a/tests/test_api_mcp.py b/tests/test_api_mcp.py index 4146134..bd7c40e 100644 --- a/tests/test_api_mcp.py +++ b/tests/test_api_mcp.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timezone from types import SimpleNamespace @@ -19,7 +20,7 @@ def test_health_returns_versions(): response = client.get("/api/health") assert response.status_code == 200 - assert response.json()["backend_version"] == "0.4.7" + assert response.json()["backend_version"] == "0.5.0" def test_mcp_lists_tools_without_auth_and_ignores_auth_header(): @@ -50,7 +51,10 @@ def test_mcp_lists_tools_without_auth_and_ignores_auth_header(): assert without_auth.status_code == 200 assert with_auth.status_code == 200 - assert without_auth.json()["result"]["tools"][0]["name"] == "search_employees" + tool_names = {tool["name"] for tool in without_auth.json()["result"]["tools"]} + assert "search_employees" in tool_names + assert "get_service_info" in tool_names + assert "sync_employees" in tool_names assert any(tool["name"] == "get_crawl_run_details" for tool in without_auth.json()["result"]["tools"]) assert with_auth.json()["result"]["tools"] == without_auth.json()["result"]["tools"] @@ -108,6 +112,128 @@ def test_mcp_search_employees_returns_matching_employee(): app.dependency_overrides.clear() +def test_mcp_service_info_returns_tools_and_dataset_hash(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + session = Session() + session.add( + Employee( + profile_key="staff:alpha", + profile_type="staff", + profile_id="alpha", + canonical_url="https://www.hse.ru/staff/alpha", + full_name="Alpha Person", + status="active", + current_checksum="a" * 64, + current_data={"sections": []}, + ) + ) + session.commit() + session.close() + + def override_db(): + db = Session() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + client = TestClient(app) + + response = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "get_service_info", "arguments": {}}}, + ) + + assert response.status_code == 200 + payload = json.loads(response.json()["result"]["content"][0]["text"]) + assert payload["service_name"] == "miem-employees" + assert payload["backend_version"] == "0.5.0" + assert payload["dataset"]["hash"] + assert any(tool["name"] == "sync_employees" for tool in payload["tools"]) + + app.dependency_overrides.clear() + + +def test_mcp_sync_employees_full_empty_and_unknown_hash_modes(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + session = Session() + session.add( + Employee( + profile_key="staff:alpha", + profile_type="staff", + profile_id="alpha", + canonical_url="https://www.hse.ru/staff/alpha", + full_name="Alpha Person", + status="active", + current_checksum="a" * 64, + current_data={"sections": [{"type": "paragraphs"}]}, + ) + ) + session.commit() + session.close() + + def override_db(): + db = Session() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + client = TestClient(app) + + full_response = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "sync_employees", "arguments": {}}}, + ) + full_payload = json.loads(full_response.json()["result"]["content"][0]["text"]) + current_hash = full_payload["to_hash"] + + empty_response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "sync_employees", "arguments": {"client_hash": current_hash}}, + }, + ) + empty_payload = json.loads(empty_response.json()["result"]["content"][0]["text"]) + + unknown_response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": {"name": "sync_employees", "arguments": {"client_hash": "missing"}}, + }, + ) + unknown_payload = json.loads(unknown_response.json()["result"]["content"][0]["text"]) + + assert full_payload["mode"] == "full" + assert full_payload["items"][0]["data"] == {"sections": [{"type": "paragraphs"}]} + assert empty_payload["mode"] == "delta" + assert empty_payload["changes"] == {"added": [], "updated": [], "dismissed": [], "removed": []} + assert unknown_payload["mode"] == "full" + assert unknown_payload["reason"] == "unknown_client_hash" + + app.dependency_overrides.clear() + + def test_mcp_get_crawl_run_details_returns_changes(): engine = create_engine( "sqlite:///:memory:", diff --git a/tests/test_dataset_versions.py b/tests/test_dataset_versions.py new file mode 100644 index 0000000..f927795 --- /dev/null +++ b/tests/test_dataset_versions.py @@ -0,0 +1,88 @@ +from datetime import datetime, timezone + +from app.models import Employee +from app.services.dataset_versions import get_or_create_current_version, sync_employees_payload + + +def _employee(profile_key: str, checksum: str, *, status: str = "active") -> Employee: + return Employee( + profile_key=profile_key, + profile_type=profile_key.split(":", 1)[0], + profile_id=profile_key.split(":", 1)[1], + canonical_url=f"https://www.hse.ru/{profile_key}", + full_name=profile_key, + status=status, + first_seen_at=datetime.now(timezone.utc), + last_seen_at=datetime.now(timezone.utc), + current_data={"profile_key": profile_key}, + current_checksum=checksum, + ) + + +def test_dataset_version_hash_is_stable_for_same_employee_state(db_session): + db_session.add(_employee("staff:alpha", "a" * 64)) + db_session.commit() + + first = get_or_create_current_version(db_session) + db_session.commit() + second = get_or_create_current_version(db_session) + + assert second.id == first.id + assert second.hash == first.hash + assert second.employee_count == 1 + + +def test_dataset_version_hash_changes_when_employee_checksum_changes(db_session): + employee = _employee("staff:alpha", "a" * 64) + db_session.add(employee) + db_session.commit() + first = get_or_create_current_version(db_session) + db_session.commit() + + employee.current_checksum = "b" * 64 + db_session.commit() + second = get_or_create_current_version(db_session) + + assert second.hash != first.hash + assert second.previous_hash == first.hash + + +def test_sync_employees_diff_spans_multiple_intermediate_versions(db_session): + alpha = _employee("staff:alpha", "a" * 64) + db_session.add(alpha) + db_session.commit() + first = get_or_create_current_version(db_session) + db_session.commit() + + beta = _employee("staff:beta", "b" * 64) + db_session.add(beta) + db_session.commit() + get_or_create_current_version(db_session) + db_session.commit() + + alpha.current_checksum = "c" * 64 + alpha.current_data = {"profile_key": "staff:alpha", "changed": True} + db_session.commit() + + payload = sync_employees_payload(db_session, client_hash=first.hash, include_data=False) + + assert payload["mode"] == "delta" + assert [item["profile_key"] for item in payload["changes"]["added"]] == ["staff:beta"] + assert [item["profile_key"] for item in payload["changes"]["updated"]] == ["staff:alpha"] + assert payload["changes"]["dismissed"] == [] + assert payload["changes"]["removed"] == [] + + +def test_sync_employees_reports_dismissed_as_tombstone(db_session): + alpha = _employee("staff:alpha", "a" * 64) + db_session.add(alpha) + db_session.commit() + first = get_or_create_current_version(db_session) + db_session.commit() + + alpha.status = "dismissed" + db_session.commit() + payload = sync_employees_payload(db_session, client_hash=first.hash, include_data=False) + + assert payload["changes"]["dismissed"][0]["profile_key"] == "staff:alpha" + assert payload["changes"]["dismissed"][0]["status"] == "dismissed"