## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
    collapse = TRUE,
    comment = "#>",
    fig.width = 7,
    fig.height = 5.5,
    fig.align = "center",
    warning = FALSE
)
set.seed(42)
par(bty = "n")
library(yaap)

## ----quick-recommendation, eval = FALSE---------------------------------------
# library(yaap)
# 
# toy <- read.csv(system.file("extdata", "toy.csv", package = "yaap"))
# 
# fit <- run_aa(toy, K = 3, nrep = 5,  init = "furthest_sum")
# fit <- run_aa(toy, K = 3, nrep = 5,  init = "kmeans_pp")
# fit <- run_aa(toy, K = 3, nrep = 20, init = "random")
# fit <- run_aa(toy, K = 3, nrep = 3,  init = "aa_pp")

## ----aa-init-basic------------------------------------------------------------
toy <- read.csv(system.file("extdata", "toy.csv", package = "yaap"))
X <- as.matrix(toy) # 250 × 2
K <- 3

init <- aa_init(X, K = K, method = "furthest_sum")
str(init)

## ----initialization-helpers, include = FALSE----------------------------------
draw_simplex <- function(A, border, fill = NULL, lty = 1, lwd = 2,
                         pch = 17, cex = 1.1) {
    if (nrow(A) >= 3) {
        hull <- chull(A)
        polygon(A[hull, , drop = FALSE],
            border = border, lwd = lwd, lty = lty,
            col = fill
        )
    } else if (nrow(A) == 2) {
        lines(A, col = border, lwd = lwd, lty = lty)
    }
    points(A, pch = pch, cex = cex, col = border)
}

selection_deviation_score <- function(counts, K, n, pseudocount = 0.5) {
    total_counts <- sum(counts)
    N_runs <- total_counts / K
    expected <- total_counts / n
    p_null <- K / n

    log_fc <- log2((counts + pseudocount) / (expected + pseudocount))
    p_lower <- pbinom(counts, size = N_runs, prob = p_null)
    p_upper <- pbinom(counts - 1, size = N_runs, prob = p_null,
        lower.tail = FALSE
    )
    p_value <- pmin(1, 2 * pmin(p_lower, p_upper))
    signed_log_p <- sign(counts - expected) *
        -log10(pmax(p_value, .Machine$double.xmin))

    list(log_fc = log_fc, p_value = p_value, signed_log_p = signed_log_p)
}

plot_selected_density <- function(X, counts, main, K, color_max = NULL,
                                  pseudocount = 0.5) {
    score <- selection_deviation_score(counts, K, nrow(X), pseudocount)
    signed_log_p <- score$signed_log_p

    pal <- colorRamp(c("blue", "white", "red"))
    cmax <- color_max
    if (is.null(cmax)) cmax <- max(abs(signed_log_p))
    probs <- if (cmax > 0) {
        pmax(0, pmin(1, (signed_log_p + cmax) / (2 * cmax)))
    } else {
        rep(0.5, length(signed_log_p))
    }
    point_cols <- pal(probs) / 255
    point_cols <- rgb(point_cols[, 1], point_cols[, 2], point_cols[, 3])
    ix <- order(abs(signed_log_p))
    plot(
        X[ix, ],
        pch = 21, cex = 1, col = "gray50", bg = point_cols[ix],
        main = main, xlab = "", ylab = "", axes = FALSE, asp = 1
    )
}

plot_data_with_points <- function(X, A, main, col = "grey75", view = c(1, 2),
                                  pch = 4, point_col = "black") {
    plot(
        X[, view, drop = FALSE],
        pch = 16, cex = 0.45, col = col,
        main = main, xlab = "", ylab = "", axes = FALSE, asp = 1
    )
    points(A[, view, drop = FALSE],
        pch = pch, cex = 1.25, lwd = 1.6,
        col = point_col
    )
}

make_concentric_circles <- function(n = 180, noise = 0.03) {
    n_inner <- n %/% 2
    n_outer <- n - n_inner
    theta   <- c(runif(n_inner, 0, 2 * pi),
                 runif(n_outer, 0, 2 * pi))
    radius  <- c(rep(0.4, n_inner),
                 rep(1.0, n_outer))
    X <- cbind(radius * cos(theta),
               radius * sin(theta))
    X <- X + rnorm(length(X), sd = noise)
    colnames(X) <- c("x", "y")
    list(
        X = X,
        view = c(1, 2),
        col = c(rep("#0072B2", n_inner), rep("#D55E00", n_outer)),
        main = "Concentric circles",
        kernel = "laplace",
        kernel_args = list(sigma = 0.25)
    )
}

make_swiss_roll <- function(n = 180, noise = 0.025) {
    t <- runif(n, 1.5 * pi, 4.5 * pi)
    h <- runif(n, -1, 1)
    X <- cbind(t * cos(t) / 7, h, t * sin(t) / 7)
    X <- X + rnorm(length(X), sd = noise)
    colnames(X) <- c("x", "height", "z")
    roll_cols <- colorRampPalette(c("#0072B2", "#009E73", "#D55E00"))(100)
    col_id <- pmax(1, pmin(100, ceiling(99 * (t - min(t)) / diff(range(t))) + 1))
    list(
        X = X,
        view = c(1, 3),
        col = roll_cols[col_id],
        main = "Swiss roll",
        kernel = "rbf",
        kernel_args = list(sigma = 0.45)
    )
}

## ----precomputed-init-matrix--------------------------------------------------
X <- as.matrix(toy)
K <- 3
tol <- 0.01
eps <- 0

A0 <- X[1:K, , drop = FALSE]

fit_matrix <- run_aa(X, K = K, init = A0, scale = FALSE, tol = tol, eps = eps)

## ----custom-init-function-----------------------------------------------------
X <- as.matrix(toy)
K <- 3
tol <- 0.01

first_k_init <- function(X, K, ...) {
    A <- X[1:K, , drop = FALSE]
    B <- diag(1, nrow = K, ncol = nrow(X))
    list(A = A, B = B)
}

fit_custom <- run_aa(X, K = K, init = first_k_init, scale = FALSE, tol = tol)

## ----custom-init-fit-simplex, eval = FALSE------------------------------------
# X <- as.matrix(toy)
# K <- 3
# 
# A_user <- some_initializer(X, K = K)
# B <- fit_simplex(A_user, X)
# A_init <- B %*% X
# init <- list(A = A_init, B = B)

## ----run-aa-init-short, eval = FALSE------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# tol <- 0.01
# 
# fit <- run_aa(X, K = K, init = "furthest_sum", tol = tol)
# fit <- run_aa(X, K = K, init = "aa_pp", tol = tol)
# fit <- run_aa(X,
#     K = K, init = "aa_pp",
#     init_args = list(batch_size = 200, batch_type = "uniform"),
#     tol = tol
# )

## ----init-selection-density---------------------------------------------------
X <- as.matrix(toy)
K <- 3

method_labels <- c(
    random         = "Random",
    kmeans_pp      = "k-means++",
    aa_pp          = "AA++",
    furthest_first = "Furthest First",
    furthest_sum   = "Furthest Sum",
    hull_outmost   = "Hull-outmost"
)
N_expected <- 10  # Expected selections per point across all runs
N_runs <- nrow(X) * N_expected
# "significance" cutoff for coloring points red/blue
color_cutoff <- -log10(1 / nrow(X))

selected_counts <- list()
for (method in names(method_labels)) {
    counts <- rep(0, nrow(X))
    for (run in 1:N_runs) {
        if (method == "hull_outmost") {
            init_obj <- aa_init(X, K = K, method = method,
                                hull_method = "partitioned")
        } else {
            init_obj <- aa_init(X, K = K, method = method)
        }
        selected <- which(init_obj$B > 0, arr.ind = TRUE)[, 2]
        counts[selected] <- counts[selected] + 1
    }
    selected_counts[[method]] <- counts
}

## ----init-selection-density-plot, fig.width = 8, fig.height = 7, out.width="100%", results='hold', echo=FALSE----

op <- par(mfrow = c(2, 3), mar = c(2, 2, 3, 0.5), bty = "n")
for (method in names(method_labels)) {
    plot_selected_density(
        X,
        selected_counts[[method]],
        main = method_labels[method],
        K = K,
        color_max = color_cutoff
    )
}
par(op)

## ----increasing-k-------------------------------------------------------------
X <- as.matrix(toy)
k_grid <- c(3, 7, 15)
k_methods <- c("random", "kmeans_pp", "furthest_sum")

k_inits <- list()
for (method in k_methods) {
    method_inits <- list()
    for (K in k_grid) {
        k_lab <- paste("K =", K)
        method_inits[[k_lab]] <- aa_init(X, K = K, method = method)
    }
    k_inits[[method]] <- method_inits
}

## ----increasing-k-plot, fig.width = 8.5, fig.height = 6, out.width="100%", results='hold', echo=FALSE----
op <- par(
    mfrow = c(length(k_methods), length(k_grid)),
    mar = c(0.5, 0.5, 2.2, 0.2), oma = c(0, 0, 0, 0), bty = "n"
)
for (method in k_methods) {
    for (k_label in names(k_inits[[method]])) {
        init_obj <- k_inits[[method]][[k_label]]
        plot(
            X,
            pch = 16, cex = 0.42, col = "grey75",
            main = sprintf("%s\n%s", method_labels[method], k_label),
            xlab = "", ylab = "", axes = FALSE, asp = 1
        )
        draw_simplex(
            init_obj$A,
            border = "firebrick3",
            fill = adjustcolor("firebrick3", 0.05),
            cex = 0.95
        )
    }
}
par(op)

## ----init-before-after-fit----------------------------------------------------
X <- as.matrix(toy)
K <- 3
focused_methods <- c("random", "kmeans_pp", "furthest_sum")
fits <- list()
for (method in focused_methods) {
    fits[[method]] <- run_aa(
        x        = X,
        K        = K,
        init     = method,
        scale    = FALSE,
        max_iter = 60,
        tol      = 0.01
    )
}

## ----init-before-after-plot, echo=FALSE, fig.width = 8.5, fig.height = 6, out.width="100%", results='hold'----
loss_values <- c()
for (method in focused_methods) {
    loss_values <- c(loss_values, fits[[method]]$loss$loss)
}
loss_ylim <- range(loss_values)
op <- par(mfrow = c(2, 3), bty = "n")
par(mar = c(0.8, 0.6, 3, 0.5))
for (method in focused_methods) {
    fit <- fits[[method]]
    plot(fit,
        what = "coordinates",
        show_anames = FALSE,
        main = sprintf("%s\n%d steps", method_labels[method], nrow(fit$loss) - 1),
        args.data.scatter = list(pch = 16, cex = 0.45, col = "grey75"),
        pch = 19,
        col = "steelblue4",
        axes = FALSE,
        asp = 1
    )
    draw_simplex(
        fit$init,
        border = adjustcolor("firebrick3", 0.75),
        fill = adjustcolor("firebrick3", 0.05),
        lty = 2
    )
}
par(mar = c(2.8, 4.2, 3, 0.5))
for (method in focused_methods) {
    plot(fits[[method]],
        what = "loss",
        lwd = 1.8,
        col = "steelblue4",
        ylim = loss_ylim,
        main = "Loss",
        ylab = "",
        las = 1
    )
}
par(op)

## ----nonlinear-init-----------------------------------------------------------
K <- 8
nonlinear_data <- list(
    circles    = make_concentric_circles(),
    swiss_roll = make_swiss_roll()
)
nonlinear_methods <- c("random", "kmeans_pp", "furthest_sum")

nonlinear_inits <- list()
for (data_shape in names(nonlinear_data)) {
    data_obj <- nonlinear_data[[data_shape]]
    out <- list()
    for (method in nonlinear_methods) {
        out[[method]] <- aa_init(data_obj$X, K = K, method = method)
    }
    nonlinear_inits[[data_shape]] <- out
}

## ----nonlinear-init-plot, fig.width = 8.5, fig.height = 6, out.width="100%", results='hold', echo=FALSE----
op <- par(
    mfrow = c(length(nonlinear_data), length(nonlinear_methods)),
    mar = c(2, 2, 3, 0.5), bty = "n"
)
for (data_shape in names(nonlinear_data)) {
    data_obj <- nonlinear_data[[data_shape]]
    for (method in nonlinear_methods) {
        init_obj <- nonlinear_inits[[data_shape]][[method]]
        plot_data_with_points(
            data_obj$X,
            init_obj$A,
            main = sprintf("%s\n%s", data_obj$main, method_labels[method]),
            col = adjustcolor(data_obj$col, 0.62),
            view = data_obj$view
        )
    }
}
par(op)

## ----kernel-user-code, eval=FALSE---------------------------------------------
# fit_kernel <- archetypes_kernel_pgd(
#     x      = X,
#     K      = K,
#     kernel = "laplace",
#     init   = "furthest_sum",
#     tol    = 0.01
# )

## ----nonlinear-kernel-setup, include = FALSE, warning=FALSE-------------------
kernel_methods <- c("furthest_sum", "kmeans_pp")
kernel_inits <- list()
for (data_shape in names(nonlinear_data)) {
    data_obj <- nonlinear_data[[data_shape]]
    out <- list()
    for (method in kernel_methods) {
        kernel_fit <- archetypes_kernel_pgd(
            data_obj$X,
            K = K,
            init = method,
            kernel = data_obj$kernel,
            kernel_args = data_obj$kernel_args,
            max_iter = 1,
            tol = 0.01,
            tol_r2 = 1
        )
        out[[method]] <- list(
            euclidean = nonlinear_inits[[data_shape]][[method]]$A,
            kernel = kernel_fit$init %*% data_obj$X
        )
    }
    kernel_inits[[data_shape]] <- out
}

## ----nonlinear-kernel-plot, fig.width = 8.5, fig.height = 6, out.width="100%", results='hold', echo=FALSE----
op <- par(
    mfrow = c(length(nonlinear_data), length(kernel_methods)),
    mar = c(2, 2, 3, 0.5), bty = "n"
)
for (data_shape in names(nonlinear_data)) {
    data_obj <- nonlinear_data[[data_shape]]
    for (method in kernel_methods) {
        init_obj <- kernel_inits[[data_shape]][[method]]
        plot(
            data_obj$X[, data_obj$view, drop = FALSE],
            pch = 16, cex = 0.45, col = adjustcolor(data_obj$col, 0.55),
            main = sprintf("%s\n%s", data_obj$main, method_labels[method]),
            xlab = "", ylab = "", axes = FALSE, asp = 1
        )
        points(init_obj$euclidean[, data_obj$view, drop = FALSE],
            pch = 4, cex = 1.2, lwd = 1.5, col = "black"
        )
        points(init_obj$kernel[, data_obj$view, drop = FALSE],
            pch = 1, cex = 1.25, lwd = 1.6, col = "#D55E00"
        )
    }
}
par(fig = c(0, 1, 0, 1), mar = c(0, 0, 0, 0), new = TRUE)
plot.new()
legend(
    "center",
    legend = c("Euclidean init", "Kernel init"),
    pch = c(4, 1), col = c("black", "#D55E00"),
    pt.cex = c(2.1, 2.2), bty = "n", cex = 1.2
)
par(op)

## ----uniform-ex, eval = FALSE-------------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# init <- aa_init(X, K = K, method = "random")

## ----dirichlet-vs-random-plot, fig.height = 3, fig.width=7, results='hold', echo=FALSE----
X <- as.matrix(toy)
K <- 3
dir_variants <- list(
    `random` = list(method = "random"),
    `dirichlet alpha = 1` = list(method = "dirichlet", alpha = 1),
    `dirichlet alpha = 0.01` = list(method = "dirichlet", alpha = 0.01)
)
n_dir_runs <- 3

dir_inits <- list()
for (variant in names(dir_variants)) {
    args <- dir_variants[[variant]]
    dir_inits[[variant]] <- list()
    for (run in 1:n_dir_runs) {
        if (args$method == "dirichlet") {
            dir_inits[[variant]][[run]] <- aa_init(
                X,
                K = K,
                method = args$method,
                alpha = args$alpha
            )
        } else {
            dir_inits[[variant]][[run]] <- aa_init(
                X,
                K = K,
                method = args$method
            )
        }
    }
}
op <- par(mfrow = c(1, 3), mar = c(2, 2, 2.5, 0.5), bty = "n")
run_cols <- adjustcolor(c("firebrick3", "darkorange3", "purple4"), 0.7)

for (nm in names(dir_inits)) {
    plot(
        X,
        pch = 16, cex = 0.45, col = "grey75",
        main = nm, xlab = "", ylab = "",
        xlim = range(X[, 1]) + c(-1, 1),
        ylim = range(X[, 2]) + c(-1, 1),
        axes = FALSE
    )
    for (run_ix in 1:length(dir_inits[[nm]])) {
        draw_simplex(
            dir_inits[[nm]][[run_ix]]$A,
            border = run_cols[run_ix],
            fill   = adjustcolor(run_cols[run_ix], 0.06),
            lty    = run_ix,
            cex    = 1.1
        )
    }
    if (nm == names(dir_inits)[1]) {
        legend(
            "topleft",
            legend = paste("Run", 1:n_dir_runs),
            col = run_cols, lty = 1:n_dir_runs, lwd = 2, pch = 17,
            bty = "n", cex = 0.7
        )
    }
}
par(op)

## ----furthest-first-ex, eval = FALSE------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# init <- aa_init(X, K = K, method = "furthest_first")

## ----kmeans-pp-ex, eval = FALSE-----------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# init <- aa_init(X, K = K, method = "kmeans_pp")

## ----furthest-sum-ex, eval = FALSE--------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# # Default refinement
# init <- aa_init(X, K = K, method = "furthest_sum")
# 
# # More aggressive refinement
# init <- aa_init(X, K = K, method = "furthest_sum", refinement_steps = 30)
# 
# # No refinement
# init <- aa_init(X, K = K, method = "furthest_sum", refinement_steps = 0)

## ----coreset-ex, eval = FALSE-------------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# batch_size <- 60
# 
# # batch_size must be at least K
# init <- aa_init(X, K = K, method = "furthest_sum", batch_size = batch_size)

## ----aa-pp-ex, eval = FALSE---------------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# init <- aa_init(X, K = K, method = "aa_pp")

## ----aa-pp-mc-ex, eval = FALSE------------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# batch_size <- 100
# batch_type <- "uniform"
# 
# # Larger batch_size -> closer approximation to exact AA++
# init <- aa_init(X,
#     K = K,
#     method = "aa_pp",
#     batch_size = batch_size,
#     batch_type = batch_type
# )

## ----hull-outmost-ex, eval = FALSE--------------------------------------------
# X <- as.matrix(toy)
# K <- 3
# # Full hull — exact but requires 'geometry' for D > 2
# init <- aa_init(X, K = K, method = "hull_outmost", hull_method = "full")
# 
# # Projected hull — works in any dimension without extra packages
# init <- aa_init(X,
#     K = K, method = "hull_outmost",
#     hull_method = "projected", projected_dim = 2
# )
# 
# # Partitioned hull — fastest, suitable for large n
# init <- aa_init(X,
#     K = K, method = "hull_outmost",
#     hull_method = "partitioned", n_partitions = 15
# )

## ----session-info, echo = FALSE-----------------------------------------------
sessionInfo()

