#!/usr/bin/env python
"""
Merge multiple FASTQ files based on a filename pattern.

Sourced from https://github.com/lazappi/binf-scripts/blob/master/mergeFastqs.py
Altered to accept glob input of files in arg for fastq file names
"""

import argparse
import logging
import os
import glob
import sys
import shutil
from collections import defaultdict

def get_args():
    """
    Get command line arguments.
    """

    parser = argparse.ArgumentParser(description="Merge FASTQ files based on name")
    parser.add_argument("-s", "--separator",
                        help="Filename separator. Default is '_'.",
                        default="_")
    parser.add_argument("-o", "--outdir",
                        help="Path to output directory for merged files.",
                        default=os.getcwd())
    parser.add_argument("-f", "--fastqs",
                        nargs="+",
                        help="FASTQ files to merge",
                        default=glob.glob('*.fastq.gz'))
    parser.add_argument("-p", "--pattern",
                        help="Pattern used to decide which files to merge." +
                        "Should have the format [CODE][SEP]...[CODE], " +
                        "where SEP is the separator and CODE is one of: " +
                        "K = Keep this section or " +
                        "M = Merge this section, "
                        "For example if the filename structure was: " +
                        "'SAMPLE_READ_LANE_DATE.fastq', " +
                        "to merge on LANE and DATE the pattern " +
                        "would be 'K_K_M_M', which would produce a" +
                        "merged file named 'SAMPLE_DATE.fastq'",
                        required=True)
    args = parser.parse_args()

    return args


def setup_logging(outdir):
    """
    Setup logging system.

    Log is written to 'mergeFastqs.log'.

    Args:
        outdir: Output directory
    """

    logger = logging.getLogger("mergeFQs")
    logger.setLevel(logging.DEBUG)

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    log_file = os.path.join(outdir, "mergeFastqs.log")

    # create file handler which logs even debug messages
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.DEBUG)

    # create console handler with a higher log level
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    # create formatter and add it to the handlers
    format_str = "[%(asctime)s] %(levelname)s %(name)s: %(message)s"
    formatter = logging.Formatter(format_str, "%Y-%m-%d %H:%M:%S")
    file_handler.setFormatter(formatter)
    format_str = "[%(asctime)s] %(message)s"
    formatter = logging.Formatter(format_str, "%H:%M:%S")
    console_handler.setFormatter(formatter)

    # add the handlers to logger
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)


def merge_filename(filename, pat, sep):
    """
    Apply a merging pattern to a filename.

    Sections of the base filename are kept according to the provided pattern.

    Args:
        filename: The filename to merge.
        pat: The merging pattern to use.
        sep: String separating filename sections.

    Return:
        String containing the merged filename.
    """

    # Extract filename and extension
    # Don't use 'os.splitex' because extension could have multiple parts
    filebase = os.path.basename(filename).split(".")[0]
    ext = ".".join(os.path.basename(filename).split(".")[1:])

    split_file = filebase.split(sep)

    # Exit if filename does not match pattern
    if not len(split_file) == len(pat):
        sys.exit("File " + filename + " does not match pattern " + str(pat))

    merge_file = []
    for idx in range(0, len(split_file)):
        code = pat[idx]
        file_sec = split_file[idx]
        if code == "K":
            merge_file.append(file_sec)

    return sep.join(merge_file) + "." + ext


def group_filenames(filenames, pat, sep):
    """
    Group files based on their merged file names.

    Args:
        filenames: List of filename strings to group.
        pat: Merging patter to use for grouping.
        sep: String separating filename sections.

    Returns:
        Dictionary with group names as keys and lists of original filenames as
        values.
    """

    groups = defaultdict(list)

    for filename in filenames:
        group = merge_filename(filename, pat, sep)
        groups[group].append(filename)

    return groups


def merge_files(groups, outdir):
    """
    Merge files that belong to the same filename group.

    Merged files are created in the output directory.

    Args:
        groups: Dictionary of filename groups from `group_filenames`.
        outdir: Output path for merged files.
    """

    logger = logging.getLogger("mergeFQs." + "merge")

    for groupname, filenames in groups.items():
        logger.info("Merging group " + groupname + " with " +
                    str(len(filenames)) + " files...")
        outpath = os.path.join(outdir, groupname)
        logger.info("Creating merge file " + outpath + "...")
        with open(outpath, "wb") as outfile:
            for filename in filenames:
                logger.info("Adding " + filename + "...")
                with open(filename, "rb") as fq_file:
                    shutil.copyfileobj(fq_file, outfile)


def main():
    """
    Run main code

    1. Get arguments
    2. Setup logging
    3. Group filenames
    4. Merge files
    """

    args = get_args()

    setup_logging(args.outdir)
    logger = logging.getLogger("mergeFQs." + __name__)

    logger.info(str(len(args.fastqs)) + " input files provided")
    logger.info("Filename pattern is " + args.pattern)
    pattern = args.pattern.split(args.separator)
    ex_file = args.fastqs[0]
    ex_merge = merge_filename(ex_file, pattern, args.separator)
    logger.info("Example merge: " + ex_file + " -> " +
                os.path.join(args.outdir, ex_merge))
    file_groups = group_filenames(args.fastqs, pattern, args.separator)
    logger.info(str(len(file_groups)) + " file groups found...")
    merge_files(file_groups, args.outdir)


if __name__ == "__main__":
    main()
