Coverage for ids_iforest/utils.py: 19%
167 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 16:19 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 16:19 +0000
1"""Utility functions for the ids_iforest package (live-only workflow)."""
3from __future__ import annotations
5import os
6import json
7import yaml
8import glob
9import ipaddress
10import logging
11import subprocess
12import sys
13from dataclasses import dataclass
14from typing import Dict, Any, Tuple, Optional, Iterable, List
16import numpy as np
17import pandas as pd
18import joblib
20try:
21 from colorama import Fore, Style, init as colorama_init # type: ignore
22except Exception:
23 class _Dummy:
24 def __getattr__(self, name: str) -> str:
25 return ""
26 Fore = Style = _Dummy() # type: ignore[assignment]
27 def colorama_init(autoreset: bool = True) -> None:
28 return None
30__all__ = [
31 "load_config",
32 "ensure_dirs",
33 "get_logger",
34 "get_git_hash",
35 "save_model",
36 "load_model",
37 "load_thresholds",
38 "canonical_5tuple",
39 "LEVEL_COLOR",
40]
43def load_config(path: str) -> Dict[str, Any]:
44 """Load YAML config and resolve paths; allow ENV overrides for Docker."""
45 with open(path, "r", encoding="utf-8") as f:
46 cfg: Dict[str, Any] = yaml.safe_load(f) or {}
48 # Defaults
49 cfg.setdefault("window_seconds", 5)
50 cfg.setdefault("bpf_filter", "tcp or udp")
51 cfg.setdefault("feature_set", "extended")
52 cfg.setdefault("model_dir", "./models")
53 cfg.setdefault("logs_dir", "./logs")
54 cfg.setdefault("iface", "eth0")
56 base = os.path.dirname(os.path.abspath(path))
58 # ENV overrides (highest precedence)
59 env_model = os.getenv("IDS_MODEL_DIR")
60 env_logs = os.getenv("IDS_LOGS_DIR")
61 env_iface = os.getenv("IFACE") or os.getenv("IDS_IFACE")
62 if env_model:
63 cfg["model_dir"] = env_model
64 if env_logs:
65 cfg["logs_dir"] = env_logs
66 if env_iface:
67 cfg["iface"] = env_iface
69 # Resolve relative paths
70 for key in ("model_dir", "logs_dir"):
71 val = cfg.get(key)
72 if isinstance(val, str) and not os.path.isabs(val):
73 cfg[key] = os.path.abspath(os.path.join(base, val))
75 # Try to ensure directories are usable; fallbacks if needed
76 def _writable(target: str) -> bool:
77 try:
78 os.makedirs(target, exist_ok=True)
79 test_file = os.path.join(target, ".writetest")
80 with open(test_file, "w", encoding="utf-8") as tf:
81 tf.write("ok")
82 os.remove(test_file)
83 return True
84 except Exception:
85 return False
87 in_container = os.path.exists("/.dockerenv") or os.path.isfile("/proc/1/cgroup")
89 for key, default_rel in (("model_dir", "models"), ("logs_dir", "logs")):
90 cur = cfg.get(key)
91 if not isinstance(cur, str):
92 continue
93 # model_dir may be read-only but acceptable if it already contains a model
94 if key == "model_dir":
95 try:
96 if os.path.isdir(cur):
97 latest = os.path.join(cur, "ids_iforest_latest.joblib")
98 any_model = bool(
99 os.path.exists(latest)
100 or glob.glob(os.path.join(cur, "ids_iforest_*.joblib"))
101 )
102 if any_model and os.access(cur, os.R_OK):
103 continue
104 except Exception:
105 pass
106 if _writable(cur):
107 continue
108 candidates: List[str] = []
109 if in_container:
110 candidates.append(f"/app/{default_rel}")
111 candidates.append(os.path.abspath(os.path.join(os.getcwd(), default_rel)))
112 home = os.path.expanduser("~")
113 candidates.append(os.path.join(home, ".ids_iforest", default_rel))
114 for cand in candidates:
115 if _writable(cand):
116 fb = cfg.setdefault("_path_fallbacks", {}) # type: ignore[assignment]
117 fb[key] = {"original": cur, "chosen": cand}
118 print(f"[ids-iforest] WARNING: {key} '{cur}' not writable; using '{cand}'")
119 cfg[key] = cand
120 break
122 return cfg
125def ensure_dirs(*paths: str) -> None:
126 for p in paths:
127 os.makedirs(p, exist_ok=True)
130def get_logger(name: str, logs_dir: str, base_filename: str) -> logging.Logger:
131 """Logger that writes to file and stdout (Docker-friendly)."""
132 ensure_dirs(logs_dir)
133 logger = logging.getLogger(name)
134 logger.setLevel(logging.INFO)
135 if not logger.handlers:
136 log_path = os.path.join(logs_dir, base_filename)
137 fh = logging.FileHandler(log_path, encoding="utf-8")
138 fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
139 fh.setFormatter(fmt)
140 logger.addHandler(fh)
141 sh = logging.StreamHandler(sys.stdout)
142 sh.setFormatter(fmt)
143 logger.addHandler(sh)
144 return logger
147def get_git_hash(short: bool = True) -> str:
148 env_sha = os.getenv("CI_COMMIT_SHA")
149 if env_sha:
150 return env_sha[:8] if short else env_sha
151 try:
152 sha = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL).decode().strip()
153 return sha[:8] if short else sha
154 except Exception:
155 return "unknown"
158def save_model(model: Any, scaler: Any, model_dir: str) -> Tuple[str, str]:
159 ensure_dirs(model_dir)
160 git_hash = get_git_hash()
161 model_path = os.path.join(model_dir, f"ids_iforest_{git_hash}.joblib")
162 joblib.dump({"model": model, "scaler": scaler}, model_path)
163 latest_path = os.path.join(model_dir, "ids_iforest_latest.joblib")
164 try:
165 if os.path.exists(latest_path):
166 os.remove(latest_path)
167 except Exception:
168 pass
169 try:
170 import shutil
171 shutil.copyfile(model_path, latest_path)
172 except Exception:
173 pass
174 return model_path, latest_path
177def load_model(model_dir: str, explicit_file: Optional[str] = None) -> Tuple[Any, Any, str]:
178 path: Optional[str] = None
179 if explicit_file:
180 path = explicit_file if os.path.isabs(explicit_file) else os.path.join(model_dir, explicit_file)
181 else:
182 latest = os.path.join(model_dir, "ids_iforest_latest.joblib")
183 if os.path.exists(latest):
184 path = latest
185 else:
186 cands = sorted(glob.glob(os.path.join(model_dir, "ids_iforest_*.joblib")), key=os.path.getmtime)
187 path = cands[-1] if cands else None
188 if not path or not os.path.exists(path):
189 raise FileNotFoundError(
190 f"No model found in {model_dir}. Train first with ids-iforest-train."
191 )
192 payload = joblib.load(path)
193 return payload["model"], payload["scaler"], path
196def load_thresholds(model_dir: str) -> Tuple[float, float]:
197 path = os.path.join(model_dir, "thresholds.json")
198 try:
199 with open(path, "r", encoding="utf-8") as f:
200 data = json.load(f)
201 red = float(data.get("red_threshold", -0.25))
202 yellow = float(data.get("yellow_threshold", 0.0))
203 return red, yellow
204 except Exception:
205 return -0.25, 0.0
208@dataclass(frozen=True)
209class Endpoint:
210 ip: str
211 port: int
214def _endpoint_order(ep: Endpoint) -> Tuple[int, bytes, int]:
215 ip_obj = ipaddress.ip_address(ep.ip)
216 return (ip_obj.version, ip_obj.packed, ep.port)
219def canonical_5tuple(
220 src_ip: str, src_port: int, dst_ip: str, dst_port: int, proto: str
221) -> Tuple[Endpoint, Endpoint, str]:
222 a = Endpoint(src_ip, int(src_port))
223 b = Endpoint(dst_ip, int(dst_port))
224 a1, a2 = (a, b) if _endpoint_order(a) <= _endpoint_order(b) else (b, a)
225 return a1, a2, proto.lower()
228# Colourised levels (console). Gracefully degrade when colorama is absent.
229colorama_init(autoreset=True)
230LEVEL_COLOR: Dict[str, str] = {
231 "GREEN": Fore.GREEN + "GREEN" + Style.RESET_ALL,
232 "YELLOW": Fore.YELLOW + "YELLOW" + Style.RESET_ALL,
233 "RED": Fore.RED + "RED" + Style.RESET_ALL,
234}