Upload app.py
Browse files
app.py
CHANGED
|
@@ -7,21 +7,22 @@ import torch
|
|
| 7 |
from models import AudioClassifier
|
| 8 |
from utils import logger
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
ckpt_dir = Path("ckpt/")
|
| 12 |
config_path = ckpt_dir / "config.json"
|
| 13 |
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
|
| 14 |
config = json.loads((ckpt_dir / "config.json").read_text())
|
| 15 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
-
model = AudioClassifier(device=device, **config["model"]).to(device)
|
| 17 |
|
|
|
|
| 18 |
# Latest checkpoint
|
| 19 |
if (ckpt_dir / "model_final.pth").exists():
|
| 20 |
ckpt = ckpt_dir / "model_final.pth"
|
| 21 |
else:
|
| 22 |
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
|
| 23 |
logger.info(f"Loading {ckpt}...")
|
| 24 |
-
model.load_state_dict(torch.load(ckpt))
|
| 25 |
|
| 26 |
|
| 27 |
def classify_audio(audio_file: str):
|
|
|
|
| 7 |
from models import AudioClassifier
|
| 8 |
from utils import logger
|
| 9 |
|
| 10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
logger.info(f"Device: {device}")
|
| 12 |
|
| 13 |
ckpt_dir = Path("ckpt/")
|
| 14 |
config_path = ckpt_dir / "config.json"
|
| 15 |
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
|
| 16 |
config = json.loads((ckpt_dir / "config.json").read_text())
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
model = AudioClassifier(device=device, **config["model"]).to(device)
|
| 19 |
# Latest checkpoint
|
| 20 |
if (ckpt_dir / "model_final.pth").exists():
|
| 21 |
ckpt = ckpt_dir / "model_final.pth"
|
| 22 |
else:
|
| 23 |
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
|
| 24 |
logger.info(f"Loading {ckpt}...")
|
| 25 |
+
model.load_state_dict(torch.load(ckpt, map_location=device))
|
| 26 |
|
| 27 |
|
| 28 |
def classify_audio(audio_file: str):
|