from __future__ import annotations import json import threading import time from datetime import UTC, datetime from .db import get_conn from .scanner import HostResult class ScanState: def __init__(self): self._lock = threading.Lock() self.running = False self.current_scan_id: int | None = None self.subnet: str | None = None self.total_hosts = 0 self.processed_hosts = 0 self.saved_hosts = 0 self.current_host: str | None = None self.started_monotonic: float | None = None def start(self, scan_id: int, subnet: str) -> bool: with self._lock: if self.running: return False self.running = True self.current_scan_id = scan_id self.subnet = subnet self.total_hosts = 0 self.processed_hosts = 0 self.saved_hosts = 0 self.current_host = None self.started_monotonic = time.monotonic() return True def set_total_hosts(self, total_hosts: int) -> None: with self._lock: self.total_hosts = max(total_hosts, 0) def set_current_host(self, host_ip: str | None) -> None: with self._lock: self.current_host = host_ip def update_progress(self, processed_hosts: int, saved_hosts: int) -> None: with self._lock: self.processed_hosts = max(processed_hosts, 0) self.saved_hosts = max(saved_hosts, 0) def finish(self) -> None: with self._lock: self.running = False self.current_scan_id = None self.current_host = None self.started_monotonic = None def snapshot(self) -> dict: with self._lock: percent = 0 if self.total_hosts > 0: percent = int((self.processed_hosts / self.total_hosts) * 100) elapsed_seconds = 0 eta_seconds = None if self.running and self.started_monotonic is not None: elapsed_seconds = int(max(time.monotonic() - self.started_monotonic, 0)) if self.processed_hosts > 0 and self.total_hosts > self.processed_hosts: rate = self.processed_hosts / max(elapsed_seconds, 1) remaining = self.total_hosts - self.processed_hosts eta_seconds = int(remaining / max(rate, 1e-9)) return { "running": self.running, "scan_id": self.current_scan_id, "subnet": self.subnet, "total_hosts": self.total_hosts, "processed_hosts": self.processed_hosts, "saved_hosts": self.saved_hosts, "current_host": self.current_host, "percent": min(max(percent, 0), 100), "elapsed_seconds": elapsed_seconds, "eta_seconds": eta_seconds, } scan_state = ScanState() def now_iso() -> str: return datetime.now(UTC).isoformat() def create_scan(subnet: str) -> int: started = now_iso() with get_conn() as conn: cur = conn.execute( "INSERT INTO scans(subnet, started_at, status) VALUES (?, ?, ?)", (subnet, started, "running"), ) return int(cur.lastrowid) def complete_scan(scan_id: int, status: str, host_count: int, notes: str | None = None) -> None: with get_conn() as conn: conn.execute( """ UPDATE scans SET completed_at = ?, status = ?, host_count = ?, notes = ? WHERE id = ? """, (now_iso(), status, host_count, notes, scan_id), ) def upsert_host(scan_id: int, host: HostResult) -> int: timestamp = now_iso() with get_conn() as conn: row = conn.execute("SELECT id FROM devices WHERE ip = ?", (host.ip,)).fetchone() if row: device_id = int(row["id"]) conn.execute( """ UPDATE devices SET hostname = ?, mac = ?, vendor = ?, os_name = ?, last_seen = ?, is_active = 1, last_scan_id = ? WHERE id = ? """, ( host.hostname, host.mac, host.vendor, host.os_name, timestamp, scan_id, device_id, ), ) else: cur = conn.execute( """ INSERT INTO devices(ip, hostname, mac, vendor, os_name, first_seen, last_seen, is_active, last_scan_id) VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?) """, ( host.ip, host.hostname, host.mac, host.vendor, host.os_name, timestamp, timestamp, scan_id, ), ) device_id = int(cur.lastrowid) seen_ports = {(p.port, p.protocol) for p in host.ports} existing_ports = conn.execute( "SELECT id, port, protocol FROM ports WHERE device_id = ?", (device_id,), ).fetchall() for p in host.ports: headers_json = json.dumps(p.headers) existing = conn.execute( "SELECT id FROM ports WHERE device_id = ? AND port = ? AND protocol = ?", (device_id, p.port, p.protocol), ).fetchone() if existing: conn.execute( """ UPDATE ports SET state = ?, service = ?, product = ?, version = ?, extra_info = ?, banner = ?, headers_json = ?, last_seen = ?, is_open = 1 WHERE id = ? """, ( p.state, p.service, p.product, p.version, p.extra_info, p.banner, headers_json, timestamp, int(existing["id"]), ), ) else: conn.execute( """ INSERT INTO ports(device_id, port, protocol, state, service, product, version, extra_info, banner, headers_json, first_seen, last_seen, is_open) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1) """, ( device_id, p.port, p.protocol, p.state, p.service, p.product, p.version, p.extra_info, p.banner, headers_json, timestamp, timestamp, ), ) for rowp in existing_ports: pair = (int(rowp["port"]), rowp["protocol"]) if pair not in seen_ports: conn.execute( "UPDATE ports SET is_open = 0, state = 'closed', last_seen = ? WHERE id = ?", (timestamp, int(rowp["id"])), ) return device_id def mark_missing_devices(scan_id: int) -> None: with get_conn() as conn: conn.execute( """ UPDATE devices SET is_active = 0 WHERE last_scan_id IS NOT NULL AND last_scan_id <> ? """, (scan_id,), ) def fetch_devices() -> list[dict]: with get_conn() as conn: rows = conn.execute( """ SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen FROM devices ORDER BY is_active DESC, ip ASC """ ).fetchall() return [dict(row) for row in rows] def fetch_device(device_id: int) -> dict | None: with get_conn() as conn: device = conn.execute( """ SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen FROM devices WHERE id = ? """, (device_id,), ).fetchone() if not device: return None ports = conn.execute( """ SELECT id, port, protocol, state, service, product, version, extra_info, banner, headers_json, first_seen, last_seen, is_open FROM ports WHERE device_id = ? ORDER BY is_open DESC, port ASC """, (device_id,), ).fetchall() result = dict(device) parsed_ports = [] for row in ports: p = dict(row) p["headers"] = json.loads(p["headers_json"] or "{}") p.pop("headers_json", None) parsed_ports.append(p) result["ports"] = parsed_ports return result def fetch_scans(limit: int = 20) -> list[dict]: with get_conn() as conn: rows = conn.execute( """ SELECT id, subnet, started_at, completed_at, status, host_count, notes FROM scans ORDER BY id DESC LIMIT ? """, (limit,), ).fetchall() return [dict(row) for row in rows]