|
|
|
|
|
|
|
|
|
|
|
import os, json, argparse, warnings, joblib |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from typing import List |
|
|
|
|
|
|
|
|
FEATURE_COLS: List[str] = [ |
|
|
"latitude","longitude","altitude", |
|
|
"accelerometer_x","accelerometer_y","accelerometer_z", |
|
|
"gyroscope_x","gyroscope_y","gyroscope_z", |
|
|
"compass" |
|
|
] |
|
|
|
|
|
|
|
|
HIST_LEN_DEFAULT = 50 |
|
|
|
|
|
|
|
|
DEFAULT_WEIGHTS = "1753670088.7075965_lstm_corr.pth" |
|
|
DEFAULT_SCALER_X = "scalers/1753670088.7075965_scaler_X.pkl" |
|
|
DEFAULT_SCALER_Y = "scalers/1753670088.7075965_scaler_y.pkl" |
|
|
DEFAULT_CONFIG = "config.json" |
|
|
|
|
|
|
|
|
from model import GPSCorrectionLSTM |
|
|
|
|
|
def load_config(cfg_path: str) -> dict: |
|
|
if os.path.exists(cfg_path): |
|
|
with open(cfg_path, "r") as f: |
|
|
return json.load(f) |
|
|
return {} |
|
|
|
|
|
def build_model(input_size: int, cfg: dict) -> torch.nn.Module: |
|
|
"""Instantiate the model with hyperparameters from config.json if available.""" |
|
|
hidden_size = int(cfg.get("hidden_size", 128)) |
|
|
num_layers = int(cfg.get("num_layers", 2)) |
|
|
dropout = float(cfg.get("dropout", 0.3)) |
|
|
|
|
|
try: |
|
|
model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout) |
|
|
except TypeError: |
|
|
model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def load_scaler(path: str): |
|
|
"""Load a Joblib scaler if present; otherwise continue without scaling.""" |
|
|
if os.path.exists(path): |
|
|
return joblib.load(path) |
|
|
warnings.warn(f"[WARN] scaler not found: {path}. Proceeding without scaling.") |
|
|
return None |
|
|
|
|
|
def load_df_from_csv(path: str) -> pd.DataFrame: |
|
|
"""Load CSV, sort by timestamp if present, and validate feature columns.""" |
|
|
df = pd.read_csv(path) |
|
|
if "timestamp" in df.columns: |
|
|
df = df.sort_values("timestamp") |
|
|
missing = [c for c in FEATURE_COLS if c not in df.columns] |
|
|
if missing: |
|
|
raise ValueError(f"CSV is missing columns: {missing}") |
|
|
return df.reset_index(drop=True) |
|
|
|
|
|
def scale_window(X_win: np.ndarray, scaler_X): |
|
|
"""Apply feature scaler to a single (T,F) window if provided.""" |
|
|
if scaler_X is None: |
|
|
return X_win |
|
|
T, F = X_win.shape |
|
|
return scaler_X.transform(X_win.reshape(-1, F)).reshape(T, F) |
|
|
|
|
|
def inverse_y(y: np.ndarray, scaler_y): |
|
|
"""Inverse-transform a single (2,) or (3,) prediction if a target scaler is provided.""" |
|
|
if scaler_y is None: |
|
|
return y |
|
|
return scaler_y.inverse_transform(y.reshape(1, -1)).reshape(-1) |
|
|
|
|
|
def predict_next_residual(model: torch.nn.Module, X_win_tf: np.ndarray, device: str = "cpu") -> np.ndarray: |
|
|
"""Predict next-step residual [res_lat, res_lon(, res_alt?)] from a (HIST_LEN,F) window.""" |
|
|
x = torch.from_numpy(X_win_tf.astype(np.float32)).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
y = model(x).squeeze(0).detach().cpu().numpy() |
|
|
return y |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser( |
|
|
description="Rolling inference for next-step GNSS residuals using an LSTM model. " |
|
|
"Uses the last HIST_LEN rows to predict the next row. " |
|
|
"If the CSV has N rows and N >= HIST_LEN+1, this script outputs corrected coordinates " |
|
|
"for rows [HIST_LEN ... N-1] (i.e., 51st to last)." |
|
|
) |
|
|
src = ap.add_mutually_exclusive_group(required=True) |
|
|
src.add_argument("--json", type=str, help="JSON string of shape [T, F]") |
|
|
src.add_argument("--json-file", type=str, help="Path to a JSON file (shape [T, F])") |
|
|
src.add_argument("--csv", type=str, help="Path to a CSV with columns: " + ",".join(FEATURE_COLS)) |
|
|
|
|
|
ap.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Model weights (state_dict or full model object).") |
|
|
ap.add_argument("--scaler-x", default=DEFAULT_SCALER_X, help="Feature scaler (Joblib).") |
|
|
ap.add_argument("--scaler-y", default=DEFAULT_SCALER_Y, help="Target scaler (Joblib).") |
|
|
ap.add_argument("--config", default=DEFAULT_CONFIG, help="Model hyperparameters (config.json).") |
|
|
ap.add_argument("--hist-len", type=int, default=HIST_LEN_DEFAULT, help="History window length used by the model (default: 50).") |
|
|
|
|
|
args = ap.parse_args() |
|
|
|
|
|
|
|
|
if args.json: |
|
|
arr = np.asarray(json.loads(args.json), dtype=np.float32) |
|
|
timestamps = None |
|
|
elif args.json_file: |
|
|
with open(args.json_file, "r") as f: |
|
|
arr = np.asarray(json.load(f), dtype=np.float32) |
|
|
timestamps = None |
|
|
else: |
|
|
df = load_df_from_csv(args.csv) |
|
|
arr = df[FEATURE_COLS].to_numpy(dtype=np.float32) |
|
|
timestamps = df["timestamp"].to_numpy() if "timestamp" in df.columns else None |
|
|
|
|
|
T, F = arr.shape |
|
|
H = int(args.hist_len) |
|
|
if F != len(FEATURE_COLS): |
|
|
raise ValueError(f"Input feature dimension must be {len(FEATURE_COLS)}, got {F}.") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
cfg = load_config(args.config) |
|
|
model = build_model(input_size=F, cfg=cfg).to(device) |
|
|
|
|
|
state = torch.load(args.weights, map_location=device) |
|
|
try: |
|
|
model.load_state_dict(state) |
|
|
except Exception: |
|
|
model = state.to(device) |
|
|
model.eval() |
|
|
|
|
|
scaler_X = load_scaler(args.scaler_x) |
|
|
scaler_y = load_scaler(args.scaler_y) |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
|
|
|
for i in range(H, T): |
|
|
X_win = arr[i - H : i, :] |
|
|
X_win_tf = scale_window(X_win, scaler_X) |
|
|
y_pred = predict_next_residual(model, X_win_tf, device=device) |
|
|
y_pred_deg = inverse_y(y_pred, scaler_y) |
|
|
|
|
|
res_lat = float(y_pred_deg[0]) |
|
|
res_lon = float(y_pred_deg[1]) |
|
|
|
|
|
noisy_lat = float(arr[i, 0]) |
|
|
noisy_lon = float(arr[i, 1]) |
|
|
|
|
|
out = { |
|
|
"index": int(i), |
|
|
"noisy_next_lat_deg": noisy_lat, |
|
|
"noisy_next_lon_deg": noisy_lon, |
|
|
"pred_residual_lat_deg": res_lat, |
|
|
"pred_residual_lon_deg": res_lon, |
|
|
"corrected_next_lat_deg": noisy_lat + res_lat, |
|
|
"corrected_next_lon_deg": noisy_lon + res_lon, |
|
|
} |
|
|
|
|
|
if y_pred_deg.shape[0] >= 3: |
|
|
res_alt = float(y_pred_deg[2]) |
|
|
noisy_alt = float(arr[i, 2]) |
|
|
out.update({ |
|
|
"noisy_next_alt_m": noisy_alt, |
|
|
"pred_residual_alt": res_alt, |
|
|
"corrected_next_alt_m": noisy_alt + res_alt |
|
|
}) |
|
|
if timestamps is not None: |
|
|
out["timestamp"] = float(timestamps[i]) |
|
|
results.append(out) |
|
|
|
|
|
print(json.dumps({ |
|
|
"history_len": H, |
|
|
"total_rows": T, |
|
|
"outputs": results |
|
|
}, ensure_ascii=False, indent=2)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|