j-silv commited on
Commit
b4ab4e0
·
1 Parent(s): 92bc2bf

Clean up style

Browse files
Files changed (2) hide show
  1. autohdl/llm.py +3 -5
  2. streamlit_app.py +26 -32
autohdl/llm.py CHANGED
@@ -32,12 +32,10 @@ class LLM:
32
  ]
33
 
34
  input_text=self.hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
- # print(input_text)
36
 
37
- inputs = self.hf_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
38
 
39
- outputs = self.hf_model.generate(inputs)
40
 
41
- # print(self.hf_tokenizer.decode(outputs[0]))
42
- return self.hf_tokenizer.decode(outputs[0])
43
 
 
32
  ]
33
 
34
  input_text=self.hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
35
 
36
+ inputs = self.hf_tokenizer(input_text, return_tensors="pt").to(self.device)
37
 
38
+ outputs = self.hf_model.generate(**inputs, max_new_tokens=50)
39
 
40
+ return self.hf_tokenizer.decode(outputs[0, inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
 
41
 
streamlit_app.py CHANGED
@@ -2,28 +2,19 @@ import streamlit as st
2
  from autohdl.data import data
3
  from autohdl.llm import system_prompt, LLM
4
  import random
5
- from code_editor import code_editor
6
 
7
  """
8
  # AutoHDL
9
  ### AI agent which generates Verilog code
10
  """
11
- def code_input(title, message, edit_disabled=False):
12
  st.text(title)
 
 
 
 
 
13
 
14
- ace_props = {"style": {"borderRadius": "0px 0px 8px 8px"}}
15
-
16
- code_editor(message,
17
- height = 10,
18
- lang="text",
19
- theme="default",
20
- shortcuts="vscode",
21
- focus=False,
22
- props=ace_props,
23
- response_mode="debounce",
24
- options={"wrap": True})
25
-
26
-
27
  def random_sample_btn(stop):
28
  """Generate a new sample"""
29
 
@@ -46,6 +37,7 @@ def load_model():
46
  def server():
47
  ds = data(small_dataset=True)
48
 
 
49
  model = load_model()
50
 
51
  if 'idx' not in st.session_state:
@@ -57,29 +49,31 @@ def server():
57
  idx = st.session_state['idx']
58
 
59
  summary = "high_level_global_summary"
 
60
 
61
- row1 = st.container(border=True, height=400)
62
- row2 = st.container(border=True, height=400)
63
 
64
- with row1:
65
- col1, col2 = st.columns(2)
 
66
 
67
- with col1:
68
- code_input("System prompt", system_prompt)
69
 
70
- with col2:
71
- description_prompt = ds['description'][idx][summary]
72
- code_input("User prompt", description_prompt)
73
-
74
- with row2:
75
- col1, col2 = st.columns(2)
76
 
77
- with col1:
78
- code_input("Expected response", ds['code'][idx])
79
-
 
 
 
 
80
 
81
- with col2:
82
- code_input("LLM response", st.session_state['response'], edit_disabled=True)
83
 
84
  with st.container(horizontal=True):
85
  st.button("Random sample", on_click=random_sample_btn, args=[ds.num_rows])
 
2
  from autohdl.data import data
3
  from autohdl.llm import system_prompt, LLM
4
  import random
 
5
 
6
  """
7
  # AutoHDL
8
  ### AI agent which generates Verilog code
9
  """
10
+ def text_cell(title, message, edit_disabled=False, height="content"):
11
  st.text(title)
12
+ st.text_area(title,
13
+ message.strip(),
14
+ label_visibility="collapsed",
15
+ height=height,
16
+ disabled=edit_disabled)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def random_sample_btn(stop):
19
  """Generate a new sample"""
20
 
 
37
  def server():
38
  ds = data(small_dataset=True)
39
 
40
+
41
  model = load_model()
42
 
43
  if 'idx' not in st.session_state:
 
49
  idx = st.session_state['idx']
50
 
51
  summary = "high_level_global_summary"
52
+ description_prompt = ds['description'][idx][summary]
53
 
54
+ CELL_HEIGHT = 300
 
55
 
56
+ # Prompt cells
57
+ with st.container(border=True, height=CELL_HEIGHT+100):
58
+ left, right = st.columns(2)
59
 
60
+ with left:
61
+ text_cell("System prompt", system_prompt, height=CELL_HEIGHT)
62
 
63
+ with right:
64
+ text_cell("User prompt", description_prompt, height=CELL_HEIGHT)
65
+
66
+ # Response cells
67
+ with st.container(border=True, height=CELL_HEIGHT+100):
 
68
 
69
+ left, right = st.columns(2)
70
+
71
+ with left:
72
+ text_cell("Expected response", ds['code'][idx], height=CELL_HEIGHT)
73
+
74
+ with right:
75
+ text_cell("LLM response", st.session_state['response'], edit_disabled=True, height=CELL_HEIGHT)
76
 
 
 
77
 
78
  with st.container(horizontal=True):
79
  st.button("Random sample", on_click=random_sample_btn, args=[ds.num_rows])