GINE-0.5 / inference.py
ISeeTheFuture's picture
remove trash param
bcea386
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 ISeeTheFuture
import os, json, argparse, warnings, joblib
import numpy as np
import pandas as pd
import torch
from typing import List
# === Features used at training time ===
FEATURE_COLS: List[str] = [
"latitude","longitude","altitude",
"accelerometer_x","accelerometer_y","accelerometer_z",
"gyroscope_x","gyroscope_y","gyroscope_z",
"compass"
]
# History window length for the model (uses the last 50 rows to predict the next row)
HIST_LEN_DEFAULT = 50 # requires at least HIST_LEN+1 (=51) rows to produce one output
# === Default file locations (relative to the model repo root) ===
DEFAULT_WEIGHTS = "1753670088.7075965_lstm_corr.pth"
DEFAULT_SCALER_X = "scalers/1753670088.7075965_scaler_X.pkl"
DEFAULT_SCALER_Y = "scalers/1753670088.7075965_scaler_y.pkl"
DEFAULT_CONFIG = "config.json"
# === Model class ===
from model import GPSCorrectionLSTM # __init__(input_size, hidden_size=128, num_layers=2, dropout=0.3)
def load_config(cfg_path: str) -> dict:
if os.path.exists(cfg_path):
with open(cfg_path, "r") as f:
return json.load(f)
return {}
def build_model(input_size: int, cfg: dict) -> torch.nn.Module:
"""Instantiate the model with hyperparameters from config.json if available."""
hidden_size = int(cfg.get("hidden_size", 128))
num_layers = int(cfg.get("num_layers", 2))
dropout = float(cfg.get("dropout", 0.3))
# output_size was 2 in training (res_lat, res_lon); keep flexible if your class needs it.
try:
model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)
except TypeError:
model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers)
model.eval()
return model
def load_scaler(path: str):
"""Load a Joblib scaler if present; otherwise continue without scaling."""
if os.path.exists(path):
return joblib.load(path)
warnings.warn(f"[WARN] scaler not found: {path}. Proceeding without scaling.")
return None
def load_df_from_csv(path: str) -> pd.DataFrame:
"""Load CSV, sort by timestamp if present, and validate feature columns."""
df = pd.read_csv(path)
if "timestamp" in df.columns:
df = df.sort_values("timestamp")
missing = [c for c in FEATURE_COLS if c not in df.columns]
if missing:
raise ValueError(f"CSV is missing columns: {missing}")
return df.reset_index(drop=True)
def scale_window(X_win: np.ndarray, scaler_X):
"""Apply feature scaler to a single (T,F) window if provided."""
if scaler_X is None:
return X_win
T, F = X_win.shape
return scaler_X.transform(X_win.reshape(-1, F)).reshape(T, F)
def inverse_y(y: np.ndarray, scaler_y):
"""Inverse-transform a single (2,) or (3,) prediction if a target scaler is provided."""
if scaler_y is None:
return y
return scaler_y.inverse_transform(y.reshape(1, -1)).reshape(-1)
def predict_next_residual(model: torch.nn.Module, X_win_tf: np.ndarray, device: str = "cpu") -> np.ndarray:
"""Predict next-step residual [res_lat, res_lon(, res_alt?)] from a (HIST_LEN,F) window."""
x = torch.from_numpy(X_win_tf.astype(np.float32)).unsqueeze(0).to(device) # (1, T, F)
with torch.no_grad():
y = model(x).squeeze(0).detach().cpu().numpy()
return y # shape: (2,) or (3,)
# python inference.py --csv samples/sample.csv
def main():
ap = argparse.ArgumentParser(
description="Rolling inference for next-step GNSS residuals using an LSTM model. "
"Uses the last HIST_LEN rows to predict the next row. "
"If the CSV has N rows and N >= HIST_LEN+1, this script outputs corrected coordinates "
"for rows [HIST_LEN ... N-1] (i.e., 51st to last)."
)
src = ap.add_mutually_exclusive_group(required=True)
src.add_argument("--json", type=str, help="JSON string of shape [T, F]")
src.add_argument("--json-file", type=str, help="Path to a JSON file (shape [T, F])")
src.add_argument("--csv", type=str, help="Path to a CSV with columns: " + ",".join(FEATURE_COLS))
ap.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Model weights (state_dict or full model object).")
ap.add_argument("--scaler-x", default=DEFAULT_SCALER_X, help="Feature scaler (Joblib).")
ap.add_argument("--scaler-y", default=DEFAULT_SCALER_Y, help="Target scaler (Joblib).")
ap.add_argument("--config", default=DEFAULT_CONFIG, help="Model hyperparameters (config.json).")
ap.add_argument("--hist-len", type=int, default=HIST_LEN_DEFAULT, help="History window length used by the model (default: 50).")
args = ap.parse_args()
# 1) Load input
if args.json:
arr = np.asarray(json.loads(args.json), dtype=np.float32)
timestamps = None
elif args.json_file:
with open(args.json_file, "r") as f:
arr = np.asarray(json.load(f), dtype=np.float32)
timestamps = None
else:
df = load_df_from_csv(args.csv)
arr = df[FEATURE_COLS].to_numpy(dtype=np.float32)
timestamps = df["timestamp"].to_numpy() if "timestamp" in df.columns else None
T, F = arr.shape
H = int(args.hist_len)
if F != len(FEATURE_COLS):
raise ValueError(f"Input feature dimension must be {len(FEATURE_COLS)}, got {F}.")
# 2) Build & load model and scalers
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = load_config(args.config)
model = build_model(input_size=F, cfg=cfg).to(device)
state = torch.load(args.weights, map_location=device)
try:
model.load_state_dict(state)
except Exception:
model = state.to(device)
model.eval()
scaler_X = load_scaler(args.scaler_x)
scaler_y = load_scaler(args.scaler_y)
results = []
# Rolling inference for indices i = H .. T-1
# Each step uses arr[i-H : i] as input, and adds residual to noisy GNSS at i.
for i in range(H, T):
X_win = arr[i - H : i, :] # (H, F)
X_win_tf = scale_window(X_win, scaler_X)
y_pred = predict_next_residual(model, X_win_tf, device=device) # (2,) or (3,)
y_pred_deg = inverse_y(y_pred, scaler_y)
res_lat = float(y_pred_deg[0])
res_lon = float(y_pred_deg[1])
# Noisy GNSS at step i (the "next" row after the window)
noisy_lat = float(arr[i, 0])
noisy_lon = float(arr[i, 1])
out = {
"index": int(i), # 0-based row index in the input
"noisy_next_lat_deg": noisy_lat,
"noisy_next_lon_deg": noisy_lon,
"pred_residual_lat_deg": res_lat,
"pred_residual_lon_deg": res_lon,
"corrected_next_lat_deg": noisy_lat + res_lat,
"corrected_next_lon_deg": noisy_lon + res_lon,
}
# If model outputs altitude residual too
if y_pred_deg.shape[0] >= 3:
res_alt = float(y_pred_deg[2])
noisy_alt = float(arr[i, 2])
out.update({
"noisy_next_alt_m": noisy_alt,
"pred_residual_alt": res_alt,
"corrected_next_alt_m": noisy_alt + res_alt
})
if timestamps is not None:
out["timestamp"] = float(timestamps[i])
results.append(out)
print(json.dumps({
"history_len": H,
"total_rows": T,
"outputs": results
}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()