//@ts-ignore
import * as dcmjs from 'dcmjs';
import * as mathjs from 'mathjs';
import * as jpeg from 'jpeg-lossless-decoder-js';
import { getK, getCubePatient, getCubeImage, getTransformation } from './coords';
import { WindowLevel, PresetSoftTissue } from '../rtviewer-core/window-level';
import { Plane, Direction, Axis } from '../rtviewer-core/view';
import { BoundingBox } from '../math/bounding-box';
import { getDistance } from "../rtviewer-core/mouse-tools/utils";
import {axisAlignedOrientation} from "../util";
import _ from 'lodash';
import { Scan } from '../store/scans';

export enum Modality { CT = 'CT', MR = 'MR', Unknown = 'UNKNOWN' };

export enum PatientPosition { HeadFirstSupine = 'HFS', HeadFirstProne = 'HFP', Unsupported = 'unsupported' }

// TODO: move entire DICOM metadata into a helper class at some point
// (or find better typings for dcmjs)
const TRANSFER_SYNTAX_JPEG_LOSSLESS_NONHIERARCHICAL_FIRST_ORDER_PREDICTION = '1.2.840.10008.1.2.4.70';

export class ImageSlice {
    public patientName: string;
    public patientId: string;
    public seriesInstanceUID: string;
    public sopInstanceUID: string;
    public pixelSpacing: any; // J, I
    public imagePosition: any;
    public imageOrientation: any;
    public patientPosition: PatientPosition;
    public modality: Modality;
    public seriesDescription: string;
    public protocolName: string;
    public studyDescription: string;
    public seriesDate: string;
    public studyDate: string;
    public bodyPartExamined?: string;
    public manufacturer: string;
    public manufacturerModelName: string;
    public rows: number;
    public cols: number;
    public data: Int16Array | null; // Set to null after it has been used to construct an Image

    constructor(arrayBuffer: ArrayBuffer, imageDataset: any) {
        if (!imageDataset.PixelSpacing || !imageDataset.ImagePositionPatient || !imageDataset.ImageOrientationPatient) {
            throw new Error("Not a valid DICOM image");
        }
        this.patientName = imageDataset.PatientName;
        this.patientId = imageDataset.PatientID;
        this.seriesInstanceUID = imageDataset.SeriesInstanceUID.substring(0,64);

        this.sopInstanceUID = imageDataset.SOPInstanceUID.substring(0,64);
        this.pixelSpacing = imageDataset.PixelSpacing; // J, I
        this.imagePosition = imageDataset.ImagePositionPatient;
        this.imageOrientation = imageDataset.ImageOrientationPatient;
        this.patientPosition = parsePatientPosition(imageDataset.PatientPosition);
        this.modality = parseModality(imageDataset.Modality);
        this.seriesDescription = imageDataset.SeriesDescription;
        this.protocolName = imageDataset.ProtocolName;
        this.studyDescription = imageDataset.StudyDescription;
        this.seriesDate = imageDataset.SeriesDate;
        this.studyDate = imageDataset.StudyDate;
        this.bodyPartExamined = imageDataset.BodyPartExamined;
        this.manufacturer = imageDataset.Manufacturer;
        this.manufacturerModelName = imageDataset.ManufacturerModelName;
        this.rows = imageDataset.Rows;
        this.cols = imageDataset.Columns;
        const intercept = parseFloat(imageDataset.RescaleIntercept);
        const slope = parseFloat(imageDataset.RescaleSlope);
        
        if (imageDataset.BitsAllocated !== 16) throw new Error("Unsupported BitsAllocated value " + imageDataset.BitsAllocated);
        const highBit = imageDataset.HighBit;
        if (highBit !== imageDataset.BitsStored - 1) throw new Error("Unsupported bit format. Highbit doesn't equal BitsStored - 1");
        // dcmjs 0.19.6 has a bug where naturalized array buffers (such as PixelData) are unnecessarily put inside an array -- the following line 
        // works with both the original (and hopefully future fixed case) where array buffers are NOT inside an array, and the current 0.19.6
        // case where they are inside an array (which has a 'length' prop)
        // TODO: revert this workaround once dcmjs has been patched
        const arrayBuf = _.has(imageDataset.PixelData, 'length') ? imageDataset.PixelData[0] : imageDataset.PixelData;

        let pixels: Int16Array | Uint16Array;
        let length: number;

        // special handling for JPEG files, although we currently only support one specific jpeg syntax
        // see also: https://www.dicomlibrary.com/dicom/transfer-syntax/
        if (_.get(imageDataset._meta, 'TransferSyntaxUID.Value[0]') === TRANSFER_SYNTAX_JPEG_LOSSLESS_NONHIERARCHICAL_FIRST_ORDER_PREDICTION) {
            const jpegDecoder = new jpeg.lossless.Decoder();
            const jpegOutput: ArrayBuffer = jpegDecoder.decompress(arrayBuf);
            pixels = new Int16Array(jpegOutput);
            length = pixels.length;
            this.data = new Int16Array(length);
        } 
        else {
            length = Math.floor(arrayBuf.byteLength / 2);
            pixels = imageDataset.PixelRepresentation === 0 ? new Int16Array(arrayBuf, 0, length) : new Uint16Array(arrayBuf, 0, length);
            this.data = new Int16Array(length);
        }

        let mask: any = null;
        if (imageDataset.BitsAllocated > imageDataset.BitsStored) { // Mask extra bits. They might be used for storing something
            const str = Array(imageDataset.BitsStored + 1).join("1");
            mask = parseInt(str, 2);
        }
        for (let i = 0; i < length; ++i) {
            let pixelIntensity = pixels[i];
            if (mask){
                pixelIntensity = pixelIntensity & mask;
            }
            this.data[i] = slope ? pixelIntensity * slope + intercept : pixelIntensity;
        }

        // console.log(imageDataset);
    }

}

/** Image object for RTViewer (the actual viewer) */
export class Image {
    public imShape: number[]
    public T: number[][];  // transformation matrix to go between image/patient coordinates
    public cubePatient: number[][];  // scan corner points (cube) in patient coordinates
    public cubeImage: number[][];  // scan corner points (cube) in image coordinates

    public iIdx: number;
    public jIdx: number;
    public kIdx: number;  // dimension along which slices are stacked

    public scanId: string;
    public modality: Modality;
    public patientPosition: PatientPosition;
    public seriesDescription: string;
    public protocolName: string;
    public studyDescription: string;
    public bodyPartExamined?: string;

    // IJK are indices in image/scan coordinate system
    public iPixels: number;
    public jPixels: number;
    public kPixels: number;

    public iSpacing: number;
    public jSpacing: number;
    public kSpacing: number;

    // scaled image/scan shape in MM
    public iSizeMM: number;
    public jSizeMM: number;
    public kSizeMM: number;

    public iMin: number;
    public jMin: number;
    public kMin: number;
    public iMax: number;
    public jMax: number;
    public kMax: number;

    // MM coordinates of the corners in image/scan coordinate
    public topLeft: number[];
    public bottomRight: number[];

    public orientationMatrix: number[];
    
    public data: Float32Array;  // image/scan data
    public sliceIds: string[];  // SOPInstanceUIDs of slices. This array is in slice order.
    public defaultWindowLevel: WindowLevel;
    public dicomTags: any;

    constructor(slices: ImageSlice[], dicomTags: any, scanId: string) {
        slices = sortSlices(slices);
        const slice = slices[0];

        this.imShape = [slices.length, slice.rows, slice.cols];
        let topLeftFirstSlice = slices[0].imagePosition;
        let topLeftLastSlice = slices.slice(-1)[0].imagePosition;

        this.T = getTransformation(slice.imageOrientation, [parseFloat(slice.pixelSpacing[0]), parseFloat(slice.pixelSpacing[1])], topLeftFirstSlice, topLeftLastSlice, slices.length);
        this.cubeImage = getCubeImage(this.imShape);
        this.cubePatient = getCubePatient(this.imShape, this.T);
        console.log('Image coordinates (IJK) at scan corners: ' + this.cubeImage);
        console.log('Patient coordinates (XYZ) at scan corners: ' + this.cubePatient);
        this.kIdx = getK(topLeftFirstSlice, topLeftLastSlice);
        if (this.kIdx === 0) {
            this.iIdx = 1;
            this.jIdx = 2;
        }
        else if (this.kIdx === 1) {
            this.iIdx = 0;
            this.jIdx = 2;
        }
        else {  //  if (this.kIdx == 2)
            this.iIdx = 0;
            this.jIdx = 1;
        }

        this.scanId = scanId;
        this.modality = parseModality(slice.modality);
        this.patientPosition = slice.patientPosition;
        this.seriesDescription = slice.seriesDescription;
        this.protocolName = slice.protocolName;
        this.studyDescription = slice.studyDescription;
        this.bodyPartExamined = slice.bodyPartExamined;
        this.iPixels = slice.cols;
        this.jPixels = slice.rows;
        this.kPixels = slices.length;
        
        this.iSpacing = parseFloat(slice.pixelSpacing[1]);
        this.jSpacing = parseFloat(slice.pixelSpacing[0]);
        this.kSpacing = getMedianKSpacing(slices);
        // console.log([this.iSpacing, this.jSpacing, this.kSpacing])

        this.iSizeMM = (this.iPixels - 1 ) * this.iSpacing;
        this.jSizeMM = (this.jPixels - 1 ) * this.jSpacing;
        this.kSizeMM = (this.kPixels - 1 ) * this.kSpacing;

        // console.log([this.iSpacing, this.jSpacing, this.kSpacing])
        // console.log([this.iSizeMM, this.jSizeMM, this.kSizeMM])
        // console.log(this.imShape)

        // this.topLeft = [
        //     parseFloat(slice.imagePosition[this.iIdx]),
        //     parseFloat(slice.imagePosition[this.jIdx]),
        //     parseFloat(slice.imagePosition[this.kIdx])
        // ];
        this.topLeft = [0.0, 0.0, 0.0];

        let zVector = cross(slice.imageOrientation.slice(0, 3), slice.imageOrientation.slice(3, 6));
        if(slices.length > 1) { // Calculate real z-vector based on the image position attribute
            const i = Math.floor( (slices.length - 1) / 2);
            const s1 = slices[i];
            const s2 = slices[i+1];
            const diff = [
                s2.imagePosition[0] - s1.imagePosition[0],
                s2.imagePosition[1] - s1.imagePosition[1],
                s2.imagePosition[2] - s1.imagePosition[2]
            ];
            const length = getDistance(s2.imagePosition, s1.imagePosition);
            zVector = [
                diff[this.iIdx] / length,
                diff[this.jIdx] / length,
                diff[this.kIdx] / length
            ];
            zVector = [0., 0., 1.0];
        }
        this.orientationMatrix = slice.imageOrientation.concat(zVector);
        // let a = this.cubePatient.slice(-1)[0]
        // this.bottomRight = [a[this.iIdx], a[this.jIdx], a[this.kIdx]]
        // this.bottomRight = [this.iPixels - 1, this.jPixels - 1, this.kPixels - 1]
        this.bottomRight = [this.iSizeMM, this.jSizeMM, this.kSizeMM];

        this.iMin = 0;
        this.jMin = 0;
        this.kMin = 0;
        this.iMax = this.iPixels;
        this.jMax = this.jPixels;
        this.kMax = this.kPixels;
        // console.log([this.iMin, this.iMax, this.jMin, this.jMax, this.kMin, this.kMax])

        this.data = new Float32Array(this.iPixels * this.jPixels * this.kPixels);
        this.sliceIds = slices.map(s => s.sopInstanceUID);
        let planeSize = this.iPixels * this.jPixels;
        let minValue = 999999;
        let maxValue = -999999;
        for(let i = 0; i < slices.length; ++i) {
            let slice = slices[i];
            if(slice.data === null || slice.data === undefined) { throw new Error("Could not get image data!"); }
            for(let j = 0; j < planeSize; ++j) {
                this.data[i * planeSize + j] = slice.data[j];
                minValue = Math.min(minValue, slice.data[j]);
                maxValue = Math.max(maxValue, slice.data[j]);
            }
            slice.data = null;
        }
        let ww = maxValue - minValue;
        let wc = (minValue + maxValue) / 2;
        this.defaultWindowLevel = (this.modality === Modality.CT) ? PresetSoftTissue : new WindowLevel(ww, wc);
        this.dicomTags = dicomTags;
        console.log("orientation: " + this.orientationMatrix + ' : ' + this.kIdx);
        this.orientationMatrix = axisAlignedOrientation;  // the orientation matrix for IJK coordinates
    }

    getValue(i: number, j: number, k: number): number {
        // find the voxel value/intensity at given index
        return this.data[(k * this.iPixels * this.jPixels) + (j * this.iPixels) + i]
    }

    getDirections() {
        const o = this.orientationMatrix;
        return {
            [Plane.Transversal]: [new Direction(Axis.X, o[0] < 0), new Direction(Axis.Y, o[4] < 0), new Direction(Axis.Z, false)],
            [Plane.Coronal]:     [new Direction(Axis.X, o[0] < 0), new Direction(Axis.Z, true ), new Direction(Axis.Y, o[4] < 0)],
            [Plane.Sagittal]:    [new Direction(Axis.Y, o[4] > 0 ), new Direction(Axis.Z, true ), new Direction(Axis.X, o[0] < 0)],
        };
    }

    getDirectionLetters() {
        // old implementation:
        // const o = this.orientationMatrix;
        // const xInv = o[0] < 0;
        // const yInv = o[4] < 0;

        // use patient position to get axis labels
        // in unsupported cases default to HFS labels
        const xInv = this.patientPosition === PatientPosition.HeadFirstProne;
        const yInv = this.patientPosition === PatientPosition.HeadFirstProne;

        return {
            // XYZ labels will be wrong for some MR scans but are being used as annotators have difficulty understanding two coordinate systems.
            [Plane.Transversal]: [yInv ? "P" : "A", xInv ? "R" : "L", yInv ? "A" : "P", xInv ? "L" : "R"],
            [Plane.Coronal]:     ["H", xInv ? "R" : "L", "F", xInv ? "L" : "R"],
            [Plane.Sagittal]:    ["H", yInv ? "P" : "A", "F", yInv ? "A" : "P"],

            // these labels are correct
            // [Plane.Transversal]: ["I", "J"], // [yInv ? "P" : "A", xInv ? "R" : "L", yInv ? "A" : "P", xInv ? "L" : "R"],
            // [Plane.Coronal]:     ["I", "K"],  // ["H", xInv ? "R" : "L", "F", xInv ? "L" : "R"],
            // [Plane.Sagittal]:    ["J", "K"],  // ["H", yInv ? "P" : "A", "F", yInv ? "A" : "P"],
        } 
    };

    isPointIn(ptMm: number[]): boolean {
        // check if a given point is within image/scan/slice bounds
        // ptMm is in MM
        const isIn = (val: number, limit1: number, limit2: number) => {
            let tol = 0.001  // todo: find out why slight expansion is needed for some scans (nccs-prostate-101-0)?
            const min = Math.min(limit1 - tol, limit2 + tol);
            const max = Math.max(limit1 - tol, limit2 + tol);
            return val >= min && val <= max;
        }

        return isIn(ptMm[0], this.topLeft[0], this.bottomRight[0])
            && isIn(ptMm[1], this.topLeft[1], this.bottomRight[1])
            && isIn(ptMm[2], this.topLeft[2], this.bottomRight[2]);
    }

    getSliceIdsForArea(bb: BoundingBox) {
        const result: string[] = [];
        this.sliceIds.forEach((sliceId, indx) => {
            const zPos = this.kMin + indx * this.kSpacing;
            if(zPos >= bb.minK && zPos <= bb.maxK) {
                result.push(sliceId);
            }
        });
        return result;
    }

    // Bounding box AROUND all pixels, not by corner pixel centers.
    getRealBoundingBox() {
        const bb = new BoundingBox();
        bb.resetToImageDimensions(this);
        return bb;
    }

    static generateFromScan(scan: Scan): Image {
        if (Object.keys(scan.slices).length === 0) {
            throw new Error('Cannot generate an RTViewer image from a scan with no slices!');
        }
    
        const firstKey = Object.keys(scan.slices)[0];
        const ab = scan.slices[firstKey].arrayBuffer;
        const dicomTags = readImageDataset(ab);
        const slices = Object.values(scan.slices).map(s => s.imageSlice);
        const img = new Image(slices, dicomTags, scan.scanId);

        return img;
    }
    
}

export function readImageDataset(arrayBuffer: ArrayBuffer): any {
    let dicomData = dcmjs.data.DicomMessage.readFile(arrayBuffer);
    let imageDataset = dcmjs.data.DicomMetaDictionary.naturalizeDataset(dicomData.dict);
    imageDataset._meta = dcmjs.data.DicomMetaDictionary.namifyDataset(dicomData.meta);
    if(!imageDataset.PixelSpacing || !imageDataset.ImagePositionPatient || !imageDataset.ImageOrientationPatient) {
        return null; // This dicom file is not an image slice
    }
    //console.log(imageDataset)
    return imageDataset;
}

export function dicomToImageSlice(arrayBuffer: ArrayBuffer): ImageSlice | null {
    const imageDataset = readImageDataset(arrayBuffer);
    return imageDataset ? new ImageSlice(arrayBuffer, imageDataset) : null;
}


function sortSlices(slices: ImageSlice[]) {
    let first = true;
    let depthAxis = null;
    let refPos: any = null;
    for(let i = 0; i < slices.length; ++i) {
        let slice = slices[i];
        let imagePosition = slice.imagePosition;
        let imageOrientation = slice.imageOrientation; // I, J

        // Calculate delta from refPos projected to depth axis
        let Oj = imageOrientation.slice(0,3);
        let Oi = imageOrientation.slice(3,6);
        let Ok = mathjs.cross(Oi, Oj);
        depthAxis = Ok;
        if(first) {
            refPos = imagePosition;
            first = false;
        }
        //@ts-ignore : mathjs
        (slice as any).delta = mathjs.dot(mathjs.subtract(imagePosition, refPos), depthAxis)
    }

    function compare(a: any,b: any) {
        if(a.delta < b.delta) {
            return 1;
        }
        if(a.delta > b.delta) {
            return -1;
        }
        return 0;
    }
    return slices.sort(compare);
}

function getMedianKSpacing(sortedSlices: ImageSlice[]): number {
    if(sortedSlices.length === 1) return 1;
    let kSpacings = [];
    for(let i = 1; i < sortedSlices.length; ++i) {
        //@ts-ignore : mathjs
        kSpacings.push(mathjs.sqrt(mathjs.sum(mathjs.square(mathjs.subtract(sortedSlices[i].imagePosition, sortedSlices[i - 1].imagePosition)))))
        // kSpacings.push( sortedSlices[i].imagePosition[kIdx] - sortedSlices[i-1].imagePosition[kIdx] );
    }
    kSpacings = kSpacings.sort();
    let kSpacingsRounded = mathjs.round(kSpacings, 1);
    let uniqueKSpacingsRounded = [...new Set(kSpacingsRounded)];
    // console.log(uniqueKSpacingsRounded)
    if( uniqueKSpacingsRounded.length > 1 ) alert("Warning! Image K position interval (" + uniqueKSpacingsRounded + " ) not constant. Image will not be shown correctly!");
    return Math.abs(kSpacings[ Math.floor(kSpacings.length/2) ]);
}

function parseModality(modality: string): Modality {
    if(modality.toLowerCase() === 'ct') return Modality.CT;
    if(modality.toLowerCase() === 'mr') return Modality.MR;
    let error = "Unknown modality: " + modality;
    console.log(error);
    alert(error);
    throw new Error(error);
}

function parsePatientPosition(dicomValue?: string): PatientPosition {
    if (dicomValue) {
        if (dicomValue === 'HFS') { return PatientPosition.HeadFirstSupine; }
        if (dicomValue === 'HFP') { return PatientPosition.HeadFirstProne; }
    }

    return PatientPosition.Unsupported;
}


// // For generating dummy test images 
// export function getRightHalf(pd: PixelData): PixelData {
//     let shape = [pd.shape[0], pd.shape[1] / 2, pd.shape[2]];
//     return subArray(pd, shape);
// }

// export function getAnteriorHalf(pd: PixelData): PixelData {
//     let shape = [pd.shape[0], pd.shape[1], pd.shape[2] / 2];
//     return subArray(pd, shape);
// }

// export function getFeetHalf(pd: PixelData): PixelData {
//     let shape = [pd.shape[0]/2, pd.shape[1], pd.shape[2]];
//     return subArray(pd, shape);
// }
// function subArray(pd: PixelData, newShape: number[]): PixelData {
//     let newx = newShape[1];
//     let newy = newShape[2];
//     let newz = newShape[0];

//     let oldx = pd.shape[1];
//     let oldy = pd.shape[2];

//     //let data = new Int16Array(newz * newx * newy);
//     let data = new Float32Array(newz * newx * newy);
//     for(let z = 0; z < newz; ++z) {
//         for(let x = 0; x < newx; ++x) {
//             for( let y = 0; y < newy; ++y) {
//                 data[z*newx*newy + y*newx + x] = pd.data[z*oldx*oldy + y*oldx + x]; 
//             }
//         }
//     }
//     return new PixelData(data, newShape, pd.spacing, pd.position);
// }

export function cross (a: number[], b: number[]) {
    return [
      a[1] * b[2] - a[2] * b[1],
      a[2] * b[0] - a[0] * b[2],
      a[0] * b[1] - a[1] * b[0]
    ]
}