# -*- coding:utf-8 -*- # Author: hankcs # Date: 2019-08-26 14:45 import logging import math import os import sys from abc import ABC, abstractmethod from typing import Optional, List, Any, Dict import numpy as np import tensorflow as tf import hanlp.utils from hanlp_common.io import save_json, load_json from hanlp.callbacks.fine_csv_logger import FineCSVLogger from hanlp.common.component import Component from hanlp.common.transform_tf import Transform from hanlp.common.vocab_tf import VocabTF from hanlp.metrics.chunking.iobes_tf import IOBES_F1_TF from hanlp.optimizers.adamw import AdamWeightDecay from hanlp.utils import io_util from hanlp.utils.io_util import get_resource, tempdir_human from hanlp.utils.log_util import init_logger, logger from hanlp.utils.string_util import format_scores from hanlp.utils.tf_util import format_metrics, size_of_dataset, summary_of_model, get_callback_by_class, NumpyEncoder from hanlp.utils.time_util import Timer, now_datetime from hanlp_common.reflection import str_to_type, classpath_of from hanlp_common.structure import SerializableDict from hanlp_common.util import merge_dict class KerasComponent(Component, ABC): def __init__(self, transform: Transform) -> None: super().__init__() self.meta = { 'class_path': classpath_of(self), 'hanlp_version': hanlp.version.__version__, } self.model: Optional[tf.keras.Model] = None self.config = SerializableDict() self.transform = transform # share config with transform for convenience, so we don't need to pass args around if self.transform.config: for k, v in self.transform.config.items(): self.config[k] = v self.transform.config = self.config def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None, callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs): input_path = get_resource(input_path) file_prefix, ext = os.path.splitext(input_path) name = os.path.basename(file_prefix) if not name: name = 'evaluate' if save_dir and not logger: logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN, mode='w') tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size) samples = self.num_samples_in(tst_data) num_batches = math.ceil(samples / batch_size) if warm_up: for x, y in tst_data: self.model.predict_on_batch(x) break if output: assert save_dir, 'Must pass save_dir in order to output' if isinstance(output, bool): output = os.path.join(save_dir, name) + '.predict' + ext elif isinstance(output, str): output = output else: raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output))) timer = Timer() eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs) loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2] delta_time = timer.stop() speed = samples / delta_time.delta_seconds if logger: f1: IOBES_F1_TF = None for metric in self.model.metrics: if isinstance(metric, IOBES_F1_TF): f1 = metric break extra_report = '' if f1: overall, by_type, extra_report = f1.state.result(full=True, verbose=False) extra_report = ' \n' + extra_report logger.info('Evaluation results for {} - ' 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}' .format(name + ext, loss, format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics), speed, extra_report)) if output: logger.info('Saving output to {}'.format(output)) with open(output, 'w', encoding='utf-8') as out: self.evaluate_output(tst_data, out, num_batches, self.model.metrics) return loss, score, speed def num_samples_in(self, dataset): return size_of_dataset(dataset) def evaluate_dataset(self, tst_data, callbacks, output, num_batches, **kwargs): loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches) return loss, score, output def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]): # out.write('x\ty_true\ty_pred\n') for metric in metrics: metric.reset_states() for idx, batch in enumerate(tst_data): outputs = self.model.predict_on_batch(batch[0]) for metric in metrics: metric(batch[1], outputs, outputs._keras_mask if hasattr(outputs, '_keras_mask') else None) self.evaluate_output_to_file(batch, outputs, out) print('\r{}/{} {}'.format(idx + 1, num_batches, format_metrics(metrics)), end='') print() def evaluate_output_to_file(self, batch, outputs, out): for x, y_gold, y_pred in zip(self.transform.X_to_inputs(batch[0]), self.transform.Y_to_outputs(batch[1], gold=True), self.transform.Y_to_outputs(outputs, gold=False)): out.write(self.transform.input_truth_output_to_str(x, y_gold, y_pred)) def _capture_config(self, config: Dict, exclude=( 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose', 'dev_batch_size', '__class__')): """ Save arguments to config Parameters ---------- config `locals()` exclude """ if 'kwargs' in config: config.update(config['kwargs']) config = dict( (key, tf.keras.utils.serialize_keras_object(value)) if hasattr(value, 'get_config') else (key, value) for key, value in config.items()) for key in exclude: config.pop(key, None) self.config.update(config) def save_meta(self, save_dir, filename='meta.json', **kwargs): self.meta['create_time']: now_datetime() self.meta.update(kwargs) save_json(self.meta, os.path.join(save_dir, filename)) def load_meta(self, save_dir, filename='meta.json'): save_dir = get_resource(save_dir) metapath = os.path.join(save_dir, filename) if os.path.isfile(metapath): self.meta.update(load_json(metapath)) def save_config(self, save_dir, filename='config.json'): self.config.save_json(os.path.join(save_dir, filename)) def load_config(self, save_dir, filename='config.json'): save_dir = get_resource(save_dir) self.config.load_json(os.path.join(save_dir, filename)) def save_weights(self, save_dir, filename='model.h5'): self.model.save_weights(os.path.join(save_dir, filename)) def load_weights(self, save_dir, filename='model.h5', **kwargs): assert self.model.built or self.model.weights, 'You must call self.model.built() in build_model() ' \ 'in order to load it' save_dir = get_resource(save_dir) self.model.load_weights(os.path.join(save_dir, filename)) def save_vocabs(self, save_dir, filename='vocabs.json'): vocabs = SerializableDict() for key, value in vars(self.transform).items(): if isinstance(value, VocabTF): vocabs[key] = value.to_dict() vocabs.save_json(os.path.join(save_dir, filename)) def load_vocabs(self, save_dir, filename='vocabs.json'): save_dir = get_resource(save_dir) vocabs = SerializableDict() vocabs.load_json(os.path.join(save_dir, filename)) for key, value in vocabs.items(): vocab = VocabTF() vocab.copy_from(value) setattr(self.transform, key, vocab) def load_transform(self, save_dir) -> Transform: """ Try to load transform only. This method might fail due to the fact it avoids building the model. If it do fail, then you have to use `load` which might be too heavy but that's the best we can do. :param save_dir: The path to load. """ save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.transform.build_config() self.transform.lock_vocabs() return self.transform def save(self, save_dir: str, **kwargs): self.save_config(save_dir) self.save_vocabs(save_dir) self.save_weights(save_dir) def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs): self.meta['load_path'] = save_dir save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True)) self.load_weights(save_dir, **kwargs) self.load_meta(save_dir) @property def input_shape(self) -> List: return self.transform.output_shapes[0] def build(self, logger, **kwargs): self.transform.build_config() self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None), loss=kwargs.get('loss', None))) self.transform.lock_vocabs() optimizer = self.build_optimizer(**self.config) loss = self.build_loss( **self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)])) # allow for different metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'), logger=logger, overwrite=True)) if not isinstance(metrics, list): if isinstance(metrics, tf.keras.metrics.Metric): metrics = [metrics] if not self.model.built: sample_inputs = self.sample_data if sample_inputs is not None: self.model(sample_inputs) else: if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None: x_shape = self.transform.output_shapes[0] else: x_shape = list(self.transform.output_shapes[0]) for i, shape in enumerate(x_shape): x_shape[i] = [None] + shape # batch + X.shape self.model.build(input_shape=x_shape) self.compile_model(optimizer, loss, metrics) return self.model, optimizer, loss, metrics def compile_model(self, optimizer, loss, metrics): try: self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly) except ValueError: from keras.saving.object_registration import CustomObjectScope with CustomObjectScope({'adamweightdecay': AdamWeightDecay}): self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly) def build_optimizer(self, optimizer, **kwargs) -> tf.keras.optimizers.Optimizer: if isinstance(optimizer, (str, dict)): custom_objects = {'AdamWeightDecay': AdamWeightDecay} try: optimizer = tf.keras.utils.deserialize_keras_object(optimizer, module_objects=vars(tf.keras.optimizers), custom_objects=custom_objects) except ValueError: optimizer['config'].pop('decay', None) optimizer = tf.keras.utils.deserialize_keras_object(optimizer, module_objects=vars(tf.keras.optimizers), custom_objects=custom_objects) self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer) return optimizer def build_loss(self, loss, **kwargs): if not loss: loss = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, from_logits=True) elif isinstance(loss, (str, dict)): loss = tf.keras.utils.deserialize_keras_object(loss, module_objects=vars(tf.keras.losses)) if isinstance(loss, tf.keras.losses.Loss): self.config.loss = tf.keras.utils.serialize_keras_object(loss) return loss def build_transform(self, **kwargs): return self.transform def build_vocab(self, trn_data, logger): train_examples = self.transform.fit(trn_data, **self.config) self.transform.summarize_vocabs(logger) return train_examples def build_metrics(self, metrics, logger: logging.Logger, **kwargs): metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') return [metric] @abstractmethod def build_model(self, **kwargs) -> tf.keras.Model: pass def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, finetune: str = None, **kwargs): self._capture_config(locals()) if sys.version_info >= (3, 10): logger.warning(f'Training with TensorFlow {tf.__version__} has not been tested on Python ' f'{sys.version_info.major}.{sys.version_info.minor}. Please downgrade to ' f'Python<=3.9 in case any compatibility issues arise.') self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) if finetune: finetune = get_resource(finetune) if os.path.isdir(finetune): finetune = os.path.join(finetune, 'model.h5') model.load_weights(finetune, by_name=True, skip_mismatch=True) logger.info(f'Loaded pretrained weights from {finetune} for finetuning') self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger)) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer, loss, metrics, callbacks, logger, **kwargs): history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch, validation_data=dev_data, callbacks=callbacks, validation_steps=dev_steps, ) # type:tf.keras.callbacks.History return history def build_valid_dataset(self, dev_data, batch_size): dev_data = self.transform.file_to_dataset(dev_data, batch_size=batch_size, shuffle=False) return dev_data def build_train_dataset(self, trn_data, batch_size, num_examples): trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size, shuffle=True, repeat=-1 if self.config.train_steps else None) return trn_data def build_callbacks(self, save_dir, logger, **kwargs): metrics = kwargs.get('metrics', 'accuracy') if isinstance(metrics, (list, tuple)): metrics = metrics[-1] monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(save_dir, 'model.h5'), # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=True) logger.debug(f'Monitor {checkpoint.monitor} for checkpoint') tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs'))) csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True) callbacks = [checkpoint, tensorboard_callback, csv_logger] lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None) if lr_decay_per_epoch: learning_rate = self.model.optimizer.get_config().get('learning_rate', None) if not learning_rate: logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer))) else: logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}') callbacks.append(tf.keras.callbacks.LearningRateScheduler( lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch))) anneal_factor = self.config.get('anneal_factor', None) if anneal_factor: callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor, patience=self.config.get('anneal_patience', 10))) early_stopping_patience = self.config.get('early_stopping_patience', None) if early_stopping_patience: callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max', verbose=1, patience=early_stopping_patience)) return callbacks def on_train_begin(self): """ Callback before the training starts """ pass def predict(self, data: Any, batch_size=None, **kwargs): assert self.model, 'Please call fit or load before predict' if not data: return [] data, flat = self.transform.input_to_inputs(data) if not batch_size: batch_size = self.config.batch_size dataset = self.transform.inputs_to_dataset(data, batch_size=batch_size, gold=kwargs.get('gold', False)) results = [] num_samples = 0 data_is_list = isinstance(data, list) for idx, batch in enumerate(dataset): samples_in_batch = tf.shape(batch[-1] if isinstance(batch[-1], tf.Tensor) else batch[-1][0])[0] if data_is_list: inputs = data[num_samples:num_samples + samples_in_batch] else: inputs = None # if data is a generator, it's usually one-time, not able to transform into a list for output in self.predict_batch(batch, inputs=inputs, **kwargs): results.append(output) num_samples += samples_in_batch self.transform.cleanup() if flat: return results[0] return results def predict_batch(self, batch, inputs=None, **kwargs): X = batch[0] Y = self.model.predict_on_batch(X) for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs): yield output @property def sample_data(self): return None @staticmethod def from_meta(meta: dict, **kwargs): """ Parameters ---------- meta kwargs Returns ------- KerasComponent """ cls = str_to_type(meta['class_path']) obj: KerasComponent = cls() assert 'load_path' in meta, f'{meta} doesn\'t contain load_path field' obj.load(meta['load_path']) return obj def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False): assert self.model, 'You have to fit or load a model before exporting it' if not export_dir: assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present' export_dir = get_resource(self.meta['load_path']) model_path = os.path.join(export_dir, str(version)) if os.path.isdir(model_path) and not overwrite: logger.info(f'{model_path} exists, skip since overwrite = {overwrite}') return export_dir logger.info(f'Exporting to {export_dir} ...') tf.saved_model.save(self.model, model_path) logger.info(f'Successfully exported model to {export_dir}') if show_hint: logger.info(f'You can serve it through \n' f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' f'--model_base_path={export_dir} --rest_api_port=8888') return export_dir def serve(self, export_dir=None, grpc_port=8500, rest_api_port=0, overwrite=False, dry_run=False): export_dir = self.export_model_for_serving(export_dir, show_hint=False, overwrite=overwrite) if not dry_run: del self.model # free memory logger.info('The inputs of exported model is shown below.') os.system(f'saved_model_cli show --all --dir {export_dir}/1') cmd = f'nohup tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' \ f'--model_base_path={export_dir} --port={grpc_port} --rest_api_port={rest_api_port} ' \ f'>serve.log 2>&1 &' logger.info(f'Running ...\n{cmd}') if not dry_run: os.system(cmd)