Coverage for ids_iforest/utils.py: 19%

167 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 16:19 +0000

1"""Utility functions for the ids_iforest package (live-only workflow).""" 

2 

3from __future__ import annotations 

4 

5import os 

6import json 

7import yaml 

8import glob 

9import ipaddress 

10import logging 

11import subprocess 

12import sys 

13from dataclasses import dataclass 

14from typing import Dict, Any, Tuple, Optional, Iterable, List 

15 

16import numpy as np 

17import pandas as pd 

18import joblib 

19 

20try: 

21 from colorama import Fore, Style, init as colorama_init # type: ignore 

22except Exception: 

23 class _Dummy: 

24 def __getattr__(self, name: str) -> str: 

25 return "" 

26 Fore = Style = _Dummy() # type: ignore[assignment] 

27 def colorama_init(autoreset: bool = True) -> None: 

28 return None 

29 

30__all__ = [ 

31 "load_config", 

32 "ensure_dirs", 

33 "get_logger", 

34 "get_git_hash", 

35 "save_model", 

36 "load_model", 

37 "load_thresholds", 

38 "canonical_5tuple", 

39 "LEVEL_COLOR", 

40] 

41 

42 

43def load_config(path: str) -> Dict[str, Any]: 

44 """Load YAML config and resolve paths; allow ENV overrides for Docker.""" 

45 with open(path, "r", encoding="utf-8") as f: 

46 cfg: Dict[str, Any] = yaml.safe_load(f) or {} 

47 

48 # Defaults 

49 cfg.setdefault("window_seconds", 5) 

50 cfg.setdefault("bpf_filter", "tcp or udp") 

51 cfg.setdefault("feature_set", "extended") 

52 cfg.setdefault("model_dir", "./models") 

53 cfg.setdefault("logs_dir", "./logs") 

54 cfg.setdefault("iface", "eth0") 

55 

56 base = os.path.dirname(os.path.abspath(path)) 

57 

58 # ENV overrides (highest precedence) 

59 env_model = os.getenv("IDS_MODEL_DIR") 

60 env_logs = os.getenv("IDS_LOGS_DIR") 

61 env_iface = os.getenv("IFACE") or os.getenv("IDS_IFACE") 

62 if env_model: 

63 cfg["model_dir"] = env_model 

64 if env_logs: 

65 cfg["logs_dir"] = env_logs 

66 if env_iface: 

67 cfg["iface"] = env_iface 

68 

69 # Resolve relative paths 

70 for key in ("model_dir", "logs_dir"): 

71 val = cfg.get(key) 

72 if isinstance(val, str) and not os.path.isabs(val): 

73 cfg[key] = os.path.abspath(os.path.join(base, val)) 

74 

75 # Try to ensure directories are usable; fallbacks if needed 

76 def _writable(target: str) -> bool: 

77 try: 

78 os.makedirs(target, exist_ok=True) 

79 test_file = os.path.join(target, ".writetest") 

80 with open(test_file, "w", encoding="utf-8") as tf: 

81 tf.write("ok") 

82 os.remove(test_file) 

83 return True 

84 except Exception: 

85 return False 

86 

87 in_container = os.path.exists("/.dockerenv") or os.path.isfile("/proc/1/cgroup") 

88 

89 for key, default_rel in (("model_dir", "models"), ("logs_dir", "logs")): 

90 cur = cfg.get(key) 

91 if not isinstance(cur, str): 

92 continue 

93 # model_dir may be read-only but acceptable if it already contains a model 

94 if key == "model_dir": 

95 try: 

96 if os.path.isdir(cur): 

97 latest = os.path.join(cur, "ids_iforest_latest.joblib") 

98 any_model = bool( 

99 os.path.exists(latest) 

100 or glob.glob(os.path.join(cur, "ids_iforest_*.joblib")) 

101 ) 

102 if any_model and os.access(cur, os.R_OK): 

103 continue 

104 except Exception: 

105 pass 

106 if _writable(cur): 

107 continue 

108 candidates: List[str] = [] 

109 if in_container: 

110 candidates.append(f"/app/{default_rel}") 

111 candidates.append(os.path.abspath(os.path.join(os.getcwd(), default_rel))) 

112 home = os.path.expanduser("~") 

113 candidates.append(os.path.join(home, ".ids_iforest", default_rel)) 

114 for cand in candidates: 

115 if _writable(cand): 

116 fb = cfg.setdefault("_path_fallbacks", {}) # type: ignore[assignment] 

117 fb[key] = {"original": cur, "chosen": cand} 

118 print(f"[ids-iforest] WARNING: {key} '{cur}' not writable; using '{cand}'") 

119 cfg[key] = cand 

120 break 

121 

122 return cfg 

123 

124 

125def ensure_dirs(*paths: str) -> None: 

126 for p in paths: 

127 os.makedirs(p, exist_ok=True) 

128 

129 

130def get_logger(name: str, logs_dir: str, base_filename: str) -> logging.Logger: 

131 """Logger that writes to file and stdout (Docker-friendly).""" 

132 ensure_dirs(logs_dir) 

133 logger = logging.getLogger(name) 

134 logger.setLevel(logging.INFO) 

135 if not logger.handlers: 

136 log_path = os.path.join(logs_dir, base_filename) 

137 fh = logging.FileHandler(log_path, encoding="utf-8") 

138 fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") 

139 fh.setFormatter(fmt) 

140 logger.addHandler(fh) 

141 sh = logging.StreamHandler(sys.stdout) 

142 sh.setFormatter(fmt) 

143 logger.addHandler(sh) 

144 return logger 

145 

146 

147def get_git_hash(short: bool = True) -> str: 

148 env_sha = os.getenv("CI_COMMIT_SHA") 

149 if env_sha: 

150 return env_sha[:8] if short else env_sha 

151 try: 

152 sha = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL).decode().strip() 

153 return sha[:8] if short else sha 

154 except Exception: 

155 return "unknown" 

156 

157 

158def save_model(model: Any, scaler: Any, model_dir: str) -> Tuple[str, str]: 

159 ensure_dirs(model_dir) 

160 git_hash = get_git_hash() 

161 model_path = os.path.join(model_dir, f"ids_iforest_{git_hash}.joblib") 

162 joblib.dump({"model": model, "scaler": scaler}, model_path) 

163 latest_path = os.path.join(model_dir, "ids_iforest_latest.joblib") 

164 try: 

165 if os.path.exists(latest_path): 

166 os.remove(latest_path) 

167 except Exception: 

168 pass 

169 try: 

170 import shutil 

171 shutil.copyfile(model_path, latest_path) 

172 except Exception: 

173 pass 

174 return model_path, latest_path 

175 

176 

177def load_model(model_dir: str, explicit_file: Optional[str] = None) -> Tuple[Any, Any, str]: 

178 path: Optional[str] = None 

179 if explicit_file: 

180 path = explicit_file if os.path.isabs(explicit_file) else os.path.join(model_dir, explicit_file) 

181 else: 

182 latest = os.path.join(model_dir, "ids_iforest_latest.joblib") 

183 if os.path.exists(latest): 

184 path = latest 

185 else: 

186 cands = sorted(glob.glob(os.path.join(model_dir, "ids_iforest_*.joblib")), key=os.path.getmtime) 

187 path = cands[-1] if cands else None 

188 if not path or not os.path.exists(path): 

189 raise FileNotFoundError( 

190 f"No model found in {model_dir}. Train first with ids-iforest-train." 

191 ) 

192 payload = joblib.load(path) 

193 return payload["model"], payload["scaler"], path 

194 

195 

196def load_thresholds(model_dir: str) -> Tuple[float, float]: 

197 path = os.path.join(model_dir, "thresholds.json") 

198 try: 

199 with open(path, "r", encoding="utf-8") as f: 

200 data = json.load(f) 

201 red = float(data.get("red_threshold", -0.25)) 

202 yellow = float(data.get("yellow_threshold", 0.0)) 

203 return red, yellow 

204 except Exception: 

205 return -0.25, 0.0 

206 

207 

208@dataclass(frozen=True) 

209class Endpoint: 

210 ip: str 

211 port: int 

212 

213 

214def _endpoint_order(ep: Endpoint) -> Tuple[int, bytes, int]: 

215 ip_obj = ipaddress.ip_address(ep.ip) 

216 return (ip_obj.version, ip_obj.packed, ep.port) 

217 

218 

219def canonical_5tuple( 

220 src_ip: str, src_port: int, dst_ip: str, dst_port: int, proto: str 

221) -> Tuple[Endpoint, Endpoint, str]: 

222 a = Endpoint(src_ip, int(src_port)) 

223 b = Endpoint(dst_ip, int(dst_port)) 

224 a1, a2 = (a, b) if _endpoint_order(a) <= _endpoint_order(b) else (b, a) 

225 return a1, a2, proto.lower() 

226 

227 

228# Colourised levels (console). Gracefully degrade when colorama is absent. 

229colorama_init(autoreset=True) 

230LEVEL_COLOR: Dict[str, str] = { 

231 "GREEN": Fore.GREEN + "GREEN" + Style.RESET_ALL, 

232 "YELLOW": Fore.YELLOW + "YELLOW" + Style.RESET_ALL, 

233 "RED": Fore.RED + "RED" + Style.RESET_ALL, 

234}