-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
177 lines (152 loc) · 5.53 KB
/
Copy pathtrainer.py
File metadata and controls
177 lines (152 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
File: trainer.py
Author: Dimitrios Kafetzis (kafetzis@aueb.gr)
Created: May 2025
License: MIT
Description:
Handles the training of all model architectures for DDoS detection.
Creates, trains, and saves various models including:
- Linear model
- Threshold detector
- Shallow DNN
- Deep Neural Network
- LSTM
- GRU
- Transformer
Each model is trained on the same dataset for fair comparison.
Usage:
$ python trainer.py
The script can also be imported:
from trainer import train_and_save
Example:
train_and_save(create_model_fn, "model_name", X_train, y_train, timestamp)
"""
import time
import json
import inspect
import numpy as np
import tensorflow as tf
import pickle
import os
from evaluation_config import HISTORY_DIR
from pathlib import Path
from data.loader import prepare_combined_dataset
from models.linear_regressor import create_linear_model
from models.threshold_detector import ThresholdDetector
from models.shallow_dnn import create_shallow_dnn_model
from models.dnn import create_dnn_model
from models.lstm import create_lstm_model
from models.gru import create_gru_model
from models.transformer import create_transformer_model
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def train_and_save(model_fn, name, X, y, timestamp, mean=None, std=None, **fit_kwargs):
"""
Instantiate model_fn, possibly reshape X for RNNs, train on (X,y),
and save under saved_models/<name>_<timestamp>/.
"""
sig = inspect.signature(model_fn)
# detect whether factory expects input_dim or input_shape
if "input_dim" in sig.parameters:
model = model_fn(input_dim=X.shape[1])
X_train = X
else:
# assume RNN: expects input_shape=(timesteps, features)
# choose timesteps=1, features=X.shape[1]
timesteps = 1
features = X.shape[1]
model = model_fn(input_shape=(timesteps, features))
X_train = X.reshape((-1, timesteps, features))
# Add callbacks for NaN detection
callbacks = fit_kwargs.get('callbacks', [])
callbacks.append(tf.keras.callbacks.TerminateOnNaN())
fit_kwargs['callbacks'] = callbacks
# Save normalization parameters
outdir = Path("saved_models") / f"{name}_{timestamp}"
outdir.mkdir(parents=True, exist_ok=True)
if mean is not None and std is not None:
np.save(str(outdir / "X_mean.npy"), mean)
np.save(str(outdir / "X_std.npy"), std)
logger.info(f"Training {name} on {len(X_train)} samples…")
history = model.fit(X_train, y, **fit_kwargs)
# Save training history
history_path = os.path.join(HISTORY_DIR, f"{name}_{timestamp}_history.pkl")
with open(history_path, 'wb') as f:
pickle.dump(history.history, f)
logger.info(f"Saving {name} to {outdir}")
model.save(str(outdir), save_format="tf")
return outdir
def main():
ts = int(time.time())
PROJECT_ROOT = Path(__file__).parent.resolve()
data_dir = PROJECT_ROOT / "data" / "nsl_kdd_dataset"
hard_csv = data_dir / "NSL-KDD-Hard.csv"
# 1) Prepare datasets
(X_train, y_train), (X_easy, y_easy), (X_hard, y_hard), feat_cols, (mean, std) = \
prepare_combined_dataset(str(data_dir), str(hard_csv))
# 2) Common fit args
fit_args = dict(
epochs=20,
batch_size=64,
validation_split=0.1,
callbacks=[
tf.keras.callbacks.EarlyStopping(
patience=5,
restore_best_weights=True,
monitor='val_loss'
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=2
)
],
verbose=2
)
# 3) Train each model family
saved = {}
saved['linear_model'] = train_and_save(
create_linear_model, "linear_model",
X_train, y_train, ts, mean, std, **fit_args
)
# Threshold detector: no training but we save its calibrated thresholds
thr = ThresholdDetector(percentile=99.5)
thr.calibrate(X_train, y_train, feature_indices=list(range(X_train.shape[1])))
outdir = Path("saved_models") / f"threshold_detector_{ts}"
outdir.mkdir(parents=True, exist_ok=True)
with open(outdir / "thresholds.json", "w") as f:
json.dump(thr.thresholds, f)
# Save normalization parameters for consistency
np.save(str(outdir / "X_mean.npy"), mean)
np.save(str(outdir / "X_std.npy"), std)
saved['threshold_detector'] = str(outdir)
saved['shallow_dnn'] = train_and_save(
create_shallow_dnn_model, "shallow_dnn",
X_train, y_train, ts, mean, std, **fit_args
)
saved['dnn'] = train_and_save(
create_dnn_model, "dnn",
X_train, y_train, ts, mean, std, **fit_args
)
saved['lstm'] = train_and_save(
create_lstm_model, "lstm",
X_train, y_train, ts, mean, std, **fit_args
)
saved['gru'] = train_and_save(
create_gru_model, "gru",
X_train, y_train, ts, mean, std, **fit_args
)
saved['transformer'] = train_and_save(
create_transformer_model, "transformer",
X_train, y_train, ts, mean, std, **fit_args
)
# 4) Write out manifest
saved = {k: str(v) for k, v in saved.items()} # Convert all paths to strings
with open("saved_models/manifest.json", "w") as f:
json.dump(saved, f, indent=2)
logger.info("All models retrained and saved.")
if __name__ == "__main__":
main()