from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter

import cv2
import numpy as np
import sys


# Arguments
input_path = sys.argv[1]
output_path = sys.argv[2]
scale = sys.argv[3] if len(sys.argv) > 3 else "2x"


# Load image
img = Image.open(input_path)

# Convert RGB
if img.mode != "RGB":
    img = img.convert("RGB")

# Original size
w, h = img.size

print("Scale:", scale)
print("Original:", w, h)


# ==========================================================
# 2X MODE
# ==========================================================

if scale == "2x":

    img = img.resize(
        (w * 2, h * 2),
        Image.LANCZOS
    )

    img = img.filter(
        ImageFilter.SHARPEN
    )

    contrast = ImageEnhance.Contrast(img)
    img = contrast.enhance(1.2)

    color = ImageEnhance.Color(img)
    img = color.enhance(1.1)

    brightness = ImageEnhance.Brightness(img)
    img = brightness.enhance(1.05)

    sharpness = ImageEnhance.Sharpness(img)
    img = sharpness.enhance(1.3)

    img.save(
        output_path,
        quality=95,
        optimize=True
    )
# ============
# NOISE REMOVAL 
# ============
elif scale == "noise":
    img = cv2.imread(input_path)
    if img is None:
        print(f"[ERROR] Failed to load image: {input_path}")
        sys.exit(1)

    # ── Step 1: Pre-smoothing — larger kernel for full-body grain suppression
    img = cv2.GaussianBlur(img, (5, 5), sigmaX=1.2)  

    # ── Step 2: Strong full-image denoising
    img = cv2.fastNlMeansDenoisingColored(
        img,
        None,
        h=18,                   
        hColor=15,              
        templateWindowSize=9,   
        searchWindowSize=27,    
    )

    # ── Step 3: Bilateral filter — edge-aware pass over entire image
    img = cv2.bilateralFilter(
        img,
        d=0,            
        sigmaColor=0,  
        sigmaSpace=65,  
    )

    # ── Step 4: CLAHE — gentle contrast recovery across full image
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l_channel, a_channel, b_channel = cv2.split(lab)

    clahe = cv2.createCLAHE(
        clipLimit=1.2,          
        tileGridSize=(8, 8),
    )
    l_enhanced = clahe.apply(l_channel)
    img = cv2.cvtColor(cv2.merge((l_enhanced, a_channel, b_channel)), cv2.COLOR_LAB2BGR)

    # ── Step 5: Lighter sharpening — recover edges without re-introducing grain
    SHARPEN_KERNEL = np.array([
        [ 0, -0.8,  0],
        [-0.8,  4.2, -0.8],  # ⬇ was 4.8 — softer to avoid re-sharpening noise in fabric
        [ 0, -0.8,  0],
    ], dtype=np.float32)
    img = cv2.filter2D(img, ddepth=-1, kernel=SHARPEN_KERNEL)

    # ── Write result
    success = cv2.imwrite(output_path, img)
    if not success:
        print(f"[ERROR] Failed to write output: {output_path}")
        sys.exit(1)

# ==========================================================
# 4X MODE
# ==========================================================

elif scale == "4x":

    multiplier = 2.5

    new_w = int(w * multiplier)
    new_h = int(h * multiplier)

    img = img.resize(
        (new_w, new_h),
        Image.LANCZOS
    )

    img = img.filter(
        ImageFilter.SHARPEN
    )

    sharpness = ImageEnhance.Sharpness(img)
    img = sharpness.enhance(2.0)

    contrast = ImageEnhance.Contrast(img)
    img = contrast.enhance(1.4)

    color = ImageEnhance.Color(img)
    img = color.enhance(1.2)

    brightness = ImageEnhance.Brightness(img)
    img = brightness.enhance(1.08)

    img = img.filter(
        ImageFilter.DETAIL
    )

    img.save(
        output_path,
        quality=95,
        optimize=True
    )


# ==========================================================
# FACE ENHANCE
# ==========================================================

elif scale == "face":

    img_cv = cv2.imread(input_path)

    if img_cv is None:

        print("Image load failed")

        sys.exit(1)

    # upscale slightly
    img_cv = cv2.resize(
        img_cv,
        None,
        fx=1.8,
        fy=1.8,
        interpolation=cv2.INTER_CUBIC
    )

    # whole image denoise
    img_cv = cv2.fastNlMeansDenoisingColored(
        img_cv,
        None,
        4,
        4,
        7,
        21
    )

    # subtle sharpening
    kernel = np.array([
        [0, -1, 0],
        [-1, 5.2, -1],
        [0, -1, 0]
    ])

    img_cv = cv2.filter2D(
        img_cv,
        -1,
        kernel
    )

    # improve contrast naturally
    lab = cv2.cvtColor(
        img_cv,
        cv2.COLOR_BGR2LAB
    )

    l, a, b = cv2.split(lab)

    clahe = cv2.createCLAHE(
        clipLimit=1.8,
        tileGridSize=(8,8)
    )

    cl = clahe.apply(l)

    merged = cv2.merge((cl, a, b))

    img_cv = cv2.cvtColor(
        merged,
        cv2.COLOR_LAB2BGR
    )

    # smooth very slightly
    img_cv = cv2.bilateralFilter(
        img_cv,
        5,
        30,
        30
    )

    cv2.imwrite(
        output_path,
        img_cv
    )

# =============
# PORTRAIT 
# =============

elif scale == "portrait":

    img_cv = cv2.imread(input_path)

    img_cv = cv2.resize(
        img_cv,
        None,
        fx=1.5,
        fy=1.5,
        interpolation=cv2.INTER_CUBIC
    )

    # smooth slightly
    img_cv = cv2.bilateralFilter(
        img_cv,
        7,
        50,
        50
    )

    # sharpen lightly
    kernel = np.array([
        [0,-1,0],
        [-1,5,-1],
        [0,-1,0]
    ])

    img_cv = cv2.filter2D(
        img_cv,
        -1,
        kernel
    )

    # improve colors
    hsv = cv2.cvtColor(
        img_cv,
        cv2.COLOR_BGR2HSV
    )

    hsv[:,:,1] = cv2.multiply(
        hsv[:,:,1],
        1.15
    )

    img_cv = cv2.cvtColor(
        hsv,
        cv2.COLOR_HSV2BGR
    )

    cv2.imwrite(output_path, img_cv)

# ===========
# BLUR FIX
# ===========

elif scale == "blurfix":

    img_cv = cv2.imread(input_path)

    img_cv = cv2.resize(
        img_cv,
        None,
        fx=2,
        fy=2,
        interpolation=cv2.INTER_CUBIC
    )

    # strong sharpen
    kernel = np.array([
        [0,-1,0],
        [-1,6,-1],
        [0,-1,0]
    ])

    img_cv = cv2.filter2D(
        img_cv,
        -1,
        kernel
    )

    # denoise
    img_cv = cv2.fastNlMeansDenoisingColored(
        img_cv,
        None,
        5,
        5,
        7,
        21
    )

    cv2.imwrite(output_path, img_cv)

# ==========
# CARTOON
# ==========

elif scale == "cartoon":

    img_cv = cv2.imread(input_path)

    gray = cv2.cvtColor(
        img_cv,
        cv2.COLOR_BGR2GRAY
    )

    gray = cv2.medianBlur(gray, 5)

    edges = cv2.adaptiveThreshold(
        gray,
        255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY,
        9,
        9
    )

    color = cv2.bilateralFilter(
        img_cv,
        9,
        250,
        250
    )

    cartoon = cv2.bitwise_and(
        color,
        color,
        mask=edges
    )

    cv2.imwrite(output_path, cartoon)

# ============
# ANIME
# ============

elif scale == "anime":

    img_cv = cv2.imread(input_path)

    # smooth colors
    for i in range(2):

        img_cv = cv2.bilateralFilter(
            img_cv,
            9,
            75,
            75
        )

    # edges
    gray = cv2.cvtColor(
        img_cv,
        cv2.COLOR_BGR2GRAY
    )

    blur = cv2.medianBlur(gray, 5)

    edges = cv2.adaptiveThreshold(
        blur,
        255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY,
        9,
        2
    )

    anime = cv2.bitwise_and(
        img_cv,
        img_cv,
        mask=edges
    )

    # boost saturation
    hsv = cv2.cvtColor(
        anime,
        cv2.COLOR_BGR2HSV
    )

    hsv[:,:,1] = cv2.multiply(
        hsv[:,:,1],
        1.4
    )

    anime = cv2.cvtColor(
        hsv,
        cv2.COLOR_HSV2BGR
    )

    cv2.imwrite(output_path, anime)

# =============
# HD
# =============

elif scale == "hd":

    img = img.resize(
        (w * 2, h * 2),
        Image.LANCZOS
    )

    # sharpness
    sharpness = ImageEnhance.Sharpness(img)
    img = sharpness.enhance(2.2)

    # contrast
    contrast = ImageEnhance.Contrast(img)
    img = contrast.enhance(1.35)

    # color
    color = ImageEnhance.Color(img)
    img = color.enhance(1.25)

    # detail
    img = img.filter(
        ImageFilter.DETAIL
    )

    img.save(
        output_path,
        quality=98,
        optimize=True
    )

# ==========================================================
# DEFAULT
# ==========================================================

else:

    img.save(output_path)


print("DONE")