Coverage for ids_iforest/detect.py: 61%
270 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 16:19 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 16:19 +0000
1"""Live anomaly detection using tshark (Docker/Windows friendly).
3- Capt all_numeric_cols = numeric_cols + extended
4 X = scaler.transform(df[all_numeric_cols].fillna(0.0).astype(float).values)
5 scores = model.decision_function(X)
6 now_str = _dt.datetime.utcnow().isoformat()s packets via tshark in "fields" mode (no pyshark).
7- Aggregates packets into time-windowed flows.
8- Scores flows using a trained Isolation Forest + scaler.
9- Writes anomalies to alerts.jsonl (Grafana/Promtail) and alerts.csv.
10"""
12from __future__ import annotations
14import argparse
15import csv
16import datetime as _dt
17import os
18import shutil
19import statistics as stats
20import subprocess
21import threading
22import time
23from typing import Optional, Tuple, Dict, Any, Iterable, List
25import pandas as pd # type: ignore
27from .utils import (
28 load_config,
29 get_logger,
30 load_model,
31 load_thresholds,
32)
34__all__ = ["main"]
37# ----------------------------
38# Scoring & alert persistence
39# ----------------------------
40def _score_flows(
41 model: Any,
42 scaler: Any,
43 df: pd.DataFrame,
44 red_thr: float,
45 yellow_thr: float,
46) -> Iterable[Tuple[str, Dict[str, Any]]]:
47 """Score each flow and yield (level, alert_dict) for anomalies."""
48 if df.empty:
49 return [] # type: ignore
51 numeric_cols = [
52 "bidirectional_packets",
53 "bidirectional_bytes",
54 "mean_packet_size",
55 "std_packet_size",
56 "flow_duration",
57 ]
58 extended = [
59 "tcp_syn_count",
60 "tcp_fin_count",
61 "tcp_rst_count",
62 "iat_mean",
63 "iat_std",
64 "bytes_per_packet",
65 "packets_per_second",
66 ]
67 for col in extended:
68 if col in df.columns:
69 numeric_cols.append(col)
71 X = scaler.transform(df[numeric_cols].fillna(0.0).astype(float).values)
72 scores = model.decision_function(X)
74 # Use timezone-aware datetime objects for UTC time (fixing deprecation warning)
75 try:
76 # Python 3.11+ has UTC constant
77 now_str = _dt.datetime.now(_dt.UTC).isoformat()
78 except AttributeError:
79 # Fallback for older Python versions
80 now_str = _dt.datetime.now(_dt.timezone.utc).isoformat()
82 for idx, s in enumerate(scores):
83 score_f = float(s)
84 if score_f < yellow_thr:
85 level = "RED" if score_f < red_thr else "YELLOW"
86 row = df.iloc[idx]
87 yield level, {
88 "timestamp": now_str,
89 "src_ip": row["src_ip"],
90 "dst_ip": row["dst_ip"],
91 "src_port": int(row["src_port"]),
92 "dst_port": int(row["dst_port"]),
93 "protocol": row["protocol"],
94 "score": score_f,
95 "level": level,
96 }
99def _write_alert_csv(
100 alerts: Iterable[Tuple[str, Dict[str, Any]]], csv_path: str
101) -> None:
102 """Append alert rows to the CSV file (header on first write)."""
103 os.makedirs(os.path.dirname(csv_path), exist_ok=True)
104 exists = os.path.exists(csv_path)
105 with open(csv_path, "a", newline="", encoding="utf-8") as f:
106 writer = csv.DictWriter(
107 f,
108 fieldnames=[
109 "timestamp",
110 "src_ip",
111 "dst_ip",
112 "src_port",
113 "dst_port",
114 "protocol",
115 "score",
116 "level",
117 ],
118 )
119 if not exists:
120 writer.writeheader()
121 for _, alert in alerts:
122 writer.writerow(alert)
125def _process_dataframe(
126 df: pd.DataFrame,
127 model: Any,
128 scaler: Any,
129 red_thr: float,
130 yellow_thr: float,
131 logger: Any,
132 csv_path: str,
133) -> None:
134 """Score flows in `df` and log any anomalies."""
135 # Score all flows to compute distribution
136 if not df.empty:
137 numeric_cols = [
138 "bidirectional_packets",
139 "bidirectional_bytes",
140 "mean_packet_size",
141 "std_packet_size",
142 "flow_duration",
143 "tcp_syn_count",
144 "tcp_fin_count",
145 "tcp_rst_count",
146 "iat_mean",
147 "iat_std",
148 "bytes_per_packet",
149 "packets_per_second",
150 ]
151 X = scaler.transform(df[numeric_cols].fillna(0.0).astype(float).values)
152 scores = model.decision_function(X)
153 min_s, max_s = float(min(scores)), float(max(scores))
154 mean_s = float(sum(scores) / len(scores))
155 # Optionally compute percentiles with numpy or statistics.quantiles
156 logger.info(
157 f"Score stats: count={len(scores)} min={min_s:.4f} "
158 f"mean={mean_s:.4f} max={max_s:.4f}"
159 )
161 alerts = list(_score_flows(model, scaler, df, red_thr, yellow_thr))
162 logger.info(f"Scored {len(df)} flows, produced {len(alerts)} alerts")
164 json_path = os.path.join(os.path.dirname(csv_path), "alerts.jsonl")
165 if alerts:
166 from .logging_utils import append_json_alert # local import to avoid cycles
167 for level, alert in alerts:
168 append_json_alert(json_path, **alert)
169 if level == "RED":
170 logger.error(
171 f"ANOMALY RED {alert['src_ip']}:{alert['src_port']} -> "
172 f"{alert['dst_ip']}:{alert['dst_port']} {alert['protocol'].upper()} "
173 f"score={alert['score']:.4f}"
174 )
175 else:
176 logger.warning(
177 f"Anomaly YELLOW {alert['src_ip']}:{alert['src_port']} -> "
178 f"{alert['dst_ip']}:{alert['dst_port']} {alert['protocol'].upper()} "
179 f"score={alert['score']:.4f}"
180 )
181 _write_alert_csv(alerts, csv_path)
182 logger.info(f"Wrote {len(alerts)} alert(s) to {csv_path} and {json_path}")
185# ------------------------------------------
186# Live detection via tshark (robust choice)
187# ------------------------------------------
188def _flows_to_df(
189 flows: Dict[Tuple[int, str, str, int, int, str], Dict[str, Any]],
190 feature_set: str,
191) -> pd.DataFrame:
192 """Build a pandas DF with the columns the model expects."""
193 rows: List[Dict[str, Any]] = []
194 for st in flows.values():
195 packets = int(st["packets"])
196 bytes_ = int(st["bytes"])
197 sizes = st["sizes"]
198 dur = max(1e-6, float(st["last_ts"] - st["first_ts"]))
199 mean_sz = float(sum(sizes) / packets) if packets > 0 else 0.0
200 std_sz = float((stats.pstdev(sizes) if len(sizes) > 1 else 0.0)) if packets > 0 else 0.0
201 iat = st["iat"]
202 iat_mean = float(sum(iat) / len(iat)) if iat else 0.0
203 iat_std = float(stats.pstdev(iat) if len(iat) > 1 else 0.0) if iat else 0.0
204 bpp = float(bytes_ / packets) if packets > 0 else 0.0
205 pps = float(packets / dur) if dur > 0 else float(packets)
207 row = {
208 "src_ip": st["src_ip"],
209 "dst_ip": st["dst_ip"],
210 "src_port": st["src_port"],
211 "dst_port": st["dst_port"],
212 "protocol": st["protocol"],
213 "bidirectional_packets": packets,
214 "bidirectional_bytes": bytes_,
215 "mean_packet_size": mean_sz,
216 "std_packet_size": std_sz,
217 "flow_duration": dur,
218 }
219 if feature_set == "extended":
220 row.update({
221 "tcp_syn_count": int(st["tcp_syn"]),
222 "tcp_fin_count": int(st["tcp_fin"]),
223 "tcp_rst_count": int(st["tcp_rst"]),
224 "iat_mean": iat_mean,
225 "iat_std": iat_std,
226 "bytes_per_packet": bpp,
227 "packets_per_second": pps,
228 })
229 rows.append(row)
230 return pd.DataFrame(rows)
233def _flag_to_int(v: str) -> int:
234 """Coerce tshark boolean-like outputs ('True'/'False', '1'/'0', '') to 0/1."""
235 if v is None:
236 return 0
237 s = str(v).strip().lower()
238 if s in ("1", "true", "t", "yes", "y"):
239 return 1
240 if s in ("0", "false", "f", "no", "n", ""):
241 return 0
242 # Some tshark versions might emit numeric or hex-y values in odd cases;
243 # treat any nonzero integer as True.
244 try:
245 return 1 if int(s) != 0 else 0
246 except Exception:
247 return 0
250def detect_live(
251 cfg: Dict[str, Any],
252 model: Any,
253 scaler: Any,
254 red_thr: float,
255 yellow_thr: float,
256 logger: Any,
257 alerts_csv: str,
258) -> None:
259 """Run live detection using tshark streaming."""
260 # Interface selection: ENV takes precedence (Windows/Docker convenience)
261 interface = os.getenv("IFACE") or os.getenv("IDS_IFACE") or cfg.get("iface", "eth0")
262 bpf_filter = cfg.get("bpf_filter", "tcp or udp")
263 window = int(cfg.get("window_seconds", 10))
264 feature_set = cfg.get("feature_set", "extended")
266 tshark_path = shutil.which("tshark")
267 if not tshark_path:
268 raise RuntimeError("tshark not found in PATH; cannot do live capture")
270 # Ensure alerts.jsonl exists (Promtail tail target)
271 json_path = os.path.join(os.path.dirname(alerts_csv), "alerts.jsonl")
272 os.makedirs(os.path.dirname(json_path), exist_ok=True)
273 open(json_path, "a").close()
275 cmd = [
276 tshark_path, "-n",
277 "-i", interface,
278 "-f", bpf_filter,
279 "-l",
280 "-T", "fields",
281 "-E", "separator=,",
282 "-E", "header=n",
283 "-E", "quote=n",
284 "-e", "frame.time_epoch",
285 "-e", "ip.src", "-e", "ip.dst",
286 "-e", "tcp.srcport", "-e", "tcp.dstport",
287 "-e", "udp.srcport", "-e", "udp.dstport",
288 "-e", "frame.len",
289 "-e", "tcp.flags.syn", "-e", "tcp.flags.fin", "-e", "tcp.flags.reset",
290 ]
292 logger.info(
293 f"Starting live capture on '{interface}' window={window}s filter=\"{bpf_filter}\"; "
294 f"tshark cmd={' '.join(cmd)}"
295 )
297 proc = subprocess.Popen(
298 cmd,
299 stdout=subprocess.PIPE,
300 stderr=subprocess.STDOUT, # show tshark errors in our logs
301 text=True,
302 bufsize=1,
303 universal_newlines=True,
304 )
306 flows: Dict[Tuple[int, str, str, int, int, str], Dict[str, Any]] = {}
307 base_ts: Optional[float] = None
308 current_win: Optional[int] = None
309 last_pkt_ts: Optional[float] = None
310 lock = threading.Lock()
311 stop = threading.Event()
313 # ingest stats
314 lines_total = 0
315 lines_parsed = 0
316 lines_skipped = 0
318 def _flush(older_only: bool = True) -> None:
319 """Flush buffered flows; if older_only, flush windows < current."""
320 nonlocal flows, current_win
321 with lock:
322 if not flows:
323 return
324 if older_only and current_win is not None:
325 done = {k: v for k, v in flows.items() if k[0] < current_win}
326 else:
327 done = dict(flows)
328 if not done:
329 return
330 for k in list(done.keys()):
331 flows.pop(k, None)
332 df = _flows_to_df(done, feature_set)
333 _process_dataframe(df, model, scaler, red_thr, yellow_thr, logger, alerts_csv)
335 def _flusher_loop() -> None:
336 idle_timeout = max(2, window) # flush everything if idle ≥ window
337 tick = max(1, window // 2) # run 2x per window
338 nonlocal lines_total, lines_parsed, lines_skipped
339 last_log = time.time()
340 while not stop.is_set():
341 time.sleep(tick)
342 with lock:
343 has_flows = bool(flows)
344 lp = last_pkt_ts
345 if has_flows:
346 _flush(older_only=True)
347 now = time.time()
348 if lp and (now - lp) >= idle_timeout:
349 _flush(older_only=False)
350 # periodic ingest stats
351 if now - last_log >= max(5, window):
352 logger.info(
353 f"Ingest stats: total_lines={lines_total} parsed={lines_parsed} "
354 f"skipped={lines_skipped} active_flows={len(flows)}"
355 )
356 last_log = now
358 flusher = threading.Thread(target=_flusher_loop, daemon=True)
359 flusher.start()
361 try:
362 assert proc.stdout is not None
363 for raw_line in proc.stdout:
364 line = raw_line.strip()
365 if not line:
366 continue
367 # TShark status/errors (keep visible)
368 if line.startswith("Capturing on"):
369 logger.info(line)
370 continue
371 if line.startswith("tshark:"):
372 logger.error(line)
373 continue
375 lines_total += 1
376 parts = line.split(",")
377 if len(parts) < 11:
378 lines_skipped += 1
379 continue
381 try:
382 ts = float(parts[0]) if parts[0] else None
383 ip_src = parts[1] or None
384 ip_dst = parts[2] or None
385 tcp_sp = parts[3]
386 tcp_dp = parts[4]
387 udp_sp = parts[5]
388 udp_dp = parts[6]
389 frame_len = int(parts[7] or 0)
390 syn = _flag_to_int(parts[8])
391 fin = _flag_to_int(parts[9])
392 rst = _flag_to_int(parts[10])
393 except Exception:
394 lines_skipped += 1
395 continue
397 if ts is None or ip_src is None or ip_dst is None:
398 lines_skipped += 1
399 continue
401 # Determine protocol & ports
402 proto = None
403 sp = dp = 0
404 if tcp_sp or tcp_dp:
405 proto = "tcp"
406 try:
407 sp = int(tcp_sp or 0)
408 dp = int(tcp_dp or 0)
409 except Exception:
410 lines_skipped += 1
411 continue
412 elif udp_sp or udp_dp:
413 proto = "udp"
414 try:
415 sp = int(udp_sp or 0)
416 dp = int(udp_dp or 0)
417 except Exception:
418 lines_skipped += 1
419 continue
420 else:
421 lines_skipped += 1
422 continue # non-TCP/UDP (shouldn't happen due to BPF)
424 with lock:
425 if base_ts is None:
426 base_ts = ts
427 win_idx = int((ts - base_ts) // window)
428 current_win = win_idx
429 last_pkt_ts = ts
431 key = (win_idx, ip_src, ip_dst, sp, dp, proto)
432 st = flows.get(key)
433 if st is None:
434 st = flows[key] = {
435 "src_ip": ip_src, "dst_ip": ip_dst,
436 "src_port": sp, "dst_port": dp,
437 "protocol": proto,
438 "packets": 0, "bytes": 0, "sizes": [],
439 "tcp_syn": 0, "tcp_fin": 0, "tcp_rst": 0,
440 "iat": [], "first_ts": ts, "last_ts": ts, "_last_ts": None,
441 }
442 st["packets"] += 1
443 st["bytes"] += frame_len
444 st["sizes"].append(frame_len)
445 if proto == "tcp":
446 st["tcp_syn"] += 1 if syn else 0
447 st["tcp_fin"] += 1 if fin else 0
448 st["tcp_rst"] += 1 if rst else 0
449 if st["_last_ts"] is not None:
450 st["iat"].append(max(0.0, ts - float(st["_last_ts"])))
451 st["_last_ts"] = ts
452 st["last_ts"] = ts
453 lines_parsed += 1
454 except KeyboardInterrupt:
455 logger.info("Live detection interrupted by user")
456 finally:
457 try:
458 stop.set()
459 flusher.join(timeout=1.0)
460 except Exception:
461 pass
462 try:
463 proc.terminate()
464 except Exception:
465 pass
466 try:
467 proc.wait(timeout=1.0)
468 except Exception:
469 pass
470 _flush(older_only=False) # final flush
473# -------------------------
474# CLI entry point
475# -------------------------
476def main() -> None:
477 ap = argparse.ArgumentParser(
478 description="Detect anomalies using a trained Isolation Forest model (live only)"
479 )
480 ap.add_argument("--config", default="config/config.yml", help="Path to configuration YAML file")
481 ap.add_argument("--model", help="Explicit model filename to load (overrides latest)")
482 ap.add_argument("--alerts-csv", default=None,
483 help="Path to alerts CSV; defaults to <logs_dir>/alerts.csv from config")
484 args = ap.parse_args()
486 cfg = load_config(args.config)
487 logger = get_logger("detect", cfg["logs_dir"], "detect.log")
488 model, scaler, _ = load_model(cfg["model_dir"], explicit_file=args.model)
489 red_thr, yellow_thr = load_thresholds(cfg["model_dir"])
491 # Optional env overrides for demo sensitivity
492 if os.getenv("IDS_RED_THRESHOLD"):
493 try:
494 red_thr = float(os.getenv("IDS_RED_THRESHOLD", "").strip())
495 logger.info(f"Overriding red_threshold via env: {red_thr}")
496 except Exception:
497 pass
498 if os.getenv("IDS_YELLOW_THRESHOLD"):
499 try:
500 yellow_thr = float(os.getenv("IDS_YELLOW_THRESHOLD", "").strip())
501 logger.info(f"Overriding yellow_threshold via env: {yellow_thr}")
502 except Exception:
503 pass
505 alerts_csv = args.alerts_csv or os.path.join(cfg["logs_dir"], "alerts.csv")
506 os.makedirs(os.path.dirname(alerts_csv), exist_ok=True)
507 open(os.path.join(os.path.dirname(alerts_csv), "alerts.jsonl"), "a").close()
509 detect_live(cfg, model, scaler, red_thr, yellow_thr, logger, alerts_csv)
512if __name__ == "__main__":
513 main()