You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
5.8 KiB
193 lines
5.8 KiB
2 years ago
|
#! /usr/bin/env python3
|
||
|
# -*- coding: utf-8 -`-
|
||
|
"""
|
||
|
Code generation script for class methods
|
||
|
to be exported as public API
|
||
|
"""
|
||
|
import argparse
|
||
|
import ast
|
||
|
import astor
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
import sys
|
||
|
|
||
|
from textwrap import indent
|
||
|
|
||
|
PREFIX = "_generated"
|
||
|
|
||
|
HEADER = """# ***********************************************************
|
||
|
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||
|
# *************************************************************
|
||
|
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
|
||
|
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||
|
from ._instrumentation import Instrument
|
||
|
|
||
|
# fmt: off
|
||
|
"""
|
||
|
|
||
|
FOOTER = """# fmt: on
|
||
|
"""
|
||
|
|
||
|
TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||
|
try:
|
||
|
return{}GLOBAL_RUN_CONTEXT.{}.{}
|
||
|
except AttributeError:
|
||
|
raise RuntimeError("must be called from async context")
|
||
|
"""
|
||
|
|
||
|
|
||
|
def is_function(node):
|
||
|
"""Check if the AST node is either a function
|
||
|
or an async function
|
||
|
"""
|
||
|
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_public(node):
|
||
|
"""Check if the AST node has a _public decorator"""
|
||
|
if not is_function(node):
|
||
|
return False
|
||
|
for decorator in node.decorator_list:
|
||
|
if isinstance(decorator, ast.Name) and decorator.id == "_public":
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def get_public_methods(tree):
|
||
|
"""Return a list of methods marked as public.
|
||
|
The function walks the given tree and extracts
|
||
|
all objects that are functions which are marked
|
||
|
public.
|
||
|
"""
|
||
|
for node in ast.walk(tree):
|
||
|
if is_public(node):
|
||
|
yield node
|
||
|
|
||
|
|
||
|
def create_passthrough_args(funcdef):
|
||
|
"""Given a function definition, create a string that represents taking all
|
||
|
the arguments from the function, and passing them through to another
|
||
|
invocation of the same function.
|
||
|
|
||
|
Example input: ast.parse("def f(a, *, b): ...")
|
||
|
Example output: "(a, b=b)"
|
||
|
"""
|
||
|
call_args = []
|
||
|
for arg in funcdef.args.args:
|
||
|
call_args.append(arg.arg)
|
||
|
if funcdef.args.vararg:
|
||
|
call_args.append("*" + funcdef.args.vararg.arg)
|
||
|
for arg in funcdef.args.kwonlyargs:
|
||
|
call_args.append(arg.arg + "=" + arg.arg)
|
||
|
if funcdef.args.kwarg:
|
||
|
call_args.append("**" + funcdef.args.kwarg.arg)
|
||
|
return "({})".format(", ".join(call_args))
|
||
|
|
||
|
|
||
|
def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
|
||
|
"""Scan the given .py file for @_public decorators, and generate wrapper
|
||
|
functions.
|
||
|
|
||
|
"""
|
||
|
generated = [HEADER]
|
||
|
source = astor.code_to_ast.parse_file(source_path)
|
||
|
for method in get_public_methods(source):
|
||
|
# Remove self from arguments
|
||
|
assert method.args.args[0].arg == "self"
|
||
|
del method.args.args[0]
|
||
|
|
||
|
# Remove decorators
|
||
|
method.decorator_list = []
|
||
|
|
||
|
# Create pass through arguments
|
||
|
new_args = create_passthrough_args(method)
|
||
|
|
||
|
# Remove method body without the docstring
|
||
|
if ast.get_docstring(method) is None:
|
||
|
del method.body[:]
|
||
|
else:
|
||
|
# The first entry is always the docstring
|
||
|
del method.body[1:]
|
||
|
|
||
|
# Create the function definition including the body
|
||
|
func = astor.to_source(method, indent_with=" " * 4)
|
||
|
|
||
|
# Create export function body
|
||
|
template = TEMPLATE.format(
|
||
|
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
|
||
|
lookup_path,
|
||
|
method.name + new_args,
|
||
|
)
|
||
|
|
||
|
# Assemble function definition arguments and body
|
||
|
snippet = func + indent(template, " " * 4)
|
||
|
|
||
|
# Append the snippet to the corresponding module
|
||
|
generated.append(snippet)
|
||
|
generated.append(FOOTER)
|
||
|
return "\n\n".join(generated)
|
||
|
|
||
|
|
||
|
def matches_disk_files(new_files):
|
||
|
for new_path, new_source in new_files.items():
|
||
|
if not os.path.exists(new_path):
|
||
|
return False
|
||
|
with open(new_path, "r", encoding="utf-8") as old_file:
|
||
|
old_source = old_file.read()
|
||
|
if old_source != new_source:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def process(sources_and_lookups, *, do_test):
|
||
|
new_files = {}
|
||
|
for source_path, lookup_path in sources_and_lookups:
|
||
|
print("Scanning:", source_path)
|
||
|
new_source = gen_public_wrappers_source(source_path, lookup_path)
|
||
|
dirname, basename = os.path.split(source_path)
|
||
|
new_path = os.path.join(dirname, PREFIX + basename)
|
||
|
new_files[new_path] = new_source
|
||
|
if do_test:
|
||
|
if not matches_disk_files(new_files):
|
||
|
print("Generated sources are outdated. Please regenerate.")
|
||
|
sys.exit(1)
|
||
|
else:
|
||
|
print("Generated sources are up to date.")
|
||
|
else:
|
||
|
for new_path, new_source in new_files.items():
|
||
|
with open(new_path, "w", encoding="utf-8") as f:
|
||
|
f.write(new_source)
|
||
|
print("Regenerated sources successfully.")
|
||
|
|
||
|
|
||
|
# This is in fact run in CI, but only in the formatting check job, which
|
||
|
# doesn't collect coverage.
|
||
|
def main(): # pragma: no cover
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Generate python code for public api wrappers"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--test", "-t", action="store_true", help="test if code is still up to date"
|
||
|
)
|
||
|
parsed_args = parser.parse_args()
|
||
|
|
||
|
source_root = Path.cwd()
|
||
|
# Double-check we found the right directory
|
||
|
assert (source_root / "LICENSE").exists()
|
||
|
core = source_root / "trio/_core"
|
||
|
to_wrap = [
|
||
|
(core / "_run.py", "runner"),
|
||
|
(core / "_instrumentation.py", "runner.instruments"),
|
||
|
(core / "_io_windows.py", "runner.io_manager"),
|
||
|
(core / "_io_epoll.py", "runner.io_manager"),
|
||
|
(core / "_io_kqueue.py", "runner.io_manager"),
|
||
|
]
|
||
|
|
||
|
process(to_wrap, do_test=parsed_args.test)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__": # pragma: no cover
|
||
|
main()
|