#include "./structs.h"
#include "./functions.h"
#include <R_ext/Arith.h>
#include <R_ext/Error.h>
#include <R_ext/Print.h>
#include <Rinternals.h>
#include <stddef.h>
#include <string.h>

#define STB_DS_IMPLEMENTATION
#include "lib/stb_ds.h"

static void get_type_name(char *buf, unsigned int type, int maxlen);

// Note for future development: R4.5.0 introduced Rf_isDataFrame. Switch all Rf_isFrame over to that once backwards compatibility is no longer a concern.

SEXP type_check(SEXP type_template, SEXP target, SEXP with_cast, SEXP log_items, SEXP call_path, SEXP ldns) {
    _Bool log_started = 0; // Print out the calling function when we first log.
    // Validate the types on our dataframes.
    if (TYPEOF(type_template) != VECSXP || TYPEOF(target) != VECSXP) {
        Rf_error("Both template and target must be lists.");
    }
    // First, parse everything from the arguments.
    _Bool use_cast = LOGICAL(with_cast)[0];
    unsigned int log_flag = INTEGER(log_items)[0]; // Log levels in a bit field: 1: casts, 2: missing, 4: excess, 8: debug
    const char *call_string = CHAR(STRING_ELT(call_path, 0));
    if (log_flag & 8) {
        if (!log_started) {
            log_started = 1;
            Rprintf("In %s:\n", call_string);
        }
        Rprintf("\tCast: %d; Log Level: %d; From: %s\n", use_cast, log_flag, call_string);
    }
    SEXP output = PROTECT(Rf_duplicate(target)); // Rf_duplicate makes a deep copy, so this should suffice for output.
    PROTECT(type_template);
    R_xlen_t n_col = LENGTH(target);
    R_xlen_t n_template_col = LENGTH(type_template);
    R_xlen_t n_target_row = LENGTH(VECTOR_ELT(target, 0));
    SEXP target_names = PROTECT(Rf_getAttrib(output, R_NamesSymbol));
    SEXP template_names = PROTECT(Rf_getAttrib(type_template, R_NamesSymbol));
    // Construct an stb_ds hash table of template names and types.
    // We use the integer value of the type that we've received and another flag for classes that warrant attention, currently only POSIXct.
    struct type_entry *template_dict = NULL;
    R_xlen_t *excess_arr = NULL;
    for (R_xlen_t i = 0; i < n_template_col; ++i) {
        struct type_container entry = {};
        SEXP col = VECTOR_ELT(type_template, i);
        entry.primitive |= TYPEOF(col);
        SEXP class_attr = PROTECT(Rf_getAttrib(col, R_ClassSymbol)); // Hold onto this; col doesn't matter.
        if (log_flag & 8) {
            char buf[20] = {};
            get_type_name(buf, entry.primitive, 20);
            Rprintf("\tTemplate field %s class(es): %s ", CHAR(STRING_ELT(template_names, i)), buf);
        }
        if (class_attr != R_NilValue) {
            for (int j = 0; j < Rf_length(class_attr); ++j) {
                if (log_flag & 8) { Rprintf("%s ", CHAR(STRING_ELT(class_attr, j))); }
                // Dates are our only special case: we flag if our input is a POSIXct.
                if (strcmp("POSIXct", CHAR(STRING_ELT(class_attr, j))) == 0) {
                    entry.class_flag |= 1;
                }
            }
        }
        if (log_flag & 8) { Rprintf("\n"); }
        shput(template_dict, CHAR(STRING_ELT(template_names, i)), entry);
        UNPROTECT(1); // Drop the class attributes.
    }
    // Now iterate over the columns of the target and see what is to be done.
    for (R_xlen_t i = 0; i < n_col; ++i) {
        const char *curr = CHAR(STRING_ELT(target_names, i));
        if (log_flag & 8) {
            Rprintf("Handling key %s\n", curr);
        }
        R_xlen_t idx = shgeti(template_dict, curr); // R_xlen_t is a ptrdiff_t, as is the return of shgeti.   
        /*if (log_flag & 8) {
            Rprintf("Index for %s is %lld\n", curr, idx);
        }*/
        if (idx == -1) { // Field not in template.
            if (log_flag & 4) {
                arrput(excess_arr, i);
            }
            continue;
        }
        // Check the contents of the match with guaranteed content.
        struct type_entry *match_entry = &template_dict[idx];
        match_entry->value.matched = 1; // Note that we've encountered this. We assume no repeated column names, but I think that's assured.
        SEXP col = VECTOR_ELT(output, i);
        unsigned int prim = TYPEOF(col);
        unsigned int curr_class = 0;
        SEXP class_attr = PROTECT(Rf_getAttrib(col, R_ClassSymbol));
        if (class_attr != R_NilValue) {
            for (int j = 0; j < Rf_length(class_attr); ++j) {
                if (strcmp("POSIXct", CHAR(STRING_ELT(class_attr, j))) == 0) {
                    curr_class |= 1;
                }
            }
        }
        UNPROTECT(1); // Drop the class attributes.
        // Now check for mismatches.
        if (match_entry->value.primitive != prim || match_entry->value.class_flag != curr_class) {
            // Log if we have logging enabled.
            if (log_flag & 1) {
                if (!log_started) {
                    log_started = 1;
                    Rprintf("In %s:\n", call_string);
                }
                char buf1[20] = {};
                char buf2[20] = {};
                get_type_name(buf1, match_entry->value.primitive, 20);
                get_type_name(buf2, prim, 20);
                Rprintf("\tType mismatch when checking field: %s\tTemplate: %s\tTarget: %s\n", curr, buf1, buf2);
            }
            // Cast if we have casting enabled.
            if (use_cast) {
                // First handle the special case of dates - specifically, only casting _to_ date.
                if (match_entry->value.class_flag > curr_class) {
                    SEXP cast_col;
                    SEXP call;
                    switch(prim) {
                        case INTSXP:
                        case REALSXP:
                        case STRSXP:
                            // We try to make a call directly to Lubridate.
                            call = PROTECT(Rf_lang2(Rf_install("as_datetime"), col));
                            cast_col = Rf_eval(call, ldns); // We pass in Lubridate's environment, since R CMD check doesn't like globalenv.
                            UNPROTECT(1); // call
                            break;
                        default:
                            cast_col = Rf_allocVector(REALSXP, n_target_row);
                            for (R_xlen_t j = 0; j < n_target_row; ++j) {
                                REAL(cast_col)[j] = NA_REAL;
                            }
                            // Now set the types.
                            SEXP date_attr = PROTECT(Rf_allocVector(STRSXP, 2));
                            SET_STRING_ELT(date_attr, 0, Rf_mkChar("POSIXct"));
                            SET_STRING_ELT(date_attr, 1, Rf_mkChar("POSIXt"));
                            Rf_setAttrib(cast_col, R_ClassSymbol, date_attr);
                            UNPROTECT(1); // date_attr
                            break;
                    }
                    SET_VECTOR_ELT(output, i, cast_col);
                } else {
                // Otherwise fall back to a normal cast to the specified primitive type.
                    SEXP cast_col = PROTECT(Rf_coerceVector(col, match_entry->value.primitive));
                    SET_VECTOR_ELT(output, i, cast_col);
                    UNPROTECT(1); // cast_col
                }
            }
        }
    }
    // Now, if we wish to check for unused entries from our template, iterate over the template 
    // and check for any values in the dictionary where `matched isn't set.
    if (log_flag & 2) { // This currently is only used for logging missing fields.
        _Bool any_missing = 0;
        for (R_xlen_t i = 0; i < n_template_col; ++i) {
            const char *curr = CHAR(STRING_ELT(template_names, i));
            R_xlen_t idx = shgeti(template_dict, curr);
            if (idx < 0) {
                Rf_error("Something went very wrong with the dictionary.");
            }
            if (!template_dict[i].value.matched) {
                if (!log_started) {
                    log_started = 1;
                    Rprintf("In %s:\n", call_string);
                }
                if (!any_missing) {
                    Rprintf("\tTarget is missing field(s):");
                    any_missing = 1;
                }
                Rprintf(" %s", curr);
            }
        }
        if (any_missing) {
            Rprintf("\n");
        }
    }
    if (log_flag & 4) { // Log excess fields in one line from their array.
        if (arrlen(excess_arr) > 0) {
            if (!log_started) {
                log_started = 1;
                Rprintf("In %s:\n", call_string);
            }
            Rprintf("\tTarget field(s) not in template:");
            for (R_xlen_t i = 0; i < arrlen(excess_arr); ++i) {
                Rprintf(" %s", CHAR(STRING_ELT(target_names, excess_arr[i])));
            }
            Rprintf("\n");
        }
    }
    // Cleanup and return.
    arrfree(excess_arr);
    shfree(template_dict); // Free the hash map.
    UNPROTECT(4); // output, template, target_names, template_names.
    return output;
}

// Get the string name of a type and write up to maxlen of it to the provided buffer.
static void get_type_name(char *buf, unsigned int type, int maxlen) {
    const char *type_name;
    switch(type) {
        case REALSXP: type_name = "double"; break;
        case INTSXP: type_name = "integer"; break;
        case CPLXSXP: type_name = "complex"; break;
        case LGLSXP: type_name = "logical"; break;
        case STRSXP: type_name = "string"; break;
        case VECSXP: type_name = "list"; break;
        case LISTSXP: type_name = "pairlist"; break;
        case DOTSXP: type_name = "..."; break;
        case NILSXP: type_name = "NULL"; break;
        case SYMSXP: type_name = "symbol"; break;
        case CLOSXP: type_name = "function"; break;
        case ENVSXP: type_name = "environment"; break;
        default: type_name = "unknown";break;
    }
    snprintf(buf, maxlen, "%s", type_name);
}
