Skip to content
181 changes: 159 additions & 22 deletions Stable_Diffusion_KLMC2_Animation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Ty3IOeXbLzvc"
"id": "Ty3IOeXbLzvc",
"tags": []
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -122,7 +123,8 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "kelHR9VM1-hg"
"id": "kelHR9VM1-hg",
"tags": []
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -165,7 +167,8 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "fJZtXShcPXx5"
"id": "fJZtXShcPXx5",
"tags": []
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -570,16 +573,17 @@
"def sample_mcmc_klmc2(\n",
" sd_model, \n",
" init_image,\n",
" height,\n",
" width,\n",
" n, \n",
" hvp_method='reverse', \n",
" prompts=None,\n",
" settings=None,\n",
" resume = False,\n",
" resume_from=-1,\n",
" img_init_steps=None,\n",
" stuff_to_plot=None,\n",
" height:int,\n",
" width:int,\n",
" n:int, \n",
" hvp_method:str='reverse', \n",
" prompts:list=None,\n",
" settings:ParameterGroup=None,\n",
" resume:bool = False,\n",
" resume_from:int=-1,\n",
" img_init_steps:int=None,\n",
" stuff_to_plot:list=None,\n",
" checkpoint_every:int=10,\n",
"):\n",
"\n",
" if stuff_to_plot is None:\n",
Expand Down Expand Up @@ -630,6 +634,16 @@
" state = read_klmc2_state(latest_frame=resume_from)\n",
" if state:\n",
" x, v, i_resume = state['x'], state['v'], state['i']\n",
" # to do: resumption of settings\n",
" settings_i = state['settings_i']\n",
" i = i_resume\n",
" settings[i]['h'] = settings_i['h']\n",
" settings[i]['gamma'] = settings_i['gamma']\n",
" settings[i]['alpha'] = settings_i['alpha']\n",
" settings[i]['tau'] = settings_i['tau']\n",
" settings[i]['g'] = settings_i['g']\n",
" settings[i]['sigma'] = settings_i['sigma']\n",
" settings[i]['steps'] = settings_i['steps']\n",
" \n",
" # to do: use multicond for init image\n",
" # we want this test after resumption if resuming\n",
Expand All @@ -652,7 +666,16 @@
" # fast-forward loop to resumption index\n",
" if resume and i < i_resume:\n",
" continue\n",
"\n",
" # if resume and (i == i_resume):\n",
" # # should these values be written into settings[i]?\n",
" # h = settings_i['h']\n",
" # gamma = settings_i['gamma']\n",
" # alpha = settings_i['alpha']\n",
" # tau = settings_i['tau']\n",
" # g = settings_i['g']\n",
" # sigma = settings_i['sigma']\n",
" # steps = settings_i['steps']\n",
" # else:\n",
" h = settings[i]['h']\n",
" gamma = settings[i]['gamma']\n",
" alpha = settings[i]['alpha']\n",
Expand Down Expand Up @@ -689,9 +712,10 @@
" extra_args,\n",
" )\n",
"\n",
" save_checkpoint = (i % 10) == 0\n",
" save_checkpoint = (i % checkpoint_every) == 0\n",
" if save_checkpoint:\n",
" ex.submit(write_klmc2_state, v=v, x=x, i=i)\n",
" settings_i = settings[i]\n",
" ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n",
" logger.debug(settings[i])\n",
"\n",
"\n",
Expand Down Expand Up @@ -888,9 +912,6 @@
"metadata": {
"cellView": "form",
"id": "yt3d1hww17ST",
"jupyter": {
"source_hidden": true
},
"tags": []
},
"outputs": [],
Expand Down Expand Up @@ -1197,6 +1218,7 @@
"# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.</small>\n",
"hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n",
"\n",
"checkpoint_every = 10 # @param {type:\"number\"}\n",
"\n",
"###########################\n",
"\n",
Expand Down Expand Up @@ -1297,10 +1319,10 @@
"curved_settings = ParameterGroup({\n",
" 'g':SmoothCurve(g),\n",
" 'sigma':SmoothCurve(sigma),\n",
" 'h':SmoothCurve(h),\n",
" #'h':SmoothCurve(h),\n",
" \n",
" # more concise notation for flowers demo:\n",
" #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n",
" 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n",
" #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n",
"\n",
" 'gamma':SmoothCurve(gamma),\n",
Expand All @@ -1324,6 +1346,57 @@
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tthag9k67Uey",
"tags": []
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "srbY3kDa7Uey",
"tags": []
},
"outputs": [],
"source": [
"# load settings from disk\n",
"\n",
"load_settings_from_disk = True # @param {type:'boolean'}\n",
"load_prompts_from_disk = True # @param {type:'boolean'}\n",
"\n",
"\n",
"from numbers import Number\n",
"\n",
"if load_settings_from_disk:\n",
" with open(outdir / 'settings.yaml', 'r') as f:\n",
" curved_settings = keyframed.serialization.from_yaml(f.read())\n",
"\n",
"#curved_settings.plot()\n",
"\n",
"###########################\n",
"\n",
"if load_prompts_from_disk:\n",
" with open(outdir / 'prompts.yaml', 'r') as f:\n",
" prompts_cfg = OmegaConf.load(f)\n",
"\n",
" prompts = []\n",
" for p in prompts_cfg.prompts:\n",
" if isinstance(p.weight, Number):\n",
" weight_curve = SmoothCurve(p.weight)\n",
" else:\n",
" weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.weight))\n",
" P = Prompt(text=p.prompt, weight_schedule=weight_curve)\n",
" prompts.append(P)\n",
" #P.weight.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1391,6 +1464,57 @@
" for p in Path('debug_frames').glob(f'*'):\n",
" p.unlink()\n",
"\n",
"\n",
"#############################################\n",
"\n",
"# save settings\n",
"\n",
"import keyframed.serialization\n",
"# verbose:\n",
"#txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n",
"\n",
"# significantly less verbose:\n",
"simplified_settings = {}\n",
"simplified_settings__curves = {}\n",
"for param, curve in curved_settings.parameters.items():\n",
" kf0 = curve._data[0]\n",
" if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n",
" simplified_settings[param] = kf0.value\n",
" else:\n",
" d_ = curve.to_dict(simplify=True, for_yaml=True)\n",
" d_.pop('label')\n",
" simplified_settings__curves[param] = d_\n",
"simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n",
"\n",
"txt = OmegaConf.to_yaml(OmegaConf.create({'parameters':simplified_settings}))\n",
"\n",
"with open(outdir / 'settings.yaml', 'w') as f:\n",
" f.write(txt)\n",
"\n",
"#########################\n",
"\n",
"# save prompts\n",
"\n",
"prompts_out = []\n",
"for prompt in prompts:\n",
" rec = {'prompt':prompt.text}\n",
" curve = prompt.weight\n",
" kf0 = curve._data[0]\n",
" if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n",
" rec['weight'] = kf0.value\n",
" else:\n",
" rec['weight'] = curve.to_dict(simplify=True, for_yaml=True)\n",
" # don't reorder prompts. order matters, esp first prompt.\n",
" prompts_out.append(rec)\n",
"\n",
"prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n",
"\n",
"with open(outdir / 'prompts.yaml', 'w') as f:\n",
" f.write(prompts_yaml)\n",
"\n",
"##########################\n",
"\n",
"\n",
"sample_mcmc_klmc2(\n",
" sd_model=sd_model,\n",
" init_image=init_image,\n",
Expand All @@ -1404,15 +1528,28 @@
" resume_from=resume_from,\n",
" img_init_steps=img_init_steps,\n",
" stuff_to_plot=stuff_to_plot,\n",
" checkpoint_every=checkpoint_every,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!rm -rf frames/.ipynb_checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "DjwY7XrooLX_"
"id": "DjwY7XrooLX_",
"tags": []
},
"outputs": [],
"source": [
Expand Down