2024-08-29 14:37:34 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
2024-08-29 15:04:41 +00:00
|
|
|
import json
|
2024-09-16 11:30:07 +00:00
|
|
|
import re
|
2024-08-29 14:37:34 +00:00
|
|
|
import shutil
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import click
|
|
|
|
|
|
|
|
|
|
ASSETS_DIR = Path(__file__).parent.parent.parent / "assets"
|
|
|
|
|
PYTHON_ROOT_DIR = Path(__file__).parent.parent
|
|
|
|
|
|
2025-01-19 15:31:13 +00:00
|
|
|
PUBLISHED_MODELS_NAMES = [
|
|
|
|
|
"standard_v3_0",
|
|
|
|
|
]
|
2024-08-29 14:37:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
2024-09-16 11:30:07 +00:00
|
|
|
@click.option("--sync-unpublished-models", is_flag=True)
|
|
|
|
|
def main(sync_unpublished_models: bool) -> None:
|
2024-08-29 14:37:34 +00:00
|
|
|
import_content_type_kb()
|
|
|
|
|
|
2024-09-16 11:30:07 +00:00
|
|
|
if sync_unpublished_models:
|
|
|
|
|
models_names_to_sync = []
|
|
|
|
|
models_dir = ASSETS_DIR / "models"
|
|
|
|
|
for model_dir in models_dir.iterdir():
|
|
|
|
|
model_name = model_dir.name
|
|
|
|
|
if re.search("_v2_", model_name):
|
|
|
|
|
models_names_to_sync.append(model_name)
|
|
|
|
|
else:
|
|
|
|
|
models_names_to_sync = PUBLISHED_MODELS_NAMES
|
|
|
|
|
|
|
|
|
|
print(f"Syncing these models: {models_names_to_sync}")
|
|
|
|
|
|
|
|
|
|
for model_name in models_names_to_sync:
|
2024-08-29 14:37:34 +00:00
|
|
|
import_model(model_name)
|
|
|
|
|
|
2024-08-29 15:04:41 +00:00
|
|
|
gen_content_type_label_source()
|
|
|
|
|
|
2024-08-29 14:37:34 +00:00
|
|
|
|
|
|
|
|
def import_content_type_kb() -> None:
|
|
|
|
|
kb_path = ASSETS_DIR / "content_types_kb.min.json"
|
2024-09-13 12:47:59 +00:00
|
|
|
python_config_dir = PYTHON_ROOT_DIR / "src" / "magika" / "config"
|
2024-08-29 14:37:34 +00:00
|
|
|
python_kb_path = python_config_dir / kb_path.name
|
|
|
|
|
copy(kb_path, python_kb_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def import_model(model_name: str) -> None:
|
|
|
|
|
models_dir = ASSETS_DIR / "models"
|
|
|
|
|
onnx_path = models_dir / model_name / "model.onnx"
|
|
|
|
|
config_path = models_dir / model_name / "config.min.json"
|
|
|
|
|
|
2024-09-13 12:47:59 +00:00
|
|
|
python_model_dir = PYTHON_ROOT_DIR / "src" / "magika" / "models" / model_name
|
2024-08-29 14:37:34 +00:00
|
|
|
python_model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
copy(onnx_path, python_model_dir / onnx_path.name)
|
|
|
|
|
copy(config_path, python_model_dir / config_path.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy(src_path: Path, dst_path: Path) -> None:
|
|
|
|
|
"""Util to copy files and log what is being copied."""
|
|
|
|
|
print(f"Copying {src_path} => {dst_path}")
|
|
|
|
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
shutil.copy(src_path, dst_path)
|
|
|
|
|
|
|
|
|
|
|
2024-08-29 15:04:41 +00:00
|
|
|
CONTENT_TYPE_LABEL_SOURCE_PREFIX = """
|
|
|
|
|
# Copyright 2024 Google LLC
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from magika.types.strenum import StrEnum
|
|
|
|
|
|
|
|
|
|
# NOTE: DO NOT EDIT --- This file is automatically generated.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# This is the list of all possible content types we know about; however, models
|
|
|
|
|
# support a smaller subset of them. See model's config for details.
|
|
|
|
|
class ContentTypeLabel(StrEnum):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_content_type_label_source() -> None:
|
|
|
|
|
kb_path = ASSETS_DIR / "content_types_kb.min.json"
|
|
|
|
|
kb = json.loads(kb_path.read_text())
|
|
|
|
|
|
|
|
|
|
content_type_label_path = (
|
2024-09-13 12:47:59 +00:00
|
|
|
PYTHON_ROOT_DIR / "src" / "magika" / "types" / "content_type_label.py"
|
2024-08-29 15:04:41 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enum_body_lines = []
|
|
|
|
|
for ct_label_str in sorted(kb.keys()):
|
|
|
|
|
if ct_label_str[0].isdigit():
|
|
|
|
|
line = (" " * 4) + f'_{ct_label_str.upper()} = "{ct_label_str}"'
|
|
|
|
|
else:
|
|
|
|
|
line = (" " * 4) + f'{ct_label_str.upper()} = "{ct_label_str}"'
|
|
|
|
|
enum_body_lines.append(line)
|
|
|
|
|
|
|
|
|
|
out = (
|
|
|
|
|
CONTENT_TYPE_LABEL_SOURCE_PREFIX.strip()
|
|
|
|
|
+ "\n"
|
|
|
|
|
+ "\n".join(enum_body_lines)
|
|
|
|
|
+ "\n"
|
|
|
|
|
)
|
|
|
|
|
out += (
|
|
|
|
|
"\n"
|
|
|
|
|
+ (" " * 4)
|
|
|
|
|
+ (
|
|
|
|
|
"""
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return str(self)
|
|
|
|
|
""".strip()
|
|
|
|
|
+ "\n"
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
content_type_label_path.write_text(out)
|
|
|
|
|
|
|
|
|
|
|
2024-08-29 14:37:34 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|