import { type Client, type Session, type Socket } from '@heroiclabs/nakama-js';
import { WebSocketAdapterPb } from '@heroiclabs/nakama-js-protobuf';
import { mainLogger } from '@nord-beaver/core/utils/logger';
import { Signal } from '@nord-beaver/core/utils/signal';
import { retry } from '@nord-beaver/core/utils/utils';
import { type DependencyContainer } from 'game/utils/dependencyContainer';

type SocketEventHandlers = {
    [K in keyof Socket]: K extends `on${string}` ? K : never;
}[keyof Socket];

export interface SocketNakamaConfig {
    socketRetryCount?: number;
    socketRetryTimeoutMs?: number;
}

const logger = mainLogger.getLogger('Nakama', '#2c92ff').getLogger('Socket');

export class SocketNakamaService {
    private readonly noop = () => {};

    private socketRetryCount = 5;
    private socketRetryTimeoutMs = 100;

    private socketSignalMap = new Map<Socket, Map<SocketEventHandlers, Signal<Parameters<Socket[SocketEventHandlers]>>>>();

    constructor(
        _dependencyContainer: DependencyContainer,
    ) {
        logger.info('initialized');
    }

    config(config: SocketNakamaConfig) {
        if (config.socketRetryCount !== undefined) {
            this.socketRetryCount = config.socketRetryCount;
        }

        if (config.socketRetryTimeoutMs !== undefined) {
            this.socketRetryTimeoutMs = config.socketRetryTimeoutMs;
        }
    }

    async connect(client: Client, session: Session, withRetry = true) {
        if (!withRetry) {
            const soketPair = this.createSoket(client);
            await this.connectSocket(soketPair, session);

            return soketPair;
        }

        const soketPair = await this.retryConnection(() => this.createSoket(client), 'connectSocket');

        await this.retryConnection(() => this.connectSocket(soketPair, session), 'connectSocket');

        return soketPair;
    }

    disconnect(socket: Socket, socketAdapter?: WebSocketAdapterPb) {
        socket.disconnect(false);
        this.cleanupSocket(socket, socketAdapter);
    }

    getSocketEventSignal<K extends SocketEventHandlers>(socket: Socket, event: K) {
        let signalMap = this.socketSignalMap.get(socket);
        if (!signalMap) {
            signalMap = new Map();
            this.socketSignalMap.set(socket, signalMap);
        }

        let signal = signalMap.get(event) as Signal<Parameters<Socket[K]>>;
        if (!signal) {
            signal = new Signal();
            signalMap.set(event, signal);

            const handler = ((...args: Parameters<Socket[K]>) => {
                signal?.dispatch([...args]);
            }) as Socket[K];

            (socket[event] as Socket[K]) = handler;
        }

        return signal;
    }

    /**
     * Unsubscribe from all socket events
     */
    unsubscribeSocketEvent<K extends SocketEventHandlers>(socket: Socket, event: K): void {
        const signalMap = this.socketSignalMap.get(socket);
        if (!signalMap) {
            return;
        }

        const signal = signalMap.get(event);
        if (signal) {
            signal.offAll();
            signalMap.delete(event);
        }

        (socket[event] as Socket[SocketEventHandlers]) = this.noop;
    }

    private getDisconnectHandler(_socket: Socket, _socketAdapter?: WebSocketAdapterPb) {
        return (event: CloseEvent | Event) => {

            if (event instanceof CloseEvent) {
                switch (event.code) {
                    case 1006:
                        logger.error(`Socket shutdown ${event.code}`, event);
                        break;
                    case 1005:
                        logger.error(`Socket disconnected ${event.code}`, event);
                        break;
                    default:
                        logger.error(`Socket closed ${event.code}`, event);
                        break;
                }
            } else {
                logger.error('Socket disconnected with unknown event', event);
            }
        };
    }

    private onSocketError(error: unknown): void {
        const message = error instanceof Error ? error.message : error;
        logger.error('Socket error occurred', { error: message });
    }

    private cleanupSocket(socket: Socket, socketAdapter?: WebSocketAdapterPb): void {
        this.clearSocketEventHandlers(socket);
        socket.disconnect(false);

        const signalMap = this.socketSignalMap.get(socket);
        if (signalMap) {
            signalMap.forEach(signal => signal.offAll());
            this.socketSignalMap.delete(socket);
        }

        if (socketAdapter && socketAdapter.isOpen()) {
            socketAdapter.close();
        }

        logger.info('Socket cleaned up');
    }

    private clearSocketEventHandlers(socket: Socket): void {
        const eventHandlers: SocketEventHandlers[] = [
            'ondisconnect',
            'onerror',
            'onnotification',
            'onmatchdata',
            'onmatchpresence',
            'onmatchmakerticket',
            'onmatchmakermatched',
            'onparty',
            'onpartyclose',
            'onpartydata',
            'onpartyjoinrequest',
            'onpartyleader',
            'onpartypresence',
            'onpartymatchmakerticket',
            'onstatuspresence',
            'onstreampresence',
            'onstreamdata',
            'onheartbeattimeout',
            'onchannelmessage',
            'onchannelpresence',
        ];

        for (const handler of eventHandlers) {
            this.unsubscribeSocketEvent(socket, handler);
        }
    }

    private async connectSocket([socket, socketAdapter]: Readonly<[Socket, WebSocketAdapterPb]>, session: Session) {
        try {
            return await socket.connect(session, true);
        } catch (error) {
            const message = error instanceof Error ? error.message : error;
            logger.error('Failed to connect socket', { error: message });
            this.cleanupSocket(socket, socketAdapter);

            throw error;
        }
    }

    private createSoket(client: Client) {
        try {
            const socketAdapter = new WebSocketAdapterPb();
            const socket = client.createSocket(client.useSSL, false, socketAdapter);

            const errorSocketSignal = this.getSocketEventSignal(socket, 'onerror');
            errorSocketSignal.on(this.onSocketError, socket);

            const ondisconnect = this.getDisconnectHandler(socket, socketAdapter);
            const closeSignal = this.getSocketEventSignal(socket, 'ondisconnect');
            closeSignal.on(([event]) => ondisconnect(event), socket);

            return [socket, socketAdapter] as const;
        } catch (error) {
            const message = error instanceof Error ? error.message : error;
            logger.error('Failed to create socket', { error: message });

            throw error;
        }
    }

    private async retryConnection<T>(
        action: () => Promise<T> | T,
        methodName: string,
    ) {

        return new Promise<T>((resolve, reject) => {
            let attempt = 0;

            retry({
                work: async () => {
                    try {
                        return await action();
                    } catch (error) {
                        const message = error instanceof Error ? error.message : error;
                        logger.error(`${methodName} attempt failed`, { error: message });
                        throw error;
                    }
                },
                beforeRetry: () => logger.warn(`Retrying ${methodName}, attempt ${++attempt}`),
                count: this.socketRetryCount,
                timeout: this.socketRetryTimeoutMs,
            })
                .then(resolve)
                .catch(reject);
        });
    }
}
