You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
"""Top-level module for stub generation."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import glob
|
|
import logging
|
|
import os.path
|
|
from types import ModuleType
|
|
|
|
import black
|
|
import capnp # type: ignore
|
|
import isort
|
|
from capnp_stub_generator.capnp_types import ModuleRegistryType
|
|
from capnp_stub_generator.helper import replace_capnp_suffix
|
|
from capnp_stub_generator.writer import Writer
|
|
|
|
capnp.remove_import_hook()
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PYI_SUFFIX = ".pyi"
|
|
PY_SUFFIX = ".py"
|
|
LINE_LENGTH = 120
|
|
|
|
|
|
def format_outputs(raw_input: str, is_pyi: bool, line_length: int = LINE_LENGTH) -> str:
|
|
"""Formats raw input by means of `black` and `isort`.
|
|
|
|
Args:
|
|
raw_input (str): The unformatted input.
|
|
is_pyi (bool): Whether or not the output is a `pyi` file.
|
|
|
|
Returns:
|
|
str: The formatted outputs.
|
|
"""
|
|
# FIXME: Extract config from dev_policies
|
|
sorted_imports = isort.code(raw_input, config=isort.Config(profile="black", line_length=line_length))
|
|
return black.format_str(sorted_imports, mode=black.Mode(is_pyi=is_pyi, line_length=line_length))
|
|
|
|
|
|
def generate_stubs(module: ModuleType, module_registry: ModuleRegistryType, output_file_path: str):
|
|
"""Entry-point for generating *.pyi stubs from a module definition.
|
|
|
|
Args:
|
|
module (ModuleType): The module to generate stubs for.
|
|
module_registry (ModuleRegistryType): A registry of all detected modules.
|
|
output_file_path (str): The name of the output stub files, without file extension.
|
|
"""
|
|
writer = Writer(module, module_registry)
|
|
writer.generate_all_nested()
|
|
|
|
for outputs, suffix, is_pyi in zip((writer.dumps_pyi(), writer.dumps_py()), (PYI_SUFFIX, PY_SUFFIX), (True, False)):
|
|
formatted_output = format_outputs(outputs, is_pyi)
|
|
|
|
with open(output_file_path + suffix, "w", encoding="utf8") as output_file:
|
|
output_file.write(formatted_output)
|
|
|
|
logger.info("Wrote stubs to '%s(%s/%s)'.", output_file_path, PYI_SUFFIX, PY_SUFFIX)
|
|
|
|
|
|
def run(args: argparse.Namespace, root_directory: str):
|
|
"""Run the stub generator on a set of paths that point to *.capnp schemas.
|
|
|
|
Uses `generate_stubs` on each input file.
|
|
|
|
Args:
|
|
args (argparse.Namespace): The arguments that were passed when calling the stub generator.
|
|
root_directory (str): The directory, from which the generator is executed.
|
|
"""
|
|
paths: list[str] = args.paths
|
|
excludes: list[str] = args.excludes
|
|
clean: list[str] = args.clean
|
|
|
|
cleanup_paths: set[str] = set()
|
|
for c in clean:
|
|
cleanup_directory = os.path.join(root_directory, c)
|
|
cleanup_paths = cleanup_paths.union(glob.glob(cleanup_directory, recursive=args.recursive))
|
|
|
|
for cleanup_path in cleanup_paths:
|
|
os.remove(cleanup_path)
|
|
|
|
excluded_paths: set[str] = set()
|
|
for exclude in excludes:
|
|
exclude_directory = os.path.join(root_directory, exclude)
|
|
excluded_paths = excluded_paths.union(glob.glob(exclude_directory, recursive=args.recursive))
|
|
|
|
search_paths: set[str] = set()
|
|
for path in paths:
|
|
search_directory = os.path.join(root_directory, path)
|
|
search_paths = search_paths.union(glob.glob(search_directory, recursive=args.recursive))
|
|
|
|
# The `valid_paths` contain the automatically detected search paths, except for specifically excluded paths.
|
|
valid_paths = search_paths - excluded_paths
|
|
|
|
parser = capnp.SchemaParser()
|
|
module_registry: ModuleRegistryType = {}
|
|
|
|
for path in valid_paths:
|
|
module = parser.load(path)
|
|
module_registry[module.schema.node.id] = (path, module)
|
|
|
|
for path, module in module_registry.values():
|
|
output_directory = os.path.dirname(path)
|
|
output_file_name = replace_capnp_suffix(os.path.basename(path))
|
|
|
|
generate_stubs(module, module_registry, os.path.join(output_directory, output_file_name))
|