Files
NetTrak/app/service.py
2026-03-08 15:40:11 -05:00

297 lines
9.4 KiB
Python

from __future__ import annotations
import json
import threading
import time
from datetime import UTC, datetime
from .db import get_conn
from .scanner import HostResult
class ScanState:
def __init__(self):
self._lock = threading.Lock()
self.running = False
self.current_scan_id: int | None = None
self.subnet: str | None = None
self.total_hosts = 0
self.processed_hosts = 0
self.saved_hosts = 0
self.current_host: str | None = None
self.started_monotonic: float | None = None
def start(self, scan_id: int, subnet: str) -> bool:
with self._lock:
if self.running:
return False
self.running = True
self.current_scan_id = scan_id
self.subnet = subnet
self.total_hosts = 0
self.processed_hosts = 0
self.saved_hosts = 0
self.current_host = None
self.started_monotonic = time.monotonic()
return True
def set_total_hosts(self, total_hosts: int) -> None:
with self._lock:
self.total_hosts = max(total_hosts, 0)
def set_current_host(self, host_ip: str | None) -> None:
with self._lock:
self.current_host = host_ip
def update_progress(self, processed_hosts: int, saved_hosts: int) -> None:
with self._lock:
self.processed_hosts = max(processed_hosts, 0)
self.saved_hosts = max(saved_hosts, 0)
def finish(self) -> None:
with self._lock:
self.running = False
self.current_scan_id = None
self.current_host = None
self.started_monotonic = None
def snapshot(self) -> dict:
with self._lock:
percent = 0
if self.total_hosts > 0:
percent = int((self.processed_hosts / self.total_hosts) * 100)
elapsed_seconds = 0
eta_seconds = None
if self.running and self.started_monotonic is not None:
elapsed_seconds = int(max(time.monotonic() - self.started_monotonic, 0))
if self.processed_hosts > 0 and self.total_hosts > self.processed_hosts:
rate = self.processed_hosts / max(elapsed_seconds, 1)
remaining = self.total_hosts - self.processed_hosts
eta_seconds = int(remaining / max(rate, 1e-9))
return {
"running": self.running,
"scan_id": self.current_scan_id,
"subnet": self.subnet,
"total_hosts": self.total_hosts,
"processed_hosts": self.processed_hosts,
"saved_hosts": self.saved_hosts,
"current_host": self.current_host,
"percent": min(max(percent, 0), 100),
"elapsed_seconds": elapsed_seconds,
"eta_seconds": eta_seconds,
}
scan_state = ScanState()
def now_iso() -> str:
return datetime.now(UTC).isoformat()
def create_scan(subnet: str) -> int:
started = now_iso()
with get_conn() as conn:
cur = conn.execute(
"INSERT INTO scans(subnet, started_at, status) VALUES (?, ?, ?)",
(subnet, started, "running"),
)
return int(cur.lastrowid)
def complete_scan(scan_id: int, status: str, host_count: int, notes: str | None = None) -> None:
with get_conn() as conn:
conn.execute(
"""
UPDATE scans
SET completed_at = ?, status = ?, host_count = ?, notes = ?
WHERE id = ?
""",
(now_iso(), status, host_count, notes, scan_id),
)
def upsert_host(scan_id: int, host: HostResult) -> int:
timestamp = now_iso()
with get_conn() as conn:
row = conn.execute("SELECT id FROM devices WHERE ip = ?", (host.ip,)).fetchone()
if row:
device_id = int(row["id"])
conn.execute(
"""
UPDATE devices
SET hostname = ?, mac = ?, vendor = ?, os_name = ?,
last_seen = ?, is_active = 1, last_scan_id = ?
WHERE id = ?
""",
(
host.hostname,
host.mac,
host.vendor,
host.os_name,
timestamp,
scan_id,
device_id,
),
)
else:
cur = conn.execute(
"""
INSERT INTO devices(ip, hostname, mac, vendor, os_name, first_seen, last_seen, is_active, last_scan_id)
VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?)
""",
(
host.ip,
host.hostname,
host.mac,
host.vendor,
host.os_name,
timestamp,
timestamp,
scan_id,
),
)
device_id = int(cur.lastrowid)
seen_ports = {(p.port, p.protocol) for p in host.ports}
existing_ports = conn.execute(
"SELECT id, port, protocol FROM ports WHERE device_id = ?",
(device_id,),
).fetchall()
for p in host.ports:
headers_json = json.dumps(p.headers)
existing = conn.execute(
"SELECT id FROM ports WHERE device_id = ? AND port = ? AND protocol = ?",
(device_id, p.port, p.protocol),
).fetchone()
if existing:
conn.execute(
"""
UPDATE ports
SET state = ?, service = ?, product = ?, version = ?, extra_info = ?, banner = ?,
headers_json = ?, last_seen = ?, is_open = 1
WHERE id = ?
""",
(
p.state,
p.service,
p.product,
p.version,
p.extra_info,
p.banner,
headers_json,
timestamp,
int(existing["id"]),
),
)
else:
conn.execute(
"""
INSERT INTO ports(device_id, port, protocol, state, service, product, version, extra_info, banner,
headers_json, first_seen, last_seen, is_open)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1)
""",
(
device_id,
p.port,
p.protocol,
p.state,
p.service,
p.product,
p.version,
p.extra_info,
p.banner,
headers_json,
timestamp,
timestamp,
),
)
for rowp in existing_ports:
pair = (int(rowp["port"]), rowp["protocol"])
if pair not in seen_ports:
conn.execute(
"UPDATE ports SET is_open = 0, state = 'closed', last_seen = ? WHERE id = ?",
(timestamp, int(rowp["id"])),
)
return device_id
def mark_missing_devices(scan_id: int) -> None:
with get_conn() as conn:
conn.execute(
"""
UPDATE devices
SET is_active = 0
WHERE last_scan_id IS NOT NULL
AND last_scan_id <> ?
""",
(scan_id,),
)
def fetch_devices() -> list[dict]:
with get_conn() as conn:
rows = conn.execute(
"""
SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen
FROM devices
ORDER BY is_active DESC, ip ASC
"""
).fetchall()
return [dict(row) for row in rows]
def fetch_device(device_id: int) -> dict | None:
with get_conn() as conn:
device = conn.execute(
"""
SELECT id, ip, hostname, os_name, mac, vendor, is_active, first_seen, last_seen
FROM devices
WHERE id = ?
""",
(device_id,),
).fetchone()
if not device:
return None
ports = conn.execute(
"""
SELECT id, port, protocol, state, service, product, version, extra_info, banner, headers_json,
first_seen, last_seen, is_open
FROM ports
WHERE device_id = ?
ORDER BY is_open DESC, port ASC
""",
(device_id,),
).fetchall()
result = dict(device)
parsed_ports = []
for row in ports:
p = dict(row)
p["headers"] = json.loads(p["headers_json"] or "{}")
p.pop("headers_json", None)
parsed_ports.append(p)
result["ports"] = parsed_ports
return result
def fetch_scans(limit: int = 20) -> list[dict]:
with get_conn() as conn:
rows = conn.execute(
"""
SELECT id, subnet, started_at, completed_at, status, host_count, notes
FROM scans
ORDER BY id DESC
LIMIT ?
""",
(limit,),
).fetchall()
return [dict(row) for row in rows]