api / backend /mistral_tokenizer.py
gary-boon
Integrate mistral-common for correct Devstral tokenization
ed06dcb
raw
history blame
4.25 kB
"""
Mistral Tokenizer Wrapper
Provides correct tokenization for Devstral using mistral-common library.
The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
standard tokenization approach. This wrapper uses mistral-common to
produce correct token sequences for the model.
"""
import logging
from typing import List, Optional
logger = logging.getLogger(__name__)
class MistralTokenizerWrapper:
"""
Wrapper around mistral-common's MistralTokenizer for Devstral.
Uses encode_chat_completion() to produce correct token IDs
that the model actually expects, rather than HF's text-based approach
which produces corrupted tokens for Tekken-based models.
"""
def __init__(self, model_name: str):
"""
Initialize the Mistral tokenizer from HuggingFace hub.
Args:
model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
"""
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
self._available = True
logger.info(f"Loaded MistralTokenizer for {model_name}")
except ImportError as e:
logger.warning(f"mistral-common not available: {e}")
self._available = False
self.tokenizer = None
except Exception as e:
logger.error(f"Failed to load MistralTokenizer: {e}")
self._available = False
self.tokenizer = None
@property
def is_available(self) -> bool:
"""Check if the tokenizer was loaded successfully."""
return self._available
def encode_chat(
self,
system_prompt: str,
user_prompt: str
) -> List[int]:
"""
Encode chat messages to token IDs using mistral-common.
This produces the correct token sequence for Devstral, including
proper handling of control tokens like [INST] and [/INST].
Args:
system_prompt: System message content
user_prompt: User message content (e.g., "def quicksort(arr):")
Returns:
List of token IDs ready for model input
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
from mistral_common.protocol.instruct.messages import (
SystemMessage, UserMessage
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# Build messages list
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(UserMessage(content=user_prompt))
# Encode using mistral-common's chat completion encoding
request = ChatCompletionRequest(messages=messages)
tokenized = self.tokenizer.encode_chat_completion(request)
logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
return tokenized.tokens
def decode(self, token_ids: List[int]) -> str:
"""
Decode token IDs back to text.
Args:
token_ids: List of token IDs to decode
Returns:
Decoded text string
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode(token_ids)
def decode_token(self, token_id: int) -> str:
"""
Decode a single token ID to text.
Args:
token_id: Single token ID to decode
Returns:
Decoded text for this token
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode([token_id])
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
"""
Factory function to create a MistralTokenizerWrapper.
Returns None if mistral-common is not available or loading fails.
Args:
model_name: HuggingFace model path
Returns:
MistralTokenizerWrapper instance or None
"""
wrapper = MistralTokenizerWrapper(model_name)
if wrapper.is_available:
return wrapper
return None