import { _entries, _isEmpty, getObjectKeyByValue } from 'common/Utils';
import cornerstone from 'cornerstone-core';
import cornerstoneTools from 'cornerstone-tools';
import {
    PointMeasurementFactory,
    RectMeasurementFactory,
    strategyClickTypeMap,
} from '../CustomTools/SmartPredictionTools/entities';
import { LabelMap2D } from '../DicomViewerHelper/interface';
import { PredictionAPI } from './PredictionAPI';
import { NewPredictionArgs, StateFactory } from './entities';

const segmentationUtils = cornerstoneTools.importInternal('util/segmentationUtils');

export class SamPredictionHelper extends PredictionAPI {
    public userMask: Float32Array;
    public element: HTMLElement;

    public subscribers = new Map<string, Function>();

    constructor() {
        super();
        this.getPrediction = this.getPrediction.bind(this);
        this.initSession = this.initSession.bind(this);
        this.getEmbedding = this.getEmbedding.bind(this);
    }

    initSession(element: HTMLElement) {
        try {
            this.element = element;
            this.clearState();

            this.initModel()
                .then(() => console.log('ONNX model loaded successfully'))
                .catch(console.error);

            this.userMask = structuredClone(this.getLabelMap2D(element)?.pixelData);
        } catch (error) {
            console.error(error);
        }
    }

    defaultShift = 0;
    async getPrediction({ element, image, point, rect }: NewPredictionArgs) {
        try {
            if (this.image && this.image?.imageId !== image?.imageId) this.applyPredictionMask(this.image?.imageId);
            this.image = image;
            this.element = element;

            if (!this.userMask) this.userMask = structuredClone(this.getLabelMap2D(element)?.pixelData);

            this.historyState.push(structuredClone(this.state));
            if (this.nextState.length) this.nextState = [];

            if (point) this.clicks.push(point);
            if (rect) this.rect = rect;

            const { image_id } = this.getParams(this.image.imageId);

            const embedding = await this.getEmbedding({ image_id });

            await this.segment(embedding);

            const mask = this.mergePredictionToUserMask();

            this.updateLabelMap2D(mask, element);

            this.changeSignal();
        } catch (error) {
            console.error(error);
        }
    }

    undo() {
        if (!this.predictionResult) return;
        this.nextState.push(structuredClone(this.state));
        if (!this.historyState.length) return this.reset();
        this.state = this.historyState.pop();
        if (!this.state.result) return this.reset();
        const mask = this.mergePredictionToUserMask();
        this.updateLabelMap2D(mask, this.element);
        this.resetToolStates();
        this.changeSignal();
    }

    redo() {
        if (!this.nextState.length) return;
        this.historyState.push(structuredClone(this.state));
        this.state = this.nextState.pop();
        const mask = this.mergePredictionToUserMask();
        this.updateLabelMap2D(mask, this.element);
        this.resetToolStates();
        this.changeSignal();
    }

    changeSignal() {
        this.subscribers.forEach(subscriber => subscriber());
    }

    resetToolStates() {
        cornerstoneTools.clearToolState(this.element, 'SAMPointPrediction');
        cornerstoneTools.clearToolState(this.element, 'SAMRectPrediction');

        this.state.clicks.forEach(point => {
            cornerstoneTools.addToolState(
                this.element,
                'SAMPointPrediction',
                PointMeasurementFactory(point, getObjectKeyByValue(strategyClickTypeMap, point.clickType))
            );
        });

        if (this.state.rect) {
            const RectState = RectMeasurementFactory(this.state.rect, 0);
            cornerstoneTools.addToolState(this.element, 'SAMRectPrediction', RectState);
        }
        cornerstone.updateImage(this.element);
    }

    mergePredictionToUserMask() {
        if (!this.userMask || !this.predictionResult) return new Float32Array(this.image?.width * this.image?.height);
        const activeSegmentIndex = this.getActiveSegmentIndex(this.element);

        return this.userMask.map((value, index) =>
            this.predictionResult[index - this.defaultShift] > 0 ? activeSegmentIndex : value
        );
    }

    applyPredictionMask(imageId?: string) {
        if (!this.userMask || !this.predictionResult) return;
        const previousPixelData = structuredClone(this.userMask);

        const mask = this.mergePredictionToUserMask();

        this.updateLabelMap2D(mask, this.element, imageId);

        this.userMask = imageId ? structuredClone(this.getLabelMap2D(this.element)?.pixelData) : structuredClone(mask);

        const operation = {
            imageIdIndex: this.findImageIndex(this.element, imageId),
            diff: segmentationUtils.getDiffBetweenPixelData(previousPixelData, this.userMask),
        };

        this.segmentationModule.setters.pushState(this.element, [operation]);
        segmentationUtils.triggerLabelmapModifiedEvent(this.element);

        this.clearState();
    }

    reset() {
        try {
            const mask = this.userMask || new Float32Array(this.image?.width * this.image?.height);

            this.updateLabelMap2D(mask, this.element);
            this.clearState();
        } catch (error) {
            console.log(error);
        }
    }

    clearState() {
        this.state = StateFactory();
        this.historyState = [];
        this.nextState = [];

        if (!this.element) return;

        const toolStateManager = cornerstoneTools.globalImageIdSpecificToolStateManager;

        _entries(toolStateManager.toolState).forEach(([_imageId, toolState]) => {
            if (toolState.SAMPointPrediction) toolState.SAMPointPrediction.data = [];
            if (toolState.SAMRectPrediction) toolState.SAMRectPrediction.data = [];
        });
    }

    getLabelMap2D(element: HTMLElement, imageId?: string): LabelMap2D {
        const { getters } = this.segmentationModule;

        if (imageId) {
            const activeLabelMapIndex = getters.activeLabelmapIndex(element);
            if (_isEmpty(activeLabelMapIndex)) return;
            const imageIndex = this.findImageIndex(element, imageId);

            const labelmap3D = getters.labelmap3D(element, activeLabelMapIndex);
            return labelmap3D.labelmaps2D?.[imageIndex];
        }
        return getters.labelmap2D(element)?.labelmap2D;
    }

    findImageIndex(element: HTMLElement, imageId?: string) {
        const stack = cornerstoneTools.getToolState(element, 'stack');

        if (imageId) return stack.data[0].imageIds.findIndex((_id: string) => _id === imageId);

        return stack.data[0].currentImageIdIndex;
    }

    updateLabelMap2D(mask: Float32Array, element: HTMLElement, imageId?: string, shift: number = 0) {
        if (!mask || !element) return;
        const labelmap2D = this.getLabelMap2D(element, imageId);
        if (!labelmap2D?.pixelData) return;

        mask.forEach((value, index) => {
            if (value > 0) {
                const newIndex = index + shift;
                labelmap2D.pixelData[newIndex] = value;
            } else {
                const newIndex = index + shift;
                labelmap2D.pixelData[newIndex] = 0;
            }
        });

        this.segmentationModule.setters.updateSegmentsOnLabelmap2D(labelmap2D);
        cornerstone.updateImage(element);
    }

    getActiveSegmentIndex(element: HTMLElement) {
        const { getters } = this.segmentationModule;
        return getters.activeSegmentIndex(element);
    }

    get segmentationModule() {
        return cornerstoneTools.getModule('segmentation');
    }
}

export * from './EmbeddingAPI';
export * from './entities';

export const samPredictionHelper = new SamPredictionHelper();
