from __future__ import annotations import json import threading from datetime import UTC, datetime from .db import get_conn from .scanner import HostResult class ScanState: def __init__(self): self._lock = threading.Lock() self.running = False self.current_scan_id: int | None = None def start(self, scan_id: int) -> bool: with self._lock: if self.running: return False self.running = True self.current_scan_id = scan_id return True def finish(self) -> None: with self._lock: self.running = False self.current_scan_id = None scan_state = ScanState() def now_iso() -> str: return datetime.now(UTC).isoformat() def create_scan(subnet: str) -> int: started = now_iso() with get_conn() as conn: cur = conn.execute( "INSERT INTO scans(subnet, started_at, status) VALUES (?, ?, ?)", (subnet, started, "running"), ) return int(cur.lastrowid) def complete_scan(scan_id: int, status: str, host_count: int, notes: str | None = None) -> None: with get_conn() as conn: conn.execute( """ UPDATE scans SET completed_at = ?, status = ?, host_count = ?, notes = ? WHERE id = ? """, (now_iso(), status, host_count, notes, scan_id), ) def upsert_host(scan_id: int, host: HostResult) -> int: timestamp = now_iso() with get_conn() as conn: row = conn.execute("SELECT id FROM devices WHERE ip = ?", (host.ip,)).fetchone() if row: device_id = int(row["id"]) conn.execute( """ UPDATE devices SET hostname = ?, mac = ?, vendor = ?, os_name = ?, last_seen = ?, is_active = 1, last_scan_id = ? WHERE id = ? """, ( host.hostname, host.mac, host.vendor, host.os_name, timestamp, scan_id, device_id, ), ) else: cur = conn.execute( """ INSERT INTO devices(ip, hostname, mac, vendor, os_name, first_seen, last_seen, is_active, last_scan_id) VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?) """, ( host.ip, host.hostname, host.mac, host.vendor, host.os_name, timestamp, timestamp, scan_id, ), ) device_id = int(cur.lastrowid) seen_ports = {(p.port, p.protocol) for p in host.ports} existing_ports = conn.execute( "SELECT id, port, protocol FROM ports WHERE device_id = ?", (device_id,), ).fetchall() for p in host.ports: headers_json = json.dumps(p.headers) existing = conn.execute( "SELECT id FROM ports WHERE device_id = ? AND port = ? AND protocol = ?", (device_id, p.port, p.protocol), ).fetchone() if existing: conn.execute( """ UPDATE ports SET state = ?, service = ?, product = ?, version = ?, extra_info = ?, banner = ?, headers_json = ?, last_seen = ?, is_open = 1 WHERE id = ? """, ( p.state, p.service, p.product, p.version, p.extra_info, p.banner, headers_json, timestamp, int(existing["id"]), ), ) else: conn.execute( """ INSERT INTO ports(device_id, port, protocol, state, service, product, version, extra_info, banner, headers_json, first_seen, last_seen, is_open) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1) """, ( device_id, p.port, p.protocol, p.state, p.service, p.product, p.version, p.extra_info, p.banner, headers_json, timestamp, timestamp, ), ) for rowp in existing_ports: pair = (int(rowp["port"]), rowp["protocol"]) if pair not in seen_ports: conn.execute( "UPDATE ports SET is_open = 0, state = 'closed', last_seen = ? WHERE id = ?", (timestamp, int(rowp["id"])), ) return device_id def mark_missing_devices(scan_id: int) -> None: with get_conn() as conn: conn.execute( """ UPDATE devices SET is_active = 0 WHERE last_scan_id IS NOT NULL AND last_scan_id <> ? """, (scan_id,), ) def fetch_devices() -> list[dict]: with get_conn() as conn: rows = conn.execute( """ SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen FROM devices ORDER BY is_active DESC, ip ASC """ ).fetchall() return [dict(row) for row in rows] def fetch_device(device_id: int) -> dict | None: with get_conn() as conn: device = conn.execute( """ SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen FROM devices WHERE id = ? """, (device_id,), ).fetchone() if not device: return None ports = conn.execute( """ SELECT id, port, protocol, state, service, product, version, extra_info, banner, headers_json, first_seen, last_seen, is_open FROM ports WHERE device_id = ? ORDER BY is_open DESC, port ASC """, (device_id,), ).fetchall() result = dict(device) parsed_ports = [] for row in ports: p = dict(row) p["headers"] = json.loads(p["headers_json"] or "{}") p.pop("headers_json", None) parsed_ports.append(p) result["ports"] = parsed_ports return result def fetch_scans(limit: int = 20) -> list[dict]: with get_conn() as conn: rows = conn.execute( """ SELECT id, subnet, started_at, completed_at, status, host_count, notes FROM scans ORDER BY id DESC LIMIT ? """, (limit,), ).fetchall() return [dict(row) for row in rows]