Spaces:
Running
Running
| import json | |
| import os | |
| import uuid | |
| from IPython.core.display import display, HTML, Javascript | |
| from bertviz.util import format_special_chars, format_attention, num_layers | |
| print("UP TO DATE") | |
| def head_view_mod( | |
| attention=None, | |
| tokens=None, | |
| sentence_b_start=None, | |
| prettify_tokens=True, | |
| layer=None, | |
| heads=None, | |
| encoder_attention=None, | |
| decoder_attention=None, | |
| cross_attention=None, | |
| encoder_tokens=None, | |
| decoder_tokens=None, | |
| include_layers=None, | |
| html_action='view', | |
| patest ="something" | |
| ): | |
| """Render head view | |
| Args: | |
| For self-attention models: | |
| attention: list of ``torch.FloatTensor``(one for each layer) of shape | |
| ``(batch_size(must be 1), num_heads, sequence_length, sequence_length)`` | |
| tokens: list of tokens | |
| sentence_b_start: index of first wordpiece in sentence B if input text is sentence pair (optional) | |
| For encoder-decoder models: | |
| encoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape | |
| ``(batch_size(must be 1), num_heads, encoder_sequence_length, encoder_sequence_length)`` | |
| decoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape | |
| ``(batch_size(must be 1), num_heads, decoder_sequence_length, decoder_sequence_length)`` | |
| cross_attention: list of ``torch.FloatTensor``(one for each layer) of shape | |
| ``(batch_size(must be 1), num_heads, decoder_sequence_length, encoder_sequence_length)`` | |
| encoder_tokens: list of tokens for encoder input | |
| decoder_tokens: list of tokens for decoder input | |
| For all models: | |
| prettify_tokens: indicates whether to remove special characters in wordpieces, e.g. Ġ | |
| layer: index (zero-based) of initial selected layer in visualization. Defaults to layer 0. | |
| heads: Indices (zero-based) of initial selected heads in visualization. Defaults to all heads. | |
| include_layers: Indices (zero-based) of layers to include in visualization. Defaults to all layers. | |
| Note: filtering layers may improve responsiveness of the visualization for long inputs. | |
| html_action: Specifies the action to be performed with the generated HTML object | |
| - 'view' (default): Displays the generated HTML representation as a notebook cell output | |
| - 'return' : Returns an HTML object containing the generated view for further processing or custom visualization | |
| """ | |
| attn_data = [] | |
| if attention is not None: | |
| if tokens is None: | |
| raise ValueError("'tokens' is required") | |
| if encoder_attention is not None or decoder_attention is not None or cross_attention is not None \ | |
| or encoder_tokens is not None or decoder_tokens is not None: | |
| raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This" | |
| " argument is only for self-attention models.") | |
| if include_layers is None: | |
| include_layers = list(range(num_layers(attention))) | |
| attention = format_attention(attention, include_layers) | |
| if sentence_b_start is None: | |
| attn_data.append( | |
| { | |
| 'name': None, | |
| 'attn': attention.tolist(), | |
| 'left_text': tokens, | |
| 'right_text': tokens | |
| } | |
| ) | |
| else: | |
| slice_a = slice(0, sentence_b_start) # Positions corresponding to sentence A in input | |
| slice_b = slice(sentence_b_start, len(tokens)) # Position corresponding to sentence B in input | |
| attn_data.append( | |
| { | |
| 'name': 'All', | |
| 'attn': attention.tolist(), | |
| 'left_text': tokens, | |
| 'right_text': tokens | |
| } | |
| ) | |
| attn_data.append( | |
| { | |
| 'name': 'Sentence A -> Sentence A', | |
| 'attn': attention[:, :, slice_a, slice_a].tolist(), | |
| 'left_text': tokens[slice_a], | |
| 'right_text': tokens[slice_a] | |
| } | |
| ) | |
| attn_data.append( | |
| { | |
| 'name': 'Sentence B -> Sentence B', | |
| 'attn': attention[:, :, slice_b, slice_b].tolist(), | |
| 'left_text': tokens[slice_b], | |
| 'right_text': tokens[slice_b] | |
| } | |
| ) | |
| attn_data.append( | |
| { | |
| 'name': 'Sentence A -> Sentence B', | |
| 'attn': attention[:, :, slice_a, slice_b].tolist(), | |
| 'left_text': tokens[slice_a], | |
| 'right_text': tokens[slice_b] | |
| } | |
| ) | |
| attn_data.append( | |
| { | |
| 'name': 'Sentence B -> Sentence A', | |
| 'attn': attention[:, :, slice_b, slice_a].tolist(), | |
| 'left_text': tokens[slice_b], | |
| 'right_text': tokens[slice_a] | |
| } | |
| ) | |
| elif encoder_attention is not None or decoder_attention is not None or cross_attention is not None: | |
| if encoder_attention is not None: | |
| if encoder_tokens is None: | |
| raise ValueError("'encoder_tokens' required if 'encoder_attention' is not None") | |
| if include_layers is None: | |
| include_layers = list(range(num_layers(encoder_attention))) | |
| encoder_attention = format_attention(encoder_attention, include_layers) | |
| attn_data.append( | |
| { | |
| 'name': 'Encoder', | |
| 'attn': encoder_attention.tolist(), | |
| 'left_text': encoder_tokens, | |
| 'right_text': encoder_tokens | |
| } | |
| ) | |
| if decoder_attention is not None: | |
| if decoder_tokens is None: | |
| raise ValueError("'decoder_tokens' required if 'decoder_attention' is not None") | |
| if include_layers is None: | |
| include_layers = list(range(num_layers(decoder_attention))) | |
| decoder_attention = format_attention(decoder_attention, include_layers) | |
| attn_data.append( | |
| { | |
| 'name': 'Decoder', | |
| 'attn': decoder_attention.tolist(), | |
| 'left_text': decoder_tokens, | |
| 'right_text': decoder_tokens | |
| } | |
| ) | |
| if cross_attention is not None: | |
| if encoder_tokens is None: | |
| raise ValueError("'encoder_tokens' required if 'cross_attention' is not None") | |
| if decoder_tokens is None: | |
| raise ValueError("'decoder_tokens' required if 'cross_attention' is not None") | |
| if include_layers is None: | |
| include_layers = list(range(num_layers(cross_attention))) | |
| cross_attention = format_attention(cross_attention, include_layers) | |
| attn_data.append( | |
| { | |
| 'name': 'Cross', | |
| 'attn': cross_attention.tolist(), | |
| 'left_text': decoder_tokens, | |
| 'right_text': encoder_tokens | |
| } | |
| ) | |
| else: | |
| raise ValueError("You must specify at least one attention argument.") | |
| if layer is not None and layer not in include_layers: | |
| raise ValueError(f"Layer {layer} is not in include_layers: {include_layers}") | |
| # Generate unique div id to enable multiple visualizations in one notebook | |
| vis_id = 'bertviz-%s'%(uuid.uuid4().hex) | |
| # vis_id = 'bertviz'#-%s'%(uuid.uuid4().hex) | |
| # Compose html | |
| if len(attn_data) > 1: | |
| options = '\n'.join( | |
| f'<option value="{i}">{attn_data[i]["name"]}</option>' | |
| for i, d in enumerate(attn_data) | |
| ) | |
| select_html = f'Attention: <select id="filter">{options}</select>' | |
| else: | |
| select_html = "" | |
| vis_html = f""" | |
| <div id="{vis_id}" style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;"> | |
| <span style="user-select:none"> | |
| Layer: <select id="layer"></select> | |
| {select_html} | |
| </span> | |
| <div id='vis'></div> | |
| </div> | |
| """ | |
| for d in attn_data: | |
| attn_seq_len_left = len(d['attn'][0][0]) | |
| if attn_seq_len_left != len(d['left_text']): | |
| raise ValueError( | |
| f"Attention has {attn_seq_len_left} positions, while number of tokens is {len(d['left_text'])} " | |
| f"for tokens: {' '.join(d['left_text'])}" | |
| ) | |
| attn_seq_len_right = len(d['attn'][0][0][0]) | |
| if attn_seq_len_right != len(d['right_text']): | |
| raise ValueError( | |
| f"Attention has {attn_seq_len_right} positions, while number of tokens is {len(d['right_text'])} " | |
| f"for tokens: {' '.join(d['right_text'])}" | |
| ) | |
| if prettify_tokens: | |
| d['left_text'] = format_special_chars(d['left_text']) | |
| d['right_text'] = format_special_chars(d['right_text']) | |
| params = { | |
| 'attention': attn_data, | |
| 'default_filter': "0", | |
| 'root_div_id': vis_id, | |
| 'layer': layer, | |
| 'heads': heads, | |
| 'include_layers': include_layers, | |
| 'test': 'test' | |
| } | |
| # require.js must be imported for Colab or JupyterLab: | |
| if html_action == 'gradio': | |
| html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>') | |
| html2 = HTML(vis_html) | |
| return {'html1': html1, 'html2' : html2, 'params': params } | |
| if html_action == 'view': | |
| display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')) | |
| display(HTML(vis_html)) | |
| __location__ = os.path.realpath( | |
| os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
| vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params)) | |
| display(Javascript(vis_js)) | |
| elif html_action == 'return': | |
| html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>') | |
| html2 = HTML(vis_html) | |
| __location__ = os.path.realpath( | |
| os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
| vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params)) | |
| html3 = Javascript(vis_js) | |
| script = '\n<script type="text/javascript">\n' + html3.data + '\n</script>\n' | |
| head_html = HTML(html1.data + html2.data + script) | |
| return head_html | |
| else: | |
| raise ValueError("'html_action' parameter must be 'view' or 'return") |