diff --git a/ratinabox/Agent.py b/ratinabox/Agent.py index 427da822..323f29fb 100644 --- a/ratinabox/Agent.py +++ b/ratinabox/Agent.py @@ -843,6 +843,7 @@ def animate_trajectory( t_end=None, fps=15, speed_up=5, #by default the animation is 5x faster than real time + hold=0, # in seconds progress_bar=False, autosave=None, **kwargs @@ -853,6 +854,7 @@ def animate_trajectory( t_end (_type_, optional): _description_. Defaults to None. fps: frames per second of end video speed_up: #times real speed animation should come out at + hold (float): time to hold the final frame for in seconds. Defaults to 0. progress_bar (bool): if True, a progress bar will be shown as the animation is created. Defaults to False. autosave (bool): whether to automatical try and save this. Defaults to None in which case looks for global constant ratinabox.autosave_plots kwargs: passed to trajectory plotting function (chuck anything you wish in here). A particularly useful kwarg is 'additional_plot_func': any function which takes a fig, ax and t as input. The animation wll be passed through this each time after plotting the trajectory, use it to modify your animations however you like @@ -898,9 +900,16 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt, kwargs): ) frames = int((t_end - t_start) / (dt * speed_up)) + + if hold: + hold_n = int(hold * fps) + frames = np.concatenate([np.arange(frames), np.repeat(frames - 1, hold_n)]) + else: + frames = range(frames) + if progress_bar: from tqdm import tqdm - frames = tqdm(range(frames), position=0, leave=True) + frames = tqdm(frames, position=0, leave=True) from matplotlib import animation diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index b3f7c8a5..d1f9c737 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -698,6 +698,7 @@ def animate_rate_timeseries( chosen_neurons="all", fps=15, speed_up=1, + hold=0, # in seconds progress_bar=False, autosave=None, **kwargs, @@ -713,6 +714,7 @@ def animate_rate_timeseries( • chosen_neurons: Which neurons to plot. string "10" or 10 will plot ten of them, "all" will plot all of them, "12rand" will plot 12 random ones. A list like [1,4,5] will plot cells indexed 1, 4 and 5. Defaults to "all". • fps: frames per second of end video. Defaults to 15. • speed_up: #times real speed animation should come out at. Defaults to 1. + • hold: time to hold the final frame for in seconds. Defaults to 0. • progress_bar: if True, a progress bar will be shown as the animation is created. Default to False. @@ -757,9 +759,16 @@ def animate_(i, fig, ax, chosen_neurons, t_start, t_max, dt, speed_up): ) frames = int((t_end - t_start) / (dt * speed_up)) + + if hold: + hold_n = int(hold * fps) + frames = np.concatenate([np.arange(frames), np.repeat(frames - 1, hold_n)]) + else: + frames = range(frames) + if progress_bar: from tqdm import tqdm - frames = tqdm(range(frames), position=0, leave=True) + frames = tqdm(frames, position=0, leave=True) from matplotlib import animation