# these are the necessary packages

library(DESeq2)
library(tximport)
library(readr)
library(tidyverse)
library(ensembldb)
library(RMariaDB)
library(AnnotationHub)
library(EnsDb.Mmusculus.v79)
library(apeglm)
library(tidyverse)
library(ggplot2)
library(ggsignif)
library(ggrepel)
library(ggvenn)
library(gplots)
library(RColorBrewer)
library(gghighlight)
library(EDASeq)
library(DEGreport)
library(RNAseqQC)
library(clusterProfiler)
library(EnhancedVolcano)
library(thematic)


# enable thematic to control plot themes at high level 
thematic_on(bg = '#222222', fg = 'white', font = font_spec(scale = 2))

# create table for the samples of the RNAseq
names <- list.files('Fastqs/', pattern = '._R1\\.fastq.gz', full.names = FALSE)
names <- str_remove(names, '.fastq.gz')

# create a data frame for all the information shown in the sample file names
# this is the sample information data which will be combined with the
# count data on each gene for each sample
samples <- data.frame(do.call(rbind, strsplit(names, "_")))
colnames(samples) <- c('Type', 'Sex', 'Drug', 'Cycle',
                       'Mouse', 'SampleNum', 'SampleName')
samples <- dplyr::select(samples, -7)
samples$SampleName <- names

# save samples as csv
write.csv(samples, file = 'SampleInformation.csv')

# getting all the quantification files from Salmon
# file paths must be in your working directory
# for me there is a folder in the working directory call quants
# which is the output from Salmon (see the salmon_quant file for an example)
files <- file.path('quants', str_c(names, "_quant"), 'quant.sf')
# set names
names(files) <- paste0(samples$SampleName)

# get the gene annotations for gene IDs based on the mouse transcriptome 
# reference data at Ensembl database
txdb <- makeTxDbFromEnsembl(organism = "Mus musculus", release = 111)
k <- keys(txdb, keytype = "TXNAME")
tx2gene <- ensembldb::select(txdb, k, "GENEID", "TXNAME")
mm.gene_symbols <- ensembldb::select(EnsDb.Mmusculus.v79, keys = tx2gene$GENEID,
                          keytype = "GENEID", columns = c("SYMBOL","GENEID"))

# map the counts from the salmon data based on the mouse reference
# and load the data into r
txi <- tximport(files, type = 'salmon', tx2gene = tx2gene, ignoreTxVersion = TRUE)
names(txi)
head(txi$counts)

# saving the count data for each sample
# can be combined with sample information table to rebuild dds
write.csv(txi$counts, file = 'CountData.csv')
write.csv(txi$length, file = 'LengthData.csv')

# create DESeqDataSet to input into differential expression analysis
# using DESeq2


# perform shrinkage of effect sizes
# improves the accuracy of the padj values
resLFC <- lfcShrink(dds, coef="Type_IP_vs_IN", type = "apeglm")
resLFC

# adding gene symbols from mm.gene_symbols to the resLFC
resLFC <- tibble::rownames_to_column(as.data.frame(resLFC), "GENEID")
resLFC <- left_join(resLFC, mm.gene_symbols, by = "GENEID")

# adding transcripts per kilobase million to the resLFC
length_means <- DataFrame(rowMeans(txi$length))
rpk <- resLFC$baseMean / (length_means$rowMeans.txi.length./1000)
scalingFactor <- sum(rpk)/1e6
resLFC$TPM <- rpk/scalingFactor



# plotting DESeq results
# results for IP vs IN (all samples)
# select microglia genes

microglia_genes <-c('Tmem119', 'C1qc', 'Csf1r', 'Selgpl', 
                    'Olfml3', 'Ctss', 'Golm1', 'Hexb', 'Aif1', 'C5ar1',
                    'P2ry12', 'P2ry13')


# all genes with log2fc > |1.5| and padj < 0.05
#labels selected microglia genes
myPalette <- colorRampPalette((brewer.pal(11, "Spectral")))
g <- ggplot(resLFC,
            aes(x = log2(TPM), y = log2FoldChange, color = padj,
                   label = SYMBOL)) +
  geom_point() + 
  gghighlight(abs(log2FoldChange) > 1.5 &
                padj < 0.05, 
              unhighlighted_params = list(color = NULL, alpha = 1)) +
  scale_color_gradientn(colors = myPalette(8))

# adding the gene annotations
g <- g +
  geom_text_repel(data = subset(resLFC, SYMBOL %in% microglia_genes),
                  nudge_x = 5,
                  nudge_y = 2,
                  hjust = 0, 
                  max.overlaps = Inf)

# how to save a high resolution plot
# note need to set path below
ggsave(filename = "Some/file/path/here/ImageName.png",
              plot = g,
              dpi = 300, width = 7, height = 6)

# volcano plot with labels

labeled_genes <-c('Tmem119', 'C1qc', 'Csf1r', 'Selgpl', 
                    'Olfml3', 'Ctss', 'Golm1', 'Hexb', 'Aif1', 'C5ar1',
                    'P2ry12', 'P2ry13')

g <- EnhancedVolcano(resLFC,
                lab = resLFC$SYMBOL,
                x = 'log2FoldChange',
                y = 'padj',
                selectLab = labeled_genes,
                ylab = bquote(~-Log[10]~ 'padj'),
                xlab = bquote(~Log[2]~ 'fold change'),
                pCutoff = 0.01,
                FCcutoff = 0.6,
                pointSize = 2,
                labSize = 6.0,
                labCol = 'white',
                labFace = 'bold',
                boxedLabels = TRUE,
                colAlpha = 0.5,
                legendPosition = 'right',
                legendLabSize = 14,
                legendIconSize = 4.0,
                drawConnectors = TRUE,
                widthConnectors = 1.0,
                colConnectors = 'black')

ggsave("C:/Users/David.SIBCR-1081/Desktop/Figures/IP_vs_IN_volcano_white.png",
       plot = g,
       dpi = 300)

# making a pca using all the samples except the controls and
# undetermined sample (S0, last column of txi$counts)
sample_pca <- as_tibble(txi$counts) %>% 
  dplyr::select(contains('Five') | contains('One'))
sample_pca <- as.matrix(sample_pca)
mode(sample_pca) <- 'integer'

samples.pca <- samples %>%
  dplyr::filter(Drug == 'F' | Drug == 'S')
rownames(samples.pca) <- samples.pca$SampleName
samples.pca <- as.matrix(samples.pca)
sample_pca <- sample_pca[, match(rownames(samples.pca), colnames(sample_pca))]

dds <- DESeqDataSetFromMatrix(countData = sample_pca, colData = samples.pca,
                              design = ~Type)
dds <- DESeq(dds)
vsd <- vst(dds)
plot_pca(vsd, color_by = 'Type')

# we can also look a a euclidean distance mapping of all the samples
plot_sample_clustering(vsd, anno_vars = c("Type"))

# get gene counts for each IP sample
counts <- as.data.frame(txi$counts) %>%
  dplyr::select(starts_with(c("IP"))) %>%
  rownames_to_column("GENEID")


# get the input values for the genes to correct for contamination
inputs <- as.data.frame(txi$counts) %>%
  dplyr::select(starts_with(c("IN"))) %>%
  rownames_to_column("GENEID")


# sort counts and inputs by sample name
names.IN <-str_sort(names[grep('IN', names)])
names.IP <- str_sort(names[grep('IP', names)])
inputs <- inputs[, c("GENEID", names.IN)]
counts <- counts[, c("GENEID", names.IP)]



# calculate percent contamination of each gene in IP data
# from counts in the input data

# we don't have all paired IP and IN samples so we will use the select paired
# samples for this calculation

paired <- c("4359", "4363", "4694", "4701", "4704", "4837", "4932", "4933", 
            "4934", "4936", "4937", "4939", "4941", "4973", "4976", "4993",
            "4999", "5004", "5157", "5189", "5568", "5609")

# below is the experimental RNA yield in ng from each IP sample above
ribotag_rna <- c(196.68, 120.56, 125.4, 126.28, 154.88, 127.6, 163.68, 96.8,
                 162.8, 118.8, 92.4, 1.18, 1.29, 154, 105.6, 1.32, 1.50,
                 119.68, 242)

# negative controls below were treated as ribotag animals in extracting IP
# rna but did not express HA tagged ribosomes leading to a very low
# RNA output; these represent average contamination of samples from the IN
# portion of the RNA in ng
negative_controls <- c(1.21, 1.19)

# approximately 20% of the input RNA can be argued to be contaminating
# each IP sample

est_contamination <- mean(mean(negative_controls)/ribotag_rna)

# determine counts of each gene from IN to subtract from IP
# Female Fent Five
for (col in counts[, 3:6]) {
  col = col - (est_contamination * rowMeans(inputs[, 3:5]))
}

# Female Fent One
for (col in counts[, 7:11]) {
  col = col - (est_contamination * rowMeans(inputs[, 6:8]))
}

# Female Sal Five
for (col in counts[, 12:13]) {
  col = col - (est_contamination * rowMeans(inputs[, 9:13]))
}

#Female Sal One
for (col in counts[, 14:18]) {
  col = col - (est_contamination * rowMeans(inputs[, 14:16]))
}

# Male Fent Five
for (col in counts[, 20:22]) {
  col = col - (est_contamination * rowMeans(inputs[, 18:19]))
}

# Male Fent One
for (col in counts[, 23:26]) {
  col = col - (est_contamination * rowMeans(inputs[, 20:23]))
}

# Male Sal Five
for (col in counts[, 27:28]) {
  col = col - (est_contamination * rowMeans(inputs[, 24:27]))
}

# Male Sal One
for (col in counts[, 29:33]) {
  col = col - (est_contamination * rowMeans(inputs[, 28:32]))
}

# make all counts that are negative equal to 0
# and make all counts integers
for (i in 2:33) {
  tmp <- counts[, i]
  for (j in 1:length(tmp)) {
    if (tmp[j] < 0) {
      tmp[j] <- 0 # set negative counts to 0
    }
  }
  counts[, i] <- round(tmp) # counts must be integers
}

# save counts as the cleaned IP data
write.csv(counts, 'IP_Counts_Minus_IN.csv', na = 'NA')

# converting to transcirpts per kilobase million to normalize counts
lengths <- rowMeans(txi$length)
counts.tpm <- counts[, 2:33]/lengths
for (i in 1:length(counts.tpm)) {
    temp <- counts.tpm[, i]
    sums <- sum(temp)/1e6
    counts.tpm[, i] <- temp/sums
    }

# rearrange the data to long format
# applies names to columns based on the original sample names 
counts.tpm$GENEID <- counts$GENEID
counts.tpm <- left_join(counts.tpm, mm.gene_symbols)
counts.tpm$SYMBOL <- toupper(counts.tpm$SYMBOL)

# addind a cell type definition from this paper
mckenzie <- read.csv('top1000_gene_markers_McKenzieAT_2018.csv', skip = 2)

# joining the counts with mckenzie to gene gene cell type annotations
counts.tpm <- left_join(counts.tpm, mckenzie[, 7:8], join_by("SYMBOL" == "gene"))

counts.long <- counts.tpm %>% 
  pivot_longer(cols = 1:32,
               values_to = "Count", names_to = "SampleName")
  
counts.long <- inner_join(counts.long, samples, by = "SampleName")

# add a column which combines Drug and Cycle as Group
counts.long['Group'] <- paste0(counts.long$Drug, ' ', counts.long$Cycle)


#+++++++++++++++++++++++++
# Function to calculate the mean and the SEM
# for each group
#+++++++++++++++++++++++++
# data : a data frame
# varname : the name of a column containing the variable
#to be summariezed
# groupnames : vector of column names to be used as
# grouping variables
data_summary <- function(data, varname, groupnames){
  require(plyr)
  summary_func <- function(x, col){
    c(mean = mean(x[[col]], na.rm=TRUE),
      se = sd(x[[col]], na.rm=TRUE)/sqrt(length(x[[col]])))
    }
  data_sum<-ddply(data, groupnames, .fun=summary_func,
                  varname)
  data_sum <- rename(data_sum, c("mean" = varname))
  return(data_sum)
}

# plotting IP gene expression by cell type
# for all groups

g <- data_summary(counts.long, 'Count', c('Group', 'Celltype')) %>%
  ggplot(aes(x = Celltype, y = Count, color = Group, 
             ymin = Count - se, ymax = Count + se)) +
  geom_point(position = position_dodge(width = 0.3), size = 3) +
  geom_errorbar(position = position_dodge(width = 0.3)) +
  scale_color_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan')) +
  xlab('') +
  ylab('Mean count of cell type genes')

ggsave('C:/Users/David.SIBCR-1081/Desktop/Figures/IPGenesByCellType.png', 
         plot = g,
         dpi = 300)


# principal component analysis comparing IP vs IN
# need to remove non-ribotag control samples
# and undetermined column
# columns 1, 16, 32, 49, 64
txi2 <- txi$counts[, c(-1, -16, -32, -49, -64)]
mode(txi2) <- 'integer'

# read in previously saved sample info for entire data set
samples <- read.csv('SampleInformation.csv')

# new DESeq2 object
dds <- DESeqDataSetFromMatrix(countData = txi2,
                              colData = samples[c(-1, -16, -32, -49, -64), ],
                              design = ~Type)

# remove genes with counts < 5 and replicates < 2
dds <- filter_genes(dds, min_count = 5, min_rep = 2)

# perform variance stabilizing transformation
# type of normalization 
vsd <- vst(dds)

# note that the IP and IN samples are clearly separated along PC1
plot_pca(vsd, color_by = 'Type')

# exploring differences between gene expression in IP samples
# by group

# counts matrix from saved csv file
counts <- read.csv('IP_Counts_Minus_IN.csv')
counts <- counts %>% dplyr::select(-2, -19) %>%
  column_to_rownames("GENEID") %>%
  as.matrix()

# samples matrix from csv file
# samples are filtered based on Type == IP
samples <- read.csv('SampleInformation.csv')
samples <- data.frame(samples[samples$Type == 'IP', ], row.names = "SampleName")


# is there a sex difference in the IP samples?
dds <- DESeqDataSetFromMatrix(countData = counts, colData = samples[c(-1, -18), ],
                              design = ~Sex + Drug + Cycle)
relevel(dds$Sex, ref = 'M')
relevel(dds$Drug, ref = 'Sal')
relevel(dds$Cycle, ref = 'One')

dds <- DESeq(dds)

# very few genes are different due to Sex in the IP samples
res.Sex <- lfcShrink(dds, coef = "Sex_M_vs_F", type = 'normal')
res.Sex <- tibble::rownames_to_column(as.data.frame(res.Sex), "GENEID")
res.Sex <- left_join(res.Sex, mm.gene_symbols, by = "GENEID")

# differences due to drug alone?
res.Drug <- lfcShrink(dds, coef = "Drug_Sal_vs_Fent", type = 'apeglm')
plotMA(res.Drug)
# only 1 gene significant
subset(res.Drug, padj < 0.1)

# differences due to Cycle alone?
res.Cycle <- lfcShrink(dds, coef = 4, type = 'apeglm')
subset(res.Cycle, padj < 0.1)
plotMA(res.Cycle)



###########################################################################
# RNA Quality control plotting
###########################################################################
# see https://cran.r-project.org/web/packages/RNAseqQC/vignettes/introduction.html


# get the counts matrix for the IN corrected IP samples
counts.IP <- read.csv('IP_Counts_Minus_IN.csv')
# remove the ctrl IP samples
counts.IP <- counts.IP[, c(-2, -19)] %>% column_to_rownames("GENEID") %>%
  as.matrix()

# metedata from the samples  
samples.IP <- read.csv('IP_SampleInformation.csv')

# create DESeq2 dataset object

dds <- make_dds(counts = counts.IP, metadata = samples.IP,
                ah_record = "AH116340", design = ~Group)

# make comparisons to Fent One group
dds$Group <- relevel(dds$Group, ref = "Fent One")
dds <- DESeq(dds)

# remove genes with counts < 5 and replicates < 2
dds <- filter_genes(dds, min_count = 5, min_rep = 2)

#library size - all samples have counts with the same order of magnitude
plot_total_counts(dds)

# library complexity - displays a curve for each samples
# curves should generally be overlapping
# which suggest that the same fraction of counts are taken up by the same
# genes
plot_library_complexity(dds)

vsd <- vst(dds)

# sample clustering via euclidean distance or pearson correlation
# suggests that IP_M_F_Five_4267 may be an outlier
plot_sample_clustering(vsd, anno_vars = c("Sex", "Group"),
                       distance = 'euclidean')

# plot variability of each sample compared to grouping variable
# each sample's gene expression is compared against the median
# gene expression of all the samples in the group
ma_plots <- plot_sample_MAs(vsd, group = "Group")
# plotting just a subset of the MAplots
cowplot::plot_grid(plotlist = ma_plots[1:6], ncol = 2)


# pca
plot_pca(vsd, PC_x = 1, PC_y = 2, color_by = "Group", shape_by = "Sex")

# which genes are contributing to pca1 the most?
pca_res <- plot_pca(vsd, show_plot = FALSE, )
plot_loadings(pca_res, PC = 1, annotate_top_n = 10)

# what about PC2?
plot_loadings(pca_res, PC = 2, annotate_top_n = 10)

# running DESeq2 
dds$Group <- as.factor(dds$Group)
dds$Group <- relevel(dds$Group, ref = 'Sal.One')
dds$Sex <- as.factor(dds$Sex)
design(dds) <- ~ Group + Sex

dds <- DESeq(dds)

# dispersion estimates for each gene versus mean expression of said gene
# the data should generally follow the fitted curve in red
# the dispersion is a measure of the variability in the counts data
# for genes with high counts, the dispersion becomes equal to the coefficient of
# variation for the mean of the count
plotDispEsts(dds)

################################################################################
# functions of note
################################################################################

# perform TPM conversion for a list of gene counts
# must have matching data for ensembl id for each gene
toTPM <- function(df, tx_obj = txi) {
  avg_lengths <- rowMeans(tx_obj$length)
  gene_list <- df$baseMean
  names(gene_list) <- df$GENEID
  shared <- intersect(names(avg_lengths), names(gene_list))
  rpk <- gene_list / (avg_lengths[shared]/1000)
  scalingFactor <- sum(rpk)/1e6
  return(rpk/scalingFactor)
}


# convert a DESeqResults object to a data frame
reslfc2df <- function(deseqres, gene2sym = mm.gene_symbols) {
  res <- data.frame(deseqres) %>%
    rownames_to_column("GENEID") %>%
    left_join(mm.gene_symbols)
  return(res)
}

###############################################################################
# plots of log2foldchange versus transcripts per kilobase million
###############################################################################

length.means <- data.frame('GENEID' = counts$GENEID,
                           'meanLength' = lengths)

# gene module key value pairs from WGCNA
# gives the module for each gene
genes2colors <- read.csv('WGCNA_genes2colors.csv')

# getting results for Sal Five v Sal One
dds <- DESeqDataSetFromMatrix(counts.IP, samples.IP, design = ~ Group + Sex)
dds <- filter_genes(dds, min_count = 5, min_rep = 2)
dds$Group <- relevel(dds$Group, ref = 'Sal One')
dds <- DESeq(dds)
res.SalFive_vs_SalOne <- results(dds, 
                                      contrast = c('Group', 'Sal Five', 'Sal One'))
res.SalFive_vs_SalOne <- results(dds, 
                                contrast = c('Group', 'Sal Five', 'Sal One'))
res.SalFive_vs_SalOne <- as.data.frame(res.SalFive_vs_SalOne) %>%
  rownames_to_column('GENEID') %>%
  left_join(length.means)

# Saline Five vs Saline One
res.SalFive_vs_SalOne <- left_join(res.SalFive_vs_SalOne, length.means)
res.SalFive_vs_SalOne$rpk <- res.SalFive_vs_SalOne$baseMean /
  (res.SalFive_vs_SalOne$meanLength/1000)
scalingFactor <- sum(res.SalFive_vs_SalOne$rpk)/1e6
res.SalFive_vs_SalOne$TPM <- res.SalFive_vs_SalOne$rpk/scalingFactor
res.SalFive_vs_SalOne <- left_join(res.SalFive_vs_SalOne,
                                   mm.gene_symbols)

# 43 genes meet the following two significance cut off limits
# padj < 0.01
# abs val of log2foldchange > 0.58 or ~ 1.5 times original base mean
count(na.omit(res.SalFive_vs_SalOne[res.SalFive_vs_SalOne$padj < 0.01 &
                      abs(res.SalFive_vs_SalOne$log2FoldChange) > 0.58, ]))

# plotting the values (eliminates the two highest points on the y axis)
ggplot(data = res.SalFive_vs_SalOne,
       aes(x = log2(TPM), y = log2FoldChange, color = padj)) +
  geom_point(size = 3, show.legend = FALSE) +
  gghighlight(abs(log2FoldChange) > 0.58 &
                padj < 0.01, 
              unhighlighted_params = list(color = NULL, alpha = 0.5)) +
  scale_color_gradientn(colors = myPalette(8)) +
  ylim(c(-8, 10))

# Fentanyl Five vs Saline Five
dds <- DESeqDataSetFromMatrix(counts.IP, samples.IP, design = ~Group)
dds <- filter_genes(dds, min_count = 5, min_rep = 2)
dds$Group <- relevel(dds$Group, ref = 'Sal One')
dds <- DESeq(dds)
res.FentFive_vs_SalineFive <- results(dds, 
                                   contrast = c('Group', 'Fent Five', 'Sal Five'))

res.FentFive_vs_SalineFive <- as.data.frame(res.FentFive_vs_SalineFive) %>%
  rownames_to_column('GENEID') %>%
  left_join(length.means)
res.FentFive_vs_SalineFive$rpk <- res.FentFive_vs_SalineFive$baseMean /
  (res.FentFive_vs_SalineFive$meanLength/1000)
scalingFactor <- sum(res.FentFive_vs_SalineFive$rpk)/1e6
res.FentFive_vs_SalineFive$TPM <- res.FentFive_vs_SalineFive$rpk/scalingFactor
res.FentFive_vs_SalineFive <- left_join(res.FentFive_vs_SalineFive,
                                     mm.gene_symbols)
# 1 gene meets sig
# padj < 0.01
# abs val of log2foldchange > 0.58 or ~ 1.5 times original base mean
count(na.omit(res.FentFive_vs_SalineFive[res.FentFive_vs_SalineFive$padj < 0.01 &
                                        abs(res.FentFive_vs_SalineFive$log2FoldChange) > 0.6, ]))

ggplot(data = res.FentFive_vs_SalineFive,
       aes(x = log2(TPM), y = log2FoldChange, color = padj)) +
  geom_point(size = 3, show.legend = FALSE) +
  gghighlight(abs(log2FoldChange) > 0.58 &
                padj < 0.01, 
              unhighlighted_params = list(color = NULL, alpha = 0.5)) +
  scale_color_gradientn(colors = myPalette(8)) +
  ylim(c(-8, 10))

ggsave(filename = 'C:/Users/David.SIBCR-1081/Desktop/Figures/IP_FentFive_vs_SalFive.png',
       dpi = 300)

#Fentanyl Five vs Fentanyl One

dds <- DESeqDataSetFromMatrix(counts.IP, samples.IP, design = ~Group)
dds <- filter_genes(dds, min_count = 5, min_rep = 2)
dds$Group <- relevel(dds$Group, ref = 'Sal One')
dds <- DESeq(dds)
res.FentFive_vs_FentOne <- results(dds, 
                                 contrast = c('Group', 'Fent Five', 'Fent One'))

res.FentFive_vs_FentOne <- data.frame(res.FentFive_vs_FentOne) %>%
  rownames_to_column("GENEID") %>%
  left_join(length.means) %>%
  left_join(genes2colors)

res.FentFive_vs_FentOne <- left_join(res.FentFive_vs_FentOne, length.means)
res.FentFive_vs_FentOne$rpk <- res.FentFive_vs_FentOne$baseMean /
  (res.FentFive_vs_FentOne$meanLength/1000)
scalingFactor <- sum(res.FentFive_vs_FentOne$rpk)/1e6
res.FentFive_vs_FentOne$TPM <- res.FentFive_vs_FentOne$rpk/scalingFactor
res.FentFive_vs_FentOne <- left_join(res.FentFive_vs_FentOne,
                                   mm.gene_symbols)
# 5613 genes meet sig
# padj < 0.01
# abs val of log2foldchange > 0.58 or ~ 1.5 times original base mean
count(na.omit(res.FentFive_vs_FentOne[res.FentFive_vs_FentOne$padj < 0.01 &
                    abs(res.FentFive_vs_FentOne$log2FoldChange) > 0.58, ]))


ggplot(data = res.FentFive_vs_FentOne,
       aes(x = log2(TPM), y = log2FoldChange, color = padj)) +
  geom_point(size = 3, show.legend = F) +
  gghighlight(abs(log2FoldChange) > 0.58 &
                padj < 0.01, 
              unhighlighted_params = list(color = NULL, alpha = 0.5)) +
  scale_color_gradientn(colors = myPalette(8)) +
  ylim(c(-8, 10))

ggsave(filename = 'C:/Users/David.SIBCR-1081/Desktop/Figures/IP_FentFive_vs_FentOne.png',
       dpi = 300)

# Fentanyl One vs Saline One

res.FentOne_vs_SalOne <- as.data.frame(
  results(dds, 
          contrast = c('Group', 'Fent One', 'Sal One'))) %>%
  rownames_to_column('GENEID')
res.FentOne_vs_SalOne <- left_join(res.FentOne_vs_SalOne, length.means)
res.FentOne_vs_SalOne$rpk <- res.FentOne_vs_SalOne$baseMean /
  (res.FentOne_vs_SalOne$meanLength/1000)
scalingFactor <- sum(res.FentOne_vs_SalOne$rpk)/1e6
res.FentOne_vs_SalOne$TPM <- res.FentOne_vs_SalOne$rpk/scalingFactor

# only 1 gene with padj < 0.01 and abs(log2fc) > 0.58
na.omit(res.FentOne_vs_SalOne[res.FentOne_vs_SalOne$padj < 0.01 &
                            abs(res.FentOne_vs_SalOne$log2FoldChange) > 0.58, ])

ggplot(data = res.FentOne_vs_SalOne,
       aes(x = log2(TPM), y = log2FoldChange, color = padj)) +
  geom_point(size = 3, show.legend = F) +
  gghighlight(abs(log2FoldChange) > 0.58 &
                padj < 0.01, 
              unhighlighted_params = list(color = NULL, alpha = 0.5)) +
  scale_color_gradientn(colors = myPalette(8)) +
  ylim(c(-8, 10))

ggsave(filename = 'C:/Users/David.SIBCR-1081/Desktop/Figures/IP_FentOne_vs_SalOne.png',
       dpi = 300)

# enrichment of IP genes
# ratio of IP counts to IN counts

enrichment <- data.frame(counts.IP) %>%
  rownames_to_column('GENEID') %>%
  left_join(mm.gene_symbols, by = 'GENEID')

enrichment$SYMBOL <- toupper(enrichment$SYMBOL)

for (col in enrichment[, 2:5]) {
  col = col / rowMeans(inputs[, 3:5])
}

for (col in enrichment[, 6:10]) {
  col = col / rowMeans(inputs[, 6:8])
}

for (col in enrichment[, 11:12]) {
  col = col / rowMeans(inputs[, 9:13])
}

for (col in enrichment[, 13:17]) {
  col = col / rowMeans(inputs[, 14:16])
}

for (col in enrichment[, 18:20]) {
  col = col / rowMeans(inputs[, 18:19])
}

for (col in enrichment[, 21:24]) {
  col = col / rowMeans(inputs[, 20:23])
}

for (col in enrichment[, 25:26]) {
  col = col / rowMeans(inputs[, 24:27])
}

for (col in enrichment[, 27:31]) {
  col = col / rowMeans(inputs[, 28:32])
}

names(enrichment) <- gsub(pattern = "IP_", replacement = "", x = names(enrichment))
enrichment.long <- pivot_longer(enrichment, cols = 2:31,
                                names_to = c("Sex", "Treatment", "Cycle", "Mouse"),
                                names_sep = "_")

enrichment.long$Treatment[enrichment.long$Treatment == "F"] <- "Fent"
enrichment.long$Treatment[enrichment.long$Treatment == "S"] <- "Sal"

enrichment.long$Group <- paste0(enrichment.long$Treatment, ".", enrichment.long$Cycle)
names(mckenzie)[7] <- 'SYMBOL'
enrichment.long <- left_join(enrichment.long, mckenzie[, 7:8])

data_summary(enrichment.long, 'value', c('Group', 'Celltype')) %>%
  ggplot(aes(x = Celltype, y = value, color = Group, 
             ymin = value - se, ymax = value + se)) +
  geom_point(position = position_dodge(width = 0.3), size = 3) +
  geom_errorbar(position = position_dodge(width = 0.3)) +
  scale_color_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan'))


ggplot(enrichment.long, aes(x = Celltype, y = value, fill = Group)) +
  geom_boxplot(position = position_dodge(), outliers = FALSE) +
  scale_fill_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan'))

ggplot(enrichment.long, aes(x = Celltype, y = value, fill = Group)) +
  geom_violin(stat = 'identity') +
  scale_fill_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan'))


# creating a summary of the enrichment data for comparisons
enrichment.summary <- enrichment.long %>%
  dplyr::group_by(SYMBOL, Group) %>%
  dplyr::summarize(mean.enrichment = mean(value, na.rm = TRUE),
                   sem = (sd(value, na.rm = TRUE) / sqrt(n())),
                   min = min(value, na.rm = TRUE),
                   max = max(value, na.rm = TRUE),
                   med = median(value, na.rm = TRUE))

# remove rows with genes that have no corresponding SYMBOL
enrichment.summary <- enrichment.summary[!(enrichment.summary$SYMBOL %in% ""),]

# plotting group gene expression from non-normalized counts

bars <- function(gene_name) {
  ggplot(counts.long[counts.long$SYMBOL == toupper(paste0(gene_name)), ],
         aes(x = Group, y = Count, fill = Group)) +
    geom_bar(stat = 'summary', fun = mean, alpha = 0.6) +
    geom_point(aes(color = Group), stat = 'identity', 
               position = position_jitter(width = 0.1, height = 0.1),
               size = 4) +
    geom_errorbar(stat = 'summary',
                  width = 0.4, alpha = 0.9, linewidth = 1.3) +
    scale_fill_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan')) +
    scale_color_manual(values = c('tan', 'red', 'salmon', 'turquoise', 'cyan')) +
    ggtitle(paste0(gene_name)) +
    xlab('') +
    ylab('')
}

bars("P2ry12")


################################################################################
# plotting enrichment related to genes of interest from various modules
# derived from WGCNA
enrichment <- read.csv('enrichment_IP.csv')

enrichment.means <- rowMeans(enrichment[, 2:31])

enrichment2 <- data.frame(ratio = enrichment.means, SYMBOL = enrichment$SYMBOL)

sig <- read.csv('significant_from_WGCNA.csv')

sig$SYMBOL <- toupper(sig$SYMBOL)
enrichment2 <- left_join(enrichment2, sig[, 1:13], )

enrichment2 <- drop_na(enrichment2)

en.summary <- enrichment2 %>%
  dplyr::group_by(mergedlabels) %>%
  summarise(
    se = sd(ratio, na.rm = TRUE)/sqrt(n()),
    ratio = mean(ratio))

ggplot(en.summary) +
  geom_bar(data = en.summary,
           aes(x = mergedlabels, y = ratio, fill = mergedlabels),
           alpha = 0.7,
           stat = 'identity', position = position_dodge2()) +
  geom_errorbar(data = en.summary,
                aes(x = mergedlabels, ymin = ratio - se,
                    ymax = ratio + se),
                position = position_dodge2(),
                size = 1) +
  scale_fill_manual(values = en.summary$mergedlabels)
################################################################################
# volcano plots

################################################################################

res.FentFive_vs_FentOne$diffexpr <- "NO"
res.FentFive_vs_FentOne$diffexpr[res.FentFive_vs_FentOne$log2FoldChange > 0.6 &
                                    res.FentFive_vs_FentOne$padj < 0.01] <- "UP"
res.FentFive_vs_FentOne$diffexpr[res.FentFive_vs_FentOne$log2FoldChange < -0.6 &
                                   res.FentFive_vs_FentOne$padj < 0.01] <- "DOWN"

res.FentFive_vs_FentOne$delabel <- ifelse(res.FentFive_vs_FentOne$SYMBOL %in%
  head(res.FentFive_vs_FentOne[order(res.FentFive_vs_FentOne$padj), "SYMBOL"], 30),
  res.FentFive_vs_FentOne$SYMBOL, NA)

ggplot(data = res.FentFive_vs_FentOne,
       aes(x = log2FoldChange, y = -log10(padj), col = diffexpr)) +
  geom_vline(xintercept = c(-0.6, 0.6), col = "gray", linetype = 'dashed') +
  geom_hline(yintercept = -log10(0.01), col = "gray", linetype = 'dashed') +
  geom_point(show.legend =  FALSE) +
  scale_color_manual(values = c("#00AFBB", "grey", "#bb0c00"), expand = TRUE,
                     labels = NULL) +
  coord_cartesian(ylim = c(0, 15), xlim = c(-5, 5))


################################################################################
# input samples analysis
################################################################################

# remove control samples
inputs <- inputs[, c(-2, -17)] %>% column_to_rownames("GENEID") %>%
  as.matrix()
mode(inputs) <- 'integer'

# metedata from the samples  
samples.IN <- read.csv('IN_SampleInformation.csv')

# create DESeq2 dataset object

dds <- make_dds(counts = inputs, metadata = samples.IN,
                ah_record = "AH116340", design = ~Group)

# make comparisons to Sal One group
dds$Group <- relevel(dds$Group, ref = "Sal One")
dds <- DESeq(dds)

# remove genes with counts < 5 and at least two replicates
dds <- filter_genes(dds, min_count = 5, min_rep = 2)

#library size - all samples have counts with the same order of magnitude
plot_total_counts(dds)

# library complexity - displays a curve for each samples
# curves should generally be overlapping
# which suggest that the same fraction of counts are taken up by the same
# genes
plot_library_complexity(dds)

vsd <- vst(dds)

# sample clustering via euclidean distance or pearson correlation
# suggests that IP_M_F_Five_4267 may be an outlier
plot_sample_clustering(vsd, anno_vars = c("Sex", "Group"),
                       distance = 'euclidean')

# plot variability of each sample compared to grouping variable
# each sample's gene expression is compared against the median
# gene expression of all the samples in the group
ma_plots <- plot_sample_MAs(vsd, group = "Group")
# plotting just a subset of the MAplots
cowplot::plot_grid(plotlist = ma_plots[1:6], ncol = 2)


# pca
plot_pca(vsd, PC_x = 1, PC_y = 2, color_by = "Group", shape_by = "Sex")

length.means <- data.frame('meanLength' = rowMeans(txi$length)) %>%
  rownames_to_column('GENEID')

# results from DESeq for inputs sal five vs saline one
resIN.SalFive_vs_SalOne <- results(dds, contrast = c("Group",
                                 "Sal Five", "Sal One"))

resIN.SalFive_vs_SalOne <- data.frame(resIN.SalFive_vs_SalOne) %>%
  rownames_to_column('GENEID')
resIN.SalFive_vs_SalOne <- left_join(resIN.SalFive_vs_SalOne, length.means)
resIN.SalFive_vs_SalOne$rpk <- resIN.SalFive_vs_SalOne$baseMean /
  (resIN.SalFive_vs_SalOne$meanLength/1000)
scalingFactor <- sum(resIN.SalFive_vs_SalOne$rpk)/1e6
resIN.SalFive_vs_SalOne$TPM <- resIN.SalFive_vs_SalOne$rpk/scalingFactor
resIN.SalFive_vs_SalOne <- left_join(resIN.SalFive_vs_SalOne,
                                   mm.gene_symbols)

# inputs fent five vs sal five
resIN.FentFive_vs_SalFive <- results(dds, contrast = c("Group",
                                                     "Fent Five", "Sal Five"))

resIN.FentFive_vs_SalFive <- data.frame(resIN.FentFive_vs_SalFive) %>%
  rownames_to_column('GENEID')
resIN.FentFive_vs_SalFive <- left_join(resIN.FentFive_vs_SalFive, length.means)
resIN.FentFive_vs_SalFive$rpk <- resIN.FentFive_vs_SalFive$baseMean /
  (resIN.FentFive_vs_SalFive$meanLength/1000)
scalingFactor <- sum(resIN.FentFive_vs_SalFive$rpk)/1e6
resIN.FentFive_vs_SalFive$TPM <- resIN.FentFive_vs_SalFive$rpk/scalingFactor
resIN.FentFive_vs_SalFive <- left_join(resIN.FentFive_vs_SalFive,
                                     mm.gene_symbols)

# inputs fent five vs fent one
resIN.FentFive_vs_FentOne <- results(dds, contrast = c("Group",
                                                       "Fent Five", "Fent One"))

resIN.FentFive_vs_FentOne <- data.frame(resIN.FentFive_vs_FentOne) %>%
  rownames_to_column('GENEID')
resIN.FentFive_vs_FentOne <- left_join(resIN.FentFive_vs_FentOne, length.means)
resIN.FentFive_vs_FentOne$rpk <- resIN.FentFive_vs_FentOne$baseMean /
  (resIN.FentFive_vs_FentOne$meanLength/1000)
scalingFactor <- sum(resIN.FentFive_vs_FentOne$rpk)/1e6
resIN.FentFive_vs_FentOne$TPM <- resIN.FentFive_vs_FentOne$rpk/scalingFactor
resIN.FentFive_vs_FentOne <- left_join(resIN.FentFive_vs_FentOne,
                                       mm.gene_symbols)


# inputs fent one vs sal one
resIN.FentOne_vs_SalOne <- results(dds, contrast = c("Group",
                                                       "Fent One", "Sal One"))

resIN.FentOne_vs_SalOne <- data.frame(resIN.FentOne_vs_SalOne) %>%
  rownames_to_column('GENEID')
resIN.FentOne_vs_SalOne <- left_join(resIN.FentOne_vs_SalOne, length.means)
resIN.FentOne_vs_SalOne$rpk <- resIN.FentOne_vs_SalOne$baseMean /
  (resIN.FentOne_vs_SalOne$meanLength/1000)
scalingFactor <- sum(resIN.FentOne_vs_SalOne$rpk)/1e6
resIN.FentOne_vs_SalOne$TPM <- resIN.FentOne_vs_SalOne$rpk/scalingFactor
resIN.FentOne_vs_SalOne <- left_join(resIN.FentOne_vs_SalOne,
                                       mm.gene_symbols)


# compare overlap in genes that are significant in comparisons between IN
# and the IP samples

# example with IP comparisons of FentFive_vs_FentOne and SalFive_vs_SalOne
# get genes in res.FentFive_vs_FentOne matching sig requirements
left <- na.omit(res.FentFive_vs_FentOne[abs(res.FentFive_vs_FentOne$log2FoldChange) > 0.6 &
                                          res.FentFive_vs_FentOne$padj < 0.01, 1])

# get all the rows from res.SalFive_vs_SalOne matching the filtered gene ids
temp <- res.SalFive_vs_SalOne[res.SalFive_vs_SalOne$GENEID %in% left, ] %>%
  dplyr::filter(padj < 0.01 & abs(log2FoldChange) > 0.6)

# list all the genes shared between the two data sets
temp$SYMBOL


# making a venn diagram
x <- list(
  'IP Fent Five vs Fent One' = na.omit(res.FentFive_vs_FentOne[abs(res.FentFive_vs_FentOne$log2FoldChange) > 0.6 &
                               res.FentFive_vs_FentOne$padj < 0.01, 1]),
  'IP Sal Five vs Sal One' = na.omit(res.SalFive_vs_SalOne[abs(res.SalFive_vs_SalOne$log2FoldChange) > 0.6 &
                                                             res.SalFive_vs_SalOne$padj < 0.01, 1]),
  'IN Sal Five vs Sal One' = na.omit(resIN.SalFive_vs_SalOne[abs(resIN.SalFive_vs_SalOne$log2FoldChange) > 0.6 &
                                                               resIN.SalFive_vs_SalOne$padj < 0.01, 1]),
  'IN Fent Five vs Fent One' = na.omit(resIN.FentFive_vs_FentOne[abs(resIN.FentFive_vs_FentOne$log2FoldChange) > 0.6 &
                                                      resIN.FentFive_vs_FentOne$padj < 0.01, 1]))


ggvenn(x, fill_color = c("#0073C2FF", "#EFC000FF", "#868686FF", "#CD534CFF"))


