import { AxisScale } from "@visx/axis"
import { interpolate } from "d3-interpolate"
import { identity } from "lodash"

import { clamper } from "../../support/plotting"
import { Range } from "../../support/types"
import { BaseScaleConfig } from "@visx/scale/lib/types/BaseScaleConfig"
import { FREQ_DOMAIN } from "../../models/Equalizer"
import { ScaleContinuousNumeric } from "d3-scale"

const deinterpolateLinear = (a: number, b: number) => {
  a = Number(a)
  b -= a
  if (b) {
    return (x: number) => (x - a) / b
  }
  return () => b
}

const bimap = (
  domain: Range,
  range: Range,
  deinterpolate: (a: number, b: number) => (x: number) => number,
  reinterpolate: (a: number, b: number) => (x: number) => number,
) => {
  let d0: any = domain[0]
  let r0: any = range[0]
  const d1 = domain[1]
  const r1 = range[1]

  if (d1 < d0) {
    d0 = deinterpolate(d1, d0)
    r0 = reinterpolate(r1, r0)
  } else {
    d0 = deinterpolate(d0, d1)
    r0 = reinterpolate(r0, r1)
  }

  return (x: number) => r0(d0(x))
}

// ROOT_OCTAVE comes from the calculation:
// currentOctave = 1000 // we want to represent octaves above and below 1khz, so we want to have a very small octave of 1khz that we can use as the smallest number
// while(true) {
//    currentOctave /= 2
//    if (currentOctave < 1) break
// }
const ROOT_OCTAVE = 1.953125

const findLowerOctave = (value: number) => {
  const x = Math.floor(value / ROOT_OCTAVE)

  // eslint-disable-next-line no-bitwise
  return (1 << (31 - Math.clz32(x))) * ROOT_OCTAVE
}

// deinterpolate(a, b)(x) takes a domain value x in [a,b] and returns the corresponding parameter t in [0,1].
const deinterpolate = (a: number, b: number) => {
  const octaveRange = Math.log2(b / a)
  const octaveWidth = 1 / octaveRange

  return (value: number) => {
    if (value < ROOT_OCTAVE) value = ROOT_OCTAVE
    const lowerOctave = findLowerOctave(value)
    const translationForOctave = Math.log2(lowerOctave / a) / octaveRange
    const posF = (value - lowerOctave) / lowerOctave

    return translationForOctave + posF * octaveWidth
  }
}

// reinterpolate(a, b)(t) takes a parameter t in [0,1] and returns the corresponding domain value x in [a,b].
const reinterpolate = (a: number, b: number) => {
  const octaveRange = Math.log2(b / a)

  return (value: number) => {
    const representedLogValue = a * 2 ** (value * octaveRange)
    const lowerOctave = findLowerOctave(representedLogValue)
    const translationForLowerOctave = deinterpolate(a, b)(lowerOctave)
    const translationForUpperOctave = deinterpolate(a, b)(lowerOctave * 2)

    return (
      lowerOctave +
      (lowerOctave * (value - translationForLowerOctave)) /
        (translationForUpperOctave - translationForLowerOctave)
    )
  }
}

type AudiogramScaleConfig = Pick<
  BaseScaleConfig<number, Range, Range>,
  "domain" | "range" | "clamp" | "unknown"
>

type AudiogramScale = ScaleContinuousNumeric<number, number> & AxisScale<number>

const audiogramScale = ({
  domain = FREQ_DOMAIN,
  range = [0, 1],
  clamp = false,
  unknown = 0,
}: AudiogramScaleConfig) => {
  const clampFn = clamp
    ? clamper(domain[0], domain[Math.min(domain.length, range.length) - 1])
    : identity

  const scale = (x: any) =>
    x == null || isNaN(Number(x))
      ? unknown
      : clampFn(bimap(domain, range, deinterpolate, interpolate)(x))

  const copy = (...extra: any[]): AudiogramScale =>
    audiogramScale({ domain, range, clamp, ...extra })

  const getOrSet =
    (get: () => any, set: (...args: any[]) => any) =>
    (...args: any[]) => {
      if (args.length === 0) {
        return get()
      }
      return set(...args)
    }

  scale.copy = copy

  scale.invert = (y: number) =>
    clampFn(bimap(range, domain, deinterpolateLinear, reinterpolate)(y)) ??
    unknown

  scale.ticks = (count: number) => {
    return [
      125, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 6000, 8000, 12000,
    ].slice(0, count)
  }

  scale.domain = getOrSet(
    () => domain.slice() as Range,
    (domain) => copy({ domain }),
  )

  scale.range = getOrSet(
    () => range.slice(),
    (range) => copy({ range }),
  )

  scale.clamp = getOrSet(
    () => clamp,
    (clamp) => copy({ clamp }),
  )

  // scale.round = notImplemented
  // scale.rangeRound = notImplemented
  // scale.padding = notImplemented
  // scale.rangeRound = notImplemented
  // scale.align = notImplemented
  // scale.bandwidth = notImplemented
  // scale.step = notImplemented
  // scale.interpolate = notImplemented
  // scale.unknown = notImplemented
  // scale.tickFormat = notImplemented
  // scale.nice = notImplemented

  return scale as AudiogramScale
}

export default audiogramScale
