Spaces:
Sleeping
Sleeping
File size: 4,246 Bytes
ed06dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""
Mistral Tokenizer Wrapper
Provides correct tokenization for Devstral using mistral-common library.
The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
standard tokenization approach. This wrapper uses mistral-common to
produce correct token sequences for the model.
"""
import logging
from typing import List, Optional
logger = logging.getLogger(__name__)
class MistralTokenizerWrapper:
"""
Wrapper around mistral-common's MistralTokenizer for Devstral.
Uses encode_chat_completion() to produce correct token IDs
that the model actually expects, rather than HF's text-based approach
which produces corrupted tokens for Tekken-based models.
"""
def __init__(self, model_name: str):
"""
Initialize the Mistral tokenizer from HuggingFace hub.
Args:
model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
"""
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
self._available = True
logger.info(f"Loaded MistralTokenizer for {model_name}")
except ImportError as e:
logger.warning(f"mistral-common not available: {e}")
self._available = False
self.tokenizer = None
except Exception as e:
logger.error(f"Failed to load MistralTokenizer: {e}")
self._available = False
self.tokenizer = None
@property
def is_available(self) -> bool:
"""Check if the tokenizer was loaded successfully."""
return self._available
def encode_chat(
self,
system_prompt: str,
user_prompt: str
) -> List[int]:
"""
Encode chat messages to token IDs using mistral-common.
This produces the correct token sequence for Devstral, including
proper handling of control tokens like [INST] and [/INST].
Args:
system_prompt: System message content
user_prompt: User message content (e.g., "def quicksort(arr):")
Returns:
List of token IDs ready for model input
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
from mistral_common.protocol.instruct.messages import (
SystemMessage, UserMessage
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# Build messages list
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(UserMessage(content=user_prompt))
# Encode using mistral-common's chat completion encoding
request = ChatCompletionRequest(messages=messages)
tokenized = self.tokenizer.encode_chat_completion(request)
logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
return tokenized.tokens
def decode(self, token_ids: List[int]) -> str:
"""
Decode token IDs back to text.
Args:
token_ids: List of token IDs to decode
Returns:
Decoded text string
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode(token_ids)
def decode_token(self, token_id: int) -> str:
"""
Decode a single token ID to text.
Args:
token_id: Single token ID to decode
Returns:
Decoded text for this token
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode([token_id])
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
"""
Factory function to create a MistralTokenizerWrapper.
Returns None if mistral-common is not available or loading fails.
Args:
model_name: HuggingFace model path
Returns:
MistralTokenizerWrapper instance or None
"""
wrapper = MistralTokenizerWrapper(model_name)
if wrapper.is_available:
return wrapper
return None
|