2020-09-17 10:12:00 +02:00
import transformers
from tokenizers . implementations import SentencePieceUnigramTokenizer , BaseTokenizer
from tokenizers . processors import TemplateProcessing
from tokenizers . models import Unigram , BPE
from tokenizers import decoders
2021-02-03 09:57:41 -05:00
from tokenizers import Tokenizer , Regex
2020-09-17 10:12:00 +02:00
from tokenizers . normalizers import (
2020-09-17 10:40:34 +02:00
StripAccents ,
2020-09-17 10:12:00 +02:00
NFKD ,
Lowercase ,
Sequence ,
BertNormalizer ,
Precompiled ,
2020-09-18 12:29:13 +02:00
Replace ,
2020-09-17 10:12:00 +02:00
)
from tokenizers . pre_tokenizers import (
Digits ,
WhitespaceSplit ,
Metaspace ,
Sequence as PSequence ,
)
import json
import unicodedata
import sys
import os
import datetime
2020-09-17 10:23:36 +02:00
import argparse
2020-09-17 10:12:00 +02:00
sys . path . append ( " . " )
from spm_parity_check import check_details
from sentencepiece_extractor import SentencePieceExtractor
def check_number_comma ( piece : str ) - > bool :
return len ( piece ) < 2 or piece [ - 1 ] != " , " or not piece [ - 2 ] . isdigit ( )
def get_proto ( filename : str ) :
try :
import sys
sys . path . append ( " . " )
import sentencepiece_model_pb2 as model
except Exception :
raise Exception (
" You don ' t seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required. "
)
m = model . ModelProto ( )
m . ParseFromString ( open ( filename , " rb " ) . read ( ) )
return m
class Converter :
def __init__ ( self , original_tokenizer ) :
self . original_tokenizer = original_tokenizer
def converted ( self ) - > Tokenizer :
raise NotImplementedError ( )
class SpmConverter ( Converter ) :
def __init__ ( self , * args ) :
super ( ) . __init__ ( * args )
self . proto = get_proto ( self . original_tokenizer . vocab_file )
def vocab ( self , proto ) :
return [ ( piece . piece , piece . score ) for piece in proto . pieces ]
def unk_id ( self , proto ) :
return proto . trainer_spec . unk_id
def tokenizer ( self , proto ) :
model_type = proto . trainer_spec . model_type
vocab = self . vocab ( proto )
unk_id = self . unk_id ( proto )
if model_type == 1 :
2020-09-23 11:29:17 +02:00
tokenizer = Tokenizer ( Unigram ( vocab , unk_id ) )
2020-09-17 10:12:00 +02:00
elif model_type == 2 :
2020-09-24 08:40:48 +02:00
vocab , merges = SentencePieceExtractor ( self . original_tokenizer . vocab_file ) . extract ( )
2024-03-12 21:24:21 +11:00
tokenizer = Tokenizer ( BPE ( vocab , merges , unk_token = proto . trainer_spec . unk_piece , fuse_unk = True ) )
2020-09-17 10:12:00 +02:00
else :
raise Exception (
" You ' re trying to run a `Unigram` model but you ' re file was trained with a different algorithm "
)
return tokenizer
def normalizer ( self , proto ) :
precompiled_charsmap = proto . normalizer_spec . precompiled_charsmap
2021-02-03 09:57:41 -05:00
return Sequence ( [ Precompiled ( precompiled_charsmap ) , Replace ( Regex ( " { 2,} " ) , " " ) ] )
2020-09-17 10:12:00 +02:00
def post_processor ( self , tokenizer ) :
return None
def converted ( self ) :
tokenizer = self . tokenizer ( self . proto )
# Tokenizer assemble
tokenizer . normalizer = self . normalizer ( self . proto )
replacement = " ▁ "
2024-04-17 09:32:40 +02:00
prepend_scheme = " always "
tokenizer . pre_tokenizer = Metaspace ( replacement = replacement , prepend_scheme = prepend_scheme )
tokenizer . decoder = decoders . Metaspace ( replacement = replacement , prepend_scheme = prepend_scheme )
2020-09-17 10:12:00 +02:00
post_processor = self . post_processor ( tokenizer )
if post_processor :
tokenizer . post_processor = post_processor
# TODO what parameters should we give ?
parameters = { }
return BaseTokenizer ( tokenizer , parameters )
class AlbertConverter ( SpmConverter ) :
def vocab ( self , proto ) :
return [
2024-03-12 21:24:21 +11:00
( piece . piece , piece . score ) if check_number_comma ( piece . piece ) else ( piece . piece , piece . score - 100 )
2020-09-17 10:12:00 +02:00
for piece in proto . pieces
]
def normalizer ( self , proto ) :
2021-05-20 15:55:55 +02:00
normalizers = [ Replace ( " `` " , ' " ' ) , Replace ( " ' ' " , ' " ' ) ]
2020-09-17 10:12:00 +02:00
if not self . original_tokenizer . keep_accents :
normalizers . append ( NFKD ( ) )
2020-09-17 10:40:34 +02:00
normalizers . append ( StripAccents ( ) )
2020-09-17 10:12:00 +02:00
if self . original_tokenizer . do_lower_case :
normalizers . append ( Lowercase ( ) )
precompiled_charsmap = proto . normalizer_spec . precompiled_charsmap
normalizers . append ( Precompiled ( precompiled_charsmap ) )
2021-05-20 15:55:55 +02:00
normalizers . append ( Replace ( Regex ( " { 2,} " ) , " " ) )
2020-09-17 10:12:00 +02:00
return Sequence ( normalizers )
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " [CLS] " , " $0 " , " [SEP] " ] ,
seq_b = [ " $1 " , " [SEP] " ] ,
special_tokens = [
( " [CLS] " , tokenizer . get_vocab ( ) [ " [CLS] " ] ) ,
( " [SEP] " , tokenizer . get_vocab ( ) [ " [SEP] " ] ) ,
] ,
)
class CamembertConverter ( SpmConverter ) :
def vocab ( self , proto ) :
vocab = [
( " <s>NOTUSED " , 0.0 ) ,
( " <pad> " , 0.0 ) ,
( " </s>NOTUSED " , 0.0 ) ,
( " <unk> " , 0.0 ) ,
]
vocab + = [ ( piece . piece , piece . score ) for piece in proto . pieces ]
return vocab
def unk_id ( self , proto ) :
# See vocab unk position
return 3
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " <s> " , " $0 " , " </s> " ] ,
seq_b = [ " $1 " , " </s> " ] ,
special_tokens = [
( " <s> " , tokenizer . get_vocab ( ) [ " <s> " ] ) ,
( " </s> " , tokenizer . get_vocab ( ) [ " </s> " ] ) ,
] ,
)
class MBartConverter ( SpmConverter ) :
def vocab ( self , proto ) :
vocab = [
( " <s> " , 0.0 ) ,
( " <pad> " , 0.0 ) ,
( " </s> " , 0.0 ) ,
( " <unk> " , 0.0 ) ,
]
vocab + = [ ( piece . piece , piece . score ) for piece in proto . pieces [ 3 : ] ]
vocab + = [
( " ar_AR " , 0.0 ) ,
( " cs_CZ " , 0.0 ) ,
( " de_DE " , 0.0 ) ,
( " en_XX " , 0.0 ) ,
( " es_XX " , 0.0 ) ,
( " et_EE " , 0.0 ) ,
( " fi_FI " , 0.0 ) ,
( " fr_XX " , 0.0 ) ,
( " gu_IN " , 0.0 ) ,
( " hi_IN " , 0.0 ) ,
( " it_IT " , 0.0 ) ,
( " ja_XX " , 0.0 ) ,
( " kk_KZ " , 0.0 ) ,
( " ko_KR " , 0.0 ) ,
( " lt_LT " , 0.0 ) ,
( " lv_LV " , 0.0 ) ,
( " my_MM " , 0.0 ) ,
( " ne_NP " , 0.0 ) ,
( " nl_XX " , 0.0 ) ,
( " ro_RO " , 0.0 ) ,
( " ru_RU " , 0.0 ) ,
( " si_LK " , 0.0 ) ,
( " tr_TR " , 0.0 ) ,
( " vi_VN " , 0.0 ) ,
( " zh_CN " , 0.0 ) ,
]
return vocab
def unk_id ( self , proto ) :
return 3
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " $0 " , " </s> " , " en_XX " ] ,
seq_b = [ " $1 " , " </s> " ] ,
special_tokens = [
( " en_XX " , tokenizer . get_vocab ( ) [ " en_XX " ] ) ,
( " </s> " , tokenizer . get_vocab ( ) [ " </s> " ] ) ,
] ,
)
class XLMRobertaConverter ( SpmConverter ) :
def vocab ( self , proto ) :
vocab = [
( " <s> " , 0.0 ) ,
( " <pad> " , 0.0 ) ,
( " </s> " , 0.0 ) ,
( " <unk> " , 0.0 ) ,
]
vocab + = [ ( piece . piece , piece . score ) for piece in proto . pieces [ 3 : ] ]
return vocab
def unk_id ( self , proto ) :
unk_id = 3
return unk_id
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " <s> " , " $0 " , " </s> " ] ,
seq_b = [ " $1 " , " </s> " ] ,
special_tokens = [
( " <s> " , tokenizer . get_vocab ( ) [ " <s> " ] ) ,
( " </s> " , tokenizer . get_vocab ( ) [ " </s> " ] ) ,
] ,
)
class XLNetConverter ( SpmConverter ) :
def vocab ( self , proto ) :
return [
2024-03-12 21:24:21 +11:00
( piece . piece , piece . score ) if check_number_comma ( piece . piece ) else ( piece . piece , piece . score - 100 )
2020-09-17 10:12:00 +02:00
for piece in proto . pieces
]
def normalizer ( self , proto ) :
2021-05-20 15:55:55 +02:00
normalizers = [ Replace ( " `` " , ' " ' ) , Replace ( " ' ' " , ' " ' ) ]
2020-09-17 10:40:34 +02:00
if not self . original_tokenizer . keep_accents :
normalizers . append ( NFKD ( ) )
normalizers . append ( StripAccents ( ) )
if self . original_tokenizer . do_lower_case :
normalizers . append ( Lowercase ( ) )
precompiled_charsmap = proto . normalizer_spec . precompiled_charsmap
normalizers . append ( Precompiled ( precompiled_charsmap ) )
2021-05-20 15:55:55 +02:00
normalizers . append ( Replace ( Regex ( " { 2,} " ) , " " ) )
2020-09-17 10:40:34 +02:00
return Sequence ( normalizers )
2020-09-17 10:12:00 +02:00
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " $0 " , " <sep> " , " <cls> " ] ,
seq_b = [ " $1 " , " <sep> " ] ,
special_tokens = [
( " <sep> " , tokenizer . get_vocab ( ) [ " <sep> " ] ) ,
( " <cls> " , tokenizer . get_vocab ( ) [ " <cls> " ] ) ,
] ,
)
class ReformerConverter ( SpmConverter ) :
pass
class PegasusConverter ( SpmConverter ) :
offset = 103
def vocab ( self , proto ) :
vocab = [
( self . original_tokenizer . pad_token , 0 ) ,
( self . original_tokenizer . eos_token , 0 ) ,
]
vocab + = [ ( f " unk_ { i } " , - 100 ) for i in range ( 2 , 2 + self . offset ) ]
vocab + = [ ( piece . piece , piece . score ) for piece in proto . pieces [ 2 : ] ]
return vocab
def unk_id ( self , proto ) :
return proto . trainer_spec . unk_id + self . offset
def post_processor ( self , tokenizer ) :
eos = self . original_tokenizer . eos_token
return TemplateProcessing (
seq_a = [ " $0 " , eos ] ,
seq_b = [ " $1 " , eos ] ,
2021-02-03 09:57:41 -05:00
special_tokens = [ ( eos , tokenizer . get_vocab ( ) [ eos ] ) ] ,
2020-09-17 10:12:00 +02:00
)
class T5Converter ( SpmConverter ) :
def post_processor ( self , tokenizer ) :
return TemplateProcessing (
seq_a = [ " $0 " , " </s> " ] ,
seq_b = [ " $1 " , " </s> " ] ,
2021-02-03 09:57:41 -05:00
special_tokens = [ ( " </s> " , tokenizer . get_vocab ( ) [ " </s> " ] ) ] ,
2020-09-17 10:12:00 +02:00
)
CONVERTERS = {
" AlbertTokenizer " : AlbertConverter ,
" CamembertTokenizer " : CamembertConverter ,
" XLMRobertaTokenizer " : XLMRobertaConverter ,
" MBartTokenizer " : MBartConverter ,
" XLNetTokenizer " : XLNetConverter ,
" ReformerTokenizer " : ReformerConverter ,
" PegasusTokenizer " : PegasusConverter ,
" T5Tokenizer " : T5Converter ,
}
2020-09-17 10:23:36 +02:00
def check ( pretrained , filename ) :
2020-09-17 10:12:00 +02:00
transformer_tokenizer = transformers . AutoTokenizer . from_pretrained ( pretrained )
converter_class = CONVERTERS [ transformer_tokenizer . __class__ . __name__ ]
tokenizer = converter_class ( transformer_tokenizer ) . converted ( )
now = datetime . datetime . now
trans_total_time = datetime . timedelta ( seconds = 0 )
tok_total_time = datetime . timedelta ( seconds = 0 )
2020-09-17 10:23:36 +02:00
with open ( filename , " r " ) as f :
2020-09-17 10:12:00 +02:00
for i , line in enumerate ( f ) :
line = line . strip ( )
start = now ( )
ids = transformer_tokenizer . encode ( line )
trans = now ( )
tok_ids = tokenizer . encode ( line ) . ids
tok = now ( )
trans_total_time + = trans - start
tok_total_time + = tok - trans
if ids != tok_ids :
2020-09-18 11:51:46 +02:00
if check_details ( line , ids , tok_ids , transformer_tokenizer , tokenizer ) :
2020-09-17 10:12:00 +02:00
continue
assert ids == tok_ids , f " Error in line { i } : { line } { ids } != { tok_ids } "
2020-09-17 10:23:36 +02:00
tokenizer . save ( f " { pretrained . replace ( ' / ' , ' - ' ) } .json " )
2020-09-17 10:12:00 +02:00
return ( " OK " , trans_total_time / tok_total_time )
def main ( ) :
pretraineds = [
" albert-base-v1 " ,
" albert-large-v1 " ,
" albert-xlarge-v1 " ,
" albert-xxlarge-v1 " ,
" albert-base-v2 " ,
" albert-large-v2 " ,
" albert-xlarge-v2 " ,
" albert-xxlarge-v2 " ,
" camembert-base " ,
" xlm-roberta-base " ,
" xlm-roberta-large " ,
" xlm-roberta-large-finetuned-conll02-dutch " ,
" xlm-roberta-large-finetuned-conll02-spanish " ,
" xlm-roberta-large-finetuned-conll03-english " ,
" xlm-roberta-large-finetuned-conll03-german " ,
" facebook/mbart-large-en-ro " ,
" facebook/mbart-large-cc25 " ,
" xlnet-base-cased " ,
" xlnet-large-cased " ,
" google/reformer-crime-and-punishment " ,
" t5-small " ,
" google/pegasus-large " ,
]
2020-09-17 10:23:36 +02:00
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --filename " ,
required = True ,
type = str ,
help = " The filename that we are going to encode in both versions to check that conversion worked " ,
)
parser . add_argument (
" --models " ,
type = lambda s : s . split ( " , " ) ,
default = pretraineds ,
help = f " The pretrained tokenizers you want to test agains, (default: { pretraineds } ) " ,
)
args = parser . parse_args ( )
print ( args . filename )
2020-09-17 10:12:00 +02:00
model_len = 50
status_len = 6
speedup_len = 8
2020-09-24 08:40:48 +02:00
print ( f " | { ' Model ' : ^ { model_len } } | { ' Status ' : ^ { status_len } } | { ' Speedup ' : ^ { speedup_len } } | " )
2020-09-17 10:12:00 +02:00
print ( f " | { ' - ' * model_len } | { ' - ' * status_len } | { ' - ' * speedup_len } | " )
2020-09-17 10:23:36 +02:00
for pretrained in args . models :
status , speedup = check ( pretrained , args . filename )
2024-03-12 21:24:21 +11:00
print ( f " | { pretrained : < { model_len } } | { status : ^ { status_len } } | { speedup : ^ { speedup_len - 1 } .2f } x| " )
2020-09-17 10:12:00 +02:00
if __name__ == " __main__ " :
main ( )