EVOLUTION-MANAGER
Edit File: ColorFIDBenchmarkArtistic.ipynb
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Color FID Benchmark (HQ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", "os.environ['OMP_NUM_THREADS']='1'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import statistics\n", "from fastai import *\n", "from deoldify.visualize import *\n", "import cv2\n", "from fid.fid_score import *\n", "from fid.inception import *\n", "import imageio\n", "plt.style.use('dark_background')\n", "torch.backends.cudnn.benchmark=True\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.nn.functional\")\n", "warnings.filterwarnings(\"ignore\", category=UserWarning, message='.*?retrieve source code for container of type.*?')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#NOTE: Data should come from here: 'https://datasets.figure-eight.com/figure_eight_datasets/open-images/test_challenge.zip'\n", "#NOTE: Minimum recommmended number of samples is 10K. Source: https://github.com/bioinf-jku/TTUR\n", "\n", "path = Path('data/ColorBenchmark')\n", "path_hr = path/'source'\n", "path_lr = path/'bandw'\n", "path_results = Path('./result_images/ColorBenchmarkFID/artistic')\n", "path_rendered = path_results/'rendered'\n", "\n", "#path = Path('data/DeOldifyColor')\n", "#path_hr = path\n", "#path_lr = path/'bandw'\n", "#path_results = Path('./result_images/ColorBenchmark/edge')\n", "#path_rendered = path_results/'rendered'\n", "\n", "#num_images = 2048\n", "#num_images = 15000\n", "num_images = 50000\n", "render_factor=35\n", "fid_batch_size = 4\n", "eval_size=299" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def inception_model(dims:int):\n", " block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\n", " model = InceptionV3([block_idx])\n", " model.cuda()\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_before_images(fn,i):\n", " dest = path_lr/fn.relative_to(path_hr)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n", " img.save(dest) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def render_images(colorizer, source_dir:Path, filtered_dir:Path, target_dir:Path, render_factor:int, num_images:int)->[(Path, Path, Path)]:\n", " results = []\n", " bandw_list = ImageList.from_folder(path_lr)\n", " bandw_list = bandw_list[:num_images]\n", "\n", " if len(bandw_list.items) == 0: return results\n", "\n", " results = []\n", " img_iterator = progress_bar(bandw_list.items)\n", "\n", " for bandw_path in img_iterator:\n", " target_path = target_dir/bandw_path.relative_to(source_dir)\n", "\n", " try:\n", " result_image = colorizer.get_transformed_image(path=bandw_path, render_factor=render_factor)\n", " result_path = Path(str(path_results) + '/' + bandw_path.parent.name + '/' + bandw_path.name)\n", " if not result_path.parent.exists():\n", " result_path.parent.mkdir(parents=True, exist_ok=True)\n", " result_image.save(result_path)\n", " results.append((result_path, bandw_path, target_path))\n", " except Exception as err:\n", " print('Failed to render image. Skipping. Details: {0}'.format(err))\n", " \n", " return results " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def calculate_fid_score(render_results, bs:int, eval_size:int):\n", " dims = 2048\n", " cuda = True\n", " model = inception_model(dims=dims)\n", " rendered_paths = []\n", " target_paths = []\n", " \n", " for render_result in render_results:\n", " rendered_path, _, target_path = render_result\n", " rendered_paths.append(str(rendered_path))\n", " target_paths.append(str(target_path))\n", " \n", " rendered_m, rendered_s = calculate_activation_statistics(files=rendered_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n", " target_m, target_s = calculate_activation_statistics(files=target_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n", " fid_score = calculate_frechet_distance(rendered_m, rendered_s, target_m, target_s)\n", " del model\n", " return fid_score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create black and whites source images" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Only runs if the directory isn't already created." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not path_lr.exists():\n", " il = ImageList.from_folder(path_hr)\n", " parallel(create_before_images, il.items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path_results.parent.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Rendering" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "colorizer = get_image_colorizer(artistic=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "render_results = render_images(colorizer=colorizer, source_dir=path_lr, target_dir=path_hr, filtered_dir=path_results, render_factor=render_factor, num_images=num_images)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Colorizaton Scoring" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fid_score = calculate_fid_score(render_results, bs=fid_batch_size, eval_size=eval_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('FID Score: ' + str(fid_score))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }