Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 103 additions & 52 deletions reloading/reloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
from functools import partial, update_wrapper


# have to make our own partial in case someone wants to use reloading as a iterator without any arguments
# they would get a partial back because a call without a iterator argument is assumed to be a decorator.
# getting a "TypeError: 'functools.partial' object is not iterable"
# which is not really descriptive.
# hence we overwrite the iter to make sure that the error makes sense.
class no_iter_partial(partial):
class NoIterPartial(partial):
"""
have to make our own partial in case someone wants to use reloading as a iterator without any arguments
they would get a partial back because a call without a iterator argument is assumed to be a decorator.
getting a "TypeError: 'functools.partial' object is not iterable"
which is not really descriptive.
hence we overwrite the iter to make sure that the error makes sense.
"""
def __iter__(self):
raise TypeError("Nothing to iterate over. Please pass an iterable to reloading.")
raise TypeError(
"Nothing to iterate over. Please pass an iterable to reloading."
)


def reloading(fn_or_seq=None, every=1, forever=None):
"""Wraps a loop iterator or decorates a function to reload the source code
"""
Wraps a loop iterator or decorates a function to reload the source code
before every loop iteration or function invocation.

When wrapped around the outermost iterator in a `for` loop, e.g.
Expand All @@ -37,29 +42,39 @@ def reloading(fn_or_seq=None, every=1, forever=None):
every (int, Optional): After how many iterations/invocations to reload
forever (bool, Optional): Pass `forever=true` instead of an iterator to
create an endless loop

"""
if fn_or_seq:
if isinstance(fn_or_seq, types.FunctionType):
return _reloading_function(fn_or_seq, every=every)
return _reloading_loop(fn_or_seq, every=every)
if forever and fn_or_seq is not None:
raise ValueError(
"Cannot use `forever=True` and pass an iterator at the same time"
)
if forever:
return _reloading_loop(iter(int, 1), every=every)

if fn_or_seq:
if isinstance(fn_or_seq, types.FunctionType):
return _reloading_function(fn_or_seq, every=every)
if hasattr(fn_or_seq, "__iter__"):
return _reloading_loop(fn_or_seq, every=every)
raise TypeError(
f"{reloading.__name__} expected function or iterable, got {type(fn_or_seq)}"
)
# return this function with the keyword arguments partialed in,
# so that the return value can be used as a decorator
decorator = update_wrapper(no_iter_partial(reloading, every=every), reloading)
decorator = update_wrapper(NoIterPartial(reloading, every=every), reloading)
return decorator


def unique_name(used):
# get the longest element of the used names and append a "0"
"""
Get the longest element of the used names and append a "0"
"""
return max(used, key=len) + "0"


def format_itervars(ast_node):
"""Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b'"""

"""
Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b'
"""
# handle the case that there only is a single loop var
if isinstance(ast_node, ast.Name):
return ast_node.id
Expand All @@ -78,7 +93,7 @@ def format_itervars(ast_node):
def load_file(path):
src = ""
# while loop here since while saving, the file may sometimes be empty.
while (src == ""):
while src == "":
with open(path, "r") as f:
src = f.read()
return src + "\n"
Expand All @@ -96,19 +111,21 @@ def parse_file_until_successful(path):


def isolate_loop_body_and_get_itervars(tree, lineno, loop_id):
"""Modifies tree inplace as unclear how to create ast.Module.
Returns itervars"""
"""
Modifies tree inplace as unclear how to create ast.Module.
Returns itervars
"""
candidate_nodes = []
for node in ast.walk(tree):
if (
isinstance(node, ast.For)
and isinstance(node.iter, ast.Call)
and node.iter.func.id == "reloading"
and (
(loop_id is not None and loop_id == get_loop_id(node))
or getattr(node, "lineno", None) == lineno
)
):
(loop_id is not None and loop_id == get_loop_id(node))
or getattr(node, "lineno", None) == lineno
)
):
candidate_nodes.append(node)

if len(candidate_nodes) > 1:
Expand All @@ -127,7 +144,8 @@ def isolate_loop_body_and_get_itervars(tree, lineno, loop_id):


def get_loop_id(ast_node):
"""Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file
"""
Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file
"""
return ast.dump(ast_node.target) + "__" + ast.dump(ast_node.iter)

Expand All @@ -137,18 +155,39 @@ def get_loop_code(loop_frame_info, loop_id):
while True:
tree = parse_file_until_successful(fpath)
try:
itervars, found_loop_id = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2], loop_id=loop_id)
return compile(tree, filename="", mode="exec"), format_itervars(itervars), found_loop_id
itervars, found_loop_id = isolate_loop_body_and_get_itervars(
tree, lineno=loop_frame_info[2], loop_id=loop_id
)
return (
compile(tree, filename="", mode="exec"),
format_itervars(itervars),
found_loop_id,
)
except LookupError:
handle_exception(fpath)


def handle_exception(fpath):
exc = traceback.format_exc()
exc = exc.replace('File "<string>"', 'File "{}"'.format(fpath))
exc = exc.replace('File "<string>"', f'File "{fpath}"')
sys.stderr.write(exc + "\n")
print("Edit {} and press return to continue".format(fpath))
sys.stdin.readline()

if sys.stdin.isatty():
print(
f"An error occurred. Please edit the file '{fpath}' to fix the issue and press return to continue or Ctrl+C to exit."
)
try:
sys.stdin.readline()
except KeyboardInterrupt:
print("\nExiting...")
sys.exit(1)
else:
# get error line number
line_number = int(exc.split(", line ")[-1].split(",")[0])
print(line_number)
raise Exception(
f"An error occurred. Please fix the issue in the file '{fpath}' and run the script again."
)


def _reloading_loop(seq, every=1):
Expand All @@ -158,19 +197,21 @@ def _reloading_loop(seq, every=1):
caller_globals = loop_frame_info[0].f_globals
caller_locals = loop_frame_info[0].f_locals

# create a unique name in the caller namespace that we can safely write
# the values of the iteration variables into
unique = unique_name(chain(caller_locals.keys(), caller_globals.keys()))
loop_id = None

for i, itervar_values in enumerate(seq):
if i % every == 0:
compiled_body, itervars, loop_id = get_loop_code(loop_frame_info, loop_id=loop_id)
compiled_body, itervars, loop_id = get_loop_code(
loop_frame_info, loop_id=loop_id
)

caller_locals[unique] = itervar_values
exec(itervars + " = " + unique, caller_globals, caller_locals)
print(itervars)
try:
# run main loop body
# print(f"{caller_locals.keys()}")
exec(compiled_body, caller_globals, caller_locals)
except Exception:
handle_exception(fpath)
Expand All @@ -191,32 +232,36 @@ def get_decorator_name_or_none(dec_node):

def strip_reloading_decorator(func):
"""Remove the 'reloading' decorator and all decorators before it"""
decorator_names = [get_decorator_name(dec) for dec in func.decorator_list]
decorator_names = [get_decorator_name_or_none(dec) for dec in func.decorator_list]
reloading_idx = decorator_names.index("reloading")
func.decorator_list = func.decorator_list[reloading_idx + 1:]
func.decorator_list = func.decorator_list[reloading_idx + 1 :]


def isolate_function_def(funcname, tree):
def isolate_function_def(qualname, fn, tree):
"""Strip everything but the function definition from the ast in-place.
Also strips the reloading decorator from the function definition"""
length = len(qualname.split("."))
funcname = qualname.split(".")[-1]
classname = qualname.split(".")[length - 2] if length > 1 else None

found = False
for node in ast.walk(tree):
if (
isinstance(node, ast.FunctionDef)
and node.name == funcname
and "reloading" in [
get_decorator_name_or_none(dec)
for dec in node.decorator_list
]
):
strip_reloading_decorator(node)
tree.body = [ node ]
return True
return False
if isinstance(node, ast.ClassDef) and node.name == classname:
for subnode in node.body:
if isinstance(subnode, ast.FunctionDef) and subnode.name == funcname:
if "reloading" in [
get_decorator_name_or_none(dec)
for dec in subnode.decorator_list
]:
strip_reloading_decorator(subnode)
tree.body = [subnode]
found = True
return found


def get_function_def_code(fpath, fn):
tree = parse_file_until_successful(fpath)
found = isolate_function_def(fn.__name__, tree)
found = isolate_function_def(fn.__qualname__, fn, tree)
if not found:
return None
compiled = compile(tree, filename="", mode="exec")
Expand All @@ -243,21 +288,27 @@ def _reloading_function(fn, every=1):

# crutch to use dict as python2 doesn't support nonlocal
state = {
"func": None,
"func": fn,
"reloads": 0,
}

def wrapped(*args, **kwargs):
if state["reloads"] % every == 0:
state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"]
state["func"] = (
get_reloaded_function(caller_globals, caller_locals, fpath, fn)
or state["func"]
)
state["reloads"] += 1
while True:
try:
result = state["func"](*args, **kwargs)
return result
except Exception:
handle_exception(fpath)
state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"]
state["func"] = (
get_reloaded_function(caller_globals, caller_locals, fpath, fn)
or state["func"]
)

caller_locals[fn.__name__] = wrapped
return wrapped