Coverage for ids_iforest/train.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 16:19 +0000

1"""Train an Isolation Forest model on a flows CSV dataset. 

2 

3This module defines a ``train`` function that takes a CSV file of 

4network flow features, a configuration file and an output directory. 

5It normalises the numeric features with a ``StandardScaler`` and 

6trains an ``IsolationForest`` model. If labels are present, the 

7contamination parameter is calibrated by performing a simple grid 

8search over a handful of candidate contamination rates. Synthetic 

9outliers can optionally be injected into the training data to help the 

10model learn to assign lower scores to extreme points. The trained 

11model and scaler are persisted via joblib, and alert thresholds are 

12saved to ``thresholds.json`` in the output directory. 

13 

14This module exposes a ``main`` function so it can be used as a console 

15script (see ``pyproject.toml``) via ``ids-iforest-train``. 

16""" 

17 

18from __future__ import annotations 

19 

20import argparse 

21import json 

22import os 

23from typing import List, Optional 

24 

25import numpy as np 

26import pandas as pd 

27from sklearn.ensemble import IsolationForest 

28from sklearn.preprocessing import StandardScaler 

29from sklearn.model_selection import train_test_split 

30 

31from .utils import load_config, get_logger, ensure_dirs, save_model, get_git_hash 

32 

33__all__ = ["train", "main"] 

34 

35 

36MINIMAL_COLS: List[str] = [ 

37 "window", 

38 "src_ip", 

39 "dst_ip", 

40 "src_port", 

41 "dst_port", 

42 "protocol", 

43 "bidirectional_packets", 

44 "bidirectional_bytes", 

45 "mean_packet_size", 

46 "std_packet_size", 

47 "flow_duration", 

48] 

49EXTENDED_COLS: List[str] = MINIMAL_COLS + [ 

50 "tcp_syn_count", 

51 "tcp_fin_count", 

52 "tcp_rst_count", 

53 "iat_mean", 

54 "iat_std", 

55 "bytes_per_packet", 

56 "packets_per_second", 

57] 

58 

59 

60def _select_numeric(df: pd.DataFrame, feature_set: str) -> pd.DataFrame: 

61 """Return only the numeric columns relevant for Isolation Forest training.""" 

62 cols = [ 

63 "bidirectional_packets", 

64 "bidirectional_bytes", 

65 "mean_packet_size", 

66 "std_packet_size", 

67 "flow_duration", 

68 ] 

69 if feature_set == "extended": 

70 cols += [ 

71 "tcp_syn_count", 

72 "tcp_fin_count", 

73 "tcp_rst_count", 

74 "iat_mean", 

75 "iat_std", 

76 "bytes_per_packet", 

77 "packets_per_second", 

78 ] 

79 return df[cols].astype(float) 

80 

81 

82def inject_synthetic( 

83 df: pd.DataFrame, 

84 ratio: float = 0.02, 

85 rng_seed: int = 42, 

86) -> pd.DataFrame: 

87 """Inject synthetic outliers into the feature matrix. 

88 

89 A small fraction of synthetic points are drawn far outside the 

90 empirical distribution of each feature (based on the 5th and 95th 

91 percentiles). These points encourage the Isolation Forest to assign 

92 low scores to extreme values. When ``ratio`` is zero or the 

93 DataFrame is empty, the input DataFrame is returned unchanged. 

94 """ 

95 if ratio <= 0 or df.empty: 

96 return df 

97 rng = np.random.default_rng(rng_seed) 

98 n = max(1, int(len(df) * ratio)) 

99 cols = df.columns 

100 q_hi = df.quantile(0.95) 

101 q_lo = df.quantile(0.05) 

102 synth = [] 

103 for _ in range(n): 

104 row = {} 

105 for c in cols: 

106 span = max(1e-9, float(q_hi[c] - q_lo[c])) 

107 # Draw well above the 95th percentile to ensure anomaly 

108 row[c] = float(q_hi[c] + 5.0 * span * rng.random()) 

109 synth.append(row) 

110 df_s = pd.concat([df, pd.DataFrame(synth)], ignore_index=True) 

111 return df_s 

112 

113 

114def _columns_for(feature_set: str) -> List[str]: 

115 cols = [ 

116 "bidirectional_packets", 

117 "bidirectional_bytes", 

118 "mean_packet_size", 

119 "std_packet_size", 

120 "flow_duration", 

121 ] 

122 if feature_set == "extended": 

123 cols += [ 

124 "tcp_syn_count", 

125 "tcp_fin_count", 

126 "tcp_rst_count", 

127 "iat_mean", 

128 "iat_std", 

129 "bytes_per_packet", 

130 "packets_per_second", 

131 ] 

132 return cols 

133 

134 

135def _calibrate_contamination( 

136 Xs: np.ndarray, 

137 y: Optional[np.ndarray], 

138 candidates: List[float], 

139 random_state: int = 42, 

140) -> float: 

141 """Select the best contamination parameter using a simple F1 grid search. 

142 

143 If labels ``y`` are provided, the data is split into a training and 

144 validation subset. Synthetic points are handled by appending ones 

145 to the labels for the synthetic rows. The contamination value with 

146 the highest F1 score on the validation set is returned. If no 

147 labels are provided, the first candidate is used. 

148 """ 

149 if y is None: 

150 return candidates[0] 

151 # Construct labels matching the synthetic injection if present 

152 # (we assume synthetic points appear at the end of the array) 

153 # n_original = len(y) 

154 

155 # Build y_ext such that any extra rows beyond the original are labelled as anomalies (1) 

156 def fit_and_score(cont: float) -> float: 

157 m = IsolationForest( 

158 n_estimators=200, 

159 max_samples="auto", 

160 contamination=cont, 

161 random_state=random_state, 

162 n_jobs=-1, 

163 ) 

164 m.fit(Xtr) 

165 s = m.decision_function(Xval) 

166 pred = (s < 0).astype(int) 

167 tp = int(((pred == 1) & (yval == 1)).sum()) 

168 fp = int(((pred == 1) & (yval == 0)).sum()) 

169 fn = int(((pred == 0) & (yval == 1)).sum()) 

170 precision = tp / max(1, tp + fp) 

171 recall = tp / max(1, tp + fn) 

172 return ( 

173 0.0 

174 if (precision + recall) == 0 

175 else (2 * precision * recall / (precision + recall)) 

176 ) 

177 

178 # Stratify only if there is at least one anomaly label 

179 strat = y if y.sum() > 0 else None 

180 Xtr, Xval, ytr, yval = train_test_split( 

181 Xs, y, test_size=0.2, random_state=random_state, stratify=strat 

182 ) 

183 best_c = candidates[0] 

184 best_f1 = -1.0 

185 for c in candidates: 

186 f1 = fit_and_score(c) 

187 if f1 > best_f1: 

188 best_f1 = f1 

189 best_c = c 

190 return best_c 

191 

192 

193def train(csv_path: str, cfg_path: str, out_dir: str) -> str: 

194 """Train an Isolation Forest model and persist it to disk. 

195 

196 Parameters 

197 ---------- 

198 csv_path: str 

199 Path to the CSV file containing flows and (optionally) labels. 

200 cfg_path: str 

201 Path to a YAML configuration file (see ``load_config``). 

202 out_dir: str 

203 Directory where the model and threshold files should be written. 

204 

205 Returns 

206 ------- 

207 str 

208 The path to the written model file. 

209 """ 

210 cfg = load_config(cfg_path) 

211 logger = get_logger("train", cfg["logs_dir"], "train.log") 

212 feature_set = cfg.get("feature_set", "extended") 

213 contamination_default = float(cfg.get("contamination", 0.02)) 

214 

215 df = pd.read_csv(csv_path) 

216 if df.empty: 

217 raise RuntimeError("Dataset is empty – nothing to train on") 

218 X = _select_numeric(df, feature_set) 

219 # Extract labels if present; assume 1 = anomaly, 0 = benign 

220 y: Optional[np.ndarray] = df["label"].values if "label" in df.columns else None 

221 scaler = StandardScaler() 

222 Xs = scaler.fit_transform(X.values) 

223 # Inject synthetic outliers 

224 Xs_df = pd.DataFrame(Xs, columns=X.columns) 

225 Xs_inj = inject_synthetic(Xs_df, ratio=0.02) 

226 # If labels present, extend labels array to account for synthetic points (label 1 for synthetic) 

227 y_inj: Optional[np.ndarray] = None 

228 if y is not None: 

229 extra = len(Xs_inj) - len(Xs) 

230 if extra > 0: 

231 y_inj = np.concatenate([y, np.ones(extra, dtype=int)]) 

232 else: 

233 y_inj = y 

234 # Calibration of contamination parameter 

235 candidates = [0.005, 0.01, 0.02, 0.05] 

236 best_cont = ( 

237 _calibrate_contamination(Xs_inj.values, y_inj, candidates) 

238 if y_inj is not None 

239 else contamination_default 

240 ) 

241 # Train the Isolation Forest 

242 model = IsolationForest( 

243 n_estimators=200, 

244 max_samples="auto", 

245 contamination=best_cont, 

246 random_state=42, 

247 n_jobs=-1, 

248 ) 

249 model.fit(Xs_inj.values) 

250 # Persist model and scaler 

251 model_path, latest_path = save_model(model, scaler, out_dir) 

252 # Compute thresholds: use training scores to set red_threshold; yellow fixed at 0 

253 try: 

254 scores_train = model.decision_function(Xs) 

255 red_threshold: float 

256 if y is not None and (y == 0).sum() > 0: 

257 benign_scores = scores_train[y == 0] 

258 # Choose threshold such that at most ~1% of benign flows are mislabelled (1st percentile) 

259 candidate = float(np.percentile(benign_scores, 1) - 1e-3) 

260 red_threshold = float(min(candidate, -0.02)) 

261 else: 

262 red_threshold = float(np.min(scores_train) - 0.05) 

263 except Exception: 

264 red_threshold = -0.1 

265 yellow_threshold = 0.0 

266 thresholds_path = os.path.join(out_dir, "thresholds.json") 

267 try: 

268 with open(thresholds_path, "w", encoding="utf-8") as f_thr: 

269 json.dump( 

270 {"red_threshold": red_threshold, "yellow_threshold": yellow_threshold}, 

271 f_thr, 

272 indent=2, 

273 ) 

274 except Exception: 

275 # Do not fail training on threshold write errors 

276 pass 

277 # Model card metadata 

278 card = { 

279 "git_hash": get_git_hash(short=False), 

280 "feature_set": feature_set, 

281 "contamination": contamination_default, 

282 "n_estimators": 200, 

283 "train_rows": int(len(df)), 

284 "columns_numeric": list(X.columns), 

285 } 

286 ensure_dirs(out_dir) 

287 try: 

288 with open( 

289 os.path.join(out_dir, f"model_card_{get_git_hash()}.json"), 

290 "w", 

291 encoding="utf-8", 

292 ) as f: 

293 json.dump(card, f, indent=2) 

294 except Exception: 

295 pass 

296 logger.info(f"Model trained → {model_path} (contamination={best_cont})") 

297 return model_path 

298 

299 

300def main() -> None: 

301 """Console entry point for the training script.""" 

302 ap = argparse.ArgumentParser( 

303 description="Train an Isolation Forest on a CSV of network flows" 

304 ) 

305 ap.add_argument( 

306 "--csv", 

307 required=True, 

308 help="Path to CSV file with flow features and optional labels", 

309 ) 

310 ap.add_argument( 

311 "--config", default="config/config.yml", help="Path to configuration YAML file" 

312 ) 

313 ap.add_argument( 

314 "--out", default="./models", help="Output directory for the trained model" 

315 ) 

316 args = ap.parse_args() 

317 ensure_dirs(args.out) 

318 train(args.csv, args.config, args.out) 

319 

320 

321if __name__ == "__main__": 

322 main()