#!/usr/bin/env python3 import json import re import shutil from pathlib import Path import click ASSETS_DIR = Path(__file__).parent.parent.parent / "assets" PYTHON_ROOT_DIR = Path(__file__).parent.parent PUBLISHED_MODELS_NAMES = [ "standard_v3_0", ] @click.command() @click.option("--sync-unpublished-models", is_flag=True) def main(sync_unpublished_models: bool) -> None: import_content_type_kb() if sync_unpublished_models: models_names_to_sync = [] models_dir = ASSETS_DIR / "models" for model_dir in models_dir.iterdir(): model_name = model_dir.name if re.search("_v2_", model_name): models_names_to_sync.append(model_name) else: models_names_to_sync = PUBLISHED_MODELS_NAMES print(f"Syncing these models: {models_names_to_sync}") for model_name in models_names_to_sync: import_model(model_name) gen_content_type_label_source() def import_content_type_kb() -> None: kb_path = ASSETS_DIR / "content_types_kb.min.json" python_config_dir = PYTHON_ROOT_DIR / "src" / "magika" / "config" python_kb_path = python_config_dir / kb_path.name copy(kb_path, python_kb_path) def import_model(model_name: str) -> None: models_dir = ASSETS_DIR / "models" onnx_path = models_dir / model_name / "model.onnx" config_path = models_dir / model_name / "config.min.json" python_model_dir = PYTHON_ROOT_DIR / "src" / "magika" / "models" / model_name python_model_dir.mkdir(parents=True, exist_ok=True) copy(onnx_path, python_model_dir / onnx_path.name) copy(config_path, python_model_dir / config_path.name) def copy(src_path: Path, dst_path: Path) -> None: """Util to copy files and log what is being copied.""" print(f"Copying {src_path} => {dst_path}") dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(src_path, dst_path) CONTENT_TYPE_LABEL_SOURCE_PREFIX = """ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from magika.types.strenum import StrEnum # NOTE: DO NOT EDIT --- This file is automatically generated. # This is the list of all possible content types we know about; however, models # support a smaller subset of them. See model's config for details. class ContentTypeLabel(StrEnum): """ def gen_content_type_label_source() -> None: kb_path = ASSETS_DIR / "content_types_kb.min.json" kb = json.loads(kb_path.read_text()) content_type_label_path = ( PYTHON_ROOT_DIR / "src" / "magika" / "types" / "content_type_label.py" ) enum_body_lines = [] for ct_label_str in sorted(kb.keys()): if ct_label_str[0].isdigit(): line = (" " * 4) + f'_{ct_label_str.upper()} = "{ct_label_str}"' else: line = (" " * 4) + f'{ct_label_str.upper()} = "{ct_label_str}"' enum_body_lines.append(line) out = ( CONTENT_TYPE_LABEL_SOURCE_PREFIX.strip() + "\n" + "\n".join(enum_body_lines) + "\n" ) out += ( "\n" + (" " * 4) + ( """ def __repr__(self) -> str: return str(self) """.strip() + "\n" ) ) content_type_label_path.write_text(out) if __name__ == "__main__": main()