Instructions to use Someshfengde/SnakeCLEF2024 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use Someshfengde/SnakeCLEF2024 with timm:
import timm model = timm.create_model("hf_hub:Someshfengde/SnakeCLEF2024", pretrained=True) - Notebooks
- Google Colab
- Kaggle
| #%% | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from tqdm import tqdm | |
| import timm | |
| import torchvision.transforms as T | |
| # from albumentations.pytorch import ToTensorV2 | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import json | |
| # from transformers import AutoImageProcessor | |
| # from create_model import HieraForImageClassification | |
| #%% | |
| # %% | |
| SZ = 224 | |
| LABELS = json.load(open("./labels_class_map_rev.json")) | |
| ORIGINAL_LABELS = json.load(open("./original_mapping.json")) | |
| def is_gpu_available(): | |
| """Check if the python package `onnxruntime-gpu` is installed.""" | |
| return torch.cuda.is_available() | |
| # VALID_AUG = A.Compose([ | |
| # A.SmallestMaxSize(max_size=SZ + 16, p=1.0), | |
| # A.CenterCrop(height=SZ, width=SZ, p=1.0), | |
| # A.Normalize(), | |
| # ToTensorV2(), | |
| # ]) | |
| def get_corn_model(model_name, pretrained=True, **kwargs): | |
| model = timm.create_model(model_name, pretrained=pretrained, **kwargs) | |
| model = nn.Sequential( | |
| model, | |
| nn.Dropout(0.15), | |
| nn.Linear(model.num_classes, model.num_classes * 2) , | |
| nn.Linear(model.num_classes * 2, len(LABELS)) | |
| ) | |
| return model | |
| class PytorchWorker: | |
| def __init__(self): | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| def _load_model(): | |
| print("Setting up Pytorch Model") | |
| print(f"Using devide: {self.device}") | |
| model = get_corn_model("vit_base_patch16_224", pretrained=False) | |
| model_ckpt = torch.load("./NB_EXP_V2_008/vit_base_patch16_224_224_bs32_ep16_lr6e05_wd0.05_mixup_cutmix_CV_0.pth", map_location=self.device) | |
| model.load_state_dict(model_ckpt) | |
| return model.to(self.device) | |
| self.transforms = T.Compose([T.Resize((SZ, SZ)), | |
| T.ToTensor(), | |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
| self.model = _load_model() | |
| def predict_image(self, image: np.ndarray) -> list(): | |
| """Run inference using ONNX runtime. | |
| :param image: Input image as numpy array. | |
| :return: A list with logits and confidences. | |
| """ | |
| image_data = self.transforms(image).unsqueeze(0).to(self.device) | |
| outputs = self.model(image_data) | |
| logits = outputs | |
| return logits.tolist() | |
| def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"): | |
| """Make submission with given """ | |
| model = PytorchWorker() | |
| predictions = [] | |
| for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)): | |
| image_path = os.path.join(images_root_path, row.filename) | |
| # image_path = row.filename | |
| image = Image.open(image_path).convert("RGB") | |
| output = model.predict_image(image) | |
| string_label_dup = LABELS.get(str(np.argmax(output)), 'Acanthophis antarcticus') | |
| prediction_class = ORIGINAL_LABELS.get(string_label_dup, 1) | |
| predictions.append(prediction_class) | |
| print(predictions) | |
| test_metadata["class_id"] = predictions | |
| user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") | |
| user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) | |
| #%% | |
| if __name__ == "__main__": | |
| import zipfile | |
| with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: | |
| zip_ref.extractall("/tmp/data") | |
| MODEL_PATH = "pytorch_model.bin" | |
| MODEL_NAME = "swinv2_tiny_window16_256.ms_in1k" | |
| metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv" | |
| test_metadata = pd.read_csv(metadata_file_path) | |
| # test_metadata = pd.DataFrame() | |
| # test_metadata['filename'] = ['../sample.png', '../sample copy.png', '../sample copy 2.png'] | |
| # test_metadata['observation_id'] = [1, 2, 3] | |
| make_submission( | |
| test_metadata=test_metadata, | |
| model_path=MODEL_PATH, | |
| model_name=MODEL_NAME | |
| ) | |
| # #%% | |
| # import requests | |
| # image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) | |
| # # %% | |
| # image = VALID_AUG(image=np.array(image))['image'] | |
| # # %% | |
| # model= PytorchWorker() | |
| # # %% | |
| # output = model.predict_image(image.unsqueeze(dim =0 )) | |
| # # %% | |
| # output | |
| # # %% | |
| # import numpy as np | |
| # np.argmax(output) | |
| # %% | |
| # df = pd.DataFrame() | |
| # df["filename"] = ['sample.png'] | |
| # # %% | |
| # make_submission( | |
| # test_metadata=df, | |
| # model_path="MODEL_PATH", | |
| # model_name="MODEL_NAME" | |
| # ) | |
| # %% | |
| # %% | |