# libraries
library(brms)
library(tidybayes)
library(ggplot2)
library(patchwork)
library(ape)
library(dplyr)
library(RColorBrewer)

#working directory
setwd("C:/Users/san260/OneDrive - CSIRO/Desktop/R_codes/All_virus")

#Load host and virus trees and compute VCV matrices
htree <- read.tree("Bat_tree.nwk")
vcv_host <- ape::vcv(htree)

corona_tree <- read.tree("BatCorona_Cophylo_231124_input.newick")
corona_tree$tip.label <- gsub("_$", "", corona_tree$tip.label)
vcv_corona <- ape::vcv(corona_tree)

para_tree <- read.tree("Paramxyovirus_WGS_tree.newick")
para_tree$tip.label <- gsub("_$", "", para_tree$tip.label)
vcv_para <- ape::vcv(para_tree)

rhabdo_tree <- read.tree("Rhabdovirus_phylogenetic_tree250425.newick")
rhabdo_tree$tip.label <- gsub("_$", "", rhabdo_tree$tip.label)
vcv_rhabdo <- ape::vcv(rhabdo_tree)

# Global y-limits
global_ylim <- c(0, 100)

#color palette
zone_palette <- c(
  "Africa"       = "#F4A582",  # orange
  "Asia"         = "#2A9D8F",  # greenish teal
  "Europe"       = "#0571B0",  # deeper blue
  "Oceania"      = "#E78AC3",  # pink
  "North America"= "#A974B3"   # lavender purple
)

#plots
fit_and_plot_zone <- function(file_path, vcv_host, vcv_virus, virus_family,
                              iter = 4000, thin = 1, y_limits = global_ylim) {
  dat <- read.csv(file_path)
  dat$phylo <- dat$host
  dat$vphy <- dat$virus
  dat$zone <- as.factor(dat$zone)
  zone_levels <- levels(dat$zone)
  zone_colors <- zone_palette[zone_levels]
  
  # Fit Bayesian mixed model
  model <- brm(
    residual ~ zone + cites + (1 | host) + (1 | virus) +
      (1 | gr(phylo, cov = A)) + (1 | gr(vphy, cov = V)),
    data = dat,
    data2 = list(A = vcv_host, V = vcv_virus),
    family = gaussian(),
    chains = 4, iter = iter, warmup = iter * 0.5,
    control = list(adapt_delta = 0.999, max_treedepth = 15),
    silent = TRUE
  )
  
  # Prepare new data for posterior predictions
  newdat <- data.frame(zone = zone_levels, cites = mean(dat$cites, na.rm = TRUE))
  pdraw <- add_fitted_draws(newdat, model, seed = 42, n = 100, re_formula = ~0, dpar = TRUE)
  
  y_breaks <- seq(0, 100, by = 25)
  
  # Plot
  ggplot(pdraw, aes(x = zone, y = .value, fill = zone)) +
    geom_jitter(data = dat, aes(x = zone, y = residual, color = zone),
                width = 0.2, alpha = 0.3, show.legend = FALSE) +
    stat_halfeye(.width = 0.95, alpha = 0.6, show.legend = FALSE) +
    scale_fill_manual(values = zone_colors, drop = FALSE) +
    scale_color_manual(values = zone_colors, drop = FALSE) +
    scale_y_continuous(limits = y_limits, breaks = y_breaks) +
    theme_minimal(base_size = 14) +
    theme(
      plot.title = element_text(face = "italic", size = 16, hjust = 0.5), 
      axis.text.x = element_text(angle = 45, hjust = 1),
      axis.line = element_line(color = "black", size = 0.5),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      axis.ticks = element_line(color = "black")
    ) +
    labs(title = virus_family, x = NULL, y = "PACo residual")
}

#final plots
p_corona <- fit_and_plot_zone("merged_corona_residual.csv", vcv_host, vcv_corona,
                              expression(italic("Coronaviridae")), iter = 4000, thin = 1)

p_para <- fit_and_plot_zone("merged_para_residual.csv", vcv_host, vcv_para,
                            expression(italic("Paramyxoviridae")), iter = 4000, thin = 1)

p_rhabdo <- fit_and_plot_zone("residual_data_with_log_cites.csv", vcv_host, vcv_rhabdo,
                              expression(italic("Rhabdoviridae")), iter = 4000, thin = 2)

# visual aesthetics of the plot
p_para <- p_para + theme(
  axis.title.y = element_blank(),
  axis.text.y  = element_blank(),
  axis.ticks.y = element_blank()
)

p_rhabdo <- p_rhabdo + theme(
  axis.title.y = element_blank(),
  axis.text.y  = element_blank(),
  axis.ticks.y = element_blank()
)

#Combine all
final_zone_plot <- (p_corona + p_para + p_rhabdo) +
  plot_layout(ncol = 3, guides = 'collect') &
  theme(
    legend.position = "bottom",
    legend.title = element_text(size = 12),
    legend.text = element_text(size = 10),
    plot.margin = margin(5, 5, 5, 5)
  )

print(final_zone_plot)
