diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 41015e2..8084616 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -85,7 +85,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "Ty3IOeXbLzvc" + "id": "Ty3IOeXbLzvc", + "tags": [] }, "outputs": [], "source": [ @@ -122,7 +123,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "kelHR9VM1-hg" + "id": "kelHR9VM1-hg", + "tags": [] }, "outputs": [], "source": [ @@ -165,7 +167,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "fJZtXShcPXx5" + "id": "fJZtXShcPXx5", + "tags": [] }, "outputs": [], "source": [ @@ -570,16 +573,17 @@ "def sample_mcmc_klmc2(\n", " sd_model, \n", " init_image,\n", - " height,\n", - " width,\n", - " n, \n", - " hvp_method='reverse', \n", - " prompts=None,\n", - " settings=None,\n", - " resume = False,\n", - " resume_from=-1,\n", - " img_init_steps=None,\n", - " stuff_to_plot=None,\n", + " height:int,\n", + " width:int,\n", + " n:int, \n", + " hvp_method:str='reverse', \n", + " prompts:list=None,\n", + " settings:ParameterGroup=None,\n", + " resume:bool = False,\n", + " resume_from:int=-1,\n", + " img_init_steps:int=None,\n", + " stuff_to_plot:list=None,\n", + " checkpoint_every:int=10,\n", "):\n", "\n", " if stuff_to_plot is None:\n", @@ -630,6 +634,16 @@ " state = read_klmc2_state(latest_frame=resume_from)\n", " if state:\n", " x, v, i_resume = state['x'], state['v'], state['i']\n", + " # to do: resumption of settings\n", + " settings_i = state['settings_i']\n", + " i = i_resume\n", + " settings[i]['h'] = settings_i['h']\n", + " settings[i]['gamma'] = settings_i['gamma']\n", + " settings[i]['alpha'] = settings_i['alpha']\n", + " settings[i]['tau'] = settings_i['tau']\n", + " settings[i]['g'] = settings_i['g']\n", + " settings[i]['sigma'] = settings_i['sigma']\n", + " settings[i]['steps'] = settings_i['steps']\n", " \n", " # to do: use multicond for init image\n", " # we want this test after resumption if resuming\n", @@ -652,7 +666,16 @@ " # fast-forward loop to resumption index\n", " if resume and i < i_resume:\n", " continue\n", - "\n", + " # if resume and (i == i_resume):\n", + " # # should these values be written into settings[i]?\n", + " # h = settings_i['h']\n", + " # gamma = settings_i['gamma']\n", + " # alpha = settings_i['alpha']\n", + " # tau = settings_i['tau']\n", + " # g = settings_i['g']\n", + " # sigma = settings_i['sigma']\n", + " # steps = settings_i['steps']\n", + " # else:\n", " h = settings[i]['h']\n", " gamma = settings[i]['gamma']\n", " alpha = settings[i]['alpha']\n", @@ -689,9 +712,10 @@ " extra_args,\n", " )\n", "\n", - " save_checkpoint = (i % 10) == 0\n", + " save_checkpoint = (i % checkpoint_every) == 0\n", " if save_checkpoint:\n", - " ex.submit(write_klmc2_state, v=v, x=x, i=i)\n", + " settings_i = settings[i]\n", + " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", " logger.debug(settings[i])\n", "\n", "\n", @@ -888,9 +912,6 @@ "metadata": { "cellView": "form", "id": "yt3d1hww17ST", - "jupyter": { - "source_hidden": true - }, "tags": [] }, "outputs": [], @@ -1197,6 +1218,7 @@ "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", "\n", + "checkpoint_every = 10 # @param {type:\"number\"}\n", "\n", "###########################\n", "\n", @@ -1297,10 +1319,10 @@ "curved_settings = ParameterGroup({\n", " 'g':SmoothCurve(g),\n", " 'sigma':SmoothCurve(sigma),\n", - " 'h':SmoothCurve(h),\n", + " #'h':SmoothCurve(h),\n", " \n", " # more concise notation for flowers demo:\n", - " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", + " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", "\n", " 'gamma':SmoothCurve(gamma),\n", @@ -1324,6 +1346,57 @@ " plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tthag9k67Uey", + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "srbY3kDa7Uey", + "tags": [] + }, + "outputs": [], + "source": [ + "# load settings from disk\n", + "\n", + "load_settings_from_disk = True # @param {type:'boolean'}\n", + "load_prompts_from_disk = True # @param {type:'boolean'}\n", + "\n", + "\n", + "from numbers import Number\n", + "\n", + "if load_settings_from_disk:\n", + " with open(outdir / 'settings.yaml', 'r') as f:\n", + " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + "\n", + "#curved_settings.plot()\n", + "\n", + "###########################\n", + "\n", + "if load_prompts_from_disk:\n", + " with open(outdir / 'prompts.yaml', 'r') as f:\n", + " prompts_cfg = OmegaConf.load(f)\n", + "\n", + " prompts = []\n", + " for p in prompts_cfg.prompts:\n", + " if isinstance(p.weight, Number):\n", + " weight_curve = SmoothCurve(p.weight)\n", + " else:\n", + " weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.weight))\n", + " P = Prompt(text=p.prompt, weight_schedule=weight_curve)\n", + " prompts.append(P)\n", + " #P.weight.plot()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1391,6 +1464,57 @@ " for p in Path('debug_frames').glob(f'*'):\n", " p.unlink()\n", "\n", + "\n", + "#############################################\n", + "\n", + "# save settings\n", + "\n", + "import keyframed.serialization\n", + "# verbose:\n", + "#txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", + "# significantly less verbose:\n", + "simplified_settings = {}\n", + "simplified_settings__curves = {}\n", + "for param, curve in curved_settings.parameters.items():\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " simplified_settings[param] = kf0.value\n", + " else:\n", + " d_ = curve.to_dict(simplify=True, for_yaml=True)\n", + " d_.pop('label')\n", + " simplified_settings__curves[param] = d_\n", + "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", + "\n", + "txt = OmegaConf.to_yaml(OmegaConf.create({'parameters':simplified_settings}))\n", + "\n", + "with open(outdir / 'settings.yaml', 'w') as f:\n", + " f.write(txt)\n", + "\n", + "#########################\n", + "\n", + "# save prompts\n", + "\n", + "prompts_out = []\n", + "for prompt in prompts:\n", + " rec = {'prompt':prompt.text}\n", + " curve = prompt.weight\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " rec['weight'] = kf0.value\n", + " else:\n", + " rec['weight'] = curve.to_dict(simplify=True, for_yaml=True)\n", + " # don't reorder prompts. order matters, esp first prompt.\n", + " prompts_out.append(rec)\n", + "\n", + "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", + "\n", + "with open(outdir / 'prompts.yaml', 'w') as f:\n", + " f.write(prompts_yaml)\n", + "\n", + "##########################\n", + "\n", + "\n", "sample_mcmc_klmc2(\n", " sd_model=sd_model,\n", " init_image=init_image,\n", @@ -1404,15 +1528,28 @@ " resume_from=resume_from,\n", " img_init_steps=img_init_steps,\n", " stuff_to_plot=stuff_to_plot,\n", + " checkpoint_every=checkpoint_every,\n", ")\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!rm -rf frames/.ipynb_checkpoints" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", - "id": "DjwY7XrooLX_" + "id": "DjwY7XrooLX_", + "tags": [] }, "outputs": [], "source": [