orgoflu commited on
Commit
ab4dfe6
Β·
verified Β·
1 Parent(s): 5b980de
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -1,26 +1,25 @@
1
  import re
2
- import math
3
  import gradio as gr
4
  import torch
5
- from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
6
 
7
- # βœ… 곡개 KoBART λͺ¨λΈ
8
- MODEL_NAME = "gogamza/kobart-base-v2"
9
 
10
  tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
11
- model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
12
 
13
  # CPU 동적 μ–‘μžν™” 적용
14
  try:
15
  model = torch.quantization.quantize_dynamic(
16
  model, {torch.nn.Linear}, dtype=torch.qint8
17
  )
18
- except Exception:
19
  pass
20
 
21
  model.eval()
22
 
23
- # ===== μœ ν‹Έ ν•¨μˆ˜ =====
24
  def normalize_text(text: str) -> str:
25
  return re.sub(r"\s+", " ", text).strip()
26
 
@@ -55,17 +54,16 @@ def chunk_by_tokens(sentences, max_tokens=900):
55
  chunks.append(" ".join(cur))
56
  return chunks
57
 
58
- # ===== μš”μ•½ ν•¨μˆ˜ =====
59
  def summarize_raw(text: str, min_len: int, max_len: int) -> str:
60
- inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
61
  with torch.no_grad():
62
  summary_ids = model.generate(
63
- inputs["input_ids"],
64
  num_beams=4,
65
  min_length=min_len,
66
  max_length=max_len,
67
- early_stopping=True,
68
- no_repeat_ngram_size=3
69
  )
70
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
71
 
@@ -77,8 +75,8 @@ def apply_style_prompt(text: str, mode: str, final: bool=False) -> str:
77
  else:
78
  inst = "λ‹€μŒ ν•œκ΅­μ–΄ ν…μŠ€νŠΈλ₯Ό bullet ν˜•νƒœλ‘œ ν•΅μ‹¬λ§Œ μš”μ•½ν•˜μ„Έμš”."
79
  if final:
80
- inst += " 이 μš”μ•½μ€ μ΅œμ’…λ³Έμž…λ‹ˆλ‹€."
81
- return f"{inst}\n\n[ν…μŠ€νŠΈ]\n{text}"
82
 
83
  def postprocess(summary: str, mode: str) -> str:
84
  s = summary.strip()
@@ -124,7 +122,7 @@ def ui_summarize(text, target_len, style):
124
  return summarize_long(text, int(target_len), mode)
125
 
126
  with gr.Blocks() as demo:
127
- gr.Markdown("## πŸ“ KoBART ν•œκ΅­μ–΄ μš”μ•½κΈ° (곡개 λͺ¨λΈ gogamza/kobart-base-v2)")
128
  with gr.Row():
129
  with gr.Column():
130
  input_text = gr.Textbox(label="원문 μž…λ ₯", lines=16)
 
1
  import re
 
2
  import gradio as gr
3
  import torch
4
+ from transformers import PreTrainedTokenizerFast, T5ForConditionalGeneration
5
 
6
+ # βœ… KoT5 μš”μ•½ λͺ¨λΈ
7
+ MODEL_NAME = "psyche/KoT5-summarization"
8
 
9
  tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
10
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
11
 
12
  # CPU 동적 μ–‘μžν™” 적용
13
  try:
14
  model = torch.quantization.quantize_dynamic(
15
  model, {torch.nn.Linear}, dtype=torch.qint8
16
  )
17
+ except:
18
  pass
19
 
20
  model.eval()
21
 
22
+ # ===== μœ ν‹Έ =====
23
  def normalize_text(text: str) -> str:
24
  return re.sub(r"\s+", " ", text).strip()
25
 
 
54
  chunks.append(" ".join(cur))
55
  return chunks
56
 
57
+ # ===== μš”μ•½ =====
58
  def summarize_raw(text: str, min_len: int, max_len: int) -> str:
59
+ input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024)
60
  with torch.no_grad():
61
  summary_ids = model.generate(
62
+ input_ids,
63
  num_beams=4,
64
  min_length=min_len,
65
  max_length=max_len,
66
+ early_stopping=True
 
67
  )
68
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
69
 
 
75
  else:
76
  inst = "λ‹€μŒ ν•œκ΅­μ–΄ ν…μŠ€νŠΈλ₯Ό bullet ν˜•νƒœλ‘œ ν•΅μ‹¬λ§Œ μš”μ•½ν•˜μ„Έμš”."
77
  if final:
78
+ inst += " μ›λž˜ μˆœμ„œλ₯Ό μœ μ§€ν•˜λ©° λ¬Έμž₯ 연결을 μžμ—°μŠ€λŸ½κ²Œ ν•˜μ„Έμš”."
79
+ return f"{inst}\n\n{text}"
80
 
81
  def postprocess(summary: str, mode: str) -> str:
82
  s = summary.strip()
 
122
  return summarize_long(text, int(target_len), mode)
123
 
124
  with gr.Blocks() as demo:
125
+ gr.Markdown("## πŸ“ KoT5 ν•œκ΅­μ–΄ μš”μ•½κΈ° (κΈ΄ λ¬Έμ„œ μžλ™ λΆ„ν•  + μˆœμ„œ 보쑴)")
126
  with gr.Row():
127
  with gr.Column():
128
  input_text = gr.Textbox(label="원문 μž…λ ₯", lines=16)