Commit
Β·
42dc9ff
1
Parent(s):
4f97a73
Update
Browse files
app.py
CHANGED
|
@@ -56,21 +56,24 @@ class Instance:
|
|
| 56 |
self.model_type = 'base'
|
| 57 |
self.loaded_model_list = {}
|
| 58 |
self.counter = Counter()
|
| 59 |
-
self.
|
| 60 |
self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
|
| 61 |
'gligen-generation-text-box',
|
| 62 |
is_inpaint=False, is_style=False, common_instances=None
|
| 63 |
)
|
| 64 |
self.capacity = capacity
|
| 65 |
|
| 66 |
-
def _log(self, batch_size, instruction, phrase_list):
|
|
|
|
|
|
|
| 67 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 68 |
-
print(
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def get_model(self, model_type, batch_size, instruction, phrase_list):
|
| 71 |
if model_type in self.loaded_model_list:
|
| 72 |
-
self.
|
| 73 |
-
self._log(batch_size, instruction, phrase_list)
|
| 74 |
return self.loaded_model_list[model_type]
|
| 75 |
|
| 76 |
if self.capacity == len(self.loaded_model_list):
|
|
@@ -80,9 +83,8 @@ class Instance:
|
|
| 80 |
gc.collect()
|
| 81 |
torch.cuda.empty_cache()
|
| 82 |
|
| 83 |
-
self.counter[model_type] = 1
|
| 84 |
self.loaded_model_list[model_type] = self._get_model(model_type)
|
| 85 |
-
self._log(batch_size, instruction, phrase_list)
|
| 86 |
return self.loaded_model_list[model_type]
|
| 87 |
|
| 88 |
def _get_model(self, model_type):
|
|
@@ -299,7 +301,8 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
|
|
| 299 |
if len(boxes) != len(grounding_texts):
|
| 300 |
if len(boxes) < len(grounding_texts):
|
| 301 |
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
| 302 |
-
Number of boxes drawn: {}, number of grounding tokens: {}
|
|
|
|
| 303 |
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
| 304 |
|
| 305 |
boxes = (np.asarray(boxes) / 512).tolist()
|
|
|
|
| 56 |
self.model_type = 'base'
|
| 57 |
self.loaded_model_list = {}
|
| 58 |
self.counter = Counter()
|
| 59 |
+
self.global_counter = Counter()
|
| 60 |
self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
|
| 61 |
'gligen-generation-text-box',
|
| 62 |
is_inpaint=False, is_style=False, common_instances=None
|
| 63 |
)
|
| 64 |
self.capacity = capacity
|
| 65 |
|
| 66 |
+
def _log(self, model_type, batch_size, instruction, phrase_list):
|
| 67 |
+
self.counter[model_type] += 1
|
| 68 |
+
self.global_counter[model_type] += 1
|
| 69 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 70 |
+
print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
|
| 71 |
+
current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
|
| 72 |
+
))
|
| 73 |
|
| 74 |
def get_model(self, model_type, batch_size, instruction, phrase_list):
|
| 75 |
if model_type in self.loaded_model_list:
|
| 76 |
+
self._log(model_type, batch_size, instruction, phrase_list)
|
|
|
|
| 77 |
return self.loaded_model_list[model_type]
|
| 78 |
|
| 79 |
if self.capacity == len(self.loaded_model_list):
|
|
|
|
| 83 |
gc.collect()
|
| 84 |
torch.cuda.empty_cache()
|
| 85 |
|
|
|
|
| 86 |
self.loaded_model_list[model_type] = self._get_model(model_type)
|
| 87 |
+
self._log(model_type, batch_size, instruction, phrase_list)
|
| 88 |
return self.loaded_model_list[model_type]
|
| 89 |
|
| 90 |
def _get_model(self, model_type):
|
|
|
|
| 301 |
if len(boxes) != len(grounding_texts):
|
| 302 |
if len(boxes) < len(grounding_texts):
|
| 303 |
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
| 304 |
+
Number of boxes drawn: {}, number of grounding tokens: {}.
|
| 305 |
+
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
|
| 306 |
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
| 307 |
|
| 308 |
boxes = (np.asarray(boxes) / 512).tolist()
|