import { HttpClient, HttpErrorResponse } from '@angular/common/http';
import { Injectable } from '@angular/core';
import { Cartesian3 } from '@datumate/angular-cesium';
import { booleanPointInPolygon } from '@turf/turf';
import { chunk } from 'lodash';
import { lastValueFrom, of, Subject, throwError } from 'rxjs';
import { delay, retry, timeout } from 'rxjs/operators';

import { environment } from '../../../environments/environment';
import { WebSocketMessage, WebsocketService } from '../../shared/services/websocket.service';
import { generateUniqueId, isDefined, sleep } from '../../shared/utils/general';
import { Cartographic, GeoUtils } from '../../shared/utils/geo';
import { gzip, GZIP_HEADERS } from '../../shared/utils/gzip-utils';
import { roundTo } from '../../shared/utils/math';
import { DetailedSiteQuery } from '../state/detailed-site.query';
import { Terrain } from './terrain-provider.service';
import { TerrainProviderEdit } from './terrain-provider-edit';

const SAMPLING_TYPE_POSITIONS_COUNT_THRESHOLD = 50_000;
const S3_BUCKET_ACCESS_CHUNK_SIZE = 5500;
const WEBSOCKET_MAX_CHUNK_SIZE = 4500;
const WEBSOCKET_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
const SAMPLING_RETRY_COUNT = 5;

interface SamplingRequest {
  url: string;
  urlParams: {
    Policy: string;
    Signature: string;
    'Key-Pair-Id': string;
  };
  positions: number[][];
  modelEditPolygons?: number[][][];
}

interface SamplingResponse {
  elevations: number[];
}

interface WebSocketSamplingRequestData extends SamplingRequest {
  partId: string;
}

interface WebSocketSamplingResponseData extends SamplingResponse {
  partId: string;
}

interface WebSocketRequestMetadata {
  request: WebSocketMessage<WebSocketSamplingRequestData>;
  index: number;
  retries: number;
  isDone: boolean;
}

export interface LinkedCartesian3 extends Cartesian3 {
  linkedPosition: Cartesian3;
}

export class SamplingError extends Error {
  constructor() {
    super('Calculation Error: Failed while sampling terrain');
  }
}

@Injectable({
  providedIn: 'root'
})
export class TerrainSamplingService {
  constructor(private http: HttpClient, private siteQuery: DetailedSiteQuery, private websocketService: WebsocketService) {}

  public async sampleTerrain(positions: Cartesian3[], terrain: Terrain, withTerrainHeightOffset = true): Promise<Cartographic[]> {
    let results: Cartographic[];
    if (!terrain || terrain.type === 'FLAT') {
      results = positions.map(p => {
        const cartographic = GeoUtils.cartesian3ToDeg(p);
        cartographic.height = 0;
        return cartographic;
      });
    } else {
      // Switch LinkedCartesian3 positions to their linked position to sample their height
      const linkedPositions = new Map<number, Cartesian3>();
      const positionsToSample: Cartesian3[] = positions.map((position, index) => {
        if ('linkedPosition' in position) {
          const positionWithLink = position as LinkedCartesian3;
          linkedPositions.set(index, positionWithLink);
          position = positionWithLink.linkedPosition;
        }

        return position;
      });

      // Don't sample positions outside terrain bounds
      let positionsInBounds = positionsToSample;
      if (terrain.bounds) {
        positionsInBounds = positionsToSample.filter(p => {
          const coordinate = GeoUtils.cartesian3ToDegArray(p);
          return booleanPointInPolygon(coordinate, terrain.bounds);
        });
      }

      let elevationMapping: Map<string, number>;
      if (!terrain.url || positions.length < SAMPLING_TYPE_POSITIONS_COUNT_THRESHOLD) {
        elevationMapping = await this.httpSampling('directly')(terrain, positionsInBounds);
      } else {
        // Sample using websockets if too many positions to sample directly
        try {
          elevationMapping = await this.webSocketSampling(terrain, positionsInBounds);
        } catch {
          // Fallback to use lambda, and if that doesn't work sample directly
          try {
            elevationMapping = await this.httpSampling('lambda')(terrain, positionsInBounds);
          } catch {
            elevationMapping = await this.httpSampling('directly')(terrain, positionsInBounds);
          }
        }
      }

      results = positionsToSample.map((p, i) => {
        // Replace linked positions with actual positions
        const position = i in linkedPositions ? linkedPositions[i] : p;
        const positionDeg = GeoUtils.cartesian3ToDeg(position);
        positionDeg.height = elevationMapping?.get([p.x, p.y].join(','));
        return positionDeg;
      });
    }

    // Add site offset if needed
    const siteOffset = this.siteQuery.getTerrainHeightOffset();
    if (withTerrainHeightOffset && terrain.type === 'TASK' && siteOffset) {
      results.forEach(p => {
        if (isDefined(p.height)) {
          p.height += siteOffset;
        }
      });
    }

    return results;
  }

  private webSocketSampling = async (terrain: Terrain, positions: Cartesian3[]) => {
    if (!isDefined(positions)) {
      return;
    }

    let modelEditPolygons: number[][][];
    if (terrain.provider instanceof TerrainProviderEdit) {
      modelEditPolygons = (terrain.provider as TerrainProviderEdit).modelEditPolygons;
    }

    const credentials = this.siteQuery.getViewerCredentials();
    const generateRequest = (positions: number[][]): WebSocketMessage<WebSocketSamplingRequestData> => {
      const data: WebSocketSamplingRequestData = {
        partId: generateUniqueId(),
        url: terrain.url,
        urlParams: {
          Policy: credentials.policy,
          Signature: credentials.signature,
          'Key-Pair-Id': credentials.keyPairId
        },
        positions,
        modelEditPolygons
      };
      return {
        action: 'sampleTerrain',
        data
      };
    };

    const requestMapping: { [partId: string]: WebSocketRequestMetadata } = {};
    const positionChunks = this.chunkPositionsByTile(positions, terrain, WEBSOCKET_MAX_CHUNK_SIZE);
    positionChunks.forEach((positionsChunk, index) => {
      const requestPositions = positionsChunk.map(p => {
        const pDeg = GeoUtils.cartesian3ToDeg(p);
        return [roundTo(pDeg.longitude, 7), roundTo(pDeg.latitude, 7)];
      });

      const request = generateRequest(requestPositions);
      requestMapping[request.data.partId] = { request, index, retries: 0, isDone: false };

      this.websocketService.sendMessage(request).subscribe({
        error: error => {
          console.error('Error sending sampling message', request, error);
          webSocketDoneSubject.error(new SamplingError());
        }
      });
    });

    const elevationMapping = new Map<string, number>();
    const webSocketDoneSubject = new Subject<Map<string, number>>();
    const webSocketObservable = this.websocketService.observe<WebSocketSamplingResponseData>('sampleTerrain');
    const webSocketSub = webSocketObservable.subscribe({
      next: (samplingResponse: WebSocketMessage<WebSocketSamplingResponseData>) => {
        if (isDefined(samplingResponse?.errorMessage)) {
          const errorResponse = samplingResponse;
          console.error('Error while terrain sampling:', errorResponse);
          webSocketDoneSubject.error(new SamplingError());
          return;
        }

        const responseData = samplingResponse?.data;
        const requestMetadata = requestMapping[responseData?.partId];
        if (!requestMetadata) {
          // Skip unknown responses
          return;
        }

        requestMetadata.isDone = true;

        if (responseData.elevations.some(height => !isDefined(height))) {
          console.error('Error while terrain sampling: Some requests returned no elevations');

          if (requestMetadata.retries < SAMPLING_RETRY_COUNT) {
            // Retry request
            const retryRequest = generateRequest(requestMetadata.request.data.positions);
            requestMapping[retryRequest.data.partId] = {
              request: retryRequest,
              index: requestMetadata.index,
              retries: requestMetadata.retries + 1,
              isDone: false
            };

            this.websocketService.sendMessage(retryRequest).subscribe({
              error: error => {
                console.error('Error sending sampling message', retryRequest, error);
                webSocketDoneSubject.error(new SamplingError());
              }
            });
          } else {
            // Done retrying - raise error
            webSocketDoneSubject.error(new SamplingError());
          }

          return;
        }

        positionChunks[requestMetadata.index].forEach((p, i) => {
          elevationMapping.set([p.x, p.y].join(','), responseData.elevations[i]);
        });

        // Check if all requests are done - close connection and finish sampling
        if (Object.values(requestMapping).every(requestMetadata => requestMetadata.isDone)) {
          webSocketSub.unsubscribe();

          webSocketDoneSubject.next(elevationMapping);
          webSocketDoneSubject.complete();
        }
      },
      error: error => {
        console.error('Error while terrain sampling', error);
        webSocketDoneSubject.error(new SamplingError());
      }
    });

    return lastValueFrom(
      webSocketDoneSubject.pipe(
        timeout({
          first: WEBSOCKET_TIMEOUT_MS,
          with: () => {
            console.error(`Reached timeout of ${WEBSOCKET_TIMEOUT_MS} ms waiting for sampling results`);
            webSocketSub.unsubscribe();
            return throwError(() => new SamplingError());
          }
        })
      )
    );
  };

  private httpSampling = (method: 'lambda' | 'directly') => async (terrain: Terrain, positions: Cartesian3[]) => {
    if (!isDefined(positions)) {
      return;
    }

    const samplingMethod = method === 'directly' ? this.sampleTerrainDirectly : this.sampleTerrainByLambda;

    const elevationMapping = new Map<string, number>();
    const positionChunks = this.chunkPositionsByTile(positions, terrain, S3_BUCKET_ACCESS_CHUNK_SIZE);
    const samplingRequests = positionChunks.map(async positionsChunk => {
      const elevationResults = await samplingMethod(positionsChunk, terrain);

      positionsChunk.forEach((p, i) => {
        elevationMapping.set([p.x, p.y].join(','), elevationResults[i]);
      });
    });

    await Promise.all(samplingRequests);

    return elevationMapping;
  };

  private sampleTerrainDirectly = async (positions: Cartesian3[], terrain: Terrain) => {
    const positionsCartographicRad: Cartographic[] = positions.map(p => Cesium.Cartographic.fromCartesian(p));

    // Don't access S3 more than the cap so there won't be any errors
    const positionChunks = chunk(positionsCartographicRad, S3_BUCKET_ACCESS_CHUNK_SIZE);
    for (const chunk of positionChunks) {
      await Cesium.sampleTerrainMostDetailed(terrain.provider, chunk);
    }

    // Retry samples with errors
    for (let i = 0; i < SAMPLING_RETRY_COUNT; i++) {
      const failedList = positionsCartographicRad.filter(p => !isDefined(p.height));
      if (failedList.length === 0) {
        break;
      }

      // Wait before retrying with exponential backoff
      await sleep(i ** 2 * 100);

      await Cesium.sampleTerrainMostDetailed(terrain.provider, failedList);
    }

    // If we still have errors after all retries - reject
    if (positionsCartographicRad.some(p => !isDefined(p.height))) {
      console.error('Error while terrain sampling: Some requests returned no elevations');
      throw new SamplingError();
    }

    return positionsCartographicRad.map(p => p.height);
  };

  private sampleTerrainByLambda = async (positions: Cartesian3[], terrain: Terrain) => {
    const credentials = this.siteQuery.getViewerCredentials();
    let modelEditPolygons: number[][][];
    if (terrain.provider instanceof TerrainProviderEdit) {
      modelEditPolygons = (terrain.provider as TerrainProviderEdit).modelEditPolygons;
    }

    const requestPositions = positions.map(p => {
      const pDeg = GeoUtils.cartesian3ToDeg(p);
      return [roundTo(pDeg.longitude, 7), roundTo(pDeg.latitude, 7)];
    });

    const request: SamplingRequest = {
      url: terrain.url,
      urlParams: {
        Policy: credentials.policy,
        Signature: credentials.signature,
        'Key-Pair-Id': credentials.keyPairId
      },
      positions: requestPositions,
      modelEditPolygons
    };
    let samplingResult: SamplingResponse;
    try {
      const gzippedRequest = await gzip(request);
      samplingResult = await lastValueFrom(
        this.http
          .post<SamplingResponse>(`${environment.lambdaGatewayUrl}/sampleterrain`, gzippedRequest, {
            headers: { ...GZIP_HEADERS }
          })
          .pipe(
            retry({
              count: SAMPLING_RETRY_COUNT,
              delay: (err: HttpErrorResponse, retryCount: number) => {
                // Lambda timout error
                if (err.status === 504) {
                  console.warn('Lambda error', err);
                  return of(true).pipe(delay(retryCount ** 2 * 100));
                } else {
                  throw err;
                }
              }
            })
          )
      );
    } catch (e) {
      console.error('Error sampling using lambda', e);
      throw new SamplingError();
    }

    if (!samplingResult?.elevations) {
      console.error(`Error while terrain sampling: failed after ${SAMPLING_RETRY_COUNT} retries`);
      throw new SamplingError();
    }
    if (samplingResult.elevations.some(height => !isDefined(height))) {
      console.error('Error while terrain sampling: Some requests returned no elevations');
      throw new SamplingError();
    }

    return samplingResult.elevations;
  };

  private chunkPositionsByTile(positions: Cartesian3[], terrain: Terrain, maxChunkSize: number) {
    if (positions.length < maxChunkSize) {
      return [positions];
    }

    const { tilingScheme, availability } = terrain.provider;
    if (!tilingScheme || !availability) {
      return chunk(positions, maxChunkSize);
    }

    // Group positions by the tile their in
    const tileMapping: Record<string, Cartesian3[]> = {};
    positions.forEach(p => {
      const positionCartographic = Cesium.Cartographic.fromCartesian(p);
      const level = availability.computeMaximumLevelAtPosition(positionCartographic);
      const tileXY = tilingScheme.positionToTileXY(positionCartographic, level);

      // We're not using the tile's y value for the distinct ID for the tile since our S3 prefixes don't take them into account
      const tileId = [level, tileXY.x].join(',');

      if (!(tileId in tileMapping)) {
        tileMapping[tileId] = [];
      }
      tileMapping[tileId].push(p);
    });

    const resultChunks = this.splitChunksEvenlyBySize(tileMapping, maxChunkSize);
    return resultChunks;
  }

  private splitChunksEvenlyBySize(tileMapping: Record<string, Cartesian3[]>, maxChunkSize: number) {
    const result: Cartesian3[][] = [];

    const positionsByTile = Object.values(tileMapping).sort((list1, list2) => list2.length - list1.length);
    positionsByTile.forEach(tilePositions => {
      // Positions from single tile are more than max chunk size - chunk them according to max, put them before last chunk and continue to process current tile last chunk
      if (tilePositions.length > maxChunkSize) {
        const singleTileChunks = chunk(tilePositions, maxChunkSize);
        tilePositions = singleTileChunks.at(-1);
        result.splice(result.length - 1, 0, ...singleTileChunks.slice(0, -1));
      }

      const lastChunk = result.at(-1);
      if (lastChunk && lastChunk.length + tilePositions.length < maxChunkSize) {
        lastChunk.push(...tilePositions);
      } else {
        result.push(tilePositions);
      }
    });

    return result;
  }
}
