Coverage for ids_iforest/detect.py: 61%

270 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 16:19 +0000

1"""Live anomaly detection using tshark (Docker/Windows friendly). 

2 

3- Capt all_numeric_cols = numeric_cols + extended 

4 X = scaler.transform(df[all_numeric_cols].fillna(0.0).astype(float).values) 

5 scores = model.decision_function(X) 

6 now_str = _dt.datetime.utcnow().isoformat()s packets via tshark in "fields" mode (no pyshark). 

7- Aggregates packets into time-windowed flows. 

8- Scores flows using a trained Isolation Forest + scaler. 

9- Writes anomalies to alerts.jsonl (Grafana/Promtail) and alerts.csv. 

10""" 

11 

12from __future__ import annotations 

13 

14import argparse 

15import csv 

16import datetime as _dt 

17import os 

18import shutil 

19import statistics as stats 

20import subprocess 

21import threading 

22import time 

23from typing import Optional, Tuple, Dict, Any, Iterable, List 

24 

25import pandas as pd # type: ignore 

26 

27from .utils import ( 

28 load_config, 

29 get_logger, 

30 load_model, 

31 load_thresholds, 

32) 

33 

34__all__ = ["main"] 

35 

36 

37# ---------------------------- 

38# Scoring & alert persistence 

39# ---------------------------- 

40def _score_flows( 

41 model: Any, 

42 scaler: Any, 

43 df: pd.DataFrame, 

44 red_thr: float, 

45 yellow_thr: float, 

46) -> Iterable[Tuple[str, Dict[str, Any]]]: 

47 """Score each flow and yield (level, alert_dict) for anomalies.""" 

48 if df.empty: 

49 return [] # type: ignore 

50 

51 numeric_cols = [ 

52 "bidirectional_packets", 

53 "bidirectional_bytes", 

54 "mean_packet_size", 

55 "std_packet_size", 

56 "flow_duration", 

57 ] 

58 extended = [ 

59 "tcp_syn_count", 

60 "tcp_fin_count", 

61 "tcp_rst_count", 

62 "iat_mean", 

63 "iat_std", 

64 "bytes_per_packet", 

65 "packets_per_second", 

66 ] 

67 for col in extended: 

68 if col in df.columns: 

69 numeric_cols.append(col) 

70 

71 X = scaler.transform(df[numeric_cols].fillna(0.0).astype(float).values) 

72 scores = model.decision_function(X) 

73 

74 # Use timezone-aware datetime objects for UTC time (fixing deprecation warning) 

75 try: 

76 # Python 3.11+ has UTC constant 

77 now_str = _dt.datetime.now(_dt.UTC).isoformat() 

78 except AttributeError: 

79 # Fallback for older Python versions 

80 now_str = _dt.datetime.now(_dt.timezone.utc).isoformat() 

81 

82 for idx, s in enumerate(scores): 

83 score_f = float(s) 

84 if score_f < yellow_thr: 

85 level = "RED" if score_f < red_thr else "YELLOW" 

86 row = df.iloc[idx] 

87 yield level, { 

88 "timestamp": now_str, 

89 "src_ip": row["src_ip"], 

90 "dst_ip": row["dst_ip"], 

91 "src_port": int(row["src_port"]), 

92 "dst_port": int(row["dst_port"]), 

93 "protocol": row["protocol"], 

94 "score": score_f, 

95 "level": level, 

96 } 

97 

98 

99def _write_alert_csv( 

100 alerts: Iterable[Tuple[str, Dict[str, Any]]], csv_path: str 

101) -> None: 

102 """Append alert rows to the CSV file (header on first write).""" 

103 os.makedirs(os.path.dirname(csv_path), exist_ok=True) 

104 exists = os.path.exists(csv_path) 

105 with open(csv_path, "a", newline="", encoding="utf-8") as f: 

106 writer = csv.DictWriter( 

107 f, 

108 fieldnames=[ 

109 "timestamp", 

110 "src_ip", 

111 "dst_ip", 

112 "src_port", 

113 "dst_port", 

114 "protocol", 

115 "score", 

116 "level", 

117 ], 

118 ) 

119 if not exists: 

120 writer.writeheader() 

121 for _, alert in alerts: 

122 writer.writerow(alert) 

123 

124 

125def _process_dataframe( 

126 df: pd.DataFrame, 

127 model: Any, 

128 scaler: Any, 

129 red_thr: float, 

130 yellow_thr: float, 

131 logger: Any, 

132 csv_path: str, 

133) -> None: 

134 """Score flows in `df` and log any anomalies.""" 

135 # Score all flows to compute distribution 

136 if not df.empty: 

137 numeric_cols = [ 

138 "bidirectional_packets", 

139 "bidirectional_bytes", 

140 "mean_packet_size", 

141 "std_packet_size", 

142 "flow_duration", 

143 "tcp_syn_count", 

144 "tcp_fin_count", 

145 "tcp_rst_count", 

146 "iat_mean", 

147 "iat_std", 

148 "bytes_per_packet", 

149 "packets_per_second", 

150 ] 

151 X = scaler.transform(df[numeric_cols].fillna(0.0).astype(float).values) 

152 scores = model.decision_function(X) 

153 min_s, max_s = float(min(scores)), float(max(scores)) 

154 mean_s = float(sum(scores) / len(scores)) 

155 # Optionally compute percentiles with numpy or statistics.quantiles 

156 logger.info( 

157 f"Score stats: count={len(scores)} min={min_s:.4f} " 

158 f"mean={mean_s:.4f} max={max_s:.4f}" 

159 ) 

160 

161 alerts = list(_score_flows(model, scaler, df, red_thr, yellow_thr)) 

162 logger.info(f"Scored {len(df)} flows, produced {len(alerts)} alerts") 

163 

164 json_path = os.path.join(os.path.dirname(csv_path), "alerts.jsonl") 

165 if alerts: 

166 from .logging_utils import append_json_alert # local import to avoid cycles 

167 for level, alert in alerts: 

168 append_json_alert(json_path, **alert) 

169 if level == "RED": 

170 logger.error( 

171 f"ANOMALY RED {alert['src_ip']}:{alert['src_port']} -> " 

172 f"{alert['dst_ip']}:{alert['dst_port']} {alert['protocol'].upper()} " 

173 f"score={alert['score']:.4f}" 

174 ) 

175 else: 

176 logger.warning( 

177 f"Anomaly YELLOW {alert['src_ip']}:{alert['src_port']} -> " 

178 f"{alert['dst_ip']}:{alert['dst_port']} {alert['protocol'].upper()} " 

179 f"score={alert['score']:.4f}" 

180 ) 

181 _write_alert_csv(alerts, csv_path) 

182 logger.info(f"Wrote {len(alerts)} alert(s) to {csv_path} and {json_path}") 

183 

184 

185# ------------------------------------------ 

186# Live detection via tshark (robust choice) 

187# ------------------------------------------ 

188def _flows_to_df( 

189 flows: Dict[Tuple[int, str, str, int, int, str], Dict[str, Any]], 

190 feature_set: str, 

191) -> pd.DataFrame: 

192 """Build a pandas DF with the columns the model expects.""" 

193 rows: List[Dict[str, Any]] = [] 

194 for st in flows.values(): 

195 packets = int(st["packets"]) 

196 bytes_ = int(st["bytes"]) 

197 sizes = st["sizes"] 

198 dur = max(1e-6, float(st["last_ts"] - st["first_ts"])) 

199 mean_sz = float(sum(sizes) / packets) if packets > 0 else 0.0 

200 std_sz = float((stats.pstdev(sizes) if len(sizes) > 1 else 0.0)) if packets > 0 else 0.0 

201 iat = st["iat"] 

202 iat_mean = float(sum(iat) / len(iat)) if iat else 0.0 

203 iat_std = float(stats.pstdev(iat) if len(iat) > 1 else 0.0) if iat else 0.0 

204 bpp = float(bytes_ / packets) if packets > 0 else 0.0 

205 pps = float(packets / dur) if dur > 0 else float(packets) 

206 

207 row = { 

208 "src_ip": st["src_ip"], 

209 "dst_ip": st["dst_ip"], 

210 "src_port": st["src_port"], 

211 "dst_port": st["dst_port"], 

212 "protocol": st["protocol"], 

213 "bidirectional_packets": packets, 

214 "bidirectional_bytes": bytes_, 

215 "mean_packet_size": mean_sz, 

216 "std_packet_size": std_sz, 

217 "flow_duration": dur, 

218 } 

219 if feature_set == "extended": 

220 row.update({ 

221 "tcp_syn_count": int(st["tcp_syn"]), 

222 "tcp_fin_count": int(st["tcp_fin"]), 

223 "tcp_rst_count": int(st["tcp_rst"]), 

224 "iat_mean": iat_mean, 

225 "iat_std": iat_std, 

226 "bytes_per_packet": bpp, 

227 "packets_per_second": pps, 

228 }) 

229 rows.append(row) 

230 return pd.DataFrame(rows) 

231 

232 

233def _flag_to_int(v: str) -> int: 

234 """Coerce tshark boolean-like outputs ('True'/'False', '1'/'0', '') to 0/1.""" 

235 if v is None: 

236 return 0 

237 s = str(v).strip().lower() 

238 if s in ("1", "true", "t", "yes", "y"): 

239 return 1 

240 if s in ("0", "false", "f", "no", "n", ""): 

241 return 0 

242 # Some tshark versions might emit numeric or hex-y values in odd cases; 

243 # treat any nonzero integer as True. 

244 try: 

245 return 1 if int(s) != 0 else 0 

246 except Exception: 

247 return 0 

248 

249 

250def detect_live( 

251 cfg: Dict[str, Any], 

252 model: Any, 

253 scaler: Any, 

254 red_thr: float, 

255 yellow_thr: float, 

256 logger: Any, 

257 alerts_csv: str, 

258) -> None: 

259 """Run live detection using tshark streaming.""" 

260 # Interface selection: ENV takes precedence (Windows/Docker convenience) 

261 interface = os.getenv("IFACE") or os.getenv("IDS_IFACE") or cfg.get("iface", "eth0") 

262 bpf_filter = cfg.get("bpf_filter", "tcp or udp") 

263 window = int(cfg.get("window_seconds", 10)) 

264 feature_set = cfg.get("feature_set", "extended") 

265 

266 tshark_path = shutil.which("tshark") 

267 if not tshark_path: 

268 raise RuntimeError("tshark not found in PATH; cannot do live capture") 

269 

270 # Ensure alerts.jsonl exists (Promtail tail target) 

271 json_path = os.path.join(os.path.dirname(alerts_csv), "alerts.jsonl") 

272 os.makedirs(os.path.dirname(json_path), exist_ok=True) 

273 open(json_path, "a").close() 

274 

275 cmd = [ 

276 tshark_path, "-n", 

277 "-i", interface, 

278 "-f", bpf_filter, 

279 "-l", 

280 "-T", "fields", 

281 "-E", "separator=,", 

282 "-E", "header=n", 

283 "-E", "quote=n", 

284 "-e", "frame.time_epoch", 

285 "-e", "ip.src", "-e", "ip.dst", 

286 "-e", "tcp.srcport", "-e", "tcp.dstport", 

287 "-e", "udp.srcport", "-e", "udp.dstport", 

288 "-e", "frame.len", 

289 "-e", "tcp.flags.syn", "-e", "tcp.flags.fin", "-e", "tcp.flags.reset", 

290 ] 

291 

292 logger.info( 

293 f"Starting live capture on '{interface}' window={window}s filter=\"{bpf_filter}\"; " 

294 f"tshark cmd={' '.join(cmd)}" 

295 ) 

296 

297 proc = subprocess.Popen( 

298 cmd, 

299 stdout=subprocess.PIPE, 

300 stderr=subprocess.STDOUT, # show tshark errors in our logs 

301 text=True, 

302 bufsize=1, 

303 universal_newlines=True, 

304 ) 

305 

306 flows: Dict[Tuple[int, str, str, int, int, str], Dict[str, Any]] = {} 

307 base_ts: Optional[float] = None 

308 current_win: Optional[int] = None 

309 last_pkt_ts: Optional[float] = None 

310 lock = threading.Lock() 

311 stop = threading.Event() 

312 

313 # ingest stats 

314 lines_total = 0 

315 lines_parsed = 0 

316 lines_skipped = 0 

317 

318 def _flush(older_only: bool = True) -> None: 

319 """Flush buffered flows; if older_only, flush windows < current.""" 

320 nonlocal flows, current_win 

321 with lock: 

322 if not flows: 

323 return 

324 if older_only and current_win is not None: 

325 done = {k: v for k, v in flows.items() if k[0] < current_win} 

326 else: 

327 done = dict(flows) 

328 if not done: 

329 return 

330 for k in list(done.keys()): 

331 flows.pop(k, None) 

332 df = _flows_to_df(done, feature_set) 

333 _process_dataframe(df, model, scaler, red_thr, yellow_thr, logger, alerts_csv) 

334 

335 def _flusher_loop() -> None: 

336 idle_timeout = max(2, window) # flush everything if idle ≥ window 

337 tick = max(1, window // 2) # run 2x per window 

338 nonlocal lines_total, lines_parsed, lines_skipped 

339 last_log = time.time() 

340 while not stop.is_set(): 

341 time.sleep(tick) 

342 with lock: 

343 has_flows = bool(flows) 

344 lp = last_pkt_ts 

345 if has_flows: 

346 _flush(older_only=True) 

347 now = time.time() 

348 if lp and (now - lp) >= idle_timeout: 

349 _flush(older_only=False) 

350 # periodic ingest stats 

351 if now - last_log >= max(5, window): 

352 logger.info( 

353 f"Ingest stats: total_lines={lines_total} parsed={lines_parsed} " 

354 f"skipped={lines_skipped} active_flows={len(flows)}" 

355 ) 

356 last_log = now 

357 

358 flusher = threading.Thread(target=_flusher_loop, daemon=True) 

359 flusher.start() 

360 

361 try: 

362 assert proc.stdout is not None 

363 for raw_line in proc.stdout: 

364 line = raw_line.strip() 

365 if not line: 

366 continue 

367 # TShark status/errors (keep visible) 

368 if line.startswith("Capturing on"): 

369 logger.info(line) 

370 continue 

371 if line.startswith("tshark:"): 

372 logger.error(line) 

373 continue 

374 

375 lines_total += 1 

376 parts = line.split(",") 

377 if len(parts) < 11: 

378 lines_skipped += 1 

379 continue 

380 

381 try: 

382 ts = float(parts[0]) if parts[0] else None 

383 ip_src = parts[1] or None 

384 ip_dst = parts[2] or None 

385 tcp_sp = parts[3] 

386 tcp_dp = parts[4] 

387 udp_sp = parts[5] 

388 udp_dp = parts[6] 

389 frame_len = int(parts[7] or 0) 

390 syn = _flag_to_int(parts[8]) 

391 fin = _flag_to_int(parts[9]) 

392 rst = _flag_to_int(parts[10]) 

393 except Exception: 

394 lines_skipped += 1 

395 continue 

396 

397 if ts is None or ip_src is None or ip_dst is None: 

398 lines_skipped += 1 

399 continue 

400 

401 # Determine protocol & ports 

402 proto = None 

403 sp = dp = 0 

404 if tcp_sp or tcp_dp: 

405 proto = "tcp" 

406 try: 

407 sp = int(tcp_sp or 0) 

408 dp = int(tcp_dp or 0) 

409 except Exception: 

410 lines_skipped += 1 

411 continue 

412 elif udp_sp or udp_dp: 

413 proto = "udp" 

414 try: 

415 sp = int(udp_sp or 0) 

416 dp = int(udp_dp or 0) 

417 except Exception: 

418 lines_skipped += 1 

419 continue 

420 else: 

421 lines_skipped += 1 

422 continue # non-TCP/UDP (shouldn't happen due to BPF) 

423 

424 with lock: 

425 if base_ts is None: 

426 base_ts = ts 

427 win_idx = int((ts - base_ts) // window) 

428 current_win = win_idx 

429 last_pkt_ts = ts 

430 

431 key = (win_idx, ip_src, ip_dst, sp, dp, proto) 

432 st = flows.get(key) 

433 if st is None: 

434 st = flows[key] = { 

435 "src_ip": ip_src, "dst_ip": ip_dst, 

436 "src_port": sp, "dst_port": dp, 

437 "protocol": proto, 

438 "packets": 0, "bytes": 0, "sizes": [], 

439 "tcp_syn": 0, "tcp_fin": 0, "tcp_rst": 0, 

440 "iat": [], "first_ts": ts, "last_ts": ts, "_last_ts": None, 

441 } 

442 st["packets"] += 1 

443 st["bytes"] += frame_len 

444 st["sizes"].append(frame_len) 

445 if proto == "tcp": 

446 st["tcp_syn"] += 1 if syn else 0 

447 st["tcp_fin"] += 1 if fin else 0 

448 st["tcp_rst"] += 1 if rst else 0 

449 if st["_last_ts"] is not None: 

450 st["iat"].append(max(0.0, ts - float(st["_last_ts"]))) 

451 st["_last_ts"] = ts 

452 st["last_ts"] = ts 

453 lines_parsed += 1 

454 except KeyboardInterrupt: 

455 logger.info("Live detection interrupted by user") 

456 finally: 

457 try: 

458 stop.set() 

459 flusher.join(timeout=1.0) 

460 except Exception: 

461 pass 

462 try: 

463 proc.terminate() 

464 except Exception: 

465 pass 

466 try: 

467 proc.wait(timeout=1.0) 

468 except Exception: 

469 pass 

470 _flush(older_only=False) # final flush 

471 

472 

473# ------------------------- 

474# CLI entry point 

475# ------------------------- 

476def main() -> None: 

477 ap = argparse.ArgumentParser( 

478 description="Detect anomalies using a trained Isolation Forest model (live only)" 

479 ) 

480 ap.add_argument("--config", default="config/config.yml", help="Path to configuration YAML file") 

481 ap.add_argument("--model", help="Explicit model filename to load (overrides latest)") 

482 ap.add_argument("--alerts-csv", default=None, 

483 help="Path to alerts CSV; defaults to <logs_dir>/alerts.csv from config") 

484 args = ap.parse_args() 

485 

486 cfg = load_config(args.config) 

487 logger = get_logger("detect", cfg["logs_dir"], "detect.log") 

488 model, scaler, _ = load_model(cfg["model_dir"], explicit_file=args.model) 

489 red_thr, yellow_thr = load_thresholds(cfg["model_dir"]) 

490 

491 # Optional env overrides for demo sensitivity 

492 if os.getenv("IDS_RED_THRESHOLD"): 

493 try: 

494 red_thr = float(os.getenv("IDS_RED_THRESHOLD", "").strip()) 

495 logger.info(f"Overriding red_threshold via env: {red_thr}") 

496 except Exception: 

497 pass 

498 if os.getenv("IDS_YELLOW_THRESHOLD"): 

499 try: 

500 yellow_thr = float(os.getenv("IDS_YELLOW_THRESHOLD", "").strip()) 

501 logger.info(f"Overriding yellow_threshold via env: {yellow_thr}") 

502 except Exception: 

503 pass 

504 

505 alerts_csv = args.alerts_csv or os.path.join(cfg["logs_dir"], "alerts.csv") 

506 os.makedirs(os.path.dirname(alerts_csv), exist_ok=True) 

507 open(os.path.join(os.path.dirname(alerts_csv), "alerts.jsonl"), "a").close() 

508 

509 detect_live(cfg, model, scaler, red_thr, yellow_thr, logger, alerts_csv) 

510 

511 

512if __name__ == "__main__": 

513 main()