/**
 * IPtool: Calculate CIDR/IP-Range addresses and match IP's against them.
 *
 * Charalampos Pournaris (charpour@gmail.com)
 */

#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <inttypes.h>
#include <string.h>
#include <ctype.h>

/* Function prototypes. */
static int get_ip_range(const char *range, uint32_t *sip, uint32_t *eip);
static void get_cidr_address(const char *address, char *cidr, size_t len,
        int *cidr_bits);
static int get_network_bits(uint32_t sip, uint32_t eip);
static uint32_t get_ip(const char *ip_dotted);
static char *get_address(uint32_t ipa);

#ifndef NDEBUG
static void print_bits(uint32_t ip);
#endif

int main(int argc, char *argv[]) {
    uint32_t ip_start, ip_end, temp, mask = 0;
    int bitcount;
    char cidr[17];
    
    if (argc < 3) {
        fprintf(stderr, "Usage: %s range | cidr | cmatch | rmatch [args]\n",
                argv[0]);
        exit(EXIT_FAILURE);
    }

    if (strcasecmp(argv[1], "cidr") == 0) {
        if (argc != 4) {
            fprintf(stderr, "Invalid usage\n");
            exit(EXIT_FAILURE);
        }

        ip_start = get_ip(argv[2]);
        ip_end = get_ip(argv[3]);

        if (!ip_end) {
            fprintf(stderr, "Invalid IP's given, try again\n");
            exit(EXIT_FAILURE);
        }

        bitcount = get_network_bits(ip_start, ip_end);
        
#ifndef NDEBUG
        fprintf(stdout, "Bitcount: %d\n", bitcount);
        fprintf(stdout, "IP start: %u IP end: %u\n", ip_start, ip_end);

        print_bits(ip_start);
        print_bits(ip_end);
        print_bits(ip_start & ip_end);
        putchar('\n');
#endif
        
        if (bitcount)
            mask = ~0x0U << (32 - bitcount);
#ifdef DMOD
        else
            mask = ~0x0U;
#endif
               
        fprintf(stdout, "%s/%u\n", get_address(ip_start & mask), bitcount);
    } else if ((strcasecmp(argv[1], "cmatch") == 0)) {
        if (argc != 4) {
            fprintf(stderr, "Invalid usage\n");
            exit(EXIT_FAILURE);
        }
        
        uint32_t mip = get_ip(argv[3]);

        get_cidr_address(argv[2], cidr, sizeof (cidr), &bitcount);

        temp = get_ip(cidr);
        
        if (bitcount)
            mask = ~0x0U << (32 - bitcount);        
#ifdef DMOD
        else {
            if (temp) {
                mask = ~0x0U;
            }
        }
#endif

#ifndef NDEBUG
        printf("Bitcount: %d\n", bitcount);
        print_bits(mip);
        putchar('\n');
#endif

        if ((mip & mask) == (temp & mask))
            fprintf(stdout, "IP match\n");
    } else if ((strcasecmp(argv[1], "rmatch") == 0)) {
        if (argc != 4) {
            fprintf(stderr, "Invalid usage\n");
            exit(EXIT_FAILURE);
        }

        uint32_t mip = get_ip(argv[3]);

        if (get_ip_range(argv[2], &ip_start, &ip_end)) {
            fprintf(stderr, "Error occured while parsing the IP range\n");
            exit(EXIT_FAILURE);
        }

        if (mip >= ip_start && mip <= ip_end)
            fprintf(stdout, "IP match\n");
    } else if ((strcasecmp(argv[1], "range") == 0)) {
        get_cidr_address(argv[2], cidr, sizeof(cidr), &bitcount);

        temp = get_ip(cidr);

        if (bitcount)
            mask = ~0x0U << (32 - bitcount);        
#ifdef DMOD
        else {
            if (temp) {
                mask = ~0x0U;
            }
        }
#endif
        
        ip_start = temp & mask;
        ip_end = ip_start | ~mask;

#ifndef NDEBUG
        print_bits(ip_start);
        print_bits(ip_end);
        putchar('\n');
#endif
        fprintf(stdout, "%s - ", get_address(ip_start));
        fprintf(stdout, "%s\n", get_address(ip_end));

    } else {
        fprintf(stderr, "Invalid option: '%s'\n", argv[1]);
        exit(EXIT_FAILURE);
    }
    
    return 0;
}

/****************************************************************************/

/**
 * get_ip_range: Split an IP range into start and end IP's.
 */

static int get_ip_range(const char *range, uint32_t *sip, uint32_t *eip) {
    static char rtext[512];
    char *t, *s, *e;

    strncpy(rtext, range, 512);

    t = rtext;

    while(*t && isspace(*t))
        t++;

    s = t;

    while(*t && (isdigit(*t) || *t == '.'))
        t++;

    if (*t == '\0')
        return 1;

    *t++ = '\0';

    *sip = get_ip(s);

    while(*t && !isdigit(*t))
        t++;

    e = t;

    if (*t == '\0')
        return 1;

    while (*t && (isdigit(*t) || *t == '.'))
        t++;

    *t = '\0';

    *eip = get_ip(e);

    if (!*eip)
        return 1;

    return 0;
}

/****************************************************************************/

/**
 * get_cidr_address: Parse a CIDR address and return the start IP and network
 * bits.
 */

static void get_cidr_address(const char *address, char *cidr, size_t len,
        int *cidr_bits) {
    int dot_count = 0;
    int l = 0;
    const char *s = address;

    while(*s) {
        if (isdigit(*s) || *s == '.') {
            if (*s == '.')
                dot_count++;            
            
            if (l == (int) len-1)
                break;            

            cidr[l++] = *s;
        } else if (*s == '/') {
            while(dot_count < 3 && l+1 < (int) len-1) {
                cidr[l++] = '.';
                cidr[l++] = '0';

                dot_count++;
            }

            *cidr_bits = atoi(s+1);
            break;
        }

        s++;
    }

    cidr[l] = '\0';
}

/****************************************************************************/

/**
 * get_network_bits: Number of network bits for CIDR addresses.
 */

static int get_network_bits(uint32_t sip, uint32_t eip) {
    int i;
    int bitcount = 0;

    for (i = 31; i >= 0; i--) {
        if (((sip >> i) & 0x1U) == ((eip >> i) & 0x1U))
            bitcount++;
        else
            break;        
    }

    return bitcount;
}

/****************************************************************************/

/**
 * print_bits: Print all the bits of a 32 bit unsigned integer.
 */

#ifndef NDEBUG
static void print_bits(uint32_t ip) {
    int i;

    fprintf(stdout, "\n%u in binary: ", ip);
    for (i=31; i >= 0; i--)
        (ip >> i) & 0x1U ? putchar('1') : putchar('0');
    putchar('\n');

}
#endif

/****************************************************************************/

/**
 * get_ip: Get a numeric representation from a dotted IP.
 */

#define LEG_RANGE(n) ((n) >= 0 && (n) <= 255)

static uint32_t get_ip(const char *ip_dotted) {
    int a, b, c, d;
    uint32_t ipa = 0;

    if (sscanf(ip_dotted, "%d.%d.%d.%d", &a, &b, &c, &d) == 4) {
        if (!LEG_RANGE(a) || !LEG_RANGE(b) || !LEG_RANGE(c) || !LEG_RANGE(d))
            return 0;        
        ipa = (a << 24) | (b << 16) | (c << 8) | d;
    }    

    return ipa;
}

/****************************************************************************/

/**
 * get_address: Get a dotted representation of a numeric IP.
 */

static char *get_address(uint32_t ipa) {
    static char ip[17];
    int a, b, c, d;

    a = ipa >> 24 & ~(~0 << 8);
    b = ipa >> 16 & ~(~0 << 8);
    c = ipa >> 8 & ~(~0 << 8);
    d = ipa & ~(~0 << 8);

    snprintf(ip, sizeof(ip), "%d.%d.%d.%d", a, b, c ,d);

    return ip;
}

/****************************************************************************/

