janmariakowalski commited on
Commit
dde831c
·
verified ·
1 Parent(s): b6fbae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -2
app.py CHANGED
@@ -11,6 +11,7 @@ from typing import Dict, Tuple, Any
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
  import numpy as np
 
14
 
15
  try:
16
  from peft import PeftModel
@@ -20,9 +21,11 @@ except ImportError:
20
 
21
  # --- Configuration ---
22
  # Model path is set to sojka
23
- MODEL_PATH = os.getenv("MODEL_PATH", "AndromedaPL/sojka")
24
  TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base")
25
 
 
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
28
  MAX_SEQ_LENGTH = 512
@@ -43,6 +46,34 @@ THRESHOLDS = {
43
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44
  logger = logging.getLogger(__name__)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
47
  """Load the trained model and tokenizer"""
48
  logger.info(f"Loading tokenizer from {tokenizer_path}")
@@ -136,12 +167,28 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
136
  label: score for label, score in predictions.items()
137
  if score >= THRESHOLDS[label]
138
  }
139
-
 
140
  if not unsafe_categories:
141
  verdict = "✅ Komunikat jest bezpieczny."
 
142
  else:
143
  highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
144
  verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści:\n {highest_unsafe_category.upper()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  return verdict, predictions
147
 
 
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
  import numpy as np
14
+ from huggingface_hub import HfApi
15
 
16
  try:
17
  from peft import PeftModel
 
21
 
22
  # --- Configuration ---
23
  # Model path is set to sojka
24
+ MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka3")
25
  TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base")
26
 
27
+ REPO_ID = "speakleash/sojka-logs"
28
+
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
31
  MAX_SEQ_LENGTH = 512
 
46
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
47
  logger = logging.getLogger(__name__)
48
 
49
+ # HfApi instance
50
+ if HF_TOKEN:
51
+ api = HfApi()
52
+ else:
53
+ api = None
54
+ logger.warning("HF_TOKEN environment variable not set. Logging to Hugging Face Hub will be disabled.")
55
+
56
+ def log_prediction(log_data: dict):
57
+ if not api:
58
+ return
59
+
60
+ day = datetime.now().strftime("%Y-%m-%d")
61
+ timestamp = log_data.get('timestamp', datetime.now().timestamp())
62
+
63
+ try:
64
+ api.upload_file(
65
+ path_or_fileobj=json.dumps(log_data, indent=2, ensure_ascii=False).encode('utf-8'),
66
+ path_in_repo=f"predictions/{day}/{timestamp}.json",
67
+ repo_id=REPO_ID,
68
+ repo_type="dataset",
69
+ commit_message="log prediction",
70
+ token=HF_TOKEN,
71
+ run_as_future=True
72
+ )
73
+ except Exception as e:
74
+ logger.error(f"Failed to log prediction to hub: {e}")
75
+
76
+
77
  def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
78
  """Load the trained model and tokenizer"""
79
  logger.info(f"Loading tokenizer from {tokenizer_path}")
 
167
  label: score for label, score in predictions.items()
168
  if score >= THRESHOLDS[label]
169
  }
170
+
171
+
172
  if not unsafe_categories:
173
  verdict = "✅ Komunikat jest bezpieczny."
174
+ verdict_label = "SAFE"
175
  else:
176
  highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
177
  verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści:\n {highest_unsafe_category.upper()}"
178
+ verdict_label = "UNSAFE"
179
+
180
+ log_data = {
181
+ 'text': text,
182
+ 'predictions': predictions,
183
+ 'thresholds': THRESHOLDS,
184
+ 'sojka_verdict': verdict_label,
185
+ 'herbert_result': {},
186
+ 'timestamp': datetime.now().timestamp(),
187
+ 'model_path': MODEL_PATH,
188
+ 'herbert_enabled': false
189
+ }
190
+
191
+ log_prediction(log_data)
192
 
193
  return verdict, predictions
194